1use std::{collections::HashMap, fmt::Display, sync::Arc};
2
3use super::*;
4use either::Either;
5use image::DynamicImage;
6use indexmap::IndexMap;
7use serde_json::{json, Value};
8
9pub trait RequestLike {
11 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>];
12 fn take_messages(&mut self) -> RequestMessage;
13 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>>;
14 fn take_adapters(&mut self) -> Option<Vec<String>>;
15 fn return_logprobs(&self) -> bool;
16 fn take_constraint(&mut self) -> Constraint;
17 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)>;
18 fn take_sampling_params(&mut self) -> SamplingParams;
19 fn take_web_search_options(&mut self) -> Option<WebSearchOptions>;
20}
21
22#[derive(Debug, Clone, PartialEq)]
23pub struct TextMessages(Vec<IndexMap<String, MessageContent>>);
29
30impl From<TextMessages> for Vec<IndexMap<String, MessageContent>> {
31 fn from(value: TextMessages) -> Self {
32 value.0
33 }
34}
35
36pub enum TextMessageRole {
38 User,
39 Assistant,
40 System,
41 Tool,
42 Custom(String),
43}
44
45impl Display for TextMessageRole {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 Self::User => write!(f, "user"),
49 Self::Assistant => write!(f, "assistant"),
50 Self::System => write!(f, "system"),
51 Self::Tool => write!(f, "tool"),
52 Self::Custom(c) => write!(f, "{c}"),
53 }
54 }
55}
56
57impl Default for TextMessages {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl TextMessages {
64 pub fn new() -> Self {
65 Self(Vec::new())
66 }
67
68 pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
69 self.0.push(IndexMap::from([
70 ("role".to_string(), Either::Left(role.to_string())),
71 ("content".to_string(), Either::Left(text.to_string())),
72 ]));
73 self
74 }
75
76 pub fn clear(mut self) -> Self {
77 self.0.clear();
78 self
79 }
80}
81
82impl RequestLike for TextMessages {
83 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
84 &self.0
85 }
86 fn take_messages(&mut self) -> RequestMessage {
87 let mut other = Vec::new();
88 std::mem::swap(&mut other, &mut self.0);
89 RequestMessage::Chat(other)
90 }
91 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
92 None
93 }
94 fn take_adapters(&mut self) -> Option<Vec<String>> {
95 None
96 }
97 fn return_logprobs(&self) -> bool {
98 false
99 }
100 fn take_constraint(&mut self) -> Constraint {
101 Constraint::None
102 }
103 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
104 None
105 }
106 fn take_sampling_params(&mut self) -> SamplingParams {
107 SamplingParams::deterministic()
108 }
109 fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
110 None
111 }
112}
113
114#[derive(Debug, Clone, PartialEq)]
115pub struct VisionMessages {
121 messages: Vec<IndexMap<String, MessageContent>>,
122 images: Vec<DynamicImage>,
123}
124
125impl Default for VisionMessages {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131impl VisionMessages {
132 pub fn new() -> Self {
133 Self {
134 images: Vec::new(),
135 messages: Vec::new(),
136 }
137 }
138
139 pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
140 self.messages.push(IndexMap::from([
141 ("role".to_string(), Either::Left(role.to_string())),
142 ("content".to_string(), Either::Left(text.to_string())),
143 ]));
144 self
145 }
146
147 pub fn add_image_message(
148 mut self,
149 role: TextMessageRole,
150 text: impl ToString,
151 image: DynamicImage,
152 model: &Model,
153 ) -> anyhow::Result<Self> {
154 let prefixer = match &model.config().category {
155 ModelCategory::Text | ModelCategory::Diffusion => {
156 anyhow::bail!("`add_image_message` expects a vision model.")
157 }
158 ModelCategory::Vision {
159 has_conv2d: _,
160 prefixer,
161 } => prefixer,
162 };
163 self.images.push(image);
164 self.messages.push(IndexMap::from([
165 ("role".to_string(), Either::Left(role.to_string())),
166 (
167 "content".to_string(),
168 Either::Right(vec![
169 IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
170 IndexMap::from([
171 ("type".to_string(), Value::String("text".to_string())),
172 (
173 "text".to_string(),
174 Value::String(
175 prefixer.prefix_image(self.images.len() - 1, &text.to_string()),
176 ),
177 ),
178 ]),
179 ]),
180 ),
181 ]));
182 Ok(self)
183 }
184
185 pub fn clear(mut self) -> Self {
186 self.messages.clear();
187 self.images.clear();
188
189 self
190 }
191}
192
193impl RequestLike for VisionMessages {
194 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
195 &self.messages
196 }
197 fn take_messages(&mut self) -> RequestMessage {
198 let mut other_messages = Vec::new();
199 std::mem::swap(&mut other_messages, &mut self.messages);
200 let mut other_images = Vec::new();
201 std::mem::swap(&mut other_images, &mut self.images);
202 RequestMessage::VisionChat {
203 images: other_images,
204 messages: other_messages,
205 }
206 }
207 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
208 None
209 }
210 fn take_adapters(&mut self) -> Option<Vec<String>> {
211 None
212 }
213 fn return_logprobs(&self) -> bool {
214 false
215 }
216 fn take_constraint(&mut self) -> Constraint {
217 Constraint::None
218 }
219 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
220 None
221 }
222 fn take_sampling_params(&mut self) -> SamplingParams {
223 SamplingParams::deterministic()
224 }
225 fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
226 None
227 }
228}
229
230#[derive(Clone)]
231pub struct RequestBuilder {
240 messages: Vec<IndexMap<String, MessageContent>>,
241 images: Vec<DynamicImage>,
242 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
243 adapters: Vec<String>,
244 return_logprobs: bool,
245 constraint: Constraint,
246 tools: Vec<Tool>,
247 tool_choice: ToolChoice,
248 sampling_params: SamplingParams,
249 web_search_options: Option<WebSearchOptions>,
250}
251
252impl Default for RequestBuilder {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258impl From<TextMessages> for RequestBuilder {
259 fn from(value: TextMessages) -> Self {
260 Self {
261 messages: value.0,
262 images: Vec::new(),
263 logits_processors: Vec::new(),
264 adapters: Vec::new(),
265 return_logprobs: false,
266 constraint: Constraint::None,
267 tools: Vec::new(),
268 tool_choice: ToolChoice::Auto,
269 sampling_params: SamplingParams::deterministic(),
270 web_search_options: None,
271 }
272 }
273}
274
275impl From<VisionMessages> for RequestBuilder {
276 fn from(value: VisionMessages) -> Self {
277 Self {
278 messages: value.messages,
279 images: value.images,
280 logits_processors: Vec::new(),
281 adapters: Vec::new(),
282 return_logprobs: false,
283 constraint: Constraint::None,
284 tools: Vec::new(),
285 tool_choice: ToolChoice::Auto,
286 sampling_params: SamplingParams::deterministic(),
287 web_search_options: None,
288 }
289 }
290}
291
292impl RequestBuilder {
293 pub fn new() -> Self {
294 Self {
295 messages: Vec::new(),
296 images: Vec::new(),
297 logits_processors: Vec::new(),
298 adapters: Vec::new(),
299 return_logprobs: false,
300 constraint: Constraint::None,
301 tools: Vec::new(),
302 tool_choice: ToolChoice::Auto,
303 sampling_params: SamplingParams::deterministic(),
304 web_search_options: None,
305 }
306 }
307
308 pub fn with_web_search_options(mut self, web_search_options: WebSearchOptions) -> Self {
309 self.web_search_options = Some(web_search_options);
310 self
311 }
312
313 pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
318 self.messages.push(IndexMap::from([
319 ("role".to_string(), Either::Left(role.to_string())),
320 ("content".to_string(), Either::Left(text.to_string())),
321 ]));
322 self
323 }
324
325 pub fn add_tool_message(mut self, tool_content: impl ToString, tool_id: impl ToString) -> Self {
327 self.messages.push(IndexMap::from([
328 (
329 "role".to_string(),
330 Either::Left(TextMessageRole::Tool.to_string()),
331 ),
332 (
333 "content".to_string(),
334 Either::Left(tool_content.to_string()),
335 ),
336 (
337 "tool_call_id".to_string(),
338 Either::Left(tool_id.to_string()),
339 ),
340 ]));
341 self
342 }
343
344 pub fn add_message_with_tool_call(
345 mut self,
346 role: TextMessageRole,
347 text: impl ToString,
348 tool_calls: Vec<ToolCallResponse>,
349 ) -> Self {
350 let tool_messages = tool_calls
351 .iter()
352 .map(|t| {
353 IndexMap::from([
354 ("id".to_string(), Value::String(t.id.clone())),
355 ("type".to_string(), Value::String(t.tp.to_string())),
356 (
357 "function".to_string(),
358 json!({
359 "name": t.function.name,
360 "arguments": t.function.arguments,
361 }),
362 ),
363 ])
364 })
365 .collect();
366 self.messages.push(IndexMap::from([
367 ("role".to_string(), Either::Left(role.to_string())),
368 ("content".to_string(), Either::Left(text.to_string())),
369 ("function".to_string(), Either::Right(tool_messages)),
370 ]));
371 self
372 }
373
374 pub fn add_image_message(
375 mut self,
376 role: TextMessageRole,
377 text: impl ToString,
378 image: DynamicImage,
379 ) -> Self {
380 self.messages.push(IndexMap::from([
381 ("role".to_string(), Either::Left(role.to_string())),
382 ("content".to_string(), Either::Left(text.to_string())),
383 ]));
384 self.images.push(image);
385 self
386 }
387
388 pub fn add_logits_processor(mut self, processor: Arc<dyn CustomLogitsProcessor>) -> Self {
389 self.logits_processors.push(processor);
390 self
391 }
392
393 pub fn set_adapters(mut self, adapters: Vec<String>) -> Self {
394 self.adapters = adapters;
395 self
396 }
397
398 pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
400 self.tools = tools;
401 self
402 }
403
404 pub fn set_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
405 self.tool_choice = tool_choice;
406 self
407 }
408
409 pub fn return_logprobs(mut self, return_logprobs: bool) -> Self {
410 self.return_logprobs = return_logprobs;
411 self
412 }
413
414 pub fn set_constraint(mut self, constraint: Constraint) -> Self {
415 self.constraint = constraint;
416 self
417 }
418
419 pub fn set_sampling(mut self, params: SamplingParams) -> Self {
421 self.sampling_params = params;
422 self
423 }
424
425 pub fn set_deterministic_sampler(mut self) -> Self {
431 self.sampling_params = SamplingParams::deterministic();
432 self
433 }
434
435 pub fn set_sampler_temperature(mut self, temperature: f64) -> Self {
436 self.sampling_params.temperature = Some(temperature);
437 self
438 }
439
440 pub fn set_sampler_topk(mut self, topk: usize) -> Self {
441 self.sampling_params.top_k = Some(topk);
442 self
443 }
444
445 pub fn set_sampler_topp(mut self, topp: f64) -> Self {
446 self.sampling_params.top_p = Some(topp);
447 self
448 }
449
450 pub fn set_sampler_minp(mut self, minp: f64) -> Self {
451 self.sampling_params.min_p = Some(minp);
452 self
453 }
454
455 pub fn set_sampler_topn_logprobs(mut self, top_n_logprobs: usize) -> Self {
456 self.sampling_params.top_n_logprobs = top_n_logprobs;
457 self
458 }
459
460 pub fn set_sampler_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
461 self.sampling_params.frequency_penalty = Some(frequency_penalty);
462 self
463 }
464
465 pub fn set_sampler_presence_penalty(mut self, presence_penalty: f32) -> Self {
466 self.sampling_params.presence_penalty = Some(presence_penalty);
467 self
468 }
469
470 pub fn set_sampler_stop_toks(mut self, stop_toks: StopTokens) -> Self {
471 self.sampling_params.stop_toks = Some(stop_toks);
472 self
473 }
474
475 pub fn set_sampler_max_len(mut self, max_len: usize) -> Self {
476 self.sampling_params.max_len = Some(max_len);
477 self
478 }
479
480 pub fn set_sampler_logits_bias(mut self, logits_bias: HashMap<u32, f32>) -> Self {
481 self.sampling_params.logits_bias = Some(logits_bias);
482 self
483 }
484
485 pub fn set_sampler_n_choices(mut self, n_choices: usize) -> Self {
486 self.sampling_params.n_choices = n_choices;
487 self
488 }
489
490 pub fn set_sampler_dry_params(mut self, dry_params: DrySamplingParams) -> Self {
491 self.sampling_params.dry_params = Some(dry_params);
492 self
493 }
494}
495
496impl RequestLike for RequestBuilder {
497 fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
498 &self.messages
499 }
500
501 fn take_messages(&mut self) -> RequestMessage {
502 if self.images.is_empty() {
503 let mut other = Vec::new();
504 std::mem::swap(&mut other, &mut self.messages);
505 RequestMessage::Chat(other)
506 } else {
507 let mut other_messages = Vec::new();
508 std::mem::swap(&mut other_messages, &mut self.messages);
509 let mut other_images = Vec::new();
510 std::mem::swap(&mut other_images, &mut self.images);
511 RequestMessage::VisionChat {
512 images: other_images,
513 messages: other_messages,
514 }
515 }
516 }
517
518 fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
519 if self.logits_processors.is_empty() {
520 None
521 } else {
522 let mut other = Vec::new();
523 std::mem::swap(&mut other, &mut self.logits_processors);
524 Some(other)
525 }
526 }
527
528 fn take_adapters(&mut self) -> Option<Vec<String>> {
529 if self.adapters.is_empty() {
530 None
531 } else {
532 let mut other = Vec::new();
533 std::mem::swap(&mut other, &mut self.adapters);
534 Some(other)
535 }
536 }
537
538 fn return_logprobs(&self) -> bool {
539 self.return_logprobs
540 }
541
542 fn take_constraint(&mut self) -> Constraint {
543 let mut other = Constraint::None;
544 std::mem::swap(&mut other, &mut self.constraint);
545 other
546 }
547
548 fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
549 if self.tools.is_empty() {
550 None
551 } else {
552 let mut other_ts = Vec::new();
553 std::mem::swap(&mut other_ts, &mut self.tools);
554 let mut other_tc = ToolChoice::Auto;
555 std::mem::swap(&mut other_tc, &mut self.tool_choice);
556 Some((other_ts, other_tc))
557 }
558 }
559
560 fn take_sampling_params(&mut self) -> SamplingParams {
561 let mut other = SamplingParams::deterministic();
562 std::mem::swap(&mut other, &mut self.sampling_params);
563 other
564 }
565
566 fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
567 let mut other = None;
568 std::mem::swap(&mut other, &mut self.web_search_options);
569 other
570 }
571}