1use std::{collections::HashMap, fmt::Display, sync::Arc};
2
3use super::*;
4use either::Either;
5use image::DynamicImage;
6use indexmap::IndexMap;
7use serde_json::{json, Value};
8
9pub trait RequestLike {
11 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>];
12 fn images_ref(&self) -> &[DynamicImage];
13 fn take_messages(&mut self) -> RequestMessage;
14 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>>;
15 fn take_adapters(&mut self) -> Option<Vec<String>>;
16 fn return_logprobs(&self) -> bool;
17 fn enable_search(&self) -> Option<bool>;
18 fn take_constraint(&mut self) -> Constraint;
19 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)>;
20 fn take_sampling_params(&mut self) -> SamplingParams;
21 fn take_web_search_options(&mut self) -> Option<WebSearchOptions>;
22}
23
24#[derive(Debug, Clone, PartialEq)]
25pub struct TextMessages(Vec<IndexMap<String, MessageContent>>);
31
32impl From<TextMessages> for Vec<IndexMap<String, MessageContent>> {
33 fn from(value: TextMessages) -> Self {
34 value.0
35 }
36}
37
38#[derive(Debug, Clone, PartialEq)]
39pub enum TextMessageRole {
41 User,
42 Assistant,
43 System,
44 Tool,
45 Custom(String),
46}
47
48impl Display for TextMessageRole {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 match self {
51 Self::User => write!(f, "user"),
52 Self::Assistant => write!(f, "assistant"),
53 Self::System => write!(f, "system"),
54 Self::Tool => write!(f, "tool"),
55 Self::Custom(c) => write!(f, "{c}"),
56 }
57 }
58}
59
60impl Default for TextMessages {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl TextMessages {
67 pub fn new() -> Self {
68 Self(Vec::new())
69 }
70
71 pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
72 self.0.push(IndexMap::from([
73 ("role".to_string(), Either::Left(role.to_string())),
74 ("content".to_string(), Either::Left(text.to_string())),
75 ]));
76 self
77 }
78
79 pub fn clear(mut self) -> Self {
80 self.0.clear();
81 self
82 }
83}
84
85impl RequestLike for TextMessages {
86 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
87 &self.0
88 }
89 fn images_ref(&self) -> &[DynamicImage] {
90 &[]
91 }
92 fn take_messages(&mut self) -> RequestMessage {
93 let mut other = Vec::new();
94 std::mem::swap(&mut other, &mut self.0);
95 RequestMessage::Chat {
96 messages: other,
97 enable_thinking: self.enable_search(),
98 }
99 }
100 fn enable_search(&self) -> Option<bool> {
101 None
102 }
103 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
104 None
105 }
106 fn take_adapters(&mut self) -> Option<Vec<String>> {
107 None
108 }
109 fn return_logprobs(&self) -> bool {
110 false
111 }
112 fn take_constraint(&mut self) -> Constraint {
113 Constraint::None
114 }
115 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
116 None
117 }
118 fn take_sampling_params(&mut self) -> SamplingParams {
119 SamplingParams::deterministic()
120 }
121 fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
122 None
123 }
124}
125
126#[derive(Debug, Clone, PartialEq)]
127pub struct VisionMessages {
133 messages: Vec<IndexMap<String, MessageContent>>,
134 images: Vec<DynamicImage>,
135 audios: Vec<AudioInput>,
136}
137
138impl Default for VisionMessages {
139 fn default() -> Self {
140 Self::new()
141 }
142}
143
144impl VisionMessages {
145 pub fn new() -> Self {
146 Self {
147 images: Vec::new(),
148 messages: Vec::new(),
149 audios: Vec::new(),
150 }
151 }
152
153 pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
154 self.messages.push(IndexMap::from([
155 ("role".to_string(), Either::Left(role.to_string())),
156 ("content".to_string(), Either::Left(text.to_string())),
157 ]));
158 self
159 }
160
161 pub fn add_image_message(
162 self,
163 role: TextMessageRole,
164 text: impl ToString,
165 images: Vec<DynamicImage>,
166 model: &Model,
167 ) -> anyhow::Result<Self> {
168 self.add_multimodal_message(role, text, images, vec![], model)
169 }
170
171 pub fn add_audio_message(
172 self,
173 role: TextMessageRole,
174 text: impl ToString,
175 audios: Vec<AudioInput>,
176 model: &Model,
177 ) -> anyhow::Result<Self> {
178 self.add_multimodal_message(role, text, vec![], audios, model)
179 }
180
181 pub fn add_multimodal_message(
182 mut self,
183 role: TextMessageRole,
184 text: impl ToString,
185 images: Vec<DynamicImage>,
186 audios: Vec<AudioInput>,
187 model: &Model,
188 ) -> anyhow::Result<Self> {
189 let prefixer = match &model.config().category {
190 ModelCategory::Vision { prefixer } => prefixer,
191 ModelCategory::Text
192 | ModelCategory::Diffusion
193 | ModelCategory::Speech
194 | ModelCategory::Audio => {
195 anyhow::bail!("`add_image_message` expects a vision model.")
196 }
197 };
198
199 let n_added_images = images.len();
201 let prefixed = prefixer.prefix_image(
202 (self.images.len()..self.images.len() + n_added_images).collect(),
203 &text.to_string(),
204 );
205 self.images.extend(images);
206
207 let n_added_audios = audios.len();
209 let prefixed = prefixer.prefix_audio(
210 (self.audios.len()..self.audios.len() + n_added_audios).collect(),
211 &prefixed,
212 );
213 self.audios.extend(audios);
214
215 if n_added_images > 0 {
216 self.messages.push(IndexMap::from([
217 ("role".to_string(), Either::Left(role.to_string())),
218 (
219 "content".to_string(),
220 Either::Right(vec![
221 IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
222 IndexMap::from([
223 ("type".to_string(), Value::String("text".to_string())),
224 ("text".to_string(), Value::String(prefixed)),
225 ]),
226 ]),
227 ),
228 ]));
229 } else {
230 self.messages.push(IndexMap::from([
231 ("role".to_string(), Either::Left(role.to_string())),
232 ("content".to_string(), Either::Left(prefixed)),
233 ]));
234 }
235 Ok(self)
236 }
237
238 pub fn clear(mut self) -> Self {
239 self.messages.clear();
240 self.images.clear();
241
242 self
243 }
244}
245
246impl RequestLike for VisionMessages {
247 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
248 &self.messages
249 }
250 fn images_ref(&self) -> &[DynamicImage] {
251 &self.images
252 }
253 fn take_messages(&mut self) -> RequestMessage {
254 let mut other_messages = Vec::new();
255 std::mem::swap(&mut other_messages, &mut self.messages);
256 let mut other_images = Vec::new();
257 std::mem::swap(&mut other_images, &mut self.images);
258 let mut other_audios = Vec::new();
259 std::mem::swap(&mut other_audios, &mut self.audios);
260 RequestMessage::VisionChat {
261 images: other_images,
262 messages: other_messages,
263 audios: other_audios,
264 enable_thinking: self.enable_search(),
265 }
266 }
267 fn enable_search(&self) -> Option<bool> {
268 None
269 }
270 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
271 None
272 }
273 fn take_adapters(&mut self) -> Option<Vec<String>> {
274 None
275 }
276 fn return_logprobs(&self) -> bool {
277 false
278 }
279 fn take_constraint(&mut self) -> Constraint {
280 Constraint::None
281 }
282 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
283 None
284 }
285 fn take_sampling_params(&mut self) -> SamplingParams {
286 SamplingParams::deterministic()
287 }
288 fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
289 None
290 }
291}
292
293#[derive(Clone)]
294pub struct RequestBuilder {
304 messages: Vec<IndexMap<String, MessageContent>>,
305 images: Vec<DynamicImage>,
306 audios: Vec<AudioInput>,
307 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
308 adapters: Vec<String>,
309 return_logprobs: bool,
310 constraint: Constraint,
311 tools: Vec<Tool>,
312 tool_choice: ToolChoice,
313 sampling_params: SamplingParams,
314 web_search_options: Option<WebSearchOptions>,
315 enable_thinking: Option<bool>,
316}
317
318impl Default for RequestBuilder {
319 fn default() -> Self {
320 Self::new()
321 }
322}
323
324impl From<TextMessages> for RequestBuilder {
325 fn from(value: TextMessages) -> Self {
326 Self {
327 messages: value.0,
328 images: Vec::new(),
329 audios: Vec::new(),
330 logits_processors: Vec::new(),
331 adapters: Vec::new(),
332 return_logprobs: false,
333 constraint: Constraint::None,
334 tools: Vec::new(),
335 tool_choice: ToolChoice::Auto,
336 sampling_params: SamplingParams::deterministic(),
337 web_search_options: None,
338 enable_thinking: None,
339 }
340 }
341}
342
343impl From<VisionMessages> for RequestBuilder {
344 fn from(value: VisionMessages) -> Self {
345 Self {
346 messages: value.messages,
347 images: value.images,
348 audios: value.audios,
349 logits_processors: Vec::new(),
350 adapters: Vec::new(),
351 return_logprobs: false,
352 constraint: Constraint::None,
353 tools: Vec::new(),
354 tool_choice: ToolChoice::Auto,
355 sampling_params: SamplingParams::deterministic(),
356 web_search_options: None,
357 enable_thinking: None,
358 }
359 }
360}
361
362impl RequestBuilder {
363 pub fn new() -> Self {
364 Self {
365 messages: Vec::new(),
366 images: Vec::new(),
367 audios: Vec::new(),
368 logits_processors: Vec::new(),
369 adapters: Vec::new(),
370 return_logprobs: false,
371 constraint: Constraint::None,
372 tools: Vec::new(),
373 tool_choice: ToolChoice::Auto,
374 sampling_params: SamplingParams::deterministic(),
375 web_search_options: None,
376 enable_thinking: None,
377 }
378 }
379
380 pub fn with_web_search_options(mut self, web_search_options: WebSearchOptions) -> Self {
381 self.web_search_options = Some(web_search_options);
382 self
383 }
384
385 pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
390 self.messages.push(IndexMap::from([
391 ("role".to_string(), Either::Left(role.to_string())),
392 ("content".to_string(), Either::Left(text.to_string())),
393 ]));
394 self
395 }
396
397 pub fn add_tool_message(mut self, tool_content: impl ToString, tool_id: impl ToString) -> Self {
399 self.messages.push(IndexMap::from([
400 (
401 "role".to_string(),
402 Either::Left(TextMessageRole::Tool.to_string()),
403 ),
404 (
405 "content".to_string(),
406 Either::Left(tool_content.to_string()),
407 ),
408 (
409 "tool_call_id".to_string(),
410 Either::Left(tool_id.to_string()),
411 ),
412 ]));
413 self
414 }
415
416 pub fn add_message_with_tool_call(
417 mut self,
418 role: TextMessageRole,
419 text: impl ToString,
420 tool_calls: Vec<ToolCallResponse>,
421 ) -> Self {
422 let tool_messages = tool_calls
423 .iter()
424 .map(|t| {
425 IndexMap::from([
426 ("id".to_string(), Value::String(t.id.clone())),
427 ("type".to_string(), Value::String(t.tp.to_string())),
428 (
429 "function".to_string(),
430 json!({
431 "name": t.function.name,
432 "arguments": t.function.arguments,
433 }),
434 ),
435 ])
436 })
437 .collect();
438 self.messages.push(IndexMap::from([
439 ("role".to_string(), Either::Left(role.to_string())),
440 ("content".to_string(), Either::Left(text.to_string())),
441 ("function".to_string(), Either::Right(tool_messages)),
442 ]));
443 self
444 }
445
446 pub fn add_image_message(
447 self,
448 role: TextMessageRole,
449 text: impl ToString,
450 images: Vec<DynamicImage>,
451 model: &Model,
452 ) -> anyhow::Result<Self> {
453 self.add_multimodal_message(role, text, images, vec![], model)
454 }
455
456 pub fn add_audio_message(
457 self,
458 role: TextMessageRole,
459 text: impl ToString,
460 audios: Vec<AudioInput>,
461 model: &Model,
462 ) -> anyhow::Result<Self> {
463 self.add_multimodal_message(role, text, vec![], audios, model)
464 }
465
466 pub fn add_multimodal_message(
467 mut self,
468 role: TextMessageRole,
469 text: impl ToString,
470 images: Vec<DynamicImage>,
471 audios: Vec<AudioInput>,
472 model: &Model,
473 ) -> anyhow::Result<Self> {
474 let prefixer = match &model.config().category {
475 ModelCategory::Vision { prefixer } => prefixer,
476 ModelCategory::Text
477 | ModelCategory::Diffusion
478 | ModelCategory::Speech
479 | ModelCategory::Audio => {
480 anyhow::bail!("`add_image_message` expects a vision model.")
481 }
482 };
483
484 let n_added_images = images.len();
486 let prefixed = prefixer.prefix_image(
487 (self.images.len()..self.images.len() + n_added_images).collect(),
488 &text.to_string(),
489 );
490 self.images.extend(images);
491
492 let n_added_audios = audios.len();
494 let prefixed = prefixer.prefix_audio(
495 (self.audios.len()..self.audios.len() + n_added_audios).collect(),
496 &prefixed,
497 );
498 self.audios.extend(audios);
499
500 if n_added_images > 0 {
501 self.messages.push(IndexMap::from([
502 ("role".to_string(), Either::Left(role.to_string())),
503 (
504 "content".to_string(),
505 Either::Right(vec![
506 IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
507 IndexMap::from([
508 ("type".to_string(), Value::String("text".to_string())),
509 ("text".to_string(), Value::String(prefixed)),
510 ]),
511 ]),
512 ),
513 ]));
514 } else {
515 self.messages.push(IndexMap::from([
516 ("role".to_string(), Either::Left(role.to_string())),
517 ("content".to_string(), Either::Left(prefixed)),
518 ]));
519 }
520 Ok(self)
521 }
522
523 pub fn add_logits_processor(mut self, processor: Arc<dyn CustomLogitsProcessor>) -> Self {
524 self.logits_processors.push(processor);
525 self
526 }
527
528 pub fn set_adapters(mut self, adapters: Vec<String>) -> Self {
529 self.adapters = adapters;
530 self
531 }
532
533 pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
535 self.tools = tools;
536 self
537 }
538
539 pub fn set_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
540 self.tool_choice = tool_choice;
541 self
542 }
543
544 pub fn return_logprobs(mut self, return_logprobs: bool) -> Self {
545 self.return_logprobs = return_logprobs;
546 self
547 }
548
549 pub fn set_constraint(mut self, constraint: Constraint) -> Self {
550 self.constraint = constraint;
551 self
552 }
553
554 pub fn set_sampling(mut self, params: SamplingParams) -> Self {
556 self.sampling_params = params;
557 self
558 }
559
560 pub fn set_deterministic_sampler(mut self) -> Self {
566 self.sampling_params = SamplingParams::deterministic();
567 self
568 }
569
570 pub fn set_sampler_temperature(mut self, temperature: f64) -> Self {
571 self.sampling_params.temperature = Some(temperature);
572 self
573 }
574
575 pub fn set_sampler_topk(mut self, topk: usize) -> Self {
576 self.sampling_params.top_k = Some(topk);
577 self
578 }
579
580 pub fn set_sampler_topp(mut self, topp: f64) -> Self {
581 self.sampling_params.top_p = Some(topp);
582 self
583 }
584
585 pub fn set_sampler_minp(mut self, minp: f64) -> Self {
586 self.sampling_params.min_p = Some(minp);
587 self
588 }
589
590 pub fn set_sampler_topn_logprobs(mut self, top_n_logprobs: usize) -> Self {
591 self.sampling_params.top_n_logprobs = top_n_logprobs;
592 self
593 }
594
595 pub fn set_sampler_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
596 self.sampling_params.frequency_penalty = Some(frequency_penalty);
597 self
598 }
599
600 pub fn set_sampler_presence_penalty(mut self, presence_penalty: f32) -> Self {
601 self.sampling_params.presence_penalty = Some(presence_penalty);
602 self
603 }
604
605 pub fn set_sampler_stop_toks(mut self, stop_toks: StopTokens) -> Self {
606 self.sampling_params.stop_toks = Some(stop_toks);
607 self
608 }
609
610 pub fn set_sampler_max_len(mut self, max_len: usize) -> Self {
611 self.sampling_params.max_len = Some(max_len);
612 self
613 }
614
615 pub fn set_sampler_logits_bias(mut self, logits_bias: HashMap<u32, f32>) -> Self {
616 self.sampling_params.logits_bias = Some(logits_bias);
617 self
618 }
619
620 pub fn set_sampler_n_choices(mut self, n_choices: usize) -> Self {
621 self.sampling_params.n_choices = n_choices;
622 self
623 }
624
625 pub fn set_sampler_dry_params(mut self, dry_params: DrySamplingParams) -> Self {
626 self.sampling_params.dry_params = Some(dry_params);
627 self
628 }
629
630 pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
631 self.enable_thinking = Some(enable_thinking);
632 self
633 }
634}
635
636impl RequestLike for RequestBuilder {
637 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
638 &self.messages
639 }
640
641 fn images_ref(&self) -> &[DynamicImage] {
642 &self.images
643 }
644
645 fn take_messages(&mut self) -> RequestMessage {
646 if self.images.is_empty() && self.audios.is_empty() {
647 let mut other = Vec::new();
648 std::mem::swap(&mut other, &mut self.messages);
649 RequestMessage::Chat {
650 messages: other,
651 enable_thinking: self.enable_thinking,
652 }
653 } else {
654 let mut other_messages = Vec::new();
655 std::mem::swap(&mut other_messages, &mut self.messages);
656 let mut other_images = Vec::new();
657 std::mem::swap(&mut other_images, &mut self.images);
658 let mut other_audios = Vec::new();
659 std::mem::swap(&mut other_audios, &mut self.audios);
660 RequestMessage::VisionChat {
661 images: other_images,
662 messages: other_messages,
663 audios: other_audios,
664 enable_thinking: self.enable_thinking,
665 }
666 }
667 }
668
669 fn enable_search(&self) -> Option<bool> {
670 self.enable_thinking
671 }
672
673 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
674 if self.logits_processors.is_empty() {
675 None
676 } else {
677 let mut other = Vec::new();
678 std::mem::swap(&mut other, &mut self.logits_processors);
679 Some(other)
680 }
681 }
682
683 fn take_adapters(&mut self) -> Option<Vec<String>> {
684 if self.adapters.is_empty() {
685 None
686 } else {
687 let mut other = Vec::new();
688 std::mem::swap(&mut other, &mut self.adapters);
689 Some(other)
690 }
691 }
692
693 fn return_logprobs(&self) -> bool {
694 self.return_logprobs
695 }
696
697 fn take_constraint(&mut self) -> Constraint {
698 let mut other = Constraint::None;
699 std::mem::swap(&mut other, &mut self.constraint);
700 other
701 }
702
703 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
704 if self.tools.is_empty() {
705 None
706 } else {
707 let mut other_ts = Vec::new();
708 std::mem::swap(&mut other_ts, &mut self.tools);
709 let mut other_tc = ToolChoice::Auto;
710 std::mem::swap(&mut other_tc, &mut self.tool_choice);
711 Some((other_ts, other_tc))
712 }
713 }
714
715 fn take_sampling_params(&mut self) -> SamplingParams {
716 let mut other = SamplingParams::deterministic();
717 std::mem::swap(&mut other, &mut self.sampling_params);
718 other
719 }
720
721 fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
722 let mut other = None;
723 std::mem::swap(&mut other, &mut self.web_search_options);
724 other
725 }
726}