mistralrs_core/engine/
mod.rs

1use crate::{
2    distributed,
3    embedding::bert::BertPipeline,
4    pipeline::{
5        llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
6        text_models_inputs_processor::PagedAttentionMeta,
7        CacheBackendMetadata, CacheInstruction,
8    },
9    prefix_cacher::PrefixCacheManagerV2,
10    response::CompletionChoice,
11    scheduler::{Scheduler, SchedulerOutput},
12    sequence::{SeqStepType, StopReason},
13    CompletionResponse, SchedulerConfig, DEBUG,
14};
15use interprocess::local_socket::{traits::Listener, ListenerOptions};
16use llguidance::ParserFactory;
17pub use logger::IntervalLogger;
18use once_cell::sync::Lazy;
19use rand::SeedableRng;
20use rand_isaac::Isaac64Rng;
21use std::{
22    collections::HashMap,
23    io::{BufWriter, Write},
24    ops::Deref,
25    sync::{
26        atomic::{AtomicBool, Ordering},
27        Arc,
28    },
29    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
30};
31use tokio::{
32    sync::{mpsc::Receiver, Mutex},
33    task::JoinHandle,
34};
35
36use crate::{
37    get_mut_arcmutex, handle_pipeline_forward_error,
38    pipeline::Pipeline,
39    request::Request,
40    response::{ChatCompletionResponse, Choice, ResponseMessage},
41    sequence::{SequenceRecognizer, SequenceState},
42    Constraint,
43};
44
45mod add_request;
46mod logger;
47mod search_request;
48
49pub enum EngineInstruction {
50    Terminate,
51}
52
53#[derive(Debug, Default, Clone)]
54/// Embedding model used for ranking web search results internally.
55pub enum BertEmbeddingModel {
56    #[default]
57    SnowflakeArcticEmbedL,
58    Custom(String),
59}
60
61const SEED: u64 = 0;
62/// Terminate all sequences on the next scheduling step. Be sure to reset this.
63pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
64
65/// Engine instructions, per Engine (MistralRs) ID.
66pub static ENGINE_INSTRUCTIONS: Lazy<std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>> =
67    Lazy::new(|| std::sync::Mutex::new(HashMap::new()));
68
69pub struct Engine {
70    rx: Arc<Mutex<Receiver<Request>>>,
71    pipeline: Arc<Mutex<dyn Pipeline>>,
72    bert_pipeline: Arc<Mutex<Option<BertPipeline>>>,
73    scheduler: Arc<Mutex<dyn Scheduler>>,
74    id: Arc<Mutex<usize>>,
75    truncate_sequence: bool,
76    no_kv_cache: bool,
77    prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
78    is_debug: bool,
79    disable_eos_stop: bool,
80    throughput_logging_enabled: bool,
81    logger: IntervalLogger,
82    handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
83}
84
85impl Drop for Engine {
86    fn drop(&mut self) {
87        for handle in &*get_mut_arcmutex!(self.handles) {
88            handle.abort();
89        }
90    }
91}
92
93impl Engine {
94    #[allow(clippy::too_many_arguments)]
95    pub fn new(
96        rx: Receiver<Request>,
97        pipeline: Arc<Mutex<dyn Pipeline>>,
98        config: SchedulerConfig,
99        truncate_sequence: bool,
100        mut no_kv_cache: bool,
101        mut no_prefix_cache: bool,
102        prefix_cache_n: usize,
103        disable_eos_stop: bool,
104        throughput_logging_enabled: bool,
105        search_embedding_model: Option<BertEmbeddingModel>,
106    ) -> anyhow::Result<Self> {
107        no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
108
109        no_prefix_cache = no_prefix_cache
110            || no_kv_cache
111            || get_mut_arcmutex!(pipeline).get_metadata().no_prefix_cache;
112
113        let bert_pipeline = match search_embedding_model {
114            Some(search_embedding_model) => Some(BertPipeline::new(
115                search_embedding_model,
116                &get_mut_arcmutex!(pipeline).device(),
117            )?),
118            None => None,
119        };
120
121        let scheduler = config.into_scheduler();
122        let block_engine = get_mut_arcmutex!(scheduler).block_engine();
123
124        Ok(Self {
125            rx: Arc::new(Mutex::new(rx)),
126            pipeline,
127            bert_pipeline: Arc::new(Mutex::new(bert_pipeline)),
128            scheduler: scheduler.clone(),
129            id: Arc::new(Mutex::new(0)),
130            truncate_sequence,
131            no_kv_cache,
132            prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
133                prefix_cache_n,
134                no_prefix_cache,
135                block_engine,
136            ))),
137            is_debug: DEBUG.load(Ordering::Relaxed),
138            disable_eos_stop,
139            throughput_logging_enabled,
140            logger: IntervalLogger::new(Duration::from_secs(5)),
141            handles: Arc::new(Mutex::new(Vec::new())),
142        })
143    }
144
145    pub async fn run(self: Arc<Self>) {
146        if self.throughput_logging_enabled {
147            self.logger.enable_logging();
148        }
149
150        let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
151        let mut last_completion_ids: Vec<usize> = vec![];
152        'lp: loop {
153            if matches!(
154                ENGINE_INSTRUCTIONS
155                    .lock()
156                    .expect("`ENGINE_INSTRUCTIONS` was poisoned")
157                    .get(get_mut_arcmutex!(self.id).deref()),
158                Some(Some(EngineInstruction::Terminate))
159            ) {
160                self.replicate_request_to_daemons(&Request::Terminate);
161                break 'lp;
162            }
163
164            while let Ok(request) = get_mut_arcmutex!(self.rx).try_recv() {
165                self.replicate_request_to_daemons(&request);
166                if matches!(request, Request::Terminate) {
167                    break 'lp;
168                }
169                self.clone().handle_request(request).await;
170            }
171
172            if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
173                self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
174            }
175
176            let run_start = Instant::now();
177            let mut scheduler = get_mut_arcmutex!(self.scheduler);
178            let scheduled = scheduler.schedule(&self.logger);
179
180            match scheduled {
181                SchedulerOutput::DefaultScheduler {
182                    output: mut scheduled,
183                } => {
184                    if !scheduled.completion.is_empty() {
185                        let current_completion_ids: Vec<usize> =
186                            scheduled.completion.iter().map(|seq| *seq.id()).collect();
187                        let res = {
188                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
189                            let pre_op = if !self.no_kv_cache
190                                && last_completion_ids != current_completion_ids
191                            {
192                                CacheInstruction::In
193                            } else {
194                                CacheInstruction::Nothing
195                            };
196                            let post_op = if !self.no_kv_cache {
197                                CacheInstruction::Out
198                            } else {
199                                CacheInstruction::Reset {
200                                    load_preallocated_cache: false,
201                                    reset_non_granular: false,
202                                }
203                            };
204
205                            let return_raw_logits = scheduled.completion[0].return_raw_logits;
206                            assert!(
207                                scheduled
208                                    .completion
209                                    .iter()
210                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
211                                "All sequences must either return raw logits, or not."
212                            );
213
214                            pipeline
215                                .step(
216                                    &mut scheduled.completion,
217                                    false,
218                                    return_raw_logits,
219                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
220                                    self.disable_eos_stop,
221                                    rng.clone(),
222                                    CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
223                                )
224                                .await
225                        };
226
227                        handle_pipeline_forward_error!(
228                            "completion step",
229                            res,
230                            &mut scheduled.completion,
231                            self.pipeline,
232                            'lp,
233                            self.prefix_cacher
234                        );
235
236                        self.logger.add_tokens_processed(scheduled.completion.len());
237
238                        last_completion_ids = current_completion_ids;
239                    }
240
241                    if !scheduled.prompt.is_empty() {
242                        let prompt_exec_time = {
243                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
244
245                            // Run the prompt seqs
246                            let post_op = if !self.no_kv_cache {
247                                CacheInstruction::Out
248                            } else {
249                                CacheInstruction::Reset {
250                                    load_preallocated_cache: false,
251                                    reset_non_granular: false,
252                                }
253                            };
254
255                            let return_raw_logits = scheduled.prompt[0].return_raw_logits;
256                            assert!(
257                                scheduled
258                                    .prompt
259                                    .iter()
260                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
261                                "All sequences must either return raw logits, or not."
262                            );
263
264                            // This comes from prefix caching
265                            // The invariant where all token offsets are the same is handled by the scheduler
266                            let pre_op = if scheduled.prompt[0].token_offset() != 0 {
267                                CacheInstruction::In
268                            } else {
269                                CacheInstruction::Reset {
270                                    load_preallocated_cache: true,
271                                    reset_non_granular: false,
272                                }
273                            };
274
275                            pipeline
276                                .step(
277                                    &mut scheduled.prompt,
278                                    true,
279                                    return_raw_logits,
280                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
281                                    self.disable_eos_stop,
282                                    rng.clone(),
283                                    CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
284                                )
285                                .await
286                        };
287
288                        let prompt_exec_time = handle_pipeline_forward_error!(
289                            "prompt step",
290                            prompt_exec_time,
291                            &mut scheduled.prompt,
292                            self.pipeline,
293                            'lp,
294                            self.prefix_cacher
295                        );
296
297                        let total_processed_tokens: usize = scheduled
298                            .prompt
299                            .iter()
300                            .map(|seq| seq.get_toks().len())
301                            .sum();
302                        self.logger.add_tokens_processed(total_processed_tokens);
303
304                        for seq in scheduled.prompt.iter_mut() {
305                            match seq.sequence_stepping_type() {
306                                SeqStepType::OneShot => {
307                                    seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
308                                }
309                                SeqStepType::PromptAndDecode => {
310                                    seq.set_state(SequenceState::RunningCompletion)
311                                }
312                            }
313                            let now = SystemTime::now()
314                                .duration_since(UNIX_EPOCH)
315                                .expect("Time travel has occurred!")
316                                .as_millis();
317                            #[allow(clippy::cast_precision_loss)]
318                            let prompt_tok_per_sec =
319                                seq.len() as f32 / prompt_exec_time.as_secs_f32();
320                            seq.prompt_tok_per_sec = prompt_tok_per_sec;
321                            seq.prompt_timestamp = Some(now);
322                            seq.total_prompt_time = Some(prompt_exec_time.as_millis());
323                        }
324                        last_completion_ids = vec![];
325                    }
326
327                    if self.is_debug {
328                        let ms_from_last_run = run_start.elapsed().as_secs_f64();
329                        let total_len = scheduled.prompt.len() + scheduled.completion.len();
330                        if total_len > 0 {
331                            let prompt_lengths = scheduled
332                                .prompt
333                                .iter()
334                                .map(|seq| seq.len().to_string())
335                                .collect::<Vec<_>>()
336                                .join(", ");
337
338                            let completion_lengths = scheduled
339                                .completion
340                                .iter()
341                                .map(|seq| seq.len().to_string())
342                                .collect::<Vec<_>>()
343                                .join(", ");
344
345                            tracing::info!(
346                                "Prompt[{}] Completion[{}] - {}ms",
347                                prompt_lengths,
348                                completion_lengths,
349                                ms_from_last_run * 1000.,
350                            );
351                        }
352                    }
353                }
354                SchedulerOutput::PagedAttention { mut output } => {
355                    if !output.scheduled.is_empty() {
356                        let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
357
358                        let mut guards = output
359                            .scheduled
360                            .iter_mut()
361                            .map(|seq| seq.lock().unwrap())
362                            .collect::<Vec<_>>();
363
364                        let mut guards_mut =
365                            guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
366
367                        let res = {
368                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
369
370                            let block_size = scheduler.block_size().unwrap();
371
372                            let metadata = PagedAttentionMeta {
373                                block_size,
374                                sliding_window: pipeline.get_metadata().sliding_window,
375                                block_engine: scheduler.block_engine().unwrap(),
376                            };
377
378                            let return_raw_logits = guards_mut[0].return_raw_logits;
379                            assert!(
380                                guards_mut
381                                    .iter()
382                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
383                                "All sequences must either return raw logits, or not."
384                            );
385
386                            pipeline
387                                .step(
388                                    &mut guards_mut,
389                                    is_prompt,
390                                    return_raw_logits,
391                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
392                                    self.disable_eos_stop,
393                                    rng.clone(),
394                                    CacheBackendMetadata::PagedAttention {
395                                        metadata,
396                                        blocks_to_copy: output.blocks_to_copy,
397                                        blocks_to_swap_in: output.blocks_to_swap_in,
398                                        blocks_to_swap_out: output.blocks_to_swap_out,
399                                    },
400                                )
401                                .await
402                        };
403
404                        handle_pipeline_forward_error!(
405                            "step",
406                            res,
407                            &mut guards_mut,
408                            self.pipeline,
409                            'lp,
410                            self.prefix_cacher
411                        );
412
413                        let total_processed_tokens: usize = guards
414                            .iter()
415                            .map(|seq| {
416                                if seq.is_prompt() {
417                                    seq.get_toks().len()
418                                } else {
419                                    1
420                                }
421                            })
422                            .sum();
423                        self.logger.add_tokens_processed(total_processed_tokens);
424
425                        if self.is_debug {
426                            let ms_from_last_run = run_start.elapsed().as_secs_f64();
427                            let total_len = guards.len();
428                            if total_len > 0 {
429                                let lengths = guards
430                                    .iter()
431                                    .map(|seq| seq.len().to_string())
432                                    .collect::<Vec<_>>()
433                                    .join(", ");
434
435                                let (prompt_lengths, completion_lengths) = if is_prompt {
436                                    (lengths, "".to_string())
437                                } else {
438                                    ("".to_string(), lengths)
439                                };
440
441                                tracing::info!(
442                                    "Prompt[{}] Completion[{}] - {}ms",
443                                    prompt_lengths,
444                                    completion_lengths,
445                                    ms_from_last_run * 1000.,
446                                );
447                            }
448                        }
449
450                        if is_prompt {
451                            for mut seq in guards {
452                                let now = SystemTime::now()
453                                    .duration_since(UNIX_EPOCH)
454                                    .expect("Time travel has occurred!")
455                                    .as_millis();
456                                #[allow(clippy::cast_precision_loss)]
457                                let prompt_tok_per_sec =
458                                    seq.len() as f32 / (now - seq.timestamp()) as f32;
459                                seq.prompt_tok_per_sec = prompt_tok_per_sec * 1000.;
460                                seq.prompt_timestamp = Some(now);
461                                seq.total_prompt_time = Some(now - seq.timestamp());
462                            }
463                        }
464                    }
465                }
466            }
467
468            scheduler.free_finished_sequence_groups();
469        }
470    }
471
472    fn build_sequence_recognizer(
473        factory: &Option<Arc<ParserFactory>>,
474        constraint: &Constraint,
475    ) -> anyhow::Result<SequenceRecognizer> {
476        if let Some(grm) = llg_grammar_from_constraint(constraint)? {
477            let factory = factory
478                .as_ref()
479                .ok_or_else(|| anyhow::anyhow!("No token environment (llg_factory) found."))?;
480            let llg = constraint_from_llg_grammar(factory, grm)?;
481            Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
482        } else {
483            Ok(SequenceRecognizer::None)
484        }
485    }
486
487    fn replicate_request_to_daemons(&self, request: &Request) {
488        if !distributed::is_daemon() && mistralrs_quant::distributed::use_nccl() {
489            let name = distributed::ipc_name().unwrap();
490            let num_workers =
491                mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
492            let listener = ListenerOptions::new().name(name).create_sync().unwrap();
493
494            for _ in 0..num_workers {
495                let stream = listener.accept().unwrap();
496                let mut writer = BufWriter::new(stream);
497                let req = format!("{}\n", serde_json::to_string(&request).unwrap());
498                writer.write_all(req.as_bytes()).unwrap();
499            }
500        };
501    }
502}