1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5 BertEmbeddingModel, EngineInstruction, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
6};
7use hf_hub::Cache;
8pub use lora::Ordering;
9pub use pipeline::ModelCategory;
10pub use pipeline::Pipeline;
11#[cfg(feature = "pyo3_macros")]
12use pyo3::exceptions::PyValueError;
13use std::io::BufRead;
14use std::io::BufReader;
15use std::sync::OnceLock;
16use std::time::Instant;
17use std::{
18 cell::RefCell,
19 error::Error,
20 fs::OpenOptions,
21 io::Write,
22 sync::{
23 atomic::{self, AtomicBool, AtomicUsize},
24 Arc, Mutex, RwLock,
25 },
26 thread::{self, JoinHandle},
27 time::{SystemTime, UNIX_EPOCH},
28};
29use tokio::sync::mpsc::{channel, Sender};
30use tracing::info;
31use tracing::warn;
32
33mod cuda;
34mod device_map;
35mod engine;
36mod lora;
37mod model_loader;
38mod ops;
39pub use model_loader::{
40 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
41};
42mod search;
43
44mod model_selected;
45pub use model_selected::ModelSelected;
46pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
47
48mod amoe;
49#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
50mod dummy_paged_attention;
51mod embedding;
52mod gguf;
53pub mod layers;
54mod layers_masker;
55mod layers_utils;
56mod models;
57#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
58mod paged_attention;
59#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
60use dummy_paged_attention as paged_attention;
61mod attention;
62mod diffusion_models;
63pub mod distributed;
64mod pipeline;
65mod prefix_cacher;
66mod request;
67mod response;
68mod sampler;
69mod scheduler;
70mod sequence;
71mod toml_selector;
72mod tools;
73mod topology;
74mod utils;
75mod vision_models;
76mod xlora_models;
77
78pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
79pub use device_map::{
80 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
81};
82pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
83pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
84pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig};
85pub use pipeline::{
86 chat_template::ChatTemplate, parse_isq_value, AnyMoeLoader, AnyMoePipeline,
87 AutoDeviceMapParams, DiffusionGenerationParams, DiffusionLoader, DiffusionLoaderBuilder,
88 DiffusionLoaderType, DiffusionSpecificConfig, GGMLLoader, GGMLLoaderBuilder,
89 GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig, GemmaLoader,
90 Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader,
91 LocalModelPaths, MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoader,
92 NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader,
93 Phi3VLoader, Qwen2Loader, SpeculativeConfig, SpeculativeLoader, SpeculativePipeline,
94 Starcoder2Loader, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
95 VisionPromptPrefixer, VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
96};
97pub use request::{
98 ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
99 LlguidanceGrammar, MessageContent, NormalRequest, Request, RequestMessage, TokenizationRequest,
100 WebSearchOptions, WebSearchUserLocation,
101};
102pub use response::*;
103pub use sampler::{
104 CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
105};
106pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
107use serde::Serialize;
108use tokio::runtime::Runtime;
109use toml_selector::{TomlLoaderArgs, TomlSelector};
110pub use tools::{
111 CalledFunction, Function, Tool, ToolCallResponse, ToolCallType, ToolChoice, ToolType,
112};
113pub use topology::{LayerTopology, Topology};
114pub use utils::debug::initialize_logging;
115pub use utils::memory_usage::MemoryUsage;
116pub use utils::normal::{ModelDType, TryIntoDType};
117pub use utils::{paged_attn_supported, using_flash_attn};
118
119pub use llguidance;
121
122pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
124pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
125static ENGINE_ID: AtomicUsize = AtomicUsize::new(0);
126
127pub struct MistralRsConfig {
128 pub kind: ModelKind,
129 pub device: Device,
130 pub category: ModelCategory,
131}
132
133pub struct MistralRs {
138 sender: RwLock<Sender<Request>>,
139 log: Option<String>,
140 id: String,
141 creation_time: u64,
142 next_request_id: Mutex<RefCell<usize>>,
143 reboot_state: RebootState,
144 engine_handler: RwLock<JoinHandle<()>>,
145 engine_id: usize,
146 category: ModelCategory,
147 config: MistralRsConfig,
148}
149
150#[derive(Clone)]
151struct RebootState {
152 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
153 method: SchedulerConfig,
154 truncate_sequence: bool,
155 no_kv_cache: bool,
156 no_prefix_cache: bool,
157 prefix_cache_n: usize,
158 disable_eos_stop: bool,
159 throughput_logging_enabled: bool,
160 search_embedding_model: Option<BertEmbeddingModel>,
161}
162
163#[derive(Debug)]
164pub enum MistralRsError {
165 EnginePoisoned,
166 SenderPoisoned,
167}
168
169impl std::fmt::Display for MistralRsError {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 write!(f, "{:?}", &self)
172 }
173}
174
175impl std::error::Error for MistralRsError {}
176
177#[cfg(feature = "pyo3_macros")]
178impl From<MistralRsError> for pyo3::PyErr {
179 fn from(value: MistralRsError) -> Self {
180 PyValueError::new_err(format!("{:?}", value))
181 }
182}
183
184pub struct MistralRsBuilder {
188 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
189 method: SchedulerConfig,
190 log: Option<String>,
191 truncate_sequence: Option<bool>,
192 no_kv_cache: Option<bool>,
193 no_prefix_cache: Option<bool>,
194 prefix_cache_n: Option<usize>,
195 disable_eos_stop: Option<bool>,
196 throughput_logging_enabled: bool,
197 search_embedding_model: Option<BertEmbeddingModel>,
198}
199
200impl MistralRsBuilder {
201 pub fn new(
202 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
203 method: SchedulerConfig,
204 throughput_logging: bool,
205 search_embedding_model: Option<BertEmbeddingModel>,
206 ) -> Self {
207 Self {
208 pipeline,
209 method,
210 log: None,
211 truncate_sequence: None,
212 no_kv_cache: None,
213 no_prefix_cache: None,
214 prefix_cache_n: None,
215 disable_eos_stop: None,
216 throughput_logging_enabled: throughput_logging,
217 search_embedding_model,
218 }
219 }
220 pub fn with_log(mut self, log: String) -> Self {
221 self.log = Some(log);
222 self
223 }
224 pub fn with_opt_log(mut self, log: Option<String>) -> Self {
225 self.log = log;
226 self
227 }
228 pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
229 self.truncate_sequence = Some(truncate_sequence);
230 self
231 }
232 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
233 self.no_kv_cache = Some(no_kv_cache);
234 self
235 }
236 pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
237 self.no_prefix_cache = Some(no_prefix_cache);
238 self
239 }
240 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
241 self.prefix_cache_n = Some(prefix_cache_n);
242 self
243 }
244 pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
245 self.disable_eos_stop = Some(disable_eos_stop);
246 self
247 }
248
249 pub fn build(self) -> Arc<MistralRs> {
250 MistralRs::new(self)
251 }
252}
253
254impl Drop for MistralRs {
255 fn drop(&mut self) {
256 ENGINE_INSTRUCTIONS
257 .lock()
258 .expect("`ENGINE_INSTRUCTIONS` was poisioned")
259 .insert(self.engine_id, Some(EngineInstruction::Terminate));
260 }
261}
262
263impl MistralRs {
264 fn new(config: MistralRsBuilder) -> Arc<Self> {
265 let MistralRsBuilder {
266 pipeline,
267 method,
268 log,
269 truncate_sequence,
270 no_kv_cache,
271 no_prefix_cache,
272 prefix_cache_n,
273 disable_eos_stop,
274 throughput_logging_enabled,
275 search_embedding_model,
276 } = config;
277
278 let category = pipeline.try_lock().unwrap().category();
279 mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
280 get_mut_arcmutex!(pipeline).device(),
281 );
282
283 let truncate_sequence = truncate_sequence.unwrap_or(false);
284 let no_kv_cache = no_kv_cache.unwrap_or(false);
285 let no_prefix_cache = no_prefix_cache.unwrap_or(false);
286 let prefix_cache_n = prefix_cache_n.unwrap_or(16);
287 let disable_eos_stop = disable_eos_stop.unwrap_or(false);
288
289 let reboot_state = RebootState {
290 pipeline: pipeline.clone(),
291 method: method.clone(),
292 truncate_sequence,
293 no_kv_cache,
294 no_prefix_cache,
295 prefix_cache_n,
296 disable_eos_stop,
297 throughput_logging_enabled,
298 search_embedding_model: search_embedding_model.clone(),
299 };
300
301 let (tx, rx) = channel(10_000);
302
303 let sender = RwLock::new(tx);
304 let id = pipeline.try_lock().unwrap().name();
305
306 let kind = pipeline.try_lock().unwrap().get_metadata().kind.clone();
307 let device = pipeline.try_lock().unwrap().device();
308 let config = MistralRsConfig {
309 kind,
310 device,
311 category: category.clone(),
312 };
313
314 let engine_handler = thread::spawn(move || {
315 #[cfg(feature = "metal")]
316 objc::rc::autoreleasepool(move || {
317 let rt = Runtime::new().unwrap();
318 rt.block_on(async move {
319 let engine = Engine::new(
320 rx,
321 pipeline,
322 method,
323 truncate_sequence,
324 no_kv_cache,
325 no_prefix_cache,
326 prefix_cache_n,
327 disable_eos_stop,
328 throughput_logging_enabled,
329 search_embedding_model,
330 )
331 .expect("Engine creation failed.");
332 Arc::new(engine).run().await;
333 })
334 });
335
336 #[cfg(not(feature = "metal"))]
337 {
338 let rt = Runtime::new().unwrap();
339 rt.block_on(async move {
340 let engine = Engine::new(
341 rx,
342 pipeline,
343 method,
344 truncate_sequence,
345 no_kv_cache,
346 no_prefix_cache,
347 prefix_cache_n,
348 disable_eos_stop,
349 throughput_logging_enabled,
350 search_embedding_model,
351 )
352 .expect("Engine creation failed.");
353 Arc::new(engine).run().await;
354 })
355 }
356 });
357
358 let engine_id = ENGINE_ID.fetch_add(1, atomic::Ordering::SeqCst);
359
360 if distributed::is_daemon() {
361 let request_sender = sender.write().unwrap().clone();
362 thread::spawn(move || {
363 let rt = Runtime::new().unwrap();
364 rt.block_on(async move {
365 use interprocess::local_socket::traits::Stream;
366 use interprocess::local_socket::Stream as LocalStream;
367
368 loop {
369 let name = distributed::ipc_name().unwrap();
370 if let Ok(stream) = LocalStream::connect(name) {
371 let mut reader = BufReader::new(stream);
372 let mut buf = String::new();
373 reader.read_line(&mut buf).unwrap();
374 let mut req: Request = serde_json::from_str(&buf).unwrap();
375
376 req = match req {
377 Request::ReIsq(x) => Request::ReIsq(x),
378 Request::Terminate => Request::Terminate,
379 Request::Detokenize(mut x) => {
380 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
381 x.response = sender;
382 let req = Request::Detokenize(x);
383
384 request_sender.send(req).await.unwrap();
385 let resp = receiver.recv().await.unwrap();
386 resp.unwrap();
387 continue;
388 }
389 Request::Tokenize(mut x) => {
390 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
391 x.response = sender;
392 let req = Request::Tokenize(x);
393
394 request_sender.send(req).await.unwrap();
395 let resp = receiver.recv().await.unwrap();
396 resp.unwrap();
397 continue;
398 }
399 Request::Normal(mut x) => {
400 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
401 x.is_streaming = false;
402 x.response = sender;
403 let req = Request::Normal(x);
404
405 request_sender.send(req).await.unwrap();
406 let resp = receiver.recv().await.unwrap();
407 resp.as_result().unwrap();
408 continue;
409 }
410 Request::TerminateAllSeqsNextStep => {
411 Request::TerminateAllSeqsNextStep
412 }
413 };
414
415 request_sender.send(req).await.unwrap();
416 }
417 }
418 });
419 });
420
421 #[allow(clippy::empty_loop)]
422 loop {}
423 }
424
425 let is_multi_threaded = tokio::runtime::Handle::try_current()
427 .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
428
429 if !distributed::is_daemon()
431 && is_multi_threaded
432 && matches!(category, ModelCategory::Text | ModelCategory::Vision { .. })
433 {
434 let clone_sender = sender.read().unwrap().clone();
435 tokio::task::block_in_place(|| {
436 let (tx, mut rx) = channel(1);
437 let req = Request::Normal(NormalRequest {
438 id: 0,
439 messages: RequestMessage::Completion {
440 text: "hello".to_string(),
441 echo_prompt: false,
442 best_of: None,
443 },
444 sampling_params: SamplingParams {
445 max_len: Some(1),
446 ..SamplingParams::deterministic()
447 },
448 response: tx,
449 return_logprobs: false,
450 is_streaming: false,
451 constraint: Constraint::None,
452 suffix: None,
453 tool_choice: None,
454 tools: None,
455 logits_processors: None,
456 return_raw_logits: false,
457 web_search_options: None,
458 });
459 info!("Beginning dummy run.");
460 let start = Instant::now();
461 clone_sender.blocking_send(req).unwrap();
462
463 if let Some(_resp) = rx.blocking_recv() {
464 let end = Instant::now();
465 info!(
466 "Dummy run completed in {}s.",
467 end.duration_since(start).as_secs_f64()
468 );
469 } else {
470 warn!("Dummy run failed!");
471 }
472 });
473 }
474
475 Arc::new(Self {
476 engine_id,
477 sender,
478 log,
479 id,
480 creation_time: SystemTime::now()
481 .duration_since(UNIX_EPOCH)
482 .expect("Time travel has occurred!")
483 .as_secs(),
484 next_request_id: Mutex::new(RefCell::new(1)),
485 reboot_state,
486 engine_handler: RwLock::new(engine_handler),
487 category,
488 config,
489 })
490 }
491
492 fn reboot_engine(&self) -> Result<(), MistralRsError> {
495 let (new_sender, rx) = channel(10_000);
496 let reboot_state = self.reboot_state.clone();
497 let mut sender_lock = self.sender.write().map_err(|_| {
498 tracing::warn!("Couldn't get write lock on the sender during reboot attempt");
499 MistralRsError::SenderPoisoned
500 })?;
501 let mut engine_lock = self.engine_handler.write().map_err(|_| {
502 tracing::warn!("Couldn't get write lock on the engine during reboot attempt");
503 MistralRsError::EnginePoisoned
504 })?;
505
506 if !engine_lock.is_finished() {
507 tracing::info!("Engine already running, returning ok");
508 Ok(())
509 } else {
510 let new_engine_handler = thread::spawn(move || {
512 let rt = Runtime::new().unwrap();
513 rt.block_on(async move {
514 let engine = Engine::new(
515 rx,
516 reboot_state.pipeline.clone(),
517 reboot_state.method,
518 reboot_state.truncate_sequence,
519 reboot_state.no_kv_cache,
520 reboot_state.no_prefix_cache,
521 reboot_state.prefix_cache_n,
522 reboot_state.disable_eos_stop,
523 reboot_state.throughput_logging_enabled,
524 reboot_state.search_embedding_model,
525 )
526 .expect("Engine creation failed");
527 Arc::new(engine).run().await;
528 });
529 });
530 *sender_lock = new_sender;
531 *engine_lock = new_engine_handler;
532 tracing::info!("Successfully rebooted engine and updated sender + engine handler");
533 Ok(())
534 }
535 }
536
537 fn engine_dead(&self) -> Result<bool, MistralRsError> {
538 match self.engine_handler.read() {
539 Ok(handler) => Ok(handler.is_finished()),
540 Err(_) => {
541 tracing::warn!("Couldn't get read lock on engine!");
542 Err(MistralRsError::EnginePoisoned)
543 }
544 }
545 }
546
547 pub fn get_sender(&self) -> Result<Sender<Request>, MistralRsError> {
548 if self.engine_dead()? {
549 tracing::warn!("Engine is dead, rebooting");
550 self.reboot_engine()?
551 }
552 match self.sender.read() {
553 Ok(sender) => Ok(sender.clone()),
554 Err(_) => Err(MistralRsError::SenderPoisoned),
555 }
556 }
557
558 pub fn get_id(&self) -> String {
559 self.id.clone()
560 }
561
562 pub fn get_creation_time(&self) -> u64 {
563 self.creation_time
564 }
565
566 pub fn get_model_category(&self) -> ModelCategory {
567 self.category.clone()
568 }
569
570 pub fn next_request_id(&self) -> usize {
571 let l = self.next_request_id.lock().unwrap();
572 let last = &mut *l.borrow_mut();
573 let last_v = *last;
574 *last += 1;
575 last_v
576 }
577
578 pub fn maybe_log_request(this: Arc<Self>, repr: String) {
579 if let Some(file) = &this.log {
580 let mut f = OpenOptions::new()
581 .append(true)
582 .create(true) .open(file)
584 .expect("Unable to open file");
585 let time = chrono::offset::Local::now();
586 f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
587 .expect("Unable to write data");
588 }
589 }
590
591 pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
592 if let Some(file) = &this.log {
593 let mut f = OpenOptions::new()
594 .append(true)
595 .create(true) .open(file)
597 .expect("Unable to open file");
598 let time = chrono::offset::Local::now();
599 let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
600 f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
601 .expect("Unable to write data");
602 }
603 }
604
605 pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
606 if let Some(file) = &this.log {
607 let mut f = OpenOptions::new()
608 .append(true)
609 .create(true) .open(file)
611 .expect("Unable to open file");
612 let time = chrono::offset::Local::now();
613 f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
614 .expect("Unable to write data");
615 }
616 }
617
618 pub fn config(&self) -> &MistralRsConfig {
619 &self.config
620 }
621}