mistralrs_server_core/
handler_core.rs1use anyhow::{Context, Result};
4use axum::{extract::Json, http::StatusCode, response::IntoResponse};
5use mistralrs_core::{Request, Response};
6use serde::Serialize;
7use tokio::sync::mpsc::{channel, Receiver, Sender};
8
9use crate::types::SharedMistralRsState;
10
11pub const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 10_000;
17
18pub(crate) trait ErrorToResponse: Serialize {
20 fn to_response(&self, code: StatusCode) -> axum::response::Response {
22 let mut response = Json(self).into_response();
23 *response.status_mut() = code;
24 response
25 }
26}
27
28#[derive(Serialize, Debug)]
30pub(crate) struct JsonError {
31 pub(crate) message: String,
32}
33
34impl JsonError {
35 pub(crate) fn new(message: String) -> Self {
37 Self { message }
38 }
39}
40
41impl std::fmt::Display for JsonError {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 write!(f, "{}", self.message)
44 }
45}
46
47impl std::error::Error for JsonError {}
48impl ErrorToResponse for JsonError {}
49
50#[derive(Debug)]
55pub(crate) struct ModelErrorMessage(pub(crate) String);
56
57impl std::fmt::Display for ModelErrorMessage {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 write!(f, "{}", self.0)
60 }
61}
62
63impl std::error::Error for ModelErrorMessage {}
64
65#[derive(Serialize, Debug)]
67pub(crate) struct BaseJsonModelError<T> {
68 pub(crate) message: String,
69 pub(crate) partial_response: T,
70}
71
72impl<T> BaseJsonModelError<T> {
73 pub(crate) fn new(message: String, partial_response: T) -> Self {
74 Self {
75 message,
76 partial_response,
77 }
78 }
79}
80
81pub fn create_response_channel(
83 buffer_size: Option<usize>,
84) -> (Sender<Response>, Receiver<Response>) {
85 let channel_buffer_size = buffer_size.unwrap_or(DEFAULT_CHANNEL_BUFFER_SIZE);
86 channel(channel_buffer_size)
87}
88
89pub async fn send_request(state: &SharedMistralRsState, request: Request) -> Result<()> {
91 send_request_with_model(state, request, None).await
92}
93
94pub async fn send_request_with_model(
95 state: &SharedMistralRsState,
96 request: Request,
97 model_id: Option<&str>,
98) -> Result<()> {
99 let sender = state
100 .get_sender(model_id)
101 .context("mistral.rs sender not available.")?;
102
103 sender
104 .send(request)
105 .await
106 .context("Failed to send request to model pipeline")
107}
108
109pub(crate) async fn base_process_non_streaming_response<R>(
111 rx: &mut Receiver<Response>,
112 state: SharedMistralRsState,
113 match_fn: fn(SharedMistralRsState, Response) -> R,
114 error_handler: fn(
115 SharedMistralRsState,
116 Box<dyn std::error::Error + Send + Sync + 'static>,
117 ) -> R,
118) -> R {
119 match rx.recv().await {
120 Some(response) => match_fn(state, response),
121 None => {
122 let error = anyhow::Error::msg("No response received from the model.");
123 error_handler(state, error.into())
124 }
125 }
126}