1use either::Either;
2use indexmap::IndexMap;
3use mistralrs_audio::AudioInput;
4use mistralrs_quant::IsqType;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8use crate::{
9 response::Response, sampler::SamplingParams, tools::ToolChoice, CustomLogitsProcessor,
10 DiffusionGenerationParams, Tool,
11};
12use std::{fmt::Debug, sync::Arc};
13use tokio::sync::mpsc::Sender;
14
15pub type LlguidanceGrammar = llguidance::api::TopLevelGrammar;
16
17#[derive(Clone, Serialize, Deserialize)]
18pub enum Constraint {
20 Regex(String),
21 Lark(String),
22 JsonSchema(serde_json::Value),
23 Llguidance(LlguidanceGrammar),
24 None,
25}
26
27#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
28#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
29#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
30pub enum ImageGenerationResponseFormat {
32 Url,
33 B64Json,
34}
35
36pub type MessageContent = Either<String, Vec<IndexMap<String, Value>>>;
37
38#[derive(Clone, Debug, Serialize, Deserialize)]
39pub enum RequestMessage {
41 Chat {
42 messages: Vec<IndexMap<String, MessageContent>>,
43 enable_thinking: Option<bool>,
44 },
45 Completion {
46 text: String,
47 echo_prompt: bool,
48 best_of: Option<usize>,
49 },
50 CompletionTokens(Vec<u32>),
51 VisionChat {
52 #[serde(skip)] images: Vec<image::DynamicImage>,
54 #[serde(skip)] audios: Vec<AudioInput>,
56 messages: Vec<IndexMap<String, MessageContent>>,
57 enable_thinking: Option<bool>,
58 },
59 ImageGeneration {
60 prompt: String,
61 format: ImageGenerationResponseFormat,
62 generation_params: DiffusionGenerationParams,
63 },
64 SpeechGeneration {
65 prompt: String,
66 },
67 Embedding {
68 prompt: String,
69 },
70 EmbeddingTokens {
71 prompt: Vec<u32>,
72 },
73}
74
75fn default_responder<T>() -> Sender<T> {
76 let (sender, _) = tokio::sync::mpsc::channel(1);
77 sender
78}
79
80#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
81#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
82#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
83pub enum SearchContextSize {
84 #[serde(rename = "low")]
85 Low,
86 #[default]
87 #[serde(rename = "medium")]
88 Medium,
89 #[serde(rename = "high")]
90 High,
91}
92
93#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
94#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
95#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
96pub struct ApproximateUserLocation {
97 pub city: String,
98 pub country: String,
99 pub region: String,
100 pub timezone: String,
101}
102
103#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
104#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
105#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
106#[serde(tag = "type")]
107pub enum WebSearchUserLocation {
108 #[serde(rename = "approximate")]
109 Approximate {
110 approximate: ApproximateUserLocation,
111 },
112}
113
114#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
115#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
116#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
117pub struct WebSearchOptions {
118 pub search_context_size: Option<SearchContextSize>,
119 pub user_location: Option<WebSearchUserLocation>,
120 pub search_description: Option<String>,
122 pub extract_description: Option<String>,
124}
125
126#[derive(Clone, Serialize, Deserialize)]
127pub struct NormalRequest {
146 pub messages: RequestMessage,
147 pub sampling_params: SamplingParams,
148 #[serde(default = "default_responder")]
149 #[serde(skip)]
150 pub response: Sender<Response>,
151 pub return_logprobs: bool,
152 pub is_streaming: bool,
153 pub id: usize,
154 pub constraint: Constraint,
155 pub suffix: Option<String>,
156 pub tools: Option<Vec<Tool>>,
157 pub tool_choice: Option<ToolChoice>,
158 #[serde(skip)]
159 pub logits_processors: Option<Vec<Arc<dyn CustomLogitsProcessor>>>,
160 pub return_raw_logits: bool,
161 pub web_search_options: Option<WebSearchOptions>,
162 pub model_id: Option<String>,
163 #[serde(default)]
164 pub truncate_sequence: bool,
165}
166
167impl NormalRequest {
168 pub fn new_simple(
169 messages: RequestMessage,
170 sampling_params: SamplingParams,
171 response: Sender<Response>,
172 id: usize,
173 tools: Option<Vec<Tool>>,
174 tool_choice: Option<ToolChoice>,
175 ) -> Self {
176 Self {
177 messages,
178 sampling_params,
179 response,
180 id,
181 tools,
182 tool_choice,
183 return_logprobs: false,
184 is_streaming: false,
185 constraint: Constraint::None,
186 suffix: None,
187 logits_processors: None,
188 return_raw_logits: false,
189 web_search_options: None,
190 model_id: None,
191 truncate_sequence: false,
192 }
193 }
194}
195
196#[derive(Clone, Serialize, Deserialize)]
197pub struct TokenizationRequest {
200 pub text: Either<Vec<IndexMap<String, MessageContent>>, String>,
201 pub tools: Option<Vec<Tool>>,
202 pub add_generation_prompt: bool,
203 pub add_special_tokens: bool,
204 pub enable_thinking: Option<bool>,
205 #[serde(default = "default_responder")]
206 #[serde(skip)]
207 pub response: Sender<anyhow::Result<Vec<u32>>>,
208}
209
210#[derive(Clone, Serialize, Deserialize)]
211pub struct DetokenizationRequest {
213 pub tokens: Vec<u32>,
214 pub skip_special_tokens: bool,
215 #[serde(default = "default_responder")]
216 #[serde(skip)]
217 pub response: Sender<anyhow::Result<String>>,
218}
219
220#[derive(Clone, Serialize, Deserialize)]
221pub enum Request {
224 Normal(Box<NormalRequest>),
225 ReIsq(IsqType),
226 Tokenize(TokenizationRequest),
227 Detokenize(DetokenizationRequest),
228 Terminate,
231 TerminateAllSeqsNextStep,
232}
233
234impl Debug for Request {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 match self {
237 Request::Normal(boxed_req) => {
238 let NormalRequest {
239 messages,
240 sampling_params,
241 is_streaming,
242 id,
243 ..
244 } = &**boxed_req;
245 write!(
246 f,
247 "Request {id} {{ messages: `{messages:?}`, sampling_params: {sampling_params:?}, is_streaming: {is_streaming}}}",
248 )
249 }
250 Request::ReIsq(tp) => {
251 write!(f, "Re ISQ Request {tp:?}",)
252 }
253 Request::Tokenize(req) => {
254 write!(f, "Tokenization Request {:?}", req.text)
255 }
256 Request::Detokenize(req) => {
257 write!(f, "Tokenization Request {:?}", req.tokens)
258 }
259 Request::Terminate => write!(f, "Termination Request"),
260 Request::TerminateAllSeqsNextStep => write!(f, "Terminate All Seqs Next Step"),
261 }
262 }
263}