mistralrs/
model.rs

1use anyhow::Context;
2use candle_core::{Device, Result, Tensor};
3use either::Either;
4use futures::future::join_all;
5use mistralrs_core::*;
6use std::sync::Arc;
7use tokio::sync::mpsc::{channel, Receiver};
8
9use crate::{EmbeddingRequest, EmbeddingRequestBuilder, RequestLike, TextMessages};
10
11// Re-export for convenience
12pub use mistralrs_core::{AddModelConfig, ModelStatus, Pipeline, SchedulerConfig};
13
14/// Gets the best device, cpu, cuda if compiled with CUDA, or Metal
15pub fn best_device(force_cpu: bool) -> Result<Device> {
16    if force_cpu {
17        return Ok(Device::Cpu);
18    }
19    #[cfg(not(feature = "metal"))]
20    {
21        Device::cuda_if_available(0)
22    }
23    #[cfg(feature = "metal")]
24    {
25        Device::new_metal(0)
26    }
27}
28
29/// The object used to interact with the model. This can be used with many varietes of models, \
30/// and as such may be created with one of:
31/// - [`TextModelBuilder`]
32/// - [`LoraModelBuilder`]
33/// - [`XLoraModelBuilder`]
34/// - [`GgufModelBuilder`]
35/// - [`GgufLoraModelBuilder`]
36/// - [`GgufXLoraModelBuilder`]
37/// - [`VisionModelBuilder`]
38/// - [`AnyMoeModelBuilder`]
39///
40/// [`TextModelBuilder`]: crate::TextModelBuilder
41/// [`LoraModelBuilder`]: crate::LoraModelBuilder
42/// [`XLoraModelBuilder`]: crate::XLoraModelBuilder
43/// [`GgufModelBuilder`]: crate::GgufModelBuilder
44/// [`GgufModelBuilder`]: crate::GgufModelBuilder
45/// [`GgufLoraModelBuilder`]: crate::GgufLoraModelBuilder
46/// [`GgufXLoraModelBuilder`]: crate::GgufXLoraModelBuilder
47/// [`VisionModelBuilder`]: crate::VisionModelBuilder
48/// [`AnyMoeModelBuilder`]: crate::AnyMoeModelBuilder
49///
50pub struct Model {
51    pub(crate) runner: Arc<MistralRs>,
52}
53
54pub struct Stream<'a> {
55    _server: &'a Model,
56    rx: Receiver<Response>,
57}
58
59impl Stream<'_> {
60    pub async fn next(&mut self) -> Option<Response> {
61        self.rx.recv().await
62    }
63}
64
65impl Model {
66    pub fn new(runner: Arc<MistralRs>) -> Self {
67        Self { runner }
68    }
69
70    // ========================================================================
71    // Chat Request Methods
72    // ========================================================================
73
74    /// Generate with the model (streaming).
75    pub async fn stream_chat_request<R: RequestLike>(
76        &self,
77        request: R,
78    ) -> anyhow::Result<Stream<'_>> {
79        self.stream_chat_request_with_model(request, None).await
80    }
81
82    /// Generate with a specific model (streaming).
83    /// If `model_id` is `None`, the request is sent to the default model.
84    pub async fn stream_chat_request_with_model<R: RequestLike>(
85        &self,
86        mut request: R,
87        model_id: Option<&str>,
88    ) -> anyhow::Result<Stream<'_>> {
89        let (tx, rx) = channel(1);
90
91        let truncate_sequence = request.truncate_sequence();
92        let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
93            (Some(a), Some(b))
94        } else {
95            (None, None)
96        };
97        let request = Request::Normal(Box::new(NormalRequest {
98            messages: request.take_messages(),
99            sampling_params: request.take_sampling_params(),
100            response: tx,
101            return_logprobs: request.return_logprobs(),
102            is_streaming: true,
103            id: 0,
104            constraint: request.take_constraint(),
105            suffix: None,
106            tools,
107            tool_choice,
108            logits_processors: request.take_logits_processors(),
109            return_raw_logits: false,
110            web_search_options: request.take_web_search_options(),
111            model_id: model_id.map(|s| s.to_string()),
112            truncate_sequence,
113        }));
114
115        self.runner.get_sender(model_id)?.send(request).await?;
116
117        let stream = Stream { _server: self, rx };
118
119        Ok(stream)
120    }
121
122    /// Generate with the model (non-streaming).
123    pub async fn send_chat_request<R: RequestLike>(
124        &self,
125        request: R,
126    ) -> anyhow::Result<ChatCompletionResponse> {
127        self.send_chat_request_with_model(request, None).await
128    }
129
130    /// Send a chat request to a specific model.
131    /// If `model_id` is `None`, the request is sent to the default model.
132    pub async fn send_chat_request_with_model<R: RequestLike>(
133        &self,
134        mut request: R,
135        model_id: Option<&str>,
136    ) -> anyhow::Result<ChatCompletionResponse> {
137        let (tx, mut rx) = channel(1);
138
139        let truncate_sequence = request.truncate_sequence();
140        let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
141            (Some(a), Some(b))
142        } else {
143            (None, None)
144        };
145        let request = Request::Normal(Box::new(NormalRequest {
146            messages: request.take_messages(),
147            sampling_params: request.take_sampling_params(),
148            response: tx,
149            return_logprobs: request.return_logprobs(),
150            is_streaming: false,
151            id: 0,
152            constraint: request.take_constraint(),
153            suffix: None,
154            tools,
155            tool_choice,
156            logits_processors: request.take_logits_processors(),
157            return_raw_logits: false,
158            web_search_options: request.take_web_search_options(),
159            model_id: model_id.map(|s| s.to_string()),
160            truncate_sequence,
161        }));
162
163        self.runner.get_sender(model_id)?.send(request).await?;
164
165        let ResponseOk::Done(response) = rx
166            .recv()
167            .await
168            .context("Channel was erroneously closed!")?
169            .as_result()?
170        else {
171            anyhow::bail!("Got unexpected response type.")
172        };
173
174        Ok(response)
175    }
176
177    /// Generate with the model, returning raw logits of the first token generated.
178    ///
179    /// Returns the chunks of the logits (1 or more, determined by prompt batchsize) and the tokens.
180    pub async fn send_raw_chat_request<R: RequestLike>(
181        &self,
182        request: R,
183    ) -> anyhow::Result<(Vec<Tensor>, Vec<u32>)> {
184        self.send_raw_chat_request_with_model(request, None).await
185    }
186
187    /// Generate with a specific model, returning raw logits of the first token generated.
188    /// If `model_id` is `None`, the request is sent to the default model.
189    pub async fn send_raw_chat_request_with_model<R: RequestLike>(
190        &self,
191        mut request: R,
192        model_id: Option<&str>,
193    ) -> anyhow::Result<(Vec<Tensor>, Vec<u32>)> {
194        let (tx, mut rx) = channel(1);
195
196        let truncate_sequence = request.truncate_sequence();
197        let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
198            (Some(a), Some(b))
199        } else {
200            (None, None)
201        };
202        let request = Request::Normal(Box::new(NormalRequest {
203            messages: request.take_messages(),
204            sampling_params: request.take_sampling_params(),
205            response: tx,
206            return_logprobs: request.return_logprobs(),
207            is_streaming: false,
208            id: 0,
209            constraint: request.take_constraint(),
210            suffix: None,
211            tools,
212            tool_choice,
213            logits_processors: request.take_logits_processors(),
214            return_raw_logits: true,
215            web_search_options: request.take_web_search_options(),
216            model_id: model_id.map(|s| s.to_string()),
217            truncate_sequence,
218        }));
219
220        self.runner.get_sender(model_id)?.send(request).await?;
221
222        let ResponseOk::Raw {
223            logits_chunks,
224            tokens,
225        } = rx
226            .recv()
227            .await
228            .context("Channel was erroneously closed!")?
229            .as_result()?
230        else {
231            anyhow::bail!("Got unexpected response type.")
232        };
233
234        Ok((logits_chunks, tokens))
235    }
236
237    // ========================================================================
238    // Image Generation Methods
239    // ========================================================================
240
241    /// Generate an image using the default model.
242    pub async fn generate_image(
243        &self,
244        prompt: impl ToString,
245        response_format: ImageGenerationResponseFormat,
246        generation_params: DiffusionGenerationParams,
247    ) -> anyhow::Result<ImageGenerationResponse> {
248        self.generate_image_with_model(prompt, response_format, generation_params, None)
249            .await
250    }
251
252    /// Generate an image using a specific model.
253    /// If `model_id` is `None`, the request is sent to the default model.
254    pub async fn generate_image_with_model(
255        &self,
256        prompt: impl ToString,
257        response_format: ImageGenerationResponseFormat,
258        generation_params: DiffusionGenerationParams,
259        model_id: Option<&str>,
260    ) -> anyhow::Result<ImageGenerationResponse> {
261        let (tx, mut rx) = channel(1);
262
263        let request = Request::Normal(Box::new(NormalRequest {
264            id: 0,
265            messages: RequestMessage::ImageGeneration {
266                prompt: prompt.to_string(),
267                format: response_format,
268                generation_params,
269            },
270            sampling_params: SamplingParams::deterministic(),
271            response: tx,
272            return_logprobs: false,
273            is_streaming: false,
274            suffix: None,
275            constraint: Constraint::None,
276            tool_choice: None,
277            tools: None,
278            logits_processors: None,
279            return_raw_logits: false,
280            web_search_options: None,
281            model_id: model_id.map(|s| s.to_string()),
282            truncate_sequence: false,
283        }));
284
285        self.runner.get_sender(model_id)?.send(request).await?;
286
287        let ResponseOk::ImageGeneration(response) = rx
288            .recv()
289            .await
290            .context("Channel was erroneously closed!")?
291            .as_result()?
292        else {
293            anyhow::bail!("Got unexpected response type.")
294        };
295
296        Ok(response)
297    }
298
299    // ========================================================================
300    // Speech Generation Methods
301    // ========================================================================
302
303    /// Generate audio given a (model specific) prompt.
304    ///
305    /// This returns: (pcm, sampling rate, channels)
306    pub async fn generate_speech(
307        &self,
308        prompt: impl ToString,
309    ) -> anyhow::Result<(Arc<Vec<f32>>, usize, usize)> {
310        self.generate_speech_with_model(prompt, None).await
311    }
312
313    /// Generate audio given a (model specific) prompt using a specific model.
314    /// If `model_id` is `None`, the request is sent to the default model.
315    ///
316    /// This returns: (pcm, sampling rate, channels)
317    pub async fn generate_speech_with_model(
318        &self,
319        prompt: impl ToString,
320        model_id: Option<&str>,
321    ) -> anyhow::Result<(Arc<Vec<f32>>, usize, usize)> {
322        let (tx, mut rx) = channel(1);
323
324        let request = Request::Normal(Box::new(NormalRequest {
325            id: 0,
326            messages: RequestMessage::SpeechGeneration {
327                prompt: prompt.to_string(),
328            },
329            sampling_params: SamplingParams::deterministic(),
330            response: tx,
331            return_logprobs: false,
332            is_streaming: false,
333            suffix: None,
334            constraint: Constraint::None,
335            tool_choice: None,
336            tools: None,
337            logits_processors: None,
338            return_raw_logits: false,
339            web_search_options: None,
340            model_id: model_id.map(|s| s.to_string()),
341            truncate_sequence: false,
342        }));
343
344        self.runner.get_sender(model_id)?.send(request).await?;
345
346        let ResponseOk::Speech {
347            pcm,
348            rate,
349            channels,
350        } = rx
351            .recv()
352            .await
353            .context("Channel was erroneously closed!")?
354            .as_result()?
355        else {
356            anyhow::bail!("Got unexpected response type.")
357        };
358
359        Ok((pcm, rate, channels))
360    }
361
362    // ========================================================================
363    // Embedding Methods
364    // ========================================================================
365
366    /// Generate embeddings for one or more inputs configured via an [`EmbeddingRequestBuilder`].
367    ///
368    /// Returns one embedding vector per input in the same order they were added.
369    pub async fn generate_embeddings(
370        &self,
371        request: EmbeddingRequestBuilder,
372    ) -> anyhow::Result<Vec<Vec<f32>>> {
373        self.generate_embeddings_with_model(request, None).await
374    }
375
376    /// Generate embeddings for one or more inputs using a specific model.
377    /// If `model_id` is `None`, the request is sent to the default model.
378    ///
379    /// Returns one embedding vector per input in the same order they were added.
380    pub async fn generate_embeddings_with_model(
381        &self,
382        request: EmbeddingRequestBuilder,
383        model_id: Option<&str>,
384    ) -> anyhow::Result<Vec<Vec<f32>>> {
385        let request = request.build()?;
386        let EmbeddingRequest {
387            inputs,
388            truncate_sequence,
389        } = request;
390
391        let runner = self.runner.clone();
392        let model_id_owned = model_id.map(|s| s.to_string());
393        let futures = inputs.into_iter().map(|input| {
394            let runner = runner.clone();
395            let model_id_owned = model_id_owned.clone();
396            async move {
397                let message = input.into_request_message();
398                let (tx, mut rx) = channel(1);
399
400                let request = Request::Normal(Box::new(NormalRequest {
401                    id: 0,
402                    messages: message,
403                    sampling_params: SamplingParams::deterministic(),
404                    response: tx,
405                    return_logprobs: false,
406                    is_streaming: false,
407                    suffix: None,
408                    constraint: Constraint::None,
409                    tool_choice: None,
410                    tools: None,
411                    logits_processors: None,
412                    return_raw_logits: false,
413                    web_search_options: None,
414                    model_id: model_id_owned.clone(),
415                    truncate_sequence,
416                }));
417
418                runner
419                    .get_sender(model_id_owned.as_deref())?
420                    .send(request)
421                    .await
422                    .map_err(|e| anyhow::anyhow!(e.to_string()))?;
423
424                let ResponseOk::Embeddings { embeddings, .. } = rx
425                    .recv()
426                    .await
427                    .context("Channel was erroneously closed!")?
428                    .as_result()?
429                else {
430                    anyhow::bail!("Got unexpected response type.")
431                };
432
433                Ok::<Vec<f32>, anyhow::Error>(embeddings)
434            }
435        });
436
437        let results = join_all(futures).await;
438        let mut embeddings = Vec::with_capacity(results.len());
439        for result in results {
440            embeddings.push(result?);
441        }
442        Ok(embeddings)
443    }
444
445    /// Convenience wrapper for generating a single embedding.
446    pub async fn generate_embedding(&self, prompt: impl ToString) -> anyhow::Result<Vec<f32>> {
447        self.generate_embedding_with_model(prompt, None).await
448    }
449
450    /// Convenience wrapper for generating a single embedding using a specific model.
451    /// If `model_id` is `None`, the request is sent to the default model.
452    pub async fn generate_embedding_with_model(
453        &self,
454        prompt: impl ToString,
455        model_id: Option<&str>,
456    ) -> anyhow::Result<Vec<f32>> {
457        let mut embeddings = self
458            .generate_embeddings_with_model(
459                EmbeddingRequest::builder().add_prompt(prompt.to_string()),
460                model_id,
461            )
462            .await?;
463
464        Ok(embeddings
465            .pop()
466            .expect("EmbeddingRequestBuilder should guarantee at least one input"))
467    }
468
469    // ========================================================================
470    // Model Management Methods
471    // ========================================================================
472
473    /// Reapply ISQ to the model. This will be done on whatever device the model is already on.
474    pub async fn re_isq_model(&self, isq_type: IsqType) -> anyhow::Result<()> {
475        self.re_isq_model_with_model(isq_type, None).await
476    }
477
478    /// Reapply ISQ to a specific model.
479    /// If `model_id` is `None`, the request is sent to the default model.
480    pub async fn re_isq_model_with_model(
481        &self,
482        isq_type: IsqType,
483        model_id: Option<&str>,
484    ) -> anyhow::Result<()> {
485        let request = Request::ReIsq(isq_type);
486
487        Ok(self.runner.get_sender(model_id)?.send(request).await?)
488    }
489
490    // ========================================================================
491    // Tokenization Methods
492    // ========================================================================
493
494    /// Tokenize some text or messages.
495    /// - `tools` is only used if messages are provided.
496    pub async fn tokenize(
497        &self,
498        text: Either<TextMessages, String>,
499        tools: Option<Vec<Tool>>,
500        add_special_tokens: bool,
501        add_generation_prompt: bool,
502        enable_thinking: Option<bool>,
503    ) -> anyhow::Result<Vec<u32>> {
504        self.tokenize_with_model(
505            text,
506            tools,
507            add_special_tokens,
508            add_generation_prompt,
509            enable_thinking,
510            None,
511        )
512        .await
513    }
514
515    /// Tokenize some text or messages using a specific model.
516    /// If `model_id` is `None`, the request is sent to the default model.
517    /// - `tools` is only used if messages are provided.
518    pub async fn tokenize_with_model(
519        &self,
520        text: Either<TextMessages, String>,
521        tools: Option<Vec<Tool>>,
522        add_special_tokens: bool,
523        add_generation_prompt: bool,
524        enable_thinking: Option<bool>,
525        model_id: Option<&str>,
526    ) -> anyhow::Result<Vec<u32>> {
527        let (tx, mut rx) = channel(1);
528        let request = Request::Tokenize(TokenizationRequest {
529            text: text.map_left(Into::into),
530            tools,
531            add_special_tokens,
532            add_generation_prompt,
533            response: tx,
534            enable_thinking,
535            reasoning_effort: None,
536        });
537        self.runner.get_sender(model_id)?.send(request).await?;
538
539        rx.recv().await.context("Channel was erroneously closed!")?
540    }
541
542    /// Detokenize some tokens.
543    pub async fn detokenize(
544        &self,
545        tokens: Vec<u32>,
546        skip_special_tokens: bool,
547    ) -> anyhow::Result<String> {
548        self.detokenize_with_model(tokens, skip_special_tokens, None)
549            .await
550    }
551
552    /// Detokenize some tokens using a specific model.
553    /// If `model_id` is `None`, the request is sent to the default model.
554    pub async fn detokenize_with_model(
555        &self,
556        tokens: Vec<u32>,
557        skip_special_tokens: bool,
558        model_id: Option<&str>,
559    ) -> anyhow::Result<String> {
560        let (tx, mut rx) = channel(1);
561        let request = Request::Detokenize(DetokenizationRequest {
562            tokens,
563            skip_special_tokens,
564            response: tx,
565        });
566        self.runner.get_sender(model_id)?.send(request).await?;
567
568        rx.recv().await.context("Channel was erroneously closed!")?
569    }
570
571    // ========================================================================
572    // Configuration Methods
573    // ========================================================================
574
575    /// Retrieve some information about this model.
576    pub fn config(&self) -> std::result::Result<MistralRsConfig, String> {
577        self.config_with_model(None)
578    }
579
580    /// Retrieve some information about a specific model.
581    /// If `model_id` is `None`, returns config for the default model.
582    pub fn config_with_model(
583        &self,
584        model_id: Option<&str>,
585    ) -> std::result::Result<MistralRsConfig, String> {
586        self.runner.config(model_id)
587    }
588
589    /// Returns the maximum supported sequence length for this model, if applicable.
590    pub fn max_sequence_length(&self) -> std::result::Result<Option<usize>, MistralRsError> {
591        self.max_sequence_length_with_model(None)
592    }
593
594    /// Returns the maximum supported sequence length for a specific model, if applicable.
595    /// If `model_id` is `None`, returns for the default model.
596    pub fn max_sequence_length_with_model(
597        &self,
598        model_id: Option<&str>,
599    ) -> std::result::Result<Option<usize>, MistralRsError> {
600        self.runner.max_sequence_length(model_id)
601    }
602
603    // ========================================================================
604    // Multi-Model Management Methods
605    // ========================================================================
606
607    /// List all available model IDs.
608    pub fn list_models(&self) -> std::result::Result<Vec<String>, String> {
609        self.runner.list_models()
610    }
611
612    /// Get the current default model ID.
613    pub fn get_default_model_id(&self) -> std::result::Result<Option<String>, String> {
614        self.runner.get_default_model_id()
615    }
616
617    /// Set the default model ID.
618    pub fn set_default_model_id(&self, model_id: &str) -> std::result::Result<(), String> {
619        self.runner.set_default_model_id(model_id)
620    }
621
622    /// Add a new model dynamically.
623    pub async fn add_model(
624        &self,
625        model_id: String,
626        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
627        method: SchedulerConfig,
628        config: AddModelConfig,
629    ) -> std::result::Result<(), String> {
630        self.runner
631            .add_model(model_id, pipeline, method, config)
632            .await
633    }
634
635    /// Remove a model by ID.
636    pub fn remove_model(&self, model_id: &str) -> std::result::Result<(), String> {
637        self.runner.remove_model(model_id)
638    }
639
640    /// Unload a model from memory (can be reloaded later).
641    pub fn unload_model(&self, model_id: &str) -> std::result::Result<(), MistralRsError> {
642        self.runner.unload_model(model_id)
643    }
644
645    /// Reload a previously unloaded model.
646    pub async fn reload_model(&self, model_id: &str) -> std::result::Result<(), MistralRsError> {
647        self.runner.reload_model(model_id).await
648    }
649
650    /// Check if a model is currently loaded.
651    pub fn is_model_loaded(&self, model_id: &str) -> std::result::Result<bool, MistralRsError> {
652        self.runner.is_model_loaded(model_id)
653    }
654
655    /// List all models with their status (Loaded, Unloaded, Reloading).
656    pub fn list_models_with_status(
657        &self,
658    ) -> std::result::Result<Vec<(String, ModelStatus)>, MistralRsError> {
659        self.runner.list_models_with_status()
660    }
661
662    /// Get the underlying MistralRs instance.
663    pub fn inner(&self) -> &MistralRs {
664        &self.runner
665    }
666}