mistralrs_server/
openai.rs

1use either::Either;
2use mistralrs_core::{
3    ImageGenerationResponseFormat, LlguidanceGrammar, Tool, ToolChoice, ToolType, WebSearchOptions,
4};
5use serde::{Deserialize, Serialize};
6use std::{collections::HashMap, ops::Deref};
7use utoipa::ToSchema;
8
9#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
10pub struct MessageInnerContent(
11    #[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
12);
13
14impl Deref for MessageInnerContent {
15    type Target = Either<String, HashMap<String, String>>;
16    fn deref(&self) -> &Self::Target {
17        &self.0
18    }
19}
20
21#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
22pub struct MessageContent(
23    #[serde(with = "either::serde_untagged")]
24    Either<String, Vec<HashMap<String, MessageInnerContent>>>,
25);
26
27impl Deref for MessageContent {
28    type Target = Either<String, Vec<HashMap<String, MessageInnerContent>>>;
29    fn deref(&self) -> &Self::Target {
30        &self.0
31    }
32}
33
34#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
35pub struct FunctionCalled {
36    pub name: String,
37    #[serde(alias = "arguments")]
38    pub parameters: String,
39}
40
41#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
42pub struct ToolCall {
43    #[serde(rename = "type")]
44    pub tp: ToolType,
45    pub function: FunctionCalled,
46}
47
48#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
49pub struct Message {
50    pub content: Option<MessageContent>,
51    pub role: String,
52    pub name: Option<String>,
53    pub tool_calls: Option<Vec<ToolCall>>,
54}
55
56#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
57#[serde(untagged)]
58pub enum StopTokens {
59    Multi(Vec<String>),
60    Single(String),
61}
62
63fn default_false() -> bool {
64    false
65}
66
67fn default_1usize() -> usize {
68    1
69}
70
71fn default_720usize() -> usize {
72    720
73}
74
75fn default_1280usize() -> usize {
76    1280
77}
78
79fn default_model() -> String {
80    "default".to_string()
81}
82
83fn default_response_format() -> ImageGenerationResponseFormat {
84    ImageGenerationResponseFormat::Url
85}
86
87#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
88#[serde(tag = "type", content = "value")]
89pub enum Grammar {
90    #[serde(rename = "regex")]
91    Regex(String),
92    #[serde(rename = "json_schema")]
93    JsonSchema(serde_json::Value),
94    #[serde(rename = "llguidance")]
95    Llguidance(LlguidanceGrammar),
96    #[serde(rename = "lark")]
97    Lark(String),
98}
99
100#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
101pub struct JsonSchemaResponseFormat {
102    pub name: String,
103    pub schema: serde_json::Value,
104}
105
106#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
107#[serde(tag = "type")]
108pub enum ResponseFormat {
109    #[serde(rename = "text")]
110    Text,
111    #[serde(rename = "json_schema")]
112    JsonSchema {
113        json_schema: JsonSchemaResponseFormat,
114    },
115}
116
117#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
118pub struct ChatCompletionRequest {
119    #[schema(example = json!(vec![Message{content:"Why did the crab cross the road?".to_string(), role:"user".to_string(), name: None}]))]
120    #[serde(with = "either::serde_untagged")]
121    pub messages: Either<Vec<Message>, String>,
122    #[schema(example = "mistral")]
123    #[serde(default = "default_model")]
124    pub model: String,
125    #[schema(example = json!(Option::None::<HashMap<u32, f32>>))]
126    pub logit_bias: Option<HashMap<u32, f32>>,
127    #[serde(default = "default_false")]
128    #[schema(example = false)]
129    pub logprobs: bool,
130    #[schema(example = json!(Option::None::<usize>))]
131    pub top_logprobs: Option<usize>,
132    #[schema(example = 256)]
133    #[serde(alias = "max_completion_tokens")]
134    pub max_tokens: Option<usize>,
135    #[serde(rename = "n")]
136    #[serde(default = "default_1usize")]
137    #[schema(example = 1)]
138    pub n_choices: usize,
139    #[schema(example = json!(Option::None::<f32>))]
140    pub presence_penalty: Option<f32>,
141    #[schema(example = json!(Option::None::<f32>))]
142    pub frequency_penalty: Option<f32>,
143    #[serde(rename = "stop")]
144    #[schema(example = json!(Option::None::<StopTokens>))]
145    pub stop_seqs: Option<StopTokens>,
146    #[schema(example = 0.7)]
147    pub temperature: Option<f64>,
148    #[schema(example = json!(Option::None::<f64>))]
149    pub top_p: Option<f64>,
150    #[schema(example = true)]
151    pub stream: Option<bool>,
152    #[schema(example = json!(Option::None::<Vec<Tool>>))]
153    pub tools: Option<Vec<Tool>>,
154    #[schema(example = json!(Option::None::<ToolChoice>))]
155    pub tool_choice: Option<ToolChoice>,
156    #[schema(example = json!(Option::None::<ResponseFormat>))]
157    pub response_format: Option<ResponseFormat>,
158    #[schema(example = json!(Option::None::<WebSearchOptions>))]
159    pub web_search_options: Option<WebSearchOptions>,
160
161    // mistral.rs additional
162    #[schema(example = json!(Option::None::<usize>))]
163    pub top_k: Option<usize>,
164    #[schema(example = json!(Option::None::<Grammar>))]
165    pub grammar: Option<Grammar>,
166    #[schema(example = json!(Option::None::<f64>))]
167    pub min_p: Option<f64>,
168    #[schema(example = json!(Option::None::<f32>))]
169    pub dry_multiplier: Option<f32>,
170    #[schema(example = json!(Option::None::<f32>))]
171    pub dry_base: Option<f32>,
172    #[schema(example = json!(Option::None::<usize>))]
173    pub dry_allowed_length: Option<usize>,
174    #[schema(example = json!(Option::None::<String>))]
175    pub dry_sequence_breakers: Option<Vec<String>>,
176}
177
178#[derive(Debug, Serialize, ToSchema)]
179pub struct ModelObject {
180    pub id: String,
181    pub object: &'static str,
182    pub created: u64,
183    pub owned_by: &'static str,
184}
185
186#[derive(Debug, Serialize, ToSchema)]
187pub struct ModelObjects {
188    pub object: &'static str,
189    pub data: Vec<ModelObject>,
190}
191
192#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
193pub struct CompletionRequest {
194    #[schema(example = "mistral")]
195    #[serde(default = "default_model")]
196    pub model: String,
197    #[schema(example = "Say this is a test.")]
198    pub prompt: String,
199    #[schema(example = 1)]
200    pub best_of: Option<usize>,
201    #[serde(rename = "echo")]
202    #[serde(default = "default_false")]
203    #[schema(example = false)]
204    pub echo_prompt: bool,
205    #[schema(example = json!(Option::None::<f32>))]
206    pub presence_penalty: Option<f32>,
207    #[schema(example = json!(Option::None::<f32>))]
208    pub frequency_penalty: Option<f32>,
209    #[schema(example = json!(Option::None::<HashMap<u32, f32>>))]
210    pub logit_bias: Option<HashMap<u32, f32>>,
211    #[schema(example = json!(Option::None::<usize>))]
212    pub logprobs: Option<usize>,
213    #[schema(example = 16)]
214    pub max_tokens: Option<usize>,
215    #[serde(rename = "n")]
216    #[serde(default = "default_1usize")]
217    #[schema(example = 1)]
218    pub n_choices: usize,
219    #[serde(rename = "stop")]
220    #[schema(example = json!(Option::None::<StopTokens>))]
221    pub stop_seqs: Option<StopTokens>,
222    pub stream: Option<bool>,
223    #[schema(example = 0.7)]
224    pub temperature: Option<f64>,
225    #[schema(example = json!(Option::None::<f64>))]
226    pub top_p: Option<f64>,
227    #[schema(example = json!(Option::None::<String>))]
228    pub suffix: Option<String>,
229    #[serde(rename = "user")]
230    pub _user: Option<String>,
231    #[schema(example = json!(Option::None::<Vec<Tool>>))]
232    pub tools: Option<Vec<Tool>>,
233    #[schema(example = json!(Option::None::<ToolChoice>))]
234    pub tool_choice: Option<ToolChoice>,
235
236    // mistral.rs additional
237    #[schema(example = json!(Option::None::<usize>))]
238    pub top_k: Option<usize>,
239    #[schema(example = json!(Option::None::<Grammar>))]
240    pub grammar: Option<Grammar>,
241    #[schema(example = json!(Option::None::<f64>))]
242    pub min_p: Option<f64>,
243    #[schema(example = json!(Option::None::<f32>))]
244    pub dry_multiplier: Option<f32>,
245    #[schema(example = json!(Option::None::<f32>))]
246    pub dry_base: Option<f32>,
247    #[schema(example = json!(Option::None::<usize>))]
248    pub dry_allowed_length: Option<usize>,
249    #[schema(example = json!(Option::None::<String>))]
250    pub dry_sequence_breakers: Option<Vec<String>>,
251}
252
253#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
254pub struct ImageGenerationRequest {
255    #[schema(example = "mistral")]
256    #[serde(default = "default_model")]
257    pub model: String,
258    #[schema(example = "Draw a picture of a majestic, snow-covered mountain.")]
259    pub prompt: String,
260    #[serde(rename = "n")]
261    #[serde(default = "default_1usize")]
262    #[schema(example = 1)]
263    pub n_choices: usize,
264    #[serde(default = "default_response_format")]
265    pub response_format: ImageGenerationResponseFormat,
266    #[serde(default = "default_720usize")]
267    #[schema(example = 720)]
268    pub height: usize,
269    #[serde(default = "default_1280usize")]
270    #[schema(example = 1280)]
271    pub width: usize,
272}