1use std::{collections::HashMap, ops::Deref};
4
5use either::Either;
6use mistralrs_core::{
7 ImageGenerationResponseFormat, LlguidanceGrammar, Tool, ToolChoice, ToolType, WebSearchOptions,
8};
9use serde::{Deserialize, Serialize};
10use utoipa::{
11 openapi::{schema::SchemaType, ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, Type},
12 PartialSchema, ToSchema,
13};
14
15#[derive(Debug, Clone, Deserialize, Serialize)]
17pub struct MessageInnerContent(
18 #[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
19);
20
21impl PartialSchema for MessageInnerContent {
24 fn schema() -> RefOr<Schema> {
25 RefOr::T(message_inner_content_schema())
26 }
27}
28
29impl ToSchema for MessageInnerContent {
30 fn schemas(
31 schemas: &mut Vec<(
32 String,
33 utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
34 )>,
35 ) {
36 schemas.push((
37 MessageInnerContent::name().into(),
38 MessageInnerContent::schema(),
39 ));
40 }
41}
42
43impl Deref for MessageInnerContent {
44 type Target = Either<String, HashMap<String, String>>;
45 fn deref(&self) -> &Self::Target {
46 &self.0
47 }
48}
49
50fn message_inner_content_schema() -> Schema {
52 Schema::OneOf(
53 OneOfBuilder::new()
54 .item(Schema::Object(
56 ObjectBuilder::new()
57 .schema_type(SchemaType::Type(Type::String))
58 .build(),
59 ))
60 .item(Schema::Object(
62 ObjectBuilder::new()
63 .schema_type(SchemaType::Type(Type::Object))
64 .additional_properties(Some(RefOr::T(Schema::Object(
65 ObjectBuilder::new()
66 .schema_type(SchemaType::Type(Type::String))
67 .build(),
68 ))))
69 .build(),
70 ))
71 .build(),
72 )
73}
74
75#[derive(Debug, Clone, Deserialize, Serialize)]
77pub struct MessageContent(
78 #[serde(with = "either::serde_untagged")]
79 Either<String, Vec<HashMap<String, MessageInnerContent>>>,
80);
81
82impl PartialSchema for MessageContent {
85 fn schema() -> RefOr<Schema> {
86 RefOr::T(message_content_schema())
87 }
88}
89
90impl ToSchema for MessageContent {
91 fn schemas(
92 schemas: &mut Vec<(
93 String,
94 utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
95 )>,
96 ) {
97 schemas.push((MessageContent::name().into(), MessageContent::schema()));
98 }
99}
100
101impl Deref for MessageContent {
102 type Target = Either<String, Vec<HashMap<String, MessageInnerContent>>>;
103 fn deref(&self) -> &Self::Target {
104 &self.0
105 }
106}
107
108fn message_content_schema() -> Schema {
110 Schema::OneOf(
111 OneOfBuilder::new()
112 .item(Schema::Object(
113 ObjectBuilder::new()
114 .schema_type(SchemaType::Type(Type::String))
115 .build(),
116 ))
117 .item(Schema::Array(
118 ArrayBuilder::new()
119 .items(RefOr::T(Schema::Object(
120 ObjectBuilder::new()
121 .schema_type(SchemaType::Type(Type::Object))
122 .additional_properties(Some(RefOr::Ref(
123 utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
124 )))
125 .build(),
126 )))
127 .build(),
128 ))
129 .build(),
130 )
131}
132
133#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, ToSchema)]
138pub struct FunctionCalled {
139 pub name: String,
141 #[serde(alias = "arguments")]
143 pub parameters: String,
144}
145
146#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, ToSchema)]
150pub struct ToolCall {
151 #[serde(rename = "type")]
153 pub tp: ToolType,
154 pub function: FunctionCalled,
156}
157
158#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
183pub struct Message {
184 pub content: Option<MessageContent>,
186 pub role: String,
188 pub name: Option<String>,
189 pub tool_calls: Option<Vec<ToolCall>>,
191}
192
193#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
198#[serde(untagged)]
199pub enum StopTokens {
200 Multi(Vec<String>),
202 Single(String),
204}
205
206fn default_false() -> bool {
208 false
209}
210
211fn default_1usize() -> usize {
213 1
214}
215
216fn default_720usize() -> usize {
218 720
219}
220
221fn default_1280usize() -> usize {
223 1280
224}
225
226fn default_model() -> String {
228 "default".to_string()
229}
230
231fn default_response_format() -> ImageGenerationResponseFormat {
233 ImageGenerationResponseFormat::Url
234}
235
236#[derive(Debug, Clone, Deserialize, Serialize)]
269#[serde(tag = "type", content = "value")]
270pub enum Grammar {
271 #[serde(rename = "regex")]
273 Regex(String),
274 #[serde(rename = "json_schema")]
276 JsonSchema(serde_json::Value),
277 #[serde(rename = "llguidance")]
279 Llguidance(LlguidanceGrammar),
280 #[serde(rename = "lark")]
282 Lark(String),
283}
284
285impl PartialSchema for Grammar {
287 fn schema() -> RefOr<Schema> {
288 RefOr::T(Schema::OneOf(
289 OneOfBuilder::new()
290 .item(create_grammar_variant_schema(
291 "regex",
292 Schema::Object(
293 ObjectBuilder::new()
294 .schema_type(SchemaType::Type(Type::String))
295 .build(),
296 ),
297 ))
298 .item(create_grammar_variant_schema(
299 "json_schema",
300 Schema::Object(
301 ObjectBuilder::new()
302 .schema_type(SchemaType::Type(Type::Object))
303 .build(),
304 ),
305 ))
306 .item(create_grammar_variant_schema(
307 "llguidance",
308 llguidance_schema(),
309 ))
310 .item(create_grammar_variant_schema(
311 "lark",
312 Schema::Object(
313 ObjectBuilder::new()
314 .schema_type(SchemaType::Type(Type::String))
315 .build(),
316 ),
317 ))
318 .build(),
319 ))
320 }
321}
322
323impl ToSchema for Grammar {
324 fn schemas(
325 schemas: &mut Vec<(
326 String,
327 utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
328 )>,
329 ) {
330 schemas.push((Grammar::name().into(), Grammar::schema()));
331 }
332}
333
334fn create_grammar_variant_schema(type_value: &str, value_schema: Schema) -> Schema {
336 Schema::Object(
337 ObjectBuilder::new()
338 .schema_type(SchemaType::Type(Type::Object))
339 .property(
340 "type",
341 RefOr::T(Schema::Object(
342 ObjectBuilder::new()
343 .schema_type(SchemaType::Type(Type::String))
344 .enum_values(Some(vec![serde_json::Value::String(
345 type_value.to_string(),
346 )]))
347 .build(),
348 )),
349 )
350 .property("value", RefOr::T(value_schema))
351 .required("type")
352 .required("value")
353 .build(),
354 )
355}
356
357fn llguidance_schema() -> Schema {
359 let grammar_with_lexer_schema = Schema::Object(
360 ObjectBuilder::new()
361 .schema_type(SchemaType::Type(Type::Object))
362 .property(
363 "name",
364 RefOr::T(Schema::Object(
365 ObjectBuilder::new()
366 .schema_type(SchemaType::from_iter([Type::String, Type::Null]))
367 .description(Some(
368 "The name of this grammar, can be used in GenGrammar nodes",
369 ))
370 .build(),
371 )),
372 )
373 .property(
374 "json_schema",
375 RefOr::T(Schema::Object(
376 ObjectBuilder::new()
377 .schema_type(SchemaType::from_iter([Type::Object, Type::Null]))
378 .description(Some("The JSON schema that the grammar should generate"))
379 .build(),
380 )),
381 )
382 .property(
383 "lark_grammar",
384 RefOr::T(Schema::Object(
385 ObjectBuilder::new()
386 .schema_type(SchemaType::from_iter([Type::String, Type::Null]))
387 .description(Some("The Lark grammar that the grammar should generate"))
388 .build(),
389 )),
390 )
391 .description(Some("Grammar configuration with lexer settings"))
392 .build(),
393 );
394
395 Schema::Object(
396 ObjectBuilder::new()
397 .schema_type(SchemaType::Type(Type::Object))
398 .property(
399 "grammars",
400 RefOr::T(Schema::Array(
401 ArrayBuilder::new()
402 .items(RefOr::T(grammar_with_lexer_schema))
403 .description(Some("List of grammar configurations"))
404 .build(),
405 )),
406 )
407 .property(
408 "max_tokens",
409 RefOr::T(Schema::Object(
410 ObjectBuilder::new()
411 .schema_type(SchemaType::from_iter([Type::Integer, Type::Null]))
412 .description(Some("Maximum number of tokens to generate"))
413 .build(),
414 )),
415 )
416 .required("grammars")
417 .description(Some("Top-level grammar configuration for LLGuidance"))
418 .build(),
419 )
420}
421
422#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
424pub struct JsonSchemaResponseFormat {
425 pub name: String,
426 pub schema: serde_json::Value,
427}
428
429#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
431#[serde(tag = "type")]
432pub enum ResponseFormat {
433 #[serde(rename = "text")]
435 Text,
436 #[serde(rename = "json_schema")]
438 JsonSchema {
439 json_schema: JsonSchemaResponseFormat,
440 },
441}
442
443#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
445pub struct ChatCompletionRequest {
446 #[schema(
447 schema_with = messages_schema,
448 example = json!(vec![Message{content:Some(MessageContent{0: either::Left(("Why did the crab cross the road?".to_string()))}), role:"user".to_string(), name: None, tool_calls: None}])
449 )]
450 #[serde(with = "either::serde_untagged")]
451 pub messages: Either<Vec<Message>, String>,
452 #[schema(example = "mistral")]
453 #[serde(default = "default_model")]
454 pub model: String,
455 #[schema(example = json!(Option::None::<HashMap<u32, f32>>))]
456 pub logit_bias: Option<HashMap<u32, f32>>,
457 #[serde(default = "default_false")]
458 #[schema(example = false)]
459 pub logprobs: bool,
460 #[schema(example = json!(Option::None::<usize>))]
461 pub top_logprobs: Option<usize>,
462 #[schema(example = 256)]
463 #[serde(alias = "max_completion_tokens")]
464 pub max_tokens: Option<usize>,
465 #[serde(rename = "n")]
466 #[serde(default = "default_1usize")]
467 #[schema(example = 1)]
468 pub n_choices: usize,
469 #[schema(example = json!(Option::None::<f32>))]
470 pub presence_penalty: Option<f32>,
471 #[schema(example = json!(Option::None::<f32>))]
472 pub frequency_penalty: Option<f32>,
473 #[serde(rename = "stop")]
474 #[schema(example = json!(Option::None::<StopTokens>))]
475 pub stop_seqs: Option<StopTokens>,
476 #[schema(example = 0.7)]
477 pub temperature: Option<f64>,
478 #[schema(example = json!(Option::None::<f64>))]
479 pub top_p: Option<f64>,
480 #[schema(example = true)]
481 pub stream: Option<bool>,
482 #[schema(example = json!(Option::None::<Vec<Tool>>))]
483 pub tools: Option<Vec<Tool>>,
484 #[schema(example = json!(Option::None::<ToolChoice>))]
485 pub tool_choice: Option<ToolChoice>,
486 #[schema(example = json!(Option::None::<ResponseFormat>))]
487 pub response_format: Option<ResponseFormat>,
488 #[schema(example = json!(Option::None::<WebSearchOptions>))]
489 pub web_search_options: Option<WebSearchOptions>,
490
491 #[schema(example = json!(Option::None::<usize>))]
493 pub top_k: Option<usize>,
494 #[schema(example = json!(Option::None::<Grammar>))]
495 pub grammar: Option<Grammar>,
496 #[schema(example = json!(Option::None::<f64>))]
497 pub min_p: Option<f64>,
498 #[schema(example = json!(Option::None::<f32>))]
499 pub dry_multiplier: Option<f32>,
500 #[schema(example = json!(Option::None::<f32>))]
501 pub dry_base: Option<f32>,
502 #[schema(example = json!(Option::None::<usize>))]
503 pub dry_allowed_length: Option<usize>,
504 #[schema(example = json!(Option::None::<String>))]
505 pub dry_sequence_breakers: Option<Vec<String>>,
506 #[schema(example = json!(Option::None::<bool>))]
507 pub enable_thinking: Option<bool>,
508}
509
510fn messages_schema() -> Schema {
512 Schema::OneOf(
513 OneOfBuilder::new()
514 .item(Schema::Array(
515 ArrayBuilder::new()
516 .items(RefOr::Ref(utoipa::openapi::Ref::from_schema_name(
517 "Message",
518 )))
519 .build(),
520 ))
521 .item(Schema::Object(
522 ObjectBuilder::new()
523 .schema_type(SchemaType::Type(Type::String))
524 .build(),
525 ))
526 .build(),
527 )
528}
529
530#[derive(Debug, Serialize, ToSchema)]
532pub struct ModelObject {
533 pub id: String,
534 pub object: &'static str,
535 pub created: u64,
536 pub owned_by: &'static str,
537 #[serde(skip_serializing_if = "Option::is_none")]
539 pub tools_available: Option<bool>,
540 #[serde(skip_serializing_if = "Option::is_none")]
542 pub mcp_tools_count: Option<usize>,
543 #[serde(skip_serializing_if = "Option::is_none")]
545 pub mcp_servers_connected: Option<usize>,
546}
547
548#[derive(Debug, Serialize, ToSchema)]
550pub struct ModelObjects {
551 pub object: &'static str,
552 pub data: Vec<ModelObject>,
553}
554
555#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
557pub struct CompletionRequest {
558 #[schema(example = "mistral")]
559 #[serde(default = "default_model")]
560 pub model: String,
561 #[schema(example = "Say this is a test.")]
562 pub prompt: String,
563 #[schema(example = 1)]
564 pub best_of: Option<usize>,
565 #[serde(rename = "echo")]
566 #[serde(default = "default_false")]
567 #[schema(example = false)]
568 pub echo_prompt: bool,
569 #[schema(example = json!(Option::None::<f32>))]
570 pub presence_penalty: Option<f32>,
571 #[schema(example = json!(Option::None::<f32>))]
572 pub frequency_penalty: Option<f32>,
573 #[schema(example = json!(Option::None::<HashMap<u32, f32>>))]
574 pub logit_bias: Option<HashMap<u32, f32>>,
575 #[schema(example = json!(Option::None::<usize>))]
576 pub logprobs: Option<usize>,
577 #[schema(example = 16)]
578 #[serde(alias = "max_completion_tokens")]
579 pub max_tokens: Option<usize>,
580 #[serde(rename = "n")]
581 #[serde(default = "default_1usize")]
582 #[schema(example = 1)]
583 pub n_choices: usize,
584 #[serde(rename = "stop")]
585 #[schema(example = json!(Option::None::<StopTokens>))]
586 pub stop_seqs: Option<StopTokens>,
587 pub stream: Option<bool>,
588 #[schema(example = 0.7)]
589 pub temperature: Option<f64>,
590 #[schema(example = json!(Option::None::<f64>))]
591 pub top_p: Option<f64>,
592 #[schema(example = json!(Option::None::<String>))]
593 pub suffix: Option<String>,
594 #[serde(rename = "user")]
595 pub _user: Option<String>,
596 #[schema(example = json!(Option::None::<Vec<Tool>>))]
597 pub tools: Option<Vec<Tool>>,
598 #[schema(example = json!(Option::None::<ToolChoice>))]
599 pub tool_choice: Option<ToolChoice>,
600
601 #[schema(example = json!(Option::None::<usize>))]
603 pub top_k: Option<usize>,
604 #[schema(example = json!(Option::None::<Grammar>))]
605 pub grammar: Option<Grammar>,
606 #[schema(example = json!(Option::None::<f64>))]
607 pub min_p: Option<f64>,
608 #[schema(example = json!(Option::None::<f32>))]
609 pub dry_multiplier: Option<f32>,
610 #[schema(example = json!(Option::None::<f32>))]
611 pub dry_base: Option<f32>,
612 #[schema(example = json!(Option::None::<usize>))]
613 pub dry_allowed_length: Option<usize>,
614 #[schema(example = json!(Option::None::<String>))]
615 pub dry_sequence_breakers: Option<Vec<String>>,
616}
617
618#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
620pub struct ImageGenerationRequest {
621 #[schema(example = "mistral")]
622 #[serde(default = "default_model")]
623 pub model: String,
624 #[schema(example = "Draw a picture of a majestic, snow-covered mountain.")]
625 pub prompt: String,
626 #[serde(rename = "n")]
627 #[serde(default = "default_1usize")]
628 #[schema(example = 1)]
629 pub n_choices: usize,
630 #[serde(default = "default_response_format")]
631 pub response_format: ImageGenerationResponseFormat,
632 #[serde(default = "default_720usize")]
633 #[schema(example = 720)]
634 pub height: usize,
635 #[serde(default = "default_1280usize")]
636 #[schema(example = 1280)]
637 pub width: usize,
638}
639
640#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, ToSchema)]
642#[serde(rename_all = "lowercase")]
643pub enum AudioResponseFormat {
644 #[default]
646 Mp3,
647 Opus,
649 Aac,
651 Flac,
653 Wav,
655 Pcm,
657}
658
659impl AudioResponseFormat {
660 pub fn audio_content_type(
662 &self,
663 pcm_rate: usize,
664 pcm_channels: usize,
665 pcm_format: &'static str,
666 ) -> String {
667 let content_type = match &self {
668 AudioResponseFormat::Mp3 => "audio/mpeg".to_string(),
669 AudioResponseFormat::Opus => "audio/ogg; codecs=opus".to_string(),
670 AudioResponseFormat::Aac => "audio/aac".to_string(),
671 AudioResponseFormat::Flac => "audio/flac".to_string(),
672 AudioResponseFormat::Wav => "audio/wav".to_string(),
673 AudioResponseFormat::Pcm => format!("audio/pcm; codecs=1; format={pcm_format}"),
674 };
675
676 format!("{content_type}; rate={pcm_rate}; channels={pcm_channels}")
677 }
678}
679
680#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
682pub struct SpeechGenerationRequest {
683 #[schema(example = "nari-labs/Dia-1.6B")]
685 #[serde(default = "default_model")]
686 pub model: String,
687 #[schema(
689 example = "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."
690 )]
691 pub input: String,
692 #[schema(example = "mp3")]
695 pub response_format: AudioResponseFormat,
696}