mistralrs/
messages.rs

1use std::{collections::HashMap, fmt::Display, sync::Arc};
2
3use super::*;
4use either::Either;
5use image::DynamicImage;
6use indexmap::IndexMap;
7use serde_json::{json, Value};
8
9/// A type which can be used as a chat request.
10pub trait RequestLike {
11    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>];
12    fn images_ref(&self) -> &[DynamicImage];
13    fn take_messages(&mut self) -> RequestMessage;
14    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>>;
15    fn take_adapters(&mut self) -> Option<Vec<String>>;
16    fn return_logprobs(&self) -> bool;
17    fn enable_search(&self) -> Option<bool>;
18    fn take_constraint(&mut self) -> Constraint;
19    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)>;
20    fn take_sampling_params(&mut self) -> SamplingParams;
21    fn take_web_search_options(&mut self) -> Option<WebSearchOptions>;
22}
23
24#[derive(Debug, Clone, PartialEq)]
25/// Plain text (chat) messages.
26///
27/// No constraints, logits processors, logprobs, tools, or adapters.
28///
29/// Sampling is deterministic.
30pub struct TextMessages {
31    messages: Vec<IndexMap<String, MessageContent>>,
32    enable_thinking: Option<bool>,
33}
34
35impl From<TextMessages> for Vec<IndexMap<String, MessageContent>> {
36    fn from(value: TextMessages) -> Self {
37        value.messages
38    }
39}
40
41#[derive(Debug, Clone, PartialEq)]
42/// A chat message role.
43pub enum TextMessageRole {
44    User,
45    Assistant,
46    System,
47    Tool,
48    Custom(String),
49}
50
51impl Display for TextMessageRole {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            Self::User => write!(f, "user"),
55            Self::Assistant => write!(f, "assistant"),
56            Self::System => write!(f, "system"),
57            Self::Tool => write!(f, "tool"),
58            Self::Custom(c) => write!(f, "{c}"),
59        }
60    }
61}
62
63impl Default for TextMessages {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl TextMessages {
70    pub fn new() -> Self {
71        Self {
72            messages: Vec::new(),
73            enable_thinking: None,
74        }
75    }
76
77    pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
78        self.messages.push(IndexMap::from([
79            ("role".to_string(), Either::Left(role.to_string())),
80            ("content".to_string(), Either::Left(text.to_string())),
81        ]));
82        self
83    }
84
85    pub fn clear(mut self) -> Self {
86        self.messages.clear();
87        self
88    }
89
90    pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
91        self.enable_thinking = Some(enable_thinking);
92        self
93    }
94}
95
96impl RequestLike for TextMessages {
97    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
98        &self.messages
99    }
100    fn images_ref(&self) -> &[DynamicImage] {
101        &[]
102    }
103    fn take_messages(&mut self) -> RequestMessage {
104        let mut other = Vec::new();
105        std::mem::swap(&mut other, &mut self.messages);
106        RequestMessage::Chat {
107            messages: other,
108            enable_thinking: self.enable_thinking,
109        }
110    }
111    fn enable_search(&self) -> Option<bool> {
112        None
113    }
114    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
115        None
116    }
117    fn take_adapters(&mut self) -> Option<Vec<String>> {
118        None
119    }
120    fn return_logprobs(&self) -> bool {
121        false
122    }
123    fn take_constraint(&mut self) -> Constraint {
124        Constraint::None
125    }
126    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
127        None
128    }
129    fn take_sampling_params(&mut self) -> SamplingParams {
130        SamplingParams::deterministic()
131    }
132    fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
133        None
134    }
135}
136
137#[derive(Debug, Clone, PartialEq)]
138/// Text (chat) messages with images and/or audios.
139///
140/// No constraints, logits processors, logprobs, tools, or adapters.
141///
142/// Sampling is deterministic.
143pub struct VisionMessages {
144    messages: Vec<IndexMap<String, MessageContent>>,
145    images: Vec<DynamicImage>,
146    audios: Vec<AudioInput>,
147    enable_thinking: Option<bool>,
148}
149
150impl Default for VisionMessages {
151    fn default() -> Self {
152        Self::new()
153    }
154}
155
156impl VisionMessages {
157    pub fn new() -> Self {
158        Self {
159            images: Vec::new(),
160            messages: Vec::new(),
161            audios: Vec::new(),
162            enable_thinking: None,
163        }
164    }
165
166    pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
167        self.messages.push(IndexMap::from([
168            ("role".to_string(), Either::Left(role.to_string())),
169            ("content".to_string(), Either::Left(text.to_string())),
170        ]));
171        self
172    }
173
174    pub fn add_image_message(
175        self,
176        role: TextMessageRole,
177        text: impl ToString,
178        images: Vec<DynamicImage>,
179        model: &Model,
180    ) -> anyhow::Result<Self> {
181        self.add_multimodal_message(role, text, images, vec![], model)
182    }
183
184    pub fn add_audio_message(
185        self,
186        role: TextMessageRole,
187        text: impl ToString,
188        audios: Vec<AudioInput>,
189        model: &Model,
190    ) -> anyhow::Result<Self> {
191        self.add_multimodal_message(role, text, vec![], audios, model)
192    }
193
194    pub fn add_multimodal_message(
195        mut self,
196        role: TextMessageRole,
197        text: impl ToString,
198        images: Vec<DynamicImage>,
199        audios: Vec<AudioInput>,
200        model: &Model,
201    ) -> anyhow::Result<Self> {
202        let config = model.config().unwrap();
203        let prefixer = match &config.category {
204            ModelCategory::Vision { prefixer } => prefixer,
205            ModelCategory::Text
206            | ModelCategory::Diffusion
207            | ModelCategory::Speech
208            | ModelCategory::Audio => {
209                anyhow::bail!("`add_image_message` expects a vision model.")
210            }
211        };
212
213        // Images
214        let n_added_images = images.len();
215        let prefixed = prefixer.prefix_image(
216            (self.images.len()..self.images.len() + n_added_images).collect(),
217            &text.to_string(),
218        );
219        self.images.extend(images);
220
221        // Audios
222        let n_added_audios = audios.len();
223        let prefixed = prefixer.prefix_audio(
224            (self.audios.len()..self.audios.len() + n_added_audios).collect(),
225            &prefixed,
226        );
227        self.audios.extend(audios);
228
229        if n_added_images > 0 {
230            self.messages.push(IndexMap::from([
231                ("role".to_string(), Either::Left(role.to_string())),
232                (
233                    "content".to_string(),
234                    Either::Right(vec![
235                        IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
236                        IndexMap::from([
237                            ("type".to_string(), Value::String("text".to_string())),
238                            ("text".to_string(), Value::String(prefixed)),
239                        ]),
240                    ]),
241                ),
242            ]));
243        } else {
244            self.messages.push(IndexMap::from([
245                ("role".to_string(), Either::Left(role.to_string())),
246                ("content".to_string(), Either::Left(prefixed)),
247            ]));
248        }
249        Ok(self)
250    }
251
252    pub fn clear(mut self) -> Self {
253        self.messages.clear();
254        self.images.clear();
255
256        self
257    }
258
259    pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
260        self.enable_thinking = Some(enable_thinking);
261        self
262    }
263}
264
265impl RequestLike for VisionMessages {
266    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
267        &self.messages
268    }
269    fn images_ref(&self) -> &[DynamicImage] {
270        &self.images
271    }
272    fn take_messages(&mut self) -> RequestMessage {
273        let mut other_messages = Vec::new();
274        std::mem::swap(&mut other_messages, &mut self.messages);
275        let mut other_images = Vec::new();
276        std::mem::swap(&mut other_images, &mut self.images);
277        let mut other_audios = Vec::new();
278        std::mem::swap(&mut other_audios, &mut self.audios);
279        RequestMessage::VisionChat {
280            images: other_images,
281            messages: other_messages,
282            audios: other_audios,
283            enable_thinking: self.enable_thinking,
284        }
285    }
286    fn enable_search(&self) -> Option<bool> {
287        None
288    }
289    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
290        None
291    }
292    fn take_adapters(&mut self) -> Option<Vec<String>> {
293        None
294    }
295    fn return_logprobs(&self) -> bool {
296        false
297    }
298    fn take_constraint(&mut self) -> Constraint {
299        Constraint::None
300    }
301    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
302        None
303    }
304    fn take_sampling_params(&mut self) -> SamplingParams {
305        SamplingParams::deterministic()
306    }
307    fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
308        None
309    }
310}
311
312#[derive(Clone)]
313/// A way to add messages with finer control given.
314///
315/// This includes control over:
316/// - Logits processors
317/// - Constraints
318/// - Logprobs
319/// - Tools
320/// - Sampling
321/// - Enable thinking for models that support the configuration
322pub struct RequestBuilder {
323    messages: Vec<IndexMap<String, MessageContent>>,
324    images: Vec<DynamicImage>,
325    audios: Vec<AudioInput>,
326    logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
327    adapters: Vec<String>,
328    return_logprobs: bool,
329    constraint: Constraint,
330    tools: Vec<Tool>,
331    tool_choice: ToolChoice,
332    sampling_params: SamplingParams,
333    web_search_options: Option<WebSearchOptions>,
334    enable_thinking: Option<bool>,
335}
336
337impl Default for RequestBuilder {
338    fn default() -> Self {
339        Self::new()
340    }
341}
342
343impl From<TextMessages> for RequestBuilder {
344    fn from(value: TextMessages) -> Self {
345        Self {
346            messages: value.messages,
347            images: Vec::new(),
348            audios: Vec::new(),
349            logits_processors: Vec::new(),
350            adapters: Vec::new(),
351            return_logprobs: false,
352            constraint: Constraint::None,
353            tools: Vec::new(),
354            tool_choice: ToolChoice::Auto,
355            sampling_params: SamplingParams::deterministic(),
356            web_search_options: None,
357            enable_thinking: None,
358        }
359    }
360}
361
362impl From<VisionMessages> for RequestBuilder {
363    fn from(value: VisionMessages) -> Self {
364        Self {
365            messages: value.messages,
366            images: value.images,
367            audios: value.audios,
368            logits_processors: Vec::new(),
369            adapters: Vec::new(),
370            return_logprobs: false,
371            constraint: Constraint::None,
372            tools: Vec::new(),
373            tool_choice: ToolChoice::Auto,
374            sampling_params: SamplingParams::deterministic(),
375            web_search_options: None,
376            enable_thinking: None,
377        }
378    }
379}
380
381impl RequestBuilder {
382    pub fn new() -> Self {
383        Self {
384            messages: Vec::new(),
385            images: Vec::new(),
386            audios: Vec::new(),
387            logits_processors: Vec::new(),
388            adapters: Vec::new(),
389            return_logprobs: false,
390            constraint: Constraint::None,
391            tools: Vec::new(),
392            tool_choice: ToolChoice::Auto,
393            sampling_params: SamplingParams::deterministic(),
394            web_search_options: None,
395            enable_thinking: None,
396        }
397    }
398
399    pub fn with_web_search_options(mut self, web_search_options: WebSearchOptions) -> Self {
400        self.web_search_options = Some(web_search_options);
401        self
402    }
403
404    /// Add a message to the request.
405    ///
406    /// For messages with tool calls, use [`Self::add_message_with_tool_call`].
407    /// For messages with tool outputs, use [`Self::add_tool_message`].
408    pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
409        self.messages.push(IndexMap::from([
410            ("role".to_string(), Either::Left(role.to_string())),
411            ("content".to_string(), Either::Left(text.to_string())),
412        ]));
413        self
414    }
415
416    /// Add a message with the output of a tool call.
417    pub fn add_tool_message(mut self, tool_content: impl ToString, tool_id: impl ToString) -> Self {
418        self.messages.push(IndexMap::from([
419            (
420                "role".to_string(),
421                Either::Left(TextMessageRole::Tool.to_string()),
422            ),
423            (
424                "content".to_string(),
425                Either::Left(tool_content.to_string()),
426            ),
427            (
428                "tool_call_id".to_string(),
429                Either::Left(tool_id.to_string()),
430            ),
431        ]));
432        self
433    }
434
435    pub fn add_message_with_tool_call(
436        mut self,
437        role: TextMessageRole,
438        text: impl ToString,
439        tool_calls: Vec<ToolCallResponse>,
440    ) -> Self {
441        let tool_messages = tool_calls
442            .iter()
443            .map(|t| {
444                IndexMap::from([
445                    ("id".to_string(), Value::String(t.id.clone())),
446                    ("type".to_string(), Value::String(t.tp.to_string())),
447                    (
448                        "function".to_string(),
449                        json!({
450                            "name": t.function.name,
451                            "arguments": t.function.arguments,
452                        }),
453                    ),
454                ])
455            })
456            .collect();
457        self.messages.push(IndexMap::from([
458            ("role".to_string(), Either::Left(role.to_string())),
459            ("content".to_string(), Either::Left(text.to_string())),
460            ("function".to_string(), Either::Right(tool_messages)),
461        ]));
462        self
463    }
464
465    pub fn add_image_message(
466        self,
467        role: TextMessageRole,
468        text: impl ToString,
469        images: Vec<DynamicImage>,
470        model: &Model,
471    ) -> anyhow::Result<Self> {
472        self.add_multimodal_message(role, text, images, vec![], model)
473    }
474
475    pub fn add_audio_message(
476        self,
477        role: TextMessageRole,
478        text: impl ToString,
479        audios: Vec<AudioInput>,
480        model: &Model,
481    ) -> anyhow::Result<Self> {
482        self.add_multimodal_message(role, text, vec![], audios, model)
483    }
484
485    pub fn add_multimodal_message(
486        mut self,
487        role: TextMessageRole,
488        text: impl ToString,
489        images: Vec<DynamicImage>,
490        audios: Vec<AudioInput>,
491        model: &Model,
492    ) -> anyhow::Result<Self> {
493        let config = model.config().unwrap();
494        let prefixer = match &config.category {
495            ModelCategory::Vision { prefixer } => prefixer,
496            ModelCategory::Text
497            | ModelCategory::Diffusion
498            | ModelCategory::Speech
499            | ModelCategory::Audio => {
500                anyhow::bail!("`add_image_message` expects a vision model.")
501            }
502        };
503
504        // Images
505        let n_added_images = images.len();
506        let prefixed = prefixer.prefix_image(
507            (self.images.len()..self.images.len() + n_added_images).collect(),
508            &text.to_string(),
509        );
510        self.images.extend(images);
511
512        // Audios
513        let n_added_audios = audios.len();
514        let prefixed = prefixer.prefix_audio(
515            (self.audios.len()..self.audios.len() + n_added_audios).collect(),
516            &prefixed,
517        );
518        self.audios.extend(audios);
519
520        if n_added_images > 0 {
521            self.messages.push(IndexMap::from([
522                ("role".to_string(), Either::Left(role.to_string())),
523                (
524                    "content".to_string(),
525                    Either::Right(vec![
526                        IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
527                        IndexMap::from([
528                            ("type".to_string(), Value::String("text".to_string())),
529                            ("text".to_string(), Value::String(prefixed)),
530                        ]),
531                    ]),
532                ),
533            ]));
534        } else {
535            self.messages.push(IndexMap::from([
536                ("role".to_string(), Either::Left(role.to_string())),
537                ("content".to_string(), Either::Left(prefixed)),
538            ]));
539        }
540        Ok(self)
541    }
542
543    pub fn add_logits_processor(mut self, processor: Arc<dyn CustomLogitsProcessor>) -> Self {
544        self.logits_processors.push(processor);
545        self
546    }
547
548    pub fn set_adapters(mut self, adapters: Vec<String>) -> Self {
549        self.adapters = adapters;
550        self
551    }
552
553    /// The default tool choice is auto.
554    pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
555        self.tools = tools;
556        self
557    }
558
559    pub fn set_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
560        self.tool_choice = tool_choice;
561        self
562    }
563
564    pub fn return_logprobs(mut self, return_logprobs: bool) -> Self {
565        self.return_logprobs = return_logprobs;
566        self
567    }
568
569    pub fn set_constraint(mut self, constraint: Constraint) -> Self {
570        self.constraint = constraint;
571        self
572    }
573
574    /// Set the sampling parameters as given.
575    pub fn set_sampling(mut self, params: SamplingParams) -> Self {
576        self.sampling_params = params;
577        self
578    }
579
580    /// Set the sampling parameters for deterministic generation.
581    /// This sets up the parameters so that there is:
582    /// - No temperature, topk, topp, minp
583    /// - No penalties, stop tokens, or logit bias
584    /// - No maximum length
585    pub fn set_deterministic_sampler(mut self) -> Self {
586        self.sampling_params = SamplingParams::deterministic();
587        self
588    }
589
590    pub fn set_sampler_temperature(mut self, temperature: f64) -> Self {
591        self.sampling_params.temperature = Some(temperature);
592        self
593    }
594
595    pub fn set_sampler_topk(mut self, topk: usize) -> Self {
596        self.sampling_params.top_k = Some(topk);
597        self
598    }
599
600    pub fn set_sampler_topp(mut self, topp: f64) -> Self {
601        self.sampling_params.top_p = Some(topp);
602        self
603    }
604
605    pub fn set_sampler_minp(mut self, minp: f64) -> Self {
606        self.sampling_params.min_p = Some(minp);
607        self
608    }
609
610    pub fn set_sampler_topn_logprobs(mut self, top_n_logprobs: usize) -> Self {
611        self.sampling_params.top_n_logprobs = top_n_logprobs;
612        self
613    }
614
615    pub fn set_sampler_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
616        self.sampling_params.frequency_penalty = Some(frequency_penalty);
617        self
618    }
619
620    pub fn set_sampler_presence_penalty(mut self, presence_penalty: f32) -> Self {
621        self.sampling_params.presence_penalty = Some(presence_penalty);
622        self
623    }
624
625    pub fn set_sampler_stop_toks(mut self, stop_toks: StopTokens) -> Self {
626        self.sampling_params.stop_toks = Some(stop_toks);
627        self
628    }
629
630    pub fn set_sampler_max_len(mut self, max_len: usize) -> Self {
631        self.sampling_params.max_len = Some(max_len);
632        self
633    }
634
635    pub fn set_sampler_logits_bias(mut self, logits_bias: HashMap<u32, f32>) -> Self {
636        self.sampling_params.logits_bias = Some(logits_bias);
637        self
638    }
639
640    pub fn set_sampler_n_choices(mut self, n_choices: usize) -> Self {
641        self.sampling_params.n_choices = n_choices;
642        self
643    }
644
645    pub fn set_sampler_dry_params(mut self, dry_params: DrySamplingParams) -> Self {
646        self.sampling_params.dry_params = Some(dry_params);
647        self
648    }
649
650    pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
651        self.enable_thinking = Some(enable_thinking);
652        self
653    }
654}
655
656impl RequestLike for RequestBuilder {
657    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
658        &self.messages
659    }
660
661    fn images_ref(&self) -> &[DynamicImage] {
662        &self.images
663    }
664
665    fn take_messages(&mut self) -> RequestMessage {
666        if self.images.is_empty() && self.audios.is_empty() {
667            let mut other = Vec::new();
668            std::mem::swap(&mut other, &mut self.messages);
669            RequestMessage::Chat {
670                messages: other,
671                enable_thinking: self.enable_thinking,
672            }
673        } else {
674            let mut other_messages = Vec::new();
675            std::mem::swap(&mut other_messages, &mut self.messages);
676            let mut other_images = Vec::new();
677            std::mem::swap(&mut other_images, &mut self.images);
678            let mut other_audios = Vec::new();
679            std::mem::swap(&mut other_audios, &mut self.audios);
680            RequestMessage::VisionChat {
681                images: other_images,
682                messages: other_messages,
683                audios: other_audios,
684                enable_thinking: self.enable_thinking,
685            }
686        }
687    }
688
689    fn enable_search(&self) -> Option<bool> {
690        self.enable_thinking
691    }
692
693    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
694        if self.logits_processors.is_empty() {
695            None
696        } else {
697            let mut other = Vec::new();
698            std::mem::swap(&mut other, &mut self.logits_processors);
699            Some(other)
700        }
701    }
702
703    fn take_adapters(&mut self) -> Option<Vec<String>> {
704        if self.adapters.is_empty() {
705            None
706        } else {
707            let mut other = Vec::new();
708            std::mem::swap(&mut other, &mut self.adapters);
709            Some(other)
710        }
711    }
712
713    fn return_logprobs(&self) -> bool {
714        self.return_logprobs
715    }
716
717    fn take_constraint(&mut self) -> Constraint {
718        let mut other = Constraint::None;
719        std::mem::swap(&mut other, &mut self.constraint);
720        other
721    }
722
723    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
724        if self.tools.is_empty() {
725            None
726        } else {
727            let mut other_ts = Vec::new();
728            std::mem::swap(&mut other_ts, &mut self.tools);
729            let mut other_tc = ToolChoice::Auto;
730            std::mem::swap(&mut other_tc, &mut self.tool_choice);
731            Some((other_ts, other_tc))
732        }
733    }
734
735    fn take_sampling_params(&mut self) -> SamplingParams {
736        let mut other = SamplingParams::deterministic();
737        std::mem::swap(&mut other, &mut self.sampling_params);
738        other
739    }
740
741    fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
742        let mut other = None;
743        std::mem::swap(&mut other, &mut self.web_search_options);
744        other
745    }
746}