mistralrs_core/
lib.rs

1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5    BertEmbeddingModel, EngineInstruction, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
6};
7use hf_hub::Cache;
8pub use lora::Ordering;
9pub use pipeline::ModelCategory;
10pub use pipeline::Pipeline;
11#[cfg(feature = "pyo3_macros")]
12use pyo3::exceptions::PyValueError;
13use std::io::BufRead;
14use std::io::BufReader;
15use std::sync::OnceLock;
16use std::time::Instant;
17use std::{
18    cell::RefCell,
19    error::Error,
20    fs::OpenOptions,
21    io::Write,
22    sync::{
23        atomic::{self, AtomicBool, AtomicUsize},
24        Arc, Mutex, RwLock,
25    },
26    thread::{self, JoinHandle},
27    time::{SystemTime, UNIX_EPOCH},
28};
29use tokio::sync::mpsc::{channel, Sender};
30use tracing::info;
31use tracing::warn;
32
33mod cuda;
34mod device_map;
35mod engine;
36mod lora;
37mod model_loader;
38mod ops;
39pub use model_loader::{
40    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
41};
42mod kv_cache;
43mod search;
44
45mod model_selected;
46pub use model_selected::ModelSelected;
47pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
48
49mod amoe;
50#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
51mod dummy_paged_attention;
52mod embedding;
53mod gguf;
54pub mod layers;
55mod layers_masker;
56mod layers_utils;
57mod models;
58#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
59mod paged_attention;
60#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
61use dummy_paged_attention as paged_attention;
62mod attention;
63mod diffusion_models;
64pub mod distributed;
65mod pipeline;
66mod prefix_cacher;
67mod request;
68mod response;
69mod sampler;
70mod scheduler;
71mod sequence;
72mod speech_models;
73mod toml_selector;
74mod tools;
75mod topology;
76mod utils;
77mod vision_models;
78mod xlora_models;
79
80pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
81pub use device_map::{
82    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
83};
84pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
85pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
86pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig};
87pub use pipeline::{
88    chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
89    AutoDeviceMapParams, DiffusionGenerationParams, DiffusionLoader, DiffusionLoaderBuilder,
90    DiffusionLoaderType, GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader,
91    GGUFLoaderBuilder, GGUFSpecificConfig, GemmaLoader, Idefics2Loader, IsqOrganization,
92    LLaVALoader, LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths, LoraAdapterPaths,
93    MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoader, NormalLoaderBuilder,
94    NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader,
95    SpeculativeConfig, SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline,
96    Starcoder2Loader, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
97    VisionPromptPrefixer, VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
98};
99pub use request::{
100    ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
101    LlguidanceGrammar, MessageContent, NormalRequest, Request, RequestMessage, SearchContextSize,
102    TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
103};
104pub use response::*;
105pub use sampler::{
106    CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
107};
108pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
109use serde::Serialize;
110pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
111use tokio::runtime::Runtime;
112use toml_selector::{TomlLoaderArgs, TomlSelector};
113pub use tools::{
114    CalledFunction, Function, Tool, ToolCallResponse, ToolCallType, ToolChoice, ToolType,
115};
116pub use topology::{LayerTopology, Topology};
117pub use utils::debug::initialize_logging;
118pub use utils::memory_usage::MemoryUsage;
119pub use utils::normal::{ModelDType, TryIntoDType};
120pub use utils::{paged_attn_supported, using_flash_attn};
121
122// re-export llguidance for easier LlguidanceGrammar construction
123pub use llguidance;
124
125/// `true` if `MISTRALRS_DEBUG=1`
126pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
127pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
128static ENGINE_ID: AtomicUsize = AtomicUsize::new(0);
129
130pub struct MistralRsConfig {
131    pub kind: ModelKind,
132    pub device: Device,
133    pub category: ModelCategory,
134}
135
136/// The MistralRs struct handles sending requests to the engine.
137/// It is the core multi-threaded component of mistral.rs, and uses `mpsc`
138/// `Sender` and `Receiver` primitives to send and receive requests to the
139/// engine.
140pub struct MistralRs {
141    sender: RwLock<Sender<Request>>,
142    log: Option<String>,
143    id: String,
144    creation_time: u64,
145    next_request_id: Mutex<RefCell<usize>>,
146    reboot_state: RebootState,
147    engine_handler: RwLock<JoinHandle<()>>,
148    engine_id: usize,
149    category: ModelCategory,
150    config: MistralRsConfig,
151}
152
153#[derive(Clone)]
154struct RebootState {
155    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
156    method: SchedulerConfig,
157    truncate_sequence: bool,
158    no_kv_cache: bool,
159    no_prefix_cache: bool,
160    prefix_cache_n: usize,
161    disable_eos_stop: bool,
162    throughput_logging_enabled: bool,
163    search_embedding_model: Option<BertEmbeddingModel>,
164}
165
166#[derive(Debug)]
167pub enum MistralRsError {
168    EnginePoisoned,
169    SenderPoisoned,
170}
171
172impl std::fmt::Display for MistralRsError {
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        write!(f, "{:?}", &self)
175    }
176}
177
178impl std::error::Error for MistralRsError {}
179
180#[cfg(feature = "pyo3_macros")]
181impl From<MistralRsError> for pyo3::PyErr {
182    fn from(value: MistralRsError) -> Self {
183        PyValueError::new_err(format!("{:?}", value))
184    }
185}
186
187/// The MistralRsBuilder takes the pipeline and a scheduler method and constructs
188/// an Engine and a MistralRs instance. The Engine runs on a separate thread, and the MistralRs
189/// instance stays on the calling thread.
190pub struct MistralRsBuilder {
191    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
192    method: SchedulerConfig,
193    log: Option<String>,
194    truncate_sequence: Option<bool>,
195    no_kv_cache: Option<bool>,
196    no_prefix_cache: Option<bool>,
197    prefix_cache_n: Option<usize>,
198    disable_eos_stop: Option<bool>,
199    throughput_logging_enabled: bool,
200    search_embedding_model: Option<BertEmbeddingModel>,
201}
202
203impl MistralRsBuilder {
204    pub fn new(
205        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
206        method: SchedulerConfig,
207        throughput_logging: bool,
208        search_embedding_model: Option<BertEmbeddingModel>,
209    ) -> Self {
210        Self {
211            pipeline,
212            method,
213            log: None,
214            truncate_sequence: None,
215            no_kv_cache: None,
216            no_prefix_cache: None,
217            prefix_cache_n: None,
218            disable_eos_stop: None,
219            throughput_logging_enabled: throughput_logging,
220            search_embedding_model,
221        }
222    }
223    pub fn with_log(mut self, log: String) -> Self {
224        self.log = Some(log);
225        self
226    }
227    pub fn with_opt_log(mut self, log: Option<String>) -> Self {
228        self.log = log;
229        self
230    }
231    pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
232        self.truncate_sequence = Some(truncate_sequence);
233        self
234    }
235    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
236        self.no_kv_cache = Some(no_kv_cache);
237        self
238    }
239    pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
240        self.no_prefix_cache = Some(no_prefix_cache);
241        self
242    }
243    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
244        self.prefix_cache_n = Some(prefix_cache_n);
245        self
246    }
247    pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
248        self.disable_eos_stop = Some(disable_eos_stop);
249        self
250    }
251
252    pub fn build(self) -> Arc<MistralRs> {
253        MistralRs::new(self)
254    }
255}
256
257impl Drop for MistralRs {
258    fn drop(&mut self) {
259        ENGINE_INSTRUCTIONS
260            .lock()
261            .expect("`ENGINE_INSTRUCTIONS` was poisoned")
262            .insert(self.engine_id, Some(EngineInstruction::Terminate));
263    }
264}
265
266impl MistralRs {
267    fn new(config: MistralRsBuilder) -> Arc<Self> {
268        let MistralRsBuilder {
269            pipeline,
270            method,
271            log,
272            truncate_sequence,
273            no_kv_cache,
274            no_prefix_cache,
275            prefix_cache_n,
276            disable_eos_stop,
277            throughput_logging_enabled,
278            search_embedding_model,
279        } = config;
280
281        let category = pipeline.try_lock().unwrap().category();
282        mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
283            get_mut_arcmutex!(pipeline).device(),
284        );
285
286        let truncate_sequence = truncate_sequence.unwrap_or(false);
287        let no_kv_cache = no_kv_cache.unwrap_or(false);
288        let no_prefix_cache = no_prefix_cache.unwrap_or(false);
289        let prefix_cache_n = prefix_cache_n.unwrap_or(16);
290        let disable_eos_stop = disable_eos_stop.unwrap_or(false);
291
292        let reboot_state = RebootState {
293            pipeline: pipeline.clone(),
294            method: method.clone(),
295            truncate_sequence,
296            no_kv_cache,
297            no_prefix_cache,
298            prefix_cache_n,
299            disable_eos_stop,
300            throughput_logging_enabled,
301            search_embedding_model: search_embedding_model.clone(),
302        };
303
304        let (tx, rx) = channel(10_000);
305
306        let sender = RwLock::new(tx);
307        let id = pipeline.try_lock().unwrap().name();
308
309        let kind = pipeline.try_lock().unwrap().get_metadata().kind.clone();
310        let device = pipeline.try_lock().unwrap().device();
311        let config = MistralRsConfig {
312            kind,
313            device,
314            category: category.clone(),
315        };
316
317        let engine_handler = thread::spawn(move || {
318            #[cfg(feature = "metal")]
319            objc::rc::autoreleasepool(move || {
320                let rt = Runtime::new().unwrap();
321                rt.block_on(async move {
322                    let engine = Engine::new(
323                        rx,
324                        pipeline,
325                        method,
326                        truncate_sequence,
327                        no_kv_cache,
328                        no_prefix_cache,
329                        prefix_cache_n,
330                        disable_eos_stop,
331                        throughput_logging_enabled,
332                        search_embedding_model,
333                    )
334                    .expect("Engine creation failed.");
335                    Arc::new(engine).run().await;
336                })
337            });
338
339            #[cfg(not(feature = "metal"))]
340            {
341                let rt = Runtime::new().unwrap();
342                rt.block_on(async move {
343                    let engine = Engine::new(
344                        rx,
345                        pipeline,
346                        method,
347                        truncate_sequence,
348                        no_kv_cache,
349                        no_prefix_cache,
350                        prefix_cache_n,
351                        disable_eos_stop,
352                        throughput_logging_enabled,
353                        search_embedding_model,
354                    )
355                    .expect("Engine creation failed.");
356                    Arc::new(engine).run().await;
357                })
358            }
359        });
360
361        let engine_id = ENGINE_ID.fetch_add(1, atomic::Ordering::SeqCst);
362
363        if distributed::is_daemon() {
364            let request_sender = sender.write().unwrap().clone();
365            thread::spawn(move || {
366                let rt = Runtime::new().unwrap();
367                rt.block_on(async move {
368                    use interprocess::local_socket::traits::Stream;
369                    use interprocess::local_socket::Stream as LocalStream;
370
371                    loop {
372                        let name = distributed::ipc_name().unwrap();
373                        if let Ok(stream) = LocalStream::connect(name) {
374                            let mut reader = BufReader::new(stream);
375                            let mut buf = String::new();
376                            reader.read_line(&mut buf).unwrap();
377                            let mut req: Request = serde_json::from_str(&buf).unwrap();
378
379                            req = match req {
380                                Request::ReIsq(x) => Request::ReIsq(x),
381                                Request::Terminate => Request::Terminate,
382                                Request::Detokenize(mut x) => {
383                                    let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
384                                    x.response = sender;
385                                    let req = Request::Detokenize(x);
386
387                                    request_sender.send(req).await.unwrap();
388                                    let resp = receiver.recv().await.unwrap();
389                                    resp.unwrap();
390                                    continue;
391                                }
392                                Request::Tokenize(mut x) => {
393                                    let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
394                                    x.response = sender;
395                                    let req = Request::Tokenize(x);
396
397                                    request_sender.send(req).await.unwrap();
398                                    let resp = receiver.recv().await.unwrap();
399                                    resp.unwrap();
400                                    continue;
401                                }
402                                Request::Normal(mut x) => {
403                                    let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
404                                    x.is_streaming = false;
405                                    x.response = sender;
406                                    let req = Request::Normal(x);
407
408                                    request_sender.send(req).await.unwrap();
409                                    let resp = receiver.recv().await.unwrap();
410                                    resp.as_result().unwrap();
411                                    continue;
412                                }
413                                Request::TerminateAllSeqsNextStep => {
414                                    Request::TerminateAllSeqsNextStep
415                                }
416                            };
417
418                            request_sender.send(req).await.unwrap();
419                        }
420                    }
421                });
422            });
423
424            #[allow(clippy::empty_loop)]
425            loop {}
426        }
427
428        // Determine if the current runtime is multi-threaded, as blocking operations are not allowed in single-threaded mode
429        let is_multi_threaded = tokio::runtime::Handle::try_current()
430            .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
431
432        // Do a dummy run
433        if !distributed::is_daemon()
434            && is_multi_threaded
435            && matches!(category, ModelCategory::Text | ModelCategory::Vision { .. })
436        {
437            let clone_sender = sender.read().unwrap().clone();
438            tokio::task::block_in_place(|| {
439                let (tx, mut rx) = channel(1);
440                let req = Request::Normal(Box::new(NormalRequest {
441                    id: 0,
442                    messages: RequestMessage::Completion {
443                        text: "hello".to_string(),
444                        echo_prompt: false,
445                        best_of: None,
446                    },
447                    sampling_params: SamplingParams {
448                        max_len: Some(1),
449                        ..SamplingParams::deterministic()
450                    },
451                    response: tx,
452                    return_logprobs: false,
453                    is_streaming: false,
454                    constraint: Constraint::None,
455                    suffix: None,
456                    tool_choice: None,
457                    tools: None,
458                    logits_processors: None,
459                    return_raw_logits: false,
460                    web_search_options: None,
461                }));
462                info!("Beginning dummy run.");
463                let start = Instant::now();
464                clone_sender.blocking_send(req).unwrap();
465
466                if let Some(_resp) = rx.blocking_recv() {
467                    let end = Instant::now();
468                    info!(
469                        "Dummy run completed in {}s.",
470                        end.duration_since(start).as_secs_f64()
471                    );
472                } else {
473                    warn!("Dummy run failed!");
474                }
475            });
476        }
477
478        Arc::new(Self {
479            engine_id,
480            sender,
481            log,
482            id,
483            creation_time: SystemTime::now()
484                .duration_since(UNIX_EPOCH)
485                .expect("Time travel has occurred!")
486                .as_secs(),
487            next_request_id: Mutex::new(RefCell::new(1)),
488            reboot_state,
489            engine_handler: RwLock::new(engine_handler),
490            category,
491            config,
492        })
493    }
494
495    /// attempts to reboot the engine, if the sender (only way to communicate with
496    /// the engine) is closed
497    fn reboot_engine(&self) -> Result<(), MistralRsError> {
498        let (new_sender, rx) = channel(10_000);
499        let reboot_state = self.reboot_state.clone();
500        let mut sender_lock = self.sender.write().map_err(|_| {
501            tracing::warn!("Couldn't get write lock on the sender during reboot attempt");
502            MistralRsError::SenderPoisoned
503        })?;
504        let mut engine_lock = self.engine_handler.write().map_err(|_| {
505            tracing::warn!("Couldn't get write lock on the engine during reboot attempt");
506            MistralRsError::EnginePoisoned
507        })?;
508
509        if !engine_lock.is_finished() {
510            tracing::info!("Engine already running, returning ok");
511            Ok(())
512        } else {
513            // critical section. A panic here could lead to poisoned locks
514            let new_engine_handler = thread::spawn(move || {
515                let rt = Runtime::new().unwrap();
516                rt.block_on(async move {
517                    let engine = Engine::new(
518                        rx,
519                        reboot_state.pipeline.clone(),
520                        reboot_state.method,
521                        reboot_state.truncate_sequence,
522                        reboot_state.no_kv_cache,
523                        reboot_state.no_prefix_cache,
524                        reboot_state.prefix_cache_n,
525                        reboot_state.disable_eos_stop,
526                        reboot_state.throughput_logging_enabled,
527                        reboot_state.search_embedding_model,
528                    )
529                    .expect("Engine creation failed");
530                    Arc::new(engine).run().await;
531                });
532            });
533            *sender_lock = new_sender;
534            *engine_lock = new_engine_handler;
535            tracing::info!("Successfully rebooted engine and updated sender + engine handler");
536            Ok(())
537        }
538    }
539
540    fn engine_dead(&self) -> Result<bool, MistralRsError> {
541        match self.engine_handler.read() {
542            Ok(handler) => Ok(handler.is_finished()),
543            Err(_) => {
544                tracing::warn!("Couldn't get read lock on engine!");
545                Err(MistralRsError::EnginePoisoned)
546            }
547        }
548    }
549
550    pub fn get_sender(&self) -> Result<Sender<Request>, MistralRsError> {
551        if self.engine_dead()? {
552            tracing::warn!("Engine is dead, rebooting");
553            self.reboot_engine()?
554        }
555        match self.sender.read() {
556            Ok(sender) => Ok(sender.clone()),
557            Err(_) => Err(MistralRsError::SenderPoisoned),
558        }
559    }
560
561    pub fn get_id(&self) -> String {
562        self.id.clone()
563    }
564
565    pub fn get_creation_time(&self) -> u64 {
566        self.creation_time
567    }
568
569    pub fn get_model_category(&self) -> ModelCategory {
570        self.category.clone()
571    }
572
573    pub fn next_request_id(&self) -> usize {
574        let l = self.next_request_id.lock().unwrap();
575        let last = &mut *l.borrow_mut();
576        let last_v = *last;
577        *last += 1;
578        last_v
579    }
580
581    pub fn maybe_log_request(this: Arc<Self>, repr: String) {
582        if let Some(file) = &this.log {
583            let mut f = OpenOptions::new()
584                .append(true)
585                .create(true) // Optionally create the file if it doesn't already exist
586                .open(file)
587                .expect("Unable to open file");
588            let time = chrono::offset::Local::now();
589            f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
590                .expect("Unable to write data");
591        }
592    }
593
594    pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
595        if let Some(file) = &this.log {
596            let mut f = OpenOptions::new()
597                .append(true)
598                .create(true) // Optionally create the file if it doesn't already exist
599                .open(file)
600                .expect("Unable to open file");
601            let time = chrono::offset::Local::now();
602            let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
603            f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
604                .expect("Unable to write data");
605        }
606    }
607
608    pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
609        if let Some(file) = &this.log {
610            let mut f = OpenOptions::new()
611                .append(true)
612                .create(true) // Optionally create the file if it doesn't already exist
613                .open(file)
614                .expect("Unable to open file");
615            let time = chrono::offset::Local::now();
616            f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
617                .expect("Unable to write data");
618        }
619    }
620
621    pub fn config(&self) -> &MistralRsConfig {
622        &self.config
623    }
624}