mistralrs_core/
request.rs

1use either::Either;
2use indexmap::IndexMap;
3use mistralrs_audio::AudioInput;
4use mistralrs_quant::IsqType;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8use crate::{
9    response::Response, sampler::SamplingParams, tools::ToolChoice, CustomLogitsProcessor,
10    DiffusionGenerationParams, Tool,
11};
12use std::{fmt::Debug, sync::Arc};
13use tokio::sync::mpsc::Sender;
14
15pub type LlguidanceGrammar = llguidance::api::TopLevelGrammar;
16
17#[derive(Clone, Serialize, Deserialize)]
18/// Control the constraint with llguidance.
19pub enum Constraint {
20    Regex(String),
21    Lark(String),
22    JsonSchema(serde_json::Value),
23    Llguidance(LlguidanceGrammar),
24    None,
25}
26
27#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
28#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
29#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
30/// Image generation response format
31pub enum ImageGenerationResponseFormat {
32    Url,
33    B64Json,
34}
35
36pub type MessageContent = Either<String, Vec<IndexMap<String, Value>>>;
37
38#[derive(Clone, Debug, Serialize, Deserialize)]
39/// Message or messages for a [`Request`].
40pub enum RequestMessage {
41    Chat {
42        messages: Vec<IndexMap<String, MessageContent>>,
43        enable_thinking: Option<bool>,
44    },
45    Completion {
46        text: String,
47        echo_prompt: bool,
48        best_of: Option<usize>,
49    },
50    CompletionTokens(Vec<u32>),
51    VisionChat {
52        #[serde(skip)] // TODO
53        images: Vec<image::DynamicImage>,
54        #[serde(skip)] // TODO
55        audios: Vec<AudioInput>,
56        messages: Vec<IndexMap<String, MessageContent>>,
57        enable_thinking: Option<bool>,
58    },
59    ImageGeneration {
60        prompt: String,
61        format: ImageGenerationResponseFormat,
62        generation_params: DiffusionGenerationParams,
63    },
64    SpeechGeneration {
65        prompt: String,
66    },
67}
68
69fn default_responder<T>() -> Sender<T> {
70    let (sender, _) = tokio::sync::mpsc::channel(1);
71    sender
72}
73
74#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
75#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
76#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
77pub enum SearchContextSize {
78    #[serde(rename = "low")]
79    Low,
80    #[default]
81    #[serde(rename = "medium")]
82    Medium,
83    #[serde(rename = "high")]
84    High,
85}
86
87#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
88#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
89#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
90pub struct ApproximateUserLocation {
91    pub city: String,
92    pub country: String,
93    pub region: String,
94    pub timezone: String,
95}
96
97#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
98#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
99#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
100#[serde(tag = "type")]
101pub enum WebSearchUserLocation {
102    #[serde(rename = "approximate")]
103    Approximate {
104        approximate: ApproximateUserLocation,
105    },
106}
107
108#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
109#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
110#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
111pub struct WebSearchOptions {
112    pub search_context_size: Option<SearchContextSize>,
113    pub user_location: Option<WebSearchUserLocation>,
114    /// Override the description for the search tool.
115    pub search_description: Option<String>,
116    /// Override the description for the extraction tool.
117    pub extract_description: Option<String>,
118}
119
120#[derive(Clone, Serialize, Deserialize)]
121/// A normal request request to the `MistralRs`.
122/// - `messages`: Messages for the request
123/// - `sampling_params`: Sampling parameters for generation
124/// - `response`: Object to send the result through
125/// - `return_logprobs`: Whether to return logprobs
126/// - `is_streaming`: Control whether the request is streaming, if so chunk responses will be sent
127/// - `id`: Request ID
128/// - `constraint`: Constraint to use during generation
129/// - `suffix`: Suffix to add
130/// - `tools`: Tools available in this request
131/// - `tool_choice`: Choice of tools
132/// - `logits_processors`: Custom logits processors. Order of application:
133///     1) Apply penalties from `sampling_params`
134///     2) Apply these custom logits processors sequentially
135///     3) Apply temperature and softmax
136///     4) Sample the next token (topk, topp, minp, etc)
137/// - `return_raw_logits`: Return raw logits.
138pub struct NormalRequest {
139    pub messages: RequestMessage,
140    pub sampling_params: SamplingParams,
141    #[serde(default = "default_responder")]
142    #[serde(skip)]
143    pub response: Sender<Response>,
144    pub return_logprobs: bool,
145    pub is_streaming: bool,
146    pub id: usize,
147    pub constraint: Constraint,
148    pub suffix: Option<String>,
149    pub tools: Option<Vec<Tool>>,
150    pub tool_choice: Option<ToolChoice>,
151    #[serde(skip)]
152    pub logits_processors: Option<Vec<Arc<dyn CustomLogitsProcessor>>>,
153    pub return_raw_logits: bool,
154    pub web_search_options: Option<WebSearchOptions>,
155    pub model_id: Option<String>,
156}
157
158impl NormalRequest {
159    pub fn new_simple(
160        messages: RequestMessage,
161        sampling_params: SamplingParams,
162        response: Sender<Response>,
163        id: usize,
164        tools: Option<Vec<Tool>>,
165        tool_choice: Option<ToolChoice>,
166    ) -> Self {
167        Self {
168            messages,
169            sampling_params,
170            response,
171            id,
172            tools,
173            tool_choice,
174            return_logprobs: false,
175            is_streaming: false,
176            constraint: Constraint::None,
177            suffix: None,
178            logits_processors: None,
179            return_raw_logits: false,
180            web_search_options: None,
181            model_id: None,
182        }
183    }
184}
185
186#[derive(Clone, Serialize, Deserialize)]
187/// Request to tokenize some messages or some text.
188/// - `add_generation_prompt` is only applicable if chat messages are provided and not a raw string.
189pub struct TokenizationRequest {
190    pub text: Either<Vec<IndexMap<String, MessageContent>>, String>,
191    pub tools: Option<Vec<Tool>>,
192    pub add_generation_prompt: bool,
193    pub add_special_tokens: bool,
194    pub enable_thinking: Option<bool>,
195    #[serde(default = "default_responder")]
196    #[serde(skip)]
197    pub response: Sender<anyhow::Result<Vec<u32>>>,
198}
199
200#[derive(Clone, Serialize, Deserialize)]
201/// Request to detokenize some text.
202pub struct DetokenizationRequest {
203    pub tokens: Vec<u32>,
204    pub skip_special_tokens: bool,
205    #[serde(default = "default_responder")]
206    #[serde(skip)]
207    pub response: Sender<anyhow::Result<String>>,
208}
209
210#[derive(Clone, Serialize, Deserialize)]
211/// A request to the Engine, encapsulating the various parameters as well as
212/// the `mpsc` response `Sender` used to return the [`Response`].
213pub enum Request {
214    Normal(Box<NormalRequest>),
215    ReIsq(IsqType),
216    Tokenize(TokenizationRequest),
217    Detokenize(DetokenizationRequest),
218    // Sending a terminate request causes the `run` function to return to the thread created in `MistralRs::new`,
219    // and then Engine will be dropped.
220    Terminate,
221    TerminateAllSeqsNextStep,
222}
223
224impl Debug for Request {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        match self {
227            Request::Normal(boxed_req) => {
228                let NormalRequest {
229                    messages,
230                    sampling_params,
231                    is_streaming,
232                    id,
233                    ..
234                } = &**boxed_req;
235                write!(
236                    f,
237                    "Request {id} {{ messages: `{messages:?}`, sampling_params: {sampling_params:?}, is_streaming: {is_streaming}}}",
238                )
239            }
240            Request::ReIsq(tp) => {
241                write!(f, "Re ISQ Request {tp:?}",)
242            }
243            Request::Tokenize(req) => {
244                write!(f, "Tokenization Request {:?}", req.text)
245            }
246            Request::Detokenize(req) => {
247                write!(f, "Tokenization Request {:?}", req.tokens)
248            }
249            Request::Terminate => write!(f, "Termination Request"),
250            Request::TerminateAllSeqsNextStep => write!(f, "Terminate All Seqs Next Step"),
251        }
252    }
253}