1use crate::{
2 distributed,
3 embedding::bert::BertPipeline,
4 pipeline::{
5 llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
6 text_models_inputs_processor::PagedAttentionMeta,
7 CacheBackendMetadata, CacheInstruction,
8 },
9 prefix_cacher::PrefixCacheManagerV2,
10 response::CompletionChoice,
11 scheduler::{Scheduler, SchedulerOutput},
12 search,
13 sequence::{SeqStepType, StopReason},
14 tools, CompletionResponse, SchedulerConfig, DEBUG,
15};
16use interprocess::local_socket::{traits::Listener, ListenerOptions};
17use llguidance::ParserFactory;
18pub use logger::IntervalLogger;
19use mistralrs_quant::RingConfig;
20use once_cell::sync::Lazy;
21use rand::SeedableRng;
22use rand_isaac::Isaac64Rng;
23use std::{
24 collections::HashMap,
25 io::{BufWriter, Write},
26 net::TcpListener,
27 ops::Deref,
28 sync::{
29 atomic::{AtomicBool, Ordering},
30 Arc,
31 },
32 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
33};
34use tokio::{
35 sync::{mpsc::Receiver, Mutex},
36 task::JoinHandle,
37};
38
39use crate::{
40 get_mut_arcmutex, handle_pipeline_forward_error,
41 pipeline::Pipeline,
42 request::Request,
43 response::{ChatCompletionResponse, Choice, ResponseMessage},
44 sequence::{SequenceRecognizer, SequenceState},
45 Constraint,
46};
47
48mod add_request;
49mod logger;
50mod search_request;
51
52pub enum EngineInstruction {
53 Terminate,
54}
55
56#[derive(Debug, Default, Clone)]
57pub enum BertEmbeddingModel {
59 #[default]
60 SnowflakeArcticEmbedL,
61 Custom(String),
62}
63
64const SEED: u64 = 0;
65pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
68
69static ENGINE_TERMINATE_FLAGS: Lazy<
71 std::sync::Mutex<HashMap<std::thread::ThreadId, Arc<AtomicBool>>>,
72> = Lazy::new(|| std::sync::Mutex::new(HashMap::new()));
73
74pub fn get_engine_terminate_flag() -> Arc<AtomicBool> {
76 let thread_id = std::thread::current().id();
77 let mut flags = ENGINE_TERMINATE_FLAGS.lock().unwrap();
78 flags
79 .entry(thread_id)
80 .or_insert_with(|| Arc::new(AtomicBool::new(false)))
81 .clone()
82}
83
84pub fn should_terminate_engine_sequences() -> bool {
86 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
88 return true;
89 }
90 let thread_id = std::thread::current().id();
92 if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
93 if let Some(flag) = flags.get(&thread_id) {
94 return flag.load(Ordering::SeqCst);
95 }
96 }
97 false
98}
99
100pub fn reset_engine_terminate_flag() {
102 let thread_id = std::thread::current().id();
103 if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
104 if let Some(flag) = flags.get(&thread_id) {
105 flag.store(false, Ordering::SeqCst);
106 }
107 }
108}
109
110pub static ENGINE_INSTRUCTIONS: Lazy<std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>> =
112 Lazy::new(|| std::sync::Mutex::new(HashMap::new()));
113
114pub struct Engine {
115 rx: Arc<Mutex<Receiver<Request>>>,
116 pipeline: Arc<Mutex<dyn Pipeline>>,
117 bert_pipeline: Arc<Mutex<Option<BertPipeline>>>,
118 search_callback: Option<Arc<search::SearchCallback>>,
119 tool_callbacks: tools::ToolCallbacks,
120 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
121 scheduler: Arc<Mutex<dyn Scheduler>>,
122 id: Arc<Mutex<usize>>,
123 truncate_sequence: bool,
124 no_kv_cache: bool,
125 prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
126 is_debug: bool,
127 disable_eos_stop: bool,
128 throughput_logging_enabled: bool,
129 logger: IntervalLogger,
130 handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
131}
132
133impl Drop for Engine {
134 fn drop(&mut self) {
135 for handle in &*get_mut_arcmutex!(self.handles) {
136 handle.abort();
137 }
138 }
139}
140
141impl Engine {
142 #[allow(clippy::too_many_arguments)]
143 pub fn new(
144 rx: Receiver<Request>,
145 pipeline: Arc<Mutex<dyn Pipeline>>,
146 config: SchedulerConfig,
147 truncate_sequence: bool,
148 mut no_kv_cache: bool,
149 mut no_prefix_cache: bool,
150 prefix_cache_n: usize,
151 disable_eos_stop: bool,
152 throughput_logging_enabled: bool,
153 search_embedding_model: Option<BertEmbeddingModel>,
154 search_callback: Option<Arc<search::SearchCallback>>,
155 tool_callbacks: tools::ToolCallbacks,
156 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
157 ) -> anyhow::Result<Self> {
158 no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
159
160 no_prefix_cache = no_prefix_cache
161 || no_kv_cache
162 || get_mut_arcmutex!(pipeline).get_metadata().no_prefix_cache
163 || prefix_cache_n == 0;
164
165 let bert_pipeline = match search_embedding_model {
166 Some(search_embedding_model) => Some(BertPipeline::new(
167 search_embedding_model,
168 &get_mut_arcmutex!(pipeline).device(),
169 )?),
170 None => None,
171 };
172
173 let scheduler = config.into_scheduler();
174 let block_engine = get_mut_arcmutex!(scheduler).block_engine();
175
176 Ok(Self {
177 rx: Arc::new(Mutex::new(rx)),
178 pipeline,
179 bert_pipeline: Arc::new(Mutex::new(bert_pipeline)),
180 search_callback,
181 tool_callbacks,
182 tool_callbacks_with_tools,
183 scheduler: scheduler.clone(),
184 id: Arc::new(Mutex::new(0)),
185 truncate_sequence,
186 no_kv_cache,
187 prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
188 prefix_cache_n,
189 no_prefix_cache,
190 block_engine,
191 ))),
192 is_debug: DEBUG.load(Ordering::Relaxed),
193 disable_eos_stop,
194 throughput_logging_enabled,
195 logger: IntervalLogger::new(Duration::from_secs(5)),
196 handles: Arc::new(Mutex::new(Vec::new())),
197 })
198 }
199
200 pub async fn run(self: Arc<Self>) {
201 if self.throughput_logging_enabled {
202 self.logger.enable_logging();
203 }
204
205 let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
206 let mut last_completion_ids: Vec<usize> = vec![];
207 'lp: loop {
208 if matches!(
209 ENGINE_INSTRUCTIONS
210 .lock()
211 .expect("`ENGINE_INSTRUCTIONS` was poisoned")
212 .get(get_mut_arcmutex!(self.id).deref()),
213 Some(Some(EngineInstruction::Terminate))
214 ) {
215 self.replicate_request_to_daemons(&Request::Terminate);
216 break 'lp;
217 }
218
219 while let Ok(request) = get_mut_arcmutex!(self.rx).try_recv() {
220 self.replicate_request_to_daemons(&request);
221 if matches!(request, Request::Terminate) {
222 break 'lp;
223 }
224 self.clone().handle_request(request).await;
225 }
226
227 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
228 self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
229 }
230
231 let run_start = Instant::now();
232 let mut scheduler = get_mut_arcmutex!(self.scheduler);
233 let scheduled = scheduler.schedule(&self.logger);
234
235 match scheduled {
236 SchedulerOutput::DefaultScheduler {
237 output: mut scheduled,
238 } => {
239 if !scheduled.completion.is_empty() {
240 let current_completion_ids: Vec<usize> =
241 scheduled.completion.iter().map(|seq| *seq.id()).collect();
242 let res = {
243 let mut pipeline = get_mut_arcmutex!(self.pipeline);
244 let pre_op = if !self.no_kv_cache
245 && last_completion_ids != current_completion_ids
246 {
247 CacheInstruction::In
248 } else {
249 CacheInstruction::Nothing
250 };
251 let post_op = if !self.no_kv_cache {
252 CacheInstruction::Out
253 } else {
254 CacheInstruction::Reset {
255 load_preallocated_cache: false,
256 reset_non_granular: false,
257 }
258 };
259
260 let return_raw_logits = scheduled.completion[0].return_raw_logits;
261 assert!(
262 scheduled
263 .completion
264 .iter()
265 .all(|seq| seq.return_raw_logits == return_raw_logits),
266 "All sequences must either return raw logits, or not."
267 );
268
269 pipeline
270 .step(
271 &mut scheduled.completion,
272 false,
273 return_raw_logits,
274 &mut *get_mut_arcmutex!(self.prefix_cacher),
275 self.disable_eos_stop,
276 rng.clone(),
277 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
278 )
279 .await
280 };
281
282 handle_pipeline_forward_error!(
283 "completion step",
284 res,
285 &mut scheduled.completion,
286 self.pipeline,
287 'lp,
288 self.prefix_cacher
289 );
290
291 self.logger.add_tokens_processed(scheduled.completion.len());
292
293 last_completion_ids = current_completion_ids;
294 }
295
296 if !scheduled.prompt.is_empty() {
297 let prompt_exec_time = {
298 let mut pipeline = get_mut_arcmutex!(self.pipeline);
299
300 let post_op = if !self.no_kv_cache {
302 CacheInstruction::Out
303 } else {
304 CacheInstruction::Reset {
305 load_preallocated_cache: false,
306 reset_non_granular: false,
307 }
308 };
309
310 let return_raw_logits = scheduled.prompt[0].return_raw_logits;
311 assert!(
312 scheduled
313 .prompt
314 .iter()
315 .all(|seq| seq.return_raw_logits == return_raw_logits),
316 "All sequences must either return raw logits, or not."
317 );
318
319 let pre_op = if scheduled.prompt[0].token_offset() != 0 {
322 CacheInstruction::In
323 } else {
324 CacheInstruction::Reset {
325 load_preallocated_cache: true,
326 reset_non_granular: false,
327 }
328 };
329
330 pipeline
331 .step(
332 &mut scheduled.prompt,
333 true,
334 return_raw_logits,
335 &mut *get_mut_arcmutex!(self.prefix_cacher),
336 self.disable_eos_stop,
337 rng.clone(),
338 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
339 )
340 .await
341 };
342
343 let prompt_exec_time = handle_pipeline_forward_error!(
344 "prompt step",
345 prompt_exec_time,
346 &mut scheduled.prompt,
347 self.pipeline,
348 'lp,
349 self.prefix_cacher
350 );
351
352 let total_processed_tokens: usize = scheduled
353 .prompt
354 .iter()
355 .map(|seq| seq.get_toks().len())
356 .sum();
357 self.logger.add_tokens_processed(total_processed_tokens);
358
359 for seq in scheduled.prompt.iter_mut() {
360 match seq.sequence_stepping_type() {
361 SeqStepType::OneShot => {
362 seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
363 }
364 SeqStepType::PromptAndDecode => {
365 seq.set_state(SequenceState::RunningCompletion)
366 }
367 }
368 let now = SystemTime::now()
369 .duration_since(UNIX_EPOCH)
370 .expect("Time travel has occurred!")
371 .as_millis();
372 #[allow(clippy::cast_precision_loss)]
373 let prompt_tok_per_sec =
374 seq.len() as f32 / prompt_exec_time.as_secs_f32();
375 seq.prompt_tok_per_sec = prompt_tok_per_sec;
376 seq.prompt_timestamp = Some(now);
377 seq.total_prompt_time = Some(prompt_exec_time.as_millis());
378 }
379 last_completion_ids = vec![];
380 }
381
382 if self.is_debug {
383 let ms_from_last_run = run_start.elapsed().as_secs_f64();
384 let total_len = scheduled.prompt.len() + scheduled.completion.len();
385 if total_len > 0 {
386 let prompt_lengths = scheduled
387 .prompt
388 .iter()
389 .map(|seq| seq.len().to_string())
390 .collect::<Vec<_>>()
391 .join(", ");
392
393 let completion_lengths = scheduled
394 .completion
395 .iter()
396 .map(|seq| seq.len().to_string())
397 .collect::<Vec<_>>()
398 .join(", ");
399
400 tracing::info!(
401 "Prompt[{}] Completion[{}] - {}ms",
402 prompt_lengths,
403 completion_lengths,
404 ms_from_last_run * 1000.,
405 );
406 }
407 }
408 }
409 SchedulerOutput::PagedAttention { mut output } => {
410 if !output.scheduled.is_empty() {
411 let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
412
413 let mut guards = output
414 .scheduled
415 .iter_mut()
416 .map(|seq| seq.lock().unwrap())
417 .collect::<Vec<_>>();
418
419 let mut guards_mut =
420 guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
421
422 let res = {
423 let mut pipeline = get_mut_arcmutex!(self.pipeline);
424
425 let block_size = scheduler.block_size().unwrap();
426
427 let metadata = PagedAttentionMeta {
428 block_size,
429 sliding_window: pipeline.get_metadata().sliding_window,
430 block_engine: scheduler.block_engine().unwrap(),
431 };
432
433 let return_raw_logits = guards_mut[0].return_raw_logits;
434 assert!(
435 guards_mut
436 .iter()
437 .all(|seq| seq.return_raw_logits == return_raw_logits),
438 "All sequences must either return raw logits, or not."
439 );
440
441 pipeline
442 .step(
443 &mut guards_mut,
444 is_prompt,
445 return_raw_logits,
446 &mut *get_mut_arcmutex!(self.prefix_cacher),
447 self.disable_eos_stop,
448 rng.clone(),
449 CacheBackendMetadata::PagedAttention {
450 metadata,
451 blocks_to_copy: output.blocks_to_copy,
452 blocks_to_swap_in: output.blocks_to_swap_in,
453 blocks_to_swap_out: output.blocks_to_swap_out,
454 },
455 )
456 .await
457 };
458
459 handle_pipeline_forward_error!(
460 "step",
461 res,
462 &mut guards_mut,
463 self.pipeline,
464 'lp,
465 self.prefix_cacher
466 );
467
468 let total_processed_tokens: usize = guards
469 .iter()
470 .map(|seq| {
471 if seq.is_prompt() {
472 seq.get_toks().len()
473 } else {
474 1
475 }
476 })
477 .sum();
478 self.logger.add_tokens_processed(total_processed_tokens);
479
480 if self.is_debug {
481 let ms_from_last_run = run_start.elapsed().as_secs_f64();
482 let total_len = guards.len();
483 if total_len > 0 {
484 let lengths = guards
485 .iter()
486 .map(|seq| seq.len().to_string())
487 .collect::<Vec<_>>()
488 .join(", ");
489
490 let (prompt_lengths, completion_lengths) = if is_prompt {
491 (lengths, "".to_string())
492 } else {
493 ("".to_string(), lengths)
494 };
495
496 tracing::info!(
497 "Prompt[{}] Completion[{}] - {}ms",
498 prompt_lengths,
499 completion_lengths,
500 ms_from_last_run * 1000.,
501 );
502 }
503 }
504
505 if is_prompt {
506 for mut seq in guards {
507 let now = SystemTime::now()
508 .duration_since(UNIX_EPOCH)
509 .expect("Time travel has occurred!")
510 .as_millis();
511 #[allow(clippy::cast_precision_loss)]
512 let prompt_tok_per_sec =
513 seq.len() as f32 / (now - seq.timestamp()) as f32;
514 seq.prompt_tok_per_sec = prompt_tok_per_sec * 1000.;
515 seq.prompt_timestamp = Some(now);
516 seq.total_prompt_time = Some(now - seq.timestamp());
517 }
518 }
519 }
520 }
521 }
522
523 scheduler.free_finished_sequence_groups();
524 }
525 }
526
527 fn build_sequence_recognizer(
528 factory: &Option<Arc<ParserFactory>>,
529 constraint: &Constraint,
530 ) -> anyhow::Result<SequenceRecognizer> {
531 if let Some(grm) = llg_grammar_from_constraint(constraint)? {
532 let factory = factory
533 .as_ref()
534 .ok_or_else(|| anyhow::anyhow!("No token environment (llg_factory) found."))?;
535 let llg = constraint_from_llg_grammar(factory, grm)?;
536 Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
537 } else {
538 Ok(SequenceRecognizer::None)
539 }
540 }
541
542 fn replicate_request_to_daemons(&self, request: &Request) {
543 if !distributed::is_daemon() && mistralrs_quant::distributed::use_nccl() {
544 let name = distributed::ipc_name().unwrap();
545 let num_workers =
546 mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
547 let listener = ListenerOptions::new().name(name).create_sync().unwrap();
548
549 for _ in 0..num_workers {
550 let stream = listener.accept().unwrap();
551 let mut writer = BufWriter::new(stream);
552 let req = format!("{}\n", serde_json::to_string(&request).unwrap());
553 writer.write_all(req.as_bytes()).unwrap();
554 }
555 } else if !distributed::is_daemon() && cfg!(feature = "ring") {
556 let num_workers =
557 mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
558 let master_port = RingConfig::load().master_port;
559 let listener =
560 TcpListener::bind(format!("0.0.0.0:{master_port}")).expect("bind replicator");
561
562 for _ in 0..num_workers {
563 let (stream, _) = listener.accept().unwrap();
564 let mut writer = BufWriter::new(stream);
565 let req = format!("{}\n", serde_json::to_string(&request).unwrap());
566 writer.write_all(req.as_bytes()).unwrap();
567 }
568 }
569 }
570}