mistralrs_core/
prefix_cacher.rs

1use std::collections::HashMap;
2
3use candle_core::{Device, Result};
4use either::Either;
5use itertools::Itertools;
6use tracing::info;
7
8use crate::{
9    pipeline::{KvCache, RotatingCache, SingleCache},
10    sequence::Sequence,
11};
12
13#[derive(PartialEq, Eq, Debug, Hash)]
14struct Tokens(Vec<u32>);
15
16impl Tokens {
17    /// Maximum index to where the two sets of tokens match
18    fn find_max_index(&self, x: &Self) -> Option<usize> {
19        self.0
20            .iter()
21            .zip(x.0.iter())
22            .enumerate()
23            .take_while(|(_, (a, b))| a == b)
24            .map(|(index, _)| index)
25            .last()
26    }
27}
28
29impl From<Vec<u32>> for Tokens {
30    fn from(value: Vec<u32>) -> Self {
31        Self(value)
32    }
33}
34
35#[derive(Clone)]
36struct CacheElement {
37    cache: Vec<Option<KvCache>>,
38    devices: Vec<Option<Device>>,
39}
40
41pub struct PrefixCacheManagerV2 {
42    caches: HashMap<Tokens, CacheElement>,
43    n_on_device: usize,
44    no_prefix_cache: bool,
45}
46
47#[derive(Clone)]
48pub struct MatchingCache {
49    pub normal: Vec<Option<KvCache>>,
50    pub toks: Vec<u32>,
51    pub offset: usize,
52}
53
54impl PrefixCacheManagerV2 {
55    pub fn new(n_on_device: usize, no_prefix_cache: bool) -> Self {
56        if !no_prefix_cache {
57            info!("PrefixCacherV2 is enabled! Expect higher multi-turn prompt throughput.");
58        }
59        PrefixCacheManagerV2 {
60            caches: HashMap::new(),
61            n_on_device,
62            no_prefix_cache,
63        }
64    }
65
66    /// This always keeps the cache on the device.
67    pub fn add_sequence(&mut self, seq: &mut Sequence) {
68        if self.no_prefix_cache || seq.has_images() {
69            return;
70        }
71        let cache = seq.normal_cache().to_vec();
72        let devices = cache
73            .iter()
74            .map(|x| x.as_ref().map(|x| x.k().unwrap().unwrap().device().clone()))
75            .collect::<Vec<_>>();
76        self.caches.insert(
77            seq.get_toks().to_vec().into(),
78            CacheElement { cache, devices },
79        );
80    }
81
82    fn cache_to(
83        cache: &mut [Option<KvCache>],
84        devices: Either<&Device, &Vec<Option<Device>>>,
85    ) -> Result<()> {
86        for (i, layer) in cache
87            .iter_mut()
88            .enumerate()
89            .flat_map(|(i, x)| x.as_mut().map(|x| (i, x)))
90        {
91            let device = devices.left_or_else(|layers| layers[i].as_ref().unwrap());
92
93            match layer {
94                KvCache::Normal { k, v } => {
95                    *layer = KvCache::Normal {
96                        k: SingleCache {
97                            all_data: k.all_data.as_ref().map(|x| x.to_device(device).unwrap()),
98                            dim: k.dim,
99                            current_seq_len: k.current_seq_len,
100                            max_seq_len: k.max_seq_len,
101                            capacity_seq_len: k.capacity_seq_len,
102                        },
103                        v: SingleCache {
104                            all_data: v.all_data.as_ref().map(|x| x.to_device(device).unwrap()),
105                            dim: v.dim,
106                            current_seq_len: v.current_seq_len,
107                            max_seq_len: v.max_seq_len,
108                            capacity_seq_len: v.capacity_seq_len,
109                        },
110                    }
111                }
112                KvCache::Rotating { k, v } => {
113                    *layer = KvCache::Rotating {
114                        k: RotatingCache {
115                            all_data: k.all_data.as_ref().map(|x| x.to_device(device).unwrap()),
116                            dim: k.dim,
117                            current_seq_len: k.current_seq_len,
118                            max_seq_len: k.max_seq_len,
119                            offset: k.offset,
120                            capacity_seq_len: k.capacity_seq_len,
121                        },
122                        v: RotatingCache {
123                            all_data: v.all_data.as_ref().map(|x| x.to_device(device).unwrap()),
124                            dim: v.dim,
125                            current_seq_len: v.current_seq_len,
126                            max_seq_len: v.max_seq_len,
127                            offset: v.offset,
128                            capacity_seq_len: v.capacity_seq_len,
129                        },
130                    }
131                }
132            }
133        }
134        Ok(())
135    }
136
137    /// Evict the caches to CPU. This will evict the first k seqs such that the number of sequences on device after the copy is
138    /// the maximum allowed. Returns the number of evicted sequences.
139    pub fn evict_to_cpu(&mut self) -> Result<usize> {
140        if self.no_prefix_cache {
141            return Ok(0);
142        }
143        let mut n_on_device = 0;
144        for cache in self.caches.values() {
145            let first_non_none = cache.cache.iter().find_or_first(|x| x.is_some());
146            let Some(Some(first_non_none)) = first_non_none else {
147                continue;
148            };
149
150            let cache_device = match first_non_none {
151                KvCache::Normal { k, .. } => {
152                    k.all_data().as_ref().expect("No KV cache data").device()
153                }
154                KvCache::Rotating { k, .. } => {
155                    k.all_data().as_ref().expect("No KV cache data").device()
156                }
157            };
158
159            if !matches!(cache_device, Device::Cpu) {
160                n_on_device += 1;
161            }
162        }
163        let mut n_evicted = 0;
164        // Intentionally evict the first ones first, as they are the oldest
165        for cache in self.caches.values_mut() {
166            if n_on_device - n_evicted == self.n_on_device {
167                break;
168            }
169            let first_non_none = cache.cache.iter().find_or_first(|x| x.is_some());
170            let Some(Some(first_non_none)) = first_non_none else {
171                continue;
172            };
173
174            let cache_device = match first_non_none {
175                KvCache::Normal { k, .. } => {
176                    k.all_data().as_ref().expect("No KV cache data").device()
177                }
178                KvCache::Rotating { k, .. } => {
179                    k.all_data().as_ref().expect("No KV cache data").device()
180                }
181            };
182
183            if !matches!(cache_device, Device::Cpu) {
184                Self::cache_to(&mut cache.cache, Either::Left(&Device::Cpu))?;
185                n_evicted += 1;
186            }
187        }
188        Ok(self.caches.len().saturating_sub(self.n_on_device))
189    }
190
191    /// Evict all the caches to CPU.
192    pub fn evict_all_to_cpu(&mut self) -> Result<usize> {
193        if self.no_prefix_cache {
194            return Ok(0);
195        }
196        // Intentionally evict the first ones first, as they are the oldest
197        for cache in self.caches.values_mut() {
198            let first_non_none = cache.cache.iter().find_or_first(|x| x.is_some());
199            let Some(Some(first_non_none)) = first_non_none else {
200                continue;
201            };
202
203            let cache_device = match first_non_none {
204                KvCache::Normal { k, .. } => {
205                    k.all_data().as_ref().expect("No KV cache data").device()
206                }
207                KvCache::Rotating { k, .. } => {
208                    k.all_data().as_ref().expect("No KV cache data").device()
209                }
210            };
211
212            if !matches!(cache_device, Device::Cpu) {
213                Self::cache_to(&mut cache.cache, Either::Left(&Device::Cpu))?;
214            }
215        }
216        Ok(self.caches.len())
217    }
218
219    /// Search for a matching cache given some toks
220    pub fn search_for_matching_cache(
221        &mut self,
222        toks: &[u32],
223        contains_images: bool,
224    ) -> Result<Option<MatchingCache>> {
225        if self.no_prefix_cache || toks.is_empty() || contains_images {
226            return Ok(None);
227        }
228
229        let toks = Tokens(toks.to_vec());
230
231        let mut longest_match = (0, None);
232        for (k, v) in self.caches.iter() {
233            let match_len = toks.find_max_index(k);
234            if let Some(match_len) = match_len {
235                if match_len > longest_match.0 {
236                    longest_match = (match_len, Some(v));
237                }
238            }
239        }
240        if let (match_len, Some(longest_match)) = longest_match {
241            let mut cache = longest_match.clone();
242            Self::cache_to(&mut cache.cache, Either::Right(&cache.devices))?;
243            for layer in cache.cache.iter_mut().flatten() {
244                match layer.set_len(match_len) {
245                    Ok(_) => (),
246                    Err(_) => return Ok(None),
247                }
248            }
249            Ok(Some(MatchingCache {
250                normal: cache.cache,
251                toks: toks.0[match_len..].to_vec(),
252                offset: match_len,
253            }))
254        } else {
255            Ok(None)
256        }
257    }
258}