1use either::Either;
2use indexmap::IndexMap;
3use mistralrs_audio::AudioInput;
4use mistralrs_quant::IsqType;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8use crate::{
9 response::Response, sampler::SamplingParams, tools::ToolChoice, CustomLogitsProcessor,
10 DiffusionGenerationParams, Tool,
11};
12use std::{fmt::Debug, sync::Arc};
13use tokio::sync::mpsc::Sender;
14
15pub type LlguidanceGrammar = llguidance::api::TopLevelGrammar;
16
17#[derive(Clone, Serialize, Deserialize)]
18pub enum Constraint {
20 Regex(String),
21 Lark(String),
22 JsonSchema(serde_json::Value),
23 Llguidance(LlguidanceGrammar),
24 None,
25}
26
27#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
28#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
29#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
30pub enum ImageGenerationResponseFormat {
32 Url,
33 B64Json,
34}
35
36pub type MessageContent = Either<String, Vec<IndexMap<String, Value>>>;
37
38#[derive(Clone, Debug, Serialize, Deserialize)]
39pub enum RequestMessage {
41 Chat {
42 messages: Vec<IndexMap<String, MessageContent>>,
43 enable_thinking: Option<bool>,
44 },
45 Completion {
46 text: String,
47 echo_prompt: bool,
48 best_of: Option<usize>,
49 },
50 CompletionTokens(Vec<u32>),
51 VisionChat {
52 #[serde(skip)] images: Vec<image::DynamicImage>,
54 #[serde(skip)] audios: Vec<AudioInput>,
56 messages: Vec<IndexMap<String, MessageContent>>,
57 enable_thinking: Option<bool>,
58 },
59 ImageGeneration {
60 prompt: String,
61 format: ImageGenerationResponseFormat,
62 generation_params: DiffusionGenerationParams,
63 },
64 SpeechGeneration {
65 prompt: String,
66 },
67}
68
69fn default_responder<T>() -> Sender<T> {
70 let (sender, _) = tokio::sync::mpsc::channel(1);
71 sender
72}
73
74#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
75#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
76#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
77pub enum SearchContextSize {
78 #[serde(rename = "low")]
79 Low,
80 #[default]
81 #[serde(rename = "medium")]
82 Medium,
83 #[serde(rename = "high")]
84 High,
85}
86
87#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
88#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
89#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
90pub struct ApproximateUserLocation {
91 pub city: String,
92 pub country: String,
93 pub region: String,
94 pub timezone: String,
95}
96
97#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
98#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
99#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
100#[serde(tag = "type")]
101pub enum WebSearchUserLocation {
102 #[serde(rename = "approximate")]
103 Approximate {
104 approximate: ApproximateUserLocation,
105 },
106}
107
108#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
109#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
110#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
111pub struct WebSearchOptions {
112 pub search_context_size: Option<SearchContextSize>,
113 pub user_location: Option<WebSearchUserLocation>,
114 pub search_description: Option<String>,
116 pub extract_description: Option<String>,
118}
119
120#[derive(Clone, Serialize, Deserialize)]
121pub struct NormalRequest {
139 pub messages: RequestMessage,
140 pub sampling_params: SamplingParams,
141 #[serde(default = "default_responder")]
142 #[serde(skip)]
143 pub response: Sender<Response>,
144 pub return_logprobs: bool,
145 pub is_streaming: bool,
146 pub id: usize,
147 pub constraint: Constraint,
148 pub suffix: Option<String>,
149 pub tools: Option<Vec<Tool>>,
150 pub tool_choice: Option<ToolChoice>,
151 #[serde(skip)]
152 pub logits_processors: Option<Vec<Arc<dyn CustomLogitsProcessor>>>,
153 pub return_raw_logits: bool,
154 pub web_search_options: Option<WebSearchOptions>,
155 pub model_id: Option<String>,
156}
157
158impl NormalRequest {
159 pub fn new_simple(
160 messages: RequestMessage,
161 sampling_params: SamplingParams,
162 response: Sender<Response>,
163 id: usize,
164 tools: Option<Vec<Tool>>,
165 tool_choice: Option<ToolChoice>,
166 ) -> Self {
167 Self {
168 messages,
169 sampling_params,
170 response,
171 id,
172 tools,
173 tool_choice,
174 return_logprobs: false,
175 is_streaming: false,
176 constraint: Constraint::None,
177 suffix: None,
178 logits_processors: None,
179 return_raw_logits: false,
180 web_search_options: None,
181 model_id: None,
182 }
183 }
184}
185
186#[derive(Clone, Serialize, Deserialize)]
187pub struct TokenizationRequest {
190 pub text: Either<Vec<IndexMap<String, MessageContent>>, String>,
191 pub tools: Option<Vec<Tool>>,
192 pub add_generation_prompt: bool,
193 pub add_special_tokens: bool,
194 pub enable_thinking: Option<bool>,
195 #[serde(default = "default_responder")]
196 #[serde(skip)]
197 pub response: Sender<anyhow::Result<Vec<u32>>>,
198}
199
200#[derive(Clone, Serialize, Deserialize)]
201pub struct DetokenizationRequest {
203 pub tokens: Vec<u32>,
204 pub skip_special_tokens: bool,
205 #[serde(default = "default_responder")]
206 #[serde(skip)]
207 pub response: Sender<anyhow::Result<String>>,
208}
209
210#[derive(Clone, Serialize, Deserialize)]
211pub enum Request {
214 Normal(Box<NormalRequest>),
215 ReIsq(IsqType),
216 Tokenize(TokenizationRequest),
217 Detokenize(DetokenizationRequest),
218 Terminate,
221 TerminateAllSeqsNextStep,
222}
223
224impl Debug for Request {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 match self {
227 Request::Normal(boxed_req) => {
228 let NormalRequest {
229 messages,
230 sampling_params,
231 is_streaming,
232 id,
233 ..
234 } = &**boxed_req;
235 write!(
236 f,
237 "Request {id} {{ messages: `{messages:?}`, sampling_params: {sampling_params:?}, is_streaming: {is_streaming}}}",
238 )
239 }
240 Request::ReIsq(tp) => {
241 write!(f, "Re ISQ Request {tp:?}",)
242 }
243 Request::Tokenize(req) => {
244 write!(f, "Tokenization Request {:?}", req.text)
245 }
246 Request::Detokenize(req) => {
247 write!(f, "Tokenization Request {:?}", req.tokens)
248 }
249 Request::Terminate => write!(f, "Termination Request"),
250 Request::TerminateAllSeqsNextStep => write!(f, "Terminate All Seqs Next Step"),
251 }
252 }
253}