mistralrs_core/paged_attention/
cache_engine.rs1use std::{
2 collections::HashMap,
3 str::FromStr,
4 sync::{Arc, Mutex, MutexGuard},
5};
6
7use candle_core::{DType, Device, Result, Tensor};
8#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
9use mistralrs_paged_attn::copy_blocks;
10use serde::{Deserialize, Serialize};
11
12use super::config::ModelConfigLike;
13
14#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
15#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
16pub enum PagedCacheType {
17 #[default]
18 Auto,
19 F8E4M3,
20}
21
22impl PagedCacheType {
23 pub fn to_dtype(&self, act_dtype: DType) -> DType {
24 match self {
25 PagedCacheType::F8E4M3 => DType::F8E4M3,
26 PagedCacheType::Auto => act_dtype,
27 }
28 }
29}
30
31impl FromStr for PagedCacheType {
32 type Err = String;
33 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
34 match s {
35 "auto" => Ok(Self::Auto),
36 "f8e4m3" => Ok(Self::F8E4M3),
37 other => Err(format!(
38 "Unexpected `PagedCacheType`, got `{other}` but expected `auto` and `f8e4m3`."
39 )),
40 }
41 }
42}
43
44#[derive(Clone, Debug)]
45pub struct CacheConfig {
46 pub block_size: usize,
47 pub num_gpu_blocks: usize,
48 pub cache_type: PagedCacheType,
49}
50
51pub type KVCache = (Tensor, Tensor);
52
53pub struct CacheEngine {
54 gpu_cache: Arc<Mutex<Vec<KVCache>>>,
55}
56
57impl CacheEngine {
58 pub fn new(
59 model_config: &dyn ModelConfigLike,
60 cache_config: &CacheConfig,
61 dtype: DType,
62 device: &Device,
63 layer_devices: Vec<Option<Device>>,
64 ) -> Result<Self> {
65 let dtype = cache_config.cache_type.to_dtype(dtype);
66 Ok(Self {
67 gpu_cache: Arc::new(Mutex::new(Self::allocate_gpu_cache(
68 model_config,
69 cache_config,
70 dtype,
71 device,
72 layer_devices,
73 )?)),
74 })
75 }
76
77 pub fn get_kv_cache(&self) -> MutexGuard<'_, Vec<KVCache>> {
78 loop {
79 if let Ok(v) = self.gpu_cache.try_lock() {
80 return v;
81 }
82 }
83 }
84
85 fn allocate_gpu_cache(
86 model_config: &dyn ModelConfigLike,
87 cache_config: &CacheConfig,
88 dtype: DType,
89 device: &Device,
90 layer_devices: Vec<Option<Device>>,
91 ) -> Result<Vec<KVCache>> {
92 let key_block_shape =
93 Self::calculate_key_block_shape(model_config, dtype, cache_config.block_size);
94 let value_block_shape =
95 Self::calculate_value_block_shape(model_config, cache_config.block_size);
96 let mut gpu_cache = Vec::new();
97
98 for device in layer_devices
99 .iter()
100 .take(model_config.num_layers())
101 .map(|x| x.as_ref().unwrap_or(device))
102 {
103 #[allow(unused)]
104 let key_blocks = if let Device::Metal(dev) = &device {
105 #[cfg(feature = "metal")]
106 {
107 use candle_core::{from_storage_no_op, MetalStorage, Shape, Storage};
108
109 let elem_count = cache_config.num_gpu_blocks
110 * key_block_shape.0
111 * key_block_shape.1
112 * key_block_shape.2
113 * key_block_shape.3;
114 let buffer = dev.new_buffer_private(elem_count, dtype, "k_cache")?;
115 let storage =
116 Storage::Metal(MetalStorage::new(buffer, dev.clone(), elem_count, dtype));
117 from_storage_no_op(
118 storage,
119 Shape::from_dims(&[
120 cache_config.num_gpu_blocks,
121 key_block_shape.0,
122 key_block_shape.1,
123 key_block_shape.2,
124 key_block_shape.3,
125 ]),
126 false,
127 )
128 }
129
130 #[cfg(not(feature = "metal"))]
131 {
132 unreachable!()
133 }
134 } else {
135 unsafe {
136 Tensor::empty(
137 (
138 cache_config.num_gpu_blocks,
139 key_block_shape.0,
140 key_block_shape.1,
141 key_block_shape.2,
142 key_block_shape.3,
143 ),
144 dtype,
145 device,
146 )?
147 }
148 };
149 #[allow(unused)]
150 let value_blocks = if let Device::Metal(dev) = &device {
151 #[cfg(feature = "metal")]
152 {
153 use candle_core::{from_storage_no_op, MetalStorage, Shape, Storage};
154
155 let elem_count = cache_config.num_gpu_blocks
156 * value_block_shape.0
157 * value_block_shape.1
158 * value_block_shape.2;
159 let buffer = dev.new_buffer_private(elem_count, dtype, "v_cache")?;
160 let storage =
161 Storage::Metal(MetalStorage::new(buffer, dev.clone(), elem_count, dtype));
162 from_storage_no_op(
163 storage,
164 Shape::from_dims(&[
165 cache_config.num_gpu_blocks,
166 value_block_shape.0,
167 value_block_shape.1,
168 value_block_shape.2,
169 ]),
170 false,
171 )
172 }
173
174 #[cfg(not(feature = "metal"))]
175 {
176 unreachable!()
177 }
178 } else {
179 unsafe {
180 Tensor::empty(
181 (
182 cache_config.num_gpu_blocks,
183 value_block_shape.0,
184 value_block_shape.1,
185 value_block_shape.2,
186 ),
187 dtype,
188 device,
189 )?
190 }
191 };
192 gpu_cache.push((key_blocks, value_blocks));
193 }
194 Ok(gpu_cache)
195 }
196
197 fn calculate_key_block_shape(
198 model_config: &dyn ModelConfigLike,
199 dtype: DType,
200 block_size: usize,
201 ) -> (usize, usize, usize, usize) {
202 let element_size = dtype.size_in_bytes();
203 let x = 16 / element_size;
204 (
205 model_config.num_kv_heads(),
206 model_config.k_head_dim() / x,
207 block_size,
208 x,
209 )
210 }
211
212 fn calculate_value_block_shape(
213 model_config: &dyn ModelConfigLike,
214 block_size: usize,
215 ) -> (usize, usize, usize) {
216 (
217 model_config.num_kv_heads(),
218 model_config.v_head_dim(),
219 block_size,
220 )
221 }
222}
223
224impl CacheEngine {
225 pub fn execute_scheduler_ops(&self, blocks_to_copy: &HashMap<usize, Vec<usize>>) -> Result<()> {
226 if !blocks_to_copy.is_empty() {
227 self.copy(blocks_to_copy)?;
228 }
229 Ok(())
230 }
231
232 pub fn copy(&self, src_to_dst: &HashMap<usize, Vec<usize>>) -> Result<()> {
233 #[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
234 {
235 let mut gpu_cache = self.get_kv_cache();
236 #[allow(clippy::map_identity)]
237 let caches: (Vec<&mut Tensor>, Vec<&mut Tensor>) =
238 gpu_cache.iter_mut().map(|(a, b)| (a, b)).unzip();
239 let (key_caches, value_caches) = caches;
240
241 copy_blocks(key_caches, value_caches, src_to_dst)?;
243 Ok(())
244 }
245 #[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
246 {
247 let _ = src_to_dst;
248 candle_core::bail!("Paged attention requires the CUDA or Metal feature flags.");
249 }
250 }
251}