mistralrs_core/dummy_paged_attention/
block_engine.rs

1use std::{
2    collections::{hash_map::Entry, HashMap},
3    hash::Hash,
4    marker::PhantomData,
5    ops::Deref,
6    sync::{Arc, Mutex, MutexGuard},
7};
8
9use super::block_engine_sequence::BlockEngineSequence;
10
11#[derive(Debug, Clone)]
12pub struct LogicalTokenBlock {
13    tokens: Vec<usize>,
14    block_size: usize,
15    num_tokens: usize,
16}
17
18impl LogicalTokenBlock {
19    pub fn new(block_size: usize) -> Self {
20        Self {
21            tokens: [0].repeat(block_size),
22            block_size,
23            num_tokens: 0,
24        }
25    }
26
27    pub fn block_size(&self) -> usize {
28        self.block_size
29    }
30
31    pub fn num_tokens(&self) -> usize {
32        self.num_tokens
33    }
34
35    pub fn is_full(&self) -> bool {
36        self.num_tokens == self.block_size
37    }
38
39    pub fn is_empty(&self) -> bool {
40        self.num_tokens == 0
41    }
42
43    pub fn append_token_id(&mut self, token: usize) {
44        assert!(!self.is_full());
45        self.tokens[self.num_tokens] = token;
46        self.num_tokens += 1;
47    }
48
49    pub fn pop_token(&mut self) {
50        assert_ne!(self.num_tokens, 0);
51        self.tokens.pop();
52        self.num_tokens -= 1;
53    }
54
55    pub fn toks(&self) -> &[usize] {
56        &self.tokens
57    }
58}
59
60impl Hash for LogicalTokenBlock {
61    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
62        self.tokens.hash(state);
63    }
64}
65
66#[derive(Hash, PartialEq, Eq)]
67pub struct _PhysicalTokenBlock {
68    pub block_id: usize,
69    block_size: usize,
70    refcount: usize,
71    is_gpu: bool,
72}
73
74impl _PhysicalTokenBlock {
75    pub fn refcount(&self) -> usize {
76        self.refcount
77    }
78    pub fn increment_refcount(&mut self) {
79        self.refcount += 1;
80    }
81    pub fn decrement_refcount(&mut self) {
82        assert!(self.refcount >= 1);
83        self.refcount -= 1;
84    }
85}
86
87pub struct PhysicalTokenBlock(pub Mutex<_PhysicalTokenBlock>);
88
89impl std::fmt::Debug for PhysicalTokenBlock {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        match self.0.lock() {
92            Ok(inner) => f
93                .debug_struct("PhysicalTokenBlock")
94                .field("block_id", &inner.block_id)
95                .field("block_size", &inner.block_size)
96                .field("refcount", &inner.refcount)
97                .field("is_gpu", &inner.is_gpu)
98                .finish(),
99            Err(_) => write!(f, "PhysicalTokenBlock(<locked>)"),
100        }
101    }
102}
103
104impl PhysicalTokenBlock {
105    pub fn deref_mut(&self) -> MutexGuard<'_, _PhysicalTokenBlock> {
106        loop {
107            if let Ok(v) = self.0.try_lock() {
108                return v;
109            }
110        }
111    }
112}
113
114impl PartialEq for PhysicalTokenBlock {
115    fn eq(&self, other: &Self) -> bool {
116        *self.deref_mut() == *other.deref_mut()
117    }
118}
119
120impl Hash for PhysicalTokenBlock {
121    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
122        self.deref_mut().hash(state)
123    }
124}
125
126impl Eq for PhysicalTokenBlock {}
127
128type BlockTable = Vec<Arc<PhysicalTokenBlock>>;
129struct GPUAllocator;
130struct CPUAllocator;
131
132struct GPUAllocatorWrapper(usize);
133// struct CPUAllocatorWrapper(usize);
134
135impl Deref for GPUAllocatorWrapper {
136    type Target = usize;
137
138    fn deref(&self) -> &Self::Target {
139        &self.0
140    }
141}
142
143// impl Deref for CPUAllocatorWrapper {
144//     type Target = usize;
145
146//     fn deref(&self) -> &Self::Target {
147//         &self.0
148//     }
149// }
150
151struct Allocator<T> {
152    free_blocks: BlockTable,
153    _ghost: PhantomData<T>,
154}
155
156impl<T> Allocator<T> {
157    fn allocate(&mut self) -> Arc<PhysicalTokenBlock> {
158        let block = self.free_blocks.pop().unwrap();
159        block.deref_mut().refcount = 1;
160        block
161    }
162
163    fn free_block(&mut self, block: Arc<PhysicalTokenBlock>) {
164        if block.deref_mut().refcount == 0 {
165            panic!(
166                "PhysicalTokenBlock with id {} experienced a double free!",
167                block.deref_mut().block_id
168            );
169        }
170        block.deref_mut().refcount -= 1;
171        if block.deref_mut().refcount == 0 {
172            self.free_blocks.push(block);
173        }
174    }
175}
176
177impl Allocator<GPUAllocator> {
178    fn new(block_size: usize, num_blocks: usize) -> Self {
179        let mut free_blocks = Vec::new();
180        for id in 0..num_blocks {
181            free_blocks.push(Arc::new(PhysicalTokenBlock(Mutex::new(
182                _PhysicalTokenBlock {
183                    block_id: id,
184                    block_size,
185                    refcount: 0,
186                    is_gpu: true,
187                },
188            ))))
189        }
190        Allocator {
191            free_blocks,
192            _ghost: PhantomData,
193        }
194    }
195
196    fn get_num_free_blocks(&self) -> GPUAllocatorWrapper {
197        GPUAllocatorWrapper(self.free_blocks.len())
198    }
199}
200
201impl Allocator<CPUAllocator> {
202    fn new(block_size: usize, num_blocks: usize) -> Self {
203        let mut free_blocks = Vec::new();
204        for id in 0..num_blocks {
205            free_blocks.push(Arc::new(PhysicalTokenBlock(Mutex::new(
206                _PhysicalTokenBlock {
207                    block_id: id,
208                    block_size,
209                    refcount: 0,
210                    is_gpu: false,
211                },
212            ))))
213        }
214        Allocator {
215            free_blocks,
216            _ghost: PhantomData,
217        }
218    }
219}
220
221#[derive(Debug)]
222pub enum AllocStatus {
223    Ok,
224    Later { waitlisted_count: usize },
225    Impossible,
226}
227
228type SeqID = usize;
229
230/// A BlockEngine maps each Sequence (identified by its SeqID), to physical token blocks.
231/// The physical token blocks may not match the logical token blocks because during
232/// scheduling, physical blocks are allocated to accommodate the new tokens generated.
233/// These new tokens will be added to the logical token block for each sequence.
234pub struct BlockEngine {
235    num_gpu_blocks: usize,
236    block_size: usize,
237    gpu_allocator: Allocator<GPUAllocator>,
238    cpu_allocator: Allocator<CPUAllocator>,
239    pub block_tables: HashMap<SeqID, BlockTable>,
240}
241
242pub type BlockTables = HashMap<usize, BlockTable>;
243
244impl BlockEngine {
245    #[must_use]
246    pub fn new(block_size: usize, num_gpu_blocks: usize, num_cpu_blocks: usize) -> Self {
247        Self {
248            num_gpu_blocks,
249            block_size,
250            gpu_allocator: Allocator::<GPUAllocator>::new(block_size, num_gpu_blocks),
251            cpu_allocator: Allocator::<CPUAllocator>::new(block_size, num_cpu_blocks),
252            block_tables: HashMap::new(),
253        }
254    }
255
256    pub fn block_size(&self) -> usize {
257        self.block_size
258    }
259
260    pub fn can_allocate(&self, seq: &mut impl BlockEngineSequence) -> AllocStatus {
261        let num_required_blocks = seq.logical_token_blocks().len();
262        let num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks();
263
264        if self.num_gpu_blocks < num_required_blocks {
265            AllocStatus::Impossible
266        } else if *num_free_gpu_blocks < num_required_blocks {
267            AllocStatus::Later {
268                waitlisted_count: seq.increment_waitlist_count(),
269            }
270        } else {
271            AllocStatus::Ok
272        }
273    }
274
275    pub fn allocate(&mut self, seq: &mut impl BlockEngineSequence) {
276        // If there are prefill physical blocks, use those here.
277        if let Some(physical_blocks_prefill) = seq.take_physical_blocks_prefill() {
278            let mut block_table = physical_blocks_prefill.clone();
279            let n_extra_blocks = seq.logical_token_blocks().len() - block_table.len();
280            for _ in 0..n_extra_blocks {
281                block_table.push(self.gpu_allocator.allocate());
282            }
283            self.block_tables.insert(seq.get_id(), block_table.clone());
284        } else {
285            let mut block_table = Vec::new();
286            for _logcical_idx in 0..seq.logical_token_blocks().len() {
287                block_table.push(self.gpu_allocator.allocate());
288            }
289            self.block_tables.insert(seq.get_id(), block_table.clone());
290        }
291    }
292
293    pub fn can_append_token_to_seq(&self, seq: &impl BlockEngineSequence) -> bool {
294        let free_blocks = self.gpu_allocator.get_num_free_blocks();
295        // Physical blocks = logical blocks
296        seq.blocks_to_add_new_tok() <= *free_blocks
297    }
298
299    pub fn free_sequence(&mut self, id: usize) {
300        // Handle double free if run out of tokens
301        if let Some(block_table) = self.block_tables.get(&id) {
302            // Free from block table
303            for block in block_table {
304                if block.deref_mut().is_gpu {
305                    self.gpu_allocator.free_block(block.clone())
306                } else {
307                    self.cpu_allocator.free_block(block.clone())
308                }
309            }
310
311            self.block_tables.remove(&id);
312        }
313    }
314
315    #[allow(dead_code)]
316    pub fn can_swap_out_seq(&self, seq: &impl BlockEngineSequence) -> bool {
317        let blocks_required: usize = self
318            .block_tables
319            .iter()
320            .filter(|(id, _)| seq.get_id() == **id)
321            .map(|(_, table)| table.len())
322            .sum();
323        blocks_required <= self.cpu_allocator.free_blocks.len()
324    }
325
326    /// Update the block table so that the sequence does no longer reserve any GPU
327    /// physical blocks, and only has CPU physical blocks.
328    #[allow(dead_code)]
329    pub fn swap_out(&mut self, seq: &impl BlockEngineSequence) -> HashMap<usize, usize> {
330        // GPU block to a CPU block
331        let mut new_mapping = HashMap::new();
332        let seq_id = seq.get_id();
333
334        let mut new_block_table = Vec::new();
335        let block_table = self.block_tables.get(&seq_id).unwrap();
336
337        for gpu_block in block_table {
338            let cpu_block =
339                if let Entry::Vacant(e) = new_mapping.entry(gpu_block.deref_mut().block_id) {
340                    // Create a new block
341                    let cpu_block = self.cpu_allocator.allocate();
342                    e.insert(cpu_block.clone());
343                    cpu_block
344                } else {
345                    // Reuse a block
346                    let cpu_block = new_mapping
347                        .get(&gpu_block.deref_mut().block_id)
348                        .unwrap()
349                        .clone();
350                    cpu_block.deref_mut().refcount += 1;
351                    cpu_block
352                };
353            new_block_table.push(cpu_block);
354            self.gpu_allocator.free_block(gpu_block.clone());
355        }
356        self.block_tables.insert(seq_id, new_block_table);
357
358        new_mapping
359            .iter()
360            .map(|(k, v)| (*k, v.deref_mut().block_id))
361            .collect::<HashMap<_, _>>()
362    }
363
364    // Returns the COW mapping (src, dst).
365    // COW is performed if there are multiple references to the last physical block.
366    pub fn append_token_slot_to_seq(
367        &mut self,
368        sequence: &impl BlockEngineSequence,
369    ) -> Option<(usize, usize)> {
370        let table = self.block_tables.get_mut(&sequence.get_id())?;
371
372        match sequence.blocks_to_add_new_tok() {
373            1 => {
374                table.push(self.gpu_allocator.allocate());
375                None
376            }
377            0 => {
378                let last_block = table.last_mut().unwrap();
379                assert!(last_block.deref_mut().is_gpu);
380                if last_block.deref_mut().refcount == 1 {
381                    None
382                } else {
383                    // We would be writing into shared, so COW.
384                    let new_block = self.gpu_allocator.allocate();
385                    self.gpu_allocator.free_block(last_block.clone());
386                    let old_number = last_block.deref_mut().block_id;
387                    let new_number = new_block.deref_mut().block_id;
388                    *last_block = new_block;
389                    Some((old_number, new_number))
390                }
391            }
392            _ => {
393                unreachable!()
394            }
395        }
396    }
397
398    pub fn can_swap_in_seq(&self, seq: &impl BlockEngineSequence) -> bool {
399        let blocks_required: usize = self
400            .block_tables
401            .iter()
402            .filter(|(id, _)| seq.get_id() == **id)
403            .map(|(_, table)| table.len())
404            .sum();
405        blocks_required <= self.gpu_allocator.free_blocks.len()
406    }
407
408    /// Update the block table so that the sequence does no longer reserve any CPU
409    /// physical blocks, and only has GPU physical blocks.
410    pub fn swap_in(&mut self, seq: &impl BlockEngineSequence) -> HashMap<usize, usize> {
411        // CPU block to a GPU block
412        let mut new_mapping = HashMap::new();
413        let seq_id = seq.get_id();
414
415        let mut new_block_table = Vec::new();
416        let block_table = self.block_tables.get(&seq_id).unwrap();
417
418        for cpu_block in block_table {
419            let gpu_block =
420                if let Entry::Vacant(e) = new_mapping.entry(cpu_block.deref_mut().block_id) {
421                    // Create a new block
422                    let gpu_block = self.cpu_allocator.allocate();
423                    e.insert(gpu_block.clone());
424                    gpu_block
425                } else {
426                    // Reuse a block
427                    let gpu_block = new_mapping
428                        .get(&cpu_block.deref_mut().block_id)
429                        .unwrap()
430                        .clone();
431                    gpu_block.deref_mut().refcount += 1;
432                    gpu_block
433                };
434            new_block_table.push(gpu_block);
435            self.gpu_allocator.free_block(cpu_block.clone());
436        }
437        self.block_tables.insert(seq_id, new_block_table);
438
439        new_mapping
440            .iter()
441            .map(|(k, v)| (*k, v.deref_mut().block_id))
442            .collect::<HashMap<_, _>>()
443    }
444}