mistralrs_core/dummy_paged_attention/
cache_engine.rs

1use std::{
2    collections::HashMap,
3    str::FromStr,
4    sync::{Arc, Mutex, MutexGuard},
5};
6
7use candle_core::{DType, Device, Result, Tensor};
8use serde::{Deserialize, Serialize};
9
10use super::config::ModelConfigLike;
11
12#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
13#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
14pub enum PagedCacheType {
15    #[default]
16    Auto,
17    F8E4M3,
18}
19
20impl PagedCacheType {
21    pub fn to_dtype(&self, act_dtype: DType) -> DType {
22        match self {
23            PagedCacheType::F8E4M3 => DType::F8E4M3,
24            PagedCacheType::Auto => act_dtype,
25        }
26    }
27}
28
29impl FromStr for PagedCacheType {
30    type Err = String;
31    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
32        match s {
33            "auto" => Ok(Self::Auto),
34            "f8e4m3" => Ok(Self::F8E4M3),
35            other => Err(format!(
36                "Unexpected `PagedCacheType`, got `{other}` but expected `auto` and `f8e4m3`."
37            )),
38        }
39    }
40}
41
42#[derive(Clone, Debug)]
43pub struct CacheConfig {
44    pub block_size: usize,
45    pub num_gpu_blocks: usize,
46    pub num_cpu_blocks: usize,
47    pub cache_type: PagedCacheType,
48}
49
50pub type KVCache = (Tensor, Tensor);
51
52pub struct CacheEngine {
53    dummy_cache: Arc<Mutex<Vec<KVCache>>>,
54}
55
56impl CacheEngine {
57    pub fn new(
58        _model_config: &dyn ModelConfigLike,
59        _cache_config: &CacheConfig,
60        _dtype: DType,
61        _device: &Device,
62        _layer_devices: Vec<Option<Device>>,
63    ) -> Result<Self> {
64        Ok(Self {
65            dummy_cache: Arc::new(Mutex::new(Vec::new())),
66        })
67    }
68
69    pub fn get_kv_cache(&self) -> MutexGuard<'_, Vec<KVCache>> {
70        loop {
71            if let Ok(v) = self.dummy_cache.try_lock() {
72                return v;
73            }
74        }
75    }
76}
77
78impl CacheEngine {
79    pub fn execute_scheduler_ops(
80        &self,
81        blocks_to_swap_in: &HashMap<usize, usize>,
82        blocks_to_swap_out: &HashMap<usize, usize>,
83        blocks_to_copy: &HashMap<usize, Vec<usize>>,
84    ) -> Result<()> {
85        if !blocks_to_swap_in.is_empty() {
86            self.swap_in(blocks_to_swap_in)?;
87        }
88        if !blocks_to_swap_out.is_empty() {
89            self.swap_out(blocks_to_swap_out)?;
90        }
91        if !blocks_to_copy.is_empty() {
92            self.copy(blocks_to_copy)?;
93        }
94        Ok(())
95    }
96
97    pub fn swap_in(&self, _src_to_dst: &HashMap<usize, usize>) -> Result<()> {
98        Ok(())
99    }
100
101    pub fn swap_out(&self, _src_to_dst: &HashMap<usize, usize>) -> Result<()> {
102        Ok(())
103    }
104
105    pub fn copy(&self, _src_to_dst: &HashMap<usize, Vec<usize>>) -> Result<()> {
106        Ok(())
107    }
108}