mistralrs_core/engine/
mod.rs

1use crate::{
2    distributed,
3    embedding::bert::BertPipeline,
4    pipeline::{
5        llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
6        text_models_inputs_processor::PagedAttentionMeta,
7        CacheBackendMetadata, CacheInstruction,
8    },
9    prefix_cacher::PrefixCacheManagerV2,
10    response::CompletionChoice,
11    scheduler::{Scheduler, SchedulerOutput},
12    search,
13    sequence::{SeqStepType, StopReason},
14    tools, CompletionResponse, SchedulerConfig, DEBUG,
15};
16use interprocess::local_socket::{traits::Listener, ListenerOptions};
17use llguidance::ParserFactory;
18pub use logger::IntervalLogger;
19use mistralrs_quant::RingConfig;
20use once_cell::sync::Lazy;
21use rand::SeedableRng;
22use rand_isaac::Isaac64Rng;
23use std::{
24    collections::HashMap,
25    io::{BufWriter, Write},
26    net::TcpListener,
27    ops::Deref,
28    sync::{
29        atomic::{AtomicBool, Ordering},
30        Arc,
31    },
32    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
33};
34use tokio::{
35    sync::{
36        mpsc::{error::TryRecvError, Receiver},
37        Mutex,
38    },
39    task::JoinHandle,
40};
41
42use crate::{
43    get_mut_arcmutex, handle_pipeline_forward_error,
44    pipeline::Pipeline,
45    request::Request,
46    response::{ChatCompletionResponse, Choice, ResponseMessage},
47    sequence::{SequenceRecognizer, SequenceState},
48    Constraint,
49};
50
51mod add_request;
52mod logger;
53mod search_request;
54
55pub enum EngineInstruction {
56    Terminate,
57}
58
59#[derive(Debug, Default, Clone)]
60/// Embedding model used for ranking web search results internally.
61pub enum BertEmbeddingModel {
62    #[default]
63    SnowflakeArcticEmbedL,
64    Custom(String),
65}
66
67const SEED: u64 = 0;
68/// Terminate all sequences on the next scheduling step. Be sure to reset this.
69/// This is a global flag for terminating all engines at once (e.g., Ctrl+C).
70pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
71
72/// Engine-specific termination flags, per Engine thread ID.
73static ENGINE_TERMINATE_FLAGS: Lazy<
74    std::sync::Mutex<HashMap<std::thread::ThreadId, Arc<AtomicBool>>>,
75> = Lazy::new(|| std::sync::Mutex::new(HashMap::new()));
76
77/// Get or create a termination flag for the current engine thread.
78pub fn get_engine_terminate_flag() -> Arc<AtomicBool> {
79    let thread_id = std::thread::current().id();
80    let mut flags = ENGINE_TERMINATE_FLAGS.lock().unwrap();
81    flags
82        .entry(thread_id)
83        .or_insert_with(|| Arc::new(AtomicBool::new(false)))
84        .clone()
85}
86
87/// Check if the current engine should terminate sequences.
88pub fn should_terminate_engine_sequences() -> bool {
89    // Check global flag first
90    if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
91        return true;
92    }
93    // Then check engine-specific flag
94    let thread_id = std::thread::current().id();
95    if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
96        if let Some(flag) = flags.get(&thread_id) {
97            return flag.load(Ordering::SeqCst);
98        }
99    }
100    false
101}
102
103/// Reset termination flags for the current engine.
104pub fn reset_engine_terminate_flag() {
105    let thread_id = std::thread::current().id();
106    if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
107        if let Some(flag) = flags.get(&thread_id) {
108            flag.store(false, Ordering::SeqCst);
109        }
110    }
111}
112
113/// Engine instructions, per Engine (MistralRs) ID.
114pub static ENGINE_INSTRUCTIONS: Lazy<std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>> =
115    Lazy::new(|| std::sync::Mutex::new(HashMap::new()));
116
117pub struct Engine {
118    rx: Arc<Mutex<Receiver<Request>>>,
119    pipeline: Arc<Mutex<dyn Pipeline>>,
120    bert_pipeline: Arc<Mutex<Option<BertPipeline>>>,
121    search_callback: Option<Arc<search::SearchCallback>>,
122    tool_callbacks: tools::ToolCallbacks,
123    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
124    scheduler: Arc<Mutex<dyn Scheduler>>,
125    id: Arc<Mutex<usize>>,
126    no_kv_cache: bool,
127    prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
128    is_debug: bool,
129    disable_eos_stop: bool,
130    throughput_logging_enabled: bool,
131    logger: IntervalLogger,
132    handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
133}
134
135impl Drop for Engine {
136    fn drop(&mut self) {
137        for handle in &*get_mut_arcmutex!(self.handles) {
138            handle.abort();
139        }
140    }
141}
142
143impl Engine {
144    #[allow(clippy::too_many_arguments)]
145    pub fn new(
146        rx: Receiver<Request>,
147        pipeline: Arc<Mutex<dyn Pipeline>>,
148        config: SchedulerConfig,
149        mut no_kv_cache: bool,
150        mut no_prefix_cache: bool,
151        prefix_cache_n: usize,
152        disable_eos_stop: bool,
153        throughput_logging_enabled: bool,
154        search_embedding_model: Option<BertEmbeddingModel>,
155        search_callback: Option<Arc<search::SearchCallback>>,
156        tool_callbacks: tools::ToolCallbacks,
157        tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
158    ) -> anyhow::Result<Self> {
159        no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
160
161        no_prefix_cache = no_prefix_cache
162            || no_kv_cache
163            || get_mut_arcmutex!(pipeline).get_metadata().no_prefix_cache
164            || prefix_cache_n == 0;
165
166        let bert_pipeline = match search_embedding_model {
167            Some(search_embedding_model) => Some(BertPipeline::new(
168                search_embedding_model,
169                &get_mut_arcmutex!(pipeline).device(),
170            )?),
171            None => None,
172        };
173
174        let scheduler = config.into_scheduler();
175        let block_engine = get_mut_arcmutex!(scheduler).block_engine();
176
177        Ok(Self {
178            rx: Arc::new(Mutex::new(rx)),
179            pipeline,
180            bert_pipeline: Arc::new(Mutex::new(bert_pipeline)),
181            search_callback,
182            tool_callbacks,
183            tool_callbacks_with_tools,
184            scheduler: scheduler.clone(),
185            id: Arc::new(Mutex::new(0)),
186            no_kv_cache,
187            prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
188                prefix_cache_n,
189                no_prefix_cache,
190                block_engine,
191            ))),
192            is_debug: DEBUG.load(Ordering::Relaxed),
193            disable_eos_stop,
194            throughput_logging_enabled,
195            logger: IntervalLogger::new(Duration::from_secs(5)),
196            handles: Arc::new(Mutex::new(Vec::new())),
197        })
198    }
199
200    pub async fn run(self: Arc<Self>) {
201        if self.throughput_logging_enabled {
202            self.logger.enable_logging();
203        }
204
205        let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
206        let mut last_completion_ids: Vec<usize> = vec![];
207        'lp: loop {
208            let should_terminate = || {
209                matches!(
210                    ENGINE_INSTRUCTIONS
211                        .lock()
212                        .expect("`ENGINE_INSTRUCTIONS` was poisoned")
213                        .get(get_mut_arcmutex!(self.id).deref()),
214                    Some(Some(EngineInstruction::Terminate))
215                )
216            };
217
218            if should_terminate() {
219                self.replicate_request_to_daemons(&Request::Terminate);
220                break 'lp;
221            }
222
223            let mut channel_disconnected = false;
224            loop {
225                let next_request = {
226                    let mut rx = self.rx.lock().await;
227                    rx.try_recv()
228                };
229
230                match next_request {
231                    Ok(request) => {
232                        self.replicate_request_to_daemons(&request);
233                        if matches!(request, Request::Terminate) {
234                            break 'lp;
235                        }
236                        self.clone().handle_request(request).await;
237                    }
238                    Err(TryRecvError::Empty) => break,
239                    Err(TryRecvError::Disconnected) => {
240                        channel_disconnected = true;
241                        break;
242                    }
243                }
244            }
245
246            if channel_disconnected {
247                break 'lp;
248            }
249
250            let scheduler_idle = {
251                let scheduler = get_mut_arcmutex!(self.scheduler);
252                scheduler.waiting_len() == 0 && scheduler.running_len() == 0
253            };
254
255            if scheduler_idle {
256                if should_terminate() {
257                    self.replicate_request_to_daemons(&Request::Terminate);
258                    break 'lp;
259                }
260
261                let next_request = {
262                    let mut rx = self.rx.lock().await;
263                    rx.recv().await
264                };
265
266                match next_request {
267                    Some(request) => {
268                        self.replicate_request_to_daemons(&request);
269                        if matches!(request, Request::Terminate) {
270                            break 'lp;
271                        }
272                        self.clone().handle_request(request).await;
273                        continue;
274                    }
275                    None => break 'lp,
276                }
277            }
278
279            if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
280                self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
281            }
282
283            let run_start = Instant::now();
284            let mut scheduler = get_mut_arcmutex!(self.scheduler);
285            let scheduled = scheduler.schedule(&self.logger);
286
287            match scheduled {
288                SchedulerOutput::DefaultScheduler {
289                    output: mut scheduled,
290                } => {
291                    if !scheduled.completion.is_empty() {
292                        let current_completion_ids: Vec<usize> =
293                            scheduled.completion.iter().map(|seq| *seq.id()).collect();
294                        let res = {
295                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
296                            let pre_op = if !self.no_kv_cache
297                                && last_completion_ids != current_completion_ids
298                            {
299                                CacheInstruction::In
300                            } else {
301                                CacheInstruction::Nothing
302                            };
303                            let post_op = if !self.no_kv_cache {
304                                CacheInstruction::Out
305                            } else {
306                                CacheInstruction::Reset {
307                                    load_preallocated_cache: false,
308                                    reset_non_granular: false,
309                                }
310                            };
311
312                            let return_raw_logits = scheduled.completion[0].return_raw_logits;
313                            assert!(
314                                scheduled
315                                    .completion
316                                    .iter()
317                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
318                                "All sequences must either return raw logits, or not."
319                            );
320
321                            pipeline
322                                .step(
323                                    &mut scheduled.completion,
324                                    false,
325                                    return_raw_logits,
326                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
327                                    self.disable_eos_stop,
328                                    rng.clone(),
329                                    CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
330                                )
331                                .await
332                        };
333
334                        handle_pipeline_forward_error!(
335                            "completion step",
336                            res,
337                            &mut scheduled.completion,
338                            self.pipeline,
339                            'lp,
340                            self.prefix_cacher
341                        );
342
343                        self.logger.add_tokens_processed(scheduled.completion.len());
344
345                        last_completion_ids = current_completion_ids;
346                    }
347
348                    if !scheduled.prompt.is_empty() {
349                        let prompt_exec_time = {
350                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
351
352                            // Run the prompt seqs
353                            let post_op = if !self.no_kv_cache {
354                                CacheInstruction::Out
355                            } else {
356                                CacheInstruction::Reset {
357                                    load_preallocated_cache: false,
358                                    reset_non_granular: false,
359                                }
360                            };
361
362                            let return_raw_logits = scheduled.prompt[0].return_raw_logits;
363                            assert!(
364                                scheduled
365                                    .prompt
366                                    .iter()
367                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
368                                "All sequences must either return raw logits, or not."
369                            );
370
371                            // This comes from prefix caching
372                            // The invariant where all token offsets are the same is handled by the scheduler
373                            let pre_op = if scheduled.prompt[0].token_offset() != 0 {
374                                CacheInstruction::In
375                            } else {
376                                CacheInstruction::Reset {
377                                    load_preallocated_cache: true,
378                                    reset_non_granular: false,
379                                }
380                            };
381
382                            pipeline
383                                .step(
384                                    &mut scheduled.prompt,
385                                    true,
386                                    return_raw_logits,
387                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
388                                    self.disable_eos_stop,
389                                    rng.clone(),
390                                    CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
391                                )
392                                .await
393                        };
394
395                        let prompt_exec_time = handle_pipeline_forward_error!(
396                            "prompt step",
397                            prompt_exec_time,
398                            &mut scheduled.prompt,
399                            self.pipeline,
400                            'lp,
401                            self.prefix_cacher
402                        );
403
404                        let total_processed_tokens: usize = scheduled
405                            .prompt
406                            .iter()
407                            .map(|seq| seq.get_toks().len())
408                            .sum();
409                        self.logger.add_tokens_processed(total_processed_tokens);
410
411                        for seq in scheduled.prompt.iter_mut() {
412                            match seq.sequence_stepping_type() {
413                                SeqStepType::OneShot => {
414                                    seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
415                                }
416                                SeqStepType::PromptAndDecode => {
417                                    seq.set_state(SequenceState::RunningCompletion)
418                                }
419                            }
420                            let now = SystemTime::now()
421                                .duration_since(UNIX_EPOCH)
422                                .expect("Time travel has occurred!")
423                                .as_millis();
424                            #[allow(clippy::cast_precision_loss)]
425                            let prompt_tok_per_sec =
426                                seq.len() as f32 / prompt_exec_time.as_secs_f32();
427                            seq.prompt_tok_per_sec = prompt_tok_per_sec;
428                            seq.prompt_timestamp = Some(now);
429                            seq.total_prompt_time = Some(prompt_exec_time.as_millis());
430                        }
431                        last_completion_ids = vec![];
432                    }
433
434                    if self.is_debug {
435                        let ms_from_last_run = run_start.elapsed().as_secs_f64();
436                        let total_len = scheduled.prompt.len() + scheduled.completion.len();
437                        if total_len > 0 {
438                            let prompt_lengths = scheduled
439                                .prompt
440                                .iter()
441                                .map(|seq| seq.len().to_string())
442                                .collect::<Vec<_>>()
443                                .join(", ");
444
445                            let completion_lengths = scheduled
446                                .completion
447                                .iter()
448                                .map(|seq| seq.len().to_string())
449                                .collect::<Vec<_>>()
450                                .join(", ");
451
452                            tracing::info!(
453                                "Prompt[{}] Completion[{}] - {}ms",
454                                prompt_lengths,
455                                completion_lengths,
456                                ms_from_last_run * 1000.,
457                            );
458                        }
459                    }
460                }
461                SchedulerOutput::PagedAttention { mut output } => {
462                    if !output.scheduled.is_empty() {
463                        let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
464
465                        let mut guards = output
466                            .scheduled
467                            .iter_mut()
468                            .map(|seq| seq.lock().unwrap())
469                            .collect::<Vec<_>>();
470
471                        let mut guards_mut =
472                            guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
473
474                        let res = {
475                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
476
477                            let block_size = scheduler.block_size().unwrap();
478
479                            let metadata = PagedAttentionMeta {
480                                block_size,
481                                sliding_window: pipeline.get_metadata().sliding_window,
482                                block_engine: scheduler.block_engine().unwrap(),
483                            };
484
485                            let return_raw_logits = guards_mut[0].return_raw_logits;
486                            assert!(
487                                guards_mut
488                                    .iter()
489                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
490                                "All sequences must either return raw logits, or not."
491                            );
492
493                            pipeline
494                                .step(
495                                    &mut guards_mut,
496                                    is_prompt,
497                                    return_raw_logits,
498                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
499                                    self.disable_eos_stop,
500                                    rng.clone(),
501                                    CacheBackendMetadata::PagedAttention {
502                                        metadata,
503                                        blocks_to_copy: output.blocks_to_copy,
504                                    },
505                                )
506                                .await
507                        };
508
509                        handle_pipeline_forward_error!(
510                            "step",
511                            res,
512                            &mut guards_mut,
513                            self.pipeline,
514                            'lp,
515                            self.prefix_cacher
516                        );
517
518                        let total_processed_tokens: usize = guards
519                            .iter()
520                            .map(|seq| {
521                                if seq.is_prompt() {
522                                    seq.get_toks().len()
523                                } else {
524                                    1
525                                }
526                            })
527                            .sum();
528                        self.logger.add_tokens_processed(total_processed_tokens);
529
530                        if self.is_debug {
531                            let ms_from_last_run = run_start.elapsed().as_secs_f64();
532                            let total_len = guards.len();
533                            if total_len > 0 {
534                                let lengths = guards
535                                    .iter()
536                                    .map(|seq| seq.len().to_string())
537                                    .collect::<Vec<_>>()
538                                    .join(", ");
539
540                                let (prompt_lengths, completion_lengths) = if is_prompt {
541                                    (lengths, "".to_string())
542                                } else {
543                                    ("".to_string(), lengths)
544                                };
545
546                                tracing::info!(
547                                    "Prompt[{}] Completion[{}] - {}ms",
548                                    prompt_lengths,
549                                    completion_lengths,
550                                    ms_from_last_run * 1000.,
551                                );
552                            }
553                        }
554
555                        if is_prompt {
556                            for mut seq in guards {
557                                let now = SystemTime::now()
558                                    .duration_since(UNIX_EPOCH)
559                                    .expect("Time travel has occurred!")
560                                    .as_millis();
561                                #[allow(clippy::cast_precision_loss)]
562                                let prompt_tok_per_sec =
563                                    seq.len() as f32 / (now - seq.timestamp()) as f32;
564                                seq.prompt_tok_per_sec = prompt_tok_per_sec * 1000.;
565                                seq.prompt_timestamp = Some(now);
566                                seq.total_prompt_time = Some(now - seq.timestamp());
567                            }
568                        }
569                    }
570                }
571            }
572
573            scheduler.free_finished_sequence_groups();
574        }
575    }
576
577    fn build_sequence_recognizer(
578        factory: &Option<Arc<ParserFactory>>,
579        constraint: &Constraint,
580    ) -> anyhow::Result<SequenceRecognizer> {
581        if let Some(grm) = llg_grammar_from_constraint(constraint)? {
582            let factory = factory
583                .as_ref()
584                .ok_or_else(|| anyhow::anyhow!("No token environment (llg_factory) found."))?;
585            let llg = constraint_from_llg_grammar(factory, grm)?;
586            Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
587        } else {
588            Ok(SequenceRecognizer::None)
589        }
590    }
591
592    fn replicate_request_to_daemons(&self, request: &Request) {
593        if !distributed::is_daemon() && mistralrs_quant::distributed::use_nccl() {
594            let name = distributed::ipc_name().unwrap();
595            let num_workers =
596                mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
597            let listener = ListenerOptions::new().name(name).create_sync().unwrap();
598
599            for _ in 0..num_workers {
600                let stream = listener.accept().unwrap();
601                let mut writer = BufWriter::new(stream);
602                let req = format!("{}\n", serde_json::to_string(&request).unwrap());
603                writer.write_all(req.as_bytes()).unwrap();
604            }
605        } else if !distributed::is_daemon() && cfg!(feature = "ring") {
606            let num_workers =
607                mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
608            let master_port = RingConfig::load().master_port;
609            let listener =
610                TcpListener::bind(format!("0.0.0.0:{master_port}")).expect("bind replicator");
611
612            for _ in 0..num_workers {
613                let (stream, _) = listener.accept().unwrap();
614                let mut writer = BufWriter::new(stream);
615                let req = format!("{}\n", serde_json::to_string(&request).unwrap());
616                writer.write_all(req.as_bytes()).unwrap();
617            }
618        }
619    }
620}