1use std::collections::HashMap;
2
3use candle_core::{Device, Result};
4use either::Either;
5use itertools::Itertools;
6use tracing::info;
7
8use crate::{
9 pipeline::{KvCache, RotatingCache, SingleCache},
10 sequence::Sequence,
11};
12
13#[derive(PartialEq, Eq, Debug, Hash)]
14struct Tokens(Vec<u32>);
15
16impl Tokens {
17 fn find_max_index(&self, x: &Self) -> Option<usize> {
19 self.0
20 .iter()
21 .zip(x.0.iter())
22 .enumerate()
23 .take_while(|(_, (a, b))| a == b)
24 .map(|(index, _)| index)
25 .last()
26 }
27}
28
29impl From<Vec<u32>> for Tokens {
30 fn from(value: Vec<u32>) -> Self {
31 Self(value)
32 }
33}
34
35#[derive(Clone)]
36struct CacheElement {
37 cache: Vec<Option<KvCache>>,
38 devices: Vec<Option<Device>>,
39}
40
41pub struct PrefixCacheManagerV2 {
42 caches: HashMap<Tokens, CacheElement>,
43 n_on_device: usize,
44 no_prefix_cache: bool,
45}
46
47#[derive(Clone)]
48pub struct MatchingCache {
49 pub normal: Vec<Option<KvCache>>,
50 pub toks: Vec<u32>,
51 pub offset: usize,
52}
53
54impl PrefixCacheManagerV2 {
55 pub fn new(n_on_device: usize, no_prefix_cache: bool) -> Self {
56 if !no_prefix_cache {
57 info!("PrefixCacherV2 is enabled! Expect higher multi-turn prompt throughput.");
58 }
59 PrefixCacheManagerV2 {
60 caches: HashMap::new(),
61 n_on_device,
62 no_prefix_cache,
63 }
64 }
65
66 pub fn add_sequence(&mut self, seq: &mut Sequence) {
68 if self.no_prefix_cache || seq.has_images() {
69 return;
70 }
71 let cache = seq.normal_cache().to_vec();
72 let devices = cache
73 .iter()
74 .map(|x| x.as_ref().map(|x| x.k().unwrap().unwrap().device().clone()))
75 .collect::<Vec<_>>();
76 self.caches.insert(
77 seq.get_toks().to_vec().into(),
78 CacheElement { cache, devices },
79 );
80 }
81
82 fn cache_to(
83 cache: &mut [Option<KvCache>],
84 devices: Either<&Device, &Vec<Option<Device>>>,
85 ) -> Result<()> {
86 for (i, layer) in cache
87 .iter_mut()
88 .enumerate()
89 .flat_map(|(i, x)| x.as_mut().map(|x| (i, x)))
90 {
91 let device = devices.left_or_else(|layers| layers[i].as_ref().unwrap());
92
93 match layer {
94 KvCache::Normal { k, v } => {
95 *layer = KvCache::Normal {
96 k: SingleCache {
97 all_data: k.all_data.as_ref().map(|x| x.to_device(device).unwrap()),
98 dim: k.dim,
99 current_seq_len: k.current_seq_len,
100 max_seq_len: k.max_seq_len,
101 capacity_seq_len: k.capacity_seq_len,
102 },
103 v: SingleCache {
104 all_data: v.all_data.as_ref().map(|x| x.to_device(device).unwrap()),
105 dim: v.dim,
106 current_seq_len: v.current_seq_len,
107 max_seq_len: v.max_seq_len,
108 capacity_seq_len: v.capacity_seq_len,
109 },
110 }
111 }
112 KvCache::Rotating { k, v } => {
113 *layer = KvCache::Rotating {
114 k: RotatingCache {
115 all_data: k.all_data.as_ref().map(|x| x.to_device(device).unwrap()),
116 dim: k.dim,
117 current_seq_len: k.current_seq_len,
118 max_seq_len: k.max_seq_len,
119 offset: k.offset,
120 capacity_seq_len: k.capacity_seq_len,
121 },
122 v: RotatingCache {
123 all_data: v.all_data.as_ref().map(|x| x.to_device(device).unwrap()),
124 dim: v.dim,
125 current_seq_len: v.current_seq_len,
126 max_seq_len: v.max_seq_len,
127 offset: v.offset,
128 capacity_seq_len: v.capacity_seq_len,
129 },
130 }
131 }
132 }
133 }
134 Ok(())
135 }
136
137 pub fn evict_to_cpu(&mut self) -> Result<usize> {
140 if self.no_prefix_cache {
141 return Ok(0);
142 }
143 let mut n_on_device = 0;
144 for cache in self.caches.values() {
145 let first_non_none = cache.cache.iter().find_or_first(|x| x.is_some());
146 let Some(Some(first_non_none)) = first_non_none else {
147 continue;
148 };
149
150 let cache_device = match first_non_none {
151 KvCache::Normal { k, .. } => {
152 k.all_data().as_ref().expect("No KV cache data").device()
153 }
154 KvCache::Rotating { k, .. } => {
155 k.all_data().as_ref().expect("No KV cache data").device()
156 }
157 };
158
159 if !matches!(cache_device, Device::Cpu) {
160 n_on_device += 1;
161 }
162 }
163 let mut n_evicted = 0;
164 for cache in self.caches.values_mut() {
166 if n_on_device - n_evicted == self.n_on_device {
167 break;
168 }
169 let first_non_none = cache.cache.iter().find_or_first(|x| x.is_some());
170 let Some(Some(first_non_none)) = first_non_none else {
171 continue;
172 };
173
174 let cache_device = match first_non_none {
175 KvCache::Normal { k, .. } => {
176 k.all_data().as_ref().expect("No KV cache data").device()
177 }
178 KvCache::Rotating { k, .. } => {
179 k.all_data().as_ref().expect("No KV cache data").device()
180 }
181 };
182
183 if !matches!(cache_device, Device::Cpu) {
184 Self::cache_to(&mut cache.cache, Either::Left(&Device::Cpu))?;
185 n_evicted += 1;
186 }
187 }
188 Ok(self.caches.len().saturating_sub(self.n_on_device))
189 }
190
191 pub fn evict_all_to_cpu(&mut self) -> Result<usize> {
193 if self.no_prefix_cache {
194 return Ok(0);
195 }
196 for cache in self.caches.values_mut() {
198 let first_non_none = cache.cache.iter().find_or_first(|x| x.is_some());
199 let Some(Some(first_non_none)) = first_non_none else {
200 continue;
201 };
202
203 let cache_device = match first_non_none {
204 KvCache::Normal { k, .. } => {
205 k.all_data().as_ref().expect("No KV cache data").device()
206 }
207 KvCache::Rotating { k, .. } => {
208 k.all_data().as_ref().expect("No KV cache data").device()
209 }
210 };
211
212 if !matches!(cache_device, Device::Cpu) {
213 Self::cache_to(&mut cache.cache, Either::Left(&Device::Cpu))?;
214 }
215 }
216 Ok(self.caches.len())
217 }
218
219 pub fn search_for_matching_cache(
221 &mut self,
222 toks: &[u32],
223 contains_images: bool,
224 ) -> Result<Option<MatchingCache>> {
225 if self.no_prefix_cache || toks.is_empty() || contains_images {
226 return Ok(None);
227 }
228
229 let toks = Tokens(toks.to_vec());
230
231 let mut longest_match = (0, None);
232 for (k, v) in self.caches.iter() {
233 let match_len = toks.find_max_index(k);
234 if let Some(match_len) = match_len {
235 if match_len > longest_match.0 {
236 longest_match = (match_len, Some(v));
237 }
238 }
239 }
240 if let (match_len, Some(longest_match)) = longest_match {
241 let mut cache = longest_match.clone();
242 Self::cache_to(&mut cache.cache, Either::Right(&cache.devices))?;
243 for layer in cache.cache.iter_mut().flatten() {
244 match layer.set_len(match_len) {
245 Ok(_) => (),
246 Err(_) => return Ok(None),
247 }
248 }
249 Ok(Some(MatchingCache {
250 normal: cache.cache,
251 toks: toks.0[match_len..].to_vec(),
252 offset: match_len,
253 }))
254 } else {
255 Ok(None)
256 }
257 }
258}