mistralrs_core/pipeline/
cache_manager.rs

1use std::sync::{Arc, Mutex, MutexGuard};
2
3use candle_core::{Result, Tensor, D};
4
5use crate::{get_mut_arcmutex, sequence::Sequence};
6
7use super::{CacheManagerMixin, MetadataMixin};
8
9pub trait CacheManager<T: CacheManagerMixin + MetadataMixin + ?Sized> {
10    fn clone_in_cache(
11        &self,
12        pipeline: &T,
13        seqs: &mut [&mut crate::sequence::Sequence],
14        modify_draft_cache: bool,
15    );
16    fn clone_out_cache(&self, pipeline: &T, seqs: &mut [&mut Sequence], modify_draft_cache: bool);
17    fn set_none_cache(
18        &self,
19        pipeline: &T,
20        seqs: &mut [&mut Sequence],
21        modify_draft_cache: bool,
22        load_preallocated_cache: bool,
23    );
24}
25
26pub type LayerCaches = Vec<Option<(Tensor, Tensor)>>;
27
28#[derive(Debug, Clone)]
29pub enum EitherCache {
30    Normal(Arc<Mutex<NormalCache>>),
31    Full(Cache),
32}
33
34impl EitherCache {
35    /// Panics otherwise!
36    pub fn full(&self) -> &Cache {
37        match self {
38            Self::Full(full) => full,
39            Self::Normal(_) => panic!("Got normal cache, expected full cache."),
40        }
41    }
42    /// Panics otherwise!
43    pub fn normal(&self) -> MutexGuard<'_, NormalCache> {
44        match self {
45            Self::Normal(normal) => normal.lock().unwrap(),
46            Self::Full(_) => panic!("Got full cache, expected normal cache."),
47        }
48    }
49}
50
51#[derive(Debug, Clone)]
52pub struct SingleCache {
53    // all_data is an option on a Tensor, this makes it possible to only create the actual tensor
54    // on the first call where the batch size is easily known.
55    // Also this makes it safe to clone a KvCache that has been reset (as in it will not share
56    // its internal state with the cloned instance).
57    pub all_data: Option<Tensor>,
58    pub dim: usize,
59    pub current_seq_len: usize,
60    pub capacity_seq_len: usize,
61    pub max_seq_len: usize,
62}
63
64impl SingleCache {
65    pub fn new(dim: usize, max_seq_len: usize, capacity_seq_len: usize) -> Self {
66        Self {
67            all_data: None,
68            dim,
69            current_seq_len: 0,
70            max_seq_len,
71            capacity_seq_len,
72        }
73    }
74
75    pub fn dim(&self) -> usize {
76        self.dim
77    }
78
79    pub fn current_seq_len(&self) -> usize {
80        self.current_seq_len
81    }
82
83    pub fn max_seq_len(&self) -> usize {
84        self.max_seq_len
85    }
86
87    pub fn all_data(&self) -> &Option<Tensor> {
88        &self.all_data
89    }
90
91    pub fn current_data(&self) -> Result<Option<Tensor>> {
92        let data = match self.all_data.as_ref() {
93            None => None,
94            Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?),
95        };
96        Ok(data)
97    }
98
99    pub fn reset(&mut self) {
100        self.current_seq_len = 0;
101        self.all_data = None;
102    }
103
104    pub fn set_len(&mut self, len: usize) -> candle_core::Result<()> {
105        self.current_seq_len = len;
106        Ok(())
107    }
108
109    pub fn append(&mut self, src: &Tensor) -> Result<()> {
110        let seq_len = src.dim(self.dim)?;
111        // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
112        // self.all_data.get_or_insert_with.
113        if self.all_data.is_none() {
114            let mut shape = src.dims().to_vec();
115            shape[self.dim] = self.capacity_seq_len;
116            let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
117            self.all_data = Some(ad);
118        };
119
120        // Expand kv cache
121        if self.current_seq_len + seq_len > self.capacity_seq_len {
122            let diff = self.current_seq_len + seq_len - self.capacity_seq_len;
123            let n_blocks_needed = diff.div_ceil(NormalCache::CACHE_GROW_SIZE);
124            self.capacity_seq_len += n_blocks_needed * NormalCache::CACHE_GROW_SIZE;
125            if self.capacity_seq_len > self.max_seq_len {
126                candle_core::bail!(
127                    "kv-cache: requested capacity ({}) above max seq len ({})",
128                    self.capacity_seq_len,
129                    self.max_seq_len
130                )
131            }
132            let mut shape = src.dims().to_vec();
133            shape[self.dim] = self.capacity_seq_len;
134            let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
135            ad.slice_set(self.all_data.as_ref().unwrap(), self.dim, 0)?;
136            self.all_data = Some(ad);
137        }
138
139        let ad = self.all_data.as_mut().unwrap();
140
141        ad.slice_set(src, self.dim, self.current_seq_len)?;
142        self.current_seq_len += seq_len;
143        Ok(())
144    }
145}
146
147#[derive(Debug, Clone)]
148pub struct RotatingCache {
149    pub all_data: Option<Tensor>,
150    pub dim: usize,
151    // `offset` is the current write index in the buffer
152    pub offset: usize,
153    // The total size of the sequence seen so far.
154    pub current_seq_len: usize,
155    // max_seq_len is the size of the rotating buffer, it is actually allowed for the full
156    // sequence to grow past this limit.
157    pub max_seq_len: usize,
158    pub capacity_seq_len: usize,
159}
160
161impl RotatingCache {
162    pub fn new(dim: usize, max_seq_len: usize, capacity_seq_len: usize) -> Self {
163        Self {
164            all_data: None,
165            dim,
166            offset: 0,
167            current_seq_len: 0,
168            max_seq_len,
169            capacity_seq_len,
170        }
171    }
172
173    pub fn offset(&self) -> usize {
174        self.offset
175    }
176
177    pub fn dim(&self) -> usize {
178        self.dim
179    }
180
181    pub fn current_seq_len(&self) -> usize {
182        self.current_seq_len
183    }
184
185    pub fn max_seq_len(&self) -> usize {
186        self.max_seq_len
187    }
188
189    pub fn all_data(&self) -> &Option<Tensor> {
190        &self.all_data
191    }
192
193    pub fn current_data(&self) -> Result<Option<Tensor>> {
194        let data = match self.all_data.as_ref() {
195            None => None,
196            Some(d) => {
197                if self.current_seq_len >= self.max_seq_len {
198                    Some(d.clone())
199                } else {
200                    Some(d.narrow(self.dim, 0, self.current_seq_len)?)
201                }
202            }
203        };
204        Ok(data)
205    }
206
207    pub fn reset(&mut self) {
208        self.offset = 0;
209        self.current_seq_len = 0;
210        self.all_data = None;
211    }
212
213    pub fn set_len(&mut self, len: usize) -> candle_core::Result<()> {
214        // If trying to roll it back past the boundary of max_seq_len, fail early.
215        if self.current_seq_len - len > self.max_seq_len {
216            candle_core::bail!(
217                "Rotating KV cache (usually for sliding window) tried to reset to len {len} while current is {} and max retained is {}",
218                self.current_seq_len,
219                self.max_seq_len
220            );
221        }
222        self.current_seq_len = len;
223        self.offset = len % self.max_seq_len;
224        Ok(())
225    }
226
227    pub fn append(&mut self, src: &Tensor) -> Result<Tensor> {
228        let seq_len = src.dim(self.dim)?;
229        // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
230        // self.all_data.get_or_insert_with.
231        if self.all_data.is_none() {
232            let mut shape = src.dims().to_vec();
233            shape[self.dim] = self.capacity_seq_len;
234            let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
235            self.all_data = Some(ad)
236        };
237
238        // Expand kv cache, this case is a little more complex.
239        if (self.current_seq_len + seq_len > self.capacity_seq_len
240            && self.current_seq_len + seq_len < self.max_seq_len)
241            || self.current_seq_len == 0
242        {
243            let diff = self.current_seq_len + seq_len - self.capacity_seq_len;
244            let n_blocks_needed = diff.div_ceil(NormalCache::CACHE_GROW_SIZE);
245            self.capacity_seq_len += n_blocks_needed * NormalCache::CACHE_GROW_SIZE;
246            self.capacity_seq_len = self.capacity_seq_len.min(self.max_seq_len);
247            if self.capacity_seq_len > self.max_seq_len {
248                candle_core::bail!(
249                    "kv-cache: requested capacity ({}) above max seq len ({})",
250                    self.capacity_seq_len,
251                    self.max_seq_len
252                )
253            }
254            let mut shape = src.dims().to_vec();
255            shape[self.dim] = self.capacity_seq_len;
256            let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
257            ad.slice_set(self.all_data.as_ref().unwrap(), self.dim, 0)?;
258            self.all_data = Some(ad);
259        }
260
261        let ad = self.all_data.as_mut().unwrap();
262
263        self.current_seq_len += seq_len;
264        if seq_len >= self.max_seq_len {
265            let to_copy = src
266                .narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?
267                .contiguous()?;
268            ad.slice_set(&to_copy, self.dim, 0)?;
269            self.offset = 0;
270            // Here we return `src` rather than `ad` so that all the past can be used.
271            Ok(src.clone())
272        } else {
273            let rem_len = self.max_seq_len - self.offset;
274            if seq_len <= rem_len {
275                ad.slice_set(&src.contiguous()?, self.dim, self.offset)?;
276                self.offset = (self.offset + seq_len) % self.max_seq_len;
277            } else {
278                // We have to make two copies here as we go over the boundary of the cache.
279                if rem_len > 0 {
280                    let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?;
281                    ad.slice_set(&src1, self.dim, self.offset)?;
282                }
283                let src2 = src
284                    .narrow(self.dim, rem_len, seq_len - rem_len)?
285                    .contiguous()?;
286                ad.slice_set(&src2, self.dim, 0)?;
287                self.offset = seq_len - rem_len;
288            }
289            if self.current_seq_len >= self.max_seq_len {
290                Ok(ad.clone())
291            } else {
292                Ok(ad.narrow(self.dim, 0, self.current_seq_len)?)
293            }
294        }
295    }
296}
297
298#[derive(Debug, Clone)]
299pub enum KvCache {
300    Normal { k: SingleCache, v: SingleCache },
301    Rotating { k: RotatingCache, v: RotatingCache },
302}
303
304impl KvCache {
305    pub fn new_normal(dim: usize, max_seq_len: usize, capacity_seq_len: usize) -> Self {
306        let k = SingleCache::new(dim, max_seq_len, capacity_seq_len);
307        let v = SingleCache::new(dim, max_seq_len, capacity_seq_len);
308        Self::Normal { k, v }
309    }
310
311    pub fn new_rotating(dim: usize, sliding_window: usize, capacity_seq_len: usize) -> Self {
312        let k = RotatingCache::new(dim, sliding_window, capacity_seq_len);
313        let v = RotatingCache::new(dim, sliding_window, capacity_seq_len);
314        Self::Rotating { k, v }
315    }
316
317    pub fn k(&self) -> Result<Option<Tensor>> {
318        match self {
319            Self::Normal { k, .. } => k.current_data(),
320            Self::Rotating { k, .. } => k.current_data(),
321        }
322    }
323
324    pub fn v(&self) -> Result<Option<Tensor>> {
325        match self {
326            Self::Normal { v, .. } => v.current_data(),
327            Self::Rotating { v, .. } => v.current_data(),
328        }
329    }
330
331    pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
332        let k = k.contiguous()?;
333        let v = v.contiguous()?;
334        let (out_k, out_v) = match self {
335            Self::Normal { k: kc, v: vc } => {
336                kc.append(&k)?;
337                vc.append(&v)?;
338                (kc.current_data()?, vc.current_data()?)
339            }
340            Self::Rotating { k: kc, v: vc } => {
341                let out_k = kc.append(&k)?;
342                let out_v = vc.append(&v)?;
343                (Some(out_k), Some(out_v))
344            }
345        };
346        let k = match out_k {
347            None => {
348                let mut shape = k.dims().to_vec();
349                match self {
350                    Self::Normal { k, .. } => shape[k.dim] = 0,
351                    Self::Rotating { k, .. } => shape[k.dim] = 0,
352                }
353                Tensor::zeros(shape, k.dtype(), k.device())?
354            }
355            Some(k) => k,
356        };
357        let v = match out_v {
358            None => {
359                let mut shape = v.dims().to_vec();
360                match self {
361                    Self::Normal { v, .. } => shape[v.dim] = 0,
362                    Self::Rotating { v, .. } => shape[v.dim] = 0,
363                }
364                Tensor::zeros(shape, v.dtype(), v.device())?
365            }
366            Some(v) => v,
367        };
368        Ok((k, v))
369    }
370
371    pub fn current_seq_len(&self) -> usize {
372        match self {
373            Self::Normal { k, .. } => k.current_seq_len(),
374            Self::Rotating { k, .. } => k.current_seq_len(),
375        }
376    }
377
378    pub fn reset(&mut self) {
379        match self {
380            Self::Normal { k, v } => {
381                k.reset();
382                v.reset();
383            }
384            Self::Rotating { k, v } => {
385                k.reset();
386                v.reset();
387            }
388        }
389    }
390
391    /// Returns Ok if the length reassignment was successful, otherwise returns Err.
392    pub fn set_len(&mut self, len: usize) -> candle_core::Result<()> {
393        match self {
394            Self::Normal { k, v } => {
395                k.set_len(len)?;
396                v.set_len(len)?;
397                Ok(())
398            }
399            Self::Rotating { k, v } => {
400                k.set_len(len)?;
401                v.set_len(len)?;
402                Ok(())
403            }
404        }
405    }
406
407    pub fn is_rotating(&self) -> bool {
408        matches!(self, Self::Rotating { .. })
409    }
410}
411
412#[derive(Debug, Clone)]
413pub struct NormalCache(pub Vec<KvCache>);
414
415#[derive(Debug)]
416pub enum NormalCacheType {
417    Normal { max_seq_len: usize },
418    SlidingWindow { window: usize },
419}
420
421impl NormalCache {
422    /// The number of tokens to grow the cache by
423    pub const CACHE_GROW_SIZE: usize = 512;
424
425    pub fn new(len: usize, max_seq_len: usize) -> Arc<Mutex<Self>> {
426        Arc::new(Mutex::new(Self(vec![
427            KvCache::new_normal(
428                2,
429                max_seq_len,
430                Self::CACHE_GROW_SIZE
431            );
432            len
433        ])))
434    }
435
436    pub fn new_sliding(
437        len: usize,
438        max_seq_len: usize,
439        sliding_window: Option<usize>,
440    ) -> Arc<Mutex<Self>> {
441        match sliding_window {
442            Some(sliding_window) => Arc::new(Mutex::new(Self(vec![
443                KvCache::new_rotating(
444                    2,
445                    sliding_window,
446                    Self::CACHE_GROW_SIZE
447                );
448                len
449            ]))),
450            None => Arc::new(Mutex::new(Self(vec![
451                KvCache::new_normal(
452                    2,
453                    max_seq_len,
454                    Self::CACHE_GROW_SIZE
455                );
456                len
457            ]))),
458        }
459    }
460
461    pub fn from_types(types: Vec<NormalCacheType>) -> Arc<Mutex<Self>> {
462        let mut caches = Vec::new();
463        for ty in types {
464            match ty {
465                NormalCacheType::Normal { max_seq_len } => {
466                    caches.push(KvCache::new_normal(2, max_seq_len, Self::CACHE_GROW_SIZE));
467                }
468                NormalCacheType::SlidingWindow { window } => {
469                    caches.push(KvCache::new_rotating(2, window, Self::CACHE_GROW_SIZE));
470                }
471            }
472        }
473        Arc::new(Mutex::new(Self(caches)))
474    }
475}
476
477pub struct NormalCacheManager;
478
479impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for NormalCacheManager {
480    fn clone_in_cache(
481        &self,
482        pipeline: &T,
483        seqs: &mut [&mut crate::sequence::Sequence],
484        modify_draft_cache: bool,
485    ) {
486        let mut new_k_cache = Vec::new();
487        let mut new_v_cache = Vec::new();
488
489        'outer: for layer in 0..pipeline.get_metadata().num_hidden_layers {
490            let mut k_vec = Vec::new();
491            let mut v_vec = Vec::new();
492            for seq in &mut *seqs {
493                let src_cache = if modify_draft_cache {
494                    seq.normal_draft_cache()
495                } else {
496                    seq.normal_cache()
497                };
498                let cache = src_cache.get(layer).unwrap();
499                // This case for llama 3.2 vision cross attn
500                if cache.is_none() {
501                    new_k_cache.push(None);
502                    new_v_cache.push(None);
503                    continue 'outer;
504                }
505                let cache = cache
506                    .as_ref()
507                    .expect("Not handling completions in `clone_in_cache`.");
508                match cache {
509                    KvCache::Normal { k, v } => {
510                        k_vec.push(k.all_data.clone().unwrap());
511                        v_vec.push(v.all_data.clone().unwrap());
512                    }
513                    KvCache::Rotating { k, v } => {
514                        k_vec.push(k.all_data.clone().unwrap());
515                        v_vec.push(v.all_data.clone().unwrap());
516                    }
517                }
518            }
519            new_k_cache.push(Some(if k_vec.len() > 1 {
520                Tensor::cat(&k_vec, 0).unwrap()
521            } else {
522                k_vec[0].clone()
523            }));
524            new_v_cache.push(Some(if v_vec.len() > 1 {
525                Tensor::cat(&v_vec, 0).unwrap()
526            } else {
527                v_vec[0].clone()
528            }));
529        }
530
531        let seq0_cache = if modify_draft_cache {
532            &*seqs[0].normal_draft_cache()
533        } else {
534            &*seqs[0].normal_cache()
535        };
536
537        let mut caches = Vec::new();
538        for (layer_idx, (k_cache, v_cache)) in new_k_cache.into_iter().zip(new_v_cache).enumerate()
539        {
540            // Use this for the various parameters. Assumes all seqs are from one model.
541            match seq0_cache[layer_idx].as_ref().unwrap() {
542                KvCache::Normal { k: old_k, .. } => {
543                    let template_cache_dim = old_k.dim;
544                    let template_cache_csl = old_k.current_seq_len;
545                    let template_cache_msl = old_k.max_seq_len;
546                    let template_cache_capsl = old_k.capacity_seq_len;
547
548                    caches.push(KvCache::Normal {
549                        k: SingleCache {
550                            all_data: k_cache.map(|x| x.contiguous().unwrap()),
551                            dim: template_cache_dim,
552                            current_seq_len: template_cache_csl,
553                            max_seq_len: template_cache_msl,
554                            capacity_seq_len: template_cache_capsl,
555                        },
556                        v: SingleCache {
557                            all_data: v_cache.map(|x| x.contiguous().unwrap()),
558                            dim: template_cache_dim,
559                            current_seq_len: template_cache_csl,
560                            max_seq_len: template_cache_msl,
561                            capacity_seq_len: template_cache_capsl,
562                        },
563                    });
564                }
565                KvCache::Rotating { k: old_k, .. } => {
566                    let template_cache_dim = old_k.dim;
567                    let template_cache_csl = old_k.current_seq_len;
568                    let template_cache_msl = old_k.max_seq_len;
569                    let template_cache_offset = old_k.offset;
570                    let template_cache_capsl = old_k.capacity_seq_len;
571
572                    caches.push(KvCache::Rotating {
573                        k: RotatingCache {
574                            all_data: k_cache.map(|x| x.contiguous().unwrap()),
575                            dim: template_cache_dim,
576                            current_seq_len: template_cache_csl,
577                            max_seq_len: template_cache_msl,
578                            offset: template_cache_offset,
579                            capacity_seq_len: template_cache_capsl,
580                        },
581                        v: RotatingCache {
582                            all_data: v_cache.map(|x| x.contiguous().unwrap()),
583                            dim: template_cache_dim,
584                            current_seq_len: template_cache_csl,
585                            max_seq_len: template_cache_msl,
586                            offset: template_cache_offset,
587                            capacity_seq_len: template_cache_capsl,
588                        },
589                    });
590                }
591            }
592        }
593        *pipeline.cache().normal() = NormalCache(caches);
594    }
595    fn clone_out_cache(&self, pipeline: &T, seqs: &mut [&mut Sequence], modify_draft_cache: bool) {
596        let all_cache = pipeline.cache().normal();
597        for layer in 0..pipeline.get_metadata().num_hidden_layers {
598            let cache = all_cache.0.get(layer).unwrap();
599            // This case for llama 3.2 vision cross attn
600            if cache.k().unwrap().is_none() {
601                continue;
602            }
603
604            let (k_cache, v_cache) = match cache {
605                KvCache::Normal { k, v } => {
606                    (k.all_data.clone().unwrap(), v.all_data.clone().unwrap())
607                }
608                KvCache::Rotating { k, v } => {
609                    (k.all_data.clone().unwrap(), v.all_data.clone().unwrap())
610                }
611            };
612
613            let k_caches = k_cache.chunk(seqs.len(), 0).unwrap();
614            debug_assert_eq!(k_caches.len(), seqs.len());
615            let v_caches = v_cache.chunk(seqs.len(), 0).unwrap();
616            debug_assert_eq!(v_caches.len(), seqs.len());
617
618            for (seq_i, seq) in seqs.iter_mut().enumerate() {
619                let output_cache = if modify_draft_cache {
620                    seq.normal_draft_cache()
621                } else {
622                    seq.normal_cache()
623                };
624                let seq_cache = &mut output_cache[layer];
625                let k = k_caches.get(seq_i).unwrap().clone();
626                let v = v_caches.get(seq_i).unwrap().clone();
627
628                match cache {
629                    KvCache::Normal {
630                        k: cache_k,
631                        v: cache_v,
632                    } => {
633                        *seq_cache = Some(KvCache::Normal {
634                            k: SingleCache {
635                                all_data: Some(k),
636                                dim: cache_k.dim,
637                                current_seq_len: cache_k.current_seq_len,
638                                max_seq_len: cache_k.max_seq_len,
639                                capacity_seq_len: cache_k.capacity_seq_len,
640                            },
641                            v: SingleCache {
642                                all_data: Some(v),
643                                dim: cache_v.dim,
644                                current_seq_len: cache_v.current_seq_len,
645                                max_seq_len: cache_v.max_seq_len,
646                                capacity_seq_len: cache_v.capacity_seq_len,
647                            },
648                        });
649                    }
650                    KvCache::Rotating {
651                        k: cache_k,
652                        v: cache_v,
653                    } => {
654                        *seq_cache = Some(KvCache::Rotating {
655                            k: RotatingCache {
656                                all_data: Some(k),
657                                dim: cache_k.dim,
658                                current_seq_len: cache_k.current_seq_len,
659                                max_seq_len: cache_k.max_seq_len,
660                                offset: cache_k.offset,
661                                capacity_seq_len: cache_k.capacity_seq_len,
662                            },
663                            v: RotatingCache {
664                                all_data: Some(v),
665                                dim: cache_v.dim,
666                                current_seq_len: cache_v.current_seq_len,
667                                max_seq_len: cache_v.max_seq_len,
668                                offset: cache_v.offset,
669                                capacity_seq_len: cache_v.capacity_seq_len,
670                            },
671                        });
672                    }
673                }
674            }
675        }
676    }
677    fn set_none_cache(
678        &self,
679        pipeline: &T,
680        seqs: &mut [&mut Sequence],
681        _modify_draft_cache: bool,
682        load_preallocated_cache: bool,
683    ) {
684        if seqs.iter().any(|seq| seq.preallocated_cache().is_none()) {
685            for layer in pipeline.cache().normal().0.iter_mut() {
686                layer.reset();
687            }
688            return;
689        }
690
691        let layer_devices = if let Some(device_mapper) = pipeline.device_mapper() {
692            let mut layer_devices = Vec::new();
693            for layer in 0..device_mapper.num_device_mapping_layers() {
694                let device = device_mapper.device_for(layer, false).cloned();
695                layer_devices.push(device.expect("Internal bug, layer out of range!"));
696            }
697            Some(layer_devices)
698        } else {
699            None
700        };
701
702        let old_caches = pipeline.cache().normal().0.clone();
703
704        for (layer_idx, layer) in pipeline.cache().normal().0.iter_mut().enumerate() {
705            if !load_preallocated_cache {
706                layer.reset();
707                continue;
708            }
709
710            let mut k_caches = Vec::new();
711            let mut v_caches = Vec::new();
712            for seq in seqs.iter_mut() {
713                let (mut k_preallocated_cache, mut v_preallocated_cache) =
714                    (*seq.preallocated_cache().as_ref().unwrap()).clone();
715                if let Some(layer_devices) = &layer_devices {
716                    let layer_dev = &layer_devices[layer_idx];
717                    k_preallocated_cache = k_preallocated_cache
718                        .to_device(layer_dev)
719                        .expect("Could not prepare cache");
720                    v_preallocated_cache = v_preallocated_cache
721                        .to_device(layer_dev)
722                        .expect("Could not prepare cache");
723                }
724                k_caches.push(k_preallocated_cache);
725                v_caches.push(v_preallocated_cache);
726            }
727            let k_cache = if k_caches.len() > 1 {
728                Tensor::cat(&k_caches, 0).unwrap()
729            } else {
730                k_caches[0].clone()
731            };
732            let v_cache = if v_caches.len() > 1 {
733                Tensor::cat(&v_caches, 0).unwrap()
734            } else {
735                v_caches[0].clone()
736            };
737
738            // Use this for the various parameters. Assumes all seqs are from one model.
739            match &old_caches[layer_idx] {
740                KvCache::Normal { k, .. } => {
741                    let template_cache_dim = k.dim;
742                    let template_cache_msl = k.max_seq_len;
743
744                    let cache = KvCache::Normal {
745                        k: SingleCache {
746                            all_data: Some(k_cache.zeros_like().unwrap()),
747                            dim: template_cache_dim,
748                            current_seq_len: 0,
749                            max_seq_len: template_cache_msl,
750                            capacity_seq_len: k_cache.dims()[template_cache_dim],
751                        },
752                        v: SingleCache {
753                            all_data: Some(v_cache.zeros_like().unwrap()),
754                            dim: template_cache_dim,
755                            current_seq_len: 0,
756                            max_seq_len: template_cache_msl,
757                            capacity_seq_len: k_cache.dims()[template_cache_dim],
758                        },
759                    };
760                    *layer = cache;
761                }
762                KvCache::Rotating { k, .. } => {
763                    let template_cache_dim = k.dim;
764                    let template_cache_msl = k.max_seq_len;
765
766                    // Rotating cache is not preallocated.
767                    let cache = KvCache::Rotating {
768                        k: RotatingCache {
769                            all_data: None,
770                            dim: template_cache_dim,
771                            current_seq_len: 0,
772                            max_seq_len: template_cache_msl,
773                            offset: 0,
774                            capacity_seq_len: 0,
775                        },
776                        v: RotatingCache {
777                            all_data: None,
778                            dim: template_cache_dim,
779                            current_seq_len: 0,
780                            max_seq_len: template_cache_msl,
781                            offset: 0,
782                            capacity_seq_len: 0,
783                        },
784                    };
785                    *layer = cache;
786                }
787            }
788        }
789    }
790}
791
792#[derive(Debug, Clone)]
793pub struct Cache {
794    cache: Arc<Mutex<LayerCaches>>,
795    xlora_cache: Option<Arc<Mutex<LayerCaches>>>,
796    draft_cache: Arc<Mutex<LayerCaches>>,
797    scalings_cache: Option<Arc<Mutex<Option<Tensor>>>>,
798}
799
800impl Cache {
801    pub(crate) fn new(len: usize, is_xlora: bool) -> Self {
802        Self {
803            cache: Arc::new(Mutex::new(vec![None; len])),
804            xlora_cache: if is_xlora {
805                Some(Arc::new(Mutex::new(vec![None; len])))
806            } else {
807                None
808            },
809            draft_cache: Arc::new(Mutex::new(vec![None; len])),
810            scalings_cache: if is_xlora {
811                Some(Arc::new(Mutex::new(None)))
812            } else {
813                None
814            },
815        }
816    }
817
818    pub(crate) fn lock(&self) -> MutexGuard<'_, LayerCaches> {
819        get_mut_arcmutex!(self.cache)
820    }
821
822    pub(crate) fn draft_lock(&self) -> MutexGuard<'_, LayerCaches> {
823        get_mut_arcmutex!(self.draft_cache)
824    }
825
826    /// # Panics
827    /// If there is no xlora cache
828    pub(crate) fn xlora_lock(&self) -> MutexGuard<'_, LayerCaches> {
829        get_mut_arcmutex!(self.xlora_cache.as_ref().expect("No X-LoRA cache."))
830    }
831
832    /// # Panics
833    /// If there is no xlora cache
834    pub(crate) fn get_scalings_cache(&self) -> MutexGuard<'_, Option<Tensor>> {
835        get_mut_arcmutex!(self
836            .scalings_cache
837            .as_ref()
838            .expect("No X-LoRA scalings cache."))
839    }
840
841    pub(crate) fn is_xlora(&self) -> bool {
842        self.xlora_cache.is_some()
843    }
844
845    /// Update the KV cache and return (k,v)
846    pub(crate) fn update_kv_cache(
847        cache: &mut Option<(Tensor, Tensor)>,
848        k: Tensor,
849        v: Tensor,
850        slow_cat: bool,
851    ) -> Result<(Tensor, Tensor)> {
852        let (k, v) = match &*cache {
853            None => (k, v),
854            Some((k_cache, v_cache)) => {
855                if !slow_cat {
856                    let k = candle_nn::ops::kvconcat(k_cache, &k, 2)?.contiguous()?;
857                    let v = candle_nn::ops::kvconcat(v_cache, &v, 2)?.contiguous()?;
858                    (k, v)
859                } else {
860                    let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
861                    let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
862                    (k, v)
863                }
864            }
865        };
866        *cache = Some((k.clone(), v.clone()));
867        Ok((k.contiguous()?, v.contiguous()?))
868    }
869
870    /// Update the KV cache and return (k,v,attn_mask)
871    pub(crate) fn update_kv_cache_sliding_window(
872        cache: &mut Option<(Tensor, Tensor)>,
873        k: Tensor,
874        v: Tensor,
875        attention_mask: Option<&Tensor>,
876        sliding_window: Option<usize>,
877        slow_cat: bool,
878    ) -> Result<(Tensor, Tensor, Option<Tensor>)> {
879        let (k, v, attention_mask) = match cache.clone() {
880            None => (k, v, attention_mask.cloned()),
881            Some((mut prev_k, mut prev_v)) => {
882                let mut mask = attention_mask.cloned();
883                if let Some(sliding_window) = sliding_window {
884                    let kv_seq_len = prev_k.dim(2)?;
885                    if kv_seq_len > sliding_window {
886                        prev_k = prev_k.narrow(
887                            2,
888                            kv_seq_len - (sliding_window - 1),
889                            sliding_window - 1,
890                        )?;
891                        prev_v = prev_v.narrow(
892                            2,
893                            kv_seq_len - (sliding_window - 1),
894                            sliding_window - 1,
895                        )?;
896                        if let Some(ref mut mask) = mask {
897                            let mask_len = mask.dim(1)?;
898                            *mask = mask.narrow(
899                                1,
900                                mask_len - (sliding_window - 1),
901                                sliding_window - 1,
902                            )?;
903                            *mask = Tensor::cat(
904                                &[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?],
905                                D::Minus1,
906                            )?;
907                        }
908                    }
909                }
910                let (k, v) = if !slow_cat {
911                    let k = candle_nn::ops::kvconcat(&prev_k, &k, 2)?;
912                    let v = candle_nn::ops::kvconcat(&prev_v, &v, 2)?;
913                    (k, v)
914                } else {
915                    let k = Tensor::cat(&[prev_k, k], 2)?.contiguous()?;
916                    let v = Tensor::cat(&[prev_v, v], 2)?.contiguous()?;
917                    (k, v)
918                };
919                (k, v, mask)
920            }
921        };
922        *cache = Some((k.clone(), v.clone()));
923        Ok((k.contiguous()?, v.contiguous()?, attention_mask))
924    }
925}
926
927pub struct FullCacheManager;
928
929enum SeqCache {
930    Normal,
931    XLora,
932    Draft,
933}
934
935fn clone_in_cache(
936    num_hidden_layers: usize,
937    cache: &mut LayerCaches,
938    seqs: &mut [&mut crate::sequence::Sequence],
939    src: SeqCache,
940) {
941    let mut new_cache = Vec::new();
942    'outer: for layer in 0..num_hidden_layers {
943        let mut k_vec = Vec::new();
944        let mut v_vec = Vec::new();
945        for seq in &mut *seqs {
946            let src_cache = match src {
947                SeqCache::Normal => seq.cache(),
948                SeqCache::XLora => seq.xlora_cache(),
949                SeqCache::Draft => seq.draft_cache(),
950            };
951            let cache = src_cache.get(layer).unwrap();
952            // This case for llama 3.2 vision cross attn
953            if cache.is_none() {
954                new_cache.push(None);
955                continue 'outer;
956            }
957            let cache = cache
958                .as_ref()
959                .expect("Not handling completions in `clone_in_cache`.");
960            k_vec.push(cache.0.clone());
961            v_vec.push(cache.1.clone());
962        }
963        new_cache.push(Some((
964            if k_vec.len() > 1 {
965                Tensor::cat(&k_vec, 0).unwrap()
966            } else {
967                k_vec[0].clone()
968            },
969            if v_vec.len() > 1 {
970                Tensor::cat(&v_vec, 0).unwrap()
971            } else {
972                v_vec[0].clone()
973            },
974        )));
975    }
976    *cache = new_cache;
977}
978
979fn clone_out_cache(
980    num_hidden_layers: usize,
981    cache: &mut LayerCaches,
982    seqs: &mut [&mut crate::sequence::Sequence],
983    target: SeqCache,
984) {
985    for layer in 0..num_hidden_layers {
986        let cache = cache.get(layer).unwrap();
987        // This case for llama 3.2 vision cross attn
988        if cache.is_none() {
989            continue;
990        }
991
992        let k_cache = cache.as_ref().unwrap().0.clone();
993        let v_cache = cache.as_ref().unwrap().1.clone();
994
995        let k_caches = k_cache.chunk(seqs.len(), 0).unwrap();
996        debug_assert_eq!(k_caches.len(), seqs.len());
997        let v_caches = v_cache.chunk(seqs.len(), 0).unwrap();
998        debug_assert_eq!(v_caches.len(), seqs.len());
999
1000        for (seq_i, seq) in seqs.iter_mut().enumerate() {
1001            let output_cache = match target {
1002                SeqCache::Normal => seq.cache(),
1003                SeqCache::XLora => seq.xlora_cache(),
1004                SeqCache::Draft => seq.draft_cache(),
1005            };
1006            let seq_cache = &mut output_cache[layer];
1007            let k = k_caches.get(seq_i).unwrap().clone();
1008            let v = v_caches.get(seq_i).unwrap().clone();
1009            *seq_cache = Some((k, v));
1010        }
1011    }
1012}
1013
1014impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for FullCacheManager {
1015    fn clone_in_cache(
1016        &self,
1017        pipeline: &T,
1018        seqs: &mut [&mut crate::sequence::Sequence],
1019        modify_draft_cache: bool,
1020    ) {
1021        if modify_draft_cache {
1022            clone_in_cache(
1023                pipeline.get_metadata().num_hidden_layers,
1024                &mut pipeline.cache().full().lock(),
1025                seqs,
1026                SeqCache::Draft,
1027            );
1028            return;
1029        }
1030        clone_in_cache(
1031            pipeline.get_metadata().num_hidden_layers,
1032            &mut pipeline.cache().full().lock(),
1033            seqs,
1034            SeqCache::Normal,
1035        );
1036        if pipeline.get_metadata().is_xlora && !pipeline.get_metadata().no_kv_cache {
1037            clone_in_cache(
1038                pipeline.get_metadata().num_hidden_layers,
1039                &mut pipeline.cache().full().xlora_lock(),
1040                seqs,
1041                SeqCache::XLora,
1042            );
1043        }
1044        if pipeline.get_metadata().is_xlora {
1045            pipeline
1046                .cache()
1047                .full()
1048                .get_scalings_cache()
1049                .clone_from(seqs[0].scaling_cache());
1050        }
1051    }
1052
1053    fn clone_out_cache(
1054        &self,
1055        pipeline: &T,
1056        seqs: &mut [&mut crate::sequence::Sequence],
1057        modify_draft_cache: bool,
1058    ) {
1059        if modify_draft_cache {
1060            clone_out_cache(
1061                pipeline.get_metadata().num_hidden_layers,
1062                &mut pipeline.cache().full().lock(),
1063                seqs,
1064                SeqCache::Draft,
1065            );
1066            return;
1067        }
1068        clone_out_cache(
1069            pipeline.get_metadata().num_hidden_layers,
1070            &mut pipeline.cache().full().lock(),
1071            seqs,
1072            SeqCache::Normal,
1073        );
1074        if pipeline.get_metadata().is_xlora && !pipeline.get_metadata().no_kv_cache {
1075            clone_out_cache(
1076                pipeline.get_metadata().num_hidden_layers,
1077                &mut pipeline.cache().full().xlora_lock(),
1078                seqs,
1079                SeqCache::XLora,
1080            );
1081        }
1082        if pipeline.get_metadata().is_xlora {
1083            seqs[0]
1084                .scaling_cache()
1085                .clone_from(&pipeline.cache().full().get_scalings_cache());
1086        }
1087    }
1088
1089    fn set_none_cache(
1090        &self,
1091        pipeline: &T,
1092        _seqs: &mut [&mut Sequence],
1093        modify_draft_cache: bool,
1094        _load_preallocated_cache: bool,
1095    ) {
1096        let mut new_cache = Vec::new();
1097        for _ in 0..pipeline.get_metadata().num_hidden_layers {
1098            new_cache.push(None);
1099        }
1100        pipeline.cache().full().lock().clone_from(&new_cache);
1101        if modify_draft_cache {
1102            pipeline.cache().full().draft_lock().clone_from(&new_cache);
1103        }
1104        if pipeline.cache().full().is_xlora() {
1105            *pipeline.cache().full().xlora_lock() = new_cache;
1106        }
1107    }
1108}