1use either::Either;
2use indexmap::IndexMap;
3use mistralrs_quant::IsqType;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::{
8 response::Response,
9 sampler::SamplingParams,
10 tools::{Tool, ToolChoice},
11 CustomLogitsProcessor, DiffusionGenerationParams,
12};
13use std::{fmt::Debug, sync::Arc};
14use tokio::sync::mpsc::Sender;
15
16pub type LlguidanceGrammar = llguidance::api::TopLevelGrammar;
17
18#[derive(Clone, Serialize, Deserialize)]
19pub enum Constraint {
21 Regex(String),
22 Lark(String),
23 JsonSchema(serde_json::Value),
24 Llguidance(LlguidanceGrammar),
25 None,
26}
27
28#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
29#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
30#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
31pub enum ImageGenerationResponseFormat {
33 Url,
34 B64Json,
35}
36
37pub type MessageContent = Either<String, Vec<IndexMap<String, Value>>>;
38
39#[derive(Clone, Debug, Serialize, Deserialize)]
40pub enum RequestMessage {
42 Chat {
43 messages: Vec<IndexMap<String, MessageContent>>,
44 enable_thinking: Option<bool>,
45 },
46 Completion {
47 text: String,
48 echo_prompt: bool,
49 best_of: Option<usize>,
50 },
51 CompletionTokens(Vec<u32>),
52 VisionChat {
53 #[serde(skip)] images: Vec<image::DynamicImage>,
55 messages: Vec<IndexMap<String, MessageContent>>,
56 enable_thinking: Option<bool>,
57 },
58 ImageGeneration {
59 prompt: String,
60 format: ImageGenerationResponseFormat,
61 generation_params: DiffusionGenerationParams,
62 },
63 SpeechGeneration {
64 prompt: String,
65 },
66}
67
68fn default_responder<T>() -> Sender<T> {
69 let (sender, _) = tokio::sync::mpsc::channel(1);
70 sender
71}
72
73#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
74#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
75#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
76pub enum SearchContextSize {
77 #[serde(rename = "low")]
78 Low,
79 #[default]
80 #[serde(rename = "medium")]
81 Medium,
82 #[serde(rename = "high")]
83 High,
84}
85
86#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
87#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
88#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
89pub struct ApproximateUserLocation {
90 pub city: String,
91 pub country: String,
92 pub region: String,
93 pub timezone: String,
94}
95
96#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
97#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
98#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
99#[serde(tag = "type")]
100pub enum WebSearchUserLocation {
101 #[serde(rename = "approximate")]
102 Approximate {
103 approximate: ApproximateUserLocation,
104 },
105}
106
107#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
108#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
109#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
110pub struct WebSearchOptions {
111 pub search_context_size: Option<SearchContextSize>,
112 pub user_location: Option<WebSearchUserLocation>,
113}
114
115#[derive(Clone, Serialize, Deserialize)]
116pub struct NormalRequest {
134 pub messages: RequestMessage,
135 pub sampling_params: SamplingParams,
136 #[serde(default = "default_responder")]
137 #[serde(skip)]
138 pub response: Sender<Response>,
139 pub return_logprobs: bool,
140 pub is_streaming: bool,
141 pub id: usize,
142 pub constraint: Constraint,
143 pub suffix: Option<String>,
144 pub tools: Option<Vec<Tool>>,
145 pub tool_choice: Option<ToolChoice>,
146 #[serde(skip)]
147 pub logits_processors: Option<Vec<Arc<dyn CustomLogitsProcessor>>>,
148 pub return_raw_logits: bool,
149 pub web_search_options: Option<WebSearchOptions>,
150}
151
152impl NormalRequest {
153 pub fn new_simple(
154 messages: RequestMessage,
155 sampling_params: SamplingParams,
156 response: Sender<Response>,
157 id: usize,
158 tools: Option<Vec<Tool>>,
159 tool_choice: Option<ToolChoice>,
160 ) -> Self {
161 Self {
162 messages,
163 sampling_params,
164 response,
165 id,
166 tools,
167 tool_choice,
168 return_logprobs: false,
169 is_streaming: false,
170 constraint: Constraint::None,
171 suffix: None,
172 logits_processors: None,
173 return_raw_logits: false,
174 web_search_options: None,
175 }
176 }
177}
178
179#[derive(Clone, Serialize, Deserialize)]
180pub struct TokenizationRequest {
183 pub text: Either<Vec<IndexMap<String, MessageContent>>, String>,
184 pub tools: Option<Vec<Tool>>,
185 pub add_generation_prompt: bool,
186 pub add_special_tokens: bool,
187 pub enable_thinking: Option<bool>,
188 #[serde(default = "default_responder")]
189 #[serde(skip)]
190 pub response: Sender<anyhow::Result<Vec<u32>>>,
191}
192
193#[derive(Clone, Serialize, Deserialize)]
194pub struct DetokenizationRequest {
196 pub tokens: Vec<u32>,
197 pub skip_special_tokens: bool,
198 #[serde(default = "default_responder")]
199 #[serde(skip)]
200 pub response: Sender<anyhow::Result<String>>,
201}
202
203#[derive(Clone, Serialize, Deserialize)]
204pub enum Request {
207 Normal(Box<NormalRequest>),
208 ReIsq(IsqType),
209 Tokenize(TokenizationRequest),
210 Detokenize(DetokenizationRequest),
211 Terminate,
214 TerminateAllSeqsNextStep,
215}
216
217impl Debug for Request {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 match self {
220 Request::Normal(boxed_req) => {
221 let NormalRequest {
222 messages,
223 sampling_params,
224 is_streaming,
225 id,
226 ..
227 } = &**boxed_req;
228 write!(
229 f,
230 "Request {id} {{ messages: `{messages:?}`, sampling_params: {sampling_params:?}, is_streaming: {is_streaming}}}",
231 )
232 }
233 Request::ReIsq(tp) => {
234 write!(f, "Re ISQ Request {tp:?}",)
235 }
236 Request::Tokenize(req) => {
237 write!(f, "Tokenization Request {:?}", req.text)
238 }
239 Request::Detokenize(req) => {
240 write!(f, "Tokenization Request {:?}", req.tokens)
241 }
242 Request::Terminate => write!(f, "Termination Request"),
243 Request::TerminateAllSeqsNextStep => write!(f, "Terminate All Seqs Next Step"),
244 }
245 }
246}