mistralrs_core/
request.rs

1use either::Either;
2use indexmap::IndexMap;
3use mistralrs_quant::IsqType;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::{
8    response::Response,
9    sampler::SamplingParams,
10    tools::{Tool, ToolChoice},
11    CustomLogitsProcessor, DiffusionGenerationParams,
12};
13use std::{fmt::Debug, sync::Arc};
14use tokio::sync::mpsc::Sender;
15
16pub type LlguidanceGrammar = llguidance::api::TopLevelGrammar;
17
18#[derive(Clone, Serialize, Deserialize)]
19/// Control the constraint with llguidance.
20pub enum Constraint {
21    Regex(String),
22    Lark(String),
23    JsonSchema(serde_json::Value),
24    Llguidance(LlguidanceGrammar),
25    None,
26}
27
28#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
29#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
30/// Image generation response format
31pub enum ImageGenerationResponseFormat {
32    Url,
33    B64Json,
34}
35
36pub type MessageContent = Either<String, Vec<IndexMap<String, Value>>>;
37
38#[derive(Clone, Debug, Serialize, Deserialize)]
39/// Message or messages for a [`Request`].
40pub enum RequestMessage {
41    Chat(Vec<IndexMap<String, MessageContent>>),
42    Completion {
43        text: String,
44        echo_prompt: bool,
45        best_of: Option<usize>,
46    },
47    CompletionTokens(Vec<u32>),
48    VisionChat {
49        #[serde(skip)] // TODO!!!!
50        images: Vec<image::DynamicImage>,
51        messages: Vec<IndexMap<String, MessageContent>>,
52    },
53    ImageGeneration {
54        prompt: String,
55        format: ImageGenerationResponseFormat,
56        generation_params: DiffusionGenerationParams,
57    },
58}
59
60fn default_responder<T>() -> Sender<T> {
61    let (sender, _) = tokio::sync::mpsc::channel(1);
62    sender
63}
64
65#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
66#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
67pub enum SearchContextSize {
68    #[serde(rename = "low")]
69    Low,
70    #[default]
71    #[serde(rename = "medium")]
72    Medium,
73    #[serde(rename = "high")]
74    High,
75}
76
77#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
78#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
79pub struct ApproximateUserLocation {
80    pub city: String,
81    pub country: String,
82    pub region: String,
83    pub timezone: String,
84}
85
86#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
87#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
88#[serde(tag = "type")]
89pub enum WebSearchUserLocation {
90    #[serde(rename = "approximate")]
91    Approximate {
92        approximate: ApproximateUserLocation,
93    },
94}
95
96#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
97#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
98pub struct WebSearchOptions {
99    pub search_context_size: Option<SearchContextSize>,
100    pub user_location: Option<WebSearchUserLocation>,
101}
102
103#[derive(Clone, Serialize, Deserialize)]
104/// A normal request request to the `MistralRs`.
105/// - `messages`: Messages for the request
106/// - `sampling_params`: Sampling parameters for generation
107/// - `response`: Object to send the result through
108/// - `return_logprobs`: Whether to return logprobs
109/// - `is_streaming`: Control whether the request is streaming, if so chunk responses will be sent
110/// - `id`: Request ID
111/// - `constraint`: Constraint to use during generation
112/// - `suffix`: Suffix to add
113/// - `tools`: Tools available in this request
114/// - `tool_choice`: Choice of tools
115/// - `logits_processors`: Custom logits processors. Order of application:
116///     1) Apply penalties from `sampling_params`
117///     2) Apply these custom logits processors sequentially
118///     3) Apply temperature and softmax
119///     4) Sample the next token (topk, topp, minp, etc)
120/// - `return_raw_logits`: Return raw logits.
121pub struct NormalRequest {
122    pub messages: RequestMessage,
123    pub sampling_params: SamplingParams,
124    #[serde(default = "default_responder")]
125    #[serde(skip)]
126    pub response: Sender<Response>,
127    pub return_logprobs: bool,
128    pub is_streaming: bool,
129    pub id: usize,
130    pub constraint: Constraint,
131    pub suffix: Option<String>,
132    pub tools: Option<Vec<Tool>>,
133    pub tool_choice: Option<ToolChoice>,
134    #[serde(skip)]
135    pub logits_processors: Option<Vec<Arc<dyn CustomLogitsProcessor>>>,
136    pub return_raw_logits: bool,
137    pub web_search_options: Option<WebSearchOptions>,
138}
139
140impl NormalRequest {
141    pub fn new_simple(
142        messages: RequestMessage,
143        sampling_params: SamplingParams,
144        response: Sender<Response>,
145        id: usize,
146        tools: Option<Vec<Tool>>,
147        tool_choice: Option<ToolChoice>,
148    ) -> Self {
149        Self {
150            messages,
151            sampling_params,
152            response,
153            id,
154            tools,
155            tool_choice,
156            return_logprobs: false,
157            is_streaming: false,
158            constraint: Constraint::None,
159            suffix: None,
160            logits_processors: None,
161            return_raw_logits: false,
162            web_search_options: None,
163        }
164    }
165}
166
167#[derive(Clone, Serialize, Deserialize)]
168/// Request to tokenize some messages or some text.
169/// - `add_generation_prompt` is only applicable if chat messages are provided and not a raw string.
170pub struct TokenizationRequest {
171    pub text: Either<Vec<IndexMap<String, MessageContent>>, String>,
172    pub tools: Option<Vec<Tool>>,
173    pub add_generation_prompt: bool,
174    pub add_special_tokens: bool,
175    #[serde(default = "default_responder")]
176    #[serde(skip)]
177    pub response: Sender<anyhow::Result<Vec<u32>>>,
178}
179
180#[derive(Clone, Serialize, Deserialize)]
181/// Request to detokenize some text.
182pub struct DetokenizationRequest {
183    pub tokens: Vec<u32>,
184    pub skip_special_tokens: bool,
185    #[serde(default = "default_responder")]
186    #[serde(skip)]
187    pub response: Sender<anyhow::Result<String>>,
188}
189
190#[derive(Clone, Serialize, Deserialize)]
191/// A request to the Engine, encapsulating the various parameters as well as
192/// the `mpsc` response `Sender` used to return the [`Response`].
193pub enum Request {
194    Normal(NormalRequest),
195    ReIsq(IsqType),
196    Tokenize(TokenizationRequest),
197    Detokenize(DetokenizationRequest),
198    // Sending a terminate request causes the `run` function to return to the thread created in `MistralRs::new`,
199    // and then Engine will be dropped.
200    Terminate,
201    TerminateAllSeqsNextStep,
202}
203
204impl Debug for Request {
205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206        match self {
207            Request::Normal(NormalRequest {
208                messages,
209                sampling_params,
210                is_streaming,
211                id,
212                ..
213            }) => {
214                write!(
215                    f,
216                    "Request {id} {{ messages: `{messages:?}`, sampling_params: {sampling_params:?}, is_streaming: {is_streaming}}}",
217                )
218            }
219            Request::ReIsq(tp) => {
220                write!(f, "Re ISQ Request {tp:?}",)
221            }
222            Request::Tokenize(req) => {
223                write!(f, "Tokenization Request {:?}", req.text)
224            }
225            Request::Detokenize(req) => {
226                write!(f, "Tokenization Request {:?}", req.tokens)
227            }
228            Request::Terminate => write!(f, "Termination Request"),
229            Request::TerminateAllSeqsNextStep => write!(f, "Terminate All Seqs Next Step"),
230        }
231    }
232}