1use crate::{
2 distributed,
3 embedding::bert::BertPipeline,
4 pipeline::{
5 llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
6 text_models_inputs_processor::PagedAttentionMeta,
7 CacheBackendMetadata, CacheInstruction,
8 },
9 prefix_cacher::PrefixCacheManagerV2,
10 response::CompletionChoice,
11 scheduler::{Scheduler, SchedulerOutput},
12 sequence::{SeqStepType, StopReason},
13 CompletionResponse, SchedulerConfig, DEBUG,
14};
15use interprocess::local_socket::{traits::Listener, ListenerOptions};
16use llguidance::toktrie::TokEnv;
17use logger::IntervalLogger;
18use once_cell::sync::Lazy;
19use rand::SeedableRng;
20use rand_isaac::Isaac64Rng;
21use std::{
22 collections::HashMap,
23 io::{BufWriter, Write},
24 ops::Deref,
25 sync::{
26 atomic::{AtomicBool, Ordering},
27 Arc,
28 },
29 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
30};
31use tokio::{
32 sync::{mpsc::Receiver, Mutex},
33 task::JoinHandle,
34};
35
36use crate::{
37 get_mut_arcmutex, handle_pipeline_forward_error,
38 pipeline::Pipeline,
39 request::Request,
40 response::{ChatCompletionResponse, Choice, ResponseMessage},
41 sequence::{SequenceRecognizer, SequenceState},
42 Constraint,
43};
44
45mod add_request;
46mod logger;
47
48pub enum EngineInstruction {
49 Terminate,
50}
51
52#[derive(Debug, Default, Clone)]
53pub enum BertEmbeddingModel {
55 #[default]
56 SnowflakeArcticEmbedL,
57 Custom(String),
58}
59
60const SEED: u64 = 0;
61pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
63
64pub static ENGINE_INSTRUCTIONS: Lazy<std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>> =
66 Lazy::new(|| std::sync::Mutex::new(HashMap::new()));
67
68pub struct Engine {
69 rx: Arc<Mutex<Receiver<Request>>>,
70 pipeline: Arc<Mutex<dyn Pipeline>>,
71 bert_pipeline: Arc<Mutex<Option<BertPipeline>>>,
72 scheduler: Arc<Mutex<dyn Scheduler>>,
73 id: Arc<Mutex<usize>>,
74 truncate_sequence: bool,
75 no_kv_cache: bool,
76 prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
77 is_debug: bool,
78 disable_eos_stop: bool,
79 throughput_logging_enabled: bool,
80 logger: IntervalLogger,
81 handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
82}
83
84impl Drop for Engine {
85 fn drop(&mut self) {
86 for handle in &*get_mut_arcmutex!(self.handles) {
87 handle.abort();
88 }
89 }
90}
91
92impl Engine {
93 #[allow(clippy::too_many_arguments)]
94 pub fn new(
95 rx: Receiver<Request>,
96 pipeline: Arc<Mutex<dyn Pipeline>>,
97 config: SchedulerConfig,
98 truncate_sequence: bool,
99 mut no_kv_cache: bool,
100 mut no_prefix_cache: bool,
101 prefix_cache_n: usize,
102 disable_eos_stop: bool,
103 throughput_logging_enabled: bool,
104 search_embedding_model: Option<BertEmbeddingModel>,
105 ) -> anyhow::Result<Self> {
106 no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
107
108 no_prefix_cache = matches!(config, SchedulerConfig::PagedAttentionMeta { .. })
109 || no_prefix_cache
110 || no_kv_cache;
111
112 let bert_pipeline = match search_embedding_model {
113 Some(search_embedding_model) => Some(BertPipeline::new(
114 search_embedding_model,
115 &get_mut_arcmutex!(pipeline).device(),
116 )?),
117 None => None,
118 };
119
120 Ok(Self {
121 rx: Arc::new(Mutex::new(rx)),
122 pipeline,
123 bert_pipeline: Arc::new(Mutex::new(bert_pipeline)),
124 scheduler: config.into_scheduler(),
125 id: Arc::new(Mutex::new(0)),
126 truncate_sequence,
127 no_kv_cache,
128 prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
129 prefix_cache_n,
130 no_prefix_cache,
131 ))),
132 is_debug: DEBUG.load(Ordering::Relaxed),
133 disable_eos_stop,
134 throughput_logging_enabled,
135 logger: IntervalLogger::new(Duration::from_secs(5)),
136 handles: Arc::new(Mutex::new(Vec::new())),
137 })
138 }
139
140 pub async fn run(self: Arc<Self>) {
141 if self.throughput_logging_enabled {
142 self.logger.enable_logging();
143 }
144
145 let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
146 let mut last_completion_ids: Vec<usize> = vec![];
147 'lp: loop {
148 if matches!(
149 ENGINE_INSTRUCTIONS
150 .lock()
151 .expect("`ENGINE_INSTRUCTIONS` was poisioned")
152 .get(get_mut_arcmutex!(self.id).deref()),
153 Some(Some(EngineInstruction::Terminate))
154 ) {
155 self.replicate_request_to_daemons(&Request::Terminate);
156 break 'lp;
157 }
158
159 while let Ok(request) = get_mut_arcmutex!(self.rx).try_recv() {
160 self.replicate_request_to_daemons(&request);
161 if matches!(request, Request::Terminate) {
162 break 'lp;
163 }
164 self.clone().handle_request(request).await;
165 }
166
167 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
168 self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
169 }
170
171 let run_start = Instant::now();
172 let mut scheduler = get_mut_arcmutex!(self.scheduler);
173 let scheduled = scheduler.schedule();
174
175 match scheduled {
176 SchedulerOutput::DefaultScheduler {
177 output: mut scheduled,
178 } => {
179 if !scheduled.completion.is_empty() {
180 let current_completion_ids: Vec<usize> =
181 scheduled.completion.iter().map(|seq| *seq.id()).collect();
182 let res = {
183 let mut pipeline = get_mut_arcmutex!(self.pipeline);
184 let pre_op = if !self.no_kv_cache
185 && last_completion_ids != current_completion_ids
186 {
187 CacheInstruction::In
188 } else {
189 CacheInstruction::Nothing
190 };
191 let post_op = if !self.no_kv_cache {
192 CacheInstruction::Out
193 } else {
194 CacheInstruction::Reset {
195 load_preallocated_cache: false,
196 reset_non_granular: false,
197 }
198 };
199
200 let return_raw_logits = scheduled.completion[0].return_raw_logits;
201 assert!(
202 scheduled
203 .completion
204 .iter()
205 .all(|seq| seq.return_raw_logits == return_raw_logits),
206 "All sequences must either return raw logits, or not."
207 );
208
209 pipeline
210 .step(
211 &mut scheduled.completion,
212 false,
213 return_raw_logits,
214 &mut *get_mut_arcmutex!(self.prefix_cacher),
215 self.disable_eos_stop,
216 rng.clone(),
217 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
218 )
219 .await
220 };
221
222 handle_pipeline_forward_error!(
223 "completion step",
224 res,
225 &mut scheduled.completion,
226 self.pipeline,
227 'lp,
228 self.prefix_cacher
229 );
230
231 self.logger.add_tokens_processed(scheduled.completion.len());
232
233 last_completion_ids = current_completion_ids;
234 }
235
236 if !scheduled.prompt.is_empty() {
237 let prompt_exec_time = {
238 let mut pipeline = get_mut_arcmutex!(self.pipeline);
239
240 let post_op = if !self.no_kv_cache {
242 CacheInstruction::Out
243 } else {
244 CacheInstruction::Reset {
245 load_preallocated_cache: false,
246 reset_non_granular: false,
247 }
248 };
249
250 let return_raw_logits = scheduled.prompt[0].return_raw_logits;
251 assert!(
252 scheduled
253 .prompt
254 .iter()
255 .all(|seq| seq.return_raw_logits == return_raw_logits),
256 "All sequences must either return raw logits, or not."
257 );
258
259 let pre_op = if scheduled.prompt[0].token_offset() != 0 {
262 CacheInstruction::In
263 } else {
264 CacheInstruction::Reset {
265 load_preallocated_cache: true,
266 reset_non_granular: false,
267 }
268 };
269
270 pipeline
271 .step(
272 &mut scheduled.prompt,
273 true,
274 return_raw_logits,
275 &mut *get_mut_arcmutex!(self.prefix_cacher),
276 self.disable_eos_stop,
277 rng.clone(),
278 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
279 )
280 .await
281 };
282
283 let prompt_exec_time = handle_pipeline_forward_error!(
284 "prompt step",
285 prompt_exec_time,
286 &mut scheduled.prompt,
287 self.pipeline,
288 'lp,
289 self.prefix_cacher
290 );
291
292 let total_processed_tokens: usize = scheduled
293 .prompt
294 .iter()
295 .map(|seq| seq.get_toks().len())
296 .sum();
297 self.logger.add_tokens_processed(total_processed_tokens);
298
299 for seq in scheduled.prompt.iter_mut() {
300 match seq.sequence_stepping_type() {
301 SeqStepType::OneShot => {
302 seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
303 }
304 SeqStepType::PromptAndDecode => {
305 seq.set_state(SequenceState::RunningCompletion)
306 }
307 }
308 let now = SystemTime::now()
309 .duration_since(UNIX_EPOCH)
310 .expect("Time travel has occurred!")
311 .as_millis();
312 #[allow(clippy::cast_precision_loss)]
313 let prompt_tok_per_sec =
314 seq.len() as f32 / prompt_exec_time.as_secs_f32();
315 seq.prompt_tok_per_sec = prompt_tok_per_sec;
316 seq.prompt_timestamp = Some(now);
317 }
318 last_completion_ids = vec![];
319 }
320
321 if self.is_debug {
322 let ms_from_last_run = run_start.elapsed().as_secs_f64();
323 let total_len = scheduled.prompt.len() + scheduled.completion.len();
324 if total_len > 0 {
325 let prompt_lengths = scheduled
326 .prompt
327 .iter()
328 .map(|seq| seq.len().to_string())
329 .collect::<Vec<_>>()
330 .join(", ");
331
332 let completion_lengths = scheduled
333 .completion
334 .iter()
335 .map(|seq| seq.len().to_string())
336 .collect::<Vec<_>>()
337 .join(", ");
338
339 tracing::info!(
340 "Prompt[{}] Completion[{}] - {}ms",
341 prompt_lengths,
342 completion_lengths,
343 ms_from_last_run * 1000.,
344 );
345 }
346 }
347 }
348 SchedulerOutput::PagedAttention { mut output } => {
349 if !output.scheduled.is_empty() {
350 let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
351
352 let mut guards = output
353 .scheduled
354 .iter_mut()
355 .map(|seq| seq.lock().unwrap())
356 .collect::<Vec<_>>();
357
358 let mut guards_mut =
359 guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
360
361 let res = {
362 let mut pipeline = get_mut_arcmutex!(self.pipeline);
363
364 let block_size = scheduler.block_size().unwrap();
365
366 let metadata = PagedAttentionMeta {
367 block_size,
368 sliding_window: pipeline.get_metadata().sliding_window,
369 block_engine: scheduler.block_engine().unwrap(),
370 };
371
372 let return_raw_logits = guards_mut[0].return_raw_logits;
373 assert!(
374 guards_mut
375 .iter()
376 .all(|seq| seq.return_raw_logits == return_raw_logits),
377 "All sequences must either return raw logits, or not."
378 );
379
380 pipeline
381 .step(
382 &mut guards_mut,
383 is_prompt,
384 return_raw_logits,
385 &mut *get_mut_arcmutex!(self.prefix_cacher),
386 self.disable_eos_stop,
387 rng.clone(),
388 CacheBackendMetadata::PagedAttention {
389 metadata,
390 blocks_to_copy: output.blocks_to_copy,
391 blocks_to_swap_in: output.blocks_to_swap_in,
392 blocks_to_swap_out: output.blocks_to_swap_out,
393 },
394 )
395 .await
396 };
397
398 handle_pipeline_forward_error!(
399 "step",
400 res,
401 &mut guards_mut,
402 self.pipeline,
403 'lp,
404 self.prefix_cacher
405 );
406
407 let total_processed_tokens: usize = guards
408 .iter()
409 .map(|seq| {
410 if seq.is_prompt() {
411 seq.get_toks().len()
412 } else {
413 1
414 }
415 })
416 .sum();
417 self.logger.add_tokens_processed(total_processed_tokens);
418
419 if self.is_debug {
420 let ms_from_last_run = run_start.elapsed().as_secs_f64();
421 let total_len = guards.len();
422 if total_len > 0 {
423 let lengths = guards
424 .iter()
425 .map(|seq| seq.len().to_string())
426 .collect::<Vec<_>>()
427 .join(", ");
428
429 let (prompt_lengths, completion_lengths) = if is_prompt {
430 (lengths, "".to_string())
431 } else {
432 ("".to_string(), lengths)
433 };
434
435 tracing::info!(
436 "Prompt[{}] Completion[{}] - {}ms",
437 prompt_lengths,
438 completion_lengths,
439 ms_from_last_run * 1000.,
440 );
441 }
442 }
443
444 if is_prompt {
445 for mut seq in guards {
446 let now = SystemTime::now()
447 .duration_since(UNIX_EPOCH)
448 .expect("Time travel has occurred!")
449 .as_millis();
450 #[allow(clippy::cast_precision_loss)]
451 let prompt_tok_per_sec =
452 seq.len() as f32 / (now - seq.timestamp()) as f32;
453 seq.prompt_tok_per_sec = prompt_tok_per_sec * 1000.;
454 seq.prompt_timestamp = Some(now);
455 }
456 }
457 }
458 }
459 }
460
461 scheduler.free_finished_sequence_groups();
462 }
463 }
464
465 fn build_sequence_recognizer(
466 tok_env: &Option<TokEnv>,
467 constraint: &Constraint,
468 ) -> anyhow::Result<SequenceRecognizer> {
469 if let Some(grm) = llg_grammar_from_constraint(constraint)? {
470 let tok_env = tok_env
471 .as_ref()
472 .ok_or_else(|| anyhow::anyhow!("No token environment found."))?;
473 let llg = constraint_from_llg_grammar(tok_env.clone(), grm)?;
474 Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
475 } else {
476 Ok(SequenceRecognizer::None)
477 }
478 }
479
480 fn replicate_request_to_daemons(&self, request: &Request) {
481 if !distributed::is_daemon() && mistralrs_quant::distributed::use_nccl() {
482 let name = distributed::ipc_name().unwrap();
483 let num_workers =
484 mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
485 let listener = ListenerOptions::new().name(name).create_sync().unwrap();
486
487 for _ in 0..num_workers {
488 let stream = listener.accept().unwrap();
489 let mut writer = BufWriter::new(stream);
490 let req = format!("{}\n", serde_json::to_string(&request).unwrap());
491 writer.write_all(req.as_bytes()).unwrap();
492 }
493 };
494 }
495}