1use crate::{
2 distributed,
3 embedding::bert::BertPipeline,
4 pipeline::{
5 llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
6 text_models_inputs_processor::PagedAttentionMeta,
7 CacheBackendMetadata, CacheInstruction,
8 },
9 prefix_cacher::PrefixCacheManagerV2,
10 response::CompletionChoice,
11 scheduler::{Scheduler, SchedulerOutput},
12 search,
13 sequence::{SeqStepType, StopReason},
14 tools, CompletionResponse, SchedulerConfig, DEBUG,
15};
16use interprocess::local_socket::{traits::Listener, ListenerOptions};
17use llguidance::ParserFactory;
18pub use logger::IntervalLogger;
19use mistralrs_quant::RingConfig;
20use once_cell::sync::Lazy;
21use rand::SeedableRng;
22use rand_isaac::Isaac64Rng;
23use std::{
24 collections::HashMap,
25 io::{BufWriter, Write},
26 net::TcpListener,
27 ops::Deref,
28 sync::{
29 atomic::{AtomicBool, Ordering},
30 Arc,
31 },
32 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
33};
34use tokio::{
35 sync::{
36 mpsc::{error::TryRecvError, Receiver},
37 Mutex,
38 },
39 task::JoinHandle,
40};
41
42use crate::{
43 get_mut_arcmutex, handle_pipeline_forward_error,
44 pipeline::Pipeline,
45 request::Request,
46 response::{ChatCompletionResponse, Choice, ResponseMessage},
47 sequence::{SequenceRecognizer, SequenceState},
48 Constraint,
49};
50
51mod add_request;
52mod logger;
53mod search_request;
54
55pub enum EngineInstruction {
56 Terminate,
57}
58
59#[derive(Debug, Default, Clone)]
60pub enum BertEmbeddingModel {
62 #[default]
63 SnowflakeArcticEmbedL,
64 Custom(String),
65}
66
67const SEED: u64 = 0;
68pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
71
72static ENGINE_TERMINATE_FLAGS: Lazy<
74 std::sync::Mutex<HashMap<std::thread::ThreadId, Arc<AtomicBool>>>,
75> = Lazy::new(|| std::sync::Mutex::new(HashMap::new()));
76
77pub fn get_engine_terminate_flag() -> Arc<AtomicBool> {
79 let thread_id = std::thread::current().id();
80 let mut flags = ENGINE_TERMINATE_FLAGS.lock().unwrap();
81 flags
82 .entry(thread_id)
83 .or_insert_with(|| Arc::new(AtomicBool::new(false)))
84 .clone()
85}
86
87pub fn should_terminate_engine_sequences() -> bool {
89 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
91 return true;
92 }
93 let thread_id = std::thread::current().id();
95 if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
96 if let Some(flag) = flags.get(&thread_id) {
97 return flag.load(Ordering::SeqCst);
98 }
99 }
100 false
101}
102
103pub fn reset_engine_terminate_flag() {
105 let thread_id = std::thread::current().id();
106 if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
107 if let Some(flag) = flags.get(&thread_id) {
108 flag.store(false, Ordering::SeqCst);
109 }
110 }
111}
112
113pub static ENGINE_INSTRUCTIONS: Lazy<std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>> =
115 Lazy::new(|| std::sync::Mutex::new(HashMap::new()));
116
117pub struct Engine {
118 rx: Arc<Mutex<Receiver<Request>>>,
119 pipeline: Arc<Mutex<dyn Pipeline>>,
120 bert_pipeline: Arc<Mutex<Option<BertPipeline>>>,
121 search_callback: Option<Arc<search::SearchCallback>>,
122 tool_callbacks: tools::ToolCallbacks,
123 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
124 scheduler: Arc<Mutex<dyn Scheduler>>,
125 id: Arc<Mutex<usize>>,
126 no_kv_cache: bool,
127 prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
128 is_debug: bool,
129 disable_eos_stop: bool,
130 throughput_logging_enabled: bool,
131 logger: IntervalLogger,
132 handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
133}
134
135impl Drop for Engine {
136 fn drop(&mut self) {
137 for handle in &*get_mut_arcmutex!(self.handles) {
138 handle.abort();
139 }
140 }
141}
142
143impl Engine {
144 #[allow(clippy::too_many_arguments)]
145 pub fn new(
146 rx: Receiver<Request>,
147 pipeline: Arc<Mutex<dyn Pipeline>>,
148 config: SchedulerConfig,
149 mut no_kv_cache: bool,
150 mut no_prefix_cache: bool,
151 prefix_cache_n: usize,
152 disable_eos_stop: bool,
153 throughput_logging_enabled: bool,
154 search_embedding_model: Option<BertEmbeddingModel>,
155 search_callback: Option<Arc<search::SearchCallback>>,
156 tool_callbacks: tools::ToolCallbacks,
157 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
158 ) -> anyhow::Result<Self> {
159 no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
160
161 no_prefix_cache = no_prefix_cache
162 || no_kv_cache
163 || get_mut_arcmutex!(pipeline).get_metadata().no_prefix_cache
164 || prefix_cache_n == 0;
165
166 let bert_pipeline = match search_embedding_model {
167 Some(search_embedding_model) => Some(BertPipeline::new(
168 search_embedding_model,
169 &get_mut_arcmutex!(pipeline).device(),
170 )?),
171 None => None,
172 };
173
174 let scheduler = config.into_scheduler();
175 let block_engine = get_mut_arcmutex!(scheduler).block_engine();
176
177 Ok(Self {
178 rx: Arc::new(Mutex::new(rx)),
179 pipeline,
180 bert_pipeline: Arc::new(Mutex::new(bert_pipeline)),
181 search_callback,
182 tool_callbacks,
183 tool_callbacks_with_tools,
184 scheduler: scheduler.clone(),
185 id: Arc::new(Mutex::new(0)),
186 no_kv_cache,
187 prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
188 prefix_cache_n,
189 no_prefix_cache,
190 block_engine,
191 ))),
192 is_debug: DEBUG.load(Ordering::Relaxed),
193 disable_eos_stop,
194 throughput_logging_enabled,
195 logger: IntervalLogger::new(Duration::from_secs(5)),
196 handles: Arc::new(Mutex::new(Vec::new())),
197 })
198 }
199
200 pub async fn run(self: Arc<Self>) {
201 if self.throughput_logging_enabled {
202 self.logger.enable_logging();
203 }
204
205 let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
206 let mut last_completion_ids: Vec<usize> = vec![];
207 'lp: loop {
208 let should_terminate = || {
209 matches!(
210 ENGINE_INSTRUCTIONS
211 .lock()
212 .expect("`ENGINE_INSTRUCTIONS` was poisoned")
213 .get(get_mut_arcmutex!(self.id).deref()),
214 Some(Some(EngineInstruction::Terminate))
215 )
216 };
217
218 if should_terminate() {
219 self.replicate_request_to_daemons(&Request::Terminate);
220 break 'lp;
221 }
222
223 let mut channel_disconnected = false;
224 loop {
225 let next_request = {
226 let mut rx = self.rx.lock().await;
227 rx.try_recv()
228 };
229
230 match next_request {
231 Ok(request) => {
232 self.replicate_request_to_daemons(&request);
233 if matches!(request, Request::Terminate) {
234 break 'lp;
235 }
236 self.clone().handle_request(request).await;
237 }
238 Err(TryRecvError::Empty) => break,
239 Err(TryRecvError::Disconnected) => {
240 channel_disconnected = true;
241 break;
242 }
243 }
244 }
245
246 if channel_disconnected {
247 break 'lp;
248 }
249
250 let scheduler_idle = {
251 let scheduler = get_mut_arcmutex!(self.scheduler);
252 scheduler.waiting_len() == 0 && scheduler.running_len() == 0
253 };
254
255 if scheduler_idle {
256 if should_terminate() {
257 self.replicate_request_to_daemons(&Request::Terminate);
258 break 'lp;
259 }
260
261 let next_request = {
262 let mut rx = self.rx.lock().await;
263 rx.recv().await
264 };
265
266 match next_request {
267 Some(request) => {
268 self.replicate_request_to_daemons(&request);
269 if matches!(request, Request::Terminate) {
270 break 'lp;
271 }
272 self.clone().handle_request(request).await;
273 continue;
274 }
275 None => break 'lp,
276 }
277 }
278
279 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
280 self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
281 }
282
283 let run_start = Instant::now();
284 let mut scheduler = get_mut_arcmutex!(self.scheduler);
285 let scheduled = scheduler.schedule(&self.logger);
286
287 match scheduled {
288 SchedulerOutput::DefaultScheduler {
289 output: mut scheduled,
290 } => {
291 if !scheduled.completion.is_empty() {
292 let current_completion_ids: Vec<usize> =
293 scheduled.completion.iter().map(|seq| *seq.id()).collect();
294 let res = {
295 let mut pipeline = get_mut_arcmutex!(self.pipeline);
296 let pre_op = if !self.no_kv_cache
297 && last_completion_ids != current_completion_ids
298 {
299 CacheInstruction::In
300 } else {
301 CacheInstruction::Nothing
302 };
303 let post_op = if !self.no_kv_cache {
304 CacheInstruction::Out
305 } else {
306 CacheInstruction::Reset {
307 load_preallocated_cache: false,
308 reset_non_granular: false,
309 }
310 };
311
312 let return_raw_logits = scheduled.completion[0].return_raw_logits;
313 assert!(
314 scheduled
315 .completion
316 .iter()
317 .all(|seq| seq.return_raw_logits == return_raw_logits),
318 "All sequences must either return raw logits, or not."
319 );
320
321 pipeline
322 .step(
323 &mut scheduled.completion,
324 false,
325 return_raw_logits,
326 &mut *get_mut_arcmutex!(self.prefix_cacher),
327 self.disable_eos_stop,
328 rng.clone(),
329 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
330 )
331 .await
332 };
333
334 handle_pipeline_forward_error!(
335 "completion step",
336 res,
337 &mut scheduled.completion,
338 self.pipeline,
339 'lp,
340 self.prefix_cacher
341 );
342
343 self.logger.add_tokens_processed(scheduled.completion.len());
344
345 last_completion_ids = current_completion_ids;
346 }
347
348 if !scheduled.prompt.is_empty() {
349 let prompt_exec_time = {
350 let mut pipeline = get_mut_arcmutex!(self.pipeline);
351
352 let post_op = if !self.no_kv_cache {
354 CacheInstruction::Out
355 } else {
356 CacheInstruction::Reset {
357 load_preallocated_cache: false,
358 reset_non_granular: false,
359 }
360 };
361
362 let return_raw_logits = scheduled.prompt[0].return_raw_logits;
363 assert!(
364 scheduled
365 .prompt
366 .iter()
367 .all(|seq| seq.return_raw_logits == return_raw_logits),
368 "All sequences must either return raw logits, or not."
369 );
370
371 let pre_op = if scheduled.prompt[0].token_offset() != 0 {
374 CacheInstruction::In
375 } else {
376 CacheInstruction::Reset {
377 load_preallocated_cache: true,
378 reset_non_granular: false,
379 }
380 };
381
382 pipeline
383 .step(
384 &mut scheduled.prompt,
385 true,
386 return_raw_logits,
387 &mut *get_mut_arcmutex!(self.prefix_cacher),
388 self.disable_eos_stop,
389 rng.clone(),
390 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
391 )
392 .await
393 };
394
395 let prompt_exec_time = handle_pipeline_forward_error!(
396 "prompt step",
397 prompt_exec_time,
398 &mut scheduled.prompt,
399 self.pipeline,
400 'lp,
401 self.prefix_cacher
402 );
403
404 let total_processed_tokens: usize = scheduled
405 .prompt
406 .iter()
407 .map(|seq| seq.get_toks().len())
408 .sum();
409 self.logger.add_tokens_processed(total_processed_tokens);
410
411 for seq in scheduled.prompt.iter_mut() {
412 match seq.sequence_stepping_type() {
413 SeqStepType::OneShot => {
414 seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
415 }
416 SeqStepType::PromptAndDecode => {
417 seq.set_state(SequenceState::RunningCompletion)
418 }
419 }
420 let now = SystemTime::now()
421 .duration_since(UNIX_EPOCH)
422 .expect("Time travel has occurred!")
423 .as_millis();
424 #[allow(clippy::cast_precision_loss)]
425 let prompt_tok_per_sec =
426 seq.len() as f32 / prompt_exec_time.as_secs_f32();
427 seq.prompt_tok_per_sec = prompt_tok_per_sec;
428 seq.prompt_timestamp = Some(now);
429 seq.total_prompt_time = Some(prompt_exec_time.as_millis());
430 }
431 last_completion_ids = vec![];
432 }
433
434 if self.is_debug {
435 let ms_from_last_run = run_start.elapsed().as_secs_f64();
436 let total_len = scheduled.prompt.len() + scheduled.completion.len();
437 if total_len > 0 {
438 let prompt_lengths = scheduled
439 .prompt
440 .iter()
441 .map(|seq| seq.len().to_string())
442 .collect::<Vec<_>>()
443 .join(", ");
444
445 let completion_lengths = scheduled
446 .completion
447 .iter()
448 .map(|seq| seq.len().to_string())
449 .collect::<Vec<_>>()
450 .join(", ");
451
452 tracing::info!(
453 "Prompt[{}] Completion[{}] - {}ms",
454 prompt_lengths,
455 completion_lengths,
456 ms_from_last_run * 1000.,
457 );
458 }
459 }
460 }
461 SchedulerOutput::PagedAttention { mut output } => {
462 if !output.scheduled.is_empty() {
463 let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
464
465 let mut guards = output
466 .scheduled
467 .iter_mut()
468 .map(|seq| seq.lock().unwrap())
469 .collect::<Vec<_>>();
470
471 let mut guards_mut =
472 guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
473
474 let res = {
475 let mut pipeline = get_mut_arcmutex!(self.pipeline);
476
477 let block_size = scheduler.block_size().unwrap();
478
479 let metadata = PagedAttentionMeta {
480 block_size,
481 sliding_window: pipeline.get_metadata().sliding_window,
482 block_engine: scheduler.block_engine().unwrap(),
483 };
484
485 let return_raw_logits = guards_mut[0].return_raw_logits;
486 assert!(
487 guards_mut
488 .iter()
489 .all(|seq| seq.return_raw_logits == return_raw_logits),
490 "All sequences must either return raw logits, or not."
491 );
492
493 pipeline
494 .step(
495 &mut guards_mut,
496 is_prompt,
497 return_raw_logits,
498 &mut *get_mut_arcmutex!(self.prefix_cacher),
499 self.disable_eos_stop,
500 rng.clone(),
501 CacheBackendMetadata::PagedAttention {
502 metadata,
503 blocks_to_copy: output.blocks_to_copy,
504 },
505 )
506 .await
507 };
508
509 handle_pipeline_forward_error!(
510 "step",
511 res,
512 &mut guards_mut,
513 self.pipeline,
514 'lp,
515 self.prefix_cacher
516 );
517
518 let total_processed_tokens: usize = guards
519 .iter()
520 .map(|seq| {
521 if seq.is_prompt() {
522 seq.get_toks().len()
523 } else {
524 1
525 }
526 })
527 .sum();
528 self.logger.add_tokens_processed(total_processed_tokens);
529
530 if self.is_debug {
531 let ms_from_last_run = run_start.elapsed().as_secs_f64();
532 let total_len = guards.len();
533 if total_len > 0 {
534 let lengths = guards
535 .iter()
536 .map(|seq| seq.len().to_string())
537 .collect::<Vec<_>>()
538 .join(", ");
539
540 let (prompt_lengths, completion_lengths) = if is_prompt {
541 (lengths, "".to_string())
542 } else {
543 ("".to_string(), lengths)
544 };
545
546 tracing::info!(
547 "Prompt[{}] Completion[{}] - {}ms",
548 prompt_lengths,
549 completion_lengths,
550 ms_from_last_run * 1000.,
551 );
552 }
553 }
554
555 if is_prompt {
556 for mut seq in guards {
557 let now = SystemTime::now()
558 .duration_since(UNIX_EPOCH)
559 .expect("Time travel has occurred!")
560 .as_millis();
561 #[allow(clippy::cast_precision_loss)]
562 let prompt_tok_per_sec =
563 seq.len() as f32 / (now - seq.timestamp()) as f32;
564 seq.prompt_tok_per_sec = prompt_tok_per_sec * 1000.;
565 seq.prompt_timestamp = Some(now);
566 seq.total_prompt_time = Some(now - seq.timestamp());
567 }
568 }
569 }
570 }
571 }
572
573 scheduler.free_finished_sequence_groups();
574 }
575 }
576
577 fn build_sequence_recognizer(
578 factory: &Option<Arc<ParserFactory>>,
579 constraint: &Constraint,
580 ) -> anyhow::Result<SequenceRecognizer> {
581 if let Some(grm) = llg_grammar_from_constraint(constraint)? {
582 let factory = factory
583 .as_ref()
584 .ok_or_else(|| anyhow::anyhow!("No token environment (llg_factory) found."))?;
585 let llg = constraint_from_llg_grammar(factory, grm)?;
586 Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
587 } else {
588 Ok(SequenceRecognizer::None)
589 }
590 }
591
592 fn replicate_request_to_daemons(&self, request: &Request) {
593 if !distributed::is_daemon() && mistralrs_quant::distributed::use_nccl() {
594 let name = distributed::ipc_name().unwrap();
595 let num_workers =
596 mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
597 let listener = ListenerOptions::new().name(name).create_sync().unwrap();
598
599 for _ in 0..num_workers {
600 let stream = listener.accept().unwrap();
601 let mut writer = BufWriter::new(stream);
602 let req = format!("{}\n", serde_json::to_string(&request).unwrap());
603 writer.write_all(req.as_bytes()).unwrap();
604 }
605 } else if !distributed::is_daemon() && cfg!(feature = "ring") {
606 let num_workers =
607 mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
608 let master_port = RingConfig::load().master_port;
609 let listener =
610 TcpListener::bind(format!("0.0.0.0:{master_port}")).expect("bind replicator");
611
612 for _ in 0..num_workers {
613 let (stream, _) = listener.accept().unwrap();
614 let mut writer = BufWriter::new(stream);
615 let req = format!("{}\n", serde_json::to_string(&request).unwrap());
616 writer.write_all(req.as_bytes()).unwrap();
617 }
618 }
619 }
620}