mistralrs_core/
response.rs

1use std::{
2    error::Error,
3    fmt::{Debug, Display},
4};
5
6use candle_core::Tensor;
7#[cfg(feature = "pyo3_macros")]
8use pyo3::{pyclass, pymethods};
9use serde::Serialize;
10
11use crate::{sampler::TopLogprob, tools::ToolCallResponse};
12
13pub const SYSTEM_FINGERPRINT: &str = "local";
14
15macro_rules! generate_repr {
16    ($t:ident) => {
17        #[cfg(feature = "pyo3_macros")]
18        #[pymethods]
19        impl $t {
20            fn __repr__(&self) -> String {
21                format!("{self:#?}")
22            }
23        }
24    };
25}
26
27#[cfg_attr(feature = "pyo3_macros", pyclass)]
28#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
29#[derive(Debug, Clone, Serialize)]
30/// Chat completion response message.
31pub struct ResponseMessage {
32    pub content: Option<String>,
33    pub role: String,
34    pub tool_calls: Option<Vec<ToolCallResponse>>,
35}
36
37generate_repr!(ResponseMessage);
38
39#[cfg_attr(feature = "pyo3_macros", pyclass)]
40#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
41#[derive(Debug, Clone, Serialize)]
42/// Delta in content for streaming response.
43pub struct Delta {
44    pub content: Option<String>,
45    pub role: String,
46    pub tool_calls: Option<Vec<ToolCallResponse>>,
47}
48
49generate_repr!(Delta);
50
51#[cfg_attr(feature = "pyo3_macros", pyclass)]
52#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
53#[derive(Debug, Clone, Serialize)]
54/// A logprob with the top logprobs for this token.
55pub struct ResponseLogprob {
56    pub token: String,
57    pub logprob: f32,
58    pub bytes: Option<Vec<u8>>,
59    pub top_logprobs: Vec<TopLogprob>,
60}
61
62generate_repr!(ResponseLogprob);
63
64#[cfg_attr(feature = "pyo3_macros", pyclass)]
65#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
66#[derive(Debug, Clone, Serialize)]
67/// Logprobs per token.
68pub struct Logprobs {
69    pub content: Option<Vec<ResponseLogprob>>,
70}
71
72generate_repr!(Logprobs);
73
74#[cfg_attr(feature = "pyo3_macros", pyclass)]
75#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
76#[derive(Debug, Clone, Serialize)]
77/// Chat completion choice.
78pub struct Choice {
79    pub finish_reason: String,
80    pub index: usize,
81    pub message: ResponseMessage,
82    pub logprobs: Option<Logprobs>,
83}
84
85generate_repr!(Choice);
86
87#[cfg_attr(feature = "pyo3_macros", pyclass)]
88#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
89#[derive(Debug, Clone, Serialize)]
90/// Chat completion streaming chunk choice.
91pub struct ChunkChoice {
92    pub finish_reason: Option<String>,
93    pub index: usize,
94    pub delta: Delta,
95    pub logprobs: Option<ResponseLogprob>,
96}
97
98generate_repr!(ChunkChoice);
99
100#[cfg_attr(feature = "pyo3_macros", pyclass)]
101#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
102#[derive(Debug, Clone, Serialize)]
103/// Chat completion streaming chunk choice.
104pub struct CompletionChunkChoice {
105    pub text: String,
106    pub index: usize,
107    pub logprobs: Option<ResponseLogprob>,
108    pub finish_reason: Option<String>,
109}
110
111generate_repr!(CompletionChunkChoice);
112
113#[cfg_attr(feature = "pyo3_macros", pyclass)]
114#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
115#[derive(Debug, Clone, Serialize)]
116/// OpenAI compatible (superset) usage during a request.
117pub struct Usage {
118    pub completion_tokens: usize,
119    pub prompt_tokens: usize,
120    pub total_tokens: usize,
121    pub avg_tok_per_sec: f32,
122    pub avg_prompt_tok_per_sec: f32,
123    pub avg_compl_tok_per_sec: f32,
124    pub total_time_sec: f32,
125    pub total_prompt_time_sec: f32,
126    pub total_completion_time_sec: f32,
127}
128
129generate_repr!(Usage);
130
131#[cfg_attr(feature = "pyo3_macros", pyclass)]
132#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
133#[derive(Debug, Clone, Serialize)]
134/// An OpenAI compatible chat completion response.
135pub struct ChatCompletionResponse {
136    pub id: String,
137    pub choices: Vec<Choice>,
138    pub created: u64,
139    pub model: String,
140    pub system_fingerprint: String,
141    pub object: String,
142    pub usage: Usage,
143}
144
145generate_repr!(ChatCompletionResponse);
146
147#[cfg_attr(feature = "pyo3_macros", pyclass)]
148#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
149#[derive(Debug, Clone, Serialize)]
150/// Chat completion streaming request chunk.
151pub struct ChatCompletionChunkResponse {
152    pub id: String,
153    pub choices: Vec<ChunkChoice>,
154    pub created: u128,
155    pub model: String,
156    pub system_fingerprint: String,
157    pub object: String,
158    pub usage: Option<Usage>,
159}
160
161generate_repr!(ChatCompletionChunkResponse);
162
163#[cfg_attr(feature = "pyo3_macros", pyclass)]
164#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
165#[derive(Debug, Clone, Serialize)]
166/// Completion request choice.
167pub struct CompletionChoice {
168    pub finish_reason: String,
169    pub index: usize,
170    pub text: String,
171    pub logprobs: Option<()>,
172}
173
174generate_repr!(CompletionChoice);
175
176#[cfg_attr(feature = "pyo3_macros", pyclass)]
177#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
178#[derive(Debug, Clone, Serialize)]
179/// An OpenAI compatible completion response.
180pub struct CompletionResponse {
181    pub id: String,
182    pub choices: Vec<CompletionChoice>,
183    pub created: u64,
184    pub model: String,
185    pub system_fingerprint: String,
186    pub object: String,
187    pub usage: Usage,
188}
189
190generate_repr!(CompletionResponse);
191
192#[cfg_attr(feature = "pyo3_macros", pyclass)]
193#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
194#[derive(Debug, Clone, Serialize)]
195/// Completion request choice.
196pub struct CompletionChunkResponse {
197    pub id: String,
198    pub choices: Vec<CompletionChunkChoice>,
199    pub created: u128,
200    pub model: String,
201    pub system_fingerprint: String,
202    pub object: String,
203}
204
205generate_repr!(CompletionChunkResponse);
206
207#[cfg_attr(feature = "pyo3_macros", pyclass)]
208#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
209#[derive(Debug, Clone, Serialize)]
210pub struct ImageChoice {
211    pub url: Option<String>,
212    pub b64_json: Option<String>,
213}
214
215generate_repr!(ImageChoice);
216
217#[cfg_attr(feature = "pyo3_macros", pyclass)]
218#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
219#[derive(Debug, Clone, Serialize)]
220pub struct ImageGenerationResponse {
221    pub created: u128,
222    pub data: Vec<ImageChoice>,
223}
224
225generate_repr!(ImageGenerationResponse);
226
227/// The response enum contains 3 types of variants:
228/// - Error (-Error suffix)
229/// - Chat (no prefix)
230/// - Completion (Completion- prefix)
231pub enum Response {
232    InternalError(Box<dyn Error + Send + Sync>),
233    ValidationError(Box<dyn Error + Send + Sync>),
234    // Chat
235    ModelError(String, ChatCompletionResponse),
236    Done(ChatCompletionResponse),
237    Chunk(ChatCompletionChunkResponse),
238    // Completion
239    CompletionModelError(String, CompletionResponse),
240    CompletionDone(CompletionResponse),
241    CompletionChunk(CompletionChunkResponse),
242    // Image generation
243    ImageGeneration(ImageGenerationResponse),
244    // Raw
245    Raw {
246        logits_chunks: Vec<Tensor>,
247        tokens: Vec<u32>,
248    },
249}
250
251#[derive(Debug, Clone)]
252pub enum ResponseOk {
253    // Chat
254    Done(ChatCompletionResponse),
255    Chunk(ChatCompletionChunkResponse),
256    // Completion
257    CompletionDone(CompletionResponse),
258    CompletionChunk(CompletionChunkResponse),
259    // Image generation
260    ImageGeneration(ImageGenerationResponse),
261    // Raw
262    Raw {
263        logits_chunks: Vec<Tensor>,
264        tokens: Vec<u32>,
265    },
266}
267
268pub enum ResponseErr {
269    InternalError(Box<dyn Error + Send + Sync>),
270    ValidationError(Box<dyn Error + Send + Sync>),
271    ModelError(String, ChatCompletionResponse),
272    CompletionModelError(String, CompletionResponse),
273}
274
275impl Display for ResponseErr {
276    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        match self {
278            Self::InternalError(e) | Self::ValidationError(e) => Display::fmt(e, f),
279            Self::ModelError(e, x) => f
280                .debug_struct("ChatModelError")
281                .field("msg", e)
282                .field("incomplete_response", x)
283                .finish(),
284            Self::CompletionModelError(e, x) => f
285                .debug_struct("CompletionModelError")
286                .field("msg", e)
287                .field("incomplete_response", x)
288                .finish(),
289        }
290    }
291}
292
293impl Debug for ResponseErr {
294    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        match self {
296            Self::InternalError(e) | Self::ValidationError(e) => Debug::fmt(e, f),
297            Self::ModelError(e, x) => f
298                .debug_struct("ChatModelError")
299                .field("msg", e)
300                .field("incomplete_response", x)
301                .finish(),
302            Self::CompletionModelError(e, x) => f
303                .debug_struct("CompletionModelError")
304                .field("msg", e)
305                .field("incomplete_response", x)
306                .finish(),
307        }
308    }
309}
310
311impl std::error::Error for ResponseErr {}
312
313impl Response {
314    /// Convert the response into a result form.
315    pub fn as_result(self) -> Result<ResponseOk, Box<ResponseErr>> {
316        match self {
317            Self::Done(x) => Ok(ResponseOk::Done(x)),
318            Self::Chunk(x) => Ok(ResponseOk::Chunk(x)),
319            Self::CompletionDone(x) => Ok(ResponseOk::CompletionDone(x)),
320            Self::CompletionChunk(x) => Ok(ResponseOk::CompletionChunk(x)),
321            Self::InternalError(e) => Err(Box::new(ResponseErr::InternalError(e))),
322            Self::ValidationError(e) => Err(Box::new(ResponseErr::ValidationError(e))),
323            Self::ModelError(e, x) => Err(Box::new(ResponseErr::ModelError(e, x))),
324            Self::CompletionModelError(e, x) => {
325                Err(Box::new(ResponseErr::CompletionModelError(e, x)))
326            }
327            Self::ImageGeneration(x) => Ok(ResponseOk::ImageGeneration(x)),
328            Self::Raw {
329                logits_chunks,
330                tokens,
331            } => Ok(ResponseOk::Raw {
332                logits_chunks,
333                tokens,
334            }),
335        }
336    }
337}