mistralrs_core/engine/
mod.rs

1use crate::{
2    distributed,
3    embedding::bert::BertPipeline,
4    pipeline::{
5        llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
6        text_models_inputs_processor::PagedAttentionMeta,
7        CacheBackendMetadata, CacheInstruction,
8    },
9    prefix_cacher::PrefixCacheManagerV2,
10    response::CompletionChoice,
11    scheduler::{Scheduler, SchedulerOutput},
12    search,
13    sequence::{SeqStepType, StopReason},
14    tools, CompletionResponse, SchedulerConfig, DEBUG,
15};
16use interprocess::local_socket::{traits::Listener, ListenerOptions};
17use llguidance::ParserFactory;
18pub use logger::IntervalLogger;
19use mistralrs_quant::RingConfig;
20use once_cell::sync::Lazy;
21use rand::SeedableRng;
22use rand_isaac::Isaac64Rng;
23use std::{
24    collections::HashMap,
25    io::{BufWriter, Write},
26    net::TcpListener,
27    ops::Deref,
28    sync::{
29        atomic::{AtomicBool, Ordering},
30        Arc,
31    },
32    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
33};
34use tokio::{
35    sync::{mpsc::Receiver, Mutex},
36    task::JoinHandle,
37};
38
39use crate::{
40    get_mut_arcmutex, handle_pipeline_forward_error,
41    pipeline::Pipeline,
42    request::Request,
43    response::{ChatCompletionResponse, Choice, ResponseMessage},
44    sequence::{SequenceRecognizer, SequenceState},
45    Constraint,
46};
47
48mod add_request;
49mod logger;
50mod search_request;
51
52pub enum EngineInstruction {
53    Terminate,
54}
55
56#[derive(Debug, Default, Clone)]
57/// Embedding model used for ranking web search results internally.
58pub enum BertEmbeddingModel {
59    #[default]
60    SnowflakeArcticEmbedL,
61    Custom(String),
62}
63
64const SEED: u64 = 0;
65/// Terminate all sequences on the next scheduling step. Be sure to reset this.
66/// This is a global flag for terminating all engines at once (e.g., Ctrl+C).
67pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
68
69/// Engine-specific termination flags, per Engine thread ID.
70static ENGINE_TERMINATE_FLAGS: Lazy<
71    std::sync::Mutex<HashMap<std::thread::ThreadId, Arc<AtomicBool>>>,
72> = Lazy::new(|| std::sync::Mutex::new(HashMap::new()));
73
74/// Get or create a termination flag for the current engine thread.
75pub fn get_engine_terminate_flag() -> Arc<AtomicBool> {
76    let thread_id = std::thread::current().id();
77    let mut flags = ENGINE_TERMINATE_FLAGS.lock().unwrap();
78    flags
79        .entry(thread_id)
80        .or_insert_with(|| Arc::new(AtomicBool::new(false)))
81        .clone()
82}
83
84/// Check if the current engine should terminate sequences.
85pub fn should_terminate_engine_sequences() -> bool {
86    // Check global flag first
87    if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
88        return true;
89    }
90    // Then check engine-specific flag
91    let thread_id = std::thread::current().id();
92    if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
93        if let Some(flag) = flags.get(&thread_id) {
94            return flag.load(Ordering::SeqCst);
95        }
96    }
97    false
98}
99
100/// Reset termination flags for the current engine.
101pub fn reset_engine_terminate_flag() {
102    let thread_id = std::thread::current().id();
103    if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
104        if let Some(flag) = flags.get(&thread_id) {
105            flag.store(false, Ordering::SeqCst);
106        }
107    }
108}
109
110/// Engine instructions, per Engine (MistralRs) ID.
111pub static ENGINE_INSTRUCTIONS: Lazy<std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>> =
112    Lazy::new(|| std::sync::Mutex::new(HashMap::new()));
113
114pub struct Engine {
115    rx: Arc<Mutex<Receiver<Request>>>,
116    pipeline: Arc<Mutex<dyn Pipeline>>,
117    bert_pipeline: Arc<Mutex<Option<BertPipeline>>>,
118    search_callback: Option<Arc<search::SearchCallback>>,
119    tool_callbacks: tools::ToolCallbacks,
120    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
121    scheduler: Arc<Mutex<dyn Scheduler>>,
122    id: Arc<Mutex<usize>>,
123    truncate_sequence: bool,
124    no_kv_cache: bool,
125    prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
126    is_debug: bool,
127    disable_eos_stop: bool,
128    throughput_logging_enabled: bool,
129    logger: IntervalLogger,
130    handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
131}
132
133impl Drop for Engine {
134    fn drop(&mut self) {
135        for handle in &*get_mut_arcmutex!(self.handles) {
136            handle.abort();
137        }
138    }
139}
140
141impl Engine {
142    #[allow(clippy::too_many_arguments)]
143    pub fn new(
144        rx: Receiver<Request>,
145        pipeline: Arc<Mutex<dyn Pipeline>>,
146        config: SchedulerConfig,
147        truncate_sequence: bool,
148        mut no_kv_cache: bool,
149        mut no_prefix_cache: bool,
150        prefix_cache_n: usize,
151        disable_eos_stop: bool,
152        throughput_logging_enabled: bool,
153        search_embedding_model: Option<BertEmbeddingModel>,
154        search_callback: Option<Arc<search::SearchCallback>>,
155        tool_callbacks: tools::ToolCallbacks,
156        tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
157    ) -> anyhow::Result<Self> {
158        no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
159
160        no_prefix_cache = no_prefix_cache
161            || no_kv_cache
162            || get_mut_arcmutex!(pipeline).get_metadata().no_prefix_cache
163            || prefix_cache_n == 0;
164
165        let bert_pipeline = match search_embedding_model {
166            Some(search_embedding_model) => Some(BertPipeline::new(
167                search_embedding_model,
168                &get_mut_arcmutex!(pipeline).device(),
169            )?),
170            None => None,
171        };
172
173        let scheduler = config.into_scheduler();
174        let block_engine = get_mut_arcmutex!(scheduler).block_engine();
175
176        Ok(Self {
177            rx: Arc::new(Mutex::new(rx)),
178            pipeline,
179            bert_pipeline: Arc::new(Mutex::new(bert_pipeline)),
180            search_callback,
181            tool_callbacks,
182            tool_callbacks_with_tools,
183            scheduler: scheduler.clone(),
184            id: Arc::new(Mutex::new(0)),
185            truncate_sequence,
186            no_kv_cache,
187            prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
188                prefix_cache_n,
189                no_prefix_cache,
190                block_engine,
191            ))),
192            is_debug: DEBUG.load(Ordering::Relaxed),
193            disable_eos_stop,
194            throughput_logging_enabled,
195            logger: IntervalLogger::new(Duration::from_secs(5)),
196            handles: Arc::new(Mutex::new(Vec::new())),
197        })
198    }
199
200    pub async fn run(self: Arc<Self>) {
201        if self.throughput_logging_enabled {
202            self.logger.enable_logging();
203        }
204
205        let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
206        let mut last_completion_ids: Vec<usize> = vec![];
207        'lp: loop {
208            if matches!(
209                ENGINE_INSTRUCTIONS
210                    .lock()
211                    .expect("`ENGINE_INSTRUCTIONS` was poisoned")
212                    .get(get_mut_arcmutex!(self.id).deref()),
213                Some(Some(EngineInstruction::Terminate))
214            ) {
215                self.replicate_request_to_daemons(&Request::Terminate);
216                break 'lp;
217            }
218
219            while let Ok(request) = get_mut_arcmutex!(self.rx).try_recv() {
220                self.replicate_request_to_daemons(&request);
221                if matches!(request, Request::Terminate) {
222                    break 'lp;
223                }
224                self.clone().handle_request(request).await;
225            }
226
227            if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
228                self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
229            }
230
231            let run_start = Instant::now();
232            let mut scheduler = get_mut_arcmutex!(self.scheduler);
233            let scheduled = scheduler.schedule(&self.logger);
234
235            match scheduled {
236                SchedulerOutput::DefaultScheduler {
237                    output: mut scheduled,
238                } => {
239                    if !scheduled.completion.is_empty() {
240                        let current_completion_ids: Vec<usize> =
241                            scheduled.completion.iter().map(|seq| *seq.id()).collect();
242                        let res = {
243                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
244                            let pre_op = if !self.no_kv_cache
245                                && last_completion_ids != current_completion_ids
246                            {
247                                CacheInstruction::In
248                            } else {
249                                CacheInstruction::Nothing
250                            };
251                            let post_op = if !self.no_kv_cache {
252                                CacheInstruction::Out
253                            } else {
254                                CacheInstruction::Reset {
255                                    load_preallocated_cache: false,
256                                    reset_non_granular: false,
257                                }
258                            };
259
260                            let return_raw_logits = scheduled.completion[0].return_raw_logits;
261                            assert!(
262                                scheduled
263                                    .completion
264                                    .iter()
265                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
266                                "All sequences must either return raw logits, or not."
267                            );
268
269                            pipeline
270                                .step(
271                                    &mut scheduled.completion,
272                                    false,
273                                    return_raw_logits,
274                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
275                                    self.disable_eos_stop,
276                                    rng.clone(),
277                                    CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
278                                )
279                                .await
280                        };
281
282                        handle_pipeline_forward_error!(
283                            "completion step",
284                            res,
285                            &mut scheduled.completion,
286                            self.pipeline,
287                            'lp,
288                            self.prefix_cacher
289                        );
290
291                        self.logger.add_tokens_processed(scheduled.completion.len());
292
293                        last_completion_ids = current_completion_ids;
294                    }
295
296                    if !scheduled.prompt.is_empty() {
297                        let prompt_exec_time = {
298                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
299
300                            // Run the prompt seqs
301                            let post_op = if !self.no_kv_cache {
302                                CacheInstruction::Out
303                            } else {
304                                CacheInstruction::Reset {
305                                    load_preallocated_cache: false,
306                                    reset_non_granular: false,
307                                }
308                            };
309
310                            let return_raw_logits = scheduled.prompt[0].return_raw_logits;
311                            assert!(
312                                scheduled
313                                    .prompt
314                                    .iter()
315                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
316                                "All sequences must either return raw logits, or not."
317                            );
318
319                            // This comes from prefix caching
320                            // The invariant where all token offsets are the same is handled by the scheduler
321                            let pre_op = if scheduled.prompt[0].token_offset() != 0 {
322                                CacheInstruction::In
323                            } else {
324                                CacheInstruction::Reset {
325                                    load_preallocated_cache: true,
326                                    reset_non_granular: false,
327                                }
328                            };
329
330                            pipeline
331                                .step(
332                                    &mut scheduled.prompt,
333                                    true,
334                                    return_raw_logits,
335                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
336                                    self.disable_eos_stop,
337                                    rng.clone(),
338                                    CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
339                                )
340                                .await
341                        };
342
343                        let prompt_exec_time = handle_pipeline_forward_error!(
344                            "prompt step",
345                            prompt_exec_time,
346                            &mut scheduled.prompt,
347                            self.pipeline,
348                            'lp,
349                            self.prefix_cacher
350                        );
351
352                        let total_processed_tokens: usize = scheduled
353                            .prompt
354                            .iter()
355                            .map(|seq| seq.get_toks().len())
356                            .sum();
357                        self.logger.add_tokens_processed(total_processed_tokens);
358
359                        for seq in scheduled.prompt.iter_mut() {
360                            match seq.sequence_stepping_type() {
361                                SeqStepType::OneShot => {
362                                    seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
363                                }
364                                SeqStepType::PromptAndDecode => {
365                                    seq.set_state(SequenceState::RunningCompletion)
366                                }
367                            }
368                            let now = SystemTime::now()
369                                .duration_since(UNIX_EPOCH)
370                                .expect("Time travel has occurred!")
371                                .as_millis();
372                            #[allow(clippy::cast_precision_loss)]
373                            let prompt_tok_per_sec =
374                                seq.len() as f32 / prompt_exec_time.as_secs_f32();
375                            seq.prompt_tok_per_sec = prompt_tok_per_sec;
376                            seq.prompt_timestamp = Some(now);
377                            seq.total_prompt_time = Some(prompt_exec_time.as_millis());
378                        }
379                        last_completion_ids = vec![];
380                    }
381
382                    if self.is_debug {
383                        let ms_from_last_run = run_start.elapsed().as_secs_f64();
384                        let total_len = scheduled.prompt.len() + scheduled.completion.len();
385                        if total_len > 0 {
386                            let prompt_lengths = scheduled
387                                .prompt
388                                .iter()
389                                .map(|seq| seq.len().to_string())
390                                .collect::<Vec<_>>()
391                                .join(", ");
392
393                            let completion_lengths = scheduled
394                                .completion
395                                .iter()
396                                .map(|seq| seq.len().to_string())
397                                .collect::<Vec<_>>()
398                                .join(", ");
399
400                            tracing::info!(
401                                "Prompt[{}] Completion[{}] - {}ms",
402                                prompt_lengths,
403                                completion_lengths,
404                                ms_from_last_run * 1000.,
405                            );
406                        }
407                    }
408                }
409                SchedulerOutput::PagedAttention { mut output } => {
410                    if !output.scheduled.is_empty() {
411                        let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
412
413                        let mut guards = output
414                            .scheduled
415                            .iter_mut()
416                            .map(|seq| seq.lock().unwrap())
417                            .collect::<Vec<_>>();
418
419                        let mut guards_mut =
420                            guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
421
422                        let res = {
423                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
424
425                            let block_size = scheduler.block_size().unwrap();
426
427                            let metadata = PagedAttentionMeta {
428                                block_size,
429                                sliding_window: pipeline.get_metadata().sliding_window,
430                                block_engine: scheduler.block_engine().unwrap(),
431                            };
432
433                            let return_raw_logits = guards_mut[0].return_raw_logits;
434                            assert!(
435                                guards_mut
436                                    .iter()
437                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
438                                "All sequences must either return raw logits, or not."
439                            );
440
441                            pipeline
442                                .step(
443                                    &mut guards_mut,
444                                    is_prompt,
445                                    return_raw_logits,
446                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
447                                    self.disable_eos_stop,
448                                    rng.clone(),
449                                    CacheBackendMetadata::PagedAttention {
450                                        metadata,
451                                        blocks_to_copy: output.blocks_to_copy,
452                                        blocks_to_swap_in: output.blocks_to_swap_in,
453                                        blocks_to_swap_out: output.blocks_to_swap_out,
454                                    },
455                                )
456                                .await
457                        };
458
459                        handle_pipeline_forward_error!(
460                            "step",
461                            res,
462                            &mut guards_mut,
463                            self.pipeline,
464                            'lp,
465                            self.prefix_cacher
466                        );
467
468                        let total_processed_tokens: usize = guards
469                            .iter()
470                            .map(|seq| {
471                                if seq.is_prompt() {
472                                    seq.get_toks().len()
473                                } else {
474                                    1
475                                }
476                            })
477                            .sum();
478                        self.logger.add_tokens_processed(total_processed_tokens);
479
480                        if self.is_debug {
481                            let ms_from_last_run = run_start.elapsed().as_secs_f64();
482                            let total_len = guards.len();
483                            if total_len > 0 {
484                                let lengths = guards
485                                    .iter()
486                                    .map(|seq| seq.len().to_string())
487                                    .collect::<Vec<_>>()
488                                    .join(", ");
489
490                                let (prompt_lengths, completion_lengths) = if is_prompt {
491                                    (lengths, "".to_string())
492                                } else {
493                                    ("".to_string(), lengths)
494                                };
495
496                                tracing::info!(
497                                    "Prompt[{}] Completion[{}] - {}ms",
498                                    prompt_lengths,
499                                    completion_lengths,
500                                    ms_from_last_run * 1000.,
501                                );
502                            }
503                        }
504
505                        if is_prompt {
506                            for mut seq in guards {
507                                let now = SystemTime::now()
508                                    .duration_since(UNIX_EPOCH)
509                                    .expect("Time travel has occurred!")
510                                    .as_millis();
511                                #[allow(clippy::cast_precision_loss)]
512                                let prompt_tok_per_sec =
513                                    seq.len() as f32 / (now - seq.timestamp()) as f32;
514                                seq.prompt_tok_per_sec = prompt_tok_per_sec * 1000.;
515                                seq.prompt_timestamp = Some(now);
516                                seq.total_prompt_time = Some(now - seq.timestamp());
517                            }
518                        }
519                    }
520                }
521            }
522
523            scheduler.free_finished_sequence_groups();
524        }
525    }
526
527    fn build_sequence_recognizer(
528        factory: &Option<Arc<ParserFactory>>,
529        constraint: &Constraint,
530    ) -> anyhow::Result<SequenceRecognizer> {
531        if let Some(grm) = llg_grammar_from_constraint(constraint)? {
532            let factory = factory
533                .as_ref()
534                .ok_or_else(|| anyhow::anyhow!("No token environment (llg_factory) found."))?;
535            let llg = constraint_from_llg_grammar(factory, grm)?;
536            Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
537        } else {
538            Ok(SequenceRecognizer::None)
539        }
540    }
541
542    fn replicate_request_to_daemons(&self, request: &Request) {
543        if !distributed::is_daemon() && mistralrs_quant::distributed::use_nccl() {
544            let name = distributed::ipc_name().unwrap();
545            let num_workers =
546                mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
547            let listener = ListenerOptions::new().name(name).create_sync().unwrap();
548
549            for _ in 0..num_workers {
550                let stream = listener.accept().unwrap();
551                let mut writer = BufWriter::new(stream);
552                let req = format!("{}\n", serde_json::to_string(&request).unwrap());
553                writer.write_all(req.as_bytes()).unwrap();
554            }
555        } else if !distributed::is_daemon() && cfg!(feature = "ring") {
556            let num_workers =
557                mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
558            let master_port = RingConfig::load().master_port;
559            let listener =
560                TcpListener::bind(format!("0.0.0.0:{master_port}")).expect("bind replicator");
561
562            for _ in 0..num_workers {
563                let (stream, _) = listener.accept().unwrap();
564                let mut writer = BufWriter::new(stream);
565                let req = format!("{}\n", serde_json::to_string(&request).unwrap());
566                writer.write_all(req.as_bytes()).unwrap();
567            }
568        }
569    }
570}