1use crate::{
2 distributed,
3 embedding::bert::BertPipeline,
4 pipeline::{
5 llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
6 text_models_inputs_processor::PagedAttentionMeta,
7 CacheBackendMetadata, CacheInstruction,
8 },
9 prefix_cacher::PrefixCacheManagerV2,
10 response::CompletionChoice,
11 scheduler::{Scheduler, SchedulerOutput},
12 sequence::{SeqStepType, StopReason},
13 CompletionResponse, SchedulerConfig, DEBUG,
14};
15use interprocess::local_socket::{traits::Listener, ListenerOptions};
16use llguidance::ParserFactory;
17pub use logger::IntervalLogger;
18use once_cell::sync::Lazy;
19use rand::SeedableRng;
20use rand_isaac::Isaac64Rng;
21use std::{
22 collections::HashMap,
23 io::{BufWriter, Write},
24 ops::Deref,
25 sync::{
26 atomic::{AtomicBool, Ordering},
27 Arc,
28 },
29 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
30};
31use tokio::{
32 sync::{mpsc::Receiver, Mutex},
33 task::JoinHandle,
34};
35
36use crate::{
37 get_mut_arcmutex, handle_pipeline_forward_error,
38 pipeline::Pipeline,
39 request::Request,
40 response::{ChatCompletionResponse, Choice, ResponseMessage},
41 sequence::{SequenceRecognizer, SequenceState},
42 Constraint,
43};
44
45mod add_request;
46mod logger;
47mod search_request;
48
49pub enum EngineInstruction {
50 Terminate,
51}
52
53#[derive(Debug, Default, Clone)]
54pub enum BertEmbeddingModel {
56 #[default]
57 SnowflakeArcticEmbedL,
58 Custom(String),
59}
60
61const SEED: u64 = 0;
62pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
64
65pub static ENGINE_INSTRUCTIONS: Lazy<std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>> =
67 Lazy::new(|| std::sync::Mutex::new(HashMap::new()));
68
69pub struct Engine {
70 rx: Arc<Mutex<Receiver<Request>>>,
71 pipeline: Arc<Mutex<dyn Pipeline>>,
72 bert_pipeline: Arc<Mutex<Option<BertPipeline>>>,
73 scheduler: Arc<Mutex<dyn Scheduler>>,
74 id: Arc<Mutex<usize>>,
75 truncate_sequence: bool,
76 no_kv_cache: bool,
77 prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
78 is_debug: bool,
79 disable_eos_stop: bool,
80 throughput_logging_enabled: bool,
81 logger: IntervalLogger,
82 handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
83}
84
85impl Drop for Engine {
86 fn drop(&mut self) {
87 for handle in &*get_mut_arcmutex!(self.handles) {
88 handle.abort();
89 }
90 }
91}
92
93impl Engine {
94 #[allow(clippy::too_many_arguments)]
95 pub fn new(
96 rx: Receiver<Request>,
97 pipeline: Arc<Mutex<dyn Pipeline>>,
98 config: SchedulerConfig,
99 truncate_sequence: bool,
100 mut no_kv_cache: bool,
101 mut no_prefix_cache: bool,
102 prefix_cache_n: usize,
103 disable_eos_stop: bool,
104 throughput_logging_enabled: bool,
105 search_embedding_model: Option<BertEmbeddingModel>,
106 ) -> anyhow::Result<Self> {
107 no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
108
109 no_prefix_cache = no_prefix_cache
110 || no_kv_cache
111 || get_mut_arcmutex!(pipeline).get_metadata().no_prefix_cache;
112
113 let bert_pipeline = match search_embedding_model {
114 Some(search_embedding_model) => Some(BertPipeline::new(
115 search_embedding_model,
116 &get_mut_arcmutex!(pipeline).device(),
117 )?),
118 None => None,
119 };
120
121 let scheduler = config.into_scheduler();
122 let block_engine = get_mut_arcmutex!(scheduler).block_engine();
123
124 Ok(Self {
125 rx: Arc::new(Mutex::new(rx)),
126 pipeline,
127 bert_pipeline: Arc::new(Mutex::new(bert_pipeline)),
128 scheduler: scheduler.clone(),
129 id: Arc::new(Mutex::new(0)),
130 truncate_sequence,
131 no_kv_cache,
132 prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
133 prefix_cache_n,
134 no_prefix_cache,
135 block_engine,
136 ))),
137 is_debug: DEBUG.load(Ordering::Relaxed),
138 disable_eos_stop,
139 throughput_logging_enabled,
140 logger: IntervalLogger::new(Duration::from_secs(5)),
141 handles: Arc::new(Mutex::new(Vec::new())),
142 })
143 }
144
145 pub async fn run(self: Arc<Self>) {
146 if self.throughput_logging_enabled {
147 self.logger.enable_logging();
148 }
149
150 let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
151 let mut last_completion_ids: Vec<usize> = vec![];
152 'lp: loop {
153 if matches!(
154 ENGINE_INSTRUCTIONS
155 .lock()
156 .expect("`ENGINE_INSTRUCTIONS` was poisoned")
157 .get(get_mut_arcmutex!(self.id).deref()),
158 Some(Some(EngineInstruction::Terminate))
159 ) {
160 self.replicate_request_to_daemons(&Request::Terminate);
161 break 'lp;
162 }
163
164 while let Ok(request) = get_mut_arcmutex!(self.rx).try_recv() {
165 self.replicate_request_to_daemons(&request);
166 if matches!(request, Request::Terminate) {
167 break 'lp;
168 }
169 self.clone().handle_request(request).await;
170 }
171
172 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
173 self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
174 }
175
176 let run_start = Instant::now();
177 let mut scheduler = get_mut_arcmutex!(self.scheduler);
178 let scheduled = scheduler.schedule(&self.logger);
179
180 match scheduled {
181 SchedulerOutput::DefaultScheduler {
182 output: mut scheduled,
183 } => {
184 if !scheduled.completion.is_empty() {
185 let current_completion_ids: Vec<usize> =
186 scheduled.completion.iter().map(|seq| *seq.id()).collect();
187 let res = {
188 let mut pipeline = get_mut_arcmutex!(self.pipeline);
189 let pre_op = if !self.no_kv_cache
190 && last_completion_ids != current_completion_ids
191 {
192 CacheInstruction::In
193 } else {
194 CacheInstruction::Nothing
195 };
196 let post_op = if !self.no_kv_cache {
197 CacheInstruction::Out
198 } else {
199 CacheInstruction::Reset {
200 load_preallocated_cache: false,
201 reset_non_granular: false,
202 }
203 };
204
205 let return_raw_logits = scheduled.completion[0].return_raw_logits;
206 assert!(
207 scheduled
208 .completion
209 .iter()
210 .all(|seq| seq.return_raw_logits == return_raw_logits),
211 "All sequences must either return raw logits, or not."
212 );
213
214 pipeline
215 .step(
216 &mut scheduled.completion,
217 false,
218 return_raw_logits,
219 &mut *get_mut_arcmutex!(self.prefix_cacher),
220 self.disable_eos_stop,
221 rng.clone(),
222 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
223 )
224 .await
225 };
226
227 handle_pipeline_forward_error!(
228 "completion step",
229 res,
230 &mut scheduled.completion,
231 self.pipeline,
232 'lp,
233 self.prefix_cacher
234 );
235
236 self.logger.add_tokens_processed(scheduled.completion.len());
237
238 last_completion_ids = current_completion_ids;
239 }
240
241 if !scheduled.prompt.is_empty() {
242 let prompt_exec_time = {
243 let mut pipeline = get_mut_arcmutex!(self.pipeline);
244
245 let post_op = if !self.no_kv_cache {
247 CacheInstruction::Out
248 } else {
249 CacheInstruction::Reset {
250 load_preallocated_cache: false,
251 reset_non_granular: false,
252 }
253 };
254
255 let return_raw_logits = scheduled.prompt[0].return_raw_logits;
256 assert!(
257 scheduled
258 .prompt
259 .iter()
260 .all(|seq| seq.return_raw_logits == return_raw_logits),
261 "All sequences must either return raw logits, or not."
262 );
263
264 let pre_op = if scheduled.prompt[0].token_offset() != 0 {
267 CacheInstruction::In
268 } else {
269 CacheInstruction::Reset {
270 load_preallocated_cache: true,
271 reset_non_granular: false,
272 }
273 };
274
275 pipeline
276 .step(
277 &mut scheduled.prompt,
278 true,
279 return_raw_logits,
280 &mut *get_mut_arcmutex!(self.prefix_cacher),
281 self.disable_eos_stop,
282 rng.clone(),
283 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
284 )
285 .await
286 };
287
288 let prompt_exec_time = handle_pipeline_forward_error!(
289 "prompt step",
290 prompt_exec_time,
291 &mut scheduled.prompt,
292 self.pipeline,
293 'lp,
294 self.prefix_cacher
295 );
296
297 let total_processed_tokens: usize = scheduled
298 .prompt
299 .iter()
300 .map(|seq| seq.get_toks().len())
301 .sum();
302 self.logger.add_tokens_processed(total_processed_tokens);
303
304 for seq in scheduled.prompt.iter_mut() {
305 match seq.sequence_stepping_type() {
306 SeqStepType::OneShot => {
307 seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
308 }
309 SeqStepType::PromptAndDecode => {
310 seq.set_state(SequenceState::RunningCompletion)
311 }
312 }
313 let now = SystemTime::now()
314 .duration_since(UNIX_EPOCH)
315 .expect("Time travel has occurred!")
316 .as_millis();
317 #[allow(clippy::cast_precision_loss)]
318 let prompt_tok_per_sec =
319 seq.len() as f32 / prompt_exec_time.as_secs_f32();
320 seq.prompt_tok_per_sec = prompt_tok_per_sec;
321 seq.prompt_timestamp = Some(now);
322 seq.total_prompt_time = Some(prompt_exec_time.as_millis());
323 }
324 last_completion_ids = vec![];
325 }
326
327 if self.is_debug {
328 let ms_from_last_run = run_start.elapsed().as_secs_f64();
329 let total_len = scheduled.prompt.len() + scheduled.completion.len();
330 if total_len > 0 {
331 let prompt_lengths = scheduled
332 .prompt
333 .iter()
334 .map(|seq| seq.len().to_string())
335 .collect::<Vec<_>>()
336 .join(", ");
337
338 let completion_lengths = scheduled
339 .completion
340 .iter()
341 .map(|seq| seq.len().to_string())
342 .collect::<Vec<_>>()
343 .join(", ");
344
345 tracing::info!(
346 "Prompt[{}] Completion[{}] - {}ms",
347 prompt_lengths,
348 completion_lengths,
349 ms_from_last_run * 1000.,
350 );
351 }
352 }
353 }
354 SchedulerOutput::PagedAttention { mut output } => {
355 if !output.scheduled.is_empty() {
356 let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
357
358 let mut guards = output
359 .scheduled
360 .iter_mut()
361 .map(|seq| seq.lock().unwrap())
362 .collect::<Vec<_>>();
363
364 let mut guards_mut =
365 guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
366
367 let res = {
368 let mut pipeline = get_mut_arcmutex!(self.pipeline);
369
370 let block_size = scheduler.block_size().unwrap();
371
372 let metadata = PagedAttentionMeta {
373 block_size,
374 sliding_window: pipeline.get_metadata().sliding_window,
375 block_engine: scheduler.block_engine().unwrap(),
376 };
377
378 let return_raw_logits = guards_mut[0].return_raw_logits;
379 assert!(
380 guards_mut
381 .iter()
382 .all(|seq| seq.return_raw_logits == return_raw_logits),
383 "All sequences must either return raw logits, or not."
384 );
385
386 pipeline
387 .step(
388 &mut guards_mut,
389 is_prompt,
390 return_raw_logits,
391 &mut *get_mut_arcmutex!(self.prefix_cacher),
392 self.disable_eos_stop,
393 rng.clone(),
394 CacheBackendMetadata::PagedAttention {
395 metadata,
396 blocks_to_copy: output.blocks_to_copy,
397 blocks_to_swap_in: output.blocks_to_swap_in,
398 blocks_to_swap_out: output.blocks_to_swap_out,
399 },
400 )
401 .await
402 };
403
404 handle_pipeline_forward_error!(
405 "step",
406 res,
407 &mut guards_mut,
408 self.pipeline,
409 'lp,
410 self.prefix_cacher
411 );
412
413 let total_processed_tokens: usize = guards
414 .iter()
415 .map(|seq| {
416 if seq.is_prompt() {
417 seq.get_toks().len()
418 } else {
419 1
420 }
421 })
422 .sum();
423 self.logger.add_tokens_processed(total_processed_tokens);
424
425 if self.is_debug {
426 let ms_from_last_run = run_start.elapsed().as_secs_f64();
427 let total_len = guards.len();
428 if total_len > 0 {
429 let lengths = guards
430 .iter()
431 .map(|seq| seq.len().to_string())
432 .collect::<Vec<_>>()
433 .join(", ");
434
435 let (prompt_lengths, completion_lengths) = if is_prompt {
436 (lengths, "".to_string())
437 } else {
438 ("".to_string(), lengths)
439 };
440
441 tracing::info!(
442 "Prompt[{}] Completion[{}] - {}ms",
443 prompt_lengths,
444 completion_lengths,
445 ms_from_last_run * 1000.,
446 );
447 }
448 }
449
450 if is_prompt {
451 for mut seq in guards {
452 let now = SystemTime::now()
453 .duration_since(UNIX_EPOCH)
454 .expect("Time travel has occurred!")
455 .as_millis();
456 #[allow(clippy::cast_precision_loss)]
457 let prompt_tok_per_sec =
458 seq.len() as f32 / (now - seq.timestamp()) as f32;
459 seq.prompt_tok_per_sec = prompt_tok_per_sec * 1000.;
460 seq.prompt_timestamp = Some(now);
461 seq.total_prompt_time = Some(now - seq.timestamp());
462 }
463 }
464 }
465 }
466 }
467
468 scheduler.free_finished_sequence_groups();
469 }
470 }
471
472 fn build_sequence_recognizer(
473 factory: &Option<Arc<ParserFactory>>,
474 constraint: &Constraint,
475 ) -> anyhow::Result<SequenceRecognizer> {
476 if let Some(grm) = llg_grammar_from_constraint(constraint)? {
477 let factory = factory
478 .as_ref()
479 .ok_or_else(|| anyhow::anyhow!("No token environment (llg_factory) found."))?;
480 let llg = constraint_from_llg_grammar(factory, grm)?;
481 Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
482 } else {
483 Ok(SequenceRecognizer::None)
484 }
485 }
486
487 fn replicate_request_to_daemons(&self, request: &Request) {
488 if !distributed::is_daemon() && mistralrs_quant::distributed::use_nccl() {
489 let name = distributed::ipc_name().unwrap();
490 let num_workers =
491 mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
492 let listener = ListenerOptions::new().name(name).create_sync().unwrap();
493
494 for _ in 0..num_workers {
495 let stream = listener.accept().unwrap();
496 let mut writer = BufWriter::new(stream);
497 let req = format!("{}\n", serde_json::to_string(&request).unwrap());
498 writer.write_all(req.as_bytes()).unwrap();
499 }
500 };
501 }
502}