mistralrs_server_core/
openai.rs

1//! ## OpenAI compatible functionality.
2
3use std::{collections::HashMap, ops::Deref};
4
5use either::Either;
6use mistralrs_core::{
7    ImageGenerationResponseFormat, LlguidanceGrammar, Tool, ToolChoice, ToolType, WebSearchOptions,
8};
9use serde::{Deserialize, Serialize};
10use utoipa::{
11    openapi::{schema::SchemaType, ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, Type},
12    PartialSchema, ToSchema,
13};
14
15/// Inner content structure for messages that can be either a string or key-value pairs
16#[derive(Debug, Clone, Deserialize, Serialize)]
17pub struct MessageInnerContent(
18    #[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
19);
20
21// The impl Deref was preventing the Derive ToSchema and #[schema] macros from
22// properly working, so manually impl ToSchema
23impl PartialSchema for MessageInnerContent {
24    fn schema() -> RefOr<Schema> {
25        RefOr::T(message_inner_content_schema())
26    }
27}
28
29impl ToSchema for MessageInnerContent {
30    fn schemas(
31        schemas: &mut Vec<(
32            String,
33            utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
34        )>,
35    ) {
36        schemas.push((
37            MessageInnerContent::name().into(),
38            MessageInnerContent::schema(),
39        ));
40    }
41}
42
43impl Deref for MessageInnerContent {
44    type Target = Either<String, HashMap<String, String>>;
45    fn deref(&self) -> &Self::Target {
46        &self.0
47    }
48}
49
50/// Function for MessageInnerContent Schema generation to handle `Either`
51fn message_inner_content_schema() -> Schema {
52    Schema::OneOf(
53        OneOfBuilder::new()
54            // Either::Left - simple string
55            .item(Schema::Object(
56                ObjectBuilder::new()
57                    .schema_type(SchemaType::Type(Type::String))
58                    .build(),
59            ))
60            // Either::Right - object with string values
61            .item(Schema::Object(
62                ObjectBuilder::new()
63                    .schema_type(SchemaType::Type(Type::Object))
64                    .additional_properties(Some(RefOr::T(Schema::Object(
65                        ObjectBuilder::new()
66                            .schema_type(SchemaType::Type(Type::String))
67                            .build(),
68                    ))))
69                    .build(),
70            ))
71            .build(),
72    )
73}
74
75/// Message content that can be either simple text or complex structured content
76#[derive(Debug, Clone, Deserialize, Serialize)]
77pub struct MessageContent(
78    #[serde(with = "either::serde_untagged")]
79    Either<String, Vec<HashMap<String, MessageInnerContent>>>,
80);
81
82// The impl Deref was preventing the Derive ToSchema and #[schema] macros from
83// properly working, so manually impl ToSchema
84impl PartialSchema for MessageContent {
85    fn schema() -> RefOr<Schema> {
86        RefOr::T(message_content_schema())
87    }
88}
89
90impl ToSchema for MessageContent {
91    fn schemas(
92        schemas: &mut Vec<(
93            String,
94            utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
95        )>,
96    ) {
97        schemas.push((MessageContent::name().into(), MessageContent::schema()));
98    }
99}
100
101impl Deref for MessageContent {
102    type Target = Either<String, Vec<HashMap<String, MessageInnerContent>>>;
103    fn deref(&self) -> &Self::Target {
104        &self.0
105    }
106}
107
108/// Function for MessageContent Schema generation to handle `Either`
109fn message_content_schema() -> Schema {
110    Schema::OneOf(
111        OneOfBuilder::new()
112            .item(Schema::Object(
113                ObjectBuilder::new()
114                    .schema_type(SchemaType::Type(Type::String))
115                    .build(),
116            ))
117            .item(Schema::Array(
118                ArrayBuilder::new()
119                    .items(RefOr::T(Schema::Object(
120                        ObjectBuilder::new()
121                            .schema_type(SchemaType::Type(Type::Object))
122                            .additional_properties(Some(RefOr::Ref(
123                                utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
124                            )))
125                            .build(),
126                    )))
127                    .build(),
128            ))
129            .build(),
130    )
131}
132
133/// Represents a function call made by the assistant
134///
135/// When using tool calling, this structure contains the details of a function
136/// that the model has decided to call, including the function name and its parameters.
137#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, ToSchema)]
138pub struct FunctionCalled {
139    /// The name of the function to call
140    pub name: String,
141    /// The function arguments
142    #[serde(alias = "arguments")]
143    pub parameters: String,
144}
145
146/// Represents a tool call made by the assistant
147///
148/// This structure wraps a function call with its type information.
149#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, ToSchema)]
150pub struct ToolCall {
151    /// The type of tool being called
152    #[serde(rename = "type")]
153    pub tp: ToolType,
154    ///  The function call details
155    pub function: FunctionCalled,
156}
157
158/// Represents a single message in a conversation
159///
160/// ### Examples
161///
162/// ```ignore
163/// use either::Either;
164/// use mistralrs_server_core::openai::{Message, MessageContent};
165///
166/// // User message
167/// let user_msg = Message {
168///     content: Some(MessageContent(Either::Left("What's 2+2?".to_string()))),
169///     role: "user".to_string(),
170///     name: None,
171///     tool_calls: None,
172/// };
173///
174/// // System message
175/// let system_msg = Message {
176///     content: Some(MessageContent(Either::Left("You are a helpful assistant.".to_string()))),
177///     role: "system".to_string(),
178///     name: None,
179///     tool_calls: None,
180/// };
181/// ```
182#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
183pub struct Message {
184    /// The message content
185    pub content: Option<MessageContent>,
186    /// The role of the message sender ("user", "assistant", "system", "tool", etc.)
187    pub role: String,
188    pub name: Option<String>,
189    /// Optional list of tool calls
190    pub tool_calls: Option<Vec<ToolCall>>,
191}
192
193/// Stop token configuration for generation
194///
195/// Defines when the model should stop generating text, either with a single
196/// stop token or multiple possible stop sequences.
197#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
198#[serde(untagged)]
199pub enum StopTokens {
200    ///  Multiple possible stop sequences
201    Multi(Vec<String>),
202    /// Single stop sequence
203    Single(String),
204}
205
206/// Default value helper
207fn default_false() -> bool {
208    false
209}
210
211/// Default value helper
212fn default_1usize() -> usize {
213    1
214}
215
216/// Default value helper
217fn default_720usize() -> usize {
218    720
219}
220
221/// Default value helper
222fn default_1280usize() -> usize {
223    1280
224}
225
226/// Default value helper
227fn default_model() -> String {
228    "default".to_string()
229}
230
231/// Default value helper
232fn default_response_format() -> ImageGenerationResponseFormat {
233    ImageGenerationResponseFormat::Url
234}
235
236/// Grammar specification for structured generation
237///
238/// Defines different types of grammars that can be used to constrain model output,
239/// ensuring it follows specific formats or structures.
240///
241/// ### Examples
242///
243/// ```ignore
244/// use mistralrs_server_core::openai::Grammar;
245///
246/// // Regex grammar for phone numbers
247/// let phone_regex = Grammar::Regex(r"\d{3}-\d{3}-\d{4}".to_string());
248///
249/// // JSON schema for structured data
250/// let json_schema = Grammar::JsonSchema(serde_json::json!({
251///     "type": "object",
252///     "properties": {
253///         "name": {"type": "string"},
254///         "age": {"type": "integer"}
255///     },
256///     "required": ["name", "age"]
257/// }));
258///
259/// // Lark grammar for arithmetic expressions
260/// let lark_grammar = Grammar::Lark(r#"
261///     ?start: expr
262///     expr: term ("+" term | "-" term)*
263///     term: factor ("*" factor | "/" factor)*
264///     factor: NUMBER | "(" expr ")"
265///     %import common.NUMBER
266/// "#.to_string());
267/// ```
268#[derive(Debug, Clone, Deserialize, Serialize)]
269#[serde(tag = "type", content = "value")]
270pub enum Grammar {
271    /// Regular expression grammar
272    #[serde(rename = "regex")]
273    Regex(String),
274    /// JSON schema grammar
275    #[serde(rename = "json_schema")]
276    JsonSchema(serde_json::Value),
277    /// LLGuidance grammar
278    #[serde(rename = "llguidance")]
279    Llguidance(LlguidanceGrammar),
280    /// Lark parser grammar
281    #[serde(rename = "lark")]
282    Lark(String),
283}
284
285// Implement ToSchema manually to handle `LlguidanceGrammar`
286impl PartialSchema for Grammar {
287    fn schema() -> RefOr<Schema> {
288        RefOr::T(Schema::OneOf(
289            OneOfBuilder::new()
290                .item(create_grammar_variant_schema(
291                    "regex",
292                    Schema::Object(
293                        ObjectBuilder::new()
294                            .schema_type(SchemaType::Type(Type::String))
295                            .build(),
296                    ),
297                ))
298                .item(create_grammar_variant_schema(
299                    "json_schema",
300                    Schema::Object(
301                        ObjectBuilder::new()
302                            .schema_type(SchemaType::Type(Type::Object))
303                            .build(),
304                    ),
305                ))
306                .item(create_grammar_variant_schema(
307                    "llguidance",
308                    llguidance_schema(),
309                ))
310                .item(create_grammar_variant_schema(
311                    "lark",
312                    Schema::Object(
313                        ObjectBuilder::new()
314                            .schema_type(SchemaType::Type(Type::String))
315                            .build(),
316                    ),
317                ))
318                .build(),
319        ))
320    }
321}
322
323impl ToSchema for Grammar {
324    fn schemas(
325        schemas: &mut Vec<(
326            String,
327            utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
328        )>,
329    ) {
330        schemas.push((Grammar::name().into(), Grammar::schema()));
331    }
332}
333
334/// Helper function to create a grammar variant schema
335fn create_grammar_variant_schema(type_value: &str, value_schema: Schema) -> Schema {
336    Schema::Object(
337        ObjectBuilder::new()
338            .schema_type(SchemaType::Type(Type::Object))
339            .property(
340                "type",
341                RefOr::T(Schema::Object(
342                    ObjectBuilder::new()
343                        .schema_type(SchemaType::Type(Type::String))
344                        .enum_values(Some(vec![serde_json::Value::String(
345                            type_value.to_string(),
346                        )]))
347                        .build(),
348                )),
349            )
350            .property("value", RefOr::T(value_schema))
351            .required("type")
352            .required("value")
353            .build(),
354    )
355}
356
357/// Helper function to generate LLGuidance schema
358fn llguidance_schema() -> Schema {
359    let grammar_with_lexer_schema = Schema::Object(
360        ObjectBuilder::new()
361            .schema_type(SchemaType::Type(Type::Object))
362            .property(
363                "name",
364                RefOr::T(Schema::Object(
365                    ObjectBuilder::new()
366                        .schema_type(SchemaType::from_iter([Type::String, Type::Null]))
367                        .description(Some(
368                            "The name of this grammar, can be used in GenGrammar nodes",
369                        ))
370                        .build(),
371                )),
372            )
373            .property(
374                "json_schema",
375                RefOr::T(Schema::Object(
376                    ObjectBuilder::new()
377                        .schema_type(SchemaType::from_iter([Type::Object, Type::Null]))
378                        .description(Some("The JSON schema that the grammar should generate"))
379                        .build(),
380                )),
381            )
382            .property(
383                "lark_grammar",
384                RefOr::T(Schema::Object(
385                    ObjectBuilder::new()
386                        .schema_type(SchemaType::from_iter([Type::String, Type::Null]))
387                        .description(Some("The Lark grammar that the grammar should generate"))
388                        .build(),
389                )),
390            )
391            .description(Some("Grammar configuration with lexer settings"))
392            .build(),
393    );
394
395    Schema::Object(
396        ObjectBuilder::new()
397            .schema_type(SchemaType::Type(Type::Object))
398            .property(
399                "grammars",
400                RefOr::T(Schema::Array(
401                    ArrayBuilder::new()
402                        .items(RefOr::T(grammar_with_lexer_schema))
403                        .description(Some("List of grammar configurations"))
404                        .build(),
405                )),
406            )
407            .property(
408                "max_tokens",
409                RefOr::T(Schema::Object(
410                    ObjectBuilder::new()
411                        .schema_type(SchemaType::from_iter([Type::Integer, Type::Null]))
412                        .description(Some("Maximum number of tokens to generate"))
413                        .build(),
414                )),
415            )
416            .required("grammars")
417            .description(Some("Top-level grammar configuration for LLGuidance"))
418            .build(),
419    )
420}
421
422/// JSON Schema for structured responses
423#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
424pub struct JsonSchemaResponseFormat {
425    pub name: String,
426    pub schema: serde_json::Value,
427}
428
429/// Response format for model output
430#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
431#[serde(tag = "type")]
432pub enum ResponseFormat {
433    /// Free-form text response
434    #[serde(rename = "text")]
435    Text,
436    /// Structured response following a JSON schema
437    #[serde(rename = "json_schema")]
438    JsonSchema {
439        json_schema: JsonSchemaResponseFormat,
440    },
441}
442
443/// Chat completion request following OpenAI's specification
444#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
445pub struct ChatCompletionRequest {
446    #[schema(
447        schema_with = messages_schema,
448        example = json!(vec![Message{content:Some(MessageContent{0: either::Left(("Why did the crab cross the road?".to_string()))}), role:"user".to_string(), name: None, tool_calls: None}])
449    )]
450    #[serde(with = "either::serde_untagged")]
451    pub messages: Either<Vec<Message>, String>,
452    #[schema(example = "mistral")]
453    #[serde(default = "default_model")]
454    pub model: String,
455    #[schema(example = json!(Option::None::<HashMap<u32, f32>>))]
456    pub logit_bias: Option<HashMap<u32, f32>>,
457    #[serde(default = "default_false")]
458    #[schema(example = false)]
459    pub logprobs: bool,
460    #[schema(example = json!(Option::None::<usize>))]
461    pub top_logprobs: Option<usize>,
462    #[schema(example = 256)]
463    #[serde(alias = "max_completion_tokens")]
464    pub max_tokens: Option<usize>,
465    #[serde(rename = "n")]
466    #[serde(default = "default_1usize")]
467    #[schema(example = 1)]
468    pub n_choices: usize,
469    #[schema(example = json!(Option::None::<f32>))]
470    pub presence_penalty: Option<f32>,
471    #[schema(example = json!(Option::None::<f32>))]
472    pub frequency_penalty: Option<f32>,
473    #[serde(rename = "stop")]
474    #[schema(example = json!(Option::None::<StopTokens>))]
475    pub stop_seqs: Option<StopTokens>,
476    #[schema(example = 0.7)]
477    pub temperature: Option<f64>,
478    #[schema(example = json!(Option::None::<f64>))]
479    pub top_p: Option<f64>,
480    #[schema(example = true)]
481    pub stream: Option<bool>,
482    #[schema(example = json!(Option::None::<Vec<Tool>>))]
483    pub tools: Option<Vec<Tool>>,
484    #[schema(example = json!(Option::None::<ToolChoice>))]
485    pub tool_choice: Option<ToolChoice>,
486    #[schema(example = json!(Option::None::<ResponseFormat>))]
487    pub response_format: Option<ResponseFormat>,
488    #[schema(example = json!(Option::None::<WebSearchOptions>))]
489    pub web_search_options: Option<WebSearchOptions>,
490
491    // mistral.rs additional
492    #[schema(example = json!(Option::None::<usize>))]
493    pub top_k: Option<usize>,
494    #[schema(example = json!(Option::None::<Grammar>))]
495    pub grammar: Option<Grammar>,
496    #[schema(example = json!(Option::None::<f64>))]
497    pub min_p: Option<f64>,
498    #[schema(example = json!(Option::None::<f32>))]
499    pub dry_multiplier: Option<f32>,
500    #[schema(example = json!(Option::None::<f32>))]
501    pub dry_base: Option<f32>,
502    #[schema(example = json!(Option::None::<usize>))]
503    pub dry_allowed_length: Option<usize>,
504    #[schema(example = json!(Option::None::<String>))]
505    pub dry_sequence_breakers: Option<Vec<String>>,
506    #[schema(example = json!(Option::None::<bool>))]
507    pub enable_thinking: Option<bool>,
508}
509
510/// Function for ChatCompletionRequest.messages Schema generation to handle `Either`
511fn messages_schema() -> Schema {
512    Schema::OneOf(
513        OneOfBuilder::new()
514            .item(Schema::Array(
515                ArrayBuilder::new()
516                    .items(RefOr::Ref(utoipa::openapi::Ref::from_schema_name(
517                        "Message",
518                    )))
519                    .build(),
520            ))
521            .item(Schema::Object(
522                ObjectBuilder::new()
523                    .schema_type(SchemaType::Type(Type::String))
524                    .build(),
525            ))
526            .build(),
527    )
528}
529
530/// Model information metadata about an available mode
531#[derive(Debug, Serialize, ToSchema)]
532pub struct ModelObject {
533    pub id: String,
534    pub object: &'static str,
535    pub created: u64,
536    pub owned_by: &'static str,
537}
538
539/// Collection of available models
540#[derive(Debug, Serialize, ToSchema)]
541pub struct ModelObjects {
542    pub object: &'static str,
543    pub data: Vec<ModelObject>,
544}
545
546/// Legacy OpenAI compatible text completion request
547#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
548pub struct CompletionRequest {
549    #[schema(example = "mistral")]
550    #[serde(default = "default_model")]
551    pub model: String,
552    #[schema(example = "Say this is a test.")]
553    pub prompt: String,
554    #[schema(example = 1)]
555    pub best_of: Option<usize>,
556    #[serde(rename = "echo")]
557    #[serde(default = "default_false")]
558    #[schema(example = false)]
559    pub echo_prompt: bool,
560    #[schema(example = json!(Option::None::<f32>))]
561    pub presence_penalty: Option<f32>,
562    #[schema(example = json!(Option::None::<f32>))]
563    pub frequency_penalty: Option<f32>,
564    #[schema(example = json!(Option::None::<HashMap<u32, f32>>))]
565    pub logit_bias: Option<HashMap<u32, f32>>,
566    #[schema(example = json!(Option::None::<usize>))]
567    pub logprobs: Option<usize>,
568    #[schema(example = 16)]
569    pub max_tokens: Option<usize>,
570    #[serde(rename = "n")]
571    #[serde(default = "default_1usize")]
572    #[schema(example = 1)]
573    pub n_choices: usize,
574    #[serde(rename = "stop")]
575    #[schema(example = json!(Option::None::<StopTokens>))]
576    pub stop_seqs: Option<StopTokens>,
577    pub stream: Option<bool>,
578    #[schema(example = 0.7)]
579    pub temperature: Option<f64>,
580    #[schema(example = json!(Option::None::<f64>))]
581    pub top_p: Option<f64>,
582    #[schema(example = json!(Option::None::<String>))]
583    pub suffix: Option<String>,
584    #[serde(rename = "user")]
585    pub _user: Option<String>,
586    #[schema(example = json!(Option::None::<Vec<Tool>>))]
587    pub tools: Option<Vec<Tool>>,
588    #[schema(example = json!(Option::None::<ToolChoice>))]
589    pub tool_choice: Option<ToolChoice>,
590
591    // mistral.rs additional
592    #[schema(example = json!(Option::None::<usize>))]
593    pub top_k: Option<usize>,
594    #[schema(example = json!(Option::None::<Grammar>))]
595    pub grammar: Option<Grammar>,
596    #[schema(example = json!(Option::None::<f64>))]
597    pub min_p: Option<f64>,
598    #[schema(example = json!(Option::None::<f32>))]
599    pub dry_multiplier: Option<f32>,
600    #[schema(example = json!(Option::None::<f32>))]
601    pub dry_base: Option<f32>,
602    #[schema(example = json!(Option::None::<usize>))]
603    pub dry_allowed_length: Option<usize>,
604    #[schema(example = json!(Option::None::<String>))]
605    pub dry_sequence_breakers: Option<Vec<String>>,
606}
607
608/// Image generation request
609#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
610pub struct ImageGenerationRequest {
611    #[schema(example = "mistral")]
612    #[serde(default = "default_model")]
613    pub model: String,
614    #[schema(example = "Draw a picture of a majestic, snow-covered mountain.")]
615    pub prompt: String,
616    #[serde(rename = "n")]
617    #[serde(default = "default_1usize")]
618    #[schema(example = 1)]
619    pub n_choices: usize,
620    #[serde(default = "default_response_format")]
621    pub response_format: ImageGenerationResponseFormat,
622    #[serde(default = "default_720usize")]
623    #[schema(example = 720)]
624    pub height: usize,
625    #[serde(default = "default_1280usize")]
626    #[schema(example = 1280)]
627    pub width: usize,
628}
629
630/// Audio format options for speech generation responses.
631#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, ToSchema)]
632#[serde(rename_all = "lowercase")]
633pub enum AudioResponseFormat {
634    /// Widely compatible, lossy compression, good for web streaming
635    #[default]
636    Mp3,
637    /// Good compression efficiency, ideal for real-time communication
638    Opus,
639    /// High-quality lossy compression, commonly used in mobile applications
640    Aac,
641    /// Lossless compression, larger file sizes but good audio quality
642    Flac,
643    /// Uncompressed, largest file sizes but maximum compatibility
644    Wav,
645    ///  Raw audio data, requires additional format specification
646    Pcm,
647}
648
649impl AudioResponseFormat {
650    /// Generate the appropriate MIME content type string for this audio format.
651    pub fn audio_content_type(
652        &self,
653        pcm_rate: usize,
654        pcm_channels: usize,
655        pcm_format: &'static str,
656    ) -> String {
657        let content_type = match &self {
658            AudioResponseFormat::Mp3 => "audio/mpeg".to_string(),
659            AudioResponseFormat::Opus => "audio/ogg; codecs=opus".to_string(),
660            AudioResponseFormat::Aac => "audio/aac".to_string(),
661            AudioResponseFormat::Flac => "audio/flac".to_string(),
662            AudioResponseFormat::Wav => "audio/wav".to_string(),
663            AudioResponseFormat::Pcm => format!("audio/pcm; codecs=1; format={pcm_format}"),
664        };
665
666        format!("{content_type}; rate={pcm_rate}; channels={pcm_channels}")
667    }
668}
669
670/// Speech generation request
671#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
672pub struct SpeechGenerationRequest {
673    /// The TTS model to use for audio generation.
674    #[schema(example = "nari-labs/Dia-1.6B")]
675    #[serde(default = "default_model")]
676    pub model: String,
677    /// The text content to convert to speech.
678    #[schema(
679        example = "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."
680    )]
681    pub input: String,
682    // `voice` and `instructions` are ignored.
683    /// The desired audio format for the generated speech.
684    #[schema(example = "mp3")]
685    pub response_format: AudioResponseFormat,
686}