mistralrs_core/
lib.rs

1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5    get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6    EngineInstruction, SearchEmbeddingModel, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
7};
8use hf_hub::Cache;
9pub use lora::Ordering;
10pub use pipeline::ModelCategory;
11pub use pipeline::Pipeline;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::exceptions::PyValueError;
14use std::collections::HashMap;
15use std::sync::OnceLock;
16use std::time::Instant;
17use std::{
18    cell::RefCell,
19    error::Error,
20    fs::OpenOptions,
21    io::Write,
22    sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
23    thread::{self, JoinHandle},
24    time::{SystemTime, UNIX_EPOCH},
25};
26use tokio::sync::mpsc::{channel, Sender};
27use tracing::info;
28use tracing::warn;
29
30mod cuda;
31mod device_map;
32mod engine;
33mod lora;
34mod model_loader;
35mod ops;
36pub use model_loader::{
37    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
38};
39mod embedding_models;
40mod kv_cache;
41mod search;
42
43mod model_selected;
44pub use model_selected::ModelSelected;
45pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
46
47mod amoe;
48mod attention;
49mod diffusion_models;
50pub mod distributed;
51mod gguf;
52pub mod layers;
53mod layers_masker;
54mod layers_utils;
55pub mod matformer;
56mod models;
57mod paged_attention;
58mod pipeline;
59mod prefix_cacher;
60mod request;
61mod response;
62mod sampler;
63mod scheduler;
64mod sequence;
65mod speech_models;
66mod toml_selector;
67mod tools;
68mod topology;
69mod utils;
70mod vision_models;
71mod xlora_models;
72
73pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
74pub use device_map::{
75    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
76};
77pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
78pub use mistralrs_audio::AudioInput;
79pub use mistralrs_mcp::{
80    CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
81};
82pub use mistralrs_mcp::{
83    McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
84};
85pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
86pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
87pub use pipeline::{
88    chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
89    AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
90    DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader, EmbeddingLoaderBuilder,
91    EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig, GGMLLoader,
92    GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig,
93    GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader,
94    Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind,
95    ModelPaths, MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
96    NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
97    SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
98    SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
99    VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
100};
101pub use request::{
102    ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
103    LlguidanceGrammar, MessageContent, NormalRequest, Request, RequestMessage, SearchContextSize,
104    TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
105};
106pub use response::*;
107pub use sampler::{
108    CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
109};
110pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
111pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
112use serde::Serialize;
113pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
114use tokio::runtime::Runtime;
115use toml_selector::{TomlLoaderArgs, TomlSelector};
116pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
117pub use topology::{LayerTopology, Topology};
118pub use utils::debug::initialize_logging;
119pub use utils::memory_usage::MemoryUsage;
120pub use utils::normal::{ModelDType, TryIntoDType};
121pub use utils::{paged_attn_supported, using_flash_attn};
122
123// re-export llguidance for easier LlguidanceGrammar construction
124pub use llguidance;
125
126/// `true` if `MISTRALRS_DEBUG=1`
127pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
128pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
129
130/// Configuration for creating an engine instance
131#[derive(Clone)]
132pub struct EngineConfig {
133    pub no_kv_cache: bool,
134    pub no_prefix_cache: bool,
135    pub prefix_cache_n: usize,
136    pub disable_eos_stop: bool,
137    pub throughput_logging_enabled: bool,
138    pub search_embedding_model: Option<SearchEmbeddingModel>,
139    pub search_callback: Option<Arc<SearchCallback>>,
140    pub tool_callbacks: tools::ToolCallbacks,
141    pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
142}
143
144impl Default for EngineConfig {
145    fn default() -> Self {
146        Self {
147            no_kv_cache: false,
148            no_prefix_cache: false,
149            prefix_cache_n: 16,
150            disable_eos_stop: false,
151            throughput_logging_enabled: true,
152            search_embedding_model: None,
153            search_callback: None,
154            tool_callbacks: HashMap::new(),
155            tool_callbacks_with_tools: HashMap::new(),
156        }
157    }
158}
159
160/// Configuration for adding a model to MistralRs
161#[derive(Clone)]
162pub struct AddModelConfig {
163    pub engine_config: EngineConfig,
164    pub mcp_client_config: Option<McpClientConfig>,
165}
166
167impl AddModelConfig {
168    pub fn new(engine_config: EngineConfig) -> Self {
169        Self {
170            engine_config,
171            mcp_client_config: None,
172        }
173    }
174
175    pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
176        self.mcp_client_config = Some(mcp_config);
177        self
178    }
179}
180
181#[derive(Clone)]
182pub struct MistralRsConfig {
183    pub kind: ModelKind,
184    pub device: Device,
185    pub category: ModelCategory,
186    pub modalities: Modalities,
187    pub max_seq_len: Option<usize>,
188}
189
190/// Internal structure to hold per-engine state
191struct EngineInstance {
192    sender: Sender<Request>,
193    engine_handler: JoinHandle<()>,
194    reboot_state: RebootState,
195    config: MistralRsConfig,
196    category: ModelCategory,
197}
198
199/// The MistralRs struct handles sending requests to multiple engines.
200/// It is the core multi-threaded component of mistral.rs, and uses `mpsc`
201/// `Sender` and `Receiver` primitives to send and receive requests to the
202/// appropriate engine based on model ID.
203pub struct MistralRs {
204    engines: RwLock<HashMap<String, EngineInstance>>,
205    default_engine_id: RwLock<Option<String>>,
206    log: Option<String>,
207    id: String,
208    creation_time: u64,
209    next_request_id: Mutex<RefCell<usize>>,
210}
211
212#[derive(Clone)]
213struct RebootState {
214    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
215    method: SchedulerConfig,
216    no_kv_cache: bool,
217    no_prefix_cache: bool,
218    prefix_cache_n: usize,
219    disable_eos_stop: bool,
220    throughput_logging_enabled: bool,
221    search_embedding_model: Option<SearchEmbeddingModel>,
222    search_callback: Option<Arc<search::SearchCallback>>,
223    tool_callbacks: tools::ToolCallbacks,
224    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
225    mcp_client_config: Option<McpClientConfig>,
226}
227
228#[derive(Debug)]
229pub enum MistralRsError {
230    EnginePoisoned,
231    SenderPoisoned,
232}
233
234impl std::fmt::Display for MistralRsError {
235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        write!(f, "{:?}", &self)
237    }
238}
239
240impl std::error::Error for MistralRsError {}
241
242#[cfg(feature = "pyo3_macros")]
243impl From<MistralRsError> for pyo3::PyErr {
244    fn from(value: MistralRsError) -> Self {
245        PyValueError::new_err(format!("{value:?}"))
246    }
247}
248
249/// The MistralRsBuilder takes the pipeline and a scheduler method and constructs
250/// an Engine and a MistralRs instance. The Engine runs on a separate thread, and the MistralRs
251/// instance stays on the calling thread.
252pub struct MistralRsBuilder {
253    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
254    method: SchedulerConfig,
255    log: Option<String>,
256    no_kv_cache: Option<bool>,
257    no_prefix_cache: Option<bool>,
258    prefix_cache_n: Option<usize>,
259    disable_eos_stop: Option<bool>,
260    throughput_logging_enabled: bool,
261    search_embedding_model: Option<SearchEmbeddingModel>,
262    search_callback: Option<Arc<SearchCallback>>,
263    tool_callbacks: tools::ToolCallbacks,
264    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
265    mcp_client_config: Option<McpClientConfig>,
266}
267
268impl MistralRsBuilder {
269    /// Creates a new builder with the given pipeline, scheduler method, logging flag,
270    /// and optional embedding model for web search. To override the search callback,
271    /// use `.with_search_callback(...)` on the builder.
272    pub fn new(
273        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
274        method: SchedulerConfig,
275        throughput_logging: bool,
276        search_embedding_model: Option<SearchEmbeddingModel>,
277    ) -> Self {
278        Self {
279            pipeline,
280            method,
281            log: None,
282            no_kv_cache: None,
283            no_prefix_cache: None,
284            prefix_cache_n: None,
285            disable_eos_stop: None,
286            throughput_logging_enabled: throughput_logging,
287            search_embedding_model,
288            search_callback: None,
289            tool_callbacks: HashMap::new(),
290            tool_callbacks_with_tools: HashMap::new(),
291            mcp_client_config: None,
292        }
293    }
294    pub fn with_log(mut self, log: String) -> Self {
295        self.log = Some(log);
296        self
297    }
298    pub fn with_opt_log(mut self, log: Option<String>) -> Self {
299        self.log = log;
300        self
301    }
302    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
303        self.no_kv_cache = Some(no_kv_cache);
304        self
305    }
306    pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
307        self.no_prefix_cache = Some(no_prefix_cache);
308        self
309    }
310    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
311        self.prefix_cache_n = Some(prefix_cache_n);
312        self
313    }
314    pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
315        self.disable_eos_stop = Some(disable_eos_stop);
316        self
317    }
318
319    /// Use a custom callback to gather search results.
320    pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
321        self.search_callback = Some(search_callback);
322        self
323    }
324
325    /// Register a custom callback for the specified tool name.
326    pub fn with_tool_callback(
327        mut self,
328        name: impl Into<String>,
329        tool_callback: Arc<ToolCallback>,
330    ) -> Self {
331        self.tool_callbacks.insert(name.into(), tool_callback);
332        self
333    }
334
335    /// Register a custom callback with its associated Tool definition. The Tool will be
336    /// automatically added to requests when tool callbacks are active.
337    pub fn with_tool_callback_and_tool(
338        mut self,
339        name: impl Into<String>,
340        tool_callback: Arc<ToolCallback>,
341        tool: Tool,
342    ) -> Self {
343        let name = name.into();
344        self.tool_callbacks_with_tools.insert(
345            name,
346            ToolCallbackWithTool {
347                callback: tool_callback,
348                tool,
349            },
350        );
351        self
352    }
353
354    /// Configure MCP client to connect to external MCP servers.
355    pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
356        self.mcp_client_config = Some(config);
357        self
358    }
359
360    pub async fn build(self) -> Arc<MistralRs> {
361        MistralRs::new(self).await
362    }
363}
364
365impl Drop for MistralRs {
366    fn drop(&mut self) {
367        // Terminate all engines
368        if let Ok(engines) = self.engines.read() {
369            for (_, engine) in engines.iter() {
370                // Use try_send instead of blocking_send to avoid runtime panics
371                let _ = engine.sender.try_send(Request::Terminate);
372            }
373        }
374    }
375}
376
377impl MistralRs {
378    /// Create an engine instance with the given configuration
379    fn create_engine_instance(
380        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
381        method: SchedulerConfig,
382        config: EngineConfig,
383        reboot_state: RebootState,
384    ) -> Result<EngineInstance, String> {
385        let (tx, rx) = channel(10_000);
386
387        let pipeline_guard = pipeline.try_lock().unwrap();
388        let category = pipeline_guard.category();
389        let metadata = pipeline_guard.get_metadata();
390        let kind = metadata.kind.clone();
391        let device = pipeline_guard.device();
392        let modalities = metadata.modalities.clone();
393        let max_seq_len = match &category {
394            ModelCategory::Diffusion | ModelCategory::Speech => None,
395            _ => Some(metadata.max_seq_len),
396        };
397        drop(pipeline_guard);
398
399        info!("Pipeline input modalities are {:?}", &modalities.input);
400        info!("Pipeline output modalities are {:?}", &modalities.output);
401
402        let mistralrs_config = MistralRsConfig {
403            kind,
404            device,
405            category: category.clone(),
406            modalities,
407            max_seq_len,
408        };
409
410        let engine_handler = thread::spawn(move || {
411            #[cfg(feature = "metal")]
412            objc::rc::autoreleasepool(move || {
413                let rt = Runtime::new().unwrap();
414                rt.block_on(async move {
415                    let engine = Engine::new(
416                        rx,
417                        pipeline,
418                        method,
419                        config.no_kv_cache,
420                        config.no_prefix_cache,
421                        config.prefix_cache_n,
422                        config.disable_eos_stop,
423                        config.throughput_logging_enabled,
424                        config.search_embedding_model,
425                        config.search_callback.clone(),
426                        config.tool_callbacks.clone(),
427                        config.tool_callbacks_with_tools.clone(),
428                    )
429                    .expect("Engine creation failed.");
430                    Arc::new(engine).run().await;
431                })
432            });
433
434            #[cfg(not(feature = "metal"))]
435            {
436                let rt = Runtime::new().unwrap();
437                rt.block_on(async move {
438                    let engine = Engine::new(
439                        rx,
440                        pipeline,
441                        method,
442                        config.no_kv_cache,
443                        config.no_prefix_cache,
444                        config.prefix_cache_n,
445                        config.disable_eos_stop,
446                        config.throughput_logging_enabled,
447                        config.search_embedding_model,
448                        config.search_callback.clone(),
449                        config.tool_callbacks.clone(),
450                        config.tool_callbacks_with_tools.clone(),
451                    )
452                    .expect("Engine creation failed.");
453                    Arc::new(engine).run().await;
454                })
455            }
456        });
457
458        Ok(EngineInstance {
459            sender: tx,
460            engine_handler,
461            reboot_state,
462            config: mistralrs_config,
463            category,
464        })
465    }
466
467    async fn new(config: MistralRsBuilder) -> Arc<Self> {
468        let MistralRsBuilder {
469            pipeline,
470            method,
471            log,
472            no_kv_cache,
473            no_prefix_cache,
474            prefix_cache_n,
475            disable_eos_stop,
476            throughput_logging_enabled,
477            search_embedding_model,
478            search_callback,
479            tool_callbacks,
480            mut tool_callbacks_with_tools,
481            mcp_client_config,
482        } = config;
483
484        mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
485            get_mut_arcmutex!(pipeline).device(),
486        );
487
488        let no_kv_cache = no_kv_cache.unwrap_or(false);
489        let no_prefix_cache = no_prefix_cache.unwrap_or(false);
490        let prefix_cache_n = prefix_cache_n.unwrap_or(16);
491        let disable_eos_stop = disable_eos_stop.unwrap_or(false);
492
493        // Initialize MCP client if configured
494        if let Some(config) = &mcp_client_config {
495            let mut mcp_client = McpClient::new(config.clone());
496            let total_servers = config.servers.len();
497
498            match mcp_client.initialize().await {
499                Ok(()) => {
500                    let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
501                    let tools_count = mcp_callbacks_with_tools.len();
502
503                    // Merge MCP tool callbacks with tools into the new collection
504                    for (name, callback_with_tool) in mcp_callbacks_with_tools {
505                        tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
506                    }
507
508                    if tools_count == 0 {
509                        warn!(
510                            "MCP client initialized but no tools were registered from {} servers",
511                            total_servers
512                        );
513                    } else {
514                        info!(
515                            "MCP client initialized successfully with {} tools from {} servers",
516                            tools_count, total_servers
517                        );
518                    }
519                }
520                Err(e) => {
521                    warn!(
522                        "Failed to initialize MCP client with {} configured servers: {}",
523                        total_servers, e
524                    );
525                    warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
526                }
527            }
528        }
529
530        let reboot_state = RebootState {
531            pipeline: pipeline.clone(),
532            method: method.clone(),
533            no_kv_cache,
534            no_prefix_cache,
535            prefix_cache_n,
536            disable_eos_stop,
537            throughput_logging_enabled,
538            search_embedding_model,
539            search_callback: search_callback.clone(),
540            tool_callbacks: tool_callbacks.clone(),
541            tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
542            mcp_client_config: mcp_client_config.clone(),
543        };
544
545        // Create the engine configuration
546        let engine_config = EngineConfig {
547            no_kv_cache,
548            no_prefix_cache,
549            prefix_cache_n,
550            disable_eos_stop,
551            throughput_logging_enabled,
552            search_embedding_model,
553            search_callback,
554            tool_callbacks,
555            tool_callbacks_with_tools,
556        };
557
558        // Create the engine instance
559        let engine_instance =
560            Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
561                .expect("Failed to create engine instance");
562
563        let id = pipeline.try_lock().unwrap().name();
564
565        if distributed::is_daemon() {
566            let request_sender = engine_instance.sender.clone();
567
568            if cfg!(feature = "ring") {
569                // Ring daemon replicator
570                distributed::ring_daemon_replicator(request_sender);
571            } else {
572                // NCCL daemon replicator
573                distributed::nccl_daemon_replicator(request_sender);
574            }
575
576            #[allow(clippy::empty_loop)]
577            loop {}
578        }
579
580        // Determine if the current runtime is multi-threaded, as blocking operations are not allowed in single-threaded mode
581        let is_multi_threaded = tokio::runtime::Handle::try_current()
582            .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
583
584        // Do a dummy run
585        if !distributed::is_daemon()
586            && is_multi_threaded
587            && matches!(
588                engine_instance.category,
589                ModelCategory::Text | ModelCategory::Vision { .. }
590            )
591        {
592            let clone_sender = engine_instance.sender.clone();
593            tokio::task::block_in_place(|| {
594                let (tx, mut rx) = channel(1);
595                let req = Request::Normal(Box::new(NormalRequest {
596                    id: 0,
597                    messages: RequestMessage::Completion {
598                        text: "hello".to_string(),
599                        echo_prompt: false,
600                        best_of: None,
601                    },
602                    sampling_params: SamplingParams {
603                        max_len: Some(1),
604                        ..SamplingParams::deterministic()
605                    },
606                    response: tx,
607                    return_logprobs: false,
608                    is_streaming: false,
609                    constraint: Constraint::None,
610                    suffix: None,
611                    tool_choice: None,
612                    tools: None,
613                    logits_processors: None,
614                    return_raw_logits: false,
615                    web_search_options: None,
616                    model_id: None,
617                    truncate_sequence: false,
618                }));
619                info!("Beginning dummy run.");
620                let start = Instant::now();
621                clone_sender.blocking_send(req).unwrap();
622
623                // Drain all responses from the channel until it's closed
624                let mut received_any = false;
625                while let Some(_resp) = rx.blocking_recv() {
626                    received_any = true;
627                }
628
629                if received_any {
630                    let end = Instant::now();
631                    info!(
632                        "Dummy run completed in {}s.",
633                        end.duration_since(start).as_secs_f64()
634                    );
635                } else {
636                    warn!("Dummy run failed!");
637                }
638            });
639        }
640
641        // Create engines map with the first engine
642        let mut engines = HashMap::new();
643        engines.insert(id.clone(), engine_instance);
644
645        Arc::new(Self {
646            engines: RwLock::new(engines),
647            default_engine_id: RwLock::new(Some(id.clone())),
648            log,
649            id,
650            creation_time: SystemTime::now()
651                .duration_since(UNIX_EPOCH)
652                .expect("Time travel has occurred!")
653                .as_secs(),
654            next_request_id: Mutex::new(RefCell::new(1)),
655        })
656    }
657
658    /// Attempts to reboot a specific engine by model_id
659    fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
660        let mut engines = self.engines.write().map_err(|_| {
661            tracing::warn!("Couldn't get write lock on engines during reboot attempt");
662            MistralRsError::EnginePoisoned
663        })?;
664
665        if let Some(engine_instance) = engines.get(model_id) {
666            if !engine_instance.engine_handler.is_finished() {
667                tracing::info!("Engine {} already running, returning ok", model_id);
668                return Ok(());
669            }
670
671            let reboot_state = engine_instance.reboot_state.clone();
672            let engine_config = EngineConfig {
673                no_kv_cache: reboot_state.no_kv_cache,
674                no_prefix_cache: reboot_state.no_prefix_cache,
675                prefix_cache_n: reboot_state.prefix_cache_n,
676                disable_eos_stop: reboot_state.disable_eos_stop,
677                throughput_logging_enabled: reboot_state.throughput_logging_enabled,
678                search_embedding_model: reboot_state.search_embedding_model,
679                search_callback: reboot_state.search_callback.clone(),
680                tool_callbacks: reboot_state.tool_callbacks.clone(),
681                tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
682            };
683            let new_engine_instance = Self::create_engine_instance(
684                reboot_state.pipeline.clone(),
685                reboot_state.method.clone(),
686                engine_config,
687                reboot_state,
688            )
689            .map_err(|e| {
690                tracing::error!("Failed to create new engine instance: {}", e);
691                MistralRsError::EnginePoisoned
692            })?;
693
694            engines.insert(model_id.to_string(), new_engine_instance);
695            tracing::info!("Successfully rebooted engine {}", model_id);
696            Ok(())
697        } else {
698            Err(MistralRsError::EnginePoisoned)
699        }
700    }
701
702    fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
703        let engines = self.engines.read().map_err(|_| {
704            tracing::warn!("Couldn't get read lock on engines!");
705            MistralRsError::EnginePoisoned
706        })?;
707
708        if let Some(engine_instance) = engines.get(model_id) {
709            Ok(engine_instance.engine_handler.is_finished())
710        } else {
711            Err(MistralRsError::EnginePoisoned)
712        }
713    }
714
715    /// Get sender for a specific model. If model_id is None, uses default engine.
716    pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
717        let resolved_model_id = match model_id {
718            Some(id) => id.to_string(),
719            None => {
720                let default_lock = self
721                    .default_engine_id
722                    .read()
723                    .map_err(|_| MistralRsError::SenderPoisoned)?;
724                default_lock
725                    .as_ref()
726                    .ok_or(MistralRsError::EnginePoisoned)?
727                    .clone()
728            }
729        };
730
731        if self.engine_dead(&resolved_model_id)? {
732            tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
733            self.reboot_engine(&resolved_model_id)?
734        }
735
736        let engines = self
737            .engines
738            .read()
739            .map_err(|_| MistralRsError::SenderPoisoned)?;
740        if let Some(engine_instance) = engines.get(&resolved_model_id) {
741            Ok(engine_instance.sender.clone())
742        } else {
743            Err(MistralRsError::EnginePoisoned)
744        }
745    }
746
747    pub fn get_id(&self) -> String {
748        self.id.clone()
749    }
750
751    pub fn get_creation_time(&self) -> u64 {
752        self.creation_time
753    }
754
755    /// Get model category for a specific model. If model_id is None, uses default engine.
756    pub fn get_model_category(
757        &self,
758        model_id: Option<&str>,
759    ) -> Result<ModelCategory, MistralRsError> {
760        let resolved_model_id = match model_id {
761            Some(id) => id.to_string(),
762            None => {
763                let default_lock = self
764                    .default_engine_id
765                    .read()
766                    .map_err(|_| MistralRsError::SenderPoisoned)?;
767                default_lock
768                    .as_ref()
769                    .ok_or(MistralRsError::EnginePoisoned)?
770                    .clone()
771            }
772        };
773
774        let engines = self
775            .engines
776            .read()
777            .map_err(|_| MistralRsError::SenderPoisoned)?;
778        if let Some(engine_instance) = engines.get(&resolved_model_id) {
779            Ok(engine_instance.category.clone())
780        } else {
781            Err(MistralRsError::EnginePoisoned)
782        }
783    }
784
785    /// Get the maximum supported sequence length for a model, if applicable.
786    pub fn max_sequence_length(
787        &self,
788        model_id: Option<&str>,
789    ) -> Result<Option<usize>, MistralRsError> {
790        let resolved_model_id = match model_id {
791            Some(id) => id.to_string(),
792            None => {
793                let default_lock = self
794                    .default_engine_id
795                    .read()
796                    .map_err(|_| MistralRsError::SenderPoisoned)?;
797                default_lock
798                    .as_ref()
799                    .ok_or(MistralRsError::EnginePoisoned)?
800                    .clone()
801            }
802        };
803
804        let engines = self
805            .engines
806            .read()
807            .map_err(|_| MistralRsError::SenderPoisoned)?;
808        if let Some(engine_instance) = engines.get(&resolved_model_id) {
809            Ok(engine_instance.config.max_seq_len)
810        } else {
811            Err(MistralRsError::EnginePoisoned)
812        }
813    }
814
815    pub fn next_request_id(&self) -> usize {
816        let l = self.next_request_id.lock().unwrap();
817        let last = &mut *l.borrow_mut();
818        let last_v = *last;
819        *last += 1;
820        last_v
821    }
822
823    /// Add a new model engine to the MistralRs instance
824    pub async fn add_model(
825        &self,
826        model_id: String,
827        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
828        method: SchedulerConfig,
829        config: AddModelConfig,
830    ) -> Result<(), String> {
831        let reboot_state = RebootState {
832            pipeline: pipeline.clone(),
833            method: method.clone(),
834            no_kv_cache: config.engine_config.no_kv_cache,
835            no_prefix_cache: config.engine_config.no_prefix_cache,
836            prefix_cache_n: config.engine_config.prefix_cache_n,
837            disable_eos_stop: config.engine_config.disable_eos_stop,
838            throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
839            search_embedding_model: config.engine_config.search_embedding_model,
840            search_callback: config.engine_config.search_callback.clone(),
841            tool_callbacks: config.engine_config.tool_callbacks.clone(),
842            tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
843            mcp_client_config: config.mcp_client_config.clone(),
844        };
845
846        let engine_instance =
847            Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
848
849        let mut engines = self
850            .engines
851            .write()
852            .map_err(|_| "Failed to acquire write lock on engines")?;
853        engines.insert(model_id.clone(), engine_instance);
854
855        // If this is the first model, set it as default
856        if engines.len() == 1 {
857            let mut default_lock = self
858                .default_engine_id
859                .write()
860                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
861            *default_lock = Some(model_id.clone());
862        }
863
864        Ok(())
865    }
866
867    /// Remove a model engine from the MistralRs instance
868    pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
869        let mut engines = self
870            .engines
871            .write()
872            .map_err(|_| "Failed to acquire write lock on engines")?;
873
874        if engines.len() <= 1 {
875            return Err("Cannot remove the last model from MistralRs".to_string());
876        }
877
878        if let Some(engine_instance) = engines.remove(model_id) {
879            // Send terminate signal to the engine
880            let _ = engine_instance.sender.blocking_send(Request::Terminate);
881
882            // If this was the default engine, set a new default
883            let mut default_lock = self
884                .default_engine_id
885                .write()
886                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
887            if let Some(ref default_id) = *default_lock {
888                if default_id == model_id {
889                    // Set the first available engine as the new default
890                    *default_lock = engines.keys().next().cloned();
891                }
892            }
893
894            Ok(())
895        } else {
896            Err(format!("Model {model_id} not found"))
897        }
898    }
899
900    /// List all available model IDs
901    pub fn list_models(&self) -> Result<Vec<String>, String> {
902        let engines = self
903            .engines
904            .read()
905            .map_err(|_| "Failed to acquire read lock on engines")?;
906        Ok(engines.keys().cloned().collect())
907    }
908
909    /// Get the current default model ID
910    pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
911        let default_lock = self
912            .default_engine_id
913            .read()
914            .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
915        Ok(default_lock.clone())
916    }
917
918    /// Set the default model ID
919    pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
920        let engines = self
921            .engines
922            .read()
923            .map_err(|_| "Failed to acquire read lock on engines")?;
924        if !engines.contains_key(model_id) {
925            return Err(format!("Model {model_id} not found"));
926        }
927        drop(engines);
928
929        let mut default_lock = self
930            .default_engine_id
931            .write()
932            .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
933        *default_lock = Some(model_id.to_string());
934
935        Ok(())
936    }
937
938    /// Dispatch a request to the appropriate engine based on the model_id in the request
939    pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
940        let model_id = match &mut request {
941            Request::Normal(normal_req) => normal_req.model_id.as_deref(),
942            _ => None, // Other request types don't specify model_id
943        };
944
945        let sender = self.get_sender(model_id)?;
946        sender
947            .blocking_send(request)
948            .map_err(|_| MistralRsError::SenderPoisoned)
949    }
950
951    pub fn maybe_log_request(this: Arc<Self>, repr: String) {
952        if let Some(file) = &this.log {
953            let mut f = OpenOptions::new()
954                .append(true)
955                .create(true) // Optionally create the file if it doesn't already exist
956                .open(file)
957                .expect("Unable to open file");
958            let time = chrono::offset::Local::now();
959            f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
960                .expect("Unable to write data");
961        }
962    }
963
964    pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
965        if let Some(file) = &this.log {
966            let mut f = OpenOptions::new()
967                .append(true)
968                .create(true) // Optionally create the file if it doesn't already exist
969                .open(file)
970                .expect("Unable to open file");
971            let time = chrono::offset::Local::now();
972            let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
973            f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
974                .expect("Unable to write data");
975        }
976    }
977
978    pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
979        if let Some(file) = &this.log {
980            let mut f = OpenOptions::new()
981                .append(true)
982                .create(true) // Optionally create the file if it doesn't already exist
983                .open(file)
984                .expect("Unable to open file");
985            let time = chrono::offset::Local::now();
986            f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
987                .expect("Unable to write data");
988        }
989    }
990
991    /// Get the number of tools available for a specific model (including MCP tools)
992    pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
993        let resolved_model_id = match model_id {
994            Some(id) => id.to_string(),
995            None => {
996                let default_lock = self
997                    .default_engine_id
998                    .read()
999                    .map_err(|_| "Failed to acquire read lock")?;
1000                default_lock
1001                    .as_ref()
1002                    .ok_or("No default engine set")?
1003                    .clone()
1004            }
1005        };
1006
1007        let engines = self
1008            .engines
1009            .read()
1010            .map_err(|_| "Failed to acquire read lock on engines")?;
1011        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1012            Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
1013        } else {
1014            Err(format!("Model {resolved_model_id} not found"))
1015        }
1016    }
1017
1018    /// Check if MCP client is configured for a specific model
1019    pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1020        let resolved_model_id = match model_id {
1021            Some(id) => id.to_string(),
1022            None => {
1023                let default_lock = self
1024                    .default_engine_id
1025                    .read()
1026                    .map_err(|_| "Failed to acquire read lock")?;
1027                default_lock
1028                    .as_ref()
1029                    .ok_or("No default engine set")?
1030                    .clone()
1031            }
1032        };
1033
1034        let engines = self
1035            .engines
1036            .read()
1037            .map_err(|_| "Failed to acquire read lock on engines")?;
1038        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1039            Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1040        } else {
1041            Err(format!("Model {resolved_model_id} not found"))
1042        }
1043    }
1044
1045    /// Get config for a specific model
1046    pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1047        let resolved_model_id = match model_id {
1048            Some(id) => id.to_string(),
1049            None => {
1050                let default_lock = self
1051                    .default_engine_id
1052                    .read()
1053                    .map_err(|_| "Failed to acquire read lock")?;
1054                default_lock
1055                    .as_ref()
1056                    .ok_or("No default engine set")?
1057                    .clone()
1058            }
1059        };
1060
1061        let engines = self
1062            .engines
1063            .read()
1064            .map_err(|_| "Failed to acquire read lock on engines")?;
1065        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1066            Ok(engine_instance.config.clone())
1067        } else {
1068            Err(format!("Model {resolved_model_id} not found"))
1069        }
1070    }
1071}