1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5 get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6 EngineInstruction, SearchEmbeddingModel, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
7};
8use hf_hub::Cache;
9pub use lora::Ordering;
10pub use pipeline::ModelCategory;
11pub use pipeline::Pipeline;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::exceptions::PyValueError;
14use std::collections::HashMap;
15use std::sync::OnceLock;
16use std::time::Instant;
17use std::{
18 cell::RefCell,
19 error::Error,
20 fs::OpenOptions,
21 io::Write,
22 sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
23 thread::{self, JoinHandle},
24 time::{SystemTime, UNIX_EPOCH},
25};
26use tokio::sync::mpsc::{channel, Sender};
27use tracing::info;
28use tracing::warn;
29
30mod cuda;
31mod device_map;
32mod engine;
33mod lora;
34mod model_loader;
35mod ops;
36pub use model_loader::{
37 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
38};
39mod embedding_models;
40mod kv_cache;
41mod search;
42
43mod model_selected;
44pub use model_selected::ModelSelected;
45pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
46
47mod amoe;
48mod attention;
49mod diffusion_models;
50pub mod distributed;
51mod gguf;
52pub mod layers;
53mod layers_masker;
54mod layers_utils;
55pub mod matformer;
56mod models;
57mod paged_attention;
58mod pipeline;
59mod prefix_cacher;
60mod request;
61mod response;
62mod sampler;
63mod scheduler;
64mod sequence;
65mod speech_models;
66mod toml_selector;
67mod tools;
68mod topology;
69mod utils;
70mod vision_models;
71mod xlora_models;
72
73pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
74pub use device_map::{
75 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
76};
77pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
78pub use mistralrs_audio::AudioInput;
79pub use mistralrs_mcp::{
80 CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
81};
82pub use mistralrs_mcp::{
83 McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
84};
85pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
86pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
87pub use pipeline::{
88 chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
89 AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
90 DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader, EmbeddingLoaderBuilder,
91 EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig, GGMLLoader,
92 GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig,
93 GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader,
94 Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind,
95 ModelPaths, MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
96 NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
97 SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
98 SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
99 VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
100};
101pub use request::{
102 ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
103 LlguidanceGrammar, MessageContent, NormalRequest, Request, RequestMessage, SearchContextSize,
104 TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
105};
106pub use response::*;
107pub use sampler::{
108 CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
109};
110pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
111pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
112use serde::Serialize;
113pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
114use tokio::runtime::Runtime;
115use toml_selector::{TomlLoaderArgs, TomlSelector};
116pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
117pub use topology::{LayerTopology, Topology};
118pub use utils::debug::initialize_logging;
119pub use utils::memory_usage::MemoryUsage;
120pub use utils::normal::{ModelDType, TryIntoDType};
121pub use utils::{paged_attn_supported, using_flash_attn};
122
123pub use llguidance;
125
126pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
128pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
129
130#[derive(Clone)]
132pub struct EngineConfig {
133 pub no_kv_cache: bool,
134 pub no_prefix_cache: bool,
135 pub prefix_cache_n: usize,
136 pub disable_eos_stop: bool,
137 pub throughput_logging_enabled: bool,
138 pub search_embedding_model: Option<SearchEmbeddingModel>,
139 pub search_callback: Option<Arc<SearchCallback>>,
140 pub tool_callbacks: tools::ToolCallbacks,
141 pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
142}
143
144impl Default for EngineConfig {
145 fn default() -> Self {
146 Self {
147 no_kv_cache: false,
148 no_prefix_cache: false,
149 prefix_cache_n: 16,
150 disable_eos_stop: false,
151 throughput_logging_enabled: true,
152 search_embedding_model: None,
153 search_callback: None,
154 tool_callbacks: HashMap::new(),
155 tool_callbacks_with_tools: HashMap::new(),
156 }
157 }
158}
159
160#[derive(Clone)]
162pub struct AddModelConfig {
163 pub engine_config: EngineConfig,
164 pub mcp_client_config: Option<McpClientConfig>,
165}
166
167impl AddModelConfig {
168 pub fn new(engine_config: EngineConfig) -> Self {
169 Self {
170 engine_config,
171 mcp_client_config: None,
172 }
173 }
174
175 pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
176 self.mcp_client_config = Some(mcp_config);
177 self
178 }
179}
180
181#[derive(Clone)]
182pub struct MistralRsConfig {
183 pub kind: ModelKind,
184 pub device: Device,
185 pub category: ModelCategory,
186 pub modalities: Modalities,
187 pub max_seq_len: Option<usize>,
188}
189
190struct EngineInstance {
192 sender: Sender<Request>,
193 engine_handler: JoinHandle<()>,
194 reboot_state: RebootState,
195 config: MistralRsConfig,
196 category: ModelCategory,
197}
198
199pub struct MistralRs {
204 engines: RwLock<HashMap<String, EngineInstance>>,
205 default_engine_id: RwLock<Option<String>>,
206 log: Option<String>,
207 id: String,
208 creation_time: u64,
209 next_request_id: Mutex<RefCell<usize>>,
210}
211
212#[derive(Clone)]
213struct RebootState {
214 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
215 method: SchedulerConfig,
216 no_kv_cache: bool,
217 no_prefix_cache: bool,
218 prefix_cache_n: usize,
219 disable_eos_stop: bool,
220 throughput_logging_enabled: bool,
221 search_embedding_model: Option<SearchEmbeddingModel>,
222 search_callback: Option<Arc<search::SearchCallback>>,
223 tool_callbacks: tools::ToolCallbacks,
224 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
225 mcp_client_config: Option<McpClientConfig>,
226}
227
228#[derive(Debug)]
229pub enum MistralRsError {
230 EnginePoisoned,
231 SenderPoisoned,
232}
233
234impl std::fmt::Display for MistralRsError {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 write!(f, "{:?}", &self)
237 }
238}
239
240impl std::error::Error for MistralRsError {}
241
242#[cfg(feature = "pyo3_macros")]
243impl From<MistralRsError> for pyo3::PyErr {
244 fn from(value: MistralRsError) -> Self {
245 PyValueError::new_err(format!("{value:?}"))
246 }
247}
248
249pub struct MistralRsBuilder {
253 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
254 method: SchedulerConfig,
255 log: Option<String>,
256 no_kv_cache: Option<bool>,
257 no_prefix_cache: Option<bool>,
258 prefix_cache_n: Option<usize>,
259 disable_eos_stop: Option<bool>,
260 throughput_logging_enabled: bool,
261 search_embedding_model: Option<SearchEmbeddingModel>,
262 search_callback: Option<Arc<SearchCallback>>,
263 tool_callbacks: tools::ToolCallbacks,
264 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
265 mcp_client_config: Option<McpClientConfig>,
266}
267
268impl MistralRsBuilder {
269 pub fn new(
273 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
274 method: SchedulerConfig,
275 throughput_logging: bool,
276 search_embedding_model: Option<SearchEmbeddingModel>,
277 ) -> Self {
278 Self {
279 pipeline,
280 method,
281 log: None,
282 no_kv_cache: None,
283 no_prefix_cache: None,
284 prefix_cache_n: None,
285 disable_eos_stop: None,
286 throughput_logging_enabled: throughput_logging,
287 search_embedding_model,
288 search_callback: None,
289 tool_callbacks: HashMap::new(),
290 tool_callbacks_with_tools: HashMap::new(),
291 mcp_client_config: None,
292 }
293 }
294 pub fn with_log(mut self, log: String) -> Self {
295 self.log = Some(log);
296 self
297 }
298 pub fn with_opt_log(mut self, log: Option<String>) -> Self {
299 self.log = log;
300 self
301 }
302 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
303 self.no_kv_cache = Some(no_kv_cache);
304 self
305 }
306 pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
307 self.no_prefix_cache = Some(no_prefix_cache);
308 self
309 }
310 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
311 self.prefix_cache_n = Some(prefix_cache_n);
312 self
313 }
314 pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
315 self.disable_eos_stop = Some(disable_eos_stop);
316 self
317 }
318
319 pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
321 self.search_callback = Some(search_callback);
322 self
323 }
324
325 pub fn with_tool_callback(
327 mut self,
328 name: impl Into<String>,
329 tool_callback: Arc<ToolCallback>,
330 ) -> Self {
331 self.tool_callbacks.insert(name.into(), tool_callback);
332 self
333 }
334
335 pub fn with_tool_callback_and_tool(
338 mut self,
339 name: impl Into<String>,
340 tool_callback: Arc<ToolCallback>,
341 tool: Tool,
342 ) -> Self {
343 let name = name.into();
344 self.tool_callbacks_with_tools.insert(
345 name,
346 ToolCallbackWithTool {
347 callback: tool_callback,
348 tool,
349 },
350 );
351 self
352 }
353
354 pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
356 self.mcp_client_config = Some(config);
357 self
358 }
359
360 pub async fn build(self) -> Arc<MistralRs> {
361 MistralRs::new(self).await
362 }
363}
364
365impl Drop for MistralRs {
366 fn drop(&mut self) {
367 if let Ok(engines) = self.engines.read() {
369 for (_, engine) in engines.iter() {
370 let _ = engine.sender.try_send(Request::Terminate);
372 }
373 }
374 }
375}
376
377impl MistralRs {
378 fn create_engine_instance(
380 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
381 method: SchedulerConfig,
382 config: EngineConfig,
383 reboot_state: RebootState,
384 ) -> Result<EngineInstance, String> {
385 let (tx, rx) = channel(10_000);
386
387 let pipeline_guard = pipeline.try_lock().unwrap();
388 let category = pipeline_guard.category();
389 let metadata = pipeline_guard.get_metadata();
390 let kind = metadata.kind.clone();
391 let device = pipeline_guard.device();
392 let modalities = metadata.modalities.clone();
393 let max_seq_len = match &category {
394 ModelCategory::Diffusion | ModelCategory::Speech => None,
395 _ => Some(metadata.max_seq_len),
396 };
397 drop(pipeline_guard);
398
399 info!("Pipeline input modalities are {:?}", &modalities.input);
400 info!("Pipeline output modalities are {:?}", &modalities.output);
401
402 let mistralrs_config = MistralRsConfig {
403 kind,
404 device,
405 category: category.clone(),
406 modalities,
407 max_seq_len,
408 };
409
410 let engine_handler = thread::spawn(move || {
411 #[cfg(feature = "metal")]
412 objc::rc::autoreleasepool(move || {
413 let rt = Runtime::new().unwrap();
414 rt.block_on(async move {
415 let engine = Engine::new(
416 rx,
417 pipeline,
418 method,
419 config.no_kv_cache,
420 config.no_prefix_cache,
421 config.prefix_cache_n,
422 config.disable_eos_stop,
423 config.throughput_logging_enabled,
424 config.search_embedding_model,
425 config.search_callback.clone(),
426 config.tool_callbacks.clone(),
427 config.tool_callbacks_with_tools.clone(),
428 )
429 .expect("Engine creation failed.");
430 Arc::new(engine).run().await;
431 })
432 });
433
434 #[cfg(not(feature = "metal"))]
435 {
436 let rt = Runtime::new().unwrap();
437 rt.block_on(async move {
438 let engine = Engine::new(
439 rx,
440 pipeline,
441 method,
442 config.no_kv_cache,
443 config.no_prefix_cache,
444 config.prefix_cache_n,
445 config.disable_eos_stop,
446 config.throughput_logging_enabled,
447 config.search_embedding_model,
448 config.search_callback.clone(),
449 config.tool_callbacks.clone(),
450 config.tool_callbacks_with_tools.clone(),
451 )
452 .expect("Engine creation failed.");
453 Arc::new(engine).run().await;
454 })
455 }
456 });
457
458 Ok(EngineInstance {
459 sender: tx,
460 engine_handler,
461 reboot_state,
462 config: mistralrs_config,
463 category,
464 })
465 }
466
467 async fn new(config: MistralRsBuilder) -> Arc<Self> {
468 let MistralRsBuilder {
469 pipeline,
470 method,
471 log,
472 no_kv_cache,
473 no_prefix_cache,
474 prefix_cache_n,
475 disable_eos_stop,
476 throughput_logging_enabled,
477 search_embedding_model,
478 search_callback,
479 tool_callbacks,
480 mut tool_callbacks_with_tools,
481 mcp_client_config,
482 } = config;
483
484 mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
485 get_mut_arcmutex!(pipeline).device(),
486 );
487
488 let no_kv_cache = no_kv_cache.unwrap_or(false);
489 let no_prefix_cache = no_prefix_cache.unwrap_or(false);
490 let prefix_cache_n = prefix_cache_n.unwrap_or(16);
491 let disable_eos_stop = disable_eos_stop.unwrap_or(false);
492
493 if let Some(config) = &mcp_client_config {
495 let mut mcp_client = McpClient::new(config.clone());
496 let total_servers = config.servers.len();
497
498 match mcp_client.initialize().await {
499 Ok(()) => {
500 let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
501 let tools_count = mcp_callbacks_with_tools.len();
502
503 for (name, callback_with_tool) in mcp_callbacks_with_tools {
505 tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
506 }
507
508 if tools_count == 0 {
509 warn!(
510 "MCP client initialized but no tools were registered from {} servers",
511 total_servers
512 );
513 } else {
514 info!(
515 "MCP client initialized successfully with {} tools from {} servers",
516 tools_count, total_servers
517 );
518 }
519 }
520 Err(e) => {
521 warn!(
522 "Failed to initialize MCP client with {} configured servers: {}",
523 total_servers, e
524 );
525 warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
526 }
527 }
528 }
529
530 let reboot_state = RebootState {
531 pipeline: pipeline.clone(),
532 method: method.clone(),
533 no_kv_cache,
534 no_prefix_cache,
535 prefix_cache_n,
536 disable_eos_stop,
537 throughput_logging_enabled,
538 search_embedding_model,
539 search_callback: search_callback.clone(),
540 tool_callbacks: tool_callbacks.clone(),
541 tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
542 mcp_client_config: mcp_client_config.clone(),
543 };
544
545 let engine_config = EngineConfig {
547 no_kv_cache,
548 no_prefix_cache,
549 prefix_cache_n,
550 disable_eos_stop,
551 throughput_logging_enabled,
552 search_embedding_model,
553 search_callback,
554 tool_callbacks,
555 tool_callbacks_with_tools,
556 };
557
558 let engine_instance =
560 Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
561 .expect("Failed to create engine instance");
562
563 let id = pipeline.try_lock().unwrap().name();
564
565 if distributed::is_daemon() {
566 let request_sender = engine_instance.sender.clone();
567
568 if cfg!(feature = "ring") {
569 distributed::ring_daemon_replicator(request_sender);
571 } else {
572 distributed::nccl_daemon_replicator(request_sender);
574 }
575
576 #[allow(clippy::empty_loop)]
577 loop {}
578 }
579
580 let is_multi_threaded = tokio::runtime::Handle::try_current()
582 .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
583
584 if !distributed::is_daemon()
586 && is_multi_threaded
587 && matches!(
588 engine_instance.category,
589 ModelCategory::Text | ModelCategory::Vision { .. }
590 )
591 {
592 let clone_sender = engine_instance.sender.clone();
593 tokio::task::block_in_place(|| {
594 let (tx, mut rx) = channel(1);
595 let req = Request::Normal(Box::new(NormalRequest {
596 id: 0,
597 messages: RequestMessage::Completion {
598 text: "hello".to_string(),
599 echo_prompt: false,
600 best_of: None,
601 },
602 sampling_params: SamplingParams {
603 max_len: Some(1),
604 ..SamplingParams::deterministic()
605 },
606 response: tx,
607 return_logprobs: false,
608 is_streaming: false,
609 constraint: Constraint::None,
610 suffix: None,
611 tool_choice: None,
612 tools: None,
613 logits_processors: None,
614 return_raw_logits: false,
615 web_search_options: None,
616 model_id: None,
617 truncate_sequence: false,
618 }));
619 info!("Beginning dummy run.");
620 let start = Instant::now();
621 clone_sender.blocking_send(req).unwrap();
622
623 let mut received_any = false;
625 while let Some(_resp) = rx.blocking_recv() {
626 received_any = true;
627 }
628
629 if received_any {
630 let end = Instant::now();
631 info!(
632 "Dummy run completed in {}s.",
633 end.duration_since(start).as_secs_f64()
634 );
635 } else {
636 warn!("Dummy run failed!");
637 }
638 });
639 }
640
641 let mut engines = HashMap::new();
643 engines.insert(id.clone(), engine_instance);
644
645 Arc::new(Self {
646 engines: RwLock::new(engines),
647 default_engine_id: RwLock::new(Some(id.clone())),
648 log,
649 id,
650 creation_time: SystemTime::now()
651 .duration_since(UNIX_EPOCH)
652 .expect("Time travel has occurred!")
653 .as_secs(),
654 next_request_id: Mutex::new(RefCell::new(1)),
655 })
656 }
657
658 fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
660 let mut engines = self.engines.write().map_err(|_| {
661 tracing::warn!("Couldn't get write lock on engines during reboot attempt");
662 MistralRsError::EnginePoisoned
663 })?;
664
665 if let Some(engine_instance) = engines.get(model_id) {
666 if !engine_instance.engine_handler.is_finished() {
667 tracing::info!("Engine {} already running, returning ok", model_id);
668 return Ok(());
669 }
670
671 let reboot_state = engine_instance.reboot_state.clone();
672 let engine_config = EngineConfig {
673 no_kv_cache: reboot_state.no_kv_cache,
674 no_prefix_cache: reboot_state.no_prefix_cache,
675 prefix_cache_n: reboot_state.prefix_cache_n,
676 disable_eos_stop: reboot_state.disable_eos_stop,
677 throughput_logging_enabled: reboot_state.throughput_logging_enabled,
678 search_embedding_model: reboot_state.search_embedding_model,
679 search_callback: reboot_state.search_callback.clone(),
680 tool_callbacks: reboot_state.tool_callbacks.clone(),
681 tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
682 };
683 let new_engine_instance = Self::create_engine_instance(
684 reboot_state.pipeline.clone(),
685 reboot_state.method.clone(),
686 engine_config,
687 reboot_state,
688 )
689 .map_err(|e| {
690 tracing::error!("Failed to create new engine instance: {}", e);
691 MistralRsError::EnginePoisoned
692 })?;
693
694 engines.insert(model_id.to_string(), new_engine_instance);
695 tracing::info!("Successfully rebooted engine {}", model_id);
696 Ok(())
697 } else {
698 Err(MistralRsError::EnginePoisoned)
699 }
700 }
701
702 fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
703 let engines = self.engines.read().map_err(|_| {
704 tracing::warn!("Couldn't get read lock on engines!");
705 MistralRsError::EnginePoisoned
706 })?;
707
708 if let Some(engine_instance) = engines.get(model_id) {
709 Ok(engine_instance.engine_handler.is_finished())
710 } else {
711 Err(MistralRsError::EnginePoisoned)
712 }
713 }
714
715 pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
717 let resolved_model_id = match model_id {
718 Some(id) => id.to_string(),
719 None => {
720 let default_lock = self
721 .default_engine_id
722 .read()
723 .map_err(|_| MistralRsError::SenderPoisoned)?;
724 default_lock
725 .as_ref()
726 .ok_or(MistralRsError::EnginePoisoned)?
727 .clone()
728 }
729 };
730
731 if self.engine_dead(&resolved_model_id)? {
732 tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
733 self.reboot_engine(&resolved_model_id)?
734 }
735
736 let engines = self
737 .engines
738 .read()
739 .map_err(|_| MistralRsError::SenderPoisoned)?;
740 if let Some(engine_instance) = engines.get(&resolved_model_id) {
741 Ok(engine_instance.sender.clone())
742 } else {
743 Err(MistralRsError::EnginePoisoned)
744 }
745 }
746
747 pub fn get_id(&self) -> String {
748 self.id.clone()
749 }
750
751 pub fn get_creation_time(&self) -> u64 {
752 self.creation_time
753 }
754
755 pub fn get_model_category(
757 &self,
758 model_id: Option<&str>,
759 ) -> Result<ModelCategory, MistralRsError> {
760 let resolved_model_id = match model_id {
761 Some(id) => id.to_string(),
762 None => {
763 let default_lock = self
764 .default_engine_id
765 .read()
766 .map_err(|_| MistralRsError::SenderPoisoned)?;
767 default_lock
768 .as_ref()
769 .ok_or(MistralRsError::EnginePoisoned)?
770 .clone()
771 }
772 };
773
774 let engines = self
775 .engines
776 .read()
777 .map_err(|_| MistralRsError::SenderPoisoned)?;
778 if let Some(engine_instance) = engines.get(&resolved_model_id) {
779 Ok(engine_instance.category.clone())
780 } else {
781 Err(MistralRsError::EnginePoisoned)
782 }
783 }
784
785 pub fn max_sequence_length(
787 &self,
788 model_id: Option<&str>,
789 ) -> Result<Option<usize>, MistralRsError> {
790 let resolved_model_id = match model_id {
791 Some(id) => id.to_string(),
792 None => {
793 let default_lock = self
794 .default_engine_id
795 .read()
796 .map_err(|_| MistralRsError::SenderPoisoned)?;
797 default_lock
798 .as_ref()
799 .ok_or(MistralRsError::EnginePoisoned)?
800 .clone()
801 }
802 };
803
804 let engines = self
805 .engines
806 .read()
807 .map_err(|_| MistralRsError::SenderPoisoned)?;
808 if let Some(engine_instance) = engines.get(&resolved_model_id) {
809 Ok(engine_instance.config.max_seq_len)
810 } else {
811 Err(MistralRsError::EnginePoisoned)
812 }
813 }
814
815 pub fn next_request_id(&self) -> usize {
816 let l = self.next_request_id.lock().unwrap();
817 let last = &mut *l.borrow_mut();
818 let last_v = *last;
819 *last += 1;
820 last_v
821 }
822
823 pub async fn add_model(
825 &self,
826 model_id: String,
827 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
828 method: SchedulerConfig,
829 config: AddModelConfig,
830 ) -> Result<(), String> {
831 let reboot_state = RebootState {
832 pipeline: pipeline.clone(),
833 method: method.clone(),
834 no_kv_cache: config.engine_config.no_kv_cache,
835 no_prefix_cache: config.engine_config.no_prefix_cache,
836 prefix_cache_n: config.engine_config.prefix_cache_n,
837 disable_eos_stop: config.engine_config.disable_eos_stop,
838 throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
839 search_embedding_model: config.engine_config.search_embedding_model,
840 search_callback: config.engine_config.search_callback.clone(),
841 tool_callbacks: config.engine_config.tool_callbacks.clone(),
842 tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
843 mcp_client_config: config.mcp_client_config.clone(),
844 };
845
846 let engine_instance =
847 Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
848
849 let mut engines = self
850 .engines
851 .write()
852 .map_err(|_| "Failed to acquire write lock on engines")?;
853 engines.insert(model_id.clone(), engine_instance);
854
855 if engines.len() == 1 {
857 let mut default_lock = self
858 .default_engine_id
859 .write()
860 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
861 *default_lock = Some(model_id.clone());
862 }
863
864 Ok(())
865 }
866
867 pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
869 let mut engines = self
870 .engines
871 .write()
872 .map_err(|_| "Failed to acquire write lock on engines")?;
873
874 if engines.len() <= 1 {
875 return Err("Cannot remove the last model from MistralRs".to_string());
876 }
877
878 if let Some(engine_instance) = engines.remove(model_id) {
879 let _ = engine_instance.sender.blocking_send(Request::Terminate);
881
882 let mut default_lock = self
884 .default_engine_id
885 .write()
886 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
887 if let Some(ref default_id) = *default_lock {
888 if default_id == model_id {
889 *default_lock = engines.keys().next().cloned();
891 }
892 }
893
894 Ok(())
895 } else {
896 Err(format!("Model {model_id} not found"))
897 }
898 }
899
900 pub fn list_models(&self) -> Result<Vec<String>, String> {
902 let engines = self
903 .engines
904 .read()
905 .map_err(|_| "Failed to acquire read lock on engines")?;
906 Ok(engines.keys().cloned().collect())
907 }
908
909 pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
911 let default_lock = self
912 .default_engine_id
913 .read()
914 .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
915 Ok(default_lock.clone())
916 }
917
918 pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
920 let engines = self
921 .engines
922 .read()
923 .map_err(|_| "Failed to acquire read lock on engines")?;
924 if !engines.contains_key(model_id) {
925 return Err(format!("Model {model_id} not found"));
926 }
927 drop(engines);
928
929 let mut default_lock = self
930 .default_engine_id
931 .write()
932 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
933 *default_lock = Some(model_id.to_string());
934
935 Ok(())
936 }
937
938 pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
940 let model_id = match &mut request {
941 Request::Normal(normal_req) => normal_req.model_id.as_deref(),
942 _ => None, };
944
945 let sender = self.get_sender(model_id)?;
946 sender
947 .blocking_send(request)
948 .map_err(|_| MistralRsError::SenderPoisoned)
949 }
950
951 pub fn maybe_log_request(this: Arc<Self>, repr: String) {
952 if let Some(file) = &this.log {
953 let mut f = OpenOptions::new()
954 .append(true)
955 .create(true) .open(file)
957 .expect("Unable to open file");
958 let time = chrono::offset::Local::now();
959 f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
960 .expect("Unable to write data");
961 }
962 }
963
964 pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
965 if let Some(file) = &this.log {
966 let mut f = OpenOptions::new()
967 .append(true)
968 .create(true) .open(file)
970 .expect("Unable to open file");
971 let time = chrono::offset::Local::now();
972 let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
973 f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
974 .expect("Unable to write data");
975 }
976 }
977
978 pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
979 if let Some(file) = &this.log {
980 let mut f = OpenOptions::new()
981 .append(true)
982 .create(true) .open(file)
984 .expect("Unable to open file");
985 let time = chrono::offset::Local::now();
986 f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
987 .expect("Unable to write data");
988 }
989 }
990
991 pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
993 let resolved_model_id = match model_id {
994 Some(id) => id.to_string(),
995 None => {
996 let default_lock = self
997 .default_engine_id
998 .read()
999 .map_err(|_| "Failed to acquire read lock")?;
1000 default_lock
1001 .as_ref()
1002 .ok_or("No default engine set")?
1003 .clone()
1004 }
1005 };
1006
1007 let engines = self
1008 .engines
1009 .read()
1010 .map_err(|_| "Failed to acquire read lock on engines")?;
1011 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1012 Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
1013 } else {
1014 Err(format!("Model {resolved_model_id} not found"))
1015 }
1016 }
1017
1018 pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1020 let resolved_model_id = match model_id {
1021 Some(id) => id.to_string(),
1022 None => {
1023 let default_lock = self
1024 .default_engine_id
1025 .read()
1026 .map_err(|_| "Failed to acquire read lock")?;
1027 default_lock
1028 .as_ref()
1029 .ok_or("No default engine set")?
1030 .clone()
1031 }
1032 };
1033
1034 let engines = self
1035 .engines
1036 .read()
1037 .map_err(|_| "Failed to acquire read lock on engines")?;
1038 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1039 Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1040 } else {
1041 Err(format!("Model {resolved_model_id} not found"))
1042 }
1043 }
1044
1045 pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1047 let resolved_model_id = match model_id {
1048 Some(id) => id.to_string(),
1049 None => {
1050 let default_lock = self
1051 .default_engine_id
1052 .read()
1053 .map_err(|_| "Failed to acquire read lock")?;
1054 default_lock
1055 .as_ref()
1056 .ok_or("No default engine set")?
1057 .clone()
1058 }
1059 };
1060
1061 let engines = self
1062 .engines
1063 .read()
1064 .map_err(|_| "Failed to acquire read lock on engines")?;
1065 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1066 Ok(engine_instance.config.clone())
1067 } else {
1068 Err(format!("Model {resolved_model_id} not found"))
1069 }
1070 }
1071}