mistralrs_core/
lib.rs

1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5    get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6    EngineInstruction, SearchEmbeddingModel, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
7};
8use hf_hub::Cache;
9pub use lora::Ordering;
10pub use pipeline::ModelCategory;
11pub use pipeline::Pipeline;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::exceptions::PyValueError;
14use std::collections::HashMap;
15use std::num::NonZeroUsize;
16use std::sync::OnceLock;
17use std::time::Instant;
18use std::{
19    cell::RefCell,
20    error::Error,
21    fs::OpenOptions,
22    io::Write,
23    sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
24    thread::{self, JoinHandle},
25    time::{SystemTime, UNIX_EPOCH},
26};
27use tokio::sync::mpsc::{channel, Sender};
28use tracing::info;
29use tracing::warn;
30
31mod cuda;
32mod device_map;
33mod engine;
34mod lora;
35mod model_loader;
36mod moe;
37mod ops;
38pub use model_loader::{
39    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
40};
41mod embedding_models;
42mod kv_cache;
43mod search;
44
45mod model_selected;
46pub use model_selected::ModelSelected;
47pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
48
49mod amoe;
50mod attention;
51mod diffusion_models;
52pub mod distributed;
53mod gguf;
54pub mod harmony;
55pub mod layers;
56mod layers_masker;
57mod layers_utils;
58pub mod matformer;
59mod models;
60mod paged_attention;
61mod pipeline;
62mod prefix_cacher;
63mod request;
64mod response;
65mod sampler;
66mod scheduler;
67mod sequence;
68mod speech_models;
69mod toml_selector;
70mod tools;
71mod topology;
72mod utils;
73mod vision_models;
74mod xlora_models;
75
76pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
77pub use device_map::{
78    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
79};
80pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
81pub use mistralrs_audio::AudioInput;
82pub use mistralrs_mcp::{
83    CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
84};
85pub use mistralrs_mcp::{
86    McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
87};
88pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
89pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
90pub use pipeline::{
91    chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
92    AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
93    DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader, EmbeddingLoaderBuilder,
94    EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig, GGMLLoader,
95    GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig,
96    GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader,
97    Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind,
98    ModelPaths, MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
99    NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
100    SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
101    SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
102    VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
103};
104pub use request::{
105    ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
106    LlguidanceGrammar, MessageContent, NormalRequest, ReasoningEffort, Request, RequestMessage,
107    SearchContextSize, TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
108};
109pub use response::*;
110pub use sampler::{
111    CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
112};
113pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
114pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
115use serde::Serialize;
116pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
117use tokio::runtime::Runtime;
118use toml_selector::{TomlLoaderArgs, TomlSelector};
119pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
120pub use topology::{LayerTopology, Topology};
121pub use utils::debug::initialize_logging;
122pub use utils::memory_usage::MemoryUsage;
123pub use utils::normal::{ModelDType, TryIntoDType};
124pub use utils::{paged_attn_supported, using_flash_attn};
125
126// re-export llguidance for easier LlguidanceGrammar construction
127pub use llguidance;
128
129/// `true` if `MISTRALRS_DEBUG=1`
130pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
131pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
132
133/// Configuration for creating an engine instance
134#[derive(Clone)]
135pub struct EngineConfig {
136    pub no_kv_cache: bool,
137    pub no_prefix_cache: bool,
138    pub prefix_cache_n: usize,
139    pub disable_eos_stop: bool,
140    pub throughput_logging_enabled: bool,
141    pub search_embedding_model: Option<SearchEmbeddingModel>,
142    pub search_callback: Option<Arc<SearchCallback>>,
143    pub tool_callbacks: tools::ToolCallbacks,
144    pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
145}
146
147impl Default for EngineConfig {
148    fn default() -> Self {
149        Self {
150            no_kv_cache: false,
151            no_prefix_cache: false,
152            prefix_cache_n: 16,
153            disable_eos_stop: false,
154            throughput_logging_enabled: true,
155            search_embedding_model: None,
156            search_callback: None,
157            tool_callbacks: HashMap::new(),
158            tool_callbacks_with_tools: HashMap::new(),
159        }
160    }
161}
162
163/// Configuration for adding a model to MistralRs
164#[derive(Clone)]
165pub struct AddModelConfig {
166    pub engine_config: EngineConfig,
167    pub mcp_client_config: Option<McpClientConfig>,
168}
169
170impl AddModelConfig {
171    pub fn new(engine_config: EngineConfig) -> Self {
172        Self {
173            engine_config,
174            mcp_client_config: None,
175        }
176    }
177
178    pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
179        self.mcp_client_config = Some(mcp_config);
180        self
181    }
182}
183
184#[derive(Clone)]
185pub struct MistralRsConfig {
186    pub kind: ModelKind,
187    pub device: Device,
188    pub category: ModelCategory,
189    pub modalities: Modalities,
190    pub max_seq_len: Option<usize>,
191}
192
193/// Internal structure to hold per-engine state
194struct EngineInstance {
195    sender: Sender<Request>,
196    engine_handler: JoinHandle<()>,
197    reboot_state: RebootState,
198    config: MistralRsConfig,
199    category: ModelCategory,
200}
201
202/// The MistralRs struct handles sending requests to multiple engines.
203/// It is the core multi-threaded component of mistral.rs, and uses `mpsc`
204/// `Sender` and `Receiver` primitives to send and receive requests to the
205/// appropriate engine based on model ID.
206pub struct MistralRs {
207    engines: RwLock<HashMap<String, EngineInstance>>,
208    default_engine_id: RwLock<Option<String>>,
209    log: Option<String>,
210    id: String,
211    creation_time: u64,
212    next_request_id: Mutex<RefCell<usize>>,
213}
214
215#[derive(Clone)]
216struct RebootState {
217    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
218    method: SchedulerConfig,
219    no_kv_cache: bool,
220    no_prefix_cache: bool,
221    prefix_cache_n: usize,
222    disable_eos_stop: bool,
223    throughput_logging_enabled: bool,
224    search_embedding_model: Option<SearchEmbeddingModel>,
225    search_callback: Option<Arc<search::SearchCallback>>,
226    tool_callbacks: tools::ToolCallbacks,
227    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
228    mcp_client_config: Option<McpClientConfig>,
229}
230
231#[derive(Debug)]
232pub enum MistralRsError {
233    EnginePoisoned,
234    SenderPoisoned,
235}
236
237impl std::fmt::Display for MistralRsError {
238    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        write!(f, "{:?}", &self)
240    }
241}
242
243impl std::error::Error for MistralRsError {}
244
245#[cfg(feature = "pyo3_macros")]
246impl From<MistralRsError> for pyo3::PyErr {
247    fn from(value: MistralRsError) -> Self {
248        PyValueError::new_err(format!("{value:?}"))
249    }
250}
251
252/// The MistralRsBuilder takes the pipeline and a scheduler method and constructs
253/// an Engine and a MistralRs instance. The Engine runs on a separate thread, and the MistralRs
254/// instance stays on the calling thread.
255pub struct MistralRsBuilder {
256    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
257    method: SchedulerConfig,
258    log: Option<String>,
259    no_kv_cache: Option<bool>,
260    no_prefix_cache: Option<bool>,
261    prefix_cache_n: Option<usize>,
262    disable_eos_stop: Option<bool>,
263    throughput_logging_enabled: bool,
264    search_embedding_model: Option<SearchEmbeddingModel>,
265    search_callback: Option<Arc<SearchCallback>>,
266    tool_callbacks: tools::ToolCallbacks,
267    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
268    mcp_client_config: Option<McpClientConfig>,
269}
270
271impl MistralRsBuilder {
272    /// Creates a new builder with the given pipeline, scheduler method, logging flag,
273    /// and optional embedding model for web search. To override the search callback,
274    /// use `.with_search_callback(...)` on the builder.
275    pub fn new(
276        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
277        method: SchedulerConfig,
278        throughput_logging: bool,
279        search_embedding_model: Option<SearchEmbeddingModel>,
280    ) -> Self {
281        Self {
282            pipeline,
283            method,
284            log: None,
285            no_kv_cache: None,
286            no_prefix_cache: None,
287            prefix_cache_n: None,
288            disable_eos_stop: None,
289            throughput_logging_enabled: throughput_logging,
290            search_embedding_model,
291            search_callback: None,
292            tool_callbacks: HashMap::new(),
293            tool_callbacks_with_tools: HashMap::new(),
294            mcp_client_config: None,
295        }
296    }
297    pub fn with_log(mut self, log: String) -> Self {
298        self.log = Some(log);
299        self
300    }
301    pub fn with_opt_log(mut self, log: Option<String>) -> Self {
302        self.log = log;
303        self
304    }
305    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
306        self.no_kv_cache = Some(no_kv_cache);
307        self
308    }
309    pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
310        self.no_prefix_cache = Some(no_prefix_cache);
311        self
312    }
313    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
314        self.prefix_cache_n = Some(prefix_cache_n);
315        self
316    }
317    pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
318        self.disable_eos_stop = Some(disable_eos_stop);
319        self
320    }
321
322    /// Use a custom callback to gather search results.
323    pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
324        self.search_callback = Some(search_callback);
325        self
326    }
327
328    /// Register a custom callback for the specified tool name.
329    pub fn with_tool_callback(
330        mut self,
331        name: impl Into<String>,
332        tool_callback: Arc<ToolCallback>,
333    ) -> Self {
334        self.tool_callbacks.insert(name.into(), tool_callback);
335        self
336    }
337
338    /// Register a custom callback with its associated Tool definition. The Tool will be
339    /// automatically added to requests when tool callbacks are active.
340    pub fn with_tool_callback_and_tool(
341        mut self,
342        name: impl Into<String>,
343        tool_callback: Arc<ToolCallback>,
344        tool: Tool,
345    ) -> Self {
346        let name = name.into();
347        self.tool_callbacks_with_tools.insert(
348            name,
349            ToolCallbackWithTool {
350                callback: tool_callback,
351                tool,
352            },
353        );
354        self
355    }
356
357    /// Configure MCP client to connect to external MCP servers.
358    pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
359        self.mcp_client_config = Some(config);
360        self
361    }
362
363    pub async fn build(self) -> Arc<MistralRs> {
364        MistralRs::new(self).await
365    }
366}
367
368impl Drop for MistralRs {
369    fn drop(&mut self) {
370        // Terminate all engines
371        if let Ok(engines) = self.engines.read() {
372            for (_, engine) in engines.iter() {
373                // Use try_send instead of blocking_send to avoid runtime panics
374                let _ = engine.sender.try_send(Request::Terminate);
375            }
376        }
377    }
378}
379
380impl MistralRs {
381    /// Create an engine instance with the given configuration
382    fn create_engine_instance(
383        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
384        method: SchedulerConfig,
385        config: EngineConfig,
386        reboot_state: RebootState,
387    ) -> Result<EngineInstance, String> {
388        let (tx, rx) = channel(10_000);
389
390        let pipeline_guard = pipeline.try_lock().unwrap();
391        let category = pipeline_guard.category();
392        let metadata = pipeline_guard.get_metadata();
393        let kind = metadata.kind.clone();
394        let device = pipeline_guard.device();
395        let modalities = metadata.modalities.clone();
396        let max_seq_len = match &category {
397            ModelCategory::Diffusion | ModelCategory::Speech => None,
398            _ => Some(metadata.max_seq_len),
399        };
400        drop(pipeline_guard);
401
402        info!("Pipeline input modalities are {:?}", &modalities.input);
403        info!("Pipeline output modalities are {:?}", &modalities.output);
404
405        let mistralrs_config = MistralRsConfig {
406            kind,
407            device,
408            category: category.clone(),
409            modalities,
410            max_seq_len,
411        };
412
413        let tx_for_engine = tx.clone();
414        let engine_handler = thread::spawn(move || {
415            #[cfg(feature = "metal")]
416            objc::rc::autoreleasepool(move || {
417                let rt = Runtime::new().unwrap();
418                rt.block_on(async move {
419                    let engine = Engine::new(
420                        tx_for_engine,
421                        rx,
422                        pipeline,
423                        method,
424                        config.no_kv_cache,
425                        config.no_prefix_cache,
426                        config.prefix_cache_n,
427                        config.disable_eos_stop,
428                        config.throughput_logging_enabled,
429                        config.search_embedding_model,
430                        config.search_callback.clone(),
431                        config.tool_callbacks.clone(),
432                        config.tool_callbacks_with_tools.clone(),
433                    )
434                    .expect("Engine creation failed.");
435                    Arc::new(engine).run().await;
436                })
437            });
438
439            #[cfg(not(feature = "metal"))]
440            {
441                let rt = Runtime::new().unwrap();
442                rt.block_on(async move {
443                    let engine = Engine::new(
444                        tx_for_engine,
445                        rx,
446                        pipeline,
447                        method,
448                        config.no_kv_cache,
449                        config.no_prefix_cache,
450                        config.prefix_cache_n,
451                        config.disable_eos_stop,
452                        config.throughput_logging_enabled,
453                        config.search_embedding_model,
454                        config.search_callback.clone(),
455                        config.tool_callbacks.clone(),
456                        config.tool_callbacks_with_tools.clone(),
457                    )
458                    .expect("Engine creation failed.");
459                    Arc::new(engine).run().await;
460                })
461            }
462        });
463
464        Ok(EngineInstance {
465            sender: tx,
466            engine_handler,
467            reboot_state,
468            config: mistralrs_config,
469            category,
470        })
471    }
472
473    async fn new(config: MistralRsBuilder) -> Arc<Self> {
474        let MistralRsBuilder {
475            pipeline,
476            method,
477            log,
478            no_kv_cache,
479            no_prefix_cache,
480            prefix_cache_n,
481            disable_eos_stop,
482            throughput_logging_enabled,
483            search_embedding_model,
484            search_callback,
485            tool_callbacks,
486            mut tool_callbacks_with_tools,
487            mcp_client_config,
488        } = config;
489
490        mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
491            get_mut_arcmutex!(pipeline).device(),
492        );
493
494        // For hybrid models (Mamba-Attention), force batch_size=1 to prevent state bleeding
495        // Mamba's stateful nature makes batched inference complex; this ensures correctness
496        let method = if !get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache
497            && get_mut_arcmutex!(pipeline).cache().is_hybrid()
498        {
499            info!(
500                "Hybrid model detected (Mamba-Attention), enforcing batch_size=1 for correctness"
501            );
502            SchedulerConfig::DefaultScheduler {
503                method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()),
504            }
505        } else {
506            method
507        };
508
509        let no_kv_cache = no_kv_cache.unwrap_or(false);
510        let no_prefix_cache = no_prefix_cache.unwrap_or(false);
511        let prefix_cache_n = prefix_cache_n.unwrap_or(16);
512        let disable_eos_stop = disable_eos_stop.unwrap_or(false);
513
514        // Initialize MCP client if configured
515        if let Some(config) = &mcp_client_config {
516            let mut mcp_client = McpClient::new(config.clone());
517            let total_servers = config.servers.len();
518
519            match mcp_client.initialize().await {
520                Ok(()) => {
521                    let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
522                    let tools_count = mcp_callbacks_with_tools.len();
523
524                    // Merge MCP tool callbacks with tools into the new collection
525                    for (name, callback_with_tool) in mcp_callbacks_with_tools {
526                        tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
527                    }
528
529                    if tools_count == 0 {
530                        warn!(
531                            "MCP client initialized but no tools were registered from {} servers",
532                            total_servers
533                        );
534                    } else {
535                        info!(
536                            "MCP client initialized successfully with {} tools from {} servers",
537                            tools_count, total_servers
538                        );
539                    }
540                }
541                Err(e) => {
542                    warn!(
543                        "Failed to initialize MCP client with {} configured servers: {}",
544                        total_servers, e
545                    );
546                    warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
547                }
548            }
549        }
550
551        let reboot_state = RebootState {
552            pipeline: pipeline.clone(),
553            method: method.clone(),
554            no_kv_cache,
555            no_prefix_cache,
556            prefix_cache_n,
557            disable_eos_stop,
558            throughput_logging_enabled,
559            search_embedding_model,
560            search_callback: search_callback.clone(),
561            tool_callbacks: tool_callbacks.clone(),
562            tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
563            mcp_client_config: mcp_client_config.clone(),
564        };
565
566        // Create the engine configuration
567        let engine_config = EngineConfig {
568            no_kv_cache,
569            no_prefix_cache,
570            prefix_cache_n,
571            disable_eos_stop,
572            throughput_logging_enabled,
573            search_embedding_model,
574            search_callback,
575            tool_callbacks,
576            tool_callbacks_with_tools,
577        };
578
579        // Create the engine instance
580        let engine_instance =
581            Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
582                .expect("Failed to create engine instance");
583
584        let id = pipeline.try_lock().unwrap().name();
585
586        if distributed::is_daemon() {
587            let request_sender = engine_instance.sender.clone();
588
589            if cfg!(feature = "ring") {
590                // Ring daemon replicator
591                distributed::ring_daemon_replicator(request_sender);
592            } else {
593                // NCCL daemon replicator
594                distributed::nccl_daemon_replicator(request_sender);
595            }
596
597            #[allow(clippy::empty_loop)]
598            loop {}
599        }
600
601        // Determine if the current runtime is multi-threaded, as blocking operations are not allowed in single-threaded mode
602        let is_multi_threaded = tokio::runtime::Handle::try_current()
603            .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
604
605        // Do a dummy run
606        if !distributed::is_daemon()
607            && is_multi_threaded
608            && matches!(
609                engine_instance.category,
610                ModelCategory::Text | ModelCategory::Vision { .. }
611            )
612        {
613            let clone_sender = engine_instance.sender.clone();
614            tokio::task::block_in_place(|| {
615                let (tx, mut rx) = channel(1);
616                let req = Request::Normal(Box::new(NormalRequest {
617                    id: 0,
618                    messages: RequestMessage::Completion {
619                        text: "hello".to_string(),
620                        echo_prompt: false,
621                        best_of: None,
622                    },
623                    sampling_params: SamplingParams {
624                        max_len: Some(1),
625                        ..SamplingParams::deterministic()
626                    },
627                    response: tx,
628                    return_logprobs: false,
629                    is_streaming: false,
630                    constraint: Constraint::None,
631                    suffix: None,
632                    tool_choice: None,
633                    tools: None,
634                    logits_processors: None,
635                    return_raw_logits: false,
636                    web_search_options: None,
637                    model_id: None,
638                    truncate_sequence: false,
639                }));
640                info!("Beginning dummy run.");
641                let start = Instant::now();
642                clone_sender.blocking_send(req).unwrap();
643
644                // Drain all responses from the channel until it's closed
645                let mut received_any = false;
646                while let Some(_resp) = rx.blocking_recv() {
647                    received_any = true;
648                }
649
650                if received_any {
651                    let end = Instant::now();
652                    info!(
653                        "Dummy run completed in {}s.",
654                        end.duration_since(start).as_secs_f64()
655                    );
656                } else {
657                    warn!("Dummy run failed!");
658                }
659            });
660        }
661
662        // Create engines map with the first engine
663        let mut engines = HashMap::new();
664        engines.insert(id.clone(), engine_instance);
665
666        Arc::new(Self {
667            engines: RwLock::new(engines),
668            default_engine_id: RwLock::new(Some(id.clone())),
669            log,
670            id,
671            creation_time: SystemTime::now()
672                .duration_since(UNIX_EPOCH)
673                .expect("Time travel has occurred!")
674                .as_secs(),
675            next_request_id: Mutex::new(RefCell::new(1)),
676        })
677    }
678
679    /// Attempts to reboot a specific engine by model_id
680    fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
681        let mut engines = self.engines.write().map_err(|_| {
682            tracing::warn!("Couldn't get write lock on engines during reboot attempt");
683            MistralRsError::EnginePoisoned
684        })?;
685
686        if let Some(engine_instance) = engines.get(model_id) {
687            if !engine_instance.engine_handler.is_finished() {
688                tracing::info!("Engine {} already running, returning ok", model_id);
689                return Ok(());
690            }
691
692            let reboot_state = engine_instance.reboot_state.clone();
693            let engine_config = EngineConfig {
694                no_kv_cache: reboot_state.no_kv_cache,
695                no_prefix_cache: reboot_state.no_prefix_cache,
696                prefix_cache_n: reboot_state.prefix_cache_n,
697                disable_eos_stop: reboot_state.disable_eos_stop,
698                throughput_logging_enabled: reboot_state.throughput_logging_enabled,
699                search_embedding_model: reboot_state.search_embedding_model,
700                search_callback: reboot_state.search_callback.clone(),
701                tool_callbacks: reboot_state.tool_callbacks.clone(),
702                tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
703            };
704            let new_engine_instance = Self::create_engine_instance(
705                reboot_state.pipeline.clone(),
706                reboot_state.method.clone(),
707                engine_config,
708                reboot_state,
709            )
710            .map_err(|e| {
711                tracing::error!("Failed to create new engine instance: {}", e);
712                MistralRsError::EnginePoisoned
713            })?;
714
715            engines.insert(model_id.to_string(), new_engine_instance);
716            tracing::info!("Successfully rebooted engine {}", model_id);
717            Ok(())
718        } else {
719            Err(MistralRsError::EnginePoisoned)
720        }
721    }
722
723    fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
724        let engines = self.engines.read().map_err(|_| {
725            tracing::warn!("Couldn't get read lock on engines!");
726            MistralRsError::EnginePoisoned
727        })?;
728
729        if let Some(engine_instance) = engines.get(model_id) {
730            Ok(engine_instance.engine_handler.is_finished())
731        } else {
732            Err(MistralRsError::EnginePoisoned)
733        }
734    }
735
736    /// Get sender for a specific model. If model_id is None, uses default engine.
737    pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
738        let resolved_model_id = match model_id {
739            Some(id) => id.to_string(),
740            None => {
741                let default_lock = self
742                    .default_engine_id
743                    .read()
744                    .map_err(|_| MistralRsError::SenderPoisoned)?;
745                default_lock
746                    .as_ref()
747                    .ok_or(MistralRsError::EnginePoisoned)?
748                    .clone()
749            }
750        };
751
752        if self.engine_dead(&resolved_model_id)? {
753            tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
754            self.reboot_engine(&resolved_model_id)?
755        }
756
757        let engines = self
758            .engines
759            .read()
760            .map_err(|_| MistralRsError::SenderPoisoned)?;
761        if let Some(engine_instance) = engines.get(&resolved_model_id) {
762            Ok(engine_instance.sender.clone())
763        } else {
764            Err(MistralRsError::EnginePoisoned)
765        }
766    }
767
768    pub fn get_id(&self) -> String {
769        self.id.clone()
770    }
771
772    pub fn get_creation_time(&self) -> u64 {
773        self.creation_time
774    }
775
776    /// Get model category for a specific model. If model_id is None, uses default engine.
777    pub fn get_model_category(
778        &self,
779        model_id: Option<&str>,
780    ) -> Result<ModelCategory, MistralRsError> {
781        let resolved_model_id = match model_id {
782            Some(id) => id.to_string(),
783            None => {
784                let default_lock = self
785                    .default_engine_id
786                    .read()
787                    .map_err(|_| MistralRsError::SenderPoisoned)?;
788                default_lock
789                    .as_ref()
790                    .ok_or(MistralRsError::EnginePoisoned)?
791                    .clone()
792            }
793        };
794
795        let engines = self
796            .engines
797            .read()
798            .map_err(|_| MistralRsError::SenderPoisoned)?;
799        if let Some(engine_instance) = engines.get(&resolved_model_id) {
800            Ok(engine_instance.category.clone())
801        } else {
802            Err(MistralRsError::EnginePoisoned)
803        }
804    }
805
806    /// Get the maximum supported sequence length for a model, if applicable.
807    pub fn max_sequence_length(
808        &self,
809        model_id: Option<&str>,
810    ) -> Result<Option<usize>, MistralRsError> {
811        let resolved_model_id = match model_id {
812            Some(id) => id.to_string(),
813            None => {
814                let default_lock = self
815                    .default_engine_id
816                    .read()
817                    .map_err(|_| MistralRsError::SenderPoisoned)?;
818                default_lock
819                    .as_ref()
820                    .ok_or(MistralRsError::EnginePoisoned)?
821                    .clone()
822            }
823        };
824
825        let engines = self
826            .engines
827            .read()
828            .map_err(|_| MistralRsError::SenderPoisoned)?;
829        if let Some(engine_instance) = engines.get(&resolved_model_id) {
830            Ok(engine_instance.config.max_seq_len)
831        } else {
832            Err(MistralRsError::EnginePoisoned)
833        }
834    }
835
836    pub fn next_request_id(&self) -> usize {
837        let l = self.next_request_id.lock().unwrap();
838        let last = &mut *l.borrow_mut();
839        let last_v = *last;
840        *last += 1;
841        last_v
842    }
843
844    /// Add a new model engine to the MistralRs instance
845    pub async fn add_model(
846        &self,
847        model_id: String,
848        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
849        method: SchedulerConfig,
850        config: AddModelConfig,
851    ) -> Result<(), String> {
852        // For hybrid models (Mamba-Attention), force batch_size=1 to prevent state bleeding
853        let method = {
854            let pipeline_guard = pipeline.try_lock().unwrap();
855            if !pipeline_guard.get_metadata().no_kv_cache && pipeline_guard.cache().is_hybrid() {
856                info!(
857                    "Hybrid model detected (Mamba-Attention), enforcing batch_size=1 for correctness"
858                );
859                SchedulerConfig::DefaultScheduler {
860                    method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()),
861                }
862            } else {
863                method
864            }
865        };
866
867        let reboot_state = RebootState {
868            pipeline: pipeline.clone(),
869            method: method.clone(),
870            no_kv_cache: config.engine_config.no_kv_cache,
871            no_prefix_cache: config.engine_config.no_prefix_cache,
872            prefix_cache_n: config.engine_config.prefix_cache_n,
873            disable_eos_stop: config.engine_config.disable_eos_stop,
874            throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
875            search_embedding_model: config.engine_config.search_embedding_model,
876            search_callback: config.engine_config.search_callback.clone(),
877            tool_callbacks: config.engine_config.tool_callbacks.clone(),
878            tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
879            mcp_client_config: config.mcp_client_config.clone(),
880        };
881
882        let engine_instance =
883            Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
884
885        let mut engines = self
886            .engines
887            .write()
888            .map_err(|_| "Failed to acquire write lock on engines")?;
889        engines.insert(model_id.clone(), engine_instance);
890
891        // If this is the first model, set it as default
892        if engines.len() == 1 {
893            let mut default_lock = self
894                .default_engine_id
895                .write()
896                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
897            *default_lock = Some(model_id.clone());
898        }
899
900        Ok(())
901    }
902
903    /// Remove a model engine from the MistralRs instance
904    pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
905        let mut engines = self
906            .engines
907            .write()
908            .map_err(|_| "Failed to acquire write lock on engines")?;
909
910        if engines.len() <= 1 {
911            return Err("Cannot remove the last model from MistralRs".to_string());
912        }
913
914        if let Some(engine_instance) = engines.remove(model_id) {
915            // Send terminate signal to the engine
916            let _ = engine_instance.sender.blocking_send(Request::Terminate);
917
918            // If this was the default engine, set a new default
919            let mut default_lock = self
920                .default_engine_id
921                .write()
922                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
923            if let Some(ref default_id) = *default_lock {
924                if default_id == model_id {
925                    // Set the first available engine as the new default
926                    *default_lock = engines.keys().next().cloned();
927                }
928            }
929
930            Ok(())
931        } else {
932            Err(format!("Model {model_id} not found"))
933        }
934    }
935
936    /// List all available model IDs
937    pub fn list_models(&self) -> Result<Vec<String>, String> {
938        let engines = self
939            .engines
940            .read()
941            .map_err(|_| "Failed to acquire read lock on engines")?;
942        Ok(engines.keys().cloned().collect())
943    }
944
945    /// Get the current default model ID
946    pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
947        let default_lock = self
948            .default_engine_id
949            .read()
950            .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
951        Ok(default_lock.clone())
952    }
953
954    /// Set the default model ID
955    pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
956        let engines = self
957            .engines
958            .read()
959            .map_err(|_| "Failed to acquire read lock on engines")?;
960        if !engines.contains_key(model_id) {
961            return Err(format!("Model {model_id} not found"));
962        }
963        drop(engines);
964
965        let mut default_lock = self
966            .default_engine_id
967            .write()
968            .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
969        *default_lock = Some(model_id.to_string());
970
971        Ok(())
972    }
973
974    /// Dispatch a request to the appropriate engine based on the model_id in the request
975    pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
976        let model_id = match &mut request {
977            Request::Normal(normal_req) => normal_req.model_id.as_deref(),
978            _ => None, // Other request types don't specify model_id
979        };
980
981        let sender = self.get_sender(model_id)?;
982        sender
983            .blocking_send(request)
984            .map_err(|_| MistralRsError::SenderPoisoned)
985    }
986
987    pub fn maybe_log_request(this: Arc<Self>, repr: String) {
988        if let Some(file) = &this.log {
989            let mut f = OpenOptions::new()
990                .append(true)
991                .create(true) // Optionally create the file if it doesn't already exist
992                .open(file)
993                .expect("Unable to open file");
994            let time = chrono::offset::Local::now();
995            f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
996                .expect("Unable to write data");
997        }
998    }
999
1000    pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
1001        if let Some(file) = &this.log {
1002            let mut f = OpenOptions::new()
1003                .append(true)
1004                .create(true) // Optionally create the file if it doesn't already exist
1005                .open(file)
1006                .expect("Unable to open file");
1007            let time = chrono::offset::Local::now();
1008            let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
1009            f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
1010                .expect("Unable to write data");
1011        }
1012    }
1013
1014    pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
1015        if let Some(file) = &this.log {
1016            let mut f = OpenOptions::new()
1017                .append(true)
1018                .create(true) // Optionally create the file if it doesn't already exist
1019                .open(file)
1020                .expect("Unable to open file");
1021            let time = chrono::offset::Local::now();
1022            f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
1023                .expect("Unable to write data");
1024        }
1025    }
1026
1027    /// Get the number of tools available for a specific model (including MCP tools)
1028    pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
1029        let resolved_model_id = match model_id {
1030            Some(id) => id.to_string(),
1031            None => {
1032                let default_lock = self
1033                    .default_engine_id
1034                    .read()
1035                    .map_err(|_| "Failed to acquire read lock")?;
1036                default_lock
1037                    .as_ref()
1038                    .ok_or("No default engine set")?
1039                    .clone()
1040            }
1041        };
1042
1043        let engines = self
1044            .engines
1045            .read()
1046            .map_err(|_| "Failed to acquire read lock on engines")?;
1047        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1048            Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
1049        } else {
1050            Err(format!("Model {resolved_model_id} not found"))
1051        }
1052    }
1053
1054    /// Check if MCP client is configured for a specific model
1055    pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1056        let resolved_model_id = match model_id {
1057            Some(id) => id.to_string(),
1058            None => {
1059                let default_lock = self
1060                    .default_engine_id
1061                    .read()
1062                    .map_err(|_| "Failed to acquire read lock")?;
1063                default_lock
1064                    .as_ref()
1065                    .ok_or("No default engine set")?
1066                    .clone()
1067            }
1068        };
1069
1070        let engines = self
1071            .engines
1072            .read()
1073            .map_err(|_| "Failed to acquire read lock on engines")?;
1074        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1075            Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1076        } else {
1077            Err(format!("Model {resolved_model_id} not found"))
1078        }
1079    }
1080
1081    /// Get config for a specific model
1082    pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1083        let resolved_model_id = match model_id {
1084            Some(id) => id.to_string(),
1085            None => {
1086                let default_lock = self
1087                    .default_engine_id
1088                    .read()
1089                    .map_err(|_| "Failed to acquire read lock")?;
1090                default_lock
1091                    .as_ref()
1092                    .ok_or("No default engine set")?
1093                    .clone()
1094            }
1095        };
1096
1097        let engines = self
1098            .engines
1099            .read()
1100            .map_err(|_| "Failed to acquire read lock on engines")?;
1101        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1102            Ok(engine_instance.config.clone())
1103        } else {
1104            Err(format!("Model {resolved_model_id} not found"))
1105        }
1106    }
1107}