mistralrs/
multi_model.rs

1use anyhow::Context;
2use mistralrs_core::*;
3use std::sync::Arc;
4
5use crate::Model;
6
7/// A simpler multi-model interface that wraps an existing MistralRs instance
8/// and provides methods to interact with multiple loaded models.
9pub struct MultiModel {
10    runner: Arc<MistralRs>,
11}
12
13impl MultiModel {
14    /// Create a MultiModel from an existing Model that has multiple models loaded.
15    /// This is useful when you've created a Model using regular builders and then
16    /// added more models to it using the add_model method.
17    pub fn from_model(model: Model) -> Self {
18        Self {
19            runner: model.runner,
20        }
21    }
22
23    /// Create a MultiModel directly from a MistralRs instance.
24    pub fn from_mistralrs(mistralrs: Arc<MistralRs>) -> Self {
25        Self { runner: mistralrs }
26    }
27
28    /// List all available model IDs.
29    pub fn list_models(&self) -> Result<Vec<String>, String> {
30        self.runner.list_models()
31    }
32
33    /// Get the default model ID.
34    pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
35        self.runner.get_default_model_id()
36    }
37
38    /// Set the default model ID.
39    pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
40        self.runner.set_default_model_id(model_id)
41    }
42
43    /// Remove a model by ID.
44    pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
45        self.runner.remove_model(model_id)
46    }
47
48    /// Add a new model to the multi-model instance.
49    pub async fn add_model(
50        &self,
51        model_id: String,
52        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
53        method: SchedulerConfig,
54        config: mistralrs_core::AddModelConfig,
55    ) -> Result<(), String> {
56        self.runner
57            .add_model(model_id, pipeline, method, config)
58            .await
59    }
60
61    /// Send a chat request to a specific model.
62    pub async fn send_chat_request_to_model(
63        &self,
64        mut request: impl crate::RequestLike,
65        model_id: Option<&str>,
66    ) -> anyhow::Result<ChatCompletionResponse> {
67        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
68
69        let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
70            (Some(a), Some(b))
71        } else {
72            (None, None)
73        };
74        let truncate_sequence = request.truncate_sequence();
75
76        let request = Request::Normal(Box::new(NormalRequest {
77            messages: request.take_messages(),
78            sampling_params: request.take_sampling_params(),
79            response: tx,
80            return_logprobs: request.return_logprobs(),
81            is_streaming: false,
82            id: 0,
83            constraint: request.take_constraint(),
84            suffix: None,
85            tools,
86            tool_choice,
87            logits_processors: request.take_logits_processors(),
88            return_raw_logits: false,
89            web_search_options: request.take_web_search_options(),
90            model_id: model_id.map(|s| s.to_string()),
91            truncate_sequence,
92        }));
93
94        self.runner.get_sender(model_id)?.send(request).await?;
95
96        let Response::Done(response) =
97            rx.recv().await.context("Channel was erroneously closed!")?
98        else {
99            anyhow::bail!("Got unexpected response, expected `Response::Done`");
100        };
101
102        Ok(response)
103    }
104
105    /// Get the underlying MistralRs instance.
106    pub fn inner(&self) -> &MistralRs {
107        &self.runner
108    }
109
110    /// Get configuration for a specific model.
111    pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
112        self.runner.config(model_id)
113    }
114}