mistralrs_core/engine/
mod.rs

1use crate::{
2    distributed,
3    pipeline::{
4        llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
5        text_models_inputs_processor::PagedAttentionMeta,
6        CacheBackendMetadata, CacheInstruction,
7    },
8    prefix_cacher::PrefixCacheManagerV2,
9    response::CompletionChoice,
10    scheduler::{Scheduler, SchedulerOutput},
11    search::{self, rag::SearchPipeline},
12    sequence::{SeqStepType, StopReason},
13    tools, CompletionResponse, SchedulerConfig, DEBUG,
14};
15use interprocess::local_socket::{traits::Listener, ListenerOptions};
16use llguidance::ParserFactory;
17pub use logger::IntervalLogger;
18use mistralrs_quant::RingConfig;
19use rand::SeedableRng;
20use rand_isaac::Isaac64Rng;
21use serde::{Deserialize, Serialize};
22use std::{
23    collections::HashMap,
24    fmt,
25    io::{BufWriter, Write},
26    net::TcpListener,
27    ops::Deref,
28    str::FromStr,
29    sync::{
30        atomic::{AtomicBool, Ordering},
31        Arc, LazyLock,
32    },
33    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
34};
35use tokio::{
36    select,
37    sync::{
38        mpsc::{error::TryRecvError, Receiver, Sender},
39        Mutex, Notify,
40    },
41    task::JoinHandle,
42};
43
44use crate::{
45    get_mut_arcmutex, handle_pipeline_forward_error,
46    pipeline::{ModelCategory, Pipeline},
47    request::Request,
48    response::{ChatCompletionResponse, Choice, ResponseMessage},
49    sequence::{SequenceRecognizer, SequenceState},
50    Constraint,
51};
52
53mod add_request;
54mod logger;
55mod search_request;
56
57pub enum EngineInstruction {
58    Terminate,
59}
60
61#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
62#[serde(rename_all = "snake_case")]
63/// Embedding model used for ranking web search results internally.
64pub enum SearchEmbeddingModel {
65    #[default]
66    #[serde(rename = "embedding_gemma")]
67    EmbeddingGemma300M,
68}
69
70impl SearchEmbeddingModel {
71    pub fn hf_model_id(&self) -> &'static str {
72        match self {
73            Self::EmbeddingGemma300M => "google/embeddinggemma-300m",
74        }
75    }
76
77    pub fn variants() -> &'static [&'static str] {
78        &["embedding_gemma"]
79    }
80}
81
82impl fmt::Display for SearchEmbeddingModel {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        match self {
85            Self::EmbeddingGemma300M => f.write_str("embedding_gemma"),
86        }
87    }
88}
89
90impl FromStr for SearchEmbeddingModel {
91    type Err = String;
92
93    fn from_str(s: &str) -> Result<Self, Self::Err> {
94        match s.trim().to_ascii_lowercase().as_str() {
95            "embedding_gemma" => Ok(Self::EmbeddingGemma300M),
96            other => Err(format!(
97                "Unknown search embedding model `{other}`. Supported values: {}",
98                Self::variants().join(", ")
99            )),
100        }
101    }
102}
103
104const SEED: u64 = 0;
105/// Terminate all sequences on the next scheduling step. Be sure to reset this.
106/// This is a global flag for terminating all engines at once (e.g., Ctrl+C).
107pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
108
109/// Engine-specific termination flags, per Engine thread ID.
110static ENGINE_TERMINATE_FLAGS: LazyLock<
111    std::sync::Mutex<HashMap<std::thread::ThreadId, Arc<AtomicBool>>>,
112> = LazyLock::new(|| std::sync::Mutex::new(HashMap::new()));
113
114/// Get or create a termination flag for the current engine thread.
115pub fn get_engine_terminate_flag() -> Arc<AtomicBool> {
116    let thread_id = std::thread::current().id();
117    let mut flags = ENGINE_TERMINATE_FLAGS.lock().unwrap();
118    flags
119        .entry(thread_id)
120        .or_insert_with(|| Arc::new(AtomicBool::new(false)))
121        .clone()
122}
123
124/// Check if the current engine should terminate sequences.
125pub fn should_terminate_engine_sequences() -> bool {
126    // Check global flag first
127    if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
128        return true;
129    }
130    // Then check engine-specific flag
131    let thread_id = std::thread::current().id();
132    if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
133        if let Some(flag) = flags.get(&thread_id) {
134            return flag.load(Ordering::SeqCst);
135        }
136    }
137    false
138}
139
140/// Reset termination flags for the current engine.
141pub fn reset_engine_terminate_flag() {
142    let thread_id = std::thread::current().id();
143    if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
144        if let Some(flag) = flags.get(&thread_id) {
145            flag.store(false, Ordering::SeqCst);
146        }
147    }
148}
149
150/// Engine instructions, per Engine (MistralRs) ID.
151pub static ENGINE_INSTRUCTIONS: LazyLock<
152    std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>,
153> = LazyLock::new(|| std::sync::Mutex::new(HashMap::new()));
154
155pub struct Engine {
156    tx: Sender<Request>,
157    rx: Arc<Mutex<Receiver<Request>>>,
158    pipeline: Arc<Mutex<dyn Pipeline>>,
159    search_pipeline: Arc<Mutex<Option<SearchPipeline>>>,
160    search_callback: Option<Arc<search::SearchCallback>>,
161    tool_callbacks: tools::ToolCallbacks,
162    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
163    scheduler: Arc<Mutex<dyn Scheduler>>,
164    id: Arc<Mutex<usize>>,
165    no_kv_cache: bool,
166    prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
167    is_debug: bool,
168    disable_eos_stop: bool,
169    throughput_logging_enabled: bool,
170    logger: IntervalLogger,
171    handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
172    pending_notify: Arc<Notify>,
173}
174
175impl Drop for Engine {
176    fn drop(&mut self) {
177        for handle in &*get_mut_arcmutex!(self.handles) {
178            handle.abort();
179        }
180    }
181}
182
183impl Engine {
184    #[allow(clippy::too_many_arguments)]
185    pub fn new(
186        tx: Sender<Request>,
187        rx: Receiver<Request>,
188        pipeline: Arc<Mutex<dyn Pipeline>>,
189        config: SchedulerConfig,
190        mut no_kv_cache: bool,
191        mut no_prefix_cache: bool,
192        prefix_cache_n: usize,
193        disable_eos_stop: bool,
194        throughput_logging_enabled: bool,
195        search_embedding_model: Option<SearchEmbeddingModel>,
196        search_callback: Option<Arc<search::SearchCallback>>,
197        tool_callbacks: tools::ToolCallbacks,
198        tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
199    ) -> anyhow::Result<Self> {
200        no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
201
202        no_prefix_cache = no_prefix_cache
203            || no_kv_cache
204            || get_mut_arcmutex!(pipeline).get_metadata().no_prefix_cache
205            || prefix_cache_n == 0;
206
207        let search_pipeline = match search_embedding_model {
208            Some(search_embedding_model) => Some(SearchPipeline::new(
209                search_embedding_model,
210                &get_mut_arcmutex!(pipeline).device(),
211            )?),
212            None => None,
213        };
214
215        let scheduler = config.into_scheduler();
216        let block_engine = get_mut_arcmutex!(scheduler).block_engine();
217
218        Ok(Self {
219            tx,
220            rx: Arc::new(Mutex::new(rx)),
221            pipeline,
222            search_pipeline: Arc::new(Mutex::new(search_pipeline)),
223            search_callback,
224            tool_callbacks,
225            tool_callbacks_with_tools,
226            scheduler: scheduler.clone(),
227            id: Arc::new(Mutex::new(0)),
228            no_kv_cache,
229            prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
230                prefix_cache_n,
231                no_prefix_cache,
232                block_engine,
233            ))),
234            is_debug: DEBUG.load(Ordering::Relaxed),
235            disable_eos_stop,
236            throughput_logging_enabled,
237            logger: IntervalLogger::new(Duration::from_secs(5)),
238            handles: Arc::new(Mutex::new(Vec::new())),
239            pending_notify: Arc::new(Notify::new()),
240        })
241    }
242
243    /// Returns the maximum supported sequence length for the underlying model, if applicable.
244    #[allow(dead_code)]
245    pub fn max_sequence_length(&self) -> Option<usize> {
246        let pipeline = get_mut_arcmutex!(self.pipeline);
247        let category = pipeline.category();
248
249        if matches!(category, ModelCategory::Diffusion | ModelCategory::Speech) {
250            None
251        } else {
252            Some(pipeline.get_metadata().max_seq_len)
253        }
254    }
255
256    pub async fn run(self: Arc<Self>) {
257        if self.throughput_logging_enabled {
258            self.logger.enable_logging();
259        }
260
261        let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
262        let mut last_completion_ids: Vec<usize> = vec![];
263        'lp: loop {
264            let should_terminate = || {
265                matches!(
266                    ENGINE_INSTRUCTIONS
267                        .lock()
268                        .expect("`ENGINE_INSTRUCTIONS` was poisoned")
269                        .get(get_mut_arcmutex!(self.id).deref()),
270                    Some(Some(EngineInstruction::Terminate))
271                )
272            };
273
274            if should_terminate() {
275                self.replicate_request_to_daemons(&Request::Terminate);
276                break 'lp;
277            }
278
279            let mut channel_disconnected = false;
280            loop {
281                let next_request = {
282                    let mut rx = self.rx.lock().await;
283                    rx.try_recv()
284                };
285
286                match next_request {
287                    Ok(request) => {
288                        self.replicate_request_to_daemons(&request);
289                        if matches!(request, Request::Terminate) {
290                            break 'lp;
291                        }
292                        self.clone().handle_request(request).await;
293                    }
294                    Err(TryRecvError::Empty) => break,
295                    Err(TryRecvError::Disconnected) => {
296                        channel_disconnected = true;
297                        break;
298                    }
299                }
300            }
301
302            if channel_disconnected {
303                break 'lp;
304            }
305
306            let (waiting_len, running_len) = {
307                let scheduler = get_mut_arcmutex!(self.scheduler);
308                (scheduler.waiting_len(), scheduler.running_len())
309            };
310            let scheduler_idle = waiting_len == 0 && running_len == 0;
311
312            if scheduler_idle {
313                if should_terminate() {
314                    self.replicate_request_to_daemons(&Request::Terminate);
315                    break 'lp;
316                }
317                enum WaitEvent {
318                    Request(Option<Request>),
319                    Wake,
320                }
321                let wait_for_request = async {
322                    let mut rx = self.rx.lock().await;
323                    rx.recv().await
324                };
325                tokio::pin!(wait_for_request);
326                let wait_for_wake = self.pending_notify.notified();
327                tokio::pin!(wait_for_wake);
328
329                let event = select! {
330                    res = &mut wait_for_request => WaitEvent::Request(res),
331                    _ = &mut wait_for_wake => WaitEvent::Wake,
332                };
333
334                match event {
335                    WaitEvent::Request(Some(request)) => {
336                        self.replicate_request_to_daemons(&request);
337                        if matches!(request, Request::Terminate) {
338                            break 'lp;
339                        }
340                        self.clone().handle_request(request).await;
341                        continue;
342                    }
343                    WaitEvent::Request(None) => break 'lp,
344                    WaitEvent::Wake => {
345                        continue;
346                    }
347                }
348            }
349
350            if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
351                self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
352            }
353
354            let run_start = Instant::now();
355            let mut scheduler = get_mut_arcmutex!(self.scheduler);
356            let scheduled = scheduler.schedule(&self.logger);
357
358            match scheduled {
359                SchedulerOutput::DefaultScheduler {
360                    output: mut scheduled,
361                } => {
362                    if !scheduled.completion.is_empty() {
363                        let current_completion_ids: Vec<usize> =
364                            scheduled.completion.iter().map(|seq| *seq.id()).collect();
365                        let res = {
366                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
367                            let pre_op = if !self.no_kv_cache
368                                && last_completion_ids != current_completion_ids
369                            {
370                                CacheInstruction::In
371                            } else {
372                                CacheInstruction::Nothing
373                            };
374                            let post_op = if !self.no_kv_cache {
375                                CacheInstruction::Out
376                            } else {
377                                CacheInstruction::Reset {
378                                    load_preallocated_cache: false,
379                                    reset_non_granular: false,
380                                }
381                            };
382
383                            let return_raw_logits = scheduled.completion[0].return_raw_logits;
384                            assert!(
385                                scheduled
386                                    .completion
387                                    .iter()
388                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
389                                "All sequences must either return raw logits, or not."
390                            );
391
392                            pipeline
393                                .step(
394                                    &mut scheduled.completion,
395                                    false,
396                                    return_raw_logits,
397                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
398                                    self.disable_eos_stop,
399                                    rng.clone(),
400                                    CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
401                                )
402                                .await
403                        };
404
405                        handle_pipeline_forward_error!(
406                            "completion step",
407                            res,
408                            &mut scheduled.completion,
409                            self.pipeline,
410                            'lp,
411                            self.prefix_cacher
412                        );
413
414                        self.logger.add_tokens_processed(scheduled.completion.len());
415
416                        last_completion_ids = current_completion_ids;
417                    }
418
419                    if !scheduled.prompt.is_empty() {
420                        let prompt_exec_time = {
421                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
422
423                            // Run the prompt seqs
424                            let post_op = if !self.no_kv_cache {
425                                CacheInstruction::Out
426                            } else {
427                                CacheInstruction::Reset {
428                                    load_preallocated_cache: false,
429                                    reset_non_granular: false,
430                                }
431                            };
432
433                            let return_raw_logits = scheduled.prompt[0].return_raw_logits;
434                            assert!(
435                                scheduled
436                                    .prompt
437                                    .iter()
438                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
439                                "All sequences must either return raw logits, or not."
440                            );
441
442                            // This comes from prefix caching
443                            // The invariant where all token offsets are the same is handled by the scheduler
444                            let pre_op = if scheduled.prompt[0].token_offset() != 0 {
445                                CacheInstruction::In
446                            } else {
447                                CacheInstruction::Reset {
448                                    load_preallocated_cache: true,
449                                    reset_non_granular: false,
450                                }
451                            };
452
453                            pipeline
454                                .step(
455                                    &mut scheduled.prompt,
456                                    true,
457                                    return_raw_logits,
458                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
459                                    self.disable_eos_stop,
460                                    rng.clone(),
461                                    CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
462                                )
463                                .await
464                        };
465
466                        let prompt_exec_time = handle_pipeline_forward_error!(
467                            "prompt step",
468                            prompt_exec_time,
469                            &mut scheduled.prompt,
470                            self.pipeline,
471                            'lp,
472                            self.prefix_cacher
473                        );
474
475                        let total_processed_tokens: usize = scheduled
476                            .prompt
477                            .iter()
478                            .map(|seq| seq.get_toks().len())
479                            .sum();
480                        self.logger.add_tokens_processed(total_processed_tokens);
481
482                        for seq in scheduled.prompt.iter_mut() {
483                            match seq.sequence_stepping_type() {
484                                SeqStepType::OneShot => {
485                                    seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
486                                }
487                                SeqStepType::PromptAndDecode => {
488                                    seq.set_state(SequenceState::RunningCompletion)
489                                }
490                            }
491                            let now = SystemTime::now()
492                                .duration_since(UNIX_EPOCH)
493                                .expect("Time travel has occurred!")
494                                .as_millis();
495                            #[allow(clippy::cast_precision_loss)]
496                            let prompt_tok_per_sec =
497                                seq.len() as f32 / prompt_exec_time.as_secs_f32();
498                            seq.prompt_tok_per_sec = prompt_tok_per_sec;
499                            seq.prompt_timestamp = Some(now);
500                            seq.total_prompt_time = Some(prompt_exec_time.as_millis());
501                        }
502                        last_completion_ids = vec![];
503                    }
504
505                    if self.is_debug {
506                        let ms_from_last_run = run_start.elapsed().as_secs_f64();
507                        let total_len = scheduled.prompt.len() + scheduled.completion.len();
508                        if total_len > 0 {
509                            let prompt_lengths = scheduled
510                                .prompt
511                                .iter()
512                                .map(|seq| seq.len().to_string())
513                                .collect::<Vec<_>>()
514                                .join(", ");
515
516                            let completion_lengths = scheduled
517                                .completion
518                                .iter()
519                                .map(|seq| seq.len().to_string())
520                                .collect::<Vec<_>>()
521                                .join(", ");
522
523                            tracing::info!(
524                                "Prompt[{}] Completion[{}] - {}ms",
525                                prompt_lengths,
526                                completion_lengths,
527                                ms_from_last_run * 1000.,
528                            );
529                        }
530                    }
531                }
532                SchedulerOutput::PagedAttention { mut output } => {
533                    if !output.scheduled.is_empty() {
534                        let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
535
536                        let mut guards = output
537                            .scheduled
538                            .iter_mut()
539                            .map(|seq| seq.lock().unwrap())
540                            .collect::<Vec<_>>();
541
542                        let mut guards_mut =
543                            guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
544
545                        let res = {
546                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
547
548                            let block_size = scheduler.block_size().unwrap();
549
550                            let metadata = PagedAttentionMeta {
551                                block_size,
552                                sliding_window: pipeline.get_metadata().sliding_window,
553                                block_engine: scheduler.block_engine().unwrap(),
554                            };
555
556                            let return_raw_logits = guards_mut[0].return_raw_logits;
557                            assert!(
558                                guards_mut
559                                    .iter()
560                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
561                                "All sequences must either return raw logits, or not."
562                            );
563
564                            pipeline
565                                .step(
566                                    &mut guards_mut,
567                                    is_prompt,
568                                    return_raw_logits,
569                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
570                                    self.disable_eos_stop,
571                                    rng.clone(),
572                                    CacheBackendMetadata::PagedAttention {
573                                        metadata,
574                                        blocks_to_copy: output.blocks_to_copy,
575                                    },
576                                )
577                                .await
578                        };
579
580                        handle_pipeline_forward_error!(
581                            "step",
582                            res,
583                            &mut guards_mut,
584                            self.pipeline,
585                            'lp,
586                            self.prefix_cacher
587                        );
588
589                        let total_processed_tokens: usize = guards
590                            .iter()
591                            .map(|seq| {
592                                if seq.is_prompt() {
593                                    seq.get_toks().len()
594                                } else {
595                                    1
596                                }
597                            })
598                            .sum();
599                        self.logger.add_tokens_processed(total_processed_tokens);
600
601                        if self.is_debug {
602                            let ms_from_last_run = run_start.elapsed().as_secs_f64();
603                            let total_len = guards.len();
604                            if total_len > 0 {
605                                let lengths = guards
606                                    .iter()
607                                    .map(|seq| seq.len().to_string())
608                                    .collect::<Vec<_>>()
609                                    .join(", ");
610
611                                let (prompt_lengths, completion_lengths) = if is_prompt {
612                                    (lengths, "".to_string())
613                                } else {
614                                    ("".to_string(), lengths)
615                                };
616
617                                tracing::info!(
618                                    "Prompt[{}] Completion[{}] - {}ms",
619                                    prompt_lengths,
620                                    completion_lengths,
621                                    ms_from_last_run * 1000.,
622                                );
623                            }
624                        }
625
626                        if is_prompt {
627                            for mut seq in guards {
628                                let now = SystemTime::now()
629                                    .duration_since(UNIX_EPOCH)
630                                    .expect("Time travel has occurred!")
631                                    .as_millis();
632                                #[allow(clippy::cast_precision_loss)]
633                                let prompt_tok_per_sec =
634                                    seq.len() as f32 / (now - seq.timestamp()) as f32;
635                                seq.prompt_tok_per_sec = prompt_tok_per_sec * 1000.;
636                                seq.prompt_timestamp = Some(now);
637                                seq.total_prompt_time = Some(now - seq.timestamp());
638                            }
639                        }
640                    }
641                }
642            }
643
644            // Free Mamba state pool slots for finished sequences (hybrid models)
645            {
646                let pipeline = get_mut_arcmutex!(self.pipeline);
647                if pipeline.cache().is_hybrid() {
648                    let mamba_indices = scheduler.get_finished_mamba_indices();
649                    if !mamba_indices.is_empty() {
650                        let mut hybrid_cache = pipeline.cache().hybrid();
651                        for idx in mamba_indices {
652                            hybrid_cache.free_seq(idx);
653                        }
654                    }
655                }
656            }
657            scheduler.free_finished_sequence_groups();
658        }
659    }
660
661    fn build_sequence_recognizer(
662        factory: &Option<Arc<ParserFactory>>,
663        constraint: &Constraint,
664    ) -> anyhow::Result<SequenceRecognizer> {
665        if let Some(grm) = llg_grammar_from_constraint(constraint)? {
666            let factory = factory
667                .as_ref()
668                .ok_or_else(|| anyhow::anyhow!("No token environment (llg_factory) found."))?;
669            let llg = constraint_from_llg_grammar(factory, grm)?;
670            Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
671        } else {
672            Ok(SequenceRecognizer::None)
673        }
674    }
675
676    fn replicate_request_to_daemons(&self, request: &Request) {
677        if !distributed::is_daemon() && mistralrs_quant::distributed::use_nccl() {
678            let name = distributed::ipc_name().unwrap();
679            let num_workers =
680                mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
681            let listener = ListenerOptions::new().name(name).create_sync().unwrap();
682
683            for _ in 0..num_workers {
684                let stream = listener.accept().unwrap();
685                let mut writer = BufWriter::new(stream);
686                let req = format!("{}\n", serde_json::to_string(&request).unwrap());
687                writer.write_all(req.as_bytes()).unwrap();
688            }
689        } else if !distributed::is_daemon() && cfg!(feature = "ring") {
690            let num_workers =
691                mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
692            let master_port = RingConfig::load().master_port;
693            let listener =
694                TcpListener::bind(format!("0.0.0.0:{master_port}")).expect("bind replicator");
695
696            for _ in 0..num_workers {
697                let (stream, _) = listener.accept().unwrap();
698                let mut writer = BufWriter::new(stream);
699                let req = format!("{}\n", serde_json::to_string(&request).unwrap());
700                writer.write_all(req.as_bytes()).unwrap();
701            }
702        }
703    }
704}