mistralrs_core/paged_attention/
cache_engine.rs1use std::{
2 collections::HashMap,
3 str::FromStr,
4 sync::{Arc, Mutex, MutexGuard},
5};
6
7use candle_core::{DType, Device, Result, Tensor};
8#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
9use mistralrs_paged_attn::copy_blocks;
10use serde::{Deserialize, Serialize};
11
12use super::config::ModelConfigLike;
13
14#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
15#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
16pub enum PagedCacheType {
17 #[default]
18 Auto,
19 F8E4M3,
20}
21
22impl PagedCacheType {
23 pub fn to_dtype(&self, act_dtype: DType) -> DType {
24 match self {
25 PagedCacheType::F8E4M3 => DType::F8E4M3,
26 PagedCacheType::Auto => act_dtype,
27 }
28 }
29}
30
31impl FromStr for PagedCacheType {
32 type Err = String;
33 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
34 match s {
35 "auto" => Ok(Self::Auto),
36 "f8e4m3" => Ok(Self::F8E4M3),
37 other => Err(format!(
38 "Unexpected `PagedCacheType`, got `{other}` but expected `auto` and `f8e4m3`."
39 )),
40 }
41 }
42}
43
44#[derive(Clone, Debug)]
45pub struct CacheConfig {
46 pub block_size: usize,
47 pub num_gpu_blocks: usize,
48 pub cache_type: PagedCacheType,
49}
50
51pub type KVCache = (Tensor, Tensor);
52
53pub struct CacheEngine {
54 gpu_cache: Arc<Mutex<Vec<KVCache>>>,
55}
56
57impl CacheEngine {
58 pub fn new(
59 model_config: &dyn ModelConfigLike,
60 cache_config: &CacheConfig,
61 dtype: DType,
62 device: &Device,
63 layer_devices: Vec<Option<Device>>,
64 ) -> Result<Self> {
65 let dtype = cache_config.cache_type.to_dtype(dtype);
66 Ok(Self {
67 gpu_cache: Arc::new(Mutex::new(Self::allocate_gpu_cache(
68 model_config,
69 cache_config,
70 dtype,
71 device,
72 layer_devices,
73 )?)),
74 })
75 }
76
77 pub fn get_kv_cache(&self) -> MutexGuard<'_, Vec<KVCache>> {
78 self.gpu_cache.lock().expect("KV cache mutex was poisoned")
81 }
82
83 fn allocate_gpu_cache(
84 model_config: &dyn ModelConfigLike,
85 cache_config: &CacheConfig,
86 dtype: DType,
87 device: &Device,
88 layer_devices: Vec<Option<Device>>,
89 ) -> Result<Vec<KVCache>> {
90 let key_block_shape =
91 Self::calculate_key_block_shape(model_config, dtype, cache_config.block_size);
92 let value_block_shape =
93 Self::calculate_value_block_shape(model_config, cache_config.block_size);
94 let mut gpu_cache = Vec::new();
95
96 for device in layer_devices
97 .iter()
98 .take(model_config.num_layers())
99 .map(|x| x.as_ref().unwrap_or(device))
100 {
101 #[allow(unused)]
102 let key_blocks = if let Device::Metal(dev) = &device {
103 #[cfg(feature = "metal")]
104 {
105 use candle_core::{MetalStorage, Shape, Storage};
106
107 let elem_count = cache_config.num_gpu_blocks
108 * key_block_shape.0
109 * key_block_shape.1
110 * key_block_shape.2
111 * key_block_shape.3;
112 let buffer = dev.new_private_buffer(elem_count, dtype, "k_cache")?;
113 let storage =
114 Storage::Metal(MetalStorage::new(buffer, dev.clone(), elem_count, dtype));
115 Tensor::from((
116 storage,
117 Shape::from_dims(&[
118 cache_config.num_gpu_blocks,
119 key_block_shape.0,
120 key_block_shape.1,
121 key_block_shape.2,
122 key_block_shape.3,
123 ]),
124 ))
125 }
126
127 #[cfg(not(feature = "metal"))]
128 {
129 unreachable!()
130 }
131 } else {
132 unsafe {
133 Tensor::empty(
134 (
135 cache_config.num_gpu_blocks,
136 key_block_shape.0,
137 key_block_shape.1,
138 key_block_shape.2,
139 key_block_shape.3,
140 ),
141 dtype,
142 device,
143 )?
144 }
145 };
146 #[allow(unused)]
147 let value_blocks = if let Device::Metal(dev) = &device {
148 #[cfg(feature = "metal")]
149 {
150 use candle_core::{MetalStorage, Shape, Storage};
151
152 let elem_count = cache_config.num_gpu_blocks
153 * value_block_shape.0
154 * value_block_shape.1
155 * value_block_shape.2;
156 let buffer = dev.new_private_buffer(elem_count, dtype, "v_cache")?;
157 let storage =
158 Storage::Metal(MetalStorage::new(buffer, dev.clone(), elem_count, dtype));
159 Tensor::from((
160 storage,
161 Shape::from_dims(&[
162 cache_config.num_gpu_blocks,
163 value_block_shape.0,
164 value_block_shape.1,
165 value_block_shape.2,
166 ]),
167 ))
168 }
169
170 #[cfg(not(feature = "metal"))]
171 {
172 unreachable!()
173 }
174 } else {
175 unsafe {
176 Tensor::empty(
177 (
178 cache_config.num_gpu_blocks,
179 value_block_shape.0,
180 value_block_shape.1,
181 value_block_shape.2,
182 ),
183 dtype,
184 device,
185 )?
186 }
187 };
188 gpu_cache.push((key_blocks, value_blocks));
189 }
190 Ok(gpu_cache)
191 }
192
193 fn calculate_key_block_shape(
194 model_config: &dyn ModelConfigLike,
195 dtype: DType,
196 block_size: usize,
197 ) -> (usize, usize, usize, usize) {
198 let element_size = dtype.size_in_bytes();
199 let x = 16 / element_size;
200 (
201 model_config.num_kv_heads(),
202 model_config.k_head_dim() / x,
203 block_size,
204 x,
205 )
206 }
207
208 fn calculate_value_block_shape(
209 model_config: &dyn ModelConfigLike,
210 block_size: usize,
211 ) -> (usize, usize, usize) {
212 (
213 model_config.num_kv_heads(),
214 model_config.v_head_dim(),
215 block_size,
216 )
217 }
218}
219
220impl CacheEngine {
221 pub fn execute_scheduler_ops(&self, blocks_to_copy: &HashMap<usize, Vec<usize>>) -> Result<()> {
222 if !blocks_to_copy.is_empty() {
223 self.copy(blocks_to_copy)?;
224 }
225 Ok(())
226 }
227
228 pub fn copy(&self, src_to_dst: &HashMap<usize, Vec<usize>>) -> Result<()> {
229 #[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
230 {
231 let mut gpu_cache = self.get_kv_cache();
232 #[allow(clippy::map_identity)]
233 let caches: (Vec<&mut Tensor>, Vec<&mut Tensor>) =
234 gpu_cache.iter_mut().map(|(a, b)| (a, b)).unzip();
235 let (key_caches, value_caches) = caches;
236
237 copy_blocks(key_caches, value_caches, src_to_dst)?;
239 Ok(())
240 }
241 #[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
242 {
243 let _ = src_to_dst;
244 candle_core::bail!("Paged attention requires the CUDA or Metal feature flags.");
245 }
246 }
247}