mistralrs_core/paged_attention/
cache_engine.rs

1use std::{
2    collections::HashMap,
3    str::FromStr,
4    sync::{Arc, Mutex, MutexGuard},
5};
6
7use candle_core::{DType, Device, Result, Tensor};
8#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
9use mistralrs_paged_attn::copy_blocks;
10use serde::{Deserialize, Serialize};
11
12use super::config::ModelConfigLike;
13
14#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
15#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
16pub enum PagedCacheType {
17    #[default]
18    Auto,
19    F8E4M3,
20}
21
22impl PagedCacheType {
23    pub fn to_dtype(&self, act_dtype: DType) -> DType {
24        match self {
25            PagedCacheType::F8E4M3 => DType::F8E4M3,
26            PagedCacheType::Auto => act_dtype,
27        }
28    }
29}
30
31impl FromStr for PagedCacheType {
32    type Err = String;
33    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
34        match s {
35            "auto" => Ok(Self::Auto),
36            "f8e4m3" => Ok(Self::F8E4M3),
37            other => Err(format!(
38                "Unexpected `PagedCacheType`, got `{other}` but expected `auto` and `f8e4m3`."
39            )),
40        }
41    }
42}
43
44#[derive(Clone, Debug)]
45pub struct CacheConfig {
46    pub block_size: usize,
47    pub num_gpu_blocks: usize,
48    pub cache_type: PagedCacheType,
49}
50
51pub type KVCache = (Tensor, Tensor);
52
53pub struct CacheEngine {
54    gpu_cache: Arc<Mutex<Vec<KVCache>>>,
55}
56
57impl CacheEngine {
58    pub fn new(
59        model_config: &dyn ModelConfigLike,
60        cache_config: &CacheConfig,
61        dtype: DType,
62        device: &Device,
63        layer_devices: Vec<Option<Device>>,
64    ) -> Result<Self> {
65        let dtype = cache_config.cache_type.to_dtype(dtype);
66        Ok(Self {
67            gpu_cache: Arc::new(Mutex::new(Self::allocate_gpu_cache(
68                model_config,
69                cache_config,
70                dtype,
71                device,
72                layer_devices,
73            )?)),
74        })
75    }
76
77    pub fn get_kv_cache(&self) -> MutexGuard<'_, Vec<KVCache>> {
78        // Use blocking lock instead of busy-wait spin loop to avoid CPU waste
79        // and potential thread starvation issues
80        self.gpu_cache.lock().expect("KV cache mutex was poisoned")
81    }
82
83    fn allocate_gpu_cache(
84        model_config: &dyn ModelConfigLike,
85        cache_config: &CacheConfig,
86        dtype: DType,
87        device: &Device,
88        layer_devices: Vec<Option<Device>>,
89    ) -> Result<Vec<KVCache>> {
90        let key_block_shape =
91            Self::calculate_key_block_shape(model_config, dtype, cache_config.block_size);
92        let value_block_shape =
93            Self::calculate_value_block_shape(model_config, cache_config.block_size);
94        let mut gpu_cache = Vec::new();
95
96        for device in layer_devices
97            .iter()
98            .take(model_config.num_layers())
99            .map(|x| x.as_ref().unwrap_or(device))
100        {
101            #[allow(unused)]
102            let key_blocks = if let Device::Metal(dev) = &device {
103                #[cfg(feature = "metal")]
104                {
105                    use candle_core::{MetalStorage, Shape, Storage};
106
107                    let elem_count = cache_config.num_gpu_blocks
108                        * key_block_shape.0
109                        * key_block_shape.1
110                        * key_block_shape.2
111                        * key_block_shape.3;
112                    let buffer = dev.new_private_buffer(elem_count, dtype, "k_cache")?;
113                    let storage =
114                        Storage::Metal(MetalStorage::new(buffer, dev.clone(), elem_count, dtype));
115                    Tensor::from((
116                        storage,
117                        Shape::from_dims(&[
118                            cache_config.num_gpu_blocks,
119                            key_block_shape.0,
120                            key_block_shape.1,
121                            key_block_shape.2,
122                            key_block_shape.3,
123                        ]),
124                    ))
125                }
126
127                #[cfg(not(feature = "metal"))]
128                {
129                    unreachable!()
130                }
131            } else {
132                unsafe {
133                    Tensor::empty(
134                        (
135                            cache_config.num_gpu_blocks,
136                            key_block_shape.0,
137                            key_block_shape.1,
138                            key_block_shape.2,
139                            key_block_shape.3,
140                        ),
141                        dtype,
142                        device,
143                    )?
144                }
145            };
146            #[allow(unused)]
147            let value_blocks = if let Device::Metal(dev) = &device {
148                #[cfg(feature = "metal")]
149                {
150                    use candle_core::{MetalStorage, Shape, Storage};
151
152                    let elem_count = cache_config.num_gpu_blocks
153                        * value_block_shape.0
154                        * value_block_shape.1
155                        * value_block_shape.2;
156                    let buffer = dev.new_private_buffer(elem_count, dtype, "v_cache")?;
157                    let storage =
158                        Storage::Metal(MetalStorage::new(buffer, dev.clone(), elem_count, dtype));
159                    Tensor::from((
160                        storage,
161                        Shape::from_dims(&[
162                            cache_config.num_gpu_blocks,
163                            value_block_shape.0,
164                            value_block_shape.1,
165                            value_block_shape.2,
166                        ]),
167                    ))
168                }
169
170                #[cfg(not(feature = "metal"))]
171                {
172                    unreachable!()
173                }
174            } else {
175                unsafe {
176                    Tensor::empty(
177                        (
178                            cache_config.num_gpu_blocks,
179                            value_block_shape.0,
180                            value_block_shape.1,
181                            value_block_shape.2,
182                        ),
183                        dtype,
184                        device,
185                    )?
186                }
187            };
188            gpu_cache.push((key_blocks, value_blocks));
189        }
190        Ok(gpu_cache)
191    }
192
193    fn calculate_key_block_shape(
194        model_config: &dyn ModelConfigLike,
195        dtype: DType,
196        block_size: usize,
197    ) -> (usize, usize, usize, usize) {
198        let element_size = dtype.size_in_bytes();
199        let x = 16 / element_size;
200        (
201            model_config.num_kv_heads(),
202            model_config.k_head_dim() / x,
203            block_size,
204            x,
205        )
206    }
207
208    fn calculate_value_block_shape(
209        model_config: &dyn ModelConfigLike,
210        block_size: usize,
211    ) -> (usize, usize, usize) {
212        (
213            model_config.num_kv_heads(),
214            model_config.v_head_dim(),
215            block_size,
216        )
217    }
218}
219
220impl CacheEngine {
221    pub fn execute_scheduler_ops(&self, blocks_to_copy: &HashMap<usize, Vec<usize>>) -> Result<()> {
222        if !blocks_to_copy.is_empty() {
223            self.copy(blocks_to_copy)?;
224        }
225        Ok(())
226    }
227
228    pub fn copy(&self, src_to_dst: &HashMap<usize, Vec<usize>>) -> Result<()> {
229        #[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
230        {
231            let mut gpu_cache = self.get_kv_cache();
232            #[allow(clippy::map_identity)]
233            let caches: (Vec<&mut Tensor>, Vec<&mut Tensor>) =
234                gpu_cache.iter_mut().map(|(a, b)| (a, b)).unzip();
235            let (key_caches, value_caches) = caches;
236
237            // NOTE(EricLBuehler): This may synchronize the CPU and GPU
238            copy_blocks(key_caches, value_caches, src_to_dst)?;
239            Ok(())
240        }
241        #[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
242        {
243            let _ = src_to_dst;
244            candle_core::bail!("Paged attention requires the CUDA or Metal feature flags.");
245        }
246    }
247}