mistralrs_core/dummy_paged_attention/
block_engine.rs1use std::{
2 collections::{hash_map::Entry, HashMap},
3 hash::Hash,
4 marker::PhantomData,
5 ops::Deref,
6 sync::{Arc, Mutex, MutexGuard},
7};
8
9use super::block_engine_sequence::BlockEngineSequence;
10
11pub struct LogicalTokenBlock {
12 tokens: Vec<usize>,
13 block_size: usize,
14 num_tokens: usize,
15}
16
17impl LogicalTokenBlock {
18 pub fn new(block_size: usize) -> Self {
19 Self {
20 tokens: [0].repeat(block_size),
21 block_size,
22 num_tokens: 0,
23 }
24 }
25
26 pub fn is_full(&self) -> bool {
27 self.num_tokens == self.block_size
28 }
29
30 pub fn is_empty(&self) -> bool {
31 self.num_tokens == 0
32 }
33
34 pub fn append_token_id(&mut self, token: usize) {
35 assert!(!self.is_full());
36 self.tokens[self.num_tokens] = token;
37 self.num_tokens += 1;
38 }
39
40 pub fn pop_token(&mut self) {
41 assert_ne!(self.num_tokens, 0);
42 self.tokens.pop();
43 self.num_tokens -= 1;
44 }
45}
46
47#[derive(Hash, PartialEq, Eq)]
48pub struct _PhysicalTokenBlock {
49 pub block_id: usize,
50 block_size: usize,
51 refcount: usize,
52 is_gpu: bool,
53}
54
55pub struct PhysicalTokenBlock(pub Mutex<_PhysicalTokenBlock>);
56
57impl PhysicalTokenBlock {
58 pub fn deref_mut(&self) -> MutexGuard<'_, _PhysicalTokenBlock> {
59 loop {
60 if let Ok(v) = self.0.try_lock() {
61 return v;
62 }
63 }
64 }
65}
66
67impl PartialEq for PhysicalTokenBlock {
68 fn eq(&self, other: &Self) -> bool {
69 *self.deref_mut() == *other.deref_mut()
70 }
71}
72
73impl Hash for PhysicalTokenBlock {
74 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
75 self.deref_mut().hash(state)
76 }
77}
78
79impl Eq for PhysicalTokenBlock {}
80
81type BlockTable = Vec<Arc<PhysicalTokenBlock>>;
82struct GPUAllocator;
83struct CPUAllocator;
84
85struct GPUAllocatorWrapper(usize);
86impl Deref for GPUAllocatorWrapper {
89 type Target = usize;
90
91 fn deref(&self) -> &Self::Target {
92 &self.0
93 }
94}
95
96struct Allocator<T> {
105 free_blocks: BlockTable,
106 _ghost: PhantomData<T>,
107}
108
109impl<T> Allocator<T> {
110 fn allocate(&mut self) -> Arc<PhysicalTokenBlock> {
111 let block = self.free_blocks.pop().unwrap();
112 block.deref_mut().refcount = 1;
113 block
114 }
115
116 fn free_block(&mut self, block: Arc<PhysicalTokenBlock>) {
117 if block.deref_mut().refcount == 0 {
118 panic!(
119 "PhysicalTokenBlock with id {} experienced a double free!",
120 block.deref_mut().block_id
121 );
122 }
123 block.deref_mut().refcount -= 1;
124 if block.deref_mut().refcount == 0 {
125 self.free_blocks.push(block);
126 }
127 }
128}
129
130impl Allocator<GPUAllocator> {
131 fn new(block_size: usize, num_blocks: usize) -> Self {
132 let mut free_blocks = Vec::new();
133 for id in 0..num_blocks {
134 free_blocks.push(Arc::new(PhysicalTokenBlock(Mutex::new(
135 _PhysicalTokenBlock {
136 block_id: id,
137 block_size,
138 refcount: 0,
139 is_gpu: true,
140 },
141 ))))
142 }
143 Allocator {
144 free_blocks,
145 _ghost: PhantomData,
146 }
147 }
148
149 fn get_num_free_blocks(&self) -> GPUAllocatorWrapper {
150 GPUAllocatorWrapper(self.free_blocks.len())
151 }
152}
153
154impl Allocator<CPUAllocator> {
155 fn new(block_size: usize, num_blocks: usize) -> Self {
156 let mut free_blocks = Vec::new();
157 for id in 0..num_blocks {
158 free_blocks.push(Arc::new(PhysicalTokenBlock(Mutex::new(
159 _PhysicalTokenBlock {
160 block_id: id,
161 block_size,
162 refcount: 0,
163 is_gpu: true,
164 },
165 ))))
166 }
167 Allocator {
168 free_blocks,
169 _ghost: PhantomData,
170 }
171 }
172}
173
174pub enum AllocStatus {
175 Ok,
176 Later,
177 Impossible,
178}
179
180type SeqID = usize;
181
182pub struct BlockEngine {
187 num_gpu_blocks: usize,
188 gpu_allocator: Allocator<GPUAllocator>,
189 cpu_allocator: Allocator<CPUAllocator>,
190 pub block_tables: HashMap<SeqID, BlockTable>,
191}
192
193pub type BlockTables = HashMap<usize, BlockTable>;
194
195impl BlockEngine {
196 #[must_use]
197 pub fn new(block_size: usize, num_gpu_blocks: usize, num_cpu_blocks: usize) -> Self {
198 Self {
199 num_gpu_blocks,
200 gpu_allocator: Allocator::<GPUAllocator>::new(block_size, num_gpu_blocks),
201 cpu_allocator: Allocator::<CPUAllocator>::new(block_size, num_cpu_blocks),
202 block_tables: HashMap::new(),
203 }
204 }
205
206 pub fn can_allocate(&self, seq: &impl BlockEngineSequence) -> AllocStatus {
207 let num_required_blocks = seq.get_logical_token_blocks();
208 let num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks();
209
210 if self.num_gpu_blocks > *num_free_gpu_blocks + num_required_blocks {
211 AllocStatus::Later
212 } else if self.num_gpu_blocks < num_required_blocks {
213 AllocStatus::Impossible
214 } else {
215 AllocStatus::Ok
216 }
217 }
218
219 pub fn allocate(&mut self, seq: &impl BlockEngineSequence) {
220 let mut block_table = Vec::new();
221 for _logcical_idx in 0..seq.get_logical_token_blocks() {
222 block_table.push(self.gpu_allocator.allocate());
223 }
224 self.block_tables.insert(seq.get_id(), block_table.clone());
225 }
226
227 pub fn can_append_token_to_seq(&self, seq: &impl BlockEngineSequence) -> bool {
228 let free_blocks = self.gpu_allocator.get_num_free_blocks();
229 seq.blocks_to_add_new_tok() <= *free_blocks
231 }
232
233 pub fn free_sequence(&mut self, id: usize) {
234 if let Some(block_table) = self.block_tables.get(&id) {
236 for block in block_table {
238 if block.deref_mut().is_gpu {
239 self.gpu_allocator.free_block(block.clone())
240 } else {
241 self.cpu_allocator.free_block(block.clone())
242 }
243 }
244
245 self.block_tables.remove(&id);
246 }
247 }
248
249 #[allow(dead_code)]
250 pub fn can_swap_out_seq(&self, seq: &impl BlockEngineSequence) -> bool {
251 let blocks_required: usize = self
252 .block_tables
253 .iter()
254 .filter(|(id, _)| seq.get_id() == **id)
255 .map(|(_, table)| table.len())
256 .sum();
257 blocks_required <= self.cpu_allocator.free_blocks.len()
258 }
259
260 #[allow(dead_code)]
263 pub fn swap_out(&mut self, seq: &impl BlockEngineSequence) -> HashMap<usize, usize> {
264 let mut new_mapping = HashMap::new();
266 let seq_id = seq.get_id();
267
268 let mut new_block_table = Vec::new();
269 let block_table = self.block_tables.get(&seq_id).unwrap();
270
271 for gpu_block in block_table {
272 let cpu_block =
273 if let Entry::Vacant(e) = new_mapping.entry(gpu_block.deref_mut().block_id) {
274 let cpu_block = self.cpu_allocator.allocate();
276 e.insert(cpu_block.clone());
277 cpu_block
278 } else {
279 let cpu_block = new_mapping
281 .get(&gpu_block.deref_mut().block_id)
282 .unwrap()
283 .clone();
284 cpu_block.deref_mut().refcount += 1;
285 cpu_block
286 };
287 new_block_table.push(cpu_block);
288 self.gpu_allocator.free_block(gpu_block.clone());
289 }
290 self.block_tables.insert(seq_id, new_block_table);
291
292 new_mapping
293 .iter()
294 .map(|(k, v)| (*k, v.deref_mut().block_id))
295 .collect::<HashMap<_, _>>()
296 }
297
298 pub fn append_token_slot_to_seq(
301 &mut self,
302 sequence: &impl BlockEngineSequence,
303 ) -> Option<(usize, usize)> {
304 let table = self.block_tables.get_mut(&sequence.get_id())?;
305
306 match sequence.blocks_to_add_new_tok() {
307 1 => {
308 table.push(self.gpu_allocator.allocate());
309 None
310 }
311 0 => {
312 let last_block = table.last_mut().unwrap();
313 assert!(last_block.deref_mut().is_gpu);
314 if last_block.deref_mut().refcount == 1 {
315 None
316 } else {
317 let new_block = self.gpu_allocator.allocate();
319 self.gpu_allocator.free_block(last_block.clone());
320 let old_number = last_block.deref_mut().block_id;
321 let new_number = new_block.deref_mut().block_id;
322 *last_block = new_block;
323 Some((old_number, new_number))
324 }
325 }
326 _ => {
327 unreachable!()
328 }
329 }
330 }
331
332 pub fn can_swap_in_seq(&self, seq: &impl BlockEngineSequence) -> bool {
333 let blocks_required: usize = self
334 .block_tables
335 .iter()
336 .filter(|(id, _)| seq.get_id() == **id)
337 .map(|(_, table)| table.len())
338 .sum();
339 blocks_required <= self.gpu_allocator.free_blocks.len()
340 }
341
342 pub fn swap_in(&mut self, seq: &impl BlockEngineSequence) -> HashMap<usize, usize> {
345 let mut new_mapping = HashMap::new();
347 let seq_id = seq.get_id();
348
349 let mut new_block_table = Vec::new();
350 let block_table = self.block_tables.get(&seq_id).unwrap();
351
352 for cpu_block in block_table {
353 let gpu_block =
354 if let Entry::Vacant(e) = new_mapping.entry(cpu_block.deref_mut().block_id) {
355 let gpu_block = self.cpu_allocator.allocate();
357 e.insert(gpu_block.clone());
358 gpu_block
359 } else {
360 let gpu_block = new_mapping
362 .get(&cpu_block.deref_mut().block_id)
363 .unwrap()
364 .clone();
365 gpu_block.deref_mut().refcount += 1;
366 gpu_block
367 };
368 new_block_table.push(gpu_block);
369 self.gpu_allocator.free_block(cpu_block.clone());
370 }
371 self.block_tables.insert(seq_id, new_block_table);
372
373 new_mapping
374 .iter()
375 .map(|(k, v)| (*k, v.deref_mut().block_id))
376 .collect::<HashMap<_, _>>()
377 }
378}