1use std::{
2 error::Error,
3 fmt::{Debug, Display},
4};
5
6use candle_core::Tensor;
7#[cfg(feature = "pyo3_macros")]
8use pyo3::{pyclass, pymethods};
9use serde::Serialize;
10
11use crate::{sampler::TopLogprob, tools::ToolCallResponse};
12
13pub const SYSTEM_FINGERPRINT: &str = "local";
14
15macro_rules! generate_repr {
16 ($t:ident) => {
17 #[cfg(feature = "pyo3_macros")]
18 #[pymethods]
19 impl $t {
20 fn __repr__(&self) -> String {
21 format!("{self:#?}")
22 }
23 }
24 };
25}
26
27#[cfg_attr(feature = "pyo3_macros", pyclass)]
28#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
29#[derive(Debug, Clone, Serialize)]
30pub struct ResponseMessage {
32 pub content: Option<String>,
33 pub role: String,
34 pub tool_calls: Option<Vec<ToolCallResponse>>,
35}
36
37generate_repr!(ResponseMessage);
38
39#[cfg_attr(feature = "pyo3_macros", pyclass)]
40#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
41#[derive(Debug, Clone, Serialize)]
42pub struct Delta {
44 pub content: Option<String>,
45 pub role: String,
46 pub tool_calls: Option<Vec<ToolCallResponse>>,
47}
48
49generate_repr!(Delta);
50
51#[cfg_attr(feature = "pyo3_macros", pyclass)]
52#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
53#[derive(Debug, Clone, Serialize)]
54pub struct ResponseLogprob {
56 pub token: String,
57 pub logprob: f32,
58 pub bytes: Option<Vec<u8>>,
59 pub top_logprobs: Vec<TopLogprob>,
60}
61
62generate_repr!(ResponseLogprob);
63
64#[cfg_attr(feature = "pyo3_macros", pyclass)]
65#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
66#[derive(Debug, Clone, Serialize)]
67pub struct Logprobs {
69 pub content: Option<Vec<ResponseLogprob>>,
70}
71
72generate_repr!(Logprobs);
73
74#[cfg_attr(feature = "pyo3_macros", pyclass)]
75#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
76#[derive(Debug, Clone, Serialize)]
77pub struct Choice {
79 pub finish_reason: String,
80 pub index: usize,
81 pub message: ResponseMessage,
82 pub logprobs: Option<Logprobs>,
83}
84
85generate_repr!(Choice);
86
87#[cfg_attr(feature = "pyo3_macros", pyclass)]
88#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
89#[derive(Debug, Clone, Serialize)]
90pub struct ChunkChoice {
92 pub finish_reason: Option<String>,
93 pub index: usize,
94 pub delta: Delta,
95 pub logprobs: Option<ResponseLogprob>,
96}
97
98generate_repr!(ChunkChoice);
99
100#[cfg_attr(feature = "pyo3_macros", pyclass)]
101#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
102#[derive(Debug, Clone, Serialize)]
103pub struct CompletionChunkChoice {
105 pub text: String,
106 pub index: usize,
107 pub logprobs: Option<ResponseLogprob>,
108 pub finish_reason: Option<String>,
109}
110
111generate_repr!(CompletionChunkChoice);
112
113#[cfg_attr(feature = "pyo3_macros", pyclass)]
114#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
115#[derive(Debug, Clone, Serialize)]
116pub struct Usage {
118 pub completion_tokens: usize,
119 pub prompt_tokens: usize,
120 pub total_tokens: usize,
121 pub avg_tok_per_sec: f32,
122 pub avg_prompt_tok_per_sec: f32,
123 pub avg_compl_tok_per_sec: f32,
124 pub total_time_sec: f32,
125 pub total_prompt_time_sec: f32,
126 pub total_completion_time_sec: f32,
127}
128
129generate_repr!(Usage);
130
131#[cfg_attr(feature = "pyo3_macros", pyclass)]
132#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
133#[derive(Debug, Clone, Serialize)]
134pub struct ChatCompletionResponse {
136 pub id: String,
137 pub choices: Vec<Choice>,
138 pub created: u64,
139 pub model: String,
140 pub system_fingerprint: String,
141 pub object: String,
142 pub usage: Usage,
143}
144
145generate_repr!(ChatCompletionResponse);
146
147#[cfg_attr(feature = "pyo3_macros", pyclass)]
148#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
149#[derive(Debug, Clone, Serialize)]
150pub struct ChatCompletionChunkResponse {
152 pub id: String,
153 pub choices: Vec<ChunkChoice>,
154 pub created: u128,
155 pub model: String,
156 pub system_fingerprint: String,
157 pub object: String,
158 pub usage: Option<Usage>,
159}
160
161generate_repr!(ChatCompletionChunkResponse);
162
163#[cfg_attr(feature = "pyo3_macros", pyclass)]
164#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
165#[derive(Debug, Clone, Serialize)]
166pub struct CompletionChoice {
168 pub finish_reason: String,
169 pub index: usize,
170 pub text: String,
171 pub logprobs: Option<()>,
172}
173
174generate_repr!(CompletionChoice);
175
176#[cfg_attr(feature = "pyo3_macros", pyclass)]
177#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
178#[derive(Debug, Clone, Serialize)]
179pub struct CompletionResponse {
181 pub id: String,
182 pub choices: Vec<CompletionChoice>,
183 pub created: u64,
184 pub model: String,
185 pub system_fingerprint: String,
186 pub object: String,
187 pub usage: Usage,
188}
189
190generate_repr!(CompletionResponse);
191
192#[cfg_attr(feature = "pyo3_macros", pyclass)]
193#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
194#[derive(Debug, Clone, Serialize)]
195pub struct CompletionChunkResponse {
197 pub id: String,
198 pub choices: Vec<CompletionChunkChoice>,
199 pub created: u128,
200 pub model: String,
201 pub system_fingerprint: String,
202 pub object: String,
203}
204
205generate_repr!(CompletionChunkResponse);
206
207#[cfg_attr(feature = "pyo3_macros", pyclass)]
208#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
209#[derive(Debug, Clone, Serialize)]
210pub struct ImageChoice {
211 pub url: Option<String>,
212 pub b64_json: Option<String>,
213}
214
215generate_repr!(ImageChoice);
216
217#[cfg_attr(feature = "pyo3_macros", pyclass)]
218#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
219#[derive(Debug, Clone, Serialize)]
220pub struct ImageGenerationResponse {
221 pub created: u128,
222 pub data: Vec<ImageChoice>,
223}
224
225generate_repr!(ImageGenerationResponse);
226
227pub enum Response {
232 InternalError(Box<dyn Error + Send + Sync>),
233 ValidationError(Box<dyn Error + Send + Sync>),
234 ModelError(String, ChatCompletionResponse),
236 Done(ChatCompletionResponse),
237 Chunk(ChatCompletionChunkResponse),
238 CompletionModelError(String, CompletionResponse),
240 CompletionDone(CompletionResponse),
241 CompletionChunk(CompletionChunkResponse),
242 ImageGeneration(ImageGenerationResponse),
244 Raw {
246 logits_chunks: Vec<Tensor>,
247 tokens: Vec<u32>,
248 },
249}
250
251#[derive(Debug, Clone)]
252pub enum ResponseOk {
253 Done(ChatCompletionResponse),
255 Chunk(ChatCompletionChunkResponse),
256 CompletionDone(CompletionResponse),
258 CompletionChunk(CompletionChunkResponse),
259 ImageGeneration(ImageGenerationResponse),
261 Raw {
263 logits_chunks: Vec<Tensor>,
264 tokens: Vec<u32>,
265 },
266}
267
268pub enum ResponseErr {
269 InternalError(Box<dyn Error + Send + Sync>),
270 ValidationError(Box<dyn Error + Send + Sync>),
271 ModelError(String, ChatCompletionResponse),
272 CompletionModelError(String, CompletionResponse),
273}
274
275impl Display for ResponseErr {
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 match self {
278 Self::InternalError(e) | Self::ValidationError(e) => Display::fmt(e, f),
279 Self::ModelError(e, x) => f
280 .debug_struct("ChatModelError")
281 .field("msg", e)
282 .field("incomplete_response", x)
283 .finish(),
284 Self::CompletionModelError(e, x) => f
285 .debug_struct("CompletionModelError")
286 .field("msg", e)
287 .field("incomplete_response", x)
288 .finish(),
289 }
290 }
291}
292
293impl Debug for ResponseErr {
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 match self {
296 Self::InternalError(e) | Self::ValidationError(e) => Debug::fmt(e, f),
297 Self::ModelError(e, x) => f
298 .debug_struct("ChatModelError")
299 .field("msg", e)
300 .field("incomplete_response", x)
301 .finish(),
302 Self::CompletionModelError(e, x) => f
303 .debug_struct("CompletionModelError")
304 .field("msg", e)
305 .field("incomplete_response", x)
306 .finish(),
307 }
308 }
309}
310
311impl std::error::Error for ResponseErr {}
312
313impl Response {
314 pub fn as_result(self) -> Result<ResponseOk, Box<ResponseErr>> {
316 match self {
317 Self::Done(x) => Ok(ResponseOk::Done(x)),
318 Self::Chunk(x) => Ok(ResponseOk::Chunk(x)),
319 Self::CompletionDone(x) => Ok(ResponseOk::CompletionDone(x)),
320 Self::CompletionChunk(x) => Ok(ResponseOk::CompletionChunk(x)),
321 Self::InternalError(e) => Err(Box::new(ResponseErr::InternalError(e))),
322 Self::ValidationError(e) => Err(Box::new(ResponseErr::ValidationError(e))),
323 Self::ModelError(e, x) => Err(Box::new(ResponseErr::ModelError(e, x))),
324 Self::CompletionModelError(e, x) => {
325 Err(Box::new(ResponseErr::CompletionModelError(e, x)))
326 }
327 Self::ImageGeneration(x) => Ok(ResponseOk::ImageGeneration(x)),
328 Self::Raw {
329 logits_chunks,
330 tokens,
331 } => Ok(ResponseOk::Raw {
332 logits_chunks,
333 tokens,
334 }),
335 }
336 }
337}