mistralrs_core/dummy_paged_attention/
scheduler.rs1type CPUBlockFrom = usize;
6type GPUBlockFrom = usize;
7type CPUBlockTo = usize;
8type GPUBlockTo = usize;
9type SrcBlockFrom = usize;
10type DstBlocksTo = Vec<usize>;
11
12use std::{
13 collections::{HashMap, VecDeque},
14 sync::{atomic::Ordering, Arc, Mutex},
15};
16
17use tracing::warn;
18
19use crate::{
20 get_mut_arcmutex,
21 paged_attention::BlockEngine,
22 scheduler::{Scheduler, SchedulerOutput},
23 sequence::{Sequence, SequenceState, StopReason},
24 TERMINATE_ALL_NEXT_STEP,
25};
26
27use super::{block_engine::AllocStatus, BlockEngineSequence, BlockTables, CacheConfig};
28
29pub struct PagedAttentionSchedulerOutput {
30 pub scheduled: Vec<Arc<Mutex<Sequence>>>,
32 pub blocks_to_swap_in: HashMap<CPUBlockFrom, GPUBlockTo>,
33 pub blocks_to_swap_out: HashMap<GPUBlockFrom, CPUBlockTo>,
34 pub blocks_to_copy: HashMap<SrcBlockFrom, DstBlocksTo>,
35}
36
37pub struct PagedAttentionSchedulerConfig {
38 pub max_num_seqs: usize,
39}
40
41pub struct PagedAttentionScheduler {
42 waiting: VecDeque<Arc<Mutex<Sequence>>>,
43 running: VecDeque<Arc<Mutex<Sequence>>>,
44 swapped_out: VecDeque<Arc<Mutex<Sequence>>>,
45 config: PagedAttentionSchedulerConfig,
46 pub block_engine: BlockEngine,
47 block_size: usize,
48}
49
50impl PagedAttentionScheduler {
51 pub fn new(config: PagedAttentionSchedulerConfig, cache_config: CacheConfig) -> Self {
52 Self {
53 waiting: VecDeque::new(),
54 running: VecDeque::new(),
55 swapped_out: VecDeque::new(),
56 config,
57 block_engine: BlockEngine::new(
58 cache_config.block_size,
59 cache_config.num_gpu_blocks,
60 cache_config.num_cpu_blocks,
61 ),
62 block_size: cache_config.block_size,
63 }
64 }
65
66 pub fn schedule(&mut self) -> PagedAttentionSchedulerOutput {
67 if self.swapped_out.is_empty() {
70 let mut scheduled = VecDeque::new();
71 let mut did_ignore = false;
72 while !self.waiting.is_empty() {
73 let seq = self.waiting.front().unwrap().clone();
74
75 if self.config.max_num_seqs == self.running.len() + 1 {
77 break;
78 }
79
80 let can_allocate = self.block_engine.can_allocate(&*get_mut_arcmutex!(seq));
82 match can_allocate {
83 AllocStatus::Later => break, AllocStatus::Impossible => {
85 let id = *get_mut_arcmutex!(seq).id();
86 let len = get_mut_arcmutex!(seq).get_toks().len();
87 warn!(
88 "Sequence {id} with length of {len} tokens is too long and exceeds capacity of block engine. Sequence will be ignored.",
89 );
90 get_mut_arcmutex!(seq).set_state(SequenceState::FinishedIgnored);
91 did_ignore = true;
92 }
93 _ => {}
94 }
95
96 if !did_ignore {
97 get_mut_arcmutex!(seq).set_state(SequenceState::RunningPrompt);
98 let seq_handle = get_mut_arcmutex!(seq);
99 self._allocate(&seq_handle);
100 }
101
102 let seq = self.waiting.pop_front().unwrap();
103 self.running.push_back(seq.clone());
104 if !did_ignore {
105 scheduled.push_back(seq);
106 }
107 }
108
109 if !scheduled.is_empty() || did_ignore {
111 return PagedAttentionSchedulerOutput {
112 scheduled: scheduled.into(),
113 blocks_to_swap_in: HashMap::new(),
114 blocks_to_copy: HashMap::new(),
115 blocks_to_swap_out: HashMap::new(),
116 };
117 }
118 }
119
120 let mut blocks_to_swap_out = HashMap::new();
121 let mut blocks_to_swap_in = HashMap::new();
122 let mut blocks_to_copy = HashMap::new();
123
124 self.sort_running_by_priority_fcfs();
132
133 let mut running = VecDeque::new();
134 let mut did_preempt = false;
135 while !self.running.is_empty() {
136 let seq = self.running.pop_front().unwrap();
137 let mut finished_with_break = false;
138 while !self
139 .block_engine
140 .can_append_token_to_seq(&*get_mut_arcmutex!(seq))
141 {
142 if !self.running.is_empty() {
144 let seq_to_preempt = self.running.pop_back().unwrap();
146 self._preempt(seq_to_preempt, &mut blocks_to_swap_out);
147 did_preempt = true;
148 } else {
149 self._preempt(seq.clone(), &mut blocks_to_swap_out);
151 did_preempt = true;
152 finished_with_break = true;
153 break;
154 }
155 }
156 if !finished_with_break {
157 {
158 let seq_handle = get_mut_arcmutex!(seq);
161 self._append_token_slot_to_seq(&seq_handle, &mut blocks_to_copy);
162 }
163 running.push_back(seq);
164 }
165 }
166 self.running = running;
167
168 self.sort_swapped_out_by_priority_fcfs();
173
174 if !did_preempt {
175 while !self.swapped_out.is_empty() {
176 let seq = self.swapped_out.front().unwrap();
177
178 if !self.block_engine.can_swap_in_seq(&*get_mut_arcmutex!(seq)) {
180 break;
181 }
182
183 let seq = self.swapped_out.pop_front().unwrap();
184 let to_swap_in = self.block_engine.swap_in(&*get_mut_arcmutex!(seq));
186 blocks_to_swap_in.extend(to_swap_in);
187 {
188 let seq_handle = get_mut_arcmutex!(seq);
190 self._append_token_slot_to_seq(&seq_handle, &mut blocks_to_copy);
191 }
192 self.running.push_back(seq);
193 }
194 }
195
196 self.running
197 .iter()
198 .for_each(|seq| get_mut_arcmutex!(seq).set_state(SequenceState::RunningCompletion));
199
200 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
201 self.running.iter().for_each(|seq| {
202 get_mut_arcmutex!(seq).set_state(SequenceState::Done(StopReason::Canceled))
203 });
204 TERMINATE_ALL_NEXT_STEP.store(false, Ordering::SeqCst);
205 }
206
207 PagedAttentionSchedulerOutput {
208 scheduled: self.running.clone().into(), blocks_to_swap_in,
210 blocks_to_copy,
211 blocks_to_swap_out,
212 }
213 }
214
215 pub fn free_finished_sequence_groups(&mut self) {
216 let mut to_free_ids = Vec::new();
217 self.running.retain(|seq| {
218 if get_mut_arcmutex!(seq).is_finished_paged_attn() {
219 to_free_ids.push(get_mut_arcmutex!(seq).get_id());
220 false
221 } else {
222 true
223 }
224 });
225 for id in to_free_ids {
226 self._free(id);
227 }
228 }
229}
230
231impl PagedAttentionScheduler {
232 #[allow(dead_code)]
233 fn remove_seq(&mut self, seq_id: usize) -> Arc<Mutex<Sequence>> {
234 if let Some(idx) = self
236 .waiting
237 .iter()
238 .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
239 {
240 return self.waiting.remove(idx).unwrap();
241 };
242 if let Some(idx) = self
244 .running
245 .iter()
246 .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
247 {
248 return self.running.remove(idx).unwrap();
249 };
250 if let Some(idx) = self
252 .swapped_out
253 .iter()
254 .position(|other| get_mut_arcmutex!(other).get_id() == seq_id)
255 {
256 return self.swapped_out.remove(idx).unwrap();
257 };
258 panic!("Attempted to remove sequence id {seq_id} but it is not running, waiting, or swapped out.");
259 }
260
261 fn _append_token_slot_to_seq(
262 &mut self,
263 seq: &Sequence,
264 blocks_to_copy: &mut HashMap<usize, Vec<usize>>,
265 ) {
266 let op = self.block_engine.append_token_slot_to_seq(seq);
267 if let Some((src_block, dst_block)) = op {
268 if let std::collections::hash_map::Entry::Vacant(e) = blocks_to_copy.entry(src_block) {
269 e.insert(vec![dst_block]);
270 } else {
271 blocks_to_copy.get_mut(&src_block).unwrap().push(dst_block);
272 }
273 }
274 }
275
276 fn _abort_seq(&mut self, seq_id: usize) {
277 let removed = self.remove_seq(seq_id);
278 get_mut_arcmutex!(removed).set_state(SequenceState::FinishedAborted);
279 self._free(seq_id);
280 }
281
282 fn _preempt(
284 &mut self,
285 seq: Arc<Mutex<Sequence>>,
286 _blocks_to_swap_out: &mut HashMap<usize, usize>,
287 ) {
288 self._preempt_by_recompute(seq)
289 }
290
291 fn _preempt_by_recompute(&mut self, seq: Arc<Mutex<Sequence>>) {
292 get_mut_arcmutex!(seq).set_state(SequenceState::Waiting);
293 self._free(get_mut_arcmutex!(seq).get_id());
294 self.waiting.push_front(seq);
295 }
296
297 fn _preempt_by_swap(
298 &mut self,
299 seq: Arc<Mutex<Sequence>>,
300 blocks_to_swap_out: &mut HashMap<usize, usize>,
301 ) {
302 if !self.block_engine.can_swap_out_seq(&*get_mut_arcmutex!(seq)) {
303 let id = get_mut_arcmutex!(seq).get_id();
305 self._abort_seq(id);
306 return;
307 }
308 let new_to_swap = self.block_engine.swap_out(&*get_mut_arcmutex!(seq));
309 blocks_to_swap_out.extend(new_to_swap);
310 get_mut_arcmutex!(seq).set_state(SequenceState::Swapped);
311
312 self.swapped_out.push_back(seq);
313 }
314
315 fn _allocate(&mut self, seq: &Sequence) {
316 self.block_engine.allocate(seq)
317 }
318
319 fn _free(&mut self, seq_id: usize) {
320 self.block_engine.free_sequence(seq_id);
321 }
322
323 fn sort_running_by_priority_fcfs(&mut self) {
324 self.running
325 .make_contiguous()
326 .sort_by_key(|seq| get_mut_arcmutex!(seq).timestamp());
327 self.running.make_contiguous().reverse();
328 }
329
330 fn sort_swapped_out_by_priority_fcfs(&mut self) {
331 self.swapped_out
332 .make_contiguous()
333 .sort_by_key(|seq| get_mut_arcmutex!(seq).timestamp());
334 self.swapped_out.make_contiguous().reverse();
335 }
336}
337
338impl Scheduler for PagedAttentionScheduler {
339 fn add_seq(&mut self, seq: Sequence) {
340 self.waiting.push_back(Arc::new(Mutex::new(seq)));
341 }
342 fn schedule(&mut self) -> SchedulerOutput<'_> {
343 SchedulerOutput::PagedAttention {
344 output: self.schedule(),
345 }
346 }
347 fn waiting_len(&self) -> usize {
348 self.waiting.len() + self.swapped_out.len()
349 }
350 fn running_len(&self) -> usize {
351 self.running.len()
352 }
353 fn block_tables(&self) -> Option<&BlockTables> {
354 Some(&self.block_engine.block_tables)
355 }
356 fn block_size(&self) -> Option<usize> {
357 Some(self.block_size)
358 }
359 fn free_finished_sequence_groups(&mut self) {
360 self.free_finished_sequence_groups()
361 }
362 fn block_engine(&mut self) -> Option<&mut BlockEngine> {
363 Some(&mut self.block_engine)
364 }
365}