mistralrs_core/scheduler/
default_scheduler.rs1use std::{
2 collections::{HashMap, VecDeque},
3 num::NonZeroUsize,
4 sync::atomic::Ordering,
5};
6
7use crate::{
8 engine::TERMINATE_ALL_NEXT_STEP,
9 paged_attention::{BlockEngine, BlockTables},
10 sequence::{Sequence, SequenceState, StopReason},
11};
12
13use super::{Scheduler, SchedulerOutput};
14
15pub trait FcfsBacker: Default {
16 fn new() -> Self;
17 fn add(&mut self, item: Sequence);
18 fn into_iter(self) -> impl Iterator<Item = Sequence>;
19 fn len(&self) -> usize;
20 fn sort_ascending_ids(&mut self);
21}
22
23impl FcfsBacker for VecDeque<Sequence> {
24 fn new() -> Self {
25 Self::new()
26 }
27 fn add(&mut self, item: Sequence) {
28 self.push_back(item)
29 }
30 fn into_iter(self) -> impl Iterator<Item = Sequence> {
31 <Self as IntoIterator>::into_iter(self)
32 }
33 fn sort_ascending_ids(&mut self) {
34 let slice = self.make_contiguous();
35 slice.sort_by_key(|seq| *seq.id());
36 }
37 fn len(&self) -> usize {
38 VecDeque::len(self)
39 }
40}
41
42pub struct DefaultSchedulerOutput<'a> {
43 pub completion: Box<[&'a mut Sequence]>,
44 pub prompt: Box<[&'a mut Sequence]>,
45}
46
47#[derive(Clone)]
52pub enum DefaultSchedulerMethod {
53 Fixed(NonZeroUsize),
54}
55
56pub struct BucketedSeqs<Backer: FcfsBacker> {
57 running: Vec<Sequence>,
58 waiting: Backer,
59}
60
61pub trait BucketingManager<Backer: FcfsBacker>: Send + Sync {
62 fn bucket_and_waitlist_seqs_waiting(
64 &mut self,
65 running: Vec<Sequence>,
66 waiting: Backer,
67 discrete: bool,
68 ) -> BucketedSeqs<Backer>;
69}
70
71type BucketKey = (usize, bool, usize);
74
75struct FixedBucketingManager;
76
77impl<Backer: FcfsBacker> BucketingManager<Backer> for FixedBucketingManager {
78 fn bucket_and_waitlist_seqs_waiting(
82 &mut self,
83 running: Vec<Sequence>,
84 mut waiting: Backer,
85 discrete: bool,
86 ) -> BucketedSeqs<Backer> {
87 let mut seq_buckets: HashMap<BucketKey, Vec<Sequence>> = HashMap::new();
89 let mut seq_priorities: HashMap<BucketKey, f64> = HashMap::new();
90 for seq in running {
91 let len = seq.len();
92 match seq_buckets.get_mut(&(
93 len,
94 seq.images().is_some() && seq.is_prompt(),
95 seq.token_offset(),
96 )) {
97 Some(bucket) => {
98 if !discrete {
99 *seq_priorities
100 .get_mut(&(
101 len,
102 seq.images().is_some() && seq.is_prompt(),
103 seq.token_offset(),
104 ))
105 .unwrap() += seq.compute_priority();
106 }
107 bucket.push(seq);
108 }
109 None => {
110 if !discrete {
111 seq_priorities.insert(
112 (
113 len,
114 seq.images().is_some() && seq.is_prompt(),
115 seq.token_offset(),
116 ),
117 seq.compute_priority(),
118 );
119 }
120 seq_buckets.insert(
121 (
122 len,
123 seq.images().is_some() && seq.is_prompt(),
124 seq.token_offset(),
125 ),
126 vec![seq],
127 );
128 }
129 }
130 }
131 let running = if seq_buckets.len() <= 1 {
132 seq_buckets
134 .into_iter()
135 .flat_map(|(_, x)| x)
136 .map(|s| s.reset_urgency())
137 .collect::<Vec<_>>()
138 } else {
139 let min = *seq_buckets
142 .keys()
143 .min_by_key(|(x, _, _)| *x)
144 .expect("No sequence buckets.");
145 let len = if !discrete {
146 seq_priorities
147 .iter()
148 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
149 .map(|(a, b)| (a, *b))
150 .unwrap_or_else(|| (&min, seq_priorities[&min]))
151 .0
152 } else {
153 &min
154 };
155 let highest_priority_seqs = seq_buckets
156 .remove(len)
157 .unwrap()
158 .into_iter()
159 .map(|s| s.reset_urgency())
160 .collect();
161 for (_, seqs) in seq_buckets {
162 for seq in seqs {
163 waiting.add(seq.add_urgency());
164 }
165 }
166 highest_priority_seqs
168 };
169 BucketedSeqs { running, waiting }
170 }
171}
172
173pub struct DefaultScheduler<Backer: FcfsBacker> {
174 waiting: Backer,
175 running: Vec<Sequence>,
176 method: DefaultSchedulerMethod,
177 bucketing_manager: Box<dyn BucketingManager<Backer>>,
178}
179
180impl<Backer: FcfsBacker> DefaultScheduler<Backer> {
181 pub fn new(method: DefaultSchedulerMethod) -> Self {
182 let bucketing_manager: Box<dyn BucketingManager<_>> = match method {
183 DefaultSchedulerMethod::Fixed(_) => Box::new(FixedBucketingManager),
184 };
185 Self {
186 running: Vec::new(),
187 waiting: Backer::new(),
188 method,
189 bucketing_manager,
190 }
191 }
192
193 fn bucket_and_waitlist_seqs(&mut self, running: Vec<Sequence>) -> Vec<Sequence> {
197 let waiting = std::mem::take(&mut self.waiting);
198 let BucketedSeqs { running, waiting } = self
199 .bucketing_manager
200 .bucket_and_waitlist_seqs_waiting(running, waiting, true);
201 self.waiting = waiting;
202 running
203 }
204
205 pub fn schedule(&mut self) -> DefaultSchedulerOutput {
207 let running = std::mem::take(&mut self.running);
209 let mut waiting = std::mem::take(&mut self.waiting);
210 let mut running = running
211 .into_iter()
212 .filter(|seq| seq.is_running())
213 .collect::<Vec<_>>();
214
215 match (waiting.len(), running.len()) {
216 (0, 0) => {
217 self.running = running;
218 return DefaultSchedulerOutput {
219 prompt: vec![].into(),
220 completion: vec![].into(),
221 };
222 }
223 (_, 0) => {
224 for seq in waiting.into_iter() {
225 seq.set_state(SequenceState::RunningPrompt);
226 self.running.push(seq);
227 }
228 self.waiting = Backer::new();
229 let running = std::mem::take(&mut self.running);
230 self.running = self.bucket_and_waitlist_seqs(running);
231 return DefaultSchedulerOutput {
232 prompt: self.running.iter_mut().collect::<Vec<_>>().into(),
233 completion: vec![].into(),
234 };
235 }
236 (0, _) => {
237 self.running = self.bucket_and_waitlist_seqs(running);
238 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
239 self.running
240 .iter_mut()
241 .for_each(|seq| seq.set_state(SequenceState::Done(StopReason::Canceled)));
242 TERMINATE_ALL_NEXT_STEP.store(false, Ordering::SeqCst);
243 }
244 return DefaultSchedulerOutput {
245 prompt: vec![].into(),
246 completion: self.running.iter_mut().collect::<Vec<_>>().into(),
247 };
248 }
249 _ => {}
250 }
251
252 waiting.sort_ascending_ids();
254
255 let mut new_waiting = Backer::new();
257 for seq in waiting.into_iter() {
258 if self.sequence_fits(&running, &seq) {
259 if seq.is_waiting() {
260 seq.set_state(SequenceState::RunningPrompt);
261 }
262 running.push(seq);
263 } else {
264 new_waiting.add(seq);
265 }
266 }
267
268 let BucketedSeqs {
269 running,
270 waiting: new_waiting,
271 } = self
272 .bucketing_manager
273 .bucket_and_waitlist_seqs_waiting(running, new_waiting, false);
274
275 self.running = running;
276 self.waiting = new_waiting;
277
278 let mut completion = Vec::new();
279 let mut prompt = Vec::new();
280 for seq in &mut self.running {
281 if seq.is_completion() {
282 completion.push(seq);
283 } else {
284 prompt.push(seq);
285 }
286 }
287
288 DefaultSchedulerOutput {
289 completion: completion.into(),
290 prompt: prompt.into(),
291 }
292 }
293
294 fn sequence_fits(&self, running: &[Sequence], _seq: &Sequence) -> bool {
295 match &self.method {
296 DefaultSchedulerMethod::Fixed(n) => (running.len() + 1) <= (*n).into(),
297 }
298 }
299}
300
301impl Scheduler for DefaultScheduler<VecDeque<Sequence>> {
302 fn schedule(&mut self) -> SchedulerOutput<'_> {
303 SchedulerOutput::DefaultScheduler {
304 output: self.schedule(),
305 }
306 }
307 fn waiting_len(&self) -> usize {
308 self.waiting.len()
309 }
310 fn running_len(&self) -> usize {
311 self.running.len()
312 }
313 fn add_seq(&mut self, seq: Sequence) {
314 if seq.is_running() {
315 self.running.push(seq);
317 } else {
318 self.waiting.add(seq);
319 }
320 }
321 fn block_tables(&self) -> Option<&BlockTables> {
322 None
323 }
324 fn block_size(&self) -> Option<usize> {
325 None
326 }
327 fn free_finished_sequence_groups(&mut self) {}
328 fn block_engine(&mut self) -> Option<&mut BlockEngine> {
329 None
330 }
331}