mistralrs_core/
prefix_cacher.rs

1use std::{
2    hash::{DefaultHasher, Hash, Hasher},
3    sync::Arc,
4};
5
6use candle_core::{Device, Result};
7use indexmap::IndexMap;
8use itertools::Itertools;
9use tracing::info;
10
11use crate::{
12    get_mut_arcmutex,
13    paged_attention::{BlockEngine, LogicalTokenBlock, PhysicalTokenBlock},
14    pipeline::KvCache,
15    sequence::{self, Sequence},
16};
17
18type BlockBestMatch<'a> = (
19    usize,                         // matched_len
20    &'a [LogicalTokenBlock],       // logical blocks
21    &'a [Arc<PhysicalTokenBlock>], // physical blocks
22    usize,                         // audios_match_until
23    usize,                         // images_match_until
24);
25
26fn hash_logical_blocks(logical_blocks: &[LogicalTokenBlock]) -> Vec<u64> {
27    logical_blocks
28        .iter()
29        .map(|block| {
30            let mut hasher = DefaultHasher::new();
31            block.hash(&mut hasher);
32            hasher.finish()
33        })
34        .collect::<Vec<_>>()
35}
36
37#[derive(PartialEq, Eq, Debug, Hash)]
38struct Tokens(Vec<u32>);
39
40impl Tokens {
41    /// Returns the length of the common prefix shared with `other`.
42    fn shared_prefix_len(&self, other: &Self) -> usize {
43        self.0
44            .iter()
45            .zip(other.0.iter())
46            .take_while(|(a, b)| a == b)
47            .count()
48    }
49}
50
51impl From<Vec<u32>> for Tokens {
52    fn from(value: Vec<u32>) -> Self {
53        Self(value)
54    }
55}
56
57#[derive(Clone)]
58struct CacheElement {
59    cache: Vec<Option<KvCache>>,
60    audio_hashes: Option<Vec<u64>>,
61    image_hashes: Option<Vec<u64>>,
62}
63
64#[derive(Clone)]
65struct BlockCacheElement {
66    logical_blocks: Vec<LogicalTokenBlock>,
67    physical_blocks: Vec<Arc<PhysicalTokenBlock>>,
68    image_hashes: Option<Vec<u64>>,
69    audio_hashes: Option<Vec<u64>>,
70}
71
72pub struct PrefixCacheManagerV2 {
73    caches: IndexMap<Tokens, CacheElement>,
74    block_caches: IndexMap<Vec<u64>, BlockCacheElement>, // (hashed logical blocks) => BlockCacheElement
75    n_on_device: usize,
76    no_prefix_cache: bool,
77    block_engine: Option<Arc<tokio::sync::Mutex<BlockEngine>>>,
78}
79
80#[derive(Clone)]
81pub enum MatchingCache {
82    Normal {
83        normal: Vec<Option<KvCache>>,
84        images_to_keep: usize,
85        audios_to_keep: usize,
86        toks: Vec<u32>,
87        offset: usize,
88    },
89    Paged {
90        logical_blocks: Vec<LogicalTokenBlock>,
91        physical_blocks: Vec<Arc<PhysicalTokenBlock>>,
92        toks: Vec<u32>,
93        offset: usize,
94        images_to_keep: usize,
95        audios_to_keep: usize,
96    },
97}
98
99impl PrefixCacheManagerV2 {
100    pub fn new(
101        n_on_device: usize,
102        no_prefix_cache: bool,
103        block_engine: Option<Arc<tokio::sync::Mutex<BlockEngine>>>,
104    ) -> Self {
105        if !no_prefix_cache {
106            info!("PrefixCacherV2 is enabled. Expect higher multi-turn throughput for both text and multimodal.");
107        }
108        PrefixCacheManagerV2 {
109            caches: IndexMap::new(),
110            block_caches: IndexMap::new(),
111            n_on_device,
112            no_prefix_cache,
113            block_engine,
114        }
115    }
116
117    /// This always keeps the cache on the device.
118    pub fn add_sequence(&mut self, seq: &mut Sequence) {
119        // Do not cache if prefix caching disabled
120        if self.no_prefix_cache {
121            return;
122        }
123
124        if let Some(_block_engine) = &self.block_engine {
125            // let logical_token_blocks = seq.logical_token_blocks();
126            // let block_engine = get_mut_arcmutex!(block_engine);
127            // let block_table = &block_engine.block_tables[seq.id()];
128            // let hashed_logical_blocks = hash_logical_blocks(logical_token_blocks);
129
130            // if !self.block_caches.contains_key(&hashed_logical_blocks) {
131            //     for block in block_table {
132            //         block.deref_mut().increment_refcount();
133            //     }
134
135            //     self.block_caches.insert(
136            //         hashed_logical_blocks,
137            //         BlockCacheElement {
138            //             logical_blocks: logical_token_blocks.to_vec(),
139            //             physical_blocks: block_table.clone(),
140            //             image_hashes: seq.image_hashes().map(|x| x.to_vec()),
141            //             audio_hashes: seq.audio_hashes().map(|x| x.to_vec()),
142            //         },
143            //     );
144            // }
145        } else {
146            let cache = seq.normal_cache().to_vec();
147
148            self.caches.insert(
149                seq.get_toks().to_vec().into(),
150                CacheElement {
151                    cache,
152                    image_hashes: seq.image_hashes().map(|x| x.to_vec()),
153                    audio_hashes: seq.audio_hashes().map(|x| x.to_vec()),
154                },
155            );
156        }
157    }
158
159    /// Evict the caches. This will evict the first k seqs such that the number of sequences on device after the copy is
160    /// the maximum allowed. Returns the number of evicted sequences.
161    pub fn evict_caches(&mut self) -> Result<usize> {
162        if self.no_prefix_cache {
163            return Ok(0);
164        }
165        let mut n_on_device = 0;
166        for cache in self.caches.values() {
167            let first_non_none = cache.cache.iter().find_or_first(|x| x.is_some());
168            let Some(Some(first_non_none)) = first_non_none else {
169                continue;
170            };
171
172            let cache_device = match first_non_none {
173                KvCache::Normal { k, .. } => {
174                    k.all_data().as_ref().expect("No KV cache data").device()
175                }
176                KvCache::Rotating { k, .. } => {
177                    k.all_data().as_ref().expect("No KV cache data").device()
178                }
179            };
180
181            if !matches!(cache_device, Device::Cpu) {
182                n_on_device += 1;
183            }
184        }
185        // Count block‑caches that still reside on‑device.
186        for cache in self.block_caches.values() {
187            if !cache.physical_blocks.is_empty() {
188                n_on_device += 1;
189            }
190        }
191        let mut n_evicted = 0;
192        // Intentionally evict the first ones first, as they are the oldest
193        for cache in self.caches.values_mut() {
194            if n_on_device - n_evicted <= self.n_on_device {
195                break;
196            }
197            let first_non_none = cache.cache.iter().find_or_first(|x| x.is_some());
198            let Some(Some(first_non_none)) = first_non_none else {
199                continue;
200            };
201
202            let cache_device = match first_non_none {
203                KvCache::Normal { k, .. } => {
204                    k.all_data().as_ref().expect("No KV cache data").device()
205                }
206                KvCache::Rotating { k, .. } => {
207                    k.all_data().as_ref().expect("No KV cache data").device()
208                }
209            };
210
211            if !matches!(cache_device, Device::Cpu) {
212                cache.cache.clear();
213                n_evicted += 1;
214            }
215        }
216
217        // Now evict block‑caches if we still exceed the on‑device limit.
218        for cache in self.block_caches.values_mut() {
219            if n_on_device - n_evicted <= self.n_on_device {
220                break;
221            }
222            if !cache.physical_blocks.is_empty() {
223                // Drop our strong references and decrement ref‑counts so the
224                // BlockEngine can reclaim the KV blocks.
225                for block in &cache.physical_blocks {
226                    block.deref_mut().decrement_refcount();
227                }
228                cache.physical_blocks.clear();
229                n_evicted += 1;
230            }
231        }
232
233        self.caches.retain(|_tokens, cache| !cache.cache.is_empty());
234        self.block_caches
235            .retain(|_key, cache| !cache.physical_blocks.is_empty());
236
237        Ok(n_evicted)
238    }
239
240    /// Evict all the caches.
241    pub fn evict_all_caches(&mut self) -> Result<usize> {
242        let len = self.caches.len();
243
244        self.caches.clear();
245
246        for cache in self.block_caches.values_mut() {
247            for block in &cache.physical_blocks {
248                block.deref_mut().decrement_refcount();
249            }
250        }
251        self.block_caches.clear();
252        Ok(len)
253    }
254
255    /// Search for a matching cache given some tokens. Image-containing sequences are now cached too.
256    pub fn search_for_matching_cache(
257        &mut self,
258        toks: &[u32],
259        image_hashes: Option<&[u64]>,
260        audio_hashes: Option<&[u64]>,
261    ) -> Result<Option<MatchingCache>> {
262        // Do not search if prefix caching disabled or no tokens
263        if self.no_prefix_cache || toks.is_empty() {
264            return Ok(None);
265        }
266
267        if let Some(block_engine) = &self.block_engine {
268            let block_engine = get_mut_arcmutex!(block_engine);
269            let block_size = block_engine.block_size();
270            let mut test_logical_blocks = Vec::new();
271            for tok in toks {
272                sequence::util_append_token_to_blocks(
273                    *tok as usize,
274                    &mut test_logical_blocks,
275                    block_size,
276                );
277            }
278            let hashed_logical_blocks = hash_logical_blocks(&test_logical_blocks);
279
280            let mut best_match: Option<BlockBestMatch> = None;
281            for (logical, cache_elem) in &self.block_caches {
282                let logical_matches_until = logical
283                    .iter()
284                    .zip(&hashed_logical_blocks)
285                    .take_while(|(a, b)| **a == **b)
286                    .count();
287                let matched_len: usize = cache_elem.logical_blocks[..logical_matches_until]
288                    .iter()
289                    .map(|block| block.num_tokens())
290                    .sum();
291
292                let images_match_until = if let (Some(input_hashes), Some(cached_hashes)) =
293                    (image_hashes, cache_elem.image_hashes.as_ref())
294                {
295                    input_hashes
296                        .iter()
297                        .zip(cached_hashes)
298                        .take_while(|(a, b)| a == b)
299                        .count()
300                } else {
301                    0
302                };
303
304                let audios_match_until = if let (Some(input_hashes), Some(cached_hashes)) =
305                    (audio_hashes, cache_elem.audio_hashes.as_ref())
306                {
307                    input_hashes
308                        .iter()
309                        .zip(cached_hashes)
310                        .take_while(|(a, b)| a == b)
311                        .count()
312                } else {
313                    0
314                };
315
316                if best_match
317                    .is_some_and(|(best_match_len, _, _, _, _)| best_match_len < matched_len)
318                    || best_match.is_none()
319                {
320                    best_match = Some((
321                        matched_len,
322                        &cache_elem.logical_blocks,
323                        &cache_elem.physical_blocks,
324                        images_match_until,
325                        audios_match_until,
326                    ))
327                }
328            }
329
330            let Some((
331                match_len,
332                logical_blocks,
333                physical_blocks,
334                images_match_until,
335                audios_match_until,
336            )) = best_match
337            else {
338                return Ok(None);
339            };
340
341            // Determine how many blocks cover the matched prefix
342            let mut n_blocks = match_len.div_ceil(block_size);
343            n_blocks = n_blocks.min(logical_blocks.len());
344
345            if n_blocks == 0 {
346                return Ok(None);
347            }
348
349            // Take the first n_blocks of both logical and physical blocks
350            let mut logical_prefix = logical_blocks[..n_blocks].to_vec();
351            let physical_prefix = physical_blocks[..n_blocks].to_vec();
352            for block in &physical_prefix {
353                block.deref_mut().increment_refcount();
354            }
355
356            // If the last reused block is full, reserve an extra empty block for new tokens
357            let new_toks = toks[match_len..].to_vec();
358            logical_prefix.push(LogicalTokenBlock::new(block_size));
359            for tok in &new_toks {
360                sequence::util_append_token_to_blocks(
361                    *tok as usize,
362                    &mut logical_prefix,
363                    block_size,
364                );
365            }
366            if logical_prefix.last().is_some_and(|last| last.is_full()) {
367                logical_prefix.push(LogicalTokenBlock::new(block_size));
368            }
369            let images_to_keep = if let Some(input_hashes) = image_hashes {
370                input_hashes.len().saturating_sub(images_match_until)
371            } else {
372                0
373            };
374            let audios_to_keep = if let Some(input_hashes) = audio_hashes {
375                input_hashes.len().saturating_sub(audios_match_until)
376            } else {
377                0
378            };
379            return Ok(Some(MatchingCache::Paged {
380                logical_blocks: logical_prefix,
381                physical_blocks: physical_prefix,
382                toks: new_toks,
383                offset: match_len,
384                images_to_keep,
385                audios_to_keep,
386            }));
387        }
388
389        let toks = Tokens(toks.to_vec());
390
391        let mut best_match: Option<(usize, &CacheElement, usize, usize)> = None;
392        for (k, v) in &self.caches {
393            let match_len = toks.shared_prefix_len(k);
394            if match_len == 0 {
395                continue;
396            }
397
398            let images_match_until = match image_hashes {
399                Some(input_hashes) => match &v.image_hashes {
400                    Some(cached_hashes) => input_hashes
401                        .iter()
402                        .zip(cached_hashes)
403                        .take_while(|(a, b)| a == b)
404                        .count(),
405                    None => 0,
406                },
407                None => 0,
408            };
409
410            let audios_match_until = match audio_hashes {
411                Some(input_hashes) => match &v.audio_hashes {
412                    Some(cached_hashes) => input_hashes
413                        .iter()
414                        .zip(cached_hashes)
415                        .take_while(|(a, b)| a == b)
416                        .count(),
417                    None => 0,
418                },
419                None => 0,
420            };
421
422            if best_match
423                .as_ref()
424                .is_none_or(|(len, _, _, _)| match_len > *len)
425            {
426                best_match = Some((match_len, v, images_match_until, audios_match_until));
427            }
428        }
429
430        if let Some((match_len, cache_element, images_match_until, audios_match_until)) = best_match
431        {
432            let new_toks = toks.0[match_len..].to_vec();
433            if new_toks.is_empty() {
434                return Ok(None);
435            }
436
437            let mut cache = cache_element.clone();
438            // Count how many input images are not already cached
439            let images_to_keep = if let Some(input_hashes) = image_hashes {
440                input_hashes.len().saturating_sub(images_match_until)
441            } else {
442                0
443            };
444            let audios_to_keep = if let Some(input_hashes) = audio_hashes {
445                input_hashes.len().saturating_sub(audios_match_until)
446            } else {
447                0
448            };
449            for layer in cache.cache.iter_mut().flatten() {
450                if layer.try_set_len(match_len).is_err() {
451                    return Ok(None);
452                }
453            }
454            for layer in cache.cache.iter_mut().flatten() {
455                layer.set_len(match_len)?;
456            }
457            return Ok(Some(MatchingCache::Normal {
458                normal: cache.cache,
459                images_to_keep,
460                audios_to_keep,
461                toks: new_toks,
462                offset: match_len,
463            }));
464        }
465
466        Ok(None)
467    }
468}