mistralrs_core/
request.rs

1use either::Either;
2use indexmap::IndexMap;
3use mistralrs_quant::IsqType;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::{
8    response::Response,
9    sampler::SamplingParams,
10    tools::{Tool, ToolChoice},
11    CustomLogitsProcessor, DiffusionGenerationParams,
12};
13use std::{fmt::Debug, sync::Arc};
14use tokio::sync::mpsc::Sender;
15
16pub type LlguidanceGrammar = llguidance::api::TopLevelGrammar;
17
18#[derive(Clone, Serialize, Deserialize)]
19/// Control the constraint with llguidance.
20pub enum Constraint {
21    Regex(String),
22    Lark(String),
23    JsonSchema(serde_json::Value),
24    Llguidance(LlguidanceGrammar),
25    None,
26}
27
28#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
29#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
30#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
31/// Image generation response format
32pub enum ImageGenerationResponseFormat {
33    Url,
34    B64Json,
35}
36
37pub type MessageContent = Either<String, Vec<IndexMap<String, Value>>>;
38
39#[derive(Clone, Debug, Serialize, Deserialize)]
40/// Message or messages for a [`Request`].
41pub enum RequestMessage {
42    Chat {
43        messages: Vec<IndexMap<String, MessageContent>>,
44        enable_thinking: Option<bool>,
45    },
46    Completion {
47        text: String,
48        echo_prompt: bool,
49        best_of: Option<usize>,
50    },
51    CompletionTokens(Vec<u32>),
52    VisionChat {
53        #[serde(skip)] // TODO!!!!
54        images: Vec<image::DynamicImage>,
55        messages: Vec<IndexMap<String, MessageContent>>,
56        enable_thinking: Option<bool>,
57    },
58    ImageGeneration {
59        prompt: String,
60        format: ImageGenerationResponseFormat,
61        generation_params: DiffusionGenerationParams,
62    },
63    SpeechGeneration {
64        prompt: String,
65    },
66}
67
68fn default_responder<T>() -> Sender<T> {
69    let (sender, _) = tokio::sync::mpsc::channel(1);
70    sender
71}
72
73#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
74#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
75#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
76pub enum SearchContextSize {
77    #[serde(rename = "low")]
78    Low,
79    #[default]
80    #[serde(rename = "medium")]
81    Medium,
82    #[serde(rename = "high")]
83    High,
84}
85
86#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
87#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
88#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
89pub struct ApproximateUserLocation {
90    pub city: String,
91    pub country: String,
92    pub region: String,
93    pub timezone: String,
94}
95
96#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
97#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
98#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
99#[serde(tag = "type")]
100pub enum WebSearchUserLocation {
101    #[serde(rename = "approximate")]
102    Approximate {
103        approximate: ApproximateUserLocation,
104    },
105}
106
107#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
108#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
109#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
110pub struct WebSearchOptions {
111    pub search_context_size: Option<SearchContextSize>,
112    pub user_location: Option<WebSearchUserLocation>,
113}
114
115#[derive(Clone, Serialize, Deserialize)]
116/// A normal request request to the `MistralRs`.
117/// - `messages`: Messages for the request
118/// - `sampling_params`: Sampling parameters for generation
119/// - `response`: Object to send the result through
120/// - `return_logprobs`: Whether to return logprobs
121/// - `is_streaming`: Control whether the request is streaming, if so chunk responses will be sent
122/// - `id`: Request ID
123/// - `constraint`: Constraint to use during generation
124/// - `suffix`: Suffix to add
125/// - `tools`: Tools available in this request
126/// - `tool_choice`: Choice of tools
127/// - `logits_processors`: Custom logits processors. Order of application:
128///     1) Apply penalties from `sampling_params`
129///     2) Apply these custom logits processors sequentially
130///     3) Apply temperature and softmax
131///     4) Sample the next token (topk, topp, minp, etc)
132/// - `return_raw_logits`: Return raw logits.
133pub struct NormalRequest {
134    pub messages: RequestMessage,
135    pub sampling_params: SamplingParams,
136    #[serde(default = "default_responder")]
137    #[serde(skip)]
138    pub response: Sender<Response>,
139    pub return_logprobs: bool,
140    pub is_streaming: bool,
141    pub id: usize,
142    pub constraint: Constraint,
143    pub suffix: Option<String>,
144    pub tools: Option<Vec<Tool>>,
145    pub tool_choice: Option<ToolChoice>,
146    #[serde(skip)]
147    pub logits_processors: Option<Vec<Arc<dyn CustomLogitsProcessor>>>,
148    pub return_raw_logits: bool,
149    pub web_search_options: Option<WebSearchOptions>,
150}
151
152impl NormalRequest {
153    pub fn new_simple(
154        messages: RequestMessage,
155        sampling_params: SamplingParams,
156        response: Sender<Response>,
157        id: usize,
158        tools: Option<Vec<Tool>>,
159        tool_choice: Option<ToolChoice>,
160    ) -> Self {
161        Self {
162            messages,
163            sampling_params,
164            response,
165            id,
166            tools,
167            tool_choice,
168            return_logprobs: false,
169            is_streaming: false,
170            constraint: Constraint::None,
171            suffix: None,
172            logits_processors: None,
173            return_raw_logits: false,
174            web_search_options: None,
175        }
176    }
177}
178
179#[derive(Clone, Serialize, Deserialize)]
180/// Request to tokenize some messages or some text.
181/// - `add_generation_prompt` is only applicable if chat messages are provided and not a raw string.
182pub struct TokenizationRequest {
183    pub text: Either<Vec<IndexMap<String, MessageContent>>, String>,
184    pub tools: Option<Vec<Tool>>,
185    pub add_generation_prompt: bool,
186    pub add_special_tokens: bool,
187    pub enable_thinking: Option<bool>,
188    #[serde(default = "default_responder")]
189    #[serde(skip)]
190    pub response: Sender<anyhow::Result<Vec<u32>>>,
191}
192
193#[derive(Clone, Serialize, Deserialize)]
194/// Request to detokenize some text.
195pub struct DetokenizationRequest {
196    pub tokens: Vec<u32>,
197    pub skip_special_tokens: bool,
198    #[serde(default = "default_responder")]
199    #[serde(skip)]
200    pub response: Sender<anyhow::Result<String>>,
201}
202
203#[derive(Clone, Serialize, Deserialize)]
204/// A request to the Engine, encapsulating the various parameters as well as
205/// the `mpsc` response `Sender` used to return the [`Response`].
206pub enum Request {
207    Normal(Box<NormalRequest>),
208    ReIsq(IsqType),
209    Tokenize(TokenizationRequest),
210    Detokenize(DetokenizationRequest),
211    // Sending a terminate request causes the `run` function to return to the thread created in `MistralRs::new`,
212    // and then Engine will be dropped.
213    Terminate,
214    TerminateAllSeqsNextStep,
215}
216
217impl Debug for Request {
218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219        match self {
220            Request::Normal(boxed_req) => {
221                let NormalRequest {
222                    messages,
223                    sampling_params,
224                    is_streaming,
225                    id,
226                    ..
227                } = &**boxed_req;
228                write!(
229                    f,
230                    "Request {id} {{ messages: `{messages:?}`, sampling_params: {sampling_params:?}, is_streaming: {is_streaming}}}",
231                )
232            }
233            Request::ReIsq(tp) => {
234                write!(f, "Re ISQ Request {tp:?}",)
235            }
236            Request::Tokenize(req) => {
237                write!(f, "Tokenization Request {:?}", req.text)
238            }
239            Request::Detokenize(req) => {
240                write!(f, "Tokenization Request {:?}", req.tokens)
241            }
242            Request::Terminate => write!(f, "Termination Request"),
243            Request::TerminateAllSeqsNextStep => write!(f, "Terminate All Seqs Next Step"),
244        }
245    }
246}