mistralrs_core/scheduler/
default_scheduler.rs1use std::{
2 collections::{HashMap, VecDeque},
3 num::NonZeroUsize,
4 sync::{atomic::Ordering, Arc},
5};
6
7use crate::{
8 engine::{IntervalLogger, TERMINATE_ALL_NEXT_STEP},
9 paged_attention::{BlockEngine, BlockTables},
10 sequence::{Sequence, SequenceState, StopReason},
11};
12
13use super::{Scheduler, SchedulerOutput};
14
15pub trait FcfsBacker: Default {
16 fn new() -> Self;
17 fn add(&mut self, item: Sequence);
18 fn into_iter(self) -> impl Iterator<Item = Sequence>;
19 fn len(&self) -> usize;
20 fn sort_ascending_ids(&mut self);
21}
22
23impl FcfsBacker for VecDeque<Sequence> {
24 fn new() -> Self {
25 Self::new()
26 }
27 fn add(&mut self, item: Sequence) {
28 self.push_back(item)
29 }
30 fn into_iter(self) -> impl Iterator<Item = Sequence> {
31 <Self as IntoIterator>::into_iter(self)
32 }
33 fn sort_ascending_ids(&mut self) {
34 let slice = self.make_contiguous();
35 slice.sort_by_key(|seq| *seq.id());
36 }
37 fn len(&self) -> usize {
38 VecDeque::len(self)
39 }
40}
41
42pub struct DefaultSchedulerOutput<'a> {
43 pub completion: Box<[&'a mut Sequence]>,
44 pub prompt: Box<[&'a mut Sequence]>,
45}
46
47#[derive(Clone)]
52pub enum DefaultSchedulerMethod {
53 Fixed(NonZeroUsize),
54}
55
56pub struct BucketedSeqs<Backer: FcfsBacker> {
57 running: Vec<Sequence>,
58 waiting: Backer,
59}
60
61pub trait BucketingManager<Backer: FcfsBacker>: Send + Sync {
62 fn bucket_and_waitlist_seqs_waiting(
64 &mut self,
65 running: Vec<Sequence>,
66 waiting: Backer,
67 discrete: bool,
68 ) -> BucketedSeqs<Backer>;
69}
70
71type BucketKey = (usize, bool, usize);
74
75struct FixedBucketingManager;
76
77impl<Backer: FcfsBacker> BucketingManager<Backer> for FixedBucketingManager {
78 fn bucket_and_waitlist_seqs_waiting(
82 &mut self,
83 running: Vec<Sequence>,
84 mut waiting: Backer,
85 discrete: bool,
86 ) -> BucketedSeqs<Backer> {
87 let mut seq_buckets: HashMap<BucketKey, Vec<Sequence>> = HashMap::new();
89 let mut seq_priorities: HashMap<BucketKey, f64> = HashMap::new();
90 for seq in running {
91 let len = seq.len();
92 match seq_buckets.get_mut(&(
93 len,
94 seq.images().is_some() && seq.is_prompt(),
95 seq.token_offset(),
96 )) {
97 Some(bucket) => {
98 if !discrete {
99 *seq_priorities
100 .get_mut(&(
101 len,
102 seq.images().is_some() && seq.is_prompt(),
103 seq.token_offset(),
104 ))
105 .unwrap() += seq.compute_priority();
106 }
107 bucket.push(seq);
108 }
109 None => {
110 if !discrete {
111 seq_priorities.insert(
112 (
113 len,
114 seq.images().is_some() && seq.is_prompt(),
115 seq.token_offset(),
116 ),
117 seq.compute_priority(),
118 );
119 }
120 seq_buckets.insert(
121 (
122 len,
123 seq.images().is_some() && seq.is_prompt(),
124 seq.token_offset(),
125 ),
126 vec![seq],
127 );
128 }
129 }
130 }
131 let running = if seq_buckets.len() <= 1 {
132 seq_buckets
134 .into_iter()
135 .flat_map(|(_, x)| x)
136 .map(|s| s.reset_urgency())
137 .collect::<Vec<_>>()
138 } else {
139 let min = *seq_buckets
142 .keys()
143 .min_by_key(|(x, _, _)| *x)
144 .expect("No sequence buckets.");
145 let len = if !discrete {
146 seq_priorities
147 .iter()
148 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
149 .map(|(a, b)| (a, *b))
150 .unwrap_or_else(|| (&min, seq_priorities[&min]))
151 .0
152 } else {
153 &min
154 };
155 let highest_priority_seqs = seq_buckets
156 .remove(len)
157 .unwrap()
158 .into_iter()
159 .map(|s| s.reset_urgency())
160 .collect();
161 for (_, seqs) in seq_buckets {
162 for seq in seqs {
163 waiting.add(seq.add_urgency());
164 }
165 }
166 highest_priority_seqs
168 };
169 BucketedSeqs { running, waiting }
170 }
171}
172
173pub struct DefaultScheduler<Backer: FcfsBacker> {
174 waiting: Backer,
175 running: Vec<Sequence>,
176 method: DefaultSchedulerMethod,
177 bucketing_manager: Box<dyn BucketingManager<Backer>>,
178}
179
180impl<Backer: FcfsBacker> DefaultScheduler<Backer> {
181 pub fn new(method: DefaultSchedulerMethod) -> Self {
182 let bucketing_manager: Box<dyn BucketingManager<_>> = match method {
183 DefaultSchedulerMethod::Fixed(_) => Box::new(FixedBucketingManager),
184 };
185 Self {
186 running: Vec::new(),
187 waiting: Backer::new(),
188 method,
189 bucketing_manager,
190 }
191 }
192
193 fn bucket_and_waitlist_seqs(&mut self, running: Vec<Sequence>) -> Vec<Sequence> {
197 let waiting = std::mem::take(&mut self.waiting);
198 let BucketedSeqs { running, waiting } = self
199 .bucketing_manager
200 .bucket_and_waitlist_seqs_waiting(running, waiting, true);
201 self.waiting = waiting;
202 running
203 }
204
205 pub fn schedule(&mut self, logger: &IntervalLogger) -> DefaultSchedulerOutput {
207 let running = std::mem::take(&mut self.running);
209 let mut waiting = std::mem::take(&mut self.waiting);
210 let mut running = running
211 .into_iter()
212 .filter(|seq| seq.is_running())
213 .collect::<Vec<_>>();
214
215 match (waiting.len(), running.len()) {
216 (0, 0) => {
217 self.running = running;
218 logger.set_num_running(self.running.len());
219 logger.set_num_waiting(self.waiting.len());
220 return DefaultSchedulerOutput {
221 prompt: vec![].into(),
222 completion: vec![].into(),
223 };
224 }
225 (_, 0) => {
226 for seq in waiting.into_iter() {
227 seq.set_state(SequenceState::RunningPrompt);
228 self.running.push(seq);
229 }
230 self.waiting = Backer::new();
231 let running = std::mem::take(&mut self.running);
232 self.running = self.bucket_and_waitlist_seqs(running);
233 logger.set_num_running(self.running.len());
234 logger.set_num_waiting(self.waiting.len());
235 return DefaultSchedulerOutput {
236 prompt: self.running.iter_mut().collect::<Vec<_>>().into(),
237 completion: vec![].into(),
238 };
239 }
240 (0, _) => {
241 self.running = self.bucket_and_waitlist_seqs(running);
242 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
243 self.running
244 .iter_mut()
245 .for_each(|seq| seq.set_state(SequenceState::Done(StopReason::Canceled)));
246 TERMINATE_ALL_NEXT_STEP.store(false, Ordering::SeqCst);
247 }
248 logger.set_num_running(self.running.len());
249 logger.set_num_waiting(self.waiting.len());
250 return DefaultSchedulerOutput {
251 prompt: vec![].into(),
252 completion: self.running.iter_mut().collect::<Vec<_>>().into(),
253 };
254 }
255 _ => {}
256 }
257
258 waiting.sort_ascending_ids();
260
261 let mut new_waiting = Backer::new();
263 for seq in waiting.into_iter() {
264 if self.sequence_fits(&running, &seq) {
265 if seq.is_waiting() {
266 seq.set_state(SequenceState::RunningPrompt);
267 }
268 running.push(seq);
269 } else {
270 new_waiting.add(seq);
271 }
272 }
273
274 let BucketedSeqs {
275 running,
276 waiting: new_waiting,
277 } = self
278 .bucketing_manager
279 .bucket_and_waitlist_seqs_waiting(running, new_waiting, false);
280
281 self.running = running;
282 self.waiting = new_waiting;
283
284 logger.set_num_running(self.running.len());
285 logger.set_num_waiting(self.waiting.len());
286
287 let mut completion = Vec::new();
288 let mut prompt = Vec::new();
289 for seq in &mut self.running {
290 if seq.is_completion() {
291 completion.push(seq);
292 } else {
293 prompt.push(seq);
294 }
295 }
296
297 DefaultSchedulerOutput {
298 completion: completion.into(),
299 prompt: prompt.into(),
300 }
301 }
302
303 fn sequence_fits(&self, running: &[Sequence], _seq: &Sequence) -> bool {
304 match &self.method {
305 DefaultSchedulerMethod::Fixed(n) => (running.len() + 1) <= (*n).into(),
306 }
307 }
308}
309
310impl Scheduler for DefaultScheduler<VecDeque<Sequence>> {
311 fn schedule(&mut self, logger: &IntervalLogger) -> SchedulerOutput<'_> {
312 SchedulerOutput::DefaultScheduler {
313 output: self.schedule(logger),
314 }
315 }
316 fn waiting_len(&self) -> usize {
317 self.waiting.len()
318 }
319 fn running_len(&self) -> usize {
320 self.running.len()
321 }
322 fn add_seq(&mut self, seq: Sequence) {
323 if seq.is_running() {
324 self.running.push(seq);
326 } else {
327 self.waiting.add(seq);
328 }
329 }
330 fn block_tables(&self) -> Option<BlockTables> {
331 None
332 }
333 fn block_size(&self) -> Option<usize> {
334 None
335 }
336 fn free_finished_sequence_groups(&mut self) {}
337 fn block_engine(&self) -> Option<Arc<tokio::sync::Mutex<BlockEngine>>> {
338 None
339 }
340}