mistralrs/
messages.rs

1use std::{collections::HashMap, fmt::Display, sync::Arc};
2
3use super::*;
4use either::Either;
5use image::DynamicImage;
6use indexmap::IndexMap;
7use serde_json::{json, Value};
8
9/// A type which can be used as a chat request.
10pub trait RequestLike {
11    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>];
12    fn images_ref(&self) -> &[DynamicImage];
13    fn take_messages(&mut self) -> RequestMessage;
14    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>>;
15    fn take_adapters(&mut self) -> Option<Vec<String>>;
16    fn return_logprobs(&self) -> bool;
17    fn enable_search(&self) -> Option<bool>;
18    fn take_constraint(&mut self) -> Constraint;
19    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)>;
20    fn take_sampling_params(&mut self) -> SamplingParams;
21    fn take_web_search_options(&mut self) -> Option<WebSearchOptions>;
22    fn truncate_sequence(&self) -> bool {
23        false
24    }
25}
26
27#[derive(Debug, Clone, PartialEq)]
28/// Plain text (chat) messages.
29///
30/// No constraints, logits processors, logprobs, tools, or adapters.
31///
32/// Sampling is deterministic.
33pub struct TextMessages {
34    messages: Vec<IndexMap<String, MessageContent>>,
35    enable_thinking: Option<bool>,
36}
37
38impl From<TextMessages> for Vec<IndexMap<String, MessageContent>> {
39    fn from(value: TextMessages) -> Self {
40        value.messages
41    }
42}
43
44#[derive(Debug, Clone, PartialEq)]
45/// A chat message role.
46pub enum TextMessageRole {
47    User,
48    Assistant,
49    System,
50    Tool,
51    Custom(String),
52}
53
54impl Display for TextMessageRole {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        match self {
57            Self::User => write!(f, "user"),
58            Self::Assistant => write!(f, "assistant"),
59            Self::System => write!(f, "system"),
60            Self::Tool => write!(f, "tool"),
61            Self::Custom(c) => write!(f, "{c}"),
62        }
63    }
64}
65
66impl Default for TextMessages {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72impl TextMessages {
73    pub fn new() -> Self {
74        Self {
75            messages: Vec::new(),
76            enable_thinking: None,
77        }
78    }
79
80    pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
81        self.messages.push(IndexMap::from([
82            ("role".to_string(), Either::Left(role.to_string())),
83            ("content".to_string(), Either::Left(text.to_string())),
84        ]));
85        self
86    }
87
88    pub fn clear(mut self) -> Self {
89        self.messages.clear();
90        self
91    }
92
93    pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
94        self.enable_thinking = Some(enable_thinking);
95        self
96    }
97}
98
99impl RequestLike for TextMessages {
100    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
101        &self.messages
102    }
103    fn images_ref(&self) -> &[DynamicImage] {
104        &[]
105    }
106    fn take_messages(&mut self) -> RequestMessage {
107        let mut other = Vec::new();
108        std::mem::swap(&mut other, &mut self.messages);
109        RequestMessage::Chat {
110            messages: other,
111            enable_thinking: self.enable_thinking,
112            reasoning_effort: None,
113        }
114    }
115    fn enable_search(&self) -> Option<bool> {
116        None
117    }
118    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
119        None
120    }
121    fn take_adapters(&mut self) -> Option<Vec<String>> {
122        None
123    }
124    fn return_logprobs(&self) -> bool {
125        false
126    }
127    fn take_constraint(&mut self) -> Constraint {
128        Constraint::None
129    }
130    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
131        None
132    }
133    fn take_sampling_params(&mut self) -> SamplingParams {
134        SamplingParams::deterministic()
135    }
136    fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
137        None
138    }
139}
140
141#[derive(Debug, Clone, PartialEq)]
142/// Text (chat) messages with images and/or audios.
143///
144/// No constraints, logits processors, logprobs, tools, or adapters.
145///
146/// Sampling is deterministic.
147pub struct VisionMessages {
148    messages: Vec<IndexMap<String, MessageContent>>,
149    images: Vec<DynamicImage>,
150    audios: Vec<AudioInput>,
151    enable_thinking: Option<bool>,
152}
153
154impl Default for VisionMessages {
155    fn default() -> Self {
156        Self::new()
157    }
158}
159
160impl VisionMessages {
161    pub fn new() -> Self {
162        Self {
163            images: Vec::new(),
164            messages: Vec::new(),
165            audios: Vec::new(),
166            enable_thinking: None,
167        }
168    }
169
170    pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
171        self.messages.push(IndexMap::from([
172            ("role".to_string(), Either::Left(role.to_string())),
173            ("content".to_string(), Either::Left(text.to_string())),
174        ]));
175        self
176    }
177
178    pub fn add_image_message(
179        self,
180        role: TextMessageRole,
181        text: impl ToString,
182        images: Vec<DynamicImage>,
183        model: &Model,
184    ) -> anyhow::Result<Self> {
185        self.add_multimodal_message(role, text, images, vec![], model)
186    }
187
188    pub fn add_audio_message(
189        self,
190        role: TextMessageRole,
191        text: impl ToString,
192        audios: Vec<AudioInput>,
193        model: &Model,
194    ) -> anyhow::Result<Self> {
195        self.add_multimodal_message(role, text, vec![], audios, model)
196    }
197
198    pub fn add_multimodal_message(
199        mut self,
200        role: TextMessageRole,
201        text: impl ToString,
202        images: Vec<DynamicImage>,
203        audios: Vec<AudioInput>,
204        model: &Model,
205    ) -> anyhow::Result<Self> {
206        let config = model.config().unwrap();
207        let prefixer = match &config.category {
208            ModelCategory::Vision { prefixer } => prefixer,
209            _ => {
210                anyhow::bail!("`add_image_message` expects a vision model.")
211            }
212        };
213
214        // Images
215        let n_added_images = images.len();
216        let image_indexes: Vec<usize> =
217            (self.images.len()..self.images.len() + n_added_images).collect();
218        self.images.extend(images);
219
220        // Audios
221        let n_added_audios = audios.len();
222        let audio_indexes: Vec<usize> =
223            (self.audios.len()..self.audios.len() + n_added_audios).collect();
224        self.audios.extend(audios);
225
226        if n_added_images > 0 || n_added_audios > 0 {
227            // Build mixed content parts
228            let mut content_vec: Vec<IndexMap<String, Value>> = Vec::new();
229            for _ in 0..n_added_images {
230                content_vec.push(IndexMap::from([(
231                    "type".to_string(),
232                    Value::String("image".to_string()),
233                )]));
234            }
235            for _ in 0..n_added_audios {
236                content_vec.push(IndexMap::from([(
237                    "type".to_string(),
238                    Value::String("audio".to_string()),
239                )]));
240            }
241            // Prefix the text with any media context
242            let mut prefixed_text = text.to_string();
243            if !image_indexes.is_empty() {
244                prefixed_text = prefixer.prefix_image(image_indexes, &prefixed_text);
245            }
246            if !audio_indexes.is_empty() {
247                prefixed_text = prefixer.prefix_audio(audio_indexes, &prefixed_text);
248            }
249            // Add the final text part
250            content_vec.push(IndexMap::from([
251                ("type".to_string(), Value::String("text".to_string())),
252                ("text".to_string(), Value::String(prefixed_text)),
253            ]));
254
255            self.messages.push(IndexMap::from([
256                ("role".to_string(), Either::Left(role.to_string())),
257                ("content".to_string(), Either::Right(content_vec)),
258            ]));
259        } else {
260            self.messages.push(IndexMap::from([
261                ("role".to_string(), Either::Left(role.to_string())),
262                ("content".to_string(), Either::Left(text.to_string())),
263            ]));
264        }
265        Ok(self)
266    }
267
268    pub fn clear(mut self) -> Self {
269        self.messages.clear();
270        self.images.clear();
271        self.audios.clear();
272
273        self
274    }
275
276    pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
277        self.enable_thinking = Some(enable_thinking);
278        self
279    }
280}
281
282impl RequestLike for VisionMessages {
283    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
284        &self.messages
285    }
286    fn images_ref(&self) -> &[DynamicImage] {
287        &self.images
288    }
289    fn take_messages(&mut self) -> RequestMessage {
290        let mut other_messages = Vec::new();
291        std::mem::swap(&mut other_messages, &mut self.messages);
292        let mut other_images = Vec::new();
293        std::mem::swap(&mut other_images, &mut self.images);
294        let mut other_audios = Vec::new();
295        std::mem::swap(&mut other_audios, &mut self.audios);
296        RequestMessage::VisionChat {
297            images: other_images,
298            messages: other_messages,
299            audios: other_audios,
300            enable_thinking: self.enable_thinking,
301            reasoning_effort: None,
302        }
303    }
304    fn enable_search(&self) -> Option<bool> {
305        None
306    }
307    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
308        None
309    }
310    fn take_adapters(&mut self) -> Option<Vec<String>> {
311        None
312    }
313    fn return_logprobs(&self) -> bool {
314        false
315    }
316    fn take_constraint(&mut self) -> Constraint {
317        Constraint::None
318    }
319    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
320        None
321    }
322    fn take_sampling_params(&mut self) -> SamplingParams {
323        SamplingParams::deterministic()
324    }
325    fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
326        None
327    }
328}
329
330#[derive(Clone)]
331/// A way to add messages with finer control given.
332///
333/// This includes control over:
334/// - Logits processors
335/// - Constraints
336/// - Logprobs
337/// - Tools
338/// - Sampling
339/// - Enable thinking for models that support the configuration
340pub struct RequestBuilder {
341    messages: Vec<IndexMap<String, MessageContent>>,
342    images: Vec<DynamicImage>,
343    audios: Vec<AudioInput>,
344    logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
345    adapters: Vec<String>,
346    return_logprobs: bool,
347    constraint: Constraint,
348    tools: Vec<Tool>,
349    tool_choice: ToolChoice,
350    sampling_params: SamplingParams,
351    web_search_options: Option<WebSearchOptions>,
352    enable_thinking: Option<bool>,
353    truncate_sequence: bool,
354}
355
356impl Default for RequestBuilder {
357    fn default() -> Self {
358        Self::new()
359    }
360}
361
362impl From<TextMessages> for RequestBuilder {
363    fn from(value: TextMessages) -> Self {
364        Self {
365            messages: value.messages,
366            images: Vec::new(),
367            audios: Vec::new(),
368            logits_processors: Vec::new(),
369            adapters: Vec::new(),
370            return_logprobs: false,
371            constraint: Constraint::None,
372            tools: Vec::new(),
373            tool_choice: ToolChoice::Auto,
374            sampling_params: SamplingParams::deterministic(),
375            web_search_options: None,
376            enable_thinking: None,
377            truncate_sequence: false,
378        }
379    }
380}
381
382impl From<VisionMessages> for RequestBuilder {
383    fn from(value: VisionMessages) -> Self {
384        Self {
385            messages: value.messages,
386            images: value.images,
387            audios: value.audios,
388            logits_processors: Vec::new(),
389            adapters: Vec::new(),
390            return_logprobs: false,
391            constraint: Constraint::None,
392            tools: Vec::new(),
393            tool_choice: ToolChoice::Auto,
394            sampling_params: SamplingParams::deterministic(),
395            web_search_options: None,
396            enable_thinking: None,
397            truncate_sequence: false,
398        }
399    }
400}
401
402impl RequestBuilder {
403    pub fn new() -> Self {
404        Self {
405            messages: Vec::new(),
406            images: Vec::new(),
407            audios: Vec::new(),
408            logits_processors: Vec::new(),
409            adapters: Vec::new(),
410            return_logprobs: false,
411            constraint: Constraint::None,
412            tools: Vec::new(),
413            tool_choice: ToolChoice::Auto,
414            sampling_params: SamplingParams::deterministic(),
415            web_search_options: None,
416            enable_thinking: None,
417            truncate_sequence: false,
418        }
419    }
420
421    pub fn with_web_search_options(mut self, web_search_options: WebSearchOptions) -> Self {
422        self.web_search_options = Some(web_search_options);
423        self
424    }
425
426    /// Add a message to the request.
427    ///
428    /// For messages with tool calls, use [`Self::add_message_with_tool_call`].
429    /// For messages with tool outputs, use [`Self::add_tool_message`].
430    pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
431        self.messages.push(IndexMap::from([
432            ("role".to_string(), Either::Left(role.to_string())),
433            ("content".to_string(), Either::Left(text.to_string())),
434        ]));
435        self
436    }
437
438    /// Add a message with the output of a tool call.
439    pub fn add_tool_message(mut self, tool_content: impl ToString, tool_id: impl ToString) -> Self {
440        self.messages.push(IndexMap::from([
441            (
442                "role".to_string(),
443                Either::Left(TextMessageRole::Tool.to_string()),
444            ),
445            (
446                "content".to_string(),
447                Either::Left(tool_content.to_string()),
448            ),
449            (
450                "tool_call_id".to_string(),
451                Either::Left(tool_id.to_string()),
452            ),
453        ]));
454        self
455    }
456
457    pub fn add_message_with_tool_call(
458        mut self,
459        role: TextMessageRole,
460        text: impl ToString,
461        tool_calls: Vec<ToolCallResponse>,
462    ) -> Self {
463        let tool_messages = tool_calls
464            .iter()
465            .map(|t| {
466                IndexMap::from([
467                    ("id".to_string(), Value::String(t.id.clone())),
468                    ("type".to_string(), Value::String(t.tp.to_string())),
469                    (
470                        "function".to_string(),
471                        json!({
472                            "name": t.function.name,
473                            "arguments": t.function.arguments,
474                        }),
475                    ),
476                ])
477            })
478            .collect();
479        self.messages.push(IndexMap::from([
480            ("role".to_string(), Either::Left(role.to_string())),
481            ("content".to_string(), Either::Left(text.to_string())),
482            ("function".to_string(), Either::Right(tool_messages)),
483        ]));
484        self
485    }
486
487    pub fn add_image_message(
488        self,
489        role: TextMessageRole,
490        text: impl ToString,
491        images: Vec<DynamicImage>,
492        model: &Model,
493    ) -> anyhow::Result<Self> {
494        self.add_multimodal_message(role, text, images, vec![], model)
495    }
496
497    pub fn add_audio_message(
498        self,
499        role: TextMessageRole,
500        text: impl ToString,
501        audios: Vec<AudioInput>,
502        model: &Model,
503    ) -> anyhow::Result<Self> {
504        self.add_multimodal_message(role, text, vec![], audios, model)
505    }
506
507    /// By convention, all images are added before all audios.
508    pub fn add_multimodal_message(
509        mut self,
510        role: TextMessageRole,
511        text: impl ToString,
512        images: Vec<DynamicImage>,
513        audios: Vec<AudioInput>,
514        model: &Model,
515    ) -> anyhow::Result<Self> {
516        let config = model.config().unwrap();
517        let prefixer = match &config.category {
518            ModelCategory::Vision { prefixer } => prefixer,
519            _ => {
520                anyhow::bail!("`add_image_message` expects a vision model.")
521            }
522        };
523
524        // Images
525        let n_added_images = images.len();
526        let image_indexes: Vec<usize> =
527            (self.images.len()..self.images.len() + n_added_images).collect();
528        self.images.extend(images);
529
530        // Audios
531        let n_added_audios = audios.len();
532        let audio_indexes: Vec<usize> =
533            (self.audios.len()..self.audios.len() + n_added_audios).collect();
534        self.audios.extend(audios);
535
536        if n_added_images > 0 || n_added_audios > 0 {
537            // Build mixed content parts
538            let mut content_vec: Vec<IndexMap<String, Value>> = Vec::new();
539            for _ in 0..n_added_images {
540                content_vec.push(IndexMap::from([(
541                    "type".to_string(),
542                    Value::String("image".to_string()),
543                )]));
544            }
545            for _ in 0..n_added_audios {
546                content_vec.push(IndexMap::from([(
547                    "type".to_string(),
548                    Value::String("audio".to_string()),
549                )]));
550            }
551            // Prefix the text with any media context
552            let mut prefixed_text = text.to_string();
553            if !image_indexes.is_empty() {
554                prefixed_text = prefixer.prefix_image(image_indexes, &prefixed_text);
555            }
556            if !audio_indexes.is_empty() {
557                prefixed_text = prefixer.prefix_audio(audio_indexes, &prefixed_text);
558            }
559            // Add the final text part
560            content_vec.push(IndexMap::from([
561                ("type".to_string(), Value::String("text".to_string())),
562                ("text".to_string(), Value::String(prefixed_text)),
563            ]));
564
565            self.messages.push(IndexMap::from([
566                ("role".to_string(), Either::Left(role.to_string())),
567                ("content".to_string(), Either::Right(content_vec)),
568            ]));
569        } else {
570            self.messages.push(IndexMap::from([
571                ("role".to_string(), Either::Left(role.to_string())),
572                ("content".to_string(), Either::Left(text.to_string())),
573            ]));
574        }
575        Ok(self)
576    }
577
578    pub fn add_logits_processor(mut self, processor: Arc<dyn CustomLogitsProcessor>) -> Self {
579        self.logits_processors.push(processor);
580        self
581    }
582
583    pub fn set_adapters(mut self, adapters: Vec<String>) -> Self {
584        self.adapters = adapters;
585        self
586    }
587
588    /// The default tool choice is auto.
589    pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
590        self.tools = tools;
591        self
592    }
593
594    pub fn set_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
595        self.tool_choice = tool_choice;
596        self
597    }
598
599    pub fn return_logprobs(mut self, return_logprobs: bool) -> Self {
600        self.return_logprobs = return_logprobs;
601        self
602    }
603
604    pub fn set_constraint(mut self, constraint: Constraint) -> Self {
605        self.constraint = constraint;
606        self
607    }
608
609    /// Set the sampling parameters as given.
610    pub fn set_sampling(mut self, params: SamplingParams) -> Self {
611        self.sampling_params = params;
612        self
613    }
614
615    /// Set the sampling parameters for deterministic generation.
616    /// This sets up the parameters so that there is:
617    /// - No temperature, topk, topp, minp
618    /// - No penalties, stop tokens, or logit bias
619    /// - No maximum length
620    pub fn set_deterministic_sampler(mut self) -> Self {
621        self.sampling_params = SamplingParams::deterministic();
622        self
623    }
624
625    pub fn set_sampler_temperature(mut self, temperature: f64) -> Self {
626        self.sampling_params.temperature = Some(temperature);
627        self
628    }
629
630    pub fn set_sampler_topk(mut self, topk: usize) -> Self {
631        self.sampling_params.top_k = Some(topk);
632        self
633    }
634
635    pub fn set_sampler_topp(mut self, topp: f64) -> Self {
636        self.sampling_params.top_p = Some(topp);
637        self
638    }
639
640    pub fn set_sampler_minp(mut self, minp: f64) -> Self {
641        self.sampling_params.min_p = Some(minp);
642        self
643    }
644
645    pub fn set_sampler_topn_logprobs(mut self, top_n_logprobs: usize) -> Self {
646        self.sampling_params.top_n_logprobs = top_n_logprobs;
647        self
648    }
649
650    pub fn set_sampler_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
651        self.sampling_params.frequency_penalty = Some(frequency_penalty);
652        self
653    }
654
655    pub fn set_sampler_presence_penalty(mut self, presence_penalty: f32) -> Self {
656        self.sampling_params.presence_penalty = Some(presence_penalty);
657        self
658    }
659
660    pub fn set_sampler_stop_toks(mut self, stop_toks: StopTokens) -> Self {
661        self.sampling_params.stop_toks = Some(stop_toks);
662        self
663    }
664
665    pub fn set_sampler_max_len(mut self, max_len: usize) -> Self {
666        self.sampling_params.max_len = Some(max_len);
667        self
668    }
669
670    pub fn set_sampler_logits_bias(mut self, logits_bias: HashMap<u32, f32>) -> Self {
671        self.sampling_params.logits_bias = Some(logits_bias);
672        self
673    }
674
675    pub fn set_sampler_n_choices(mut self, n_choices: usize) -> Self {
676        self.sampling_params.n_choices = n_choices;
677        self
678    }
679
680    pub fn set_sampler_dry_params(mut self, dry_params: DrySamplingParams) -> Self {
681        self.sampling_params.dry_params = Some(dry_params);
682        self
683    }
684
685    pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
686        self.enable_thinking = Some(enable_thinking);
687        self
688    }
689
690    /// Truncate prompts that exceed the model's maximum context length.
691    pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
692        self.truncate_sequence = truncate_sequence;
693        self
694    }
695}
696
697impl RequestLike for RequestBuilder {
698    fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
699        &self.messages
700    }
701
702    fn images_ref(&self) -> &[DynamicImage] {
703        &self.images
704    }
705
706    fn take_messages(&mut self) -> RequestMessage {
707        if self.images.is_empty() && self.audios.is_empty() {
708            let mut other = Vec::new();
709            std::mem::swap(&mut other, &mut self.messages);
710            RequestMessage::Chat {
711                messages: other,
712                enable_thinking: self.enable_thinking,
713                reasoning_effort: None,
714            }
715        } else {
716            let mut other_messages = Vec::new();
717            std::mem::swap(&mut other_messages, &mut self.messages);
718            let mut other_images = Vec::new();
719            std::mem::swap(&mut other_images, &mut self.images);
720            let mut other_audios = Vec::new();
721            std::mem::swap(&mut other_audios, &mut self.audios);
722            RequestMessage::VisionChat {
723                images: other_images,
724                messages: other_messages,
725                audios: other_audios,
726                enable_thinking: self.enable_thinking,
727                reasoning_effort: None,
728            }
729        }
730    }
731
732    fn enable_search(&self) -> Option<bool> {
733        self.web_search_options.as_ref().map(|_| true)
734    }
735
736    fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
737        if self.logits_processors.is_empty() {
738            None
739        } else {
740            let mut other = Vec::new();
741            std::mem::swap(&mut other, &mut self.logits_processors);
742            Some(other)
743        }
744    }
745
746    fn take_adapters(&mut self) -> Option<Vec<String>> {
747        if self.adapters.is_empty() {
748            None
749        } else {
750            let mut other = Vec::new();
751            std::mem::swap(&mut other, &mut self.adapters);
752            Some(other)
753        }
754    }
755
756    fn return_logprobs(&self) -> bool {
757        self.return_logprobs
758    }
759
760    fn take_constraint(&mut self) -> Constraint {
761        let mut other = Constraint::None;
762        std::mem::swap(&mut other, &mut self.constraint);
763        other
764    }
765
766    fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
767        if self.tools.is_empty() {
768            None
769        } else {
770            let mut other_ts = Vec::new();
771            std::mem::swap(&mut other_ts, &mut self.tools);
772            let mut other_tc = ToolChoice::Auto;
773            std::mem::swap(&mut other_tc, &mut self.tool_choice);
774            Some((other_ts, other_tc))
775        }
776    }
777
778    fn take_sampling_params(&mut self) -> SamplingParams {
779        let mut other = SamplingParams::deterministic();
780        std::mem::swap(&mut other, &mut self.sampling_params);
781        other
782    }
783
784    fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
785        let mut other = None;
786        std::mem::swap(&mut other, &mut self.web_search_options);
787        other
788    }
789
790    fn truncate_sequence(&self) -> bool {
791        self.truncate_sequence
792    }
793}
794
795#[derive(Clone, Debug)]
796/// An individual embedding input.
797pub enum EmbeddingRequestInput {
798    /// Raw text prompt that will be tokenized.
799    Prompt(String),
800    /// Pre-tokenized input.
801    Tokens(Vec<u32>),
802}
803
804impl EmbeddingRequestInput {
805    pub fn into_request_message(self) -> RequestMessage {
806        match self {
807            Self::Prompt(prompt) => RequestMessage::Embedding { prompt },
808            Self::Tokens(prompt) => RequestMessage::EmbeddingTokens { prompt },
809        }
810    }
811}
812
813#[derive(Clone, Debug)]
814/// A validated embedding request constructed via [`EmbeddingRequestBuilder`].
815pub struct EmbeddingRequest {
816    pub inputs: Vec<EmbeddingRequestInput>,
817    pub truncate_sequence: bool,
818}
819
820impl EmbeddingRequest {
821    /// Create a new builder for an embedding request.
822    pub fn builder() -> EmbeddingRequestBuilder {
823        EmbeddingRequestBuilder::new()
824    }
825}
826
827/// Builder for configuring embedding requests.
828#[derive(Clone, Debug, Default)]
829pub struct EmbeddingRequestBuilder {
830    inputs: Vec<EmbeddingRequestInput>,
831    truncate_sequence: bool,
832}
833
834impl EmbeddingRequestBuilder {
835    /// Create an empty builder. You must add at least one input before using it.
836    pub fn new() -> Self {
837        Self::default()
838    }
839
840    /// Add a single text prompt.
841    pub fn add_prompt(mut self, prompt: impl Into<String>) -> Self {
842        self.inputs
843            .push(EmbeddingRequestInput::Prompt(prompt.into()));
844        self
845    }
846
847    /// Add multiple text prompts at once.
848    pub fn add_prompts<I, S>(mut self, prompts: I) -> Self
849    where
850        I: IntoIterator<Item = S>,
851        S: Into<String>,
852    {
853        self.inputs.extend(
854            prompts
855                .into_iter()
856                .map(|prompt| EmbeddingRequestInput::Prompt(prompt.into())),
857        );
858        self
859    }
860
861    /// Add a single pre-tokenized prompt.
862    pub fn add_tokens(mut self, tokens: impl Into<Vec<u32>>) -> Self {
863        self.inputs
864            .push(EmbeddingRequestInput::Tokens(tokens.into()));
865        self
866    }
867
868    /// Add multiple pre-tokenized prompts.
869    pub fn add_tokens_batch<I>(mut self, batches: I) -> Self
870    where
871        I: IntoIterator<Item = Vec<u32>>,
872    {
873        self.inputs
874            .extend(batches.into_iter().map(EmbeddingRequestInput::Tokens));
875        self
876    }
877
878    /// Control whether prompts longer than the model context are truncated.
879    pub fn with_truncate_sequence(mut self, truncate: bool) -> Self {
880        self.truncate_sequence = truncate;
881        self
882    }
883
884    pub fn build(self) -> anyhow::Result<EmbeddingRequest> {
885        if self.inputs.is_empty() {
886            anyhow::bail!("Embedding request must contain at least one input.");
887        }
888
889        Ok(EmbeddingRequest {
890            inputs: self.inputs,
891            truncate_sequence: self.truncate_sequence,
892        })
893    }
894}