1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5 get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6 EngineInstruction, SearchEmbeddingModel, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
7};
8use hf_hub::Cache;
9pub use lora::Ordering;
10pub use pipeline::ModelCategory;
11pub use pipeline::Pipeline;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::exceptions::PyValueError;
14use std::collections::HashMap;
15use std::num::NonZeroUsize;
16use std::sync::OnceLock;
17use std::time::Instant;
18use std::{
19 cell::RefCell,
20 error::Error,
21 fs::OpenOptions,
22 io::Write,
23 sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
24 thread::{self, JoinHandle},
25 time::{SystemTime, UNIX_EPOCH},
26};
27use tokio::sync::mpsc::{channel, Sender};
28use tracing::info;
29use tracing::warn;
30
31mod cuda;
32mod device_map;
33mod engine;
34mod lora;
35mod model_loader;
36mod moe;
37mod ops;
38pub use model_loader::{
39 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
40};
41mod embedding_models;
42mod kv_cache;
43mod search;
44
45mod model_selected;
46pub use model_selected::ModelSelected;
47pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
48
49mod amoe;
50mod attention;
51mod diffusion_models;
52pub mod distributed;
53mod gguf;
54pub mod harmony;
55pub mod layers;
56mod layers_masker;
57mod layers_utils;
58pub mod matformer;
59mod models;
60mod paged_attention;
61mod pipeline;
62mod prefix_cacher;
63mod request;
64mod response;
65mod sampler;
66mod scheduler;
67mod sequence;
68mod speech_models;
69pub mod think_tags;
70mod toml_selector;
71mod tools;
72mod topology;
73mod utils;
74mod vision_models;
75mod xlora_models;
76
77pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
78pub use device_map::{
79 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
80};
81pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
82pub use mistralrs_audio::AudioInput;
83pub use mistralrs_mcp::{
84 CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
85};
86pub use mistralrs_mcp::{
87 McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
88};
89pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
90pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
91pub use pipeline::{
92 chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
93 AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
94 DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader, EmbeddingLoaderBuilder,
95 EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig, GGMLLoader,
96 GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig,
97 GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader,
98 Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind,
99 ModelPaths, MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
100 NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
101 SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
102 SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
103 VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
104};
105pub use request::{
106 ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
107 LlguidanceGrammar, MessageContent, NormalRequest, ReasoningEffort, Request, RequestMessage,
108 SearchContextSize, TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
109};
110pub use response::*;
111pub use sampler::{
112 CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
113};
114pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
115pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
116use serde::Serialize;
117pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
118use tokio::runtime::Runtime;
119use toml_selector::{TomlLoaderArgs, TomlSelector};
120pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
121pub use topology::{LayerTopology, Topology};
122pub use utils::debug::initialize_logging;
123pub use utils::memory_usage::MemoryUsage;
124pub use utils::normal::{ModelDType, TryIntoDType};
125pub use utils::{paged_attn_supported, using_flash_attn};
126
127pub use llguidance;
129
130pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
132pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
133
134#[derive(Clone)]
136pub struct EngineConfig {
137 pub no_kv_cache: bool,
138 pub no_prefix_cache: bool,
139 pub prefix_cache_n: usize,
140 pub disable_eos_stop: bool,
141 pub throughput_logging_enabled: bool,
142 pub search_embedding_model: Option<SearchEmbeddingModel>,
143 pub search_callback: Option<Arc<SearchCallback>>,
144 pub tool_callbacks: tools::ToolCallbacks,
145 pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
146}
147
148impl Default for EngineConfig {
149 fn default() -> Self {
150 Self {
151 no_kv_cache: false,
152 no_prefix_cache: false,
153 prefix_cache_n: 16,
154 disable_eos_stop: false,
155 throughput_logging_enabled: true,
156 search_embedding_model: None,
157 search_callback: None,
158 tool_callbacks: HashMap::new(),
159 tool_callbacks_with_tools: HashMap::new(),
160 }
161 }
162}
163
164#[derive(Clone)]
166pub struct AddModelConfig {
167 pub engine_config: EngineConfig,
168 pub mcp_client_config: Option<McpClientConfig>,
169}
170
171impl AddModelConfig {
172 pub fn new(engine_config: EngineConfig) -> Self {
173 Self {
174 engine_config,
175 mcp_client_config: None,
176 }
177 }
178
179 pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
180 self.mcp_client_config = Some(mcp_config);
181 self
182 }
183}
184
185#[derive(Clone)]
186pub struct MistralRsConfig {
187 pub kind: ModelKind,
188 pub device: Device,
189 pub category: ModelCategory,
190 pub modalities: Modalities,
191 pub max_seq_len: Option<usize>,
192}
193
194struct EngineInstance {
196 sender: Sender<Request>,
197 engine_handler: JoinHandle<()>,
198 reboot_state: RebootState,
199 config: MistralRsConfig,
200 category: ModelCategory,
201}
202
203pub struct MistralRs {
208 engines: RwLock<HashMap<String, EngineInstance>>,
209 default_engine_id: RwLock<Option<String>>,
210 log: Option<String>,
211 id: String,
212 creation_time: u64,
213 next_request_id: Mutex<RefCell<usize>>,
214}
215
216#[derive(Clone)]
217struct RebootState {
218 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
219 method: SchedulerConfig,
220 no_kv_cache: bool,
221 no_prefix_cache: bool,
222 prefix_cache_n: usize,
223 disable_eos_stop: bool,
224 throughput_logging_enabled: bool,
225 search_embedding_model: Option<SearchEmbeddingModel>,
226 search_callback: Option<Arc<search::SearchCallback>>,
227 tool_callbacks: tools::ToolCallbacks,
228 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
229 mcp_client_config: Option<McpClientConfig>,
230}
231
232#[derive(Debug)]
233pub enum MistralRsError {
234 EnginePoisoned,
235 SenderPoisoned,
236}
237
238impl std::fmt::Display for MistralRsError {
239 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240 write!(f, "{:?}", &self)
241 }
242}
243
244impl std::error::Error for MistralRsError {}
245
246#[cfg(feature = "pyo3_macros")]
247impl From<MistralRsError> for pyo3::PyErr {
248 fn from(value: MistralRsError) -> Self {
249 PyValueError::new_err(format!("{value:?}"))
250 }
251}
252
253pub struct MistralRsBuilder {
257 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
258 method: SchedulerConfig,
259 log: Option<String>,
260 no_kv_cache: Option<bool>,
261 no_prefix_cache: Option<bool>,
262 prefix_cache_n: Option<usize>,
263 disable_eos_stop: Option<bool>,
264 throughput_logging_enabled: bool,
265 search_embedding_model: Option<SearchEmbeddingModel>,
266 search_callback: Option<Arc<SearchCallback>>,
267 tool_callbacks: tools::ToolCallbacks,
268 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
269 mcp_client_config: Option<McpClientConfig>,
270}
271
272impl MistralRsBuilder {
273 pub fn new(
277 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
278 method: SchedulerConfig,
279 throughput_logging: bool,
280 search_embedding_model: Option<SearchEmbeddingModel>,
281 ) -> Self {
282 Self {
283 pipeline,
284 method,
285 log: None,
286 no_kv_cache: None,
287 no_prefix_cache: None,
288 prefix_cache_n: None,
289 disable_eos_stop: None,
290 throughput_logging_enabled: throughput_logging,
291 search_embedding_model,
292 search_callback: None,
293 tool_callbacks: HashMap::new(),
294 tool_callbacks_with_tools: HashMap::new(),
295 mcp_client_config: None,
296 }
297 }
298 pub fn with_log(mut self, log: String) -> Self {
299 self.log = Some(log);
300 self
301 }
302 pub fn with_opt_log(mut self, log: Option<String>) -> Self {
303 self.log = log;
304 self
305 }
306 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
307 self.no_kv_cache = Some(no_kv_cache);
308 self
309 }
310 pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
311 self.no_prefix_cache = Some(no_prefix_cache);
312 self
313 }
314 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
315 self.prefix_cache_n = Some(prefix_cache_n);
316 self
317 }
318 pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
319 self.disable_eos_stop = Some(disable_eos_stop);
320 self
321 }
322
323 pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
325 self.search_callback = Some(search_callback);
326 self
327 }
328
329 pub fn with_tool_callback(
331 mut self,
332 name: impl Into<String>,
333 tool_callback: Arc<ToolCallback>,
334 ) -> Self {
335 self.tool_callbacks.insert(name.into(), tool_callback);
336 self
337 }
338
339 pub fn with_tool_callback_and_tool(
342 mut self,
343 name: impl Into<String>,
344 tool_callback: Arc<ToolCallback>,
345 tool: Tool,
346 ) -> Self {
347 let name = name.into();
348 self.tool_callbacks_with_tools.insert(
349 name,
350 ToolCallbackWithTool {
351 callback: tool_callback,
352 tool,
353 },
354 );
355 self
356 }
357
358 pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
360 self.mcp_client_config = Some(config);
361 self
362 }
363
364 pub async fn build(self) -> Arc<MistralRs> {
365 MistralRs::new(self).await
366 }
367}
368
369impl Drop for MistralRs {
370 fn drop(&mut self) {
371 if let Ok(engines) = self.engines.read() {
373 for (_, engine) in engines.iter() {
374 let _ = engine.sender.try_send(Request::Terminate);
376 }
377 }
378 }
379}
380
381impl MistralRs {
382 fn create_engine_instance(
384 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
385 method: SchedulerConfig,
386 config: EngineConfig,
387 reboot_state: RebootState,
388 ) -> Result<EngineInstance, String> {
389 let (tx, rx) = channel(10_000);
390
391 let pipeline_guard = pipeline.try_lock().unwrap();
392 let category = pipeline_guard.category();
393 let metadata = pipeline_guard.get_metadata();
394 let kind = metadata.kind.clone();
395 let device = pipeline_guard.device();
396 let modalities = metadata.modalities.clone();
397 let max_seq_len = match &category {
398 ModelCategory::Diffusion | ModelCategory::Speech => None,
399 _ => Some(metadata.max_seq_len),
400 };
401 drop(pipeline_guard);
402
403 info!("Pipeline input modalities are {:?}", &modalities.input);
404 info!("Pipeline output modalities are {:?}", &modalities.output);
405
406 let mistralrs_config = MistralRsConfig {
407 kind,
408 device,
409 category: category.clone(),
410 modalities,
411 max_seq_len,
412 };
413
414 let tx_for_engine = tx.clone();
415 let engine_handler = thread::spawn(move || {
416 #[cfg(feature = "metal")]
417 objc::rc::autoreleasepool(move || {
418 let rt = Runtime::new().unwrap();
419 rt.block_on(async move {
420 let engine = Engine::new(
421 tx_for_engine,
422 rx,
423 pipeline,
424 method,
425 config.no_kv_cache,
426 config.no_prefix_cache,
427 config.prefix_cache_n,
428 config.disable_eos_stop,
429 config.throughput_logging_enabled,
430 config.search_embedding_model,
431 config.search_callback.clone(),
432 config.tool_callbacks.clone(),
433 config.tool_callbacks_with_tools.clone(),
434 )
435 .expect("Engine creation failed.");
436 Arc::new(engine).run().await;
437 })
438 });
439
440 #[cfg(not(feature = "metal"))]
441 {
442 let rt = Runtime::new().unwrap();
443 rt.block_on(async move {
444 let engine = Engine::new(
445 tx_for_engine,
446 rx,
447 pipeline,
448 method,
449 config.no_kv_cache,
450 config.no_prefix_cache,
451 config.prefix_cache_n,
452 config.disable_eos_stop,
453 config.throughput_logging_enabled,
454 config.search_embedding_model,
455 config.search_callback.clone(),
456 config.tool_callbacks.clone(),
457 config.tool_callbacks_with_tools.clone(),
458 )
459 .expect("Engine creation failed.");
460 Arc::new(engine).run().await;
461 })
462 }
463 });
464
465 Ok(EngineInstance {
466 sender: tx,
467 engine_handler,
468 reboot_state,
469 config: mistralrs_config,
470 category,
471 })
472 }
473
474 async fn new(config: MistralRsBuilder) -> Arc<Self> {
475 let MistralRsBuilder {
476 pipeline,
477 method,
478 log,
479 no_kv_cache,
480 no_prefix_cache,
481 prefix_cache_n,
482 disable_eos_stop,
483 throughput_logging_enabled,
484 search_embedding_model,
485 search_callback,
486 tool_callbacks,
487 mut tool_callbacks_with_tools,
488 mcp_client_config,
489 } = config;
490
491 mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
492 get_mut_arcmutex!(pipeline).device(),
493 );
494
495 let method = if !get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache
498 && get_mut_arcmutex!(pipeline).cache().is_hybrid()
499 {
500 info!(
501 "Hybrid model detected (Mamba-Attention), enforcing batch_size=1 for correctness"
502 );
503 SchedulerConfig::DefaultScheduler {
504 method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()),
505 }
506 } else {
507 method
508 };
509
510 let no_kv_cache = no_kv_cache.unwrap_or(false);
511 let no_prefix_cache = no_prefix_cache.unwrap_or(false);
512 let prefix_cache_n = prefix_cache_n.unwrap_or(16);
513 let disable_eos_stop = disable_eos_stop.unwrap_or(false);
514
515 if let Some(config) = &mcp_client_config {
517 let mut mcp_client = McpClient::new(config.clone());
518 let total_servers = config.servers.len();
519
520 match mcp_client.initialize().await {
521 Ok(()) => {
522 let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
523 let tools_count = mcp_callbacks_with_tools.len();
524
525 for (name, callback_with_tool) in mcp_callbacks_with_tools {
527 tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
528 }
529
530 if tools_count == 0 {
531 warn!(
532 "MCP client initialized but no tools were registered from {} servers",
533 total_servers
534 );
535 } else {
536 info!(
537 "MCP client initialized successfully with {} tools from {} servers",
538 tools_count, total_servers
539 );
540 }
541 }
542 Err(e) => {
543 warn!(
544 "Failed to initialize MCP client with {} configured servers: {}",
545 total_servers, e
546 );
547 warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
548 }
549 }
550 }
551
552 let reboot_state = RebootState {
553 pipeline: pipeline.clone(),
554 method: method.clone(),
555 no_kv_cache,
556 no_prefix_cache,
557 prefix_cache_n,
558 disable_eos_stop,
559 throughput_logging_enabled,
560 search_embedding_model,
561 search_callback: search_callback.clone(),
562 tool_callbacks: tool_callbacks.clone(),
563 tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
564 mcp_client_config: mcp_client_config.clone(),
565 };
566
567 let engine_config = EngineConfig {
569 no_kv_cache,
570 no_prefix_cache,
571 prefix_cache_n,
572 disable_eos_stop,
573 throughput_logging_enabled,
574 search_embedding_model,
575 search_callback,
576 tool_callbacks,
577 tool_callbacks_with_tools,
578 };
579
580 let engine_instance =
582 Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
583 .expect("Failed to create engine instance");
584
585 let id = pipeline.try_lock().unwrap().name();
586
587 if distributed::is_daemon() {
588 let request_sender = engine_instance.sender.clone();
589
590 if cfg!(feature = "ring") {
591 distributed::ring_daemon_replicator(request_sender);
593 } else {
594 distributed::nccl_daemon_replicator(request_sender);
596 }
597
598 #[allow(clippy::empty_loop)]
599 loop {}
600 }
601
602 let is_multi_threaded = tokio::runtime::Handle::try_current()
604 .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
605
606 if !distributed::is_daemon()
608 && is_multi_threaded
609 && matches!(
610 engine_instance.category,
611 ModelCategory::Text | ModelCategory::Vision { .. }
612 )
613 {
614 let clone_sender = engine_instance.sender.clone();
615 tokio::task::block_in_place(|| {
616 let (tx, mut rx) = channel(1);
617 let req = Request::Normal(Box::new(NormalRequest {
618 id: 0,
619 messages: RequestMessage::Completion {
620 text: "hello".to_string(),
621 echo_prompt: false,
622 best_of: None,
623 },
624 sampling_params: SamplingParams {
625 max_len: Some(1),
626 ..SamplingParams::deterministic()
627 },
628 response: tx,
629 return_logprobs: false,
630 is_streaming: false,
631 constraint: Constraint::None,
632 suffix: None,
633 tool_choice: None,
634 tools: None,
635 logits_processors: None,
636 return_raw_logits: false,
637 web_search_options: None,
638 model_id: None,
639 truncate_sequence: false,
640 }));
641 info!("Beginning dummy run.");
642 let start = Instant::now();
643 clone_sender.blocking_send(req).unwrap();
644
645 let mut received_any = false;
647 while let Some(_resp) = rx.blocking_recv() {
648 received_any = true;
649 }
650
651 if received_any {
652 let end = Instant::now();
653 info!(
654 "Dummy run completed in {}s.",
655 end.duration_since(start).as_secs_f64()
656 );
657 } else {
658 warn!("Dummy run failed!");
659 }
660 });
661 }
662
663 let mut engines = HashMap::new();
665 engines.insert(id.clone(), engine_instance);
666
667 Arc::new(Self {
668 engines: RwLock::new(engines),
669 default_engine_id: RwLock::new(Some(id.clone())),
670 log,
671 id,
672 creation_time: SystemTime::now()
673 .duration_since(UNIX_EPOCH)
674 .expect("Time travel has occurred!")
675 .as_secs(),
676 next_request_id: Mutex::new(RefCell::new(1)),
677 })
678 }
679
680 fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
682 let mut engines = self.engines.write().map_err(|_| {
683 tracing::warn!("Couldn't get write lock on engines during reboot attempt");
684 MistralRsError::EnginePoisoned
685 })?;
686
687 if let Some(engine_instance) = engines.get(model_id) {
688 if !engine_instance.engine_handler.is_finished() {
689 tracing::info!("Engine {} already running, returning ok", model_id);
690 return Ok(());
691 }
692
693 let reboot_state = engine_instance.reboot_state.clone();
694 let engine_config = EngineConfig {
695 no_kv_cache: reboot_state.no_kv_cache,
696 no_prefix_cache: reboot_state.no_prefix_cache,
697 prefix_cache_n: reboot_state.prefix_cache_n,
698 disable_eos_stop: reboot_state.disable_eos_stop,
699 throughput_logging_enabled: reboot_state.throughput_logging_enabled,
700 search_embedding_model: reboot_state.search_embedding_model,
701 search_callback: reboot_state.search_callback.clone(),
702 tool_callbacks: reboot_state.tool_callbacks.clone(),
703 tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
704 };
705 let new_engine_instance = Self::create_engine_instance(
706 reboot_state.pipeline.clone(),
707 reboot_state.method.clone(),
708 engine_config,
709 reboot_state,
710 )
711 .map_err(|e| {
712 tracing::error!("Failed to create new engine instance: {}", e);
713 MistralRsError::EnginePoisoned
714 })?;
715
716 engines.insert(model_id.to_string(), new_engine_instance);
717 tracing::info!("Successfully rebooted engine {}", model_id);
718 Ok(())
719 } else {
720 Err(MistralRsError::EnginePoisoned)
721 }
722 }
723
724 fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
725 let engines = self.engines.read().map_err(|_| {
726 tracing::warn!("Couldn't get read lock on engines!");
727 MistralRsError::EnginePoisoned
728 })?;
729
730 if let Some(engine_instance) = engines.get(model_id) {
731 Ok(engine_instance.engine_handler.is_finished())
732 } else {
733 Err(MistralRsError::EnginePoisoned)
734 }
735 }
736
737 pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
739 let resolved_model_id = match model_id {
740 Some(id) => id.to_string(),
741 None => {
742 let default_lock = self
743 .default_engine_id
744 .read()
745 .map_err(|_| MistralRsError::SenderPoisoned)?;
746 default_lock
747 .as_ref()
748 .ok_or(MistralRsError::EnginePoisoned)?
749 .clone()
750 }
751 };
752
753 if self.engine_dead(&resolved_model_id)? {
754 tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
755 self.reboot_engine(&resolved_model_id)?
756 }
757
758 let engines = self
759 .engines
760 .read()
761 .map_err(|_| MistralRsError::SenderPoisoned)?;
762 if let Some(engine_instance) = engines.get(&resolved_model_id) {
763 Ok(engine_instance.sender.clone())
764 } else {
765 Err(MistralRsError::EnginePoisoned)
766 }
767 }
768
769 pub fn get_id(&self) -> String {
770 self.id.clone()
771 }
772
773 pub fn get_creation_time(&self) -> u64 {
774 self.creation_time
775 }
776
777 pub fn get_model_category(
779 &self,
780 model_id: Option<&str>,
781 ) -> Result<ModelCategory, MistralRsError> {
782 let resolved_model_id = match model_id {
783 Some(id) => id.to_string(),
784 None => {
785 let default_lock = self
786 .default_engine_id
787 .read()
788 .map_err(|_| MistralRsError::SenderPoisoned)?;
789 default_lock
790 .as_ref()
791 .ok_or(MistralRsError::EnginePoisoned)?
792 .clone()
793 }
794 };
795
796 let engines = self
797 .engines
798 .read()
799 .map_err(|_| MistralRsError::SenderPoisoned)?;
800 if let Some(engine_instance) = engines.get(&resolved_model_id) {
801 Ok(engine_instance.category.clone())
802 } else {
803 Err(MistralRsError::EnginePoisoned)
804 }
805 }
806
807 pub fn max_sequence_length(
809 &self,
810 model_id: Option<&str>,
811 ) -> Result<Option<usize>, MistralRsError> {
812 let resolved_model_id = match model_id {
813 Some(id) => id.to_string(),
814 None => {
815 let default_lock = self
816 .default_engine_id
817 .read()
818 .map_err(|_| MistralRsError::SenderPoisoned)?;
819 default_lock
820 .as_ref()
821 .ok_or(MistralRsError::EnginePoisoned)?
822 .clone()
823 }
824 };
825
826 let engines = self
827 .engines
828 .read()
829 .map_err(|_| MistralRsError::SenderPoisoned)?;
830 if let Some(engine_instance) = engines.get(&resolved_model_id) {
831 Ok(engine_instance.config.max_seq_len)
832 } else {
833 Err(MistralRsError::EnginePoisoned)
834 }
835 }
836
837 pub fn next_request_id(&self) -> usize {
838 let l = self.next_request_id.lock().unwrap();
839 let last = &mut *l.borrow_mut();
840 let last_v = *last;
841 *last += 1;
842 last_v
843 }
844
845 pub async fn add_model(
847 &self,
848 model_id: String,
849 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
850 method: SchedulerConfig,
851 config: AddModelConfig,
852 ) -> Result<(), String> {
853 let method = {
855 let pipeline_guard = pipeline.try_lock().unwrap();
856 if !pipeline_guard.get_metadata().no_kv_cache && pipeline_guard.cache().is_hybrid() {
857 info!(
858 "Hybrid model detected (Mamba-Attention), enforcing batch_size=1 for correctness"
859 );
860 SchedulerConfig::DefaultScheduler {
861 method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()),
862 }
863 } else {
864 method
865 }
866 };
867
868 let reboot_state = RebootState {
869 pipeline: pipeline.clone(),
870 method: method.clone(),
871 no_kv_cache: config.engine_config.no_kv_cache,
872 no_prefix_cache: config.engine_config.no_prefix_cache,
873 prefix_cache_n: config.engine_config.prefix_cache_n,
874 disable_eos_stop: config.engine_config.disable_eos_stop,
875 throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
876 search_embedding_model: config.engine_config.search_embedding_model,
877 search_callback: config.engine_config.search_callback.clone(),
878 tool_callbacks: config.engine_config.tool_callbacks.clone(),
879 tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
880 mcp_client_config: config.mcp_client_config.clone(),
881 };
882
883 let engine_instance =
884 Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
885
886 let mut engines = self
887 .engines
888 .write()
889 .map_err(|_| "Failed to acquire write lock on engines")?;
890 engines.insert(model_id.clone(), engine_instance);
891
892 if engines.len() == 1 {
894 let mut default_lock = self
895 .default_engine_id
896 .write()
897 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
898 *default_lock = Some(model_id.clone());
899 }
900
901 Ok(())
902 }
903
904 pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
906 let mut engines = self
907 .engines
908 .write()
909 .map_err(|_| "Failed to acquire write lock on engines")?;
910
911 if engines.len() <= 1 {
912 return Err("Cannot remove the last model from MistralRs".to_string());
913 }
914
915 if let Some(engine_instance) = engines.remove(model_id) {
916 let _ = engine_instance.sender.blocking_send(Request::Terminate);
918
919 let mut default_lock = self
921 .default_engine_id
922 .write()
923 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
924 if let Some(ref default_id) = *default_lock {
925 if default_id == model_id {
926 *default_lock = engines.keys().next().cloned();
928 }
929 }
930
931 Ok(())
932 } else {
933 Err(format!("Model {model_id} not found"))
934 }
935 }
936
937 pub fn list_models(&self) -> Result<Vec<String>, String> {
939 let engines = self
940 .engines
941 .read()
942 .map_err(|_| "Failed to acquire read lock on engines")?;
943 Ok(engines.keys().cloned().collect())
944 }
945
946 pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
948 let default_lock = self
949 .default_engine_id
950 .read()
951 .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
952 Ok(default_lock.clone())
953 }
954
955 pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
957 let engines = self
958 .engines
959 .read()
960 .map_err(|_| "Failed to acquire read lock on engines")?;
961 if !engines.contains_key(model_id) {
962 return Err(format!("Model {model_id} not found"));
963 }
964 drop(engines);
965
966 let mut default_lock = self
967 .default_engine_id
968 .write()
969 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
970 *default_lock = Some(model_id.to_string());
971
972 Ok(())
973 }
974
975 pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
977 let model_id = match &mut request {
978 Request::Normal(normal_req) => normal_req.model_id.as_deref(),
979 _ => None, };
981
982 let sender = self.get_sender(model_id)?;
983 sender
984 .blocking_send(request)
985 .map_err(|_| MistralRsError::SenderPoisoned)
986 }
987
988 pub fn maybe_log_request(this: Arc<Self>, repr: String) {
989 if let Some(file) = &this.log {
990 let mut f = OpenOptions::new()
991 .append(true)
992 .create(true) .open(file)
994 .expect("Unable to open file");
995 let time = chrono::offset::Local::now();
996 f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
997 .expect("Unable to write data");
998 }
999 }
1000
1001 pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
1002 if let Some(file) = &this.log {
1003 let mut f = OpenOptions::new()
1004 .append(true)
1005 .create(true) .open(file)
1007 .expect("Unable to open file");
1008 let time = chrono::offset::Local::now();
1009 let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
1010 f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
1011 .expect("Unable to write data");
1012 }
1013 }
1014
1015 pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
1016 if let Some(file) = &this.log {
1017 let mut f = OpenOptions::new()
1018 .append(true)
1019 .create(true) .open(file)
1021 .expect("Unable to open file");
1022 let time = chrono::offset::Local::now();
1023 f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
1024 .expect("Unable to write data");
1025 }
1026 }
1027
1028 pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
1030 let resolved_model_id = match model_id {
1031 Some(id) => id.to_string(),
1032 None => {
1033 let default_lock = self
1034 .default_engine_id
1035 .read()
1036 .map_err(|_| "Failed to acquire read lock")?;
1037 default_lock
1038 .as_ref()
1039 .ok_or("No default engine set")?
1040 .clone()
1041 }
1042 };
1043
1044 let engines = self
1045 .engines
1046 .read()
1047 .map_err(|_| "Failed to acquire read lock on engines")?;
1048 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1049 Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
1050 } else {
1051 Err(format!("Model {resolved_model_id} not found"))
1052 }
1053 }
1054
1055 pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1057 let resolved_model_id = match model_id {
1058 Some(id) => id.to_string(),
1059 None => {
1060 let default_lock = self
1061 .default_engine_id
1062 .read()
1063 .map_err(|_| "Failed to acquire read lock")?;
1064 default_lock
1065 .as_ref()
1066 .ok_or("No default engine set")?
1067 .clone()
1068 }
1069 };
1070
1071 let engines = self
1072 .engines
1073 .read()
1074 .map_err(|_| "Failed to acquire read lock on engines")?;
1075 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1076 Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1077 } else {
1078 Err(format!("Model {resolved_model_id} not found"))
1079 }
1080 }
1081
1082 pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1084 let resolved_model_id = match model_id {
1085 Some(id) => id.to_string(),
1086 None => {
1087 let default_lock = self
1088 .default_engine_id
1089 .read()
1090 .map_err(|_| "Failed to acquire read lock")?;
1091 default_lock
1092 .as_ref()
1093 .ok_or("No default engine set")?
1094 .clone()
1095 }
1096 };
1097
1098 let engines = self
1099 .engines
1100 .read()
1101 .map_err(|_| "Failed to acquire read lock on engines")?;
1102 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1103 Ok(engine_instance.config.clone())
1104 } else {
1105 Err(format!("Model {resolved_model_id} not found"))
1106 }
1107 }
1108}