mistralrs_core/dummy_paged_attention/
block_engine.rs

1use std::{
2    collections::{hash_map::Entry, HashMap},
3    hash::Hash,
4    marker::PhantomData,
5    ops::Deref,
6    sync::{Arc, Mutex, MutexGuard},
7};
8
9use super::block_engine_sequence::BlockEngineSequence;
10
11pub struct LogicalTokenBlock {
12    tokens: Vec<usize>,
13    block_size: usize,
14    num_tokens: usize,
15}
16
17impl LogicalTokenBlock {
18    pub fn new(block_size: usize) -> Self {
19        Self {
20            tokens: [0].repeat(block_size),
21            block_size,
22            num_tokens: 0,
23        }
24    }
25
26    pub fn is_full(&self) -> bool {
27        self.num_tokens == self.block_size
28    }
29
30    pub fn is_empty(&self) -> bool {
31        self.num_tokens == 0
32    }
33
34    pub fn append_token_id(&mut self, token: usize) {
35        assert!(!self.is_full());
36        self.tokens[self.num_tokens] = token;
37        self.num_tokens += 1;
38    }
39
40    pub fn pop_token(&mut self) {
41        assert_ne!(self.num_tokens, 0);
42        self.tokens.pop();
43        self.num_tokens -= 1;
44    }
45}
46
47#[derive(Hash, PartialEq, Eq)]
48pub struct _PhysicalTokenBlock {
49    pub block_id: usize,
50    block_size: usize,
51    refcount: usize,
52    is_gpu: bool,
53}
54
55pub struct PhysicalTokenBlock(pub Mutex<_PhysicalTokenBlock>);
56
57impl PhysicalTokenBlock {
58    pub fn deref_mut(&self) -> MutexGuard<'_, _PhysicalTokenBlock> {
59        loop {
60            if let Ok(v) = self.0.try_lock() {
61                return v;
62            }
63        }
64    }
65}
66
67impl PartialEq for PhysicalTokenBlock {
68    fn eq(&self, other: &Self) -> bool {
69        *self.deref_mut() == *other.deref_mut()
70    }
71}
72
73impl Hash for PhysicalTokenBlock {
74    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
75        self.deref_mut().hash(state)
76    }
77}
78
79impl Eq for PhysicalTokenBlock {}
80
81type BlockTable = Vec<Arc<PhysicalTokenBlock>>;
82struct GPUAllocator;
83struct CPUAllocator;
84
85struct GPUAllocatorWrapper(usize);
86// struct CPUAllocatorWrapper(usize);
87
88impl Deref for GPUAllocatorWrapper {
89    type Target = usize;
90
91    fn deref(&self) -> &Self::Target {
92        &self.0
93    }
94}
95
96// impl Deref for CPUAllocatorWrapper {
97//     type Target = usize;
98
99//     fn deref(&self) -> &Self::Target {
100//         &self.0
101//     }
102// }
103
104struct Allocator<T> {
105    free_blocks: BlockTable,
106    _ghost: PhantomData<T>,
107}
108
109impl<T> Allocator<T> {
110    fn allocate(&mut self) -> Arc<PhysicalTokenBlock> {
111        let block = self.free_blocks.pop().unwrap();
112        block.deref_mut().refcount = 1;
113        block
114    }
115
116    fn free_block(&mut self, block: Arc<PhysicalTokenBlock>) {
117        if block.deref_mut().refcount == 0 {
118            panic!(
119                "PhysicalTokenBlock with id {} experienced a double free!",
120                block.deref_mut().block_id
121            );
122        }
123        block.deref_mut().refcount -= 1;
124        if block.deref_mut().refcount == 0 {
125            self.free_blocks.push(block);
126        }
127    }
128}
129
130impl Allocator<GPUAllocator> {
131    fn new(block_size: usize, num_blocks: usize) -> Self {
132        let mut free_blocks = Vec::new();
133        for id in 0..num_blocks {
134            free_blocks.push(Arc::new(PhysicalTokenBlock(Mutex::new(
135                _PhysicalTokenBlock {
136                    block_id: id,
137                    block_size,
138                    refcount: 0,
139                    is_gpu: true,
140                },
141            ))))
142        }
143        Allocator {
144            free_blocks,
145            _ghost: PhantomData,
146        }
147    }
148
149    fn get_num_free_blocks(&self) -> GPUAllocatorWrapper {
150        GPUAllocatorWrapper(self.free_blocks.len())
151    }
152}
153
154impl Allocator<CPUAllocator> {
155    fn new(block_size: usize, num_blocks: usize) -> Self {
156        let mut free_blocks = Vec::new();
157        for id in 0..num_blocks {
158            free_blocks.push(Arc::new(PhysicalTokenBlock(Mutex::new(
159                _PhysicalTokenBlock {
160                    block_id: id,
161                    block_size,
162                    refcount: 0,
163                    is_gpu: true,
164                },
165            ))))
166        }
167        Allocator {
168            free_blocks,
169            _ghost: PhantomData,
170        }
171    }
172}
173
174pub enum AllocStatus {
175    Ok,
176    Later,
177    Impossible,
178}
179
180type SeqID = usize;
181
182/// A BlockEngine maps each Sequence (identified by its SeqID), to physical token blocks.
183/// The physical token blocks may not match the logical token blocks because during
184/// scheduling, physical blocks are allocated to accommodate the new tokens generated.
185/// These new tokens will be added to the logical token block for each sequence.
186pub struct BlockEngine {
187    num_gpu_blocks: usize,
188    gpu_allocator: Allocator<GPUAllocator>,
189    cpu_allocator: Allocator<CPUAllocator>,
190    pub block_tables: HashMap<SeqID, BlockTable>,
191}
192
193pub type BlockTables = HashMap<usize, BlockTable>;
194
195impl BlockEngine {
196    #[must_use]
197    pub fn new(block_size: usize, num_gpu_blocks: usize, num_cpu_blocks: usize) -> Self {
198        Self {
199            num_gpu_blocks,
200            gpu_allocator: Allocator::<GPUAllocator>::new(block_size, num_gpu_blocks),
201            cpu_allocator: Allocator::<CPUAllocator>::new(block_size, num_cpu_blocks),
202            block_tables: HashMap::new(),
203        }
204    }
205
206    pub fn can_allocate(&self, seq: &impl BlockEngineSequence) -> AllocStatus {
207        let num_required_blocks = seq.get_logical_token_blocks();
208        let num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks();
209
210        if self.num_gpu_blocks > *num_free_gpu_blocks + num_required_blocks {
211            AllocStatus::Later
212        } else if self.num_gpu_blocks < num_required_blocks {
213            AllocStatus::Impossible
214        } else {
215            AllocStatus::Ok
216        }
217    }
218
219    pub fn allocate(&mut self, seq: &impl BlockEngineSequence) {
220        let mut block_table = Vec::new();
221        for _logcical_idx in 0..seq.get_logical_token_blocks() {
222            block_table.push(self.gpu_allocator.allocate());
223        }
224        self.block_tables.insert(seq.get_id(), block_table.clone());
225    }
226
227    pub fn can_append_token_to_seq(&self, seq: &impl BlockEngineSequence) -> bool {
228        let free_blocks = self.gpu_allocator.get_num_free_blocks();
229        // Physical blocks = logical blocks
230        seq.blocks_to_add_new_tok() <= *free_blocks
231    }
232
233    pub fn free_sequence(&mut self, id: usize) {
234        // Handle double free if run out of tokens
235        if let Some(block_table) = self.block_tables.get(&id) {
236            // Free from block table
237            for block in block_table {
238                if block.deref_mut().is_gpu {
239                    self.gpu_allocator.free_block(block.clone())
240                } else {
241                    self.cpu_allocator.free_block(block.clone())
242                }
243            }
244
245            self.block_tables.remove(&id);
246        }
247    }
248
249    #[allow(dead_code)]
250    pub fn can_swap_out_seq(&self, seq: &impl BlockEngineSequence) -> bool {
251        let blocks_required: usize = self
252            .block_tables
253            .iter()
254            .filter(|(id, _)| seq.get_id() == **id)
255            .map(|(_, table)| table.len())
256            .sum();
257        blocks_required <= self.cpu_allocator.free_blocks.len()
258    }
259
260    /// Update the block table so that the sequence does no longer reserve any GPU
261    /// physical blocks, and only has CPU physical blocks.
262    #[allow(dead_code)]
263    pub fn swap_out(&mut self, seq: &impl BlockEngineSequence) -> HashMap<usize, usize> {
264        // GPU block to a CPU block
265        let mut new_mapping = HashMap::new();
266        let seq_id = seq.get_id();
267
268        let mut new_block_table = Vec::new();
269        let block_table = self.block_tables.get(&seq_id).unwrap();
270
271        for gpu_block in block_table {
272            let cpu_block =
273                if let Entry::Vacant(e) = new_mapping.entry(gpu_block.deref_mut().block_id) {
274                    // Create a new block
275                    let cpu_block = self.cpu_allocator.allocate();
276                    e.insert(cpu_block.clone());
277                    cpu_block
278                } else {
279                    // Reuse a block
280                    let cpu_block = new_mapping
281                        .get(&gpu_block.deref_mut().block_id)
282                        .unwrap()
283                        .clone();
284                    cpu_block.deref_mut().refcount += 1;
285                    cpu_block
286                };
287            new_block_table.push(cpu_block);
288            self.gpu_allocator.free_block(gpu_block.clone());
289        }
290        self.block_tables.insert(seq_id, new_block_table);
291
292        new_mapping
293            .iter()
294            .map(|(k, v)| (*k, v.deref_mut().block_id))
295            .collect::<HashMap<_, _>>()
296    }
297
298    // Returns the COW mapping (src, dst).
299    // COW is performed if there are multiple references to the last physical block.
300    pub fn append_token_slot_to_seq(
301        &mut self,
302        sequence: &impl BlockEngineSequence,
303    ) -> Option<(usize, usize)> {
304        let table = self.block_tables.get_mut(&sequence.get_id())?;
305
306        match sequence.blocks_to_add_new_tok() {
307            1 => {
308                table.push(self.gpu_allocator.allocate());
309                None
310            }
311            0 => {
312                let last_block = table.last_mut().unwrap();
313                assert!(last_block.deref_mut().is_gpu);
314                if last_block.deref_mut().refcount == 1 {
315                    None
316                } else {
317                    // We would be writing into shared, so COW.
318                    let new_block = self.gpu_allocator.allocate();
319                    self.gpu_allocator.free_block(last_block.clone());
320                    let old_number = last_block.deref_mut().block_id;
321                    let new_number = new_block.deref_mut().block_id;
322                    *last_block = new_block;
323                    Some((old_number, new_number))
324                }
325            }
326            _ => {
327                unreachable!()
328            }
329        }
330    }
331
332    pub fn can_swap_in_seq(&self, seq: &impl BlockEngineSequence) -> bool {
333        let blocks_required: usize = self
334            .block_tables
335            .iter()
336            .filter(|(id, _)| seq.get_id() == **id)
337            .map(|(_, table)| table.len())
338            .sum();
339        blocks_required <= self.gpu_allocator.free_blocks.len()
340    }
341
342    /// Update the block table so that the sequence does no longer reserve any CPU
343    /// physical blocks, and only has GPU physical blocks.
344    pub fn swap_in(&mut self, seq: &impl BlockEngineSequence) -> HashMap<usize, usize> {
345        // CPU block to a GPU block
346        let mut new_mapping = HashMap::new();
347        let seq_id = seq.get_id();
348
349        let mut new_block_table = Vec::new();
350        let block_table = self.block_tables.get(&seq_id).unwrap();
351
352        for cpu_block in block_table {
353            let gpu_block =
354                if let Entry::Vacant(e) = new_mapping.entry(cpu_block.deref_mut().block_id) {
355                    // Create a new block
356                    let gpu_block = self.cpu_allocator.allocate();
357                    e.insert(gpu_block.clone());
358                    gpu_block
359                } else {
360                    // Reuse a block
361                    let gpu_block = new_mapping
362                        .get(&cpu_block.deref_mut().block_id)
363                        .unwrap()
364                        .clone();
365                    gpu_block.deref_mut().refcount += 1;
366                    gpu_block
367                };
368            new_block_table.push(gpu_block);
369            self.gpu_allocator.free_block(cpu_block.clone());
370        }
371        self.block_tables.insert(seq_id, new_block_table);
372
373        new_mapping
374            .iter()
375            .map(|(k, v)| (*k, v.deref_mut().block_id))
376            .collect::<HashMap<_, _>>()
377    }
378}