mistralrs/
messages.rs

1use std::{collections::HashMap, fmt::Display, sync::Arc};
2
3use super::*;
4use either::Either;
5use image::DynamicImage;
6use indexmap::IndexMap;
7use serde_json::{json, Value};
8
9/// A type which can be used as a chat request.
10pub trait RequestLike {
11    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>];
12    fn images_ref(&self) -> &[DynamicImage];
13    fn take_messages(&mut self) -> RequestMessage;
14    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>>;
15    fn take_adapters(&mut self) -> Option<Vec<String>>;
16    fn return_logprobs(&self) -> bool;
17    fn enable_search(&self) -> Option<bool>;
18    fn take_constraint(&mut self) -> Constraint;
19    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)>;
20    fn take_sampling_params(&mut self) -> SamplingParams;
21    fn take_web_search_options(&mut self) -> Option<WebSearchOptions>;
22}
23
24#[derive(Debug, Clone, PartialEq)]
25/// Plain text (chat) messages.
26///
27/// No constraints, logits processors, logprobs, tools, or adapters.
28///
29/// Sampling is deterministic.
30pub struct TextMessages {
31    messages: Vec<IndexMap<String, MessageContent>>,
32    enable_thinking: Option<bool>,
33}
34
35impl From<TextMessages> for Vec<IndexMap<String, MessageContent>> {
36    fn from(value: TextMessages) -> Self {
37        value.messages
38    }
39}
40
41#[derive(Debug, Clone, PartialEq)]
42/// A chat message role.
43pub enum TextMessageRole {
44    User,
45    Assistant,
46    System,
47    Tool,
48    Custom(String),
49}
50
51impl Display for TextMessageRole {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            Self::User => write!(f, "user"),
55            Self::Assistant => write!(f, "assistant"),
56            Self::System => write!(f, "system"),
57            Self::Tool => write!(f, "tool"),
58            Self::Custom(c) => write!(f, "{c}"),
59        }
60    }
61}
62
63impl Default for TextMessages {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl TextMessages {
70    pub fn new() -> Self {
71        Self {
72            messages: Vec::new(),
73            enable_thinking: None,
74        }
75    }
76
77    pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
78        self.messages.push(IndexMap::from([
79            ("role".to_string(), Either::Left(role.to_string())),
80            ("content".to_string(), Either::Left(text.to_string())),
81        ]));
82        self
83    }
84
85    pub fn clear(mut self) -> Self {
86        self.messages.clear();
87        self
88    }
89
90    pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
91        self.enable_thinking = Some(enable_thinking);
92        self
93    }
94}
95
96impl RequestLike for TextMessages {
97    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
98        &self.messages
99    }
100    fn images_ref(&self) -> &[DynamicImage] {
101        &[]
102    }
103    fn take_messages(&mut self) -> RequestMessage {
104        let mut other = Vec::new();
105        std::mem::swap(&mut other, &mut self.messages);
106        RequestMessage::Chat {
107            messages: other,
108            enable_thinking: self.enable_thinking,
109        }
110    }
111    fn enable_search(&self) -> Option<bool> {
112        None
113    }
114    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
115        None
116    }
117    fn take_adapters(&mut self) -> Option<Vec<String>> {
118        None
119    }
120    fn return_logprobs(&self) -> bool {
121        false
122    }
123    fn take_constraint(&mut self) -> Constraint {
124        Constraint::None
125    }
126    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
127        None
128    }
129    fn take_sampling_params(&mut self) -> SamplingParams {
130        SamplingParams::deterministic()
131    }
132    fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
133        None
134    }
135}
136
137#[derive(Debug, Clone, PartialEq)]
138/// Text (chat) messages with images and/or audios.
139///
140/// No constraints, logits processors, logprobs, tools, or adapters.
141///
142/// Sampling is deterministic.
143pub struct VisionMessages {
144    messages: Vec<IndexMap<String, MessageContent>>,
145    images: Vec<DynamicImage>,
146    audios: Vec<AudioInput>,
147    enable_thinking: Option<bool>,
148}
149
150impl Default for VisionMessages {
151    fn default() -> Self {
152        Self::new()
153    }
154}
155
156impl VisionMessages {
157    pub fn new() -> Self {
158        Self {
159            images: Vec::new(),
160            messages: Vec::new(),
161            audios: Vec::new(),
162            enable_thinking: None,
163        }
164    }
165
166    pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
167        self.messages.push(IndexMap::from([
168            ("role".to_string(), Either::Left(role.to_string())),
169            ("content".to_string(), Either::Left(text.to_string())),
170        ]));
171        self
172    }
173
174    pub fn add_image_message(
175        self,
176        role: TextMessageRole,
177        text: impl ToString,
178        images: Vec<DynamicImage>,
179        model: &Model,
180    ) -> anyhow::Result<Self> {
181        self.add_multimodal_message(role, text, images, vec![], model)
182    }
183
184    pub fn add_audio_message(
185        self,
186        role: TextMessageRole,
187        text: impl ToString,
188        audios: Vec<AudioInput>,
189        model: &Model,
190    ) -> anyhow::Result<Self> {
191        self.add_multimodal_message(role, text, vec![], audios, model)
192    }
193
194    pub fn add_multimodal_message(
195        mut self,
196        role: TextMessageRole,
197        text: impl ToString,
198        images: Vec<DynamicImage>,
199        audios: Vec<AudioInput>,
200        model: &Model,
201    ) -> anyhow::Result<Self> {
202        let config = model.config().unwrap();
203        let prefixer = match &config.category {
204            ModelCategory::Vision { prefixer } => prefixer,
205            ModelCategory::Text
206            | ModelCategory::Diffusion
207            | ModelCategory::Speech
208            | ModelCategory::Audio => {
209                anyhow::bail!("`add_image_message` expects a vision model.")
210            }
211        };
212
213        // Images
214        let n_added_images = images.len();
215        let prefixed = prefixer.prefix_image(
216            (self.images.len()..self.images.len() + n_added_images).collect(),
217            &text.to_string(),
218        );
219        self.images.extend(images);
220
221        // Audios
222        let n_added_audios = audios.len();
223        let prefixed = prefixer.prefix_audio(
224            (self.audios.len()..self.audios.len() + n_added_audios).collect(),
225            &prefixed,
226        );
227        self.audios.extend(audios);
228
229        if n_added_images > 0 {
230            self.messages.push(IndexMap::from([
231                ("role".to_string(), Either::Left(role.to_string())),
232                (
233                    "content".to_string(),
234                    Either::Right(vec![
235                        IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
236                        IndexMap::from([
237                            ("type".to_string(), Value::String("text".to_string())),
238                            ("text".to_string(), Value::String(prefixed)),
239                        ]),
240                    ]),
241                ),
242            ]));
243        } else {
244            self.messages.push(IndexMap::from([
245                ("role".to_string(), Either::Left(role.to_string())),
246                ("content".to_string(), Either::Left(prefixed)),
247            ]));
248        }
249        Ok(self)
250    }
251
252    pub fn clear(mut self) -> Self {
253        self.messages.clear();
254        self.images.clear();
255        self.audios.clear();
256
257        self
258    }
259
260    pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
261        self.enable_thinking = Some(enable_thinking);
262        self
263    }
264}
265
266impl RequestLike for VisionMessages {
267    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
268        &self.messages
269    }
270    fn images_ref(&self) -> &[DynamicImage] {
271        &self.images
272    }
273    fn take_messages(&mut self) -> RequestMessage {
274        let mut other_messages = Vec::new();
275        std::mem::swap(&mut other_messages, &mut self.messages);
276        let mut other_images = Vec::new();
277        std::mem::swap(&mut other_images, &mut self.images);
278        let mut other_audios = Vec::new();
279        std::mem::swap(&mut other_audios, &mut self.audios);
280        RequestMessage::VisionChat {
281            images: other_images,
282            messages: other_messages,
283            audios: other_audios,
284            enable_thinking: self.enable_thinking,
285        }
286    }
287    fn enable_search(&self) -> Option<bool> {
288        None
289    }
290    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
291        None
292    }
293    fn take_adapters(&mut self) -> Option<Vec<String>> {
294        None
295    }
296    fn return_logprobs(&self) -> bool {
297        false
298    }
299    fn take_constraint(&mut self) -> Constraint {
300        Constraint::None
301    }
302    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
303        None
304    }
305    fn take_sampling_params(&mut self) -> SamplingParams {
306        SamplingParams::deterministic()
307    }
308    fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
309        None
310    }
311}
312
313#[derive(Clone)]
314/// A way to add messages with finer control given.
315///
316/// This includes control over:
317/// - Logits processors
318/// - Constraints
319/// - Logprobs
320/// - Tools
321/// - Sampling
322/// - Enable thinking for models that support the configuration
323pub struct RequestBuilder {
324    messages: Vec<IndexMap<String, MessageContent>>,
325    images: Vec<DynamicImage>,
326    audios: Vec<AudioInput>,
327    logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
328    adapters: Vec<String>,
329    return_logprobs: bool,
330    constraint: Constraint,
331    tools: Vec<Tool>,
332    tool_choice: ToolChoice,
333    sampling_params: SamplingParams,
334    web_search_options: Option<WebSearchOptions>,
335    enable_thinking: Option<bool>,
336}
337
338impl Default for RequestBuilder {
339    fn default() -> Self {
340        Self::new()
341    }
342}
343
344impl From<TextMessages> for RequestBuilder {
345    fn from(value: TextMessages) -> Self {
346        Self {
347            messages: value.messages,
348            images: Vec::new(),
349            audios: Vec::new(),
350            logits_processors: Vec::new(),
351            adapters: Vec::new(),
352            return_logprobs: false,
353            constraint: Constraint::None,
354            tools: Vec::new(),
355            tool_choice: ToolChoice::Auto,
356            sampling_params: SamplingParams::deterministic(),
357            web_search_options: None,
358            enable_thinking: None,
359        }
360    }
361}
362
363impl From<VisionMessages> for RequestBuilder {
364    fn from(value: VisionMessages) -> Self {
365        Self {
366            messages: value.messages,
367            images: value.images,
368            audios: value.audios,
369            logits_processors: Vec::new(),
370            adapters: Vec::new(),
371            return_logprobs: false,
372            constraint: Constraint::None,
373            tools: Vec::new(),
374            tool_choice: ToolChoice::Auto,
375            sampling_params: SamplingParams::deterministic(),
376            web_search_options: None,
377            enable_thinking: None,
378        }
379    }
380}
381
382impl RequestBuilder {
383    pub fn new() -> Self {
384        Self {
385            messages: Vec::new(),
386            images: Vec::new(),
387            audios: Vec::new(),
388            logits_processors: Vec::new(),
389            adapters: Vec::new(),
390            return_logprobs: false,
391            constraint: Constraint::None,
392            tools: Vec::new(),
393            tool_choice: ToolChoice::Auto,
394            sampling_params: SamplingParams::deterministic(),
395            web_search_options: None,
396            enable_thinking: None,
397        }
398    }
399
400    pub fn with_web_search_options(mut self, web_search_options: WebSearchOptions) -> Self {
401        self.web_search_options = Some(web_search_options);
402        self
403    }
404
405    /// Add a message to the request.
406    ///
407    /// For messages with tool calls, use [`Self::add_message_with_tool_call`].
408    /// For messages with tool outputs, use [`Self::add_tool_message`].
409    pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
410        self.messages.push(IndexMap::from([
411            ("role".to_string(), Either::Left(role.to_string())),
412            ("content".to_string(), Either::Left(text.to_string())),
413        ]));
414        self
415    }
416
417    /// Add a message with the output of a tool call.
418    pub fn add_tool_message(mut self, tool_content: impl ToString, tool_id: impl ToString) -> Self {
419        self.messages.push(IndexMap::from([
420            (
421                "role".to_string(),
422                Either::Left(TextMessageRole::Tool.to_string()),
423            ),
424            (
425                "content".to_string(),
426                Either::Left(tool_content.to_string()),
427            ),
428            (
429                "tool_call_id".to_string(),
430                Either::Left(tool_id.to_string()),
431            ),
432        ]));
433        self
434    }
435
436    pub fn add_message_with_tool_call(
437        mut self,
438        role: TextMessageRole,
439        text: impl ToString,
440        tool_calls: Vec<ToolCallResponse>,
441    ) -> Self {
442        let tool_messages = tool_calls
443            .iter()
444            .map(|t| {
445                IndexMap::from([
446                    ("id".to_string(), Value::String(t.id.clone())),
447                    ("type".to_string(), Value::String(t.tp.to_string())),
448                    (
449                        "function".to_string(),
450                        json!({
451                            "name": t.function.name,
452                            "arguments": t.function.arguments,
453                        }),
454                    ),
455                ])
456            })
457            .collect();
458        self.messages.push(IndexMap::from([
459            ("role".to_string(), Either::Left(role.to_string())),
460            ("content".to_string(), Either::Left(text.to_string())),
461            ("function".to_string(), Either::Right(tool_messages)),
462        ]));
463        self
464    }
465
466    pub fn add_image_message(
467        self,
468        role: TextMessageRole,
469        text: impl ToString,
470        images: Vec<DynamicImage>,
471        model: &Model,
472    ) -> anyhow::Result<Self> {
473        self.add_multimodal_message(role, text, images, vec![], model)
474    }
475
476    pub fn add_audio_message(
477        self,
478        role: TextMessageRole,
479        text: impl ToString,
480        audios: Vec<AudioInput>,
481        model: &Model,
482    ) -> anyhow::Result<Self> {
483        self.add_multimodal_message(role, text, vec![], audios, model)
484    }
485
486    pub fn add_multimodal_message(
487        mut self,
488        role: TextMessageRole,
489        text: impl ToString,
490        images: Vec<DynamicImage>,
491        audios: Vec<AudioInput>,
492        model: &Model,
493    ) -> anyhow::Result<Self> {
494        let config = model.config().unwrap();
495        let prefixer = match &config.category {
496            ModelCategory::Vision { prefixer } => prefixer,
497            ModelCategory::Text
498            | ModelCategory::Diffusion
499            | ModelCategory::Speech
500            | ModelCategory::Audio => {
501                anyhow::bail!("`add_image_message` expects a vision model.")
502            }
503        };
504
505        // Images
506        let n_added_images = images.len();
507        let prefixed = prefixer.prefix_image(
508            (self.images.len()..self.images.len() + n_added_images).collect(),
509            &text.to_string(),
510        );
511        self.images.extend(images);
512
513        // Audios
514        let n_added_audios = audios.len();
515        let prefixed = prefixer.prefix_audio(
516            (self.audios.len()..self.audios.len() + n_added_audios).collect(),
517            &prefixed,
518        );
519        self.audios.extend(audios);
520
521        if n_added_images > 0 {
522            self.messages.push(IndexMap::from([
523                ("role".to_string(), Either::Left(role.to_string())),
524                (
525                    "content".to_string(),
526                    Either::Right(vec![
527                        IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
528                        IndexMap::from([
529                            ("type".to_string(), Value::String("text".to_string())),
530                            ("text".to_string(), Value::String(prefixed)),
531                        ]),
532                    ]),
533                ),
534            ]));
535        } else {
536            self.messages.push(IndexMap::from([
537                ("role".to_string(), Either::Left(role.to_string())),
538                ("content".to_string(), Either::Left(prefixed)),
539            ]));
540        }
541        Ok(self)
542    }
543
544    pub fn add_logits_processor(mut self, processor: Arc<dyn CustomLogitsProcessor>) -> Self {
545        self.logits_processors.push(processor);
546        self
547    }
548
549    pub fn set_adapters(mut self, adapters: Vec<String>) -> Self {
550        self.adapters = adapters;
551        self
552    }
553
554    /// The default tool choice is auto.
555    pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
556        self.tools = tools;
557        self
558    }
559
560    pub fn set_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
561        self.tool_choice = tool_choice;
562        self
563    }
564
565    pub fn return_logprobs(mut self, return_logprobs: bool) -> Self {
566        self.return_logprobs = return_logprobs;
567        self
568    }
569
570    pub fn set_constraint(mut self, constraint: Constraint) -> Self {
571        self.constraint = constraint;
572        self
573    }
574
575    /// Set the sampling parameters as given.
576    pub fn set_sampling(mut self, params: SamplingParams) -> Self {
577        self.sampling_params = params;
578        self
579    }
580
581    /// Set the sampling parameters for deterministic generation.
582    /// This sets up the parameters so that there is:
583    /// - No temperature, topk, topp, minp
584    /// - No penalties, stop tokens, or logit bias
585    /// - No maximum length
586    pub fn set_deterministic_sampler(mut self) -> Self {
587        self.sampling_params = SamplingParams::deterministic();
588        self
589    }
590
591    pub fn set_sampler_temperature(mut self, temperature: f64) -> Self {
592        self.sampling_params.temperature = Some(temperature);
593        self
594    }
595
596    pub fn set_sampler_topk(mut self, topk: usize) -> Self {
597        self.sampling_params.top_k = Some(topk);
598        self
599    }
600
601    pub fn set_sampler_topp(mut self, topp: f64) -> Self {
602        self.sampling_params.top_p = Some(topp);
603        self
604    }
605
606    pub fn set_sampler_minp(mut self, minp: f64) -> Self {
607        self.sampling_params.min_p = Some(minp);
608        self
609    }
610
611    pub fn set_sampler_topn_logprobs(mut self, top_n_logprobs: usize) -> Self {
612        self.sampling_params.top_n_logprobs = top_n_logprobs;
613        self
614    }
615
616    pub fn set_sampler_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
617        self.sampling_params.frequency_penalty = Some(frequency_penalty);
618        self
619    }
620
621    pub fn set_sampler_presence_penalty(mut self, presence_penalty: f32) -> Self {
622        self.sampling_params.presence_penalty = Some(presence_penalty);
623        self
624    }
625
626    pub fn set_sampler_stop_toks(mut self, stop_toks: StopTokens) -> Self {
627        self.sampling_params.stop_toks = Some(stop_toks);
628        self
629    }
630
631    pub fn set_sampler_max_len(mut self, max_len: usize) -> Self {
632        self.sampling_params.max_len = Some(max_len);
633        self
634    }
635
636    pub fn set_sampler_logits_bias(mut self, logits_bias: HashMap<u32, f32>) -> Self {
637        self.sampling_params.logits_bias = Some(logits_bias);
638        self
639    }
640
641    pub fn set_sampler_n_choices(mut self, n_choices: usize) -> Self {
642        self.sampling_params.n_choices = n_choices;
643        self
644    }
645
646    pub fn set_sampler_dry_params(mut self, dry_params: DrySamplingParams) -> Self {
647        self.sampling_params.dry_params = Some(dry_params);
648        self
649    }
650
651    pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
652        self.enable_thinking = Some(enable_thinking);
653        self
654    }
655}
656
657impl RequestLike for RequestBuilder {
658    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
659        &self.messages
660    }
661
662    fn images_ref(&self) -> &[DynamicImage] {
663        &self.images
664    }
665
666    fn take_messages(&mut self) -> RequestMessage {
667        if self.images.is_empty() && self.audios.is_empty() {
668            let mut other = Vec::new();
669            std::mem::swap(&mut other, &mut self.messages);
670            RequestMessage::Chat {
671                messages: other,
672                enable_thinking: self.enable_thinking,
673            }
674        } else {
675            let mut other_messages = Vec::new();
676            std::mem::swap(&mut other_messages, &mut self.messages);
677            let mut other_images = Vec::new();
678            std::mem::swap(&mut other_images, &mut self.images);
679            let mut other_audios = Vec::new();
680            std::mem::swap(&mut other_audios, &mut self.audios);
681            RequestMessage::VisionChat {
682                images: other_images,
683                messages: other_messages,
684                audios: other_audios,
685                enable_thinking: self.enable_thinking,
686            }
687        }
688    }
689
690    fn enable_search(&self) -> Option<bool> {
691        self.web_search_options.as_ref().map(|_| true)
692    }
693
694    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
695        if self.logits_processors.is_empty() {
696            None
697        } else {
698            let mut other = Vec::new();
699            std::mem::swap(&mut other, &mut self.logits_processors);
700            Some(other)
701        }
702    }
703
704    fn take_adapters(&mut self) -> Option<Vec<String>> {
705        if self.adapters.is_empty() {
706            None
707        } else {
708            let mut other = Vec::new();
709            std::mem::swap(&mut other, &mut self.adapters);
710            Some(other)
711        }
712    }
713
714    fn return_logprobs(&self) -> bool {
715        self.return_logprobs
716    }
717
718    fn take_constraint(&mut self) -> Constraint {
719        let mut other = Constraint::None;
720        std::mem::swap(&mut other, &mut self.constraint);
721        other
722    }
723
724    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
725        if self.tools.is_empty() {
726            None
727        } else {
728            let mut other_ts = Vec::new();
729            std::mem::swap(&mut other_ts, &mut self.tools);
730            let mut other_tc = ToolChoice::Auto;
731            std::mem::swap(&mut other_tc, &mut self.tool_choice);
732            Some((other_ts, other_tc))
733        }
734    }
735
736    fn take_sampling_params(&mut self) -> SamplingParams {
737        let mut other = SamplingParams::deterministic();
738        std::mem::swap(&mut other, &mut self.sampling_params);
739        other
740    }
741
742    fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
743        let mut other = None;
744        std::mem::swap(&mut other, &mut self.web_search_options);
745        other
746    }
747}