mistralrs_core/dummy_paged_attention/
scheduler.rs

1//! The Scheduler uses a BlockEngine to schedule and automatically batch sequences. The
2//! primary method `schedule` returns the batched sequences as inputs, as well as the
3//! operations to be executed on the cache by the CacheEngine.
4
5type CPUBlockFrom = usize;
6type GPUBlockFrom = usize;
7type CPUBlockTo = usize;
8type GPUBlockTo = usize;
9type SrcBlockFrom = usize;
10type DstBlocksTo = Vec<usize>;
11
12use std::{
13    collections::{HashMap, VecDeque},
14    sync::{atomic::Ordering, Arc, Mutex},
15};
16
17use tracing::warn;
18
19use crate::{
20    engine::IntervalLogger,
21    get_mut_arcmutex,
22    paged_attention::BlockEngine,
23    scheduler::{Scheduler, SchedulerOutput},
24    sequence::{Sequence, SequenceState, StopReason},
25    TERMINATE_ALL_NEXT_STEP,
26};
27
28use super::{block_engine::AllocStatus, BlockEngineSequence, BlockTables, CacheConfig};
29
30/// Allow sequences to wait for 64 scheduling passes before warning of deprivation.
31const WAITING_TIMEOUT: usize = 64;
32
33pub struct PagedAttentionSchedulerOutput {
34    /// Either ALL prompt or ALL completion.
35    pub scheduled: Vec<Arc<Mutex<Sequence>>>,
36    pub blocks_to_swap_in: HashMap<CPUBlockFrom, GPUBlockTo>,
37    pub blocks_to_swap_out: HashMap<GPUBlockFrom, CPUBlockTo>,
38    pub blocks_to_copy: HashMap<SrcBlockFrom, DstBlocksTo>,
39}
40
41pub struct PagedAttentionSchedulerConfig {
42    pub max_num_seqs: usize,
43}
44
45pub struct PagedAttentionScheduler {
46    waiting: VecDeque<Arc<Mutex<Sequence>>>,
47    running: VecDeque<Arc<Mutex<Sequence>>>,
48    swapped_out: VecDeque<Arc<Mutex<Sequence>>>,
49    config: PagedAttentionSchedulerConfig,
50    pub block_engine: Arc<tokio::sync::Mutex<BlockEngine>>,
51    block_size: usize,
52}
53
54impl PagedAttentionScheduler {
55    pub fn new(config: PagedAttentionSchedulerConfig, cache_config: CacheConfig) -> Self {
56        Self {
57            waiting: VecDeque::new(),
58            running: VecDeque::new(),
59            swapped_out: VecDeque::new(),
60            config,
61            block_engine: Arc::new(tokio::sync::Mutex::new(BlockEngine::new(
62                cache_config.block_size,
63                cache_config.num_gpu_blocks,
64                cache_config.num_cpu_blocks,
65            ))),
66            block_size: cache_config.block_size,
67        }
68    }
69
70    pub fn schedule(&mut self, logger: &IntervalLogger) -> PagedAttentionSchedulerOutput {
71        // If there are no swapped seqs (they have higher priority), add seqs that are in the
72        // waiting queue to the running queue.
73        if self.swapped_out.is_empty() {
74            let mut scheduled: VecDeque<Arc<Mutex<Sequence>>> = VecDeque::new();
75            let mut for_waiting_again: VecDeque<Arc<Mutex<Sequence>>> = VecDeque::new();
76            let mut did_ignore = false;
77            while !self.waiting.is_empty() {
78                let seq = self.waiting.front().unwrap().clone();
79
80                // If adding this seq means we will have too many, stop as no more could be added.
81                if self.config.max_num_seqs == self.running.len() + 1 {
82                    break;
83                }
84
85                // If we cannot allocate either now or in the future, either do not continue or remove the sequence.
86                let can_allocate =
87                    get_mut_arcmutex!(self.block_engine).can_allocate(&mut *get_mut_arcmutex!(seq));
88                match can_allocate {
89                    AllocStatus::Later { waitlisted_count } => {
90                        {
91                            // If the sequence has waited too long, try to free space by evicting a
92                            // low‑priority running sequence instead of permanently ignoring it.
93                            if waitlisted_count > WAITING_TIMEOUT {
94                                // Attempt to preempt the least‑recently created running sequence
95                                // (the back of the `running` queue after FCFS sort).
96                                if let Some(seq_to_preempt) = self.running.pop_back() {
97                                    // Move the running sequence back to the waiting queue, freeing its KV‑cache.
98                                    self._preempt_by_recompute(seq_to_preempt);
99
100                                    // Retry allocation for the current sequence now that space is freed.
101                                    if !matches!(
102                                        get_mut_arcmutex!(self.block_engine)
103                                            .can_allocate(&mut *get_mut_arcmutex!(seq)),
104                                        AllocStatus::Ok
105                                    ) {
106                                        // Even after eviction we still cannot fit the sequence – fall back to
107                                        // finishing it as ignored and restore the previous behaviour.
108                                        let id = *get_mut_arcmutex!(seq).id();
109                                        let len = get_mut_arcmutex!(seq).get_toks().len();
110                                        warn!(
111                                            "Sequence {id} with length of {len} tokens still exceeds KV cache size \
112                                             even after evicting another sequence.",
113                                        );
114                                        get_mut_arcmutex!(seq)
115                                            .set_state(SequenceState::FinishedIgnored);
116                                        did_ignore = true;
117                                    }
118                                } else {
119                                    // No running sequence is available to evict; keep the original ignore logic.
120                                    let id = *get_mut_arcmutex!(seq).id();
121                                    let len = get_mut_arcmutex!(seq).get_toks().len();
122                                    warn!(
123                                        "Sequence {id} with length of {len} tokens is too long and exceeds KV cache size. \
124                                         To fix, increase the maximum sequence length for the KV cache, for example with \
125                                         `--max-seq-len`/ `max_seq_len` in automatic device mapping parameters.",
126                                    );
127                                    get_mut_arcmutex!(seq)
128                                        .set_state(SequenceState::FinishedIgnored);
129                                    did_ignore = true;
130                                }
131                            } else {
132                                // Keep waiting until the timeout threshold is reached.
133                                break;
134                            }
135                        }
136                    }
137                    AllocStatus::Impossible => {
138                        let id = *get_mut_arcmutex!(seq).id();
139                        let len = get_mut_arcmutex!(seq).get_toks().len();
140                        warn!(
141                            "Sequence {id} with length of {len} tokens is too long and exceeds KV cache size. To fix, increase the maximum sequence length for the KV cache, for example with `--max-seq-len`/ `max_seq_len` in automatic device mapping parameters.",
142                        );
143                        get_mut_arcmutex!(seq).set_state(SequenceState::FinishedIgnored);
144                        did_ignore = true;
145                    }
146                    _ => {}
147                }
148
149                let new_seq_has_images = get_mut_arcmutex!(seq).has_images();
150                // Only add it if has_images matches either current or there are none.
151                if !scheduled.is_empty()
152                    && get_mut_arcmutex!(scheduled[0]).has_images() != new_seq_has_images
153                {
154                    let seq = self.waiting.pop_front().unwrap();
155                    for_waiting_again.push_back(seq.clone());
156                    continue;
157                }
158                if !did_ignore {
159                    get_mut_arcmutex!(seq).set_state(SequenceState::RunningPrompt);
160                    let mut seq_handle = get_mut_arcmutex!(seq);
161                    self._allocate(&mut seq_handle);
162                }
163
164                let seq = self.waiting.pop_front().unwrap();
165                self.running.push_back(seq.clone());
166                if !did_ignore {
167                    scheduled.push_back(seq);
168                }
169            }
170            self.waiting.extend(for_waiting_again);
171
172            // If we did schedule, or we ignored sequences.
173            if !scheduled.is_empty() || did_ignore {
174                logger.set_num_running(self.running.len());
175                logger.set_num_waiting(self.waiting.len() + self.swapped_out.len());
176
177                return PagedAttentionSchedulerOutput {
178                    scheduled: scheduled.into(),
179                    blocks_to_swap_in: HashMap::new(),
180                    blocks_to_copy: HashMap::new(),
181                    blocks_to_swap_out: HashMap::new(),
182                };
183            }
184        }
185
186        let mut blocks_to_swap_out = HashMap::new();
187        let mut blocks_to_swap_in = HashMap::new();
188        let mut blocks_to_copy = HashMap::new();
189
190        // Reserve token slots for the running sequence groups, preempting the lowest (earliest) first.
191        // Preempt lowest priority sequences that are in the running queue, forming a
192        // new running queue that has the actually running sequences. Remember the preempted
193        // sequences, which will be put into the waiting or swapped out state depending on
194        // the preemption method (recompute or swap, respectively).
195
196        // Sorts by creation time, in descending order so that earliest are latest (first come first serve).
197        self.sort_running_by_priority_fcfs();
198
199        let mut running: VecDeque<Arc<Mutex<Sequence>>> = VecDeque::new();
200        let mut did_preempt = false;
201        while !self.running.is_empty() {
202            let seq = self.running.pop_front().unwrap();
203            let mut finished_with_break = false;
204            while !get_mut_arcmutex!(self.block_engine)
205                .can_append_token_to_seq(&*get_mut_arcmutex!(seq))
206            {
207                // If we cannot, now we need to preempt some seqs
208                if !self.running.is_empty() {
209                    // There is something to preempt.
210                    let seq_to_preempt = self.running.pop_back().unwrap();
211                    self._preempt(seq_to_preempt, &mut blocks_to_swap_out);
212                    did_preempt = true;
213                } else {
214                    // Nothing to preempt, preempt ourselves. Also, do not bother looking at anything else.
215                    self._preempt(seq.clone(), &mut blocks_to_swap_out);
216                    did_preempt = true;
217                    finished_with_break = true;
218                    break;
219                }
220            }
221            if !finished_with_break {
222                {
223                    // If we need to, append physical blocks for a new token. We do not need to if there is enough space.
224                    // If we just got preempted, there is no reason to allocate
225                    let seq_handle = get_mut_arcmutex!(seq);
226                    self._append_token_slot_to_seq(&seq_handle, &mut blocks_to_copy);
227                }
228                let new_seq_has_images = get_mut_arcmutex!(seq).has_images();
229                // Only add it if has_images matches either current or there are none.
230                if running.is_empty()
231                    || get_mut_arcmutex!(running[0]).has_images() == new_seq_has_images
232                {
233                    running.push_back(seq);
234                } else {
235                    self.running.push_back(seq);
236                }
237            }
238        }
239        self.running = running;
240
241        // Try to swap in the swapped out sequences and add these to the
242        // running state if possible.
243
244        // Sorts by creation time, in descending order so that earliest are latest (first come first serve).
245        self.sort_swapped_out_by_priority_fcfs();
246
247        if !did_preempt {
248            while !self.swapped_out.is_empty() {
249                let seq = self.swapped_out.front().unwrap();
250
251                // If the GPU cannot handle the group being swapped in, stop
252                if !get_mut_arcmutex!(self.block_engine).can_swap_in_seq(&*get_mut_arcmutex!(seq)) {
253                    break;
254                }
255
256                let seq = self.swapped_out.pop_front().unwrap();
257                // Swap in the blocks
258                let to_swap_in =
259                    get_mut_arcmutex!(self.block_engine).swap_in(&*get_mut_arcmutex!(seq));
260                blocks_to_swap_in.extend(to_swap_in);
261                {
262                    // Reserve a new slot
263                    let seq_handle = get_mut_arcmutex!(seq);
264                    self._append_token_slot_to_seq(&seq_handle, &mut blocks_to_copy);
265                }
266                self.running.push_back(seq);
267            }
268        }
269
270        self.running
271            .iter()
272            .for_each(|seq| get_mut_arcmutex!(seq).set_state(SequenceState::RunningCompletion));
273
274        if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
275            self.running.iter().for_each(|seq| {
276                get_mut_arcmutex!(seq).set_state(SequenceState::Done(StopReason::Canceled))
277            });
278            TERMINATE_ALL_NEXT_STEP.store(false, Ordering::SeqCst);
279        }
280
281        logger.set_num_running(self.running.len());
282        logger.set_num_waiting(self.waiting.len() + self.swapped_out.len());
283
284        PagedAttentionSchedulerOutput {
285            scheduled: self.running.clone().into(), // Clone should be cheap.
286            blocks_to_swap_in,
287            blocks_to_copy,
288            blocks_to_swap_out,
289        }
290    }
291
292    pub fn free_finished_sequence_groups(&mut self) {
293        let mut to_free_ids = Vec::new();
294        self.running.retain(|seq| {
295            if get_mut_arcmutex!(seq).is_finished_paged_attn() {
296                to_free_ids.push(get_mut_arcmutex!(seq).get_id());
297                false
298            } else {
299                true
300            }
301        });
302
303        for id in to_free_ids {
304            self._free(id);
305        }
306    }
307}
308
309impl PagedAttentionScheduler {
310    #[allow(dead_code)]
311    fn remove_seq(&mut self, seq_id: usize) -> Arc<Mutex<Sequence>> {
312        // Remove it if it is in waiting
313        if let Some(idx) = self
314            .waiting
315            .iter()
316            .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
317        {
318            return self.waiting.remove(idx).unwrap();
319        };
320        // Remove it if it is in running
321        if let Some(idx) = self
322            .running
323            .iter()
324            .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
325        {
326            return self.running.remove(idx).unwrap();
327        };
328        // Remove it if it is in swapped out
329        if let Some(idx) = self
330            .swapped_out
331            .iter()
332            .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
333        {
334            return self.swapped_out.remove(idx).unwrap();
335        };
336        panic!("Attempted to remove sequence id {seq_id} but it is not running, waiting, or swapped out.");
337    }
338
339    fn _append_token_slot_to_seq(
340        &mut self,
341        seq: &Sequence,
342        blocks_to_copy: &mut HashMap<usize, Vec<usize>>,
343    ) {
344        let op = get_mut_arcmutex!(self.block_engine).append_token_slot_to_seq(seq);
345        if let Some((src_block, dst_block)) = op {
346            if let std::collections::hash_map::Entry::Vacant(e) = blocks_to_copy.entry(src_block) {
347                e.insert(vec![dst_block]);
348            } else {
349                blocks_to_copy.get_mut(&src_block).unwrap().push(dst_block);
350            }
351        }
352    }
353
354    fn _abort_seq(&mut self, seq_id: usize) {
355        let removed = self.remove_seq(seq_id);
356        get_mut_arcmutex!(removed).set_state(SequenceState::FinishedAborted);
357        self._free(seq_id);
358    }
359
360    /// Preempt either by recomputation (for single sequence), or by swapping (for multiple).
361    fn _preempt(
362        &mut self,
363        seq: Arc<Mutex<Sequence>>,
364        _blocks_to_swap_out: &mut HashMap<usize, usize>,
365    ) {
366        self._preempt_by_recompute(seq)
367    }
368
369    fn _preempt_by_recompute(&mut self, seq: Arc<Mutex<Sequence>>) {
370        get_mut_arcmutex!(seq).set_state(SequenceState::Waiting);
371        self._free(get_mut_arcmutex!(seq).get_id());
372        self.waiting.push_front(seq);
373    }
374
375    fn _preempt_by_swap(
376        &mut self,
377        seq: Arc<Mutex<Sequence>>,
378        blocks_to_swap_out: &mut HashMap<usize, usize>,
379    ) {
380        if !get_mut_arcmutex!(self.block_engine).can_swap_out_seq(&*get_mut_arcmutex!(seq)) {
381            // If we cannot swap it out, abort the sequence group.
382            let id = get_mut_arcmutex!(seq).get_id();
383            self._abort_seq(id);
384            return;
385        }
386        let new_to_swap = get_mut_arcmutex!(self.block_engine).swap_out(&*get_mut_arcmutex!(seq));
387        blocks_to_swap_out.extend(new_to_swap);
388        get_mut_arcmutex!(seq).set_state(SequenceState::Swapped);
389
390        self.swapped_out.push_back(seq);
391    }
392
393    fn _allocate(&mut self, seq: &mut Sequence) {
394        get_mut_arcmutex!(self.block_engine).allocate(seq)
395    }
396
397    fn _free(&mut self, seq_id: usize) {
398        get_mut_arcmutex!(self.block_engine).free_sequence(seq_id);
399    }
400
401    fn sort_running_by_priority_fcfs(&mut self) {
402        self.running
403            .make_contiguous()
404            .sort_by_key(|seq| get_mut_arcmutex!(seq).timestamp());
405        self.running.make_contiguous().reverse();
406    }
407
408    fn sort_swapped_out_by_priority_fcfs(&mut self) {
409        self.swapped_out
410            .make_contiguous()
411            .sort_by_key(|seq| get_mut_arcmutex!(seq).timestamp());
412        self.swapped_out.make_contiguous().reverse();
413    }
414}
415
416impl Scheduler for PagedAttentionScheduler {
417    fn add_seq(&mut self, seq: Sequence) {
418        self.waiting.push_back(Arc::new(Mutex::new(seq)));
419    }
420    fn schedule(&mut self, logger: &IntervalLogger) -> SchedulerOutput<'_> {
421        SchedulerOutput::PagedAttention {
422            output: self.schedule(logger),
423        }
424    }
425    fn waiting_len(&self) -> usize {
426        self.waiting.len() + self.swapped_out.len()
427    }
428    fn running_len(&self) -> usize {
429        self.running.len()
430    }
431    fn block_tables(&self) -> Option<BlockTables> {
432        Some(get_mut_arcmutex!(self.block_engine).block_tables.clone())
433    }
434    fn block_size(&self) -> Option<usize> {
435        Some(self.block_size)
436    }
437    fn free_finished_sequence_groups(&mut self) {
438        self.free_finished_sequence_groups()
439    }
440    fn block_engine(&self) -> Option<Arc<tokio::sync::Mutex<BlockEngine>>> {
441        Some(self.block_engine.clone())
442    }
443}