1use std::{collections::HashMap, ops::Deref};
4
5use either::Either;
6use mistralrs_core::{
7 ImageGenerationResponseFormat, LlguidanceGrammar, Tool, ToolChoice, ToolType, WebSearchOptions,
8};
9use serde::{Deserialize, Serialize};
10use utoipa::{
11 openapi::{schema::SchemaType, ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, Type},
12 PartialSchema, ToSchema,
13};
14
15#[derive(Debug, Clone, Deserialize, Serialize)]
17pub struct MessageInnerContent(
18 #[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
19);
20
21impl PartialSchema for MessageInnerContent {
24 fn schema() -> RefOr<Schema> {
25 RefOr::T(message_inner_content_schema())
26 }
27}
28
29impl ToSchema for MessageInnerContent {
30 fn schemas(
31 schemas: &mut Vec<(
32 String,
33 utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
34 )>,
35 ) {
36 schemas.push((
37 MessageInnerContent::name().into(),
38 MessageInnerContent::schema(),
39 ));
40 }
41}
42
43impl Deref for MessageInnerContent {
44 type Target = Either<String, HashMap<String, String>>;
45 fn deref(&self) -> &Self::Target {
46 &self.0
47 }
48}
49
50fn message_inner_content_schema() -> Schema {
52 Schema::OneOf(
53 OneOfBuilder::new()
54 .item(Schema::Object(
56 ObjectBuilder::new()
57 .schema_type(SchemaType::Type(Type::String))
58 .build(),
59 ))
60 .item(Schema::Object(
62 ObjectBuilder::new()
63 .schema_type(SchemaType::Type(Type::Object))
64 .additional_properties(Some(RefOr::T(Schema::Object(
65 ObjectBuilder::new()
66 .schema_type(SchemaType::Type(Type::String))
67 .build(),
68 ))))
69 .build(),
70 ))
71 .build(),
72 )
73}
74
75#[derive(Debug, Clone, Deserialize, Serialize)]
77pub struct MessageContent(
78 #[serde(with = "either::serde_untagged")]
79 Either<String, Vec<HashMap<String, MessageInnerContent>>>,
80);
81
82impl PartialSchema for MessageContent {
85 fn schema() -> RefOr<Schema> {
86 RefOr::T(message_content_schema())
87 }
88}
89
90impl ToSchema for MessageContent {
91 fn schemas(
92 schemas: &mut Vec<(
93 String,
94 utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
95 )>,
96 ) {
97 schemas.push((MessageContent::name().into(), MessageContent::schema()));
98 }
99}
100
101impl Deref for MessageContent {
102 type Target = Either<String, Vec<HashMap<String, MessageInnerContent>>>;
103 fn deref(&self) -> &Self::Target {
104 &self.0
105 }
106}
107
108fn message_content_schema() -> Schema {
110 Schema::OneOf(
111 OneOfBuilder::new()
112 .item(Schema::Object(
113 ObjectBuilder::new()
114 .schema_type(SchemaType::Type(Type::String))
115 .build(),
116 ))
117 .item(Schema::Array(
118 ArrayBuilder::new()
119 .items(RefOr::T(Schema::Object(
120 ObjectBuilder::new()
121 .schema_type(SchemaType::Type(Type::Object))
122 .additional_properties(Some(RefOr::Ref(
123 utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
124 )))
125 .build(),
126 )))
127 .build(),
128 ))
129 .build(),
130 )
131}
132
133#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, ToSchema)]
138pub struct FunctionCalled {
139 pub name: String,
141 #[serde(alias = "arguments")]
143 pub parameters: String,
144}
145
146#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, ToSchema)]
150pub struct ToolCall {
151 #[serde(rename = "type")]
153 pub tp: ToolType,
154 pub function: FunctionCalled,
156}
157
158#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
183pub struct Message {
184 pub content: Option<MessageContent>,
186 pub role: String,
188 pub name: Option<String>,
189 pub tool_calls: Option<Vec<ToolCall>>,
191}
192
193#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
198#[serde(untagged)]
199pub enum StopTokens {
200 Multi(Vec<String>),
202 Single(String),
204}
205
206fn default_false() -> bool {
208 false
209}
210
211fn default_1usize() -> usize {
213 1
214}
215
216fn default_720usize() -> usize {
218 720
219}
220
221fn default_1280usize() -> usize {
223 1280
224}
225
226fn default_model() -> String {
228 "default".to_string()
229}
230
231fn default_response_format() -> ImageGenerationResponseFormat {
233 ImageGenerationResponseFormat::Url
234}
235
236#[derive(Debug, Clone, Deserialize, Serialize)]
269#[serde(tag = "type", content = "value")]
270pub enum Grammar {
271 #[serde(rename = "regex")]
273 Regex(String),
274 #[serde(rename = "json_schema")]
276 JsonSchema(serde_json::Value),
277 #[serde(rename = "llguidance")]
279 Llguidance(LlguidanceGrammar),
280 #[serde(rename = "lark")]
282 Lark(String),
283}
284
285impl PartialSchema for Grammar {
287 fn schema() -> RefOr<Schema> {
288 RefOr::T(Schema::OneOf(
289 OneOfBuilder::new()
290 .item(create_grammar_variant_schema(
291 "regex",
292 Schema::Object(
293 ObjectBuilder::new()
294 .schema_type(SchemaType::Type(Type::String))
295 .build(),
296 ),
297 ))
298 .item(create_grammar_variant_schema(
299 "json_schema",
300 Schema::Object(
301 ObjectBuilder::new()
302 .schema_type(SchemaType::Type(Type::Object))
303 .build(),
304 ),
305 ))
306 .item(create_grammar_variant_schema(
307 "llguidance",
308 llguidance_schema(),
309 ))
310 .item(create_grammar_variant_schema(
311 "lark",
312 Schema::Object(
313 ObjectBuilder::new()
314 .schema_type(SchemaType::Type(Type::String))
315 .build(),
316 ),
317 ))
318 .build(),
319 ))
320 }
321}
322
323impl ToSchema for Grammar {
324 fn schemas(
325 schemas: &mut Vec<(
326 String,
327 utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
328 )>,
329 ) {
330 schemas.push((Grammar::name().into(), Grammar::schema()));
331 }
332}
333
334fn create_grammar_variant_schema(type_value: &str, value_schema: Schema) -> Schema {
336 Schema::Object(
337 ObjectBuilder::new()
338 .schema_type(SchemaType::Type(Type::Object))
339 .property(
340 "type",
341 RefOr::T(Schema::Object(
342 ObjectBuilder::new()
343 .schema_type(SchemaType::Type(Type::String))
344 .enum_values(Some(vec![serde_json::Value::String(
345 type_value.to_string(),
346 )]))
347 .build(),
348 )),
349 )
350 .property("value", RefOr::T(value_schema))
351 .required("type")
352 .required("value")
353 .build(),
354 )
355}
356
357fn llguidance_schema() -> Schema {
359 let grammar_with_lexer_schema = Schema::Object(
360 ObjectBuilder::new()
361 .schema_type(SchemaType::Type(Type::Object))
362 .property(
363 "name",
364 RefOr::T(Schema::Object(
365 ObjectBuilder::new()
366 .schema_type(SchemaType::from_iter([Type::String, Type::Null]))
367 .description(Some(
368 "The name of this grammar, can be used in GenGrammar nodes",
369 ))
370 .build(),
371 )),
372 )
373 .property(
374 "json_schema",
375 RefOr::T(Schema::Object(
376 ObjectBuilder::new()
377 .schema_type(SchemaType::from_iter([Type::Object, Type::Null]))
378 .description(Some("The JSON schema that the grammar should generate"))
379 .build(),
380 )),
381 )
382 .property(
383 "lark_grammar",
384 RefOr::T(Schema::Object(
385 ObjectBuilder::new()
386 .schema_type(SchemaType::from_iter([Type::String, Type::Null]))
387 .description(Some("The Lark grammar that the grammar should generate"))
388 .build(),
389 )),
390 )
391 .description(Some("Grammar configuration with lexer settings"))
392 .build(),
393 );
394
395 Schema::Object(
396 ObjectBuilder::new()
397 .schema_type(SchemaType::Type(Type::Object))
398 .property(
399 "grammars",
400 RefOr::T(Schema::Array(
401 ArrayBuilder::new()
402 .items(RefOr::T(grammar_with_lexer_schema))
403 .description(Some("List of grammar configurations"))
404 .build(),
405 )),
406 )
407 .property(
408 "max_tokens",
409 RefOr::T(Schema::Object(
410 ObjectBuilder::new()
411 .schema_type(SchemaType::from_iter([Type::Integer, Type::Null]))
412 .description(Some("Maximum number of tokens to generate"))
413 .build(),
414 )),
415 )
416 .required("grammars")
417 .description(Some("Top-level grammar configuration for LLGuidance"))
418 .build(),
419 )
420}
421
422#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
424pub struct JsonSchemaResponseFormat {
425 pub name: String,
426 pub schema: serde_json::Value,
427}
428
429#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
431#[serde(tag = "type")]
432pub enum ResponseFormat {
433 #[serde(rename = "text")]
435 Text,
436 #[serde(rename = "json_schema")]
438 JsonSchema {
439 json_schema: JsonSchemaResponseFormat,
440 },
441}
442
443#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
445pub struct ChatCompletionRequest {
446 #[schema(
447 schema_with = messages_schema,
448 example = json!(vec![Message{content:Some(MessageContent{0: either::Left(("Why did the crab cross the road?".to_string()))}), role:"user".to_string(), name: None, tool_calls: None}])
449 )]
450 #[serde(with = "either::serde_untagged")]
451 pub messages: Either<Vec<Message>, String>,
452 #[schema(example = "mistral")]
453 #[serde(default = "default_model")]
454 pub model: String,
455 #[schema(example = json!(Option::None::<HashMap<u32, f32>>))]
456 pub logit_bias: Option<HashMap<u32, f32>>,
457 #[serde(default = "default_false")]
458 #[schema(example = false)]
459 pub logprobs: bool,
460 #[schema(example = json!(Option::None::<usize>))]
461 pub top_logprobs: Option<usize>,
462 #[schema(example = 256)]
463 #[serde(alias = "max_completion_tokens")]
464 pub max_tokens: Option<usize>,
465 #[serde(rename = "n")]
466 #[serde(default = "default_1usize")]
467 #[schema(example = 1)]
468 pub n_choices: usize,
469 #[schema(example = json!(Option::None::<f32>))]
470 pub presence_penalty: Option<f32>,
471 #[schema(example = json!(Option::None::<f32>))]
472 pub frequency_penalty: Option<f32>,
473 #[serde(rename = "stop")]
474 #[schema(example = json!(Option::None::<StopTokens>))]
475 pub stop_seqs: Option<StopTokens>,
476 #[schema(example = 0.7)]
477 pub temperature: Option<f64>,
478 #[schema(example = json!(Option::None::<f64>))]
479 pub top_p: Option<f64>,
480 #[schema(example = true)]
481 pub stream: Option<bool>,
482 #[schema(example = json!(Option::None::<Vec<Tool>>))]
483 pub tools: Option<Vec<Tool>>,
484 #[schema(example = json!(Option::None::<ToolChoice>))]
485 pub tool_choice: Option<ToolChoice>,
486 #[schema(example = json!(Option::None::<ResponseFormat>))]
487 pub response_format: Option<ResponseFormat>,
488 #[schema(example = json!(Option::None::<WebSearchOptions>))]
489 pub web_search_options: Option<WebSearchOptions>,
490
491 #[schema(example = json!(Option::None::<usize>))]
493 pub top_k: Option<usize>,
494 #[schema(example = json!(Option::None::<Grammar>))]
495 pub grammar: Option<Grammar>,
496 #[schema(example = json!(Option::None::<f64>))]
497 pub min_p: Option<f64>,
498 #[schema(example = json!(Option::None::<f32>))]
499 pub dry_multiplier: Option<f32>,
500 #[schema(example = json!(Option::None::<f32>))]
501 pub dry_base: Option<f32>,
502 #[schema(example = json!(Option::None::<usize>))]
503 pub dry_allowed_length: Option<usize>,
504 #[schema(example = json!(Option::None::<String>))]
505 pub dry_sequence_breakers: Option<Vec<String>>,
506 #[schema(example = json!(Option::None::<bool>))]
507 pub enable_thinking: Option<bool>,
508}
509
510fn messages_schema() -> Schema {
512 Schema::OneOf(
513 OneOfBuilder::new()
514 .item(Schema::Array(
515 ArrayBuilder::new()
516 .items(RefOr::Ref(utoipa::openapi::Ref::from_schema_name(
517 "Message",
518 )))
519 .build(),
520 ))
521 .item(Schema::Object(
522 ObjectBuilder::new()
523 .schema_type(SchemaType::Type(Type::String))
524 .build(),
525 ))
526 .build(),
527 )
528}
529
530#[derive(Debug, Serialize, ToSchema)]
532pub struct ModelObject {
533 pub id: String,
534 pub object: &'static str,
535 pub created: u64,
536 pub owned_by: &'static str,
537}
538
539#[derive(Debug, Serialize, ToSchema)]
541pub struct ModelObjects {
542 pub object: &'static str,
543 pub data: Vec<ModelObject>,
544}
545
546#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
548pub struct CompletionRequest {
549 #[schema(example = "mistral")]
550 #[serde(default = "default_model")]
551 pub model: String,
552 #[schema(example = "Say this is a test.")]
553 pub prompt: String,
554 #[schema(example = 1)]
555 pub best_of: Option<usize>,
556 #[serde(rename = "echo")]
557 #[serde(default = "default_false")]
558 #[schema(example = false)]
559 pub echo_prompt: bool,
560 #[schema(example = json!(Option::None::<f32>))]
561 pub presence_penalty: Option<f32>,
562 #[schema(example = json!(Option::None::<f32>))]
563 pub frequency_penalty: Option<f32>,
564 #[schema(example = json!(Option::None::<HashMap<u32, f32>>))]
565 pub logit_bias: Option<HashMap<u32, f32>>,
566 #[schema(example = json!(Option::None::<usize>))]
567 pub logprobs: Option<usize>,
568 #[schema(example = 16)]
569 pub max_tokens: Option<usize>,
570 #[serde(rename = "n")]
571 #[serde(default = "default_1usize")]
572 #[schema(example = 1)]
573 pub n_choices: usize,
574 #[serde(rename = "stop")]
575 #[schema(example = json!(Option::None::<StopTokens>))]
576 pub stop_seqs: Option<StopTokens>,
577 pub stream: Option<bool>,
578 #[schema(example = 0.7)]
579 pub temperature: Option<f64>,
580 #[schema(example = json!(Option::None::<f64>))]
581 pub top_p: Option<f64>,
582 #[schema(example = json!(Option::None::<String>))]
583 pub suffix: Option<String>,
584 #[serde(rename = "user")]
585 pub _user: Option<String>,
586 #[schema(example = json!(Option::None::<Vec<Tool>>))]
587 pub tools: Option<Vec<Tool>>,
588 #[schema(example = json!(Option::None::<ToolChoice>))]
589 pub tool_choice: Option<ToolChoice>,
590
591 #[schema(example = json!(Option::None::<usize>))]
593 pub top_k: Option<usize>,
594 #[schema(example = json!(Option::None::<Grammar>))]
595 pub grammar: Option<Grammar>,
596 #[schema(example = json!(Option::None::<f64>))]
597 pub min_p: Option<f64>,
598 #[schema(example = json!(Option::None::<f32>))]
599 pub dry_multiplier: Option<f32>,
600 #[schema(example = json!(Option::None::<f32>))]
601 pub dry_base: Option<f32>,
602 #[schema(example = json!(Option::None::<usize>))]
603 pub dry_allowed_length: Option<usize>,
604 #[schema(example = json!(Option::None::<String>))]
605 pub dry_sequence_breakers: Option<Vec<String>>,
606}
607
608#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
610pub struct ImageGenerationRequest {
611 #[schema(example = "mistral")]
612 #[serde(default = "default_model")]
613 pub model: String,
614 #[schema(example = "Draw a picture of a majestic, snow-covered mountain.")]
615 pub prompt: String,
616 #[serde(rename = "n")]
617 #[serde(default = "default_1usize")]
618 #[schema(example = 1)]
619 pub n_choices: usize,
620 #[serde(default = "default_response_format")]
621 pub response_format: ImageGenerationResponseFormat,
622 #[serde(default = "default_720usize")]
623 #[schema(example = 720)]
624 pub height: usize,
625 #[serde(default = "default_1280usize")]
626 #[schema(example = 1280)]
627 pub width: usize,
628}
629
630#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, ToSchema)]
632#[serde(rename_all = "lowercase")]
633pub enum AudioResponseFormat {
634 #[default]
636 Mp3,
637 Opus,
639 Aac,
641 Flac,
643 Wav,
645 Pcm,
647}
648
649impl AudioResponseFormat {
650 pub fn audio_content_type(
652 &self,
653 pcm_rate: usize,
654 pcm_channels: usize,
655 pcm_format: &'static str,
656 ) -> String {
657 let content_type = match &self {
658 AudioResponseFormat::Mp3 => "audio/mpeg".to_string(),
659 AudioResponseFormat::Opus => "audio/ogg; codecs=opus".to_string(),
660 AudioResponseFormat::Aac => "audio/aac".to_string(),
661 AudioResponseFormat::Flac => "audio/flac".to_string(),
662 AudioResponseFormat::Wav => "audio/wav".to_string(),
663 AudioResponseFormat::Pcm => format!("audio/pcm; codecs=1; format={pcm_format}"),
664 };
665
666 format!("{content_type}; rate={pcm_rate}; channels={pcm_channels}")
667 }
668}
669
670#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
672pub struct SpeechGenerationRequest {
673 #[schema(example = "nari-labs/Dia-1.6B")]
675 #[serde(default = "default_model")]
676 pub model: String,
677 #[schema(
679 example = "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."
680 )]
681 pub input: String,
682 #[schema(example = "mp3")]
685 pub response_format: AudioResponseFormat,
686}