mistralrs_core/
request.rs

1use either::Either;
2use indexmap::IndexMap;
3use mistralrs_audio::AudioInput;
4use mistralrs_quant::IsqType;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8use crate::{
9    response::Response, sampler::SamplingParams, tools::ToolChoice, CustomLogitsProcessor,
10    DiffusionGenerationParams, Tool,
11};
12use std::{fmt::Debug, sync::Arc};
13use tokio::sync::mpsc::Sender;
14
15pub type LlguidanceGrammar = llguidance::api::TopLevelGrammar;
16
17#[derive(Clone, Serialize, Deserialize)]
18/// Control the constraint with llguidance.
19pub enum Constraint {
20    Regex(String),
21    Lark(String),
22    JsonSchema(serde_json::Value),
23    Llguidance(LlguidanceGrammar),
24    None,
25}
26
27#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
28#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
29#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
30/// Image generation response format
31pub enum ImageGenerationResponseFormat {
32    Url,
33    B64Json,
34}
35
36pub type MessageContent = Either<String, Vec<IndexMap<String, Value>>>;
37
38#[derive(Clone, Debug, Serialize, Deserialize)]
39/// Message or messages for a [`Request`].
40pub enum RequestMessage {
41    Chat {
42        messages: Vec<IndexMap<String, MessageContent>>,
43        enable_thinking: Option<bool>,
44    },
45    Completion {
46        text: String,
47        echo_prompt: bool,
48        best_of: Option<usize>,
49    },
50    CompletionTokens(Vec<u32>),
51    VisionChat {
52        #[serde(skip)] // TODO
53        images: Vec<image::DynamicImage>,
54        #[serde(skip)] // TODO
55        audios: Vec<AudioInput>,
56        messages: Vec<IndexMap<String, MessageContent>>,
57        enable_thinking: Option<bool>,
58    },
59    ImageGeneration {
60        prompt: String,
61        format: ImageGenerationResponseFormat,
62        generation_params: DiffusionGenerationParams,
63    },
64    SpeechGeneration {
65        prompt: String,
66    },
67    Embedding {
68        prompt: String,
69    },
70    EmbeddingTokens {
71        prompt: Vec<u32>,
72    },
73}
74
75fn default_responder<T>() -> Sender<T> {
76    let (sender, _) = tokio::sync::mpsc::channel(1);
77    sender
78}
79
80#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
81#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
82#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
83pub enum SearchContextSize {
84    #[serde(rename = "low")]
85    Low,
86    #[default]
87    #[serde(rename = "medium")]
88    Medium,
89    #[serde(rename = "high")]
90    High,
91}
92
93#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
94#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
95#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
96pub struct ApproximateUserLocation {
97    pub city: String,
98    pub country: String,
99    pub region: String,
100    pub timezone: String,
101}
102
103#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
104#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
105#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
106#[serde(tag = "type")]
107pub enum WebSearchUserLocation {
108    #[serde(rename = "approximate")]
109    Approximate {
110        approximate: ApproximateUserLocation,
111    },
112}
113
114#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
115#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
116#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
117pub struct WebSearchOptions {
118    pub search_context_size: Option<SearchContextSize>,
119    pub user_location: Option<WebSearchUserLocation>,
120    /// Override the description for the search tool.
121    pub search_description: Option<String>,
122    /// Override the description for the extraction tool.
123    pub extract_description: Option<String>,
124}
125
126#[derive(Clone, Serialize, Deserialize)]
127/// A normal request request to the `MistralRs`.
128/// - `messages`: Messages for the request
129/// - `sampling_params`: Sampling parameters for generation
130/// - `response`: Object to send the result through
131/// - `return_logprobs`: Whether to return logprobs
132/// - `is_streaming`: Control whether the request is streaming, if so chunk responses will be sent
133/// - `id`: Request ID
134/// - `constraint`: Constraint to use during generation
135/// - `suffix`: Suffix to add
136/// - `tools`: Tools available in this request
137/// - `tool_choice`: Choice of tools
138/// - `logits_processors`: Custom logits processors. Order of application:
139///     1) Apply penalties from `sampling_params`
140///     2) Apply these custom logits processors sequentially
141///     3) Apply temperature and softmax
142///     4) Sample the next token (topk, topp, minp, etc)
143/// - `return_raw_logits`: Return raw logits.
144/// - `truncate_sequence`: Whether to truncate the prompt if it exceeds the model's maximum context length.
145pub struct NormalRequest {
146    pub messages: RequestMessage,
147    pub sampling_params: SamplingParams,
148    #[serde(default = "default_responder")]
149    #[serde(skip)]
150    pub response: Sender<Response>,
151    pub return_logprobs: bool,
152    pub is_streaming: bool,
153    pub id: usize,
154    pub constraint: Constraint,
155    pub suffix: Option<String>,
156    pub tools: Option<Vec<Tool>>,
157    pub tool_choice: Option<ToolChoice>,
158    #[serde(skip)]
159    pub logits_processors: Option<Vec<Arc<dyn CustomLogitsProcessor>>>,
160    pub return_raw_logits: bool,
161    pub web_search_options: Option<WebSearchOptions>,
162    pub model_id: Option<String>,
163    #[serde(default)]
164    pub truncate_sequence: bool,
165}
166
167impl NormalRequest {
168    pub fn new_simple(
169        messages: RequestMessage,
170        sampling_params: SamplingParams,
171        response: Sender<Response>,
172        id: usize,
173        tools: Option<Vec<Tool>>,
174        tool_choice: Option<ToolChoice>,
175    ) -> Self {
176        Self {
177            messages,
178            sampling_params,
179            response,
180            id,
181            tools,
182            tool_choice,
183            return_logprobs: false,
184            is_streaming: false,
185            constraint: Constraint::None,
186            suffix: None,
187            logits_processors: None,
188            return_raw_logits: false,
189            web_search_options: None,
190            model_id: None,
191            truncate_sequence: false,
192        }
193    }
194}
195
196#[derive(Clone, Serialize, Deserialize)]
197/// Request to tokenize some messages or some text.
198/// - `add_generation_prompt` is only applicable if chat messages are provided and not a raw string.
199pub struct TokenizationRequest {
200    pub text: Either<Vec<IndexMap<String, MessageContent>>, String>,
201    pub tools: Option<Vec<Tool>>,
202    pub add_generation_prompt: bool,
203    pub add_special_tokens: bool,
204    pub enable_thinking: Option<bool>,
205    #[serde(default = "default_responder")]
206    #[serde(skip)]
207    pub response: Sender<anyhow::Result<Vec<u32>>>,
208}
209
210#[derive(Clone, Serialize, Deserialize)]
211/// Request to detokenize some text.
212pub struct DetokenizationRequest {
213    pub tokens: Vec<u32>,
214    pub skip_special_tokens: bool,
215    #[serde(default = "default_responder")]
216    #[serde(skip)]
217    pub response: Sender<anyhow::Result<String>>,
218}
219
220#[derive(Clone, Serialize, Deserialize)]
221/// A request to the Engine, encapsulating the various parameters as well as
222/// the `mpsc` response `Sender` used to return the [`Response`].
223pub enum Request {
224    Normal(Box<NormalRequest>),
225    ReIsq(IsqType),
226    Tokenize(TokenizationRequest),
227    Detokenize(DetokenizationRequest),
228    // Sending a terminate request causes the `run` function to return to the thread created in `MistralRs::new`,
229    // and then Engine will be dropped.
230    Terminate,
231    TerminateAllSeqsNextStep,
232}
233
234impl Debug for Request {
235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        match self {
237            Request::Normal(boxed_req) => {
238                let NormalRequest {
239                    messages,
240                    sampling_params,
241                    is_streaming,
242                    id,
243                    ..
244                } = &**boxed_req;
245                write!(
246                    f,
247                    "Request {id} {{ messages: `{messages:?}`, sampling_params: {sampling_params:?}, is_streaming: {is_streaming}}}",
248                )
249            }
250            Request::ReIsq(tp) => {
251                write!(f, "Re ISQ Request {tp:?}",)
252            }
253            Request::Tokenize(req) => {
254                write!(f, "Tokenization Request {:?}", req.text)
255            }
256            Request::Detokenize(req) => {
257                write!(f, "Tokenization Request {:?}", req.tokens)
258            }
259            Request::Terminate => write!(f, "Termination Request"),
260            Request::TerminateAllSeqsNextStep => write!(f, "Terminate All Seqs Next Step"),
261        }
262    }
263}