mistralrs/
messages.rs

1use std::{collections::HashMap, fmt::Display, sync::Arc};
2
3use super::*;
4use either::Either;
5use image::DynamicImage;
6use indexmap::IndexMap;
7use serde_json::{json, Value};
8
9/// A type which can be used as a chat request.
10pub trait RequestLike {
11    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>];
12    fn images_ref(&self) -> &[DynamicImage];
13    fn take_messages(&mut self) -> RequestMessage;
14    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>>;
15    fn take_adapters(&mut self) -> Option<Vec<String>>;
16    fn return_logprobs(&self) -> bool;
17    fn enable_search(&self) -> Option<bool>;
18    fn take_constraint(&mut self) -> Constraint;
19    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)>;
20    fn take_sampling_params(&mut self) -> SamplingParams;
21    fn take_web_search_options(&mut self) -> Option<WebSearchOptions>;
22}
23
24#[derive(Debug, Clone, PartialEq)]
25/// Plain text (chat) messages.
26///
27/// No constraints, logits processors, logprobs, tools, or adapters.
28///
29/// Sampling is deterministic.
30pub struct TextMessages(Vec<IndexMap<String, MessageContent>>);
31
32impl From<TextMessages> for Vec<IndexMap<String, MessageContent>> {
33    fn from(value: TextMessages) -> Self {
34        value.0
35    }
36}
37
38#[derive(Debug, Clone, PartialEq)]
39/// A chat message role.
40pub enum TextMessageRole {
41    User,
42    Assistant,
43    System,
44    Tool,
45    Custom(String),
46}
47
48impl Display for TextMessageRole {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        match self {
51            Self::User => write!(f, "user"),
52            Self::Assistant => write!(f, "assistant"),
53            Self::System => write!(f, "system"),
54            Self::Tool => write!(f, "tool"),
55            Self::Custom(c) => write!(f, "{c}"),
56        }
57    }
58}
59
60impl Default for TextMessages {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl TextMessages {
67    pub fn new() -> Self {
68        Self(Vec::new())
69    }
70
71    pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
72        self.0.push(IndexMap::from([
73            ("role".to_string(), Either::Left(role.to_string())),
74            ("content".to_string(), Either::Left(text.to_string())),
75        ]));
76        self
77    }
78
79    pub fn clear(mut self) -> Self {
80        self.0.clear();
81        self
82    }
83}
84
85impl RequestLike for TextMessages {
86    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
87        &self.0
88    }
89    fn images_ref(&self) -> &[DynamicImage] {
90        &[]
91    }
92    fn take_messages(&mut self) -> RequestMessage {
93        let mut other = Vec::new();
94        std::mem::swap(&mut other, &mut self.0);
95        RequestMessage::Chat {
96            messages: other,
97            enable_thinking: self.enable_search(),
98        }
99    }
100    fn enable_search(&self) -> Option<bool> {
101        None
102    }
103    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
104        None
105    }
106    fn take_adapters(&mut self) -> Option<Vec<String>> {
107        None
108    }
109    fn return_logprobs(&self) -> bool {
110        false
111    }
112    fn take_constraint(&mut self) -> Constraint {
113        Constraint::None
114    }
115    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
116        None
117    }
118    fn take_sampling_params(&mut self) -> SamplingParams {
119        SamplingParams::deterministic()
120    }
121    fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
122        None
123    }
124}
125
126#[derive(Debug, Clone, PartialEq)]
127/// Text (chat) messages with images and/or audios.
128///
129/// No constraints, logits processors, logprobs, tools, or adapters.
130///
131/// Sampling is deterministic.
132pub struct VisionMessages {
133    messages: Vec<IndexMap<String, MessageContent>>,
134    images: Vec<DynamicImage>,
135    audios: Vec<AudioInput>,
136}
137
138impl Default for VisionMessages {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143
144impl VisionMessages {
145    pub fn new() -> Self {
146        Self {
147            images: Vec::new(),
148            messages: Vec::new(),
149            audios: Vec::new(),
150        }
151    }
152
153    pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
154        self.messages.push(IndexMap::from([
155            ("role".to_string(), Either::Left(role.to_string())),
156            ("content".to_string(), Either::Left(text.to_string())),
157        ]));
158        self
159    }
160
161    pub fn add_image_message(
162        self,
163        role: TextMessageRole,
164        text: impl ToString,
165        images: Vec<DynamicImage>,
166        model: &Model,
167    ) -> anyhow::Result<Self> {
168        self.add_multimodal_message(role, text, images, vec![], model)
169    }
170
171    pub fn add_audio_message(
172        self,
173        role: TextMessageRole,
174        text: impl ToString,
175        audios: Vec<AudioInput>,
176        model: &Model,
177    ) -> anyhow::Result<Self> {
178        self.add_multimodal_message(role, text, vec![], audios, model)
179    }
180
181    pub fn add_multimodal_message(
182        mut self,
183        role: TextMessageRole,
184        text: impl ToString,
185        images: Vec<DynamicImage>,
186        audios: Vec<AudioInput>,
187        model: &Model,
188    ) -> anyhow::Result<Self> {
189        let prefixer = match &model.config().category {
190            ModelCategory::Vision { prefixer } => prefixer,
191            ModelCategory::Text
192            | ModelCategory::Diffusion
193            | ModelCategory::Speech
194            | ModelCategory::Audio => {
195                anyhow::bail!("`add_image_message` expects a vision model.")
196            }
197        };
198
199        // Images
200        let n_added_images = images.len();
201        let prefixed = prefixer.prefix_image(
202            (self.images.len()..self.images.len() + n_added_images).collect(),
203            &text.to_string(),
204        );
205        self.images.extend(images);
206
207        // Audios
208        let n_added_audios = audios.len();
209        let prefixed = prefixer.prefix_audio(
210            (self.audios.len()..self.audios.len() + n_added_audios).collect(),
211            &prefixed,
212        );
213        self.audios.extend(audios);
214
215        if n_added_images > 0 {
216            self.messages.push(IndexMap::from([
217                ("role".to_string(), Either::Left(role.to_string())),
218                (
219                    "content".to_string(),
220                    Either::Right(vec![
221                        IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
222                        IndexMap::from([
223                            ("type".to_string(), Value::String("text".to_string())),
224                            ("text".to_string(), Value::String(prefixed)),
225                        ]),
226                    ]),
227                ),
228            ]));
229        } else {
230            self.messages.push(IndexMap::from([
231                ("role".to_string(), Either::Left(role.to_string())),
232                ("content".to_string(), Either::Left(prefixed)),
233            ]));
234        }
235        Ok(self)
236    }
237
238    pub fn clear(mut self) -> Self {
239        self.messages.clear();
240        self.images.clear();
241
242        self
243    }
244}
245
246impl RequestLike for VisionMessages {
247    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
248        &self.messages
249    }
250    fn images_ref(&self) -> &[DynamicImage] {
251        &self.images
252    }
253    fn take_messages(&mut self) -> RequestMessage {
254        let mut other_messages = Vec::new();
255        std::mem::swap(&mut other_messages, &mut self.messages);
256        let mut other_images = Vec::new();
257        std::mem::swap(&mut other_images, &mut self.images);
258        let mut other_audios = Vec::new();
259        std::mem::swap(&mut other_audios, &mut self.audios);
260        RequestMessage::VisionChat {
261            images: other_images,
262            messages: other_messages,
263            audios: other_audios,
264            enable_thinking: self.enable_search(),
265        }
266    }
267    fn enable_search(&self) -> Option<bool> {
268        None
269    }
270    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
271        None
272    }
273    fn take_adapters(&mut self) -> Option<Vec<String>> {
274        None
275    }
276    fn return_logprobs(&self) -> bool {
277        false
278    }
279    fn take_constraint(&mut self) -> Constraint {
280        Constraint::None
281    }
282    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
283        None
284    }
285    fn take_sampling_params(&mut self) -> SamplingParams {
286        SamplingParams::deterministic()
287    }
288    fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
289        None
290    }
291}
292
293#[derive(Clone)]
294/// A way to add messages with finer control given.
295///
296/// This includes control over:
297/// - Logits processors
298/// - Constraints
299/// - Logprobs
300/// - Tools
301/// - Sampling
302/// - Enable thinking for models that support the configuration
303pub struct RequestBuilder {
304    messages: Vec<IndexMap<String, MessageContent>>,
305    images: Vec<DynamicImage>,
306    audios: Vec<AudioInput>,
307    logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
308    adapters: Vec<String>,
309    return_logprobs: bool,
310    constraint: Constraint,
311    tools: Vec<Tool>,
312    tool_choice: ToolChoice,
313    sampling_params: SamplingParams,
314    web_search_options: Option<WebSearchOptions>,
315    enable_thinking: Option<bool>,
316}
317
318impl Default for RequestBuilder {
319    fn default() -> Self {
320        Self::new()
321    }
322}
323
324impl From<TextMessages> for RequestBuilder {
325    fn from(value: TextMessages) -> Self {
326        Self {
327            messages: value.0,
328            images: Vec::new(),
329            audios: Vec::new(),
330            logits_processors: Vec::new(),
331            adapters: Vec::new(),
332            return_logprobs: false,
333            constraint: Constraint::None,
334            tools: Vec::new(),
335            tool_choice: ToolChoice::Auto,
336            sampling_params: SamplingParams::deterministic(),
337            web_search_options: None,
338            enable_thinking: None,
339        }
340    }
341}
342
343impl From<VisionMessages> for RequestBuilder {
344    fn from(value: VisionMessages) -> Self {
345        Self {
346            messages: value.messages,
347            images: value.images,
348            audios: value.audios,
349            logits_processors: Vec::new(),
350            adapters: Vec::new(),
351            return_logprobs: false,
352            constraint: Constraint::None,
353            tools: Vec::new(),
354            tool_choice: ToolChoice::Auto,
355            sampling_params: SamplingParams::deterministic(),
356            web_search_options: None,
357            enable_thinking: None,
358        }
359    }
360}
361
362impl RequestBuilder {
363    pub fn new() -> Self {
364        Self {
365            messages: Vec::new(),
366            images: Vec::new(),
367            audios: Vec::new(),
368            logits_processors: Vec::new(),
369            adapters: Vec::new(),
370            return_logprobs: false,
371            constraint: Constraint::None,
372            tools: Vec::new(),
373            tool_choice: ToolChoice::Auto,
374            sampling_params: SamplingParams::deterministic(),
375            web_search_options: None,
376            enable_thinking: None,
377        }
378    }
379
380    pub fn with_web_search_options(mut self, web_search_options: WebSearchOptions) -> Self {
381        self.web_search_options = Some(web_search_options);
382        self
383    }
384
385    /// Add a message to the request.
386    ///
387    /// For messages with tool calls, use [`Self::add_message_with_tool_call`].
388    /// For messages with tool outputs, use [`Self::add_tool_message`].
389    pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
390        self.messages.push(IndexMap::from([
391            ("role".to_string(), Either::Left(role.to_string())),
392            ("content".to_string(), Either::Left(text.to_string())),
393        ]));
394        self
395    }
396
397    /// Add a message with the output of a tool call.
398    pub fn add_tool_message(mut self, tool_content: impl ToString, tool_id: impl ToString) -> Self {
399        self.messages.push(IndexMap::from([
400            (
401                "role".to_string(),
402                Either::Left(TextMessageRole::Tool.to_string()),
403            ),
404            (
405                "content".to_string(),
406                Either::Left(tool_content.to_string()),
407            ),
408            (
409                "tool_call_id".to_string(),
410                Either::Left(tool_id.to_string()),
411            ),
412        ]));
413        self
414    }
415
416    pub fn add_message_with_tool_call(
417        mut self,
418        role: TextMessageRole,
419        text: impl ToString,
420        tool_calls: Vec<ToolCallResponse>,
421    ) -> Self {
422        let tool_messages = tool_calls
423            .iter()
424            .map(|t| {
425                IndexMap::from([
426                    ("id".to_string(), Value::String(t.id.clone())),
427                    ("type".to_string(), Value::String(t.tp.to_string())),
428                    (
429                        "function".to_string(),
430                        json!({
431                            "name": t.function.name,
432                            "arguments": t.function.arguments,
433                        }),
434                    ),
435                ])
436            })
437            .collect();
438        self.messages.push(IndexMap::from([
439            ("role".to_string(), Either::Left(role.to_string())),
440            ("content".to_string(), Either::Left(text.to_string())),
441            ("function".to_string(), Either::Right(tool_messages)),
442        ]));
443        self
444    }
445
446    pub fn add_image_message(
447        self,
448        role: TextMessageRole,
449        text: impl ToString,
450        images: Vec<DynamicImage>,
451        model: &Model,
452    ) -> anyhow::Result<Self> {
453        self.add_multimodal_message(role, text, images, vec![], model)
454    }
455
456    pub fn add_audio_message(
457        self,
458        role: TextMessageRole,
459        text: impl ToString,
460        audios: Vec<AudioInput>,
461        model: &Model,
462    ) -> anyhow::Result<Self> {
463        self.add_multimodal_message(role, text, vec![], audios, model)
464    }
465
466    pub fn add_multimodal_message(
467        mut self,
468        role: TextMessageRole,
469        text: impl ToString,
470        images: Vec<DynamicImage>,
471        audios: Vec<AudioInput>,
472        model: &Model,
473    ) -> anyhow::Result<Self> {
474        let prefixer = match &model.config().category {
475            ModelCategory::Vision { prefixer } => prefixer,
476            ModelCategory::Text
477            | ModelCategory::Diffusion
478            | ModelCategory::Speech
479            | ModelCategory::Audio => {
480                anyhow::bail!("`add_image_message` expects a vision model.")
481            }
482        };
483
484        // Images
485        let n_added_images = images.len();
486        let prefixed = prefixer.prefix_image(
487            (self.images.len()..self.images.len() + n_added_images).collect(),
488            &text.to_string(),
489        );
490        self.images.extend(images);
491
492        // Audios
493        let n_added_audios = audios.len();
494        let prefixed = prefixer.prefix_audio(
495            (self.audios.len()..self.audios.len() + n_added_audios).collect(),
496            &prefixed,
497        );
498        self.audios.extend(audios);
499
500        if n_added_images > 0 {
501            self.messages.push(IndexMap::from([
502                ("role".to_string(), Either::Left(role.to_string())),
503                (
504                    "content".to_string(),
505                    Either::Right(vec![
506                        IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
507                        IndexMap::from([
508                            ("type".to_string(), Value::String("text".to_string())),
509                            ("text".to_string(), Value::String(prefixed)),
510                        ]),
511                    ]),
512                ),
513            ]));
514        } else {
515            self.messages.push(IndexMap::from([
516                ("role".to_string(), Either::Left(role.to_string())),
517                ("content".to_string(), Either::Left(prefixed)),
518            ]));
519        }
520        Ok(self)
521    }
522
523    pub fn add_logits_processor(mut self, processor: Arc<dyn CustomLogitsProcessor>) -> Self {
524        self.logits_processors.push(processor);
525        self
526    }
527
528    pub fn set_adapters(mut self, adapters: Vec<String>) -> Self {
529        self.adapters = adapters;
530        self
531    }
532
533    /// The default tool choice is auto.
534    pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
535        self.tools = tools;
536        self
537    }
538
539    pub fn set_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
540        self.tool_choice = tool_choice;
541        self
542    }
543
544    pub fn return_logprobs(mut self, return_logprobs: bool) -> Self {
545        self.return_logprobs = return_logprobs;
546        self
547    }
548
549    pub fn set_constraint(mut self, constraint: Constraint) -> Self {
550        self.constraint = constraint;
551        self
552    }
553
554    /// Set the sampling parameters as given.
555    pub fn set_sampling(mut self, params: SamplingParams) -> Self {
556        self.sampling_params = params;
557        self
558    }
559
560    /// Set the sampling parameters for deterministic generation.
561    /// This sets up the parameters so that there is:
562    /// - No temperature, topk, topp, minp
563    /// - No penalties, stop tokens, or logit bias
564    /// - No maximum length
565    pub fn set_deterministic_sampler(mut self) -> Self {
566        self.sampling_params = SamplingParams::deterministic();
567        self
568    }
569
570    pub fn set_sampler_temperature(mut self, temperature: f64) -> Self {
571        self.sampling_params.temperature = Some(temperature);
572        self
573    }
574
575    pub fn set_sampler_topk(mut self, topk: usize) -> Self {
576        self.sampling_params.top_k = Some(topk);
577        self
578    }
579
580    pub fn set_sampler_topp(mut self, topp: f64) -> Self {
581        self.sampling_params.top_p = Some(topp);
582        self
583    }
584
585    pub fn set_sampler_minp(mut self, minp: f64) -> Self {
586        self.sampling_params.min_p = Some(minp);
587        self
588    }
589
590    pub fn set_sampler_topn_logprobs(mut self, top_n_logprobs: usize) -> Self {
591        self.sampling_params.top_n_logprobs = top_n_logprobs;
592        self
593    }
594
595    pub fn set_sampler_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
596        self.sampling_params.frequency_penalty = Some(frequency_penalty);
597        self
598    }
599
600    pub fn set_sampler_presence_penalty(mut self, presence_penalty: f32) -> Self {
601        self.sampling_params.presence_penalty = Some(presence_penalty);
602        self
603    }
604
605    pub fn set_sampler_stop_toks(mut self, stop_toks: StopTokens) -> Self {
606        self.sampling_params.stop_toks = Some(stop_toks);
607        self
608    }
609
610    pub fn set_sampler_max_len(mut self, max_len: usize) -> Self {
611        self.sampling_params.max_len = Some(max_len);
612        self
613    }
614
615    pub fn set_sampler_logits_bias(mut self, logits_bias: HashMap<u32, f32>) -> Self {
616        self.sampling_params.logits_bias = Some(logits_bias);
617        self
618    }
619
620    pub fn set_sampler_n_choices(mut self, n_choices: usize) -> Self {
621        self.sampling_params.n_choices = n_choices;
622        self
623    }
624
625    pub fn set_sampler_dry_params(mut self, dry_params: DrySamplingParams) -> Self {
626        self.sampling_params.dry_params = Some(dry_params);
627        self
628    }
629
630    pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
631        self.enable_thinking = Some(enable_thinking);
632        self
633    }
634}
635
636impl RequestLike for RequestBuilder {
637    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
638        &self.messages
639    }
640
641    fn images_ref(&self) -> &[DynamicImage] {
642        &self.images
643    }
644
645    fn take_messages(&mut self) -> RequestMessage {
646        if self.images.is_empty() && self.audios.is_empty() {
647            let mut other = Vec::new();
648            std::mem::swap(&mut other, &mut self.messages);
649            RequestMessage::Chat {
650                messages: other,
651                enable_thinking: self.enable_thinking,
652            }
653        } else {
654            let mut other_messages = Vec::new();
655            std::mem::swap(&mut other_messages, &mut self.messages);
656            let mut other_images = Vec::new();
657            std::mem::swap(&mut other_images, &mut self.images);
658            let mut other_audios = Vec::new();
659            std::mem::swap(&mut other_audios, &mut self.audios);
660            RequestMessage::VisionChat {
661                images: other_images,
662                messages: other_messages,
663                audios: other_audios,
664                enable_thinking: self.enable_thinking,
665            }
666        }
667    }
668
669    fn enable_search(&self) -> Option<bool> {
670        self.enable_thinking
671    }
672
673    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
674        if self.logits_processors.is_empty() {
675            None
676        } else {
677            let mut other = Vec::new();
678            std::mem::swap(&mut other, &mut self.logits_processors);
679            Some(other)
680        }
681    }
682
683    fn take_adapters(&mut self) -> Option<Vec<String>> {
684        if self.adapters.is_empty() {
685            None
686        } else {
687            let mut other = Vec::new();
688            std::mem::swap(&mut other, &mut self.adapters);
689            Some(other)
690        }
691    }
692
693    fn return_logprobs(&self) -> bool {
694        self.return_logprobs
695    }
696
697    fn take_constraint(&mut self) -> Constraint {
698        let mut other = Constraint::None;
699        std::mem::swap(&mut other, &mut self.constraint);
700        other
701    }
702
703    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
704        if self.tools.is_empty() {
705            None
706        } else {
707            let mut other_ts = Vec::new();
708            std::mem::swap(&mut other_ts, &mut self.tools);
709            let mut other_tc = ToolChoice::Auto;
710            std::mem::swap(&mut other_tc, &mut self.tool_choice);
711            Some((other_ts, other_tc))
712        }
713    }
714
715    fn take_sampling_params(&mut self) -> SamplingParams {
716        let mut other = SamplingParams::deterministic();
717        std::mem::swap(&mut other, &mut self.sampling_params);
718        other
719    }
720
721    fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
722        let mut other = None;
723        std::mem::swap(&mut other, &mut self.web_search_options);
724        other
725    }
726}