mistralrs_server_core/
openai.rs

1//! ## OpenAI compatible functionality.
2
3use std::{collections::HashMap, ops::Deref};
4
5use either::Either;
6use mistralrs_core::{
7    ImageGenerationResponseFormat, LlguidanceGrammar, Tool, ToolChoice, ToolType, WebSearchOptions,
8};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use utoipa::{
12    openapi::{schema::SchemaType, ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, Type},
13    PartialSchema, ToSchema,
14};
15
16/// Inner content structure for messages that can be either a string or key-value pairs
17#[derive(Debug, Clone, Deserialize, Serialize)]
18pub struct MessageInnerContent(
19    #[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
20);
21
22// The impl Deref was preventing the Derive ToSchema and #[schema] macros from
23// properly working, so manually impl ToSchema
24impl PartialSchema for MessageInnerContent {
25    fn schema() -> RefOr<Schema> {
26        RefOr::T(message_inner_content_schema())
27    }
28}
29
30impl ToSchema for MessageInnerContent {
31    fn schemas(
32        schemas: &mut Vec<(
33            String,
34            utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
35        )>,
36    ) {
37        schemas.push((
38            MessageInnerContent::name().into(),
39            MessageInnerContent::schema(),
40        ));
41    }
42}
43
44impl Deref for MessageInnerContent {
45    type Target = Either<String, HashMap<String, String>>;
46    fn deref(&self) -> &Self::Target {
47        &self.0
48    }
49}
50
51/// Function for MessageInnerContent Schema generation to handle `Either`
52fn message_inner_content_schema() -> Schema {
53    Schema::OneOf(
54        OneOfBuilder::new()
55            // Either::Left - simple string
56            .item(Schema::Object(
57                ObjectBuilder::new()
58                    .schema_type(SchemaType::Type(Type::String))
59                    .build(),
60            ))
61            // Either::Right - object with string values
62            .item(Schema::Object(
63                ObjectBuilder::new()
64                    .schema_type(SchemaType::Type(Type::Object))
65                    .additional_properties(Some(RefOr::T(Schema::Object(
66                        ObjectBuilder::new()
67                            .schema_type(SchemaType::Type(Type::String))
68                            .build(),
69                    ))))
70                    .build(),
71            ))
72            .build(),
73    )
74}
75
76/// Message content that can be either simple text or complex structured content
77#[derive(Debug, Clone, Deserialize, Serialize)]
78pub struct MessageContent(
79    #[serde(with = "either::serde_untagged")]
80    Either<String, Vec<HashMap<String, MessageInnerContent>>>,
81);
82
83// The impl Deref was preventing the Derive ToSchema and #[schema] macros from
84// properly working, so manually impl ToSchema
85impl PartialSchema for MessageContent {
86    fn schema() -> RefOr<Schema> {
87        RefOr::T(message_content_schema())
88    }
89}
90
91impl ToSchema for MessageContent {
92    fn schemas(
93        schemas: &mut Vec<(
94            String,
95            utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
96        )>,
97    ) {
98        schemas.push((MessageContent::name().into(), MessageContent::schema()));
99    }
100}
101
102impl MessageContent {
103    /// Create a new MessageContent from a string
104    pub fn from_text(text: String) -> Self {
105        MessageContent(Either::Left(text))
106    }
107
108    /// Extract text from MessageContent
109    pub fn to_text(&self) -> Option<String> {
110        match &self.0 {
111            Either::Left(text) => Some(text.clone()),
112            Either::Right(parts) => {
113                // For complex content, try to extract text from parts
114                let mut text_parts = Vec::new();
115                for part in parts {
116                    for (key, value) in part {
117                        if key == "text" {
118                            if let Either::Left(text) = &**value {
119                                text_parts.push(text.clone());
120                            }
121                        }
122                    }
123                }
124                if text_parts.is_empty() {
125                    None
126                } else {
127                    Some(text_parts.join(" "))
128                }
129            }
130        }
131    }
132}
133
134impl Deref for MessageContent {
135    type Target = Either<String, Vec<HashMap<String, MessageInnerContent>>>;
136    fn deref(&self) -> &Self::Target {
137        &self.0
138    }
139}
140
141/// Function for MessageContent Schema generation to handle `Either`
142fn message_content_schema() -> Schema {
143    Schema::OneOf(
144        OneOfBuilder::new()
145            .item(Schema::Object(
146                ObjectBuilder::new()
147                    .schema_type(SchemaType::Type(Type::String))
148                    .build(),
149            ))
150            .item(Schema::Array(
151                ArrayBuilder::new()
152                    .items(RefOr::T(Schema::Object(
153                        ObjectBuilder::new()
154                            .schema_type(SchemaType::Type(Type::Object))
155                            .additional_properties(Some(RefOr::Ref(
156                                utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
157                            )))
158                            .build(),
159                    )))
160                    .build(),
161            ))
162            .build(),
163    )
164}
165
166/// Represents a function call made by the assistant
167///
168/// When using tool calling, this structure contains the details of a function
169/// that the model has decided to call, including the function name and its parameters.
170#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, ToSchema)]
171pub struct FunctionCalled {
172    /// The name of the function to call
173    pub name: String,
174    /// The function arguments
175    #[serde(alias = "arguments")]
176    pub parameters: String,
177}
178
179/// Represents a tool call made by the assistant
180///
181/// This structure wraps a function call with its type information.
182#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, ToSchema)]
183pub struct ToolCall {
184    /// The type of tool being called
185    #[serde(rename = "type")]
186    pub tp: ToolType,
187    ///  The function call details
188    pub function: FunctionCalled,
189}
190
191/// Represents a single message in a conversation
192///
193/// ### Examples
194///
195/// ```ignore
196/// use either::Either;
197/// use mistralrs_server_core::openai::{Message, MessageContent};
198///
199/// // User message
200/// let user_msg = Message {
201///     content: Some(MessageContent(Either::Left("What's 2+2?".to_string()))),
202///     role: "user".to_string(),
203///     name: None,
204///     tool_calls: None,
205/// };
206///
207/// // System message
208/// let system_msg = Message {
209///     content: Some(MessageContent(Either::Left("You are a helpful assistant.".to_string()))),
210///     role: "system".to_string(),
211///     name: None,
212///     tool_calls: None,
213/// };
214/// ```
215#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
216pub struct Message {
217    /// The message content
218    pub content: Option<MessageContent>,
219    /// The role of the message sender ("user", "assistant", "system", "tool", etc.)
220    pub role: String,
221    pub name: Option<String>,
222    /// Optional list of tool calls
223    pub tool_calls: Option<Vec<ToolCall>>,
224}
225
226/// Stop token configuration for generation
227///
228/// Defines when the model should stop generating text, either with a single
229/// stop token or multiple possible stop sequences.
230#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
231#[serde(untagged)]
232pub enum StopTokens {
233    ///  Multiple possible stop sequences
234    Multi(Vec<String>),
235    /// Single stop sequence
236    Single(String),
237}
238
239/// Default value helper
240fn default_false() -> bool {
241    false
242}
243
244/// Default value helper
245fn default_1usize() -> usize {
246    1
247}
248
249/// Default value helper
250fn default_720usize() -> usize {
251    720
252}
253
254/// Default value helper
255fn default_1280usize() -> usize {
256    1280
257}
258
259/// Default value helper
260fn default_model() -> String {
261    "default".to_string()
262}
263
264/// Default value helper
265fn default_response_format() -> ImageGenerationResponseFormat {
266    ImageGenerationResponseFormat::Url
267}
268
269/// Grammar specification for structured generation
270///
271/// Defines different types of grammars that can be used to constrain model output,
272/// ensuring it follows specific formats or structures.
273///
274/// ### Examples
275///
276/// ```ignore
277/// use mistralrs_server_core::openai::Grammar;
278///
279/// // Regex grammar for phone numbers
280/// let phone_regex = Grammar::Regex(r"\d{3}-\d{3}-\d{4}".to_string());
281///
282/// // JSON schema for structured data
283/// let json_schema = Grammar::JsonSchema(serde_json::json!({
284///     "type": "object",
285///     "properties": {
286///         "name": {"type": "string"},
287///         "age": {"type": "integer"}
288///     },
289///     "required": ["name", "age"]
290/// }));
291///
292/// // Lark grammar for arithmetic expressions
293/// let lark_grammar = Grammar::Lark(r#"
294///     ?start: expr
295///     expr: term ("+" term | "-" term)*
296///     term: factor ("*" factor | "/" factor)*
297///     factor: NUMBER | "(" expr ")"
298///     %import common.NUMBER
299/// "#.to_string());
300/// ```
301#[derive(Debug, Clone, Deserialize, Serialize)]
302#[serde(tag = "type", content = "value")]
303pub enum Grammar {
304    /// Regular expression grammar
305    #[serde(rename = "regex")]
306    Regex(String),
307    /// JSON schema grammar
308    #[serde(rename = "json_schema")]
309    JsonSchema(serde_json::Value),
310    /// LLGuidance grammar
311    #[serde(rename = "llguidance")]
312    Llguidance(LlguidanceGrammar),
313    /// Lark parser grammar
314    #[serde(rename = "lark")]
315    Lark(String),
316}
317
318// Implement ToSchema manually to handle `LlguidanceGrammar`
319impl PartialSchema for Grammar {
320    fn schema() -> RefOr<Schema> {
321        RefOr::T(Schema::OneOf(
322            OneOfBuilder::new()
323                .item(create_grammar_variant_schema(
324                    "regex",
325                    Schema::Object(
326                        ObjectBuilder::new()
327                            .schema_type(SchemaType::Type(Type::String))
328                            .build(),
329                    ),
330                ))
331                .item(create_grammar_variant_schema(
332                    "json_schema",
333                    Schema::Object(
334                        ObjectBuilder::new()
335                            .schema_type(SchemaType::Type(Type::Object))
336                            .build(),
337                    ),
338                ))
339                .item(create_grammar_variant_schema(
340                    "llguidance",
341                    llguidance_schema(),
342                ))
343                .item(create_grammar_variant_schema(
344                    "lark",
345                    Schema::Object(
346                        ObjectBuilder::new()
347                            .schema_type(SchemaType::Type(Type::String))
348                            .build(),
349                    ),
350                ))
351                .build(),
352        ))
353    }
354}
355
356impl ToSchema for Grammar {
357    fn schemas(
358        schemas: &mut Vec<(
359            String,
360            utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
361        )>,
362    ) {
363        schemas.push((Grammar::name().into(), Grammar::schema()));
364    }
365}
366
367/// Helper function to create a grammar variant schema
368fn create_grammar_variant_schema(type_value: &str, value_schema: Schema) -> Schema {
369    Schema::Object(
370        ObjectBuilder::new()
371            .schema_type(SchemaType::Type(Type::Object))
372            .property(
373                "type",
374                RefOr::T(Schema::Object(
375                    ObjectBuilder::new()
376                        .schema_type(SchemaType::Type(Type::String))
377                        .enum_values(Some(vec![serde_json::Value::String(
378                            type_value.to_string(),
379                        )]))
380                        .build(),
381                )),
382            )
383            .property("value", RefOr::T(value_schema))
384            .required("type")
385            .required("value")
386            .build(),
387    )
388}
389
390/// Helper function to generate LLGuidance schema
391fn llguidance_schema() -> Schema {
392    let grammar_with_lexer_schema = Schema::Object(
393        ObjectBuilder::new()
394            .schema_type(SchemaType::Type(Type::Object))
395            .property(
396                "name",
397                RefOr::T(Schema::Object(
398                    ObjectBuilder::new()
399                        .schema_type(SchemaType::from_iter([Type::String, Type::Null]))
400                        .description(Some(
401                            "The name of this grammar, can be used in GenGrammar nodes",
402                        ))
403                        .build(),
404                )),
405            )
406            .property(
407                "json_schema",
408                RefOr::T(Schema::Object(
409                    ObjectBuilder::new()
410                        .schema_type(SchemaType::from_iter([Type::Object, Type::Null]))
411                        .description(Some("The JSON schema that the grammar should generate"))
412                        .build(),
413                )),
414            )
415            .property(
416                "lark_grammar",
417                RefOr::T(Schema::Object(
418                    ObjectBuilder::new()
419                        .schema_type(SchemaType::from_iter([Type::String, Type::Null]))
420                        .description(Some("The Lark grammar that the grammar should generate"))
421                        .build(),
422                )),
423            )
424            .description(Some("Grammar configuration with lexer settings"))
425            .build(),
426    );
427
428    Schema::Object(
429        ObjectBuilder::new()
430            .schema_type(SchemaType::Type(Type::Object))
431            .property(
432                "grammars",
433                RefOr::T(Schema::Array(
434                    ArrayBuilder::new()
435                        .items(RefOr::T(grammar_with_lexer_schema))
436                        .description(Some("List of grammar configurations"))
437                        .build(),
438                )),
439            )
440            .property(
441                "max_tokens",
442                RefOr::T(Schema::Object(
443                    ObjectBuilder::new()
444                        .schema_type(SchemaType::from_iter([Type::Integer, Type::Null]))
445                        .description(Some("Maximum number of tokens to generate"))
446                        .build(),
447                )),
448            )
449            .required("grammars")
450            .description(Some("Top-level grammar configuration for LLGuidance"))
451            .build(),
452    )
453}
454
455/// JSON Schema for structured responses
456#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
457pub struct JsonSchemaResponseFormat {
458    pub name: String,
459    pub schema: serde_json::Value,
460}
461
462/// Response format for model output
463#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
464#[serde(tag = "type")]
465pub enum ResponseFormat {
466    /// Free-form text response
467    #[serde(rename = "text")]
468    Text,
469    /// Structured response following a JSON schema
470    #[serde(rename = "json_schema")]
471    JsonSchema {
472        json_schema: JsonSchemaResponseFormat,
473    },
474}
475
476/// Chat completion request following OpenAI's specification
477#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
478pub struct ChatCompletionRequest {
479    #[schema(
480        schema_with = messages_schema,
481        example = json!(vec![Message{content:Some(MessageContent{0: either::Left(("Why did the crab cross the road?".to_string()))}), role:"user".to_string(), name: None, tool_calls: None}])
482    )]
483    #[serde(with = "either::serde_untagged")]
484    pub messages: Either<Vec<Message>, String>,
485    #[schema(example = "mistral")]
486    #[serde(default = "default_model")]
487    pub model: String,
488    #[schema(example = json!(Option::None::<HashMap<u32, f32>>))]
489    pub logit_bias: Option<HashMap<u32, f32>>,
490    #[serde(default = "default_false")]
491    #[schema(example = false)]
492    pub logprobs: bool,
493    #[schema(example = json!(Option::None::<usize>))]
494    pub top_logprobs: Option<usize>,
495    #[schema(example = 256)]
496    #[serde(alias = "max_completion_tokens")]
497    pub max_tokens: Option<usize>,
498    #[serde(rename = "n")]
499    #[serde(default = "default_1usize")]
500    #[schema(example = 1)]
501    pub n_choices: usize,
502    #[schema(example = json!(Option::None::<f32>))]
503    pub presence_penalty: Option<f32>,
504    #[schema(example = json!(Option::None::<f32>))]
505    pub frequency_penalty: Option<f32>,
506    #[schema(example = json!(Option::None::<f32>))]
507    pub repetition_penalty: Option<f32>,
508    #[serde(rename = "stop")]
509    #[schema(example = json!(Option::None::<StopTokens>))]
510    pub stop_seqs: Option<StopTokens>,
511    #[schema(example = 0.7)]
512    pub temperature: Option<f64>,
513    #[schema(example = json!(Option::None::<f64>))]
514    pub top_p: Option<f64>,
515    #[schema(example = true)]
516    pub stream: Option<bool>,
517    #[schema(example = json!(Option::None::<Vec<Tool>>))]
518    pub tools: Option<Vec<Tool>>,
519    #[schema(example = json!(Option::None::<ToolChoice>))]
520    pub tool_choice: Option<ToolChoice>,
521    #[schema(example = json!(Option::None::<ResponseFormat>))]
522    pub response_format: Option<ResponseFormat>,
523    #[schema(example = json!(Option::None::<WebSearchOptions>))]
524    pub web_search_options: Option<WebSearchOptions>,
525
526    // mistral.rs additional
527    #[schema(example = json!(Option::None::<usize>))]
528    pub top_k: Option<usize>,
529    #[schema(example = json!(Option::None::<Grammar>))]
530    pub grammar: Option<Grammar>,
531    #[schema(example = json!(Option::None::<f64>))]
532    pub min_p: Option<f64>,
533    #[schema(example = json!(Option::None::<f32>))]
534    pub dry_multiplier: Option<f32>,
535    #[schema(example = json!(Option::None::<f32>))]
536    pub dry_base: Option<f32>,
537    #[schema(example = json!(Option::None::<usize>))]
538    pub dry_allowed_length: Option<usize>,
539    #[schema(example = json!(Option::None::<String>))]
540    pub dry_sequence_breakers: Option<Vec<String>>,
541    #[schema(example = json!(Option::None::<bool>))]
542    pub enable_thinking: Option<bool>,
543    #[schema(example = json!(Option::None::<bool>))]
544    #[serde(default)]
545    pub truncate_sequence: Option<bool>,
546}
547
548/// Function for ChatCompletionRequest.messages Schema generation to handle `Either`
549fn messages_schema() -> Schema {
550    Schema::OneOf(
551        OneOfBuilder::new()
552            .item(Schema::Array(
553                ArrayBuilder::new()
554                    .items(RefOr::Ref(utoipa::openapi::Ref::from_schema_name(
555                        "Message",
556                    )))
557                    .build(),
558            ))
559            .item(Schema::Object(
560                ObjectBuilder::new()
561                    .schema_type(SchemaType::Type(Type::String))
562                    .build(),
563            ))
564            .build(),
565    )
566}
567
568/// Model information metadata about an available mode
569#[derive(Debug, Serialize, ToSchema)]
570pub struct ModelObject {
571    pub id: String,
572    pub object: &'static str,
573    pub created: u64,
574    pub owned_by: &'static str,
575    /// Whether tools are available through MCP or tool callbacks
576    #[serde(skip_serializing_if = "Option::is_none")]
577    pub tools_available: Option<bool>,
578    /// Number of tools available from MCP servers
579    #[serde(skip_serializing_if = "Option::is_none")]
580    pub mcp_tools_count: Option<usize>,
581    /// Number of connected MCP servers
582    #[serde(skip_serializing_if = "Option::is_none")]
583    pub mcp_servers_connected: Option<usize>,
584}
585
586/// Collection of available models
587#[derive(Debug, Serialize, ToSchema)]
588pub struct ModelObjects {
589    pub object: &'static str,
590    pub data: Vec<ModelObject>,
591}
592
593/// Legacy OpenAI compatible text completion request
594#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
595pub struct CompletionRequest {
596    #[schema(example = "mistral")]
597    #[serde(default = "default_model")]
598    pub model: String,
599    #[schema(example = "Say this is a test.")]
600    pub prompt: String,
601    #[schema(example = 1)]
602    pub best_of: Option<usize>,
603    #[serde(rename = "echo")]
604    #[serde(default = "default_false")]
605    #[schema(example = false)]
606    pub echo_prompt: bool,
607    #[schema(example = json!(Option::None::<f32>))]
608    pub presence_penalty: Option<f32>,
609    #[schema(example = json!(Option::None::<f32>))]
610    pub frequency_penalty: Option<f32>,
611    #[schema(example = json!(Option::None::<HashMap<u32, f32>>))]
612    pub logit_bias: Option<HashMap<u32, f32>>,
613    #[schema(example = json!(Option::None::<usize>))]
614    pub logprobs: Option<usize>,
615    #[schema(example = 16)]
616    #[serde(alias = "max_completion_tokens")]
617    pub max_tokens: Option<usize>,
618    #[serde(rename = "n")]
619    #[serde(default = "default_1usize")]
620    #[schema(example = 1)]
621    pub n_choices: usize,
622    #[serde(rename = "stop")]
623    #[schema(example = json!(Option::None::<StopTokens>))]
624    pub stop_seqs: Option<StopTokens>,
625    pub stream: Option<bool>,
626    #[schema(example = 0.7)]
627    pub temperature: Option<f64>,
628    #[schema(example = json!(Option::None::<f64>))]
629    pub top_p: Option<f64>,
630    #[schema(example = json!(Option::None::<String>))]
631    pub suffix: Option<String>,
632    #[serde(rename = "user")]
633    pub _user: Option<String>,
634    #[schema(example = json!(Option::None::<Vec<Tool>>))]
635    pub tools: Option<Vec<Tool>>,
636    #[schema(example = json!(Option::None::<ToolChoice>))]
637    pub tool_choice: Option<ToolChoice>,
638
639    // mistral.rs additional
640    #[schema(example = json!(Option::None::<usize>))]
641    pub top_k: Option<usize>,
642    #[schema(example = json!(Option::None::<Grammar>))]
643    pub grammar: Option<Grammar>,
644    #[schema(example = json!(Option::None::<f64>))]
645    pub min_p: Option<f64>,
646    #[schema(example = json!(Option::None::<f32>))]
647    pub repetition_penalty: Option<f32>,
648    #[schema(example = json!(Option::None::<f32>))]
649    pub dry_multiplier: Option<f32>,
650    #[schema(example = json!(Option::None::<f32>))]
651    pub dry_base: Option<f32>,
652    #[schema(example = json!(Option::None::<usize>))]
653    pub dry_allowed_length: Option<usize>,
654    #[schema(example = json!(Option::None::<String>))]
655    pub dry_sequence_breakers: Option<Vec<String>>,
656    #[schema(example = json!(Option::None::<bool>))]
657    #[serde(default)]
658    pub truncate_sequence: Option<bool>,
659}
660
661#[derive(Debug, Clone, Deserialize, Serialize)]
662#[serde(untagged)]
663pub enum EmbeddingInput {
664    Single(String),
665    Multiple(Vec<String>),
666    Tokens(Vec<u32>),
667    TokensBatch(Vec<Vec<u32>>),
668}
669
670impl PartialSchema for EmbeddingInput {
671    fn schema() -> RefOr<Schema> {
672        RefOr::T(embedding_input_schema())
673    }
674}
675
676impl ToSchema for EmbeddingInput {
677    fn schemas(
678        schemas: &mut Vec<(
679            String,
680            utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
681        )>,
682    ) {
683        schemas.push((EmbeddingInput::name().into(), EmbeddingInput::schema()));
684    }
685}
686
687fn embedding_input_schema() -> Schema {
688    Schema::OneOf(
689        OneOfBuilder::new()
690            .item(Schema::Object(
691                ObjectBuilder::new()
692                    .schema_type(SchemaType::Type(Type::String))
693                    .description(Some("Single input string"))
694                    .build(),
695            ))
696            .item(Schema::Array(
697                ArrayBuilder::new()
698                    .items(RefOr::T(Schema::Object(
699                        ObjectBuilder::new()
700                            .schema_type(SchemaType::Type(Type::String))
701                            .build(),
702                    )))
703                    .description(Some("Multiple input strings"))
704                    .build(),
705            ))
706            .item(Schema::Array(
707                ArrayBuilder::new()
708                    .items(RefOr::T(Schema::Object(
709                        ObjectBuilder::new()
710                            .schema_type(SchemaType::Type(Type::Integer))
711                            .build(),
712                    )))
713                    .description(Some("Single token array"))
714                    .build(),
715            ))
716            .item(Schema::Array(
717                ArrayBuilder::new()
718                    .items(RefOr::T(Schema::Array(
719                        ArrayBuilder::new()
720                            .items(RefOr::T(Schema::Object(
721                                ObjectBuilder::new()
722                                    .schema_type(SchemaType::Type(Type::Integer))
723                                    .build(),
724                            )))
725                            .build(),
726                    )))
727                    .description(Some("Multiple token arrays"))
728                    .build(),
729            ))
730            .build(),
731    )
732}
733
734#[derive(Debug, Clone, Deserialize, Serialize, ToSchema, Default)]
735#[serde(rename_all = "snake_case")]
736pub enum EmbeddingEncodingFormat {
737    #[default]
738    Float,
739    Base64,
740}
741
742#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
743pub struct EmbeddingRequest {
744    #[schema(example = "default")]
745    #[serde(default = "default_model")]
746    pub model: String,
747    pub input: EmbeddingInput,
748    #[schema(example = "float")]
749    #[serde(default)]
750    pub encoding_format: Option<EmbeddingEncodingFormat>,
751    #[schema(example = json!(Option::None::<usize>))]
752    pub dimensions: Option<usize>,
753    #[schema(example = json!(Option::None::<String>))]
754    #[serde(rename = "user")]
755    pub _user: Option<String>,
756
757    // mistral.rs additional
758    #[schema(example = json!(Option::None::<bool>))]
759    #[serde(default)]
760    pub truncate_sequence: Option<bool>,
761}
762
763#[derive(Debug, Clone, Serialize, ToSchema)]
764pub struct EmbeddingUsage {
765    pub prompt_tokens: u32,
766    pub total_tokens: u32,
767}
768
769#[derive(Debug, Clone, Serialize)]
770#[serde(untagged)]
771pub enum EmbeddingVector {
772    Float(Vec<f32>),
773    Base64(String),
774}
775
776impl PartialSchema for EmbeddingVector {
777    fn schema() -> RefOr<Schema> {
778        RefOr::T(embedding_vector_schema())
779    }
780}
781
782impl ToSchema for EmbeddingVector {
783    fn schemas(
784        schemas: &mut Vec<(
785            String,
786            utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
787        )>,
788    ) {
789        schemas.push((EmbeddingVector::name().into(), EmbeddingVector::schema()));
790    }
791}
792
793fn embedding_vector_schema() -> Schema {
794    Schema::OneOf(
795        OneOfBuilder::new()
796            .item(Schema::Array(
797                ArrayBuilder::new()
798                    .items(RefOr::T(Schema::Object(
799                        ObjectBuilder::new()
800                            .schema_type(SchemaType::Type(Type::Number))
801                            .build(),
802                    )))
803                    .description(Some("Embedding returned as an array of floats"))
804                    .build(),
805            ))
806            .item(Schema::Object(
807                ObjectBuilder::new()
808                    .schema_type(SchemaType::Type(Type::String))
809                    .description(Some("Embedding returned as a base64-encoded string"))
810                    .build(),
811            ))
812            .build(),
813    )
814}
815
816#[derive(Debug, Clone, Serialize, ToSchema)]
817pub struct EmbeddingData {
818    pub object: &'static str,
819    pub embedding: EmbeddingVector,
820    pub index: usize,
821}
822
823#[derive(Debug, Clone, Serialize, ToSchema)]
824pub struct EmbeddingResponse {
825    pub object: &'static str,
826    pub data: Vec<EmbeddingData>,
827    pub model: String,
828    pub usage: EmbeddingUsage,
829}
830
831/// Image generation request
832#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
833pub struct ImageGenerationRequest {
834    #[schema(example = "mistral")]
835    #[serde(default = "default_model")]
836    pub model: String,
837    #[schema(example = "Draw a picture of a majestic, snow-covered mountain.")]
838    pub prompt: String,
839    #[serde(rename = "n")]
840    #[serde(default = "default_1usize")]
841    #[schema(example = 1)]
842    pub n_choices: usize,
843    #[serde(default = "default_response_format")]
844    pub response_format: ImageGenerationResponseFormat,
845    #[serde(default = "default_720usize")]
846    #[schema(example = 720)]
847    pub height: usize,
848    #[serde(default = "default_1280usize")]
849    #[schema(example = 1280)]
850    pub width: usize,
851}
852
853/// Audio format options for speech generation responses.
854#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, ToSchema)]
855#[serde(rename_all = "lowercase")]
856pub enum AudioResponseFormat {
857    /// Widely compatible, lossy compression, good for web streaming
858    #[default]
859    Mp3,
860    /// Good compression efficiency, ideal for real-time communication
861    Opus,
862    /// High-quality lossy compression, commonly used in mobile applications
863    Aac,
864    /// Lossless compression, larger file sizes but good audio quality
865    Flac,
866    /// Uncompressed, largest file sizes but maximum compatibility
867    Wav,
868    ///  Raw audio data, requires additional format specification
869    Pcm,
870}
871
872impl AudioResponseFormat {
873    /// Generate the appropriate MIME content type string for this audio format.
874    pub fn audio_content_type(
875        &self,
876        pcm_rate: usize,
877        pcm_channels: usize,
878        pcm_format: &'static str,
879    ) -> String {
880        let content_type = match &self {
881            AudioResponseFormat::Mp3 => "audio/mpeg".to_string(),
882            AudioResponseFormat::Opus => "audio/ogg; codecs=opus".to_string(),
883            AudioResponseFormat::Aac => "audio/aac".to_string(),
884            AudioResponseFormat::Flac => "audio/flac".to_string(),
885            AudioResponseFormat::Wav => "audio/wav".to_string(),
886            AudioResponseFormat::Pcm => format!("audio/pcm; codecs=1; format={pcm_format}"),
887        };
888
889        format!("{content_type}; rate={pcm_rate}; channels={pcm_channels}")
890    }
891}
892
893/// Speech generation request
894#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
895pub struct SpeechGenerationRequest {
896    /// The TTS model to use for audio generation.
897    #[schema(example = "nari-labs/Dia-1.6B")]
898    #[serde(default = "default_model")]
899    pub model: String,
900    /// The text content to convert to speech.
901    #[schema(
902        example = "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."
903    )]
904    pub input: String,
905    // `voice` and `instructions` are ignored.
906    /// The desired audio format for the generated speech.
907    #[schema(example = "mp3")]
908    pub response_format: AudioResponseFormat,
909}
910
911/// Helper type for messages field in ResponsesCreateRequest
912#[derive(Debug, Clone, Deserialize, Serialize)]
913#[serde(untagged)]
914pub enum ResponsesMessages {
915    Messages(Vec<Message>),
916    String(String),
917}
918
919impl ResponsesMessages {
920    pub fn into_either(self) -> Either<Vec<Message>, String> {
921        match self {
922            ResponsesMessages::Messages(msgs) => Either::Left(msgs),
923            ResponsesMessages::String(s) => Either::Right(s),
924        }
925    }
926}
927
928impl PartialSchema for ResponsesMessages {
929    fn schema() -> RefOr<Schema> {
930        RefOr::T(messages_schema())
931    }
932}
933
934impl ToSchema for ResponsesMessages {
935    fn schemas(
936        schemas: &mut Vec<(
937            String,
938            utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
939        )>,
940    ) {
941        schemas.push((
942            ResponsesMessages::name().into(),
943            ResponsesMessages::schema(),
944        ));
945    }
946}
947
948/// Response creation request
949#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
950pub struct ResponsesCreateRequest {
951    #[schema(example = "mistral")]
952    #[serde(default = "default_model")]
953    pub model: String,
954    pub input: ResponsesMessages,
955    #[schema(example = json!(Option::None::<String>))]
956    pub instructions: Option<String>,
957    #[schema(example = json!(Option::None::<Vec<String>>))]
958    pub modalities: Option<Vec<String>>,
959    #[schema(example = json!(Option::None::<String>))]
960    pub previous_response_id: Option<String>,
961    #[schema(example = json!(Option::None::<HashMap<u32, f32>>))]
962    pub logit_bias: Option<HashMap<u32, f32>>,
963    #[serde(default = "default_false")]
964    #[schema(example = false)]
965    pub logprobs: bool,
966    #[schema(example = json!(Option::None::<usize>))]
967    pub top_logprobs: Option<usize>,
968    #[schema(example = 256)]
969    #[serde(alias = "max_completion_tokens", alias = "max_output_tokens")]
970    pub max_tokens: Option<usize>,
971    #[serde(rename = "n")]
972    #[serde(default = "default_1usize")]
973    #[schema(example = 1)]
974    pub n_choices: usize,
975    #[schema(example = json!(Option::None::<f32>))]
976    pub presence_penalty: Option<f32>,
977    #[schema(example = json!(Option::None::<f32>))]
978    pub frequency_penalty: Option<f32>,
979    #[serde(rename = "stop")]
980    #[schema(example = json!(Option::None::<StopTokens>))]
981    pub stop_seqs: Option<StopTokens>,
982    #[schema(example = 0.7)]
983    pub temperature: Option<f64>,
984    #[schema(example = json!(Option::None::<f64>))]
985    pub top_p: Option<f64>,
986    #[schema(example = false)]
987    pub stream: Option<bool>,
988    #[schema(example = json!(Option::None::<Vec<Tool>>))]
989    pub tools: Option<Vec<Tool>>,
990    #[schema(example = json!(Option::None::<ToolChoice>))]
991    pub tool_choice: Option<ToolChoice>,
992    #[schema(example = json!(Option::None::<ResponseFormat>))]
993    pub response_format: Option<ResponseFormat>,
994    #[schema(example = json!(Option::None::<WebSearchOptions>))]
995    pub web_search_options: Option<WebSearchOptions>,
996    #[schema(example = json!(Option::None::<Value>))]
997    pub metadata: Option<Value>,
998    #[schema(example = json!(Option::None::<bool>))]
999    pub output_token_details: Option<bool>,
1000    #[schema(example = json!(Option::None::<bool>))]
1001    pub parallel_tool_calls: Option<bool>,
1002    #[schema(example = json!(Option::None::<bool>))]
1003    pub store: Option<bool>,
1004    #[schema(example = json!(Option::None::<usize>))]
1005    pub max_tool_calls: Option<usize>,
1006    #[schema(example = json!(Option::None::<bool>))]
1007    pub reasoning_enabled: Option<bool>,
1008    #[schema(example = json!(Option::None::<usize>))]
1009    pub reasoning_max_tokens: Option<usize>,
1010    #[schema(example = json!(Option::None::<usize>))]
1011    pub reasoning_top_logprobs: Option<usize>,
1012    #[schema(example = json!(Option::None::<Vec<String>>))]
1013    pub truncation: Option<HashMap<String, Value>>,
1014
1015    // mistral.rs additional
1016    #[schema(example = json!(Option::None::<usize>))]
1017    pub top_k: Option<usize>,
1018    #[schema(example = json!(Option::None::<Grammar>))]
1019    pub grammar: Option<Grammar>,
1020    #[schema(example = json!(Option::None::<f64>))]
1021    pub min_p: Option<f64>,
1022    #[schema(example = json!(Option::None::<f32>))]
1023    pub repetition_penalty: Option<f32>,
1024    #[schema(example = json!(Option::None::<f32>))]
1025    pub dry_multiplier: Option<f32>,
1026    #[schema(example = json!(Option::None::<f32>))]
1027    pub dry_base: Option<f32>,
1028    #[schema(example = json!(Option::None::<usize>))]
1029    pub dry_allowed_length: Option<usize>,
1030    #[schema(example = json!(Option::None::<String>))]
1031    pub dry_sequence_breakers: Option<Vec<String>>,
1032    #[schema(example = json!(Option::None::<bool>))]
1033    pub enable_thinking: Option<bool>,
1034    #[schema(example = json!(Option::None::<bool>))]
1035    #[serde(default)]
1036    pub truncate_sequence: Option<bool>,
1037}
1038
1039/// Response object
1040#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
1041pub struct ResponsesObject {
1042    pub id: String,
1043    pub object: &'static str,
1044    pub created_at: f64,
1045    pub model: String,
1046    pub status: String,
1047    pub output: Vec<ResponsesOutput>,
1048    pub output_text: Option<String>,
1049    pub usage: Option<ResponsesUsage>,
1050    pub error: Option<ResponsesError>,
1051    pub metadata: Option<Value>,
1052    pub instructions: Option<String>,
1053    pub incomplete_details: Option<ResponsesIncompleteDetails>,
1054}
1055
1056/// Response usage information
1057#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
1058pub struct ResponsesUsage {
1059    pub input_tokens: usize,
1060    pub output_tokens: usize,
1061    pub total_tokens: usize,
1062    pub input_tokens_details: Option<ResponsesInputTokensDetails>,
1063    pub output_tokens_details: Option<ResponsesOutputTokensDetails>,
1064}
1065
1066/// Input tokens details
1067#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
1068pub struct ResponsesInputTokensDetails {
1069    pub audio_tokens: Option<usize>,
1070    pub cached_tokens: Option<usize>,
1071    pub image_tokens: Option<usize>,
1072    pub text_tokens: Option<usize>,
1073}
1074
1075/// Output tokens details
1076#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
1077pub struct ResponsesOutputTokensDetails {
1078    pub audio_tokens: Option<usize>,
1079    pub text_tokens: Option<usize>,
1080    pub reasoning_tokens: Option<usize>,
1081}
1082
1083/// Response error
1084#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
1085pub struct ResponsesError {
1086    #[serde(rename = "type")]
1087    pub error_type: String,
1088    pub message: String,
1089}
1090
1091/// Incomplete details for incomplete responses
1092#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
1093pub struct ResponsesIncompleteDetails {
1094    pub reason: String,
1095}
1096
1097/// Response output item
1098#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
1099pub struct ResponsesOutput {
1100    pub id: String,
1101    #[serde(rename = "type")]
1102    pub output_type: String,
1103    pub role: String,
1104    pub status: Option<String>,
1105    pub content: Vec<ResponsesContent>,
1106}
1107
1108/// Response content item
1109#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
1110pub struct ResponsesContent {
1111    #[serde(rename = "type")]
1112    pub content_type: String,
1113    pub text: Option<String>,
1114    pub annotations: Option<Vec<ResponsesAnnotation>>,
1115}
1116
1117/// Response annotation
1118#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
1119pub struct ResponsesAnnotation {
1120    #[serde(rename = "type")]
1121    pub annotation_type: String,
1122    pub text: String,
1123    pub start_index: usize,
1124    pub end_index: usize,
1125}
1126
1127/// Response streaming chunk
1128#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
1129pub struct ResponsesChunk {
1130    pub id: String,
1131    pub object: &'static str,
1132    pub created_at: f64,
1133    pub model: String,
1134    pub chunk_type: String,
1135    pub delta: Option<ResponsesDelta>,
1136    pub usage: Option<ResponsesUsage>,
1137    pub metadata: Option<Value>,
1138}
1139
1140/// Response delta for streaming
1141#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
1142pub struct ResponsesDelta {
1143    pub output: Option<Vec<ResponsesDeltaOutput>>,
1144    pub status: Option<String>,
1145}
1146
1147/// Response delta output item
1148#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
1149pub struct ResponsesDeltaOutput {
1150    pub id: String,
1151    #[serde(rename = "type")]
1152    pub output_type: String,
1153    pub content: Option<Vec<ResponsesDeltaContent>>,
1154}
1155
1156/// Response delta content item
1157#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
1158pub struct ResponsesDeltaContent {
1159    #[serde(rename = "type")]
1160    pub content_type: String,
1161    pub text: Option<String>,
1162}