mistralrs_core/dummy_paged_attention/
scheduler.rs

1//! The Scheduler uses a BlockEngine to schedule and automatically batch sequences. The
2//! primary method `schedule` returns the batched sequences as inputs, as well as the
3//! operations to be executed on the cache by the CacheEngine.
4
5type CPUBlockFrom = usize;
6type GPUBlockFrom = usize;
7type CPUBlockTo = usize;
8type GPUBlockTo = usize;
9type SrcBlockFrom = usize;
10type DstBlocksTo = Vec<usize>;
11
12use std::{
13    collections::{HashMap, VecDeque},
14    sync::{atomic::Ordering, Arc, Mutex},
15};
16
17use tracing::warn;
18
19use crate::{
20    get_mut_arcmutex,
21    paged_attention::BlockEngine,
22    scheduler::{Scheduler, SchedulerOutput},
23    sequence::{Sequence, SequenceState, StopReason},
24    TERMINATE_ALL_NEXT_STEP,
25};
26
27use super::{block_engine::AllocStatus, BlockEngineSequence, BlockTables, CacheConfig};
28
29pub struct PagedAttentionSchedulerOutput {
30    /// Either ALL prompt or ALL completion.
31    pub scheduled: Vec<Arc<Mutex<Sequence>>>,
32    pub blocks_to_swap_in: HashMap<CPUBlockFrom, GPUBlockTo>,
33    pub blocks_to_swap_out: HashMap<GPUBlockFrom, CPUBlockTo>,
34    pub blocks_to_copy: HashMap<SrcBlockFrom, DstBlocksTo>,
35}
36
37pub struct PagedAttentionSchedulerConfig {
38    pub max_num_seqs: usize,
39}
40
41pub struct PagedAttentionScheduler {
42    waiting: VecDeque<Arc<Mutex<Sequence>>>,
43    running: VecDeque<Arc<Mutex<Sequence>>>,
44    swapped_out: VecDeque<Arc<Mutex<Sequence>>>,
45    config: PagedAttentionSchedulerConfig,
46    pub block_engine: BlockEngine,
47    block_size: usize,
48}
49
50impl PagedAttentionScheduler {
51    pub fn new(config: PagedAttentionSchedulerConfig, cache_config: CacheConfig) -> Self {
52        Self {
53            waiting: VecDeque::new(),
54            running: VecDeque::new(),
55            swapped_out: VecDeque::new(),
56            config,
57            block_engine: BlockEngine::new(
58                cache_config.block_size,
59                cache_config.num_gpu_blocks,
60                cache_config.num_cpu_blocks,
61            ),
62            block_size: cache_config.block_size,
63        }
64    }
65
66    pub fn schedule(&mut self) -> PagedAttentionSchedulerOutput {
67        // If there are no swapped seqs (they have higher priority), add seqs that are in the
68        // waiting queue to the running queue.
69        if self.swapped_out.is_empty() {
70            let mut scheduled = VecDeque::new();
71            let mut did_ignore = false;
72            while !self.waiting.is_empty() {
73                let seq = self.waiting.front().unwrap().clone();
74
75                // If adding this seq means we will have too many, stop as no more could be added.
76                if self.config.max_num_seqs == self.running.len() + 1 {
77                    break;
78                }
79
80                // If we cannot allocate either now or in the future, either do not continue or remove the sequence.
81                let can_allocate = self.block_engine.can_allocate(&*get_mut_arcmutex!(seq));
82                match can_allocate {
83                    AllocStatus::Later => break, // If we can only allocate later, do not bother iterating over the rest.
84                    AllocStatus::Impossible => {
85                        let id = *get_mut_arcmutex!(seq).id();
86                        let len = get_mut_arcmutex!(seq).get_toks().len();
87                        warn!(
88                            "Sequence {id} with length of {len} tokens is too long and exceeds capacity of block engine. Sequence will be ignored.",
89                        );
90                        get_mut_arcmutex!(seq).set_state(SequenceState::FinishedIgnored);
91                        did_ignore = true;
92                    }
93                    _ => {}
94                }
95
96                if !did_ignore {
97                    get_mut_arcmutex!(seq).set_state(SequenceState::RunningPrompt);
98                    let seq_handle = get_mut_arcmutex!(seq);
99                    self._allocate(&seq_handle);
100                }
101
102                let seq = self.waiting.pop_front().unwrap();
103                self.running.push_back(seq.clone());
104                if !did_ignore {
105                    scheduled.push_back(seq);
106                }
107            }
108
109            // If we did schedule, or we ignored sequences.
110            if !scheduled.is_empty() || did_ignore {
111                return PagedAttentionSchedulerOutput {
112                    scheduled: scheduled.into(),
113                    blocks_to_swap_in: HashMap::new(),
114                    blocks_to_copy: HashMap::new(),
115                    blocks_to_swap_out: HashMap::new(),
116                };
117            }
118        }
119
120        let mut blocks_to_swap_out = HashMap::new();
121        let mut blocks_to_swap_in = HashMap::new();
122        let mut blocks_to_copy = HashMap::new();
123
124        // Reserve token slots for the running sequence groups, preempting the lowest (earliest) first.
125        // Preempt lowest priority sequences that are in the running queue, forming a
126        // new running queue that has the actually running sequences. Remember the preempted
127        // sequences, which will be put into the waiting or swapped out state depending on
128        // the preemption method (recompute or swap, respectively).
129
130        // Sorts by creation time, in descending order so that earliest are latest (first come first serve).
131        self.sort_running_by_priority_fcfs();
132
133        let mut running = VecDeque::new();
134        let mut did_preempt = false;
135        while !self.running.is_empty() {
136            let seq = self.running.pop_front().unwrap();
137            let mut finished_with_break = false;
138            while !self
139                .block_engine
140                .can_append_token_to_seq(&*get_mut_arcmutex!(seq))
141            {
142                // If we cannot, now we need to preempt some seqs
143                if !self.running.is_empty() {
144                    // There is something to preempt.
145                    let seq_to_preempt = self.running.pop_back().unwrap();
146                    self._preempt(seq_to_preempt, &mut blocks_to_swap_out);
147                    did_preempt = true;
148                } else {
149                    // Nothing to preempt, preempt ourselves. Also, do not bother looking at anything else.
150                    self._preempt(seq.clone(), &mut blocks_to_swap_out);
151                    did_preempt = true;
152                    finished_with_break = true;
153                    break;
154                }
155            }
156            if !finished_with_break {
157                {
158                    // If we need to, append physical blocks for a new token. We do not need to if there is enough space.
159                    // If we just got preempted, there is no reason to allocate
160                    let seq_handle = get_mut_arcmutex!(seq);
161                    self._append_token_slot_to_seq(&seq_handle, &mut blocks_to_copy);
162                }
163                running.push_back(seq);
164            }
165        }
166        self.running = running;
167
168        // Try to swap in the swapped out sequences and add these to the
169        // running state if possible.
170
171        // Sorts by creation time, in descending order so that earliest are latest (first come first serve).
172        self.sort_swapped_out_by_priority_fcfs();
173
174        if !did_preempt {
175            while !self.swapped_out.is_empty() {
176                let seq = self.swapped_out.front().unwrap();
177
178                // If the GPU cannot handle the group being swapped in, stop
179                if !self.block_engine.can_swap_in_seq(&*get_mut_arcmutex!(seq)) {
180                    break;
181                }
182
183                let seq = self.swapped_out.pop_front().unwrap();
184                // Swap in the blocks
185                let to_swap_in = self.block_engine.swap_in(&*get_mut_arcmutex!(seq));
186                blocks_to_swap_in.extend(to_swap_in);
187                {
188                    // Reserve a new slot
189                    let seq_handle = get_mut_arcmutex!(seq);
190                    self._append_token_slot_to_seq(&seq_handle, &mut blocks_to_copy);
191                }
192                self.running.push_back(seq);
193            }
194        }
195
196        self.running
197            .iter()
198            .for_each(|seq| get_mut_arcmutex!(seq).set_state(SequenceState::RunningCompletion));
199
200        if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
201            self.running.iter().for_each(|seq| {
202                get_mut_arcmutex!(seq).set_state(SequenceState::Done(StopReason::Canceled))
203            });
204            TERMINATE_ALL_NEXT_STEP.store(false, Ordering::SeqCst);
205        }
206
207        PagedAttentionSchedulerOutput {
208            scheduled: self.running.clone().into(), // Clone should be cheap.
209            blocks_to_swap_in,
210            blocks_to_copy,
211            blocks_to_swap_out,
212        }
213    }
214
215    pub fn free_finished_sequence_groups(&mut self) {
216        let mut to_free_ids = Vec::new();
217        self.running.retain(|seq| {
218            if get_mut_arcmutex!(seq).is_finished_paged_attn() {
219                to_free_ids.push(get_mut_arcmutex!(seq).get_id());
220                false
221            } else {
222                true
223            }
224        });
225        for id in to_free_ids {
226            self._free(id);
227        }
228    }
229}
230
231impl PagedAttentionScheduler {
232    #[allow(dead_code)]
233    fn remove_seq(&mut self, seq_id: usize) -> Arc<Mutex<Sequence>> {
234        // Remove it if it is in waiting
235        if let Some(idx) = self
236            .waiting
237            .iter()
238            .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
239        {
240            return self.waiting.remove(idx).unwrap();
241        };
242        // Remove it if it is in running
243        if let Some(idx) = self
244            .running
245            .iter()
246            .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
247        {
248            return self.running.remove(idx).unwrap();
249        };
250        // Remove it if it is in swapped out
251        if let Some(idx) = self
252            .swapped_out
253            .iter()
254            .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
255        {
256            return self.swapped_out.remove(idx).unwrap();
257        };
258        panic!("Attempted to remove sequence id {seq_id} but it is not running, waiting, or swapped out.");
259    }
260
261    fn _append_token_slot_to_seq(
262        &mut self,
263        seq: &Sequence,
264        blocks_to_copy: &mut HashMap<usize, Vec<usize>>,
265    ) {
266        let op = self.block_engine.append_token_slot_to_seq(seq);
267        if let Some((src_block, dst_block)) = op {
268            if let std::collections::hash_map::Entry::Vacant(e) = blocks_to_copy.entry(src_block) {
269                e.insert(vec![dst_block]);
270            } else {
271                blocks_to_copy.get_mut(&src_block).unwrap().push(dst_block);
272            }
273        }
274    }
275
276    fn _abort_seq(&mut self, seq_id: usize) {
277        let removed = self.remove_seq(seq_id);
278        get_mut_arcmutex!(removed).set_state(SequenceState::FinishedAborted);
279        self._free(seq_id);
280    }
281
282    /// Preempt either by recomputation (for single sequence), or by swapping (for multiple).
283    fn _preempt(
284        &mut self,
285        seq: Arc<Mutex<Sequence>>,
286        _blocks_to_swap_out: &mut HashMap<usize, usize>,
287    ) {
288        self._preempt_by_recompute(seq)
289    }
290
291    fn _preempt_by_recompute(&mut self, seq: Arc<Mutex<Sequence>>) {
292        get_mut_arcmutex!(seq).set_state(SequenceState::Waiting);
293        self._free(get_mut_arcmutex!(seq).get_id());
294        self.waiting.push_front(seq);
295    }
296
297    fn _preempt_by_swap(
298        &mut self,
299        seq: Arc<Mutex<Sequence>>,
300        blocks_to_swap_out: &mut HashMap<usize, usize>,
301    ) {
302        if !self.block_engine.can_swap_out_seq(&*get_mut_arcmutex!(seq)) {
303            // If we cannot swap it out, abort the sequence group.
304            let id = get_mut_arcmutex!(seq).get_id();
305            self._abort_seq(id);
306            return;
307        }
308        let new_to_swap = self.block_engine.swap_out(&*get_mut_arcmutex!(seq));
309        blocks_to_swap_out.extend(new_to_swap);
310        get_mut_arcmutex!(seq).set_state(SequenceState::Swapped);
311
312        self.swapped_out.push_back(seq);
313    }
314
315    fn _allocate(&mut self, seq: &Sequence) {
316        self.block_engine.allocate(seq)
317    }
318
319    fn _free(&mut self, seq_id: usize) {
320        self.block_engine.free_sequence(seq_id);
321    }
322
323    fn sort_running_by_priority_fcfs(&mut self) {
324        self.running
325            .make_contiguous()
326            .sort_by_key(|seq| get_mut_arcmutex!(seq).timestamp());
327        self.running.make_contiguous().reverse();
328    }
329
330    fn sort_swapped_out_by_priority_fcfs(&mut self) {
331        self.swapped_out
332            .make_contiguous()
333            .sort_by_key(|seq| get_mut_arcmutex!(seq).timestamp());
334        self.swapped_out.make_contiguous().reverse();
335    }
336}
337
338impl Scheduler for PagedAttentionScheduler {
339    fn add_seq(&mut self, seq: Sequence) {
340        self.waiting.push_back(Arc::new(Mutex::new(seq)));
341    }
342    fn schedule(&mut self) -> SchedulerOutput<'_> {
343        SchedulerOutput::PagedAttention {
344            output: self.schedule(),
345        }
346    }
347    fn waiting_len(&self) -> usize {
348        self.waiting.len() + self.swapped_out.len()
349    }
350    fn running_len(&self) -> usize {
351        self.running.len()
352    }
353    fn block_tables(&self) -> Option<&BlockTables> {
354        Some(&self.block_engine.block_tables)
355    }
356    fn block_size(&self) -> Option<usize> {
357        Some(self.block_size)
358    }
359    fn free_finished_sequence_groups(&mut self) {
360        self.free_finished_sequence_groups()
361    }
362    fn block_engine(&mut self) -> Option<&mut BlockEngine> {
363        Some(&mut self.block_engine)
364    }
365}