1use anyhow::Context;
2use candle_core::{Device, Result, Tensor};
3use either::Either;
4use mistralrs_core::*;
5use std::sync::Arc;
6use tokio::sync::mpsc::{channel, Receiver};
7
8use crate::{RequestLike, TextMessages};
9
10pub fn best_device(force_cpu: bool) -> Result<Device> {
12 if force_cpu {
13 return Ok(Device::Cpu);
14 }
15 #[cfg(not(feature = "metal"))]
16 {
17 Device::cuda_if_available(0)
18 }
19 #[cfg(feature = "metal")]
20 {
21 Device::new_metal(0)
22 }
23}
24
25pub struct Model {
47 runner: Arc<MistralRs>,
48}
49
50pub struct Stream<'a> {
51 _server: &'a Model,
52 rx: Receiver<Response>,
53}
54
55impl Stream<'_> {
56 pub async fn next(&mut self) -> Option<Response> {
57 self.rx.recv().await
58 }
59}
60
61impl Model {
62 pub fn new(runner: Arc<MistralRs>) -> Self {
63 Self { runner }
64 }
65
66 pub async fn stream_chat_request<R: RequestLike>(
68 &self,
69 mut request: R,
70 ) -> anyhow::Result<Stream> {
71 let (tx, rx) = channel(1);
72
73 let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
74 (Some(a), Some(b))
75 } else {
76 (None, None)
77 };
78 let request = Request::Normal(NormalRequest {
79 messages: request.take_messages(),
80 sampling_params: request.take_sampling_params(),
81 response: tx,
82 return_logprobs: request.return_logprobs(),
83 is_streaming: true,
84 id: 0,
85 constraint: request.take_constraint(),
86 suffix: None,
87 tools,
88 tool_choice,
89 logits_processors: request.take_logits_processors(),
90 return_raw_logits: false,
91 web_search_options: request.take_web_search_options(),
92 });
93
94 self.runner.get_sender()?.send(request).await?;
95
96 let stream = Stream { _server: self, rx };
97
98 Ok(stream)
99 }
100
101 pub async fn send_chat_request<R: RequestLike>(
103 &self,
104 mut request: R,
105 ) -> anyhow::Result<ChatCompletionResponse> {
106 let (tx, mut rx) = channel(1);
107
108 let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
109 (Some(a), Some(b))
110 } else {
111 (None, None)
112 };
113 let request = Request::Normal(NormalRequest {
114 messages: request.take_messages(),
115 sampling_params: request.take_sampling_params(),
116 response: tx,
117 return_logprobs: request.return_logprobs(),
118 is_streaming: false,
119 id: 0,
120 constraint: request.take_constraint(),
121 suffix: None,
122 tools,
123 tool_choice,
124 logits_processors: request.take_logits_processors(),
125 return_raw_logits: false,
126 web_search_options: request.take_web_search_options(),
127 });
128
129 self.runner.get_sender()?.send(request).await?;
130
131 let ResponseOk::Done(response) = rx
132 .recv()
133 .await
134 .context("Channel was erroneously closed!")?
135 .as_result()?
136 else {
137 anyhow::bail!("Got unexpected response type.")
138 };
139
140 Ok(response)
141 }
142
143 pub async fn send_raw_chat_request<R: RequestLike>(
147 &self,
148 mut request: R,
149 ) -> anyhow::Result<(Vec<Tensor>, Vec<u32>)> {
150 let (tx, mut rx) = channel(1);
151
152 let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
153 (Some(a), Some(b))
154 } else {
155 (None, None)
156 };
157 let request = Request::Normal(NormalRequest {
158 messages: request.take_messages(),
159 sampling_params: request.take_sampling_params(),
160 response: tx,
161 return_logprobs: request.return_logprobs(),
162 is_streaming: false,
163 id: 0,
164 constraint: request.take_constraint(),
165 suffix: None,
166 tools,
167 tool_choice,
168 logits_processors: request.take_logits_processors(),
169 return_raw_logits: true,
170 web_search_options: request.take_web_search_options(),
171 });
172
173 self.runner.get_sender()?.send(request).await?;
174
175 let ResponseOk::Raw {
176 logits_chunks,
177 tokens,
178 } = rx
179 .recv()
180 .await
181 .context("Channel was erroneously closed!")?
182 .as_result()?
183 else {
184 anyhow::bail!("Got unexpected response type.")
185 };
186
187 Ok((logits_chunks, tokens))
188 }
189
190 pub async fn generate_image(
191 &self,
192 prompt: impl ToString,
193 response_format: ImageGenerationResponseFormat,
194 generation_params: DiffusionGenerationParams,
195 ) -> anyhow::Result<ImageGenerationResponse> {
196 let (tx, mut rx) = channel(1);
197
198 let request = Request::Normal(NormalRequest {
199 id: 0,
200 messages: RequestMessage::ImageGeneration {
201 prompt: prompt.to_string(),
202 format: response_format,
203 generation_params,
204 },
205 sampling_params: SamplingParams::deterministic(),
206 response: tx,
207 return_logprobs: false,
208 is_streaming: false,
209 suffix: None,
210 constraint: Constraint::None,
211 tool_choice: None,
212 tools: None,
213 logits_processors: None,
214 return_raw_logits: false,
215 web_search_options: None,
216 });
217
218 self.runner.get_sender()?.send(request).await?;
219
220 let ResponseOk::ImageGeneration(response) = rx
221 .recv()
222 .await
223 .context("Channel was erroneously closed!")?
224 .as_result()?
225 else {
226 anyhow::bail!("Got unexpected response type.")
227 };
228
229 Ok(response)
230 }
231
232 pub async fn re_isq_model(&self, isq_type: IsqType) -> anyhow::Result<()> {
234 let request = Request::ReIsq(isq_type);
235
236 Ok(self.runner.get_sender()?.send(request).await?)
237 }
238
239 pub async fn tokenize(
242 &self,
243 text: Either<TextMessages, String>,
244 tools: Option<Vec<Tool>>,
245 add_special_tokens: bool,
246 add_generation_prompt: bool,
247 ) -> anyhow::Result<Vec<u32>> {
248 let (tx, mut rx) = channel(1);
249 let request = Request::Tokenize(TokenizationRequest {
250 text: text.map_left(Into::into),
251 tools,
252 add_special_tokens,
253 add_generation_prompt,
254 response: tx,
255 });
256 self.runner.get_sender()?.send(request).await?;
257
258 rx.recv().await.context("Channel was erroneously closed!")?
259 }
260
261 pub async fn detokenize(
263 &self,
264 tokens: Vec<u32>,
265 skip_special_tokens: bool,
266 ) -> anyhow::Result<String> {
267 let (tx, mut rx) = channel(1);
268 let request = Request::Detokenize(DetokenizationRequest {
269 tokens,
270 skip_special_tokens,
271 response: tx,
272 });
273 self.runner.get_sender()?.send(request).await?;
274
275 rx.recv().await.context("Channel was erroneously closed!")?
276 }
277
278 pub fn config(&self) -> &MistralRsConfig {
280 self.runner.config()
281 }
282
283 pub fn inner(&self) -> &MistralRs {
284 &self.runner
285 }
286}