1use crate::{
2 distributed,
3 pipeline::{
4 llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
5 text_models_inputs_processor::PagedAttentionMeta,
6 CacheBackendMetadata, CacheInstruction,
7 },
8 prefix_cacher::PrefixCacheManagerV2,
9 response::CompletionChoice,
10 scheduler::{Scheduler, SchedulerOutput},
11 search::{self, rag::SearchPipeline},
12 sequence::{SeqStepType, StopReason},
13 tools, CompletionResponse, SchedulerConfig, DEBUG,
14};
15use interprocess::local_socket::{traits::Listener, ListenerOptions};
16use llguidance::ParserFactory;
17pub use logger::IntervalLogger;
18use mistralrs_quant::RingConfig;
19use rand::SeedableRng;
20use rand_isaac::Isaac64Rng;
21use serde::{Deserialize, Serialize};
22use std::{
23 collections::HashMap,
24 fmt,
25 io::{BufWriter, Write},
26 net::TcpListener,
27 ops::Deref,
28 str::FromStr,
29 sync::{
30 atomic::{AtomicBool, Ordering},
31 Arc, LazyLock,
32 },
33 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
34};
35use tokio::{
36 select,
37 sync::{
38 mpsc::{error::TryRecvError, Receiver, Sender},
39 Mutex, Notify,
40 },
41 task::JoinHandle,
42};
43
44use crate::{
45 get_mut_arcmutex, handle_pipeline_forward_error,
46 pipeline::{ModelCategory, Pipeline},
47 request::Request,
48 response::{ChatCompletionResponse, Choice, ResponseMessage},
49 sequence::{SequenceRecognizer, SequenceState},
50 Constraint,
51};
52
53mod add_request;
54mod logger;
55mod search_request;
56
57pub enum EngineInstruction {
58 Terminate,
59}
60
61#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
62#[serde(rename_all = "snake_case")]
63pub enum SearchEmbeddingModel {
65 #[default]
66 #[serde(rename = "embedding_gemma")]
67 EmbeddingGemma300M,
68}
69
70impl SearchEmbeddingModel {
71 pub fn hf_model_id(&self) -> &'static str {
72 match self {
73 Self::EmbeddingGemma300M => "google/embeddinggemma-300m",
74 }
75 }
76
77 pub fn variants() -> &'static [&'static str] {
78 &["embedding_gemma"]
79 }
80}
81
82impl fmt::Display for SearchEmbeddingModel {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 match self {
85 Self::EmbeddingGemma300M => f.write_str("embedding_gemma"),
86 }
87 }
88}
89
90impl FromStr for SearchEmbeddingModel {
91 type Err = String;
92
93 fn from_str(s: &str) -> Result<Self, Self::Err> {
94 match s.trim().to_ascii_lowercase().as_str() {
95 "embedding_gemma" => Ok(Self::EmbeddingGemma300M),
96 other => Err(format!(
97 "Unknown search embedding model `{other}`. Supported values: {}",
98 Self::variants().join(", ")
99 )),
100 }
101 }
102}
103
104const SEED: u64 = 0;
105pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
108
109static ENGINE_TERMINATE_FLAGS: LazyLock<
111 std::sync::Mutex<HashMap<std::thread::ThreadId, Arc<AtomicBool>>>,
112> = LazyLock::new(|| std::sync::Mutex::new(HashMap::new()));
113
114pub fn get_engine_terminate_flag() -> Arc<AtomicBool> {
116 let thread_id = std::thread::current().id();
117 let mut flags = ENGINE_TERMINATE_FLAGS.lock().unwrap();
118 flags
119 .entry(thread_id)
120 .or_insert_with(|| Arc::new(AtomicBool::new(false)))
121 .clone()
122}
123
124pub fn should_terminate_engine_sequences() -> bool {
126 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
128 return true;
129 }
130 let thread_id = std::thread::current().id();
132 if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
133 if let Some(flag) = flags.get(&thread_id) {
134 return flag.load(Ordering::SeqCst);
135 }
136 }
137 false
138}
139
140pub fn reset_engine_terminate_flag() {
142 let thread_id = std::thread::current().id();
143 if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
144 if let Some(flag) = flags.get(&thread_id) {
145 flag.store(false, Ordering::SeqCst);
146 }
147 }
148}
149
150pub static ENGINE_INSTRUCTIONS: LazyLock<
152 std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>,
153> = LazyLock::new(|| std::sync::Mutex::new(HashMap::new()));
154
155pub struct Engine {
156 tx: Sender<Request>,
157 rx: Arc<Mutex<Receiver<Request>>>,
158 pipeline: Arc<Mutex<dyn Pipeline>>,
159 search_pipeline: Arc<Mutex<Option<SearchPipeline>>>,
160 search_callback: Option<Arc<search::SearchCallback>>,
161 tool_callbacks: tools::ToolCallbacks,
162 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
163 scheduler: Arc<Mutex<dyn Scheduler>>,
164 id: Arc<Mutex<usize>>,
165 no_kv_cache: bool,
166 prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
167 is_debug: bool,
168 disable_eos_stop: bool,
169 throughput_logging_enabled: bool,
170 logger: IntervalLogger,
171 handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
172 pending_notify: Arc<Notify>,
173}
174
175impl Drop for Engine {
176 fn drop(&mut self) {
177 for handle in &*get_mut_arcmutex!(self.handles) {
178 handle.abort();
179 }
180 }
181}
182
183impl Engine {
184 #[allow(clippy::too_many_arguments)]
185 pub fn new(
186 tx: Sender<Request>,
187 rx: Receiver<Request>,
188 pipeline: Arc<Mutex<dyn Pipeline>>,
189 config: SchedulerConfig,
190 mut no_kv_cache: bool,
191 mut no_prefix_cache: bool,
192 prefix_cache_n: usize,
193 disable_eos_stop: bool,
194 throughput_logging_enabled: bool,
195 search_embedding_model: Option<SearchEmbeddingModel>,
196 search_callback: Option<Arc<search::SearchCallback>>,
197 tool_callbacks: tools::ToolCallbacks,
198 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
199 ) -> anyhow::Result<Self> {
200 no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
201
202 no_prefix_cache = no_prefix_cache
203 || no_kv_cache
204 || get_mut_arcmutex!(pipeline).get_metadata().no_prefix_cache
205 || prefix_cache_n == 0;
206
207 let search_pipeline = match search_embedding_model {
208 Some(search_embedding_model) => Some(SearchPipeline::new(
209 search_embedding_model,
210 &get_mut_arcmutex!(pipeline).device(),
211 )?),
212 None => None,
213 };
214
215 let scheduler = config.into_scheduler();
216 let block_engine = get_mut_arcmutex!(scheduler).block_engine();
217
218 Ok(Self {
219 tx,
220 rx: Arc::new(Mutex::new(rx)),
221 pipeline,
222 search_pipeline: Arc::new(Mutex::new(search_pipeline)),
223 search_callback,
224 tool_callbacks,
225 tool_callbacks_with_tools,
226 scheduler: scheduler.clone(),
227 id: Arc::new(Mutex::new(0)),
228 no_kv_cache,
229 prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
230 prefix_cache_n,
231 no_prefix_cache,
232 block_engine,
233 ))),
234 is_debug: DEBUG.load(Ordering::Relaxed),
235 disable_eos_stop,
236 throughput_logging_enabled,
237 logger: IntervalLogger::new(Duration::from_secs(5)),
238 handles: Arc::new(Mutex::new(Vec::new())),
239 pending_notify: Arc::new(Notify::new()),
240 })
241 }
242
243 #[allow(dead_code)]
245 pub fn max_sequence_length(&self) -> Option<usize> {
246 let pipeline = get_mut_arcmutex!(self.pipeline);
247 let category = pipeline.category();
248
249 if matches!(category, ModelCategory::Diffusion | ModelCategory::Speech) {
250 None
251 } else {
252 Some(pipeline.get_metadata().max_seq_len)
253 }
254 }
255
256 pub async fn run(self: Arc<Self>) {
257 if self.throughput_logging_enabled {
258 self.logger.enable_logging();
259 }
260
261 let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
262 let mut last_completion_ids: Vec<usize> = vec![];
263 'lp: loop {
264 let should_terminate = || {
265 matches!(
266 ENGINE_INSTRUCTIONS
267 .lock()
268 .expect("`ENGINE_INSTRUCTIONS` was poisoned")
269 .get(get_mut_arcmutex!(self.id).deref()),
270 Some(Some(EngineInstruction::Terminate))
271 )
272 };
273
274 if should_terminate() {
275 self.replicate_request_to_daemons(&Request::Terminate);
276 break 'lp;
277 }
278
279 let mut channel_disconnected = false;
280 loop {
281 let next_request = {
282 let mut rx = self.rx.lock().await;
283 rx.try_recv()
284 };
285
286 match next_request {
287 Ok(request) => {
288 self.replicate_request_to_daemons(&request);
289 if matches!(request, Request::Terminate) {
290 break 'lp;
291 }
292 self.clone().handle_request(request).await;
293 }
294 Err(TryRecvError::Empty) => break,
295 Err(TryRecvError::Disconnected) => {
296 channel_disconnected = true;
297 break;
298 }
299 }
300 }
301
302 if channel_disconnected {
303 break 'lp;
304 }
305
306 let (waiting_len, running_len) = {
307 let scheduler = get_mut_arcmutex!(self.scheduler);
308 (scheduler.waiting_len(), scheduler.running_len())
309 };
310 let scheduler_idle = waiting_len == 0 && running_len == 0;
311
312 if scheduler_idle {
313 if should_terminate() {
314 self.replicate_request_to_daemons(&Request::Terminate);
315 break 'lp;
316 }
317 enum WaitEvent {
318 Request(Option<Request>),
319 Wake,
320 }
321 let wait_for_request = async {
322 let mut rx = self.rx.lock().await;
323 rx.recv().await
324 };
325 tokio::pin!(wait_for_request);
326 let wait_for_wake = self.pending_notify.notified();
327 tokio::pin!(wait_for_wake);
328
329 let event = select! {
330 res = &mut wait_for_request => WaitEvent::Request(res),
331 _ = &mut wait_for_wake => WaitEvent::Wake,
332 };
333
334 match event {
335 WaitEvent::Request(Some(request)) => {
336 self.replicate_request_to_daemons(&request);
337 if matches!(request, Request::Terminate) {
338 break 'lp;
339 }
340 self.clone().handle_request(request).await;
341 continue;
342 }
343 WaitEvent::Request(None) => break 'lp,
344 WaitEvent::Wake => {
345 continue;
346 }
347 }
348 }
349
350 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
351 self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
352 }
353
354 let run_start = Instant::now();
355 let mut scheduler = get_mut_arcmutex!(self.scheduler);
356 let scheduled = scheduler.schedule(&self.logger);
357
358 match scheduled {
359 SchedulerOutput::DefaultScheduler {
360 output: mut scheduled,
361 } => {
362 if !scheduled.completion.is_empty() {
363 let current_completion_ids: Vec<usize> =
364 scheduled.completion.iter().map(|seq| *seq.id()).collect();
365 let res = {
366 let mut pipeline = get_mut_arcmutex!(self.pipeline);
367 let pre_op = if !self.no_kv_cache
368 && last_completion_ids != current_completion_ids
369 {
370 CacheInstruction::In
371 } else {
372 CacheInstruction::Nothing
373 };
374 let post_op = if !self.no_kv_cache {
375 CacheInstruction::Out
376 } else {
377 CacheInstruction::Reset {
378 load_preallocated_cache: false,
379 reset_non_granular: false,
380 }
381 };
382
383 let return_raw_logits = scheduled.completion[0].return_raw_logits;
384 assert!(
385 scheduled
386 .completion
387 .iter()
388 .all(|seq| seq.return_raw_logits == return_raw_logits),
389 "All sequences must either return raw logits, or not."
390 );
391
392 pipeline
393 .step(
394 &mut scheduled.completion,
395 false,
396 return_raw_logits,
397 &mut *get_mut_arcmutex!(self.prefix_cacher),
398 self.disable_eos_stop,
399 rng.clone(),
400 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
401 )
402 .await
403 };
404
405 handle_pipeline_forward_error!(
406 "completion step",
407 res,
408 &mut scheduled.completion,
409 self.pipeline,
410 'lp,
411 self.prefix_cacher
412 );
413
414 self.logger.add_tokens_processed(scheduled.completion.len());
415
416 last_completion_ids = current_completion_ids;
417 }
418
419 if !scheduled.prompt.is_empty() {
420 let prompt_exec_time = {
421 let mut pipeline = get_mut_arcmutex!(self.pipeline);
422
423 let post_op = if !self.no_kv_cache {
425 CacheInstruction::Out
426 } else {
427 CacheInstruction::Reset {
428 load_preallocated_cache: false,
429 reset_non_granular: false,
430 }
431 };
432
433 let return_raw_logits = scheduled.prompt[0].return_raw_logits;
434 assert!(
435 scheduled
436 .prompt
437 .iter()
438 .all(|seq| seq.return_raw_logits == return_raw_logits),
439 "All sequences must either return raw logits, or not."
440 );
441
442 let pre_op = if scheduled.prompt[0].token_offset() != 0 {
445 CacheInstruction::In
446 } else {
447 CacheInstruction::Reset {
448 load_preallocated_cache: true,
449 reset_non_granular: false,
450 }
451 };
452
453 pipeline
454 .step(
455 &mut scheduled.prompt,
456 true,
457 return_raw_logits,
458 &mut *get_mut_arcmutex!(self.prefix_cacher),
459 self.disable_eos_stop,
460 rng.clone(),
461 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
462 )
463 .await
464 };
465
466 let prompt_exec_time = handle_pipeline_forward_error!(
467 "prompt step",
468 prompt_exec_time,
469 &mut scheduled.prompt,
470 self.pipeline,
471 'lp,
472 self.prefix_cacher
473 );
474
475 let total_processed_tokens: usize = scheduled
476 .prompt
477 .iter()
478 .map(|seq| seq.get_toks().len())
479 .sum();
480 self.logger.add_tokens_processed(total_processed_tokens);
481
482 for seq in scheduled.prompt.iter_mut() {
483 match seq.sequence_stepping_type() {
484 SeqStepType::OneShot => {
485 seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
486 }
487 SeqStepType::PromptAndDecode => {
488 seq.set_state(SequenceState::RunningCompletion)
489 }
490 }
491 let now = SystemTime::now()
492 .duration_since(UNIX_EPOCH)
493 .expect("Time travel has occurred!")
494 .as_millis();
495 #[allow(clippy::cast_precision_loss)]
496 let prompt_tok_per_sec =
497 seq.len() as f32 / prompt_exec_time.as_secs_f32();
498 seq.prompt_tok_per_sec = prompt_tok_per_sec;
499 seq.prompt_timestamp = Some(now);
500 seq.total_prompt_time = Some(prompt_exec_time.as_millis());
501 }
502 last_completion_ids = vec![];
503 }
504
505 if self.is_debug {
506 let ms_from_last_run = run_start.elapsed().as_secs_f64();
507 let total_len = scheduled.prompt.len() + scheduled.completion.len();
508 if total_len > 0 {
509 let prompt_lengths = scheduled
510 .prompt
511 .iter()
512 .map(|seq| seq.len().to_string())
513 .collect::<Vec<_>>()
514 .join(", ");
515
516 let completion_lengths = scheduled
517 .completion
518 .iter()
519 .map(|seq| seq.len().to_string())
520 .collect::<Vec<_>>()
521 .join(", ");
522
523 tracing::info!(
524 "Prompt[{}] Completion[{}] - {}ms",
525 prompt_lengths,
526 completion_lengths,
527 ms_from_last_run * 1000.,
528 );
529 }
530 }
531 }
532 SchedulerOutput::PagedAttention { mut output } => {
533 if !output.scheduled.is_empty() {
534 let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
535
536 let mut guards = output
537 .scheduled
538 .iter_mut()
539 .map(|seq| seq.lock().unwrap())
540 .collect::<Vec<_>>();
541
542 let mut guards_mut =
543 guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
544
545 let res = {
546 let mut pipeline = get_mut_arcmutex!(self.pipeline);
547
548 let block_size = scheduler.block_size().unwrap();
549
550 let metadata = PagedAttentionMeta {
551 block_size,
552 sliding_window: pipeline.get_metadata().sliding_window,
553 block_engine: scheduler.block_engine().unwrap(),
554 };
555
556 let return_raw_logits = guards_mut[0].return_raw_logits;
557 assert!(
558 guards_mut
559 .iter()
560 .all(|seq| seq.return_raw_logits == return_raw_logits),
561 "All sequences must either return raw logits, or not."
562 );
563
564 pipeline
565 .step(
566 &mut guards_mut,
567 is_prompt,
568 return_raw_logits,
569 &mut *get_mut_arcmutex!(self.prefix_cacher),
570 self.disable_eos_stop,
571 rng.clone(),
572 CacheBackendMetadata::PagedAttention {
573 metadata,
574 blocks_to_copy: output.blocks_to_copy,
575 },
576 )
577 .await
578 };
579
580 handle_pipeline_forward_error!(
581 "step",
582 res,
583 &mut guards_mut,
584 self.pipeline,
585 'lp,
586 self.prefix_cacher
587 );
588
589 let total_processed_tokens: usize = guards
590 .iter()
591 .map(|seq| {
592 if seq.is_prompt() {
593 seq.get_toks().len()
594 } else {
595 1
596 }
597 })
598 .sum();
599 self.logger.add_tokens_processed(total_processed_tokens);
600
601 if self.is_debug {
602 let ms_from_last_run = run_start.elapsed().as_secs_f64();
603 let total_len = guards.len();
604 if total_len > 0 {
605 let lengths = guards
606 .iter()
607 .map(|seq| seq.len().to_string())
608 .collect::<Vec<_>>()
609 .join(", ");
610
611 let (prompt_lengths, completion_lengths) = if is_prompt {
612 (lengths, "".to_string())
613 } else {
614 ("".to_string(), lengths)
615 };
616
617 tracing::info!(
618 "Prompt[{}] Completion[{}] - {}ms",
619 prompt_lengths,
620 completion_lengths,
621 ms_from_last_run * 1000.,
622 );
623 }
624 }
625
626 if is_prompt {
627 for mut seq in guards {
628 let now = SystemTime::now()
629 .duration_since(UNIX_EPOCH)
630 .expect("Time travel has occurred!")
631 .as_millis();
632 #[allow(clippy::cast_precision_loss)]
633 let prompt_tok_per_sec =
634 seq.len() as f32 / (now - seq.timestamp()) as f32;
635 seq.prompt_tok_per_sec = prompt_tok_per_sec * 1000.;
636 seq.prompt_timestamp = Some(now);
637 seq.total_prompt_time = Some(now - seq.timestamp());
638 }
639 }
640 }
641 }
642 }
643
644 {
646 let pipeline = get_mut_arcmutex!(self.pipeline);
647 if pipeline.cache().is_hybrid() {
648 let mamba_indices = scheduler.get_finished_mamba_indices();
649 if !mamba_indices.is_empty() {
650 let mut hybrid_cache = pipeline.cache().hybrid();
651 for idx in mamba_indices {
652 hybrid_cache.free_seq(idx);
653 }
654 }
655 }
656 }
657 scheduler.free_finished_sequence_groups();
658 }
659 }
660
661 fn build_sequence_recognizer(
662 factory: &Option<Arc<ParserFactory>>,
663 constraint: &Constraint,
664 ) -> anyhow::Result<SequenceRecognizer> {
665 if let Some(grm) = llg_grammar_from_constraint(constraint)? {
666 let factory = factory
667 .as_ref()
668 .ok_or_else(|| anyhow::anyhow!("No token environment (llg_factory) found."))?;
669 let llg = constraint_from_llg_grammar(factory, grm)?;
670 Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
671 } else {
672 Ok(SequenceRecognizer::None)
673 }
674 }
675
676 fn replicate_request_to_daemons(&self, request: &Request) {
677 if !distributed::is_daemon() && mistralrs_quant::distributed::use_nccl() {
678 let name = distributed::ipc_name().unwrap();
679 let num_workers =
680 mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
681 let listener = ListenerOptions::new().name(name).create_sync().unwrap();
682
683 for _ in 0..num_workers {
684 let stream = listener.accept().unwrap();
685 let mut writer = BufWriter::new(stream);
686 let req = format!("{}\n", serde_json::to_string(&request).unwrap());
687 writer.write_all(req.as_bytes()).unwrap();
688 }
689 } else if !distributed::is_daemon() && cfg!(feature = "ring") {
690 let num_workers =
691 mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
692 let master_port = RingConfig::load().master_port;
693 let listener =
694 TcpListener::bind(format!("0.0.0.0:{master_port}")).expect("bind replicator");
695
696 for _ in 0..num_workers {
697 let (stream, _) = listener.accept().unwrap();
698 let mut writer = BufWriter::new(stream);
699 let req = format!("{}\n", serde_json::to_string(&request).unwrap());
700 writer.write_all(req.as_bytes()).unwrap();
701 }
702 }
703 }
704}