use std::{
error::Error,
fmt::{Debug, Display},
};
use candle_core::Tensor;
#[cfg(feature = "pyo3_macros")]
use pyo3::{pyclass, pymethods};
use serde::Serialize;
use crate::{sampler::TopLogprob, tools::ToolCallResponse};
pub const SYSTEM_FINGERPRINT: &str = "local";
macro_rules! generate_repr {
($t:ident) => {
#[cfg(feature = "pyo3_macros")]
#[pymethods]
impl $t {
fn __repr__(&self) -> String {
format!("{self:#?}")
}
}
};
}
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
pub struct ResponseMessage {
pub content: Option<String>,
pub role: String,
pub tool_calls: Vec<ToolCallResponse>,
}
generate_repr!(ResponseMessage);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
pub struct Delta {
pub content: String,
pub role: String,
}
generate_repr!(Delta);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
pub struct ResponseLogprob {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
pub top_logprobs: Vec<TopLogprob>,
}
generate_repr!(ResponseLogprob);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
pub struct Logprobs {
pub content: Option<Vec<ResponseLogprob>>,
}
generate_repr!(Logprobs);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
pub struct Choice {
pub finish_reason: String,
pub index: usize,
pub message: ResponseMessage,
pub logprobs: Option<Logprobs>,
}
generate_repr!(Choice);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
pub struct ChunkChoice {
pub finish_reason: Option<String>,
pub index: usize,
pub delta: Delta,
pub logprobs: Option<ResponseLogprob>,
}
generate_repr!(ChunkChoice);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
pub struct CompletionChunkChoice {
pub text: String,
pub index: usize,
pub logprobs: Option<ResponseLogprob>,
pub finish_reason: Option<String>,
}
generate_repr!(CompletionChunkChoice);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
pub struct Usage {
pub completion_tokens: usize,
pub prompt_tokens: usize,
pub total_tokens: usize,
pub avg_tok_per_sec: f32,
pub avg_prompt_tok_per_sec: f32,
pub avg_compl_tok_per_sec: f32,
pub total_time_sec: f32,
pub total_prompt_time_sec: f32,
pub total_completion_time_sec: f32,
}
generate_repr!(Usage);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub choices: Vec<Choice>,
pub created: u64,
pub model: String,
pub system_fingerprint: String,
pub object: String,
pub usage: Usage,
}
generate_repr!(ChatCompletionResponse);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
pub struct ChatCompletionChunkResponse {
pub id: String,
pub choices: Vec<ChunkChoice>,
pub created: u128,
pub model: String,
pub system_fingerprint: String,
pub object: String,
}
generate_repr!(ChatCompletionChunkResponse);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
pub struct CompletionChoice {
pub finish_reason: String,
pub index: usize,
pub text: String,
pub logprobs: Option<()>,
}
generate_repr!(CompletionChoice);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
pub struct CompletionResponse {
pub id: String,
pub choices: Vec<CompletionChoice>,
pub created: u64,
pub model: String,
pub system_fingerprint: String,
pub object: String,
pub usage: Usage,
}
generate_repr!(CompletionResponse);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
pub struct CompletionChunkResponse {
pub id: String,
pub choices: Vec<CompletionChunkChoice>,
pub created: u128,
pub model: String,
pub system_fingerprint: String,
pub object: String,
}
generate_repr!(CompletionChunkResponse);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
pub struct ImageChoice {
pub url: Option<String>,
pub b64_json: Option<String>,
}
generate_repr!(ImageChoice);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
pub struct ImageGenerationResponse {
pub created: u128,
pub data: Vec<ImageChoice>,
}
generate_repr!(ImageGenerationResponse);
pub enum Response {
InternalError(Box<dyn Error + Send + Sync>),
ValidationError(Box<dyn Error + Send + Sync>),
ModelError(String, ChatCompletionResponse),
Done(ChatCompletionResponse),
Chunk(ChatCompletionChunkResponse),
CompletionModelError(String, CompletionResponse),
CompletionDone(CompletionResponse),
CompletionChunk(CompletionChunkResponse),
ImageGeneration(ImageGenerationResponse),
Raw {
logits_chunks: Vec<Tensor>,
tokens: Vec<u32>,
},
}
#[derive(Debug, Clone)]
pub enum ResponseOk {
Done(ChatCompletionResponse),
Chunk(ChatCompletionChunkResponse),
CompletionDone(CompletionResponse),
CompletionChunk(CompletionChunkResponse),
ImageGeneration(ImageGenerationResponse),
Raw {
logits_chunks: Vec<Tensor>,
tokens: Vec<u32>,
},
}
pub enum ResponseErr {
InternalError(Box<dyn Error + Send + Sync>),
ValidationError(Box<dyn Error + Send + Sync>),
ModelError(String, ChatCompletionResponse),
CompletionModelError(String, CompletionResponse),
}
impl Display for ResponseErr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InternalError(e) | Self::ValidationError(e) => Display::fmt(e, f),
Self::ModelError(e, x) => f
.debug_struct("ChatModelError")
.field("msg", e)
.field("incomplete_response", x)
.finish(),
Self::CompletionModelError(e, x) => f
.debug_struct("CompletionModelError")
.field("msg", e)
.field("incomplete_response", x)
.finish(),
}
}
}
impl Debug for ResponseErr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InternalError(e) | Self::ValidationError(e) => Debug::fmt(e, f),
Self::ModelError(e, x) => f
.debug_struct("ChatModelError")
.field("msg", e)
.field("incomplete_response", x)
.finish(),
Self::CompletionModelError(e, x) => f
.debug_struct("CompletionModelError")
.field("msg", e)
.field("incomplete_response", x)
.finish(),
}
}
}
impl std::error::Error for ResponseErr {}
impl Response {
pub fn as_result(self) -> Result<ResponseOk, Box<ResponseErr>> {
match self {
Self::Done(x) => Ok(ResponseOk::Done(x)),
Self::Chunk(x) => Ok(ResponseOk::Chunk(x)),
Self::CompletionDone(x) => Ok(ResponseOk::CompletionDone(x)),
Self::CompletionChunk(x) => Ok(ResponseOk::CompletionChunk(x)),
Self::InternalError(e) => Err(Box::new(ResponseErr::InternalError(e))),
Self::ValidationError(e) => Err(Box::new(ResponseErr::ValidationError(e))),
Self::ModelError(e, x) => Err(Box::new(ResponseErr::ModelError(e, x))),
Self::CompletionModelError(e, x) => {
Err(Box::new(ResponseErr::CompletionModelError(e, x)))
}
Self::ImageGeneration(x) => Ok(ResponseOk::ImageGeneration(x)),
Self::Raw {
logits_chunks,
tokens,
} => Ok(ResponseOk::Raw {
logits_chunks,
tokens,
}),
}
}
}