mistralrs/
multi_model.rs

1use anyhow::Context;
2use mistralrs_core::*;
3use std::sync::Arc;
4
5use crate::Model;
6
7/// A simpler multi-model interface that wraps an existing MistralRs instance
8/// and provides methods to interact with multiple loaded models.
9pub struct MultiModel {
10    runner: Arc<MistralRs>,
11}
12
13impl MultiModel {
14    /// Create a MultiModel from an existing Model that has multiple models loaded.
15    /// This is useful when you've created a Model using regular builders and then
16    /// added more models to it using the add_model method.
17    pub fn from_model(model: Model) -> Self {
18        Self {
19            runner: model.runner,
20        }
21    }
22
23    /// Create a MultiModel directly from a MistralRs instance.
24    pub fn from_mistralrs(mistralrs: Arc<MistralRs>) -> Self {
25        Self { runner: mistralrs }
26    }
27
28    /// List all available model IDs.
29    pub fn list_models(&self) -> Result<Vec<String>, String> {
30        self.runner.list_models()
31    }
32
33    /// Get the default model ID.
34    pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
35        self.runner.get_default_model_id()
36    }
37
38    /// Set the default model ID.
39    pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
40        self.runner.set_default_model_id(model_id)
41    }
42
43    /// Remove a model by ID.
44    pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
45        self.runner.remove_model(model_id)
46    }
47
48    /// Add a new model to the multi-model instance.
49    pub async fn add_model(
50        &self,
51        model_id: String,
52        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
53        method: SchedulerConfig,
54        config: mistralrs_core::AddModelConfig,
55    ) -> Result<(), String> {
56        self.runner
57            .add_model(model_id, pipeline, method, config)
58            .await
59    }
60
61    /// Send a chat request to a specific model.
62    pub async fn send_chat_request_to_model(
63        &self,
64        mut request: impl crate::RequestLike,
65        model_id: Option<&str>,
66    ) -> anyhow::Result<ChatCompletionResponse> {
67        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
68
69        let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
70            (Some(a), Some(b))
71        } else {
72            (None, None)
73        };
74
75        let request = Request::Normal(Box::new(NormalRequest {
76            messages: request.take_messages(),
77            sampling_params: request.take_sampling_params(),
78            response: tx,
79            return_logprobs: request.return_logprobs(),
80            is_streaming: false,
81            id: 0,
82            constraint: request.take_constraint(),
83            suffix: None,
84            tools,
85            tool_choice,
86            logits_processors: request.take_logits_processors(),
87            return_raw_logits: false,
88            web_search_options: request.take_web_search_options(),
89            model_id: model_id.map(|s| s.to_string()),
90        }));
91
92        self.runner.get_sender(model_id)?.send(request).await?;
93
94        let Response::Done(response) =
95            rx.recv().await.context("Channel was erroneously closed!")?
96        else {
97            anyhow::bail!("Got unexpected response, expected `Response::Done`");
98        };
99
100        Ok(response)
101    }
102
103    /// Get the underlying MistralRs instance.
104    pub fn inner(&self) -> &MistralRs {
105        &self.runner
106    }
107
108    /// Get configuration for a specific model.
109    pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
110        self.runner.config(model_id)
111    }
112}