mistralrs_core/
lib.rs

1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5    BertEmbeddingModel, EngineInstruction, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
6};
7use hf_hub::Cache;
8pub use lora::Ordering;
9pub use pipeline::ModelCategory;
10pub use pipeline::Pipeline;
11#[cfg(feature = "pyo3_macros")]
12use pyo3::exceptions::PyValueError;
13use std::io::BufRead;
14use std::io::BufReader;
15use std::sync::OnceLock;
16use std::time::Instant;
17use std::{
18    cell::RefCell,
19    error::Error,
20    fs::OpenOptions,
21    io::Write,
22    sync::{
23        atomic::{self, AtomicBool, AtomicUsize},
24        Arc, Mutex, RwLock,
25    },
26    thread::{self, JoinHandle},
27    time::{SystemTime, UNIX_EPOCH},
28};
29use tokio::sync::mpsc::{channel, Sender};
30use tracing::info;
31use tracing::warn;
32
33mod cuda;
34mod device_map;
35mod engine;
36mod lora;
37mod model_loader;
38mod ops;
39pub use model_loader::{
40    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
41};
42mod search;
43
44mod model_selected;
45pub use model_selected::ModelSelected;
46pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
47
48mod amoe;
49#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
50mod dummy_paged_attention;
51mod embedding;
52mod gguf;
53pub mod layers;
54mod layers_masker;
55mod layers_utils;
56mod models;
57#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
58mod paged_attention;
59#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
60use dummy_paged_attention as paged_attention;
61mod attention;
62mod diffusion_models;
63pub mod distributed;
64mod pipeline;
65mod prefix_cacher;
66mod request;
67mod response;
68mod sampler;
69mod scheduler;
70mod sequence;
71mod toml_selector;
72mod tools;
73mod topology;
74mod utils;
75mod vision_models;
76mod xlora_models;
77
78pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
79pub use device_map::{
80    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
81};
82pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
83pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
84pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig};
85pub use pipeline::{
86    chat_template::ChatTemplate, parse_isq_value, AnyMoeLoader, AnyMoePipeline,
87    AutoDeviceMapParams, DiffusionGenerationParams, DiffusionLoader, DiffusionLoaderBuilder,
88    DiffusionLoaderType, DiffusionSpecificConfig, GGMLLoader, GGMLLoaderBuilder,
89    GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig, GemmaLoader,
90    Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader,
91    LocalModelPaths, MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoader,
92    NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader,
93    Phi3VLoader, Qwen2Loader, SpeculativeConfig, SpeculativeLoader, SpeculativePipeline,
94    Starcoder2Loader, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
95    VisionPromptPrefixer, VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
96};
97pub use request::{
98    ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
99    LlguidanceGrammar, MessageContent, NormalRequest, Request, RequestMessage, TokenizationRequest,
100    WebSearchOptions, WebSearchUserLocation,
101};
102pub use response::*;
103pub use sampler::{
104    CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
105};
106pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
107use serde::Serialize;
108use tokio::runtime::Runtime;
109use toml_selector::{TomlLoaderArgs, TomlSelector};
110pub use tools::{
111    CalledFunction, Function, Tool, ToolCallResponse, ToolCallType, ToolChoice, ToolType,
112};
113pub use topology::{LayerTopology, Topology};
114pub use utils::debug::initialize_logging;
115pub use utils::memory_usage::MemoryUsage;
116pub use utils::normal::{ModelDType, TryIntoDType};
117pub use utils::{paged_attn_supported, using_flash_attn};
118
119// re-export llguidance for easier LlguidanceGrammar construction
120pub use llguidance;
121
122/// `true` if `MISTRALRS_DEBUG=1`
123pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
124pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
125static ENGINE_ID: AtomicUsize = AtomicUsize::new(0);
126
127pub struct MistralRsConfig {
128    pub kind: ModelKind,
129    pub device: Device,
130    pub category: ModelCategory,
131}
132
133/// The MistralRs struct handles sending requests to the engine.
134/// It is the core multi-threaded component of mistral.rs, and uses `mpsc`
135/// `Sender` and `Receiver` primitives to send and receive requests to the
136/// engine.
137pub struct MistralRs {
138    sender: RwLock<Sender<Request>>,
139    log: Option<String>,
140    id: String,
141    creation_time: u64,
142    next_request_id: Mutex<RefCell<usize>>,
143    reboot_state: RebootState,
144    engine_handler: RwLock<JoinHandle<()>>,
145    engine_id: usize,
146    category: ModelCategory,
147    config: MistralRsConfig,
148}
149
150#[derive(Clone)]
151struct RebootState {
152    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
153    method: SchedulerConfig,
154    truncate_sequence: bool,
155    no_kv_cache: bool,
156    no_prefix_cache: bool,
157    prefix_cache_n: usize,
158    disable_eos_stop: bool,
159    throughput_logging_enabled: bool,
160    search_embedding_model: Option<BertEmbeddingModel>,
161}
162
163#[derive(Debug)]
164pub enum MistralRsError {
165    EnginePoisoned,
166    SenderPoisoned,
167}
168
169impl std::fmt::Display for MistralRsError {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        write!(f, "{:?}", &self)
172    }
173}
174
175impl std::error::Error for MistralRsError {}
176
177#[cfg(feature = "pyo3_macros")]
178impl From<MistralRsError> for pyo3::PyErr {
179    fn from(value: MistralRsError) -> Self {
180        PyValueError::new_err(format!("{:?}", value))
181    }
182}
183
184/// The MistralRsBuilder takes the pipeline and a scheduler method and constructs
185/// an Engine and a MistralRs instance. The Engine runs on a separate thread, and the MistralRs
186/// instance stays on the calling thread.
187pub struct MistralRsBuilder {
188    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
189    method: SchedulerConfig,
190    log: Option<String>,
191    truncate_sequence: Option<bool>,
192    no_kv_cache: Option<bool>,
193    no_prefix_cache: Option<bool>,
194    prefix_cache_n: Option<usize>,
195    disable_eos_stop: Option<bool>,
196    throughput_logging_enabled: bool,
197    search_embedding_model: Option<BertEmbeddingModel>,
198}
199
200impl MistralRsBuilder {
201    pub fn new(
202        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
203        method: SchedulerConfig,
204        throughput_logging: bool,
205        search_embedding_model: Option<BertEmbeddingModel>,
206    ) -> Self {
207        Self {
208            pipeline,
209            method,
210            log: None,
211            truncate_sequence: None,
212            no_kv_cache: None,
213            no_prefix_cache: None,
214            prefix_cache_n: None,
215            disable_eos_stop: None,
216            throughput_logging_enabled: throughput_logging,
217            search_embedding_model,
218        }
219    }
220    pub fn with_log(mut self, log: String) -> Self {
221        self.log = Some(log);
222        self
223    }
224    pub fn with_opt_log(mut self, log: Option<String>) -> Self {
225        self.log = log;
226        self
227    }
228    pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
229        self.truncate_sequence = Some(truncate_sequence);
230        self
231    }
232    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
233        self.no_kv_cache = Some(no_kv_cache);
234        self
235    }
236    pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
237        self.no_prefix_cache = Some(no_prefix_cache);
238        self
239    }
240    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
241        self.prefix_cache_n = Some(prefix_cache_n);
242        self
243    }
244    pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
245        self.disable_eos_stop = Some(disable_eos_stop);
246        self
247    }
248
249    pub fn build(self) -> Arc<MistralRs> {
250        MistralRs::new(self)
251    }
252}
253
254impl Drop for MistralRs {
255    fn drop(&mut self) {
256        ENGINE_INSTRUCTIONS
257            .lock()
258            .expect("`ENGINE_INSTRUCTIONS` was poisioned")
259            .insert(self.engine_id, Some(EngineInstruction::Terminate));
260    }
261}
262
263impl MistralRs {
264    fn new(config: MistralRsBuilder) -> Arc<Self> {
265        let MistralRsBuilder {
266            pipeline,
267            method,
268            log,
269            truncate_sequence,
270            no_kv_cache,
271            no_prefix_cache,
272            prefix_cache_n,
273            disable_eos_stop,
274            throughput_logging_enabled,
275            search_embedding_model,
276        } = config;
277
278        let category = pipeline.try_lock().unwrap().category();
279        mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
280            get_mut_arcmutex!(pipeline).device(),
281        );
282
283        let truncate_sequence = truncate_sequence.unwrap_or(false);
284        let no_kv_cache = no_kv_cache.unwrap_or(false);
285        let no_prefix_cache = no_prefix_cache.unwrap_or(false);
286        let prefix_cache_n = prefix_cache_n.unwrap_or(16);
287        let disable_eos_stop = disable_eos_stop.unwrap_or(false);
288
289        let reboot_state = RebootState {
290            pipeline: pipeline.clone(),
291            method: method.clone(),
292            truncate_sequence,
293            no_kv_cache,
294            no_prefix_cache,
295            prefix_cache_n,
296            disable_eos_stop,
297            throughput_logging_enabled,
298            search_embedding_model: search_embedding_model.clone(),
299        };
300
301        let (tx, rx) = channel(10_000);
302
303        let sender = RwLock::new(tx);
304        let id = pipeline.try_lock().unwrap().name();
305
306        let kind = pipeline.try_lock().unwrap().get_metadata().kind.clone();
307        let device = pipeline.try_lock().unwrap().device();
308        let config = MistralRsConfig {
309            kind,
310            device,
311            category: category.clone(),
312        };
313
314        let engine_handler = thread::spawn(move || {
315            #[cfg(feature = "metal")]
316            objc::rc::autoreleasepool(move || {
317                let rt = Runtime::new().unwrap();
318                rt.block_on(async move {
319                    let engine = Engine::new(
320                        rx,
321                        pipeline,
322                        method,
323                        truncate_sequence,
324                        no_kv_cache,
325                        no_prefix_cache,
326                        prefix_cache_n,
327                        disable_eos_stop,
328                        throughput_logging_enabled,
329                        search_embedding_model,
330                    )
331                    .expect("Engine creation failed.");
332                    Arc::new(engine).run().await;
333                })
334            });
335
336            #[cfg(not(feature = "metal"))]
337            {
338                let rt = Runtime::new().unwrap();
339                rt.block_on(async move {
340                    let engine = Engine::new(
341                        rx,
342                        pipeline,
343                        method,
344                        truncate_sequence,
345                        no_kv_cache,
346                        no_prefix_cache,
347                        prefix_cache_n,
348                        disable_eos_stop,
349                        throughput_logging_enabled,
350                        search_embedding_model,
351                    )
352                    .expect("Engine creation failed.");
353                    Arc::new(engine).run().await;
354                })
355            }
356        });
357
358        let engine_id = ENGINE_ID.fetch_add(1, atomic::Ordering::SeqCst);
359
360        if distributed::is_daemon() {
361            let request_sender = sender.write().unwrap().clone();
362            thread::spawn(move || {
363                let rt = Runtime::new().unwrap();
364                rt.block_on(async move {
365                    use interprocess::local_socket::traits::Stream;
366                    use interprocess::local_socket::Stream as LocalStream;
367
368                    loop {
369                        let name = distributed::ipc_name().unwrap();
370                        if let Ok(stream) = LocalStream::connect(name) {
371                            let mut reader = BufReader::new(stream);
372                            let mut buf = String::new();
373                            reader.read_line(&mut buf).unwrap();
374                            let mut req: Request = serde_json::from_str(&buf).unwrap();
375
376                            req = match req {
377                                Request::ReIsq(x) => Request::ReIsq(x),
378                                Request::Terminate => Request::Terminate,
379                                Request::Detokenize(mut x) => {
380                                    let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
381                                    x.response = sender;
382                                    let req = Request::Detokenize(x);
383
384                                    request_sender.send(req).await.unwrap();
385                                    let resp = receiver.recv().await.unwrap();
386                                    resp.unwrap();
387                                    continue;
388                                }
389                                Request::Tokenize(mut x) => {
390                                    let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
391                                    x.response = sender;
392                                    let req = Request::Tokenize(x);
393
394                                    request_sender.send(req).await.unwrap();
395                                    let resp = receiver.recv().await.unwrap();
396                                    resp.unwrap();
397                                    continue;
398                                }
399                                Request::Normal(mut x) => {
400                                    let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
401                                    x.is_streaming = false;
402                                    x.response = sender;
403                                    let req = Request::Normal(x);
404
405                                    request_sender.send(req).await.unwrap();
406                                    let resp = receiver.recv().await.unwrap();
407                                    resp.as_result().unwrap();
408                                    continue;
409                                }
410                                Request::TerminateAllSeqsNextStep => {
411                                    Request::TerminateAllSeqsNextStep
412                                }
413                            };
414
415                            request_sender.send(req).await.unwrap();
416                        }
417                    }
418                });
419            });
420
421            #[allow(clippy::empty_loop)]
422            loop {}
423        }
424
425        // Determine if the current runtime is multi-threaded, as blocking operations are not allowed in single-threaded mode
426        let is_multi_threaded = tokio::runtime::Handle::try_current()
427            .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
428
429        // Do a dummy run
430        if !distributed::is_daemon()
431            && is_multi_threaded
432            && matches!(category, ModelCategory::Text | ModelCategory::Vision { .. })
433        {
434            let clone_sender = sender.read().unwrap().clone();
435            tokio::task::block_in_place(|| {
436                let (tx, mut rx) = channel(1);
437                let req = Request::Normal(NormalRequest {
438                    id: 0,
439                    messages: RequestMessage::Completion {
440                        text: "hello".to_string(),
441                        echo_prompt: false,
442                        best_of: None,
443                    },
444                    sampling_params: SamplingParams {
445                        max_len: Some(1),
446                        ..SamplingParams::deterministic()
447                    },
448                    response: tx,
449                    return_logprobs: false,
450                    is_streaming: false,
451                    constraint: Constraint::None,
452                    suffix: None,
453                    tool_choice: None,
454                    tools: None,
455                    logits_processors: None,
456                    return_raw_logits: false,
457                    web_search_options: None,
458                });
459                info!("Beginning dummy run.");
460                let start = Instant::now();
461                clone_sender.blocking_send(req).unwrap();
462
463                if let Some(_resp) = rx.blocking_recv() {
464                    let end = Instant::now();
465                    info!(
466                        "Dummy run completed in {}s.",
467                        end.duration_since(start).as_secs_f64()
468                    );
469                } else {
470                    warn!("Dummy run failed!");
471                }
472            });
473        }
474
475        Arc::new(Self {
476            engine_id,
477            sender,
478            log,
479            id,
480            creation_time: SystemTime::now()
481                .duration_since(UNIX_EPOCH)
482                .expect("Time travel has occurred!")
483                .as_secs(),
484            next_request_id: Mutex::new(RefCell::new(1)),
485            reboot_state,
486            engine_handler: RwLock::new(engine_handler),
487            category,
488            config,
489        })
490    }
491
492    /// attempts to reboot the engine, if the sender (only way to communicate with
493    /// the engine) is closed
494    fn reboot_engine(&self) -> Result<(), MistralRsError> {
495        let (new_sender, rx) = channel(10_000);
496        let reboot_state = self.reboot_state.clone();
497        let mut sender_lock = self.sender.write().map_err(|_| {
498            tracing::warn!("Couldn't get write lock on the sender during reboot attempt");
499            MistralRsError::SenderPoisoned
500        })?;
501        let mut engine_lock = self.engine_handler.write().map_err(|_| {
502            tracing::warn!("Couldn't get write lock on the engine during reboot attempt");
503            MistralRsError::EnginePoisoned
504        })?;
505
506        if !engine_lock.is_finished() {
507            tracing::info!("Engine already running, returning ok");
508            Ok(())
509        } else {
510            // critical section. A panic here could lead to poisoned locks
511            let new_engine_handler = thread::spawn(move || {
512                let rt = Runtime::new().unwrap();
513                rt.block_on(async move {
514                    let engine = Engine::new(
515                        rx,
516                        reboot_state.pipeline.clone(),
517                        reboot_state.method,
518                        reboot_state.truncate_sequence,
519                        reboot_state.no_kv_cache,
520                        reboot_state.no_prefix_cache,
521                        reboot_state.prefix_cache_n,
522                        reboot_state.disable_eos_stop,
523                        reboot_state.throughput_logging_enabled,
524                        reboot_state.search_embedding_model,
525                    )
526                    .expect("Engine creation failed");
527                    Arc::new(engine).run().await;
528                });
529            });
530            *sender_lock = new_sender;
531            *engine_lock = new_engine_handler;
532            tracing::info!("Successfully rebooted engine and updated sender + engine handler");
533            Ok(())
534        }
535    }
536
537    fn engine_dead(&self) -> Result<bool, MistralRsError> {
538        match self.engine_handler.read() {
539            Ok(handler) => Ok(handler.is_finished()),
540            Err(_) => {
541                tracing::warn!("Couldn't get read lock on engine!");
542                Err(MistralRsError::EnginePoisoned)
543            }
544        }
545    }
546
547    pub fn get_sender(&self) -> Result<Sender<Request>, MistralRsError> {
548        if self.engine_dead()? {
549            tracing::warn!("Engine is dead, rebooting");
550            self.reboot_engine()?
551        }
552        match self.sender.read() {
553            Ok(sender) => Ok(sender.clone()),
554            Err(_) => Err(MistralRsError::SenderPoisoned),
555        }
556    }
557
558    pub fn get_id(&self) -> String {
559        self.id.clone()
560    }
561
562    pub fn get_creation_time(&self) -> u64 {
563        self.creation_time
564    }
565
566    pub fn get_model_category(&self) -> ModelCategory {
567        self.category.clone()
568    }
569
570    pub fn next_request_id(&self) -> usize {
571        let l = self.next_request_id.lock().unwrap();
572        let last = &mut *l.borrow_mut();
573        let last_v = *last;
574        *last += 1;
575        last_v
576    }
577
578    pub fn maybe_log_request(this: Arc<Self>, repr: String) {
579        if let Some(file) = &this.log {
580            let mut f = OpenOptions::new()
581                .append(true)
582                .create(true) // Optionally create the file if it doesn't already exist
583                .open(file)
584                .expect("Unable to open file");
585            let time = chrono::offset::Local::now();
586            f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
587                .expect("Unable to write data");
588        }
589    }
590
591    pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
592        if let Some(file) = &this.log {
593            let mut f = OpenOptions::new()
594                .append(true)
595                .create(true) // Optionally create the file if it doesn't already exist
596                .open(file)
597                .expect("Unable to open file");
598            let time = chrono::offset::Local::now();
599            let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
600            f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
601                .expect("Unable to write data");
602        }
603    }
604
605    pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
606        if let Some(file) = &this.log {
607            let mut f = OpenOptions::new()
608                .append(true)
609                .create(true) // Optionally create the file if it doesn't already exist
610                .open(file)
611                .expect("Unable to open file");
612            let time = chrono::offset::Local::now();
613            f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
614                .expect("Unable to write data");
615        }
616    }
617
618    pub fn config(&self) -> &MistralRsConfig {
619        &self.config
620    }
621}