mistralrs_core/scheduler/
default_scheduler.rsuse std::{
collections::{HashMap, VecDeque},
num::NonZeroUsize,
sync::atomic::Ordering,
};
use crate::{
engine::TERMINATE_ALL_NEXT_STEP,
paged_attention::{BlockEngine, BlockTables},
sequence::{Sequence, SequenceState, StopReason},
};
use super::{Scheduler, SchedulerOutput};
pub trait FcfsBacker: Default {
fn new() -> Self;
fn add(&mut self, item: Sequence);
fn into_iter(self) -> impl Iterator<Item = Sequence>;
fn len(&self) -> usize;
fn sort_ascending_ids(&mut self);
}
impl FcfsBacker for VecDeque<Sequence> {
fn new() -> Self {
Self::new()
}
fn add(&mut self, item: Sequence) {
self.push_back(item)
}
fn into_iter(self) -> impl Iterator<Item = Sequence> {
<Self as IntoIterator>::into_iter(self)
}
fn sort_ascending_ids(&mut self) {
let slice = self.make_contiguous();
slice.sort_by_key(|seq| *seq.id());
}
fn len(&self) -> usize {
VecDeque::len(self)
}
}
pub struct DefaultSchedulerOutput<'a> {
pub completion: Box<[&'a mut Sequence]>,
pub prompt: Box<[&'a mut Sequence]>,
}
#[derive(Clone)]
pub enum DefaultSchedulerMethod {
Fixed(NonZeroUsize),
}
pub struct BucketedSeqs<Backer: FcfsBacker> {
running: Vec<Sequence>,
waiting: Backer,
}
pub trait BucketingManager<Backer: FcfsBacker> {
fn bucket_and_waitlist_seqs_waiting(
&mut self,
running: Vec<Sequence>,
waiting: Backer,
discrete: bool,
) -> BucketedSeqs<Backer>;
}
type BucketKey = (Option<Vec<String>>, usize, bool, usize);
struct FixedBucketingManager;
impl<Backer: FcfsBacker> BucketingManager<Backer> for FixedBucketingManager {
fn bucket_and_waitlist_seqs_waiting(
&mut self,
running: Vec<Sequence>,
mut waiting: Backer,
discrete: bool,
) -> BucketedSeqs<Backer> {
let mut seq_buckets: HashMap<BucketKey, Vec<Sequence>> = HashMap::new();
let mut seq_priorities: HashMap<BucketKey, f64> = HashMap::new();
for seq in running {
let len = seq.len();
match seq_buckets.get_mut(&(
seq.get_adapters(),
len,
seq.images().is_some() && seq.is_prompt(),
seq.token_offset(),
)) {
Some(bucket) => {
if !discrete {
*seq_priorities
.get_mut(&(
seq.get_adapters(),
len,
seq.images().is_some() && seq.is_prompt(),
seq.token_offset(),
))
.unwrap() += seq.compute_priority();
}
bucket.push(seq);
}
None => {
if !discrete {
seq_priorities.insert(
(
seq.get_adapters(),
len,
seq.images().is_some() && seq.is_prompt(),
seq.token_offset(),
),
seq.compute_priority(),
);
}
seq_buckets.insert(
(
seq.get_adapters(),
len,
seq.images().is_some() && seq.is_prompt(),
seq.token_offset(),
),
vec![seq],
);
}
}
}
let running = if seq_buckets.len() <= 1 {
seq_buckets
.into_iter()
.flat_map(|(_, x)| x)
.map(|s| s.reset_urgency())
.collect::<Vec<_>>()
} else {
let min = seq_buckets
.keys()
.min_by_key(|(_, x, _, _)| *x)
.expect("No sequence buckets.")
.clone();
let len = if !discrete {
seq_priorities
.iter()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(a, b)| (a, *b))
.unwrap_or_else(|| (&min, seq_priorities[&min]))
.0
} else {
&min
};
let highest_priority_seqs = seq_buckets
.remove(len)
.unwrap()
.into_iter()
.map(|s| s.reset_urgency())
.collect();
for (_, seqs) in seq_buckets {
for seq in seqs {
waiting.add(seq.add_urgency());
}
}
highest_priority_seqs
};
BucketedSeqs { running, waiting }
}
}
pub struct DefaultScheduler<Backer: FcfsBacker> {
waiting: Backer,
running: Vec<Sequence>,
method: DefaultSchedulerMethod,
bucketing_manager: Box<dyn BucketingManager<Backer>>,
}
impl<Backer: FcfsBacker> DefaultScheduler<Backer> {
pub fn new(method: DefaultSchedulerMethod) -> Self {
let bucketing_manager: Box<dyn BucketingManager<_>> = match method {
DefaultSchedulerMethod::Fixed(_) => Box::new(FixedBucketingManager),
};
Self {
running: Vec::new(),
waiting: Backer::new(),
method,
bucketing_manager,
}
}
fn bucket_and_waitlist_seqs(&mut self, running: Vec<Sequence>) -> Vec<Sequence> {
let waiting = std::mem::take(&mut self.waiting);
let BucketedSeqs { running, waiting } = self
.bucketing_manager
.bucket_and_waitlist_seqs_waiting(running, waiting, true);
self.waiting = waiting;
running
}
pub fn schedule(&mut self) -> DefaultSchedulerOutput {
let running = std::mem::take(&mut self.running);
let mut waiting = std::mem::take(&mut self.waiting);
let mut running = running
.into_iter()
.filter(|seq| seq.is_running())
.collect::<Vec<_>>();
match (waiting.len(), running.len()) {
(0, 0) => {
self.running = running;
return DefaultSchedulerOutput {
prompt: vec![].into(),
completion: vec![].into(),
};
}
(_, 0) => {
for seq in waiting.into_iter() {
seq.set_state(SequenceState::RunningPrompt);
self.running.push(seq);
}
self.waiting = Backer::new();
let running = std::mem::take(&mut self.running);
self.running = self.bucket_and_waitlist_seqs(running);
return DefaultSchedulerOutput {
prompt: self.running.iter_mut().collect::<Vec<_>>().into(),
completion: vec![].into(),
};
}
(0, _) => {
self.running = self.bucket_and_waitlist_seqs(running);
if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
self.running
.iter_mut()
.for_each(|seq| seq.set_state(SequenceState::Done(StopReason::Canceled)));
TERMINATE_ALL_NEXT_STEP.store(false, Ordering::SeqCst);
}
return DefaultSchedulerOutput {
prompt: vec![].into(),
completion: self.running.iter_mut().collect::<Vec<_>>().into(),
};
}
_ => {}
}
waiting.sort_ascending_ids();
let mut new_waiting = Backer::new();
for seq in waiting.into_iter() {
if self.sequence_fits(&running, &seq) {
if seq.is_waiting() {
seq.set_state(SequenceState::RunningPrompt);
}
running.push(seq);
} else {
new_waiting.add(seq);
}
}
let BucketedSeqs {
running,
waiting: new_waiting,
} = self
.bucketing_manager
.bucket_and_waitlist_seqs_waiting(running, new_waiting, false);
self.running = running;
self.waiting = new_waiting;
let mut completion = Vec::new();
let mut prompt = Vec::new();
for seq in &mut self.running {
if seq.is_completion() {
completion.push(seq);
} else {
prompt.push(seq);
}
}
DefaultSchedulerOutput {
completion: completion.into(),
prompt: prompt.into(),
}
}
fn sequence_fits(&self, running: &[Sequence], _seq: &Sequence) -> bool {
match &self.method {
DefaultSchedulerMethod::Fixed(n) => (running.len() + 1) <= (*n).into(),
}
}
}
impl Scheduler for DefaultScheduler<VecDeque<Sequence>> {
fn schedule(&mut self) -> SchedulerOutput<'_> {
SchedulerOutput::DefaultScheduler {
output: self.schedule(),
}
}
fn waiting_len(&self) -> usize {
self.waiting.len()
}
fn running_len(&self) -> usize {
self.running.len()
}
fn add_seq(&mut self, seq: Sequence) {
if seq.is_running() {
self.running.push(seq);
} else {
self.waiting.add(seq);
}
}
fn block_tables(&self) -> Option<&BlockTables> {
None
}
fn block_size(&self) -> Option<usize> {
None
}
fn free_finished_sequence_groups(&mut self) {}
fn block_engine(&mut self) -> Option<&mut BlockEngine> {
None
}
}