1use std::{
2 hash::{DefaultHasher, Hash, Hasher},
3 sync::Arc,
4};
5
6use candle_core::{Device, Result};
7use indexmap::IndexMap;
8use itertools::Itertools;
9use tracing::info;
10
11use crate::{
12 get_mut_arcmutex,
13 paged_attention::{BlockEngine, LogicalTokenBlock, PhysicalTokenBlock},
14 pipeline::KvCache,
15 sequence::{self, Sequence},
16};
17
18type BlockBestMatch<'a> = (
19 usize, &'a [LogicalTokenBlock], &'a [Arc<PhysicalTokenBlock>], usize, usize, );
25
26fn hash_logical_blocks(logical_blocks: &[LogicalTokenBlock]) -> Vec<u64> {
27 logical_blocks
28 .iter()
29 .map(|block| {
30 let mut hasher = DefaultHasher::new();
31 block.hash(&mut hasher);
32 hasher.finish()
33 })
34 .collect::<Vec<_>>()
35}
36
37#[derive(PartialEq, Eq, Debug, Hash)]
38struct Tokens(Vec<u32>);
39
40impl Tokens {
41 fn shared_prefix_len(&self, other: &Self) -> usize {
43 self.0
44 .iter()
45 .zip(other.0.iter())
46 .take_while(|(a, b)| a == b)
47 .count()
48 }
49}
50
51impl From<Vec<u32>> for Tokens {
52 fn from(value: Vec<u32>) -> Self {
53 Self(value)
54 }
55}
56
57#[derive(Clone)]
58struct CacheElement {
59 cache: Vec<Option<KvCache>>,
60 audio_hashes: Option<Vec<u64>>,
61 image_hashes: Option<Vec<u64>>,
62}
63
64#[derive(Clone)]
65struct BlockCacheElement {
66 logical_blocks: Vec<LogicalTokenBlock>,
67 physical_blocks: Vec<Arc<PhysicalTokenBlock>>,
68 image_hashes: Option<Vec<u64>>,
69 audio_hashes: Option<Vec<u64>>,
70}
71
72pub struct PrefixCacheManagerV2 {
73 caches: IndexMap<Tokens, CacheElement>,
74 block_caches: IndexMap<Vec<u64>, BlockCacheElement>, n_on_device: usize,
76 no_prefix_cache: bool,
77 block_engine: Option<Arc<tokio::sync::Mutex<BlockEngine>>>,
78}
79
80#[derive(Clone)]
81pub enum MatchingCache {
82 Normal {
83 normal: Vec<Option<KvCache>>,
84 images_to_keep: usize,
85 audios_to_keep: usize,
86 toks: Vec<u32>,
87 offset: usize,
88 },
89 Paged {
90 logical_blocks: Vec<LogicalTokenBlock>,
91 physical_blocks: Vec<Arc<PhysicalTokenBlock>>,
92 toks: Vec<u32>,
93 offset: usize,
94 images_to_keep: usize,
95 audios_to_keep: usize,
96 },
97}
98
99impl PrefixCacheManagerV2 {
100 pub fn new(
101 n_on_device: usize,
102 no_prefix_cache: bool,
103 block_engine: Option<Arc<tokio::sync::Mutex<BlockEngine>>>,
104 ) -> Self {
105 if !no_prefix_cache {
106 info!("PrefixCacherV2 is enabled. Expect higher multi-turn throughput for both text and multimodal.");
107 }
108 PrefixCacheManagerV2 {
109 caches: IndexMap::new(),
110 block_caches: IndexMap::new(),
111 n_on_device,
112 no_prefix_cache,
113 block_engine,
114 }
115 }
116
117 pub fn add_sequence(&mut self, seq: &mut Sequence) {
119 if self.no_prefix_cache {
121 return;
122 }
123
124 if let Some(_block_engine) = &self.block_engine {
125 } else {
146 let cache = seq.normal_cache().to_vec();
147
148 self.caches.insert(
149 seq.get_toks().to_vec().into(),
150 CacheElement {
151 cache,
152 image_hashes: seq.image_hashes().map(|x| x.to_vec()),
153 audio_hashes: seq.audio_hashes().map(|x| x.to_vec()),
154 },
155 );
156 }
157 }
158
159 pub fn evict_caches(&mut self) -> Result<usize> {
162 if self.no_prefix_cache {
163 return Ok(0);
164 }
165 let mut n_on_device = 0;
166 for cache in self.caches.values() {
167 let first_non_none = cache.cache.iter().find_or_first(|x| x.is_some());
168 let Some(Some(first_non_none)) = first_non_none else {
169 continue;
170 };
171
172 let cache_device = match first_non_none {
173 KvCache::Normal { k, .. } => {
174 k.all_data().as_ref().expect("No KV cache data").device()
175 }
176 KvCache::Rotating { k, .. } => {
177 k.all_data().as_ref().expect("No KV cache data").device()
178 }
179 };
180
181 if !matches!(cache_device, Device::Cpu) {
182 n_on_device += 1;
183 }
184 }
185 for cache in self.block_caches.values() {
187 if !cache.physical_blocks.is_empty() {
188 n_on_device += 1;
189 }
190 }
191 let mut n_evicted = 0;
192 for cache in self.caches.values_mut() {
194 if n_on_device - n_evicted <= self.n_on_device {
195 break;
196 }
197 let first_non_none = cache.cache.iter().find_or_first(|x| x.is_some());
198 let Some(Some(first_non_none)) = first_non_none else {
199 continue;
200 };
201
202 let cache_device = match first_non_none {
203 KvCache::Normal { k, .. } => {
204 k.all_data().as_ref().expect("No KV cache data").device()
205 }
206 KvCache::Rotating { k, .. } => {
207 k.all_data().as_ref().expect("No KV cache data").device()
208 }
209 };
210
211 if !matches!(cache_device, Device::Cpu) {
212 cache.cache.clear();
213 n_evicted += 1;
214 }
215 }
216
217 for cache in self.block_caches.values_mut() {
219 if n_on_device - n_evicted <= self.n_on_device {
220 break;
221 }
222 if !cache.physical_blocks.is_empty() {
223 for block in &cache.physical_blocks {
226 block.deref_mut().decrement_refcount();
227 }
228 cache.physical_blocks.clear();
229 n_evicted += 1;
230 }
231 }
232
233 self.caches.retain(|_tokens, cache| !cache.cache.is_empty());
234 self.block_caches
235 .retain(|_key, cache| !cache.physical_blocks.is_empty());
236
237 Ok(n_evicted)
238 }
239
240 pub fn evict_all_caches(&mut self) -> Result<usize> {
242 let len = self.caches.len();
243
244 self.caches.clear();
245
246 for cache in self.block_caches.values_mut() {
247 for block in &cache.physical_blocks {
248 block.deref_mut().decrement_refcount();
249 }
250 }
251 self.block_caches.clear();
252 Ok(len)
253 }
254
255 pub fn search_for_matching_cache(
257 &mut self,
258 toks: &[u32],
259 image_hashes: Option<&[u64]>,
260 audio_hashes: Option<&[u64]>,
261 ) -> Result<Option<MatchingCache>> {
262 if self.no_prefix_cache || toks.is_empty() {
264 return Ok(None);
265 }
266
267 if let Some(block_engine) = &self.block_engine {
268 let block_engine = get_mut_arcmutex!(block_engine);
269 let block_size = block_engine.block_size();
270 let mut test_logical_blocks = Vec::new();
271 for tok in toks {
272 sequence::util_append_token_to_blocks(
273 *tok as usize,
274 &mut test_logical_blocks,
275 block_size,
276 );
277 }
278 let hashed_logical_blocks = hash_logical_blocks(&test_logical_blocks);
279
280 let mut best_match: Option<BlockBestMatch> = None;
281 for (logical, cache_elem) in &self.block_caches {
282 let logical_matches_until = logical
283 .iter()
284 .zip(&hashed_logical_blocks)
285 .take_while(|(a, b)| **a == **b)
286 .count();
287 let matched_len: usize = cache_elem.logical_blocks[..logical_matches_until]
288 .iter()
289 .map(|block| block.num_tokens())
290 .sum();
291
292 let images_match_until = if let (Some(input_hashes), Some(cached_hashes)) =
293 (image_hashes, cache_elem.image_hashes.as_ref())
294 {
295 input_hashes
296 .iter()
297 .zip(cached_hashes)
298 .take_while(|(a, b)| a == b)
299 .count()
300 } else {
301 0
302 };
303
304 let audios_match_until = if let (Some(input_hashes), Some(cached_hashes)) =
305 (audio_hashes, cache_elem.audio_hashes.as_ref())
306 {
307 input_hashes
308 .iter()
309 .zip(cached_hashes)
310 .take_while(|(a, b)| a == b)
311 .count()
312 } else {
313 0
314 };
315
316 if best_match
317 .is_some_and(|(best_match_len, _, _, _, _)| best_match_len < matched_len)
318 || best_match.is_none()
319 {
320 best_match = Some((
321 matched_len,
322 &cache_elem.logical_blocks,
323 &cache_elem.physical_blocks,
324 images_match_until,
325 audios_match_until,
326 ))
327 }
328 }
329
330 let Some((
331 match_len,
332 logical_blocks,
333 physical_blocks,
334 images_match_until,
335 audios_match_until,
336 )) = best_match
337 else {
338 return Ok(None);
339 };
340
341 let mut n_blocks = match_len.div_ceil(block_size);
343 n_blocks = n_blocks.min(logical_blocks.len());
344
345 if n_blocks == 0 {
346 return Ok(None);
347 }
348
349 let mut logical_prefix = logical_blocks[..n_blocks].to_vec();
351 let physical_prefix = physical_blocks[..n_blocks].to_vec();
352 for block in &physical_prefix {
353 block.deref_mut().increment_refcount();
354 }
355
356 let new_toks = toks[match_len..].to_vec();
358 logical_prefix.push(LogicalTokenBlock::new(block_size));
359 for tok in &new_toks {
360 sequence::util_append_token_to_blocks(
361 *tok as usize,
362 &mut logical_prefix,
363 block_size,
364 );
365 }
366 if logical_prefix.last().is_some_and(|last| last.is_full()) {
367 logical_prefix.push(LogicalTokenBlock::new(block_size));
368 }
369 let images_to_keep = if let Some(input_hashes) = image_hashes {
370 input_hashes.len().saturating_sub(images_match_until)
371 } else {
372 0
373 };
374 let audios_to_keep = if let Some(input_hashes) = audio_hashes {
375 input_hashes.len().saturating_sub(audios_match_until)
376 } else {
377 0
378 };
379 return Ok(Some(MatchingCache::Paged {
380 logical_blocks: logical_prefix,
381 physical_blocks: physical_prefix,
382 toks: new_toks,
383 offset: match_len,
384 images_to_keep,
385 audios_to_keep,
386 }));
387 }
388
389 let toks = Tokens(toks.to_vec());
390
391 let mut best_match: Option<(usize, &CacheElement, usize, usize)> = None;
392 for (k, v) in &self.caches {
393 let match_len = toks.shared_prefix_len(k);
394 if match_len == 0 {
395 continue;
396 }
397
398 let images_match_until = match image_hashes {
399 Some(input_hashes) => match &v.image_hashes {
400 Some(cached_hashes) => input_hashes
401 .iter()
402 .zip(cached_hashes)
403 .take_while(|(a, b)| a == b)
404 .count(),
405 None => 0,
406 },
407 None => 0,
408 };
409
410 let audios_match_until = match audio_hashes {
411 Some(input_hashes) => match &v.audio_hashes {
412 Some(cached_hashes) => input_hashes
413 .iter()
414 .zip(cached_hashes)
415 .take_while(|(a, b)| a == b)
416 .count(),
417 None => 0,
418 },
419 None => 0,
420 };
421
422 if best_match
423 .as_ref()
424 .is_none_or(|(len, _, _, _)| match_len > *len)
425 {
426 best_match = Some((match_len, v, images_match_until, audios_match_until));
427 }
428 }
429
430 if let Some((match_len, cache_element, images_match_until, audios_match_until)) = best_match
431 {
432 let new_toks = toks.0[match_len..].to_vec();
433 if new_toks.is_empty() {
434 return Ok(None);
435 }
436
437 let mut cache = cache_element.clone();
438 let images_to_keep = if let Some(input_hashes) = image_hashes {
440 input_hashes.len().saturating_sub(images_match_until)
441 } else {
442 0
443 };
444 let audios_to_keep = if let Some(input_hashes) = audio_hashes {
445 input_hashes.len().saturating_sub(audios_match_until)
446 } else {
447 0
448 };
449 for layer in cache.cache.iter_mut().flatten() {
450 if layer.try_set_len(match_len).is_err() {
451 return Ok(None);
452 }
453 }
454 for layer in cache.cache.iter_mut().flatten() {
455 layer.set_len(match_len)?;
456 }
457 return Ok(Some(MatchingCache::Normal {
458 normal: cache.cache,
459 images_to_keep,
460 audios_to_keep,
461 toks: new_toks,
462 offset: match_len,
463 }));
464 }
465
466 Ok(None)
467 }
468}