1use std::{
2 error::Error,
3 fmt::{Debug, Display},
4 sync::Arc,
5};
6
7use candle_core::Tensor;
8#[cfg(feature = "pyo3_macros")]
9use pyo3::{pyclass, pymethods};
10use serde::Serialize;
11
12use crate::{sampler::TopLogprob, tools::ToolCallResponse};
13
14pub const SYSTEM_FINGERPRINT: &str = "local";
15
16macro_rules! generate_repr {
17 ($t:ident) => {
18 #[cfg(feature = "pyo3_macros")]
19 #[pymethods]
20 impl $t {
21 fn __repr__(&self) -> String {
22 format!("{self:#?}")
23 }
24 }
25 };
26}
27
28#[cfg_attr(feature = "pyo3_macros", pyclass)]
29#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
30#[derive(Debug, Clone, Serialize)]
31pub struct ResponseMessage {
33 pub content: Option<String>,
34 pub role: String,
35 pub tool_calls: Option<Vec<ToolCallResponse>>,
36 #[serde(skip_serializing_if = "Option::is_none")]
39 pub reasoning_content: Option<String>,
40}
41
42generate_repr!(ResponseMessage);
43
44#[cfg_attr(feature = "pyo3_macros", pyclass)]
45#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
46#[derive(Debug, Clone, Serialize)]
47pub struct Delta {
49 pub content: Option<String>,
50 pub role: String,
51 pub tool_calls: Option<Vec<ToolCallResponse>>,
52 #[serde(skip_serializing_if = "Option::is_none")]
55 pub reasoning_content: Option<String>,
56}
57
58generate_repr!(Delta);
59
60#[cfg_attr(feature = "pyo3_macros", pyclass)]
61#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
62#[derive(Debug, Clone, Serialize)]
63pub struct ResponseLogprob {
65 pub token: String,
66 pub logprob: f32,
67 pub bytes: Option<Vec<u8>>,
68 pub top_logprobs: Vec<TopLogprob>,
69}
70
71generate_repr!(ResponseLogprob);
72
73#[cfg_attr(feature = "pyo3_macros", pyclass)]
74#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
75#[derive(Debug, Clone, Serialize)]
76pub struct Logprobs {
78 pub content: Option<Vec<ResponseLogprob>>,
79}
80
81generate_repr!(Logprobs);
82
83#[cfg_attr(feature = "pyo3_macros", pyclass)]
84#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
85#[derive(Debug, Clone, Serialize)]
86pub struct Choice {
88 pub finish_reason: String,
89 pub index: usize,
90 pub message: ResponseMessage,
91 pub logprobs: Option<Logprobs>,
92}
93
94generate_repr!(Choice);
95
96#[cfg_attr(feature = "pyo3_macros", pyclass)]
97#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
98#[derive(Debug, Clone, Serialize)]
99pub struct ChunkChoice {
101 pub finish_reason: Option<String>,
102 pub index: usize,
103 pub delta: Delta,
104 pub logprobs: Option<ResponseLogprob>,
105}
106
107generate_repr!(ChunkChoice);
108
109#[cfg_attr(feature = "pyo3_macros", pyclass)]
110#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
111#[derive(Debug, Clone, Serialize)]
112pub struct CompletionChunkChoice {
114 pub text: String,
115 pub index: usize,
116 pub logprobs: Option<ResponseLogprob>,
117 pub finish_reason: Option<String>,
118}
119
120generate_repr!(CompletionChunkChoice);
121
122#[cfg_attr(feature = "pyo3_macros", pyclass)]
123#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
124#[derive(Debug, Clone, Serialize)]
125pub struct Usage {
127 pub completion_tokens: usize,
128 pub prompt_tokens: usize,
129 pub total_tokens: usize,
130 pub avg_tok_per_sec: f32,
131 pub avg_prompt_tok_per_sec: f32,
132 pub avg_compl_tok_per_sec: f32,
133 pub total_time_sec: f32,
134 pub total_prompt_time_sec: f32,
135 pub total_completion_time_sec: f32,
136}
137
138generate_repr!(Usage);
139
140#[cfg_attr(feature = "pyo3_macros", pyclass)]
141#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
142#[derive(Debug, Clone, Serialize)]
143pub struct ChatCompletionResponse {
145 pub id: String,
146 pub choices: Vec<Choice>,
147 pub created: u64,
148 pub model: String,
149 pub system_fingerprint: String,
150 pub object: String,
151 pub usage: Usage,
152}
153
154generate_repr!(ChatCompletionResponse);
155
156#[cfg_attr(feature = "pyo3_macros", pyclass)]
157#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
158#[derive(Debug, Clone, Serialize)]
159pub struct ChatCompletionChunkResponse {
161 pub id: String,
162 pub choices: Vec<ChunkChoice>,
163 pub created: u128,
164 pub model: String,
165 pub system_fingerprint: String,
166 pub object: String,
167 pub usage: Option<Usage>,
168}
169
170generate_repr!(ChatCompletionChunkResponse);
171
172#[cfg_attr(feature = "pyo3_macros", pyclass)]
173#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
174#[derive(Debug, Clone, Serialize)]
175pub struct CompletionChoice {
177 pub finish_reason: String,
178 pub index: usize,
179 pub text: String,
180 pub logprobs: Option<()>,
181}
182
183generate_repr!(CompletionChoice);
184
185#[cfg_attr(feature = "pyo3_macros", pyclass)]
186#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
187#[derive(Debug, Clone, Serialize)]
188pub struct CompletionResponse {
190 pub id: String,
191 pub choices: Vec<CompletionChoice>,
192 pub created: u64,
193 pub model: String,
194 pub system_fingerprint: String,
195 pub object: String,
196 pub usage: Usage,
197}
198
199generate_repr!(CompletionResponse);
200
201#[cfg_attr(feature = "pyo3_macros", pyclass)]
202#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
203#[derive(Debug, Clone, Serialize)]
204pub struct CompletionChunkResponse {
206 pub id: String,
207 pub choices: Vec<CompletionChunkChoice>,
208 pub created: u128,
209 pub model: String,
210 pub system_fingerprint: String,
211 pub object: String,
212}
213
214generate_repr!(CompletionChunkResponse);
215
216#[cfg_attr(feature = "pyo3_macros", pyclass)]
217#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
218#[derive(Debug, Clone, Serialize)]
219pub struct ImageChoice {
220 pub url: Option<String>,
221 pub b64_json: Option<String>,
222}
223
224generate_repr!(ImageChoice);
225
226#[cfg_attr(feature = "pyo3_macros", pyclass)]
227#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
228#[derive(Debug, Clone, Serialize)]
229pub struct ImageGenerationResponse {
230 pub created: u128,
231 pub data: Vec<ImageChoice>,
232}
233
234generate_repr!(ImageGenerationResponse);
235
236pub enum Response {
241 InternalError(Box<dyn Error + Send + Sync>),
242 ValidationError(Box<dyn Error + Send + Sync>),
243 ModelError(String, ChatCompletionResponse),
245 Done(ChatCompletionResponse),
246 Chunk(ChatCompletionChunkResponse),
247 CompletionModelError(String, CompletionResponse),
249 CompletionDone(CompletionResponse),
250 CompletionChunk(CompletionChunkResponse),
251 ImageGeneration(ImageGenerationResponse),
253 Speech {
255 pcm: Arc<Vec<f32>>,
256 rate: usize,
257 channels: usize,
258 },
259 Raw {
261 logits_chunks: Vec<Tensor>,
262 tokens: Vec<u32>,
263 },
264 Embeddings {
265 embeddings: Vec<f32>,
266 prompt_tokens: usize,
267 total_tokens: usize,
268 },
269}
270
271#[derive(Debug, Clone)]
272pub enum ResponseOk {
273 Done(ChatCompletionResponse),
275 Chunk(ChatCompletionChunkResponse),
276 CompletionDone(CompletionResponse),
278 CompletionChunk(CompletionChunkResponse),
279 ImageGeneration(ImageGenerationResponse),
281 Speech {
283 pcm: Arc<Vec<f32>>,
284 rate: usize,
285 channels: usize,
286 },
287 Raw {
289 logits_chunks: Vec<Tensor>,
290 tokens: Vec<u32>,
291 },
292 Embeddings {
294 embeddings: Vec<f32>,
295 prompt_tokens: usize,
296 total_tokens: usize,
297 },
298}
299
300pub enum ResponseErr {
301 InternalError(Box<dyn Error + Send + Sync>),
302 ValidationError(Box<dyn Error + Send + Sync>),
303 ModelError(String, ChatCompletionResponse),
304 CompletionModelError(String, CompletionResponse),
305}
306
307impl Display for ResponseErr {
308 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 match self {
310 Self::InternalError(e) | Self::ValidationError(e) => Display::fmt(e, f),
311 Self::ModelError(e, x) => f
312 .debug_struct("ChatModelError")
313 .field("msg", e)
314 .field("incomplete_response", x)
315 .finish(),
316 Self::CompletionModelError(e, x) => f
317 .debug_struct("CompletionModelError")
318 .field("msg", e)
319 .field("incomplete_response", x)
320 .finish(),
321 }
322 }
323}
324
325impl Debug for ResponseErr {
326 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327 match self {
328 Self::InternalError(e) | Self::ValidationError(e) => Debug::fmt(e, f),
329 Self::ModelError(e, x) => f
330 .debug_struct("ChatModelError")
331 .field("msg", e)
332 .field("incomplete_response", x)
333 .finish(),
334 Self::CompletionModelError(e, x) => f
335 .debug_struct("CompletionModelError")
336 .field("msg", e)
337 .field("incomplete_response", x)
338 .finish(),
339 }
340 }
341}
342
343impl std::error::Error for ResponseErr {}
344
345impl Response {
346 pub fn as_result(self) -> Result<ResponseOk, Box<ResponseErr>> {
348 match self {
349 Self::Done(x) => Ok(ResponseOk::Done(x)),
350 Self::Chunk(x) => Ok(ResponseOk::Chunk(x)),
351 Self::CompletionDone(x) => Ok(ResponseOk::CompletionDone(x)),
352 Self::CompletionChunk(x) => Ok(ResponseOk::CompletionChunk(x)),
353 Self::InternalError(e) => Err(Box::new(ResponseErr::InternalError(e))),
354 Self::ValidationError(e) => Err(Box::new(ResponseErr::ValidationError(e))),
355 Self::ModelError(e, x) => Err(Box::new(ResponseErr::ModelError(e, x))),
356 Self::CompletionModelError(e, x) => {
357 Err(Box::new(ResponseErr::CompletionModelError(e, x)))
358 }
359 Self::ImageGeneration(x) => Ok(ResponseOk::ImageGeneration(x)),
360 Self::Speech {
361 pcm,
362 rate,
363 channels,
364 } => Ok(ResponseOk::Speech {
365 pcm,
366 rate,
367 channels,
368 }),
369 Self::Raw {
370 logits_chunks,
371 tokens,
372 } => Ok(ResponseOk::Raw {
373 logits_chunks,
374 tokens,
375 }),
376 Self::Embeddings {
377 embeddings,
378 prompt_tokens,
379 total_tokens,
380 } => Ok(ResponseOk::Embeddings {
381 embeddings,
382 prompt_tokens,
383 total_tokens,
384 }),
385 }
386 }
387}