mistralrs_core/
lib.rs

1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5    get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6    BertEmbeddingModel, EngineInstruction, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
7};
8use hf_hub::Cache;
9pub use lora::Ordering;
10pub use pipeline::ModelCategory;
11pub use pipeline::Pipeline;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::exceptions::PyValueError;
14use std::collections::HashMap;
15use std::sync::OnceLock;
16use std::time::Instant;
17use std::{
18    cell::RefCell,
19    error::Error,
20    fs::OpenOptions,
21    io::Write,
22    sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
23    thread::{self, JoinHandle},
24    time::{SystemTime, UNIX_EPOCH},
25};
26use tokio::sync::mpsc::{channel, Sender};
27use tracing::info;
28use tracing::warn;
29
30mod cuda;
31mod device_map;
32mod engine;
33mod lora;
34mod model_loader;
35mod ops;
36pub use model_loader::{
37    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
38};
39mod kv_cache;
40mod search;
41
42mod model_selected;
43pub use model_selected::ModelSelected;
44pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
45
46mod amoe;
47#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
48mod dummy_paged_attention;
49mod embedding;
50mod gguf;
51pub mod layers;
52mod layers_masker;
53mod layers_utils;
54pub mod matformer;
55mod models;
56#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
57mod paged_attention;
58#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
59use dummy_paged_attention as paged_attention;
60mod attention;
61mod diffusion_models;
62pub mod distributed;
63mod pipeline;
64mod prefix_cacher;
65mod request;
66mod response;
67mod sampler;
68mod scheduler;
69mod sequence;
70mod speech_models;
71mod toml_selector;
72mod tools;
73mod topology;
74mod utils;
75mod vision_models;
76mod xlora_models;
77
78pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
79pub use device_map::{
80    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
81};
82pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
83pub use mistralrs_audio::AudioInput;
84pub use mistralrs_mcp::{
85    CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
86};
87pub use mistralrs_mcp::{
88    McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
89};
90pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
91pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
92pub use pipeline::{
93    chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
94    AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
95    DiffusionLoaderBuilder, DiffusionLoaderType, GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig,
96    GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig, GemmaLoader, Idefics2Loader,
97    IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths,
98    LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind, ModelPaths,
99    MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
100    NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
101    SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
102    SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
103    VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
104};
105pub use request::{
106    ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
107    LlguidanceGrammar, MessageContent, NormalRequest, Request, RequestMessage, SearchContextSize,
108    TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
109};
110pub use response::*;
111pub use sampler::{
112    CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
113};
114pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
115pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
116use serde::Serialize;
117pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
118use tokio::runtime::Runtime;
119use toml_selector::{TomlLoaderArgs, TomlSelector};
120pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
121pub use topology::{LayerTopology, Topology};
122pub use utils::debug::initialize_logging;
123pub use utils::memory_usage::MemoryUsage;
124pub use utils::normal::{ModelDType, TryIntoDType};
125pub use utils::{paged_attn_supported, using_flash_attn};
126
127// re-export llguidance for easier LlguidanceGrammar construction
128pub use llguidance;
129
130/// `true` if `MISTRALRS_DEBUG=1`
131pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
132pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
133
134/// Configuration for creating an engine instance
135#[derive(Clone)]
136pub struct EngineConfig {
137    pub truncate_sequence: bool,
138    pub no_kv_cache: bool,
139    pub no_prefix_cache: bool,
140    pub prefix_cache_n: usize,
141    pub disable_eos_stop: bool,
142    pub throughput_logging_enabled: bool,
143    pub search_embedding_model: Option<BertEmbeddingModel>,
144    pub search_callback: Option<Arc<SearchCallback>>,
145    pub tool_callbacks: tools::ToolCallbacks,
146    pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
147}
148
149impl Default for EngineConfig {
150    fn default() -> Self {
151        Self {
152            truncate_sequence: false,
153            no_kv_cache: false,
154            no_prefix_cache: false,
155            prefix_cache_n: 16,
156            disable_eos_stop: false,
157            throughput_logging_enabled: true,
158            search_embedding_model: None,
159            search_callback: None,
160            tool_callbacks: HashMap::new(),
161            tool_callbacks_with_tools: HashMap::new(),
162        }
163    }
164}
165
166/// Configuration for adding a model to MistralRs
167#[derive(Clone)]
168pub struct AddModelConfig {
169    pub engine_config: EngineConfig,
170    pub mcp_client_config: Option<McpClientConfig>,
171}
172
173impl AddModelConfig {
174    pub fn new(engine_config: EngineConfig) -> Self {
175        Self {
176            engine_config,
177            mcp_client_config: None,
178        }
179    }
180
181    pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
182        self.mcp_client_config = Some(mcp_config);
183        self
184    }
185}
186
187#[derive(Clone)]
188pub struct MistralRsConfig {
189    pub kind: ModelKind,
190    pub device: Device,
191    pub category: ModelCategory,
192    pub modalities: Modalities,
193}
194
195/// Internal structure to hold per-engine state
196struct EngineInstance {
197    sender: Sender<Request>,
198    engine_handler: JoinHandle<()>,
199    reboot_state: RebootState,
200    config: MistralRsConfig,
201    category: ModelCategory,
202}
203
204/// The MistralRs struct handles sending requests to multiple engines.
205/// It is the core multi-threaded component of mistral.rs, and uses `mpsc`
206/// `Sender` and `Receiver` primitives to send and receive requests to the
207/// appropriate engine based on model ID.
208pub struct MistralRs {
209    engines: RwLock<HashMap<String, EngineInstance>>,
210    default_engine_id: RwLock<Option<String>>,
211    log: Option<String>,
212    id: String,
213    creation_time: u64,
214    next_request_id: Mutex<RefCell<usize>>,
215}
216
217#[derive(Clone)]
218struct RebootState {
219    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
220    method: SchedulerConfig,
221    truncate_sequence: bool,
222    no_kv_cache: bool,
223    no_prefix_cache: bool,
224    prefix_cache_n: usize,
225    disable_eos_stop: bool,
226    throughput_logging_enabled: bool,
227    search_embedding_model: Option<BertEmbeddingModel>,
228    search_callback: Option<Arc<search::SearchCallback>>,
229    tool_callbacks: tools::ToolCallbacks,
230    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
231    mcp_client_config: Option<McpClientConfig>,
232}
233
234#[derive(Debug)]
235pub enum MistralRsError {
236    EnginePoisoned,
237    SenderPoisoned,
238}
239
240impl std::fmt::Display for MistralRsError {
241    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        write!(f, "{:?}", &self)
243    }
244}
245
246impl std::error::Error for MistralRsError {}
247
248#[cfg(feature = "pyo3_macros")]
249impl From<MistralRsError> for pyo3::PyErr {
250    fn from(value: MistralRsError) -> Self {
251        PyValueError::new_err(format!("{value:?}"))
252    }
253}
254
255/// The MistralRsBuilder takes the pipeline and a scheduler method and constructs
256/// an Engine and a MistralRs instance. The Engine runs on a separate thread, and the MistralRs
257/// instance stays on the calling thread.
258pub struct MistralRsBuilder {
259    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
260    method: SchedulerConfig,
261    log: Option<String>,
262    truncate_sequence: Option<bool>,
263    no_kv_cache: Option<bool>,
264    no_prefix_cache: Option<bool>,
265    prefix_cache_n: Option<usize>,
266    disable_eos_stop: Option<bool>,
267    throughput_logging_enabled: bool,
268    search_embedding_model: Option<BertEmbeddingModel>,
269    search_callback: Option<Arc<SearchCallback>>,
270    tool_callbacks: tools::ToolCallbacks,
271    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
272    mcp_client_config: Option<McpClientConfig>,
273}
274
275impl MistralRsBuilder {
276    /// Creates a new builder with the given pipeline, scheduler method, logging flag,
277    /// and optional embedding model for web search. To override the search callback,
278    /// use `.with_search_callback(...)` on the builder.
279    pub fn new(
280        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
281        method: SchedulerConfig,
282        throughput_logging: bool,
283        search_embedding_model: Option<BertEmbeddingModel>,
284    ) -> Self {
285        Self {
286            pipeline,
287            method,
288            log: None,
289            truncate_sequence: None,
290            no_kv_cache: None,
291            no_prefix_cache: None,
292            prefix_cache_n: None,
293            disable_eos_stop: None,
294            throughput_logging_enabled: throughput_logging,
295            search_embedding_model,
296            search_callback: None,
297            tool_callbacks: HashMap::new(),
298            tool_callbacks_with_tools: HashMap::new(),
299            mcp_client_config: None,
300        }
301    }
302    pub fn with_log(mut self, log: String) -> Self {
303        self.log = Some(log);
304        self
305    }
306    pub fn with_opt_log(mut self, log: Option<String>) -> Self {
307        self.log = log;
308        self
309    }
310    pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
311        self.truncate_sequence = Some(truncate_sequence);
312        self
313    }
314    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
315        self.no_kv_cache = Some(no_kv_cache);
316        self
317    }
318    pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
319        self.no_prefix_cache = Some(no_prefix_cache);
320        self
321    }
322    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
323        self.prefix_cache_n = Some(prefix_cache_n);
324        self
325    }
326    pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
327        self.disable_eos_stop = Some(disable_eos_stop);
328        self
329    }
330
331    /// Use a custom callback to gather search results.
332    pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
333        self.search_callback = Some(search_callback);
334        self
335    }
336
337    /// Register a custom callback for the specified tool name.
338    pub fn with_tool_callback(
339        mut self,
340        name: impl Into<String>,
341        tool_callback: Arc<ToolCallback>,
342    ) -> Self {
343        self.tool_callbacks.insert(name.into(), tool_callback);
344        self
345    }
346
347    /// Register a custom callback with its associated Tool definition. The Tool will be
348    /// automatically added to requests when tool callbacks are active.
349    pub fn with_tool_callback_and_tool(
350        mut self,
351        name: impl Into<String>,
352        tool_callback: Arc<ToolCallback>,
353        tool: Tool,
354    ) -> Self {
355        let name = name.into();
356        self.tool_callbacks_with_tools.insert(
357            name,
358            ToolCallbackWithTool {
359                callback: tool_callback,
360                tool,
361            },
362        );
363        self
364    }
365
366    /// Configure MCP client to connect to external MCP servers.
367    pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
368        self.mcp_client_config = Some(config);
369        self
370    }
371
372    pub async fn build(self) -> Arc<MistralRs> {
373        MistralRs::new(self).await
374    }
375}
376
377impl Drop for MistralRs {
378    fn drop(&mut self) {
379        // Terminate all engines
380        if let Ok(engines) = self.engines.read() {
381            for (_, engine) in engines.iter() {
382                // Use try_send instead of blocking_send to avoid runtime panics
383                let _ = engine.sender.try_send(Request::Terminate);
384            }
385        }
386    }
387}
388
389impl MistralRs {
390    /// Create an engine instance with the given configuration
391    fn create_engine_instance(
392        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
393        method: SchedulerConfig,
394        config: EngineConfig,
395        reboot_state: RebootState,
396    ) -> Result<EngineInstance, String> {
397        let (tx, rx) = channel(10_000);
398
399        let category = pipeline.try_lock().unwrap().category();
400        let kind = pipeline.try_lock().unwrap().get_metadata().kind.clone();
401        let device = pipeline.try_lock().unwrap().device();
402        let modalities = pipeline
403            .try_lock()
404            .unwrap()
405            .get_metadata()
406            .modalities
407            .clone();
408
409        info!("Pipeline input modalities are {:?}", &modalities.input);
410        info!("Pipeline output modalities are {:?}", &modalities.output);
411
412        let mistralrs_config = MistralRsConfig {
413            kind,
414            device,
415            category: category.clone(),
416            modalities,
417        };
418
419        let engine_handler = thread::spawn(move || {
420            #[cfg(feature = "metal")]
421            objc::rc::autoreleasepool(move || {
422                let rt = Runtime::new().unwrap();
423                rt.block_on(async move {
424                    let engine = Engine::new(
425                        rx,
426                        pipeline,
427                        method,
428                        config.truncate_sequence,
429                        config.no_kv_cache,
430                        config.no_prefix_cache,
431                        config.prefix_cache_n,
432                        config.disable_eos_stop,
433                        config.throughput_logging_enabled,
434                        config.search_embedding_model,
435                        config.search_callback.clone(),
436                        config.tool_callbacks.clone(),
437                        config.tool_callbacks_with_tools.clone(),
438                    )
439                    .expect("Engine creation failed.");
440                    Arc::new(engine).run().await;
441                })
442            });
443
444            #[cfg(not(feature = "metal"))]
445            {
446                let rt = Runtime::new().unwrap();
447                rt.block_on(async move {
448                    let engine = Engine::new(
449                        rx,
450                        pipeline,
451                        method,
452                        config.truncate_sequence,
453                        config.no_kv_cache,
454                        config.no_prefix_cache,
455                        config.prefix_cache_n,
456                        config.disable_eos_stop,
457                        config.throughput_logging_enabled,
458                        config.search_embedding_model,
459                        config.search_callback.clone(),
460                        config.tool_callbacks.clone(),
461                        config.tool_callbacks_with_tools.clone(),
462                    )
463                    .expect("Engine creation failed.");
464                    Arc::new(engine).run().await;
465                })
466            }
467        });
468
469        Ok(EngineInstance {
470            sender: tx,
471            engine_handler,
472            reboot_state,
473            config: mistralrs_config,
474            category,
475        })
476    }
477
478    async fn new(config: MistralRsBuilder) -> Arc<Self> {
479        let MistralRsBuilder {
480            pipeline,
481            method,
482            log,
483            truncate_sequence,
484            no_kv_cache,
485            no_prefix_cache,
486            prefix_cache_n,
487            disable_eos_stop,
488            throughput_logging_enabled,
489            search_embedding_model,
490            search_callback,
491            tool_callbacks,
492            mut tool_callbacks_with_tools,
493            mcp_client_config,
494        } = config;
495
496        mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
497            get_mut_arcmutex!(pipeline).device(),
498        );
499
500        let truncate_sequence = truncate_sequence.unwrap_or(false);
501        let no_kv_cache = no_kv_cache.unwrap_or(false);
502        let no_prefix_cache = no_prefix_cache.unwrap_or(false);
503        let prefix_cache_n = prefix_cache_n.unwrap_or(16);
504        let disable_eos_stop = disable_eos_stop.unwrap_or(false);
505
506        // Initialize MCP client if configured
507        if let Some(config) = &mcp_client_config {
508            let mut mcp_client = McpClient::new(config.clone());
509            let total_servers = config.servers.len();
510
511            match mcp_client.initialize().await {
512                Ok(()) => {
513                    let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
514                    let tools_count = mcp_callbacks_with_tools.len();
515
516                    // Merge MCP tool callbacks with tools into the new collection
517                    for (name, callback_with_tool) in mcp_callbacks_with_tools {
518                        tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
519                    }
520
521                    if tools_count == 0 {
522                        warn!(
523                            "MCP client initialized but no tools were registered from {} servers",
524                            total_servers
525                        );
526                    } else {
527                        info!(
528                            "MCP client initialized successfully with {} tools from {} servers",
529                            tools_count, total_servers
530                        );
531                    }
532                }
533                Err(e) => {
534                    warn!(
535                        "Failed to initialize MCP client with {} configured servers: {}",
536                        total_servers, e
537                    );
538                    warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
539                }
540            }
541        }
542
543        let reboot_state = RebootState {
544            pipeline: pipeline.clone(),
545            method: method.clone(),
546            truncate_sequence,
547            no_kv_cache,
548            no_prefix_cache,
549            prefix_cache_n,
550            disable_eos_stop,
551            throughput_logging_enabled,
552            search_embedding_model: search_embedding_model.clone(),
553            search_callback: search_callback.clone(),
554            tool_callbacks: tool_callbacks.clone(),
555            tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
556            mcp_client_config: mcp_client_config.clone(),
557        };
558
559        // Create the engine configuration
560        let engine_config = EngineConfig {
561            truncate_sequence,
562            no_kv_cache,
563            no_prefix_cache,
564            prefix_cache_n,
565            disable_eos_stop,
566            throughput_logging_enabled,
567            search_embedding_model,
568            search_callback,
569            tool_callbacks,
570            tool_callbacks_with_tools,
571        };
572
573        // Create the engine instance
574        let engine_instance =
575            Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
576                .expect("Failed to create engine instance");
577
578        let id = pipeline.try_lock().unwrap().name();
579
580        if distributed::is_daemon() {
581            let request_sender = engine_instance.sender.clone();
582
583            if cfg!(feature = "ring") {
584                // Ring daemon replicator
585                distributed::ring_daemon_replicator(request_sender);
586            } else {
587                // NCCL daemon replicator
588                distributed::nccl_daemon_replicator(request_sender);
589            }
590
591            #[allow(clippy::empty_loop)]
592            loop {}
593        }
594
595        // Determine if the current runtime is multi-threaded, as blocking operations are not allowed in single-threaded mode
596        let is_multi_threaded = tokio::runtime::Handle::try_current()
597            .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
598
599        // Do a dummy run
600        if !distributed::is_daemon()
601            && is_multi_threaded
602            && matches!(
603                engine_instance.category,
604                ModelCategory::Text | ModelCategory::Vision { .. }
605            )
606        {
607            let clone_sender = engine_instance.sender.clone();
608            tokio::task::block_in_place(|| {
609                let (tx, mut rx) = channel(1);
610                let req = Request::Normal(Box::new(NormalRequest {
611                    id: 0,
612                    messages: RequestMessage::Completion {
613                        text: "hello".to_string(),
614                        echo_prompt: false,
615                        best_of: None,
616                    },
617                    sampling_params: SamplingParams {
618                        max_len: Some(1),
619                        ..SamplingParams::deterministic()
620                    },
621                    response: tx,
622                    return_logprobs: false,
623                    is_streaming: false,
624                    constraint: Constraint::None,
625                    suffix: None,
626                    tool_choice: None,
627                    tools: None,
628                    logits_processors: None,
629                    return_raw_logits: false,
630                    web_search_options: None,
631                    model_id: None,
632                }));
633                info!("Beginning dummy run.");
634                let start = Instant::now();
635                clone_sender.blocking_send(req).unwrap();
636
637                // Drain all responses from the channel until it's closed
638                let mut received_any = false;
639                while let Some(_resp) = rx.blocking_recv() {
640                    received_any = true;
641                }
642
643                if received_any {
644                    let end = Instant::now();
645                    info!(
646                        "Dummy run completed in {}s.",
647                        end.duration_since(start).as_secs_f64()
648                    );
649                } else {
650                    warn!("Dummy run failed!");
651                }
652            });
653        }
654
655        // Create engines map with the first engine
656        let mut engines = HashMap::new();
657        engines.insert(id.clone(), engine_instance);
658
659        Arc::new(Self {
660            engines: RwLock::new(engines),
661            default_engine_id: RwLock::new(Some(id.clone())),
662            log,
663            id,
664            creation_time: SystemTime::now()
665                .duration_since(UNIX_EPOCH)
666                .expect("Time travel has occurred!")
667                .as_secs(),
668            next_request_id: Mutex::new(RefCell::new(1)),
669        })
670    }
671
672    /// Attempts to reboot a specific engine by model_id
673    fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
674        let mut engines = self.engines.write().map_err(|_| {
675            tracing::warn!("Couldn't get write lock on engines during reboot attempt");
676            MistralRsError::EnginePoisoned
677        })?;
678
679        if let Some(engine_instance) = engines.get(model_id) {
680            if !engine_instance.engine_handler.is_finished() {
681                tracing::info!("Engine {} already running, returning ok", model_id);
682                return Ok(());
683            }
684
685            let reboot_state = engine_instance.reboot_state.clone();
686            let engine_config = EngineConfig {
687                truncate_sequence: reboot_state.truncate_sequence,
688                no_kv_cache: reboot_state.no_kv_cache,
689                no_prefix_cache: reboot_state.no_prefix_cache,
690                prefix_cache_n: reboot_state.prefix_cache_n,
691                disable_eos_stop: reboot_state.disable_eos_stop,
692                throughput_logging_enabled: reboot_state.throughput_logging_enabled,
693                search_embedding_model: reboot_state.search_embedding_model.clone(),
694                search_callback: reboot_state.search_callback.clone(),
695                tool_callbacks: reboot_state.tool_callbacks.clone(),
696                tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
697            };
698            let new_engine_instance = Self::create_engine_instance(
699                reboot_state.pipeline.clone(),
700                reboot_state.method.clone(),
701                engine_config,
702                reboot_state,
703            )
704            .map_err(|e| {
705                tracing::error!("Failed to create new engine instance: {}", e);
706                MistralRsError::EnginePoisoned
707            })?;
708
709            engines.insert(model_id.to_string(), new_engine_instance);
710            tracing::info!("Successfully rebooted engine {}", model_id);
711            Ok(())
712        } else {
713            Err(MistralRsError::EnginePoisoned)
714        }
715    }
716
717    fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
718        let engines = self.engines.read().map_err(|_| {
719            tracing::warn!("Couldn't get read lock on engines!");
720            MistralRsError::EnginePoisoned
721        })?;
722
723        if let Some(engine_instance) = engines.get(model_id) {
724            Ok(engine_instance.engine_handler.is_finished())
725        } else {
726            Err(MistralRsError::EnginePoisoned)
727        }
728    }
729
730    /// Get sender for a specific model. If model_id is None, uses default engine.
731    pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
732        let resolved_model_id = match model_id {
733            Some(id) => id.to_string(),
734            None => {
735                let default_lock = self
736                    .default_engine_id
737                    .read()
738                    .map_err(|_| MistralRsError::SenderPoisoned)?;
739                default_lock
740                    .as_ref()
741                    .ok_or(MistralRsError::EnginePoisoned)?
742                    .clone()
743            }
744        };
745
746        if self.engine_dead(&resolved_model_id)? {
747            tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
748            self.reboot_engine(&resolved_model_id)?
749        }
750
751        let engines = self
752            .engines
753            .read()
754            .map_err(|_| MistralRsError::SenderPoisoned)?;
755        if let Some(engine_instance) = engines.get(&resolved_model_id) {
756            Ok(engine_instance.sender.clone())
757        } else {
758            Err(MistralRsError::EnginePoisoned)
759        }
760    }
761
762    pub fn get_id(&self) -> String {
763        self.id.clone()
764    }
765
766    pub fn get_creation_time(&self) -> u64 {
767        self.creation_time
768    }
769
770    /// Get model category for a specific model. If model_id is None, uses default engine.
771    pub fn get_model_category(
772        &self,
773        model_id: Option<&str>,
774    ) -> Result<ModelCategory, MistralRsError> {
775        let resolved_model_id = match model_id {
776            Some(id) => id.to_string(),
777            None => {
778                let default_lock = self
779                    .default_engine_id
780                    .read()
781                    .map_err(|_| MistralRsError::SenderPoisoned)?;
782                default_lock
783                    .as_ref()
784                    .ok_or(MistralRsError::EnginePoisoned)?
785                    .clone()
786            }
787        };
788
789        let engines = self
790            .engines
791            .read()
792            .map_err(|_| MistralRsError::SenderPoisoned)?;
793        if let Some(engine_instance) = engines.get(&resolved_model_id) {
794            Ok(engine_instance.category.clone())
795        } else {
796            Err(MistralRsError::EnginePoisoned)
797        }
798    }
799
800    pub fn next_request_id(&self) -> usize {
801        let l = self.next_request_id.lock().unwrap();
802        let last = &mut *l.borrow_mut();
803        let last_v = *last;
804        *last += 1;
805        last_v
806    }
807
808    /// Add a new model engine to the MistralRs instance
809    pub async fn add_model(
810        &self,
811        model_id: String,
812        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
813        method: SchedulerConfig,
814        config: AddModelConfig,
815    ) -> Result<(), String> {
816        let reboot_state = RebootState {
817            pipeline: pipeline.clone(),
818            method: method.clone(),
819            truncate_sequence: config.engine_config.truncate_sequence,
820            no_kv_cache: config.engine_config.no_kv_cache,
821            no_prefix_cache: config.engine_config.no_prefix_cache,
822            prefix_cache_n: config.engine_config.prefix_cache_n,
823            disable_eos_stop: config.engine_config.disable_eos_stop,
824            throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
825            search_embedding_model: config.engine_config.search_embedding_model.clone(),
826            search_callback: config.engine_config.search_callback.clone(),
827            tool_callbacks: config.engine_config.tool_callbacks.clone(),
828            tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
829            mcp_client_config: config.mcp_client_config.clone(),
830        };
831
832        let engine_instance =
833            Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
834
835        let mut engines = self
836            .engines
837            .write()
838            .map_err(|_| "Failed to acquire write lock on engines")?;
839        engines.insert(model_id.clone(), engine_instance);
840
841        // If this is the first model, set it as default
842        if engines.len() == 1 {
843            let mut default_lock = self
844                .default_engine_id
845                .write()
846                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
847            *default_lock = Some(model_id.clone());
848        }
849
850        Ok(())
851    }
852
853    /// Remove a model engine from the MistralRs instance
854    pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
855        let mut engines = self
856            .engines
857            .write()
858            .map_err(|_| "Failed to acquire write lock on engines")?;
859
860        if engines.len() <= 1 {
861            return Err("Cannot remove the last model from MistralRs".to_string());
862        }
863
864        if let Some(engine_instance) = engines.remove(model_id) {
865            // Send terminate signal to the engine
866            let _ = engine_instance.sender.blocking_send(Request::Terminate);
867
868            // If this was the default engine, set a new default
869            let mut default_lock = self
870                .default_engine_id
871                .write()
872                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
873            if let Some(ref default_id) = *default_lock {
874                if default_id == model_id {
875                    // Set the first available engine as the new default
876                    *default_lock = engines.keys().next().cloned();
877                }
878            }
879
880            Ok(())
881        } else {
882            Err(format!("Model {model_id} not found"))
883        }
884    }
885
886    /// List all available model IDs
887    pub fn list_models(&self) -> Result<Vec<String>, String> {
888        let engines = self
889            .engines
890            .read()
891            .map_err(|_| "Failed to acquire read lock on engines")?;
892        Ok(engines.keys().cloned().collect())
893    }
894
895    /// Get the current default model ID
896    pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
897        let default_lock = self
898            .default_engine_id
899            .read()
900            .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
901        Ok(default_lock.clone())
902    }
903
904    /// Set the default model ID
905    pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
906        let engines = self
907            .engines
908            .read()
909            .map_err(|_| "Failed to acquire read lock on engines")?;
910        if !engines.contains_key(model_id) {
911            return Err(format!("Model {model_id} not found"));
912        }
913        drop(engines);
914
915        let mut default_lock = self
916            .default_engine_id
917            .write()
918            .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
919        *default_lock = Some(model_id.to_string());
920
921        Ok(())
922    }
923
924    /// Dispatch a request to the appropriate engine based on the model_id in the request
925    pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
926        let model_id = match &mut request {
927            Request::Normal(normal_req) => normal_req.model_id.as_deref(),
928            _ => None, // Other request types don't specify model_id
929        };
930
931        let sender = self.get_sender(model_id)?;
932        sender
933            .blocking_send(request)
934            .map_err(|_| MistralRsError::SenderPoisoned)
935    }
936
937    pub fn maybe_log_request(this: Arc<Self>, repr: String) {
938        if let Some(file) = &this.log {
939            let mut f = OpenOptions::new()
940                .append(true)
941                .create(true) // Optionally create the file if it doesn't already exist
942                .open(file)
943                .expect("Unable to open file");
944            let time = chrono::offset::Local::now();
945            f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
946                .expect("Unable to write data");
947        }
948    }
949
950    pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
951        if let Some(file) = &this.log {
952            let mut f = OpenOptions::new()
953                .append(true)
954                .create(true) // Optionally create the file if it doesn't already exist
955                .open(file)
956                .expect("Unable to open file");
957            let time = chrono::offset::Local::now();
958            let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
959            f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
960                .expect("Unable to write data");
961        }
962    }
963
964    pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
965        if let Some(file) = &this.log {
966            let mut f = OpenOptions::new()
967                .append(true)
968                .create(true) // Optionally create the file if it doesn't already exist
969                .open(file)
970                .expect("Unable to open file");
971            let time = chrono::offset::Local::now();
972            f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
973                .expect("Unable to write data");
974        }
975    }
976
977    /// Get the number of tools available for a specific model (including MCP tools)
978    pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
979        let resolved_model_id = match model_id {
980            Some(id) => id.to_string(),
981            None => {
982                let default_lock = self
983                    .default_engine_id
984                    .read()
985                    .map_err(|_| "Failed to acquire read lock")?;
986                default_lock
987                    .as_ref()
988                    .ok_or("No default engine set")?
989                    .clone()
990            }
991        };
992
993        let engines = self
994            .engines
995            .read()
996            .map_err(|_| "Failed to acquire read lock on engines")?;
997        if let Some(engine_instance) = engines.get(&resolved_model_id) {
998            Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
999        } else {
1000            Err(format!("Model {resolved_model_id} not found"))
1001        }
1002    }
1003
1004    /// Check if MCP client is configured for a specific model
1005    pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1006        let resolved_model_id = match model_id {
1007            Some(id) => id.to_string(),
1008            None => {
1009                let default_lock = self
1010                    .default_engine_id
1011                    .read()
1012                    .map_err(|_| "Failed to acquire read lock")?;
1013                default_lock
1014                    .as_ref()
1015                    .ok_or("No default engine set")?
1016                    .clone()
1017            }
1018        };
1019
1020        let engines = self
1021            .engines
1022            .read()
1023            .map_err(|_| "Failed to acquire read lock on engines")?;
1024        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1025            Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1026        } else {
1027            Err(format!("Model {resolved_model_id} not found"))
1028        }
1029    }
1030
1031    /// Get config for a specific model
1032    pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1033        let resolved_model_id = match model_id {
1034            Some(id) => id.to_string(),
1035            None => {
1036                let default_lock = self
1037                    .default_engine_id
1038                    .read()
1039                    .map_err(|_| "Failed to acquire read lock")?;
1040                default_lock
1041                    .as_ref()
1042                    .ok_or("No default engine set")?
1043                    .clone()
1044            }
1045        };
1046
1047        let engines = self
1048            .engines
1049            .read()
1050            .map_err(|_| "Failed to acquire read lock on engines")?;
1051        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1052            Ok(engine_instance.config.clone())
1053        } else {
1054            Err(format!("Model {resolved_model_id} not found"))
1055        }
1056    }
1057}