mistralrs_core/
response.rs

1use std::{
2    error::Error,
3    fmt::{Debug, Display},
4    sync::Arc,
5};
6
7use candle_core::Tensor;
8#[cfg(feature = "pyo3_macros")]
9use pyo3::{pyclass, pymethods};
10use serde::Serialize;
11
12use crate::{sampler::TopLogprob, tools::ToolCallResponse};
13
14pub const SYSTEM_FINGERPRINT: &str = "local";
15
16macro_rules! generate_repr {
17    ($t:ident) => {
18        #[cfg(feature = "pyo3_macros")]
19        #[pymethods]
20        impl $t {
21            fn __repr__(&self) -> String {
22                format!("{self:#?}")
23            }
24        }
25    };
26}
27
28#[cfg_attr(feature = "pyo3_macros", pyclass)]
29#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
30#[derive(Debug, Clone, Serialize)]
31/// Chat completion response message.
32pub struct ResponseMessage {
33    pub content: Option<String>,
34    pub role: String,
35    pub tool_calls: Option<Vec<ToolCallResponse>>,
36}
37
38generate_repr!(ResponseMessage);
39
40#[cfg_attr(feature = "pyo3_macros", pyclass)]
41#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
42#[derive(Debug, Clone, Serialize)]
43/// Delta in content for streaming response.
44pub struct Delta {
45    pub content: Option<String>,
46    pub role: String,
47    pub tool_calls: Option<Vec<ToolCallResponse>>,
48}
49
50generate_repr!(Delta);
51
52#[cfg_attr(feature = "pyo3_macros", pyclass)]
53#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
54#[derive(Debug, Clone, Serialize)]
55/// A logprob with the top logprobs for this token.
56pub struct ResponseLogprob {
57    pub token: String,
58    pub logprob: f32,
59    pub bytes: Option<Vec<u8>>,
60    pub top_logprobs: Vec<TopLogprob>,
61}
62
63generate_repr!(ResponseLogprob);
64
65#[cfg_attr(feature = "pyo3_macros", pyclass)]
66#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
67#[derive(Debug, Clone, Serialize)]
68/// Logprobs per token.
69pub struct Logprobs {
70    pub content: Option<Vec<ResponseLogprob>>,
71}
72
73generate_repr!(Logprobs);
74
75#[cfg_attr(feature = "pyo3_macros", pyclass)]
76#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
77#[derive(Debug, Clone, Serialize)]
78/// Chat completion choice.
79pub struct Choice {
80    pub finish_reason: String,
81    pub index: usize,
82    pub message: ResponseMessage,
83    pub logprobs: Option<Logprobs>,
84}
85
86generate_repr!(Choice);
87
88#[cfg_attr(feature = "pyo3_macros", pyclass)]
89#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
90#[derive(Debug, Clone, Serialize)]
91/// Chat completion streaming chunk choice.
92pub struct ChunkChoice {
93    pub finish_reason: Option<String>,
94    pub index: usize,
95    pub delta: Delta,
96    pub logprobs: Option<ResponseLogprob>,
97}
98
99generate_repr!(ChunkChoice);
100
101#[cfg_attr(feature = "pyo3_macros", pyclass)]
102#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
103#[derive(Debug, Clone, Serialize)]
104/// Chat completion streaming chunk choice.
105pub struct CompletionChunkChoice {
106    pub text: String,
107    pub index: usize,
108    pub logprobs: Option<ResponseLogprob>,
109    pub finish_reason: Option<String>,
110}
111
112generate_repr!(CompletionChunkChoice);
113
114#[cfg_attr(feature = "pyo3_macros", pyclass)]
115#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
116#[derive(Debug, Clone, Serialize)]
117/// OpenAI compatible (superset) usage during a request.
118pub struct Usage {
119    pub completion_tokens: usize,
120    pub prompt_tokens: usize,
121    pub total_tokens: usize,
122    pub avg_tok_per_sec: f32,
123    pub avg_prompt_tok_per_sec: f32,
124    pub avg_compl_tok_per_sec: f32,
125    pub total_time_sec: f32,
126    pub total_prompt_time_sec: f32,
127    pub total_completion_time_sec: f32,
128}
129
130generate_repr!(Usage);
131
132#[cfg_attr(feature = "pyo3_macros", pyclass)]
133#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
134#[derive(Debug, Clone, Serialize)]
135/// An OpenAI compatible chat completion response.
136pub struct ChatCompletionResponse {
137    pub id: String,
138    pub choices: Vec<Choice>,
139    pub created: u64,
140    pub model: String,
141    pub system_fingerprint: String,
142    pub object: String,
143    pub usage: Usage,
144}
145
146generate_repr!(ChatCompletionResponse);
147
148#[cfg_attr(feature = "pyo3_macros", pyclass)]
149#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
150#[derive(Debug, Clone, Serialize)]
151/// Chat completion streaming request chunk.
152pub struct ChatCompletionChunkResponse {
153    pub id: String,
154    pub choices: Vec<ChunkChoice>,
155    pub created: u128,
156    pub model: String,
157    pub system_fingerprint: String,
158    pub object: String,
159    pub usage: Option<Usage>,
160}
161
162generate_repr!(ChatCompletionChunkResponse);
163
164#[cfg_attr(feature = "pyo3_macros", pyclass)]
165#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
166#[derive(Debug, Clone, Serialize)]
167/// Completion request choice.
168pub struct CompletionChoice {
169    pub finish_reason: String,
170    pub index: usize,
171    pub text: String,
172    pub logprobs: Option<()>,
173}
174
175generate_repr!(CompletionChoice);
176
177#[cfg_attr(feature = "pyo3_macros", pyclass)]
178#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
179#[derive(Debug, Clone, Serialize)]
180/// An OpenAI compatible completion response.
181pub struct CompletionResponse {
182    pub id: String,
183    pub choices: Vec<CompletionChoice>,
184    pub created: u64,
185    pub model: String,
186    pub system_fingerprint: String,
187    pub object: String,
188    pub usage: Usage,
189}
190
191generate_repr!(CompletionResponse);
192
193#[cfg_attr(feature = "pyo3_macros", pyclass)]
194#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
195#[derive(Debug, Clone, Serialize)]
196/// Completion request choice.
197pub struct CompletionChunkResponse {
198    pub id: String,
199    pub choices: Vec<CompletionChunkChoice>,
200    pub created: u128,
201    pub model: String,
202    pub system_fingerprint: String,
203    pub object: String,
204}
205
206generate_repr!(CompletionChunkResponse);
207
208#[cfg_attr(feature = "pyo3_macros", pyclass)]
209#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
210#[derive(Debug, Clone, Serialize)]
211pub struct ImageChoice {
212    pub url: Option<String>,
213    pub b64_json: Option<String>,
214}
215
216generate_repr!(ImageChoice);
217
218#[cfg_attr(feature = "pyo3_macros", pyclass)]
219#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
220#[derive(Debug, Clone, Serialize)]
221pub struct ImageGenerationResponse {
222    pub created: u128,
223    pub data: Vec<ImageChoice>,
224}
225
226generate_repr!(ImageGenerationResponse);
227
228/// The response enum contains 3 types of variants:
229/// - Error (-Error suffix)
230/// - Chat (no prefix)
231/// - Completion (Completion- prefix)
232pub enum Response {
233    InternalError(Box<dyn Error + Send + Sync>),
234    ValidationError(Box<dyn Error + Send + Sync>),
235    // Chat
236    ModelError(String, ChatCompletionResponse),
237    Done(ChatCompletionResponse),
238    Chunk(ChatCompletionChunkResponse),
239    // Completion
240    CompletionModelError(String, CompletionResponse),
241    CompletionDone(CompletionResponse),
242    CompletionChunk(CompletionChunkResponse),
243    // Image generation
244    ImageGeneration(ImageGenerationResponse),
245    // Speech generation
246    Speech {
247        pcm: Arc<Vec<f32>>,
248        rate: usize,
249        channels: usize,
250    },
251    // Raw
252    Raw {
253        logits_chunks: Vec<Tensor>,
254        tokens: Vec<u32>,
255    },
256}
257
258#[derive(Debug, Clone)]
259pub enum ResponseOk {
260    // Chat
261    Done(ChatCompletionResponse),
262    Chunk(ChatCompletionChunkResponse),
263    // Completion
264    CompletionDone(CompletionResponse),
265    CompletionChunk(CompletionChunkResponse),
266    // Image generation
267    ImageGeneration(ImageGenerationResponse),
268    // Speech generation
269    Speech {
270        pcm: Arc<Vec<f32>>,
271        rate: usize,
272        channels: usize,
273    },
274    // Raw
275    Raw {
276        logits_chunks: Vec<Tensor>,
277        tokens: Vec<u32>,
278    },
279}
280
281pub enum ResponseErr {
282    InternalError(Box<dyn Error + Send + Sync>),
283    ValidationError(Box<dyn Error + Send + Sync>),
284    ModelError(String, ChatCompletionResponse),
285    CompletionModelError(String, CompletionResponse),
286}
287
288impl Display for ResponseErr {
289    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
290        match self {
291            Self::InternalError(e) | Self::ValidationError(e) => Display::fmt(e, f),
292            Self::ModelError(e, x) => f
293                .debug_struct("ChatModelError")
294                .field("msg", e)
295                .field("incomplete_response", x)
296                .finish(),
297            Self::CompletionModelError(e, x) => f
298                .debug_struct("CompletionModelError")
299                .field("msg", e)
300                .field("incomplete_response", x)
301                .finish(),
302        }
303    }
304}
305
306impl Debug for ResponseErr {
307    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308        match self {
309            Self::InternalError(e) | Self::ValidationError(e) => Debug::fmt(e, f),
310            Self::ModelError(e, x) => f
311                .debug_struct("ChatModelError")
312                .field("msg", e)
313                .field("incomplete_response", x)
314                .finish(),
315            Self::CompletionModelError(e, x) => f
316                .debug_struct("CompletionModelError")
317                .field("msg", e)
318                .field("incomplete_response", x)
319                .finish(),
320        }
321    }
322}
323
324impl std::error::Error for ResponseErr {}
325
326impl Response {
327    /// Convert the response into a result form.
328    pub fn as_result(self) -> Result<ResponseOk, Box<ResponseErr>> {
329        match self {
330            Self::Done(x) => Ok(ResponseOk::Done(x)),
331            Self::Chunk(x) => Ok(ResponseOk::Chunk(x)),
332            Self::CompletionDone(x) => Ok(ResponseOk::CompletionDone(x)),
333            Self::CompletionChunk(x) => Ok(ResponseOk::CompletionChunk(x)),
334            Self::InternalError(e) => Err(Box::new(ResponseErr::InternalError(e))),
335            Self::ValidationError(e) => Err(Box::new(ResponseErr::ValidationError(e))),
336            Self::ModelError(e, x) => Err(Box::new(ResponseErr::ModelError(e, x))),
337            Self::CompletionModelError(e, x) => {
338                Err(Box::new(ResponseErr::CompletionModelError(e, x)))
339            }
340            Self::ImageGeneration(x) => Ok(ResponseOk::ImageGeneration(x)),
341            Self::Speech {
342                pcm,
343                rate,
344                channels,
345            } => Ok(ResponseOk::Speech {
346                pcm,
347                rate,
348                channels,
349            }),
350            Self::Raw {
351                logits_chunks,
352                tokens,
353            } => Ok(ResponseOk::Raw {
354                logits_chunks,
355                tokens,
356            }),
357        }
358    }
359}