1use std::{
2 error::Error,
3 fmt::{Debug, Display},
4 sync::Arc,
5};
6
7use candle_core::Tensor;
8#[cfg(feature = "pyo3_macros")]
9use pyo3::{pyclass, pymethods};
10use serde::Serialize;
11
12use crate::{sampler::TopLogprob, tools::ToolCallResponse};
13
14pub const SYSTEM_FINGERPRINT: &str = "local";
15
16macro_rules! generate_repr {
17 ($t:ident) => {
18 #[cfg(feature = "pyo3_macros")]
19 #[pymethods]
20 impl $t {
21 fn __repr__(&self) -> String {
22 format!("{self:#?}")
23 }
24 }
25 };
26}
27
28#[cfg_attr(feature = "pyo3_macros", pyclass)]
29#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
30#[derive(Debug, Clone, Serialize)]
31pub struct ResponseMessage {
33 pub content: Option<String>,
34 pub role: String,
35 pub tool_calls: Option<Vec<ToolCallResponse>>,
36}
37
38generate_repr!(ResponseMessage);
39
40#[cfg_attr(feature = "pyo3_macros", pyclass)]
41#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
42#[derive(Debug, Clone, Serialize)]
43pub struct Delta {
45 pub content: Option<String>,
46 pub role: String,
47 pub tool_calls: Option<Vec<ToolCallResponse>>,
48}
49
50generate_repr!(Delta);
51
52#[cfg_attr(feature = "pyo3_macros", pyclass)]
53#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
54#[derive(Debug, Clone, Serialize)]
55pub struct ResponseLogprob {
57 pub token: String,
58 pub logprob: f32,
59 pub bytes: Option<Vec<u8>>,
60 pub top_logprobs: Vec<TopLogprob>,
61}
62
63generate_repr!(ResponseLogprob);
64
65#[cfg_attr(feature = "pyo3_macros", pyclass)]
66#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
67#[derive(Debug, Clone, Serialize)]
68pub struct Logprobs {
70 pub content: Option<Vec<ResponseLogprob>>,
71}
72
73generate_repr!(Logprobs);
74
75#[cfg_attr(feature = "pyo3_macros", pyclass)]
76#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
77#[derive(Debug, Clone, Serialize)]
78pub struct Choice {
80 pub finish_reason: String,
81 pub index: usize,
82 pub message: ResponseMessage,
83 pub logprobs: Option<Logprobs>,
84}
85
86generate_repr!(Choice);
87
88#[cfg_attr(feature = "pyo3_macros", pyclass)]
89#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
90#[derive(Debug, Clone, Serialize)]
91pub struct ChunkChoice {
93 pub finish_reason: Option<String>,
94 pub index: usize,
95 pub delta: Delta,
96 pub logprobs: Option<ResponseLogprob>,
97}
98
99generate_repr!(ChunkChoice);
100
101#[cfg_attr(feature = "pyo3_macros", pyclass)]
102#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
103#[derive(Debug, Clone, Serialize)]
104pub struct CompletionChunkChoice {
106 pub text: String,
107 pub index: usize,
108 pub logprobs: Option<ResponseLogprob>,
109 pub finish_reason: Option<String>,
110}
111
112generate_repr!(CompletionChunkChoice);
113
114#[cfg_attr(feature = "pyo3_macros", pyclass)]
115#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
116#[derive(Debug, Clone, Serialize)]
117pub struct Usage {
119 pub completion_tokens: usize,
120 pub prompt_tokens: usize,
121 pub total_tokens: usize,
122 pub avg_tok_per_sec: f32,
123 pub avg_prompt_tok_per_sec: f32,
124 pub avg_compl_tok_per_sec: f32,
125 pub total_time_sec: f32,
126 pub total_prompt_time_sec: f32,
127 pub total_completion_time_sec: f32,
128}
129
130generate_repr!(Usage);
131
132#[cfg_attr(feature = "pyo3_macros", pyclass)]
133#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
134#[derive(Debug, Clone, Serialize)]
135pub struct ChatCompletionResponse {
137 pub id: String,
138 pub choices: Vec<Choice>,
139 pub created: u64,
140 pub model: String,
141 pub system_fingerprint: String,
142 pub object: String,
143 pub usage: Usage,
144}
145
146generate_repr!(ChatCompletionResponse);
147
148#[cfg_attr(feature = "pyo3_macros", pyclass)]
149#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
150#[derive(Debug, Clone, Serialize)]
151pub struct ChatCompletionChunkResponse {
153 pub id: String,
154 pub choices: Vec<ChunkChoice>,
155 pub created: u128,
156 pub model: String,
157 pub system_fingerprint: String,
158 pub object: String,
159 pub usage: Option<Usage>,
160}
161
162generate_repr!(ChatCompletionChunkResponse);
163
164#[cfg_attr(feature = "pyo3_macros", pyclass)]
165#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
166#[derive(Debug, Clone, Serialize)]
167pub struct CompletionChoice {
169 pub finish_reason: String,
170 pub index: usize,
171 pub text: String,
172 pub logprobs: Option<()>,
173}
174
175generate_repr!(CompletionChoice);
176
177#[cfg_attr(feature = "pyo3_macros", pyclass)]
178#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
179#[derive(Debug, Clone, Serialize)]
180pub struct CompletionResponse {
182 pub id: String,
183 pub choices: Vec<CompletionChoice>,
184 pub created: u64,
185 pub model: String,
186 pub system_fingerprint: String,
187 pub object: String,
188 pub usage: Usage,
189}
190
191generate_repr!(CompletionResponse);
192
193#[cfg_attr(feature = "pyo3_macros", pyclass)]
194#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
195#[derive(Debug, Clone, Serialize)]
196pub struct CompletionChunkResponse {
198 pub id: String,
199 pub choices: Vec<CompletionChunkChoice>,
200 pub created: u128,
201 pub model: String,
202 pub system_fingerprint: String,
203 pub object: String,
204}
205
206generate_repr!(CompletionChunkResponse);
207
208#[cfg_attr(feature = "pyo3_macros", pyclass)]
209#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
210#[derive(Debug, Clone, Serialize)]
211pub struct ImageChoice {
212 pub url: Option<String>,
213 pub b64_json: Option<String>,
214}
215
216generate_repr!(ImageChoice);
217
218#[cfg_attr(feature = "pyo3_macros", pyclass)]
219#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
220#[derive(Debug, Clone, Serialize)]
221pub struct ImageGenerationResponse {
222 pub created: u128,
223 pub data: Vec<ImageChoice>,
224}
225
226generate_repr!(ImageGenerationResponse);
227
228pub enum Response {
233 InternalError(Box<dyn Error + Send + Sync>),
234 ValidationError(Box<dyn Error + Send + Sync>),
235 ModelError(String, ChatCompletionResponse),
237 Done(ChatCompletionResponse),
238 Chunk(ChatCompletionChunkResponse),
239 CompletionModelError(String, CompletionResponse),
241 CompletionDone(CompletionResponse),
242 CompletionChunk(CompletionChunkResponse),
243 ImageGeneration(ImageGenerationResponse),
245 Speech {
247 pcm: Arc<Vec<f32>>,
248 rate: usize,
249 channels: usize,
250 },
251 Raw {
253 logits_chunks: Vec<Tensor>,
254 tokens: Vec<u32>,
255 },
256 Embeddings {
257 embeddings: Vec<f32>,
258 prompt_tokens: usize,
259 total_tokens: usize,
260 },
261}
262
263#[derive(Debug, Clone)]
264pub enum ResponseOk {
265 Done(ChatCompletionResponse),
267 Chunk(ChatCompletionChunkResponse),
268 CompletionDone(CompletionResponse),
270 CompletionChunk(CompletionChunkResponse),
271 ImageGeneration(ImageGenerationResponse),
273 Speech {
275 pcm: Arc<Vec<f32>>,
276 rate: usize,
277 channels: usize,
278 },
279 Raw {
281 logits_chunks: Vec<Tensor>,
282 tokens: Vec<u32>,
283 },
284 Embeddings {
286 embeddings: Vec<f32>,
287 prompt_tokens: usize,
288 total_tokens: usize,
289 },
290}
291
292pub enum ResponseErr {
293 InternalError(Box<dyn Error + Send + Sync>),
294 ValidationError(Box<dyn Error + Send + Sync>),
295 ModelError(String, ChatCompletionResponse),
296 CompletionModelError(String, CompletionResponse),
297}
298
299impl Display for ResponseErr {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 match self {
302 Self::InternalError(e) | Self::ValidationError(e) => Display::fmt(e, f),
303 Self::ModelError(e, x) => f
304 .debug_struct("ChatModelError")
305 .field("msg", e)
306 .field("incomplete_response", x)
307 .finish(),
308 Self::CompletionModelError(e, x) => f
309 .debug_struct("CompletionModelError")
310 .field("msg", e)
311 .field("incomplete_response", x)
312 .finish(),
313 }
314 }
315}
316
317impl Debug for ResponseErr {
318 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319 match self {
320 Self::InternalError(e) | Self::ValidationError(e) => Debug::fmt(e, f),
321 Self::ModelError(e, x) => f
322 .debug_struct("ChatModelError")
323 .field("msg", e)
324 .field("incomplete_response", x)
325 .finish(),
326 Self::CompletionModelError(e, x) => f
327 .debug_struct("CompletionModelError")
328 .field("msg", e)
329 .field("incomplete_response", x)
330 .finish(),
331 }
332 }
333}
334
335impl std::error::Error for ResponseErr {}
336
337impl Response {
338 pub fn as_result(self) -> Result<ResponseOk, Box<ResponseErr>> {
340 match self {
341 Self::Done(x) => Ok(ResponseOk::Done(x)),
342 Self::Chunk(x) => Ok(ResponseOk::Chunk(x)),
343 Self::CompletionDone(x) => Ok(ResponseOk::CompletionDone(x)),
344 Self::CompletionChunk(x) => Ok(ResponseOk::CompletionChunk(x)),
345 Self::InternalError(e) => Err(Box::new(ResponseErr::InternalError(e))),
346 Self::ValidationError(e) => Err(Box::new(ResponseErr::ValidationError(e))),
347 Self::ModelError(e, x) => Err(Box::new(ResponseErr::ModelError(e, x))),
348 Self::CompletionModelError(e, x) => {
349 Err(Box::new(ResponseErr::CompletionModelError(e, x)))
350 }
351 Self::ImageGeneration(x) => Ok(ResponseOk::ImageGeneration(x)),
352 Self::Speech {
353 pcm,
354 rate,
355 channels,
356 } => Ok(ResponseOk::Speech {
357 pcm,
358 rate,
359 channels,
360 }),
361 Self::Raw {
362 logits_chunks,
363 tokens,
364 } => Ok(ResponseOk::Raw {
365 logits_chunks,
366 tokens,
367 }),
368 Self::Embeddings {
369 embeddings,
370 prompt_tokens,
371 total_tokens,
372 } => Ok(ResponseOk::Embeddings {
373 embeddings,
374 prompt_tokens,
375 total_tokens,
376 }),
377 }
378 }
379}