mistralrs_core/dummy_paged_attention/
block_engine.rs1use std::{
2 collections::{hash_map::Entry, HashMap},
3 hash::Hash,
4 marker::PhantomData,
5 ops::Deref,
6 sync::{Arc, Mutex, MutexGuard},
7};
8
9use super::block_engine_sequence::BlockEngineSequence;
10
11#[derive(Debug, Clone)]
12pub struct LogicalTokenBlock {
13 tokens: Vec<usize>,
14 block_size: usize,
15 num_tokens: usize,
16}
17
18impl LogicalTokenBlock {
19 pub fn new(block_size: usize) -> Self {
20 Self {
21 tokens: [0].repeat(block_size),
22 block_size,
23 num_tokens: 0,
24 }
25 }
26
27 pub fn block_size(&self) -> usize {
28 self.block_size
29 }
30
31 pub fn num_tokens(&self) -> usize {
32 self.num_tokens
33 }
34
35 pub fn is_full(&self) -> bool {
36 self.num_tokens == self.block_size
37 }
38
39 pub fn is_empty(&self) -> bool {
40 self.num_tokens == 0
41 }
42
43 pub fn append_token_id(&mut self, token: usize) {
44 assert!(!self.is_full());
45 self.tokens[self.num_tokens] = token;
46 self.num_tokens += 1;
47 }
48
49 pub fn pop_token(&mut self) {
50 assert_ne!(self.num_tokens, 0);
51 self.tokens.pop();
52 self.num_tokens -= 1;
53 }
54
55 pub fn toks(&self) -> &[usize] {
56 &self.tokens
57 }
58}
59
60impl Hash for LogicalTokenBlock {
61 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
62 self.tokens.hash(state);
63 }
64}
65
66#[derive(Hash, PartialEq, Eq)]
67pub struct _PhysicalTokenBlock {
68 pub block_id: usize,
69 block_size: usize,
70 refcount: usize,
71 is_gpu: bool,
72}
73
74impl _PhysicalTokenBlock {
75 pub fn refcount(&self) -> usize {
76 self.refcount
77 }
78 pub fn increment_refcount(&mut self) {
79 self.refcount += 1;
80 }
81 pub fn decrement_refcount(&mut self) {
82 assert!(self.refcount >= 1);
83 self.refcount -= 1;
84 }
85}
86
87pub struct PhysicalTokenBlock(pub Mutex<_PhysicalTokenBlock>);
88
89impl std::fmt::Debug for PhysicalTokenBlock {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 match self.0.lock() {
92 Ok(inner) => f
93 .debug_struct("PhysicalTokenBlock")
94 .field("block_id", &inner.block_id)
95 .field("block_size", &inner.block_size)
96 .field("refcount", &inner.refcount)
97 .field("is_gpu", &inner.is_gpu)
98 .finish(),
99 Err(_) => write!(f, "PhysicalTokenBlock(<locked>)"),
100 }
101 }
102}
103
104impl PhysicalTokenBlock {
105 pub fn deref_mut(&self) -> MutexGuard<'_, _PhysicalTokenBlock> {
106 loop {
107 if let Ok(v) = self.0.try_lock() {
108 return v;
109 }
110 }
111 }
112}
113
114impl PartialEq for PhysicalTokenBlock {
115 fn eq(&self, other: &Self) -> bool {
116 *self.deref_mut() == *other.deref_mut()
117 }
118}
119
120impl Hash for PhysicalTokenBlock {
121 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
122 self.deref_mut().hash(state)
123 }
124}
125
126impl Eq for PhysicalTokenBlock {}
127
128type BlockTable = Vec<Arc<PhysicalTokenBlock>>;
129struct GPUAllocator;
130struct CPUAllocator;
131
132struct GPUAllocatorWrapper(usize);
133impl Deref for GPUAllocatorWrapper {
136 type Target = usize;
137
138 fn deref(&self) -> &Self::Target {
139 &self.0
140 }
141}
142
143struct Allocator<T> {
152 free_blocks: BlockTable,
153 _ghost: PhantomData<T>,
154}
155
156impl<T> Allocator<T> {
157 fn allocate(&mut self) -> Arc<PhysicalTokenBlock> {
158 let block = self.free_blocks.pop().unwrap();
159 block.deref_mut().refcount = 1;
160 block
161 }
162
163 fn free_block(&mut self, block: Arc<PhysicalTokenBlock>) {
164 if block.deref_mut().refcount == 0 {
165 panic!(
166 "PhysicalTokenBlock with id {} experienced a double free!",
167 block.deref_mut().block_id
168 );
169 }
170 block.deref_mut().refcount -= 1;
171 if block.deref_mut().refcount == 0 {
172 self.free_blocks.push(block);
173 }
174 }
175}
176
177impl Allocator<GPUAllocator> {
178 fn new(block_size: usize, num_blocks: usize) -> Self {
179 let mut free_blocks = Vec::new();
180 for id in 0..num_blocks {
181 free_blocks.push(Arc::new(PhysicalTokenBlock(Mutex::new(
182 _PhysicalTokenBlock {
183 block_id: id,
184 block_size,
185 refcount: 0,
186 is_gpu: true,
187 },
188 ))))
189 }
190 Allocator {
191 free_blocks,
192 _ghost: PhantomData,
193 }
194 }
195
196 fn get_num_free_blocks(&self) -> GPUAllocatorWrapper {
197 GPUAllocatorWrapper(self.free_blocks.len())
198 }
199}
200
201impl Allocator<CPUAllocator> {
202 fn new(block_size: usize, num_blocks: usize) -> Self {
203 let mut free_blocks = Vec::new();
204 for id in 0..num_blocks {
205 free_blocks.push(Arc::new(PhysicalTokenBlock(Mutex::new(
206 _PhysicalTokenBlock {
207 block_id: id,
208 block_size,
209 refcount: 0,
210 is_gpu: false,
211 },
212 ))))
213 }
214 Allocator {
215 free_blocks,
216 _ghost: PhantomData,
217 }
218 }
219}
220
221#[derive(Debug)]
222pub enum AllocStatus {
223 Ok,
224 Later { waitlisted_count: usize },
225 Impossible,
226}
227
228type SeqID = usize;
229
230pub struct BlockEngine {
235 num_gpu_blocks: usize,
236 block_size: usize,
237 gpu_allocator: Allocator<GPUAllocator>,
238 cpu_allocator: Allocator<CPUAllocator>,
239 pub block_tables: HashMap<SeqID, BlockTable>,
240}
241
242pub type BlockTables = HashMap<usize, BlockTable>;
243
244impl BlockEngine {
245 #[must_use]
246 pub fn new(block_size: usize, num_gpu_blocks: usize, num_cpu_blocks: usize) -> Self {
247 Self {
248 num_gpu_blocks,
249 block_size,
250 gpu_allocator: Allocator::<GPUAllocator>::new(block_size, num_gpu_blocks),
251 cpu_allocator: Allocator::<CPUAllocator>::new(block_size, num_cpu_blocks),
252 block_tables: HashMap::new(),
253 }
254 }
255
256 pub fn block_size(&self) -> usize {
257 self.block_size
258 }
259
260 pub fn can_allocate(&self, seq: &mut impl BlockEngineSequence) -> AllocStatus {
261 let num_required_blocks = seq.logical_token_blocks().len();
262 let num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks();
263
264 if self.num_gpu_blocks < num_required_blocks {
265 AllocStatus::Impossible
266 } else if *num_free_gpu_blocks < num_required_blocks {
267 AllocStatus::Later {
268 waitlisted_count: seq.increment_waitlist_count(),
269 }
270 } else {
271 AllocStatus::Ok
272 }
273 }
274
275 pub fn allocate(&mut self, seq: &mut impl BlockEngineSequence) {
276 if let Some(physical_blocks_prefill) = seq.take_physical_blocks_prefill() {
278 let mut block_table = physical_blocks_prefill.clone();
279 let n_extra_blocks = seq.logical_token_blocks().len() - block_table.len();
280 for _ in 0..n_extra_blocks {
281 block_table.push(self.gpu_allocator.allocate());
282 }
283 self.block_tables.insert(seq.get_id(), block_table.clone());
284 } else {
285 let mut block_table = Vec::new();
286 for _logcical_idx in 0..seq.logical_token_blocks().len() {
287 block_table.push(self.gpu_allocator.allocate());
288 }
289 self.block_tables.insert(seq.get_id(), block_table.clone());
290 }
291 }
292
293 pub fn can_append_token_to_seq(&self, seq: &impl BlockEngineSequence) -> bool {
294 let free_blocks = self.gpu_allocator.get_num_free_blocks();
295 seq.blocks_to_add_new_tok() <= *free_blocks
297 }
298
299 pub fn free_sequence(&mut self, id: usize) {
300 if let Some(block_table) = self.block_tables.get(&id) {
302 for block in block_table {
304 if block.deref_mut().is_gpu {
305 self.gpu_allocator.free_block(block.clone())
306 } else {
307 self.cpu_allocator.free_block(block.clone())
308 }
309 }
310
311 self.block_tables.remove(&id);
312 }
313 }
314
315 #[allow(dead_code)]
316 pub fn can_swap_out_seq(&self, seq: &impl BlockEngineSequence) -> bool {
317 let blocks_required: usize = self
318 .block_tables
319 .iter()
320 .filter(|(id, _)| seq.get_id() == **id)
321 .map(|(_, table)| table.len())
322 .sum();
323 blocks_required <= self.cpu_allocator.free_blocks.len()
324 }
325
326 #[allow(dead_code)]
329 pub fn swap_out(&mut self, seq: &impl BlockEngineSequence) -> HashMap<usize, usize> {
330 let mut new_mapping = HashMap::new();
332 let seq_id = seq.get_id();
333
334 let mut new_block_table = Vec::new();
335 let block_table = self.block_tables.get(&seq_id).unwrap();
336
337 for gpu_block in block_table {
338 let cpu_block =
339 if let Entry::Vacant(e) = new_mapping.entry(gpu_block.deref_mut().block_id) {
340 let cpu_block = self.cpu_allocator.allocate();
342 e.insert(cpu_block.clone());
343 cpu_block
344 } else {
345 let cpu_block = new_mapping
347 .get(&gpu_block.deref_mut().block_id)
348 .unwrap()
349 .clone();
350 cpu_block.deref_mut().refcount += 1;
351 cpu_block
352 };
353 new_block_table.push(cpu_block);
354 self.gpu_allocator.free_block(gpu_block.clone());
355 }
356 self.block_tables.insert(seq_id, new_block_table);
357
358 new_mapping
359 .iter()
360 .map(|(k, v)| (*k, v.deref_mut().block_id))
361 .collect::<HashMap<_, _>>()
362 }
363
364 pub fn append_token_slot_to_seq(
367 &mut self,
368 sequence: &impl BlockEngineSequence,
369 ) -> Option<(usize, usize)> {
370 let table = self.block_tables.get_mut(&sequence.get_id())?;
371
372 match sequence.blocks_to_add_new_tok() {
373 1 => {
374 table.push(self.gpu_allocator.allocate());
375 None
376 }
377 0 => {
378 let last_block = table.last_mut().unwrap();
379 assert!(last_block.deref_mut().is_gpu);
380 if last_block.deref_mut().refcount == 1 {
381 None
382 } else {
383 let new_block = self.gpu_allocator.allocate();
385 self.gpu_allocator.free_block(last_block.clone());
386 let old_number = last_block.deref_mut().block_id;
387 let new_number = new_block.deref_mut().block_id;
388 *last_block = new_block;
389 Some((old_number, new_number))
390 }
391 }
392 _ => {
393 unreachable!()
394 }
395 }
396 }
397
398 pub fn can_swap_in_seq(&self, seq: &impl BlockEngineSequence) -> bool {
399 let blocks_required: usize = self
400 .block_tables
401 .iter()
402 .filter(|(id, _)| seq.get_id() == **id)
403 .map(|(_, table)| table.len())
404 .sum();
405 blocks_required <= self.gpu_allocator.free_blocks.len()
406 }
407
408 pub fn swap_in(&mut self, seq: &impl BlockEngineSequence) -> HashMap<usize, usize> {
411 let mut new_mapping = HashMap::new();
413 let seq_id = seq.get_id();
414
415 let mut new_block_table = Vec::new();
416 let block_table = self.block_tables.get(&seq_id).unwrap();
417
418 for cpu_block in block_table {
419 let gpu_block =
420 if let Entry::Vacant(e) = new_mapping.entry(cpu_block.deref_mut().block_id) {
421 let gpu_block = self.cpu_allocator.allocate();
423 e.insert(gpu_block.clone());
424 gpu_block
425 } else {
426 let gpu_block = new_mapping
428 .get(&cpu_block.deref_mut().block_id)
429 .unwrap()
430 .clone();
431 gpu_block.deref_mut().refcount += 1;
432 gpu_block
433 };
434 new_block_table.push(gpu_block);
435 self.gpu_allocator.free_block(cpu_block.clone());
436 }
437 self.block_tables.insert(seq_id, new_block_table);
438
439 new_mapping
440 .iter()
441 .map(|(k, v)| (*k, v.deref_mut().block_id))
442 .collect::<HashMap<_, _>>()
443 }
444}