mistralrs_core/dummy_paged_attention/
cache_engine.rs1use std::{
2 collections::HashMap,
3 str::FromStr,
4 sync::{Arc, Mutex, MutexGuard},
5};
6
7use candle_core::{DType, Device, Result, Tensor};
8use serde::{Deserialize, Serialize};
9
10use super::config::ModelConfigLike;
11
12#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
13#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
14pub enum PagedCacheType {
15 #[default]
16 Auto,
17 F8E4M3,
18}
19
20impl PagedCacheType {
21 pub fn to_dtype(&self, act_dtype: DType) -> DType {
22 match self {
23 PagedCacheType::F8E4M3 => DType::F8E4M3,
24 PagedCacheType::Auto => act_dtype,
25 }
26 }
27}
28
29impl FromStr for PagedCacheType {
30 type Err = String;
31 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
32 match s {
33 "auto" => Ok(Self::Auto),
34 "f8e4m3" => Ok(Self::F8E4M3),
35 other => Err(format!(
36 "Unexpected `PagedCacheType`, got `{other}` but expected `auto` and `f8e4m3`."
37 )),
38 }
39 }
40}
41
42#[derive(Clone, Debug)]
43pub struct CacheConfig {
44 pub block_size: usize,
45 pub num_gpu_blocks: usize,
46 pub num_cpu_blocks: usize,
47 pub cache_type: PagedCacheType,
48}
49
50pub type KVCache = (Tensor, Tensor);
51
52pub struct CacheEngine {
53 dummy_cache: Arc<Mutex<Vec<KVCache>>>,
54}
55
56impl CacheEngine {
57 pub fn new(
58 _model_config: &dyn ModelConfigLike,
59 _cache_config: &CacheConfig,
60 _dtype: DType,
61 _device: &Device,
62 _layer_devices: Vec<Option<Device>>,
63 ) -> Result<Self> {
64 Ok(Self {
65 dummy_cache: Arc::new(Mutex::new(Vec::new())),
66 })
67 }
68
69 pub fn get_kv_cache(&self) -> MutexGuard<'_, Vec<KVCache>> {
70 loop {
71 if let Ok(v) = self.dummy_cache.try_lock() {
72 return v;
73 }
74 }
75 }
76}
77
78impl CacheEngine {
79 pub fn execute_scheduler_ops(
80 &self,
81 blocks_to_swap_in: &HashMap<usize, usize>,
82 blocks_to_swap_out: &HashMap<usize, usize>,
83 blocks_to_copy: &HashMap<usize, Vec<usize>>,
84 ) -> Result<()> {
85 if !blocks_to_swap_in.is_empty() {
86 self.swap_in(blocks_to_swap_in)?;
87 }
88 if !blocks_to_swap_out.is_empty() {
89 self.swap_out(blocks_to_swap_out)?;
90 }
91 if !blocks_to_copy.is_empty() {
92 self.copy(blocks_to_copy)?;
93 }
94 Ok(())
95 }
96
97 pub fn swap_in(&self, _src_to_dst: &HashMap<usize, usize>) -> Result<()> {
98 Ok(())
99 }
100
101 pub fn swap_out(&self, _src_to_dst: &HashMap<usize, usize>) -> Result<()> {
102 Ok(())
103 }
104
105 pub fn copy(&self, _src_to_dst: &HashMap<usize, Vec<usize>>) -> Result<()> {
106 Ok(())
107 }
108}