1use std::{collections::HashMap, fmt::Display, sync::Arc};
2
3use super::*;
4use either::Either;
5use image::DynamicImage;
6use indexmap::IndexMap;
7use serde_json::{json, Value};
8
9pub trait RequestLike {
11 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>];
12 fn images_ref(&self) -> &[DynamicImage];
13 fn take_messages(&mut self) -> RequestMessage;
14 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>>;
15 fn take_adapters(&mut self) -> Option<Vec<String>>;
16 fn return_logprobs(&self) -> bool;
17 fn enable_search(&self) -> Option<bool>;
18 fn take_constraint(&mut self) -> Constraint;
19 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)>;
20 fn take_sampling_params(&mut self) -> SamplingParams;
21 fn take_web_search_options(&mut self) -> Option<WebSearchOptions>;
22}
23
24#[derive(Debug, Clone, PartialEq)]
25pub struct TextMessages {
31 messages: Vec<IndexMap<String, MessageContent>>,
32 enable_thinking: Option<bool>,
33}
34
35impl From<TextMessages> for Vec<IndexMap<String, MessageContent>> {
36 fn from(value: TextMessages) -> Self {
37 value.messages
38 }
39}
40
41#[derive(Debug, Clone, PartialEq)]
42pub enum TextMessageRole {
44 User,
45 Assistant,
46 System,
47 Tool,
48 Custom(String),
49}
50
51impl Display for TextMessageRole {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 match self {
54 Self::User => write!(f, "user"),
55 Self::Assistant => write!(f, "assistant"),
56 Self::System => write!(f, "system"),
57 Self::Tool => write!(f, "tool"),
58 Self::Custom(c) => write!(f, "{c}"),
59 }
60 }
61}
62
63impl Default for TextMessages {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl TextMessages {
70 pub fn new() -> Self {
71 Self {
72 messages: Vec::new(),
73 enable_thinking: None,
74 }
75 }
76
77 pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
78 self.messages.push(IndexMap::from([
79 ("role".to_string(), Either::Left(role.to_string())),
80 ("content".to_string(), Either::Left(text.to_string())),
81 ]));
82 self
83 }
84
85 pub fn clear(mut self) -> Self {
86 self.messages.clear();
87 self
88 }
89
90 pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
91 self.enable_thinking = Some(enable_thinking);
92 self
93 }
94}
95
96impl RequestLike for TextMessages {
97 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
98 &self.messages
99 }
100 fn images_ref(&self) -> &[DynamicImage] {
101 &[]
102 }
103 fn take_messages(&mut self) -> RequestMessage {
104 let mut other = Vec::new();
105 std::mem::swap(&mut other, &mut self.messages);
106 RequestMessage::Chat {
107 messages: other,
108 enable_thinking: self.enable_thinking,
109 }
110 }
111 fn enable_search(&self) -> Option<bool> {
112 None
113 }
114 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
115 None
116 }
117 fn take_adapters(&mut self) -> Option<Vec<String>> {
118 None
119 }
120 fn return_logprobs(&self) -> bool {
121 false
122 }
123 fn take_constraint(&mut self) -> Constraint {
124 Constraint::None
125 }
126 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
127 None
128 }
129 fn take_sampling_params(&mut self) -> SamplingParams {
130 SamplingParams::deterministic()
131 }
132 fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
133 None
134 }
135}
136
137#[derive(Debug, Clone, PartialEq)]
138pub struct VisionMessages {
144 messages: Vec<IndexMap<String, MessageContent>>,
145 images: Vec<DynamicImage>,
146 audios: Vec<AudioInput>,
147 enable_thinking: Option<bool>,
148}
149
150impl Default for VisionMessages {
151 fn default() -> Self {
152 Self::new()
153 }
154}
155
156impl VisionMessages {
157 pub fn new() -> Self {
158 Self {
159 images: Vec::new(),
160 messages: Vec::new(),
161 audios: Vec::new(),
162 enable_thinking: None,
163 }
164 }
165
166 pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
167 self.messages.push(IndexMap::from([
168 ("role".to_string(), Either::Left(role.to_string())),
169 ("content".to_string(), Either::Left(text.to_string())),
170 ]));
171 self
172 }
173
174 pub fn add_image_message(
175 self,
176 role: TextMessageRole,
177 text: impl ToString,
178 images: Vec<DynamicImage>,
179 model: &Model,
180 ) -> anyhow::Result<Self> {
181 self.add_multimodal_message(role, text, images, vec![], model)
182 }
183
184 pub fn add_audio_message(
185 self,
186 role: TextMessageRole,
187 text: impl ToString,
188 audios: Vec<AudioInput>,
189 model: &Model,
190 ) -> anyhow::Result<Self> {
191 self.add_multimodal_message(role, text, vec![], audios, model)
192 }
193
194 pub fn add_multimodal_message(
195 mut self,
196 role: TextMessageRole,
197 text: impl ToString,
198 images: Vec<DynamicImage>,
199 audios: Vec<AudioInput>,
200 model: &Model,
201 ) -> anyhow::Result<Self> {
202 let config = model.config().unwrap();
203 let prefixer = match &config.category {
204 ModelCategory::Vision { prefixer } => prefixer,
205 ModelCategory::Text
206 | ModelCategory::Diffusion
207 | ModelCategory::Speech
208 | ModelCategory::Audio => {
209 anyhow::bail!("`add_image_message` expects a vision model.")
210 }
211 };
212
213 let n_added_images = images.len();
215 let prefixed = prefixer.prefix_image(
216 (self.images.len()..self.images.len() + n_added_images).collect(),
217 &text.to_string(),
218 );
219 self.images.extend(images);
220
221 let n_added_audios = audios.len();
223 let prefixed = prefixer.prefix_audio(
224 (self.audios.len()..self.audios.len() + n_added_audios).collect(),
225 &prefixed,
226 );
227 self.audios.extend(audios);
228
229 if n_added_images > 0 {
230 self.messages.push(IndexMap::from([
231 ("role".to_string(), Either::Left(role.to_string())),
232 (
233 "content".to_string(),
234 Either::Right(vec![
235 IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
236 IndexMap::from([
237 ("type".to_string(), Value::String("text".to_string())),
238 ("text".to_string(), Value::String(prefixed)),
239 ]),
240 ]),
241 ),
242 ]));
243 } else {
244 self.messages.push(IndexMap::from([
245 ("role".to_string(), Either::Left(role.to_string())),
246 ("content".to_string(), Either::Left(prefixed)),
247 ]));
248 }
249 Ok(self)
250 }
251
252 pub fn clear(mut self) -> Self {
253 self.messages.clear();
254 self.images.clear();
255
256 self
257 }
258
259 pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
260 self.enable_thinking = Some(enable_thinking);
261 self
262 }
263}
264
265impl RequestLike for VisionMessages {
266 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
267 &self.messages
268 }
269 fn images_ref(&self) -> &[DynamicImage] {
270 &self.images
271 }
272 fn take_messages(&mut self) -> RequestMessage {
273 let mut other_messages = Vec::new();
274 std::mem::swap(&mut other_messages, &mut self.messages);
275 let mut other_images = Vec::new();
276 std::mem::swap(&mut other_images, &mut self.images);
277 let mut other_audios = Vec::new();
278 std::mem::swap(&mut other_audios, &mut self.audios);
279 RequestMessage::VisionChat {
280 images: other_images,
281 messages: other_messages,
282 audios: other_audios,
283 enable_thinking: self.enable_thinking,
284 }
285 }
286 fn enable_search(&self) -> Option<bool> {
287 None
288 }
289 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
290 None
291 }
292 fn take_adapters(&mut self) -> Option<Vec<String>> {
293 None
294 }
295 fn return_logprobs(&self) -> bool {
296 false
297 }
298 fn take_constraint(&mut self) -> Constraint {
299 Constraint::None
300 }
301 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
302 None
303 }
304 fn take_sampling_params(&mut self) -> SamplingParams {
305 SamplingParams::deterministic()
306 }
307 fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
308 None
309 }
310}
311
312#[derive(Clone)]
313pub struct RequestBuilder {
323 messages: Vec<IndexMap<String, MessageContent>>,
324 images: Vec<DynamicImage>,
325 audios: Vec<AudioInput>,
326 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
327 adapters: Vec<String>,
328 return_logprobs: bool,
329 constraint: Constraint,
330 tools: Vec<Tool>,
331 tool_choice: ToolChoice,
332 sampling_params: SamplingParams,
333 web_search_options: Option<WebSearchOptions>,
334 enable_thinking: Option<bool>,
335}
336
337impl Default for RequestBuilder {
338 fn default() -> Self {
339 Self::new()
340 }
341}
342
343impl From<TextMessages> for RequestBuilder {
344 fn from(value: TextMessages) -> Self {
345 Self {
346 messages: value.messages,
347 images: Vec::new(),
348 audios: Vec::new(),
349 logits_processors: Vec::new(),
350 adapters: Vec::new(),
351 return_logprobs: false,
352 constraint: Constraint::None,
353 tools: Vec::new(),
354 tool_choice: ToolChoice::Auto,
355 sampling_params: SamplingParams::deterministic(),
356 web_search_options: None,
357 enable_thinking: None,
358 }
359 }
360}
361
362impl From<VisionMessages> for RequestBuilder {
363 fn from(value: VisionMessages) -> Self {
364 Self {
365 messages: value.messages,
366 images: value.images,
367 audios: value.audios,
368 logits_processors: Vec::new(),
369 adapters: Vec::new(),
370 return_logprobs: false,
371 constraint: Constraint::None,
372 tools: Vec::new(),
373 tool_choice: ToolChoice::Auto,
374 sampling_params: SamplingParams::deterministic(),
375 web_search_options: None,
376 enable_thinking: None,
377 }
378 }
379}
380
381impl RequestBuilder {
382 pub fn new() -> Self {
383 Self {
384 messages: Vec::new(),
385 images: Vec::new(),
386 audios: Vec::new(),
387 logits_processors: Vec::new(),
388 adapters: Vec::new(),
389 return_logprobs: false,
390 constraint: Constraint::None,
391 tools: Vec::new(),
392 tool_choice: ToolChoice::Auto,
393 sampling_params: SamplingParams::deterministic(),
394 web_search_options: None,
395 enable_thinking: None,
396 }
397 }
398
399 pub fn with_web_search_options(mut self, web_search_options: WebSearchOptions) -> Self {
400 self.web_search_options = Some(web_search_options);
401 self
402 }
403
404 pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
409 self.messages.push(IndexMap::from([
410 ("role".to_string(), Either::Left(role.to_string())),
411 ("content".to_string(), Either::Left(text.to_string())),
412 ]));
413 self
414 }
415
416 pub fn add_tool_message(mut self, tool_content: impl ToString, tool_id: impl ToString) -> Self {
418 self.messages.push(IndexMap::from([
419 (
420 "role".to_string(),
421 Either::Left(TextMessageRole::Tool.to_string()),
422 ),
423 (
424 "content".to_string(),
425 Either::Left(tool_content.to_string()),
426 ),
427 (
428 "tool_call_id".to_string(),
429 Either::Left(tool_id.to_string()),
430 ),
431 ]));
432 self
433 }
434
435 pub fn add_message_with_tool_call(
436 mut self,
437 role: TextMessageRole,
438 text: impl ToString,
439 tool_calls: Vec<ToolCallResponse>,
440 ) -> Self {
441 let tool_messages = tool_calls
442 .iter()
443 .map(|t| {
444 IndexMap::from([
445 ("id".to_string(), Value::String(t.id.clone())),
446 ("type".to_string(), Value::String(t.tp.to_string())),
447 (
448 "function".to_string(),
449 json!({
450 "name": t.function.name,
451 "arguments": t.function.arguments,
452 }),
453 ),
454 ])
455 })
456 .collect();
457 self.messages.push(IndexMap::from([
458 ("role".to_string(), Either::Left(role.to_string())),
459 ("content".to_string(), Either::Left(text.to_string())),
460 ("function".to_string(), Either::Right(tool_messages)),
461 ]));
462 self
463 }
464
465 pub fn add_image_message(
466 self,
467 role: TextMessageRole,
468 text: impl ToString,
469 images: Vec<DynamicImage>,
470 model: &Model,
471 ) -> anyhow::Result<Self> {
472 self.add_multimodal_message(role, text, images, vec![], model)
473 }
474
475 pub fn add_audio_message(
476 self,
477 role: TextMessageRole,
478 text: impl ToString,
479 audios: Vec<AudioInput>,
480 model: &Model,
481 ) -> anyhow::Result<Self> {
482 self.add_multimodal_message(role, text, vec![], audios, model)
483 }
484
485 pub fn add_multimodal_message(
486 mut self,
487 role: TextMessageRole,
488 text: impl ToString,
489 images: Vec<DynamicImage>,
490 audios: Vec<AudioInput>,
491 model: &Model,
492 ) -> anyhow::Result<Self> {
493 let config = model.config().unwrap();
494 let prefixer = match &config.category {
495 ModelCategory::Vision { prefixer } => prefixer,
496 ModelCategory::Text
497 | ModelCategory::Diffusion
498 | ModelCategory::Speech
499 | ModelCategory::Audio => {
500 anyhow::bail!("`add_image_message` expects a vision model.")
501 }
502 };
503
504 let n_added_images = images.len();
506 let prefixed = prefixer.prefix_image(
507 (self.images.len()..self.images.len() + n_added_images).collect(),
508 &text.to_string(),
509 );
510 self.images.extend(images);
511
512 let n_added_audios = audios.len();
514 let prefixed = prefixer.prefix_audio(
515 (self.audios.len()..self.audios.len() + n_added_audios).collect(),
516 &prefixed,
517 );
518 self.audios.extend(audios);
519
520 if n_added_images > 0 {
521 self.messages.push(IndexMap::from([
522 ("role".to_string(), Either::Left(role.to_string())),
523 (
524 "content".to_string(),
525 Either::Right(vec![
526 IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
527 IndexMap::from([
528 ("type".to_string(), Value::String("text".to_string())),
529 ("text".to_string(), Value::String(prefixed)),
530 ]),
531 ]),
532 ),
533 ]));
534 } else {
535 self.messages.push(IndexMap::from([
536 ("role".to_string(), Either::Left(role.to_string())),
537 ("content".to_string(), Either::Left(prefixed)),
538 ]));
539 }
540 Ok(self)
541 }
542
543 pub fn add_logits_processor(mut self, processor: Arc<dyn CustomLogitsProcessor>) -> Self {
544 self.logits_processors.push(processor);
545 self
546 }
547
548 pub fn set_adapters(mut self, adapters: Vec<String>) -> Self {
549 self.adapters = adapters;
550 self
551 }
552
553 pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
555 self.tools = tools;
556 self
557 }
558
559 pub fn set_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
560 self.tool_choice = tool_choice;
561 self
562 }
563
564 pub fn return_logprobs(mut self, return_logprobs: bool) -> Self {
565 self.return_logprobs = return_logprobs;
566 self
567 }
568
569 pub fn set_constraint(mut self, constraint: Constraint) -> Self {
570 self.constraint = constraint;
571 self
572 }
573
574 pub fn set_sampling(mut self, params: SamplingParams) -> Self {
576 self.sampling_params = params;
577 self
578 }
579
580 pub fn set_deterministic_sampler(mut self) -> Self {
586 self.sampling_params = SamplingParams::deterministic();
587 self
588 }
589
590 pub fn set_sampler_temperature(mut self, temperature: f64) -> Self {
591 self.sampling_params.temperature = Some(temperature);
592 self
593 }
594
595 pub fn set_sampler_topk(mut self, topk: usize) -> Self {
596 self.sampling_params.top_k = Some(topk);
597 self
598 }
599
600 pub fn set_sampler_topp(mut self, topp: f64) -> Self {
601 self.sampling_params.top_p = Some(topp);
602 self
603 }
604
605 pub fn set_sampler_minp(mut self, minp: f64) -> Self {
606 self.sampling_params.min_p = Some(minp);
607 self
608 }
609
610 pub fn set_sampler_topn_logprobs(mut self, top_n_logprobs: usize) -> Self {
611 self.sampling_params.top_n_logprobs = top_n_logprobs;
612 self
613 }
614
615 pub fn set_sampler_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
616 self.sampling_params.frequency_penalty = Some(frequency_penalty);
617 self
618 }
619
620 pub fn set_sampler_presence_penalty(mut self, presence_penalty: f32) -> Self {
621 self.sampling_params.presence_penalty = Some(presence_penalty);
622 self
623 }
624
625 pub fn set_sampler_stop_toks(mut self, stop_toks: StopTokens) -> Self {
626 self.sampling_params.stop_toks = Some(stop_toks);
627 self
628 }
629
630 pub fn set_sampler_max_len(mut self, max_len: usize) -> Self {
631 self.sampling_params.max_len = Some(max_len);
632 self
633 }
634
635 pub fn set_sampler_logits_bias(mut self, logits_bias: HashMap<u32, f32>) -> Self {
636 self.sampling_params.logits_bias = Some(logits_bias);
637 self
638 }
639
640 pub fn set_sampler_n_choices(mut self, n_choices: usize) -> Self {
641 self.sampling_params.n_choices = n_choices;
642 self
643 }
644
645 pub fn set_sampler_dry_params(mut self, dry_params: DrySamplingParams) -> Self {
646 self.sampling_params.dry_params = Some(dry_params);
647 self
648 }
649
650 pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
651 self.enable_thinking = Some(enable_thinking);
652 self
653 }
654}
655
656impl RequestLike for RequestBuilder {
657 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
658 &self.messages
659 }
660
661 fn images_ref(&self) -> &[DynamicImage] {
662 &self.images
663 }
664
665 fn take_messages(&mut self) -> RequestMessage {
666 if self.images.is_empty() && self.audios.is_empty() {
667 let mut other = Vec::new();
668 std::mem::swap(&mut other, &mut self.messages);
669 RequestMessage::Chat {
670 messages: other,
671 enable_thinking: self.enable_thinking,
672 }
673 } else {
674 let mut other_messages = Vec::new();
675 std::mem::swap(&mut other_messages, &mut self.messages);
676 let mut other_images = Vec::new();
677 std::mem::swap(&mut other_images, &mut self.images);
678 let mut other_audios = Vec::new();
679 std::mem::swap(&mut other_audios, &mut self.audios);
680 RequestMessage::VisionChat {
681 images: other_images,
682 messages: other_messages,
683 audios: other_audios,
684 enable_thinking: self.enable_thinking,
685 }
686 }
687 }
688
689 fn enable_search(&self) -> Option<bool> {
690 self.enable_thinking
691 }
692
693 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
694 if self.logits_processors.is_empty() {
695 None
696 } else {
697 let mut other = Vec::new();
698 std::mem::swap(&mut other, &mut self.logits_processors);
699 Some(other)
700 }
701 }
702
703 fn take_adapters(&mut self) -> Option<Vec<String>> {
704 if self.adapters.is_empty() {
705 None
706 } else {
707 let mut other = Vec::new();
708 std::mem::swap(&mut other, &mut self.adapters);
709 Some(other)
710 }
711 }
712
713 fn return_logprobs(&self) -> bool {
714 self.return_logprobs
715 }
716
717 fn take_constraint(&mut self) -> Constraint {
718 let mut other = Constraint::None;
719 std::mem::swap(&mut other, &mut self.constraint);
720 other
721 }
722
723 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
724 if self.tools.is_empty() {
725 None
726 } else {
727 let mut other_ts = Vec::new();
728 std::mem::swap(&mut other_ts, &mut self.tools);
729 let mut other_tc = ToolChoice::Auto;
730 std::mem::swap(&mut other_tc, &mut self.tool_choice);
731 Some((other_ts, other_tc))
732 }
733 }
734
735 fn take_sampling_params(&mut self) -> SamplingParams {
736 let mut other = SamplingParams::deterministic();
737 std::mem::swap(&mut other, &mut self.sampling_params);
738 other
739 }
740
741 fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
742 let mut other = None;
743 std::mem::swap(&mut other, &mut self.web_search_options);
744 other
745 }
746}