1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5 get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6 BertEmbeddingModel, EngineInstruction, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
7};
8use hf_hub::Cache;
9pub use lora::Ordering;
10pub use pipeline::ModelCategory;
11pub use pipeline::Pipeline;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::exceptions::PyValueError;
14use std::collections::HashMap;
15use std::sync::OnceLock;
16use std::time::Instant;
17use std::{
18 cell::RefCell,
19 error::Error,
20 fs::OpenOptions,
21 io::Write,
22 sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
23 thread::{self, JoinHandle},
24 time::{SystemTime, UNIX_EPOCH},
25};
26use tokio::sync::mpsc::{channel, Sender};
27use tracing::info;
28use tracing::warn;
29
30mod cuda;
31mod device_map;
32mod engine;
33mod lora;
34mod model_loader;
35mod ops;
36pub use model_loader::{
37 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
38};
39mod kv_cache;
40mod search;
41
42mod model_selected;
43pub use model_selected::ModelSelected;
44pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
45
46mod amoe;
47#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
48mod dummy_paged_attention;
49mod embedding;
50mod gguf;
51pub mod layers;
52mod layers_masker;
53mod layers_utils;
54mod models;
55#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
56mod paged_attention;
57#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
58use dummy_paged_attention as paged_attention;
59mod attention;
60mod diffusion_models;
61pub mod distributed;
62mod pipeline;
63mod prefix_cacher;
64mod request;
65mod response;
66mod sampler;
67mod scheduler;
68mod sequence;
69mod speech_models;
70mod toml_selector;
71mod tools;
72mod topology;
73mod utils;
74mod vision_models;
75mod xlora_models;
76
77pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
78pub use device_map::{
79 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
80};
81pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
82pub use mistralrs_audio::AudioInput;
83pub use mistralrs_mcp::{
84 CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
85};
86pub use mistralrs_mcp::{
87 McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
88};
89pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
90pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
91pub use pipeline::{
92 chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
93 AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
94 DiffusionLoaderBuilder, DiffusionLoaderType, GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig,
95 GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig, GemmaLoader, Idefics2Loader,
96 IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths,
97 LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind, ModelPaths,
98 MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
99 NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
100 SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
101 SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
102 VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
103};
104pub use request::{
105 ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
106 LlguidanceGrammar, MessageContent, NormalRequest, Request, RequestMessage, SearchContextSize,
107 TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
108};
109pub use response::*;
110pub use sampler::{
111 CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
112};
113pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
114pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
115use serde::Serialize;
116pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
117use tokio::runtime::Runtime;
118use toml_selector::{TomlLoaderArgs, TomlSelector};
119pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
120pub use topology::{LayerTopology, Topology};
121pub use utils::debug::initialize_logging;
122pub use utils::memory_usage::MemoryUsage;
123pub use utils::normal::{ModelDType, TryIntoDType};
124pub use utils::{paged_attn_supported, using_flash_attn};
125
126pub use llguidance;
128
129pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
131pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
132
133#[derive(Clone)]
135pub struct EngineConfig {
136 pub truncate_sequence: bool,
137 pub no_kv_cache: bool,
138 pub no_prefix_cache: bool,
139 pub prefix_cache_n: usize,
140 pub disable_eos_stop: bool,
141 pub throughput_logging_enabled: bool,
142 pub search_embedding_model: Option<BertEmbeddingModel>,
143 pub search_callback: Option<Arc<SearchCallback>>,
144 pub tool_callbacks: tools::ToolCallbacks,
145 pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
146}
147
148impl Default for EngineConfig {
149 fn default() -> Self {
150 Self {
151 truncate_sequence: false,
152 no_kv_cache: false,
153 no_prefix_cache: false,
154 prefix_cache_n: 16,
155 disable_eos_stop: false,
156 throughput_logging_enabled: true,
157 search_embedding_model: None,
158 search_callback: None,
159 tool_callbacks: HashMap::new(),
160 tool_callbacks_with_tools: HashMap::new(),
161 }
162 }
163}
164
165#[derive(Clone)]
167pub struct AddModelConfig {
168 pub engine_config: EngineConfig,
169 pub mcp_client_config: Option<McpClientConfig>,
170}
171
172impl AddModelConfig {
173 pub fn new(engine_config: EngineConfig) -> Self {
174 Self {
175 engine_config,
176 mcp_client_config: None,
177 }
178 }
179
180 pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
181 self.mcp_client_config = Some(mcp_config);
182 self
183 }
184}
185
186#[derive(Clone)]
187pub struct MistralRsConfig {
188 pub kind: ModelKind,
189 pub device: Device,
190 pub category: ModelCategory,
191 pub modalities: Modalities,
192}
193
194struct EngineInstance {
196 sender: Sender<Request>,
197 engine_handler: JoinHandle<()>,
198 reboot_state: RebootState,
199 config: MistralRsConfig,
200 category: ModelCategory,
201}
202
203pub struct MistralRs {
208 engines: RwLock<HashMap<String, EngineInstance>>,
209 default_engine_id: RwLock<Option<String>>,
210 log: Option<String>,
211 id: String,
212 creation_time: u64,
213 next_request_id: Mutex<RefCell<usize>>,
214}
215
216#[derive(Clone)]
217struct RebootState {
218 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
219 method: SchedulerConfig,
220 truncate_sequence: bool,
221 no_kv_cache: bool,
222 no_prefix_cache: bool,
223 prefix_cache_n: usize,
224 disable_eos_stop: bool,
225 throughput_logging_enabled: bool,
226 search_embedding_model: Option<BertEmbeddingModel>,
227 search_callback: Option<Arc<search::SearchCallback>>,
228 tool_callbacks: tools::ToolCallbacks,
229 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
230 mcp_client_config: Option<McpClientConfig>,
231}
232
233#[derive(Debug)]
234pub enum MistralRsError {
235 EnginePoisoned,
236 SenderPoisoned,
237}
238
239impl std::fmt::Display for MistralRsError {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 write!(f, "{:?}", &self)
242 }
243}
244
245impl std::error::Error for MistralRsError {}
246
247#[cfg(feature = "pyo3_macros")]
248impl From<MistralRsError> for pyo3::PyErr {
249 fn from(value: MistralRsError) -> Self {
250 PyValueError::new_err(format!("{value:?}"))
251 }
252}
253
254pub struct MistralRsBuilder {
258 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
259 method: SchedulerConfig,
260 log: Option<String>,
261 truncate_sequence: Option<bool>,
262 no_kv_cache: Option<bool>,
263 no_prefix_cache: Option<bool>,
264 prefix_cache_n: Option<usize>,
265 disable_eos_stop: Option<bool>,
266 throughput_logging_enabled: bool,
267 search_embedding_model: Option<BertEmbeddingModel>,
268 search_callback: Option<Arc<SearchCallback>>,
269 tool_callbacks: tools::ToolCallbacks,
270 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
271 mcp_client_config: Option<McpClientConfig>,
272}
273
274impl MistralRsBuilder {
275 pub fn new(
279 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
280 method: SchedulerConfig,
281 throughput_logging: bool,
282 search_embedding_model: Option<BertEmbeddingModel>,
283 ) -> Self {
284 Self {
285 pipeline,
286 method,
287 log: None,
288 truncate_sequence: None,
289 no_kv_cache: None,
290 no_prefix_cache: None,
291 prefix_cache_n: None,
292 disable_eos_stop: None,
293 throughput_logging_enabled: throughput_logging,
294 search_embedding_model,
295 search_callback: None,
296 tool_callbacks: HashMap::new(),
297 tool_callbacks_with_tools: HashMap::new(),
298 mcp_client_config: None,
299 }
300 }
301 pub fn with_log(mut self, log: String) -> Self {
302 self.log = Some(log);
303 self
304 }
305 pub fn with_opt_log(mut self, log: Option<String>) -> Self {
306 self.log = log;
307 self
308 }
309 pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
310 self.truncate_sequence = Some(truncate_sequence);
311 self
312 }
313 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
314 self.no_kv_cache = Some(no_kv_cache);
315 self
316 }
317 pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
318 self.no_prefix_cache = Some(no_prefix_cache);
319 self
320 }
321 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
322 self.prefix_cache_n = Some(prefix_cache_n);
323 self
324 }
325 pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
326 self.disable_eos_stop = Some(disable_eos_stop);
327 self
328 }
329
330 pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
332 self.search_callback = Some(search_callback);
333 self
334 }
335
336 pub fn with_tool_callback(
338 mut self,
339 name: impl Into<String>,
340 tool_callback: Arc<ToolCallback>,
341 ) -> Self {
342 self.tool_callbacks.insert(name.into(), tool_callback);
343 self
344 }
345
346 pub fn with_tool_callback_and_tool(
349 mut self,
350 name: impl Into<String>,
351 tool_callback: Arc<ToolCallback>,
352 tool: Tool,
353 ) -> Self {
354 let name = name.into();
355 self.tool_callbacks_with_tools.insert(
356 name,
357 ToolCallbackWithTool {
358 callback: tool_callback,
359 tool,
360 },
361 );
362 self
363 }
364
365 pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
367 self.mcp_client_config = Some(config);
368 self
369 }
370
371 pub async fn build(self) -> Arc<MistralRs> {
372 MistralRs::new(self).await
373 }
374}
375
376impl Drop for MistralRs {
377 fn drop(&mut self) {
378 if let Ok(engines) = self.engines.read() {
380 for (_, engine) in engines.iter() {
381 let _ = engine.sender.try_send(Request::Terminate);
383 }
384 }
385 }
386}
387
388impl MistralRs {
389 fn create_engine_instance(
391 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
392 method: SchedulerConfig,
393 config: EngineConfig,
394 reboot_state: RebootState,
395 ) -> Result<EngineInstance, String> {
396 let (tx, rx) = channel(10_000);
397
398 let category = pipeline.try_lock().unwrap().category();
399 let kind = pipeline.try_lock().unwrap().get_metadata().kind.clone();
400 let device = pipeline.try_lock().unwrap().device();
401 let modalities = pipeline
402 .try_lock()
403 .unwrap()
404 .get_metadata()
405 .modalities
406 .clone();
407
408 info!("Pipeline input modalities are {:?}", &modalities.input);
409 info!("Pipeline output modalities are {:?}", &modalities.output);
410
411 let mistralrs_config = MistralRsConfig {
412 kind,
413 device,
414 category: category.clone(),
415 modalities,
416 };
417
418 let engine_handler = thread::spawn(move || {
419 #[cfg(feature = "metal")]
420 objc::rc::autoreleasepool(move || {
421 let rt = Runtime::new().unwrap();
422 rt.block_on(async move {
423 let engine = Engine::new(
424 rx,
425 pipeline,
426 method,
427 config.truncate_sequence,
428 config.no_kv_cache,
429 config.no_prefix_cache,
430 config.prefix_cache_n,
431 config.disable_eos_stop,
432 config.throughput_logging_enabled,
433 config.search_embedding_model,
434 config.search_callback.clone(),
435 config.tool_callbacks.clone(),
436 config.tool_callbacks_with_tools.clone(),
437 )
438 .expect("Engine creation failed.");
439 Arc::new(engine).run().await;
440 })
441 });
442
443 #[cfg(not(feature = "metal"))]
444 {
445 let rt = Runtime::new().unwrap();
446 rt.block_on(async move {
447 let engine = Engine::new(
448 rx,
449 pipeline,
450 method,
451 config.truncate_sequence,
452 config.no_kv_cache,
453 config.no_prefix_cache,
454 config.prefix_cache_n,
455 config.disable_eos_stop,
456 config.throughput_logging_enabled,
457 config.search_embedding_model,
458 config.search_callback.clone(),
459 config.tool_callbacks.clone(),
460 config.tool_callbacks_with_tools.clone(),
461 )
462 .expect("Engine creation failed.");
463 Arc::new(engine).run().await;
464 })
465 }
466 });
467
468 Ok(EngineInstance {
469 sender: tx,
470 engine_handler,
471 reboot_state,
472 config: mistralrs_config,
473 category,
474 })
475 }
476
477 async fn new(config: MistralRsBuilder) -> Arc<Self> {
478 let MistralRsBuilder {
479 pipeline,
480 method,
481 log,
482 truncate_sequence,
483 no_kv_cache,
484 no_prefix_cache,
485 prefix_cache_n,
486 disable_eos_stop,
487 throughput_logging_enabled,
488 search_embedding_model,
489 search_callback,
490 tool_callbacks,
491 mut tool_callbacks_with_tools,
492 mcp_client_config,
493 } = config;
494
495 mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
496 get_mut_arcmutex!(pipeline).device(),
497 );
498
499 let truncate_sequence = truncate_sequence.unwrap_or(false);
500 let no_kv_cache = no_kv_cache.unwrap_or(false);
501 let no_prefix_cache = no_prefix_cache.unwrap_or(false);
502 let prefix_cache_n = prefix_cache_n.unwrap_or(16);
503 let disable_eos_stop = disable_eos_stop.unwrap_or(false);
504
505 if let Some(config) = &mcp_client_config {
507 let mut mcp_client = McpClient::new(config.clone());
508 let total_servers = config.servers.len();
509
510 match mcp_client.initialize().await {
511 Ok(()) => {
512 let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
513 let tools_count = mcp_callbacks_with_tools.len();
514
515 for (name, callback_with_tool) in mcp_callbacks_with_tools {
517 tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
518 }
519
520 if tools_count == 0 {
521 warn!(
522 "MCP client initialized but no tools were registered from {} servers",
523 total_servers
524 );
525 } else {
526 info!(
527 "MCP client initialized successfully with {} tools from {} servers",
528 tools_count, total_servers
529 );
530 }
531 }
532 Err(e) => {
533 warn!(
534 "Failed to initialize MCP client with {} configured servers: {}",
535 total_servers, e
536 );
537 warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
538 }
539 }
540 }
541
542 let reboot_state = RebootState {
543 pipeline: pipeline.clone(),
544 method: method.clone(),
545 truncate_sequence,
546 no_kv_cache,
547 no_prefix_cache,
548 prefix_cache_n,
549 disable_eos_stop,
550 throughput_logging_enabled,
551 search_embedding_model: search_embedding_model.clone(),
552 search_callback: search_callback.clone(),
553 tool_callbacks: tool_callbacks.clone(),
554 tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
555 mcp_client_config: mcp_client_config.clone(),
556 };
557
558 let engine_config = EngineConfig {
560 truncate_sequence,
561 no_kv_cache,
562 no_prefix_cache,
563 prefix_cache_n,
564 disable_eos_stop,
565 throughput_logging_enabled,
566 search_embedding_model,
567 search_callback,
568 tool_callbacks,
569 tool_callbacks_with_tools,
570 };
571
572 let engine_instance =
574 Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
575 .expect("Failed to create engine instance");
576
577 let id = pipeline.try_lock().unwrap().name();
578
579 if distributed::is_daemon() {
580 let request_sender = engine_instance.sender.clone();
581
582 if cfg!(feature = "ring") {
583 distributed::ring_daemon_replicator(request_sender);
585 } else {
586 distributed::nccl_daemon_replicator(request_sender);
588 }
589
590 #[allow(clippy::empty_loop)]
591 loop {}
592 }
593
594 let is_multi_threaded = tokio::runtime::Handle::try_current()
596 .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
597
598 if !distributed::is_daemon()
600 && is_multi_threaded
601 && matches!(
602 engine_instance.category,
603 ModelCategory::Text | ModelCategory::Vision { .. }
604 )
605 {
606 let clone_sender = engine_instance.sender.clone();
607 tokio::task::block_in_place(|| {
608 let (tx, mut rx) = channel(1);
609 let req = Request::Normal(Box::new(NormalRequest {
610 id: 0,
611 messages: RequestMessage::Completion {
612 text: "hello".to_string(),
613 echo_prompt: false,
614 best_of: None,
615 },
616 sampling_params: SamplingParams {
617 max_len: Some(1),
618 ..SamplingParams::deterministic()
619 },
620 response: tx,
621 return_logprobs: false,
622 is_streaming: false,
623 constraint: Constraint::None,
624 suffix: None,
625 tool_choice: None,
626 tools: None,
627 logits_processors: None,
628 return_raw_logits: false,
629 web_search_options: None,
630 model_id: None,
631 }));
632 info!("Beginning dummy run.");
633 let start = Instant::now();
634 clone_sender.blocking_send(req).unwrap();
635
636 if let Some(_resp) = rx.blocking_recv() {
637 let end = Instant::now();
638 info!(
639 "Dummy run completed in {}s.",
640 end.duration_since(start).as_secs_f64()
641 );
642 } else {
643 warn!("Dummy run failed!");
644 }
645 });
646 }
647
648 let mut engines = HashMap::new();
650 engines.insert(id.clone(), engine_instance);
651
652 Arc::new(Self {
653 engines: RwLock::new(engines),
654 default_engine_id: RwLock::new(Some(id.clone())),
655 log,
656 id,
657 creation_time: SystemTime::now()
658 .duration_since(UNIX_EPOCH)
659 .expect("Time travel has occurred!")
660 .as_secs(),
661 next_request_id: Mutex::new(RefCell::new(1)),
662 })
663 }
664
665 fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
667 let mut engines = self.engines.write().map_err(|_| {
668 tracing::warn!("Couldn't get write lock on engines during reboot attempt");
669 MistralRsError::EnginePoisoned
670 })?;
671
672 if let Some(engine_instance) = engines.get(model_id) {
673 if !engine_instance.engine_handler.is_finished() {
674 tracing::info!("Engine {} already running, returning ok", model_id);
675 return Ok(());
676 }
677
678 let reboot_state = engine_instance.reboot_state.clone();
679 let engine_config = EngineConfig {
680 truncate_sequence: reboot_state.truncate_sequence,
681 no_kv_cache: reboot_state.no_kv_cache,
682 no_prefix_cache: reboot_state.no_prefix_cache,
683 prefix_cache_n: reboot_state.prefix_cache_n,
684 disable_eos_stop: reboot_state.disable_eos_stop,
685 throughput_logging_enabled: reboot_state.throughput_logging_enabled,
686 search_embedding_model: reboot_state.search_embedding_model.clone(),
687 search_callback: reboot_state.search_callback.clone(),
688 tool_callbacks: reboot_state.tool_callbacks.clone(),
689 tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
690 };
691 let new_engine_instance = Self::create_engine_instance(
692 reboot_state.pipeline.clone(),
693 reboot_state.method.clone(),
694 engine_config,
695 reboot_state,
696 )
697 .map_err(|e| {
698 tracing::error!("Failed to create new engine instance: {}", e);
699 MistralRsError::EnginePoisoned
700 })?;
701
702 engines.insert(model_id.to_string(), new_engine_instance);
703 tracing::info!("Successfully rebooted engine {}", model_id);
704 Ok(())
705 } else {
706 Err(MistralRsError::EnginePoisoned)
707 }
708 }
709
710 fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
711 let engines = self.engines.read().map_err(|_| {
712 tracing::warn!("Couldn't get read lock on engines!");
713 MistralRsError::EnginePoisoned
714 })?;
715
716 if let Some(engine_instance) = engines.get(model_id) {
717 Ok(engine_instance.engine_handler.is_finished())
718 } else {
719 Err(MistralRsError::EnginePoisoned)
720 }
721 }
722
723 pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
725 let resolved_model_id = match model_id {
726 Some(id) => id.to_string(),
727 None => {
728 let default_lock = self
729 .default_engine_id
730 .read()
731 .map_err(|_| MistralRsError::SenderPoisoned)?;
732 default_lock
733 .as_ref()
734 .ok_or(MistralRsError::EnginePoisoned)?
735 .clone()
736 }
737 };
738
739 if self.engine_dead(&resolved_model_id)? {
740 tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
741 self.reboot_engine(&resolved_model_id)?
742 }
743
744 let engines = self
745 .engines
746 .read()
747 .map_err(|_| MistralRsError::SenderPoisoned)?;
748 if let Some(engine_instance) = engines.get(&resolved_model_id) {
749 Ok(engine_instance.sender.clone())
750 } else {
751 Err(MistralRsError::EnginePoisoned)
752 }
753 }
754
755 pub fn get_id(&self) -> String {
756 self.id.clone()
757 }
758
759 pub fn get_creation_time(&self) -> u64 {
760 self.creation_time
761 }
762
763 pub fn get_model_category(
765 &self,
766 model_id: Option<&str>,
767 ) -> Result<ModelCategory, MistralRsError> {
768 let resolved_model_id = match model_id {
769 Some(id) => id.to_string(),
770 None => {
771 let default_lock = self
772 .default_engine_id
773 .read()
774 .map_err(|_| MistralRsError::SenderPoisoned)?;
775 default_lock
776 .as_ref()
777 .ok_or(MistralRsError::EnginePoisoned)?
778 .clone()
779 }
780 };
781
782 let engines = self
783 .engines
784 .read()
785 .map_err(|_| MistralRsError::SenderPoisoned)?;
786 if let Some(engine_instance) = engines.get(&resolved_model_id) {
787 Ok(engine_instance.category.clone())
788 } else {
789 Err(MistralRsError::EnginePoisoned)
790 }
791 }
792
793 pub fn next_request_id(&self) -> usize {
794 let l = self.next_request_id.lock().unwrap();
795 let last = &mut *l.borrow_mut();
796 let last_v = *last;
797 *last += 1;
798 last_v
799 }
800
801 pub async fn add_model(
803 &self,
804 model_id: String,
805 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
806 method: SchedulerConfig,
807 config: AddModelConfig,
808 ) -> Result<(), String> {
809 let reboot_state = RebootState {
810 pipeline: pipeline.clone(),
811 method: method.clone(),
812 truncate_sequence: config.engine_config.truncate_sequence,
813 no_kv_cache: config.engine_config.no_kv_cache,
814 no_prefix_cache: config.engine_config.no_prefix_cache,
815 prefix_cache_n: config.engine_config.prefix_cache_n,
816 disable_eos_stop: config.engine_config.disable_eos_stop,
817 throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
818 search_embedding_model: config.engine_config.search_embedding_model.clone(),
819 search_callback: config.engine_config.search_callback.clone(),
820 tool_callbacks: config.engine_config.tool_callbacks.clone(),
821 tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
822 mcp_client_config: config.mcp_client_config.clone(),
823 };
824
825 let engine_instance =
826 Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
827
828 let mut engines = self
829 .engines
830 .write()
831 .map_err(|_| "Failed to acquire write lock on engines")?;
832 engines.insert(model_id.clone(), engine_instance);
833
834 if engines.len() == 1 {
836 let mut default_lock = self
837 .default_engine_id
838 .write()
839 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
840 *default_lock = Some(model_id.clone());
841 }
842
843 Ok(())
844 }
845
846 pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
848 let mut engines = self
849 .engines
850 .write()
851 .map_err(|_| "Failed to acquire write lock on engines")?;
852
853 if engines.len() <= 1 {
854 return Err("Cannot remove the last model from MistralRs".to_string());
855 }
856
857 if let Some(engine_instance) = engines.remove(model_id) {
858 let _ = engine_instance.sender.blocking_send(Request::Terminate);
860
861 let mut default_lock = self
863 .default_engine_id
864 .write()
865 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
866 if let Some(ref default_id) = *default_lock {
867 if default_id == model_id {
868 *default_lock = engines.keys().next().cloned();
870 }
871 }
872
873 Ok(())
874 } else {
875 Err(format!("Model {model_id} not found"))
876 }
877 }
878
879 pub fn list_models(&self) -> Result<Vec<String>, String> {
881 let engines = self
882 .engines
883 .read()
884 .map_err(|_| "Failed to acquire read lock on engines")?;
885 Ok(engines.keys().cloned().collect())
886 }
887
888 pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
890 let default_lock = self
891 .default_engine_id
892 .read()
893 .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
894 Ok(default_lock.clone())
895 }
896
897 pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
899 let engines = self
900 .engines
901 .read()
902 .map_err(|_| "Failed to acquire read lock on engines")?;
903 if !engines.contains_key(model_id) {
904 return Err(format!("Model {model_id} not found"));
905 }
906 drop(engines);
907
908 let mut default_lock = self
909 .default_engine_id
910 .write()
911 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
912 *default_lock = Some(model_id.to_string());
913
914 Ok(())
915 }
916
917 pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
919 let model_id = match &mut request {
920 Request::Normal(normal_req) => normal_req.model_id.as_deref(),
921 _ => None, };
923
924 let sender = self.get_sender(model_id)?;
925 sender
926 .blocking_send(request)
927 .map_err(|_| MistralRsError::SenderPoisoned)
928 }
929
930 pub fn maybe_log_request(this: Arc<Self>, repr: String) {
931 if let Some(file) = &this.log {
932 let mut f = OpenOptions::new()
933 .append(true)
934 .create(true) .open(file)
936 .expect("Unable to open file");
937 let time = chrono::offset::Local::now();
938 f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
939 .expect("Unable to write data");
940 }
941 }
942
943 pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
944 if let Some(file) = &this.log {
945 let mut f = OpenOptions::new()
946 .append(true)
947 .create(true) .open(file)
949 .expect("Unable to open file");
950 let time = chrono::offset::Local::now();
951 let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
952 f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
953 .expect("Unable to write data");
954 }
955 }
956
957 pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
958 if let Some(file) = &this.log {
959 let mut f = OpenOptions::new()
960 .append(true)
961 .create(true) .open(file)
963 .expect("Unable to open file");
964 let time = chrono::offset::Local::now();
965 f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
966 .expect("Unable to write data");
967 }
968 }
969
970 pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
972 let resolved_model_id = match model_id {
973 Some(id) => id.to_string(),
974 None => {
975 let default_lock = self
976 .default_engine_id
977 .read()
978 .map_err(|_| "Failed to acquire read lock")?;
979 default_lock
980 .as_ref()
981 .ok_or("No default engine set")?
982 .clone()
983 }
984 };
985
986 let engines = self
987 .engines
988 .read()
989 .map_err(|_| "Failed to acquire read lock on engines")?;
990 if let Some(engine_instance) = engines.get(&resolved_model_id) {
991 Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
992 } else {
993 Err(format!("Model {resolved_model_id} not found"))
994 }
995 }
996
997 pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
999 let resolved_model_id = match model_id {
1000 Some(id) => id.to_string(),
1001 None => {
1002 let default_lock = self
1003 .default_engine_id
1004 .read()
1005 .map_err(|_| "Failed to acquire read lock")?;
1006 default_lock
1007 .as_ref()
1008 .ok_or("No default engine set")?
1009 .clone()
1010 }
1011 };
1012
1013 let engines = self
1014 .engines
1015 .read()
1016 .map_err(|_| "Failed to acquire read lock on engines")?;
1017 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1018 Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1019 } else {
1020 Err(format!("Model {resolved_model_id} not found"))
1021 }
1022 }
1023
1024 pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1026 let resolved_model_id = match model_id {
1027 Some(id) => id.to_string(),
1028 None => {
1029 let default_lock = self
1030 .default_engine_id
1031 .read()
1032 .map_err(|_| "Failed to acquire read lock")?;
1033 default_lock
1034 .as_ref()
1035 .ok_or("No default engine set")?
1036 .clone()
1037 }
1038 };
1039
1040 let engines = self
1041 .engines
1042 .read()
1043 .map_err(|_| "Failed to acquire read lock on engines")?;
1044 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1045 Ok(engine_instance.config.clone())
1046 } else {
1047 Err(format!("Model {resolved_model_id} not found"))
1048 }
1049 }
1050}