mistralrs_core/
lib.rs

1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5    get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6    EngineInstruction, SearchEmbeddingModel, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
7};
8use hf_hub::Cache;
9pub use lora::Ordering;
10pub use pipeline::ModelCategory;
11pub use pipeline::Pipeline;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::exceptions::PyValueError;
14use std::collections::HashMap;
15use std::num::NonZeroUsize;
16use std::sync::OnceLock;
17use std::time::Instant;
18use std::{
19    cell::RefCell,
20    error::Error,
21    fs::OpenOptions,
22    io::Write,
23    sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
24    thread::{self, JoinHandle},
25    time::{SystemTime, UNIX_EPOCH},
26};
27use tokio::sync::mpsc::{channel, Sender};
28use tracing::info;
29use tracing::warn;
30
31mod cuda;
32mod device_map;
33mod engine;
34mod lora;
35mod model_loader;
36mod moe;
37mod ops;
38pub use model_loader::{
39    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
40};
41mod embedding_models;
42mod kv_cache;
43mod search;
44
45mod model_selected;
46pub use model_selected::ModelSelected;
47pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
48
49mod amoe;
50mod attention;
51mod diffusion_models;
52pub mod distributed;
53mod gguf;
54pub mod layers;
55mod layers_masker;
56mod layers_utils;
57pub mod matformer;
58mod models;
59mod paged_attention;
60mod pipeline;
61mod prefix_cacher;
62mod request;
63mod response;
64mod sampler;
65mod scheduler;
66mod sequence;
67mod speech_models;
68mod toml_selector;
69mod tools;
70mod topology;
71mod utils;
72mod vision_models;
73mod xlora_models;
74
75pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
76pub use device_map::{
77    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
78};
79pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
80pub use mistralrs_audio::AudioInput;
81pub use mistralrs_mcp::{
82    CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
83};
84pub use mistralrs_mcp::{
85    McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
86};
87pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
88pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
89pub use pipeline::{
90    chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
91    AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
92    DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader, EmbeddingLoaderBuilder,
93    EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig, GGMLLoader,
94    GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig,
95    GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader,
96    Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind,
97    ModelPaths, MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
98    NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
99    SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
100    SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
101    VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
102};
103pub use request::{
104    ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
105    LlguidanceGrammar, MessageContent, NormalRequest, Request, RequestMessage, SearchContextSize,
106    TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
107};
108pub use response::*;
109pub use sampler::{
110    CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
111};
112pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
113pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
114use serde::Serialize;
115pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
116use tokio::runtime::Runtime;
117use toml_selector::{TomlLoaderArgs, TomlSelector};
118pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
119pub use topology::{LayerTopology, Topology};
120pub use utils::debug::initialize_logging;
121pub use utils::memory_usage::MemoryUsage;
122pub use utils::normal::{ModelDType, TryIntoDType};
123pub use utils::{paged_attn_supported, using_flash_attn};
124
125// re-export llguidance for easier LlguidanceGrammar construction
126pub use llguidance;
127
128/// `true` if `MISTRALRS_DEBUG=1`
129pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
130pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
131
132/// Configuration for creating an engine instance
133#[derive(Clone)]
134pub struct EngineConfig {
135    pub no_kv_cache: bool,
136    pub no_prefix_cache: bool,
137    pub prefix_cache_n: usize,
138    pub disable_eos_stop: bool,
139    pub throughput_logging_enabled: bool,
140    pub search_embedding_model: Option<SearchEmbeddingModel>,
141    pub search_callback: Option<Arc<SearchCallback>>,
142    pub tool_callbacks: tools::ToolCallbacks,
143    pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
144}
145
146impl Default for EngineConfig {
147    fn default() -> Self {
148        Self {
149            no_kv_cache: false,
150            no_prefix_cache: false,
151            prefix_cache_n: 16,
152            disable_eos_stop: false,
153            throughput_logging_enabled: true,
154            search_embedding_model: None,
155            search_callback: None,
156            tool_callbacks: HashMap::new(),
157            tool_callbacks_with_tools: HashMap::new(),
158        }
159    }
160}
161
162/// Configuration for adding a model to MistralRs
163#[derive(Clone)]
164pub struct AddModelConfig {
165    pub engine_config: EngineConfig,
166    pub mcp_client_config: Option<McpClientConfig>,
167}
168
169impl AddModelConfig {
170    pub fn new(engine_config: EngineConfig) -> Self {
171        Self {
172            engine_config,
173            mcp_client_config: None,
174        }
175    }
176
177    pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
178        self.mcp_client_config = Some(mcp_config);
179        self
180    }
181}
182
183#[derive(Clone)]
184pub struct MistralRsConfig {
185    pub kind: ModelKind,
186    pub device: Device,
187    pub category: ModelCategory,
188    pub modalities: Modalities,
189    pub max_seq_len: Option<usize>,
190}
191
192/// Internal structure to hold per-engine state
193struct EngineInstance {
194    sender: Sender<Request>,
195    engine_handler: JoinHandle<()>,
196    reboot_state: RebootState,
197    config: MistralRsConfig,
198    category: ModelCategory,
199}
200
201/// The MistralRs struct handles sending requests to multiple engines.
202/// It is the core multi-threaded component of mistral.rs, and uses `mpsc`
203/// `Sender` and `Receiver` primitives to send and receive requests to the
204/// appropriate engine based on model ID.
205pub struct MistralRs {
206    engines: RwLock<HashMap<String, EngineInstance>>,
207    default_engine_id: RwLock<Option<String>>,
208    log: Option<String>,
209    id: String,
210    creation_time: u64,
211    next_request_id: Mutex<RefCell<usize>>,
212}
213
214#[derive(Clone)]
215struct RebootState {
216    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
217    method: SchedulerConfig,
218    no_kv_cache: bool,
219    no_prefix_cache: bool,
220    prefix_cache_n: usize,
221    disable_eos_stop: bool,
222    throughput_logging_enabled: bool,
223    search_embedding_model: Option<SearchEmbeddingModel>,
224    search_callback: Option<Arc<search::SearchCallback>>,
225    tool_callbacks: tools::ToolCallbacks,
226    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
227    mcp_client_config: Option<McpClientConfig>,
228}
229
230#[derive(Debug)]
231pub enum MistralRsError {
232    EnginePoisoned,
233    SenderPoisoned,
234}
235
236impl std::fmt::Display for MistralRsError {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        write!(f, "{:?}", &self)
239    }
240}
241
242impl std::error::Error for MistralRsError {}
243
244#[cfg(feature = "pyo3_macros")]
245impl From<MistralRsError> for pyo3::PyErr {
246    fn from(value: MistralRsError) -> Self {
247        PyValueError::new_err(format!("{value:?}"))
248    }
249}
250
251/// The MistralRsBuilder takes the pipeline and a scheduler method and constructs
252/// an Engine and a MistralRs instance. The Engine runs on a separate thread, and the MistralRs
253/// instance stays on the calling thread.
254pub struct MistralRsBuilder {
255    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
256    method: SchedulerConfig,
257    log: Option<String>,
258    no_kv_cache: Option<bool>,
259    no_prefix_cache: Option<bool>,
260    prefix_cache_n: Option<usize>,
261    disable_eos_stop: Option<bool>,
262    throughput_logging_enabled: bool,
263    search_embedding_model: Option<SearchEmbeddingModel>,
264    search_callback: Option<Arc<SearchCallback>>,
265    tool_callbacks: tools::ToolCallbacks,
266    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
267    mcp_client_config: Option<McpClientConfig>,
268}
269
270impl MistralRsBuilder {
271    /// Creates a new builder with the given pipeline, scheduler method, logging flag,
272    /// and optional embedding model for web search. To override the search callback,
273    /// use `.with_search_callback(...)` on the builder.
274    pub fn new(
275        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
276        method: SchedulerConfig,
277        throughput_logging: bool,
278        search_embedding_model: Option<SearchEmbeddingModel>,
279    ) -> Self {
280        Self {
281            pipeline,
282            method,
283            log: None,
284            no_kv_cache: None,
285            no_prefix_cache: None,
286            prefix_cache_n: None,
287            disable_eos_stop: None,
288            throughput_logging_enabled: throughput_logging,
289            search_embedding_model,
290            search_callback: None,
291            tool_callbacks: HashMap::new(),
292            tool_callbacks_with_tools: HashMap::new(),
293            mcp_client_config: None,
294        }
295    }
296    pub fn with_log(mut self, log: String) -> Self {
297        self.log = Some(log);
298        self
299    }
300    pub fn with_opt_log(mut self, log: Option<String>) -> Self {
301        self.log = log;
302        self
303    }
304    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
305        self.no_kv_cache = Some(no_kv_cache);
306        self
307    }
308    pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
309        self.no_prefix_cache = Some(no_prefix_cache);
310        self
311    }
312    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
313        self.prefix_cache_n = Some(prefix_cache_n);
314        self
315    }
316    pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
317        self.disable_eos_stop = Some(disable_eos_stop);
318        self
319    }
320
321    /// Use a custom callback to gather search results.
322    pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
323        self.search_callback = Some(search_callback);
324        self
325    }
326
327    /// Register a custom callback for the specified tool name.
328    pub fn with_tool_callback(
329        mut self,
330        name: impl Into<String>,
331        tool_callback: Arc<ToolCallback>,
332    ) -> Self {
333        self.tool_callbacks.insert(name.into(), tool_callback);
334        self
335    }
336
337    /// Register a custom callback with its associated Tool definition. The Tool will be
338    /// automatically added to requests when tool callbacks are active.
339    pub fn with_tool_callback_and_tool(
340        mut self,
341        name: impl Into<String>,
342        tool_callback: Arc<ToolCallback>,
343        tool: Tool,
344    ) -> Self {
345        let name = name.into();
346        self.tool_callbacks_with_tools.insert(
347            name,
348            ToolCallbackWithTool {
349                callback: tool_callback,
350                tool,
351            },
352        );
353        self
354    }
355
356    /// Configure MCP client to connect to external MCP servers.
357    pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
358        self.mcp_client_config = Some(config);
359        self
360    }
361
362    pub async fn build(self) -> Arc<MistralRs> {
363        MistralRs::new(self).await
364    }
365}
366
367impl Drop for MistralRs {
368    fn drop(&mut self) {
369        // Terminate all engines
370        if let Ok(engines) = self.engines.read() {
371            for (_, engine) in engines.iter() {
372                // Use try_send instead of blocking_send to avoid runtime panics
373                let _ = engine.sender.try_send(Request::Terminate);
374            }
375        }
376    }
377}
378
379impl MistralRs {
380    /// Create an engine instance with the given configuration
381    fn create_engine_instance(
382        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
383        method: SchedulerConfig,
384        config: EngineConfig,
385        reboot_state: RebootState,
386    ) -> Result<EngineInstance, String> {
387        let (tx, rx) = channel(10_000);
388
389        let pipeline_guard = pipeline.try_lock().unwrap();
390        let category = pipeline_guard.category();
391        let metadata = pipeline_guard.get_metadata();
392        let kind = metadata.kind.clone();
393        let device = pipeline_guard.device();
394        let modalities = metadata.modalities.clone();
395        let max_seq_len = match &category {
396            ModelCategory::Diffusion | ModelCategory::Speech => None,
397            _ => Some(metadata.max_seq_len),
398        };
399        drop(pipeline_guard);
400
401        info!("Pipeline input modalities are {:?}", &modalities.input);
402        info!("Pipeline output modalities are {:?}", &modalities.output);
403
404        let mistralrs_config = MistralRsConfig {
405            kind,
406            device,
407            category: category.clone(),
408            modalities,
409            max_seq_len,
410        };
411
412        let tx_for_engine = tx.clone();
413        let engine_handler = thread::spawn(move || {
414            #[cfg(feature = "metal")]
415            objc::rc::autoreleasepool(move || {
416                let rt = Runtime::new().unwrap();
417                rt.block_on(async move {
418                    let engine = Engine::new(
419                        tx_for_engine,
420                        rx,
421                        pipeline,
422                        method,
423                        config.no_kv_cache,
424                        config.no_prefix_cache,
425                        config.prefix_cache_n,
426                        config.disable_eos_stop,
427                        config.throughput_logging_enabled,
428                        config.search_embedding_model,
429                        config.search_callback.clone(),
430                        config.tool_callbacks.clone(),
431                        config.tool_callbacks_with_tools.clone(),
432                    )
433                    .expect("Engine creation failed.");
434                    Arc::new(engine).run().await;
435                })
436            });
437
438            #[cfg(not(feature = "metal"))]
439            {
440                let rt = Runtime::new().unwrap();
441                rt.block_on(async move {
442                    let engine = Engine::new(
443                        tx_for_engine,
444                        rx,
445                        pipeline,
446                        method,
447                        config.no_kv_cache,
448                        config.no_prefix_cache,
449                        config.prefix_cache_n,
450                        config.disable_eos_stop,
451                        config.throughput_logging_enabled,
452                        config.search_embedding_model,
453                        config.search_callback.clone(),
454                        config.tool_callbacks.clone(),
455                        config.tool_callbacks_with_tools.clone(),
456                    )
457                    .expect("Engine creation failed.");
458                    Arc::new(engine).run().await;
459                })
460            }
461        });
462
463        Ok(EngineInstance {
464            sender: tx,
465            engine_handler,
466            reboot_state,
467            config: mistralrs_config,
468            category,
469        })
470    }
471
472    async fn new(config: MistralRsBuilder) -> Arc<Self> {
473        let MistralRsBuilder {
474            pipeline,
475            method,
476            log,
477            no_kv_cache,
478            no_prefix_cache,
479            prefix_cache_n,
480            disable_eos_stop,
481            throughput_logging_enabled,
482            search_embedding_model,
483            search_callback,
484            tool_callbacks,
485            mut tool_callbacks_with_tools,
486            mcp_client_config,
487        } = config;
488
489        mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
490            get_mut_arcmutex!(pipeline).device(),
491        );
492
493        // For hybrid models (Mamba-Attention), force batch_size=1 to prevent state bleeding
494        // Mamba's stateful nature makes batched inference complex; this ensures correctness
495        let method = if !get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache
496            && get_mut_arcmutex!(pipeline).cache().is_hybrid()
497        {
498            info!(
499                "Hybrid model detected (Mamba-Attention), enforcing batch_size=1 for correctness"
500            );
501            SchedulerConfig::DefaultScheduler {
502                method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()),
503            }
504        } else {
505            method
506        };
507
508        let no_kv_cache = no_kv_cache.unwrap_or(false);
509        let no_prefix_cache = no_prefix_cache.unwrap_or(false);
510        let prefix_cache_n = prefix_cache_n.unwrap_or(16);
511        let disable_eos_stop = disable_eos_stop.unwrap_or(false);
512
513        // Initialize MCP client if configured
514        if let Some(config) = &mcp_client_config {
515            let mut mcp_client = McpClient::new(config.clone());
516            let total_servers = config.servers.len();
517
518            match mcp_client.initialize().await {
519                Ok(()) => {
520                    let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
521                    let tools_count = mcp_callbacks_with_tools.len();
522
523                    // Merge MCP tool callbacks with tools into the new collection
524                    for (name, callback_with_tool) in mcp_callbacks_with_tools {
525                        tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
526                    }
527
528                    if tools_count == 0 {
529                        warn!(
530                            "MCP client initialized but no tools were registered from {} servers",
531                            total_servers
532                        );
533                    } else {
534                        info!(
535                            "MCP client initialized successfully with {} tools from {} servers",
536                            tools_count, total_servers
537                        );
538                    }
539                }
540                Err(e) => {
541                    warn!(
542                        "Failed to initialize MCP client with {} configured servers: {}",
543                        total_servers, e
544                    );
545                    warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
546                }
547            }
548        }
549
550        let reboot_state = RebootState {
551            pipeline: pipeline.clone(),
552            method: method.clone(),
553            no_kv_cache,
554            no_prefix_cache,
555            prefix_cache_n,
556            disable_eos_stop,
557            throughput_logging_enabled,
558            search_embedding_model,
559            search_callback: search_callback.clone(),
560            tool_callbacks: tool_callbacks.clone(),
561            tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
562            mcp_client_config: mcp_client_config.clone(),
563        };
564
565        // Create the engine configuration
566        let engine_config = EngineConfig {
567            no_kv_cache,
568            no_prefix_cache,
569            prefix_cache_n,
570            disable_eos_stop,
571            throughput_logging_enabled,
572            search_embedding_model,
573            search_callback,
574            tool_callbacks,
575            tool_callbacks_with_tools,
576        };
577
578        // Create the engine instance
579        let engine_instance =
580            Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
581                .expect("Failed to create engine instance");
582
583        let id = pipeline.try_lock().unwrap().name();
584
585        if distributed::is_daemon() {
586            let request_sender = engine_instance.sender.clone();
587
588            if cfg!(feature = "ring") {
589                // Ring daemon replicator
590                distributed::ring_daemon_replicator(request_sender);
591            } else {
592                // NCCL daemon replicator
593                distributed::nccl_daemon_replicator(request_sender);
594            }
595
596            #[allow(clippy::empty_loop)]
597            loop {}
598        }
599
600        // Determine if the current runtime is multi-threaded, as blocking operations are not allowed in single-threaded mode
601        let is_multi_threaded = tokio::runtime::Handle::try_current()
602            .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
603
604        // Do a dummy run
605        if !distributed::is_daemon()
606            && is_multi_threaded
607            && matches!(
608                engine_instance.category,
609                ModelCategory::Text | ModelCategory::Vision { .. }
610            )
611        {
612            let clone_sender = engine_instance.sender.clone();
613            tokio::task::block_in_place(|| {
614                let (tx, mut rx) = channel(1);
615                let req = Request::Normal(Box::new(NormalRequest {
616                    id: 0,
617                    messages: RequestMessage::Completion {
618                        text: "hello".to_string(),
619                        echo_prompt: false,
620                        best_of: None,
621                    },
622                    sampling_params: SamplingParams {
623                        max_len: Some(1),
624                        ..SamplingParams::deterministic()
625                    },
626                    response: tx,
627                    return_logprobs: false,
628                    is_streaming: false,
629                    constraint: Constraint::None,
630                    suffix: None,
631                    tool_choice: None,
632                    tools: None,
633                    logits_processors: None,
634                    return_raw_logits: false,
635                    web_search_options: None,
636                    model_id: None,
637                    truncate_sequence: false,
638                }));
639                info!("Beginning dummy run.");
640                let start = Instant::now();
641                clone_sender.blocking_send(req).unwrap();
642
643                // Drain all responses from the channel until it's closed
644                let mut received_any = false;
645                while let Some(_resp) = rx.blocking_recv() {
646                    received_any = true;
647                }
648
649                if received_any {
650                    let end = Instant::now();
651                    info!(
652                        "Dummy run completed in {}s.",
653                        end.duration_since(start).as_secs_f64()
654                    );
655                } else {
656                    warn!("Dummy run failed!");
657                }
658            });
659        }
660
661        // Create engines map with the first engine
662        let mut engines = HashMap::new();
663        engines.insert(id.clone(), engine_instance);
664
665        Arc::new(Self {
666            engines: RwLock::new(engines),
667            default_engine_id: RwLock::new(Some(id.clone())),
668            log,
669            id,
670            creation_time: SystemTime::now()
671                .duration_since(UNIX_EPOCH)
672                .expect("Time travel has occurred!")
673                .as_secs(),
674            next_request_id: Mutex::new(RefCell::new(1)),
675        })
676    }
677
678    /// Attempts to reboot a specific engine by model_id
679    fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
680        let mut engines = self.engines.write().map_err(|_| {
681            tracing::warn!("Couldn't get write lock on engines during reboot attempt");
682            MistralRsError::EnginePoisoned
683        })?;
684
685        if let Some(engine_instance) = engines.get(model_id) {
686            if !engine_instance.engine_handler.is_finished() {
687                tracing::info!("Engine {} already running, returning ok", model_id);
688                return Ok(());
689            }
690
691            let reboot_state = engine_instance.reboot_state.clone();
692            let engine_config = EngineConfig {
693                no_kv_cache: reboot_state.no_kv_cache,
694                no_prefix_cache: reboot_state.no_prefix_cache,
695                prefix_cache_n: reboot_state.prefix_cache_n,
696                disable_eos_stop: reboot_state.disable_eos_stop,
697                throughput_logging_enabled: reboot_state.throughput_logging_enabled,
698                search_embedding_model: reboot_state.search_embedding_model,
699                search_callback: reboot_state.search_callback.clone(),
700                tool_callbacks: reboot_state.tool_callbacks.clone(),
701                tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
702            };
703            let new_engine_instance = Self::create_engine_instance(
704                reboot_state.pipeline.clone(),
705                reboot_state.method.clone(),
706                engine_config,
707                reboot_state,
708            )
709            .map_err(|e| {
710                tracing::error!("Failed to create new engine instance: {}", e);
711                MistralRsError::EnginePoisoned
712            })?;
713
714            engines.insert(model_id.to_string(), new_engine_instance);
715            tracing::info!("Successfully rebooted engine {}", model_id);
716            Ok(())
717        } else {
718            Err(MistralRsError::EnginePoisoned)
719        }
720    }
721
722    fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
723        let engines = self.engines.read().map_err(|_| {
724            tracing::warn!("Couldn't get read lock on engines!");
725            MistralRsError::EnginePoisoned
726        })?;
727
728        if let Some(engine_instance) = engines.get(model_id) {
729            Ok(engine_instance.engine_handler.is_finished())
730        } else {
731            Err(MistralRsError::EnginePoisoned)
732        }
733    }
734
735    /// Get sender for a specific model. If model_id is None, uses default engine.
736    pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
737        let resolved_model_id = match model_id {
738            Some(id) => id.to_string(),
739            None => {
740                let default_lock = self
741                    .default_engine_id
742                    .read()
743                    .map_err(|_| MistralRsError::SenderPoisoned)?;
744                default_lock
745                    .as_ref()
746                    .ok_or(MistralRsError::EnginePoisoned)?
747                    .clone()
748            }
749        };
750
751        if self.engine_dead(&resolved_model_id)? {
752            tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
753            self.reboot_engine(&resolved_model_id)?
754        }
755
756        let engines = self
757            .engines
758            .read()
759            .map_err(|_| MistralRsError::SenderPoisoned)?;
760        if let Some(engine_instance) = engines.get(&resolved_model_id) {
761            Ok(engine_instance.sender.clone())
762        } else {
763            Err(MistralRsError::EnginePoisoned)
764        }
765    }
766
767    pub fn get_id(&self) -> String {
768        self.id.clone()
769    }
770
771    pub fn get_creation_time(&self) -> u64 {
772        self.creation_time
773    }
774
775    /// Get model category for a specific model. If model_id is None, uses default engine.
776    pub fn get_model_category(
777        &self,
778        model_id: Option<&str>,
779    ) -> Result<ModelCategory, MistralRsError> {
780        let resolved_model_id = match model_id {
781            Some(id) => id.to_string(),
782            None => {
783                let default_lock = self
784                    .default_engine_id
785                    .read()
786                    .map_err(|_| MistralRsError::SenderPoisoned)?;
787                default_lock
788                    .as_ref()
789                    .ok_or(MistralRsError::EnginePoisoned)?
790                    .clone()
791            }
792        };
793
794        let engines = self
795            .engines
796            .read()
797            .map_err(|_| MistralRsError::SenderPoisoned)?;
798        if let Some(engine_instance) = engines.get(&resolved_model_id) {
799            Ok(engine_instance.category.clone())
800        } else {
801            Err(MistralRsError::EnginePoisoned)
802        }
803    }
804
805    /// Get the maximum supported sequence length for a model, if applicable.
806    pub fn max_sequence_length(
807        &self,
808        model_id: Option<&str>,
809    ) -> Result<Option<usize>, MistralRsError> {
810        let resolved_model_id = match model_id {
811            Some(id) => id.to_string(),
812            None => {
813                let default_lock = self
814                    .default_engine_id
815                    .read()
816                    .map_err(|_| MistralRsError::SenderPoisoned)?;
817                default_lock
818                    .as_ref()
819                    .ok_or(MistralRsError::EnginePoisoned)?
820                    .clone()
821            }
822        };
823
824        let engines = self
825            .engines
826            .read()
827            .map_err(|_| MistralRsError::SenderPoisoned)?;
828        if let Some(engine_instance) = engines.get(&resolved_model_id) {
829            Ok(engine_instance.config.max_seq_len)
830        } else {
831            Err(MistralRsError::EnginePoisoned)
832        }
833    }
834
835    pub fn next_request_id(&self) -> usize {
836        let l = self.next_request_id.lock().unwrap();
837        let last = &mut *l.borrow_mut();
838        let last_v = *last;
839        *last += 1;
840        last_v
841    }
842
843    /// Add a new model engine to the MistralRs instance
844    pub async fn add_model(
845        &self,
846        model_id: String,
847        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
848        method: SchedulerConfig,
849        config: AddModelConfig,
850    ) -> Result<(), String> {
851        // For hybrid models (Mamba-Attention), force batch_size=1 to prevent state bleeding
852        let method = {
853            let pipeline_guard = pipeline.try_lock().unwrap();
854            if !pipeline_guard.get_metadata().no_kv_cache && pipeline_guard.cache().is_hybrid() {
855                info!(
856                    "Hybrid model detected (Mamba-Attention), enforcing batch_size=1 for correctness"
857                );
858                SchedulerConfig::DefaultScheduler {
859                    method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()),
860                }
861            } else {
862                method
863            }
864        };
865
866        let reboot_state = RebootState {
867            pipeline: pipeline.clone(),
868            method: method.clone(),
869            no_kv_cache: config.engine_config.no_kv_cache,
870            no_prefix_cache: config.engine_config.no_prefix_cache,
871            prefix_cache_n: config.engine_config.prefix_cache_n,
872            disable_eos_stop: config.engine_config.disable_eos_stop,
873            throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
874            search_embedding_model: config.engine_config.search_embedding_model,
875            search_callback: config.engine_config.search_callback.clone(),
876            tool_callbacks: config.engine_config.tool_callbacks.clone(),
877            tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
878            mcp_client_config: config.mcp_client_config.clone(),
879        };
880
881        let engine_instance =
882            Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
883
884        let mut engines = self
885            .engines
886            .write()
887            .map_err(|_| "Failed to acquire write lock on engines")?;
888        engines.insert(model_id.clone(), engine_instance);
889
890        // If this is the first model, set it as default
891        if engines.len() == 1 {
892            let mut default_lock = self
893                .default_engine_id
894                .write()
895                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
896            *default_lock = Some(model_id.clone());
897        }
898
899        Ok(())
900    }
901
902    /// Remove a model engine from the MistralRs instance
903    pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
904        let mut engines = self
905            .engines
906            .write()
907            .map_err(|_| "Failed to acquire write lock on engines")?;
908
909        if engines.len() <= 1 {
910            return Err("Cannot remove the last model from MistralRs".to_string());
911        }
912
913        if let Some(engine_instance) = engines.remove(model_id) {
914            // Send terminate signal to the engine
915            let _ = engine_instance.sender.blocking_send(Request::Terminate);
916
917            // If this was the default engine, set a new default
918            let mut default_lock = self
919                .default_engine_id
920                .write()
921                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
922            if let Some(ref default_id) = *default_lock {
923                if default_id == model_id {
924                    // Set the first available engine as the new default
925                    *default_lock = engines.keys().next().cloned();
926                }
927            }
928
929            Ok(())
930        } else {
931            Err(format!("Model {model_id} not found"))
932        }
933    }
934
935    /// List all available model IDs
936    pub fn list_models(&self) -> Result<Vec<String>, String> {
937        let engines = self
938            .engines
939            .read()
940            .map_err(|_| "Failed to acquire read lock on engines")?;
941        Ok(engines.keys().cloned().collect())
942    }
943
944    /// Get the current default model ID
945    pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
946        let default_lock = self
947            .default_engine_id
948            .read()
949            .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
950        Ok(default_lock.clone())
951    }
952
953    /// Set the default model ID
954    pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
955        let engines = self
956            .engines
957            .read()
958            .map_err(|_| "Failed to acquire read lock on engines")?;
959        if !engines.contains_key(model_id) {
960            return Err(format!("Model {model_id} not found"));
961        }
962        drop(engines);
963
964        let mut default_lock = self
965            .default_engine_id
966            .write()
967            .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
968        *default_lock = Some(model_id.to_string());
969
970        Ok(())
971    }
972
973    /// Dispatch a request to the appropriate engine based on the model_id in the request
974    pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
975        let model_id = match &mut request {
976            Request::Normal(normal_req) => normal_req.model_id.as_deref(),
977            _ => None, // Other request types don't specify model_id
978        };
979
980        let sender = self.get_sender(model_id)?;
981        sender
982            .blocking_send(request)
983            .map_err(|_| MistralRsError::SenderPoisoned)
984    }
985
986    pub fn maybe_log_request(this: Arc<Self>, repr: String) {
987        if let Some(file) = &this.log {
988            let mut f = OpenOptions::new()
989                .append(true)
990                .create(true) // Optionally create the file if it doesn't already exist
991                .open(file)
992                .expect("Unable to open file");
993            let time = chrono::offset::Local::now();
994            f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
995                .expect("Unable to write data");
996        }
997    }
998
999    pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
1000        if let Some(file) = &this.log {
1001            let mut f = OpenOptions::new()
1002                .append(true)
1003                .create(true) // Optionally create the file if it doesn't already exist
1004                .open(file)
1005                .expect("Unable to open file");
1006            let time = chrono::offset::Local::now();
1007            let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
1008            f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
1009                .expect("Unable to write data");
1010        }
1011    }
1012
1013    pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
1014        if let Some(file) = &this.log {
1015            let mut f = OpenOptions::new()
1016                .append(true)
1017                .create(true) // Optionally create the file if it doesn't already exist
1018                .open(file)
1019                .expect("Unable to open file");
1020            let time = chrono::offset::Local::now();
1021            f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
1022                .expect("Unable to write data");
1023        }
1024    }
1025
1026    /// Get the number of tools available for a specific model (including MCP tools)
1027    pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
1028        let resolved_model_id = match model_id {
1029            Some(id) => id.to_string(),
1030            None => {
1031                let default_lock = self
1032                    .default_engine_id
1033                    .read()
1034                    .map_err(|_| "Failed to acquire read lock")?;
1035                default_lock
1036                    .as_ref()
1037                    .ok_or("No default engine set")?
1038                    .clone()
1039            }
1040        };
1041
1042        let engines = self
1043            .engines
1044            .read()
1045            .map_err(|_| "Failed to acquire read lock on engines")?;
1046        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1047            Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
1048        } else {
1049            Err(format!("Model {resolved_model_id} not found"))
1050        }
1051    }
1052
1053    /// Check if MCP client is configured for a specific model
1054    pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1055        let resolved_model_id = match model_id {
1056            Some(id) => id.to_string(),
1057            None => {
1058                let default_lock = self
1059                    .default_engine_id
1060                    .read()
1061                    .map_err(|_| "Failed to acquire read lock")?;
1062                default_lock
1063                    .as_ref()
1064                    .ok_or("No default engine set")?
1065                    .clone()
1066            }
1067        };
1068
1069        let engines = self
1070            .engines
1071            .read()
1072            .map_err(|_| "Failed to acquire read lock on engines")?;
1073        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1074            Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1075        } else {
1076            Err(format!("Model {resolved_model_id} not found"))
1077        }
1078    }
1079
1080    /// Get config for a specific model
1081    pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1082        let resolved_model_id = match model_id {
1083            Some(id) => id.to_string(),
1084            None => {
1085                let default_lock = self
1086                    .default_engine_id
1087                    .read()
1088                    .map_err(|_| "Failed to acquire read lock")?;
1089                default_lock
1090                    .as_ref()
1091                    .ok_or("No default engine set")?
1092                    .clone()
1093            }
1094        };
1095
1096        let engines = self
1097            .engines
1098            .read()
1099            .map_err(|_| "Failed to acquire read lock on engines")?;
1100        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1101            Ok(engine_instance.config.clone())
1102        } else {
1103            Err(format!("Model {resolved_model_id} not found"))
1104        }
1105    }
1106}