1use crate::{
2 distributed,
3 pipeline::{
4 llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
5 text_models_inputs_processor::PagedAttentionMeta,
6 CacheBackendMetadata, CacheInstruction,
7 },
8 prefix_cacher::PrefixCacheManagerV2,
9 response::CompletionChoice,
10 scheduler::{Scheduler, SchedulerOutput},
11 search::{self, rag::SearchPipeline},
12 sequence::{SeqStepType, StopReason},
13 tools, CompletionResponse, SchedulerConfig, DEBUG,
14};
15use interprocess::local_socket::{traits::Listener, ListenerOptions};
16use llguidance::ParserFactory;
17pub use logger::IntervalLogger;
18use mistralrs_quant::RingConfig;
19use rand::SeedableRng;
20use rand_isaac::Isaac64Rng;
21use serde::{Deserialize, Serialize};
22use std::{
23 collections::HashMap,
24 fmt,
25 io::{BufWriter, Write},
26 net::TcpListener,
27 ops::Deref,
28 str::FromStr,
29 sync::{
30 atomic::{AtomicBool, Ordering},
31 Arc, LazyLock,
32 },
33 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
34};
35use tokio::{
36 select,
37 sync::{
38 mpsc::{error::TryRecvError, Receiver, Sender},
39 Mutex, Notify,
40 },
41 task::JoinHandle,
42};
43
44use crate::{
45 get_mut_arcmutex, handle_pipeline_forward_error,
46 pipeline::{ModelCategory, Pipeline},
47 request::Request,
48 response::{ChatCompletionResponse, Choice, ResponseMessage},
49 sequence::{SequenceRecognizer, SequenceState},
50 Constraint,
51};
52
53mod add_request;
54mod logger;
55mod search_request;
56
57pub enum EngineInstruction {
58 Terminate,
59}
60
61#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
62#[serde(rename_all = "snake_case")]
63pub enum SearchEmbeddingModel {
65 #[default]
66 #[serde(rename = "embedding_gemma")]
67 EmbeddingGemma300M,
68}
69
70impl SearchEmbeddingModel {
71 pub fn hf_model_id(&self) -> &'static str {
72 match self {
73 Self::EmbeddingGemma300M => "google/embeddinggemma-300m",
74 }
75 }
76
77 pub fn variants() -> &'static [&'static str] {
78 &["embedding_gemma"]
79 }
80}
81
82impl fmt::Display for SearchEmbeddingModel {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 match self {
85 Self::EmbeddingGemma300M => f.write_str("embedding_gemma"),
86 }
87 }
88}
89
90impl FromStr for SearchEmbeddingModel {
91 type Err = String;
92
93 fn from_str(s: &str) -> Result<Self, Self::Err> {
94 match s.trim().to_ascii_lowercase().as_str() {
95 "embedding_gemma" => Ok(Self::EmbeddingGemma300M),
96 other => Err(format!(
97 "Unknown search embedding model `{other}`. Supported values: {}",
98 Self::variants().join(", ")
99 )),
100 }
101 }
102}
103
104const SEED: u64 = 0;
105pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
108
109static ENGINE_TERMINATE_FLAGS: LazyLock<
111 std::sync::Mutex<HashMap<std::thread::ThreadId, Arc<AtomicBool>>>,
112> = LazyLock::new(|| std::sync::Mutex::new(HashMap::new()));
113
114pub fn get_engine_terminate_flag() -> Arc<AtomicBool> {
116 let thread_id = std::thread::current().id();
117 let mut flags = ENGINE_TERMINATE_FLAGS.lock().unwrap();
118 flags
119 .entry(thread_id)
120 .or_insert_with(|| Arc::new(AtomicBool::new(false)))
121 .clone()
122}
123
124pub fn should_terminate_engine_sequences() -> bool {
126 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
128 return true;
129 }
130 let thread_id = std::thread::current().id();
132 if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
133 if let Some(flag) = flags.get(&thread_id) {
134 return flag.load(Ordering::SeqCst);
135 }
136 }
137 false
138}
139
140pub fn reset_engine_terminate_flag() {
142 let thread_id = std::thread::current().id();
143 if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
144 if let Some(flag) = flags.get(&thread_id) {
145 flag.store(false, Ordering::SeqCst);
146 }
147 }
148}
149
150pub static ENGINE_INSTRUCTIONS: LazyLock<
152 std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>,
153> = LazyLock::new(|| std::sync::Mutex::new(HashMap::new()));
154
155pub struct Engine {
156 tx: Sender<Request>,
157 rx: Arc<Mutex<Receiver<Request>>>,
158 pipeline: Arc<Mutex<dyn Pipeline>>,
159 search_pipeline: Arc<Mutex<Option<SearchPipeline>>>,
160 search_callback: Option<Arc<search::SearchCallback>>,
161 tool_callbacks: tools::ToolCallbacks,
162 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
163 scheduler: Arc<Mutex<dyn Scheduler>>,
164 id: Arc<Mutex<usize>>,
165 no_kv_cache: bool,
166 prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
167 is_debug: bool,
168 disable_eos_stop: bool,
169 throughput_logging_enabled: bool,
170 logger: IntervalLogger,
171 handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
172 pending_notify: Arc<Notify>,
173}
174
175impl Drop for Engine {
176 fn drop(&mut self) {
177 for handle in &*get_mut_arcmutex!(self.handles) {
178 handle.abort();
179 }
180 }
181}
182
183impl Engine {
184 #[allow(clippy::too_many_arguments)]
185 pub fn new(
186 tx: Sender<Request>,
187 rx: Receiver<Request>,
188 pipeline: Arc<Mutex<dyn Pipeline>>,
189 config: SchedulerConfig,
190 mut no_kv_cache: bool,
191 mut no_prefix_cache: bool,
192 prefix_cache_n: usize,
193 disable_eos_stop: bool,
194 throughput_logging_enabled: bool,
195 search_embedding_model: Option<SearchEmbeddingModel>,
196 search_callback: Option<Arc<search::SearchCallback>>,
197 tool_callbacks: tools::ToolCallbacks,
198 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
199 ) -> anyhow::Result<Self> {
200 no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
201
202 no_prefix_cache = no_prefix_cache
203 || no_kv_cache
204 || get_mut_arcmutex!(pipeline).get_metadata().no_prefix_cache
205 || prefix_cache_n == 0;
206
207 let search_pipeline = match search_embedding_model {
208 Some(search_embedding_model) => Some(SearchPipeline::new(
209 search_embedding_model,
210 &get_mut_arcmutex!(pipeline).device(),
211 )?),
212 None => None,
213 };
214
215 let scheduler = config.into_scheduler();
216
217 get_mut_arcmutex!(scheduler).set_prefix_caching_enabled(!no_prefix_cache);
220
221 let block_engine = get_mut_arcmutex!(scheduler).block_engine();
222
223 Ok(Self {
224 tx,
225 rx: Arc::new(Mutex::new(rx)),
226 pipeline,
227 search_pipeline: Arc::new(Mutex::new(search_pipeline)),
228 search_callback,
229 tool_callbacks,
230 tool_callbacks_with_tools,
231 scheduler: scheduler.clone(),
232 id: Arc::new(Mutex::new(0)),
233 no_kv_cache,
234 prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
235 prefix_cache_n,
236 no_prefix_cache,
237 block_engine,
238 ))),
239 is_debug: DEBUG.load(Ordering::Relaxed),
240 disable_eos_stop,
241 throughput_logging_enabled,
242 logger: IntervalLogger::new(Duration::from_secs(5)),
243 handles: Arc::new(Mutex::new(Vec::new())),
244 pending_notify: Arc::new(Notify::new()),
245 })
246 }
247
248 #[allow(dead_code)]
250 pub fn max_sequence_length(&self) -> Option<usize> {
251 let pipeline = get_mut_arcmutex!(self.pipeline);
252 let category = pipeline.category();
253
254 if matches!(category, ModelCategory::Diffusion | ModelCategory::Speech) {
255 None
256 } else {
257 Some(pipeline.get_metadata().max_seq_len)
258 }
259 }
260
261 pub async fn run(self: Arc<Self>) {
262 if self.throughput_logging_enabled {
263 self.logger.enable_logging();
264 }
265
266 let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
267 let mut last_completion_ids: Vec<usize> = vec![];
268 'lp: loop {
269 let should_terminate = || {
270 matches!(
271 ENGINE_INSTRUCTIONS
272 .lock()
273 .expect("`ENGINE_INSTRUCTIONS` was poisoned")
274 .get(get_mut_arcmutex!(self.id).deref()),
275 Some(Some(EngineInstruction::Terminate))
276 )
277 };
278
279 if should_terminate() {
280 self.replicate_request_to_daemons(&Request::Terminate);
281 break 'lp;
282 }
283
284 let mut channel_disconnected = false;
285 loop {
286 let next_request = {
287 let mut rx = self.rx.lock().await;
288 rx.try_recv()
289 };
290
291 match next_request {
292 Ok(request) => {
293 self.replicate_request_to_daemons(&request);
294 if matches!(request, Request::Terminate) {
295 break 'lp;
296 }
297 self.clone().handle_request(request).await;
298 }
299 Err(TryRecvError::Empty) => break,
300 Err(TryRecvError::Disconnected) => {
301 channel_disconnected = true;
302 break;
303 }
304 }
305 }
306
307 if channel_disconnected {
308 break 'lp;
309 }
310
311 let (waiting_len, running_len) = {
312 let scheduler = get_mut_arcmutex!(self.scheduler);
313 (scheduler.waiting_len(), scheduler.running_len())
314 };
315 let scheduler_idle = waiting_len == 0 && running_len == 0;
316
317 if scheduler_idle {
318 if should_terminate() {
319 self.replicate_request_to_daemons(&Request::Terminate);
320 break 'lp;
321 }
322 enum WaitEvent {
323 Request(Option<Request>),
324 Wake,
325 }
326 let wait_for_request = async {
327 let mut rx = self.rx.lock().await;
328 rx.recv().await
329 };
330 tokio::pin!(wait_for_request);
331 let wait_for_wake = self.pending_notify.notified();
332 tokio::pin!(wait_for_wake);
333
334 let event = select! {
335 res = &mut wait_for_request => WaitEvent::Request(res),
336 _ = &mut wait_for_wake => WaitEvent::Wake,
337 };
338
339 match event {
340 WaitEvent::Request(Some(request)) => {
341 self.replicate_request_to_daemons(&request);
342 if matches!(request, Request::Terminate) {
343 break 'lp;
344 }
345 self.clone().handle_request(request).await;
346 continue;
347 }
348 WaitEvent::Request(None) => break 'lp,
349 WaitEvent::Wake => {
350 continue;
351 }
352 }
353 }
354
355 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
356 self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
357 }
358
359 let run_start = Instant::now();
360 let mut scheduler = get_mut_arcmutex!(self.scheduler);
361 let scheduled = scheduler.schedule(&self.logger);
362
363 match scheduled {
364 SchedulerOutput::DefaultScheduler {
365 output: mut scheduled,
366 } => {
367 if !scheduled.completion.is_empty() {
368 let current_completion_ids: Vec<usize> =
369 scheduled.completion.iter().map(|seq| *seq.id()).collect();
370 let res = {
371 let mut pipeline = get_mut_arcmutex!(self.pipeline);
372 let pre_op = if !self.no_kv_cache
373 && last_completion_ids != current_completion_ids
374 {
375 CacheInstruction::In
376 } else {
377 CacheInstruction::Nothing
378 };
379 let post_op = if !self.no_kv_cache {
380 CacheInstruction::Out
381 } else {
382 CacheInstruction::Reset {
383 load_preallocated_cache: false,
384 reset_non_granular: false,
385 }
386 };
387
388 let return_raw_logits = scheduled.completion[0].return_raw_logits;
389 assert!(
390 scheduled
391 .completion
392 .iter()
393 .all(|seq| seq.return_raw_logits == return_raw_logits),
394 "All sequences must either return raw logits, or not."
395 );
396
397 pipeline
398 .step(
399 &mut scheduled.completion,
400 false,
401 return_raw_logits,
402 &mut *get_mut_arcmutex!(self.prefix_cacher),
403 self.disable_eos_stop,
404 rng.clone(),
405 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
406 )
407 .await
408 };
409
410 handle_pipeline_forward_error!(
411 "completion step",
412 res,
413 &mut scheduled.completion,
414 self.pipeline,
415 'lp,
416 self.prefix_cacher
417 );
418
419 self.logger.add_tokens_processed(scheduled.completion.len());
420
421 last_completion_ids = current_completion_ids;
422 }
423
424 if !scheduled.prompt.is_empty() {
425 let prompt_exec_time = {
426 let mut pipeline = get_mut_arcmutex!(self.pipeline);
427
428 let post_op = if !self.no_kv_cache {
430 CacheInstruction::Out
431 } else {
432 CacheInstruction::Reset {
433 load_preallocated_cache: false,
434 reset_non_granular: false,
435 }
436 };
437
438 let return_raw_logits = scheduled.prompt[0].return_raw_logits;
439 assert!(
440 scheduled
441 .prompt
442 .iter()
443 .all(|seq| seq.return_raw_logits == return_raw_logits),
444 "All sequences must either return raw logits, or not."
445 );
446
447 let pre_op = if scheduled.prompt[0].token_offset() != 0 {
450 CacheInstruction::In
451 } else {
452 CacheInstruction::Reset {
453 load_preallocated_cache: true,
454 reset_non_granular: false,
455 }
456 };
457
458 pipeline
459 .step(
460 &mut scheduled.prompt,
461 true,
462 return_raw_logits,
463 &mut *get_mut_arcmutex!(self.prefix_cacher),
464 self.disable_eos_stop,
465 rng.clone(),
466 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
467 )
468 .await
469 };
470
471 let prompt_exec_time = handle_pipeline_forward_error!(
472 "prompt step",
473 prompt_exec_time,
474 &mut scheduled.prompt,
475 self.pipeline,
476 'lp,
477 self.prefix_cacher
478 );
479
480 let total_processed_tokens: usize = scheduled
481 .prompt
482 .iter()
483 .map(|seq| seq.get_toks().len())
484 .sum();
485 self.logger.add_tokens_processed(total_processed_tokens);
486
487 for seq in scheduled.prompt.iter_mut() {
488 match seq.sequence_stepping_type() {
489 SeqStepType::OneShot => {
490 seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
491 }
492 SeqStepType::PromptAndDecode => {
493 seq.set_state(SequenceState::RunningCompletion)
494 }
495 }
496 let now = SystemTime::now()
497 .duration_since(UNIX_EPOCH)
498 .expect("Time travel has occurred!")
499 .as_millis();
500 #[allow(clippy::cast_precision_loss)]
501 let prompt_tok_per_sec =
502 seq.len() as f32 / prompt_exec_time.as_secs_f32();
503 seq.prompt_tok_per_sec = prompt_tok_per_sec;
504 seq.prompt_timestamp = Some(now);
505 seq.total_prompt_time = Some(prompt_exec_time.as_millis());
506 }
507 last_completion_ids = vec![];
508 }
509
510 if self.is_debug {
511 let ms_from_last_run = run_start.elapsed().as_secs_f64();
512 let total_len = scheduled.prompt.len() + scheduled.completion.len();
513 if total_len > 0 {
514 let prompt_lengths = scheduled
515 .prompt
516 .iter()
517 .map(|seq| seq.len().to_string())
518 .collect::<Vec<_>>()
519 .join(", ");
520
521 let completion_lengths = scheduled
522 .completion
523 .iter()
524 .map(|seq| seq.len().to_string())
525 .collect::<Vec<_>>()
526 .join(", ");
527
528 tracing::info!(
529 "Prompt[{}] Completion[{}] - {}ms",
530 prompt_lengths,
531 completion_lengths,
532 ms_from_last_run * 1000.,
533 );
534 }
535 }
536 }
537 SchedulerOutput::PagedAttention { mut output } => {
538 if !output.scheduled.is_empty() {
539 let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
540
541 let mut guards = output
542 .scheduled
543 .iter_mut()
544 .map(|seq| seq.lock().unwrap())
545 .collect::<Vec<_>>();
546
547 let mut guards_mut =
548 guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
549
550 let res = {
551 let mut pipeline = get_mut_arcmutex!(self.pipeline);
552
553 let block_size = scheduler.block_size().unwrap();
554
555 let metadata = PagedAttentionMeta {
556 block_size,
557 sliding_window: pipeline.get_metadata().sliding_window,
558 block_engine: scheduler.block_engine().unwrap(),
559 };
560
561 let return_raw_logits = guards_mut[0].return_raw_logits;
562 assert!(
563 guards_mut
564 .iter()
565 .all(|seq| seq.return_raw_logits == return_raw_logits),
566 "All sequences must either return raw logits, or not."
567 );
568
569 pipeline
570 .step(
571 &mut guards_mut,
572 is_prompt,
573 return_raw_logits,
574 &mut *get_mut_arcmutex!(self.prefix_cacher),
575 self.disable_eos_stop,
576 rng.clone(),
577 CacheBackendMetadata::PagedAttention {
578 metadata,
579 blocks_to_copy: output.blocks_to_copy,
580 },
581 )
582 .await
583 };
584
585 handle_pipeline_forward_error!(
586 "step",
587 res,
588 &mut guards_mut,
589 self.pipeline,
590 'lp,
591 self.prefix_cacher
592 );
593
594 let total_processed_tokens: usize = guards
595 .iter()
596 .map(|seq| {
597 if seq.is_prompt() {
598 seq.get_toks().len()
599 } else {
600 1
601 }
602 })
603 .sum();
604 self.logger.add_tokens_processed(total_processed_tokens);
605
606 if self.is_debug {
607 let ms_from_last_run = run_start.elapsed().as_secs_f64();
608 let total_len = guards.len();
609 if total_len > 0 {
610 let lengths = guards
611 .iter()
612 .map(|seq| seq.len().to_string())
613 .collect::<Vec<_>>()
614 .join(", ");
615
616 let (prompt_lengths, completion_lengths) = if is_prompt {
617 (lengths, "".to_string())
618 } else {
619 ("".to_string(), lengths)
620 };
621
622 tracing::info!(
623 "Prompt[{}] Completion[{}] - {}ms",
624 prompt_lengths,
625 completion_lengths,
626 ms_from_last_run * 1000.,
627 );
628 }
629 }
630
631 if is_prompt {
632 for mut seq in guards {
633 let now = SystemTime::now()
634 .duration_since(UNIX_EPOCH)
635 .expect("Time travel has occurred!")
636 .as_millis();
637 #[allow(clippy::cast_precision_loss)]
638 let prompt_tok_per_sec =
639 seq.len() as f32 / (now - seq.timestamp()) as f32;
640 seq.prompt_tok_per_sec = prompt_tok_per_sec * 1000.;
641 seq.prompt_timestamp = Some(now);
642 seq.total_prompt_time = Some(now - seq.timestamp());
643 }
644 }
645 }
646 }
647 }
648
649 {
651 let pipeline = get_mut_arcmutex!(self.pipeline);
652 if !pipeline.get_metadata().no_kv_cache && pipeline.cache().is_hybrid() {
653 let mamba_indices = scheduler.get_finished_mamba_indices();
654 if !mamba_indices.is_empty() {
655 let mut hybrid_cache = pipeline.cache().hybrid();
656 for idx in mamba_indices {
657 hybrid_cache.free_seq(idx);
658 }
659 }
660 }
661 }
662 scheduler.free_finished_sequence_groups();
663 }
664 }
665
666 fn build_sequence_recognizer(
667 factory: &Option<Arc<ParserFactory>>,
668 constraint: &Constraint,
669 ) -> anyhow::Result<SequenceRecognizer> {
670 if let Some(grm) = llg_grammar_from_constraint(constraint)? {
671 let factory = factory
672 .as_ref()
673 .ok_or_else(|| anyhow::anyhow!("No token environment (llg_factory) found."))?;
674 let llg = constraint_from_llg_grammar(factory, grm)?;
675 Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
676 } else {
677 Ok(SequenceRecognizer::None)
678 }
679 }
680
681 fn replicate_request_to_daemons(&self, request: &Request) {
682 if !distributed::is_daemon() && mistralrs_quant::distributed::use_nccl() {
683 let name = distributed::ipc_name().unwrap();
684 let num_workers =
685 mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
686 let listener = ListenerOptions::new().name(name).create_sync().unwrap();
687
688 for _ in 0..num_workers {
689 let stream = listener.accept().unwrap();
690 let mut writer = BufWriter::new(stream);
691 let req = format!("{}\n", serde_json::to_string(&request).unwrap());
692 writer.write_all(req.as_bytes()).unwrap();
693 }
694 } else if !distributed::is_daemon() && cfg!(feature = "ring") {
695 let num_workers =
696 mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
697 let master_port = RingConfig::load().master_port;
698 let listener =
699 TcpListener::bind(format!("0.0.0.0:{master_port}")).expect("bind replicator");
700
701 for _ in 0..num_workers {
702 let (stream, _) = listener.accept().unwrap();
703 let mut writer = BufWriter::new(stream);
704 let req = format!("{}\n", serde_json::to_string(&request).unwrap());
705 writer.write_all(req.as_bytes()).unwrap();
706 }
707 }
708 }
709}