mistralrs_core/scheduler/
default_scheduler.rs

1use std::{
2    collections::{HashMap, VecDeque},
3    num::NonZeroUsize,
4    sync::{atomic::Ordering, Arc},
5};
6
7use crate::{
8    engine::{IntervalLogger, TERMINATE_ALL_NEXT_STEP},
9    paged_attention::{BlockEngine, BlockTables},
10    sequence::{Sequence, SequenceState, StopReason},
11};
12
13use super::{Scheduler, SchedulerOutput};
14
15pub trait FcfsBacker: Default {
16    fn new() -> Self;
17    fn add(&mut self, item: Sequence);
18    fn into_iter(self) -> impl Iterator<Item = Sequence>;
19    fn len(&self) -> usize;
20    fn sort_ascending_ids(&mut self);
21}
22
23impl FcfsBacker for VecDeque<Sequence> {
24    fn new() -> Self {
25        Self::new()
26    }
27    fn add(&mut self, item: Sequence) {
28        self.push_back(item)
29    }
30    fn into_iter(self) -> impl Iterator<Item = Sequence> {
31        <Self as IntoIterator>::into_iter(self)
32    }
33    fn sort_ascending_ids(&mut self) {
34        let slice = self.make_contiguous();
35        slice.sort_by_key(|seq| *seq.id());
36    }
37    fn len(&self) -> usize {
38        VecDeque::len(self)
39    }
40}
41
42pub struct DefaultSchedulerOutput<'a> {
43    pub completion: Box<[&'a mut Sequence]>,
44    pub prompt: Box<[&'a mut Sequence]>,
45}
46
47/// The scheduler method controld how sequences are scheduled during each
48/// step of the engine. For each scheduling step, the scheduler method is used if there
49/// are not only running, only waiting sequences, or none. If is it used, then it
50/// is used to allow waiting sequences to run.
51#[derive(Clone)]
52pub enum DefaultSchedulerMethod {
53    Fixed(NonZeroUsize),
54}
55
56pub struct BucketedSeqs<Backer: FcfsBacker> {
57    running: Vec<Sequence>,
58    waiting: Backer,
59}
60
61pub trait BucketingManager<Backer: FcfsBacker>: Send + Sync {
62    /// Bucket and waitlist running input sequences, returning the newly running sequences.
63    fn bucket_and_waitlist_seqs_waiting(
64        &mut self,
65        running: Vec<Sequence>,
66        waiting: Backer,
67        discrete: bool,
68    ) -> BucketedSeqs<Backer>;
69}
70
71// (cache length, (has_imgs && is_prompt), sequence offset)
72// Bucket by that metric for images because if we are not a prompt, then this doesn't apply
73type BucketKey = (usize, bool, usize);
74
75struct FixedBucketingManager;
76
77impl<Backer: FcfsBacker> BucketingManager<Backer> for FixedBucketingManager {
78    /// Move the seuqences into buckets, and run the ones with the shortest lengths.
79    /// The others are moved to the waiting list (retaining high priority due to start time),
80    /// without a state modification.
81    fn bucket_and_waitlist_seqs_waiting(
82        &mut self,
83        running: Vec<Sequence>,
84        mut waiting: Backer,
85        discrete: bool,
86    ) -> BucketedSeqs<Backer> {
87        // Now, get the sequences with the smallest sequence lengths, and allow them to catch up.
88        let mut seq_buckets: HashMap<BucketKey, Vec<Sequence>> = HashMap::new();
89        let mut seq_priorities: HashMap<BucketKey, f64> = HashMap::new();
90        for seq in running {
91            let len = seq.len();
92            match seq_buckets.get_mut(&(
93                len,
94                seq.images().is_some() && seq.is_prompt(),
95                seq.token_offset(),
96            )) {
97                Some(bucket) => {
98                    if !discrete {
99                        *seq_priorities
100                            .get_mut(&(
101                                len,
102                                seq.images().is_some() && seq.is_prompt(),
103                                seq.token_offset(),
104                            ))
105                            .unwrap() += seq.compute_priority();
106                    }
107                    bucket.push(seq);
108                }
109                None => {
110                    if !discrete {
111                        seq_priorities.insert(
112                            (
113                                len,
114                                seq.images().is_some() && seq.is_prompt(),
115                                seq.token_offset(),
116                            ),
117                            seq.compute_priority(),
118                        );
119                    }
120                    seq_buckets.insert(
121                        (
122                            len,
123                            seq.images().is_some() && seq.is_prompt(),
124                            seq.token_offset(),
125                        ),
126                        vec![seq],
127                    );
128                }
129            }
130        }
131        let running = if seq_buckets.len() <= 1 {
132            // Full steam ahead or have everything
133            seq_buckets
134                .into_iter()
135                .flat_map(|(_, x)| x)
136                .map(|s| s.reset_urgency())
137                .collect::<Vec<_>>()
138        } else {
139            // Set the min seqs to be the running ones, and the rest to be waiting (but their states are not changed!)
140            // Allow the min seqs to catch up.
141            let min = *seq_buckets
142                .keys()
143                .min_by_key(|(x, _, _)| *x)
144                .expect("No sequence buckets.");
145            let len = if !discrete {
146                seq_priorities
147                    .iter()
148                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
149                    .map(|(a, b)| (a, *b))
150                    .unwrap_or_else(|| (&min, seq_priorities[&min]))
151                    .0
152            } else {
153                &min
154            };
155            let highest_priority_seqs = seq_buckets
156                .remove(len)
157                .unwrap()
158                .into_iter()
159                .map(|s| s.reset_urgency())
160                .collect();
161            for (_, seqs) in seq_buckets {
162                for seq in seqs {
163                    waiting.add(seq.add_urgency());
164                }
165            }
166            // Know min_seqs.len < running.len() <= max
167            highest_priority_seqs
168        };
169        BucketedSeqs { running, waiting }
170    }
171}
172
173pub struct DefaultScheduler<Backer: FcfsBacker> {
174    waiting: Backer,
175    running: Vec<Sequence>,
176    method: DefaultSchedulerMethod,
177    bucketing_manager: Box<dyn BucketingManager<Backer>>,
178}
179
180impl<Backer: FcfsBacker> DefaultScheduler<Backer> {
181    pub fn new(method: DefaultSchedulerMethod) -> Self {
182        let bucketing_manager: Box<dyn BucketingManager<_>> = match method {
183            DefaultSchedulerMethod::Fixed(_) => Box::new(FixedBucketingManager),
184        };
185        Self {
186            running: Vec::new(),
187            waiting: Backer::new(),
188            method,
189            bucketing_manager,
190        }
191    }
192
193    /// Move the seuqences into buckets, and run the ones with the shortest lengths.
194    /// The others are moved to the waiting list (retaining high priority due to start time),
195    /// without a state modification.
196    fn bucket_and_waitlist_seqs(&mut self, running: Vec<Sequence>) -> Vec<Sequence> {
197        let waiting = std::mem::take(&mut self.waiting);
198        let BucketedSeqs { running, waiting } = self
199            .bucketing_manager
200            .bucket_and_waitlist_seqs_waiting(running, waiting, true);
201        self.waiting = waiting;
202        running
203    }
204
205    /// Schedule all sequences based on their state and the available space.
206    pub fn schedule(&mut self, logger: &IntervalLogger) -> DefaultSchedulerOutput {
207        // Filter out all done sequences
208        let running = std::mem::take(&mut self.running);
209        let mut waiting = std::mem::take(&mut self.waiting);
210        let mut running = running
211            .into_iter()
212            .filter(|seq| seq.is_running())
213            .collect::<Vec<_>>();
214
215        match (waiting.len(), running.len()) {
216            (0, 0) => {
217                self.running = running;
218                logger.set_num_running(self.running.len());
219                logger.set_num_waiting(self.waiting.len());
220                return DefaultSchedulerOutput {
221                    prompt: vec![].into(),
222                    completion: vec![].into(),
223                };
224            }
225            (_, 0) => {
226                for seq in waiting.into_iter() {
227                    seq.set_state(SequenceState::RunningPrompt);
228                    self.running.push(seq);
229                }
230                self.waiting = Backer::new();
231                let running = std::mem::take(&mut self.running);
232                self.running = self.bucket_and_waitlist_seqs(running);
233                logger.set_num_running(self.running.len());
234                logger.set_num_waiting(self.waiting.len());
235                return DefaultSchedulerOutput {
236                    prompt: self.running.iter_mut().collect::<Vec<_>>().into(),
237                    completion: vec![].into(),
238                };
239            }
240            (0, _) => {
241                self.running = self.bucket_and_waitlist_seqs(running);
242                if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
243                    self.running
244                        .iter_mut()
245                        .for_each(|seq| seq.set_state(SequenceState::Done(StopReason::Canceled)));
246                    TERMINATE_ALL_NEXT_STEP.store(false, Ordering::SeqCst);
247                }
248                logger.set_num_running(self.running.len());
249                logger.set_num_waiting(self.waiting.len());
250                return DefaultSchedulerOutput {
251                    prompt: vec![].into(),
252                    completion: self.running.iter_mut().collect::<Vec<_>>().into(),
253                };
254            }
255            _ => {}
256        }
257
258        // Sort the waiting seqs
259        waiting.sort_ascending_ids();
260
261        // If the waiting sequence will fit, add it. Otherwise remove it
262        let mut new_waiting = Backer::new();
263        for seq in waiting.into_iter() {
264            if self.sequence_fits(&running, &seq) {
265                if seq.is_waiting() {
266                    seq.set_state(SequenceState::RunningPrompt);
267                }
268                running.push(seq);
269            } else {
270                new_waiting.add(seq);
271            }
272        }
273
274        let BucketedSeqs {
275            running,
276            waiting: new_waiting,
277        } = self
278            .bucketing_manager
279            .bucket_and_waitlist_seqs_waiting(running, new_waiting, false);
280
281        self.running = running;
282        self.waiting = new_waiting;
283
284        logger.set_num_running(self.running.len());
285        logger.set_num_waiting(self.waiting.len());
286
287        let mut completion = Vec::new();
288        let mut prompt = Vec::new();
289        for seq in &mut self.running {
290            if seq.is_completion() {
291                completion.push(seq);
292            } else {
293                prompt.push(seq);
294            }
295        }
296
297        DefaultSchedulerOutput {
298            completion: completion.into(),
299            prompt: prompt.into(),
300        }
301    }
302
303    fn sequence_fits(&self, running: &[Sequence], _seq: &Sequence) -> bool {
304        match &self.method {
305            DefaultSchedulerMethod::Fixed(n) => (running.len() + 1) <= (*n).into(),
306        }
307    }
308}
309
310impl Scheduler for DefaultScheduler<VecDeque<Sequence>> {
311    fn schedule(&mut self, logger: &IntervalLogger) -> SchedulerOutput<'_> {
312        SchedulerOutput::DefaultScheduler {
313            output: self.schedule(logger),
314        }
315    }
316    fn waiting_len(&self) -> usize {
317        self.waiting.len()
318    }
319    fn running_len(&self) -> usize {
320        self.running.len()
321    }
322    fn add_seq(&mut self, seq: Sequence) {
323        if seq.is_running() {
324            // prefill case
325            self.running.push(seq);
326        } else {
327            self.waiting.add(seq);
328        }
329    }
330    fn block_tables(&self) -> Option<BlockTables> {
331        None
332    }
333    fn block_size(&self) -> Option<usize> {
334        None
335    }
336    fn free_finished_sequence_groups(&mut self) {}
337    fn block_engine(&self) -> Option<Arc<tokio::sync::Mutex<BlockEngine>>> {
338        None
339    }
340}