1use either::Either;
2use indexmap::IndexMap;
3use mistralrs_audio::AudioInput;
4use mistralrs_quant::IsqType;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8use crate::{
9 response::Response, sampler::SamplingParams, tools::ToolChoice, CustomLogitsProcessor,
10 DiffusionGenerationParams, Tool,
11};
12use std::{fmt::Debug, sync::Arc};
13use tokio::sync::mpsc::Sender;
14
15pub type LlguidanceGrammar = llguidance::api::TopLevelGrammar;
16
17#[derive(Clone, Serialize, Deserialize)]
18pub enum Constraint {
20 Regex(String),
21 Lark(String),
22 JsonSchema(serde_json::Value),
23 Llguidance(LlguidanceGrammar),
24 None,
25}
26
27#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
28#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
29#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
30pub enum ImageGenerationResponseFormat {
32 Url,
33 B64Json,
34}
35
36pub type MessageContent = Either<String, Vec<IndexMap<String, Value>>>;
37
38#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Default)]
41#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
42#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
43#[serde(rename_all = "lowercase")]
44pub enum ReasoningEffort {
45 Low,
47 #[default]
49 Medium,
50 High,
52}
53
54impl ReasoningEffort {
55 pub fn as_str(&self) -> &'static str {
57 match self {
58 Self::Low => "low",
59 Self::Medium => "medium",
60 Self::High => "high",
61 }
62 }
63}
64
65#[derive(Clone, Debug, Serialize, Deserialize)]
66pub enum RequestMessage {
68 Chat {
69 messages: Vec<IndexMap<String, MessageContent>>,
70 enable_thinking: Option<bool>,
71 reasoning_effort: Option<ReasoningEffort>,
73 },
74 Completion {
75 text: String,
76 echo_prompt: bool,
77 best_of: Option<usize>,
78 },
79 CompletionTokens(Vec<u32>),
80 VisionChat {
81 #[serde(skip)] images: Vec<image::DynamicImage>,
83 #[serde(skip)] audios: Vec<AudioInput>,
85 messages: Vec<IndexMap<String, MessageContent>>,
86 enable_thinking: Option<bool>,
87 reasoning_effort: Option<ReasoningEffort>,
89 },
90 ImageGeneration {
91 prompt: String,
92 format: ImageGenerationResponseFormat,
93 generation_params: DiffusionGenerationParams,
94 },
95 SpeechGeneration {
96 prompt: String,
97 },
98 Embedding {
99 prompt: String,
100 },
101 EmbeddingTokens {
102 prompt: Vec<u32>,
103 },
104}
105
106fn default_responder<T>() -> Sender<T> {
107 let (sender, _) = tokio::sync::mpsc::channel(1);
108 sender
109}
110
111#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
112#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
113#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
114pub enum SearchContextSize {
115 #[serde(rename = "low")]
116 Low,
117 #[default]
118 #[serde(rename = "medium")]
119 Medium,
120 #[serde(rename = "high")]
121 High,
122}
123
124#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
125#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
126#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
127pub struct ApproximateUserLocation {
128 pub city: String,
129 pub country: String,
130 pub region: String,
131 pub timezone: String,
132}
133
134#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
135#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
136#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
137#[serde(tag = "type")]
138pub enum WebSearchUserLocation {
139 #[serde(rename = "approximate")]
140 Approximate {
141 approximate: ApproximateUserLocation,
142 },
143}
144
145#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
146#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
147#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
148pub struct WebSearchOptions {
149 pub search_context_size: Option<SearchContextSize>,
150 pub user_location: Option<WebSearchUserLocation>,
151 pub search_description: Option<String>,
153 pub extract_description: Option<String>,
155}
156
157#[derive(Clone, Serialize, Deserialize)]
158pub struct NormalRequest {
177 pub messages: RequestMessage,
178 pub sampling_params: SamplingParams,
179 #[serde(default = "default_responder")]
180 #[serde(skip)]
181 pub response: Sender<Response>,
182 pub return_logprobs: bool,
183 pub is_streaming: bool,
184 pub id: usize,
185 pub constraint: Constraint,
186 pub suffix: Option<String>,
187 pub tools: Option<Vec<Tool>>,
188 pub tool_choice: Option<ToolChoice>,
189 #[serde(skip)]
190 pub logits_processors: Option<Vec<Arc<dyn CustomLogitsProcessor>>>,
191 pub return_raw_logits: bool,
192 pub web_search_options: Option<WebSearchOptions>,
193 pub model_id: Option<String>,
194 #[serde(default)]
195 pub truncate_sequence: bool,
196}
197
198impl NormalRequest {
199 pub fn new_simple(
200 messages: RequestMessage,
201 sampling_params: SamplingParams,
202 response: Sender<Response>,
203 id: usize,
204 tools: Option<Vec<Tool>>,
205 tool_choice: Option<ToolChoice>,
206 ) -> Self {
207 Self {
208 messages,
209 sampling_params,
210 response,
211 id,
212 tools,
213 tool_choice,
214 return_logprobs: false,
215 is_streaming: false,
216 constraint: Constraint::None,
217 suffix: None,
218 logits_processors: None,
219 return_raw_logits: false,
220 web_search_options: None,
221 model_id: None,
222 truncate_sequence: false,
223 }
224 }
225}
226
227#[derive(Clone, Serialize, Deserialize)]
228pub struct TokenizationRequest {
231 pub text: Either<Vec<IndexMap<String, MessageContent>>, String>,
232 pub tools: Option<Vec<Tool>>,
233 pub add_generation_prompt: bool,
234 pub add_special_tokens: bool,
235 pub enable_thinking: Option<bool>,
236 pub reasoning_effort: Option<ReasoningEffort>,
237 #[serde(default = "default_responder")]
238 #[serde(skip)]
239 pub response: Sender<anyhow::Result<Vec<u32>>>,
240}
241
242#[derive(Clone, Serialize, Deserialize)]
243pub struct DetokenizationRequest {
245 pub tokens: Vec<u32>,
246 pub skip_special_tokens: bool,
247 #[serde(default = "default_responder")]
248 #[serde(skip)]
249 pub response: Sender<anyhow::Result<String>>,
250}
251
252#[derive(Clone, Serialize, Deserialize)]
253pub enum Request {
256 Normal(Box<NormalRequest>),
257 ReIsq(IsqType),
258 Tokenize(TokenizationRequest),
259 Detokenize(DetokenizationRequest),
260 Terminate,
263 TerminateAllSeqsNextStep,
264}
265
266impl Debug for Request {
267 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268 match self {
269 Request::Normal(boxed_req) => {
270 let NormalRequest {
271 messages,
272 sampling_params,
273 is_streaming,
274 id,
275 ..
276 } = &**boxed_req;
277 write!(
278 f,
279 "Request {id} {{ messages: `{messages:?}`, sampling_params: {sampling_params:?}, is_streaming: {is_streaming}}}",
280 )
281 }
282 Request::ReIsq(tp) => {
283 write!(f, "Re ISQ Request {tp:?}",)
284 }
285 Request::Tokenize(req) => {
286 write!(f, "Tokenization Request {:?}", req.text)
287 }
288 Request::Detokenize(req) => {
289 write!(f, "Tokenization Request {:?}", req.tokens)
290 }
291 Request::Terminate => write!(f, "Termination Request"),
292 Request::TerminateAllSeqsNextStep => write!(f, "Terminate All Seqs Next Step"),
293 }
294 }
295}