mistralrs_core/
lib.rs

1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5    get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6    BertEmbeddingModel, EngineInstruction, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
7};
8use hf_hub::Cache;
9pub use lora::Ordering;
10pub use pipeline::ModelCategory;
11pub use pipeline::Pipeline;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::exceptions::PyValueError;
14use std::collections::HashMap;
15use std::sync::OnceLock;
16use std::time::Instant;
17use std::{
18    cell::RefCell,
19    error::Error,
20    fs::OpenOptions,
21    io::Write,
22    sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
23    thread::{self, JoinHandle},
24    time::{SystemTime, UNIX_EPOCH},
25};
26use tokio::sync::mpsc::{channel, Sender};
27use tracing::info;
28use tracing::warn;
29
30mod cuda;
31mod device_map;
32mod engine;
33mod lora;
34mod model_loader;
35mod ops;
36pub use model_loader::{
37    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
38};
39mod kv_cache;
40mod search;
41
42mod model_selected;
43pub use model_selected::ModelSelected;
44pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
45
46mod amoe;
47#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
48mod dummy_paged_attention;
49mod embedding;
50mod gguf;
51pub mod layers;
52mod layers_masker;
53mod layers_utils;
54mod models;
55#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
56mod paged_attention;
57#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
58use dummy_paged_attention as paged_attention;
59mod attention;
60mod diffusion_models;
61pub mod distributed;
62mod pipeline;
63mod prefix_cacher;
64mod request;
65mod response;
66mod sampler;
67mod scheduler;
68mod sequence;
69mod speech_models;
70mod toml_selector;
71mod tools;
72mod topology;
73mod utils;
74mod vision_models;
75mod xlora_models;
76
77pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
78pub use device_map::{
79    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
80};
81pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
82pub use mistralrs_audio::AudioInput;
83pub use mistralrs_mcp::{
84    CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
85};
86pub use mistralrs_mcp::{
87    McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
88};
89pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
90pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
91pub use pipeline::{
92    chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
93    AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
94    DiffusionLoaderBuilder, DiffusionLoaderType, GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig,
95    GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig, GemmaLoader, Idefics2Loader,
96    IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths,
97    LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind, ModelPaths,
98    MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
99    NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
100    SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
101    SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
102    VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
103};
104pub use request::{
105    ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
106    LlguidanceGrammar, MessageContent, NormalRequest, Request, RequestMessage, SearchContextSize,
107    TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
108};
109pub use response::*;
110pub use sampler::{
111    CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
112};
113pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
114pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
115use serde::Serialize;
116pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
117use tokio::runtime::Runtime;
118use toml_selector::{TomlLoaderArgs, TomlSelector};
119pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
120pub use topology::{LayerTopology, Topology};
121pub use utils::debug::initialize_logging;
122pub use utils::memory_usage::MemoryUsage;
123pub use utils::normal::{ModelDType, TryIntoDType};
124pub use utils::{paged_attn_supported, using_flash_attn};
125
126// re-export llguidance for easier LlguidanceGrammar construction
127pub use llguidance;
128
129/// `true` if `MISTRALRS_DEBUG=1`
130pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
131pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
132
133/// Configuration for creating an engine instance
134#[derive(Clone)]
135pub struct EngineConfig {
136    pub truncate_sequence: bool,
137    pub no_kv_cache: bool,
138    pub no_prefix_cache: bool,
139    pub prefix_cache_n: usize,
140    pub disable_eos_stop: bool,
141    pub throughput_logging_enabled: bool,
142    pub search_embedding_model: Option<BertEmbeddingModel>,
143    pub search_callback: Option<Arc<SearchCallback>>,
144    pub tool_callbacks: tools::ToolCallbacks,
145    pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
146}
147
148impl Default for EngineConfig {
149    fn default() -> Self {
150        Self {
151            truncate_sequence: false,
152            no_kv_cache: false,
153            no_prefix_cache: false,
154            prefix_cache_n: 16,
155            disable_eos_stop: false,
156            throughput_logging_enabled: true,
157            search_embedding_model: None,
158            search_callback: None,
159            tool_callbacks: HashMap::new(),
160            tool_callbacks_with_tools: HashMap::new(),
161        }
162    }
163}
164
165/// Configuration for adding a model to MistralRs
166#[derive(Clone)]
167pub struct AddModelConfig {
168    pub engine_config: EngineConfig,
169    pub mcp_client_config: Option<McpClientConfig>,
170}
171
172impl AddModelConfig {
173    pub fn new(engine_config: EngineConfig) -> Self {
174        Self {
175            engine_config,
176            mcp_client_config: None,
177        }
178    }
179
180    pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
181        self.mcp_client_config = Some(mcp_config);
182        self
183    }
184}
185
186#[derive(Clone)]
187pub struct MistralRsConfig {
188    pub kind: ModelKind,
189    pub device: Device,
190    pub category: ModelCategory,
191    pub modalities: Modalities,
192}
193
194/// Internal structure to hold per-engine state
195struct EngineInstance {
196    sender: Sender<Request>,
197    engine_handler: JoinHandle<()>,
198    reboot_state: RebootState,
199    config: MistralRsConfig,
200    category: ModelCategory,
201}
202
203/// The MistralRs struct handles sending requests to multiple engines.
204/// It is the core multi-threaded component of mistral.rs, and uses `mpsc`
205/// `Sender` and `Receiver` primitives to send and receive requests to the
206/// appropriate engine based on model ID.
207pub struct MistralRs {
208    engines: RwLock<HashMap<String, EngineInstance>>,
209    default_engine_id: RwLock<Option<String>>,
210    log: Option<String>,
211    id: String,
212    creation_time: u64,
213    next_request_id: Mutex<RefCell<usize>>,
214}
215
216#[derive(Clone)]
217struct RebootState {
218    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
219    method: SchedulerConfig,
220    truncate_sequence: bool,
221    no_kv_cache: bool,
222    no_prefix_cache: bool,
223    prefix_cache_n: usize,
224    disable_eos_stop: bool,
225    throughput_logging_enabled: bool,
226    search_embedding_model: Option<BertEmbeddingModel>,
227    search_callback: Option<Arc<search::SearchCallback>>,
228    tool_callbacks: tools::ToolCallbacks,
229    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
230    mcp_client_config: Option<McpClientConfig>,
231}
232
233#[derive(Debug)]
234pub enum MistralRsError {
235    EnginePoisoned,
236    SenderPoisoned,
237}
238
239impl std::fmt::Display for MistralRsError {
240    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241        write!(f, "{:?}", &self)
242    }
243}
244
245impl std::error::Error for MistralRsError {}
246
247#[cfg(feature = "pyo3_macros")]
248impl From<MistralRsError> for pyo3::PyErr {
249    fn from(value: MistralRsError) -> Self {
250        PyValueError::new_err(format!("{value:?}"))
251    }
252}
253
254/// The MistralRsBuilder takes the pipeline and a scheduler method and constructs
255/// an Engine and a MistralRs instance. The Engine runs on a separate thread, and the MistralRs
256/// instance stays on the calling thread.
257pub struct MistralRsBuilder {
258    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
259    method: SchedulerConfig,
260    log: Option<String>,
261    truncate_sequence: Option<bool>,
262    no_kv_cache: Option<bool>,
263    no_prefix_cache: Option<bool>,
264    prefix_cache_n: Option<usize>,
265    disable_eos_stop: Option<bool>,
266    throughput_logging_enabled: bool,
267    search_embedding_model: Option<BertEmbeddingModel>,
268    search_callback: Option<Arc<SearchCallback>>,
269    tool_callbacks: tools::ToolCallbacks,
270    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
271    mcp_client_config: Option<McpClientConfig>,
272}
273
274impl MistralRsBuilder {
275    /// Creates a new builder with the given pipeline, scheduler method, logging flag,
276    /// and optional embedding model for web search. To override the search callback,
277    /// use `.with_search_callback(...)` on the builder.
278    pub fn new(
279        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
280        method: SchedulerConfig,
281        throughput_logging: bool,
282        search_embedding_model: Option<BertEmbeddingModel>,
283    ) -> Self {
284        Self {
285            pipeline,
286            method,
287            log: None,
288            truncate_sequence: None,
289            no_kv_cache: None,
290            no_prefix_cache: None,
291            prefix_cache_n: None,
292            disable_eos_stop: None,
293            throughput_logging_enabled: throughput_logging,
294            search_embedding_model,
295            search_callback: None,
296            tool_callbacks: HashMap::new(),
297            tool_callbacks_with_tools: HashMap::new(),
298            mcp_client_config: None,
299        }
300    }
301    pub fn with_log(mut self, log: String) -> Self {
302        self.log = Some(log);
303        self
304    }
305    pub fn with_opt_log(mut self, log: Option<String>) -> Self {
306        self.log = log;
307        self
308    }
309    pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
310        self.truncate_sequence = Some(truncate_sequence);
311        self
312    }
313    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
314        self.no_kv_cache = Some(no_kv_cache);
315        self
316    }
317    pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
318        self.no_prefix_cache = Some(no_prefix_cache);
319        self
320    }
321    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
322        self.prefix_cache_n = Some(prefix_cache_n);
323        self
324    }
325    pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
326        self.disable_eos_stop = Some(disable_eos_stop);
327        self
328    }
329
330    /// Use a custom callback to gather search results.
331    pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
332        self.search_callback = Some(search_callback);
333        self
334    }
335
336    /// Register a custom callback for the specified tool name.
337    pub fn with_tool_callback(
338        mut self,
339        name: impl Into<String>,
340        tool_callback: Arc<ToolCallback>,
341    ) -> Self {
342        self.tool_callbacks.insert(name.into(), tool_callback);
343        self
344    }
345
346    /// Register a custom callback with its associated Tool definition. The Tool will be
347    /// automatically added to requests when tool callbacks are active.
348    pub fn with_tool_callback_and_tool(
349        mut self,
350        name: impl Into<String>,
351        tool_callback: Arc<ToolCallback>,
352        tool: Tool,
353    ) -> Self {
354        let name = name.into();
355        self.tool_callbacks_with_tools.insert(
356            name,
357            ToolCallbackWithTool {
358                callback: tool_callback,
359                tool,
360            },
361        );
362        self
363    }
364
365    /// Configure MCP client to connect to external MCP servers.
366    pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
367        self.mcp_client_config = Some(config);
368        self
369    }
370
371    pub async fn build(self) -> Arc<MistralRs> {
372        MistralRs::new(self).await
373    }
374}
375
376impl Drop for MistralRs {
377    fn drop(&mut self) {
378        // Terminate all engines
379        if let Ok(engines) = self.engines.read() {
380            for (_, engine) in engines.iter() {
381                // Use try_send instead of blocking_send to avoid runtime panics
382                let _ = engine.sender.try_send(Request::Terminate);
383            }
384        }
385    }
386}
387
388impl MistralRs {
389    /// Create an engine instance with the given configuration
390    fn create_engine_instance(
391        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
392        method: SchedulerConfig,
393        config: EngineConfig,
394        reboot_state: RebootState,
395    ) -> Result<EngineInstance, String> {
396        let (tx, rx) = channel(10_000);
397
398        let category = pipeline.try_lock().unwrap().category();
399        let kind = pipeline.try_lock().unwrap().get_metadata().kind.clone();
400        let device = pipeline.try_lock().unwrap().device();
401        let modalities = pipeline
402            .try_lock()
403            .unwrap()
404            .get_metadata()
405            .modalities
406            .clone();
407
408        info!("Pipeline input modalities are {:?}", &modalities.input);
409        info!("Pipeline output modalities are {:?}", &modalities.output);
410
411        let mistralrs_config = MistralRsConfig {
412            kind,
413            device,
414            category: category.clone(),
415            modalities,
416        };
417
418        let engine_handler = thread::spawn(move || {
419            #[cfg(feature = "metal")]
420            objc::rc::autoreleasepool(move || {
421                let rt = Runtime::new().unwrap();
422                rt.block_on(async move {
423                    let engine = Engine::new(
424                        rx,
425                        pipeline,
426                        method,
427                        config.truncate_sequence,
428                        config.no_kv_cache,
429                        config.no_prefix_cache,
430                        config.prefix_cache_n,
431                        config.disable_eos_stop,
432                        config.throughput_logging_enabled,
433                        config.search_embedding_model,
434                        config.search_callback.clone(),
435                        config.tool_callbacks.clone(),
436                        config.tool_callbacks_with_tools.clone(),
437                    )
438                    .expect("Engine creation failed.");
439                    Arc::new(engine).run().await;
440                })
441            });
442
443            #[cfg(not(feature = "metal"))]
444            {
445                let rt = Runtime::new().unwrap();
446                rt.block_on(async move {
447                    let engine = Engine::new(
448                        rx,
449                        pipeline,
450                        method,
451                        config.truncate_sequence,
452                        config.no_kv_cache,
453                        config.no_prefix_cache,
454                        config.prefix_cache_n,
455                        config.disable_eos_stop,
456                        config.throughput_logging_enabled,
457                        config.search_embedding_model,
458                        config.search_callback.clone(),
459                        config.tool_callbacks.clone(),
460                        config.tool_callbacks_with_tools.clone(),
461                    )
462                    .expect("Engine creation failed.");
463                    Arc::new(engine).run().await;
464                })
465            }
466        });
467
468        Ok(EngineInstance {
469            sender: tx,
470            engine_handler,
471            reboot_state,
472            config: mistralrs_config,
473            category,
474        })
475    }
476
477    async fn new(config: MistralRsBuilder) -> Arc<Self> {
478        let MistralRsBuilder {
479            pipeline,
480            method,
481            log,
482            truncate_sequence,
483            no_kv_cache,
484            no_prefix_cache,
485            prefix_cache_n,
486            disable_eos_stop,
487            throughput_logging_enabled,
488            search_embedding_model,
489            search_callback,
490            tool_callbacks,
491            mut tool_callbacks_with_tools,
492            mcp_client_config,
493        } = config;
494
495        mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
496            get_mut_arcmutex!(pipeline).device(),
497        );
498
499        let truncate_sequence = truncate_sequence.unwrap_or(false);
500        let no_kv_cache = no_kv_cache.unwrap_or(false);
501        let no_prefix_cache = no_prefix_cache.unwrap_or(false);
502        let prefix_cache_n = prefix_cache_n.unwrap_or(16);
503        let disable_eos_stop = disable_eos_stop.unwrap_or(false);
504
505        // Initialize MCP client if configured
506        if let Some(config) = &mcp_client_config {
507            let mut mcp_client = McpClient::new(config.clone());
508            let total_servers = config.servers.len();
509
510            match mcp_client.initialize().await {
511                Ok(()) => {
512                    let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
513                    let tools_count = mcp_callbacks_with_tools.len();
514
515                    // Merge MCP tool callbacks with tools into the new collection
516                    for (name, callback_with_tool) in mcp_callbacks_with_tools {
517                        tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
518                    }
519
520                    if tools_count == 0 {
521                        warn!(
522                            "MCP client initialized but no tools were registered from {} servers",
523                            total_servers
524                        );
525                    } else {
526                        info!(
527                            "MCP client initialized successfully with {} tools from {} servers",
528                            tools_count, total_servers
529                        );
530                    }
531                }
532                Err(e) => {
533                    warn!(
534                        "Failed to initialize MCP client with {} configured servers: {}",
535                        total_servers, e
536                    );
537                    warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
538                }
539            }
540        }
541
542        let reboot_state = RebootState {
543            pipeline: pipeline.clone(),
544            method: method.clone(),
545            truncate_sequence,
546            no_kv_cache,
547            no_prefix_cache,
548            prefix_cache_n,
549            disable_eos_stop,
550            throughput_logging_enabled,
551            search_embedding_model: search_embedding_model.clone(),
552            search_callback: search_callback.clone(),
553            tool_callbacks: tool_callbacks.clone(),
554            tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
555            mcp_client_config: mcp_client_config.clone(),
556        };
557
558        // Create the engine configuration
559        let engine_config = EngineConfig {
560            truncate_sequence,
561            no_kv_cache,
562            no_prefix_cache,
563            prefix_cache_n,
564            disable_eos_stop,
565            throughput_logging_enabled,
566            search_embedding_model,
567            search_callback,
568            tool_callbacks,
569            tool_callbacks_with_tools,
570        };
571
572        // Create the engine instance
573        let engine_instance =
574            Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
575                .expect("Failed to create engine instance");
576
577        let id = pipeline.try_lock().unwrap().name();
578
579        if distributed::is_daemon() {
580            let request_sender = engine_instance.sender.clone();
581
582            if cfg!(feature = "ring") {
583                // Ring daemon replicator
584                distributed::ring_daemon_replicator(request_sender);
585            } else {
586                // NCCL daemon replicator
587                distributed::nccl_daemon_replicator(request_sender);
588            }
589
590            #[allow(clippy::empty_loop)]
591            loop {}
592        }
593
594        // Determine if the current runtime is multi-threaded, as blocking operations are not allowed in single-threaded mode
595        let is_multi_threaded = tokio::runtime::Handle::try_current()
596            .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
597
598        // Do a dummy run
599        if !distributed::is_daemon()
600            && is_multi_threaded
601            && matches!(
602                engine_instance.category,
603                ModelCategory::Text | ModelCategory::Vision { .. }
604            )
605        {
606            let clone_sender = engine_instance.sender.clone();
607            tokio::task::block_in_place(|| {
608                let (tx, mut rx) = channel(1);
609                let req = Request::Normal(Box::new(NormalRequest {
610                    id: 0,
611                    messages: RequestMessage::Completion {
612                        text: "hello".to_string(),
613                        echo_prompt: false,
614                        best_of: None,
615                    },
616                    sampling_params: SamplingParams {
617                        max_len: Some(1),
618                        ..SamplingParams::deterministic()
619                    },
620                    response: tx,
621                    return_logprobs: false,
622                    is_streaming: false,
623                    constraint: Constraint::None,
624                    suffix: None,
625                    tool_choice: None,
626                    tools: None,
627                    logits_processors: None,
628                    return_raw_logits: false,
629                    web_search_options: None,
630                    model_id: None,
631                }));
632                info!("Beginning dummy run.");
633                let start = Instant::now();
634                clone_sender.blocking_send(req).unwrap();
635
636                if let Some(_resp) = rx.blocking_recv() {
637                    let end = Instant::now();
638                    info!(
639                        "Dummy run completed in {}s.",
640                        end.duration_since(start).as_secs_f64()
641                    );
642                } else {
643                    warn!("Dummy run failed!");
644                }
645            });
646        }
647
648        // Create engines map with the first engine
649        let mut engines = HashMap::new();
650        engines.insert(id.clone(), engine_instance);
651
652        Arc::new(Self {
653            engines: RwLock::new(engines),
654            default_engine_id: RwLock::new(Some(id.clone())),
655            log,
656            id,
657            creation_time: SystemTime::now()
658                .duration_since(UNIX_EPOCH)
659                .expect("Time travel has occurred!")
660                .as_secs(),
661            next_request_id: Mutex::new(RefCell::new(1)),
662        })
663    }
664
665    /// Attempts to reboot a specific engine by model_id
666    fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
667        let mut engines = self.engines.write().map_err(|_| {
668            tracing::warn!("Couldn't get write lock on engines during reboot attempt");
669            MistralRsError::EnginePoisoned
670        })?;
671
672        if let Some(engine_instance) = engines.get(model_id) {
673            if !engine_instance.engine_handler.is_finished() {
674                tracing::info!("Engine {} already running, returning ok", model_id);
675                return Ok(());
676            }
677
678            let reboot_state = engine_instance.reboot_state.clone();
679            let engine_config = EngineConfig {
680                truncate_sequence: reboot_state.truncate_sequence,
681                no_kv_cache: reboot_state.no_kv_cache,
682                no_prefix_cache: reboot_state.no_prefix_cache,
683                prefix_cache_n: reboot_state.prefix_cache_n,
684                disable_eos_stop: reboot_state.disable_eos_stop,
685                throughput_logging_enabled: reboot_state.throughput_logging_enabled,
686                search_embedding_model: reboot_state.search_embedding_model.clone(),
687                search_callback: reboot_state.search_callback.clone(),
688                tool_callbacks: reboot_state.tool_callbacks.clone(),
689                tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
690            };
691            let new_engine_instance = Self::create_engine_instance(
692                reboot_state.pipeline.clone(),
693                reboot_state.method.clone(),
694                engine_config,
695                reboot_state,
696            )
697            .map_err(|e| {
698                tracing::error!("Failed to create new engine instance: {}", e);
699                MistralRsError::EnginePoisoned
700            })?;
701
702            engines.insert(model_id.to_string(), new_engine_instance);
703            tracing::info!("Successfully rebooted engine {}", model_id);
704            Ok(())
705        } else {
706            Err(MistralRsError::EnginePoisoned)
707        }
708    }
709
710    fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
711        let engines = self.engines.read().map_err(|_| {
712            tracing::warn!("Couldn't get read lock on engines!");
713            MistralRsError::EnginePoisoned
714        })?;
715
716        if let Some(engine_instance) = engines.get(model_id) {
717            Ok(engine_instance.engine_handler.is_finished())
718        } else {
719            Err(MistralRsError::EnginePoisoned)
720        }
721    }
722
723    /// Get sender for a specific model. If model_id is None, uses default engine.
724    pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
725        let resolved_model_id = match model_id {
726            Some(id) => id.to_string(),
727            None => {
728                let default_lock = self
729                    .default_engine_id
730                    .read()
731                    .map_err(|_| MistralRsError::SenderPoisoned)?;
732                default_lock
733                    .as_ref()
734                    .ok_or(MistralRsError::EnginePoisoned)?
735                    .clone()
736            }
737        };
738
739        if self.engine_dead(&resolved_model_id)? {
740            tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
741            self.reboot_engine(&resolved_model_id)?
742        }
743
744        let engines = self
745            .engines
746            .read()
747            .map_err(|_| MistralRsError::SenderPoisoned)?;
748        if let Some(engine_instance) = engines.get(&resolved_model_id) {
749            Ok(engine_instance.sender.clone())
750        } else {
751            Err(MistralRsError::EnginePoisoned)
752        }
753    }
754
755    pub fn get_id(&self) -> String {
756        self.id.clone()
757    }
758
759    pub fn get_creation_time(&self) -> u64 {
760        self.creation_time
761    }
762
763    /// Get model category for a specific model. If model_id is None, uses default engine.
764    pub fn get_model_category(
765        &self,
766        model_id: Option<&str>,
767    ) -> Result<ModelCategory, MistralRsError> {
768        let resolved_model_id = match model_id {
769            Some(id) => id.to_string(),
770            None => {
771                let default_lock = self
772                    .default_engine_id
773                    .read()
774                    .map_err(|_| MistralRsError::SenderPoisoned)?;
775                default_lock
776                    .as_ref()
777                    .ok_or(MistralRsError::EnginePoisoned)?
778                    .clone()
779            }
780        };
781
782        let engines = self
783            .engines
784            .read()
785            .map_err(|_| MistralRsError::SenderPoisoned)?;
786        if let Some(engine_instance) = engines.get(&resolved_model_id) {
787            Ok(engine_instance.category.clone())
788        } else {
789            Err(MistralRsError::EnginePoisoned)
790        }
791    }
792
793    pub fn next_request_id(&self) -> usize {
794        let l = self.next_request_id.lock().unwrap();
795        let last = &mut *l.borrow_mut();
796        let last_v = *last;
797        *last += 1;
798        last_v
799    }
800
801    /// Add a new model engine to the MistralRs instance
802    pub async fn add_model(
803        &self,
804        model_id: String,
805        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
806        method: SchedulerConfig,
807        config: AddModelConfig,
808    ) -> Result<(), String> {
809        let reboot_state = RebootState {
810            pipeline: pipeline.clone(),
811            method: method.clone(),
812            truncate_sequence: config.engine_config.truncate_sequence,
813            no_kv_cache: config.engine_config.no_kv_cache,
814            no_prefix_cache: config.engine_config.no_prefix_cache,
815            prefix_cache_n: config.engine_config.prefix_cache_n,
816            disable_eos_stop: config.engine_config.disable_eos_stop,
817            throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
818            search_embedding_model: config.engine_config.search_embedding_model.clone(),
819            search_callback: config.engine_config.search_callback.clone(),
820            tool_callbacks: config.engine_config.tool_callbacks.clone(),
821            tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
822            mcp_client_config: config.mcp_client_config.clone(),
823        };
824
825        let engine_instance =
826            Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
827
828        let mut engines = self
829            .engines
830            .write()
831            .map_err(|_| "Failed to acquire write lock on engines")?;
832        engines.insert(model_id.clone(), engine_instance);
833
834        // If this is the first model, set it as default
835        if engines.len() == 1 {
836            let mut default_lock = self
837                .default_engine_id
838                .write()
839                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
840            *default_lock = Some(model_id.clone());
841        }
842
843        Ok(())
844    }
845
846    /// Remove a model engine from the MistralRs instance
847    pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
848        let mut engines = self
849            .engines
850            .write()
851            .map_err(|_| "Failed to acquire write lock on engines")?;
852
853        if engines.len() <= 1 {
854            return Err("Cannot remove the last model from MistralRs".to_string());
855        }
856
857        if let Some(engine_instance) = engines.remove(model_id) {
858            // Send terminate signal to the engine
859            let _ = engine_instance.sender.blocking_send(Request::Terminate);
860
861            // If this was the default engine, set a new default
862            let mut default_lock = self
863                .default_engine_id
864                .write()
865                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
866            if let Some(ref default_id) = *default_lock {
867                if default_id == model_id {
868                    // Set the first available engine as the new default
869                    *default_lock = engines.keys().next().cloned();
870                }
871            }
872
873            Ok(())
874        } else {
875            Err(format!("Model {model_id} not found"))
876        }
877    }
878
879    /// List all available model IDs
880    pub fn list_models(&self) -> Result<Vec<String>, String> {
881        let engines = self
882            .engines
883            .read()
884            .map_err(|_| "Failed to acquire read lock on engines")?;
885        Ok(engines.keys().cloned().collect())
886    }
887
888    /// Get the current default model ID
889    pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
890        let default_lock = self
891            .default_engine_id
892            .read()
893            .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
894        Ok(default_lock.clone())
895    }
896
897    /// Set the default model ID
898    pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
899        let engines = self
900            .engines
901            .read()
902            .map_err(|_| "Failed to acquire read lock on engines")?;
903        if !engines.contains_key(model_id) {
904            return Err(format!("Model {model_id} not found"));
905        }
906        drop(engines);
907
908        let mut default_lock = self
909            .default_engine_id
910            .write()
911            .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
912        *default_lock = Some(model_id.to_string());
913
914        Ok(())
915    }
916
917    /// Dispatch a request to the appropriate engine based on the model_id in the request
918    pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
919        let model_id = match &mut request {
920            Request::Normal(normal_req) => normal_req.model_id.as_deref(),
921            _ => None, // Other request types don't specify model_id
922        };
923
924        let sender = self.get_sender(model_id)?;
925        sender
926            .blocking_send(request)
927            .map_err(|_| MistralRsError::SenderPoisoned)
928    }
929
930    pub fn maybe_log_request(this: Arc<Self>, repr: String) {
931        if let Some(file) = &this.log {
932            let mut f = OpenOptions::new()
933                .append(true)
934                .create(true) // Optionally create the file if it doesn't already exist
935                .open(file)
936                .expect("Unable to open file");
937            let time = chrono::offset::Local::now();
938            f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
939                .expect("Unable to write data");
940        }
941    }
942
943    pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
944        if let Some(file) = &this.log {
945            let mut f = OpenOptions::new()
946                .append(true)
947                .create(true) // Optionally create the file if it doesn't already exist
948                .open(file)
949                .expect("Unable to open file");
950            let time = chrono::offset::Local::now();
951            let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
952            f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
953                .expect("Unable to write data");
954        }
955    }
956
957    pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
958        if let Some(file) = &this.log {
959            let mut f = OpenOptions::new()
960                .append(true)
961                .create(true) // Optionally create the file if it doesn't already exist
962                .open(file)
963                .expect("Unable to open file");
964            let time = chrono::offset::Local::now();
965            f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
966                .expect("Unable to write data");
967        }
968    }
969
970    /// Get the number of tools available for a specific model (including MCP tools)
971    pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
972        let resolved_model_id = match model_id {
973            Some(id) => id.to_string(),
974            None => {
975                let default_lock = self
976                    .default_engine_id
977                    .read()
978                    .map_err(|_| "Failed to acquire read lock")?;
979                default_lock
980                    .as_ref()
981                    .ok_or("No default engine set")?
982                    .clone()
983            }
984        };
985
986        let engines = self
987            .engines
988            .read()
989            .map_err(|_| "Failed to acquire read lock on engines")?;
990        if let Some(engine_instance) = engines.get(&resolved_model_id) {
991            Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
992        } else {
993            Err(format!("Model {resolved_model_id} not found"))
994        }
995    }
996
997    /// Check if MCP client is configured for a specific model
998    pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
999        let resolved_model_id = match model_id {
1000            Some(id) => id.to_string(),
1001            None => {
1002                let default_lock = self
1003                    .default_engine_id
1004                    .read()
1005                    .map_err(|_| "Failed to acquire read lock")?;
1006                default_lock
1007                    .as_ref()
1008                    .ok_or("No default engine set")?
1009                    .clone()
1010            }
1011        };
1012
1013        let engines = self
1014            .engines
1015            .read()
1016            .map_err(|_| "Failed to acquire read lock on engines")?;
1017        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1018            Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1019        } else {
1020            Err(format!("Model {resolved_model_id} not found"))
1021        }
1022    }
1023
1024    /// Get config for a specific model
1025    pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1026        let resolved_model_id = match model_id {
1027            Some(id) => id.to_string(),
1028            None => {
1029                let default_lock = self
1030                    .default_engine_id
1031                    .read()
1032                    .map_err(|_| "Failed to acquire read lock")?;
1033                default_lock
1034                    .as_ref()
1035                    .ok_or("No default engine set")?
1036                    .clone()
1037            }
1038        };
1039
1040        let engines = self
1041            .engines
1042            .read()
1043            .map_err(|_| "Failed to acquire read lock on engines")?;
1044        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1045            Ok(engine_instance.config.clone())
1046        } else {
1047            Err(format!("Model {resolved_model_id} not found"))
1048        }
1049    }
1050}