1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5 get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6 EngineInstruction, SearchEmbeddingModel, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
7};
8use hf_hub::Cache;
9pub use lora::Ordering;
10pub use pipeline::ModelCategory;
11pub use pipeline::Pipeline;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::exceptions::PyValueError;
14use std::collections::HashMap;
15use std::num::NonZeroUsize;
16use std::sync::OnceLock;
17use std::time::Instant;
18use std::{
19 cell::RefCell,
20 error::Error,
21 fs::OpenOptions,
22 io::Write,
23 sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
24 thread::{self, JoinHandle},
25 time::{SystemTime, UNIX_EPOCH},
26};
27use tokio::sync::mpsc::{channel, Sender};
28use tracing::info;
29use tracing::warn;
30
31mod cuda;
32mod device_map;
33mod engine;
34mod lora;
35mod model_loader;
36mod moe;
37mod ops;
38pub use model_loader::{
39 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
40};
41mod embedding_models;
42mod kv_cache;
43mod search;
44
45mod model_selected;
46pub use model_selected::ModelSelected;
47pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
48
49mod amoe;
50mod attention;
51mod diffusion_models;
52pub mod distributed;
53mod gguf;
54pub mod layers;
55mod layers_masker;
56mod layers_utils;
57pub mod matformer;
58mod models;
59mod paged_attention;
60mod pipeline;
61mod prefix_cacher;
62mod request;
63mod response;
64mod sampler;
65mod scheduler;
66mod sequence;
67mod speech_models;
68mod toml_selector;
69mod tools;
70mod topology;
71mod utils;
72mod vision_models;
73mod xlora_models;
74
75pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
76pub use device_map::{
77 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
78};
79pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
80pub use mistralrs_audio::AudioInput;
81pub use mistralrs_mcp::{
82 CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
83};
84pub use mistralrs_mcp::{
85 McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
86};
87pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
88pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
89pub use pipeline::{
90 chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
91 AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
92 DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader, EmbeddingLoaderBuilder,
93 EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig, GGMLLoader,
94 GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig,
95 GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader,
96 Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind,
97 ModelPaths, MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
98 NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
99 SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
100 SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
101 VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
102};
103pub use request::{
104 ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
105 LlguidanceGrammar, MessageContent, NormalRequest, Request, RequestMessage, SearchContextSize,
106 TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
107};
108pub use response::*;
109pub use sampler::{
110 CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
111};
112pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
113pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
114use serde::Serialize;
115pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
116use tokio::runtime::Runtime;
117use toml_selector::{TomlLoaderArgs, TomlSelector};
118pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
119pub use topology::{LayerTopology, Topology};
120pub use utils::debug::initialize_logging;
121pub use utils::memory_usage::MemoryUsage;
122pub use utils::normal::{ModelDType, TryIntoDType};
123pub use utils::{paged_attn_supported, using_flash_attn};
124
125pub use llguidance;
127
128pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
130pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
131
132#[derive(Clone)]
134pub struct EngineConfig {
135 pub no_kv_cache: bool,
136 pub no_prefix_cache: bool,
137 pub prefix_cache_n: usize,
138 pub disable_eos_stop: bool,
139 pub throughput_logging_enabled: bool,
140 pub search_embedding_model: Option<SearchEmbeddingModel>,
141 pub search_callback: Option<Arc<SearchCallback>>,
142 pub tool_callbacks: tools::ToolCallbacks,
143 pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
144}
145
146impl Default for EngineConfig {
147 fn default() -> Self {
148 Self {
149 no_kv_cache: false,
150 no_prefix_cache: false,
151 prefix_cache_n: 16,
152 disable_eos_stop: false,
153 throughput_logging_enabled: true,
154 search_embedding_model: None,
155 search_callback: None,
156 tool_callbacks: HashMap::new(),
157 tool_callbacks_with_tools: HashMap::new(),
158 }
159 }
160}
161
162#[derive(Clone)]
164pub struct AddModelConfig {
165 pub engine_config: EngineConfig,
166 pub mcp_client_config: Option<McpClientConfig>,
167}
168
169impl AddModelConfig {
170 pub fn new(engine_config: EngineConfig) -> Self {
171 Self {
172 engine_config,
173 mcp_client_config: None,
174 }
175 }
176
177 pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
178 self.mcp_client_config = Some(mcp_config);
179 self
180 }
181}
182
183#[derive(Clone)]
184pub struct MistralRsConfig {
185 pub kind: ModelKind,
186 pub device: Device,
187 pub category: ModelCategory,
188 pub modalities: Modalities,
189 pub max_seq_len: Option<usize>,
190}
191
192struct EngineInstance {
194 sender: Sender<Request>,
195 engine_handler: JoinHandle<()>,
196 reboot_state: RebootState,
197 config: MistralRsConfig,
198 category: ModelCategory,
199}
200
201pub struct MistralRs {
206 engines: RwLock<HashMap<String, EngineInstance>>,
207 default_engine_id: RwLock<Option<String>>,
208 log: Option<String>,
209 id: String,
210 creation_time: u64,
211 next_request_id: Mutex<RefCell<usize>>,
212}
213
214#[derive(Clone)]
215struct RebootState {
216 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
217 method: SchedulerConfig,
218 no_kv_cache: bool,
219 no_prefix_cache: bool,
220 prefix_cache_n: usize,
221 disable_eos_stop: bool,
222 throughput_logging_enabled: bool,
223 search_embedding_model: Option<SearchEmbeddingModel>,
224 search_callback: Option<Arc<search::SearchCallback>>,
225 tool_callbacks: tools::ToolCallbacks,
226 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
227 mcp_client_config: Option<McpClientConfig>,
228}
229
230#[derive(Debug)]
231pub enum MistralRsError {
232 EnginePoisoned,
233 SenderPoisoned,
234}
235
236impl std::fmt::Display for MistralRsError {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 write!(f, "{:?}", &self)
239 }
240}
241
242impl std::error::Error for MistralRsError {}
243
244#[cfg(feature = "pyo3_macros")]
245impl From<MistralRsError> for pyo3::PyErr {
246 fn from(value: MistralRsError) -> Self {
247 PyValueError::new_err(format!("{value:?}"))
248 }
249}
250
251pub struct MistralRsBuilder {
255 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
256 method: SchedulerConfig,
257 log: Option<String>,
258 no_kv_cache: Option<bool>,
259 no_prefix_cache: Option<bool>,
260 prefix_cache_n: Option<usize>,
261 disable_eos_stop: Option<bool>,
262 throughput_logging_enabled: bool,
263 search_embedding_model: Option<SearchEmbeddingModel>,
264 search_callback: Option<Arc<SearchCallback>>,
265 tool_callbacks: tools::ToolCallbacks,
266 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
267 mcp_client_config: Option<McpClientConfig>,
268}
269
270impl MistralRsBuilder {
271 pub fn new(
275 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
276 method: SchedulerConfig,
277 throughput_logging: bool,
278 search_embedding_model: Option<SearchEmbeddingModel>,
279 ) -> Self {
280 Self {
281 pipeline,
282 method,
283 log: None,
284 no_kv_cache: None,
285 no_prefix_cache: None,
286 prefix_cache_n: None,
287 disable_eos_stop: None,
288 throughput_logging_enabled: throughput_logging,
289 search_embedding_model,
290 search_callback: None,
291 tool_callbacks: HashMap::new(),
292 tool_callbacks_with_tools: HashMap::new(),
293 mcp_client_config: None,
294 }
295 }
296 pub fn with_log(mut self, log: String) -> Self {
297 self.log = Some(log);
298 self
299 }
300 pub fn with_opt_log(mut self, log: Option<String>) -> Self {
301 self.log = log;
302 self
303 }
304 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
305 self.no_kv_cache = Some(no_kv_cache);
306 self
307 }
308 pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
309 self.no_prefix_cache = Some(no_prefix_cache);
310 self
311 }
312 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
313 self.prefix_cache_n = Some(prefix_cache_n);
314 self
315 }
316 pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
317 self.disable_eos_stop = Some(disable_eos_stop);
318 self
319 }
320
321 pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
323 self.search_callback = Some(search_callback);
324 self
325 }
326
327 pub fn with_tool_callback(
329 mut self,
330 name: impl Into<String>,
331 tool_callback: Arc<ToolCallback>,
332 ) -> Self {
333 self.tool_callbacks.insert(name.into(), tool_callback);
334 self
335 }
336
337 pub fn with_tool_callback_and_tool(
340 mut self,
341 name: impl Into<String>,
342 tool_callback: Arc<ToolCallback>,
343 tool: Tool,
344 ) -> Self {
345 let name = name.into();
346 self.tool_callbacks_with_tools.insert(
347 name,
348 ToolCallbackWithTool {
349 callback: tool_callback,
350 tool,
351 },
352 );
353 self
354 }
355
356 pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
358 self.mcp_client_config = Some(config);
359 self
360 }
361
362 pub async fn build(self) -> Arc<MistralRs> {
363 MistralRs::new(self).await
364 }
365}
366
367impl Drop for MistralRs {
368 fn drop(&mut self) {
369 if let Ok(engines) = self.engines.read() {
371 for (_, engine) in engines.iter() {
372 let _ = engine.sender.try_send(Request::Terminate);
374 }
375 }
376 }
377}
378
379impl MistralRs {
380 fn create_engine_instance(
382 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
383 method: SchedulerConfig,
384 config: EngineConfig,
385 reboot_state: RebootState,
386 ) -> Result<EngineInstance, String> {
387 let (tx, rx) = channel(10_000);
388
389 let pipeline_guard = pipeline.try_lock().unwrap();
390 let category = pipeline_guard.category();
391 let metadata = pipeline_guard.get_metadata();
392 let kind = metadata.kind.clone();
393 let device = pipeline_guard.device();
394 let modalities = metadata.modalities.clone();
395 let max_seq_len = match &category {
396 ModelCategory::Diffusion | ModelCategory::Speech => None,
397 _ => Some(metadata.max_seq_len),
398 };
399 drop(pipeline_guard);
400
401 info!("Pipeline input modalities are {:?}", &modalities.input);
402 info!("Pipeline output modalities are {:?}", &modalities.output);
403
404 let mistralrs_config = MistralRsConfig {
405 kind,
406 device,
407 category: category.clone(),
408 modalities,
409 max_seq_len,
410 };
411
412 let tx_for_engine = tx.clone();
413 let engine_handler = thread::spawn(move || {
414 #[cfg(feature = "metal")]
415 objc::rc::autoreleasepool(move || {
416 let rt = Runtime::new().unwrap();
417 rt.block_on(async move {
418 let engine = Engine::new(
419 tx_for_engine,
420 rx,
421 pipeline,
422 method,
423 config.no_kv_cache,
424 config.no_prefix_cache,
425 config.prefix_cache_n,
426 config.disable_eos_stop,
427 config.throughput_logging_enabled,
428 config.search_embedding_model,
429 config.search_callback.clone(),
430 config.tool_callbacks.clone(),
431 config.tool_callbacks_with_tools.clone(),
432 )
433 .expect("Engine creation failed.");
434 Arc::new(engine).run().await;
435 })
436 });
437
438 #[cfg(not(feature = "metal"))]
439 {
440 let rt = Runtime::new().unwrap();
441 rt.block_on(async move {
442 let engine = Engine::new(
443 tx_for_engine,
444 rx,
445 pipeline,
446 method,
447 config.no_kv_cache,
448 config.no_prefix_cache,
449 config.prefix_cache_n,
450 config.disable_eos_stop,
451 config.throughput_logging_enabled,
452 config.search_embedding_model,
453 config.search_callback.clone(),
454 config.tool_callbacks.clone(),
455 config.tool_callbacks_with_tools.clone(),
456 )
457 .expect("Engine creation failed.");
458 Arc::new(engine).run().await;
459 })
460 }
461 });
462
463 Ok(EngineInstance {
464 sender: tx,
465 engine_handler,
466 reboot_state,
467 config: mistralrs_config,
468 category,
469 })
470 }
471
472 async fn new(config: MistralRsBuilder) -> Arc<Self> {
473 let MistralRsBuilder {
474 pipeline,
475 method,
476 log,
477 no_kv_cache,
478 no_prefix_cache,
479 prefix_cache_n,
480 disable_eos_stop,
481 throughput_logging_enabled,
482 search_embedding_model,
483 search_callback,
484 tool_callbacks,
485 mut tool_callbacks_with_tools,
486 mcp_client_config,
487 } = config;
488
489 mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
490 get_mut_arcmutex!(pipeline).device(),
491 );
492
493 let method = if !get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache
496 && get_mut_arcmutex!(pipeline).cache().is_hybrid()
497 {
498 info!(
499 "Hybrid model detected (Mamba-Attention), enforcing batch_size=1 for correctness"
500 );
501 SchedulerConfig::DefaultScheduler {
502 method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()),
503 }
504 } else {
505 method
506 };
507
508 let no_kv_cache = no_kv_cache.unwrap_or(false);
509 let no_prefix_cache = no_prefix_cache.unwrap_or(false);
510 let prefix_cache_n = prefix_cache_n.unwrap_or(16);
511 let disable_eos_stop = disable_eos_stop.unwrap_or(false);
512
513 if let Some(config) = &mcp_client_config {
515 let mut mcp_client = McpClient::new(config.clone());
516 let total_servers = config.servers.len();
517
518 match mcp_client.initialize().await {
519 Ok(()) => {
520 let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
521 let tools_count = mcp_callbacks_with_tools.len();
522
523 for (name, callback_with_tool) in mcp_callbacks_with_tools {
525 tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
526 }
527
528 if tools_count == 0 {
529 warn!(
530 "MCP client initialized but no tools were registered from {} servers",
531 total_servers
532 );
533 } else {
534 info!(
535 "MCP client initialized successfully with {} tools from {} servers",
536 tools_count, total_servers
537 );
538 }
539 }
540 Err(e) => {
541 warn!(
542 "Failed to initialize MCP client with {} configured servers: {}",
543 total_servers, e
544 );
545 warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
546 }
547 }
548 }
549
550 let reboot_state = RebootState {
551 pipeline: pipeline.clone(),
552 method: method.clone(),
553 no_kv_cache,
554 no_prefix_cache,
555 prefix_cache_n,
556 disable_eos_stop,
557 throughput_logging_enabled,
558 search_embedding_model,
559 search_callback: search_callback.clone(),
560 tool_callbacks: tool_callbacks.clone(),
561 tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
562 mcp_client_config: mcp_client_config.clone(),
563 };
564
565 let engine_config = EngineConfig {
567 no_kv_cache,
568 no_prefix_cache,
569 prefix_cache_n,
570 disable_eos_stop,
571 throughput_logging_enabled,
572 search_embedding_model,
573 search_callback,
574 tool_callbacks,
575 tool_callbacks_with_tools,
576 };
577
578 let engine_instance =
580 Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
581 .expect("Failed to create engine instance");
582
583 let id = pipeline.try_lock().unwrap().name();
584
585 if distributed::is_daemon() {
586 let request_sender = engine_instance.sender.clone();
587
588 if cfg!(feature = "ring") {
589 distributed::ring_daemon_replicator(request_sender);
591 } else {
592 distributed::nccl_daemon_replicator(request_sender);
594 }
595
596 #[allow(clippy::empty_loop)]
597 loop {}
598 }
599
600 let is_multi_threaded = tokio::runtime::Handle::try_current()
602 .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
603
604 if !distributed::is_daemon()
606 && is_multi_threaded
607 && matches!(
608 engine_instance.category,
609 ModelCategory::Text | ModelCategory::Vision { .. }
610 )
611 {
612 let clone_sender = engine_instance.sender.clone();
613 tokio::task::block_in_place(|| {
614 let (tx, mut rx) = channel(1);
615 let req = Request::Normal(Box::new(NormalRequest {
616 id: 0,
617 messages: RequestMessage::Completion {
618 text: "hello".to_string(),
619 echo_prompt: false,
620 best_of: None,
621 },
622 sampling_params: SamplingParams {
623 max_len: Some(1),
624 ..SamplingParams::deterministic()
625 },
626 response: tx,
627 return_logprobs: false,
628 is_streaming: false,
629 constraint: Constraint::None,
630 suffix: None,
631 tool_choice: None,
632 tools: None,
633 logits_processors: None,
634 return_raw_logits: false,
635 web_search_options: None,
636 model_id: None,
637 truncate_sequence: false,
638 }));
639 info!("Beginning dummy run.");
640 let start = Instant::now();
641 clone_sender.blocking_send(req).unwrap();
642
643 let mut received_any = false;
645 while let Some(_resp) = rx.blocking_recv() {
646 received_any = true;
647 }
648
649 if received_any {
650 let end = Instant::now();
651 info!(
652 "Dummy run completed in {}s.",
653 end.duration_since(start).as_secs_f64()
654 );
655 } else {
656 warn!("Dummy run failed!");
657 }
658 });
659 }
660
661 let mut engines = HashMap::new();
663 engines.insert(id.clone(), engine_instance);
664
665 Arc::new(Self {
666 engines: RwLock::new(engines),
667 default_engine_id: RwLock::new(Some(id.clone())),
668 log,
669 id,
670 creation_time: SystemTime::now()
671 .duration_since(UNIX_EPOCH)
672 .expect("Time travel has occurred!")
673 .as_secs(),
674 next_request_id: Mutex::new(RefCell::new(1)),
675 })
676 }
677
678 fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
680 let mut engines = self.engines.write().map_err(|_| {
681 tracing::warn!("Couldn't get write lock on engines during reboot attempt");
682 MistralRsError::EnginePoisoned
683 })?;
684
685 if let Some(engine_instance) = engines.get(model_id) {
686 if !engine_instance.engine_handler.is_finished() {
687 tracing::info!("Engine {} already running, returning ok", model_id);
688 return Ok(());
689 }
690
691 let reboot_state = engine_instance.reboot_state.clone();
692 let engine_config = EngineConfig {
693 no_kv_cache: reboot_state.no_kv_cache,
694 no_prefix_cache: reboot_state.no_prefix_cache,
695 prefix_cache_n: reboot_state.prefix_cache_n,
696 disable_eos_stop: reboot_state.disable_eos_stop,
697 throughput_logging_enabled: reboot_state.throughput_logging_enabled,
698 search_embedding_model: reboot_state.search_embedding_model,
699 search_callback: reboot_state.search_callback.clone(),
700 tool_callbacks: reboot_state.tool_callbacks.clone(),
701 tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
702 };
703 let new_engine_instance = Self::create_engine_instance(
704 reboot_state.pipeline.clone(),
705 reboot_state.method.clone(),
706 engine_config,
707 reboot_state,
708 )
709 .map_err(|e| {
710 tracing::error!("Failed to create new engine instance: {}", e);
711 MistralRsError::EnginePoisoned
712 })?;
713
714 engines.insert(model_id.to_string(), new_engine_instance);
715 tracing::info!("Successfully rebooted engine {}", model_id);
716 Ok(())
717 } else {
718 Err(MistralRsError::EnginePoisoned)
719 }
720 }
721
722 fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
723 let engines = self.engines.read().map_err(|_| {
724 tracing::warn!("Couldn't get read lock on engines!");
725 MistralRsError::EnginePoisoned
726 })?;
727
728 if let Some(engine_instance) = engines.get(model_id) {
729 Ok(engine_instance.engine_handler.is_finished())
730 } else {
731 Err(MistralRsError::EnginePoisoned)
732 }
733 }
734
735 pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
737 let resolved_model_id = match model_id {
738 Some(id) => id.to_string(),
739 None => {
740 let default_lock = self
741 .default_engine_id
742 .read()
743 .map_err(|_| MistralRsError::SenderPoisoned)?;
744 default_lock
745 .as_ref()
746 .ok_or(MistralRsError::EnginePoisoned)?
747 .clone()
748 }
749 };
750
751 if self.engine_dead(&resolved_model_id)? {
752 tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
753 self.reboot_engine(&resolved_model_id)?
754 }
755
756 let engines = self
757 .engines
758 .read()
759 .map_err(|_| MistralRsError::SenderPoisoned)?;
760 if let Some(engine_instance) = engines.get(&resolved_model_id) {
761 Ok(engine_instance.sender.clone())
762 } else {
763 Err(MistralRsError::EnginePoisoned)
764 }
765 }
766
767 pub fn get_id(&self) -> String {
768 self.id.clone()
769 }
770
771 pub fn get_creation_time(&self) -> u64 {
772 self.creation_time
773 }
774
775 pub fn get_model_category(
777 &self,
778 model_id: Option<&str>,
779 ) -> Result<ModelCategory, MistralRsError> {
780 let resolved_model_id = match model_id {
781 Some(id) => id.to_string(),
782 None => {
783 let default_lock = self
784 .default_engine_id
785 .read()
786 .map_err(|_| MistralRsError::SenderPoisoned)?;
787 default_lock
788 .as_ref()
789 .ok_or(MistralRsError::EnginePoisoned)?
790 .clone()
791 }
792 };
793
794 let engines = self
795 .engines
796 .read()
797 .map_err(|_| MistralRsError::SenderPoisoned)?;
798 if let Some(engine_instance) = engines.get(&resolved_model_id) {
799 Ok(engine_instance.category.clone())
800 } else {
801 Err(MistralRsError::EnginePoisoned)
802 }
803 }
804
805 pub fn max_sequence_length(
807 &self,
808 model_id: Option<&str>,
809 ) -> Result<Option<usize>, MistralRsError> {
810 let resolved_model_id = match model_id {
811 Some(id) => id.to_string(),
812 None => {
813 let default_lock = self
814 .default_engine_id
815 .read()
816 .map_err(|_| MistralRsError::SenderPoisoned)?;
817 default_lock
818 .as_ref()
819 .ok_or(MistralRsError::EnginePoisoned)?
820 .clone()
821 }
822 };
823
824 let engines = self
825 .engines
826 .read()
827 .map_err(|_| MistralRsError::SenderPoisoned)?;
828 if let Some(engine_instance) = engines.get(&resolved_model_id) {
829 Ok(engine_instance.config.max_seq_len)
830 } else {
831 Err(MistralRsError::EnginePoisoned)
832 }
833 }
834
835 pub fn next_request_id(&self) -> usize {
836 let l = self.next_request_id.lock().unwrap();
837 let last = &mut *l.borrow_mut();
838 let last_v = *last;
839 *last += 1;
840 last_v
841 }
842
843 pub async fn add_model(
845 &self,
846 model_id: String,
847 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
848 method: SchedulerConfig,
849 config: AddModelConfig,
850 ) -> Result<(), String> {
851 let method = {
853 let pipeline_guard = pipeline.try_lock().unwrap();
854 if !pipeline_guard.get_metadata().no_kv_cache && pipeline_guard.cache().is_hybrid() {
855 info!(
856 "Hybrid model detected (Mamba-Attention), enforcing batch_size=1 for correctness"
857 );
858 SchedulerConfig::DefaultScheduler {
859 method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()),
860 }
861 } else {
862 method
863 }
864 };
865
866 let reboot_state = RebootState {
867 pipeline: pipeline.clone(),
868 method: method.clone(),
869 no_kv_cache: config.engine_config.no_kv_cache,
870 no_prefix_cache: config.engine_config.no_prefix_cache,
871 prefix_cache_n: config.engine_config.prefix_cache_n,
872 disable_eos_stop: config.engine_config.disable_eos_stop,
873 throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
874 search_embedding_model: config.engine_config.search_embedding_model,
875 search_callback: config.engine_config.search_callback.clone(),
876 tool_callbacks: config.engine_config.tool_callbacks.clone(),
877 tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
878 mcp_client_config: config.mcp_client_config.clone(),
879 };
880
881 let engine_instance =
882 Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
883
884 let mut engines = self
885 .engines
886 .write()
887 .map_err(|_| "Failed to acquire write lock on engines")?;
888 engines.insert(model_id.clone(), engine_instance);
889
890 if engines.len() == 1 {
892 let mut default_lock = self
893 .default_engine_id
894 .write()
895 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
896 *default_lock = Some(model_id.clone());
897 }
898
899 Ok(())
900 }
901
902 pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
904 let mut engines = self
905 .engines
906 .write()
907 .map_err(|_| "Failed to acquire write lock on engines")?;
908
909 if engines.len() <= 1 {
910 return Err("Cannot remove the last model from MistralRs".to_string());
911 }
912
913 if let Some(engine_instance) = engines.remove(model_id) {
914 let _ = engine_instance.sender.blocking_send(Request::Terminate);
916
917 let mut default_lock = self
919 .default_engine_id
920 .write()
921 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
922 if let Some(ref default_id) = *default_lock {
923 if default_id == model_id {
924 *default_lock = engines.keys().next().cloned();
926 }
927 }
928
929 Ok(())
930 } else {
931 Err(format!("Model {model_id} not found"))
932 }
933 }
934
935 pub fn list_models(&self) -> Result<Vec<String>, String> {
937 let engines = self
938 .engines
939 .read()
940 .map_err(|_| "Failed to acquire read lock on engines")?;
941 Ok(engines.keys().cloned().collect())
942 }
943
944 pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
946 let default_lock = self
947 .default_engine_id
948 .read()
949 .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
950 Ok(default_lock.clone())
951 }
952
953 pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
955 let engines = self
956 .engines
957 .read()
958 .map_err(|_| "Failed to acquire read lock on engines")?;
959 if !engines.contains_key(model_id) {
960 return Err(format!("Model {model_id} not found"));
961 }
962 drop(engines);
963
964 let mut default_lock = self
965 .default_engine_id
966 .write()
967 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
968 *default_lock = Some(model_id.to_string());
969
970 Ok(())
971 }
972
973 pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
975 let model_id = match &mut request {
976 Request::Normal(normal_req) => normal_req.model_id.as_deref(),
977 _ => None, };
979
980 let sender = self.get_sender(model_id)?;
981 sender
982 .blocking_send(request)
983 .map_err(|_| MistralRsError::SenderPoisoned)
984 }
985
986 pub fn maybe_log_request(this: Arc<Self>, repr: String) {
987 if let Some(file) = &this.log {
988 let mut f = OpenOptions::new()
989 .append(true)
990 .create(true) .open(file)
992 .expect("Unable to open file");
993 let time = chrono::offset::Local::now();
994 f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
995 .expect("Unable to write data");
996 }
997 }
998
999 pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
1000 if let Some(file) = &this.log {
1001 let mut f = OpenOptions::new()
1002 .append(true)
1003 .create(true) .open(file)
1005 .expect("Unable to open file");
1006 let time = chrono::offset::Local::now();
1007 let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
1008 f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
1009 .expect("Unable to write data");
1010 }
1011 }
1012
1013 pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
1014 if let Some(file) = &this.log {
1015 let mut f = OpenOptions::new()
1016 .append(true)
1017 .create(true) .open(file)
1019 .expect("Unable to open file");
1020 let time = chrono::offset::Local::now();
1021 f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
1022 .expect("Unable to write data");
1023 }
1024 }
1025
1026 pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
1028 let resolved_model_id = match model_id {
1029 Some(id) => id.to_string(),
1030 None => {
1031 let default_lock = self
1032 .default_engine_id
1033 .read()
1034 .map_err(|_| "Failed to acquire read lock")?;
1035 default_lock
1036 .as_ref()
1037 .ok_or("No default engine set")?
1038 .clone()
1039 }
1040 };
1041
1042 let engines = self
1043 .engines
1044 .read()
1045 .map_err(|_| "Failed to acquire read lock on engines")?;
1046 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1047 Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
1048 } else {
1049 Err(format!("Model {resolved_model_id} not found"))
1050 }
1051 }
1052
1053 pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1055 let resolved_model_id = match model_id {
1056 Some(id) => id.to_string(),
1057 None => {
1058 let default_lock = self
1059 .default_engine_id
1060 .read()
1061 .map_err(|_| "Failed to acquire read lock")?;
1062 default_lock
1063 .as_ref()
1064 .ok_or("No default engine set")?
1065 .clone()
1066 }
1067 };
1068
1069 let engines = self
1070 .engines
1071 .read()
1072 .map_err(|_| "Failed to acquire read lock on engines")?;
1073 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1074 Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1075 } else {
1076 Err(format!("Model {resolved_model_id} not found"))
1077 }
1078 }
1079
1080 pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1082 let resolved_model_id = match model_id {
1083 Some(id) => id.to_string(),
1084 None => {
1085 let default_lock = self
1086 .default_engine_id
1087 .read()
1088 .map_err(|_| "Failed to acquire read lock")?;
1089 default_lock
1090 .as_ref()
1091 .ok_or("No default engine set")?
1092 .clone()
1093 }
1094 };
1095
1096 let engines = self
1097 .engines
1098 .read()
1099 .map_err(|_| "Failed to acquire read lock on engines")?;
1100 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1101 Ok(engine_instance.config.clone())
1102 } else {
1103 Err(format!("Model {resolved_model_id} not found"))
1104 }
1105 }
1106}