mistralrs/
messages.rs

1use std::{collections::HashMap, fmt::Display, sync::Arc};
2
3use super::*;
4use either::Either;
5use image::DynamicImage;
6use indexmap::IndexMap;
7use serde_json::{json, Value};
8
9/// A type which can be used as a chat request.
10pub trait RequestLike {
11    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>];
12    fn take_messages(&mut self) -> RequestMessage;
13    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>>;
14    fn take_adapters(&mut self) -> Option<Vec<String>>;
15    fn return_logprobs(&self) -> bool;
16    fn take_constraint(&mut self) -> Constraint;
17    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)>;
18    fn take_sampling_params(&mut self) -> SamplingParams;
19    fn take_web_search_options(&mut self) -> Option<WebSearchOptions>;
20}
21
22#[derive(Debug, Clone, PartialEq)]
23/// Plain text (chat) messages.
24///
25/// No constraints, logits processors, logprobs, tools, or adapters.
26///
27/// Sampling is deterministic.
28pub struct TextMessages(Vec<IndexMap<String, MessageContent>>);
29
30impl From<TextMessages> for Vec<IndexMap<String, MessageContent>> {
31    fn from(value: TextMessages) -> Self {
32        value.0
33    }
34}
35
36/// A chat message role.
37pub enum TextMessageRole {
38    User,
39    Assistant,
40    System,
41    Tool,
42    Custom(String),
43}
44
45impl Display for TextMessageRole {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            Self::User => write!(f, "user"),
49            Self::Assistant => write!(f, "assistant"),
50            Self::System => write!(f, "system"),
51            Self::Tool => write!(f, "tool"),
52            Self::Custom(c) => write!(f, "{c}"),
53        }
54    }
55}
56
57impl Default for TextMessages {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl TextMessages {
64    pub fn new() -> Self {
65        Self(Vec::new())
66    }
67
68    pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
69        self.0.push(IndexMap::from([
70            ("role".to_string(), Either::Left(role.to_string())),
71            ("content".to_string(), Either::Left(text.to_string())),
72        ]));
73        self
74    }
75
76    pub fn clear(mut self) -> Self {
77        self.0.clear();
78        self
79    }
80}
81
82impl RequestLike for TextMessages {
83    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
84        &self.0
85    }
86    fn take_messages(&mut self) -> RequestMessage {
87        let mut other = Vec::new();
88        std::mem::swap(&mut other, &mut self.0);
89        RequestMessage::Chat(other)
90    }
91    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
92        None
93    }
94    fn take_adapters(&mut self) -> Option<Vec<String>> {
95        None
96    }
97    fn return_logprobs(&self) -> bool {
98        false
99    }
100    fn take_constraint(&mut self) -> Constraint {
101        Constraint::None
102    }
103    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
104        None
105    }
106    fn take_sampling_params(&mut self) -> SamplingParams {
107        SamplingParams::deterministic()
108    }
109    fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
110        None
111    }
112}
113
114#[derive(Debug, Clone, PartialEq)]
115/// Text (chat) messages with images.
116///
117/// No constraints, logits processors, logprobs, tools, or adapters.
118///
119/// Sampling is deterministic.
120pub struct VisionMessages {
121    messages: Vec<IndexMap<String, MessageContent>>,
122    images: Vec<DynamicImage>,
123}
124
125impl Default for VisionMessages {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131impl VisionMessages {
132    pub fn new() -> Self {
133        Self {
134            images: Vec::new(),
135            messages: Vec::new(),
136        }
137    }
138
139    pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
140        self.messages.push(IndexMap::from([
141            ("role".to_string(), Either::Left(role.to_string())),
142            ("content".to_string(), Either::Left(text.to_string())),
143        ]));
144        self
145    }
146
147    pub fn add_image_message(
148        mut self,
149        role: TextMessageRole,
150        text: impl ToString,
151        image: DynamicImage,
152        model: &Model,
153    ) -> anyhow::Result<Self> {
154        let prefixer = match &model.config().category {
155            ModelCategory::Text | ModelCategory::Diffusion => {
156                anyhow::bail!("`add_image_message` expects a vision model.")
157            }
158            ModelCategory::Vision {
159                has_conv2d: _,
160                prefixer,
161            } => prefixer,
162        };
163        self.images.push(image);
164        self.messages.push(IndexMap::from([
165            ("role".to_string(), Either::Left(role.to_string())),
166            (
167                "content".to_string(),
168                Either::Right(vec![
169                    IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
170                    IndexMap::from([
171                        ("type".to_string(), Value::String("text".to_string())),
172                        (
173                            "text".to_string(),
174                            Value::String(
175                                prefixer.prefix_image(self.images.len() - 1, &text.to_string()),
176                            ),
177                        ),
178                    ]),
179                ]),
180            ),
181        ]));
182        Ok(self)
183    }
184
185    pub fn clear(mut self) -> Self {
186        self.messages.clear();
187        self.images.clear();
188
189        self
190    }
191}
192
193impl RequestLike for VisionMessages {
194    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
195        &self.messages
196    }
197    fn take_messages(&mut self) -> RequestMessage {
198        let mut other_messages = Vec::new();
199        std::mem::swap(&mut other_messages, &mut self.messages);
200        let mut other_images = Vec::new();
201        std::mem::swap(&mut other_images, &mut self.images);
202        RequestMessage::VisionChat {
203            images: other_images,
204            messages: other_messages,
205        }
206    }
207    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
208        None
209    }
210    fn take_adapters(&mut self) -> Option<Vec<String>> {
211        None
212    }
213    fn return_logprobs(&self) -> bool {
214        false
215    }
216    fn take_constraint(&mut self) -> Constraint {
217        Constraint::None
218    }
219    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
220        None
221    }
222    fn take_sampling_params(&mut self) -> SamplingParams {
223        SamplingParams::deterministic()
224    }
225    fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
226        None
227    }
228}
229
230#[derive(Clone)]
231/// A way to add messages with finer control given.
232///
233/// This includes control over:
234/// - Logits processors
235/// - Constraints
236/// - Logprobs
237/// - Tools
238/// - Sampling
239pub struct RequestBuilder {
240    messages: Vec<IndexMap<String, MessageContent>>,
241    images: Vec<DynamicImage>,
242    logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
243    adapters: Vec<String>,
244    return_logprobs: bool,
245    constraint: Constraint,
246    tools: Vec<Tool>,
247    tool_choice: ToolChoice,
248    sampling_params: SamplingParams,
249    web_search_options: Option<WebSearchOptions>,
250}
251
252impl Default for RequestBuilder {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258impl From<TextMessages> for RequestBuilder {
259    fn from(value: TextMessages) -> Self {
260        Self {
261            messages: value.0,
262            images: Vec::new(),
263            logits_processors: Vec::new(),
264            adapters: Vec::new(),
265            return_logprobs: false,
266            constraint: Constraint::None,
267            tools: Vec::new(),
268            tool_choice: ToolChoice::Auto,
269            sampling_params: SamplingParams::deterministic(),
270            web_search_options: None,
271        }
272    }
273}
274
275impl From<VisionMessages> for RequestBuilder {
276    fn from(value: VisionMessages) -> Self {
277        Self {
278            messages: value.messages,
279            images: value.images,
280            logits_processors: Vec::new(),
281            adapters: Vec::new(),
282            return_logprobs: false,
283            constraint: Constraint::None,
284            tools: Vec::new(),
285            tool_choice: ToolChoice::Auto,
286            sampling_params: SamplingParams::deterministic(),
287            web_search_options: None,
288        }
289    }
290}
291
292impl RequestBuilder {
293    pub fn new() -> Self {
294        Self {
295            messages: Vec::new(),
296            images: Vec::new(),
297            logits_processors: Vec::new(),
298            adapters: Vec::new(),
299            return_logprobs: false,
300            constraint: Constraint::None,
301            tools: Vec::new(),
302            tool_choice: ToolChoice::Auto,
303            sampling_params: SamplingParams::deterministic(),
304            web_search_options: None,
305        }
306    }
307
308    pub fn with_web_search_options(mut self, web_search_options: WebSearchOptions) -> Self {
309        self.web_search_options = Some(web_search_options);
310        self
311    }
312
313    /// Add a message to the request.
314    ///
315    /// For messages with tool calls, use [`Self::add_message_with_tool_call`].
316    /// For messages with tool outputs, use [`Self::add_tool_message`].
317    pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
318        self.messages.push(IndexMap::from([
319            ("role".to_string(), Either::Left(role.to_string())),
320            ("content".to_string(), Either::Left(text.to_string())),
321        ]));
322        self
323    }
324
325    /// Add a message with the output of a tool call.
326    pub fn add_tool_message(mut self, tool_content: impl ToString, tool_id: impl ToString) -> Self {
327        self.messages.push(IndexMap::from([
328            (
329                "role".to_string(),
330                Either::Left(TextMessageRole::Tool.to_string()),
331            ),
332            (
333                "content".to_string(),
334                Either::Left(tool_content.to_string()),
335            ),
336            (
337                "tool_call_id".to_string(),
338                Either::Left(tool_id.to_string()),
339            ),
340        ]));
341        self
342    }
343
344    pub fn add_message_with_tool_call(
345        mut self,
346        role: TextMessageRole,
347        text: impl ToString,
348        tool_calls: Vec<ToolCallResponse>,
349    ) -> Self {
350        let tool_messages = tool_calls
351            .iter()
352            .map(|t| {
353                IndexMap::from([
354                    ("id".to_string(), Value::String(t.id.clone())),
355                    ("type".to_string(), Value::String(t.tp.to_string())),
356                    (
357                        "function".to_string(),
358                        json!({
359                            "name": t.function.name,
360                            "arguments": t.function.arguments,
361                        }),
362                    ),
363                ])
364            })
365            .collect();
366        self.messages.push(IndexMap::from([
367            ("role".to_string(), Either::Left(role.to_string())),
368            ("content".to_string(), Either::Left(text.to_string())),
369            ("function".to_string(), Either::Right(tool_messages)),
370        ]));
371        self
372    }
373
374    pub fn add_image_message(
375        mut self,
376        role: TextMessageRole,
377        text: impl ToString,
378        image: DynamicImage,
379    ) -> Self {
380        self.messages.push(IndexMap::from([
381            ("role".to_string(), Either::Left(role.to_string())),
382            ("content".to_string(), Either::Left(text.to_string())),
383        ]));
384        self.images.push(image);
385        self
386    }
387
388    pub fn add_logits_processor(mut self, processor: Arc<dyn CustomLogitsProcessor>) -> Self {
389        self.logits_processors.push(processor);
390        self
391    }
392
393    pub fn set_adapters(mut self, adapters: Vec<String>) -> Self {
394        self.adapters = adapters;
395        self
396    }
397
398    /// The default tool choice is auto.
399    pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
400        self.tools = tools;
401        self
402    }
403
404    pub fn set_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
405        self.tool_choice = tool_choice;
406        self
407    }
408
409    pub fn return_logprobs(mut self, return_logprobs: bool) -> Self {
410        self.return_logprobs = return_logprobs;
411        self
412    }
413
414    pub fn set_constraint(mut self, constraint: Constraint) -> Self {
415        self.constraint = constraint;
416        self
417    }
418
419    /// Set the sampling parameters as given.
420    pub fn set_sampling(mut self, params: SamplingParams) -> Self {
421        self.sampling_params = params;
422        self
423    }
424
425    /// Set the sampling parameters for deterministic generation.
426    /// This sets up the parameters so that there is:
427    /// - No temperature, topk, topp, minp
428    /// - No penalties, stop tokens, or logit bias
429    /// - No maximum length
430    pub fn set_deterministic_sampler(mut self) -> Self {
431        self.sampling_params = SamplingParams::deterministic();
432        self
433    }
434
435    pub fn set_sampler_temperature(mut self, temperature: f64) -> Self {
436        self.sampling_params.temperature = Some(temperature);
437        self
438    }
439
440    pub fn set_sampler_topk(mut self, topk: usize) -> Self {
441        self.sampling_params.top_k = Some(topk);
442        self
443    }
444
445    pub fn set_sampler_topp(mut self, topp: f64) -> Self {
446        self.sampling_params.top_p = Some(topp);
447        self
448    }
449
450    pub fn set_sampler_minp(mut self, minp: f64) -> Self {
451        self.sampling_params.min_p = Some(minp);
452        self
453    }
454
455    pub fn set_sampler_topn_logprobs(mut self, top_n_logprobs: usize) -> Self {
456        self.sampling_params.top_n_logprobs = top_n_logprobs;
457        self
458    }
459
460    pub fn set_sampler_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
461        self.sampling_params.frequency_penalty = Some(frequency_penalty);
462        self
463    }
464
465    pub fn set_sampler_presence_penalty(mut self, presence_penalty: f32) -> Self {
466        self.sampling_params.presence_penalty = Some(presence_penalty);
467        self
468    }
469
470    pub fn set_sampler_stop_toks(mut self, stop_toks: StopTokens) -> Self {
471        self.sampling_params.stop_toks = Some(stop_toks);
472        self
473    }
474
475    pub fn set_sampler_max_len(mut self, max_len: usize) -> Self {
476        self.sampling_params.max_len = Some(max_len);
477        self
478    }
479
480    pub fn set_sampler_logits_bias(mut self, logits_bias: HashMap<u32, f32>) -> Self {
481        self.sampling_params.logits_bias = Some(logits_bias);
482        self
483    }
484
485    pub fn set_sampler_n_choices(mut self, n_choices: usize) -> Self {
486        self.sampling_params.n_choices = n_choices;
487        self
488    }
489
490    pub fn set_sampler_dry_params(mut self, dry_params: DrySamplingParams) -> Self {
491        self.sampling_params.dry_params = Some(dry_params);
492        self
493    }
494}
495
496impl RequestLike for RequestBuilder {
497    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
498        &self.messages
499    }
500
501    fn take_messages(&mut self) -> RequestMessage {
502        if self.images.is_empty() {
503            let mut other = Vec::new();
504            std::mem::swap(&mut other, &mut self.messages);
505            RequestMessage::Chat(other)
506        } else {
507            let mut other_messages = Vec::new();
508            std::mem::swap(&mut other_messages, &mut self.messages);
509            let mut other_images = Vec::new();
510            std::mem::swap(&mut other_images, &mut self.images);
511            RequestMessage::VisionChat {
512                images: other_images,
513                messages: other_messages,
514            }
515        }
516    }
517
518    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
519        if self.logits_processors.is_empty() {
520            None
521        } else {
522            let mut other = Vec::new();
523            std::mem::swap(&mut other, &mut self.logits_processors);
524            Some(other)
525        }
526    }
527
528    fn take_adapters(&mut self) -> Option<Vec<String>> {
529        if self.adapters.is_empty() {
530            None
531        } else {
532            let mut other = Vec::new();
533            std::mem::swap(&mut other, &mut self.adapters);
534            Some(other)
535        }
536    }
537
538    fn return_logprobs(&self) -> bool {
539        self.return_logprobs
540    }
541
542    fn take_constraint(&mut self) -> Constraint {
543        let mut other = Constraint::None;
544        std::mem::swap(&mut other, &mut self.constraint);
545        other
546    }
547
548    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
549        if self.tools.is_empty() {
550            None
551        } else {
552            let mut other_ts = Vec::new();
553            std::mem::swap(&mut other_ts, &mut self.tools);
554            let mut other_tc = ToolChoice::Auto;
555            std::mem::swap(&mut other_tc, &mut self.tool_choice);
556            Some((other_ts, other_tc))
557        }
558    }
559
560    fn take_sampling_params(&mut self) -> SamplingParams {
561        let mut other = SamplingParams::deterministic();
562        std::mem::swap(&mut other, &mut self.sampling_params);
563        other
564    }
565
566    fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
567        let mut other = None;
568        std::mem::swap(&mut other, &mut self.web_search_options);
569        other
570    }
571}