mistralrs_core/
response.rs

1use std::{
2    error::Error,
3    fmt::{Debug, Display},
4    sync::Arc,
5};
6
7use candle_core::Tensor;
8#[cfg(feature = "pyo3_macros")]
9use pyo3::{pyclass, pymethods};
10use serde::Serialize;
11
12use crate::{sampler::TopLogprob, tools::ToolCallResponse};
13
14pub const SYSTEM_FINGERPRINT: &str = "local";
15
16macro_rules! generate_repr {
17    ($t:ident) => {
18        #[cfg(feature = "pyo3_macros")]
19        #[pymethods]
20        impl $t {
21            fn __repr__(&self) -> String {
22                format!("{self:#?}")
23            }
24        }
25    };
26}
27
28#[cfg_attr(feature = "pyo3_macros", pyclass)]
29#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
30#[derive(Debug, Clone, Serialize)]
31/// Chat completion response message.
32pub struct ResponseMessage {
33    pub content: Option<String>,
34    pub role: String,
35    pub tool_calls: Option<Vec<ToolCallResponse>>,
36    /// Reasoning/analysis content from Harmony format (separate from final content).
37    /// This contains chain-of-thought reasoning that is not intended for end users.
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub reasoning_content: Option<String>,
40}
41
42generate_repr!(ResponseMessage);
43
44#[cfg_attr(feature = "pyo3_macros", pyclass)]
45#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
46#[derive(Debug, Clone, Serialize)]
47/// Delta in content for streaming response.
48pub struct Delta {
49    pub content: Option<String>,
50    pub role: String,
51    pub tool_calls: Option<Vec<ToolCallResponse>>,
52    /// Reasoning/analysis content delta from Harmony format.
53    /// This contains incremental chain-of-thought reasoning.
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub reasoning_content: Option<String>,
56}
57
58generate_repr!(Delta);
59
60#[cfg_attr(feature = "pyo3_macros", pyclass)]
61#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
62#[derive(Debug, Clone, Serialize)]
63/// A logprob with the top logprobs for this token.
64pub struct ResponseLogprob {
65    pub token: String,
66    pub logprob: f32,
67    pub bytes: Option<Vec<u8>>,
68    pub top_logprobs: Vec<TopLogprob>,
69}
70
71generate_repr!(ResponseLogprob);
72
73#[cfg_attr(feature = "pyo3_macros", pyclass)]
74#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
75#[derive(Debug, Clone, Serialize)]
76/// Logprobs per token.
77pub struct Logprobs {
78    pub content: Option<Vec<ResponseLogprob>>,
79}
80
81generate_repr!(Logprobs);
82
83#[cfg_attr(feature = "pyo3_macros", pyclass)]
84#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
85#[derive(Debug, Clone, Serialize)]
86/// Chat completion choice.
87pub struct Choice {
88    pub finish_reason: String,
89    pub index: usize,
90    pub message: ResponseMessage,
91    pub logprobs: Option<Logprobs>,
92}
93
94generate_repr!(Choice);
95
96#[cfg_attr(feature = "pyo3_macros", pyclass)]
97#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
98#[derive(Debug, Clone, Serialize)]
99/// Chat completion streaming chunk choice.
100pub struct ChunkChoice {
101    pub finish_reason: Option<String>,
102    pub index: usize,
103    pub delta: Delta,
104    pub logprobs: Option<ResponseLogprob>,
105}
106
107generate_repr!(ChunkChoice);
108
109#[cfg_attr(feature = "pyo3_macros", pyclass)]
110#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
111#[derive(Debug, Clone, Serialize)]
112/// Chat completion streaming chunk choice.
113pub struct CompletionChunkChoice {
114    pub text: String,
115    pub index: usize,
116    pub logprobs: Option<ResponseLogprob>,
117    pub finish_reason: Option<String>,
118}
119
120generate_repr!(CompletionChunkChoice);
121
122#[cfg_attr(feature = "pyo3_macros", pyclass)]
123#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
124#[derive(Debug, Clone, Serialize)]
125/// OpenAI compatible (superset) usage during a request.
126pub struct Usage {
127    pub completion_tokens: usize,
128    pub prompt_tokens: usize,
129    pub total_tokens: usize,
130    pub avg_tok_per_sec: f32,
131    pub avg_prompt_tok_per_sec: f32,
132    pub avg_compl_tok_per_sec: f32,
133    pub total_time_sec: f32,
134    pub total_prompt_time_sec: f32,
135    pub total_completion_time_sec: f32,
136}
137
138generate_repr!(Usage);
139
140#[cfg_attr(feature = "pyo3_macros", pyclass)]
141#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
142#[derive(Debug, Clone, Serialize)]
143/// An OpenAI compatible chat completion response.
144pub struct ChatCompletionResponse {
145    pub id: String,
146    pub choices: Vec<Choice>,
147    pub created: u64,
148    pub model: String,
149    pub system_fingerprint: String,
150    pub object: String,
151    pub usage: Usage,
152}
153
154generate_repr!(ChatCompletionResponse);
155
156#[cfg_attr(feature = "pyo3_macros", pyclass)]
157#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
158#[derive(Debug, Clone, Serialize)]
159/// Chat completion streaming request chunk.
160pub struct ChatCompletionChunkResponse {
161    pub id: String,
162    pub choices: Vec<ChunkChoice>,
163    pub created: u128,
164    pub model: String,
165    pub system_fingerprint: String,
166    pub object: String,
167    pub usage: Option<Usage>,
168}
169
170generate_repr!(ChatCompletionChunkResponse);
171
172#[cfg_attr(feature = "pyo3_macros", pyclass)]
173#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
174#[derive(Debug, Clone, Serialize)]
175/// Completion request choice.
176pub struct CompletionChoice {
177    pub finish_reason: String,
178    pub index: usize,
179    pub text: String,
180    pub logprobs: Option<()>,
181}
182
183generate_repr!(CompletionChoice);
184
185#[cfg_attr(feature = "pyo3_macros", pyclass)]
186#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
187#[derive(Debug, Clone, Serialize)]
188/// An OpenAI compatible completion response.
189pub struct CompletionResponse {
190    pub id: String,
191    pub choices: Vec<CompletionChoice>,
192    pub created: u64,
193    pub model: String,
194    pub system_fingerprint: String,
195    pub object: String,
196    pub usage: Usage,
197}
198
199generate_repr!(CompletionResponse);
200
201#[cfg_attr(feature = "pyo3_macros", pyclass)]
202#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
203#[derive(Debug, Clone, Serialize)]
204/// Completion request choice.
205pub struct CompletionChunkResponse {
206    pub id: String,
207    pub choices: Vec<CompletionChunkChoice>,
208    pub created: u128,
209    pub model: String,
210    pub system_fingerprint: String,
211    pub object: String,
212}
213
214generate_repr!(CompletionChunkResponse);
215
216#[cfg_attr(feature = "pyo3_macros", pyclass)]
217#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
218#[derive(Debug, Clone, Serialize)]
219pub struct ImageChoice {
220    pub url: Option<String>,
221    pub b64_json: Option<String>,
222}
223
224generate_repr!(ImageChoice);
225
226#[cfg_attr(feature = "pyo3_macros", pyclass)]
227#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
228#[derive(Debug, Clone, Serialize)]
229pub struct ImageGenerationResponse {
230    pub created: u128,
231    pub data: Vec<ImageChoice>,
232}
233
234generate_repr!(ImageGenerationResponse);
235
236/// The response enum contains 3 types of variants:
237/// - Error (-Error suffix)
238/// - Chat (no prefix)
239/// - Completion (Completion- prefix)
240pub enum Response {
241    InternalError(Box<dyn Error + Send + Sync>),
242    ValidationError(Box<dyn Error + Send + Sync>),
243    // Chat
244    ModelError(String, ChatCompletionResponse),
245    Done(ChatCompletionResponse),
246    Chunk(ChatCompletionChunkResponse),
247    // Completion
248    CompletionModelError(String, CompletionResponse),
249    CompletionDone(CompletionResponse),
250    CompletionChunk(CompletionChunkResponse),
251    // Image generation
252    ImageGeneration(ImageGenerationResponse),
253    // Speech generation
254    Speech {
255        pcm: Arc<Vec<f32>>,
256        rate: usize,
257        channels: usize,
258    },
259    // Raw
260    Raw {
261        logits_chunks: Vec<Tensor>,
262        tokens: Vec<u32>,
263    },
264    Embeddings {
265        embeddings: Vec<f32>,
266        prompt_tokens: usize,
267        total_tokens: usize,
268    },
269}
270
271#[derive(Debug, Clone)]
272pub enum ResponseOk {
273    // Chat
274    Done(ChatCompletionResponse),
275    Chunk(ChatCompletionChunkResponse),
276    // Completion
277    CompletionDone(CompletionResponse),
278    CompletionChunk(CompletionChunkResponse),
279    // Image generation
280    ImageGeneration(ImageGenerationResponse),
281    // Speech generation
282    Speech {
283        pcm: Arc<Vec<f32>>,
284        rate: usize,
285        channels: usize,
286    },
287    // Raw
288    Raw {
289        logits_chunks: Vec<Tensor>,
290        tokens: Vec<u32>,
291    },
292    // Embeddings
293    Embeddings {
294        embeddings: Vec<f32>,
295        prompt_tokens: usize,
296        total_tokens: usize,
297    },
298}
299
300pub enum ResponseErr {
301    InternalError(Box<dyn Error + Send + Sync>),
302    ValidationError(Box<dyn Error + Send + Sync>),
303    ModelError(String, ChatCompletionResponse),
304    CompletionModelError(String, CompletionResponse),
305}
306
307impl Display for ResponseErr {
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        match self {
310            Self::InternalError(e) | Self::ValidationError(e) => Display::fmt(e, f),
311            Self::ModelError(e, x) => f
312                .debug_struct("ChatModelError")
313                .field("msg", e)
314                .field("incomplete_response", x)
315                .finish(),
316            Self::CompletionModelError(e, x) => f
317                .debug_struct("CompletionModelError")
318                .field("msg", e)
319                .field("incomplete_response", x)
320                .finish(),
321        }
322    }
323}
324
325impl Debug for ResponseErr {
326    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327        match self {
328            Self::InternalError(e) | Self::ValidationError(e) => Debug::fmt(e, f),
329            Self::ModelError(e, x) => f
330                .debug_struct("ChatModelError")
331                .field("msg", e)
332                .field("incomplete_response", x)
333                .finish(),
334            Self::CompletionModelError(e, x) => f
335                .debug_struct("CompletionModelError")
336                .field("msg", e)
337                .field("incomplete_response", x)
338                .finish(),
339        }
340    }
341}
342
343impl std::error::Error for ResponseErr {}
344
345impl Response {
346    /// Convert the response into a result form.
347    pub fn as_result(self) -> Result<ResponseOk, Box<ResponseErr>> {
348        match self {
349            Self::Done(x) => Ok(ResponseOk::Done(x)),
350            Self::Chunk(x) => Ok(ResponseOk::Chunk(x)),
351            Self::CompletionDone(x) => Ok(ResponseOk::CompletionDone(x)),
352            Self::CompletionChunk(x) => Ok(ResponseOk::CompletionChunk(x)),
353            Self::InternalError(e) => Err(Box::new(ResponseErr::InternalError(e))),
354            Self::ValidationError(e) => Err(Box::new(ResponseErr::ValidationError(e))),
355            Self::ModelError(e, x) => Err(Box::new(ResponseErr::ModelError(e, x))),
356            Self::CompletionModelError(e, x) => {
357                Err(Box::new(ResponseErr::CompletionModelError(e, x)))
358            }
359            Self::ImageGeneration(x) => Ok(ResponseOk::ImageGeneration(x)),
360            Self::Speech {
361                pcm,
362                rate,
363                channels,
364            } => Ok(ResponseOk::Speech {
365                pcm,
366                rate,
367                channels,
368            }),
369            Self::Raw {
370                logits_chunks,
371                tokens,
372            } => Ok(ResponseOk::Raw {
373                logits_chunks,
374                tokens,
375            }),
376            Self::Embeddings {
377                embeddings,
378                prompt_tokens,
379                total_tokens,
380            } => Ok(ResponseOk::Embeddings {
381                embeddings,
382                prompt_tokens,
383                total_tokens,
384            }),
385        }
386    }
387}