mistralrs_core/
request.rs

1use either::Either;
2use indexmap::IndexMap;
3use mistralrs_audio::AudioInput;
4use mistralrs_quant::IsqType;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8use crate::{
9    response::Response, sampler::SamplingParams, tools::ToolChoice, CustomLogitsProcessor,
10    DiffusionGenerationParams, Tool,
11};
12use std::{fmt::Debug, sync::Arc};
13use tokio::sync::mpsc::Sender;
14
15pub type LlguidanceGrammar = llguidance::api::TopLevelGrammar;
16
17#[derive(Clone, Serialize, Deserialize)]
18/// Control the constraint with llguidance.
19pub enum Constraint {
20    Regex(String),
21    Lark(String),
22    JsonSchema(serde_json::Value),
23    Llguidance(LlguidanceGrammar),
24    None,
25}
26
27#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
28#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
29#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
30/// Image generation response format
31pub enum ImageGenerationResponseFormat {
32    Url,
33    B64Json,
34}
35
36pub type MessageContent = Either<String, Vec<IndexMap<String, Value>>>;
37
38/// Reasoning effort level for models that support it (e.g., GPT-OSS with Harmony format).
39/// Controls the depth of reasoning/analysis in the model's response.
40#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Default)]
41#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
42#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
43#[serde(rename_all = "lowercase")]
44pub enum ReasoningEffort {
45    /// Minimal reasoning, faster responses
46    Low,
47    /// Balanced reasoning depth
48    #[default]
49    Medium,
50    /// Deep reasoning, more thorough analysis
51    High,
52}
53
54impl ReasoningEffort {
55    /// Convert to string representation for chat template
56    pub fn as_str(&self) -> &'static str {
57        match self {
58            Self::Low => "low",
59            Self::Medium => "medium",
60            Self::High => "high",
61        }
62    }
63}
64
65#[derive(Clone, Debug, Serialize, Deserialize)]
66/// Message or messages for a [`Request`].
67pub enum RequestMessage {
68    Chat {
69        messages: Vec<IndexMap<String, MessageContent>>,
70        enable_thinking: Option<bool>,
71        /// Reasoning effort level for Harmony-format models
72        reasoning_effort: Option<ReasoningEffort>,
73    },
74    Completion {
75        text: String,
76        echo_prompt: bool,
77        best_of: Option<usize>,
78    },
79    CompletionTokens(Vec<u32>),
80    VisionChat {
81        #[serde(skip)] // TODO
82        images: Vec<image::DynamicImage>,
83        #[serde(skip)] // TODO
84        audios: Vec<AudioInput>,
85        messages: Vec<IndexMap<String, MessageContent>>,
86        enable_thinking: Option<bool>,
87        /// Reasoning effort level for Harmony-format models
88        reasoning_effort: Option<ReasoningEffort>,
89    },
90    ImageGeneration {
91        prompt: String,
92        format: ImageGenerationResponseFormat,
93        generation_params: DiffusionGenerationParams,
94    },
95    SpeechGeneration {
96        prompt: String,
97    },
98    Embedding {
99        prompt: String,
100    },
101    EmbeddingTokens {
102        prompt: Vec<u32>,
103    },
104}
105
106fn default_responder<T>() -> Sender<T> {
107    let (sender, _) = tokio::sync::mpsc::channel(1);
108    sender
109}
110
111#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
112#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
113#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
114pub enum SearchContextSize {
115    #[serde(rename = "low")]
116    Low,
117    #[default]
118    #[serde(rename = "medium")]
119    Medium,
120    #[serde(rename = "high")]
121    High,
122}
123
124#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
125#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
126#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
127pub struct ApproximateUserLocation {
128    pub city: String,
129    pub country: String,
130    pub region: String,
131    pub timezone: String,
132}
133
134#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
135#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
136#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
137#[serde(tag = "type")]
138pub enum WebSearchUserLocation {
139    #[serde(rename = "approximate")]
140    Approximate {
141        approximate: ApproximateUserLocation,
142    },
143}
144
145#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
146#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
147#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
148pub struct WebSearchOptions {
149    pub search_context_size: Option<SearchContextSize>,
150    pub user_location: Option<WebSearchUserLocation>,
151    /// Override the description for the search tool.
152    pub search_description: Option<String>,
153    /// Override the description for the extraction tool.
154    pub extract_description: Option<String>,
155}
156
157#[derive(Clone, Serialize, Deserialize)]
158/// A normal request request to the `MistralRs`.
159/// - `messages`: Messages for the request
160/// - `sampling_params`: Sampling parameters for generation
161/// - `response`: Object to send the result through
162/// - `return_logprobs`: Whether to return logprobs
163/// - `is_streaming`: Control whether the request is streaming, if so chunk responses will be sent
164/// - `id`: Request ID
165/// - `constraint`: Constraint to use during generation
166/// - `suffix`: Suffix to add
167/// - `tools`: Tools available in this request
168/// - `tool_choice`: Choice of tools
169/// - `logits_processors`: Custom logits processors. Order of application:
170///     1) Apply penalties from `sampling_params`
171///     2) Apply these custom logits processors sequentially
172///     3) Apply temperature and softmax
173///     4) Sample the next token (topk, topp, minp, etc)
174/// - `return_raw_logits`: Return raw logits.
175/// - `truncate_sequence`: Whether to truncate the prompt if it exceeds the model's maximum context length.
176pub struct NormalRequest {
177    pub messages: RequestMessage,
178    pub sampling_params: SamplingParams,
179    #[serde(default = "default_responder")]
180    #[serde(skip)]
181    pub response: Sender<Response>,
182    pub return_logprobs: bool,
183    pub is_streaming: bool,
184    pub id: usize,
185    pub constraint: Constraint,
186    pub suffix: Option<String>,
187    pub tools: Option<Vec<Tool>>,
188    pub tool_choice: Option<ToolChoice>,
189    #[serde(skip)]
190    pub logits_processors: Option<Vec<Arc<dyn CustomLogitsProcessor>>>,
191    pub return_raw_logits: bool,
192    pub web_search_options: Option<WebSearchOptions>,
193    pub model_id: Option<String>,
194    #[serde(default)]
195    pub truncate_sequence: bool,
196}
197
198impl NormalRequest {
199    pub fn new_simple(
200        messages: RequestMessage,
201        sampling_params: SamplingParams,
202        response: Sender<Response>,
203        id: usize,
204        tools: Option<Vec<Tool>>,
205        tool_choice: Option<ToolChoice>,
206    ) -> Self {
207        Self {
208            messages,
209            sampling_params,
210            response,
211            id,
212            tools,
213            tool_choice,
214            return_logprobs: false,
215            is_streaming: false,
216            constraint: Constraint::None,
217            suffix: None,
218            logits_processors: None,
219            return_raw_logits: false,
220            web_search_options: None,
221            model_id: None,
222            truncate_sequence: false,
223        }
224    }
225}
226
227#[derive(Clone, Serialize, Deserialize)]
228/// Request to tokenize some messages or some text.
229/// - `add_generation_prompt` is only applicable if chat messages are provided and not a raw string.
230pub struct TokenizationRequest {
231    pub text: Either<Vec<IndexMap<String, MessageContent>>, String>,
232    pub tools: Option<Vec<Tool>>,
233    pub add_generation_prompt: bool,
234    pub add_special_tokens: bool,
235    pub enable_thinking: Option<bool>,
236    pub reasoning_effort: Option<ReasoningEffort>,
237    #[serde(default = "default_responder")]
238    #[serde(skip)]
239    pub response: Sender<anyhow::Result<Vec<u32>>>,
240}
241
242#[derive(Clone, Serialize, Deserialize)]
243/// Request to detokenize some text.
244pub struct DetokenizationRequest {
245    pub tokens: Vec<u32>,
246    pub skip_special_tokens: bool,
247    #[serde(default = "default_responder")]
248    #[serde(skip)]
249    pub response: Sender<anyhow::Result<String>>,
250}
251
252#[derive(Clone, Serialize, Deserialize)]
253/// A request to the Engine, encapsulating the various parameters as well as
254/// the `mpsc` response `Sender` used to return the [`Response`].
255pub enum Request {
256    Normal(Box<NormalRequest>),
257    ReIsq(IsqType),
258    Tokenize(TokenizationRequest),
259    Detokenize(DetokenizationRequest),
260    // Sending a terminate request causes the `run` function to return to the thread created in `MistralRs::new`,
261    // and then Engine will be dropped.
262    Terminate,
263    TerminateAllSeqsNextStep,
264}
265
266impl Debug for Request {
267    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268        match self {
269            Request::Normal(boxed_req) => {
270                let NormalRequest {
271                    messages,
272                    sampling_params,
273                    is_streaming,
274                    id,
275                    ..
276                } = &**boxed_req;
277                write!(
278                    f,
279                    "Request {id} {{ messages: `{messages:?}`, sampling_params: {sampling_params:?}, is_streaming: {is_streaming}}}",
280                )
281            }
282            Request::ReIsq(tp) => {
283                write!(f, "Re ISQ Request {tp:?}",)
284            }
285            Request::Tokenize(req) => {
286                write!(f, "Tokenization Request {:?}", req.text)
287            }
288            Request::Detokenize(req) => {
289                write!(f, "Tokenization Request {:?}", req.tokens)
290            }
291            Request::Terminate => write!(f, "Termination Request"),
292            Request::TerminateAllSeqsNextStep => write!(f, "Terminate All Seqs Next Step"),
293        }
294    }
295}