1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5 BertEmbeddingModel, EngineInstruction, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
6};
7use hf_hub::Cache;
8pub use lora::Ordering;
9pub use pipeline::ModelCategory;
10pub use pipeline::Pipeline;
11#[cfg(feature = "pyo3_macros")]
12use pyo3::exceptions::PyValueError;
13use std::io::BufRead;
14use std::io::BufReader;
15use std::sync::OnceLock;
16use std::time::Instant;
17use std::{
18 cell::RefCell,
19 error::Error,
20 fs::OpenOptions,
21 io::Write,
22 sync::{
23 atomic::{self, AtomicBool, AtomicUsize},
24 Arc, Mutex, RwLock,
25 },
26 thread::{self, JoinHandle},
27 time::{SystemTime, UNIX_EPOCH},
28};
29use tokio::sync::mpsc::{channel, Sender};
30use tracing::info;
31use tracing::warn;
32
33mod cuda;
34mod device_map;
35mod engine;
36mod lora;
37mod model_loader;
38mod ops;
39pub use model_loader::{
40 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
41};
42mod kv_cache;
43mod search;
44
45mod model_selected;
46pub use model_selected::ModelSelected;
47pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
48
49mod amoe;
50#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
51mod dummy_paged_attention;
52mod embedding;
53mod gguf;
54pub mod layers;
55mod layers_masker;
56mod layers_utils;
57mod models;
58#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
59mod paged_attention;
60#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
61use dummy_paged_attention as paged_attention;
62mod attention;
63mod diffusion_models;
64pub mod distributed;
65mod pipeline;
66mod prefix_cacher;
67mod request;
68mod response;
69mod sampler;
70mod scheduler;
71mod sequence;
72mod speech_models;
73mod toml_selector;
74mod tools;
75mod topology;
76mod utils;
77mod vision_models;
78mod xlora_models;
79
80pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
81pub use device_map::{
82 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
83};
84pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
85pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
86pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig};
87pub use pipeline::{
88 chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
89 AutoDeviceMapParams, DiffusionGenerationParams, DiffusionLoader, DiffusionLoaderBuilder,
90 DiffusionLoaderType, GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader,
91 GGUFLoaderBuilder, GGUFSpecificConfig, GemmaLoader, Idefics2Loader, IsqOrganization,
92 LLaVALoader, LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths, LoraAdapterPaths,
93 MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoader, NormalLoaderBuilder,
94 NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader,
95 SpeculativeConfig, SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline,
96 Starcoder2Loader, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
97 VisionPromptPrefixer, VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
98};
99pub use request::{
100 ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
101 LlguidanceGrammar, MessageContent, NormalRequest, Request, RequestMessage, SearchContextSize,
102 TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
103};
104pub use response::*;
105pub use sampler::{
106 CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
107};
108pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
109use serde::Serialize;
110pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
111use tokio::runtime::Runtime;
112use toml_selector::{TomlLoaderArgs, TomlSelector};
113pub use tools::{
114 CalledFunction, Function, Tool, ToolCallResponse, ToolCallType, ToolChoice, ToolType,
115};
116pub use topology::{LayerTopology, Topology};
117pub use utils::debug::initialize_logging;
118pub use utils::memory_usage::MemoryUsage;
119pub use utils::normal::{ModelDType, TryIntoDType};
120pub use utils::{paged_attn_supported, using_flash_attn};
121
122pub use llguidance;
124
125pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
127pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
128static ENGINE_ID: AtomicUsize = AtomicUsize::new(0);
129
130pub struct MistralRsConfig {
131 pub kind: ModelKind,
132 pub device: Device,
133 pub category: ModelCategory,
134}
135
136pub struct MistralRs {
141 sender: RwLock<Sender<Request>>,
142 log: Option<String>,
143 id: String,
144 creation_time: u64,
145 next_request_id: Mutex<RefCell<usize>>,
146 reboot_state: RebootState,
147 engine_handler: RwLock<JoinHandle<()>>,
148 engine_id: usize,
149 category: ModelCategory,
150 config: MistralRsConfig,
151}
152
153#[derive(Clone)]
154struct RebootState {
155 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
156 method: SchedulerConfig,
157 truncate_sequence: bool,
158 no_kv_cache: bool,
159 no_prefix_cache: bool,
160 prefix_cache_n: usize,
161 disable_eos_stop: bool,
162 throughput_logging_enabled: bool,
163 search_embedding_model: Option<BertEmbeddingModel>,
164}
165
166#[derive(Debug)]
167pub enum MistralRsError {
168 EnginePoisoned,
169 SenderPoisoned,
170}
171
172impl std::fmt::Display for MistralRsError {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 write!(f, "{:?}", &self)
175 }
176}
177
178impl std::error::Error for MistralRsError {}
179
180#[cfg(feature = "pyo3_macros")]
181impl From<MistralRsError> for pyo3::PyErr {
182 fn from(value: MistralRsError) -> Self {
183 PyValueError::new_err(format!("{:?}", value))
184 }
185}
186
187pub struct MistralRsBuilder {
191 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
192 method: SchedulerConfig,
193 log: Option<String>,
194 truncate_sequence: Option<bool>,
195 no_kv_cache: Option<bool>,
196 no_prefix_cache: Option<bool>,
197 prefix_cache_n: Option<usize>,
198 disable_eos_stop: Option<bool>,
199 throughput_logging_enabled: bool,
200 search_embedding_model: Option<BertEmbeddingModel>,
201}
202
203impl MistralRsBuilder {
204 pub fn new(
205 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
206 method: SchedulerConfig,
207 throughput_logging: bool,
208 search_embedding_model: Option<BertEmbeddingModel>,
209 ) -> Self {
210 Self {
211 pipeline,
212 method,
213 log: None,
214 truncate_sequence: None,
215 no_kv_cache: None,
216 no_prefix_cache: None,
217 prefix_cache_n: None,
218 disable_eos_stop: None,
219 throughput_logging_enabled: throughput_logging,
220 search_embedding_model,
221 }
222 }
223 pub fn with_log(mut self, log: String) -> Self {
224 self.log = Some(log);
225 self
226 }
227 pub fn with_opt_log(mut self, log: Option<String>) -> Self {
228 self.log = log;
229 self
230 }
231 pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
232 self.truncate_sequence = Some(truncate_sequence);
233 self
234 }
235 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
236 self.no_kv_cache = Some(no_kv_cache);
237 self
238 }
239 pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
240 self.no_prefix_cache = Some(no_prefix_cache);
241 self
242 }
243 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
244 self.prefix_cache_n = Some(prefix_cache_n);
245 self
246 }
247 pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
248 self.disable_eos_stop = Some(disable_eos_stop);
249 self
250 }
251
252 pub fn build(self) -> Arc<MistralRs> {
253 MistralRs::new(self)
254 }
255}
256
257impl Drop for MistralRs {
258 fn drop(&mut self) {
259 ENGINE_INSTRUCTIONS
260 .lock()
261 .expect("`ENGINE_INSTRUCTIONS` was poisoned")
262 .insert(self.engine_id, Some(EngineInstruction::Terminate));
263 }
264}
265
266impl MistralRs {
267 fn new(config: MistralRsBuilder) -> Arc<Self> {
268 let MistralRsBuilder {
269 pipeline,
270 method,
271 log,
272 truncate_sequence,
273 no_kv_cache,
274 no_prefix_cache,
275 prefix_cache_n,
276 disable_eos_stop,
277 throughput_logging_enabled,
278 search_embedding_model,
279 } = config;
280
281 let category = pipeline.try_lock().unwrap().category();
282 mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
283 get_mut_arcmutex!(pipeline).device(),
284 );
285
286 let truncate_sequence = truncate_sequence.unwrap_or(false);
287 let no_kv_cache = no_kv_cache.unwrap_or(false);
288 let no_prefix_cache = no_prefix_cache.unwrap_or(false);
289 let prefix_cache_n = prefix_cache_n.unwrap_or(16);
290 let disable_eos_stop = disable_eos_stop.unwrap_or(false);
291
292 let reboot_state = RebootState {
293 pipeline: pipeline.clone(),
294 method: method.clone(),
295 truncate_sequence,
296 no_kv_cache,
297 no_prefix_cache,
298 prefix_cache_n,
299 disable_eos_stop,
300 throughput_logging_enabled,
301 search_embedding_model: search_embedding_model.clone(),
302 };
303
304 let (tx, rx) = channel(10_000);
305
306 let sender = RwLock::new(tx);
307 let id = pipeline.try_lock().unwrap().name();
308
309 let kind = pipeline.try_lock().unwrap().get_metadata().kind.clone();
310 let device = pipeline.try_lock().unwrap().device();
311 let config = MistralRsConfig {
312 kind,
313 device,
314 category: category.clone(),
315 };
316
317 let engine_handler = thread::spawn(move || {
318 #[cfg(feature = "metal")]
319 objc::rc::autoreleasepool(move || {
320 let rt = Runtime::new().unwrap();
321 rt.block_on(async move {
322 let engine = Engine::new(
323 rx,
324 pipeline,
325 method,
326 truncate_sequence,
327 no_kv_cache,
328 no_prefix_cache,
329 prefix_cache_n,
330 disable_eos_stop,
331 throughput_logging_enabled,
332 search_embedding_model,
333 )
334 .expect("Engine creation failed.");
335 Arc::new(engine).run().await;
336 })
337 });
338
339 #[cfg(not(feature = "metal"))]
340 {
341 let rt = Runtime::new().unwrap();
342 rt.block_on(async move {
343 let engine = Engine::new(
344 rx,
345 pipeline,
346 method,
347 truncate_sequence,
348 no_kv_cache,
349 no_prefix_cache,
350 prefix_cache_n,
351 disable_eos_stop,
352 throughput_logging_enabled,
353 search_embedding_model,
354 )
355 .expect("Engine creation failed.");
356 Arc::new(engine).run().await;
357 })
358 }
359 });
360
361 let engine_id = ENGINE_ID.fetch_add(1, atomic::Ordering::SeqCst);
362
363 if distributed::is_daemon() {
364 let request_sender = sender.write().unwrap().clone();
365 thread::spawn(move || {
366 let rt = Runtime::new().unwrap();
367 rt.block_on(async move {
368 use interprocess::local_socket::traits::Stream;
369 use interprocess::local_socket::Stream as LocalStream;
370
371 loop {
372 let name = distributed::ipc_name().unwrap();
373 if let Ok(stream) = LocalStream::connect(name) {
374 let mut reader = BufReader::new(stream);
375 let mut buf = String::new();
376 reader.read_line(&mut buf).unwrap();
377 let mut req: Request = serde_json::from_str(&buf).unwrap();
378
379 req = match req {
380 Request::ReIsq(x) => Request::ReIsq(x),
381 Request::Terminate => Request::Terminate,
382 Request::Detokenize(mut x) => {
383 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
384 x.response = sender;
385 let req = Request::Detokenize(x);
386
387 request_sender.send(req).await.unwrap();
388 let resp = receiver.recv().await.unwrap();
389 resp.unwrap();
390 continue;
391 }
392 Request::Tokenize(mut x) => {
393 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
394 x.response = sender;
395 let req = Request::Tokenize(x);
396
397 request_sender.send(req).await.unwrap();
398 let resp = receiver.recv().await.unwrap();
399 resp.unwrap();
400 continue;
401 }
402 Request::Normal(mut x) => {
403 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
404 x.is_streaming = false;
405 x.response = sender;
406 let req = Request::Normal(x);
407
408 request_sender.send(req).await.unwrap();
409 let resp = receiver.recv().await.unwrap();
410 resp.as_result().unwrap();
411 continue;
412 }
413 Request::TerminateAllSeqsNextStep => {
414 Request::TerminateAllSeqsNextStep
415 }
416 };
417
418 request_sender.send(req).await.unwrap();
419 }
420 }
421 });
422 });
423
424 #[allow(clippy::empty_loop)]
425 loop {}
426 }
427
428 let is_multi_threaded = tokio::runtime::Handle::try_current()
430 .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
431
432 if !distributed::is_daemon()
434 && is_multi_threaded
435 && matches!(category, ModelCategory::Text | ModelCategory::Vision { .. })
436 {
437 let clone_sender = sender.read().unwrap().clone();
438 tokio::task::block_in_place(|| {
439 let (tx, mut rx) = channel(1);
440 let req = Request::Normal(Box::new(NormalRequest {
441 id: 0,
442 messages: RequestMessage::Completion {
443 text: "hello".to_string(),
444 echo_prompt: false,
445 best_of: None,
446 },
447 sampling_params: SamplingParams {
448 max_len: Some(1),
449 ..SamplingParams::deterministic()
450 },
451 response: tx,
452 return_logprobs: false,
453 is_streaming: false,
454 constraint: Constraint::None,
455 suffix: None,
456 tool_choice: None,
457 tools: None,
458 logits_processors: None,
459 return_raw_logits: false,
460 web_search_options: None,
461 }));
462 info!("Beginning dummy run.");
463 let start = Instant::now();
464 clone_sender.blocking_send(req).unwrap();
465
466 if let Some(_resp) = rx.blocking_recv() {
467 let end = Instant::now();
468 info!(
469 "Dummy run completed in {}s.",
470 end.duration_since(start).as_secs_f64()
471 );
472 } else {
473 warn!("Dummy run failed!");
474 }
475 });
476 }
477
478 Arc::new(Self {
479 engine_id,
480 sender,
481 log,
482 id,
483 creation_time: SystemTime::now()
484 .duration_since(UNIX_EPOCH)
485 .expect("Time travel has occurred!")
486 .as_secs(),
487 next_request_id: Mutex::new(RefCell::new(1)),
488 reboot_state,
489 engine_handler: RwLock::new(engine_handler),
490 category,
491 config,
492 })
493 }
494
495 fn reboot_engine(&self) -> Result<(), MistralRsError> {
498 let (new_sender, rx) = channel(10_000);
499 let reboot_state = self.reboot_state.clone();
500 let mut sender_lock = self.sender.write().map_err(|_| {
501 tracing::warn!("Couldn't get write lock on the sender during reboot attempt");
502 MistralRsError::SenderPoisoned
503 })?;
504 let mut engine_lock = self.engine_handler.write().map_err(|_| {
505 tracing::warn!("Couldn't get write lock on the engine during reboot attempt");
506 MistralRsError::EnginePoisoned
507 })?;
508
509 if !engine_lock.is_finished() {
510 tracing::info!("Engine already running, returning ok");
511 Ok(())
512 } else {
513 let new_engine_handler = thread::spawn(move || {
515 let rt = Runtime::new().unwrap();
516 rt.block_on(async move {
517 let engine = Engine::new(
518 rx,
519 reboot_state.pipeline.clone(),
520 reboot_state.method,
521 reboot_state.truncate_sequence,
522 reboot_state.no_kv_cache,
523 reboot_state.no_prefix_cache,
524 reboot_state.prefix_cache_n,
525 reboot_state.disable_eos_stop,
526 reboot_state.throughput_logging_enabled,
527 reboot_state.search_embedding_model,
528 )
529 .expect("Engine creation failed");
530 Arc::new(engine).run().await;
531 });
532 });
533 *sender_lock = new_sender;
534 *engine_lock = new_engine_handler;
535 tracing::info!("Successfully rebooted engine and updated sender + engine handler");
536 Ok(())
537 }
538 }
539
540 fn engine_dead(&self) -> Result<bool, MistralRsError> {
541 match self.engine_handler.read() {
542 Ok(handler) => Ok(handler.is_finished()),
543 Err(_) => {
544 tracing::warn!("Couldn't get read lock on engine!");
545 Err(MistralRsError::EnginePoisoned)
546 }
547 }
548 }
549
550 pub fn get_sender(&self) -> Result<Sender<Request>, MistralRsError> {
551 if self.engine_dead()? {
552 tracing::warn!("Engine is dead, rebooting");
553 self.reboot_engine()?
554 }
555 match self.sender.read() {
556 Ok(sender) => Ok(sender.clone()),
557 Err(_) => Err(MistralRsError::SenderPoisoned),
558 }
559 }
560
561 pub fn get_id(&self) -> String {
562 self.id.clone()
563 }
564
565 pub fn get_creation_time(&self) -> u64 {
566 self.creation_time
567 }
568
569 pub fn get_model_category(&self) -> ModelCategory {
570 self.category.clone()
571 }
572
573 pub fn next_request_id(&self) -> usize {
574 let l = self.next_request_id.lock().unwrap();
575 let last = &mut *l.borrow_mut();
576 let last_v = *last;
577 *last += 1;
578 last_v
579 }
580
581 pub fn maybe_log_request(this: Arc<Self>, repr: String) {
582 if let Some(file) = &this.log {
583 let mut f = OpenOptions::new()
584 .append(true)
585 .create(true) .open(file)
587 .expect("Unable to open file");
588 let time = chrono::offset::Local::now();
589 f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
590 .expect("Unable to write data");
591 }
592 }
593
594 pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
595 if let Some(file) = &this.log {
596 let mut f = OpenOptions::new()
597 .append(true)
598 .create(true) .open(file)
600 .expect("Unable to open file");
601 let time = chrono::offset::Local::now();
602 let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
603 f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
604 .expect("Unable to write data");
605 }
606 }
607
608 pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
609 if let Some(file) = &this.log {
610 let mut f = OpenOptions::new()
611 .append(true)
612 .create(true) .open(file)
614 .expect("Unable to open file");
615 let time = chrono::offset::Local::now();
616 f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
617 .expect("Unable to write data");
618 }
619 }
620
621 pub fn config(&self) -> &MistralRsConfig {
622 &self.config
623 }
624}