1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5 get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6 BertEmbeddingModel, EngineInstruction, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
7};
8use hf_hub::Cache;
9pub use lora::Ordering;
10pub use pipeline::ModelCategory;
11pub use pipeline::Pipeline;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::exceptions::PyValueError;
14use std::collections::HashMap;
15use std::sync::OnceLock;
16use std::time::Instant;
17use std::{
18 cell::RefCell,
19 error::Error,
20 fs::OpenOptions,
21 io::Write,
22 sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
23 thread::{self, JoinHandle},
24 time::{SystemTime, UNIX_EPOCH},
25};
26use tokio::sync::mpsc::{channel, Sender};
27use tracing::info;
28use tracing::warn;
29
30mod cuda;
31mod device_map;
32mod engine;
33mod lora;
34mod model_loader;
35mod ops;
36pub use model_loader::{
37 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
38};
39mod kv_cache;
40mod search;
41
42mod model_selected;
43pub use model_selected::ModelSelected;
44pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
45
46mod amoe;
47#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
48mod dummy_paged_attention;
49mod embedding;
50mod gguf;
51pub mod layers;
52mod layers_masker;
53mod layers_utils;
54pub mod matformer;
55mod models;
56#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
57mod paged_attention;
58#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
59use dummy_paged_attention as paged_attention;
60mod attention;
61mod diffusion_models;
62pub mod distributed;
63mod pipeline;
64mod prefix_cacher;
65mod request;
66mod response;
67mod sampler;
68mod scheduler;
69mod sequence;
70mod speech_models;
71mod toml_selector;
72mod tools;
73mod topology;
74mod utils;
75mod vision_models;
76mod xlora_models;
77
78pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
79pub use device_map::{
80 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
81};
82pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
83pub use mistralrs_audio::AudioInput;
84pub use mistralrs_mcp::{
85 CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
86};
87pub use mistralrs_mcp::{
88 McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
89};
90pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
91pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
92pub use pipeline::{
93 chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
94 AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
95 DiffusionLoaderBuilder, DiffusionLoaderType, GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig,
96 GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig, GemmaLoader, Idefics2Loader,
97 IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths,
98 LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind, ModelPaths,
99 MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
100 NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
101 SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
102 SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
103 VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
104};
105pub use request::{
106 ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
107 LlguidanceGrammar, MessageContent, NormalRequest, Request, RequestMessage, SearchContextSize,
108 TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
109};
110pub use response::*;
111pub use sampler::{
112 CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
113};
114pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
115pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
116use serde::Serialize;
117pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
118use tokio::runtime::Runtime;
119use toml_selector::{TomlLoaderArgs, TomlSelector};
120pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
121pub use topology::{LayerTopology, Topology};
122pub use utils::debug::initialize_logging;
123pub use utils::memory_usage::MemoryUsage;
124pub use utils::normal::{ModelDType, TryIntoDType};
125pub use utils::{paged_attn_supported, using_flash_attn};
126
127pub use llguidance;
129
130pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
132pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
133
134#[derive(Clone)]
136pub struct EngineConfig {
137 pub truncate_sequence: bool,
138 pub no_kv_cache: bool,
139 pub no_prefix_cache: bool,
140 pub prefix_cache_n: usize,
141 pub disable_eos_stop: bool,
142 pub throughput_logging_enabled: bool,
143 pub search_embedding_model: Option<BertEmbeddingModel>,
144 pub search_callback: Option<Arc<SearchCallback>>,
145 pub tool_callbacks: tools::ToolCallbacks,
146 pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
147}
148
149impl Default for EngineConfig {
150 fn default() -> Self {
151 Self {
152 truncate_sequence: false,
153 no_kv_cache: false,
154 no_prefix_cache: false,
155 prefix_cache_n: 16,
156 disable_eos_stop: false,
157 throughput_logging_enabled: true,
158 search_embedding_model: None,
159 search_callback: None,
160 tool_callbacks: HashMap::new(),
161 tool_callbacks_with_tools: HashMap::new(),
162 }
163 }
164}
165
166#[derive(Clone)]
168pub struct AddModelConfig {
169 pub engine_config: EngineConfig,
170 pub mcp_client_config: Option<McpClientConfig>,
171}
172
173impl AddModelConfig {
174 pub fn new(engine_config: EngineConfig) -> Self {
175 Self {
176 engine_config,
177 mcp_client_config: None,
178 }
179 }
180
181 pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
182 self.mcp_client_config = Some(mcp_config);
183 self
184 }
185}
186
187#[derive(Clone)]
188pub struct MistralRsConfig {
189 pub kind: ModelKind,
190 pub device: Device,
191 pub category: ModelCategory,
192 pub modalities: Modalities,
193}
194
195struct EngineInstance {
197 sender: Sender<Request>,
198 engine_handler: JoinHandle<()>,
199 reboot_state: RebootState,
200 config: MistralRsConfig,
201 category: ModelCategory,
202}
203
204pub struct MistralRs {
209 engines: RwLock<HashMap<String, EngineInstance>>,
210 default_engine_id: RwLock<Option<String>>,
211 log: Option<String>,
212 id: String,
213 creation_time: u64,
214 next_request_id: Mutex<RefCell<usize>>,
215}
216
217#[derive(Clone)]
218struct RebootState {
219 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
220 method: SchedulerConfig,
221 truncate_sequence: bool,
222 no_kv_cache: bool,
223 no_prefix_cache: bool,
224 prefix_cache_n: usize,
225 disable_eos_stop: bool,
226 throughput_logging_enabled: bool,
227 search_embedding_model: Option<BertEmbeddingModel>,
228 search_callback: Option<Arc<search::SearchCallback>>,
229 tool_callbacks: tools::ToolCallbacks,
230 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
231 mcp_client_config: Option<McpClientConfig>,
232}
233
234#[derive(Debug)]
235pub enum MistralRsError {
236 EnginePoisoned,
237 SenderPoisoned,
238}
239
240impl std::fmt::Display for MistralRsError {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 write!(f, "{:?}", &self)
243 }
244}
245
246impl std::error::Error for MistralRsError {}
247
248#[cfg(feature = "pyo3_macros")]
249impl From<MistralRsError> for pyo3::PyErr {
250 fn from(value: MistralRsError) -> Self {
251 PyValueError::new_err(format!("{value:?}"))
252 }
253}
254
255pub struct MistralRsBuilder {
259 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
260 method: SchedulerConfig,
261 log: Option<String>,
262 truncate_sequence: Option<bool>,
263 no_kv_cache: Option<bool>,
264 no_prefix_cache: Option<bool>,
265 prefix_cache_n: Option<usize>,
266 disable_eos_stop: Option<bool>,
267 throughput_logging_enabled: bool,
268 search_embedding_model: Option<BertEmbeddingModel>,
269 search_callback: Option<Arc<SearchCallback>>,
270 tool_callbacks: tools::ToolCallbacks,
271 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
272 mcp_client_config: Option<McpClientConfig>,
273}
274
275impl MistralRsBuilder {
276 pub fn new(
280 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
281 method: SchedulerConfig,
282 throughput_logging: bool,
283 search_embedding_model: Option<BertEmbeddingModel>,
284 ) -> Self {
285 Self {
286 pipeline,
287 method,
288 log: None,
289 truncate_sequence: None,
290 no_kv_cache: None,
291 no_prefix_cache: None,
292 prefix_cache_n: None,
293 disable_eos_stop: None,
294 throughput_logging_enabled: throughput_logging,
295 search_embedding_model,
296 search_callback: None,
297 tool_callbacks: HashMap::new(),
298 tool_callbacks_with_tools: HashMap::new(),
299 mcp_client_config: None,
300 }
301 }
302 pub fn with_log(mut self, log: String) -> Self {
303 self.log = Some(log);
304 self
305 }
306 pub fn with_opt_log(mut self, log: Option<String>) -> Self {
307 self.log = log;
308 self
309 }
310 pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
311 self.truncate_sequence = Some(truncate_sequence);
312 self
313 }
314 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
315 self.no_kv_cache = Some(no_kv_cache);
316 self
317 }
318 pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
319 self.no_prefix_cache = Some(no_prefix_cache);
320 self
321 }
322 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
323 self.prefix_cache_n = Some(prefix_cache_n);
324 self
325 }
326 pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
327 self.disable_eos_stop = Some(disable_eos_stop);
328 self
329 }
330
331 pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
333 self.search_callback = Some(search_callback);
334 self
335 }
336
337 pub fn with_tool_callback(
339 mut self,
340 name: impl Into<String>,
341 tool_callback: Arc<ToolCallback>,
342 ) -> Self {
343 self.tool_callbacks.insert(name.into(), tool_callback);
344 self
345 }
346
347 pub fn with_tool_callback_and_tool(
350 mut self,
351 name: impl Into<String>,
352 tool_callback: Arc<ToolCallback>,
353 tool: Tool,
354 ) -> Self {
355 let name = name.into();
356 self.tool_callbacks_with_tools.insert(
357 name,
358 ToolCallbackWithTool {
359 callback: tool_callback,
360 tool,
361 },
362 );
363 self
364 }
365
366 pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
368 self.mcp_client_config = Some(config);
369 self
370 }
371
372 pub async fn build(self) -> Arc<MistralRs> {
373 MistralRs::new(self).await
374 }
375}
376
377impl Drop for MistralRs {
378 fn drop(&mut self) {
379 if let Ok(engines) = self.engines.read() {
381 for (_, engine) in engines.iter() {
382 let _ = engine.sender.try_send(Request::Terminate);
384 }
385 }
386 }
387}
388
389impl MistralRs {
390 fn create_engine_instance(
392 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
393 method: SchedulerConfig,
394 config: EngineConfig,
395 reboot_state: RebootState,
396 ) -> Result<EngineInstance, String> {
397 let (tx, rx) = channel(10_000);
398
399 let category = pipeline.try_lock().unwrap().category();
400 let kind = pipeline.try_lock().unwrap().get_metadata().kind.clone();
401 let device = pipeline.try_lock().unwrap().device();
402 let modalities = pipeline
403 .try_lock()
404 .unwrap()
405 .get_metadata()
406 .modalities
407 .clone();
408
409 info!("Pipeline input modalities are {:?}", &modalities.input);
410 info!("Pipeline output modalities are {:?}", &modalities.output);
411
412 let mistralrs_config = MistralRsConfig {
413 kind,
414 device,
415 category: category.clone(),
416 modalities,
417 };
418
419 let engine_handler = thread::spawn(move || {
420 #[cfg(feature = "metal")]
421 objc::rc::autoreleasepool(move || {
422 let rt = Runtime::new().unwrap();
423 rt.block_on(async move {
424 let engine = Engine::new(
425 rx,
426 pipeline,
427 method,
428 config.truncate_sequence,
429 config.no_kv_cache,
430 config.no_prefix_cache,
431 config.prefix_cache_n,
432 config.disable_eos_stop,
433 config.throughput_logging_enabled,
434 config.search_embedding_model,
435 config.search_callback.clone(),
436 config.tool_callbacks.clone(),
437 config.tool_callbacks_with_tools.clone(),
438 )
439 .expect("Engine creation failed.");
440 Arc::new(engine).run().await;
441 })
442 });
443
444 #[cfg(not(feature = "metal"))]
445 {
446 let rt = Runtime::new().unwrap();
447 rt.block_on(async move {
448 let engine = Engine::new(
449 rx,
450 pipeline,
451 method,
452 config.truncate_sequence,
453 config.no_kv_cache,
454 config.no_prefix_cache,
455 config.prefix_cache_n,
456 config.disable_eos_stop,
457 config.throughput_logging_enabled,
458 config.search_embedding_model,
459 config.search_callback.clone(),
460 config.tool_callbacks.clone(),
461 config.tool_callbacks_with_tools.clone(),
462 )
463 .expect("Engine creation failed.");
464 Arc::new(engine).run().await;
465 })
466 }
467 });
468
469 Ok(EngineInstance {
470 sender: tx,
471 engine_handler,
472 reboot_state,
473 config: mistralrs_config,
474 category,
475 })
476 }
477
478 async fn new(config: MistralRsBuilder) -> Arc<Self> {
479 let MistralRsBuilder {
480 pipeline,
481 method,
482 log,
483 truncate_sequence,
484 no_kv_cache,
485 no_prefix_cache,
486 prefix_cache_n,
487 disable_eos_stop,
488 throughput_logging_enabled,
489 search_embedding_model,
490 search_callback,
491 tool_callbacks,
492 mut tool_callbacks_with_tools,
493 mcp_client_config,
494 } = config;
495
496 mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
497 get_mut_arcmutex!(pipeline).device(),
498 );
499
500 let truncate_sequence = truncate_sequence.unwrap_or(false);
501 let no_kv_cache = no_kv_cache.unwrap_or(false);
502 let no_prefix_cache = no_prefix_cache.unwrap_or(false);
503 let prefix_cache_n = prefix_cache_n.unwrap_or(16);
504 let disable_eos_stop = disable_eos_stop.unwrap_or(false);
505
506 if let Some(config) = &mcp_client_config {
508 let mut mcp_client = McpClient::new(config.clone());
509 let total_servers = config.servers.len();
510
511 match mcp_client.initialize().await {
512 Ok(()) => {
513 let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
514 let tools_count = mcp_callbacks_with_tools.len();
515
516 for (name, callback_with_tool) in mcp_callbacks_with_tools {
518 tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
519 }
520
521 if tools_count == 0 {
522 warn!(
523 "MCP client initialized but no tools were registered from {} servers",
524 total_servers
525 );
526 } else {
527 info!(
528 "MCP client initialized successfully with {} tools from {} servers",
529 tools_count, total_servers
530 );
531 }
532 }
533 Err(e) => {
534 warn!(
535 "Failed to initialize MCP client with {} configured servers: {}",
536 total_servers, e
537 );
538 warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
539 }
540 }
541 }
542
543 let reboot_state = RebootState {
544 pipeline: pipeline.clone(),
545 method: method.clone(),
546 truncate_sequence,
547 no_kv_cache,
548 no_prefix_cache,
549 prefix_cache_n,
550 disable_eos_stop,
551 throughput_logging_enabled,
552 search_embedding_model: search_embedding_model.clone(),
553 search_callback: search_callback.clone(),
554 tool_callbacks: tool_callbacks.clone(),
555 tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
556 mcp_client_config: mcp_client_config.clone(),
557 };
558
559 let engine_config = EngineConfig {
561 truncate_sequence,
562 no_kv_cache,
563 no_prefix_cache,
564 prefix_cache_n,
565 disable_eos_stop,
566 throughput_logging_enabled,
567 search_embedding_model,
568 search_callback,
569 tool_callbacks,
570 tool_callbacks_with_tools,
571 };
572
573 let engine_instance =
575 Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
576 .expect("Failed to create engine instance");
577
578 let id = pipeline.try_lock().unwrap().name();
579
580 if distributed::is_daemon() {
581 let request_sender = engine_instance.sender.clone();
582
583 if cfg!(feature = "ring") {
584 distributed::ring_daemon_replicator(request_sender);
586 } else {
587 distributed::nccl_daemon_replicator(request_sender);
589 }
590
591 #[allow(clippy::empty_loop)]
592 loop {}
593 }
594
595 let is_multi_threaded = tokio::runtime::Handle::try_current()
597 .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
598
599 if !distributed::is_daemon()
601 && is_multi_threaded
602 && matches!(
603 engine_instance.category,
604 ModelCategory::Text | ModelCategory::Vision { .. }
605 )
606 {
607 let clone_sender = engine_instance.sender.clone();
608 tokio::task::block_in_place(|| {
609 let (tx, mut rx) = channel(1);
610 let req = Request::Normal(Box::new(NormalRequest {
611 id: 0,
612 messages: RequestMessage::Completion {
613 text: "hello".to_string(),
614 echo_prompt: false,
615 best_of: None,
616 },
617 sampling_params: SamplingParams {
618 max_len: Some(1),
619 ..SamplingParams::deterministic()
620 },
621 response: tx,
622 return_logprobs: false,
623 is_streaming: false,
624 constraint: Constraint::None,
625 suffix: None,
626 tool_choice: None,
627 tools: None,
628 logits_processors: None,
629 return_raw_logits: false,
630 web_search_options: None,
631 model_id: None,
632 }));
633 info!("Beginning dummy run.");
634 let start = Instant::now();
635 clone_sender.blocking_send(req).unwrap();
636
637 let mut received_any = false;
639 while let Some(_resp) = rx.blocking_recv() {
640 received_any = true;
641 }
642
643 if received_any {
644 let end = Instant::now();
645 info!(
646 "Dummy run completed in {}s.",
647 end.duration_since(start).as_secs_f64()
648 );
649 } else {
650 warn!("Dummy run failed!");
651 }
652 });
653 }
654
655 let mut engines = HashMap::new();
657 engines.insert(id.clone(), engine_instance);
658
659 Arc::new(Self {
660 engines: RwLock::new(engines),
661 default_engine_id: RwLock::new(Some(id.clone())),
662 log,
663 id,
664 creation_time: SystemTime::now()
665 .duration_since(UNIX_EPOCH)
666 .expect("Time travel has occurred!")
667 .as_secs(),
668 next_request_id: Mutex::new(RefCell::new(1)),
669 })
670 }
671
672 fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
674 let mut engines = self.engines.write().map_err(|_| {
675 tracing::warn!("Couldn't get write lock on engines during reboot attempt");
676 MistralRsError::EnginePoisoned
677 })?;
678
679 if let Some(engine_instance) = engines.get(model_id) {
680 if !engine_instance.engine_handler.is_finished() {
681 tracing::info!("Engine {} already running, returning ok", model_id);
682 return Ok(());
683 }
684
685 let reboot_state = engine_instance.reboot_state.clone();
686 let engine_config = EngineConfig {
687 truncate_sequence: reboot_state.truncate_sequence,
688 no_kv_cache: reboot_state.no_kv_cache,
689 no_prefix_cache: reboot_state.no_prefix_cache,
690 prefix_cache_n: reboot_state.prefix_cache_n,
691 disable_eos_stop: reboot_state.disable_eos_stop,
692 throughput_logging_enabled: reboot_state.throughput_logging_enabled,
693 search_embedding_model: reboot_state.search_embedding_model.clone(),
694 search_callback: reboot_state.search_callback.clone(),
695 tool_callbacks: reboot_state.tool_callbacks.clone(),
696 tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
697 };
698 let new_engine_instance = Self::create_engine_instance(
699 reboot_state.pipeline.clone(),
700 reboot_state.method.clone(),
701 engine_config,
702 reboot_state,
703 )
704 .map_err(|e| {
705 tracing::error!("Failed to create new engine instance: {}", e);
706 MistralRsError::EnginePoisoned
707 })?;
708
709 engines.insert(model_id.to_string(), new_engine_instance);
710 tracing::info!("Successfully rebooted engine {}", model_id);
711 Ok(())
712 } else {
713 Err(MistralRsError::EnginePoisoned)
714 }
715 }
716
717 fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
718 let engines = self.engines.read().map_err(|_| {
719 tracing::warn!("Couldn't get read lock on engines!");
720 MistralRsError::EnginePoisoned
721 })?;
722
723 if let Some(engine_instance) = engines.get(model_id) {
724 Ok(engine_instance.engine_handler.is_finished())
725 } else {
726 Err(MistralRsError::EnginePoisoned)
727 }
728 }
729
730 pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
732 let resolved_model_id = match model_id {
733 Some(id) => id.to_string(),
734 None => {
735 let default_lock = self
736 .default_engine_id
737 .read()
738 .map_err(|_| MistralRsError::SenderPoisoned)?;
739 default_lock
740 .as_ref()
741 .ok_or(MistralRsError::EnginePoisoned)?
742 .clone()
743 }
744 };
745
746 if self.engine_dead(&resolved_model_id)? {
747 tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
748 self.reboot_engine(&resolved_model_id)?
749 }
750
751 let engines = self
752 .engines
753 .read()
754 .map_err(|_| MistralRsError::SenderPoisoned)?;
755 if let Some(engine_instance) = engines.get(&resolved_model_id) {
756 Ok(engine_instance.sender.clone())
757 } else {
758 Err(MistralRsError::EnginePoisoned)
759 }
760 }
761
762 pub fn get_id(&self) -> String {
763 self.id.clone()
764 }
765
766 pub fn get_creation_time(&self) -> u64 {
767 self.creation_time
768 }
769
770 pub fn get_model_category(
772 &self,
773 model_id: Option<&str>,
774 ) -> Result<ModelCategory, MistralRsError> {
775 let resolved_model_id = match model_id {
776 Some(id) => id.to_string(),
777 None => {
778 let default_lock = self
779 .default_engine_id
780 .read()
781 .map_err(|_| MistralRsError::SenderPoisoned)?;
782 default_lock
783 .as_ref()
784 .ok_or(MistralRsError::EnginePoisoned)?
785 .clone()
786 }
787 };
788
789 let engines = self
790 .engines
791 .read()
792 .map_err(|_| MistralRsError::SenderPoisoned)?;
793 if let Some(engine_instance) = engines.get(&resolved_model_id) {
794 Ok(engine_instance.category.clone())
795 } else {
796 Err(MistralRsError::EnginePoisoned)
797 }
798 }
799
800 pub fn next_request_id(&self) -> usize {
801 let l = self.next_request_id.lock().unwrap();
802 let last = &mut *l.borrow_mut();
803 let last_v = *last;
804 *last += 1;
805 last_v
806 }
807
808 pub async fn add_model(
810 &self,
811 model_id: String,
812 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
813 method: SchedulerConfig,
814 config: AddModelConfig,
815 ) -> Result<(), String> {
816 let reboot_state = RebootState {
817 pipeline: pipeline.clone(),
818 method: method.clone(),
819 truncate_sequence: config.engine_config.truncate_sequence,
820 no_kv_cache: config.engine_config.no_kv_cache,
821 no_prefix_cache: config.engine_config.no_prefix_cache,
822 prefix_cache_n: config.engine_config.prefix_cache_n,
823 disable_eos_stop: config.engine_config.disable_eos_stop,
824 throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
825 search_embedding_model: config.engine_config.search_embedding_model.clone(),
826 search_callback: config.engine_config.search_callback.clone(),
827 tool_callbacks: config.engine_config.tool_callbacks.clone(),
828 tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
829 mcp_client_config: config.mcp_client_config.clone(),
830 };
831
832 let engine_instance =
833 Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
834
835 let mut engines = self
836 .engines
837 .write()
838 .map_err(|_| "Failed to acquire write lock on engines")?;
839 engines.insert(model_id.clone(), engine_instance);
840
841 if engines.len() == 1 {
843 let mut default_lock = self
844 .default_engine_id
845 .write()
846 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
847 *default_lock = Some(model_id.clone());
848 }
849
850 Ok(())
851 }
852
853 pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
855 let mut engines = self
856 .engines
857 .write()
858 .map_err(|_| "Failed to acquire write lock on engines")?;
859
860 if engines.len() <= 1 {
861 return Err("Cannot remove the last model from MistralRs".to_string());
862 }
863
864 if let Some(engine_instance) = engines.remove(model_id) {
865 let _ = engine_instance.sender.blocking_send(Request::Terminate);
867
868 let mut default_lock = self
870 .default_engine_id
871 .write()
872 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
873 if let Some(ref default_id) = *default_lock {
874 if default_id == model_id {
875 *default_lock = engines.keys().next().cloned();
877 }
878 }
879
880 Ok(())
881 } else {
882 Err(format!("Model {model_id} not found"))
883 }
884 }
885
886 pub fn list_models(&self) -> Result<Vec<String>, String> {
888 let engines = self
889 .engines
890 .read()
891 .map_err(|_| "Failed to acquire read lock on engines")?;
892 Ok(engines.keys().cloned().collect())
893 }
894
895 pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
897 let default_lock = self
898 .default_engine_id
899 .read()
900 .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
901 Ok(default_lock.clone())
902 }
903
904 pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
906 let engines = self
907 .engines
908 .read()
909 .map_err(|_| "Failed to acquire read lock on engines")?;
910 if !engines.contains_key(model_id) {
911 return Err(format!("Model {model_id} not found"));
912 }
913 drop(engines);
914
915 let mut default_lock = self
916 .default_engine_id
917 .write()
918 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
919 *default_lock = Some(model_id.to_string());
920
921 Ok(())
922 }
923
924 pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
926 let model_id = match &mut request {
927 Request::Normal(normal_req) => normal_req.model_id.as_deref(),
928 _ => None, };
930
931 let sender = self.get_sender(model_id)?;
932 sender
933 .blocking_send(request)
934 .map_err(|_| MistralRsError::SenderPoisoned)
935 }
936
937 pub fn maybe_log_request(this: Arc<Self>, repr: String) {
938 if let Some(file) = &this.log {
939 let mut f = OpenOptions::new()
940 .append(true)
941 .create(true) .open(file)
943 .expect("Unable to open file");
944 let time = chrono::offset::Local::now();
945 f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
946 .expect("Unable to write data");
947 }
948 }
949
950 pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
951 if let Some(file) = &this.log {
952 let mut f = OpenOptions::new()
953 .append(true)
954 .create(true) .open(file)
956 .expect("Unable to open file");
957 let time = chrono::offset::Local::now();
958 let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
959 f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
960 .expect("Unable to write data");
961 }
962 }
963
964 pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
965 if let Some(file) = &this.log {
966 let mut f = OpenOptions::new()
967 .append(true)
968 .create(true) .open(file)
970 .expect("Unable to open file");
971 let time = chrono::offset::Local::now();
972 f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
973 .expect("Unable to write data");
974 }
975 }
976
977 pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
979 let resolved_model_id = match model_id {
980 Some(id) => id.to_string(),
981 None => {
982 let default_lock = self
983 .default_engine_id
984 .read()
985 .map_err(|_| "Failed to acquire read lock")?;
986 default_lock
987 .as_ref()
988 .ok_or("No default engine set")?
989 .clone()
990 }
991 };
992
993 let engines = self
994 .engines
995 .read()
996 .map_err(|_| "Failed to acquire read lock on engines")?;
997 if let Some(engine_instance) = engines.get(&resolved_model_id) {
998 Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
999 } else {
1000 Err(format!("Model {resolved_model_id} not found"))
1001 }
1002 }
1003
1004 pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1006 let resolved_model_id = match model_id {
1007 Some(id) => id.to_string(),
1008 None => {
1009 let default_lock = self
1010 .default_engine_id
1011 .read()
1012 .map_err(|_| "Failed to acquire read lock")?;
1013 default_lock
1014 .as_ref()
1015 .ok_or("No default engine set")?
1016 .clone()
1017 }
1018 };
1019
1020 let engines = self
1021 .engines
1022 .read()
1023 .map_err(|_| "Failed to acquire read lock on engines")?;
1024 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1025 Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1026 } else {
1027 Err(format!("Model {resolved_model_id} not found"))
1028 }
1029 }
1030
1031 pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1033 let resolved_model_id = match model_id {
1034 Some(id) => id.to_string(),
1035 None => {
1036 let default_lock = self
1037 .default_engine_id
1038 .read()
1039 .map_err(|_| "Failed to acquire read lock")?;
1040 default_lock
1041 .as_ref()
1042 .ok_or("No default engine set")?
1043 .clone()
1044 }
1045 };
1046
1047 let engines = self
1048 .engines
1049 .read()
1050 .map_err(|_| "Failed to acquire read lock on engines")?;
1051 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1052 Ok(engine_instance.config.clone())
1053 } else {
1054 Err(format!("Model {resolved_model_id} not found"))
1055 }
1056 }
1057}