mistralrs/
model.rs

1use anyhow::Context;
2use candle_core::{Device, Result, Tensor};
3use either::Either;
4use futures::future::join_all;
5use mistralrs_core::*;
6use std::sync::Arc;
7use tokio::sync::mpsc::{channel, Receiver};
8
9use crate::{EmbeddingRequest, EmbeddingRequestBuilder, RequestLike, TextMessages};
10
11/// Gets the best device, cpu, cuda if compiled with CUDA, or Metal
12pub fn best_device(force_cpu: bool) -> Result<Device> {
13    if force_cpu {
14        return Ok(Device::Cpu);
15    }
16    #[cfg(not(feature = "metal"))]
17    {
18        Device::cuda_if_available(0)
19    }
20    #[cfg(feature = "metal")]
21    {
22        Device::new_metal(0)
23    }
24}
25
26/// The object used to interact with the model. This can be used with many varietes of models, \
27/// and as such may be created with one of:
28/// - [`TextModelBuilder`]
29/// - [`LoraModelBuilder`]
30/// - [`XLoraModelBuilder`]
31/// - [`GgufModelBuilder`]
32/// - [`GgufLoraModelBuilder`]
33/// - [`GgufXLoraModelBuilder`]
34/// - [`VisionModelBuilder`]
35/// - [`AnyMoeModelBuilder`]
36///
37/// [`TextModelBuilder`]: crate::TextModelBuilder
38/// [`LoraModelBuilder`]: crate::LoraModelBuilder
39/// [`XLoraModelBuilder`]: crate::XLoraModelBuilder
40/// [`GgufModelBuilder`]: crate::GgufModelBuilder
41/// [`GgufModelBuilder`]: crate::GgufModelBuilder
42/// [`GgufLoraModelBuilder`]: crate::GgufLoraModelBuilder
43/// [`GgufXLoraModelBuilder`]: crate::GgufXLoraModelBuilder
44/// [`VisionModelBuilder`]: crate::VisionModelBuilder
45/// [`AnyMoeModelBuilder`]: crate::AnyMoeModelBuilder
46///
47pub struct Model {
48    pub(crate) runner: Arc<MistralRs>,
49}
50
51pub struct Stream<'a> {
52    _server: &'a Model,
53    rx: Receiver<Response>,
54}
55
56impl Stream<'_> {
57    pub async fn next(&mut self) -> Option<Response> {
58        self.rx.recv().await
59    }
60}
61
62impl Model {
63    pub fn new(runner: Arc<MistralRs>) -> Self {
64        Self { runner }
65    }
66
67    /// Generate with the model.
68    pub async fn stream_chat_request<R: RequestLike>(
69        &self,
70        mut request: R,
71    ) -> anyhow::Result<Stream<'_>> {
72        let (tx, rx) = channel(1);
73
74        let truncate_sequence = request.truncate_sequence();
75        let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
76            (Some(a), Some(b))
77        } else {
78            (None, None)
79        };
80        let request = Request::Normal(Box::new(NormalRequest {
81            messages: request.take_messages(),
82            sampling_params: request.take_sampling_params(),
83            response: tx,
84            return_logprobs: request.return_logprobs(),
85            is_streaming: true,
86            id: 0,
87            constraint: request.take_constraint(),
88            suffix: None,
89            tools,
90            tool_choice,
91            logits_processors: request.take_logits_processors(),
92            return_raw_logits: false,
93            web_search_options: request.take_web_search_options(),
94            model_id: None,
95            truncate_sequence,
96        }));
97
98        self.runner.get_sender(None)?.send(request).await?;
99
100        let stream = Stream { _server: self, rx };
101
102        Ok(stream)
103    }
104
105    /// Generate with the model.
106    pub async fn send_chat_request<R: RequestLike>(
107        &self,
108        mut request: R,
109    ) -> anyhow::Result<ChatCompletionResponse> {
110        let (tx, mut rx) = channel(1);
111
112        let truncate_sequence = request.truncate_sequence();
113        let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
114            (Some(a), Some(b))
115        } else {
116            (None, None)
117        };
118        let request = Request::Normal(Box::new(NormalRequest {
119            messages: request.take_messages(),
120            sampling_params: request.take_sampling_params(),
121            response: tx,
122            return_logprobs: request.return_logprobs(),
123            is_streaming: false,
124            id: 0,
125            constraint: request.take_constraint(),
126            suffix: None,
127            tools,
128            tool_choice,
129            logits_processors: request.take_logits_processors(),
130            return_raw_logits: false,
131            web_search_options: request.take_web_search_options(),
132            model_id: None,
133            truncate_sequence,
134        }));
135
136        self.runner.get_sender(None)?.send(request).await?;
137
138        let ResponseOk::Done(response) = rx
139            .recv()
140            .await
141            .context("Channel was erroneously closed!")?
142            .as_result()?
143        else {
144            anyhow::bail!("Got unexpected response type.")
145        };
146
147        Ok(response)
148    }
149
150    /// Generate with the model, returning raw logits of the first token generated.
151    ///
152    /// Returns the chunks of the logits (1 or more, determined by prompt batchsize) and the tokens.
153    pub async fn send_raw_chat_request<R: RequestLike>(
154        &self,
155        mut request: R,
156    ) -> anyhow::Result<(Vec<Tensor>, Vec<u32>)> {
157        let (tx, mut rx) = channel(1);
158
159        let truncate_sequence = request.truncate_sequence();
160        let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
161            (Some(a), Some(b))
162        } else {
163            (None, None)
164        };
165        let request = Request::Normal(Box::new(NormalRequest {
166            messages: request.take_messages(),
167            sampling_params: request.take_sampling_params(),
168            response: tx,
169            return_logprobs: request.return_logprobs(),
170            is_streaming: false,
171            id: 0,
172            constraint: request.take_constraint(),
173            suffix: None,
174            tools,
175            tool_choice,
176            logits_processors: request.take_logits_processors(),
177            return_raw_logits: true,
178            web_search_options: request.take_web_search_options(),
179            model_id: None,
180            truncate_sequence,
181        }));
182
183        self.runner.get_sender(None)?.send(request).await?;
184
185        let ResponseOk::Raw {
186            logits_chunks,
187            tokens,
188        } = rx
189            .recv()
190            .await
191            .context("Channel was erroneously closed!")?
192            .as_result()?
193        else {
194            anyhow::bail!("Got unexpected response type.")
195        };
196
197        Ok((logits_chunks, tokens))
198    }
199
200    pub async fn generate_image(
201        &self,
202        prompt: impl ToString,
203        response_format: ImageGenerationResponseFormat,
204        generation_params: DiffusionGenerationParams,
205    ) -> anyhow::Result<ImageGenerationResponse> {
206        let (tx, mut rx) = channel(1);
207
208        let request = Request::Normal(Box::new(NormalRequest {
209            id: 0,
210            messages: RequestMessage::ImageGeneration {
211                prompt: prompt.to_string(),
212                format: response_format,
213                generation_params,
214            },
215            sampling_params: SamplingParams::deterministic(),
216            response: tx,
217            return_logprobs: false,
218            is_streaming: false,
219            suffix: None,
220            constraint: Constraint::None,
221            tool_choice: None,
222            tools: None,
223            logits_processors: None,
224            return_raw_logits: false,
225            web_search_options: None,
226            model_id: None,
227            truncate_sequence: false,
228        }));
229
230        self.runner.get_sender(None)?.send(request).await?;
231
232        let ResponseOk::ImageGeneration(response) = rx
233            .recv()
234            .await
235            .context("Channel was erroneously closed!")?
236            .as_result()?
237        else {
238            anyhow::bail!("Got unexpected response type.")
239        };
240
241        Ok(response)
242    }
243
244    /// Generate audio given a (model specific) prompt.
245    ///
246    /// This returns: (pcm, sampling rate, channels)
247    pub async fn generate_speech(
248        &self,
249        prompt: impl ToString,
250    ) -> anyhow::Result<(Arc<Vec<f32>>, usize, usize)> {
251        let (tx, mut rx) = channel(1);
252
253        let request = Request::Normal(Box::new(NormalRequest {
254            id: 0,
255            messages: RequestMessage::SpeechGeneration {
256                prompt: prompt.to_string(),
257            },
258            sampling_params: SamplingParams::deterministic(),
259            response: tx,
260            return_logprobs: false,
261            is_streaming: false,
262            suffix: None,
263            constraint: Constraint::None,
264            tool_choice: None,
265            tools: None,
266            logits_processors: None,
267            return_raw_logits: false,
268            web_search_options: None,
269            model_id: None,
270            truncate_sequence: false,
271        }));
272
273        self.runner.get_sender(None)?.send(request).await?;
274
275        let ResponseOk::Speech {
276            pcm,
277            rate,
278            channels,
279        } = rx
280            .recv()
281            .await
282            .context("Channel was erroneously closed!")?
283            .as_result()?
284        else {
285            anyhow::bail!("Got unexpected response type.")
286        };
287
288        Ok((pcm, rate, channels))
289    }
290
291    /// Generate embeddings for one or more inputs configured via an [`EmbeddingRequestBuilder`].
292    ///
293    /// Returns one embedding vector per input in the same order they were added.
294    pub async fn generate_embeddings(
295        &self,
296        request: EmbeddingRequestBuilder,
297    ) -> anyhow::Result<Vec<Vec<f32>>> {
298        let request = request.build()?;
299        let EmbeddingRequest {
300            inputs,
301            truncate_sequence,
302        } = request;
303
304        let runner = self.runner.clone();
305        let futures = inputs.into_iter().map(|input| {
306            let runner = runner.clone();
307            async move {
308                let message = input.into_request_message();
309                let (tx, mut rx) = channel(1);
310
311                let request = Request::Normal(Box::new(NormalRequest {
312                    id: 0,
313                    messages: message,
314                    sampling_params: SamplingParams::deterministic(),
315                    response: tx,
316                    return_logprobs: false,
317                    is_streaming: false,
318                    suffix: None,
319                    constraint: Constraint::None,
320                    tool_choice: None,
321                    tools: None,
322                    logits_processors: None,
323                    return_raw_logits: false,
324                    web_search_options: None,
325                    model_id: None,
326                    truncate_sequence,
327                }));
328
329                runner
330                    .get_sender(None)?
331                    .send(request)
332                    .await
333                    .map_err(|e| anyhow::anyhow!(e.to_string()))?;
334
335                let ResponseOk::Embeddings { embeddings, .. } = rx
336                    .recv()
337                    .await
338                    .context("Channel was erroneously closed!")?
339                    .as_result()?
340                else {
341                    anyhow::bail!("Got unexpected response type.")
342                };
343
344                Ok::<Vec<f32>, anyhow::Error>(embeddings)
345            }
346        });
347
348        let results = join_all(futures).await;
349        let mut embeddings = Vec::with_capacity(results.len());
350        for result in results {
351            embeddings.push(result?);
352        }
353        Ok(embeddings)
354    }
355
356    /// Convenience wrapper for generating a single embedding.
357    pub async fn generate_embedding(&self, prompt: impl ToString) -> anyhow::Result<Vec<f32>> {
358        let mut embeddings = self
359            .generate_embeddings(EmbeddingRequest::builder().add_prompt(prompt.to_string()))
360            .await?;
361
362        Ok(embeddings
363            .pop()
364            .expect("EmbeddingRequestBuilder should guarantee at least one input"))
365    }
366
367    /// Reapply ISQ to the model. This will be done on whatever device the model is already on.
368    pub async fn re_isq_model(&self, isq_type: IsqType) -> anyhow::Result<()> {
369        let request = Request::ReIsq(isq_type);
370
371        Ok(self.runner.get_sender(None)?.send(request).await?)
372    }
373
374    /// Tokenize some text or messages.
375    /// - `tools` is only used if messages are provided.
376    pub async fn tokenize(
377        &self,
378        text: Either<TextMessages, String>,
379        tools: Option<Vec<Tool>>,
380        add_special_tokens: bool,
381        add_generation_prompt: bool,
382        enable_thinking: Option<bool>,
383    ) -> anyhow::Result<Vec<u32>> {
384        let (tx, mut rx) = channel(1);
385        let request = Request::Tokenize(TokenizationRequest {
386            text: text.map_left(Into::into),
387            tools,
388            add_special_tokens,
389            add_generation_prompt,
390            response: tx,
391            enable_thinking,
392        });
393        self.runner.get_sender(None)?.send(request).await?;
394
395        rx.recv().await.context("Channel was erroneously closed!")?
396    }
397
398    /// Detokenize some tokens.
399    pub async fn detokenize(
400        &self,
401        tokens: Vec<u32>,
402        skip_special_tokens: bool,
403    ) -> anyhow::Result<String> {
404        let (tx, mut rx) = channel(1);
405        let request = Request::Detokenize(DetokenizationRequest {
406            tokens,
407            skip_special_tokens,
408            response: tx,
409        });
410        self.runner.get_sender(None)?.send(request).await?;
411
412        rx.recv().await.context("Channel was erroneously closed!")?
413    }
414
415    /// Retrieve some information about this model.
416    pub fn config(&self) -> std::result::Result<MistralRsConfig, String> {
417        self.runner.config(None)
418    }
419
420    pub fn inner(&self) -> &MistralRs {
421        &self.runner
422    }
423}