1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5 get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6 EngineInstruction, SearchEmbeddingModel, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
7};
8use hf_hub::Cache;
9pub use lora::Ordering;
10pub use pipeline::ModelCategory;
11pub use pipeline::Pipeline;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::exceptions::PyValueError;
14use std::collections::HashMap;
15use std::num::NonZeroUsize;
16use std::sync::OnceLock;
17use std::time::Instant;
18use std::{
19 cell::RefCell,
20 error::Error,
21 fs::OpenOptions,
22 io::Write,
23 sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
24 thread::{self, JoinHandle},
25 time::{SystemTime, UNIX_EPOCH},
26};
27use tokio::sync::mpsc::{channel, Sender};
28use tracing::info;
29use tracing::warn;
30
31mod cuda;
32mod device_map;
33mod engine;
34mod lora;
35mod model_loader;
36mod moe;
37mod ops;
38pub use model_loader::{
39 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
40};
41mod embedding_models;
42mod kv_cache;
43mod search;
44
45mod model_selected;
46pub use model_selected::ModelSelected;
47pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
48
49mod amoe;
50mod attention;
51mod diffusion_models;
52pub mod distributed;
53mod gguf;
54pub mod harmony;
55pub mod layers;
56mod layers_masker;
57mod layers_utils;
58pub mod matformer;
59mod models;
60mod paged_attention;
61mod pipeline;
62mod prefix_cacher;
63mod request;
64mod response;
65mod sampler;
66mod scheduler;
67mod sequence;
68mod speech_models;
69mod toml_selector;
70mod tools;
71mod topology;
72mod utils;
73mod vision_models;
74mod xlora_models;
75
76pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
77pub use device_map::{
78 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
79};
80pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
81pub use mistralrs_audio::AudioInput;
82pub use mistralrs_mcp::{
83 CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
84};
85pub use mistralrs_mcp::{
86 McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
87};
88pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
89pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
90pub use pipeline::{
91 chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
92 AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
93 DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader, EmbeddingLoaderBuilder,
94 EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig, GGMLLoader,
95 GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig,
96 GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader,
97 Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind,
98 ModelPaths, MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
99 NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
100 SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
101 SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
102 VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
103};
104pub use request::{
105 ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
106 LlguidanceGrammar, MessageContent, NormalRequest, ReasoningEffort, Request, RequestMessage,
107 SearchContextSize, TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
108};
109pub use response::*;
110pub use sampler::{
111 CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
112};
113pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
114pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
115use serde::Serialize;
116pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
117use tokio::runtime::Runtime;
118use toml_selector::{TomlLoaderArgs, TomlSelector};
119pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
120pub use topology::{LayerTopology, Topology};
121pub use utils::debug::initialize_logging;
122pub use utils::memory_usage::MemoryUsage;
123pub use utils::normal::{ModelDType, TryIntoDType};
124pub use utils::{paged_attn_supported, using_flash_attn};
125
126pub use llguidance;
128
129pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
131pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
132
133#[derive(Clone)]
135pub struct EngineConfig {
136 pub no_kv_cache: bool,
137 pub no_prefix_cache: bool,
138 pub prefix_cache_n: usize,
139 pub disable_eos_stop: bool,
140 pub throughput_logging_enabled: bool,
141 pub search_embedding_model: Option<SearchEmbeddingModel>,
142 pub search_callback: Option<Arc<SearchCallback>>,
143 pub tool_callbacks: tools::ToolCallbacks,
144 pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
145}
146
147impl Default for EngineConfig {
148 fn default() -> Self {
149 Self {
150 no_kv_cache: false,
151 no_prefix_cache: false,
152 prefix_cache_n: 16,
153 disable_eos_stop: false,
154 throughput_logging_enabled: true,
155 search_embedding_model: None,
156 search_callback: None,
157 tool_callbacks: HashMap::new(),
158 tool_callbacks_with_tools: HashMap::new(),
159 }
160 }
161}
162
163#[derive(Clone)]
165pub struct AddModelConfig {
166 pub engine_config: EngineConfig,
167 pub mcp_client_config: Option<McpClientConfig>,
168}
169
170impl AddModelConfig {
171 pub fn new(engine_config: EngineConfig) -> Self {
172 Self {
173 engine_config,
174 mcp_client_config: None,
175 }
176 }
177
178 pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
179 self.mcp_client_config = Some(mcp_config);
180 self
181 }
182}
183
184#[derive(Clone)]
185pub struct MistralRsConfig {
186 pub kind: ModelKind,
187 pub device: Device,
188 pub category: ModelCategory,
189 pub modalities: Modalities,
190 pub max_seq_len: Option<usize>,
191}
192
193struct EngineInstance {
195 sender: Sender<Request>,
196 engine_handler: JoinHandle<()>,
197 reboot_state: RebootState,
198 config: MistralRsConfig,
199 category: ModelCategory,
200}
201
202pub struct MistralRs {
207 engines: RwLock<HashMap<String, EngineInstance>>,
208 default_engine_id: RwLock<Option<String>>,
209 log: Option<String>,
210 id: String,
211 creation_time: u64,
212 next_request_id: Mutex<RefCell<usize>>,
213}
214
215#[derive(Clone)]
216struct RebootState {
217 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
218 method: SchedulerConfig,
219 no_kv_cache: bool,
220 no_prefix_cache: bool,
221 prefix_cache_n: usize,
222 disable_eos_stop: bool,
223 throughput_logging_enabled: bool,
224 search_embedding_model: Option<SearchEmbeddingModel>,
225 search_callback: Option<Arc<search::SearchCallback>>,
226 tool_callbacks: tools::ToolCallbacks,
227 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
228 mcp_client_config: Option<McpClientConfig>,
229}
230
231#[derive(Debug)]
232pub enum MistralRsError {
233 EnginePoisoned,
234 SenderPoisoned,
235}
236
237impl std::fmt::Display for MistralRsError {
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 write!(f, "{:?}", &self)
240 }
241}
242
243impl std::error::Error for MistralRsError {}
244
245#[cfg(feature = "pyo3_macros")]
246impl From<MistralRsError> for pyo3::PyErr {
247 fn from(value: MistralRsError) -> Self {
248 PyValueError::new_err(format!("{value:?}"))
249 }
250}
251
252pub struct MistralRsBuilder {
256 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
257 method: SchedulerConfig,
258 log: Option<String>,
259 no_kv_cache: Option<bool>,
260 no_prefix_cache: Option<bool>,
261 prefix_cache_n: Option<usize>,
262 disable_eos_stop: Option<bool>,
263 throughput_logging_enabled: bool,
264 search_embedding_model: Option<SearchEmbeddingModel>,
265 search_callback: Option<Arc<SearchCallback>>,
266 tool_callbacks: tools::ToolCallbacks,
267 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
268 mcp_client_config: Option<McpClientConfig>,
269}
270
271impl MistralRsBuilder {
272 pub fn new(
276 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
277 method: SchedulerConfig,
278 throughput_logging: bool,
279 search_embedding_model: Option<SearchEmbeddingModel>,
280 ) -> Self {
281 Self {
282 pipeline,
283 method,
284 log: None,
285 no_kv_cache: None,
286 no_prefix_cache: None,
287 prefix_cache_n: None,
288 disable_eos_stop: None,
289 throughput_logging_enabled: throughput_logging,
290 search_embedding_model,
291 search_callback: None,
292 tool_callbacks: HashMap::new(),
293 tool_callbacks_with_tools: HashMap::new(),
294 mcp_client_config: None,
295 }
296 }
297 pub fn with_log(mut self, log: String) -> Self {
298 self.log = Some(log);
299 self
300 }
301 pub fn with_opt_log(mut self, log: Option<String>) -> Self {
302 self.log = log;
303 self
304 }
305 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
306 self.no_kv_cache = Some(no_kv_cache);
307 self
308 }
309 pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
310 self.no_prefix_cache = Some(no_prefix_cache);
311 self
312 }
313 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
314 self.prefix_cache_n = Some(prefix_cache_n);
315 self
316 }
317 pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
318 self.disable_eos_stop = Some(disable_eos_stop);
319 self
320 }
321
322 pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
324 self.search_callback = Some(search_callback);
325 self
326 }
327
328 pub fn with_tool_callback(
330 mut self,
331 name: impl Into<String>,
332 tool_callback: Arc<ToolCallback>,
333 ) -> Self {
334 self.tool_callbacks.insert(name.into(), tool_callback);
335 self
336 }
337
338 pub fn with_tool_callback_and_tool(
341 mut self,
342 name: impl Into<String>,
343 tool_callback: Arc<ToolCallback>,
344 tool: Tool,
345 ) -> Self {
346 let name = name.into();
347 self.tool_callbacks_with_tools.insert(
348 name,
349 ToolCallbackWithTool {
350 callback: tool_callback,
351 tool,
352 },
353 );
354 self
355 }
356
357 pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
359 self.mcp_client_config = Some(config);
360 self
361 }
362
363 pub async fn build(self) -> Arc<MistralRs> {
364 MistralRs::new(self).await
365 }
366}
367
368impl Drop for MistralRs {
369 fn drop(&mut self) {
370 if let Ok(engines) = self.engines.read() {
372 for (_, engine) in engines.iter() {
373 let _ = engine.sender.try_send(Request::Terminate);
375 }
376 }
377 }
378}
379
380impl MistralRs {
381 fn create_engine_instance(
383 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
384 method: SchedulerConfig,
385 config: EngineConfig,
386 reboot_state: RebootState,
387 ) -> Result<EngineInstance, String> {
388 let (tx, rx) = channel(10_000);
389
390 let pipeline_guard = pipeline.try_lock().unwrap();
391 let category = pipeline_guard.category();
392 let metadata = pipeline_guard.get_metadata();
393 let kind = metadata.kind.clone();
394 let device = pipeline_guard.device();
395 let modalities = metadata.modalities.clone();
396 let max_seq_len = match &category {
397 ModelCategory::Diffusion | ModelCategory::Speech => None,
398 _ => Some(metadata.max_seq_len),
399 };
400 drop(pipeline_guard);
401
402 info!("Pipeline input modalities are {:?}", &modalities.input);
403 info!("Pipeline output modalities are {:?}", &modalities.output);
404
405 let mistralrs_config = MistralRsConfig {
406 kind,
407 device,
408 category: category.clone(),
409 modalities,
410 max_seq_len,
411 };
412
413 let tx_for_engine = tx.clone();
414 let engine_handler = thread::spawn(move || {
415 #[cfg(feature = "metal")]
416 objc::rc::autoreleasepool(move || {
417 let rt = Runtime::new().unwrap();
418 rt.block_on(async move {
419 let engine = Engine::new(
420 tx_for_engine,
421 rx,
422 pipeline,
423 method,
424 config.no_kv_cache,
425 config.no_prefix_cache,
426 config.prefix_cache_n,
427 config.disable_eos_stop,
428 config.throughput_logging_enabled,
429 config.search_embedding_model,
430 config.search_callback.clone(),
431 config.tool_callbacks.clone(),
432 config.tool_callbacks_with_tools.clone(),
433 )
434 .expect("Engine creation failed.");
435 Arc::new(engine).run().await;
436 })
437 });
438
439 #[cfg(not(feature = "metal"))]
440 {
441 let rt = Runtime::new().unwrap();
442 rt.block_on(async move {
443 let engine = Engine::new(
444 tx_for_engine,
445 rx,
446 pipeline,
447 method,
448 config.no_kv_cache,
449 config.no_prefix_cache,
450 config.prefix_cache_n,
451 config.disable_eos_stop,
452 config.throughput_logging_enabled,
453 config.search_embedding_model,
454 config.search_callback.clone(),
455 config.tool_callbacks.clone(),
456 config.tool_callbacks_with_tools.clone(),
457 )
458 .expect("Engine creation failed.");
459 Arc::new(engine).run().await;
460 })
461 }
462 });
463
464 Ok(EngineInstance {
465 sender: tx,
466 engine_handler,
467 reboot_state,
468 config: mistralrs_config,
469 category,
470 })
471 }
472
473 async fn new(config: MistralRsBuilder) -> Arc<Self> {
474 let MistralRsBuilder {
475 pipeline,
476 method,
477 log,
478 no_kv_cache,
479 no_prefix_cache,
480 prefix_cache_n,
481 disable_eos_stop,
482 throughput_logging_enabled,
483 search_embedding_model,
484 search_callback,
485 tool_callbacks,
486 mut tool_callbacks_with_tools,
487 mcp_client_config,
488 } = config;
489
490 mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
491 get_mut_arcmutex!(pipeline).device(),
492 );
493
494 let method = if !get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache
497 && get_mut_arcmutex!(pipeline).cache().is_hybrid()
498 {
499 info!(
500 "Hybrid model detected (Mamba-Attention), enforcing batch_size=1 for correctness"
501 );
502 SchedulerConfig::DefaultScheduler {
503 method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()),
504 }
505 } else {
506 method
507 };
508
509 let no_kv_cache = no_kv_cache.unwrap_or(false);
510 let no_prefix_cache = no_prefix_cache.unwrap_or(false);
511 let prefix_cache_n = prefix_cache_n.unwrap_or(16);
512 let disable_eos_stop = disable_eos_stop.unwrap_or(false);
513
514 if let Some(config) = &mcp_client_config {
516 let mut mcp_client = McpClient::new(config.clone());
517 let total_servers = config.servers.len();
518
519 match mcp_client.initialize().await {
520 Ok(()) => {
521 let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
522 let tools_count = mcp_callbacks_with_tools.len();
523
524 for (name, callback_with_tool) in mcp_callbacks_with_tools {
526 tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
527 }
528
529 if tools_count == 0 {
530 warn!(
531 "MCP client initialized but no tools were registered from {} servers",
532 total_servers
533 );
534 } else {
535 info!(
536 "MCP client initialized successfully with {} tools from {} servers",
537 tools_count, total_servers
538 );
539 }
540 }
541 Err(e) => {
542 warn!(
543 "Failed to initialize MCP client with {} configured servers: {}",
544 total_servers, e
545 );
546 warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
547 }
548 }
549 }
550
551 let reboot_state = RebootState {
552 pipeline: pipeline.clone(),
553 method: method.clone(),
554 no_kv_cache,
555 no_prefix_cache,
556 prefix_cache_n,
557 disable_eos_stop,
558 throughput_logging_enabled,
559 search_embedding_model,
560 search_callback: search_callback.clone(),
561 tool_callbacks: tool_callbacks.clone(),
562 tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
563 mcp_client_config: mcp_client_config.clone(),
564 };
565
566 let engine_config = EngineConfig {
568 no_kv_cache,
569 no_prefix_cache,
570 prefix_cache_n,
571 disable_eos_stop,
572 throughput_logging_enabled,
573 search_embedding_model,
574 search_callback,
575 tool_callbacks,
576 tool_callbacks_with_tools,
577 };
578
579 let engine_instance =
581 Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
582 .expect("Failed to create engine instance");
583
584 let id = pipeline.try_lock().unwrap().name();
585
586 if distributed::is_daemon() {
587 let request_sender = engine_instance.sender.clone();
588
589 if cfg!(feature = "ring") {
590 distributed::ring_daemon_replicator(request_sender);
592 } else {
593 distributed::nccl_daemon_replicator(request_sender);
595 }
596
597 #[allow(clippy::empty_loop)]
598 loop {}
599 }
600
601 let is_multi_threaded = tokio::runtime::Handle::try_current()
603 .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
604
605 if !distributed::is_daemon()
607 && is_multi_threaded
608 && matches!(
609 engine_instance.category,
610 ModelCategory::Text | ModelCategory::Vision { .. }
611 )
612 {
613 let clone_sender = engine_instance.sender.clone();
614 tokio::task::block_in_place(|| {
615 let (tx, mut rx) = channel(1);
616 let req = Request::Normal(Box::new(NormalRequest {
617 id: 0,
618 messages: RequestMessage::Completion {
619 text: "hello".to_string(),
620 echo_prompt: false,
621 best_of: None,
622 },
623 sampling_params: SamplingParams {
624 max_len: Some(1),
625 ..SamplingParams::deterministic()
626 },
627 response: tx,
628 return_logprobs: false,
629 is_streaming: false,
630 constraint: Constraint::None,
631 suffix: None,
632 tool_choice: None,
633 tools: None,
634 logits_processors: None,
635 return_raw_logits: false,
636 web_search_options: None,
637 model_id: None,
638 truncate_sequence: false,
639 }));
640 info!("Beginning dummy run.");
641 let start = Instant::now();
642 clone_sender.blocking_send(req).unwrap();
643
644 let mut received_any = false;
646 while let Some(_resp) = rx.blocking_recv() {
647 received_any = true;
648 }
649
650 if received_any {
651 let end = Instant::now();
652 info!(
653 "Dummy run completed in {}s.",
654 end.duration_since(start).as_secs_f64()
655 );
656 } else {
657 warn!("Dummy run failed!");
658 }
659 });
660 }
661
662 let mut engines = HashMap::new();
664 engines.insert(id.clone(), engine_instance);
665
666 Arc::new(Self {
667 engines: RwLock::new(engines),
668 default_engine_id: RwLock::new(Some(id.clone())),
669 log,
670 id,
671 creation_time: SystemTime::now()
672 .duration_since(UNIX_EPOCH)
673 .expect("Time travel has occurred!")
674 .as_secs(),
675 next_request_id: Mutex::new(RefCell::new(1)),
676 })
677 }
678
679 fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
681 let mut engines = self.engines.write().map_err(|_| {
682 tracing::warn!("Couldn't get write lock on engines during reboot attempt");
683 MistralRsError::EnginePoisoned
684 })?;
685
686 if let Some(engine_instance) = engines.get(model_id) {
687 if !engine_instance.engine_handler.is_finished() {
688 tracing::info!("Engine {} already running, returning ok", model_id);
689 return Ok(());
690 }
691
692 let reboot_state = engine_instance.reboot_state.clone();
693 let engine_config = EngineConfig {
694 no_kv_cache: reboot_state.no_kv_cache,
695 no_prefix_cache: reboot_state.no_prefix_cache,
696 prefix_cache_n: reboot_state.prefix_cache_n,
697 disable_eos_stop: reboot_state.disable_eos_stop,
698 throughput_logging_enabled: reboot_state.throughput_logging_enabled,
699 search_embedding_model: reboot_state.search_embedding_model,
700 search_callback: reboot_state.search_callback.clone(),
701 tool_callbacks: reboot_state.tool_callbacks.clone(),
702 tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
703 };
704 let new_engine_instance = Self::create_engine_instance(
705 reboot_state.pipeline.clone(),
706 reboot_state.method.clone(),
707 engine_config,
708 reboot_state,
709 )
710 .map_err(|e| {
711 tracing::error!("Failed to create new engine instance: {}", e);
712 MistralRsError::EnginePoisoned
713 })?;
714
715 engines.insert(model_id.to_string(), new_engine_instance);
716 tracing::info!("Successfully rebooted engine {}", model_id);
717 Ok(())
718 } else {
719 Err(MistralRsError::EnginePoisoned)
720 }
721 }
722
723 fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
724 let engines = self.engines.read().map_err(|_| {
725 tracing::warn!("Couldn't get read lock on engines!");
726 MistralRsError::EnginePoisoned
727 })?;
728
729 if let Some(engine_instance) = engines.get(model_id) {
730 Ok(engine_instance.engine_handler.is_finished())
731 } else {
732 Err(MistralRsError::EnginePoisoned)
733 }
734 }
735
736 pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
738 let resolved_model_id = match model_id {
739 Some(id) => id.to_string(),
740 None => {
741 let default_lock = self
742 .default_engine_id
743 .read()
744 .map_err(|_| MistralRsError::SenderPoisoned)?;
745 default_lock
746 .as_ref()
747 .ok_or(MistralRsError::EnginePoisoned)?
748 .clone()
749 }
750 };
751
752 if self.engine_dead(&resolved_model_id)? {
753 tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
754 self.reboot_engine(&resolved_model_id)?
755 }
756
757 let engines = self
758 .engines
759 .read()
760 .map_err(|_| MistralRsError::SenderPoisoned)?;
761 if let Some(engine_instance) = engines.get(&resolved_model_id) {
762 Ok(engine_instance.sender.clone())
763 } else {
764 Err(MistralRsError::EnginePoisoned)
765 }
766 }
767
768 pub fn get_id(&self) -> String {
769 self.id.clone()
770 }
771
772 pub fn get_creation_time(&self) -> u64 {
773 self.creation_time
774 }
775
776 pub fn get_model_category(
778 &self,
779 model_id: Option<&str>,
780 ) -> Result<ModelCategory, MistralRsError> {
781 let resolved_model_id = match model_id {
782 Some(id) => id.to_string(),
783 None => {
784 let default_lock = self
785 .default_engine_id
786 .read()
787 .map_err(|_| MistralRsError::SenderPoisoned)?;
788 default_lock
789 .as_ref()
790 .ok_or(MistralRsError::EnginePoisoned)?
791 .clone()
792 }
793 };
794
795 let engines = self
796 .engines
797 .read()
798 .map_err(|_| MistralRsError::SenderPoisoned)?;
799 if let Some(engine_instance) = engines.get(&resolved_model_id) {
800 Ok(engine_instance.category.clone())
801 } else {
802 Err(MistralRsError::EnginePoisoned)
803 }
804 }
805
806 pub fn max_sequence_length(
808 &self,
809 model_id: Option<&str>,
810 ) -> Result<Option<usize>, MistralRsError> {
811 let resolved_model_id = match model_id {
812 Some(id) => id.to_string(),
813 None => {
814 let default_lock = self
815 .default_engine_id
816 .read()
817 .map_err(|_| MistralRsError::SenderPoisoned)?;
818 default_lock
819 .as_ref()
820 .ok_or(MistralRsError::EnginePoisoned)?
821 .clone()
822 }
823 };
824
825 let engines = self
826 .engines
827 .read()
828 .map_err(|_| MistralRsError::SenderPoisoned)?;
829 if let Some(engine_instance) = engines.get(&resolved_model_id) {
830 Ok(engine_instance.config.max_seq_len)
831 } else {
832 Err(MistralRsError::EnginePoisoned)
833 }
834 }
835
836 pub fn next_request_id(&self) -> usize {
837 let l = self.next_request_id.lock().unwrap();
838 let last = &mut *l.borrow_mut();
839 let last_v = *last;
840 *last += 1;
841 last_v
842 }
843
844 pub async fn add_model(
846 &self,
847 model_id: String,
848 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
849 method: SchedulerConfig,
850 config: AddModelConfig,
851 ) -> Result<(), String> {
852 let method = {
854 let pipeline_guard = pipeline.try_lock().unwrap();
855 if !pipeline_guard.get_metadata().no_kv_cache && pipeline_guard.cache().is_hybrid() {
856 info!(
857 "Hybrid model detected (Mamba-Attention), enforcing batch_size=1 for correctness"
858 );
859 SchedulerConfig::DefaultScheduler {
860 method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()),
861 }
862 } else {
863 method
864 }
865 };
866
867 let reboot_state = RebootState {
868 pipeline: pipeline.clone(),
869 method: method.clone(),
870 no_kv_cache: config.engine_config.no_kv_cache,
871 no_prefix_cache: config.engine_config.no_prefix_cache,
872 prefix_cache_n: config.engine_config.prefix_cache_n,
873 disable_eos_stop: config.engine_config.disable_eos_stop,
874 throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
875 search_embedding_model: config.engine_config.search_embedding_model,
876 search_callback: config.engine_config.search_callback.clone(),
877 tool_callbacks: config.engine_config.tool_callbacks.clone(),
878 tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
879 mcp_client_config: config.mcp_client_config.clone(),
880 };
881
882 let engine_instance =
883 Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
884
885 let mut engines = self
886 .engines
887 .write()
888 .map_err(|_| "Failed to acquire write lock on engines")?;
889 engines.insert(model_id.clone(), engine_instance);
890
891 if engines.len() == 1 {
893 let mut default_lock = self
894 .default_engine_id
895 .write()
896 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
897 *default_lock = Some(model_id.clone());
898 }
899
900 Ok(())
901 }
902
903 pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
905 let mut engines = self
906 .engines
907 .write()
908 .map_err(|_| "Failed to acquire write lock on engines")?;
909
910 if engines.len() <= 1 {
911 return Err("Cannot remove the last model from MistralRs".to_string());
912 }
913
914 if let Some(engine_instance) = engines.remove(model_id) {
915 let _ = engine_instance.sender.blocking_send(Request::Terminate);
917
918 let mut default_lock = self
920 .default_engine_id
921 .write()
922 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
923 if let Some(ref default_id) = *default_lock {
924 if default_id == model_id {
925 *default_lock = engines.keys().next().cloned();
927 }
928 }
929
930 Ok(())
931 } else {
932 Err(format!("Model {model_id} not found"))
933 }
934 }
935
936 pub fn list_models(&self) -> Result<Vec<String>, String> {
938 let engines = self
939 .engines
940 .read()
941 .map_err(|_| "Failed to acquire read lock on engines")?;
942 Ok(engines.keys().cloned().collect())
943 }
944
945 pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
947 let default_lock = self
948 .default_engine_id
949 .read()
950 .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
951 Ok(default_lock.clone())
952 }
953
954 pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
956 let engines = self
957 .engines
958 .read()
959 .map_err(|_| "Failed to acquire read lock on engines")?;
960 if !engines.contains_key(model_id) {
961 return Err(format!("Model {model_id} not found"));
962 }
963 drop(engines);
964
965 let mut default_lock = self
966 .default_engine_id
967 .write()
968 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
969 *default_lock = Some(model_id.to_string());
970
971 Ok(())
972 }
973
974 pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
976 let model_id = match &mut request {
977 Request::Normal(normal_req) => normal_req.model_id.as_deref(),
978 _ => None, };
980
981 let sender = self.get_sender(model_id)?;
982 sender
983 .blocking_send(request)
984 .map_err(|_| MistralRsError::SenderPoisoned)
985 }
986
987 pub fn maybe_log_request(this: Arc<Self>, repr: String) {
988 if let Some(file) = &this.log {
989 let mut f = OpenOptions::new()
990 .append(true)
991 .create(true) .open(file)
993 .expect("Unable to open file");
994 let time = chrono::offset::Local::now();
995 f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
996 .expect("Unable to write data");
997 }
998 }
999
1000 pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
1001 if let Some(file) = &this.log {
1002 let mut f = OpenOptions::new()
1003 .append(true)
1004 .create(true) .open(file)
1006 .expect("Unable to open file");
1007 let time = chrono::offset::Local::now();
1008 let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
1009 f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
1010 .expect("Unable to write data");
1011 }
1012 }
1013
1014 pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
1015 if let Some(file) = &this.log {
1016 let mut f = OpenOptions::new()
1017 .append(true)
1018 .create(true) .open(file)
1020 .expect("Unable to open file");
1021 let time = chrono::offset::Local::now();
1022 f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
1023 .expect("Unable to write data");
1024 }
1025 }
1026
1027 pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
1029 let resolved_model_id = match model_id {
1030 Some(id) => id.to_string(),
1031 None => {
1032 let default_lock = self
1033 .default_engine_id
1034 .read()
1035 .map_err(|_| "Failed to acquire read lock")?;
1036 default_lock
1037 .as_ref()
1038 .ok_or("No default engine set")?
1039 .clone()
1040 }
1041 };
1042
1043 let engines = self
1044 .engines
1045 .read()
1046 .map_err(|_| "Failed to acquire read lock on engines")?;
1047 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1048 Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
1049 } else {
1050 Err(format!("Model {resolved_model_id} not found"))
1051 }
1052 }
1053
1054 pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1056 let resolved_model_id = match model_id {
1057 Some(id) => id.to_string(),
1058 None => {
1059 let default_lock = self
1060 .default_engine_id
1061 .read()
1062 .map_err(|_| "Failed to acquire read lock")?;
1063 default_lock
1064 .as_ref()
1065 .ok_or("No default engine set")?
1066 .clone()
1067 }
1068 };
1069
1070 let engines = self
1071 .engines
1072 .read()
1073 .map_err(|_| "Failed to acquire read lock on engines")?;
1074 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1075 Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1076 } else {
1077 Err(format!("Model {resolved_model_id} not found"))
1078 }
1079 }
1080
1081 pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1083 let resolved_model_id = match model_id {
1084 Some(id) => id.to_string(),
1085 None => {
1086 let default_lock = self
1087 .default_engine_id
1088 .read()
1089 .map_err(|_| "Failed to acquire read lock")?;
1090 default_lock
1091 .as_ref()
1092 .ok_or("No default engine set")?
1093 .clone()
1094 }
1095 };
1096
1097 let engines = self
1098 .engines
1099 .read()
1100 .map_err(|_| "Failed to acquire read lock on engines")?;
1101 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1102 Ok(engine_instance.config.clone())
1103 } else {
1104 Err(format!("Model {resolved_model_id} not found"))
1105 }
1106 }
1107}