1use anyhow::Context;
2use candle_core::{Device, Result, Tensor};
3use either::Either;
4use mistralrs_core::*;
5use std::sync::Arc;
6use tokio::sync::mpsc::{channel, Receiver};
7
8use crate::{RequestLike, TextMessages};
9
10pub fn best_device(force_cpu: bool) -> Result<Device> {
12 if force_cpu {
13 return Ok(Device::Cpu);
14 }
15 #[cfg(not(feature = "metal"))]
16 {
17 Device::cuda_if_available(0)
18 }
19 #[cfg(feature = "metal")]
20 {
21 Device::new_metal(0)
22 }
23}
24
25pub struct Model {
47 pub(crate) runner: Arc<MistralRs>,
48}
49
50pub struct Stream<'a> {
51 _server: &'a Model,
52 rx: Receiver<Response>,
53}
54
55impl Stream<'_> {
56 pub async fn next(&mut self) -> Option<Response> {
57 self.rx.recv().await
58 }
59}
60
61impl Model {
62 pub fn new(runner: Arc<MistralRs>) -> Self {
63 Self { runner }
64 }
65
66 pub async fn stream_chat_request<R: RequestLike>(
68 &self,
69 mut request: R,
70 ) -> anyhow::Result<Stream> {
71 let (tx, rx) = channel(1);
72
73 let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
74 (Some(a), Some(b))
75 } else {
76 (None, None)
77 };
78 let request = Request::Normal(Box::new(NormalRequest {
79 messages: request.take_messages(),
80 sampling_params: request.take_sampling_params(),
81 response: tx,
82 return_logprobs: request.return_logprobs(),
83 is_streaming: true,
84 id: 0,
85 constraint: request.take_constraint(),
86 suffix: None,
87 tools,
88 tool_choice,
89 logits_processors: request.take_logits_processors(),
90 return_raw_logits: false,
91 web_search_options: request.take_web_search_options(),
92 model_id: None,
93 }));
94
95 self.runner.get_sender(None)?.send(request).await?;
96
97 let stream = Stream { _server: self, rx };
98
99 Ok(stream)
100 }
101
102 pub async fn send_chat_request<R: RequestLike>(
104 &self,
105 mut request: R,
106 ) -> anyhow::Result<ChatCompletionResponse> {
107 let (tx, mut rx) = channel(1);
108
109 let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
110 (Some(a), Some(b))
111 } else {
112 (None, None)
113 };
114 let request = Request::Normal(Box::new(NormalRequest {
115 messages: request.take_messages(),
116 sampling_params: request.take_sampling_params(),
117 response: tx,
118 return_logprobs: request.return_logprobs(),
119 is_streaming: false,
120 id: 0,
121 constraint: request.take_constraint(),
122 suffix: None,
123 tools,
124 tool_choice,
125 logits_processors: request.take_logits_processors(),
126 return_raw_logits: false,
127 web_search_options: request.take_web_search_options(),
128 model_id: None,
129 }));
130
131 self.runner.get_sender(None)?.send(request).await?;
132
133 let ResponseOk::Done(response) = rx
134 .recv()
135 .await
136 .context("Channel was erroneously closed!")?
137 .as_result()?
138 else {
139 anyhow::bail!("Got unexpected response type.")
140 };
141
142 Ok(response)
143 }
144
145 pub async fn send_raw_chat_request<R: RequestLike>(
149 &self,
150 mut request: R,
151 ) -> anyhow::Result<(Vec<Tensor>, Vec<u32>)> {
152 let (tx, mut rx) = channel(1);
153
154 let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
155 (Some(a), Some(b))
156 } else {
157 (None, None)
158 };
159 let request = Request::Normal(Box::new(NormalRequest {
160 messages: request.take_messages(),
161 sampling_params: request.take_sampling_params(),
162 response: tx,
163 return_logprobs: request.return_logprobs(),
164 is_streaming: false,
165 id: 0,
166 constraint: request.take_constraint(),
167 suffix: None,
168 tools,
169 tool_choice,
170 logits_processors: request.take_logits_processors(),
171 return_raw_logits: true,
172 web_search_options: request.take_web_search_options(),
173 model_id: None,
174 }));
175
176 self.runner.get_sender(None)?.send(request).await?;
177
178 let ResponseOk::Raw {
179 logits_chunks,
180 tokens,
181 } = rx
182 .recv()
183 .await
184 .context("Channel was erroneously closed!")?
185 .as_result()?
186 else {
187 anyhow::bail!("Got unexpected response type.")
188 };
189
190 Ok((logits_chunks, tokens))
191 }
192
193 pub async fn generate_image(
194 &self,
195 prompt: impl ToString,
196 response_format: ImageGenerationResponseFormat,
197 generation_params: DiffusionGenerationParams,
198 ) -> anyhow::Result<ImageGenerationResponse> {
199 let (tx, mut rx) = channel(1);
200
201 let request = Request::Normal(Box::new(NormalRequest {
202 id: 0,
203 messages: RequestMessage::ImageGeneration {
204 prompt: prompt.to_string(),
205 format: response_format,
206 generation_params,
207 },
208 sampling_params: SamplingParams::deterministic(),
209 response: tx,
210 return_logprobs: false,
211 is_streaming: false,
212 suffix: None,
213 constraint: Constraint::None,
214 tool_choice: None,
215 tools: None,
216 logits_processors: None,
217 return_raw_logits: false,
218 web_search_options: None,
219 model_id: None,
220 }));
221
222 self.runner.get_sender(None)?.send(request).await?;
223
224 let ResponseOk::ImageGeneration(response) = rx
225 .recv()
226 .await
227 .context("Channel was erroneously closed!")?
228 .as_result()?
229 else {
230 anyhow::bail!("Got unexpected response type.")
231 };
232
233 Ok(response)
234 }
235
236 pub async fn generate_speech(
240 &self,
241 prompt: impl ToString,
242 ) -> anyhow::Result<(Arc<Vec<f32>>, usize, usize)> {
243 let (tx, mut rx) = channel(1);
244
245 let request = Request::Normal(Box::new(NormalRequest {
246 id: 0,
247 messages: RequestMessage::SpeechGeneration {
248 prompt: prompt.to_string(),
249 },
250 sampling_params: SamplingParams::deterministic(),
251 response: tx,
252 return_logprobs: false,
253 is_streaming: false,
254 suffix: None,
255 constraint: Constraint::None,
256 tool_choice: None,
257 tools: None,
258 logits_processors: None,
259 return_raw_logits: false,
260 web_search_options: None,
261 model_id: None,
262 }));
263
264 self.runner.get_sender(None)?.send(request).await?;
265
266 let ResponseOk::Speech {
267 pcm,
268 rate,
269 channels,
270 } = rx
271 .recv()
272 .await
273 .context("Channel was erroneously closed!")?
274 .as_result()?
275 else {
276 anyhow::bail!("Got unexpected response type.")
277 };
278
279 Ok((pcm, rate, channels))
280 }
281
282 pub async fn re_isq_model(&self, isq_type: IsqType) -> anyhow::Result<()> {
284 let request = Request::ReIsq(isq_type);
285
286 Ok(self.runner.get_sender(None)?.send(request).await?)
287 }
288
289 pub async fn tokenize(
292 &self,
293 text: Either<TextMessages, String>,
294 tools: Option<Vec<Tool>>,
295 add_special_tokens: bool,
296 add_generation_prompt: bool,
297 enable_thinking: Option<bool>,
298 ) -> anyhow::Result<Vec<u32>> {
299 let (tx, mut rx) = channel(1);
300 let request = Request::Tokenize(TokenizationRequest {
301 text: text.map_left(Into::into),
302 tools,
303 add_special_tokens,
304 add_generation_prompt,
305 response: tx,
306 enable_thinking,
307 });
308 self.runner.get_sender(None)?.send(request).await?;
309
310 rx.recv().await.context("Channel was erroneously closed!")?
311 }
312
313 pub async fn detokenize(
315 &self,
316 tokens: Vec<u32>,
317 skip_special_tokens: bool,
318 ) -> anyhow::Result<String> {
319 let (tx, mut rx) = channel(1);
320 let request = Request::Detokenize(DetokenizationRequest {
321 tokens,
322 skip_special_tokens,
323 response: tx,
324 });
325 self.runner.get_sender(None)?.send(request).await?;
326
327 rx.recv().await.context("Channel was erroneously closed!")?
328 }
329
330 pub fn config(&self) -> std::result::Result<MistralRsConfig, String> {
332 self.runner.config(None)
333 }
334
335 pub fn inner(&self) -> &MistralRs {
336 &self.runner
337 }
338}