mistralrs_core/engine/
mod.rs

1use crate::{
2    distributed,
3    embedding::bert::BertPipeline,
4    pipeline::{
5        llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
6        text_models_inputs_processor::PagedAttentionMeta,
7        CacheBackendMetadata, CacheInstruction,
8    },
9    prefix_cacher::PrefixCacheManagerV2,
10    response::CompletionChoice,
11    scheduler::{Scheduler, SchedulerOutput},
12    sequence::{SeqStepType, StopReason},
13    CompletionResponse, SchedulerConfig, DEBUG,
14};
15use interprocess::local_socket::{traits::Listener, ListenerOptions};
16use llguidance::ParserFactory;
17use logger::IntervalLogger;
18use once_cell::sync::Lazy;
19use rand::SeedableRng;
20use rand_isaac::Isaac64Rng;
21use std::{
22    collections::HashMap,
23    io::{BufWriter, Write},
24    ops::Deref,
25    sync::{
26        atomic::{AtomicBool, Ordering},
27        Arc,
28    },
29    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
30};
31use tokio::{
32    sync::{mpsc::Receiver, Mutex},
33    task::JoinHandle,
34};
35
36use crate::{
37    get_mut_arcmutex, handle_pipeline_forward_error,
38    pipeline::Pipeline,
39    request::Request,
40    response::{ChatCompletionResponse, Choice, ResponseMessage},
41    sequence::{SequenceRecognizer, SequenceState},
42    Constraint,
43};
44
45mod add_request;
46mod logger;
47mod search_request;
48
49pub enum EngineInstruction {
50    Terminate,
51}
52
53#[derive(Debug, Default, Clone)]
54/// Embedding model used for ranking web search results internally.
55pub enum BertEmbeddingModel {
56    #[default]
57    SnowflakeArcticEmbedL,
58    Custom(String),
59}
60
61const SEED: u64 = 0;
62/// Terminate all sequences on the next scheduling step. Be sure to reset this.
63pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
64
65/// Engine instructions, per Engine (MistralRs) ID.
66pub static ENGINE_INSTRUCTIONS: Lazy<std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>> =
67    Lazy::new(|| std::sync::Mutex::new(HashMap::new()));
68
69pub struct Engine {
70    rx: Arc<Mutex<Receiver<Request>>>,
71    pipeline: Arc<Mutex<dyn Pipeline>>,
72    bert_pipeline: Arc<Mutex<Option<BertPipeline>>>,
73    scheduler: Arc<Mutex<dyn Scheduler>>,
74    id: Arc<Mutex<usize>>,
75    truncate_sequence: bool,
76    no_kv_cache: bool,
77    prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
78    is_debug: bool,
79    disable_eos_stop: bool,
80    throughput_logging_enabled: bool,
81    logger: IntervalLogger,
82    handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
83}
84
85impl Drop for Engine {
86    fn drop(&mut self) {
87        for handle in &*get_mut_arcmutex!(self.handles) {
88            handle.abort();
89        }
90    }
91}
92
93impl Engine {
94    #[allow(clippy::too_many_arguments)]
95    pub fn new(
96        rx: Receiver<Request>,
97        pipeline: Arc<Mutex<dyn Pipeline>>,
98        config: SchedulerConfig,
99        truncate_sequence: bool,
100        mut no_kv_cache: bool,
101        mut no_prefix_cache: bool,
102        prefix_cache_n: usize,
103        disable_eos_stop: bool,
104        throughput_logging_enabled: bool,
105        search_embedding_model: Option<BertEmbeddingModel>,
106    ) -> anyhow::Result<Self> {
107        no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
108
109        no_prefix_cache = matches!(config, SchedulerConfig::PagedAttentionMeta { .. })
110            || no_prefix_cache
111            || no_kv_cache;
112
113        let bert_pipeline = match search_embedding_model {
114            Some(search_embedding_model) => Some(BertPipeline::new(
115                search_embedding_model,
116                &get_mut_arcmutex!(pipeline).device(),
117            )?),
118            None => None,
119        };
120
121        Ok(Self {
122            rx: Arc::new(Mutex::new(rx)),
123            pipeline,
124            bert_pipeline: Arc::new(Mutex::new(bert_pipeline)),
125            scheduler: config.into_scheduler(),
126            id: Arc::new(Mutex::new(0)),
127            truncate_sequence,
128            no_kv_cache,
129            prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
130                prefix_cache_n,
131                no_prefix_cache,
132            ))),
133            is_debug: DEBUG.load(Ordering::Relaxed),
134            disable_eos_stop,
135            throughput_logging_enabled,
136            logger: IntervalLogger::new(Duration::from_secs(5)),
137            handles: Arc::new(Mutex::new(Vec::new())),
138        })
139    }
140
141    pub async fn run(self: Arc<Self>) {
142        if self.throughput_logging_enabled {
143            self.logger.enable_logging();
144        }
145
146        let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
147        let mut last_completion_ids: Vec<usize> = vec![];
148        'lp: loop {
149            if matches!(
150                ENGINE_INSTRUCTIONS
151                    .lock()
152                    .expect("`ENGINE_INSTRUCTIONS` was poisioned")
153                    .get(get_mut_arcmutex!(self.id).deref()),
154                Some(Some(EngineInstruction::Terminate))
155            ) {
156                self.replicate_request_to_daemons(&Request::Terminate);
157                break 'lp;
158            }
159
160            while let Ok(request) = get_mut_arcmutex!(self.rx).try_recv() {
161                self.replicate_request_to_daemons(&request);
162                if matches!(request, Request::Terminate) {
163                    break 'lp;
164                }
165                self.clone().handle_request(request).await;
166            }
167
168            if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
169                self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
170            }
171
172            let run_start = Instant::now();
173            let mut scheduler = get_mut_arcmutex!(self.scheduler);
174            let scheduled = scheduler.schedule();
175
176            match scheduled {
177                SchedulerOutput::DefaultScheduler {
178                    output: mut scheduled,
179                } => {
180                    if !scheduled.completion.is_empty() {
181                        let current_completion_ids: Vec<usize> =
182                            scheduled.completion.iter().map(|seq| *seq.id()).collect();
183                        let res = {
184                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
185                            let pre_op = if !self.no_kv_cache
186                                && last_completion_ids != current_completion_ids
187                            {
188                                CacheInstruction::In
189                            } else {
190                                CacheInstruction::Nothing
191                            };
192                            let post_op = if !self.no_kv_cache {
193                                CacheInstruction::Out
194                            } else {
195                                CacheInstruction::Reset {
196                                    load_preallocated_cache: false,
197                                    reset_non_granular: false,
198                                }
199                            };
200
201                            let return_raw_logits = scheduled.completion[0].return_raw_logits;
202                            assert!(
203                                scheduled
204                                    .completion
205                                    .iter()
206                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
207                                "All sequences must either return raw logits, or not."
208                            );
209
210                            pipeline
211                                .step(
212                                    &mut scheduled.completion,
213                                    false,
214                                    return_raw_logits,
215                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
216                                    self.disable_eos_stop,
217                                    rng.clone(),
218                                    CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
219                                )
220                                .await
221                        };
222
223                        handle_pipeline_forward_error!(
224                            "completion step",
225                            res,
226                            &mut scheduled.completion,
227                            self.pipeline,
228                            'lp,
229                            self.prefix_cacher
230                        );
231
232                        self.logger.add_tokens_processed(scheduled.completion.len());
233
234                        last_completion_ids = current_completion_ids;
235                    }
236
237                    if !scheduled.prompt.is_empty() {
238                        let prompt_exec_time = {
239                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
240
241                            // Run the prompt seqs
242                            let post_op = if !self.no_kv_cache {
243                                CacheInstruction::Out
244                            } else {
245                                CacheInstruction::Reset {
246                                    load_preallocated_cache: false,
247                                    reset_non_granular: false,
248                                }
249                            };
250
251                            let return_raw_logits = scheduled.prompt[0].return_raw_logits;
252                            assert!(
253                                scheduled
254                                    .prompt
255                                    .iter()
256                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
257                                "All sequences must either return raw logits, or not."
258                            );
259
260                            // This comes from prefix caching
261                            // The invariant where all token offsets are the same is handled by the scheduler
262                            let pre_op = if scheduled.prompt[0].token_offset() != 0 {
263                                CacheInstruction::In
264                            } else {
265                                CacheInstruction::Reset {
266                                    load_preallocated_cache: true,
267                                    reset_non_granular: false,
268                                }
269                            };
270
271                            pipeline
272                                .step(
273                                    &mut scheduled.prompt,
274                                    true,
275                                    return_raw_logits,
276                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
277                                    self.disable_eos_stop,
278                                    rng.clone(),
279                                    CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
280                                )
281                                .await
282                        };
283
284                        let prompt_exec_time = handle_pipeline_forward_error!(
285                            "prompt step",
286                            prompt_exec_time,
287                            &mut scheduled.prompt,
288                            self.pipeline,
289                            'lp,
290                            self.prefix_cacher
291                        );
292
293                        let total_processed_tokens: usize = scheduled
294                            .prompt
295                            .iter()
296                            .map(|seq| seq.get_toks().len())
297                            .sum();
298                        self.logger.add_tokens_processed(total_processed_tokens);
299
300                        for seq in scheduled.prompt.iter_mut() {
301                            match seq.sequence_stepping_type() {
302                                SeqStepType::OneShot => {
303                                    seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
304                                }
305                                SeqStepType::PromptAndDecode => {
306                                    seq.set_state(SequenceState::RunningCompletion)
307                                }
308                            }
309                            let now = SystemTime::now()
310                                .duration_since(UNIX_EPOCH)
311                                .expect("Time travel has occurred!")
312                                .as_millis();
313                            #[allow(clippy::cast_precision_loss)]
314                            let prompt_tok_per_sec =
315                                seq.len() as f32 / prompt_exec_time.as_secs_f32();
316                            seq.prompt_tok_per_sec = prompt_tok_per_sec;
317                            seq.prompt_timestamp = Some(now);
318                            seq.total_prompt_time = Some(prompt_exec_time.as_millis());
319                        }
320                        last_completion_ids = vec![];
321                    }
322
323                    if self.is_debug {
324                        let ms_from_last_run = run_start.elapsed().as_secs_f64();
325                        let total_len = scheduled.prompt.len() + scheduled.completion.len();
326                        if total_len > 0 {
327                            let prompt_lengths = scheduled
328                                .prompt
329                                .iter()
330                                .map(|seq| seq.len().to_string())
331                                .collect::<Vec<_>>()
332                                .join(", ");
333
334                            let completion_lengths = scheduled
335                                .completion
336                                .iter()
337                                .map(|seq| seq.len().to_string())
338                                .collect::<Vec<_>>()
339                                .join(", ");
340
341                            tracing::info!(
342                                "Prompt[{}] Completion[{}] - {}ms",
343                                prompt_lengths,
344                                completion_lengths,
345                                ms_from_last_run * 1000.,
346                            );
347                        }
348                    }
349                }
350                SchedulerOutput::PagedAttention { mut output } => {
351                    if !output.scheduled.is_empty() {
352                        let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
353
354                        let mut guards = output
355                            .scheduled
356                            .iter_mut()
357                            .map(|seq| seq.lock().unwrap())
358                            .collect::<Vec<_>>();
359
360                        let mut guards_mut =
361                            guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
362
363                        let res = {
364                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
365
366                            let block_size = scheduler.block_size().unwrap();
367
368                            let metadata = PagedAttentionMeta {
369                                block_size,
370                                sliding_window: pipeline.get_metadata().sliding_window,
371                                block_engine: scheduler.block_engine().unwrap(),
372                            };
373
374                            let return_raw_logits = guards_mut[0].return_raw_logits;
375                            assert!(
376                                guards_mut
377                                    .iter()
378                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
379                                "All sequences must either return raw logits, or not."
380                            );
381
382                            pipeline
383                                .step(
384                                    &mut guards_mut,
385                                    is_prompt,
386                                    return_raw_logits,
387                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
388                                    self.disable_eos_stop,
389                                    rng.clone(),
390                                    CacheBackendMetadata::PagedAttention {
391                                        metadata,
392                                        blocks_to_copy: output.blocks_to_copy,
393                                        blocks_to_swap_in: output.blocks_to_swap_in,
394                                        blocks_to_swap_out: output.blocks_to_swap_out,
395                                    },
396                                )
397                                .await
398                        };
399
400                        handle_pipeline_forward_error!(
401                            "step",
402                            res,
403                            &mut guards_mut,
404                            self.pipeline,
405                            'lp,
406                            self.prefix_cacher
407                        );
408
409                        let total_processed_tokens: usize = guards
410                            .iter()
411                            .map(|seq| {
412                                if seq.is_prompt() {
413                                    seq.get_toks().len()
414                                } else {
415                                    1
416                                }
417                            })
418                            .sum();
419                        self.logger.add_tokens_processed(total_processed_tokens);
420
421                        if self.is_debug {
422                            let ms_from_last_run = run_start.elapsed().as_secs_f64();
423                            let total_len = guards.len();
424                            if total_len > 0 {
425                                let lengths = guards
426                                    .iter()
427                                    .map(|seq| seq.len().to_string())
428                                    .collect::<Vec<_>>()
429                                    .join(", ");
430
431                                let (prompt_lengths, completion_lengths) = if is_prompt {
432                                    (lengths, "".to_string())
433                                } else {
434                                    ("".to_string(), lengths)
435                                };
436
437                                tracing::info!(
438                                    "Prompt[{}] Completion[{}] - {}ms",
439                                    prompt_lengths,
440                                    completion_lengths,
441                                    ms_from_last_run * 1000.,
442                                );
443                            }
444                        }
445
446                        if is_prompt {
447                            for mut seq in guards {
448                                let now = SystemTime::now()
449                                    .duration_since(UNIX_EPOCH)
450                                    .expect("Time travel has occurred!")
451                                    .as_millis();
452                                #[allow(clippy::cast_precision_loss)]
453                                let prompt_tok_per_sec =
454                                    seq.len() as f32 / (now - seq.timestamp()) as f32;
455                                seq.prompt_tok_per_sec = prompt_tok_per_sec * 1000.;
456                                seq.prompt_timestamp = Some(now);
457                                seq.total_prompt_time = Some(now - seq.timestamp());
458                            }
459                        }
460                    }
461                }
462            }
463
464            scheduler.free_finished_sequence_groups();
465        }
466    }
467
468    fn build_sequence_recognizer(
469        factory: &Option<Arc<ParserFactory>>,
470        constraint: &Constraint,
471    ) -> anyhow::Result<SequenceRecognizer> {
472        if let Some(grm) = llg_grammar_from_constraint(constraint)? {
473            let factory = factory
474                .as_ref()
475                .ok_or_else(|| anyhow::anyhow!("No token environment (llg_factory) found."))?;
476            let llg = constraint_from_llg_grammar(factory, grm)?;
477            Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
478        } else {
479            Ok(SequenceRecognizer::None)
480        }
481    }
482
483    fn replicate_request_to_daemons(&self, request: &Request) {
484        if !distributed::is_daemon() && mistralrs_quant::distributed::use_nccl() {
485            let name = distributed::ipc_name().unwrap();
486            let num_workers =
487                mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
488            let listener = ListenerOptions::new().name(name).create_sync().unwrap();
489
490            for _ in 0..num_workers {
491                let stream = listener.accept().unwrap();
492                let mut writer = BufWriter::new(stream);
493                let req = format!("{}\n", serde_json::to_string(&request).unwrap());
494                writer.write_all(req.as_bytes()).unwrap();
495            }
496        };
497    }
498}