mistralrs_core/
response.rs

1use std::{
2    error::Error,
3    fmt::{Debug, Display},
4    sync::Arc,
5};
6
7use candle_core::Tensor;
8#[cfg(feature = "pyo3_macros")]
9use pyo3::{pyclass, pymethods};
10use serde::Serialize;
11
12use crate::{sampler::TopLogprob, tools::ToolCallResponse};
13
14pub const SYSTEM_FINGERPRINT: &str = "local";
15
16macro_rules! generate_repr {
17    ($t:ident) => {
18        #[cfg(feature = "pyo3_macros")]
19        #[pymethods]
20        impl $t {
21            fn __repr__(&self) -> String {
22                format!("{self:#?}")
23            }
24        }
25    };
26}
27
28#[cfg_attr(feature = "pyo3_macros", pyclass)]
29#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
30#[derive(Debug, Clone, Serialize)]
31/// Chat completion response message.
32pub struct ResponseMessage {
33    pub content: Option<String>,
34    pub role: String,
35    pub tool_calls: Option<Vec<ToolCallResponse>>,
36}
37
38generate_repr!(ResponseMessage);
39
40#[cfg_attr(feature = "pyo3_macros", pyclass)]
41#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
42#[derive(Debug, Clone, Serialize)]
43/// Delta in content for streaming response.
44pub struct Delta {
45    pub content: Option<String>,
46    pub role: String,
47    pub tool_calls: Option<Vec<ToolCallResponse>>,
48}
49
50generate_repr!(Delta);
51
52#[cfg_attr(feature = "pyo3_macros", pyclass)]
53#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
54#[derive(Debug, Clone, Serialize)]
55/// A logprob with the top logprobs for this token.
56pub struct ResponseLogprob {
57    pub token: String,
58    pub logprob: f32,
59    pub bytes: Option<Vec<u8>>,
60    pub top_logprobs: Vec<TopLogprob>,
61}
62
63generate_repr!(ResponseLogprob);
64
65#[cfg_attr(feature = "pyo3_macros", pyclass)]
66#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
67#[derive(Debug, Clone, Serialize)]
68/// Logprobs per token.
69pub struct Logprobs {
70    pub content: Option<Vec<ResponseLogprob>>,
71}
72
73generate_repr!(Logprobs);
74
75#[cfg_attr(feature = "pyo3_macros", pyclass)]
76#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
77#[derive(Debug, Clone, Serialize)]
78/// Chat completion choice.
79pub struct Choice {
80    pub finish_reason: String,
81    pub index: usize,
82    pub message: ResponseMessage,
83    pub logprobs: Option<Logprobs>,
84}
85
86generate_repr!(Choice);
87
88#[cfg_attr(feature = "pyo3_macros", pyclass)]
89#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
90#[derive(Debug, Clone, Serialize)]
91/// Chat completion streaming chunk choice.
92pub struct ChunkChoice {
93    pub finish_reason: Option<String>,
94    pub index: usize,
95    pub delta: Delta,
96    pub logprobs: Option<ResponseLogprob>,
97}
98
99generate_repr!(ChunkChoice);
100
101#[cfg_attr(feature = "pyo3_macros", pyclass)]
102#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
103#[derive(Debug, Clone, Serialize)]
104/// Chat completion streaming chunk choice.
105pub struct CompletionChunkChoice {
106    pub text: String,
107    pub index: usize,
108    pub logprobs: Option<ResponseLogprob>,
109    pub finish_reason: Option<String>,
110}
111
112generate_repr!(CompletionChunkChoice);
113
114#[cfg_attr(feature = "pyo3_macros", pyclass)]
115#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
116#[derive(Debug, Clone, Serialize)]
117/// OpenAI compatible (superset) usage during a request.
118pub struct Usage {
119    pub completion_tokens: usize,
120    pub prompt_tokens: usize,
121    pub total_tokens: usize,
122    pub avg_tok_per_sec: f32,
123    pub avg_prompt_tok_per_sec: f32,
124    pub avg_compl_tok_per_sec: f32,
125    pub total_time_sec: f32,
126    pub total_prompt_time_sec: f32,
127    pub total_completion_time_sec: f32,
128}
129
130generate_repr!(Usage);
131
132#[cfg_attr(feature = "pyo3_macros", pyclass)]
133#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
134#[derive(Debug, Clone, Serialize)]
135/// An OpenAI compatible chat completion response.
136pub struct ChatCompletionResponse {
137    pub id: String,
138    pub choices: Vec<Choice>,
139    pub created: u64,
140    pub model: String,
141    pub system_fingerprint: String,
142    pub object: String,
143    pub usage: Usage,
144}
145
146generate_repr!(ChatCompletionResponse);
147
148#[cfg_attr(feature = "pyo3_macros", pyclass)]
149#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
150#[derive(Debug, Clone, Serialize)]
151/// Chat completion streaming request chunk.
152pub struct ChatCompletionChunkResponse {
153    pub id: String,
154    pub choices: Vec<ChunkChoice>,
155    pub created: u128,
156    pub model: String,
157    pub system_fingerprint: String,
158    pub object: String,
159    pub usage: Option<Usage>,
160}
161
162generate_repr!(ChatCompletionChunkResponse);
163
164#[cfg_attr(feature = "pyo3_macros", pyclass)]
165#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
166#[derive(Debug, Clone, Serialize)]
167/// Completion request choice.
168pub struct CompletionChoice {
169    pub finish_reason: String,
170    pub index: usize,
171    pub text: String,
172    pub logprobs: Option<()>,
173}
174
175generate_repr!(CompletionChoice);
176
177#[cfg_attr(feature = "pyo3_macros", pyclass)]
178#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
179#[derive(Debug, Clone, Serialize)]
180/// An OpenAI compatible completion response.
181pub struct CompletionResponse {
182    pub id: String,
183    pub choices: Vec<CompletionChoice>,
184    pub created: u64,
185    pub model: String,
186    pub system_fingerprint: String,
187    pub object: String,
188    pub usage: Usage,
189}
190
191generate_repr!(CompletionResponse);
192
193#[cfg_attr(feature = "pyo3_macros", pyclass)]
194#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
195#[derive(Debug, Clone, Serialize)]
196/// Completion request choice.
197pub struct CompletionChunkResponse {
198    pub id: String,
199    pub choices: Vec<CompletionChunkChoice>,
200    pub created: u128,
201    pub model: String,
202    pub system_fingerprint: String,
203    pub object: String,
204}
205
206generate_repr!(CompletionChunkResponse);
207
208#[cfg_attr(feature = "pyo3_macros", pyclass)]
209#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
210#[derive(Debug, Clone, Serialize)]
211pub struct ImageChoice {
212    pub url: Option<String>,
213    pub b64_json: Option<String>,
214}
215
216generate_repr!(ImageChoice);
217
218#[cfg_attr(feature = "pyo3_macros", pyclass)]
219#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
220#[derive(Debug, Clone, Serialize)]
221pub struct ImageGenerationResponse {
222    pub created: u128,
223    pub data: Vec<ImageChoice>,
224}
225
226generate_repr!(ImageGenerationResponse);
227
228/// The response enum contains 3 types of variants:
229/// - Error (-Error suffix)
230/// - Chat (no prefix)
231/// - Completion (Completion- prefix)
232pub enum Response {
233    InternalError(Box<dyn Error + Send + Sync>),
234    ValidationError(Box<dyn Error + Send + Sync>),
235    // Chat
236    ModelError(String, ChatCompletionResponse),
237    Done(ChatCompletionResponse),
238    Chunk(ChatCompletionChunkResponse),
239    // Completion
240    CompletionModelError(String, CompletionResponse),
241    CompletionDone(CompletionResponse),
242    CompletionChunk(CompletionChunkResponse),
243    // Image generation
244    ImageGeneration(ImageGenerationResponse),
245    // Speech generation
246    Speech {
247        pcm: Arc<Vec<f32>>,
248        rate: usize,
249        channels: usize,
250    },
251    // Raw
252    Raw {
253        logits_chunks: Vec<Tensor>,
254        tokens: Vec<u32>,
255    },
256    Embeddings {
257        embeddings: Vec<f32>,
258        prompt_tokens: usize,
259        total_tokens: usize,
260    },
261}
262
263#[derive(Debug, Clone)]
264pub enum ResponseOk {
265    // Chat
266    Done(ChatCompletionResponse),
267    Chunk(ChatCompletionChunkResponse),
268    // Completion
269    CompletionDone(CompletionResponse),
270    CompletionChunk(CompletionChunkResponse),
271    // Image generation
272    ImageGeneration(ImageGenerationResponse),
273    // Speech generation
274    Speech {
275        pcm: Arc<Vec<f32>>,
276        rate: usize,
277        channels: usize,
278    },
279    // Raw
280    Raw {
281        logits_chunks: Vec<Tensor>,
282        tokens: Vec<u32>,
283    },
284    // Embeddings
285    Embeddings {
286        embeddings: Vec<f32>,
287        prompt_tokens: usize,
288        total_tokens: usize,
289    },
290}
291
292pub enum ResponseErr {
293    InternalError(Box<dyn Error + Send + Sync>),
294    ValidationError(Box<dyn Error + Send + Sync>),
295    ModelError(String, ChatCompletionResponse),
296    CompletionModelError(String, CompletionResponse),
297}
298
299impl Display for ResponseErr {
300    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301        match self {
302            Self::InternalError(e) | Self::ValidationError(e) => Display::fmt(e, f),
303            Self::ModelError(e, x) => f
304                .debug_struct("ChatModelError")
305                .field("msg", e)
306                .field("incomplete_response", x)
307                .finish(),
308            Self::CompletionModelError(e, x) => f
309                .debug_struct("CompletionModelError")
310                .field("msg", e)
311                .field("incomplete_response", x)
312                .finish(),
313        }
314    }
315}
316
317impl Debug for ResponseErr {
318    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319        match self {
320            Self::InternalError(e) | Self::ValidationError(e) => Debug::fmt(e, f),
321            Self::ModelError(e, x) => f
322                .debug_struct("ChatModelError")
323                .field("msg", e)
324                .field("incomplete_response", x)
325                .finish(),
326            Self::CompletionModelError(e, x) => f
327                .debug_struct("CompletionModelError")
328                .field("msg", e)
329                .field("incomplete_response", x)
330                .finish(),
331        }
332    }
333}
334
335impl std::error::Error for ResponseErr {}
336
337impl Response {
338    /// Convert the response into a result form.
339    pub fn as_result(self) -> Result<ResponseOk, Box<ResponseErr>> {
340        match self {
341            Self::Done(x) => Ok(ResponseOk::Done(x)),
342            Self::Chunk(x) => Ok(ResponseOk::Chunk(x)),
343            Self::CompletionDone(x) => Ok(ResponseOk::CompletionDone(x)),
344            Self::CompletionChunk(x) => Ok(ResponseOk::CompletionChunk(x)),
345            Self::InternalError(e) => Err(Box::new(ResponseErr::InternalError(e))),
346            Self::ValidationError(e) => Err(Box::new(ResponseErr::ValidationError(e))),
347            Self::ModelError(e, x) => Err(Box::new(ResponseErr::ModelError(e, x))),
348            Self::CompletionModelError(e, x) => {
349                Err(Box::new(ResponseErr::CompletionModelError(e, x)))
350            }
351            Self::ImageGeneration(x) => Ok(ResponseOk::ImageGeneration(x)),
352            Self::Speech {
353                pcm,
354                rate,
355                channels,
356            } => Ok(ResponseOk::Speech {
357                pcm,
358                rate,
359                channels,
360            }),
361            Self::Raw {
362                logits_chunks,
363                tokens,
364            } => Ok(ResponseOk::Raw {
365                logits_chunks,
366                tokens,
367            }),
368            Self::Embeddings {
369                embeddings,
370                prompt_tokens,
371                total_tokens,
372            } => Ok(ResponseOk::Embeddings {
373                embeddings,
374                prompt_tokens,
375                total_tokens,
376            }),
377        }
378    }
379}