mistralrs_core/
lib.rs

1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5    get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6    BertEmbeddingModel, EngineInstruction, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
7};
8use hf_hub::Cache;
9pub use lora::Ordering;
10pub use pipeline::ModelCategory;
11pub use pipeline::Pipeline;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::exceptions::PyValueError;
14use std::collections::HashMap;
15use std::sync::OnceLock;
16use std::time::Instant;
17use std::{
18    cell::RefCell,
19    error::Error,
20    fs::OpenOptions,
21    io::Write,
22    sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
23    thread::{self, JoinHandle},
24    time::{SystemTime, UNIX_EPOCH},
25};
26use tokio::sync::mpsc::{channel, Sender};
27use tracing::info;
28use tracing::warn;
29
30mod cuda;
31mod device_map;
32mod engine;
33mod lora;
34mod model_loader;
35mod ops;
36pub use model_loader::{
37    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
38};
39mod embedding_models;
40mod kv_cache;
41mod search;
42
43mod model_selected;
44pub use model_selected::ModelSelected;
45pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
46
47mod amoe;
48mod attention;
49mod diffusion_models;
50pub mod distributed;
51mod embedding;
52mod gguf;
53pub mod layers;
54mod layers_masker;
55mod layers_utils;
56pub mod matformer;
57mod models;
58mod paged_attention;
59mod pipeline;
60mod prefix_cacher;
61mod request;
62mod response;
63mod sampler;
64mod scheduler;
65mod sequence;
66mod speech_models;
67mod toml_selector;
68mod tools;
69mod topology;
70mod utils;
71mod vision_models;
72mod xlora_models;
73
74pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
75pub use device_map::{
76    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
77};
78pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
79pub use mistralrs_audio::AudioInput;
80pub use mistralrs_mcp::{
81    CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
82};
83pub use mistralrs_mcp::{
84    McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
85};
86pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
87pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
88pub use pipeline::{
89    chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
90    AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
91    DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader, EmbeddingLoaderBuilder,
92    EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig, GGMLLoader,
93    GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig,
94    GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader,
95    Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind,
96    ModelPaths, MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
97    NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
98    SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
99    SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
100    VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
101};
102pub use request::{
103    ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
104    LlguidanceGrammar, MessageContent, NormalRequest, Request, RequestMessage, SearchContextSize,
105    TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
106};
107pub use response::*;
108pub use sampler::{
109    CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
110};
111pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
112pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
113use serde::Serialize;
114pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
115use tokio::runtime::Runtime;
116use toml_selector::{TomlLoaderArgs, TomlSelector};
117pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
118pub use topology::{LayerTopology, Topology};
119pub use utils::debug::initialize_logging;
120pub use utils::memory_usage::MemoryUsage;
121pub use utils::normal::{ModelDType, TryIntoDType};
122pub use utils::{paged_attn_supported, using_flash_attn};
123
124// re-export llguidance for easier LlguidanceGrammar construction
125pub use llguidance;
126
127/// `true` if `MISTRALRS_DEBUG=1`
128pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
129pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
130
131/// Configuration for creating an engine instance
132#[derive(Clone)]
133pub struct EngineConfig {
134    pub no_kv_cache: bool,
135    pub no_prefix_cache: bool,
136    pub prefix_cache_n: usize,
137    pub disable_eos_stop: bool,
138    pub throughput_logging_enabled: bool,
139    pub search_embedding_model: Option<BertEmbeddingModel>,
140    pub search_callback: Option<Arc<SearchCallback>>,
141    pub tool_callbacks: tools::ToolCallbacks,
142    pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
143}
144
145impl Default for EngineConfig {
146    fn default() -> Self {
147        Self {
148            no_kv_cache: false,
149            no_prefix_cache: false,
150            prefix_cache_n: 16,
151            disable_eos_stop: false,
152            throughput_logging_enabled: true,
153            search_embedding_model: None,
154            search_callback: None,
155            tool_callbacks: HashMap::new(),
156            tool_callbacks_with_tools: HashMap::new(),
157        }
158    }
159}
160
161/// Configuration for adding a model to MistralRs
162#[derive(Clone)]
163pub struct AddModelConfig {
164    pub engine_config: EngineConfig,
165    pub mcp_client_config: Option<McpClientConfig>,
166}
167
168impl AddModelConfig {
169    pub fn new(engine_config: EngineConfig) -> Self {
170        Self {
171            engine_config,
172            mcp_client_config: None,
173        }
174    }
175
176    pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
177        self.mcp_client_config = Some(mcp_config);
178        self
179    }
180}
181
182#[derive(Clone)]
183pub struct MistralRsConfig {
184    pub kind: ModelKind,
185    pub device: Device,
186    pub category: ModelCategory,
187    pub modalities: Modalities,
188}
189
190/// Internal structure to hold per-engine state
191struct EngineInstance {
192    sender: Sender<Request>,
193    engine_handler: JoinHandle<()>,
194    reboot_state: RebootState,
195    config: MistralRsConfig,
196    category: ModelCategory,
197}
198
199/// The MistralRs struct handles sending requests to multiple engines.
200/// It is the core multi-threaded component of mistral.rs, and uses `mpsc`
201/// `Sender` and `Receiver` primitives to send and receive requests to the
202/// appropriate engine based on model ID.
203pub struct MistralRs {
204    engines: RwLock<HashMap<String, EngineInstance>>,
205    default_engine_id: RwLock<Option<String>>,
206    log: Option<String>,
207    id: String,
208    creation_time: u64,
209    next_request_id: Mutex<RefCell<usize>>,
210}
211
212#[derive(Clone)]
213struct RebootState {
214    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
215    method: SchedulerConfig,
216    no_kv_cache: bool,
217    no_prefix_cache: bool,
218    prefix_cache_n: usize,
219    disable_eos_stop: bool,
220    throughput_logging_enabled: bool,
221    search_embedding_model: Option<BertEmbeddingModel>,
222    search_callback: Option<Arc<search::SearchCallback>>,
223    tool_callbacks: tools::ToolCallbacks,
224    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
225    mcp_client_config: Option<McpClientConfig>,
226}
227
228#[derive(Debug)]
229pub enum MistralRsError {
230    EnginePoisoned,
231    SenderPoisoned,
232}
233
234impl std::fmt::Display for MistralRsError {
235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        write!(f, "{:?}", &self)
237    }
238}
239
240impl std::error::Error for MistralRsError {}
241
242#[cfg(feature = "pyo3_macros")]
243impl From<MistralRsError> for pyo3::PyErr {
244    fn from(value: MistralRsError) -> Self {
245        PyValueError::new_err(format!("{value:?}"))
246    }
247}
248
249/// The MistralRsBuilder takes the pipeline and a scheduler method and constructs
250/// an Engine and a MistralRs instance. The Engine runs on a separate thread, and the MistralRs
251/// instance stays on the calling thread.
252pub struct MistralRsBuilder {
253    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
254    method: SchedulerConfig,
255    log: Option<String>,
256    no_kv_cache: Option<bool>,
257    no_prefix_cache: Option<bool>,
258    prefix_cache_n: Option<usize>,
259    disable_eos_stop: Option<bool>,
260    throughput_logging_enabled: bool,
261    search_embedding_model: Option<BertEmbeddingModel>,
262    search_callback: Option<Arc<SearchCallback>>,
263    tool_callbacks: tools::ToolCallbacks,
264    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
265    mcp_client_config: Option<McpClientConfig>,
266}
267
268impl MistralRsBuilder {
269    /// Creates a new builder with the given pipeline, scheduler method, logging flag,
270    /// and optional embedding model for web search. To override the search callback,
271    /// use `.with_search_callback(...)` on the builder.
272    pub fn new(
273        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
274        method: SchedulerConfig,
275        throughput_logging: bool,
276        search_embedding_model: Option<BertEmbeddingModel>,
277    ) -> Self {
278        Self {
279            pipeline,
280            method,
281            log: None,
282            no_kv_cache: None,
283            no_prefix_cache: None,
284            prefix_cache_n: None,
285            disable_eos_stop: None,
286            throughput_logging_enabled: throughput_logging,
287            search_embedding_model,
288            search_callback: None,
289            tool_callbacks: HashMap::new(),
290            tool_callbacks_with_tools: HashMap::new(),
291            mcp_client_config: None,
292        }
293    }
294    pub fn with_log(mut self, log: String) -> Self {
295        self.log = Some(log);
296        self
297    }
298    pub fn with_opt_log(mut self, log: Option<String>) -> Self {
299        self.log = log;
300        self
301    }
302    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
303        self.no_kv_cache = Some(no_kv_cache);
304        self
305    }
306    pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
307        self.no_prefix_cache = Some(no_prefix_cache);
308        self
309    }
310    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
311        self.prefix_cache_n = Some(prefix_cache_n);
312        self
313    }
314    pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
315        self.disable_eos_stop = Some(disable_eos_stop);
316        self
317    }
318
319    /// Use a custom callback to gather search results.
320    pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
321        self.search_callback = Some(search_callback);
322        self
323    }
324
325    /// Register a custom callback for the specified tool name.
326    pub fn with_tool_callback(
327        mut self,
328        name: impl Into<String>,
329        tool_callback: Arc<ToolCallback>,
330    ) -> Self {
331        self.tool_callbacks.insert(name.into(), tool_callback);
332        self
333    }
334
335    /// Register a custom callback with its associated Tool definition. The Tool will be
336    /// automatically added to requests when tool callbacks are active.
337    pub fn with_tool_callback_and_tool(
338        mut self,
339        name: impl Into<String>,
340        tool_callback: Arc<ToolCallback>,
341        tool: Tool,
342    ) -> Self {
343        let name = name.into();
344        self.tool_callbacks_with_tools.insert(
345            name,
346            ToolCallbackWithTool {
347                callback: tool_callback,
348                tool,
349            },
350        );
351        self
352    }
353
354    /// Configure MCP client to connect to external MCP servers.
355    pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
356        self.mcp_client_config = Some(config);
357        self
358    }
359
360    pub async fn build(self) -> Arc<MistralRs> {
361        MistralRs::new(self).await
362    }
363}
364
365impl Drop for MistralRs {
366    fn drop(&mut self) {
367        // Terminate all engines
368        if let Ok(engines) = self.engines.read() {
369            for (_, engine) in engines.iter() {
370                // Use try_send instead of blocking_send to avoid runtime panics
371                let _ = engine.sender.try_send(Request::Terminate);
372            }
373        }
374    }
375}
376
377impl MistralRs {
378    /// Create an engine instance with the given configuration
379    fn create_engine_instance(
380        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
381        method: SchedulerConfig,
382        config: EngineConfig,
383        reboot_state: RebootState,
384    ) -> Result<EngineInstance, String> {
385        let (tx, rx) = channel(10_000);
386
387        let category = pipeline.try_lock().unwrap().category();
388        let kind = pipeline.try_lock().unwrap().get_metadata().kind.clone();
389        let device = pipeline.try_lock().unwrap().device();
390        let modalities = pipeline
391            .try_lock()
392            .unwrap()
393            .get_metadata()
394            .modalities
395            .clone();
396
397        info!("Pipeline input modalities are {:?}", &modalities.input);
398        info!("Pipeline output modalities are {:?}", &modalities.output);
399
400        let mistralrs_config = MistralRsConfig {
401            kind,
402            device,
403            category: category.clone(),
404            modalities,
405        };
406
407        let engine_handler = thread::spawn(move || {
408            #[cfg(feature = "metal")]
409            objc::rc::autoreleasepool(move || {
410                let rt = Runtime::new().unwrap();
411                rt.block_on(async move {
412                    let engine = Engine::new(
413                        rx,
414                        pipeline,
415                        method,
416                        config.no_kv_cache,
417                        config.no_prefix_cache,
418                        config.prefix_cache_n,
419                        config.disable_eos_stop,
420                        config.throughput_logging_enabled,
421                        config.search_embedding_model,
422                        config.search_callback.clone(),
423                        config.tool_callbacks.clone(),
424                        config.tool_callbacks_with_tools.clone(),
425                    )
426                    .expect("Engine creation failed.");
427                    Arc::new(engine).run().await;
428                })
429            });
430
431            #[cfg(not(feature = "metal"))]
432            {
433                let rt = Runtime::new().unwrap();
434                rt.block_on(async move {
435                    let engine = Engine::new(
436                        rx,
437                        pipeline,
438                        method,
439                        config.no_kv_cache,
440                        config.no_prefix_cache,
441                        config.prefix_cache_n,
442                        config.disable_eos_stop,
443                        config.throughput_logging_enabled,
444                        config.search_embedding_model,
445                        config.search_callback.clone(),
446                        config.tool_callbacks.clone(),
447                        config.tool_callbacks_with_tools.clone(),
448                    )
449                    .expect("Engine creation failed.");
450                    Arc::new(engine).run().await;
451                })
452            }
453        });
454
455        Ok(EngineInstance {
456            sender: tx,
457            engine_handler,
458            reboot_state,
459            config: mistralrs_config,
460            category,
461        })
462    }
463
464    async fn new(config: MistralRsBuilder) -> Arc<Self> {
465        let MistralRsBuilder {
466            pipeline,
467            method,
468            log,
469            no_kv_cache,
470            no_prefix_cache,
471            prefix_cache_n,
472            disable_eos_stop,
473            throughput_logging_enabled,
474            search_embedding_model,
475            search_callback,
476            tool_callbacks,
477            mut tool_callbacks_with_tools,
478            mcp_client_config,
479        } = config;
480
481        mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
482            get_mut_arcmutex!(pipeline).device(),
483        );
484
485        let no_kv_cache = no_kv_cache.unwrap_or(false);
486        let no_prefix_cache = no_prefix_cache.unwrap_or(false);
487        let prefix_cache_n = prefix_cache_n.unwrap_or(16);
488        let disable_eos_stop = disable_eos_stop.unwrap_or(false);
489
490        // Initialize MCP client if configured
491        if let Some(config) = &mcp_client_config {
492            let mut mcp_client = McpClient::new(config.clone());
493            let total_servers = config.servers.len();
494
495            match mcp_client.initialize().await {
496                Ok(()) => {
497                    let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
498                    let tools_count = mcp_callbacks_with_tools.len();
499
500                    // Merge MCP tool callbacks with tools into the new collection
501                    for (name, callback_with_tool) in mcp_callbacks_with_tools {
502                        tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
503                    }
504
505                    if tools_count == 0 {
506                        warn!(
507                            "MCP client initialized but no tools were registered from {} servers",
508                            total_servers
509                        );
510                    } else {
511                        info!(
512                            "MCP client initialized successfully with {} tools from {} servers",
513                            tools_count, total_servers
514                        );
515                    }
516                }
517                Err(e) => {
518                    warn!(
519                        "Failed to initialize MCP client with {} configured servers: {}",
520                        total_servers, e
521                    );
522                    warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
523                }
524            }
525        }
526
527        let reboot_state = RebootState {
528            pipeline: pipeline.clone(),
529            method: method.clone(),
530            no_kv_cache,
531            no_prefix_cache,
532            prefix_cache_n,
533            disable_eos_stop,
534            throughput_logging_enabled,
535            search_embedding_model: search_embedding_model.clone(),
536            search_callback: search_callback.clone(),
537            tool_callbacks: tool_callbacks.clone(),
538            tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
539            mcp_client_config: mcp_client_config.clone(),
540        };
541
542        // Create the engine configuration
543        let engine_config = EngineConfig {
544            no_kv_cache,
545            no_prefix_cache,
546            prefix_cache_n,
547            disable_eos_stop,
548            throughput_logging_enabled,
549            search_embedding_model,
550            search_callback,
551            tool_callbacks,
552            tool_callbacks_with_tools,
553        };
554
555        // Create the engine instance
556        let engine_instance =
557            Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
558                .expect("Failed to create engine instance");
559
560        let id = pipeline.try_lock().unwrap().name();
561
562        if distributed::is_daemon() {
563            let request_sender = engine_instance.sender.clone();
564
565            if cfg!(feature = "ring") {
566                // Ring daemon replicator
567                distributed::ring_daemon_replicator(request_sender);
568            } else {
569                // NCCL daemon replicator
570                distributed::nccl_daemon_replicator(request_sender);
571            }
572
573            #[allow(clippy::empty_loop)]
574            loop {}
575        }
576
577        // Determine if the current runtime is multi-threaded, as blocking operations are not allowed in single-threaded mode
578        let is_multi_threaded = tokio::runtime::Handle::try_current()
579            .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
580
581        // Do a dummy run
582        if !distributed::is_daemon()
583            && is_multi_threaded
584            && matches!(
585                engine_instance.category,
586                ModelCategory::Text | ModelCategory::Vision { .. }
587            )
588        {
589            let clone_sender = engine_instance.sender.clone();
590            tokio::task::block_in_place(|| {
591                let (tx, mut rx) = channel(1);
592                let req = Request::Normal(Box::new(NormalRequest {
593                    id: 0,
594                    messages: RequestMessage::Completion {
595                        text: "hello".to_string(),
596                        echo_prompt: false,
597                        best_of: None,
598                    },
599                    sampling_params: SamplingParams {
600                        max_len: Some(1),
601                        ..SamplingParams::deterministic()
602                    },
603                    response: tx,
604                    return_logprobs: false,
605                    is_streaming: false,
606                    constraint: Constraint::None,
607                    suffix: None,
608                    tool_choice: None,
609                    tools: None,
610                    logits_processors: None,
611                    return_raw_logits: false,
612                    web_search_options: None,
613                    model_id: None,
614                    truncate_sequence: false,
615                }));
616                info!("Beginning dummy run.");
617                let start = Instant::now();
618                clone_sender.blocking_send(req).unwrap();
619
620                // Drain all responses from the channel until it's closed
621                let mut received_any = false;
622                while let Some(_resp) = rx.blocking_recv() {
623                    received_any = true;
624                }
625
626                if received_any {
627                    let end = Instant::now();
628                    info!(
629                        "Dummy run completed in {}s.",
630                        end.duration_since(start).as_secs_f64()
631                    );
632                } else {
633                    warn!("Dummy run failed!");
634                }
635            });
636        }
637
638        // Create engines map with the first engine
639        let mut engines = HashMap::new();
640        engines.insert(id.clone(), engine_instance);
641
642        Arc::new(Self {
643            engines: RwLock::new(engines),
644            default_engine_id: RwLock::new(Some(id.clone())),
645            log,
646            id,
647            creation_time: SystemTime::now()
648                .duration_since(UNIX_EPOCH)
649                .expect("Time travel has occurred!")
650                .as_secs(),
651            next_request_id: Mutex::new(RefCell::new(1)),
652        })
653    }
654
655    /// Attempts to reboot a specific engine by model_id
656    fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
657        let mut engines = self.engines.write().map_err(|_| {
658            tracing::warn!("Couldn't get write lock on engines during reboot attempt");
659            MistralRsError::EnginePoisoned
660        })?;
661
662        if let Some(engine_instance) = engines.get(model_id) {
663            if !engine_instance.engine_handler.is_finished() {
664                tracing::info!("Engine {} already running, returning ok", model_id);
665                return Ok(());
666            }
667
668            let reboot_state = engine_instance.reboot_state.clone();
669            let engine_config = EngineConfig {
670                no_kv_cache: reboot_state.no_kv_cache,
671                no_prefix_cache: reboot_state.no_prefix_cache,
672                prefix_cache_n: reboot_state.prefix_cache_n,
673                disable_eos_stop: reboot_state.disable_eos_stop,
674                throughput_logging_enabled: reboot_state.throughput_logging_enabled,
675                search_embedding_model: reboot_state.search_embedding_model.clone(),
676                search_callback: reboot_state.search_callback.clone(),
677                tool_callbacks: reboot_state.tool_callbacks.clone(),
678                tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
679            };
680            let new_engine_instance = Self::create_engine_instance(
681                reboot_state.pipeline.clone(),
682                reboot_state.method.clone(),
683                engine_config,
684                reboot_state,
685            )
686            .map_err(|e| {
687                tracing::error!("Failed to create new engine instance: {}", e);
688                MistralRsError::EnginePoisoned
689            })?;
690
691            engines.insert(model_id.to_string(), new_engine_instance);
692            tracing::info!("Successfully rebooted engine {}", model_id);
693            Ok(())
694        } else {
695            Err(MistralRsError::EnginePoisoned)
696        }
697    }
698
699    fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
700        let engines = self.engines.read().map_err(|_| {
701            tracing::warn!("Couldn't get read lock on engines!");
702            MistralRsError::EnginePoisoned
703        })?;
704
705        if let Some(engine_instance) = engines.get(model_id) {
706            Ok(engine_instance.engine_handler.is_finished())
707        } else {
708            Err(MistralRsError::EnginePoisoned)
709        }
710    }
711
712    /// Get sender for a specific model. If model_id is None, uses default engine.
713    pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
714        let resolved_model_id = match model_id {
715            Some(id) => id.to_string(),
716            None => {
717                let default_lock = self
718                    .default_engine_id
719                    .read()
720                    .map_err(|_| MistralRsError::SenderPoisoned)?;
721                default_lock
722                    .as_ref()
723                    .ok_or(MistralRsError::EnginePoisoned)?
724                    .clone()
725            }
726        };
727
728        if self.engine_dead(&resolved_model_id)? {
729            tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
730            self.reboot_engine(&resolved_model_id)?
731        }
732
733        let engines = self
734            .engines
735            .read()
736            .map_err(|_| MistralRsError::SenderPoisoned)?;
737        if let Some(engine_instance) = engines.get(&resolved_model_id) {
738            Ok(engine_instance.sender.clone())
739        } else {
740            Err(MistralRsError::EnginePoisoned)
741        }
742    }
743
744    pub fn get_id(&self) -> String {
745        self.id.clone()
746    }
747
748    pub fn get_creation_time(&self) -> u64 {
749        self.creation_time
750    }
751
752    /// Get model category for a specific model. If model_id is None, uses default engine.
753    pub fn get_model_category(
754        &self,
755        model_id: Option<&str>,
756    ) -> Result<ModelCategory, MistralRsError> {
757        let resolved_model_id = match model_id {
758            Some(id) => id.to_string(),
759            None => {
760                let default_lock = self
761                    .default_engine_id
762                    .read()
763                    .map_err(|_| MistralRsError::SenderPoisoned)?;
764                default_lock
765                    .as_ref()
766                    .ok_or(MistralRsError::EnginePoisoned)?
767                    .clone()
768            }
769        };
770
771        let engines = self
772            .engines
773            .read()
774            .map_err(|_| MistralRsError::SenderPoisoned)?;
775        if let Some(engine_instance) = engines.get(&resolved_model_id) {
776            Ok(engine_instance.category.clone())
777        } else {
778            Err(MistralRsError::EnginePoisoned)
779        }
780    }
781
782    pub fn next_request_id(&self) -> usize {
783        let l = self.next_request_id.lock().unwrap();
784        let last = &mut *l.borrow_mut();
785        let last_v = *last;
786        *last += 1;
787        last_v
788    }
789
790    /// Add a new model engine to the MistralRs instance
791    pub async fn add_model(
792        &self,
793        model_id: String,
794        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
795        method: SchedulerConfig,
796        config: AddModelConfig,
797    ) -> Result<(), String> {
798        let reboot_state = RebootState {
799            pipeline: pipeline.clone(),
800            method: method.clone(),
801            no_kv_cache: config.engine_config.no_kv_cache,
802            no_prefix_cache: config.engine_config.no_prefix_cache,
803            prefix_cache_n: config.engine_config.prefix_cache_n,
804            disable_eos_stop: config.engine_config.disable_eos_stop,
805            throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
806            search_embedding_model: config.engine_config.search_embedding_model.clone(),
807            search_callback: config.engine_config.search_callback.clone(),
808            tool_callbacks: config.engine_config.tool_callbacks.clone(),
809            tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
810            mcp_client_config: config.mcp_client_config.clone(),
811        };
812
813        let engine_instance =
814            Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
815
816        let mut engines = self
817            .engines
818            .write()
819            .map_err(|_| "Failed to acquire write lock on engines")?;
820        engines.insert(model_id.clone(), engine_instance);
821
822        // If this is the first model, set it as default
823        if engines.len() == 1 {
824            let mut default_lock = self
825                .default_engine_id
826                .write()
827                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
828            *default_lock = Some(model_id.clone());
829        }
830
831        Ok(())
832    }
833
834    /// Remove a model engine from the MistralRs instance
835    pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
836        let mut engines = self
837            .engines
838            .write()
839            .map_err(|_| "Failed to acquire write lock on engines")?;
840
841        if engines.len() <= 1 {
842            return Err("Cannot remove the last model from MistralRs".to_string());
843        }
844
845        if let Some(engine_instance) = engines.remove(model_id) {
846            // Send terminate signal to the engine
847            let _ = engine_instance.sender.blocking_send(Request::Terminate);
848
849            // If this was the default engine, set a new default
850            let mut default_lock = self
851                .default_engine_id
852                .write()
853                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
854            if let Some(ref default_id) = *default_lock {
855                if default_id == model_id {
856                    // Set the first available engine as the new default
857                    *default_lock = engines.keys().next().cloned();
858                }
859            }
860
861            Ok(())
862        } else {
863            Err(format!("Model {model_id} not found"))
864        }
865    }
866
867    /// List all available model IDs
868    pub fn list_models(&self) -> Result<Vec<String>, String> {
869        let engines = self
870            .engines
871            .read()
872            .map_err(|_| "Failed to acquire read lock on engines")?;
873        Ok(engines.keys().cloned().collect())
874    }
875
876    /// Get the current default model ID
877    pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
878        let default_lock = self
879            .default_engine_id
880            .read()
881            .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
882        Ok(default_lock.clone())
883    }
884
885    /// Set the default model ID
886    pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
887        let engines = self
888            .engines
889            .read()
890            .map_err(|_| "Failed to acquire read lock on engines")?;
891        if !engines.contains_key(model_id) {
892            return Err(format!("Model {model_id} not found"));
893        }
894        drop(engines);
895
896        let mut default_lock = self
897            .default_engine_id
898            .write()
899            .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
900        *default_lock = Some(model_id.to_string());
901
902        Ok(())
903    }
904
905    /// Dispatch a request to the appropriate engine based on the model_id in the request
906    pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
907        let model_id = match &mut request {
908            Request::Normal(normal_req) => normal_req.model_id.as_deref(),
909            _ => None, // Other request types don't specify model_id
910        };
911
912        let sender = self.get_sender(model_id)?;
913        sender
914            .blocking_send(request)
915            .map_err(|_| MistralRsError::SenderPoisoned)
916    }
917
918    pub fn maybe_log_request(this: Arc<Self>, repr: String) {
919        if let Some(file) = &this.log {
920            let mut f = OpenOptions::new()
921                .append(true)
922                .create(true) // Optionally create the file if it doesn't already exist
923                .open(file)
924                .expect("Unable to open file");
925            let time = chrono::offset::Local::now();
926            f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
927                .expect("Unable to write data");
928        }
929    }
930
931    pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
932        if let Some(file) = &this.log {
933            let mut f = OpenOptions::new()
934                .append(true)
935                .create(true) // Optionally create the file if it doesn't already exist
936                .open(file)
937                .expect("Unable to open file");
938            let time = chrono::offset::Local::now();
939            let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
940            f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
941                .expect("Unable to write data");
942        }
943    }
944
945    pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
946        if let Some(file) = &this.log {
947            let mut f = OpenOptions::new()
948                .append(true)
949                .create(true) // Optionally create the file if it doesn't already exist
950                .open(file)
951                .expect("Unable to open file");
952            let time = chrono::offset::Local::now();
953            f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
954                .expect("Unable to write data");
955        }
956    }
957
958    /// Get the number of tools available for a specific model (including MCP tools)
959    pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
960        let resolved_model_id = match model_id {
961            Some(id) => id.to_string(),
962            None => {
963                let default_lock = self
964                    .default_engine_id
965                    .read()
966                    .map_err(|_| "Failed to acquire read lock")?;
967                default_lock
968                    .as_ref()
969                    .ok_or("No default engine set")?
970                    .clone()
971            }
972        };
973
974        let engines = self
975            .engines
976            .read()
977            .map_err(|_| "Failed to acquire read lock on engines")?;
978        if let Some(engine_instance) = engines.get(&resolved_model_id) {
979            Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
980        } else {
981            Err(format!("Model {resolved_model_id} not found"))
982        }
983    }
984
985    /// Check if MCP client is configured for a specific model
986    pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
987        let resolved_model_id = match model_id {
988            Some(id) => id.to_string(),
989            None => {
990                let default_lock = self
991                    .default_engine_id
992                    .read()
993                    .map_err(|_| "Failed to acquire read lock")?;
994                default_lock
995                    .as_ref()
996                    .ok_or("No default engine set")?
997                    .clone()
998            }
999        };
1000
1001        let engines = self
1002            .engines
1003            .read()
1004            .map_err(|_| "Failed to acquire read lock on engines")?;
1005        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1006            Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1007        } else {
1008            Err(format!("Model {resolved_model_id} not found"))
1009        }
1010    }
1011
1012    /// Get config for a specific model
1013    pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1014        let resolved_model_id = match model_id {
1015            Some(id) => id.to_string(),
1016            None => {
1017                let default_lock = self
1018                    .default_engine_id
1019                    .read()
1020                    .map_err(|_| "Failed to acquire read lock")?;
1021                default_lock
1022                    .as_ref()
1023                    .ok_or("No default engine set")?
1024                    .clone()
1025            }
1026        };
1027
1028        let engines = self
1029            .engines
1030            .read()
1031            .map_err(|_| "Failed to acquire read lock on engines")?;
1032        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1033            Ok(engine_instance.config.clone())
1034        } else {
1035            Err(format!("Model {resolved_model_id} not found"))
1036        }
1037    }
1038}