1use crate::{
2 distributed,
3 embedding::bert::BertPipeline,
4 pipeline::{
5 llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
6 text_models_inputs_processor::PagedAttentionMeta,
7 CacheBackendMetadata, CacheInstruction,
8 },
9 prefix_cacher::PrefixCacheManagerV2,
10 response::CompletionChoice,
11 scheduler::{Scheduler, SchedulerOutput},
12 sequence::{SeqStepType, StopReason},
13 CompletionResponse, SchedulerConfig, DEBUG,
14};
15use interprocess::local_socket::{traits::Listener, ListenerOptions};
16use llguidance::ParserFactory;
17use logger::IntervalLogger;
18use once_cell::sync::Lazy;
19use rand::SeedableRng;
20use rand_isaac::Isaac64Rng;
21use std::{
22 collections::HashMap,
23 io::{BufWriter, Write},
24 ops::Deref,
25 sync::{
26 atomic::{AtomicBool, Ordering},
27 Arc,
28 },
29 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
30};
31use tokio::{
32 sync::{mpsc::Receiver, Mutex},
33 task::JoinHandle,
34};
35
36use crate::{
37 get_mut_arcmutex, handle_pipeline_forward_error,
38 pipeline::Pipeline,
39 request::Request,
40 response::{ChatCompletionResponse, Choice, ResponseMessage},
41 sequence::{SequenceRecognizer, SequenceState},
42 Constraint,
43};
44
45mod add_request;
46mod logger;
47mod search_request;
48
49pub enum EngineInstruction {
50 Terminate,
51}
52
53#[derive(Debug, Default, Clone)]
54pub enum BertEmbeddingModel {
56 #[default]
57 SnowflakeArcticEmbedL,
58 Custom(String),
59}
60
61const SEED: u64 = 0;
62pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
64
65pub static ENGINE_INSTRUCTIONS: Lazy<std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>> =
67 Lazy::new(|| std::sync::Mutex::new(HashMap::new()));
68
69pub struct Engine {
70 rx: Arc<Mutex<Receiver<Request>>>,
71 pipeline: Arc<Mutex<dyn Pipeline>>,
72 bert_pipeline: Arc<Mutex<Option<BertPipeline>>>,
73 scheduler: Arc<Mutex<dyn Scheduler>>,
74 id: Arc<Mutex<usize>>,
75 truncate_sequence: bool,
76 no_kv_cache: bool,
77 prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
78 is_debug: bool,
79 disable_eos_stop: bool,
80 throughput_logging_enabled: bool,
81 logger: IntervalLogger,
82 handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
83}
84
85impl Drop for Engine {
86 fn drop(&mut self) {
87 for handle in &*get_mut_arcmutex!(self.handles) {
88 handle.abort();
89 }
90 }
91}
92
93impl Engine {
94 #[allow(clippy::too_many_arguments)]
95 pub fn new(
96 rx: Receiver<Request>,
97 pipeline: Arc<Mutex<dyn Pipeline>>,
98 config: SchedulerConfig,
99 truncate_sequence: bool,
100 mut no_kv_cache: bool,
101 mut no_prefix_cache: bool,
102 prefix_cache_n: usize,
103 disable_eos_stop: bool,
104 throughput_logging_enabled: bool,
105 search_embedding_model: Option<BertEmbeddingModel>,
106 ) -> anyhow::Result<Self> {
107 no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
108
109 no_prefix_cache = matches!(config, SchedulerConfig::PagedAttentionMeta { .. })
110 || no_prefix_cache
111 || no_kv_cache;
112
113 let bert_pipeline = match search_embedding_model {
114 Some(search_embedding_model) => Some(BertPipeline::new(
115 search_embedding_model,
116 &get_mut_arcmutex!(pipeline).device(),
117 )?),
118 None => None,
119 };
120
121 Ok(Self {
122 rx: Arc::new(Mutex::new(rx)),
123 pipeline,
124 bert_pipeline: Arc::new(Mutex::new(bert_pipeline)),
125 scheduler: config.into_scheduler(),
126 id: Arc::new(Mutex::new(0)),
127 truncate_sequence,
128 no_kv_cache,
129 prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
130 prefix_cache_n,
131 no_prefix_cache,
132 ))),
133 is_debug: DEBUG.load(Ordering::Relaxed),
134 disable_eos_stop,
135 throughput_logging_enabled,
136 logger: IntervalLogger::new(Duration::from_secs(5)),
137 handles: Arc::new(Mutex::new(Vec::new())),
138 })
139 }
140
141 pub async fn run(self: Arc<Self>) {
142 if self.throughput_logging_enabled {
143 self.logger.enable_logging();
144 }
145
146 let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
147 let mut last_completion_ids: Vec<usize> = vec![];
148 'lp: loop {
149 if matches!(
150 ENGINE_INSTRUCTIONS
151 .lock()
152 .expect("`ENGINE_INSTRUCTIONS` was poisioned")
153 .get(get_mut_arcmutex!(self.id).deref()),
154 Some(Some(EngineInstruction::Terminate))
155 ) {
156 self.replicate_request_to_daemons(&Request::Terminate);
157 break 'lp;
158 }
159
160 while let Ok(request) = get_mut_arcmutex!(self.rx).try_recv() {
161 self.replicate_request_to_daemons(&request);
162 if matches!(request, Request::Terminate) {
163 break 'lp;
164 }
165 self.clone().handle_request(request).await;
166 }
167
168 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
169 self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
170 }
171
172 let run_start = Instant::now();
173 let mut scheduler = get_mut_arcmutex!(self.scheduler);
174 let scheduled = scheduler.schedule();
175
176 match scheduled {
177 SchedulerOutput::DefaultScheduler {
178 output: mut scheduled,
179 } => {
180 if !scheduled.completion.is_empty() {
181 let current_completion_ids: Vec<usize> =
182 scheduled.completion.iter().map(|seq| *seq.id()).collect();
183 let res = {
184 let mut pipeline = get_mut_arcmutex!(self.pipeline);
185 let pre_op = if !self.no_kv_cache
186 && last_completion_ids != current_completion_ids
187 {
188 CacheInstruction::In
189 } else {
190 CacheInstruction::Nothing
191 };
192 let post_op = if !self.no_kv_cache {
193 CacheInstruction::Out
194 } else {
195 CacheInstruction::Reset {
196 load_preallocated_cache: false,
197 reset_non_granular: false,
198 }
199 };
200
201 let return_raw_logits = scheduled.completion[0].return_raw_logits;
202 assert!(
203 scheduled
204 .completion
205 .iter()
206 .all(|seq| seq.return_raw_logits == return_raw_logits),
207 "All sequences must either return raw logits, or not."
208 );
209
210 pipeline
211 .step(
212 &mut scheduled.completion,
213 false,
214 return_raw_logits,
215 &mut *get_mut_arcmutex!(self.prefix_cacher),
216 self.disable_eos_stop,
217 rng.clone(),
218 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
219 )
220 .await
221 };
222
223 handle_pipeline_forward_error!(
224 "completion step",
225 res,
226 &mut scheduled.completion,
227 self.pipeline,
228 'lp,
229 self.prefix_cacher
230 );
231
232 self.logger.add_tokens_processed(scheduled.completion.len());
233
234 last_completion_ids = current_completion_ids;
235 }
236
237 if !scheduled.prompt.is_empty() {
238 let prompt_exec_time = {
239 let mut pipeline = get_mut_arcmutex!(self.pipeline);
240
241 let post_op = if !self.no_kv_cache {
243 CacheInstruction::Out
244 } else {
245 CacheInstruction::Reset {
246 load_preallocated_cache: false,
247 reset_non_granular: false,
248 }
249 };
250
251 let return_raw_logits = scheduled.prompt[0].return_raw_logits;
252 assert!(
253 scheduled
254 .prompt
255 .iter()
256 .all(|seq| seq.return_raw_logits == return_raw_logits),
257 "All sequences must either return raw logits, or not."
258 );
259
260 let pre_op = if scheduled.prompt[0].token_offset() != 0 {
263 CacheInstruction::In
264 } else {
265 CacheInstruction::Reset {
266 load_preallocated_cache: true,
267 reset_non_granular: false,
268 }
269 };
270
271 pipeline
272 .step(
273 &mut scheduled.prompt,
274 true,
275 return_raw_logits,
276 &mut *get_mut_arcmutex!(self.prefix_cacher),
277 self.disable_eos_stop,
278 rng.clone(),
279 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
280 )
281 .await
282 };
283
284 let prompt_exec_time = handle_pipeline_forward_error!(
285 "prompt step",
286 prompt_exec_time,
287 &mut scheduled.prompt,
288 self.pipeline,
289 'lp,
290 self.prefix_cacher
291 );
292
293 let total_processed_tokens: usize = scheduled
294 .prompt
295 .iter()
296 .map(|seq| seq.get_toks().len())
297 .sum();
298 self.logger.add_tokens_processed(total_processed_tokens);
299
300 for seq in scheduled.prompt.iter_mut() {
301 match seq.sequence_stepping_type() {
302 SeqStepType::OneShot => {
303 seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
304 }
305 SeqStepType::PromptAndDecode => {
306 seq.set_state(SequenceState::RunningCompletion)
307 }
308 }
309 let now = SystemTime::now()
310 .duration_since(UNIX_EPOCH)
311 .expect("Time travel has occurred!")
312 .as_millis();
313 #[allow(clippy::cast_precision_loss)]
314 let prompt_tok_per_sec =
315 seq.len() as f32 / prompt_exec_time.as_secs_f32();
316 seq.prompt_tok_per_sec = prompt_tok_per_sec;
317 seq.prompt_timestamp = Some(now);
318 seq.total_prompt_time = Some(prompt_exec_time.as_millis());
319 }
320 last_completion_ids = vec![];
321 }
322
323 if self.is_debug {
324 let ms_from_last_run = run_start.elapsed().as_secs_f64();
325 let total_len = scheduled.prompt.len() + scheduled.completion.len();
326 if total_len > 0 {
327 let prompt_lengths = scheduled
328 .prompt
329 .iter()
330 .map(|seq| seq.len().to_string())
331 .collect::<Vec<_>>()
332 .join(", ");
333
334 let completion_lengths = scheduled
335 .completion
336 .iter()
337 .map(|seq| seq.len().to_string())
338 .collect::<Vec<_>>()
339 .join(", ");
340
341 tracing::info!(
342 "Prompt[{}] Completion[{}] - {}ms",
343 prompt_lengths,
344 completion_lengths,
345 ms_from_last_run * 1000.,
346 );
347 }
348 }
349 }
350 SchedulerOutput::PagedAttention { mut output } => {
351 if !output.scheduled.is_empty() {
352 let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
353
354 let mut guards = output
355 .scheduled
356 .iter_mut()
357 .map(|seq| seq.lock().unwrap())
358 .collect::<Vec<_>>();
359
360 let mut guards_mut =
361 guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
362
363 let res = {
364 let mut pipeline = get_mut_arcmutex!(self.pipeline);
365
366 let block_size = scheduler.block_size().unwrap();
367
368 let metadata = PagedAttentionMeta {
369 block_size,
370 sliding_window: pipeline.get_metadata().sliding_window,
371 block_engine: scheduler.block_engine().unwrap(),
372 };
373
374 let return_raw_logits = guards_mut[0].return_raw_logits;
375 assert!(
376 guards_mut
377 .iter()
378 .all(|seq| seq.return_raw_logits == return_raw_logits),
379 "All sequences must either return raw logits, or not."
380 );
381
382 pipeline
383 .step(
384 &mut guards_mut,
385 is_prompt,
386 return_raw_logits,
387 &mut *get_mut_arcmutex!(self.prefix_cacher),
388 self.disable_eos_stop,
389 rng.clone(),
390 CacheBackendMetadata::PagedAttention {
391 metadata,
392 blocks_to_copy: output.blocks_to_copy,
393 blocks_to_swap_in: output.blocks_to_swap_in,
394 blocks_to_swap_out: output.blocks_to_swap_out,
395 },
396 )
397 .await
398 };
399
400 handle_pipeline_forward_error!(
401 "step",
402 res,
403 &mut guards_mut,
404 self.pipeline,
405 'lp,
406 self.prefix_cacher
407 );
408
409 let total_processed_tokens: usize = guards
410 .iter()
411 .map(|seq| {
412 if seq.is_prompt() {
413 seq.get_toks().len()
414 } else {
415 1
416 }
417 })
418 .sum();
419 self.logger.add_tokens_processed(total_processed_tokens);
420
421 if self.is_debug {
422 let ms_from_last_run = run_start.elapsed().as_secs_f64();
423 let total_len = guards.len();
424 if total_len > 0 {
425 let lengths = guards
426 .iter()
427 .map(|seq| seq.len().to_string())
428 .collect::<Vec<_>>()
429 .join(", ");
430
431 let (prompt_lengths, completion_lengths) = if is_prompt {
432 (lengths, "".to_string())
433 } else {
434 ("".to_string(), lengths)
435 };
436
437 tracing::info!(
438 "Prompt[{}] Completion[{}] - {}ms",
439 prompt_lengths,
440 completion_lengths,
441 ms_from_last_run * 1000.,
442 );
443 }
444 }
445
446 if is_prompt {
447 for mut seq in guards {
448 let now = SystemTime::now()
449 .duration_since(UNIX_EPOCH)
450 .expect("Time travel has occurred!")
451 .as_millis();
452 #[allow(clippy::cast_precision_loss)]
453 let prompt_tok_per_sec =
454 seq.len() as f32 / (now - seq.timestamp()) as f32;
455 seq.prompt_tok_per_sec = prompt_tok_per_sec * 1000.;
456 seq.prompt_timestamp = Some(now);
457 seq.total_prompt_time = Some(now - seq.timestamp());
458 }
459 }
460 }
461 }
462 }
463
464 scheduler.free_finished_sequence_groups();
465 }
466 }
467
468 fn build_sequence_recognizer(
469 factory: &Option<Arc<ParserFactory>>,
470 constraint: &Constraint,
471 ) -> anyhow::Result<SequenceRecognizer> {
472 if let Some(grm) = llg_grammar_from_constraint(constraint)? {
473 let factory = factory
474 .as_ref()
475 .ok_or_else(|| anyhow::anyhow!("No token environment (llg_factory) found."))?;
476 let llg = constraint_from_llg_grammar(factory, grm)?;
477 Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
478 } else {
479 Ok(SequenceRecognizer::None)
480 }
481 }
482
483 fn replicate_request_to_daemons(&self, request: &Request) {
484 if !distributed::is_daemon() && mistralrs_quant::distributed::use_nccl() {
485 let name = distributed::ipc_name().unwrap();
486 let num_workers =
487 mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
488 let listener = ListenerOptions::new().name(name).create_sync().unwrap();
489
490 for _ in 0..num_workers {
491 let stream = listener.accept().unwrap();
492 let mut writer = BufWriter::new(stream);
493 let req = format!("{}\n", serde_json::to_string(&request).unwrap());
494 writer.write_all(req.as_bytes()).unwrap();
495 }
496 };
497 }
498}