mistralrs_core/paged_attention/
cache_engine.rs

1use std::{
2    collections::HashMap,
3    str::FromStr,
4    sync::{Arc, Mutex, MutexGuard},
5};
6
7use candle_core::{DType, Device, Result, Tensor};
8#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
9use mistralrs_paged_attn::copy_blocks;
10use serde::{Deserialize, Serialize};
11
12use super::config::ModelConfigLike;
13
14#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
15#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
16pub enum PagedCacheType {
17    #[default]
18    Auto,
19    F8E4M3,
20}
21
22impl PagedCacheType {
23    pub fn to_dtype(&self, act_dtype: DType) -> DType {
24        match self {
25            PagedCacheType::F8E4M3 => DType::F8E4M3,
26            PagedCacheType::Auto => act_dtype,
27        }
28    }
29}
30
31impl FromStr for PagedCacheType {
32    type Err = String;
33    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
34        match s {
35            "auto" => Ok(Self::Auto),
36            "f8e4m3" => Ok(Self::F8E4M3),
37            other => Err(format!(
38                "Unexpected `PagedCacheType`, got `{other}` but expected `auto` and `f8e4m3`."
39            )),
40        }
41    }
42}
43
44#[derive(Clone, Debug)]
45pub struct CacheConfig {
46    pub block_size: usize,
47    pub num_gpu_blocks: usize,
48    pub cache_type: PagedCacheType,
49}
50
51pub type KVCache = (Tensor, Tensor);
52
53pub struct CacheEngine {
54    gpu_cache: Arc<Mutex<Vec<KVCache>>>,
55}
56
57impl CacheEngine {
58    pub fn new(
59        model_config: &dyn ModelConfigLike,
60        cache_config: &CacheConfig,
61        dtype: DType,
62        device: &Device,
63        layer_devices: Vec<Option<Device>>,
64    ) -> Result<Self> {
65        let dtype = cache_config.cache_type.to_dtype(dtype);
66        Ok(Self {
67            gpu_cache: Arc::new(Mutex::new(Self::allocate_gpu_cache(
68                model_config,
69                cache_config,
70                dtype,
71                device,
72                layer_devices,
73            )?)),
74        })
75    }
76
77    pub fn get_kv_cache(&self) -> MutexGuard<'_, Vec<KVCache>> {
78        loop {
79            if let Ok(v) = self.gpu_cache.try_lock() {
80                return v;
81            }
82        }
83    }
84
85    fn allocate_gpu_cache(
86        model_config: &dyn ModelConfigLike,
87        cache_config: &CacheConfig,
88        dtype: DType,
89        device: &Device,
90        layer_devices: Vec<Option<Device>>,
91    ) -> Result<Vec<KVCache>> {
92        let key_block_shape =
93            Self::calculate_key_block_shape(model_config, dtype, cache_config.block_size);
94        let value_block_shape =
95            Self::calculate_value_block_shape(model_config, cache_config.block_size);
96        let mut gpu_cache = Vec::new();
97
98        for device in layer_devices
99            .iter()
100            .take(model_config.num_layers())
101            .map(|x| x.as_ref().unwrap_or(device))
102        {
103            #[allow(unused)]
104            let key_blocks = if let Device::Metal(dev) = &device {
105                #[cfg(feature = "metal")]
106                {
107                    use candle_core::{from_storage_no_op, MetalStorage, Shape, Storage};
108
109                    let elem_count = cache_config.num_gpu_blocks
110                        * key_block_shape.0
111                        * key_block_shape.1
112                        * key_block_shape.2
113                        * key_block_shape.3;
114                    let buffer = dev.new_buffer_private(elem_count, dtype, "k_cache")?;
115                    let storage =
116                        Storage::Metal(MetalStorage::new(buffer, dev.clone(), elem_count, dtype));
117                    from_storage_no_op(
118                        storage,
119                        Shape::from_dims(&[
120                            cache_config.num_gpu_blocks,
121                            key_block_shape.0,
122                            key_block_shape.1,
123                            key_block_shape.2,
124                            key_block_shape.3,
125                        ]),
126                        false,
127                    )
128                }
129
130                #[cfg(not(feature = "metal"))]
131                {
132                    unreachable!()
133                }
134            } else {
135                unsafe {
136                    Tensor::empty(
137                        (
138                            cache_config.num_gpu_blocks,
139                            key_block_shape.0,
140                            key_block_shape.1,
141                            key_block_shape.2,
142                            key_block_shape.3,
143                        ),
144                        dtype,
145                        device,
146                    )?
147                }
148            };
149            #[allow(unused)]
150            let value_blocks = if let Device::Metal(dev) = &device {
151                #[cfg(feature = "metal")]
152                {
153                    use candle_core::{from_storage_no_op, MetalStorage, Shape, Storage};
154
155                    let elem_count = cache_config.num_gpu_blocks
156                        * value_block_shape.0
157                        * value_block_shape.1
158                        * value_block_shape.2;
159                    let buffer = dev.new_buffer_private(elem_count, dtype, "v_cache")?;
160                    let storage =
161                        Storage::Metal(MetalStorage::new(buffer, dev.clone(), elem_count, dtype));
162                    from_storage_no_op(
163                        storage,
164                        Shape::from_dims(&[
165                            cache_config.num_gpu_blocks,
166                            value_block_shape.0,
167                            value_block_shape.1,
168                            value_block_shape.2,
169                        ]),
170                        false,
171                    )
172                }
173
174                #[cfg(not(feature = "metal"))]
175                {
176                    unreachable!()
177                }
178            } else {
179                unsafe {
180                    Tensor::empty(
181                        (
182                            cache_config.num_gpu_blocks,
183                            value_block_shape.0,
184                            value_block_shape.1,
185                            value_block_shape.2,
186                        ),
187                        dtype,
188                        device,
189                    )?
190                }
191            };
192            gpu_cache.push((key_blocks, value_blocks));
193        }
194        Ok(gpu_cache)
195    }
196
197    fn calculate_key_block_shape(
198        model_config: &dyn ModelConfigLike,
199        dtype: DType,
200        block_size: usize,
201    ) -> (usize, usize, usize, usize) {
202        let element_size = dtype.size_in_bytes();
203        let x = 16 / element_size;
204        (
205            model_config.num_kv_heads(),
206            model_config.k_head_dim() / x,
207            block_size,
208            x,
209        )
210    }
211
212    fn calculate_value_block_shape(
213        model_config: &dyn ModelConfigLike,
214        block_size: usize,
215    ) -> (usize, usize, usize) {
216        (
217            model_config.num_kv_heads(),
218            model_config.v_head_dim(),
219            block_size,
220        )
221    }
222}
223
224impl CacheEngine {
225    pub fn execute_scheduler_ops(&self, blocks_to_copy: &HashMap<usize, Vec<usize>>) -> Result<()> {
226        if !blocks_to_copy.is_empty() {
227            self.copy(blocks_to_copy)?;
228        }
229        Ok(())
230    }
231
232    pub fn copy(&self, src_to_dst: &HashMap<usize, Vec<usize>>) -> Result<()> {
233        #[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
234        {
235            let mut gpu_cache = self.get_kv_cache();
236            #[allow(clippy::map_identity)]
237            let caches: (Vec<&mut Tensor>, Vec<&mut Tensor>) =
238                gpu_cache.iter_mut().map(|(a, b)| (a, b)).unzip();
239            let (key_caches, value_caches) = caches;
240
241            // NOTE(EricLBuehler): This may synchronize the CPU and GPU
242            copy_blocks(key_caches, value_caches, src_to_dst)?;
243            Ok(())
244        }
245        #[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
246        {
247            let _ = src_to_dst;
248            candle_core::bail!("Paged attention requires the CUDA or Metal feature flags.");
249        }
250    }
251}