mistralrs_core/
lib.rs

1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5    get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6    EngineInstruction, SearchEmbeddingModel, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
7};
8use hf_hub::Cache;
9pub use lora::Ordering;
10pub use pipeline::ModelCategory;
11pub use pipeline::Pipeline;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::exceptions::PyValueError;
14use std::collections::HashMap;
15use std::num::NonZeroUsize;
16use std::sync::OnceLock;
17use std::time::Instant;
18use std::{
19    cell::RefCell,
20    error::Error,
21    fs::OpenOptions,
22    io::Write,
23    sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
24    thread::{self, JoinHandle},
25    time::{SystemTime, UNIX_EPOCH},
26};
27use tokio::sync::mpsc::{channel, Sender};
28use tracing::info;
29use tracing::warn;
30
31mod cuda;
32mod device_map;
33mod engine;
34mod lora;
35mod model_loader;
36mod moe;
37mod ops;
38pub use model_loader::{
39    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
40};
41mod embedding_models;
42mod kv_cache;
43mod search;
44
45mod model_selected;
46pub use model_selected::ModelSelected;
47pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
48
49mod amoe;
50mod attention;
51mod diffusion_models;
52pub mod distributed;
53mod gguf;
54pub mod harmony;
55pub mod layers;
56mod layers_masker;
57mod layers_utils;
58pub mod matformer;
59mod models;
60mod paged_attention;
61mod pipeline;
62mod prefix_cacher;
63mod request;
64mod response;
65mod sampler;
66mod scheduler;
67mod sequence;
68mod speech_models;
69pub mod think_tags;
70mod toml_selector;
71mod tools;
72mod topology;
73mod utils;
74mod vision_models;
75mod xlora_models;
76
77pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
78pub use device_map::{
79    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
80};
81pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
82pub use mistralrs_audio::AudioInput;
83pub use mistralrs_mcp::{
84    CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
85};
86pub use mistralrs_mcp::{
87    McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
88};
89pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
90pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
91pub use pipeline::{
92    chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
93    AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
94    DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader, EmbeddingLoaderBuilder,
95    EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig, GGMLLoader,
96    GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig,
97    GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader,
98    Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind,
99    ModelPaths, MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
100    NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
101    SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
102    SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
103    VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
104};
105pub use request::{
106    ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
107    LlguidanceGrammar, MessageContent, NormalRequest, ReasoningEffort, Request, RequestMessage,
108    SearchContextSize, TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
109};
110pub use response::*;
111pub use sampler::{
112    CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
113};
114pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
115pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
116use serde::Serialize;
117pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
118use tokio::runtime::Runtime;
119use toml_selector::{TomlLoaderArgs, TomlSelector};
120pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
121pub use topology::{LayerTopology, Topology};
122pub use utils::debug::initialize_logging;
123pub use utils::memory_usage::MemoryUsage;
124pub use utils::normal::{ModelDType, TryIntoDType};
125pub use utils::{paged_attn_supported, using_flash_attn};
126
127// re-export llguidance for easier LlguidanceGrammar construction
128pub use llguidance;
129
130/// `true` if `MISTRALRS_DEBUG=1`
131pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
132pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
133
134/// Configuration for creating an engine instance
135#[derive(Clone)]
136pub struct EngineConfig {
137    pub no_kv_cache: bool,
138    pub no_prefix_cache: bool,
139    pub prefix_cache_n: usize,
140    pub disable_eos_stop: bool,
141    pub throughput_logging_enabled: bool,
142    pub search_embedding_model: Option<SearchEmbeddingModel>,
143    pub search_callback: Option<Arc<SearchCallback>>,
144    pub tool_callbacks: tools::ToolCallbacks,
145    pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
146}
147
148impl Default for EngineConfig {
149    fn default() -> Self {
150        Self {
151            no_kv_cache: false,
152            no_prefix_cache: false,
153            prefix_cache_n: 16,
154            disable_eos_stop: false,
155            throughput_logging_enabled: true,
156            search_embedding_model: None,
157            search_callback: None,
158            tool_callbacks: HashMap::new(),
159            tool_callbacks_with_tools: HashMap::new(),
160        }
161    }
162}
163
164/// Configuration for adding a model to MistralRs
165#[derive(Clone)]
166pub struct AddModelConfig {
167    pub engine_config: EngineConfig,
168    pub mcp_client_config: Option<McpClientConfig>,
169}
170
171impl AddModelConfig {
172    pub fn new(engine_config: EngineConfig) -> Self {
173        Self {
174            engine_config,
175            mcp_client_config: None,
176        }
177    }
178
179    pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
180        self.mcp_client_config = Some(mcp_config);
181        self
182    }
183}
184
185#[derive(Clone)]
186pub struct MistralRsConfig {
187    pub kind: ModelKind,
188    pub device: Device,
189    pub category: ModelCategory,
190    pub modalities: Modalities,
191    pub max_seq_len: Option<usize>,
192}
193
194/// Internal structure to hold per-engine state
195struct EngineInstance {
196    sender: Sender<Request>,
197    engine_handler: JoinHandle<()>,
198    reboot_state: RebootState,
199    config: MistralRsConfig,
200    category: ModelCategory,
201}
202
203/// The MistralRs struct handles sending requests to multiple engines.
204/// It is the core multi-threaded component of mistral.rs, and uses `mpsc`
205/// `Sender` and `Receiver` primitives to send and receive requests to the
206/// appropriate engine based on model ID.
207pub struct MistralRs {
208    engines: RwLock<HashMap<String, EngineInstance>>,
209    default_engine_id: RwLock<Option<String>>,
210    log: Option<String>,
211    id: String,
212    creation_time: u64,
213    next_request_id: Mutex<RefCell<usize>>,
214}
215
216#[derive(Clone)]
217struct RebootState {
218    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
219    method: SchedulerConfig,
220    no_kv_cache: bool,
221    no_prefix_cache: bool,
222    prefix_cache_n: usize,
223    disable_eos_stop: bool,
224    throughput_logging_enabled: bool,
225    search_embedding_model: Option<SearchEmbeddingModel>,
226    search_callback: Option<Arc<search::SearchCallback>>,
227    tool_callbacks: tools::ToolCallbacks,
228    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
229    mcp_client_config: Option<McpClientConfig>,
230}
231
232#[derive(Debug)]
233pub enum MistralRsError {
234    EnginePoisoned,
235    SenderPoisoned,
236}
237
238impl std::fmt::Display for MistralRsError {
239    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240        write!(f, "{:?}", &self)
241    }
242}
243
244impl std::error::Error for MistralRsError {}
245
246#[cfg(feature = "pyo3_macros")]
247impl From<MistralRsError> for pyo3::PyErr {
248    fn from(value: MistralRsError) -> Self {
249        PyValueError::new_err(format!("{value:?}"))
250    }
251}
252
253/// The MistralRsBuilder takes the pipeline and a scheduler method and constructs
254/// an Engine and a MistralRs instance. The Engine runs on a separate thread, and the MistralRs
255/// instance stays on the calling thread.
256pub struct MistralRsBuilder {
257    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
258    method: SchedulerConfig,
259    log: Option<String>,
260    no_kv_cache: Option<bool>,
261    no_prefix_cache: Option<bool>,
262    prefix_cache_n: Option<usize>,
263    disable_eos_stop: Option<bool>,
264    throughput_logging_enabled: bool,
265    search_embedding_model: Option<SearchEmbeddingModel>,
266    search_callback: Option<Arc<SearchCallback>>,
267    tool_callbacks: tools::ToolCallbacks,
268    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
269    mcp_client_config: Option<McpClientConfig>,
270}
271
272impl MistralRsBuilder {
273    /// Creates a new builder with the given pipeline, scheduler method, logging flag,
274    /// and optional embedding model for web search. To override the search callback,
275    /// use `.with_search_callback(...)` on the builder.
276    pub fn new(
277        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
278        method: SchedulerConfig,
279        throughput_logging: bool,
280        search_embedding_model: Option<SearchEmbeddingModel>,
281    ) -> Self {
282        Self {
283            pipeline,
284            method,
285            log: None,
286            no_kv_cache: None,
287            no_prefix_cache: None,
288            prefix_cache_n: None,
289            disable_eos_stop: None,
290            throughput_logging_enabled: throughput_logging,
291            search_embedding_model,
292            search_callback: None,
293            tool_callbacks: HashMap::new(),
294            tool_callbacks_with_tools: HashMap::new(),
295            mcp_client_config: None,
296        }
297    }
298    pub fn with_log(mut self, log: String) -> Self {
299        self.log = Some(log);
300        self
301    }
302    pub fn with_opt_log(mut self, log: Option<String>) -> Self {
303        self.log = log;
304        self
305    }
306    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
307        self.no_kv_cache = Some(no_kv_cache);
308        self
309    }
310    pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
311        self.no_prefix_cache = Some(no_prefix_cache);
312        self
313    }
314    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
315        self.prefix_cache_n = Some(prefix_cache_n);
316        self
317    }
318    pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
319        self.disable_eos_stop = Some(disable_eos_stop);
320        self
321    }
322
323    /// Use a custom callback to gather search results.
324    pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
325        self.search_callback = Some(search_callback);
326        self
327    }
328
329    /// Register a custom callback for the specified tool name.
330    pub fn with_tool_callback(
331        mut self,
332        name: impl Into<String>,
333        tool_callback: Arc<ToolCallback>,
334    ) -> Self {
335        self.tool_callbacks.insert(name.into(), tool_callback);
336        self
337    }
338
339    /// Register a custom callback with its associated Tool definition. The Tool will be
340    /// automatically added to requests when tool callbacks are active.
341    pub fn with_tool_callback_and_tool(
342        mut self,
343        name: impl Into<String>,
344        tool_callback: Arc<ToolCallback>,
345        tool: Tool,
346    ) -> Self {
347        let name = name.into();
348        self.tool_callbacks_with_tools.insert(
349            name,
350            ToolCallbackWithTool {
351                callback: tool_callback,
352                tool,
353            },
354        );
355        self
356    }
357
358    /// Configure MCP client to connect to external MCP servers.
359    pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
360        self.mcp_client_config = Some(config);
361        self
362    }
363
364    pub async fn build(self) -> Arc<MistralRs> {
365        MistralRs::new(self).await
366    }
367}
368
369impl Drop for MistralRs {
370    fn drop(&mut self) {
371        // Terminate all engines
372        if let Ok(engines) = self.engines.read() {
373            for (_, engine) in engines.iter() {
374                // Use try_send instead of blocking_send to avoid runtime panics
375                let _ = engine.sender.try_send(Request::Terminate);
376            }
377        }
378    }
379}
380
381impl MistralRs {
382    /// Create an engine instance with the given configuration
383    fn create_engine_instance(
384        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
385        method: SchedulerConfig,
386        config: EngineConfig,
387        reboot_state: RebootState,
388    ) -> Result<EngineInstance, String> {
389        let (tx, rx) = channel(10_000);
390
391        let pipeline_guard = pipeline.try_lock().unwrap();
392        let category = pipeline_guard.category();
393        let metadata = pipeline_guard.get_metadata();
394        let kind = metadata.kind.clone();
395        let device = pipeline_guard.device();
396        let modalities = metadata.modalities.clone();
397        let max_seq_len = match &category {
398            ModelCategory::Diffusion | ModelCategory::Speech => None,
399            _ => Some(metadata.max_seq_len),
400        };
401        drop(pipeline_guard);
402
403        info!("Pipeline input modalities are {:?}", &modalities.input);
404        info!("Pipeline output modalities are {:?}", &modalities.output);
405
406        let mistralrs_config = MistralRsConfig {
407            kind,
408            device,
409            category: category.clone(),
410            modalities,
411            max_seq_len,
412        };
413
414        let tx_for_engine = tx.clone();
415        let engine_handler = thread::spawn(move || {
416            #[cfg(feature = "metal")]
417            objc::rc::autoreleasepool(move || {
418                let rt = Runtime::new().unwrap();
419                rt.block_on(async move {
420                    let engine = Engine::new(
421                        tx_for_engine,
422                        rx,
423                        pipeline,
424                        method,
425                        config.no_kv_cache,
426                        config.no_prefix_cache,
427                        config.prefix_cache_n,
428                        config.disable_eos_stop,
429                        config.throughput_logging_enabled,
430                        config.search_embedding_model,
431                        config.search_callback.clone(),
432                        config.tool_callbacks.clone(),
433                        config.tool_callbacks_with_tools.clone(),
434                    )
435                    .expect("Engine creation failed.");
436                    Arc::new(engine).run().await;
437                })
438            });
439
440            #[cfg(not(feature = "metal"))]
441            {
442                let rt = Runtime::new().unwrap();
443                rt.block_on(async move {
444                    let engine = Engine::new(
445                        tx_for_engine,
446                        rx,
447                        pipeline,
448                        method,
449                        config.no_kv_cache,
450                        config.no_prefix_cache,
451                        config.prefix_cache_n,
452                        config.disable_eos_stop,
453                        config.throughput_logging_enabled,
454                        config.search_embedding_model,
455                        config.search_callback.clone(),
456                        config.tool_callbacks.clone(),
457                        config.tool_callbacks_with_tools.clone(),
458                    )
459                    .expect("Engine creation failed.");
460                    Arc::new(engine).run().await;
461                })
462            }
463        });
464
465        Ok(EngineInstance {
466            sender: tx,
467            engine_handler,
468            reboot_state,
469            config: mistralrs_config,
470            category,
471        })
472    }
473
474    async fn new(config: MistralRsBuilder) -> Arc<Self> {
475        let MistralRsBuilder {
476            pipeline,
477            method,
478            log,
479            no_kv_cache,
480            no_prefix_cache,
481            prefix_cache_n,
482            disable_eos_stop,
483            throughput_logging_enabled,
484            search_embedding_model,
485            search_callback,
486            tool_callbacks,
487            mut tool_callbacks_with_tools,
488            mcp_client_config,
489        } = config;
490
491        mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
492            get_mut_arcmutex!(pipeline).device(),
493        );
494
495        // For hybrid models (Mamba-Attention), force batch_size=1 to prevent state bleeding
496        // Mamba's stateful nature makes batched inference complex; this ensures correctness
497        let method = if !get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache
498            && get_mut_arcmutex!(pipeline).cache().is_hybrid()
499        {
500            info!(
501                "Hybrid model detected (Mamba-Attention), enforcing batch_size=1 for correctness"
502            );
503            SchedulerConfig::DefaultScheduler {
504                method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()),
505            }
506        } else {
507            method
508        };
509
510        let no_kv_cache = no_kv_cache.unwrap_or(false);
511        let no_prefix_cache = no_prefix_cache.unwrap_or(false);
512        let prefix_cache_n = prefix_cache_n.unwrap_or(16);
513        let disable_eos_stop = disable_eos_stop.unwrap_or(false);
514
515        // Initialize MCP client if configured
516        if let Some(config) = &mcp_client_config {
517            let mut mcp_client = McpClient::new(config.clone());
518            let total_servers = config.servers.len();
519
520            match mcp_client.initialize().await {
521                Ok(()) => {
522                    let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
523                    let tools_count = mcp_callbacks_with_tools.len();
524
525                    // Merge MCP tool callbacks with tools into the new collection
526                    for (name, callback_with_tool) in mcp_callbacks_with_tools {
527                        tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
528                    }
529
530                    if tools_count == 0 {
531                        warn!(
532                            "MCP client initialized but no tools were registered from {} servers",
533                            total_servers
534                        );
535                    } else {
536                        info!(
537                            "MCP client initialized successfully with {} tools from {} servers",
538                            tools_count, total_servers
539                        );
540                    }
541                }
542                Err(e) => {
543                    warn!(
544                        "Failed to initialize MCP client with {} configured servers: {}",
545                        total_servers, e
546                    );
547                    warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
548                }
549            }
550        }
551
552        let reboot_state = RebootState {
553            pipeline: pipeline.clone(),
554            method: method.clone(),
555            no_kv_cache,
556            no_prefix_cache,
557            prefix_cache_n,
558            disable_eos_stop,
559            throughput_logging_enabled,
560            search_embedding_model,
561            search_callback: search_callback.clone(),
562            tool_callbacks: tool_callbacks.clone(),
563            tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
564            mcp_client_config: mcp_client_config.clone(),
565        };
566
567        // Create the engine configuration
568        let engine_config = EngineConfig {
569            no_kv_cache,
570            no_prefix_cache,
571            prefix_cache_n,
572            disable_eos_stop,
573            throughput_logging_enabled,
574            search_embedding_model,
575            search_callback,
576            tool_callbacks,
577            tool_callbacks_with_tools,
578        };
579
580        // Create the engine instance
581        let engine_instance =
582            Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
583                .expect("Failed to create engine instance");
584
585        let id = pipeline.try_lock().unwrap().name();
586
587        if distributed::is_daemon() {
588            let request_sender = engine_instance.sender.clone();
589
590            if cfg!(feature = "ring") {
591                // Ring daemon replicator
592                distributed::ring_daemon_replicator(request_sender);
593            } else {
594                // NCCL daemon replicator
595                distributed::nccl_daemon_replicator(request_sender);
596            }
597
598            #[allow(clippy::empty_loop)]
599            loop {}
600        }
601
602        // Determine if the current runtime is multi-threaded, as blocking operations are not allowed in single-threaded mode
603        let is_multi_threaded = tokio::runtime::Handle::try_current()
604            .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
605
606        // Do a dummy run
607        if !distributed::is_daemon()
608            && is_multi_threaded
609            && matches!(
610                engine_instance.category,
611                ModelCategory::Text | ModelCategory::Vision { .. }
612            )
613        {
614            let clone_sender = engine_instance.sender.clone();
615            tokio::task::block_in_place(|| {
616                let (tx, mut rx) = channel(1);
617                let req = Request::Normal(Box::new(NormalRequest {
618                    id: 0,
619                    messages: RequestMessage::Completion {
620                        text: "hello".to_string(),
621                        echo_prompt: false,
622                        best_of: None,
623                    },
624                    sampling_params: SamplingParams {
625                        max_len: Some(1),
626                        ..SamplingParams::deterministic()
627                    },
628                    response: tx,
629                    return_logprobs: false,
630                    is_streaming: false,
631                    constraint: Constraint::None,
632                    suffix: None,
633                    tool_choice: None,
634                    tools: None,
635                    logits_processors: None,
636                    return_raw_logits: false,
637                    web_search_options: None,
638                    model_id: None,
639                    truncate_sequence: false,
640                }));
641                info!("Beginning dummy run.");
642                let start = Instant::now();
643                clone_sender.blocking_send(req).unwrap();
644
645                // Drain all responses from the channel until it's closed
646                let mut received_any = false;
647                while let Some(_resp) = rx.blocking_recv() {
648                    received_any = true;
649                }
650
651                if received_any {
652                    let end = Instant::now();
653                    info!(
654                        "Dummy run completed in {}s.",
655                        end.duration_since(start).as_secs_f64()
656                    );
657                } else {
658                    warn!("Dummy run failed!");
659                }
660            });
661        }
662
663        // Create engines map with the first engine
664        let mut engines = HashMap::new();
665        engines.insert(id.clone(), engine_instance);
666
667        Arc::new(Self {
668            engines: RwLock::new(engines),
669            default_engine_id: RwLock::new(Some(id.clone())),
670            log,
671            id,
672            creation_time: SystemTime::now()
673                .duration_since(UNIX_EPOCH)
674                .expect("Time travel has occurred!")
675                .as_secs(),
676            next_request_id: Mutex::new(RefCell::new(1)),
677        })
678    }
679
680    /// Attempts to reboot a specific engine by model_id
681    fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
682        let mut engines = self.engines.write().map_err(|_| {
683            tracing::warn!("Couldn't get write lock on engines during reboot attempt");
684            MistralRsError::EnginePoisoned
685        })?;
686
687        if let Some(engine_instance) = engines.get(model_id) {
688            if !engine_instance.engine_handler.is_finished() {
689                tracing::info!("Engine {} already running, returning ok", model_id);
690                return Ok(());
691            }
692
693            let reboot_state = engine_instance.reboot_state.clone();
694            let engine_config = EngineConfig {
695                no_kv_cache: reboot_state.no_kv_cache,
696                no_prefix_cache: reboot_state.no_prefix_cache,
697                prefix_cache_n: reboot_state.prefix_cache_n,
698                disable_eos_stop: reboot_state.disable_eos_stop,
699                throughput_logging_enabled: reboot_state.throughput_logging_enabled,
700                search_embedding_model: reboot_state.search_embedding_model,
701                search_callback: reboot_state.search_callback.clone(),
702                tool_callbacks: reboot_state.tool_callbacks.clone(),
703                tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
704            };
705            let new_engine_instance = Self::create_engine_instance(
706                reboot_state.pipeline.clone(),
707                reboot_state.method.clone(),
708                engine_config,
709                reboot_state,
710            )
711            .map_err(|e| {
712                tracing::error!("Failed to create new engine instance: {}", e);
713                MistralRsError::EnginePoisoned
714            })?;
715
716            engines.insert(model_id.to_string(), new_engine_instance);
717            tracing::info!("Successfully rebooted engine {}", model_id);
718            Ok(())
719        } else {
720            Err(MistralRsError::EnginePoisoned)
721        }
722    }
723
724    fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
725        let engines = self.engines.read().map_err(|_| {
726            tracing::warn!("Couldn't get read lock on engines!");
727            MistralRsError::EnginePoisoned
728        })?;
729
730        if let Some(engine_instance) = engines.get(model_id) {
731            Ok(engine_instance.engine_handler.is_finished())
732        } else {
733            Err(MistralRsError::EnginePoisoned)
734        }
735    }
736
737    /// Get sender for a specific model. If model_id is None, uses default engine.
738    pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
739        let resolved_model_id = match model_id {
740            Some(id) => id.to_string(),
741            None => {
742                let default_lock = self
743                    .default_engine_id
744                    .read()
745                    .map_err(|_| MistralRsError::SenderPoisoned)?;
746                default_lock
747                    .as_ref()
748                    .ok_or(MistralRsError::EnginePoisoned)?
749                    .clone()
750            }
751        };
752
753        if self.engine_dead(&resolved_model_id)? {
754            tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
755            self.reboot_engine(&resolved_model_id)?
756        }
757
758        let engines = self
759            .engines
760            .read()
761            .map_err(|_| MistralRsError::SenderPoisoned)?;
762        if let Some(engine_instance) = engines.get(&resolved_model_id) {
763            Ok(engine_instance.sender.clone())
764        } else {
765            Err(MistralRsError::EnginePoisoned)
766        }
767    }
768
769    pub fn get_id(&self) -> String {
770        self.id.clone()
771    }
772
773    pub fn get_creation_time(&self) -> u64 {
774        self.creation_time
775    }
776
777    /// Get model category for a specific model. If model_id is None, uses default engine.
778    pub fn get_model_category(
779        &self,
780        model_id: Option<&str>,
781    ) -> Result<ModelCategory, MistralRsError> {
782        let resolved_model_id = match model_id {
783            Some(id) => id.to_string(),
784            None => {
785                let default_lock = self
786                    .default_engine_id
787                    .read()
788                    .map_err(|_| MistralRsError::SenderPoisoned)?;
789                default_lock
790                    .as_ref()
791                    .ok_or(MistralRsError::EnginePoisoned)?
792                    .clone()
793            }
794        };
795
796        let engines = self
797            .engines
798            .read()
799            .map_err(|_| MistralRsError::SenderPoisoned)?;
800        if let Some(engine_instance) = engines.get(&resolved_model_id) {
801            Ok(engine_instance.category.clone())
802        } else {
803            Err(MistralRsError::EnginePoisoned)
804        }
805    }
806
807    /// Get the maximum supported sequence length for a model, if applicable.
808    pub fn max_sequence_length(
809        &self,
810        model_id: Option<&str>,
811    ) -> Result<Option<usize>, MistralRsError> {
812        let resolved_model_id = match model_id {
813            Some(id) => id.to_string(),
814            None => {
815                let default_lock = self
816                    .default_engine_id
817                    .read()
818                    .map_err(|_| MistralRsError::SenderPoisoned)?;
819                default_lock
820                    .as_ref()
821                    .ok_or(MistralRsError::EnginePoisoned)?
822                    .clone()
823            }
824        };
825
826        let engines = self
827            .engines
828            .read()
829            .map_err(|_| MistralRsError::SenderPoisoned)?;
830        if let Some(engine_instance) = engines.get(&resolved_model_id) {
831            Ok(engine_instance.config.max_seq_len)
832        } else {
833            Err(MistralRsError::EnginePoisoned)
834        }
835    }
836
837    pub fn next_request_id(&self) -> usize {
838        let l = self.next_request_id.lock().unwrap();
839        let last = &mut *l.borrow_mut();
840        let last_v = *last;
841        *last += 1;
842        last_v
843    }
844
845    /// Add a new model engine to the MistralRs instance
846    pub async fn add_model(
847        &self,
848        model_id: String,
849        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
850        method: SchedulerConfig,
851        config: AddModelConfig,
852    ) -> Result<(), String> {
853        // For hybrid models (Mamba-Attention), force batch_size=1 to prevent state bleeding
854        let method = {
855            let pipeline_guard = pipeline.try_lock().unwrap();
856            if !pipeline_guard.get_metadata().no_kv_cache && pipeline_guard.cache().is_hybrid() {
857                info!(
858                    "Hybrid model detected (Mamba-Attention), enforcing batch_size=1 for correctness"
859                );
860                SchedulerConfig::DefaultScheduler {
861                    method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()),
862                }
863            } else {
864                method
865            }
866        };
867
868        let reboot_state = RebootState {
869            pipeline: pipeline.clone(),
870            method: method.clone(),
871            no_kv_cache: config.engine_config.no_kv_cache,
872            no_prefix_cache: config.engine_config.no_prefix_cache,
873            prefix_cache_n: config.engine_config.prefix_cache_n,
874            disable_eos_stop: config.engine_config.disable_eos_stop,
875            throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
876            search_embedding_model: config.engine_config.search_embedding_model,
877            search_callback: config.engine_config.search_callback.clone(),
878            tool_callbacks: config.engine_config.tool_callbacks.clone(),
879            tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
880            mcp_client_config: config.mcp_client_config.clone(),
881        };
882
883        let engine_instance =
884            Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
885
886        let mut engines = self
887            .engines
888            .write()
889            .map_err(|_| "Failed to acquire write lock on engines")?;
890        engines.insert(model_id.clone(), engine_instance);
891
892        // If this is the first model, set it as default
893        if engines.len() == 1 {
894            let mut default_lock = self
895                .default_engine_id
896                .write()
897                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
898            *default_lock = Some(model_id.clone());
899        }
900
901        Ok(())
902    }
903
904    /// Remove a model engine from the MistralRs instance
905    pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
906        let mut engines = self
907            .engines
908            .write()
909            .map_err(|_| "Failed to acquire write lock on engines")?;
910
911        if engines.len() <= 1 {
912            return Err("Cannot remove the last model from MistralRs".to_string());
913        }
914
915        if let Some(engine_instance) = engines.remove(model_id) {
916            // Send terminate signal to the engine
917            let _ = engine_instance.sender.blocking_send(Request::Terminate);
918
919            // If this was the default engine, set a new default
920            let mut default_lock = self
921                .default_engine_id
922                .write()
923                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
924            if let Some(ref default_id) = *default_lock {
925                if default_id == model_id {
926                    // Set the first available engine as the new default
927                    *default_lock = engines.keys().next().cloned();
928                }
929            }
930
931            Ok(())
932        } else {
933            Err(format!("Model {model_id} not found"))
934        }
935    }
936
937    /// List all available model IDs
938    pub fn list_models(&self) -> Result<Vec<String>, String> {
939        let engines = self
940            .engines
941            .read()
942            .map_err(|_| "Failed to acquire read lock on engines")?;
943        Ok(engines.keys().cloned().collect())
944    }
945
946    /// Get the current default model ID
947    pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
948        let default_lock = self
949            .default_engine_id
950            .read()
951            .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
952        Ok(default_lock.clone())
953    }
954
955    /// Set the default model ID
956    pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
957        let engines = self
958            .engines
959            .read()
960            .map_err(|_| "Failed to acquire read lock on engines")?;
961        if !engines.contains_key(model_id) {
962            return Err(format!("Model {model_id} not found"));
963        }
964        drop(engines);
965
966        let mut default_lock = self
967            .default_engine_id
968            .write()
969            .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
970        *default_lock = Some(model_id.to_string());
971
972        Ok(())
973    }
974
975    /// Dispatch a request to the appropriate engine based on the model_id in the request
976    pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
977        let model_id = match &mut request {
978            Request::Normal(normal_req) => normal_req.model_id.as_deref(),
979            _ => None, // Other request types don't specify model_id
980        };
981
982        let sender = self.get_sender(model_id)?;
983        sender
984            .blocking_send(request)
985            .map_err(|_| MistralRsError::SenderPoisoned)
986    }
987
988    pub fn maybe_log_request(this: Arc<Self>, repr: String) {
989        if let Some(file) = &this.log {
990            let mut f = OpenOptions::new()
991                .append(true)
992                .create(true) // Optionally create the file if it doesn't already exist
993                .open(file)
994                .expect("Unable to open file");
995            let time = chrono::offset::Local::now();
996            f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
997                .expect("Unable to write data");
998        }
999    }
1000
1001    pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
1002        if let Some(file) = &this.log {
1003            let mut f = OpenOptions::new()
1004                .append(true)
1005                .create(true) // Optionally create the file if it doesn't already exist
1006                .open(file)
1007                .expect("Unable to open file");
1008            let time = chrono::offset::Local::now();
1009            let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
1010            f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
1011                .expect("Unable to write data");
1012        }
1013    }
1014
1015    pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
1016        if let Some(file) = &this.log {
1017            let mut f = OpenOptions::new()
1018                .append(true)
1019                .create(true) // Optionally create the file if it doesn't already exist
1020                .open(file)
1021                .expect("Unable to open file");
1022            let time = chrono::offset::Local::now();
1023            f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
1024                .expect("Unable to write data");
1025        }
1026    }
1027
1028    /// Get the number of tools available for a specific model (including MCP tools)
1029    pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
1030        let resolved_model_id = match model_id {
1031            Some(id) => id.to_string(),
1032            None => {
1033                let default_lock = self
1034                    .default_engine_id
1035                    .read()
1036                    .map_err(|_| "Failed to acquire read lock")?;
1037                default_lock
1038                    .as_ref()
1039                    .ok_or("No default engine set")?
1040                    .clone()
1041            }
1042        };
1043
1044        let engines = self
1045            .engines
1046            .read()
1047            .map_err(|_| "Failed to acquire read lock on engines")?;
1048        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1049            Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
1050        } else {
1051            Err(format!("Model {resolved_model_id} not found"))
1052        }
1053    }
1054
1055    /// Check if MCP client is configured for a specific model
1056    pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1057        let resolved_model_id = match model_id {
1058            Some(id) => id.to_string(),
1059            None => {
1060                let default_lock = self
1061                    .default_engine_id
1062                    .read()
1063                    .map_err(|_| "Failed to acquire read lock")?;
1064                default_lock
1065                    .as_ref()
1066                    .ok_or("No default engine set")?
1067                    .clone()
1068            }
1069        };
1070
1071        let engines = self
1072            .engines
1073            .read()
1074            .map_err(|_| "Failed to acquire read lock on engines")?;
1075        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1076            Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1077        } else {
1078            Err(format!("Model {resolved_model_id} not found"))
1079        }
1080    }
1081
1082    /// Get config for a specific model
1083    pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1084        let resolved_model_id = match model_id {
1085            Some(id) => id.to_string(),
1086            None => {
1087                let default_lock = self
1088                    .default_engine_id
1089                    .read()
1090                    .map_err(|_| "Failed to acquire read lock")?;
1091                default_lock
1092                    .as_ref()
1093                    .ok_or("No default engine set")?
1094                    .clone()
1095            }
1096        };
1097
1098        let engines = self
1099            .engines
1100            .read()
1101            .map_err(|_| "Failed to acquire read lock on engines")?;
1102        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1103            Ok(engine_instance.config.clone())
1104        } else {
1105            Err(format!("Model {resolved_model_id} not found"))
1106        }
1107    }
1108}