mistralrs_core/dummy_paged_attention/
scheduler.rs1type CPUBlockFrom = usize;
6type GPUBlockFrom = usize;
7type CPUBlockTo = usize;
8type GPUBlockTo = usize;
9type SrcBlockFrom = usize;
10type DstBlocksTo = Vec<usize>;
11
12use std::{
13 collections::{HashMap, VecDeque},
14 sync::{atomic::Ordering, Arc, Mutex},
15};
16
17use tracing::warn;
18
19use crate::{
20 get_mut_arcmutex,
21 paged_attention::BlockEngine,
22 scheduler::{Scheduler, SchedulerOutput},
23 sequence::{Sequence, SequenceState, StopReason},
24 TERMINATE_ALL_NEXT_STEP,
25};
26
27use super::{block_engine::AllocStatus, BlockEngineSequence, BlockTables, CacheConfig};
28
29pub struct PagedAttentionSchedulerOutput {
30 pub scheduled: Vec<Arc<Mutex<Sequence>>>,
32 pub blocks_to_swap_in: HashMap<CPUBlockFrom, GPUBlockTo>,
33 pub blocks_to_swap_out: HashMap<GPUBlockFrom, CPUBlockTo>,
34 pub blocks_to_copy: HashMap<SrcBlockFrom, DstBlocksTo>,
35}
36
37pub struct PagedAttentionSchedulerConfig {
38 pub max_num_seqs: usize,
39}
40
41pub struct PagedAttentionScheduler {
42 waiting: VecDeque<Arc<Mutex<Sequence>>>,
43 running: VecDeque<Arc<Mutex<Sequence>>>,
44 swapped_out: VecDeque<Arc<Mutex<Sequence>>>,
45 config: PagedAttentionSchedulerConfig,
46 pub block_engine: BlockEngine,
47 block_size: usize,
48}
49
50impl PagedAttentionScheduler {
51 pub fn new(config: PagedAttentionSchedulerConfig, cache_config: CacheConfig) -> Self {
52 Self {
53 waiting: VecDeque::new(),
54 running: VecDeque::new(),
55 swapped_out: VecDeque::new(),
56 config,
57 block_engine: BlockEngine::new(
58 cache_config.block_size,
59 cache_config.num_gpu_blocks,
60 cache_config.num_cpu_blocks,
61 ),
62 block_size: cache_config.block_size,
63 }
64 }
65
66 pub fn schedule(&mut self) -> PagedAttentionSchedulerOutput {
67 if self.swapped_out.is_empty() {
70 let mut scheduled: VecDeque<Arc<Mutex<Sequence>>> = VecDeque::new();
71 let mut for_waiting_again: VecDeque<Arc<Mutex<Sequence>>> = VecDeque::new();
72 let mut did_ignore = false;
73 while !self.waiting.is_empty() {
74 let seq = self.waiting.front().unwrap().clone();
75
76 if self.config.max_num_seqs == self.running.len() + 1 {
78 break;
79 }
80
81 let can_allocate = self.block_engine.can_allocate(&*get_mut_arcmutex!(seq));
83 match can_allocate {
84 AllocStatus::Later => break, AllocStatus::Impossible => {
86 let id = *get_mut_arcmutex!(seq).id();
87 let len = get_mut_arcmutex!(seq).get_toks().len();
88 warn!(
89 "Sequence {id} with length of {len} tokens is too long and exceeds capacity of block engine. Sequence will be ignored.",
90 );
91 get_mut_arcmutex!(seq).set_state(SequenceState::FinishedIgnored);
92 did_ignore = true;
93 }
94 _ => {}
95 }
96
97 let new_seq_has_images = get_mut_arcmutex!(seq).has_images();
98 if !scheduled.is_empty()
100 && get_mut_arcmutex!(scheduled[0]).has_images() != new_seq_has_images
101 {
102 let seq = self.waiting.pop_front().unwrap();
103 for_waiting_again.push_back(seq.clone());
104 continue;
105 }
106 if !did_ignore {
107 get_mut_arcmutex!(seq).set_state(SequenceState::RunningPrompt);
108 let seq_handle = get_mut_arcmutex!(seq);
109 self._allocate(&seq_handle);
110 }
111
112 let seq = self.waiting.pop_front().unwrap();
113 self.running.push_back(seq.clone());
114 if !did_ignore {
115 scheduled.push_back(seq);
116 }
117 }
118 self.waiting.extend(for_waiting_again);
119
120 if !scheduled.is_empty() || did_ignore {
122 return PagedAttentionSchedulerOutput {
123 scheduled: scheduled.into(),
124 blocks_to_swap_in: HashMap::new(),
125 blocks_to_copy: HashMap::new(),
126 blocks_to_swap_out: HashMap::new(),
127 };
128 }
129 }
130
131 let mut blocks_to_swap_out = HashMap::new();
132 let mut blocks_to_swap_in = HashMap::new();
133 let mut blocks_to_copy = HashMap::new();
134
135 self.sort_running_by_priority_fcfs();
143
144 let mut running: VecDeque<Arc<Mutex<Sequence>>> = VecDeque::new();
145 let mut did_preempt = false;
146 while !self.running.is_empty() {
147 let seq = self.running.pop_front().unwrap();
148 let mut finished_with_break = false;
149 while !self
150 .block_engine
151 .can_append_token_to_seq(&*get_mut_arcmutex!(seq))
152 {
153 if !self.running.is_empty() {
155 let seq_to_preempt = self.running.pop_back().unwrap();
157 self._preempt(seq_to_preempt, &mut blocks_to_swap_out);
158 did_preempt = true;
159 } else {
160 self._preempt(seq.clone(), &mut blocks_to_swap_out);
162 did_preempt = true;
163 finished_with_break = true;
164 break;
165 }
166 }
167 if !finished_with_break {
168 {
169 let seq_handle = get_mut_arcmutex!(seq);
172 self._append_token_slot_to_seq(&seq_handle, &mut blocks_to_copy);
173 }
174 let new_seq_has_images = get_mut_arcmutex!(seq).has_images();
175 if running.is_empty()
177 || get_mut_arcmutex!(running[0]).has_images() == new_seq_has_images
178 {
179 running.push_back(seq);
180 } else {
181 self.running.push_back(seq);
182 }
183 }
184 }
185 self.running = running;
186
187 self.sort_swapped_out_by_priority_fcfs();
192
193 if !did_preempt {
194 while !self.swapped_out.is_empty() {
195 let seq = self.swapped_out.front().unwrap();
196
197 if !self.block_engine.can_swap_in_seq(&*get_mut_arcmutex!(seq)) {
199 break;
200 }
201
202 let seq = self.swapped_out.pop_front().unwrap();
203 let to_swap_in = self.block_engine.swap_in(&*get_mut_arcmutex!(seq));
205 blocks_to_swap_in.extend(to_swap_in);
206 {
207 let seq_handle = get_mut_arcmutex!(seq);
209 self._append_token_slot_to_seq(&seq_handle, &mut blocks_to_copy);
210 }
211 self.running.push_back(seq);
212 }
213 }
214
215 self.running
216 .iter()
217 .for_each(|seq| get_mut_arcmutex!(seq).set_state(SequenceState::RunningCompletion));
218
219 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
220 self.running.iter().for_each(|seq| {
221 get_mut_arcmutex!(seq).set_state(SequenceState::Done(StopReason::Canceled))
222 });
223 TERMINATE_ALL_NEXT_STEP.store(false, Ordering::SeqCst);
224 }
225
226 PagedAttentionSchedulerOutput {
227 scheduled: self.running.clone().into(), blocks_to_swap_in,
229 blocks_to_copy,
230 blocks_to_swap_out,
231 }
232 }
233
234 pub fn free_finished_sequence_groups(&mut self) {
235 let mut to_free_ids = Vec::new();
236 self.running.retain(|seq| {
237 if get_mut_arcmutex!(seq).is_finished_paged_attn() {
238 to_free_ids.push(get_mut_arcmutex!(seq).get_id());
239 false
240 } else {
241 true
242 }
243 });
244
245 for id in to_free_ids {
246 self._free(id);
247 }
248 }
249}
250
251impl PagedAttentionScheduler {
252 #[allow(dead_code)]
253 fn remove_seq(&mut self, seq_id: usize) -> Arc<Mutex<Sequence>> {
254 if let Some(idx) = self
256 .waiting
257 .iter()
258 .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
259 {
260 return self.waiting.remove(idx).unwrap();
261 };
262 if let Some(idx) = self
264 .running
265 .iter()
266 .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
267 {
268 return self.running.remove(idx).unwrap();
269 };
270 if let Some(idx) = self
272 .swapped_out
273 .iter()
274 .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
275 {
276 return self.swapped_out.remove(idx).unwrap();
277 };
278 panic!("Attempted to remove sequence id {seq_id} but it is not running, waiting, or swapped out.");
279 }
280
281 fn _append_token_slot_to_seq(
282 &mut self,
283 seq: &Sequence,
284 blocks_to_copy: &mut HashMap<usize, Vec<usize>>,
285 ) {
286 let op = self.block_engine.append_token_slot_to_seq(seq);
287 if let Some((src_block, dst_block)) = op {
288 if let std::collections::hash_map::Entry::Vacant(e) = blocks_to_copy.entry(src_block) {
289 e.insert(vec![dst_block]);
290 } else {
291 blocks_to_copy.get_mut(&src_block).unwrap().push(dst_block);
292 }
293 }
294 }
295
296 fn _abort_seq(&mut self, seq_id: usize) {
297 let removed = self.remove_seq(seq_id);
298 get_mut_arcmutex!(removed).set_state(SequenceState::FinishedAborted);
299 self._free(seq_id);
300 }
301
302 fn _preempt(
304 &mut self,
305 seq: Arc<Mutex<Sequence>>,
306 _blocks_to_swap_out: &mut HashMap<usize, usize>,
307 ) {
308 self._preempt_by_recompute(seq)
309 }
310
311 fn _preempt_by_recompute(&mut self, seq: Arc<Mutex<Sequence>>) {
312 get_mut_arcmutex!(seq).set_state(SequenceState::Waiting);
313 self._free(get_mut_arcmutex!(seq).get_id());
314 self.waiting.push_front(seq);
315 }
316
317 fn _preempt_by_swap(
318 &mut self,
319 seq: Arc<Mutex<Sequence>>,
320 blocks_to_swap_out: &mut HashMap<usize, usize>,
321 ) {
322 if !self.block_engine.can_swap_out_seq(&*get_mut_arcmutex!(seq)) {
323 let id = get_mut_arcmutex!(seq).get_id();
325 self._abort_seq(id);
326 return;
327 }
328 let new_to_swap = self.block_engine.swap_out(&*get_mut_arcmutex!(seq));
329 blocks_to_swap_out.extend(new_to_swap);
330 get_mut_arcmutex!(seq).set_state(SequenceState::Swapped);
331
332 self.swapped_out.push_back(seq);
333 }
334
335 fn _allocate(&mut self, seq: &Sequence) {
336 self.block_engine.allocate(seq)
337 }
338
339 fn _free(&mut self, seq_id: usize) {
340 self.block_engine.free_sequence(seq_id);
341 }
342
343 fn sort_running_by_priority_fcfs(&mut self) {
344 self.running
345 .make_contiguous()
346 .sort_by_key(|seq| get_mut_arcmutex!(seq).timestamp());
347 self.running.make_contiguous().reverse();
348 }
349
350 fn sort_swapped_out_by_priority_fcfs(&mut self) {
351 self.swapped_out
352 .make_contiguous()
353 .sort_by_key(|seq| get_mut_arcmutex!(seq).timestamp());
354 self.swapped_out.make_contiguous().reverse();
355 }
356}
357
358impl Scheduler for PagedAttentionScheduler {
359 fn add_seq(&mut self, seq: Sequence) {
360 self.waiting.push_back(Arc::new(Mutex::new(seq)));
361 }
362 fn schedule(&mut self) -> SchedulerOutput<'_> {
363 SchedulerOutput::PagedAttention {
364 output: self.schedule(),
365 }
366 }
367 fn waiting_len(&self) -> usize {
368 self.waiting.len() + self.swapped_out.len()
369 }
370 fn running_len(&self) -> usize {
371 self.running.len()
372 }
373 fn block_tables(&self) -> Option<&BlockTables> {
374 Some(&self.block_engine.block_tables)
375 }
376 fn block_size(&self) -> Option<usize> {
377 Some(self.block_size)
378 }
379 fn free_finished_sequence_groups(&mut self) {
380 self.free_finished_sequence_groups()
381 }
382 fn block_engine(&mut self) -> Option<&mut BlockEngine> {
383 Some(&mut self.block_engine)
384 }
385}