mistralrs_core/scheduler/
default_scheduler.rs

1use std::{
2    collections::{HashMap, VecDeque},
3    num::NonZeroUsize,
4    sync::atomic::Ordering,
5};
6
7use crate::{
8    engine::TERMINATE_ALL_NEXT_STEP,
9    paged_attention::{BlockEngine, BlockTables},
10    sequence::{Sequence, SequenceState, StopReason},
11};
12
13use super::{Scheduler, SchedulerOutput};
14
15pub trait FcfsBacker: Default {
16    fn new() -> Self;
17    fn add(&mut self, item: Sequence);
18    fn into_iter(self) -> impl Iterator<Item = Sequence>;
19    fn len(&self) -> usize;
20    fn sort_ascending_ids(&mut self);
21}
22
23impl FcfsBacker for VecDeque<Sequence> {
24    fn new() -> Self {
25        Self::new()
26    }
27    fn add(&mut self, item: Sequence) {
28        self.push_back(item)
29    }
30    fn into_iter(self) -> impl Iterator<Item = Sequence> {
31        <Self as IntoIterator>::into_iter(self)
32    }
33    fn sort_ascending_ids(&mut self) {
34        let slice = self.make_contiguous();
35        slice.sort_by_key(|seq| *seq.id());
36    }
37    fn len(&self) -> usize {
38        VecDeque::len(self)
39    }
40}
41
42pub struct DefaultSchedulerOutput<'a> {
43    pub completion: Box<[&'a mut Sequence]>,
44    pub prompt: Box<[&'a mut Sequence]>,
45}
46
47/// The scheduler method controld how sequences are scheduled during each
48/// step of the engine. For each scheduling step, the scheduler method is used if there
49/// are not only running, only waiting sequences, or none. If is it used, then it
50/// is used to allow waiting sequences to run.
51#[derive(Clone)]
52pub enum DefaultSchedulerMethod {
53    Fixed(NonZeroUsize),
54}
55
56pub struct BucketedSeqs<Backer: FcfsBacker> {
57    running: Vec<Sequence>,
58    waiting: Backer,
59}
60
61pub trait BucketingManager<Backer: FcfsBacker>: Send + Sync {
62    /// Bucket and waitlist running input sequences, returning the newly running sequences.
63    fn bucket_and_waitlist_seqs_waiting(
64        &mut self,
65        running: Vec<Sequence>,
66        waiting: Backer,
67        discrete: bool,
68    ) -> BucketedSeqs<Backer>;
69}
70
71// (cache length, (has_imgs && is_prompt), sequence offset)
72// Bucket by that metric for images because if we are not a prompt, then this doesn't apply
73type BucketKey = (usize, bool, usize);
74
75struct FixedBucketingManager;
76
77impl<Backer: FcfsBacker> BucketingManager<Backer> for FixedBucketingManager {
78    /// Move the seuqences into buckets, and run the ones with the shortest lengths.
79    /// The others are moved to the waiting list (retaining high priority due to start time),
80    /// without a state modification.
81    fn bucket_and_waitlist_seqs_waiting(
82        &mut self,
83        running: Vec<Sequence>,
84        mut waiting: Backer,
85        discrete: bool,
86    ) -> BucketedSeqs<Backer> {
87        // Now, get the sequences with the smallest sequence lengths, and allow them to catch up.
88        let mut seq_buckets: HashMap<BucketKey, Vec<Sequence>> = HashMap::new();
89        let mut seq_priorities: HashMap<BucketKey, f64> = HashMap::new();
90        for seq in running {
91            let len = seq.len();
92            match seq_buckets.get_mut(&(
93                len,
94                seq.images().is_some() && seq.is_prompt(),
95                seq.token_offset(),
96            )) {
97                Some(bucket) => {
98                    if !discrete {
99                        *seq_priorities
100                            .get_mut(&(
101                                len,
102                                seq.images().is_some() && seq.is_prompt(),
103                                seq.token_offset(),
104                            ))
105                            .unwrap() += seq.compute_priority();
106                    }
107                    bucket.push(seq);
108                }
109                None => {
110                    if !discrete {
111                        seq_priorities.insert(
112                            (
113                                len,
114                                seq.images().is_some() && seq.is_prompt(),
115                                seq.token_offset(),
116                            ),
117                            seq.compute_priority(),
118                        );
119                    }
120                    seq_buckets.insert(
121                        (
122                            len,
123                            seq.images().is_some() && seq.is_prompt(),
124                            seq.token_offset(),
125                        ),
126                        vec![seq],
127                    );
128                }
129            }
130        }
131        let running = if seq_buckets.len() <= 1 {
132            // Full steam ahead or have everything
133            seq_buckets
134                .into_iter()
135                .flat_map(|(_, x)| x)
136                .map(|s| s.reset_urgency())
137                .collect::<Vec<_>>()
138        } else {
139            // Set the min seqs to be the running ones, and the rest to be waiting (but their states are not changed!)
140            // Allow the min seqs to catch up.
141            let min = *seq_buckets
142                .keys()
143                .min_by_key(|(x, _, _)| *x)
144                .expect("No sequence buckets.");
145            let len = if !discrete {
146                seq_priorities
147                    .iter()
148                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
149                    .map(|(a, b)| (a, *b))
150                    .unwrap_or_else(|| (&min, seq_priorities[&min]))
151                    .0
152            } else {
153                &min
154            };
155            let highest_priority_seqs = seq_buckets
156                .remove(len)
157                .unwrap()
158                .into_iter()
159                .map(|s| s.reset_urgency())
160                .collect();
161            for (_, seqs) in seq_buckets {
162                for seq in seqs {
163                    waiting.add(seq.add_urgency());
164                }
165            }
166            // Know min_seqs.len < running.len() <= max
167            highest_priority_seqs
168        };
169        BucketedSeqs { running, waiting }
170    }
171}
172
173pub struct DefaultScheduler<Backer: FcfsBacker> {
174    waiting: Backer,
175    running: Vec<Sequence>,
176    method: DefaultSchedulerMethod,
177    bucketing_manager: Box<dyn BucketingManager<Backer>>,
178}
179
180impl<Backer: FcfsBacker> DefaultScheduler<Backer> {
181    pub fn new(method: DefaultSchedulerMethod) -> Self {
182        let bucketing_manager: Box<dyn BucketingManager<_>> = match method {
183            DefaultSchedulerMethod::Fixed(_) => Box::new(FixedBucketingManager),
184        };
185        Self {
186            running: Vec::new(),
187            waiting: Backer::new(),
188            method,
189            bucketing_manager,
190        }
191    }
192
193    /// Move the seuqences into buckets, and run the ones with the shortest lengths.
194    /// The others are moved to the waiting list (retaining high priority due to start time),
195    /// without a state modification.
196    fn bucket_and_waitlist_seqs(&mut self, running: Vec<Sequence>) -> Vec<Sequence> {
197        let waiting = std::mem::take(&mut self.waiting);
198        let BucketedSeqs { running, waiting } = self
199            .bucketing_manager
200            .bucket_and_waitlist_seqs_waiting(running, waiting, true);
201        self.waiting = waiting;
202        running
203    }
204
205    /// Schedule all sequences based on their state and the available space.
206    pub fn schedule(&mut self) -> DefaultSchedulerOutput {
207        // Filter out all done sequences
208        let running = std::mem::take(&mut self.running);
209        let mut waiting = std::mem::take(&mut self.waiting);
210        let mut running = running
211            .into_iter()
212            .filter(|seq| seq.is_running())
213            .collect::<Vec<_>>();
214
215        match (waiting.len(), running.len()) {
216            (0, 0) => {
217                self.running = running;
218                return DefaultSchedulerOutput {
219                    prompt: vec![].into(),
220                    completion: vec![].into(),
221                };
222            }
223            (_, 0) => {
224                for seq in waiting.into_iter() {
225                    seq.set_state(SequenceState::RunningPrompt);
226                    self.running.push(seq);
227                }
228                self.waiting = Backer::new();
229                let running = std::mem::take(&mut self.running);
230                self.running = self.bucket_and_waitlist_seqs(running);
231                return DefaultSchedulerOutput {
232                    prompt: self.running.iter_mut().collect::<Vec<_>>().into(),
233                    completion: vec![].into(),
234                };
235            }
236            (0, _) => {
237                self.running = self.bucket_and_waitlist_seqs(running);
238                if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
239                    self.running
240                        .iter_mut()
241                        .for_each(|seq| seq.set_state(SequenceState::Done(StopReason::Canceled)));
242                    TERMINATE_ALL_NEXT_STEP.store(false, Ordering::SeqCst);
243                }
244                return DefaultSchedulerOutput {
245                    prompt: vec![].into(),
246                    completion: self.running.iter_mut().collect::<Vec<_>>().into(),
247                };
248            }
249            _ => {}
250        }
251
252        // Sort the waiting seqs
253        waiting.sort_ascending_ids();
254
255        // If the waiting sequence will fit, add it. Otherwise remove it
256        let mut new_waiting = Backer::new();
257        for seq in waiting.into_iter() {
258            if self.sequence_fits(&running, &seq) {
259                if seq.is_waiting() {
260                    seq.set_state(SequenceState::RunningPrompt);
261                }
262                running.push(seq);
263            } else {
264                new_waiting.add(seq);
265            }
266        }
267
268        let BucketedSeqs {
269            running,
270            waiting: new_waiting,
271        } = self
272            .bucketing_manager
273            .bucket_and_waitlist_seqs_waiting(running, new_waiting, false);
274
275        self.running = running;
276        self.waiting = new_waiting;
277
278        let mut completion = Vec::new();
279        let mut prompt = Vec::new();
280        for seq in &mut self.running {
281            if seq.is_completion() {
282                completion.push(seq);
283            } else {
284                prompt.push(seq);
285            }
286        }
287
288        DefaultSchedulerOutput {
289            completion: completion.into(),
290            prompt: prompt.into(),
291        }
292    }
293
294    fn sequence_fits(&self, running: &[Sequence], _seq: &Sequence) -> bool {
295        match &self.method {
296            DefaultSchedulerMethod::Fixed(n) => (running.len() + 1) <= (*n).into(),
297        }
298    }
299}
300
301impl Scheduler for DefaultScheduler<VecDeque<Sequence>> {
302    fn schedule(&mut self) -> SchedulerOutput<'_> {
303        SchedulerOutput::DefaultScheduler {
304            output: self.schedule(),
305        }
306    }
307    fn waiting_len(&self) -> usize {
308        self.waiting.len()
309    }
310    fn running_len(&self) -> usize {
311        self.running.len()
312    }
313    fn add_seq(&mut self, seq: Sequence) {
314        if seq.is_running() {
315            // prefill case
316            self.running.push(seq);
317        } else {
318            self.waiting.add(seq);
319        }
320    }
321    fn block_tables(&self) -> Option<&BlockTables> {
322        None
323    }
324    fn block_size(&self) -> Option<usize> {
325        None
326    }
327    fn free_finished_sequence_groups(&mut self) {}
328    fn block_engine(&mut self) -> Option<&mut BlockEngine> {
329        None
330    }
331}