mistralrs_core/engine/
mod.rs

1use crate::{
2    distributed,
3    embedding::bert::BertPipeline,
4    pipeline::{
5        llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
6        text_models_inputs_processor::PagedAttentionMeta,
7        CacheBackendMetadata, CacheInstruction,
8    },
9    prefix_cacher::PrefixCacheManagerV2,
10    response::CompletionChoice,
11    scheduler::{Scheduler, SchedulerOutput},
12    sequence::{SeqStepType, StopReason},
13    CompletionResponse, SchedulerConfig, DEBUG,
14};
15use interprocess::local_socket::{traits::Listener, ListenerOptions};
16use llguidance::toktrie::TokEnv;
17use logger::IntervalLogger;
18use once_cell::sync::Lazy;
19use rand::SeedableRng;
20use rand_isaac::Isaac64Rng;
21use std::{
22    collections::HashMap,
23    io::{BufWriter, Write},
24    ops::Deref,
25    sync::{
26        atomic::{AtomicBool, Ordering},
27        Arc,
28    },
29    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
30};
31use tokio::{
32    sync::{mpsc::Receiver, Mutex},
33    task::JoinHandle,
34};
35
36use crate::{
37    get_mut_arcmutex, handle_pipeline_forward_error,
38    pipeline::Pipeline,
39    request::Request,
40    response::{ChatCompletionResponse, Choice, ResponseMessage},
41    sequence::{SequenceRecognizer, SequenceState},
42    Constraint,
43};
44
45mod add_request;
46mod logger;
47
48pub enum EngineInstruction {
49    Terminate,
50}
51
52#[derive(Debug, Default, Clone)]
53/// Embedding model used for ranking web search results internally.
54pub enum BertEmbeddingModel {
55    #[default]
56    SnowflakeArcticEmbedL,
57    Custom(String),
58}
59
60const SEED: u64 = 0;
61/// Terminate all sequences on the next scheduling step. Be sure to reset this.
62pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
63
64/// Engine instructions, per Engine (MistralRs) ID.
65pub static ENGINE_INSTRUCTIONS: Lazy<std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>> =
66    Lazy::new(|| std::sync::Mutex::new(HashMap::new()));
67
68pub struct Engine {
69    rx: Arc<Mutex<Receiver<Request>>>,
70    pipeline: Arc<Mutex<dyn Pipeline>>,
71    bert_pipeline: Arc<Mutex<Option<BertPipeline>>>,
72    scheduler: Arc<Mutex<dyn Scheduler>>,
73    id: Arc<Mutex<usize>>,
74    truncate_sequence: bool,
75    no_kv_cache: bool,
76    prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
77    is_debug: bool,
78    disable_eos_stop: bool,
79    throughput_logging_enabled: bool,
80    logger: IntervalLogger,
81    handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
82}
83
84impl Drop for Engine {
85    fn drop(&mut self) {
86        for handle in &*get_mut_arcmutex!(self.handles) {
87            handle.abort();
88        }
89    }
90}
91
92impl Engine {
93    #[allow(clippy::too_many_arguments)]
94    pub fn new(
95        rx: Receiver<Request>,
96        pipeline: Arc<Mutex<dyn Pipeline>>,
97        config: SchedulerConfig,
98        truncate_sequence: bool,
99        mut no_kv_cache: bool,
100        mut no_prefix_cache: bool,
101        prefix_cache_n: usize,
102        disable_eos_stop: bool,
103        throughput_logging_enabled: bool,
104        search_embedding_model: Option<BertEmbeddingModel>,
105    ) -> anyhow::Result<Self> {
106        no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
107
108        no_prefix_cache = matches!(config, SchedulerConfig::PagedAttentionMeta { .. })
109            || no_prefix_cache
110            || no_kv_cache;
111
112        let bert_pipeline = match search_embedding_model {
113            Some(search_embedding_model) => Some(BertPipeline::new(
114                search_embedding_model,
115                &get_mut_arcmutex!(pipeline).device(),
116            )?),
117            None => None,
118        };
119
120        Ok(Self {
121            rx: Arc::new(Mutex::new(rx)),
122            pipeline,
123            bert_pipeline: Arc::new(Mutex::new(bert_pipeline)),
124            scheduler: config.into_scheduler(),
125            id: Arc::new(Mutex::new(0)),
126            truncate_sequence,
127            no_kv_cache,
128            prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
129                prefix_cache_n,
130                no_prefix_cache,
131            ))),
132            is_debug: DEBUG.load(Ordering::Relaxed),
133            disable_eos_stop,
134            throughput_logging_enabled,
135            logger: IntervalLogger::new(Duration::from_secs(5)),
136            handles: Arc::new(Mutex::new(Vec::new())),
137        })
138    }
139
140    pub async fn run(self: Arc<Self>) {
141        if self.throughput_logging_enabled {
142            self.logger.enable_logging();
143        }
144
145        let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
146        let mut last_completion_ids: Vec<usize> = vec![];
147        'lp: loop {
148            if matches!(
149                ENGINE_INSTRUCTIONS
150                    .lock()
151                    .expect("`ENGINE_INSTRUCTIONS` was poisioned")
152                    .get(get_mut_arcmutex!(self.id).deref()),
153                Some(Some(EngineInstruction::Terminate))
154            ) {
155                self.replicate_request_to_daemons(&Request::Terminate);
156                break 'lp;
157            }
158
159            while let Ok(request) = get_mut_arcmutex!(self.rx).try_recv() {
160                self.replicate_request_to_daemons(&request);
161                if matches!(request, Request::Terminate) {
162                    break 'lp;
163                }
164                self.clone().handle_request(request).await;
165            }
166
167            if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
168                self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
169            }
170
171            let run_start = Instant::now();
172            let mut scheduler = get_mut_arcmutex!(self.scheduler);
173            let scheduled = scheduler.schedule();
174
175            match scheduled {
176                SchedulerOutput::DefaultScheduler {
177                    output: mut scheduled,
178                } => {
179                    if !scheduled.completion.is_empty() {
180                        let current_completion_ids: Vec<usize> =
181                            scheduled.completion.iter().map(|seq| *seq.id()).collect();
182                        let res = {
183                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
184                            let pre_op = if !self.no_kv_cache
185                                && last_completion_ids != current_completion_ids
186                            {
187                                CacheInstruction::In
188                            } else {
189                                CacheInstruction::Nothing
190                            };
191                            let post_op = if !self.no_kv_cache {
192                                CacheInstruction::Out
193                            } else {
194                                CacheInstruction::Reset {
195                                    load_preallocated_cache: false,
196                                    reset_non_granular: false,
197                                }
198                            };
199
200                            let return_raw_logits = scheduled.completion[0].return_raw_logits;
201                            assert!(
202                                scheduled
203                                    .completion
204                                    .iter()
205                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
206                                "All sequences must either return raw logits, or not."
207                            );
208
209                            pipeline
210                                .step(
211                                    &mut scheduled.completion,
212                                    false,
213                                    return_raw_logits,
214                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
215                                    self.disable_eos_stop,
216                                    rng.clone(),
217                                    CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
218                                )
219                                .await
220                        };
221
222                        handle_pipeline_forward_error!(
223                            "completion step",
224                            res,
225                            &mut scheduled.completion,
226                            self.pipeline,
227                            'lp,
228                            self.prefix_cacher
229                        );
230
231                        self.logger.add_tokens_processed(scheduled.completion.len());
232
233                        last_completion_ids = current_completion_ids;
234                    }
235
236                    if !scheduled.prompt.is_empty() {
237                        let prompt_exec_time = {
238                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
239
240                            // Run the prompt seqs
241                            let post_op = if !self.no_kv_cache {
242                                CacheInstruction::Out
243                            } else {
244                                CacheInstruction::Reset {
245                                    load_preallocated_cache: false,
246                                    reset_non_granular: false,
247                                }
248                            };
249
250                            let return_raw_logits = scheduled.prompt[0].return_raw_logits;
251                            assert!(
252                                scheduled
253                                    .prompt
254                                    .iter()
255                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
256                                "All sequences must either return raw logits, or not."
257                            );
258
259                            // This comes from prefix caching
260                            // The invariant where all token offsets are the same is handled by the scheduler
261                            let pre_op = if scheduled.prompt[0].token_offset() != 0 {
262                                CacheInstruction::In
263                            } else {
264                                CacheInstruction::Reset {
265                                    load_preallocated_cache: true,
266                                    reset_non_granular: false,
267                                }
268                            };
269
270                            pipeline
271                                .step(
272                                    &mut scheduled.prompt,
273                                    true,
274                                    return_raw_logits,
275                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
276                                    self.disable_eos_stop,
277                                    rng.clone(),
278                                    CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
279                                )
280                                .await
281                        };
282
283                        let prompt_exec_time = handle_pipeline_forward_error!(
284                            "prompt step",
285                            prompt_exec_time,
286                            &mut scheduled.prompt,
287                            self.pipeline,
288                            'lp,
289                            self.prefix_cacher
290                        );
291
292                        let total_processed_tokens: usize = scheduled
293                            .prompt
294                            .iter()
295                            .map(|seq| seq.get_toks().len())
296                            .sum();
297                        self.logger.add_tokens_processed(total_processed_tokens);
298
299                        for seq in scheduled.prompt.iter_mut() {
300                            match seq.sequence_stepping_type() {
301                                SeqStepType::OneShot => {
302                                    seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
303                                }
304                                SeqStepType::PromptAndDecode => {
305                                    seq.set_state(SequenceState::RunningCompletion)
306                                }
307                            }
308                            let now = SystemTime::now()
309                                .duration_since(UNIX_EPOCH)
310                                .expect("Time travel has occurred!")
311                                .as_millis();
312                            #[allow(clippy::cast_precision_loss)]
313                            let prompt_tok_per_sec =
314                                seq.len() as f32 / prompt_exec_time.as_secs_f32();
315                            seq.prompt_tok_per_sec = prompt_tok_per_sec;
316                            seq.prompt_timestamp = Some(now);
317                        }
318                        last_completion_ids = vec![];
319                    }
320
321                    if self.is_debug {
322                        let ms_from_last_run = run_start.elapsed().as_secs_f64();
323                        let total_len = scheduled.prompt.len() + scheduled.completion.len();
324                        if total_len > 0 {
325                            let prompt_lengths = scheduled
326                                .prompt
327                                .iter()
328                                .map(|seq| seq.len().to_string())
329                                .collect::<Vec<_>>()
330                                .join(", ");
331
332                            let completion_lengths = scheduled
333                                .completion
334                                .iter()
335                                .map(|seq| seq.len().to_string())
336                                .collect::<Vec<_>>()
337                                .join(", ");
338
339                            tracing::info!(
340                                "Prompt[{}] Completion[{}] - {}ms",
341                                prompt_lengths,
342                                completion_lengths,
343                                ms_from_last_run * 1000.,
344                            );
345                        }
346                    }
347                }
348                SchedulerOutput::PagedAttention { mut output } => {
349                    if !output.scheduled.is_empty() {
350                        let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
351
352                        let mut guards = output
353                            .scheduled
354                            .iter_mut()
355                            .map(|seq| seq.lock().unwrap())
356                            .collect::<Vec<_>>();
357
358                        let mut guards_mut =
359                            guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
360
361                        let res = {
362                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
363
364                            let block_size = scheduler.block_size().unwrap();
365
366                            let metadata = PagedAttentionMeta {
367                                block_size,
368                                sliding_window: pipeline.get_metadata().sliding_window,
369                                block_engine: scheduler.block_engine().unwrap(),
370                            };
371
372                            let return_raw_logits = guards_mut[0].return_raw_logits;
373                            assert!(
374                                guards_mut
375                                    .iter()
376                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
377                                "All sequences must either return raw logits, or not."
378                            );
379
380                            pipeline
381                                .step(
382                                    &mut guards_mut,
383                                    is_prompt,
384                                    return_raw_logits,
385                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
386                                    self.disable_eos_stop,
387                                    rng.clone(),
388                                    CacheBackendMetadata::PagedAttention {
389                                        metadata,
390                                        blocks_to_copy: output.blocks_to_copy,
391                                        blocks_to_swap_in: output.blocks_to_swap_in,
392                                        blocks_to_swap_out: output.blocks_to_swap_out,
393                                    },
394                                )
395                                .await
396                        };
397
398                        handle_pipeline_forward_error!(
399                            "step",
400                            res,
401                            &mut guards_mut,
402                            self.pipeline,
403                            'lp,
404                            self.prefix_cacher
405                        );
406
407                        let total_processed_tokens: usize = guards
408                            .iter()
409                            .map(|seq| {
410                                if seq.is_prompt() {
411                                    seq.get_toks().len()
412                                } else {
413                                    1
414                                }
415                            })
416                            .sum();
417                        self.logger.add_tokens_processed(total_processed_tokens);
418
419                        if self.is_debug {
420                            let ms_from_last_run = run_start.elapsed().as_secs_f64();
421                            let total_len = guards.len();
422                            if total_len > 0 {
423                                let lengths = guards
424                                    .iter()
425                                    .map(|seq| seq.len().to_string())
426                                    .collect::<Vec<_>>()
427                                    .join(", ");
428
429                                let (prompt_lengths, completion_lengths) = if is_prompt {
430                                    (lengths, "".to_string())
431                                } else {
432                                    ("".to_string(), lengths)
433                                };
434
435                                tracing::info!(
436                                    "Prompt[{}] Completion[{}] - {}ms",
437                                    prompt_lengths,
438                                    completion_lengths,
439                                    ms_from_last_run * 1000.,
440                                );
441                            }
442                        }
443
444                        if is_prompt {
445                            for mut seq in guards {
446                                let now = SystemTime::now()
447                                    .duration_since(UNIX_EPOCH)
448                                    .expect("Time travel has occurred!")
449                                    .as_millis();
450                                #[allow(clippy::cast_precision_loss)]
451                                let prompt_tok_per_sec =
452                                    seq.len() as f32 / (now - seq.timestamp()) as f32;
453                                seq.prompt_tok_per_sec = prompt_tok_per_sec * 1000.;
454                                seq.prompt_timestamp = Some(now);
455                            }
456                        }
457                    }
458                }
459            }
460
461            scheduler.free_finished_sequence_groups();
462        }
463    }
464
465    fn build_sequence_recognizer(
466        tok_env: &Option<TokEnv>,
467        constraint: &Constraint,
468    ) -> anyhow::Result<SequenceRecognizer> {
469        if let Some(grm) = llg_grammar_from_constraint(constraint)? {
470            let tok_env = tok_env
471                .as_ref()
472                .ok_or_else(|| anyhow::anyhow!("No token environment found."))?;
473            let llg = constraint_from_llg_grammar(tok_env.clone(), grm)?;
474            Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
475        } else {
476            Ok(SequenceRecognizer::None)
477        }
478    }
479
480    fn replicate_request_to_daemons(&self, request: &Request) {
481        if !distributed::is_daemon() && mistralrs_quant::distributed::use_nccl() {
482            let name = distributed::ipc_name().unwrap();
483            let num_workers =
484                mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
485            let listener = ListenerOptions::new().name(name).create_sync().unwrap();
486
487            for _ in 0..num_workers {
488                let stream = listener.accept().unwrap();
489                let mut writer = BufWriter::new(stream);
490                let req = format!("{}\n", serde_json::to_string(&request).unwrap());
491                writer.write_all(req.as_bytes()).unwrap();
492            }
493        };
494    }
495}