mistralrs_server_core/
handler_core.rs

1//! Core functionality for handlers.
2
3use anyhow::{Context, Result};
4use axum::{extract::Json, http::StatusCode, response::IntoResponse};
5use mistralrs_core::{Request, Response};
6use serde::Serialize;
7use tokio::sync::mpsc::{channel, Receiver, Sender};
8
9use crate::types::SharedMistralRsState;
10
11/// Default buffer size for the response channel used in streaming operations.
12///
13/// This constant defines the maximum number of response messages that can be buffered
14/// in the channel before backpressure is applied. A larger buffer reduces the likelihood
15/// of blocking but uses more memory.
16pub const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 10_000;
17
18/// Trait for converting errors to HTTP responses with appropriate status codes.
19pub(crate) trait ErrorToResponse: Serialize {
20    /// Converts the error to an HTTP response with the specified status code.
21    fn to_response(&self, code: StatusCode) -> axum::response::Response {
22        let mut response = Json(self).into_response();
23        *response.status_mut() = code;
24        response
25    }
26}
27
28/// Standard JSON error response structure.
29#[derive(Serialize, Debug)]
30pub(crate) struct JsonError {
31    pub(crate) message: String,
32}
33
34impl JsonError {
35    /// Creates a new JSON error with the specified message.
36    pub(crate) fn new(message: String) -> Self {
37        Self { message }
38    }
39}
40
41impl std::fmt::Display for JsonError {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        write!(f, "{}", self.message)
44    }
45}
46
47impl std::error::Error for JsonError {}
48impl ErrorToResponse for JsonError {}
49
50/// Internal error type for model-related errors with a descriptive message.
51///
52/// This struct wraps error messages from the underlying model and implements
53/// the standard error traits for proper error handling and display.
54#[derive(Debug)]
55pub(crate) struct ModelErrorMessage(pub(crate) String);
56
57impl std::fmt::Display for ModelErrorMessage {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        write!(f, "{}", self.0)
60    }
61}
62
63impl std::error::Error for ModelErrorMessage {}
64
65/// Generic JSON error response structure
66#[derive(Serialize, Debug)]
67pub(crate) struct BaseJsonModelError<T> {
68    pub(crate) message: String,
69    pub(crate) partial_response: T,
70}
71
72impl<T> BaseJsonModelError<T> {
73    pub(crate) fn new(message: String, partial_response: T) -> Self {
74        Self {
75            message,
76            partial_response,
77        }
78    }
79}
80
81/// Creates a channel for response communication.
82pub fn create_response_channel(
83    buffer_size: Option<usize>,
84) -> (Sender<Response>, Receiver<Response>) {
85    let channel_buffer_size = buffer_size.unwrap_or(DEFAULT_CHANNEL_BUFFER_SIZE);
86    channel(channel_buffer_size)
87}
88
89/// Sends a request to the model processing pipeline.
90pub async fn send_request(state: &SharedMistralRsState, request: Request) -> Result<()> {
91    send_request_with_model(state, request, None).await
92}
93
94pub async fn send_request_with_model(
95    state: &SharedMistralRsState,
96    request: Request,
97    model_id: Option<&str>,
98) -> Result<()> {
99    let sender = state
100        .get_sender(model_id)
101        .context("mistral.rs sender not available.")?;
102
103    sender
104        .send(request)
105        .await
106        .context("Failed to send request to model pipeline")
107}
108
109/// Generic function to process non-streaming responses.
110pub(crate) async fn base_process_non_streaming_response<R>(
111    rx: &mut Receiver<Response>,
112    state: SharedMistralRsState,
113    match_fn: fn(SharedMistralRsState, Response) -> R,
114    error_handler: fn(
115        SharedMistralRsState,
116        Box<dyn std::error::Error + Send + Sync + 'static>,
117    ) -> R,
118) -> R {
119    match rx.recv().await {
120        Some(response) => match_fn(state, response),
121        None => {
122            let error = anyhow::Error::msg("No response received from the model.");
123            error_handler(state, error.into())
124        }
125    }
126}