1use either::Either;
2use indexmap::IndexMap;
3use mistralrs_quant::IsqType;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::{
8 response::Response,
9 sampler::SamplingParams,
10 tools::{Tool, ToolChoice},
11 CustomLogitsProcessor, DiffusionGenerationParams,
12};
13use std::{fmt::Debug, sync::Arc};
14use tokio::sync::mpsc::Sender;
15
16pub type LlguidanceGrammar = llguidance::api::TopLevelGrammar;
17
18#[derive(Clone, Serialize, Deserialize)]
19pub enum Constraint {
21 Regex(String),
22 Lark(String),
23 JsonSchema(serde_json::Value),
24 Llguidance(LlguidanceGrammar),
25 None,
26}
27
28#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
29#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
30pub enum ImageGenerationResponseFormat {
32 Url,
33 B64Json,
34}
35
36pub type MessageContent = Either<String, Vec<IndexMap<String, Value>>>;
37
38#[derive(Clone, Debug, Serialize, Deserialize)]
39pub enum RequestMessage {
41 Chat(Vec<IndexMap<String, MessageContent>>),
42 Completion {
43 text: String,
44 echo_prompt: bool,
45 best_of: Option<usize>,
46 },
47 CompletionTokens(Vec<u32>),
48 VisionChat {
49 #[serde(skip)] images: Vec<image::DynamicImage>,
51 messages: Vec<IndexMap<String, MessageContent>>,
52 },
53 ImageGeneration {
54 prompt: String,
55 format: ImageGenerationResponseFormat,
56 generation_params: DiffusionGenerationParams,
57 },
58}
59
60fn default_responder<T>() -> Sender<T> {
61 let (sender, _) = tokio::sync::mpsc::channel(1);
62 sender
63}
64
65#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
66#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
67pub enum SearchContextSize {
68 #[serde(rename = "low")]
69 Low,
70 #[default]
71 #[serde(rename = "medium")]
72 Medium,
73 #[serde(rename = "high")]
74 High,
75}
76
77#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
78#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
79pub struct ApproximateUserLocation {
80 pub city: String,
81 pub country: String,
82 pub region: String,
83 pub timezone: String,
84}
85
86#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
87#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
88#[serde(tag = "type")]
89pub enum WebSearchUserLocation {
90 #[serde(rename = "approximate")]
91 Approximate {
92 approximate: ApproximateUserLocation,
93 },
94}
95
96#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
97#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
98pub struct WebSearchOptions {
99 pub search_context_size: Option<SearchContextSize>,
100 pub user_location: Option<WebSearchUserLocation>,
101}
102
103#[derive(Clone, Serialize, Deserialize)]
104pub struct NormalRequest {
122 pub messages: RequestMessage,
123 pub sampling_params: SamplingParams,
124 #[serde(default = "default_responder")]
125 #[serde(skip)]
126 pub response: Sender<Response>,
127 pub return_logprobs: bool,
128 pub is_streaming: bool,
129 pub id: usize,
130 pub constraint: Constraint,
131 pub suffix: Option<String>,
132 pub tools: Option<Vec<Tool>>,
133 pub tool_choice: Option<ToolChoice>,
134 #[serde(skip)]
135 pub logits_processors: Option<Vec<Arc<dyn CustomLogitsProcessor>>>,
136 pub return_raw_logits: bool,
137 pub web_search_options: Option<WebSearchOptions>,
138}
139
140impl NormalRequest {
141 pub fn new_simple(
142 messages: RequestMessage,
143 sampling_params: SamplingParams,
144 response: Sender<Response>,
145 id: usize,
146 tools: Option<Vec<Tool>>,
147 tool_choice: Option<ToolChoice>,
148 ) -> Self {
149 Self {
150 messages,
151 sampling_params,
152 response,
153 id,
154 tools,
155 tool_choice,
156 return_logprobs: false,
157 is_streaming: false,
158 constraint: Constraint::None,
159 suffix: None,
160 logits_processors: None,
161 return_raw_logits: false,
162 web_search_options: None,
163 }
164 }
165}
166
167#[derive(Clone, Serialize, Deserialize)]
168pub struct TokenizationRequest {
171 pub text: Either<Vec<IndexMap<String, MessageContent>>, String>,
172 pub tools: Option<Vec<Tool>>,
173 pub add_generation_prompt: bool,
174 pub add_special_tokens: bool,
175 #[serde(default = "default_responder")]
176 #[serde(skip)]
177 pub response: Sender<anyhow::Result<Vec<u32>>>,
178}
179
180#[derive(Clone, Serialize, Deserialize)]
181pub struct DetokenizationRequest {
183 pub tokens: Vec<u32>,
184 pub skip_special_tokens: bool,
185 #[serde(default = "default_responder")]
186 #[serde(skip)]
187 pub response: Sender<anyhow::Result<String>>,
188}
189
190#[derive(Clone, Serialize, Deserialize)]
191pub enum Request {
194 Normal(NormalRequest),
195 ReIsq(IsqType),
196 Tokenize(TokenizationRequest),
197 Detokenize(DetokenizationRequest),
198 Terminate,
201 TerminateAllSeqsNextStep,
202}
203
204impl Debug for Request {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 match self {
207 Request::Normal(NormalRequest {
208 messages,
209 sampling_params,
210 is_streaming,
211 id,
212 ..
213 }) => {
214 write!(
215 f,
216 "Request {id} {{ messages: `{messages:?}`, sampling_params: {sampling_params:?}, is_streaming: {is_streaming}}}",
217 )
218 }
219 Request::ReIsq(tp) => {
220 write!(f, "Re ISQ Request {tp:?}",)
221 }
222 Request::Tokenize(req) => {
223 write!(f, "Tokenization Request {:?}", req.text)
224 }
225 Request::Detokenize(req) => {
226 write!(f, "Tokenization Request {:?}", req.tokens)
227 }
228 Request::Terminate => write!(f, "Termination Request"),
229 Request::TerminateAllSeqsNextStep => write!(f, "Terminate All Seqs Next Step"),
230 }
231 }
232}