mistralrs/
model.rs

1use anyhow::Context;
2use candle_core::{Device, Result, Tensor};
3use either::Either;
4use mistralrs_core::*;
5use std::sync::Arc;
6use tokio::sync::mpsc::{channel, Receiver};
7
8use crate::{RequestLike, TextMessages};
9
10/// Gets the best device, cpu, cuda if compiled with CUDA, or Metal
11pub fn best_device(force_cpu: bool) -> Result<Device> {
12    if force_cpu {
13        return Ok(Device::Cpu);
14    }
15    #[cfg(not(feature = "metal"))]
16    {
17        Device::cuda_if_available(0)
18    }
19    #[cfg(feature = "metal")]
20    {
21        Device::new_metal(0)
22    }
23}
24
25/// The object used to interact with the model. This can be used with many varietes of models, \
26/// and as such may be created with one of:
27/// - [`TextModelBuilder`]
28/// - [`LoraModelBuilder`]
29/// - [`XLoraModelBuilder`]
30/// - [`GgufModelBuilder`]
31/// - [`GgufLoraModelBuilder`]
32/// - [`GgufXLoraModelBuilder`]
33/// - [`VisionModelBuilder`]
34/// - [`AnyMoeModelBuilder`]
35///
36/// [`TextModelBuilder`]: crate::TextModelBuilder
37/// [`LoraModelBuilder`]: crate::LoraModelBuilder
38/// [`XLoraModelBuilder`]: crate::XLoraModelBuilder
39/// [`GgufModelBuilder`]: crate::GgufModelBuilder
40/// [`GgufModelBuilder`]: crate::GgufModelBuilder
41/// [`GgufLoraModelBuilder`]: crate::GgufLoraModelBuilder
42/// [`GgufXLoraModelBuilder`]: crate::GgufXLoraModelBuilder
43/// [`VisionModelBuilder`]: crate::VisionModelBuilder
44/// [`AnyMoeModelBuilder`]: crate::AnyMoeModelBuilder
45///
46pub struct Model {
47    runner: Arc<MistralRs>,
48}
49
50pub struct Stream<'a> {
51    _server: &'a Model,
52    rx: Receiver<Response>,
53}
54
55impl Stream<'_> {
56    pub async fn next(&mut self) -> Option<Response> {
57        self.rx.recv().await
58    }
59}
60
61impl Model {
62    pub fn new(runner: Arc<MistralRs>) -> Self {
63        Self { runner }
64    }
65
66    /// Generate with the model.
67    pub async fn stream_chat_request<R: RequestLike>(
68        &self,
69        mut request: R,
70    ) -> anyhow::Result<Stream> {
71        let (tx, rx) = channel(1);
72
73        let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
74            (Some(a), Some(b))
75        } else {
76            (None, None)
77        };
78        let request = Request::Normal(NormalRequest {
79            messages: request.take_messages(),
80            sampling_params: request.take_sampling_params(),
81            response: tx,
82            return_logprobs: request.return_logprobs(),
83            is_streaming: true,
84            id: 0,
85            constraint: request.take_constraint(),
86            suffix: None,
87            tools,
88            tool_choice,
89            logits_processors: request.take_logits_processors(),
90            return_raw_logits: false,
91            web_search_options: request.take_web_search_options(),
92        });
93
94        self.runner.get_sender()?.send(request).await?;
95
96        let stream = Stream { _server: self, rx };
97
98        Ok(stream)
99    }
100
101    /// Generate with the model.
102    pub async fn send_chat_request<R: RequestLike>(
103        &self,
104        mut request: R,
105    ) -> anyhow::Result<ChatCompletionResponse> {
106        let (tx, mut rx) = channel(1);
107
108        let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
109            (Some(a), Some(b))
110        } else {
111            (None, None)
112        };
113        let request = Request::Normal(NormalRequest {
114            messages: request.take_messages(),
115            sampling_params: request.take_sampling_params(),
116            response: tx,
117            return_logprobs: request.return_logprobs(),
118            is_streaming: false,
119            id: 0,
120            constraint: request.take_constraint(),
121            suffix: None,
122            tools,
123            tool_choice,
124            logits_processors: request.take_logits_processors(),
125            return_raw_logits: false,
126            web_search_options: request.take_web_search_options(),
127        });
128
129        self.runner.get_sender()?.send(request).await?;
130
131        let ResponseOk::Done(response) = rx
132            .recv()
133            .await
134            .context("Channel was erroneously closed!")?
135            .as_result()?
136        else {
137            anyhow::bail!("Got unexpected response type.")
138        };
139
140        Ok(response)
141    }
142
143    /// Generate with the model, returning raw logits of the first token generated.
144    ///
145    /// Returns the chunks of the logits (1 or more, determined by prompt batchsize) and the tokens.
146    pub async fn send_raw_chat_request<R: RequestLike>(
147        &self,
148        mut request: R,
149    ) -> anyhow::Result<(Vec<Tensor>, Vec<u32>)> {
150        let (tx, mut rx) = channel(1);
151
152        let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
153            (Some(a), Some(b))
154        } else {
155            (None, None)
156        };
157        let request = Request::Normal(NormalRequest {
158            messages: request.take_messages(),
159            sampling_params: request.take_sampling_params(),
160            response: tx,
161            return_logprobs: request.return_logprobs(),
162            is_streaming: false,
163            id: 0,
164            constraint: request.take_constraint(),
165            suffix: None,
166            tools,
167            tool_choice,
168            logits_processors: request.take_logits_processors(),
169            return_raw_logits: true,
170            web_search_options: request.take_web_search_options(),
171        });
172
173        self.runner.get_sender()?.send(request).await?;
174
175        let ResponseOk::Raw {
176            logits_chunks,
177            tokens,
178        } = rx
179            .recv()
180            .await
181            .context("Channel was erroneously closed!")?
182            .as_result()?
183        else {
184            anyhow::bail!("Got unexpected response type.")
185        };
186
187        Ok((logits_chunks, tokens))
188    }
189
190    pub async fn generate_image(
191        &self,
192        prompt: impl ToString,
193        response_format: ImageGenerationResponseFormat,
194        generation_params: DiffusionGenerationParams,
195    ) -> anyhow::Result<ImageGenerationResponse> {
196        let (tx, mut rx) = channel(1);
197
198        let request = Request::Normal(NormalRequest {
199            id: 0,
200            messages: RequestMessage::ImageGeneration {
201                prompt: prompt.to_string(),
202                format: response_format,
203                generation_params,
204            },
205            sampling_params: SamplingParams::deterministic(),
206            response: tx,
207            return_logprobs: false,
208            is_streaming: false,
209            suffix: None,
210            constraint: Constraint::None,
211            tool_choice: None,
212            tools: None,
213            logits_processors: None,
214            return_raw_logits: false,
215            web_search_options: None,
216        });
217
218        self.runner.get_sender()?.send(request).await?;
219
220        let ResponseOk::ImageGeneration(response) = rx
221            .recv()
222            .await
223            .context("Channel was erroneously closed!")?
224            .as_result()?
225        else {
226            anyhow::bail!("Got unexpected response type.")
227        };
228
229        Ok(response)
230    }
231
232    /// Reapply ISQ to the model. This will be done on whatever device the model is already on.
233    pub async fn re_isq_model(&self, isq_type: IsqType) -> anyhow::Result<()> {
234        let request = Request::ReIsq(isq_type);
235
236        Ok(self.runner.get_sender()?.send(request).await?)
237    }
238
239    /// Tokenize some text or messages.
240    /// - `tools` is only used if messages are provided.
241    pub async fn tokenize(
242        &self,
243        text: Either<TextMessages, String>,
244        tools: Option<Vec<Tool>>,
245        add_special_tokens: bool,
246        add_generation_prompt: bool,
247    ) -> anyhow::Result<Vec<u32>> {
248        let (tx, mut rx) = channel(1);
249        let request = Request::Tokenize(TokenizationRequest {
250            text: text.map_left(Into::into),
251            tools,
252            add_special_tokens,
253            add_generation_prompt,
254            response: tx,
255        });
256        self.runner.get_sender()?.send(request).await?;
257
258        rx.recv().await.context("Channel was erroneously closed!")?
259    }
260
261    /// Detokenize some tokens.
262    pub async fn detokenize(
263        &self,
264        tokens: Vec<u32>,
265        skip_special_tokens: bool,
266    ) -> anyhow::Result<String> {
267        let (tx, mut rx) = channel(1);
268        let request = Request::Detokenize(DetokenizationRequest {
269            tokens,
270            skip_special_tokens,
271            response: tx,
272        });
273        self.runner.get_sender()?.send(request).await?;
274
275        rx.recv().await.context("Channel was erroneously closed!")?
276    }
277
278    /// Retrieve some information about this model.
279    pub fn config(&self) -> &MistralRsConfig {
280        self.runner.config()
281    }
282
283    pub fn inner(&self) -> &MistralRs {
284        &self.runner
285    }
286}