1use anyhow::Context;
2use candle_core::{Device, Result, Tensor};
3use either::Either;
4use futures::future::join_all;
5use mistralrs_core::*;
6use std::sync::Arc;
7use tokio::sync::mpsc::{channel, Receiver};
8
9use crate::{EmbeddingRequest, EmbeddingRequestBuilder, RequestLike, TextMessages};
10
11pub fn best_device(force_cpu: bool) -> Result<Device> {
13 if force_cpu {
14 return Ok(Device::Cpu);
15 }
16 #[cfg(not(feature = "metal"))]
17 {
18 Device::cuda_if_available(0)
19 }
20 #[cfg(feature = "metal")]
21 {
22 Device::new_metal(0)
23 }
24}
25
26pub struct Model {
48 pub(crate) runner: Arc<MistralRs>,
49}
50
51pub struct Stream<'a> {
52 _server: &'a Model,
53 rx: Receiver<Response>,
54}
55
56impl Stream<'_> {
57 pub async fn next(&mut self) -> Option<Response> {
58 self.rx.recv().await
59 }
60}
61
62impl Model {
63 pub fn new(runner: Arc<MistralRs>) -> Self {
64 Self { runner }
65 }
66
67 pub async fn stream_chat_request<R: RequestLike>(
69 &self,
70 mut request: R,
71 ) -> anyhow::Result<Stream<'_>> {
72 let (tx, rx) = channel(1);
73
74 let truncate_sequence = request.truncate_sequence();
75 let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
76 (Some(a), Some(b))
77 } else {
78 (None, None)
79 };
80 let request = Request::Normal(Box::new(NormalRequest {
81 messages: request.take_messages(),
82 sampling_params: request.take_sampling_params(),
83 response: tx,
84 return_logprobs: request.return_logprobs(),
85 is_streaming: true,
86 id: 0,
87 constraint: request.take_constraint(),
88 suffix: None,
89 tools,
90 tool_choice,
91 logits_processors: request.take_logits_processors(),
92 return_raw_logits: false,
93 web_search_options: request.take_web_search_options(),
94 model_id: None,
95 truncate_sequence,
96 }));
97
98 self.runner.get_sender(None)?.send(request).await?;
99
100 let stream = Stream { _server: self, rx };
101
102 Ok(stream)
103 }
104
105 pub async fn send_chat_request<R: RequestLike>(
107 &self,
108 mut request: R,
109 ) -> anyhow::Result<ChatCompletionResponse> {
110 let (tx, mut rx) = channel(1);
111
112 let truncate_sequence = request.truncate_sequence();
113 let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
114 (Some(a), Some(b))
115 } else {
116 (None, None)
117 };
118 let request = Request::Normal(Box::new(NormalRequest {
119 messages: request.take_messages(),
120 sampling_params: request.take_sampling_params(),
121 response: tx,
122 return_logprobs: request.return_logprobs(),
123 is_streaming: false,
124 id: 0,
125 constraint: request.take_constraint(),
126 suffix: None,
127 tools,
128 tool_choice,
129 logits_processors: request.take_logits_processors(),
130 return_raw_logits: false,
131 web_search_options: request.take_web_search_options(),
132 model_id: None,
133 truncate_sequence,
134 }));
135
136 self.runner.get_sender(None)?.send(request).await?;
137
138 let ResponseOk::Done(response) = rx
139 .recv()
140 .await
141 .context("Channel was erroneously closed!")?
142 .as_result()?
143 else {
144 anyhow::bail!("Got unexpected response type.")
145 };
146
147 Ok(response)
148 }
149
150 pub async fn send_raw_chat_request<R: RequestLike>(
154 &self,
155 mut request: R,
156 ) -> anyhow::Result<(Vec<Tensor>, Vec<u32>)> {
157 let (tx, mut rx) = channel(1);
158
159 let truncate_sequence = request.truncate_sequence();
160 let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
161 (Some(a), Some(b))
162 } else {
163 (None, None)
164 };
165 let request = Request::Normal(Box::new(NormalRequest {
166 messages: request.take_messages(),
167 sampling_params: request.take_sampling_params(),
168 response: tx,
169 return_logprobs: request.return_logprobs(),
170 is_streaming: false,
171 id: 0,
172 constraint: request.take_constraint(),
173 suffix: None,
174 tools,
175 tool_choice,
176 logits_processors: request.take_logits_processors(),
177 return_raw_logits: true,
178 web_search_options: request.take_web_search_options(),
179 model_id: None,
180 truncate_sequence,
181 }));
182
183 self.runner.get_sender(None)?.send(request).await?;
184
185 let ResponseOk::Raw {
186 logits_chunks,
187 tokens,
188 } = rx
189 .recv()
190 .await
191 .context("Channel was erroneously closed!")?
192 .as_result()?
193 else {
194 anyhow::bail!("Got unexpected response type.")
195 };
196
197 Ok((logits_chunks, tokens))
198 }
199
200 pub async fn generate_image(
201 &self,
202 prompt: impl ToString,
203 response_format: ImageGenerationResponseFormat,
204 generation_params: DiffusionGenerationParams,
205 ) -> anyhow::Result<ImageGenerationResponse> {
206 let (tx, mut rx) = channel(1);
207
208 let request = Request::Normal(Box::new(NormalRequest {
209 id: 0,
210 messages: RequestMessage::ImageGeneration {
211 prompt: prompt.to_string(),
212 format: response_format,
213 generation_params,
214 },
215 sampling_params: SamplingParams::deterministic(),
216 response: tx,
217 return_logprobs: false,
218 is_streaming: false,
219 suffix: None,
220 constraint: Constraint::None,
221 tool_choice: None,
222 tools: None,
223 logits_processors: None,
224 return_raw_logits: false,
225 web_search_options: None,
226 model_id: None,
227 truncate_sequence: false,
228 }));
229
230 self.runner.get_sender(None)?.send(request).await?;
231
232 let ResponseOk::ImageGeneration(response) = rx
233 .recv()
234 .await
235 .context("Channel was erroneously closed!")?
236 .as_result()?
237 else {
238 anyhow::bail!("Got unexpected response type.")
239 };
240
241 Ok(response)
242 }
243
244 pub async fn generate_speech(
248 &self,
249 prompt: impl ToString,
250 ) -> anyhow::Result<(Arc<Vec<f32>>, usize, usize)> {
251 let (tx, mut rx) = channel(1);
252
253 let request = Request::Normal(Box::new(NormalRequest {
254 id: 0,
255 messages: RequestMessage::SpeechGeneration {
256 prompt: prompt.to_string(),
257 },
258 sampling_params: SamplingParams::deterministic(),
259 response: tx,
260 return_logprobs: false,
261 is_streaming: false,
262 suffix: None,
263 constraint: Constraint::None,
264 tool_choice: None,
265 tools: None,
266 logits_processors: None,
267 return_raw_logits: false,
268 web_search_options: None,
269 model_id: None,
270 truncate_sequence: false,
271 }));
272
273 self.runner.get_sender(None)?.send(request).await?;
274
275 let ResponseOk::Speech {
276 pcm,
277 rate,
278 channels,
279 } = rx
280 .recv()
281 .await
282 .context("Channel was erroneously closed!")?
283 .as_result()?
284 else {
285 anyhow::bail!("Got unexpected response type.")
286 };
287
288 Ok((pcm, rate, channels))
289 }
290
291 pub async fn generate_embeddings(
295 &self,
296 request: EmbeddingRequestBuilder,
297 ) -> anyhow::Result<Vec<Vec<f32>>> {
298 let request = request.build()?;
299 let EmbeddingRequest {
300 inputs,
301 truncate_sequence,
302 } = request;
303
304 let runner = self.runner.clone();
305 let futures = inputs.into_iter().map(|input| {
306 let runner = runner.clone();
307 async move {
308 let message = input.into_request_message();
309 let (tx, mut rx) = channel(1);
310
311 let request = Request::Normal(Box::new(NormalRequest {
312 id: 0,
313 messages: message,
314 sampling_params: SamplingParams::deterministic(),
315 response: tx,
316 return_logprobs: false,
317 is_streaming: false,
318 suffix: None,
319 constraint: Constraint::None,
320 tool_choice: None,
321 tools: None,
322 logits_processors: None,
323 return_raw_logits: false,
324 web_search_options: None,
325 model_id: None,
326 truncate_sequence,
327 }));
328
329 runner
330 .get_sender(None)?
331 .send(request)
332 .await
333 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
334
335 let ResponseOk::Embeddings { embeddings, .. } = rx
336 .recv()
337 .await
338 .context("Channel was erroneously closed!")?
339 .as_result()?
340 else {
341 anyhow::bail!("Got unexpected response type.")
342 };
343
344 Ok::<Vec<f32>, anyhow::Error>(embeddings)
345 }
346 });
347
348 let results = join_all(futures).await;
349 let mut embeddings = Vec::with_capacity(results.len());
350 for result in results {
351 embeddings.push(result?);
352 }
353 Ok(embeddings)
354 }
355
356 pub async fn generate_embedding(&self, prompt: impl ToString) -> anyhow::Result<Vec<f32>> {
358 let mut embeddings = self
359 .generate_embeddings(EmbeddingRequest::builder().add_prompt(prompt.to_string()))
360 .await?;
361
362 Ok(embeddings
363 .pop()
364 .expect("EmbeddingRequestBuilder should guarantee at least one input"))
365 }
366
367 pub async fn re_isq_model(&self, isq_type: IsqType) -> anyhow::Result<()> {
369 let request = Request::ReIsq(isq_type);
370
371 Ok(self.runner.get_sender(None)?.send(request).await?)
372 }
373
374 pub async fn tokenize(
377 &self,
378 text: Either<TextMessages, String>,
379 tools: Option<Vec<Tool>>,
380 add_special_tokens: bool,
381 add_generation_prompt: bool,
382 enable_thinking: Option<bool>,
383 ) -> anyhow::Result<Vec<u32>> {
384 let (tx, mut rx) = channel(1);
385 let request = Request::Tokenize(TokenizationRequest {
386 text: text.map_left(Into::into),
387 tools,
388 add_special_tokens,
389 add_generation_prompt,
390 response: tx,
391 enable_thinking,
392 });
393 self.runner.get_sender(None)?.send(request).await?;
394
395 rx.recv().await.context("Channel was erroneously closed!")?
396 }
397
398 pub async fn detokenize(
400 &self,
401 tokens: Vec<u32>,
402 skip_special_tokens: bool,
403 ) -> anyhow::Result<String> {
404 let (tx, mut rx) = channel(1);
405 let request = Request::Detokenize(DetokenizationRequest {
406 tokens,
407 skip_special_tokens,
408 response: tx,
409 });
410 self.runner.get_sender(None)?.send(request).await?;
411
412 rx.recv().await.context("Channel was erroneously closed!")?
413 }
414
415 pub fn config(&self) -> std::result::Result<MistralRsConfig, String> {
417 self.runner.config(None)
418 }
419
420 pub fn max_sequence_length(&self) -> std::result::Result<Option<usize>, MistralRsError> {
422 self.runner.max_sequence_length(None)
423 }
424
425 pub fn inner(&self) -> &MistralRs {
426 &self.runner
427 }
428}