mistralrs_core/dummy_paged_attention/
scheduler.rs1type CPUBlockFrom = usize;
6type GPUBlockFrom = usize;
7type CPUBlockTo = usize;
8type GPUBlockTo = usize;
9type SrcBlockFrom = usize;
10type DstBlocksTo = Vec<usize>;
11
12use std::{
13 collections::{HashMap, VecDeque},
14 sync::{atomic::Ordering, Arc, Mutex},
15};
16
17use tracing::warn;
18
19use crate::{
20 engine::IntervalLogger,
21 get_mut_arcmutex,
22 paged_attention::BlockEngine,
23 scheduler::{Scheduler, SchedulerOutput},
24 sequence::{Sequence, SequenceState, StopReason},
25 TERMINATE_ALL_NEXT_STEP,
26};
27
28use super::{block_engine::AllocStatus, BlockEngineSequence, BlockTables, CacheConfig};
29
30const WAITING_TIMEOUT: usize = 64;
32
33pub struct PagedAttentionSchedulerOutput {
34 pub scheduled: Vec<Arc<Mutex<Sequence>>>,
36 pub blocks_to_swap_in: HashMap<CPUBlockFrom, GPUBlockTo>,
37 pub blocks_to_swap_out: HashMap<GPUBlockFrom, CPUBlockTo>,
38 pub blocks_to_copy: HashMap<SrcBlockFrom, DstBlocksTo>,
39}
40
41pub struct PagedAttentionSchedulerConfig {
42 pub max_num_seqs: usize,
43}
44
45pub struct PagedAttentionScheduler {
46 waiting: VecDeque<Arc<Mutex<Sequence>>>,
47 running: VecDeque<Arc<Mutex<Sequence>>>,
48 swapped_out: VecDeque<Arc<Mutex<Sequence>>>,
49 config: PagedAttentionSchedulerConfig,
50 pub block_engine: Arc<tokio::sync::Mutex<BlockEngine>>,
51 block_size: usize,
52}
53
54impl PagedAttentionScheduler {
55 pub fn new(config: PagedAttentionSchedulerConfig, cache_config: CacheConfig) -> Self {
56 Self {
57 waiting: VecDeque::new(),
58 running: VecDeque::new(),
59 swapped_out: VecDeque::new(),
60 config,
61 block_engine: Arc::new(tokio::sync::Mutex::new(BlockEngine::new(
62 cache_config.block_size,
63 cache_config.num_gpu_blocks,
64 cache_config.num_cpu_blocks,
65 ))),
66 block_size: cache_config.block_size,
67 }
68 }
69
70 pub fn schedule(&mut self, logger: &IntervalLogger) -> PagedAttentionSchedulerOutput {
71 if self.swapped_out.is_empty() {
74 let mut scheduled: VecDeque<Arc<Mutex<Sequence>>> = VecDeque::new();
75 let mut for_waiting_again: VecDeque<Arc<Mutex<Sequence>>> = VecDeque::new();
76 let mut did_ignore = false;
77 while !self.waiting.is_empty() {
78 let seq = self.waiting.front().unwrap().clone();
79
80 if self.config.max_num_seqs == self.running.len() + 1 {
82 break;
83 }
84
85 let can_allocate =
87 get_mut_arcmutex!(self.block_engine).can_allocate(&mut *get_mut_arcmutex!(seq));
88 match can_allocate {
89 AllocStatus::Later { waitlisted_count } => {
90 {
91 if waitlisted_count > WAITING_TIMEOUT {
94 if let Some(seq_to_preempt) = self.running.pop_back() {
97 self._preempt_by_recompute(seq_to_preempt);
99
100 if !matches!(
102 get_mut_arcmutex!(self.block_engine)
103 .can_allocate(&mut *get_mut_arcmutex!(seq)),
104 AllocStatus::Ok
105 ) {
106 let id = *get_mut_arcmutex!(seq).id();
109 let len = get_mut_arcmutex!(seq).get_toks().len();
110 warn!(
111 "Sequence {id} with length of {len} tokens still exceeds KV cache size \
112 even after evicting another sequence.",
113 );
114 get_mut_arcmutex!(seq)
115 .set_state(SequenceState::FinishedIgnored);
116 did_ignore = true;
117 }
118 } else {
119 let id = *get_mut_arcmutex!(seq).id();
121 let len = get_mut_arcmutex!(seq).get_toks().len();
122 warn!(
123 "Sequence {id} with length of {len} tokens is too long and exceeds KV cache size. \
124 To fix, increase the maximum sequence length for the KV cache, for example with \
125 `--max-seq-len`/ `max_seq_len` in automatic device mapping parameters.",
126 );
127 get_mut_arcmutex!(seq)
128 .set_state(SequenceState::FinishedIgnored);
129 did_ignore = true;
130 }
131 } else {
132 break;
134 }
135 }
136 }
137 AllocStatus::Impossible => {
138 let id = *get_mut_arcmutex!(seq).id();
139 let len = get_mut_arcmutex!(seq).get_toks().len();
140 warn!(
141 "Sequence {id} with length of {len} tokens is too long and exceeds KV cache size. To fix, increase the maximum sequence length for the KV cache, for example with `--max-seq-len`/ `max_seq_len` in automatic device mapping parameters.",
142 );
143 get_mut_arcmutex!(seq).set_state(SequenceState::FinishedIgnored);
144 did_ignore = true;
145 }
146 _ => {}
147 }
148
149 let new_seq_has_images = get_mut_arcmutex!(seq).has_images();
150 if !scheduled.is_empty()
152 && get_mut_arcmutex!(scheduled[0]).has_images() != new_seq_has_images
153 {
154 let seq = self.waiting.pop_front().unwrap();
155 for_waiting_again.push_back(seq.clone());
156 continue;
157 }
158 if !did_ignore {
159 get_mut_arcmutex!(seq).set_state(SequenceState::RunningPrompt);
160 let mut seq_handle = get_mut_arcmutex!(seq);
161 self._allocate(&mut seq_handle);
162 }
163
164 let seq = self.waiting.pop_front().unwrap();
165 self.running.push_back(seq.clone());
166 if !did_ignore {
167 scheduled.push_back(seq);
168 }
169 }
170 self.waiting.extend(for_waiting_again);
171
172 if !scheduled.is_empty() || did_ignore {
174 logger.set_num_running(self.running.len());
175 logger.set_num_waiting(self.waiting.len() + self.swapped_out.len());
176
177 return PagedAttentionSchedulerOutput {
178 scheduled: scheduled.into(),
179 blocks_to_swap_in: HashMap::new(),
180 blocks_to_copy: HashMap::new(),
181 blocks_to_swap_out: HashMap::new(),
182 };
183 }
184 }
185
186 let mut blocks_to_swap_out = HashMap::new();
187 let mut blocks_to_swap_in = HashMap::new();
188 let mut blocks_to_copy = HashMap::new();
189
190 self.sort_running_by_priority_fcfs();
198
199 let mut running: VecDeque<Arc<Mutex<Sequence>>> = VecDeque::new();
200 let mut did_preempt = false;
201 while !self.running.is_empty() {
202 let seq = self.running.pop_front().unwrap();
203 let mut finished_with_break = false;
204 while !get_mut_arcmutex!(self.block_engine)
205 .can_append_token_to_seq(&*get_mut_arcmutex!(seq))
206 {
207 if !self.running.is_empty() {
209 let seq_to_preempt = self.running.pop_back().unwrap();
211 self._preempt(seq_to_preempt, &mut blocks_to_swap_out);
212 did_preempt = true;
213 } else {
214 self._preempt(seq.clone(), &mut blocks_to_swap_out);
216 did_preempt = true;
217 finished_with_break = true;
218 break;
219 }
220 }
221 if !finished_with_break {
222 {
223 let seq_handle = get_mut_arcmutex!(seq);
226 self._append_token_slot_to_seq(&seq_handle, &mut blocks_to_copy);
227 }
228 let new_seq_has_images = get_mut_arcmutex!(seq).has_images();
229 if running.is_empty()
231 || get_mut_arcmutex!(running[0]).has_images() == new_seq_has_images
232 {
233 running.push_back(seq);
234 } else {
235 self.running.push_back(seq);
236 }
237 }
238 }
239 self.running = running;
240
241 self.sort_swapped_out_by_priority_fcfs();
246
247 if !did_preempt {
248 while !self.swapped_out.is_empty() {
249 let seq = self.swapped_out.front().unwrap();
250
251 if !get_mut_arcmutex!(self.block_engine).can_swap_in_seq(&*get_mut_arcmutex!(seq)) {
253 break;
254 }
255
256 let seq = self.swapped_out.pop_front().unwrap();
257 let to_swap_in =
259 get_mut_arcmutex!(self.block_engine).swap_in(&*get_mut_arcmutex!(seq));
260 blocks_to_swap_in.extend(to_swap_in);
261 {
262 let seq_handle = get_mut_arcmutex!(seq);
264 self._append_token_slot_to_seq(&seq_handle, &mut blocks_to_copy);
265 }
266 self.running.push_back(seq);
267 }
268 }
269
270 self.running
271 .iter()
272 .for_each(|seq| get_mut_arcmutex!(seq).set_state(SequenceState::RunningCompletion));
273
274 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
275 self.running.iter().for_each(|seq| {
276 get_mut_arcmutex!(seq).set_state(SequenceState::Done(StopReason::Canceled))
277 });
278 TERMINATE_ALL_NEXT_STEP.store(false, Ordering::SeqCst);
279 }
280
281 logger.set_num_running(self.running.len());
282 logger.set_num_waiting(self.waiting.len() + self.swapped_out.len());
283
284 PagedAttentionSchedulerOutput {
285 scheduled: self.running.clone().into(), blocks_to_swap_in,
287 blocks_to_copy,
288 blocks_to_swap_out,
289 }
290 }
291
292 pub fn free_finished_sequence_groups(&mut self) {
293 let mut to_free_ids = Vec::new();
294 self.running.retain(|seq| {
295 if get_mut_arcmutex!(seq).is_finished_paged_attn() {
296 to_free_ids.push(get_mut_arcmutex!(seq).get_id());
297 false
298 } else {
299 true
300 }
301 });
302
303 for id in to_free_ids {
304 self._free(id);
305 }
306 }
307}
308
309impl PagedAttentionScheduler {
310 #[allow(dead_code)]
311 fn remove_seq(&mut self, seq_id: usize) -> Arc<Mutex<Sequence>> {
312 if let Some(idx) = self
314 .waiting
315 .iter()
316 .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
317 {
318 return self.waiting.remove(idx).unwrap();
319 };
320 if let Some(idx) = self
322 .running
323 .iter()
324 .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
325 {
326 return self.running.remove(idx).unwrap();
327 };
328 if let Some(idx) = self
330 .swapped_out
331 .iter()
332 .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
333 {
334 return self.swapped_out.remove(idx).unwrap();
335 };
336 panic!("Attempted to remove sequence id {seq_id} but it is not running, waiting, or swapped out.");
337 }
338
339 fn _append_token_slot_to_seq(
340 &mut self,
341 seq: &Sequence,
342 blocks_to_copy: &mut HashMap<usize, Vec<usize>>,
343 ) {
344 let op = get_mut_arcmutex!(self.block_engine).append_token_slot_to_seq(seq);
345 if let Some((src_block, dst_block)) = op {
346 if let std::collections::hash_map::Entry::Vacant(e) = blocks_to_copy.entry(src_block) {
347 e.insert(vec![dst_block]);
348 } else {
349 blocks_to_copy.get_mut(&src_block).unwrap().push(dst_block);
350 }
351 }
352 }
353
354 fn _abort_seq(&mut self, seq_id: usize) {
355 let removed = self.remove_seq(seq_id);
356 get_mut_arcmutex!(removed).set_state(SequenceState::FinishedAborted);
357 self._free(seq_id);
358 }
359
360 fn _preempt(
362 &mut self,
363 seq: Arc<Mutex<Sequence>>,
364 _blocks_to_swap_out: &mut HashMap<usize, usize>,
365 ) {
366 self._preempt_by_recompute(seq)
367 }
368
369 fn _preempt_by_recompute(&mut self, seq: Arc<Mutex<Sequence>>) {
370 get_mut_arcmutex!(seq).set_state(SequenceState::Waiting);
371 self._free(get_mut_arcmutex!(seq).get_id());
372 self.waiting.push_front(seq);
373 }
374
375 fn _preempt_by_swap(
376 &mut self,
377 seq: Arc<Mutex<Sequence>>,
378 blocks_to_swap_out: &mut HashMap<usize, usize>,
379 ) {
380 if !get_mut_arcmutex!(self.block_engine).can_swap_out_seq(&*get_mut_arcmutex!(seq)) {
381 let id = get_mut_arcmutex!(seq).get_id();
383 self._abort_seq(id);
384 return;
385 }
386 let new_to_swap = get_mut_arcmutex!(self.block_engine).swap_out(&*get_mut_arcmutex!(seq));
387 blocks_to_swap_out.extend(new_to_swap);
388 get_mut_arcmutex!(seq).set_state(SequenceState::Swapped);
389
390 self.swapped_out.push_back(seq);
391 }
392
393 fn _allocate(&mut self, seq: &mut Sequence) {
394 get_mut_arcmutex!(self.block_engine).allocate(seq)
395 }
396
397 fn _free(&mut self, seq_id: usize) {
398 get_mut_arcmutex!(self.block_engine).free_sequence(seq_id);
399 }
400
401 fn sort_running_by_priority_fcfs(&mut self) {
402 self.running
403 .make_contiguous()
404 .sort_by_key(|seq| get_mut_arcmutex!(seq).timestamp());
405 self.running.make_contiguous().reverse();
406 }
407
408 fn sort_swapped_out_by_priority_fcfs(&mut self) {
409 self.swapped_out
410 .make_contiguous()
411 .sort_by_key(|seq| get_mut_arcmutex!(seq).timestamp());
412 self.swapped_out.make_contiguous().reverse();
413 }
414}
415
416impl Scheduler for PagedAttentionScheduler {
417 fn add_seq(&mut self, seq: Sequence) {
418 self.waiting.push_back(Arc::new(Mutex::new(seq)));
419 }
420 fn schedule(&mut self, logger: &IntervalLogger) -> SchedulerOutput<'_> {
421 SchedulerOutput::PagedAttention {
422 output: self.schedule(logger),
423 }
424 }
425 fn waiting_len(&self) -> usize {
426 self.waiting.len() + self.swapped_out.len()
427 }
428 fn running_len(&self) -> usize {
429 self.running.len()
430 }
431 fn block_tables(&self) -> Option<BlockTables> {
432 Some(get_mut_arcmutex!(self.block_engine).block_tables.clone())
433 }
434 fn block_size(&self) -> Option<usize> {
435 Some(self.block_size)
436 }
437 fn free_finished_sequence_groups(&mut self) {
438 self.free_finished_sequence_groups()
439 }
440 fn block_engine(&self) -> Option<Arc<tokio::sync::Mutex<BlockEngine>>> {
441 Some(self.block_engine.clone())
442 }
443}