1use std::{collections::HashMap, fmt::Display, sync::Arc};
2
3use super::*;
4use either::Either;
5use image::DynamicImage;
6use indexmap::IndexMap;
7use serde_json::{json, Value};
8
9pub trait RequestLike {
11 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>];
12 fn images_ref(&self) -> &[DynamicImage];
13 fn take_messages(&mut self) -> RequestMessage;
14 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>>;
15 fn take_adapters(&mut self) -> Option<Vec<String>>;
16 fn return_logprobs(&self) -> bool;
17 fn enable_search(&self) -> Option<bool>;
18 fn take_constraint(&mut self) -> Constraint;
19 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)>;
20 fn take_sampling_params(&mut self) -> SamplingParams;
21 fn take_web_search_options(&mut self) -> Option<WebSearchOptions>;
22 fn truncate_sequence(&self) -> bool {
23 false
24 }
25}
26
27#[derive(Debug, Clone, PartialEq)]
28pub struct TextMessages {
34 messages: Vec<IndexMap<String, MessageContent>>,
35 enable_thinking: Option<bool>,
36}
37
38impl From<TextMessages> for Vec<IndexMap<String, MessageContent>> {
39 fn from(value: TextMessages) -> Self {
40 value.messages
41 }
42}
43
44#[derive(Debug, Clone, PartialEq)]
45pub enum TextMessageRole {
47 User,
48 Assistant,
49 System,
50 Tool,
51 Custom(String),
52}
53
54impl Display for TextMessageRole {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 match self {
57 Self::User => write!(f, "user"),
58 Self::Assistant => write!(f, "assistant"),
59 Self::System => write!(f, "system"),
60 Self::Tool => write!(f, "tool"),
61 Self::Custom(c) => write!(f, "{c}"),
62 }
63 }
64}
65
66impl Default for TextMessages {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72impl TextMessages {
73 pub fn new() -> Self {
74 Self {
75 messages: Vec::new(),
76 enable_thinking: None,
77 }
78 }
79
80 pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
81 self.messages.push(IndexMap::from([
82 ("role".to_string(), Either::Left(role.to_string())),
83 ("content".to_string(), Either::Left(text.to_string())),
84 ]));
85 self
86 }
87
88 pub fn clear(mut self) -> Self {
89 self.messages.clear();
90 self
91 }
92
93 pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
94 self.enable_thinking = Some(enable_thinking);
95 self
96 }
97}
98
99impl RequestLike for TextMessages {
100 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
101 &self.messages
102 }
103 fn images_ref(&self) -> &[DynamicImage] {
104 &[]
105 }
106 fn take_messages(&mut self) -> RequestMessage {
107 let mut other = Vec::new();
108 std::mem::swap(&mut other, &mut self.messages);
109 RequestMessage::Chat {
110 messages: other,
111 enable_thinking: self.enable_thinking,
112 reasoning_effort: None,
113 }
114 }
115 fn enable_search(&self) -> Option<bool> {
116 None
117 }
118 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
119 None
120 }
121 fn take_adapters(&mut self) -> Option<Vec<String>> {
122 None
123 }
124 fn return_logprobs(&self) -> bool {
125 false
126 }
127 fn take_constraint(&mut self) -> Constraint {
128 Constraint::None
129 }
130 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
131 None
132 }
133 fn take_sampling_params(&mut self) -> SamplingParams {
134 SamplingParams::deterministic()
135 }
136 fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
137 None
138 }
139}
140
141#[derive(Debug, Clone, PartialEq)]
142pub struct VisionMessages {
148 messages: Vec<IndexMap<String, MessageContent>>,
149 images: Vec<DynamicImage>,
150 audios: Vec<AudioInput>,
151 enable_thinking: Option<bool>,
152}
153
154impl Default for VisionMessages {
155 fn default() -> Self {
156 Self::new()
157 }
158}
159
160impl VisionMessages {
161 pub fn new() -> Self {
162 Self {
163 images: Vec::new(),
164 messages: Vec::new(),
165 audios: Vec::new(),
166 enable_thinking: None,
167 }
168 }
169
170 pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
171 self.messages.push(IndexMap::from([
172 ("role".to_string(), Either::Left(role.to_string())),
173 ("content".to_string(), Either::Left(text.to_string())),
174 ]));
175 self
176 }
177
178 pub fn add_image_message(
179 self,
180 role: TextMessageRole,
181 text: impl ToString,
182 images: Vec<DynamicImage>,
183 model: &Model,
184 ) -> anyhow::Result<Self> {
185 self.add_multimodal_message(role, text, images, vec![], model)
186 }
187
188 pub fn add_audio_message(
189 self,
190 role: TextMessageRole,
191 text: impl ToString,
192 audios: Vec<AudioInput>,
193 model: &Model,
194 ) -> anyhow::Result<Self> {
195 self.add_multimodal_message(role, text, vec![], audios, model)
196 }
197
198 pub fn add_multimodal_message(
199 mut self,
200 role: TextMessageRole,
201 text: impl ToString,
202 images: Vec<DynamicImage>,
203 audios: Vec<AudioInput>,
204 model: &Model,
205 ) -> anyhow::Result<Self> {
206 let config = model.config().unwrap();
207 let prefixer = match &config.category {
208 ModelCategory::Vision { prefixer } => prefixer,
209 _ => {
210 anyhow::bail!("`add_image_message` expects a vision model.")
211 }
212 };
213
214 let n_added_images = images.len();
216 let image_indexes: Vec<usize> =
217 (self.images.len()..self.images.len() + n_added_images).collect();
218 self.images.extend(images);
219
220 let n_added_audios = audios.len();
222 let audio_indexes: Vec<usize> =
223 (self.audios.len()..self.audios.len() + n_added_audios).collect();
224 self.audios.extend(audios);
225
226 if n_added_images > 0 || n_added_audios > 0 {
227 let mut content_vec: Vec<IndexMap<String, Value>> = Vec::new();
229 for _ in 0..n_added_images {
230 content_vec.push(IndexMap::from([(
231 "type".to_string(),
232 Value::String("image".to_string()),
233 )]));
234 }
235 for _ in 0..n_added_audios {
236 content_vec.push(IndexMap::from([(
237 "type".to_string(),
238 Value::String("audio".to_string()),
239 )]));
240 }
241 let mut prefixed_text = text.to_string();
243 if !image_indexes.is_empty() {
244 prefixed_text = prefixer.prefix_image(image_indexes, &prefixed_text);
245 }
246 if !audio_indexes.is_empty() {
247 prefixed_text = prefixer.prefix_audio(audio_indexes, &prefixed_text);
248 }
249 content_vec.push(IndexMap::from([
251 ("type".to_string(), Value::String("text".to_string())),
252 ("text".to_string(), Value::String(prefixed_text)),
253 ]));
254
255 self.messages.push(IndexMap::from([
256 ("role".to_string(), Either::Left(role.to_string())),
257 ("content".to_string(), Either::Right(content_vec)),
258 ]));
259 } else {
260 self.messages.push(IndexMap::from([
261 ("role".to_string(), Either::Left(role.to_string())),
262 ("content".to_string(), Either::Left(text.to_string())),
263 ]));
264 }
265 Ok(self)
266 }
267
268 pub fn clear(mut self) -> Self {
269 self.messages.clear();
270 self.images.clear();
271 self.audios.clear();
272
273 self
274 }
275
276 pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
277 self.enable_thinking = Some(enable_thinking);
278 self
279 }
280}
281
282impl RequestLike for VisionMessages {
283 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
284 &self.messages
285 }
286 fn images_ref(&self) -> &[DynamicImage] {
287 &self.images
288 }
289 fn take_messages(&mut self) -> RequestMessage {
290 let mut other_messages = Vec::new();
291 std::mem::swap(&mut other_messages, &mut self.messages);
292 let mut other_images = Vec::new();
293 std::mem::swap(&mut other_images, &mut self.images);
294 let mut other_audios = Vec::new();
295 std::mem::swap(&mut other_audios, &mut self.audios);
296 RequestMessage::VisionChat {
297 images: other_images,
298 messages: other_messages,
299 audios: other_audios,
300 enable_thinking: self.enable_thinking,
301 reasoning_effort: None,
302 }
303 }
304 fn enable_search(&self) -> Option<bool> {
305 None
306 }
307 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
308 None
309 }
310 fn take_adapters(&mut self) -> Option<Vec<String>> {
311 None
312 }
313 fn return_logprobs(&self) -> bool {
314 false
315 }
316 fn take_constraint(&mut self) -> Constraint {
317 Constraint::None
318 }
319 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
320 None
321 }
322 fn take_sampling_params(&mut self) -> SamplingParams {
323 SamplingParams::deterministic()
324 }
325 fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
326 None
327 }
328}
329
330#[derive(Clone)]
331pub struct RequestBuilder {
341 messages: Vec<IndexMap<String, MessageContent>>,
342 images: Vec<DynamicImage>,
343 audios: Vec<AudioInput>,
344 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
345 adapters: Vec<String>,
346 return_logprobs: bool,
347 constraint: Constraint,
348 tools: Vec<Tool>,
349 tool_choice: ToolChoice,
350 sampling_params: SamplingParams,
351 web_search_options: Option<WebSearchOptions>,
352 enable_thinking: Option<bool>,
353 truncate_sequence: bool,
354}
355
356impl Default for RequestBuilder {
357 fn default() -> Self {
358 Self::new()
359 }
360}
361
362impl From<TextMessages> for RequestBuilder {
363 fn from(value: TextMessages) -> Self {
364 Self {
365 messages: value.messages,
366 images: Vec::new(),
367 audios: Vec::new(),
368 logits_processors: Vec::new(),
369 adapters: Vec::new(),
370 return_logprobs: false,
371 constraint: Constraint::None,
372 tools: Vec::new(),
373 tool_choice: ToolChoice::Auto,
374 sampling_params: SamplingParams::deterministic(),
375 web_search_options: None,
376 enable_thinking: None,
377 truncate_sequence: false,
378 }
379 }
380}
381
382impl From<VisionMessages> for RequestBuilder {
383 fn from(value: VisionMessages) -> Self {
384 Self {
385 messages: value.messages,
386 images: value.images,
387 audios: value.audios,
388 logits_processors: Vec::new(),
389 adapters: Vec::new(),
390 return_logprobs: false,
391 constraint: Constraint::None,
392 tools: Vec::new(),
393 tool_choice: ToolChoice::Auto,
394 sampling_params: SamplingParams::deterministic(),
395 web_search_options: None,
396 enable_thinking: None,
397 truncate_sequence: false,
398 }
399 }
400}
401
402impl RequestBuilder {
403 pub fn new() -> Self {
404 Self {
405 messages: Vec::new(),
406 images: Vec::new(),
407 audios: Vec::new(),
408 logits_processors: Vec::new(),
409 adapters: Vec::new(),
410 return_logprobs: false,
411 constraint: Constraint::None,
412 tools: Vec::new(),
413 tool_choice: ToolChoice::Auto,
414 sampling_params: SamplingParams::deterministic(),
415 web_search_options: None,
416 enable_thinking: None,
417 truncate_sequence: false,
418 }
419 }
420
421 pub fn with_web_search_options(mut self, web_search_options: WebSearchOptions) -> Self {
422 self.web_search_options = Some(web_search_options);
423 self
424 }
425
426 pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
431 self.messages.push(IndexMap::from([
432 ("role".to_string(), Either::Left(role.to_string())),
433 ("content".to_string(), Either::Left(text.to_string())),
434 ]));
435 self
436 }
437
438 pub fn add_tool_message(mut self, tool_content: impl ToString, tool_id: impl ToString) -> Self {
440 self.messages.push(IndexMap::from([
441 (
442 "role".to_string(),
443 Either::Left(TextMessageRole::Tool.to_string()),
444 ),
445 (
446 "content".to_string(),
447 Either::Left(tool_content.to_string()),
448 ),
449 (
450 "tool_call_id".to_string(),
451 Either::Left(tool_id.to_string()),
452 ),
453 ]));
454 self
455 }
456
457 pub fn add_message_with_tool_call(
458 mut self,
459 role: TextMessageRole,
460 text: impl ToString,
461 tool_calls: Vec<ToolCallResponse>,
462 ) -> Self {
463 let tool_messages = tool_calls
464 .iter()
465 .map(|t| {
466 IndexMap::from([
467 ("id".to_string(), Value::String(t.id.clone())),
468 ("type".to_string(), Value::String(t.tp.to_string())),
469 (
470 "function".to_string(),
471 json!({
472 "name": t.function.name,
473 "arguments": t.function.arguments,
474 }),
475 ),
476 ])
477 })
478 .collect();
479 self.messages.push(IndexMap::from([
480 ("role".to_string(), Either::Left(role.to_string())),
481 ("content".to_string(), Either::Left(text.to_string())),
482 ("function".to_string(), Either::Right(tool_messages)),
483 ]));
484 self
485 }
486
487 pub fn add_image_message(
488 self,
489 role: TextMessageRole,
490 text: impl ToString,
491 images: Vec<DynamicImage>,
492 model: &Model,
493 ) -> anyhow::Result<Self> {
494 self.add_multimodal_message(role, text, images, vec![], model)
495 }
496
497 pub fn add_audio_message(
498 self,
499 role: TextMessageRole,
500 text: impl ToString,
501 audios: Vec<AudioInput>,
502 model: &Model,
503 ) -> anyhow::Result<Self> {
504 self.add_multimodal_message(role, text, vec![], audios, model)
505 }
506
507 pub fn add_multimodal_message(
509 mut self,
510 role: TextMessageRole,
511 text: impl ToString,
512 images: Vec<DynamicImage>,
513 audios: Vec<AudioInput>,
514 model: &Model,
515 ) -> anyhow::Result<Self> {
516 let config = model.config().unwrap();
517 let prefixer = match &config.category {
518 ModelCategory::Vision { prefixer } => prefixer,
519 _ => {
520 anyhow::bail!("`add_image_message` expects a vision model.")
521 }
522 };
523
524 let n_added_images = images.len();
526 let image_indexes: Vec<usize> =
527 (self.images.len()..self.images.len() + n_added_images).collect();
528 self.images.extend(images);
529
530 let n_added_audios = audios.len();
532 let audio_indexes: Vec<usize> =
533 (self.audios.len()..self.audios.len() + n_added_audios).collect();
534 self.audios.extend(audios);
535
536 if n_added_images > 0 || n_added_audios > 0 {
537 let mut content_vec: Vec<IndexMap<String, Value>> = Vec::new();
539 for _ in 0..n_added_images {
540 content_vec.push(IndexMap::from([(
541 "type".to_string(),
542 Value::String("image".to_string()),
543 )]));
544 }
545 for _ in 0..n_added_audios {
546 content_vec.push(IndexMap::from([(
547 "type".to_string(),
548 Value::String("audio".to_string()),
549 )]));
550 }
551 let mut prefixed_text = text.to_string();
553 if !image_indexes.is_empty() {
554 prefixed_text = prefixer.prefix_image(image_indexes, &prefixed_text);
555 }
556 if !audio_indexes.is_empty() {
557 prefixed_text = prefixer.prefix_audio(audio_indexes, &prefixed_text);
558 }
559 content_vec.push(IndexMap::from([
561 ("type".to_string(), Value::String("text".to_string())),
562 ("text".to_string(), Value::String(prefixed_text)),
563 ]));
564
565 self.messages.push(IndexMap::from([
566 ("role".to_string(), Either::Left(role.to_string())),
567 ("content".to_string(), Either::Right(content_vec)),
568 ]));
569 } else {
570 self.messages.push(IndexMap::from([
571 ("role".to_string(), Either::Left(role.to_string())),
572 ("content".to_string(), Either::Left(text.to_string())),
573 ]));
574 }
575 Ok(self)
576 }
577
578 pub fn add_logits_processor(mut self, processor: Arc<dyn CustomLogitsProcessor>) -> Self {
579 self.logits_processors.push(processor);
580 self
581 }
582
583 pub fn set_adapters(mut self, adapters: Vec<String>) -> Self {
584 self.adapters = adapters;
585 self
586 }
587
588 pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
590 self.tools = tools;
591 self
592 }
593
594 pub fn set_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
595 self.tool_choice = tool_choice;
596 self
597 }
598
599 pub fn return_logprobs(mut self, return_logprobs: bool) -> Self {
600 self.return_logprobs = return_logprobs;
601 self
602 }
603
604 pub fn set_constraint(mut self, constraint: Constraint) -> Self {
605 self.constraint = constraint;
606 self
607 }
608
609 pub fn set_sampling(mut self, params: SamplingParams) -> Self {
611 self.sampling_params = params;
612 self
613 }
614
615 pub fn set_deterministic_sampler(mut self) -> Self {
621 self.sampling_params = SamplingParams::deterministic();
622 self
623 }
624
625 pub fn set_sampler_temperature(mut self, temperature: f64) -> Self {
626 self.sampling_params.temperature = Some(temperature);
627 self
628 }
629
630 pub fn set_sampler_topk(mut self, topk: usize) -> Self {
631 self.sampling_params.top_k = Some(topk);
632 self
633 }
634
635 pub fn set_sampler_topp(mut self, topp: f64) -> Self {
636 self.sampling_params.top_p = Some(topp);
637 self
638 }
639
640 pub fn set_sampler_minp(mut self, minp: f64) -> Self {
641 self.sampling_params.min_p = Some(minp);
642 self
643 }
644
645 pub fn set_sampler_topn_logprobs(mut self, top_n_logprobs: usize) -> Self {
646 self.sampling_params.top_n_logprobs = top_n_logprobs;
647 self
648 }
649
650 pub fn set_sampler_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
651 self.sampling_params.frequency_penalty = Some(frequency_penalty);
652 self
653 }
654
655 pub fn set_sampler_presence_penalty(mut self, presence_penalty: f32) -> Self {
656 self.sampling_params.presence_penalty = Some(presence_penalty);
657 self
658 }
659
660 pub fn set_sampler_stop_toks(mut self, stop_toks: StopTokens) -> Self {
661 self.sampling_params.stop_toks = Some(stop_toks);
662 self
663 }
664
665 pub fn set_sampler_max_len(mut self, max_len: usize) -> Self {
666 self.sampling_params.max_len = Some(max_len);
667 self
668 }
669
670 pub fn set_sampler_logits_bias(mut self, logits_bias: HashMap<u32, f32>) -> Self {
671 self.sampling_params.logits_bias = Some(logits_bias);
672 self
673 }
674
675 pub fn set_sampler_n_choices(mut self, n_choices: usize) -> Self {
676 self.sampling_params.n_choices = n_choices;
677 self
678 }
679
680 pub fn set_sampler_dry_params(mut self, dry_params: DrySamplingParams) -> Self {
681 self.sampling_params.dry_params = Some(dry_params);
682 self
683 }
684
685 pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
686 self.enable_thinking = Some(enable_thinking);
687 self
688 }
689
690 pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
692 self.truncate_sequence = truncate_sequence;
693 self
694 }
695}
696
697impl RequestLike for RequestBuilder {
698 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
699 &self.messages
700 }
701
702 fn images_ref(&self) -> &[DynamicImage] {
703 &self.images
704 }
705
706 fn take_messages(&mut self) -> RequestMessage {
707 if self.images.is_empty() && self.audios.is_empty() {
708 let mut other = Vec::new();
709 std::mem::swap(&mut other, &mut self.messages);
710 RequestMessage::Chat {
711 messages: other,
712 enable_thinking: self.enable_thinking,
713 reasoning_effort: None,
714 }
715 } else {
716 let mut other_messages = Vec::new();
717 std::mem::swap(&mut other_messages, &mut self.messages);
718 let mut other_images = Vec::new();
719 std::mem::swap(&mut other_images, &mut self.images);
720 let mut other_audios = Vec::new();
721 std::mem::swap(&mut other_audios, &mut self.audios);
722 RequestMessage::VisionChat {
723 images: other_images,
724 messages: other_messages,
725 audios: other_audios,
726 enable_thinking: self.enable_thinking,
727 reasoning_effort: None,
728 }
729 }
730 }
731
732 fn enable_search(&self) -> Option<bool> {
733 self.web_search_options.as_ref().map(|_| true)
734 }
735
736 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
737 if self.logits_processors.is_empty() {
738 None
739 } else {
740 let mut other = Vec::new();
741 std::mem::swap(&mut other, &mut self.logits_processors);
742 Some(other)
743 }
744 }
745
746 fn take_adapters(&mut self) -> Option<Vec<String>> {
747 if self.adapters.is_empty() {
748 None
749 } else {
750 let mut other = Vec::new();
751 std::mem::swap(&mut other, &mut self.adapters);
752 Some(other)
753 }
754 }
755
756 fn return_logprobs(&self) -> bool {
757 self.return_logprobs
758 }
759
760 fn take_constraint(&mut self) -> Constraint {
761 let mut other = Constraint::None;
762 std::mem::swap(&mut other, &mut self.constraint);
763 other
764 }
765
766 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
767 if self.tools.is_empty() {
768 None
769 } else {
770 let mut other_ts = Vec::new();
771 std::mem::swap(&mut other_ts, &mut self.tools);
772 let mut other_tc = ToolChoice::Auto;
773 std::mem::swap(&mut other_tc, &mut self.tool_choice);
774 Some((other_ts, other_tc))
775 }
776 }
777
778 fn take_sampling_params(&mut self) -> SamplingParams {
779 let mut other = SamplingParams::deterministic();
780 std::mem::swap(&mut other, &mut self.sampling_params);
781 other
782 }
783
784 fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
785 let mut other = None;
786 std::mem::swap(&mut other, &mut self.web_search_options);
787 other
788 }
789
790 fn truncate_sequence(&self) -> bool {
791 self.truncate_sequence
792 }
793}
794
795#[derive(Clone, Debug)]
796pub enum EmbeddingRequestInput {
798 Prompt(String),
800 Tokens(Vec<u32>),
802}
803
804impl EmbeddingRequestInput {
805 pub fn into_request_message(self) -> RequestMessage {
806 match self {
807 Self::Prompt(prompt) => RequestMessage::Embedding { prompt },
808 Self::Tokens(prompt) => RequestMessage::EmbeddingTokens { prompt },
809 }
810 }
811}
812
813#[derive(Clone, Debug)]
814pub struct EmbeddingRequest {
816 pub inputs: Vec<EmbeddingRequestInput>,
817 pub truncate_sequence: bool,
818}
819
820impl EmbeddingRequest {
821 pub fn builder() -> EmbeddingRequestBuilder {
823 EmbeddingRequestBuilder::new()
824 }
825}
826
827#[derive(Clone, Debug, Default)]
829pub struct EmbeddingRequestBuilder {
830 inputs: Vec<EmbeddingRequestInput>,
831 truncate_sequence: bool,
832}
833
834impl EmbeddingRequestBuilder {
835 pub fn new() -> Self {
837 Self::default()
838 }
839
840 pub fn add_prompt(mut self, prompt: impl Into<String>) -> Self {
842 self.inputs
843 .push(EmbeddingRequestInput::Prompt(prompt.into()));
844 self
845 }
846
847 pub fn add_prompts<I, S>(mut self, prompts: I) -> Self
849 where
850 I: IntoIterator<Item = S>,
851 S: Into<String>,
852 {
853 self.inputs.extend(
854 prompts
855 .into_iter()
856 .map(|prompt| EmbeddingRequestInput::Prompt(prompt.into())),
857 );
858 self
859 }
860
861 pub fn add_tokens(mut self, tokens: impl Into<Vec<u32>>) -> Self {
863 self.inputs
864 .push(EmbeddingRequestInput::Tokens(tokens.into()));
865 self
866 }
867
868 pub fn add_tokens_batch<I>(mut self, batches: I) -> Self
870 where
871 I: IntoIterator<Item = Vec<u32>>,
872 {
873 self.inputs
874 .extend(batches.into_iter().map(EmbeddingRequestInput::Tokens));
875 self
876 }
877
878 pub fn with_truncate_sequence(mut self, truncate: bool) -> Self {
880 self.truncate_sequence = truncate;
881 self
882 }
883
884 pub fn build(self) -> anyhow::Result<EmbeddingRequest> {
885 if self.inputs.is_empty() {
886 anyhow::bail!("Embedding request must contain at least one input.");
887 }
888
889 Ok(EmbeddingRequest {
890 inputs: self.inputs,
891 truncate_sequence: self.truncate_sequence,
892 })
893 }
894}