1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5 get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6 BertEmbeddingModel, EngineInstruction, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
7};
8use hf_hub::Cache;
9pub use lora::Ordering;
10pub use pipeline::ModelCategory;
11pub use pipeline::Pipeline;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::exceptions::PyValueError;
14use std::collections::HashMap;
15use std::sync::OnceLock;
16use std::time::Instant;
17use std::{
18 cell::RefCell,
19 error::Error,
20 fs::OpenOptions,
21 io::Write,
22 sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
23 thread::{self, JoinHandle},
24 time::{SystemTime, UNIX_EPOCH},
25};
26use tokio::sync::mpsc::{channel, Sender};
27use tracing::info;
28use tracing::warn;
29
30mod cuda;
31mod device_map;
32mod engine;
33mod lora;
34mod model_loader;
35mod ops;
36pub use model_loader::{
37 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
38};
39mod embedding_models;
40mod kv_cache;
41mod search;
42
43mod model_selected;
44pub use model_selected::ModelSelected;
45pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
46
47mod amoe;
48mod attention;
49mod diffusion_models;
50pub mod distributed;
51mod embedding;
52mod gguf;
53pub mod layers;
54mod layers_masker;
55mod layers_utils;
56pub mod matformer;
57mod models;
58mod paged_attention;
59mod pipeline;
60mod prefix_cacher;
61mod request;
62mod response;
63mod sampler;
64mod scheduler;
65mod sequence;
66mod speech_models;
67mod toml_selector;
68mod tools;
69mod topology;
70mod utils;
71mod vision_models;
72mod xlora_models;
73
74pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
75pub use device_map::{
76 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
77};
78pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
79pub use mistralrs_audio::AudioInput;
80pub use mistralrs_mcp::{
81 CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
82};
83pub use mistralrs_mcp::{
84 McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
85};
86pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
87pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
88pub use pipeline::{
89 chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
90 AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
91 DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader, EmbeddingLoaderBuilder,
92 EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig, GGMLLoader,
93 GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig,
94 GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader,
95 Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind,
96 ModelPaths, MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
97 NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
98 SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
99 SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
100 VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
101};
102pub use request::{
103 ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
104 LlguidanceGrammar, MessageContent, NormalRequest, Request, RequestMessage, SearchContextSize,
105 TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
106};
107pub use response::*;
108pub use sampler::{
109 CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
110};
111pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
112pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
113use serde::Serialize;
114pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
115use tokio::runtime::Runtime;
116use toml_selector::{TomlLoaderArgs, TomlSelector};
117pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
118pub use topology::{LayerTopology, Topology};
119pub use utils::debug::initialize_logging;
120pub use utils::memory_usage::MemoryUsage;
121pub use utils::normal::{ModelDType, TryIntoDType};
122pub use utils::{paged_attn_supported, using_flash_attn};
123
124pub use llguidance;
126
127pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
129pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
130
131#[derive(Clone)]
133pub struct EngineConfig {
134 pub no_kv_cache: bool,
135 pub no_prefix_cache: bool,
136 pub prefix_cache_n: usize,
137 pub disable_eos_stop: bool,
138 pub throughput_logging_enabled: bool,
139 pub search_embedding_model: Option<BertEmbeddingModel>,
140 pub search_callback: Option<Arc<SearchCallback>>,
141 pub tool_callbacks: tools::ToolCallbacks,
142 pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
143}
144
145impl Default for EngineConfig {
146 fn default() -> Self {
147 Self {
148 no_kv_cache: false,
149 no_prefix_cache: false,
150 prefix_cache_n: 16,
151 disable_eos_stop: false,
152 throughput_logging_enabled: true,
153 search_embedding_model: None,
154 search_callback: None,
155 tool_callbacks: HashMap::new(),
156 tool_callbacks_with_tools: HashMap::new(),
157 }
158 }
159}
160
161#[derive(Clone)]
163pub struct AddModelConfig {
164 pub engine_config: EngineConfig,
165 pub mcp_client_config: Option<McpClientConfig>,
166}
167
168impl AddModelConfig {
169 pub fn new(engine_config: EngineConfig) -> Self {
170 Self {
171 engine_config,
172 mcp_client_config: None,
173 }
174 }
175
176 pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
177 self.mcp_client_config = Some(mcp_config);
178 self
179 }
180}
181
182#[derive(Clone)]
183pub struct MistralRsConfig {
184 pub kind: ModelKind,
185 pub device: Device,
186 pub category: ModelCategory,
187 pub modalities: Modalities,
188}
189
190struct EngineInstance {
192 sender: Sender<Request>,
193 engine_handler: JoinHandle<()>,
194 reboot_state: RebootState,
195 config: MistralRsConfig,
196 category: ModelCategory,
197}
198
199pub struct MistralRs {
204 engines: RwLock<HashMap<String, EngineInstance>>,
205 default_engine_id: RwLock<Option<String>>,
206 log: Option<String>,
207 id: String,
208 creation_time: u64,
209 next_request_id: Mutex<RefCell<usize>>,
210}
211
212#[derive(Clone)]
213struct RebootState {
214 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
215 method: SchedulerConfig,
216 no_kv_cache: bool,
217 no_prefix_cache: bool,
218 prefix_cache_n: usize,
219 disable_eos_stop: bool,
220 throughput_logging_enabled: bool,
221 search_embedding_model: Option<BertEmbeddingModel>,
222 search_callback: Option<Arc<search::SearchCallback>>,
223 tool_callbacks: tools::ToolCallbacks,
224 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
225 mcp_client_config: Option<McpClientConfig>,
226}
227
228#[derive(Debug)]
229pub enum MistralRsError {
230 EnginePoisoned,
231 SenderPoisoned,
232}
233
234impl std::fmt::Display for MistralRsError {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 write!(f, "{:?}", &self)
237 }
238}
239
240impl std::error::Error for MistralRsError {}
241
242#[cfg(feature = "pyo3_macros")]
243impl From<MistralRsError> for pyo3::PyErr {
244 fn from(value: MistralRsError) -> Self {
245 PyValueError::new_err(format!("{value:?}"))
246 }
247}
248
249pub struct MistralRsBuilder {
253 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
254 method: SchedulerConfig,
255 log: Option<String>,
256 no_kv_cache: Option<bool>,
257 no_prefix_cache: Option<bool>,
258 prefix_cache_n: Option<usize>,
259 disable_eos_stop: Option<bool>,
260 throughput_logging_enabled: bool,
261 search_embedding_model: Option<BertEmbeddingModel>,
262 search_callback: Option<Arc<SearchCallback>>,
263 tool_callbacks: tools::ToolCallbacks,
264 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
265 mcp_client_config: Option<McpClientConfig>,
266}
267
268impl MistralRsBuilder {
269 pub fn new(
273 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
274 method: SchedulerConfig,
275 throughput_logging: bool,
276 search_embedding_model: Option<BertEmbeddingModel>,
277 ) -> Self {
278 Self {
279 pipeline,
280 method,
281 log: None,
282 no_kv_cache: None,
283 no_prefix_cache: None,
284 prefix_cache_n: None,
285 disable_eos_stop: None,
286 throughput_logging_enabled: throughput_logging,
287 search_embedding_model,
288 search_callback: None,
289 tool_callbacks: HashMap::new(),
290 tool_callbacks_with_tools: HashMap::new(),
291 mcp_client_config: None,
292 }
293 }
294 pub fn with_log(mut self, log: String) -> Self {
295 self.log = Some(log);
296 self
297 }
298 pub fn with_opt_log(mut self, log: Option<String>) -> Self {
299 self.log = log;
300 self
301 }
302 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
303 self.no_kv_cache = Some(no_kv_cache);
304 self
305 }
306 pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
307 self.no_prefix_cache = Some(no_prefix_cache);
308 self
309 }
310 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
311 self.prefix_cache_n = Some(prefix_cache_n);
312 self
313 }
314 pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
315 self.disable_eos_stop = Some(disable_eos_stop);
316 self
317 }
318
319 pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
321 self.search_callback = Some(search_callback);
322 self
323 }
324
325 pub fn with_tool_callback(
327 mut self,
328 name: impl Into<String>,
329 tool_callback: Arc<ToolCallback>,
330 ) -> Self {
331 self.tool_callbacks.insert(name.into(), tool_callback);
332 self
333 }
334
335 pub fn with_tool_callback_and_tool(
338 mut self,
339 name: impl Into<String>,
340 tool_callback: Arc<ToolCallback>,
341 tool: Tool,
342 ) -> Self {
343 let name = name.into();
344 self.tool_callbacks_with_tools.insert(
345 name,
346 ToolCallbackWithTool {
347 callback: tool_callback,
348 tool,
349 },
350 );
351 self
352 }
353
354 pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
356 self.mcp_client_config = Some(config);
357 self
358 }
359
360 pub async fn build(self) -> Arc<MistralRs> {
361 MistralRs::new(self).await
362 }
363}
364
365impl Drop for MistralRs {
366 fn drop(&mut self) {
367 if let Ok(engines) = self.engines.read() {
369 for (_, engine) in engines.iter() {
370 let _ = engine.sender.try_send(Request::Terminate);
372 }
373 }
374 }
375}
376
377impl MistralRs {
378 fn create_engine_instance(
380 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
381 method: SchedulerConfig,
382 config: EngineConfig,
383 reboot_state: RebootState,
384 ) -> Result<EngineInstance, String> {
385 let (tx, rx) = channel(10_000);
386
387 let category = pipeline.try_lock().unwrap().category();
388 let kind = pipeline.try_lock().unwrap().get_metadata().kind.clone();
389 let device = pipeline.try_lock().unwrap().device();
390 let modalities = pipeline
391 .try_lock()
392 .unwrap()
393 .get_metadata()
394 .modalities
395 .clone();
396
397 info!("Pipeline input modalities are {:?}", &modalities.input);
398 info!("Pipeline output modalities are {:?}", &modalities.output);
399
400 let mistralrs_config = MistralRsConfig {
401 kind,
402 device,
403 category: category.clone(),
404 modalities,
405 };
406
407 let engine_handler = thread::spawn(move || {
408 #[cfg(feature = "metal")]
409 objc::rc::autoreleasepool(move || {
410 let rt = Runtime::new().unwrap();
411 rt.block_on(async move {
412 let engine = Engine::new(
413 rx,
414 pipeline,
415 method,
416 config.no_kv_cache,
417 config.no_prefix_cache,
418 config.prefix_cache_n,
419 config.disable_eos_stop,
420 config.throughput_logging_enabled,
421 config.search_embedding_model,
422 config.search_callback.clone(),
423 config.tool_callbacks.clone(),
424 config.tool_callbacks_with_tools.clone(),
425 )
426 .expect("Engine creation failed.");
427 Arc::new(engine).run().await;
428 })
429 });
430
431 #[cfg(not(feature = "metal"))]
432 {
433 let rt = Runtime::new().unwrap();
434 rt.block_on(async move {
435 let engine = Engine::new(
436 rx,
437 pipeline,
438 method,
439 config.no_kv_cache,
440 config.no_prefix_cache,
441 config.prefix_cache_n,
442 config.disable_eos_stop,
443 config.throughput_logging_enabled,
444 config.search_embedding_model,
445 config.search_callback.clone(),
446 config.tool_callbacks.clone(),
447 config.tool_callbacks_with_tools.clone(),
448 )
449 .expect("Engine creation failed.");
450 Arc::new(engine).run().await;
451 })
452 }
453 });
454
455 Ok(EngineInstance {
456 sender: tx,
457 engine_handler,
458 reboot_state,
459 config: mistralrs_config,
460 category,
461 })
462 }
463
464 async fn new(config: MistralRsBuilder) -> Arc<Self> {
465 let MistralRsBuilder {
466 pipeline,
467 method,
468 log,
469 no_kv_cache,
470 no_prefix_cache,
471 prefix_cache_n,
472 disable_eos_stop,
473 throughput_logging_enabled,
474 search_embedding_model,
475 search_callback,
476 tool_callbacks,
477 mut tool_callbacks_with_tools,
478 mcp_client_config,
479 } = config;
480
481 mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
482 get_mut_arcmutex!(pipeline).device(),
483 );
484
485 let no_kv_cache = no_kv_cache.unwrap_or(false);
486 let no_prefix_cache = no_prefix_cache.unwrap_or(false);
487 let prefix_cache_n = prefix_cache_n.unwrap_or(16);
488 let disable_eos_stop = disable_eos_stop.unwrap_or(false);
489
490 if let Some(config) = &mcp_client_config {
492 let mut mcp_client = McpClient::new(config.clone());
493 let total_servers = config.servers.len();
494
495 match mcp_client.initialize().await {
496 Ok(()) => {
497 let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
498 let tools_count = mcp_callbacks_with_tools.len();
499
500 for (name, callback_with_tool) in mcp_callbacks_with_tools {
502 tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
503 }
504
505 if tools_count == 0 {
506 warn!(
507 "MCP client initialized but no tools were registered from {} servers",
508 total_servers
509 );
510 } else {
511 info!(
512 "MCP client initialized successfully with {} tools from {} servers",
513 tools_count, total_servers
514 );
515 }
516 }
517 Err(e) => {
518 warn!(
519 "Failed to initialize MCP client with {} configured servers: {}",
520 total_servers, e
521 );
522 warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
523 }
524 }
525 }
526
527 let reboot_state = RebootState {
528 pipeline: pipeline.clone(),
529 method: method.clone(),
530 no_kv_cache,
531 no_prefix_cache,
532 prefix_cache_n,
533 disable_eos_stop,
534 throughput_logging_enabled,
535 search_embedding_model: search_embedding_model.clone(),
536 search_callback: search_callback.clone(),
537 tool_callbacks: tool_callbacks.clone(),
538 tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
539 mcp_client_config: mcp_client_config.clone(),
540 };
541
542 let engine_config = EngineConfig {
544 no_kv_cache,
545 no_prefix_cache,
546 prefix_cache_n,
547 disable_eos_stop,
548 throughput_logging_enabled,
549 search_embedding_model,
550 search_callback,
551 tool_callbacks,
552 tool_callbacks_with_tools,
553 };
554
555 let engine_instance =
557 Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
558 .expect("Failed to create engine instance");
559
560 let id = pipeline.try_lock().unwrap().name();
561
562 if distributed::is_daemon() {
563 let request_sender = engine_instance.sender.clone();
564
565 if cfg!(feature = "ring") {
566 distributed::ring_daemon_replicator(request_sender);
568 } else {
569 distributed::nccl_daemon_replicator(request_sender);
571 }
572
573 #[allow(clippy::empty_loop)]
574 loop {}
575 }
576
577 let is_multi_threaded = tokio::runtime::Handle::try_current()
579 .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
580
581 if !distributed::is_daemon()
583 && is_multi_threaded
584 && matches!(
585 engine_instance.category,
586 ModelCategory::Text | ModelCategory::Vision { .. }
587 )
588 {
589 let clone_sender = engine_instance.sender.clone();
590 tokio::task::block_in_place(|| {
591 let (tx, mut rx) = channel(1);
592 let req = Request::Normal(Box::new(NormalRequest {
593 id: 0,
594 messages: RequestMessage::Completion {
595 text: "hello".to_string(),
596 echo_prompt: false,
597 best_of: None,
598 },
599 sampling_params: SamplingParams {
600 max_len: Some(1),
601 ..SamplingParams::deterministic()
602 },
603 response: tx,
604 return_logprobs: false,
605 is_streaming: false,
606 constraint: Constraint::None,
607 suffix: None,
608 tool_choice: None,
609 tools: None,
610 logits_processors: None,
611 return_raw_logits: false,
612 web_search_options: None,
613 model_id: None,
614 truncate_sequence: false,
615 }));
616 info!("Beginning dummy run.");
617 let start = Instant::now();
618 clone_sender.blocking_send(req).unwrap();
619
620 let mut received_any = false;
622 while let Some(_resp) = rx.blocking_recv() {
623 received_any = true;
624 }
625
626 if received_any {
627 let end = Instant::now();
628 info!(
629 "Dummy run completed in {}s.",
630 end.duration_since(start).as_secs_f64()
631 );
632 } else {
633 warn!("Dummy run failed!");
634 }
635 });
636 }
637
638 let mut engines = HashMap::new();
640 engines.insert(id.clone(), engine_instance);
641
642 Arc::new(Self {
643 engines: RwLock::new(engines),
644 default_engine_id: RwLock::new(Some(id.clone())),
645 log,
646 id,
647 creation_time: SystemTime::now()
648 .duration_since(UNIX_EPOCH)
649 .expect("Time travel has occurred!")
650 .as_secs(),
651 next_request_id: Mutex::new(RefCell::new(1)),
652 })
653 }
654
655 fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
657 let mut engines = self.engines.write().map_err(|_| {
658 tracing::warn!("Couldn't get write lock on engines during reboot attempt");
659 MistralRsError::EnginePoisoned
660 })?;
661
662 if let Some(engine_instance) = engines.get(model_id) {
663 if !engine_instance.engine_handler.is_finished() {
664 tracing::info!("Engine {} already running, returning ok", model_id);
665 return Ok(());
666 }
667
668 let reboot_state = engine_instance.reboot_state.clone();
669 let engine_config = EngineConfig {
670 no_kv_cache: reboot_state.no_kv_cache,
671 no_prefix_cache: reboot_state.no_prefix_cache,
672 prefix_cache_n: reboot_state.prefix_cache_n,
673 disable_eos_stop: reboot_state.disable_eos_stop,
674 throughput_logging_enabled: reboot_state.throughput_logging_enabled,
675 search_embedding_model: reboot_state.search_embedding_model.clone(),
676 search_callback: reboot_state.search_callback.clone(),
677 tool_callbacks: reboot_state.tool_callbacks.clone(),
678 tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
679 };
680 let new_engine_instance = Self::create_engine_instance(
681 reboot_state.pipeline.clone(),
682 reboot_state.method.clone(),
683 engine_config,
684 reboot_state,
685 )
686 .map_err(|e| {
687 tracing::error!("Failed to create new engine instance: {}", e);
688 MistralRsError::EnginePoisoned
689 })?;
690
691 engines.insert(model_id.to_string(), new_engine_instance);
692 tracing::info!("Successfully rebooted engine {}", model_id);
693 Ok(())
694 } else {
695 Err(MistralRsError::EnginePoisoned)
696 }
697 }
698
699 fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
700 let engines = self.engines.read().map_err(|_| {
701 tracing::warn!("Couldn't get read lock on engines!");
702 MistralRsError::EnginePoisoned
703 })?;
704
705 if let Some(engine_instance) = engines.get(model_id) {
706 Ok(engine_instance.engine_handler.is_finished())
707 } else {
708 Err(MistralRsError::EnginePoisoned)
709 }
710 }
711
712 pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
714 let resolved_model_id = match model_id {
715 Some(id) => id.to_string(),
716 None => {
717 let default_lock = self
718 .default_engine_id
719 .read()
720 .map_err(|_| MistralRsError::SenderPoisoned)?;
721 default_lock
722 .as_ref()
723 .ok_or(MistralRsError::EnginePoisoned)?
724 .clone()
725 }
726 };
727
728 if self.engine_dead(&resolved_model_id)? {
729 tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
730 self.reboot_engine(&resolved_model_id)?
731 }
732
733 let engines = self
734 .engines
735 .read()
736 .map_err(|_| MistralRsError::SenderPoisoned)?;
737 if let Some(engine_instance) = engines.get(&resolved_model_id) {
738 Ok(engine_instance.sender.clone())
739 } else {
740 Err(MistralRsError::EnginePoisoned)
741 }
742 }
743
744 pub fn get_id(&self) -> String {
745 self.id.clone()
746 }
747
748 pub fn get_creation_time(&self) -> u64 {
749 self.creation_time
750 }
751
752 pub fn get_model_category(
754 &self,
755 model_id: Option<&str>,
756 ) -> Result<ModelCategory, MistralRsError> {
757 let resolved_model_id = match model_id {
758 Some(id) => id.to_string(),
759 None => {
760 let default_lock = self
761 .default_engine_id
762 .read()
763 .map_err(|_| MistralRsError::SenderPoisoned)?;
764 default_lock
765 .as_ref()
766 .ok_or(MistralRsError::EnginePoisoned)?
767 .clone()
768 }
769 };
770
771 let engines = self
772 .engines
773 .read()
774 .map_err(|_| MistralRsError::SenderPoisoned)?;
775 if let Some(engine_instance) = engines.get(&resolved_model_id) {
776 Ok(engine_instance.category.clone())
777 } else {
778 Err(MistralRsError::EnginePoisoned)
779 }
780 }
781
782 pub fn next_request_id(&self) -> usize {
783 let l = self.next_request_id.lock().unwrap();
784 let last = &mut *l.borrow_mut();
785 let last_v = *last;
786 *last += 1;
787 last_v
788 }
789
790 pub async fn add_model(
792 &self,
793 model_id: String,
794 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
795 method: SchedulerConfig,
796 config: AddModelConfig,
797 ) -> Result<(), String> {
798 let reboot_state = RebootState {
799 pipeline: pipeline.clone(),
800 method: method.clone(),
801 no_kv_cache: config.engine_config.no_kv_cache,
802 no_prefix_cache: config.engine_config.no_prefix_cache,
803 prefix_cache_n: config.engine_config.prefix_cache_n,
804 disable_eos_stop: config.engine_config.disable_eos_stop,
805 throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
806 search_embedding_model: config.engine_config.search_embedding_model.clone(),
807 search_callback: config.engine_config.search_callback.clone(),
808 tool_callbacks: config.engine_config.tool_callbacks.clone(),
809 tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
810 mcp_client_config: config.mcp_client_config.clone(),
811 };
812
813 let engine_instance =
814 Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
815
816 let mut engines = self
817 .engines
818 .write()
819 .map_err(|_| "Failed to acquire write lock on engines")?;
820 engines.insert(model_id.clone(), engine_instance);
821
822 if engines.len() == 1 {
824 let mut default_lock = self
825 .default_engine_id
826 .write()
827 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
828 *default_lock = Some(model_id.clone());
829 }
830
831 Ok(())
832 }
833
834 pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
836 let mut engines = self
837 .engines
838 .write()
839 .map_err(|_| "Failed to acquire write lock on engines")?;
840
841 if engines.len() <= 1 {
842 return Err("Cannot remove the last model from MistralRs".to_string());
843 }
844
845 if let Some(engine_instance) = engines.remove(model_id) {
846 let _ = engine_instance.sender.blocking_send(Request::Terminate);
848
849 let mut default_lock = self
851 .default_engine_id
852 .write()
853 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
854 if let Some(ref default_id) = *default_lock {
855 if default_id == model_id {
856 *default_lock = engines.keys().next().cloned();
858 }
859 }
860
861 Ok(())
862 } else {
863 Err(format!("Model {model_id} not found"))
864 }
865 }
866
867 pub fn list_models(&self) -> Result<Vec<String>, String> {
869 let engines = self
870 .engines
871 .read()
872 .map_err(|_| "Failed to acquire read lock on engines")?;
873 Ok(engines.keys().cloned().collect())
874 }
875
876 pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
878 let default_lock = self
879 .default_engine_id
880 .read()
881 .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
882 Ok(default_lock.clone())
883 }
884
885 pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
887 let engines = self
888 .engines
889 .read()
890 .map_err(|_| "Failed to acquire read lock on engines")?;
891 if !engines.contains_key(model_id) {
892 return Err(format!("Model {model_id} not found"));
893 }
894 drop(engines);
895
896 let mut default_lock = self
897 .default_engine_id
898 .write()
899 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
900 *default_lock = Some(model_id.to_string());
901
902 Ok(())
903 }
904
905 pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
907 let model_id = match &mut request {
908 Request::Normal(normal_req) => normal_req.model_id.as_deref(),
909 _ => None, };
911
912 let sender = self.get_sender(model_id)?;
913 sender
914 .blocking_send(request)
915 .map_err(|_| MistralRsError::SenderPoisoned)
916 }
917
918 pub fn maybe_log_request(this: Arc<Self>, repr: String) {
919 if let Some(file) = &this.log {
920 let mut f = OpenOptions::new()
921 .append(true)
922 .create(true) .open(file)
924 .expect("Unable to open file");
925 let time = chrono::offset::Local::now();
926 f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
927 .expect("Unable to write data");
928 }
929 }
930
931 pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
932 if let Some(file) = &this.log {
933 let mut f = OpenOptions::new()
934 .append(true)
935 .create(true) .open(file)
937 .expect("Unable to open file");
938 let time = chrono::offset::Local::now();
939 let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
940 f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
941 .expect("Unable to write data");
942 }
943 }
944
945 pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
946 if let Some(file) = &this.log {
947 let mut f = OpenOptions::new()
948 .append(true)
949 .create(true) .open(file)
951 .expect("Unable to open file");
952 let time = chrono::offset::Local::now();
953 f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
954 .expect("Unable to write data");
955 }
956 }
957
958 pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
960 let resolved_model_id = match model_id {
961 Some(id) => id.to_string(),
962 None => {
963 let default_lock = self
964 .default_engine_id
965 .read()
966 .map_err(|_| "Failed to acquire read lock")?;
967 default_lock
968 .as_ref()
969 .ok_or("No default engine set")?
970 .clone()
971 }
972 };
973
974 let engines = self
975 .engines
976 .read()
977 .map_err(|_| "Failed to acquire read lock on engines")?;
978 if let Some(engine_instance) = engines.get(&resolved_model_id) {
979 Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
980 } else {
981 Err(format!("Model {resolved_model_id} not found"))
982 }
983 }
984
985 pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
987 let resolved_model_id = match model_id {
988 Some(id) => id.to_string(),
989 None => {
990 let default_lock = self
991 .default_engine_id
992 .read()
993 .map_err(|_| "Failed to acquire read lock")?;
994 default_lock
995 .as_ref()
996 .ok_or("No default engine set")?
997 .clone()
998 }
999 };
1000
1001 let engines = self
1002 .engines
1003 .read()
1004 .map_err(|_| "Failed to acquire read lock on engines")?;
1005 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1006 Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1007 } else {
1008 Err(format!("Model {resolved_model_id} not found"))
1009 }
1010 }
1011
1012 pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1014 let resolved_model_id = match model_id {
1015 Some(id) => id.to_string(),
1016 None => {
1017 let default_lock = self
1018 .default_engine_id
1019 .read()
1020 .map_err(|_| "Failed to acquire read lock")?;
1021 default_lock
1022 .as_ref()
1023 .ok_or("No default engine set")?
1024 .clone()
1025 }
1026 };
1027
1028 let engines = self
1029 .engines
1030 .read()
1031 .map_err(|_| "Failed to acquire read lock on engines")?;
1032 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1033 Ok(engine_instance.config.clone())
1034 } else {
1035 Err(format!("Model {resolved_model_id} not found"))
1036 }
1037 }
1038}