1use anyhow::Context;
2use mistralrs_core::*;
3use std::sync::Arc;
4
5use crate::Model;
6
7pub struct MultiModel {
10 runner: Arc<MistralRs>,
11}
12
13impl MultiModel {
14 pub fn from_model(model: Model) -> Self {
18 Self {
19 runner: model.runner,
20 }
21 }
22
23 pub fn from_mistralrs(mistralrs: Arc<MistralRs>) -> Self {
25 Self { runner: mistralrs }
26 }
27
28 pub fn list_models(&self) -> Result<Vec<String>, String> {
30 self.runner.list_models()
31 }
32
33 pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
35 self.runner.get_default_model_id()
36 }
37
38 pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
40 self.runner.set_default_model_id(model_id)
41 }
42
43 pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
45 self.runner.remove_model(model_id)
46 }
47
48 pub async fn add_model(
50 &self,
51 model_id: String,
52 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
53 method: SchedulerConfig,
54 config: mistralrs_core::AddModelConfig,
55 ) -> Result<(), String> {
56 self.runner
57 .add_model(model_id, pipeline, method, config)
58 .await
59 }
60
61 pub async fn send_chat_request_to_model(
63 &self,
64 mut request: impl crate::RequestLike,
65 model_id: Option<&str>,
66 ) -> anyhow::Result<ChatCompletionResponse> {
67 let (tx, mut rx) = tokio::sync::mpsc::channel(1);
68
69 let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
70 (Some(a), Some(b))
71 } else {
72 (None, None)
73 };
74
75 let request = Request::Normal(Box::new(NormalRequest {
76 messages: request.take_messages(),
77 sampling_params: request.take_sampling_params(),
78 response: tx,
79 return_logprobs: request.return_logprobs(),
80 is_streaming: false,
81 id: 0,
82 constraint: request.take_constraint(),
83 suffix: None,
84 tools,
85 tool_choice,
86 logits_processors: request.take_logits_processors(),
87 return_raw_logits: false,
88 web_search_options: request.take_web_search_options(),
89 model_id: model_id.map(|s| s.to_string()),
90 }));
91
92 self.runner.get_sender(model_id)?.send(request).await?;
93
94 let Response::Done(response) =
95 rx.recv().await.context("Channel was erroneously closed!")?
96 else {
97 anyhow::bail!("Got unexpected response, expected `Response::Done`");
98 };
99
100 Ok(response)
101 }
102
103 pub fn inner(&self) -> &MistralRs {
105 &self.runner
106 }
107
108 pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
110 self.runner.config(model_id)
111 }
112}