1use std::sync::{Arc, Mutex, MutexGuard};
2
3use candle_core::{Result, Tensor, D};
4
5use crate::{get_mut_arcmutex, sequence::Sequence};
6
7use super::{CacheManagerMixin, MetadataMixin};
8
9pub trait CacheManager<T: CacheManagerMixin + MetadataMixin + ?Sized> {
10 fn clone_in_cache(
11 &self,
12 pipeline: &T,
13 seqs: &mut [&mut crate::sequence::Sequence],
14 modify_draft_cache: bool,
15 );
16 fn clone_out_cache(&self, pipeline: &T, seqs: &mut [&mut Sequence], modify_draft_cache: bool);
17 fn set_none_cache(
18 &self,
19 pipeline: &T,
20 seqs: &mut [&mut Sequence],
21 modify_draft_cache: bool,
22 load_preallocated_cache: bool,
23 );
24}
25
26pub type LayerCaches = Vec<Option<(Tensor, Tensor)>>;
27
28#[derive(Debug, Clone)]
29pub enum EitherCache {
30 Normal(Arc<Mutex<NormalCache>>),
31 Full(Cache),
32}
33
34impl EitherCache {
35 pub fn full(&self) -> &Cache {
37 match self {
38 Self::Full(full) => full,
39 Self::Normal(_) => panic!("Got normal cache, expected full cache."),
40 }
41 }
42 pub fn normal(&self) -> MutexGuard<'_, NormalCache> {
44 match self {
45 Self::Normal(normal) => normal.lock().unwrap(),
46 Self::Full(_) => panic!("Got full cache, expected normal cache."),
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
52pub struct SingleCache {
53 pub all_data: Option<Tensor>,
58 pub dim: usize,
59 pub current_seq_len: usize,
60 pub capacity_seq_len: usize,
61 pub max_seq_len: usize,
62}
63
64impl SingleCache {
65 pub fn new(dim: usize, max_seq_len: usize, capacity_seq_len: usize) -> Self {
66 Self {
67 all_data: None,
68 dim,
69 current_seq_len: 0,
70 max_seq_len,
71 capacity_seq_len,
72 }
73 }
74
75 pub fn dim(&self) -> usize {
76 self.dim
77 }
78
79 pub fn current_seq_len(&self) -> usize {
80 self.current_seq_len
81 }
82
83 pub fn max_seq_len(&self) -> usize {
84 self.max_seq_len
85 }
86
87 pub fn all_data(&self) -> &Option<Tensor> {
88 &self.all_data
89 }
90
91 pub fn current_data(&self) -> Result<Option<Tensor>> {
92 let data = match self.all_data.as_ref() {
93 None => None,
94 Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?),
95 };
96 Ok(data)
97 }
98
99 pub fn reset(&mut self) {
100 self.current_seq_len = 0;
101 self.all_data = None;
102 }
103
104 pub fn set_len(&mut self, len: usize) -> candle_core::Result<()> {
105 self.current_seq_len = len;
106 Ok(())
107 }
108
109 pub fn append(&mut self, src: &Tensor) -> Result<()> {
110 let seq_len = src.dim(self.dim)?;
111 if self.all_data.is_none() {
114 let mut shape = src.dims().to_vec();
115 shape[self.dim] = self.capacity_seq_len;
116 let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
117 self.all_data = Some(ad);
118 };
119
120 if self.current_seq_len + seq_len > self.capacity_seq_len {
122 let diff = self.current_seq_len + seq_len - self.capacity_seq_len;
123 let n_blocks_needed = diff.div_ceil(NormalCache::CACHE_GROW_SIZE);
124 self.capacity_seq_len += n_blocks_needed * NormalCache::CACHE_GROW_SIZE;
125 if self.capacity_seq_len > self.max_seq_len {
126 candle_core::bail!(
127 "kv-cache: requested capacity ({}) above max seq len ({})",
128 self.capacity_seq_len,
129 self.max_seq_len
130 )
131 }
132 let mut shape = src.dims().to_vec();
133 shape[self.dim] = self.capacity_seq_len;
134 let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
135 ad.slice_set(self.all_data.as_ref().unwrap(), self.dim, 0)?;
136 self.all_data = Some(ad);
137 }
138
139 let ad = self.all_data.as_mut().unwrap();
140
141 ad.slice_set(src, self.dim, self.current_seq_len)?;
142 self.current_seq_len += seq_len;
143 Ok(())
144 }
145}
146
147#[derive(Debug, Clone)]
148pub struct RotatingCache {
149 pub all_data: Option<Tensor>,
150 pub dim: usize,
151 pub offset: usize,
153 pub current_seq_len: usize,
155 pub max_seq_len: usize,
158 pub capacity_seq_len: usize,
159}
160
161impl RotatingCache {
162 pub fn new(dim: usize, max_seq_len: usize, capacity_seq_len: usize) -> Self {
163 Self {
164 all_data: None,
165 dim,
166 offset: 0,
167 current_seq_len: 0,
168 max_seq_len,
169 capacity_seq_len,
170 }
171 }
172
173 pub fn offset(&self) -> usize {
174 self.offset
175 }
176
177 pub fn dim(&self) -> usize {
178 self.dim
179 }
180
181 pub fn current_seq_len(&self) -> usize {
182 self.current_seq_len
183 }
184
185 pub fn max_seq_len(&self) -> usize {
186 self.max_seq_len
187 }
188
189 pub fn all_data(&self) -> &Option<Tensor> {
190 &self.all_data
191 }
192
193 pub fn current_data(&self) -> Result<Option<Tensor>> {
194 let data = match self.all_data.as_ref() {
195 None => None,
196 Some(d) => {
197 if self.current_seq_len >= self.max_seq_len {
198 Some(d.clone())
199 } else {
200 Some(d.narrow(self.dim, 0, self.current_seq_len)?)
201 }
202 }
203 };
204 Ok(data)
205 }
206
207 pub fn reset(&mut self) {
208 self.offset = 0;
209 self.current_seq_len = 0;
210 self.all_data = None;
211 }
212
213 pub fn set_len(&mut self, len: usize) -> candle_core::Result<()> {
214 if self.current_seq_len - len > self.max_seq_len {
216 candle_core::bail!(
217 "Rotating KV cache (usually for sliding window) tried to reset to len {len} while current is {} and max retained is {}",
218 self.current_seq_len,
219 self.max_seq_len
220 );
221 }
222 self.current_seq_len = len;
223 self.offset = len % self.max_seq_len;
224 Ok(())
225 }
226
227 pub fn append(&mut self, src: &Tensor) -> Result<Tensor> {
228 let seq_len = src.dim(self.dim)?;
229 if self.all_data.is_none() {
232 let mut shape = src.dims().to_vec();
233 shape[self.dim] = self.capacity_seq_len;
234 let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
235 self.all_data = Some(ad)
236 };
237
238 if (self.current_seq_len + seq_len > self.capacity_seq_len
240 && self.current_seq_len + seq_len < self.max_seq_len)
241 || self.current_seq_len == 0
242 {
243 let diff = self.current_seq_len + seq_len - self.capacity_seq_len;
244 let n_blocks_needed = diff.div_ceil(NormalCache::CACHE_GROW_SIZE);
245 self.capacity_seq_len += n_blocks_needed * NormalCache::CACHE_GROW_SIZE;
246 self.capacity_seq_len = self.capacity_seq_len.min(self.max_seq_len);
247 if self.capacity_seq_len > self.max_seq_len {
248 candle_core::bail!(
249 "kv-cache: requested capacity ({}) above max seq len ({})",
250 self.capacity_seq_len,
251 self.max_seq_len
252 )
253 }
254 let mut shape = src.dims().to_vec();
255 shape[self.dim] = self.capacity_seq_len;
256 let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
257 ad.slice_set(self.all_data.as_ref().unwrap(), self.dim, 0)?;
258 self.all_data = Some(ad);
259 }
260
261 let ad = self.all_data.as_mut().unwrap();
262
263 self.current_seq_len += seq_len;
264 if seq_len >= self.max_seq_len {
265 let to_copy = src
266 .narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?
267 .contiguous()?;
268 ad.slice_set(&to_copy, self.dim, 0)?;
269 self.offset = 0;
270 Ok(src.clone())
272 } else {
273 let rem_len = self.max_seq_len - self.offset;
274 if seq_len <= rem_len {
275 ad.slice_set(&src.contiguous()?, self.dim, self.offset)?;
276 self.offset = (self.offset + seq_len) % self.max_seq_len;
277 } else {
278 if rem_len > 0 {
280 let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?;
281 ad.slice_set(&src1, self.dim, self.offset)?;
282 }
283 let src2 = src
284 .narrow(self.dim, rem_len, seq_len - rem_len)?
285 .contiguous()?;
286 ad.slice_set(&src2, self.dim, 0)?;
287 self.offset = seq_len - rem_len;
288 }
289 if self.current_seq_len >= self.max_seq_len {
290 Ok(ad.clone())
291 } else {
292 Ok(ad.narrow(self.dim, 0, self.current_seq_len)?)
293 }
294 }
295 }
296}
297
298#[derive(Debug, Clone)]
299pub enum KvCache {
300 Normal { k: SingleCache, v: SingleCache },
301 Rotating { k: RotatingCache, v: RotatingCache },
302}
303
304impl KvCache {
305 pub fn new_normal(dim: usize, max_seq_len: usize, capacity_seq_len: usize) -> Self {
306 let k = SingleCache::new(dim, max_seq_len, capacity_seq_len);
307 let v = SingleCache::new(dim, max_seq_len, capacity_seq_len);
308 Self::Normal { k, v }
309 }
310
311 pub fn new_rotating(dim: usize, sliding_window: usize, capacity_seq_len: usize) -> Self {
312 let k = RotatingCache::new(dim, sliding_window, capacity_seq_len);
313 let v = RotatingCache::new(dim, sliding_window, capacity_seq_len);
314 Self::Rotating { k, v }
315 }
316
317 pub fn k(&self) -> Result<Option<Tensor>> {
318 match self {
319 Self::Normal { k, .. } => k.current_data(),
320 Self::Rotating { k, .. } => k.current_data(),
321 }
322 }
323
324 pub fn v(&self) -> Result<Option<Tensor>> {
325 match self {
326 Self::Normal { v, .. } => v.current_data(),
327 Self::Rotating { v, .. } => v.current_data(),
328 }
329 }
330
331 pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
332 let k = k.contiguous()?;
333 let v = v.contiguous()?;
334 let (out_k, out_v) = match self {
335 Self::Normal { k: kc, v: vc } => {
336 kc.append(&k)?;
337 vc.append(&v)?;
338 (kc.current_data()?, vc.current_data()?)
339 }
340 Self::Rotating { k: kc, v: vc } => {
341 let out_k = kc.append(&k)?;
342 let out_v = vc.append(&v)?;
343 (Some(out_k), Some(out_v))
344 }
345 };
346 let k = match out_k {
347 None => {
348 let mut shape = k.dims().to_vec();
349 match self {
350 Self::Normal { k, .. } => shape[k.dim] = 0,
351 Self::Rotating { k, .. } => shape[k.dim] = 0,
352 }
353 Tensor::zeros(shape, k.dtype(), k.device())?
354 }
355 Some(k) => k,
356 };
357 let v = match out_v {
358 None => {
359 let mut shape = v.dims().to_vec();
360 match self {
361 Self::Normal { v, .. } => shape[v.dim] = 0,
362 Self::Rotating { v, .. } => shape[v.dim] = 0,
363 }
364 Tensor::zeros(shape, v.dtype(), v.device())?
365 }
366 Some(v) => v,
367 };
368 Ok((k, v))
369 }
370
371 pub fn current_seq_len(&self) -> usize {
372 match self {
373 Self::Normal { k, .. } => k.current_seq_len(),
374 Self::Rotating { k, .. } => k.current_seq_len(),
375 }
376 }
377
378 pub fn reset(&mut self) {
379 match self {
380 Self::Normal { k, v } => {
381 k.reset();
382 v.reset();
383 }
384 Self::Rotating { k, v } => {
385 k.reset();
386 v.reset();
387 }
388 }
389 }
390
391 pub fn set_len(&mut self, len: usize) -> candle_core::Result<()> {
393 match self {
394 Self::Normal { k, v } => {
395 k.set_len(len)?;
396 v.set_len(len)?;
397 Ok(())
398 }
399 Self::Rotating { k, v } => {
400 k.set_len(len)?;
401 v.set_len(len)?;
402 Ok(())
403 }
404 }
405 }
406
407 pub fn is_rotating(&self) -> bool {
408 matches!(self, Self::Rotating { .. })
409 }
410}
411
412#[derive(Debug, Clone)]
413pub struct NormalCache(pub Vec<KvCache>);
414
415#[derive(Debug)]
416pub enum NormalCacheType {
417 Normal { max_seq_len: usize },
418 SlidingWindow { window: usize },
419}
420
421impl NormalCache {
422 pub const CACHE_GROW_SIZE: usize = 512;
424
425 pub fn new(len: usize, max_seq_len: usize) -> Arc<Mutex<Self>> {
426 Arc::new(Mutex::new(Self(vec![
427 KvCache::new_normal(
428 2,
429 max_seq_len,
430 Self::CACHE_GROW_SIZE
431 );
432 len
433 ])))
434 }
435
436 pub fn new_sliding(
437 len: usize,
438 max_seq_len: usize,
439 sliding_window: Option<usize>,
440 ) -> Arc<Mutex<Self>> {
441 match sliding_window {
442 Some(sliding_window) => Arc::new(Mutex::new(Self(vec![
443 KvCache::new_rotating(
444 2,
445 sliding_window,
446 Self::CACHE_GROW_SIZE
447 );
448 len
449 ]))),
450 None => Arc::new(Mutex::new(Self(vec![
451 KvCache::new_normal(
452 2,
453 max_seq_len,
454 Self::CACHE_GROW_SIZE
455 );
456 len
457 ]))),
458 }
459 }
460
461 pub fn from_types(types: Vec<NormalCacheType>) -> Arc<Mutex<Self>> {
462 let mut caches = Vec::new();
463 for ty in types {
464 match ty {
465 NormalCacheType::Normal { max_seq_len } => {
466 caches.push(KvCache::new_normal(2, max_seq_len, Self::CACHE_GROW_SIZE));
467 }
468 NormalCacheType::SlidingWindow { window } => {
469 caches.push(KvCache::new_rotating(2, window, Self::CACHE_GROW_SIZE));
470 }
471 }
472 }
473 Arc::new(Mutex::new(Self(caches)))
474 }
475}
476
477pub struct NormalCacheManager;
478
479impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for NormalCacheManager {
480 fn clone_in_cache(
481 &self,
482 pipeline: &T,
483 seqs: &mut [&mut crate::sequence::Sequence],
484 modify_draft_cache: bool,
485 ) {
486 let mut new_k_cache = Vec::new();
487 let mut new_v_cache = Vec::new();
488
489 'outer: for layer in 0..pipeline.get_metadata().num_hidden_layers {
490 let mut k_vec = Vec::new();
491 let mut v_vec = Vec::new();
492 for seq in &mut *seqs {
493 let src_cache = if modify_draft_cache {
494 seq.normal_draft_cache()
495 } else {
496 seq.normal_cache()
497 };
498 let cache = src_cache.get(layer).unwrap();
499 if cache.is_none() {
501 new_k_cache.push(None);
502 new_v_cache.push(None);
503 continue 'outer;
504 }
505 let cache = cache
506 .as_ref()
507 .expect("Not handling completions in `clone_in_cache`.");
508 match cache {
509 KvCache::Normal { k, v } => {
510 k_vec.push(k.all_data.clone().unwrap());
511 v_vec.push(v.all_data.clone().unwrap());
512 }
513 KvCache::Rotating { k, v } => {
514 k_vec.push(k.all_data.clone().unwrap());
515 v_vec.push(v.all_data.clone().unwrap());
516 }
517 }
518 }
519 new_k_cache.push(Some(if k_vec.len() > 1 {
520 Tensor::cat(&k_vec, 0).unwrap()
521 } else {
522 k_vec[0].clone()
523 }));
524 new_v_cache.push(Some(if v_vec.len() > 1 {
525 Tensor::cat(&v_vec, 0).unwrap()
526 } else {
527 v_vec[0].clone()
528 }));
529 }
530
531 let seq0_cache = if modify_draft_cache {
532 &*seqs[0].normal_draft_cache()
533 } else {
534 &*seqs[0].normal_cache()
535 };
536
537 let mut caches = Vec::new();
538 for (layer_idx, (k_cache, v_cache)) in new_k_cache.into_iter().zip(new_v_cache).enumerate()
539 {
540 match seq0_cache[layer_idx].as_ref().unwrap() {
542 KvCache::Normal { k: old_k, .. } => {
543 let template_cache_dim = old_k.dim;
544 let template_cache_csl = old_k.current_seq_len;
545 let template_cache_msl = old_k.max_seq_len;
546 let template_cache_capsl = old_k.capacity_seq_len;
547
548 caches.push(KvCache::Normal {
549 k: SingleCache {
550 all_data: k_cache.map(|x| x.contiguous().unwrap()),
551 dim: template_cache_dim,
552 current_seq_len: template_cache_csl,
553 max_seq_len: template_cache_msl,
554 capacity_seq_len: template_cache_capsl,
555 },
556 v: SingleCache {
557 all_data: v_cache.map(|x| x.contiguous().unwrap()),
558 dim: template_cache_dim,
559 current_seq_len: template_cache_csl,
560 max_seq_len: template_cache_msl,
561 capacity_seq_len: template_cache_capsl,
562 },
563 });
564 }
565 KvCache::Rotating { k: old_k, .. } => {
566 let template_cache_dim = old_k.dim;
567 let template_cache_csl = old_k.current_seq_len;
568 let template_cache_msl = old_k.max_seq_len;
569 let template_cache_offset = old_k.offset;
570 let template_cache_capsl = old_k.capacity_seq_len;
571
572 caches.push(KvCache::Rotating {
573 k: RotatingCache {
574 all_data: k_cache.map(|x| x.contiguous().unwrap()),
575 dim: template_cache_dim,
576 current_seq_len: template_cache_csl,
577 max_seq_len: template_cache_msl,
578 offset: template_cache_offset,
579 capacity_seq_len: template_cache_capsl,
580 },
581 v: RotatingCache {
582 all_data: v_cache.map(|x| x.contiguous().unwrap()),
583 dim: template_cache_dim,
584 current_seq_len: template_cache_csl,
585 max_seq_len: template_cache_msl,
586 offset: template_cache_offset,
587 capacity_seq_len: template_cache_capsl,
588 },
589 });
590 }
591 }
592 }
593 *pipeline.cache().normal() = NormalCache(caches);
594 }
595 fn clone_out_cache(&self, pipeline: &T, seqs: &mut [&mut Sequence], modify_draft_cache: bool) {
596 let all_cache = pipeline.cache().normal();
597 for layer in 0..pipeline.get_metadata().num_hidden_layers {
598 let cache = all_cache.0.get(layer).unwrap();
599 if cache.k().unwrap().is_none() {
601 continue;
602 }
603
604 let (k_cache, v_cache) = match cache {
605 KvCache::Normal { k, v } => {
606 (k.all_data.clone().unwrap(), v.all_data.clone().unwrap())
607 }
608 KvCache::Rotating { k, v } => {
609 (k.all_data.clone().unwrap(), v.all_data.clone().unwrap())
610 }
611 };
612
613 let k_caches = k_cache.chunk(seqs.len(), 0).unwrap();
614 debug_assert_eq!(k_caches.len(), seqs.len());
615 let v_caches = v_cache.chunk(seqs.len(), 0).unwrap();
616 debug_assert_eq!(v_caches.len(), seqs.len());
617
618 for (seq_i, seq) in seqs.iter_mut().enumerate() {
619 let output_cache = if modify_draft_cache {
620 seq.normal_draft_cache()
621 } else {
622 seq.normal_cache()
623 };
624 let seq_cache = &mut output_cache[layer];
625 let k = k_caches.get(seq_i).unwrap().clone();
626 let v = v_caches.get(seq_i).unwrap().clone();
627
628 match cache {
629 KvCache::Normal {
630 k: cache_k,
631 v: cache_v,
632 } => {
633 *seq_cache = Some(KvCache::Normal {
634 k: SingleCache {
635 all_data: Some(k),
636 dim: cache_k.dim,
637 current_seq_len: cache_k.current_seq_len,
638 max_seq_len: cache_k.max_seq_len,
639 capacity_seq_len: cache_k.capacity_seq_len,
640 },
641 v: SingleCache {
642 all_data: Some(v),
643 dim: cache_v.dim,
644 current_seq_len: cache_v.current_seq_len,
645 max_seq_len: cache_v.max_seq_len,
646 capacity_seq_len: cache_v.capacity_seq_len,
647 },
648 });
649 }
650 KvCache::Rotating {
651 k: cache_k,
652 v: cache_v,
653 } => {
654 *seq_cache = Some(KvCache::Rotating {
655 k: RotatingCache {
656 all_data: Some(k),
657 dim: cache_k.dim,
658 current_seq_len: cache_k.current_seq_len,
659 max_seq_len: cache_k.max_seq_len,
660 offset: cache_k.offset,
661 capacity_seq_len: cache_k.capacity_seq_len,
662 },
663 v: RotatingCache {
664 all_data: Some(v),
665 dim: cache_v.dim,
666 current_seq_len: cache_v.current_seq_len,
667 max_seq_len: cache_v.max_seq_len,
668 offset: cache_v.offset,
669 capacity_seq_len: cache_v.capacity_seq_len,
670 },
671 });
672 }
673 }
674 }
675 }
676 }
677 fn set_none_cache(
678 &self,
679 pipeline: &T,
680 seqs: &mut [&mut Sequence],
681 _modify_draft_cache: bool,
682 load_preallocated_cache: bool,
683 ) {
684 if seqs.iter().any(|seq| seq.preallocated_cache().is_none()) {
685 for layer in pipeline.cache().normal().0.iter_mut() {
686 layer.reset();
687 }
688 return;
689 }
690
691 let layer_devices = if let Some(device_mapper) = pipeline.device_mapper() {
692 let mut layer_devices = Vec::new();
693 for layer in 0..device_mapper.num_device_mapping_layers() {
694 let device = device_mapper.device_for(layer, false).cloned();
695 layer_devices.push(device.expect("Internal bug, layer out of range!"));
696 }
697 Some(layer_devices)
698 } else {
699 None
700 };
701
702 let old_caches = pipeline.cache().normal().0.clone();
703
704 for (layer_idx, layer) in pipeline.cache().normal().0.iter_mut().enumerate() {
705 if !load_preallocated_cache {
706 layer.reset();
707 continue;
708 }
709
710 let mut k_caches = Vec::new();
711 let mut v_caches = Vec::new();
712 for seq in seqs.iter_mut() {
713 let (mut k_preallocated_cache, mut v_preallocated_cache) =
714 (*seq.preallocated_cache().as_ref().unwrap()).clone();
715 if let Some(layer_devices) = &layer_devices {
716 let layer_dev = &layer_devices[layer_idx];
717 k_preallocated_cache = k_preallocated_cache
718 .to_device(layer_dev)
719 .expect("Could not prepare cache");
720 v_preallocated_cache = v_preallocated_cache
721 .to_device(layer_dev)
722 .expect("Could not prepare cache");
723 }
724 k_caches.push(k_preallocated_cache);
725 v_caches.push(v_preallocated_cache);
726 }
727 let k_cache = if k_caches.len() > 1 {
728 Tensor::cat(&k_caches, 0).unwrap()
729 } else {
730 k_caches[0].clone()
731 };
732 let v_cache = if v_caches.len() > 1 {
733 Tensor::cat(&v_caches, 0).unwrap()
734 } else {
735 v_caches[0].clone()
736 };
737
738 match &old_caches[layer_idx] {
740 KvCache::Normal { k, .. } => {
741 let template_cache_dim = k.dim;
742 let template_cache_msl = k.max_seq_len;
743
744 let cache = KvCache::Normal {
745 k: SingleCache {
746 all_data: Some(k_cache.zeros_like().unwrap()),
747 dim: template_cache_dim,
748 current_seq_len: 0,
749 max_seq_len: template_cache_msl,
750 capacity_seq_len: k_cache.dims()[template_cache_dim],
751 },
752 v: SingleCache {
753 all_data: Some(v_cache.zeros_like().unwrap()),
754 dim: template_cache_dim,
755 current_seq_len: 0,
756 max_seq_len: template_cache_msl,
757 capacity_seq_len: k_cache.dims()[template_cache_dim],
758 },
759 };
760 *layer = cache;
761 }
762 KvCache::Rotating { k, .. } => {
763 let template_cache_dim = k.dim;
764 let template_cache_msl = k.max_seq_len;
765
766 let cache = KvCache::Rotating {
768 k: RotatingCache {
769 all_data: None,
770 dim: template_cache_dim,
771 current_seq_len: 0,
772 max_seq_len: template_cache_msl,
773 offset: 0,
774 capacity_seq_len: 0,
775 },
776 v: RotatingCache {
777 all_data: None,
778 dim: template_cache_dim,
779 current_seq_len: 0,
780 max_seq_len: template_cache_msl,
781 offset: 0,
782 capacity_seq_len: 0,
783 },
784 };
785 *layer = cache;
786 }
787 }
788 }
789 }
790}
791
792#[derive(Debug, Clone)]
793pub struct Cache {
794 cache: Arc<Mutex<LayerCaches>>,
795 xlora_cache: Option<Arc<Mutex<LayerCaches>>>,
796 draft_cache: Arc<Mutex<LayerCaches>>,
797 scalings_cache: Option<Arc<Mutex<Option<Tensor>>>>,
798}
799
800impl Cache {
801 pub(crate) fn new(len: usize, is_xlora: bool) -> Self {
802 Self {
803 cache: Arc::new(Mutex::new(vec![None; len])),
804 xlora_cache: if is_xlora {
805 Some(Arc::new(Mutex::new(vec![None; len])))
806 } else {
807 None
808 },
809 draft_cache: Arc::new(Mutex::new(vec![None; len])),
810 scalings_cache: if is_xlora {
811 Some(Arc::new(Mutex::new(None)))
812 } else {
813 None
814 },
815 }
816 }
817
818 pub(crate) fn lock(&self) -> MutexGuard<'_, LayerCaches> {
819 get_mut_arcmutex!(self.cache)
820 }
821
822 pub(crate) fn draft_lock(&self) -> MutexGuard<'_, LayerCaches> {
823 get_mut_arcmutex!(self.draft_cache)
824 }
825
826 pub(crate) fn xlora_lock(&self) -> MutexGuard<'_, LayerCaches> {
829 get_mut_arcmutex!(self.xlora_cache.as_ref().expect("No X-LoRA cache."))
830 }
831
832 pub(crate) fn get_scalings_cache(&self) -> MutexGuard<'_, Option<Tensor>> {
835 get_mut_arcmutex!(self
836 .scalings_cache
837 .as_ref()
838 .expect("No X-LoRA scalings cache."))
839 }
840
841 pub(crate) fn is_xlora(&self) -> bool {
842 self.xlora_cache.is_some()
843 }
844
845 pub(crate) fn update_kv_cache(
847 cache: &mut Option<(Tensor, Tensor)>,
848 k: Tensor,
849 v: Tensor,
850 slow_cat: bool,
851 ) -> Result<(Tensor, Tensor)> {
852 let (k, v) = match &*cache {
853 None => (k, v),
854 Some((k_cache, v_cache)) => {
855 if !slow_cat {
856 let k = candle_nn::ops::kvconcat(k_cache, &k, 2)?.contiguous()?;
857 let v = candle_nn::ops::kvconcat(v_cache, &v, 2)?.contiguous()?;
858 (k, v)
859 } else {
860 let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
861 let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
862 (k, v)
863 }
864 }
865 };
866 *cache = Some((k.clone(), v.clone()));
867 Ok((k.contiguous()?, v.contiguous()?))
868 }
869
870 pub(crate) fn update_kv_cache_sliding_window(
872 cache: &mut Option<(Tensor, Tensor)>,
873 k: Tensor,
874 v: Tensor,
875 attention_mask: Option<&Tensor>,
876 sliding_window: Option<usize>,
877 slow_cat: bool,
878 ) -> Result<(Tensor, Tensor, Option<Tensor>)> {
879 let (k, v, attention_mask) = match cache.clone() {
880 None => (k, v, attention_mask.cloned()),
881 Some((mut prev_k, mut prev_v)) => {
882 let mut mask = attention_mask.cloned();
883 if let Some(sliding_window) = sliding_window {
884 let kv_seq_len = prev_k.dim(2)?;
885 if kv_seq_len > sliding_window {
886 prev_k = prev_k.narrow(
887 2,
888 kv_seq_len - (sliding_window - 1),
889 sliding_window - 1,
890 )?;
891 prev_v = prev_v.narrow(
892 2,
893 kv_seq_len - (sliding_window - 1),
894 sliding_window - 1,
895 )?;
896 if let Some(ref mut mask) = mask {
897 let mask_len = mask.dim(1)?;
898 *mask = mask.narrow(
899 1,
900 mask_len - (sliding_window - 1),
901 sliding_window - 1,
902 )?;
903 *mask = Tensor::cat(
904 &[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?],
905 D::Minus1,
906 )?;
907 }
908 }
909 }
910 let (k, v) = if !slow_cat {
911 let k = candle_nn::ops::kvconcat(&prev_k, &k, 2)?;
912 let v = candle_nn::ops::kvconcat(&prev_v, &v, 2)?;
913 (k, v)
914 } else {
915 let k = Tensor::cat(&[prev_k, k], 2)?.contiguous()?;
916 let v = Tensor::cat(&[prev_v, v], 2)?.contiguous()?;
917 (k, v)
918 };
919 (k, v, mask)
920 }
921 };
922 *cache = Some((k.clone(), v.clone()));
923 Ok((k.contiguous()?, v.contiguous()?, attention_mask))
924 }
925}
926
927pub struct FullCacheManager;
928
929enum SeqCache {
930 Normal,
931 XLora,
932 Draft,
933}
934
935fn clone_in_cache(
936 num_hidden_layers: usize,
937 cache: &mut LayerCaches,
938 seqs: &mut [&mut crate::sequence::Sequence],
939 src: SeqCache,
940) {
941 let mut new_cache = Vec::new();
942 'outer: for layer in 0..num_hidden_layers {
943 let mut k_vec = Vec::new();
944 let mut v_vec = Vec::new();
945 for seq in &mut *seqs {
946 let src_cache = match src {
947 SeqCache::Normal => seq.cache(),
948 SeqCache::XLora => seq.xlora_cache(),
949 SeqCache::Draft => seq.draft_cache(),
950 };
951 let cache = src_cache.get(layer).unwrap();
952 if cache.is_none() {
954 new_cache.push(None);
955 continue 'outer;
956 }
957 let cache = cache
958 .as_ref()
959 .expect("Not handling completions in `clone_in_cache`.");
960 k_vec.push(cache.0.clone());
961 v_vec.push(cache.1.clone());
962 }
963 new_cache.push(Some((
964 if k_vec.len() > 1 {
965 Tensor::cat(&k_vec, 0).unwrap()
966 } else {
967 k_vec[0].clone()
968 },
969 if v_vec.len() > 1 {
970 Tensor::cat(&v_vec, 0).unwrap()
971 } else {
972 v_vec[0].clone()
973 },
974 )));
975 }
976 *cache = new_cache;
977}
978
979fn clone_out_cache(
980 num_hidden_layers: usize,
981 cache: &mut LayerCaches,
982 seqs: &mut [&mut crate::sequence::Sequence],
983 target: SeqCache,
984) {
985 for layer in 0..num_hidden_layers {
986 let cache = cache.get(layer).unwrap();
987 if cache.is_none() {
989 continue;
990 }
991
992 let k_cache = cache.as_ref().unwrap().0.clone();
993 let v_cache = cache.as_ref().unwrap().1.clone();
994
995 let k_caches = k_cache.chunk(seqs.len(), 0).unwrap();
996 debug_assert_eq!(k_caches.len(), seqs.len());
997 let v_caches = v_cache.chunk(seqs.len(), 0).unwrap();
998 debug_assert_eq!(v_caches.len(), seqs.len());
999
1000 for (seq_i, seq) in seqs.iter_mut().enumerate() {
1001 let output_cache = match target {
1002 SeqCache::Normal => seq.cache(),
1003 SeqCache::XLora => seq.xlora_cache(),
1004 SeqCache::Draft => seq.draft_cache(),
1005 };
1006 let seq_cache = &mut output_cache[layer];
1007 let k = k_caches.get(seq_i).unwrap().clone();
1008 let v = v_caches.get(seq_i).unwrap().clone();
1009 *seq_cache = Some((k, v));
1010 }
1011 }
1012}
1013
1014impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for FullCacheManager {
1015 fn clone_in_cache(
1016 &self,
1017 pipeline: &T,
1018 seqs: &mut [&mut crate::sequence::Sequence],
1019 modify_draft_cache: bool,
1020 ) {
1021 if modify_draft_cache {
1022 clone_in_cache(
1023 pipeline.get_metadata().num_hidden_layers,
1024 &mut pipeline.cache().full().lock(),
1025 seqs,
1026 SeqCache::Draft,
1027 );
1028 return;
1029 }
1030 clone_in_cache(
1031 pipeline.get_metadata().num_hidden_layers,
1032 &mut pipeline.cache().full().lock(),
1033 seqs,
1034 SeqCache::Normal,
1035 );
1036 if pipeline.get_metadata().is_xlora && !pipeline.get_metadata().no_kv_cache {
1037 clone_in_cache(
1038 pipeline.get_metadata().num_hidden_layers,
1039 &mut pipeline.cache().full().xlora_lock(),
1040 seqs,
1041 SeqCache::XLora,
1042 );
1043 }
1044 if pipeline.get_metadata().is_xlora {
1045 pipeline
1046 .cache()
1047 .full()
1048 .get_scalings_cache()
1049 .clone_from(seqs[0].scaling_cache());
1050 }
1051 }
1052
1053 fn clone_out_cache(
1054 &self,
1055 pipeline: &T,
1056 seqs: &mut [&mut crate::sequence::Sequence],
1057 modify_draft_cache: bool,
1058 ) {
1059 if modify_draft_cache {
1060 clone_out_cache(
1061 pipeline.get_metadata().num_hidden_layers,
1062 &mut pipeline.cache().full().lock(),
1063 seqs,
1064 SeqCache::Draft,
1065 );
1066 return;
1067 }
1068 clone_out_cache(
1069 pipeline.get_metadata().num_hidden_layers,
1070 &mut pipeline.cache().full().lock(),
1071 seqs,
1072 SeqCache::Normal,
1073 );
1074 if pipeline.get_metadata().is_xlora && !pipeline.get_metadata().no_kv_cache {
1075 clone_out_cache(
1076 pipeline.get_metadata().num_hidden_layers,
1077 &mut pipeline.cache().full().xlora_lock(),
1078 seqs,
1079 SeqCache::XLora,
1080 );
1081 }
1082 if pipeline.get_metadata().is_xlora {
1083 seqs[0]
1084 .scaling_cache()
1085 .clone_from(&pipeline.cache().full().get_scalings_cache());
1086 }
1087 }
1088
1089 fn set_none_cache(
1090 &self,
1091 pipeline: &T,
1092 _seqs: &mut [&mut Sequence],
1093 modify_draft_cache: bool,
1094 _load_preallocated_cache: bool,
1095 ) {
1096 let mut new_cache = Vec::new();
1097 for _ in 0..pipeline.get_metadata().num_hidden_layers {
1098 new_cache.push(None);
1099 }
1100 pipeline.cache().full().lock().clone_from(&new_cache);
1101 if modify_draft_cache {
1102 pipeline.cache().full().draft_lock().clone_from(&new_cache);
1103 }
1104 if pipeline.cache().full().is_xlora() {
1105 *pipeline.cache().full().xlora_lock() = new_cache;
1106 }
1107 }
1108}