mistralrs/
model.rs

1use anyhow::Context;
2use candle_core::{Device, Result, Tensor};
3use either::Either;
4use mistralrs_core::*;
5use std::sync::Arc;
6use tokio::sync::mpsc::{channel, Receiver};
7
8use crate::{RequestLike, TextMessages};
9
10/// Gets the best device, cpu, cuda if compiled with CUDA, or Metal
11pub fn best_device(force_cpu: bool) -> Result<Device> {
12    if force_cpu {
13        return Ok(Device::Cpu);
14    }
15    #[cfg(not(feature = "metal"))]
16    {
17        Device::cuda_if_available(0)
18    }
19    #[cfg(feature = "metal")]
20    {
21        Device::new_metal(0)
22    }
23}
24
25/// The object used to interact with the model. This can be used with many varietes of models, \
26/// and as such may be created with one of:
27/// - [`TextModelBuilder`]
28/// - [`LoraModelBuilder`]
29/// - [`XLoraModelBuilder`]
30/// - [`GgufModelBuilder`]
31/// - [`GgufLoraModelBuilder`]
32/// - [`GgufXLoraModelBuilder`]
33/// - [`VisionModelBuilder`]
34/// - [`AnyMoeModelBuilder`]
35///
36/// [`TextModelBuilder`]: crate::TextModelBuilder
37/// [`LoraModelBuilder`]: crate::LoraModelBuilder
38/// [`XLoraModelBuilder`]: crate::XLoraModelBuilder
39/// [`GgufModelBuilder`]: crate::GgufModelBuilder
40/// [`GgufModelBuilder`]: crate::GgufModelBuilder
41/// [`GgufLoraModelBuilder`]: crate::GgufLoraModelBuilder
42/// [`GgufXLoraModelBuilder`]: crate::GgufXLoraModelBuilder
43/// [`VisionModelBuilder`]: crate::VisionModelBuilder
44/// [`AnyMoeModelBuilder`]: crate::AnyMoeModelBuilder
45///
46pub struct Model {
47    pub(crate) runner: Arc<MistralRs>,
48}
49
50pub struct Stream<'a> {
51    _server: &'a Model,
52    rx: Receiver<Response>,
53}
54
55impl Stream<'_> {
56    pub async fn next(&mut self) -> Option<Response> {
57        self.rx.recv().await
58    }
59}
60
61impl Model {
62    pub fn new(runner: Arc<MistralRs>) -> Self {
63        Self { runner }
64    }
65
66    /// Generate with the model.
67    pub async fn stream_chat_request<R: RequestLike>(
68        &self,
69        mut request: R,
70    ) -> anyhow::Result<Stream> {
71        let (tx, rx) = channel(1);
72
73        let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
74            (Some(a), Some(b))
75        } else {
76            (None, None)
77        };
78        let request = Request::Normal(Box::new(NormalRequest {
79            messages: request.take_messages(),
80            sampling_params: request.take_sampling_params(),
81            response: tx,
82            return_logprobs: request.return_logprobs(),
83            is_streaming: true,
84            id: 0,
85            constraint: request.take_constraint(),
86            suffix: None,
87            tools,
88            tool_choice,
89            logits_processors: request.take_logits_processors(),
90            return_raw_logits: false,
91            web_search_options: request.take_web_search_options(),
92            model_id: None,
93        }));
94
95        self.runner.get_sender(None)?.send(request).await?;
96
97        let stream = Stream { _server: self, rx };
98
99        Ok(stream)
100    }
101
102    /// Generate with the model.
103    pub async fn send_chat_request<R: RequestLike>(
104        &self,
105        mut request: R,
106    ) -> anyhow::Result<ChatCompletionResponse> {
107        let (tx, mut rx) = channel(1);
108
109        let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
110            (Some(a), Some(b))
111        } else {
112            (None, None)
113        };
114        let request = Request::Normal(Box::new(NormalRequest {
115            messages: request.take_messages(),
116            sampling_params: request.take_sampling_params(),
117            response: tx,
118            return_logprobs: request.return_logprobs(),
119            is_streaming: false,
120            id: 0,
121            constraint: request.take_constraint(),
122            suffix: None,
123            tools,
124            tool_choice,
125            logits_processors: request.take_logits_processors(),
126            return_raw_logits: false,
127            web_search_options: request.take_web_search_options(),
128            model_id: None,
129        }));
130
131        self.runner.get_sender(None)?.send(request).await?;
132
133        let ResponseOk::Done(response) = rx
134            .recv()
135            .await
136            .context("Channel was erroneously closed!")?
137            .as_result()?
138        else {
139            anyhow::bail!("Got unexpected response type.")
140        };
141
142        Ok(response)
143    }
144
145    /// Generate with the model, returning raw logits of the first token generated.
146    ///
147    /// Returns the chunks of the logits (1 or more, determined by prompt batchsize) and the tokens.
148    pub async fn send_raw_chat_request<R: RequestLike>(
149        &self,
150        mut request: R,
151    ) -> anyhow::Result<(Vec<Tensor>, Vec<u32>)> {
152        let (tx, mut rx) = channel(1);
153
154        let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
155            (Some(a), Some(b))
156        } else {
157            (None, None)
158        };
159        let request = Request::Normal(Box::new(NormalRequest {
160            messages: request.take_messages(),
161            sampling_params: request.take_sampling_params(),
162            response: tx,
163            return_logprobs: request.return_logprobs(),
164            is_streaming: false,
165            id: 0,
166            constraint: request.take_constraint(),
167            suffix: None,
168            tools,
169            tool_choice,
170            logits_processors: request.take_logits_processors(),
171            return_raw_logits: true,
172            web_search_options: request.take_web_search_options(),
173            model_id: None,
174        }));
175
176        self.runner.get_sender(None)?.send(request).await?;
177
178        let ResponseOk::Raw {
179            logits_chunks,
180            tokens,
181        } = rx
182            .recv()
183            .await
184            .context("Channel was erroneously closed!")?
185            .as_result()?
186        else {
187            anyhow::bail!("Got unexpected response type.")
188        };
189
190        Ok((logits_chunks, tokens))
191    }
192
193    pub async fn generate_image(
194        &self,
195        prompt: impl ToString,
196        response_format: ImageGenerationResponseFormat,
197        generation_params: DiffusionGenerationParams,
198    ) -> anyhow::Result<ImageGenerationResponse> {
199        let (tx, mut rx) = channel(1);
200
201        let request = Request::Normal(Box::new(NormalRequest {
202            id: 0,
203            messages: RequestMessage::ImageGeneration {
204                prompt: prompt.to_string(),
205                format: response_format,
206                generation_params,
207            },
208            sampling_params: SamplingParams::deterministic(),
209            response: tx,
210            return_logprobs: false,
211            is_streaming: false,
212            suffix: None,
213            constraint: Constraint::None,
214            tool_choice: None,
215            tools: None,
216            logits_processors: None,
217            return_raw_logits: false,
218            web_search_options: None,
219            model_id: None,
220        }));
221
222        self.runner.get_sender(None)?.send(request).await?;
223
224        let ResponseOk::ImageGeneration(response) = rx
225            .recv()
226            .await
227            .context("Channel was erroneously closed!")?
228            .as_result()?
229        else {
230            anyhow::bail!("Got unexpected response type.")
231        };
232
233        Ok(response)
234    }
235
236    /// Generate audio given a (model specific) prompt.
237    ///
238    /// This returns: (pcm, sampling rate, channels)
239    pub async fn generate_speech(
240        &self,
241        prompt: impl ToString,
242    ) -> anyhow::Result<(Arc<Vec<f32>>, usize, usize)> {
243        let (tx, mut rx) = channel(1);
244
245        let request = Request::Normal(Box::new(NormalRequest {
246            id: 0,
247            messages: RequestMessage::SpeechGeneration {
248                prompt: prompt.to_string(),
249            },
250            sampling_params: SamplingParams::deterministic(),
251            response: tx,
252            return_logprobs: false,
253            is_streaming: false,
254            suffix: None,
255            constraint: Constraint::None,
256            tool_choice: None,
257            tools: None,
258            logits_processors: None,
259            return_raw_logits: false,
260            web_search_options: None,
261            model_id: None,
262        }));
263
264        self.runner.get_sender(None)?.send(request).await?;
265
266        let ResponseOk::Speech {
267            pcm,
268            rate,
269            channels,
270        } = rx
271            .recv()
272            .await
273            .context("Channel was erroneously closed!")?
274            .as_result()?
275        else {
276            anyhow::bail!("Got unexpected response type.")
277        };
278
279        Ok((pcm, rate, channels))
280    }
281
282    /// Reapply ISQ to the model. This will be done on whatever device the model is already on.
283    pub async fn re_isq_model(&self, isq_type: IsqType) -> anyhow::Result<()> {
284        let request = Request::ReIsq(isq_type);
285
286        Ok(self.runner.get_sender(None)?.send(request).await?)
287    }
288
289    /// Tokenize some text or messages.
290    /// - `tools` is only used if messages are provided.
291    pub async fn tokenize(
292        &self,
293        text: Either<TextMessages, String>,
294        tools: Option<Vec<Tool>>,
295        add_special_tokens: bool,
296        add_generation_prompt: bool,
297        enable_thinking: Option<bool>,
298    ) -> anyhow::Result<Vec<u32>> {
299        let (tx, mut rx) = channel(1);
300        let request = Request::Tokenize(TokenizationRequest {
301            text: text.map_left(Into::into),
302            tools,
303            add_special_tokens,
304            add_generation_prompt,
305            response: tx,
306            enable_thinking,
307        });
308        self.runner.get_sender(None)?.send(request).await?;
309
310        rx.recv().await.context("Channel was erroneously closed!")?
311    }
312
313    /// Detokenize some tokens.
314    pub async fn detokenize(
315        &self,
316        tokens: Vec<u32>,
317        skip_special_tokens: bool,
318    ) -> anyhow::Result<String> {
319        let (tx, mut rx) = channel(1);
320        let request = Request::Detokenize(DetokenizationRequest {
321            tokens,
322            skip_special_tokens,
323            response: tx,
324        });
325        self.runner.get_sender(None)?.send(request).await?;
326
327        rx.recv().await.context("Channel was erroneously closed!")?
328    }
329
330    /// Retrieve some information about this model.
331    pub fn config(&self) -> std::result::Result<MistralRsConfig, String> {
332        self.runner.config(None)
333    }
334
335    pub fn inner(&self) -> &MistralRs {
336        &self.runner
337    }
338}