mistralrs_core/dummy_paged_attention/
scheduler.rs

1//! The Scheduler uses a BlockEngine to schedule and automatically batch sequences. The
2//! primary method `schedule` returns the batched sequences as inputs, as well as the
3//! operations to be executed on the cache by the CacheEngine.
4
5type CPUBlockFrom = usize;
6type GPUBlockFrom = usize;
7type CPUBlockTo = usize;
8type GPUBlockTo = usize;
9type SrcBlockFrom = usize;
10type DstBlocksTo = Vec<usize>;
11
12use std::{
13    collections::{HashMap, VecDeque},
14    sync::{atomic::Ordering, Arc, Mutex},
15};
16
17use tracing::warn;
18
19use crate::{
20    get_mut_arcmutex,
21    paged_attention::BlockEngine,
22    scheduler::{Scheduler, SchedulerOutput},
23    sequence::{Sequence, SequenceState, StopReason},
24    TERMINATE_ALL_NEXT_STEP,
25};
26
27use super::{block_engine::AllocStatus, BlockEngineSequence, BlockTables, CacheConfig};
28
29pub struct PagedAttentionSchedulerOutput {
30    /// Either ALL prompt or ALL completion.
31    pub scheduled: Vec<Arc<Mutex<Sequence>>>,
32    pub blocks_to_swap_in: HashMap<CPUBlockFrom, GPUBlockTo>,
33    pub blocks_to_swap_out: HashMap<GPUBlockFrom, CPUBlockTo>,
34    pub blocks_to_copy: HashMap<SrcBlockFrom, DstBlocksTo>,
35}
36
37pub struct PagedAttentionSchedulerConfig {
38    pub max_num_seqs: usize,
39}
40
41pub struct PagedAttentionScheduler {
42    waiting: VecDeque<Arc<Mutex<Sequence>>>,
43    running: VecDeque<Arc<Mutex<Sequence>>>,
44    swapped_out: VecDeque<Arc<Mutex<Sequence>>>,
45    config: PagedAttentionSchedulerConfig,
46    pub block_engine: BlockEngine,
47    block_size: usize,
48}
49
50impl PagedAttentionScheduler {
51    pub fn new(config: PagedAttentionSchedulerConfig, cache_config: CacheConfig) -> Self {
52        Self {
53            waiting: VecDeque::new(),
54            running: VecDeque::new(),
55            swapped_out: VecDeque::new(),
56            config,
57            block_engine: BlockEngine::new(
58                cache_config.block_size,
59                cache_config.num_gpu_blocks,
60                cache_config.num_cpu_blocks,
61            ),
62            block_size: cache_config.block_size,
63        }
64    }
65
66    pub fn schedule(&mut self) -> PagedAttentionSchedulerOutput {
67        // If there are no swapped seqs (they have higher priority), add seqs that are in the
68        // waiting queue to the running queue.
69        if self.swapped_out.is_empty() {
70            let mut scheduled: VecDeque<Arc<Mutex<Sequence>>> = VecDeque::new();
71            let mut for_waiting_again: VecDeque<Arc<Mutex<Sequence>>> = VecDeque::new();
72            let mut did_ignore = false;
73            while !self.waiting.is_empty() {
74                let seq = self.waiting.front().unwrap().clone();
75
76                // If adding this seq means we will have too many, stop as no more could be added.
77                if self.config.max_num_seqs == self.running.len() + 1 {
78                    break;
79                }
80
81                // If we cannot allocate either now or in the future, either do not continue or remove the sequence.
82                let can_allocate = self.block_engine.can_allocate(&*get_mut_arcmutex!(seq));
83                match can_allocate {
84                    AllocStatus::Later => break, // If we can only allocate later, do not bother iterating over the rest.
85                    AllocStatus::Impossible => {
86                        let id = *get_mut_arcmutex!(seq).id();
87                        let len = get_mut_arcmutex!(seq).get_toks().len();
88                        warn!(
89                            "Sequence {id} with length of {len} tokens is too long and exceeds capacity of block engine. Sequence will be ignored.",
90                        );
91                        get_mut_arcmutex!(seq).set_state(SequenceState::FinishedIgnored);
92                        did_ignore = true;
93                    }
94                    _ => {}
95                }
96
97                let new_seq_has_images = get_mut_arcmutex!(seq).has_images();
98                // Only add it if has_images matches either current or there are none.
99                if !scheduled.is_empty()
100                    && get_mut_arcmutex!(scheduled[0]).has_images() != new_seq_has_images
101                {
102                    let seq = self.waiting.pop_front().unwrap();
103                    for_waiting_again.push_back(seq.clone());
104                    continue;
105                }
106                if !did_ignore {
107                    get_mut_arcmutex!(seq).set_state(SequenceState::RunningPrompt);
108                    let seq_handle = get_mut_arcmutex!(seq);
109                    self._allocate(&seq_handle);
110                }
111
112                let seq = self.waiting.pop_front().unwrap();
113                self.running.push_back(seq.clone());
114                if !did_ignore {
115                    scheduled.push_back(seq);
116                }
117            }
118            self.waiting.extend(for_waiting_again);
119
120            // If we did schedule, or we ignored sequences.
121            if !scheduled.is_empty() || did_ignore {
122                return PagedAttentionSchedulerOutput {
123                    scheduled: scheduled.into(),
124                    blocks_to_swap_in: HashMap::new(),
125                    blocks_to_copy: HashMap::new(),
126                    blocks_to_swap_out: HashMap::new(),
127                };
128            }
129        }
130
131        let mut blocks_to_swap_out = HashMap::new();
132        let mut blocks_to_swap_in = HashMap::new();
133        let mut blocks_to_copy = HashMap::new();
134
135        // Reserve token slots for the running sequence groups, preempting the lowest (earliest) first.
136        // Preempt lowest priority sequences that are in the running queue, forming a
137        // new running queue that has the actually running sequences. Remember the preempted
138        // sequences, which will be put into the waiting or swapped out state depending on
139        // the preemption method (recompute or swap, respectively).
140
141        // Sorts by creation time, in descending order so that earliest are latest (first come first serve).
142        self.sort_running_by_priority_fcfs();
143
144        let mut running: VecDeque<Arc<Mutex<Sequence>>> = VecDeque::new();
145        let mut did_preempt = false;
146        while !self.running.is_empty() {
147            let seq = self.running.pop_front().unwrap();
148            let mut finished_with_break = false;
149            while !self
150                .block_engine
151                .can_append_token_to_seq(&*get_mut_arcmutex!(seq))
152            {
153                // If we cannot, now we need to preempt some seqs
154                if !self.running.is_empty() {
155                    // There is something to preempt.
156                    let seq_to_preempt = self.running.pop_back().unwrap();
157                    self._preempt(seq_to_preempt, &mut blocks_to_swap_out);
158                    did_preempt = true;
159                } else {
160                    // Nothing to preempt, preempt ourselves. Also, do not bother looking at anything else.
161                    self._preempt(seq.clone(), &mut blocks_to_swap_out);
162                    did_preempt = true;
163                    finished_with_break = true;
164                    break;
165                }
166            }
167            if !finished_with_break {
168                {
169                    // If we need to, append physical blocks for a new token. We do not need to if there is enough space.
170                    // If we just got preempted, there is no reason to allocate
171                    let seq_handle = get_mut_arcmutex!(seq);
172                    self._append_token_slot_to_seq(&seq_handle, &mut blocks_to_copy);
173                }
174                let new_seq_has_images = get_mut_arcmutex!(seq).has_images();
175                // Only add it if has_images matches either current or there are none.
176                if running.is_empty()
177                    || get_mut_arcmutex!(running[0]).has_images() == new_seq_has_images
178                {
179                    running.push_back(seq);
180                } else {
181                    self.running.push_back(seq);
182                }
183            }
184        }
185        self.running = running;
186
187        // Try to swap in the swapped out sequences and add these to the
188        // running state if possible.
189
190        // Sorts by creation time, in descending order so that earliest are latest (first come first serve).
191        self.sort_swapped_out_by_priority_fcfs();
192
193        if !did_preempt {
194            while !self.swapped_out.is_empty() {
195                let seq = self.swapped_out.front().unwrap();
196
197                // If the GPU cannot handle the group being swapped in, stop
198                if !self.block_engine.can_swap_in_seq(&*get_mut_arcmutex!(seq)) {
199                    break;
200                }
201
202                let seq = self.swapped_out.pop_front().unwrap();
203                // Swap in the blocks
204                let to_swap_in = self.block_engine.swap_in(&*get_mut_arcmutex!(seq));
205                blocks_to_swap_in.extend(to_swap_in);
206                {
207                    // Reserve a new slot
208                    let seq_handle = get_mut_arcmutex!(seq);
209                    self._append_token_slot_to_seq(&seq_handle, &mut blocks_to_copy);
210                }
211                self.running.push_back(seq);
212            }
213        }
214
215        self.running
216            .iter()
217            .for_each(|seq| get_mut_arcmutex!(seq).set_state(SequenceState::RunningCompletion));
218
219        if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
220            self.running.iter().for_each(|seq| {
221                get_mut_arcmutex!(seq).set_state(SequenceState::Done(StopReason::Canceled))
222            });
223            TERMINATE_ALL_NEXT_STEP.store(false, Ordering::SeqCst);
224        }
225
226        PagedAttentionSchedulerOutput {
227            scheduled: self.running.clone().into(), // Clone should be cheap.
228            blocks_to_swap_in,
229            blocks_to_copy,
230            blocks_to_swap_out,
231        }
232    }
233
234    pub fn free_finished_sequence_groups(&mut self) {
235        let mut to_free_ids = Vec::new();
236        self.running.retain(|seq| {
237            if get_mut_arcmutex!(seq).is_finished_paged_attn() {
238                to_free_ids.push(get_mut_arcmutex!(seq).get_id());
239                false
240            } else {
241                true
242            }
243        });
244
245        for id in to_free_ids {
246            self._free(id);
247        }
248    }
249}
250
251impl PagedAttentionScheduler {
252    #[allow(dead_code)]
253    fn remove_seq(&mut self, seq_id: usize) -> Arc<Mutex<Sequence>> {
254        // Remove it if it is in waiting
255        if let Some(idx) = self
256            .waiting
257            .iter()
258            .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
259        {
260            return self.waiting.remove(idx).unwrap();
261        };
262        // Remove it if it is in running
263        if let Some(idx) = self
264            .running
265            .iter()
266            .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
267        {
268            return self.running.remove(idx).unwrap();
269        };
270        // Remove it if it is in swapped out
271        if let Some(idx) = self
272            .swapped_out
273            .iter()
274            .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
275        {
276            return self.swapped_out.remove(idx).unwrap();
277        };
278        panic!("Attempted to remove sequence id {seq_id} but it is not running, waiting, or swapped out.");
279    }
280
281    fn _append_token_slot_to_seq(
282        &mut self,
283        seq: &Sequence,
284        blocks_to_copy: &mut HashMap<usize, Vec<usize>>,
285    ) {
286        let op = self.block_engine.append_token_slot_to_seq(seq);
287        if let Some((src_block, dst_block)) = op {
288            if let std::collections::hash_map::Entry::Vacant(e) = blocks_to_copy.entry(src_block) {
289                e.insert(vec![dst_block]);
290            } else {
291                blocks_to_copy.get_mut(&src_block).unwrap().push(dst_block);
292            }
293        }
294    }
295
296    fn _abort_seq(&mut self, seq_id: usize) {
297        let removed = self.remove_seq(seq_id);
298        get_mut_arcmutex!(removed).set_state(SequenceState::FinishedAborted);
299        self._free(seq_id);
300    }
301
302    /// Preempt either by recomputation (for single sequence), or by swapping (for multiple).
303    fn _preempt(
304        &mut self,
305        seq: Arc<Mutex<Sequence>>,
306        _blocks_to_swap_out: &mut HashMap<usize, usize>,
307    ) {
308        self._preempt_by_recompute(seq)
309    }
310
311    fn _preempt_by_recompute(&mut self, seq: Arc<Mutex<Sequence>>) {
312        get_mut_arcmutex!(seq).set_state(SequenceState::Waiting);
313        self._free(get_mut_arcmutex!(seq).get_id());
314        self.waiting.push_front(seq);
315    }
316
317    fn _preempt_by_swap(
318        &mut self,
319        seq: Arc<Mutex<Sequence>>,
320        blocks_to_swap_out: &mut HashMap<usize, usize>,
321    ) {
322        if !self.block_engine.can_swap_out_seq(&*get_mut_arcmutex!(seq)) {
323            // If we cannot swap it out, abort the sequence group.
324            let id = get_mut_arcmutex!(seq).get_id();
325            self._abort_seq(id);
326            return;
327        }
328        let new_to_swap = self.block_engine.swap_out(&*get_mut_arcmutex!(seq));
329        blocks_to_swap_out.extend(new_to_swap);
330        get_mut_arcmutex!(seq).set_state(SequenceState::Swapped);
331
332        self.swapped_out.push_back(seq);
333    }
334
335    fn _allocate(&mut self, seq: &Sequence) {
336        self.block_engine.allocate(seq)
337    }
338
339    fn _free(&mut self, seq_id: usize) {
340        self.block_engine.free_sequence(seq_id);
341    }
342
343    fn sort_running_by_priority_fcfs(&mut self) {
344        self.running
345            .make_contiguous()
346            .sort_by_key(|seq| get_mut_arcmutex!(seq).timestamp());
347        self.running.make_contiguous().reverse();
348    }
349
350    fn sort_swapped_out_by_priority_fcfs(&mut self) {
351        self.swapped_out
352            .make_contiguous()
353            .sort_by_key(|seq| get_mut_arcmutex!(seq).timestamp());
354        self.swapped_out.make_contiguous().reverse();
355    }
356}
357
358impl Scheduler for PagedAttentionScheduler {
359    fn add_seq(&mut self, seq: Sequence) {
360        self.waiting.push_back(Arc::new(Mutex::new(seq)));
361    }
362    fn schedule(&mut self) -> SchedulerOutput<'_> {
363        SchedulerOutput::PagedAttention {
364            output: self.schedule(),
365        }
366    }
367    fn waiting_len(&self) -> usize {
368        self.waiting.len() + self.swapped_out.len()
369    }
370    fn running_len(&self) -> usize {
371        self.running.len()
372    }
373    fn block_tables(&self) -> Option<&BlockTables> {
374        Some(&self.block_engine.block_tables)
375    }
376    fn block_size(&self) -> Option<usize> {
377        Some(self.block_size)
378    }
379    fn free_finished_sequence_groups(&mut self) {
380        self.free_finished_sequence_groups()
381    }
382    fn block_engine(&mut self) -> Option<&mut BlockEngine> {
383        Some(&mut self.block_engine)
384    }
385}