use std::sync::{Arc, Mutex, MutexGuard};
use candle_core::{Result, Tensor, D};
use crate::{get_mut_arcmutex, sequence::Sequence};
use super::{CacheManagerMixin, MetadataMixin};
pub trait CacheManager<T: CacheManagerMixin + MetadataMixin + ?Sized> {
fn clone_in_cache(
&self,
pipeline: &T,
seqs: &mut [&mut crate::sequence::Sequence],
modify_draft_cache: bool,
);
fn clone_out_cache(&self, pipeline: &T, seqs: &mut [&mut Sequence], modify_draft_cache: bool);
fn set_none_cache(
&self,
pipeline: &T,
seqs: &mut [&mut Sequence],
modify_draft_cache: bool,
load_preallocated_cache: bool,
);
}
pub type LayerCaches = Vec<Option<(Tensor, Tensor)>>;
#[derive(Debug, Clone)]
pub enum EitherCache {
Normal(Arc<Mutex<NormalCache>>),
Full(Cache),
}
impl EitherCache {
pub fn full(&self) -> &Cache {
match self {
Self::Full(full) => full,
Self::Normal(_) => panic!("Got normal cache, expected full cache."),
}
}
pub fn normal(&self) -> MutexGuard<'_, NormalCache> {
match self {
Self::Normal(normal) => normal.lock().unwrap(),
Self::Full(_) => panic!("Got full cache, expected normal cache."),
}
}
}
#[derive(Debug, Clone)]
pub struct SingleCache {
pub all_data: Option<Tensor>,
pub dim: usize,
pub current_seq_len: usize,
pub capacity_seq_len: usize,
pub max_seq_len: usize,
}
impl SingleCache {
pub fn new(dim: usize, max_seq_len: usize, capacity_seq_len: usize) -> Self {
Self {
all_data: None,
dim,
current_seq_len: 0,
max_seq_len,
capacity_seq_len,
}
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn current_seq_len(&self) -> usize {
self.current_seq_len
}
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
pub fn all_data(&self) -> &Option<Tensor> {
&self.all_data
}
pub fn current_data(&self) -> Result<Option<Tensor>> {
let data = match self.all_data.as_ref() {
None => None,
Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?),
};
Ok(data)
}
pub fn reset(&mut self) {
self.current_seq_len = 0;
self.all_data = None;
}
pub fn set_len(&mut self, len: usize) {
self.current_seq_len = len;
}
pub fn append(&mut self, src: &Tensor) -> Result<()> {
let seq_len = src.dim(self.dim)?;
if self.all_data.is_none() {
let mut shape = src.dims().to_vec();
shape[self.dim] = self.capacity_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
self.all_data = Some(ad);
};
if self.current_seq_len + seq_len > self.capacity_seq_len {
let diff = self.current_seq_len + seq_len - self.capacity_seq_len;
let n_blocks_needed = diff.div_ceil(NormalCache::CACHE_GROW_SIZE);
self.capacity_seq_len += n_blocks_needed * NormalCache::CACHE_GROW_SIZE;
if self.capacity_seq_len > self.max_seq_len {
candle_core::bail!(
"kv-cache: requested capacity ({}) above max seq len ({})",
self.capacity_seq_len,
self.max_seq_len
)
}
let mut shape = src.dims().to_vec();
shape[self.dim] = self.capacity_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
ad.slice_set(self.all_data.as_ref().unwrap(), self.dim, 0)?;
self.all_data = Some(ad);
}
let ad = self.all_data.as_mut().unwrap();
ad.slice_set(src, self.dim, self.current_seq_len)?;
self.current_seq_len += seq_len;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct KvCache {
pub k: SingleCache,
pub v: SingleCache,
}
impl KvCache {
pub fn new(dim: usize, max_seq_len: usize, capacity_seq_len: usize) -> Self {
let k = SingleCache::new(dim, max_seq_len, capacity_seq_len);
let v = SingleCache::new(dim, max_seq_len, capacity_seq_len);
Self { k, v }
}
pub fn k_cache(&self) -> &SingleCache {
&self.k
}
pub fn v_cache(&self) -> &SingleCache {
&self.v
}
pub fn k_cache_mut(&mut self) -> &mut SingleCache {
&mut self.k
}
pub fn v_cache_mut(&mut self) -> &mut SingleCache {
&mut self.v
}
pub fn k(&self) -> Result<Option<Tensor>> {
self.k.current_data()
}
pub fn v(&self) -> Result<Option<Tensor>> {
self.v.current_data()
}
pub fn append_sliding_window(
&mut self,
k: &Tensor,
v: &Tensor,
mask: Option<&Tensor>,
sliding_window: Option<usize>,
) -> Result<(Tensor, Tensor, Option<Tensor>)> {
let (mut k, mut v) = self.append(k, v)?;
if let Some(sliding_window) = sliding_window {
assert_eq!(self.k.dim, 2);
let kv_seq_len = k.dim(2)?;
if kv_seq_len > sliding_window {
k = k.narrow(2, kv_seq_len - (sliding_window - 1), sliding_window - 1)?;
v = v.narrow(2, kv_seq_len - (sliding_window - 1), sliding_window - 1)?;
if let Some(mut mask) = mask.cloned() {
let mask_len = mask.dim(1)?;
mask = mask.narrow(1, mask_len - (sliding_window - 1), sliding_window - 1)?;
return Ok((k, v, Some(mask)));
}
}
}
Ok((k, v, mask.cloned()))
}
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
let k = k.contiguous()?;
let v = v.contiguous()?;
self.k.append(&k)?;
self.v.append(&v)?;
let out_k = self.k.current_data()?;
let out_v = self.v.current_data()?;
let k = match out_k {
None => {
let mut shape = k.dims().to_vec();
shape[self.k.dim] = 0;
Tensor::zeros(shape, k.dtype(), k.device())?
}
Some(k) => k,
};
let v = match out_v {
None => {
let mut shape = v.dims().to_vec();
shape[self.k.dim] = 0;
Tensor::zeros(shape, v.dtype(), v.device())?
}
Some(v) => v,
};
Ok((k, v))
}
pub fn current_seq_len(&self) -> usize {
self.k.current_seq_len()
}
pub fn reset(&mut self) {
self.k.reset();
self.v.reset();
}
pub fn set_len(&mut self, len: usize) {
self.k.set_len(len);
self.v.set_len(len);
}
}
#[derive(Debug, Clone)]
pub struct NormalCache(pub Vec<KvCache>);
impl NormalCache {
pub const CACHE_GROW_SIZE: usize = 512;
pub fn new(len: usize, max_seq_len: usize) -> Arc<Mutex<Self>> {
Arc::new(Mutex::new(Self(vec![
KvCache::new(
2,
max_seq_len,
Self::CACHE_GROW_SIZE
);
len
])))
}
}
pub struct NormalCacheManager;
impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for NormalCacheManager {
fn clone_in_cache(
&self,
pipeline: &T,
seqs: &mut [&mut crate::sequence::Sequence],
_modify_draft_cache: bool,
) {
let mut new_k_cache = Vec::new();
let mut new_v_cache = Vec::new();
let template_cache_dim = seqs[0].normal_cache()[0].as_ref().unwrap().k.dim;
let template_cache_csl = seqs[0].normal_cache()[0]
.as_ref()
.unwrap()
.k
.current_seq_len;
let template_cache_msl = seqs[0].normal_cache()[0].as_ref().unwrap().k.max_seq_len;
let template_cache_capsl = seqs[0].normal_cache()[0]
.as_ref()
.unwrap()
.k
.capacity_seq_len;
'outer: for layer in 0..pipeline.get_metadata().num_hidden_layers {
let mut k_vec = Vec::new();
let mut v_vec = Vec::new();
for seq in &mut *seqs {
let src_cache = seq.normal_cache();
let cache = src_cache.get(layer).unwrap();
if cache.is_none() {
new_k_cache.push(None);
new_v_cache.push(None);
continue 'outer;
}
let cache = cache
.as_ref()
.expect("Not handling completions in `clone_in_cache`.");
k_vec.push(cache.k.all_data.clone().unwrap());
v_vec.push(cache.v.all_data.clone().unwrap());
}
new_k_cache.push(Some(if k_vec.len() > 1 {
Tensor::cat(&k_vec, 0).unwrap()
} else {
k_vec[0].clone()
}));
new_v_cache.push(Some(if v_vec.len() > 1 {
Tensor::cat(&v_vec, 0).unwrap()
} else {
v_vec[0].clone()
}));
}
let mut caches = Vec::new();
for (k_cache, v_cache) in new_k_cache.into_iter().zip(new_v_cache) {
caches.push(KvCache {
k: SingleCache {
all_data: k_cache.map(|x| x.contiguous().unwrap()),
dim: template_cache_dim,
current_seq_len: template_cache_csl,
max_seq_len: template_cache_msl,
capacity_seq_len: template_cache_capsl,
},
v: SingleCache {
all_data: v_cache.map(|x| x.contiguous().unwrap()),
dim: template_cache_dim,
current_seq_len: template_cache_csl,
max_seq_len: template_cache_msl,
capacity_seq_len: template_cache_capsl,
},
});
}
*pipeline.cache().normal() = NormalCache(caches);
}
fn clone_out_cache(&self, pipeline: &T, seqs: &mut [&mut Sequence], _modify_draft_cache: bool) {
let all_cache = pipeline.cache().normal();
for layer in 0..pipeline.get_metadata().num_hidden_layers {
let cache = all_cache.0.get(layer).unwrap();
if cache.k().unwrap().is_none() {
continue;
}
let k_cache = cache.k.all_data.clone().unwrap();
let v_cache = cache.v.all_data.clone().unwrap();
let k_caches = k_cache.chunk(seqs.len(), 0).unwrap();
debug_assert_eq!(k_caches.len(), seqs.len());
let v_caches = v_cache.chunk(seqs.len(), 0).unwrap();
debug_assert_eq!(v_caches.len(), seqs.len());
for (seq_i, seq) in seqs.iter_mut().enumerate() {
let output_cache = seq.normal_cache();
let seq_cache = &mut output_cache[layer];
let k = k_caches.get(seq_i).unwrap().clone();
let v = v_caches.get(seq_i).unwrap().clone();
*seq_cache = Some(KvCache {
k: SingleCache {
all_data: Some(k),
dim: cache.k.dim,
current_seq_len: cache.k.current_seq_len,
max_seq_len: cache.k.max_seq_len,
capacity_seq_len: cache.k.capacity_seq_len,
},
v: SingleCache {
all_data: Some(v),
dim: cache.v.dim,
current_seq_len: cache.v.current_seq_len,
max_seq_len: cache.v.max_seq_len,
capacity_seq_len: cache.v.capacity_seq_len,
},
});
}
}
}
fn set_none_cache(
&self,
pipeline: &T,
seqs: &mut [&mut Sequence],
_modify_draft_cache: bool,
load_preallocated_cache: bool,
) {
let template_cache_dim = pipeline.cache().normal().0[0].k.dim;
let template_cache_msl = pipeline.cache().normal().0[0].k.max_seq_len;
for layer in pipeline.cache().normal().0.iter_mut() {
if !load_preallocated_cache {
layer.reset();
continue;
}
let mut k_caches = Vec::new();
let mut v_caches = Vec::new();
for seq in seqs.iter_mut() {
let (k_preallocated_cache, v_preallocated_cache) =
(*seq.preallocated_cache().as_ref().unwrap()).clone();
k_caches.push(k_preallocated_cache);
v_caches.push(v_preallocated_cache);
}
let k_cache = if k_caches.len() > 1 {
Tensor::cat(&k_caches, 0).unwrap()
} else {
k_caches[0].clone()
};
let v_cache = if v_caches.len() > 1 {
Tensor::cat(&v_caches, 0).unwrap()
} else {
v_caches[0].clone()
};
let cache = KvCache {
k: SingleCache {
all_data: Some(k_cache.zeros_like().unwrap()),
dim: template_cache_dim,
current_seq_len: 0,
max_seq_len: template_cache_msl,
capacity_seq_len: k_cache.dims()[template_cache_dim],
},
v: SingleCache {
all_data: Some(v_cache.zeros_like().unwrap()),
dim: template_cache_dim,
current_seq_len: 0,
max_seq_len: template_cache_msl,
capacity_seq_len: k_cache.dims()[template_cache_dim],
},
};
*layer = cache;
}
}
}
#[derive(Debug, Clone)]
pub struct Cache {
cache: Arc<Mutex<LayerCaches>>,
xlora_cache: Option<Arc<Mutex<LayerCaches>>>,
draft_cache: Arc<Mutex<LayerCaches>>,
scalings_cache: Option<Arc<Mutex<Option<Tensor>>>>,
}
impl Cache {
pub(crate) fn new(len: usize, is_xlora: bool) -> Self {
Self {
cache: Arc::new(Mutex::new(vec![None; len])),
xlora_cache: if is_xlora {
Some(Arc::new(Mutex::new(vec![None; len])))
} else {
None
},
draft_cache: Arc::new(Mutex::new(vec![None; len])),
scalings_cache: if is_xlora {
Some(Arc::new(Mutex::new(None)))
} else {
None
},
}
}
pub(crate) fn lock(&self) -> MutexGuard<'_, LayerCaches> {
get_mut_arcmutex!(self.cache)
}
pub(crate) fn draft_lock(&self) -> MutexGuard<'_, LayerCaches> {
get_mut_arcmutex!(self.draft_cache)
}
pub(crate) fn xlora_lock(&self) -> MutexGuard<'_, LayerCaches> {
get_mut_arcmutex!(self.xlora_cache.as_ref().expect("No X-LoRA cache."))
}
pub(crate) fn get_scalings_cache(&self) -> MutexGuard<'_, Option<Tensor>> {
get_mut_arcmutex!(self
.scalings_cache
.as_ref()
.expect("No X-LoRA scalings cache."))
}
pub(crate) fn is_xlora(&self) -> bool {
self.xlora_cache.is_some()
}
pub(crate) fn update_kv_cache(
cache: &mut Option<(Tensor, Tensor)>,
k: Tensor,
v: Tensor,
slow_cat: bool,
) -> Result<(Tensor, Tensor)> {
let (k, v) = match &*cache {
None => (k, v),
Some((k_cache, v_cache)) => {
if !slow_cat {
let k = candle_nn::ops::kvconcat(k_cache, &k, 2)?.contiguous()?;
let v = candle_nn::ops::kvconcat(v_cache, &v, 2)?.contiguous()?;
(k, v)
} else {
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
(k, v)
}
}
};
*cache = Some((k.clone(), v.clone()));
Ok((k.contiguous()?, v.contiguous()?))
}
pub(crate) fn update_kv_cache_sliding_window(
cache: &mut Option<(Tensor, Tensor)>,
k: Tensor,
v: Tensor,
attention_mask: Option<&Tensor>,
sliding_window: Option<usize>,
slow_cat: bool,
) -> Result<(Tensor, Tensor, Option<Tensor>)> {
let (k, v, attention_mask) = match cache.clone() {
None => (k, v, attention_mask.cloned()),
Some((mut prev_k, mut prev_v)) => {
let mut mask = attention_mask.cloned();
if let Some(sliding_window) = sliding_window {
let kv_seq_len = prev_k.dim(2)?;
if kv_seq_len > sliding_window {
prev_k = prev_k.narrow(
2,
kv_seq_len - (sliding_window - 1),
sliding_window - 1,
)?;
prev_v = prev_v.narrow(
2,
kv_seq_len - (sliding_window - 1),
sliding_window - 1,
)?;
if let Some(ref mut mask) = mask {
let mask_len = mask.dim(1)?;
*mask = mask.narrow(
1,
mask_len - (sliding_window - 1),
sliding_window - 1,
)?;
*mask = Tensor::cat(
&[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?],
D::Minus1,
)?;
}
}
}
let (k, v) = if !slow_cat {
let k = candle_nn::ops::kvconcat(&prev_k, &k, 2)?;
let v = candle_nn::ops::kvconcat(&prev_v, &v, 2)?;
(k, v)
} else {
let k = Tensor::cat(&[prev_k, k], 2)?.contiguous()?;
let v = Tensor::cat(&[prev_v, v], 2)?.contiguous()?;
(k, v)
};
(k, v, mask)
}
};
*cache = Some((k.clone(), v.clone()));
Ok((k.contiguous()?, v.contiguous()?, attention_mask))
}
}
pub struct FullCacheManager;
enum SeqCache {
Normal,
XLora,
Draft,
}
fn clone_in_cache(
num_hidden_layers: usize,
cache: &mut LayerCaches,
seqs: &mut [&mut crate::sequence::Sequence],
src: SeqCache,
) {
let mut new_cache = Vec::new();
'outer: for layer in 0..num_hidden_layers {
let mut k_vec = Vec::new();
let mut v_vec = Vec::new();
for seq in &mut *seqs {
let src_cache = match src {
SeqCache::Normal => seq.cache(),
SeqCache::XLora => seq.xlora_cache(),
SeqCache::Draft => seq.draft_cache(),
};
let cache = src_cache.get(layer).unwrap();
if cache.is_none() {
new_cache.push(None);
continue 'outer;
}
let cache = cache
.as_ref()
.expect("Not handling completions in `clone_in_cache`.");
k_vec.push(cache.0.clone());
v_vec.push(cache.1.clone());
}
new_cache.push(Some((
if k_vec.len() > 1 {
Tensor::cat(&k_vec, 0).unwrap()
} else {
k_vec[0].clone()
},
if v_vec.len() > 1 {
Tensor::cat(&v_vec, 0).unwrap()
} else {
v_vec[0].clone()
},
)));
}
*cache = new_cache;
}
fn clone_out_cache(
num_hidden_layers: usize,
cache: &mut LayerCaches,
seqs: &mut [&mut crate::sequence::Sequence],
target: SeqCache,
) {
for layer in 0..num_hidden_layers {
let cache = cache.get(layer).unwrap();
if cache.is_none() {
continue;
}
let k_cache = cache.as_ref().unwrap().0.clone();
let v_cache = cache.as_ref().unwrap().1.clone();
let k_caches = k_cache.chunk(seqs.len(), 0).unwrap();
debug_assert_eq!(k_caches.len(), seqs.len());
let v_caches = v_cache.chunk(seqs.len(), 0).unwrap();
debug_assert_eq!(v_caches.len(), seqs.len());
for (seq_i, seq) in seqs.iter_mut().enumerate() {
let output_cache = match target {
SeqCache::Normal => seq.cache(),
SeqCache::XLora => seq.xlora_cache(),
SeqCache::Draft => seq.draft_cache(),
};
let seq_cache = &mut output_cache[layer];
let k = k_caches.get(seq_i).unwrap().clone();
let v = v_caches.get(seq_i).unwrap().clone();
*seq_cache = Some((k, v));
}
}
}
impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for FullCacheManager {
fn clone_in_cache(
&self,
pipeline: &T,
seqs: &mut [&mut crate::sequence::Sequence],
modify_draft_cache: bool,
) {
if modify_draft_cache {
clone_in_cache(
pipeline.get_metadata().num_hidden_layers,
&mut pipeline.cache().full().lock(),
seqs,
SeqCache::Draft,
);
return;
}
clone_in_cache(
pipeline.get_metadata().num_hidden_layers,
&mut pipeline.cache().full().lock(),
seqs,
SeqCache::Normal,
);
if pipeline.get_metadata().is_xlora && !pipeline.get_metadata().no_kv_cache {
clone_in_cache(
pipeline.get_metadata().num_hidden_layers,
&mut pipeline.cache().full().xlora_lock(),
seqs,
SeqCache::XLora,
);
}
if pipeline.get_metadata().is_xlora {
pipeline
.cache()
.full()
.get_scalings_cache()
.clone_from(seqs[0].scaling_cache());
}
}
fn clone_out_cache(
&self,
pipeline: &T,
seqs: &mut [&mut crate::sequence::Sequence],
modify_draft_cache: bool,
) {
if modify_draft_cache {
clone_out_cache(
pipeline.get_metadata().num_hidden_layers,
&mut pipeline.cache().full().lock(),
seqs,
SeqCache::Draft,
);
return;
}
clone_out_cache(
pipeline.get_metadata().num_hidden_layers,
&mut pipeline.cache().full().lock(),
seqs,
SeqCache::Normal,
);
if pipeline.get_metadata().is_xlora && !pipeline.get_metadata().no_kv_cache {
clone_out_cache(
pipeline.get_metadata().num_hidden_layers,
&mut pipeline.cache().full().xlora_lock(),
seqs,
SeqCache::XLora,
);
}
if pipeline.get_metadata().is_xlora {
seqs[0]
.scaling_cache()
.clone_from(&pipeline.cache().full().get_scalings_cache());
}
}
fn set_none_cache(
&self,
pipeline: &T,
_seqs: &mut [&mut Sequence],
modify_draft_cache: bool,
_load_preallocated_cache: bool,
) {
let mut new_cache = Vec::new();
for _ in 0..pipeline.get_metadata().num_hidden_layers {
new_cache.push(None);
}
pipeline.cache().full().lock().clone_from(&new_cache);
if modify_draft_cache {
pipeline.cache().full().draft_lock().clone_from(&new_cache);
}
if pipeline.cache().full().is_xlora() {
*pipeline.cache().full().xlora_lock() = new_cache;
}
}
}