1use std::{collections::HashMap, fmt::Display, sync::Arc};
2
3use super::*;
4use either::Either;
5use image::DynamicImage;
6use indexmap::IndexMap;
7use serde_json::{json, Value};
8
9pub trait RequestLike {
11 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>];
12 fn images_ref(&self) -> &[DynamicImage];
13 fn take_messages(&mut self) -> RequestMessage;
14 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>>;
15 fn take_adapters(&mut self) -> Option<Vec<String>>;
16 fn return_logprobs(&self) -> bool;
17 fn enable_search(&self) -> Option<bool>;
18 fn take_constraint(&mut self) -> Constraint;
19 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)>;
20 fn take_sampling_params(&mut self) -> SamplingParams;
21 fn take_web_search_options(&mut self) -> Option<WebSearchOptions>;
22}
23
24#[derive(Debug, Clone, PartialEq)]
25pub struct TextMessages {
31 messages: Vec<IndexMap<String, MessageContent>>,
32 enable_thinking: Option<bool>,
33}
34
35impl From<TextMessages> for Vec<IndexMap<String, MessageContent>> {
36 fn from(value: TextMessages) -> Self {
37 value.messages
38 }
39}
40
41#[derive(Debug, Clone, PartialEq)]
42pub enum TextMessageRole {
44 User,
45 Assistant,
46 System,
47 Tool,
48 Custom(String),
49}
50
51impl Display for TextMessageRole {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 match self {
54 Self::User => write!(f, "user"),
55 Self::Assistant => write!(f, "assistant"),
56 Self::System => write!(f, "system"),
57 Self::Tool => write!(f, "tool"),
58 Self::Custom(c) => write!(f, "{c}"),
59 }
60 }
61}
62
63impl Default for TextMessages {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl TextMessages {
70 pub fn new() -> Self {
71 Self {
72 messages: Vec::new(),
73 enable_thinking: None,
74 }
75 }
76
77 pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
78 self.messages.push(IndexMap::from([
79 ("role".to_string(), Either::Left(role.to_string())),
80 ("content".to_string(), Either::Left(text.to_string())),
81 ]));
82 self
83 }
84
85 pub fn clear(mut self) -> Self {
86 self.messages.clear();
87 self
88 }
89
90 pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
91 self.enable_thinking = Some(enable_thinking);
92 self
93 }
94}
95
96impl RequestLike for TextMessages {
97 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
98 &self.messages
99 }
100 fn images_ref(&self) -> &[DynamicImage] {
101 &[]
102 }
103 fn take_messages(&mut self) -> RequestMessage {
104 let mut other = Vec::new();
105 std::mem::swap(&mut other, &mut self.messages);
106 RequestMessage::Chat {
107 messages: other,
108 enable_thinking: self.enable_thinking,
109 }
110 }
111 fn enable_search(&self) -> Option<bool> {
112 None
113 }
114 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
115 None
116 }
117 fn take_adapters(&mut self) -> Option<Vec<String>> {
118 None
119 }
120 fn return_logprobs(&self) -> bool {
121 false
122 }
123 fn take_constraint(&mut self) -> Constraint {
124 Constraint::None
125 }
126 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
127 None
128 }
129 fn take_sampling_params(&mut self) -> SamplingParams {
130 SamplingParams::deterministic()
131 }
132 fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
133 None
134 }
135}
136
137#[derive(Debug, Clone, PartialEq)]
138pub struct VisionMessages {
144 messages: Vec<IndexMap<String, MessageContent>>,
145 images: Vec<DynamicImage>,
146 audios: Vec<AudioInput>,
147 enable_thinking: Option<bool>,
148}
149
150impl Default for VisionMessages {
151 fn default() -> Self {
152 Self::new()
153 }
154}
155
156impl VisionMessages {
157 pub fn new() -> Self {
158 Self {
159 images: Vec::new(),
160 messages: Vec::new(),
161 audios: Vec::new(),
162 enable_thinking: None,
163 }
164 }
165
166 pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
167 self.messages.push(IndexMap::from([
168 ("role".to_string(), Either::Left(role.to_string())),
169 ("content".to_string(), Either::Left(text.to_string())),
170 ]));
171 self
172 }
173
174 pub fn add_image_message(
175 self,
176 role: TextMessageRole,
177 text: impl ToString,
178 images: Vec<DynamicImage>,
179 model: &Model,
180 ) -> anyhow::Result<Self> {
181 self.add_multimodal_message(role, text, images, vec![], model)
182 }
183
184 pub fn add_audio_message(
185 self,
186 role: TextMessageRole,
187 text: impl ToString,
188 audios: Vec<AudioInput>,
189 model: &Model,
190 ) -> anyhow::Result<Self> {
191 self.add_multimodal_message(role, text, vec![], audios, model)
192 }
193
194 pub fn add_multimodal_message(
195 mut self,
196 role: TextMessageRole,
197 text: impl ToString,
198 images: Vec<DynamicImage>,
199 audios: Vec<AudioInput>,
200 model: &Model,
201 ) -> anyhow::Result<Self> {
202 let config = model.config().unwrap();
203 let prefixer = match &config.category {
204 ModelCategory::Vision { prefixer } => prefixer,
205 ModelCategory::Text
206 | ModelCategory::Diffusion
207 | ModelCategory::Speech
208 | ModelCategory::Audio => {
209 anyhow::bail!("`add_image_message` expects a vision model.")
210 }
211 };
212
213 let n_added_images = images.len();
215 let prefixed = prefixer.prefix_image(
216 (self.images.len()..self.images.len() + n_added_images).collect(),
217 &text.to_string(),
218 );
219 self.images.extend(images);
220
221 let n_added_audios = audios.len();
223 let prefixed = prefixer.prefix_audio(
224 (self.audios.len()..self.audios.len() + n_added_audios).collect(),
225 &prefixed,
226 );
227 self.audios.extend(audios);
228
229 if n_added_images > 0 {
230 self.messages.push(IndexMap::from([
231 ("role".to_string(), Either::Left(role.to_string())),
232 (
233 "content".to_string(),
234 Either::Right(vec![
235 IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
236 IndexMap::from([
237 ("type".to_string(), Value::String("text".to_string())),
238 ("text".to_string(), Value::String(prefixed)),
239 ]),
240 ]),
241 ),
242 ]));
243 } else {
244 self.messages.push(IndexMap::from([
245 ("role".to_string(), Either::Left(role.to_string())),
246 ("content".to_string(), Either::Left(prefixed)),
247 ]));
248 }
249 Ok(self)
250 }
251
252 pub fn clear(mut self) -> Self {
253 self.messages.clear();
254 self.images.clear();
255 self.audios.clear();
256
257 self
258 }
259
260 pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
261 self.enable_thinking = Some(enable_thinking);
262 self
263 }
264}
265
266impl RequestLike for VisionMessages {
267 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
268 &self.messages
269 }
270 fn images_ref(&self) -> &[DynamicImage] {
271 &self.images
272 }
273 fn take_messages(&mut self) -> RequestMessage {
274 let mut other_messages = Vec::new();
275 std::mem::swap(&mut other_messages, &mut self.messages);
276 let mut other_images = Vec::new();
277 std::mem::swap(&mut other_images, &mut self.images);
278 let mut other_audios = Vec::new();
279 std::mem::swap(&mut other_audios, &mut self.audios);
280 RequestMessage::VisionChat {
281 images: other_images,
282 messages: other_messages,
283 audios: other_audios,
284 enable_thinking: self.enable_thinking,
285 }
286 }
287 fn enable_search(&self) -> Option<bool> {
288 None
289 }
290 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
291 None
292 }
293 fn take_adapters(&mut self) -> Option<Vec<String>> {
294 None
295 }
296 fn return_logprobs(&self) -> bool {
297 false
298 }
299 fn take_constraint(&mut self) -> Constraint {
300 Constraint::None
301 }
302 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
303 None
304 }
305 fn take_sampling_params(&mut self) -> SamplingParams {
306 SamplingParams::deterministic()
307 }
308 fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
309 None
310 }
311}
312
313#[derive(Clone)]
314pub struct RequestBuilder {
324 messages: Vec<IndexMap<String, MessageContent>>,
325 images: Vec<DynamicImage>,
326 audios: Vec<AudioInput>,
327 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
328 adapters: Vec<String>,
329 return_logprobs: bool,
330 constraint: Constraint,
331 tools: Vec<Tool>,
332 tool_choice: ToolChoice,
333 sampling_params: SamplingParams,
334 web_search_options: Option<WebSearchOptions>,
335 enable_thinking: Option<bool>,
336}
337
338impl Default for RequestBuilder {
339 fn default() -> Self {
340 Self::new()
341 }
342}
343
344impl From<TextMessages> for RequestBuilder {
345 fn from(value: TextMessages) -> Self {
346 Self {
347 messages: value.messages,
348 images: Vec::new(),
349 audios: Vec::new(),
350 logits_processors: Vec::new(),
351 adapters: Vec::new(),
352 return_logprobs: false,
353 constraint: Constraint::None,
354 tools: Vec::new(),
355 tool_choice: ToolChoice::Auto,
356 sampling_params: SamplingParams::deterministic(),
357 web_search_options: None,
358 enable_thinking: None,
359 }
360 }
361}
362
363impl From<VisionMessages> for RequestBuilder {
364 fn from(value: VisionMessages) -> Self {
365 Self {
366 messages: value.messages,
367 images: value.images,
368 audios: value.audios,
369 logits_processors: Vec::new(),
370 adapters: Vec::new(),
371 return_logprobs: false,
372 constraint: Constraint::None,
373 tools: Vec::new(),
374 tool_choice: ToolChoice::Auto,
375 sampling_params: SamplingParams::deterministic(),
376 web_search_options: None,
377 enable_thinking: None,
378 }
379 }
380}
381
382impl RequestBuilder {
383 pub fn new() -> Self {
384 Self {
385 messages: Vec::new(),
386 images: Vec::new(),
387 audios: Vec::new(),
388 logits_processors: Vec::new(),
389 adapters: Vec::new(),
390 return_logprobs: false,
391 constraint: Constraint::None,
392 tools: Vec::new(),
393 tool_choice: ToolChoice::Auto,
394 sampling_params: SamplingParams::deterministic(),
395 web_search_options: None,
396 enable_thinking: None,
397 }
398 }
399
400 pub fn with_web_search_options(mut self, web_search_options: WebSearchOptions) -> Self {
401 self.web_search_options = Some(web_search_options);
402 self
403 }
404
405 pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
410 self.messages.push(IndexMap::from([
411 ("role".to_string(), Either::Left(role.to_string())),
412 ("content".to_string(), Either::Left(text.to_string())),
413 ]));
414 self
415 }
416
417 pub fn add_tool_message(mut self, tool_content: impl ToString, tool_id: impl ToString) -> Self {
419 self.messages.push(IndexMap::from([
420 (
421 "role".to_string(),
422 Either::Left(TextMessageRole::Tool.to_string()),
423 ),
424 (
425 "content".to_string(),
426 Either::Left(tool_content.to_string()),
427 ),
428 (
429 "tool_call_id".to_string(),
430 Either::Left(tool_id.to_string()),
431 ),
432 ]));
433 self
434 }
435
436 pub fn add_message_with_tool_call(
437 mut self,
438 role: TextMessageRole,
439 text: impl ToString,
440 tool_calls: Vec<ToolCallResponse>,
441 ) -> Self {
442 let tool_messages = tool_calls
443 .iter()
444 .map(|t| {
445 IndexMap::from([
446 ("id".to_string(), Value::String(t.id.clone())),
447 ("type".to_string(), Value::String(t.tp.to_string())),
448 (
449 "function".to_string(),
450 json!({
451 "name": t.function.name,
452 "arguments": t.function.arguments,
453 }),
454 ),
455 ])
456 })
457 .collect();
458 self.messages.push(IndexMap::from([
459 ("role".to_string(), Either::Left(role.to_string())),
460 ("content".to_string(), Either::Left(text.to_string())),
461 ("function".to_string(), Either::Right(tool_messages)),
462 ]));
463 self
464 }
465
466 pub fn add_image_message(
467 self,
468 role: TextMessageRole,
469 text: impl ToString,
470 images: Vec<DynamicImage>,
471 model: &Model,
472 ) -> anyhow::Result<Self> {
473 self.add_multimodal_message(role, text, images, vec![], model)
474 }
475
476 pub fn add_audio_message(
477 self,
478 role: TextMessageRole,
479 text: impl ToString,
480 audios: Vec<AudioInput>,
481 model: &Model,
482 ) -> anyhow::Result<Self> {
483 self.add_multimodal_message(role, text, vec![], audios, model)
484 }
485
486 pub fn add_multimodal_message(
487 mut self,
488 role: TextMessageRole,
489 text: impl ToString,
490 images: Vec<DynamicImage>,
491 audios: Vec<AudioInput>,
492 model: &Model,
493 ) -> anyhow::Result<Self> {
494 let config = model.config().unwrap();
495 let prefixer = match &config.category {
496 ModelCategory::Vision { prefixer } => prefixer,
497 ModelCategory::Text
498 | ModelCategory::Diffusion
499 | ModelCategory::Speech
500 | ModelCategory::Audio => {
501 anyhow::bail!("`add_image_message` expects a vision model.")
502 }
503 };
504
505 let n_added_images = images.len();
507 let prefixed = prefixer.prefix_image(
508 (self.images.len()..self.images.len() + n_added_images).collect(),
509 &text.to_string(),
510 );
511 self.images.extend(images);
512
513 let n_added_audios = audios.len();
515 let prefixed = prefixer.prefix_audio(
516 (self.audios.len()..self.audios.len() + n_added_audios).collect(),
517 &prefixed,
518 );
519 self.audios.extend(audios);
520
521 if n_added_images > 0 {
522 self.messages.push(IndexMap::from([
523 ("role".to_string(), Either::Left(role.to_string())),
524 (
525 "content".to_string(),
526 Either::Right(vec![
527 IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
528 IndexMap::from([
529 ("type".to_string(), Value::String("text".to_string())),
530 ("text".to_string(), Value::String(prefixed)),
531 ]),
532 ]),
533 ),
534 ]));
535 } else {
536 self.messages.push(IndexMap::from([
537 ("role".to_string(), Either::Left(role.to_string())),
538 ("content".to_string(), Either::Left(prefixed)),
539 ]));
540 }
541 Ok(self)
542 }
543
544 pub fn add_logits_processor(mut self, processor: Arc<dyn CustomLogitsProcessor>) -> Self {
545 self.logits_processors.push(processor);
546 self
547 }
548
549 pub fn set_adapters(mut self, adapters: Vec<String>) -> Self {
550 self.adapters = adapters;
551 self
552 }
553
554 pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
556 self.tools = tools;
557 self
558 }
559
560 pub fn set_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
561 self.tool_choice = tool_choice;
562 self
563 }
564
565 pub fn return_logprobs(mut self, return_logprobs: bool) -> Self {
566 self.return_logprobs = return_logprobs;
567 self
568 }
569
570 pub fn set_constraint(mut self, constraint: Constraint) -> Self {
571 self.constraint = constraint;
572 self
573 }
574
575 pub fn set_sampling(mut self, params: SamplingParams) -> Self {
577 self.sampling_params = params;
578 self
579 }
580
581 pub fn set_deterministic_sampler(mut self) -> Self {
587 self.sampling_params = SamplingParams::deterministic();
588 self
589 }
590
591 pub fn set_sampler_temperature(mut self, temperature: f64) -> Self {
592 self.sampling_params.temperature = Some(temperature);
593 self
594 }
595
596 pub fn set_sampler_topk(mut self, topk: usize) -> Self {
597 self.sampling_params.top_k = Some(topk);
598 self
599 }
600
601 pub fn set_sampler_topp(mut self, topp: f64) -> Self {
602 self.sampling_params.top_p = Some(topp);
603 self
604 }
605
606 pub fn set_sampler_minp(mut self, minp: f64) -> Self {
607 self.sampling_params.min_p = Some(minp);
608 self
609 }
610
611 pub fn set_sampler_topn_logprobs(mut self, top_n_logprobs: usize) -> Self {
612 self.sampling_params.top_n_logprobs = top_n_logprobs;
613 self
614 }
615
616 pub fn set_sampler_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
617 self.sampling_params.frequency_penalty = Some(frequency_penalty);
618 self
619 }
620
621 pub fn set_sampler_presence_penalty(mut self, presence_penalty: f32) -> Self {
622 self.sampling_params.presence_penalty = Some(presence_penalty);
623 self
624 }
625
626 pub fn set_sampler_stop_toks(mut self, stop_toks: StopTokens) -> Self {
627 self.sampling_params.stop_toks = Some(stop_toks);
628 self
629 }
630
631 pub fn set_sampler_max_len(mut self, max_len: usize) -> Self {
632 self.sampling_params.max_len = Some(max_len);
633 self
634 }
635
636 pub fn set_sampler_logits_bias(mut self, logits_bias: HashMap<u32, f32>) -> Self {
637 self.sampling_params.logits_bias = Some(logits_bias);
638 self
639 }
640
641 pub fn set_sampler_n_choices(mut self, n_choices: usize) -> Self {
642 self.sampling_params.n_choices = n_choices;
643 self
644 }
645
646 pub fn set_sampler_dry_params(mut self, dry_params: DrySamplingParams) -> Self {
647 self.sampling_params.dry_params = Some(dry_params);
648 self
649 }
650
651 pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
652 self.enable_thinking = Some(enable_thinking);
653 self
654 }
655}
656
657impl RequestLike for RequestBuilder {
658 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
659 &self.messages
660 }
661
662 fn images_ref(&self) -> &[DynamicImage] {
663 &self.images
664 }
665
666 fn take_messages(&mut self) -> RequestMessage {
667 if self.images.is_empty() && self.audios.is_empty() {
668 let mut other = Vec::new();
669 std::mem::swap(&mut other, &mut self.messages);
670 RequestMessage::Chat {
671 messages: other,
672 enable_thinking: self.enable_thinking,
673 }
674 } else {
675 let mut other_messages = Vec::new();
676 std::mem::swap(&mut other_messages, &mut self.messages);
677 let mut other_images = Vec::new();
678 std::mem::swap(&mut other_images, &mut self.images);
679 let mut other_audios = Vec::new();
680 std::mem::swap(&mut other_audios, &mut self.audios);
681 RequestMessage::VisionChat {
682 images: other_images,
683 messages: other_messages,
684 audios: other_audios,
685 enable_thinking: self.enable_thinking,
686 }
687 }
688 }
689
690 fn enable_search(&self) -> Option<bool> {
691 self.web_search_options.as_ref().map(|_| true)
692 }
693
694 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
695 if self.logits_processors.is_empty() {
696 None
697 } else {
698 let mut other = Vec::new();
699 std::mem::swap(&mut other, &mut self.logits_processors);
700 Some(other)
701 }
702 }
703
704 fn take_adapters(&mut self) -> Option<Vec<String>> {
705 if self.adapters.is_empty() {
706 None
707 } else {
708 let mut other = Vec::new();
709 std::mem::swap(&mut other, &mut self.adapters);
710 Some(other)
711 }
712 }
713
714 fn return_logprobs(&self) -> bool {
715 self.return_logprobs
716 }
717
718 fn take_constraint(&mut self) -> Constraint {
719 let mut other = Constraint::None;
720 std::mem::swap(&mut other, &mut self.constraint);
721 other
722 }
723
724 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
725 if self.tools.is_empty() {
726 None
727 } else {
728 let mut other_ts = Vec::new();
729 std::mem::swap(&mut other_ts, &mut self.tools);
730 let mut other_tc = ToolChoice::Auto;
731 std::mem::swap(&mut other_tc, &mut self.tool_choice);
732 Some((other_ts, other_tc))
733 }
734 }
735
736 fn take_sampling_params(&mut self) -> SamplingParams {
737 let mut other = SamplingParams::deterministic();
738 std::mem::swap(&mut other, &mut self.sampling_params);
739 other
740 }
741
742 fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
743 let mut other = None;
744 std::mem::swap(&mut other, &mut self.web_search_options);
745 other
746 }
747}