1use std::{
2 error::Error,
3 fmt::{Debug, Display},
4 sync::Arc,
5};
6
7use candle_core::Tensor;
8#[cfg(feature = "pyo3_macros")]
9use pyo3::{pyclass, pymethods};
10use serde::Serialize;
11
12use crate::{sampler::TopLogprob, tools::ToolCallResponse};
13
14pub const SYSTEM_FINGERPRINT: &str = "local";
15
16macro_rules! generate_repr {
17 ($t:ident) => {
18 #[cfg(feature = "pyo3_macros")]
19 #[pymethods]
20 impl $t {
21 fn __repr__(&self) -> String {
22 format!("{self:#?}")
23 }
24 }
25 };
26}
27
28#[cfg_attr(feature = "pyo3_macros", pyclass)]
29#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
30#[derive(Debug, Clone, Serialize)]
31pub struct ResponseMessage {
33 pub content: Option<String>,
34 pub role: String,
35 pub tool_calls: Option<Vec<ToolCallResponse>>,
36}
37
38generate_repr!(ResponseMessage);
39
40#[cfg_attr(feature = "pyo3_macros", pyclass)]
41#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
42#[derive(Debug, Clone, Serialize)]
43pub struct Delta {
45 pub content: Option<String>,
46 pub role: String,
47 pub tool_calls: Option<Vec<ToolCallResponse>>,
48}
49
50generate_repr!(Delta);
51
52#[cfg_attr(feature = "pyo3_macros", pyclass)]
53#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
54#[derive(Debug, Clone, Serialize)]
55pub struct ResponseLogprob {
57 pub token: String,
58 pub logprob: f32,
59 pub bytes: Option<Vec<u8>>,
60 pub top_logprobs: Vec<TopLogprob>,
61}
62
63generate_repr!(ResponseLogprob);
64
65#[cfg_attr(feature = "pyo3_macros", pyclass)]
66#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
67#[derive(Debug, Clone, Serialize)]
68pub struct Logprobs {
70 pub content: Option<Vec<ResponseLogprob>>,
71}
72
73generate_repr!(Logprobs);
74
75#[cfg_attr(feature = "pyo3_macros", pyclass)]
76#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
77#[derive(Debug, Clone, Serialize)]
78pub struct Choice {
80 pub finish_reason: String,
81 pub index: usize,
82 pub message: ResponseMessage,
83 pub logprobs: Option<Logprobs>,
84}
85
86generate_repr!(Choice);
87
88#[cfg_attr(feature = "pyo3_macros", pyclass)]
89#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
90#[derive(Debug, Clone, Serialize)]
91pub struct ChunkChoice {
93 pub finish_reason: Option<String>,
94 pub index: usize,
95 pub delta: Delta,
96 pub logprobs: Option<ResponseLogprob>,
97}
98
99generate_repr!(ChunkChoice);
100
101#[cfg_attr(feature = "pyo3_macros", pyclass)]
102#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
103#[derive(Debug, Clone, Serialize)]
104pub struct CompletionChunkChoice {
106 pub text: String,
107 pub index: usize,
108 pub logprobs: Option<ResponseLogprob>,
109 pub finish_reason: Option<String>,
110}
111
112generate_repr!(CompletionChunkChoice);
113
114#[cfg_attr(feature = "pyo3_macros", pyclass)]
115#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
116#[derive(Debug, Clone, Serialize)]
117pub struct Usage {
119 pub completion_tokens: usize,
120 pub prompt_tokens: usize,
121 pub total_tokens: usize,
122 pub avg_tok_per_sec: f32,
123 pub avg_prompt_tok_per_sec: f32,
124 pub avg_compl_tok_per_sec: f32,
125 pub total_time_sec: f32,
126 pub total_prompt_time_sec: f32,
127 pub total_completion_time_sec: f32,
128}
129
130generate_repr!(Usage);
131
132#[cfg_attr(feature = "pyo3_macros", pyclass)]
133#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
134#[derive(Debug, Clone, Serialize)]
135pub struct ChatCompletionResponse {
137 pub id: String,
138 pub choices: Vec<Choice>,
139 pub created: u64,
140 pub model: String,
141 pub system_fingerprint: String,
142 pub object: String,
143 pub usage: Usage,
144}
145
146generate_repr!(ChatCompletionResponse);
147
148#[cfg_attr(feature = "pyo3_macros", pyclass)]
149#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
150#[derive(Debug, Clone, Serialize)]
151pub struct ChatCompletionChunkResponse {
153 pub id: String,
154 pub choices: Vec<ChunkChoice>,
155 pub created: u128,
156 pub model: String,
157 pub system_fingerprint: String,
158 pub object: String,
159 pub usage: Option<Usage>,
160}
161
162generate_repr!(ChatCompletionChunkResponse);
163
164#[cfg_attr(feature = "pyo3_macros", pyclass)]
165#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
166#[derive(Debug, Clone, Serialize)]
167pub struct CompletionChoice {
169 pub finish_reason: String,
170 pub index: usize,
171 pub text: String,
172 pub logprobs: Option<()>,
173}
174
175generate_repr!(CompletionChoice);
176
177#[cfg_attr(feature = "pyo3_macros", pyclass)]
178#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
179#[derive(Debug, Clone, Serialize)]
180pub struct CompletionResponse {
182 pub id: String,
183 pub choices: Vec<CompletionChoice>,
184 pub created: u64,
185 pub model: String,
186 pub system_fingerprint: String,
187 pub object: String,
188 pub usage: Usage,
189}
190
191generate_repr!(CompletionResponse);
192
193#[cfg_attr(feature = "pyo3_macros", pyclass)]
194#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
195#[derive(Debug, Clone, Serialize)]
196pub struct CompletionChunkResponse {
198 pub id: String,
199 pub choices: Vec<CompletionChunkChoice>,
200 pub created: u128,
201 pub model: String,
202 pub system_fingerprint: String,
203 pub object: String,
204}
205
206generate_repr!(CompletionChunkResponse);
207
208#[cfg_attr(feature = "pyo3_macros", pyclass)]
209#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
210#[derive(Debug, Clone, Serialize)]
211pub struct ImageChoice {
212 pub url: Option<String>,
213 pub b64_json: Option<String>,
214}
215
216generate_repr!(ImageChoice);
217
218#[cfg_attr(feature = "pyo3_macros", pyclass)]
219#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
220#[derive(Debug, Clone, Serialize)]
221pub struct ImageGenerationResponse {
222 pub created: u128,
223 pub data: Vec<ImageChoice>,
224}
225
226generate_repr!(ImageGenerationResponse);
227
228pub enum Response {
233 InternalError(Box<dyn Error + Send + Sync>),
234 ValidationError(Box<dyn Error + Send + Sync>),
235 ModelError(String, ChatCompletionResponse),
237 Done(ChatCompletionResponse),
238 Chunk(ChatCompletionChunkResponse),
239 CompletionModelError(String, CompletionResponse),
241 CompletionDone(CompletionResponse),
242 CompletionChunk(CompletionChunkResponse),
243 ImageGeneration(ImageGenerationResponse),
245 Speech {
247 pcm: Arc<Vec<f32>>,
248 rate: usize,
249 channels: usize,
250 },
251 Raw {
253 logits_chunks: Vec<Tensor>,
254 tokens: Vec<u32>,
255 },
256}
257
258#[derive(Debug, Clone)]
259pub enum ResponseOk {
260 Done(ChatCompletionResponse),
262 Chunk(ChatCompletionChunkResponse),
263 CompletionDone(CompletionResponse),
265 CompletionChunk(CompletionChunkResponse),
266 ImageGeneration(ImageGenerationResponse),
268 Speech {
270 pcm: Arc<Vec<f32>>,
271 rate: usize,
272 channels: usize,
273 },
274 Raw {
276 logits_chunks: Vec<Tensor>,
277 tokens: Vec<u32>,
278 },
279}
280
281pub enum ResponseErr {
282 InternalError(Box<dyn Error + Send + Sync>),
283 ValidationError(Box<dyn Error + Send + Sync>),
284 ModelError(String, ChatCompletionResponse),
285 CompletionModelError(String, CompletionResponse),
286}
287
288impl Display for ResponseErr {
289 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
290 match self {
291 Self::InternalError(e) | Self::ValidationError(e) => Display::fmt(e, f),
292 Self::ModelError(e, x) => f
293 .debug_struct("ChatModelError")
294 .field("msg", e)
295 .field("incomplete_response", x)
296 .finish(),
297 Self::CompletionModelError(e, x) => f
298 .debug_struct("CompletionModelError")
299 .field("msg", e)
300 .field("incomplete_response", x)
301 .finish(),
302 }
303 }
304}
305
306impl Debug for ResponseErr {
307 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 match self {
309 Self::InternalError(e) | Self::ValidationError(e) => Debug::fmt(e, f),
310 Self::ModelError(e, x) => f
311 .debug_struct("ChatModelError")
312 .field("msg", e)
313 .field("incomplete_response", x)
314 .finish(),
315 Self::CompletionModelError(e, x) => f
316 .debug_struct("CompletionModelError")
317 .field("msg", e)
318 .field("incomplete_response", x)
319 .finish(),
320 }
321 }
322}
323
324impl std::error::Error for ResponseErr {}
325
326impl Response {
327 pub fn as_result(self) -> Result<ResponseOk, Box<ResponseErr>> {
329 match self {
330 Self::Done(x) => Ok(ResponseOk::Done(x)),
331 Self::Chunk(x) => Ok(ResponseOk::Chunk(x)),
332 Self::CompletionDone(x) => Ok(ResponseOk::CompletionDone(x)),
333 Self::CompletionChunk(x) => Ok(ResponseOk::CompletionChunk(x)),
334 Self::InternalError(e) => Err(Box::new(ResponseErr::InternalError(e))),
335 Self::ValidationError(e) => Err(Box::new(ResponseErr::ValidationError(e))),
336 Self::ModelError(e, x) => Err(Box::new(ResponseErr::ModelError(e, x))),
337 Self::CompletionModelError(e, x) => {
338 Err(Box::new(ResponseErr::CompletionModelError(e, x)))
339 }
340 Self::ImageGeneration(x) => Ok(ResponseOk::ImageGeneration(x)),
341 Self::Speech {
342 pcm,
343 rate,
344 channels,
345 } => Ok(ResponseOk::Speech {
346 pcm,
347 rate,
348 channels,
349 }),
350 Self::Raw {
351 logits_chunks,
352 tokens,
353 } => Ok(ResponseOk::Raw {
354 logits_chunks,
355 tokens,
356 }),
357 }
358 }
359}