mistralrs_server_core/
image_generation.rs

1//! ## Image generation functionality and route handler.
2
3use std::{error::Error, sync::Arc};
4
5use anyhow::Result;
6use axum::{
7    extract::{Json, State},
8    http::{self},
9    response::IntoResponse,
10};
11use mistralrs_core::{
12    Constraint, DiffusionGenerationParams, ImageGenerationResponse, MistralRs, NormalRequest,
13    Request, RequestMessage, Response, SamplingParams,
14};
15use tokio::sync::mpsc::{Receiver, Sender};
16
17use crate::{
18    handler_core::{
19        base_process_non_streaming_response, create_response_channel, send_request,
20        ErrorToResponse, JsonError,
21    },
22    openai::ImageGenerationRequest,
23    types::{ExtractedMistralRsState, SharedMistralRsState},
24    util::{sanitize_error_message, validate_model_name},
25};
26
27/// Represents different types of image generation responses.
28pub enum ImageGenerationResponder {
29    Json(ImageGenerationResponse),
30    InternalError(Box<dyn Error>),
31    ValidationError(Box<dyn Error>),
32}
33
34impl IntoResponse for ImageGenerationResponder {
35    /// Converts the image generation responder into an HTTP response.
36    fn into_response(self) -> axum::response::Response {
37        match self {
38            ImageGenerationResponder::Json(s) => Json(s).into_response(),
39            ImageGenerationResponder::InternalError(e) => {
40                JsonError::new(sanitize_error_message(e.as_ref()))
41                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
42            }
43            ImageGenerationResponder::ValidationError(e) => {
44                JsonError::new(sanitize_error_message(e.as_ref()))
45                    .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
46            }
47        }
48    }
49}
50
51/// Parses and validates a image generation request.
52///
53/// This function transforms a image generation request into the
54/// request format used by mistral.rs.
55pub fn parse_request(
56    oairequest: ImageGenerationRequest,
57    state: Arc<MistralRs>,
58    tx: Sender<Response>,
59) -> Result<Request> {
60    let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
61    MistralRs::maybe_log_request(state.clone(), repr);
62
63    // Validate that the requested model matches the loaded model
64    validate_model_name(&oairequest.model, state.clone())?;
65
66    Ok(Request::Normal(Box::new(NormalRequest {
67        id: state.next_request_id(),
68        messages: RequestMessage::ImageGeneration {
69            prompt: oairequest.prompt,
70            format: oairequest.response_format,
71            generation_params: DiffusionGenerationParams {
72                height: oairequest.height,
73                width: oairequest.width,
74            },
75        },
76        sampling_params: SamplingParams::deterministic(),
77        response: tx,
78        return_logprobs: false,
79        is_streaming: false,
80        suffix: None,
81        constraint: Constraint::None,
82        tool_choice: None,
83        tools: None,
84        logits_processors: None,
85        return_raw_logits: false,
86        web_search_options: None,
87        model_id: if oairequest.model == "default" {
88            None
89        } else {
90            Some(oairequest.model.clone())
91        },
92        truncate_sequence: false,
93    })))
94}
95
96/// Image generation endpoint handler.
97#[utoipa::path(
98    post,
99    tag = "Mistral.rs",
100    path = "/v1/images/generations",
101    request_body = ImageGenerationRequest,
102    responses((status = 200, description = "Image generation"))
103)]
104pub async fn image_generation(
105    State(state): ExtractedMistralRsState,
106    Json(oairequest): Json<ImageGenerationRequest>,
107) -> ImageGenerationResponder {
108    let (tx, mut rx) = create_response_channel(None);
109
110    let request = match parse_request(oairequest, state.clone(), tx) {
111        Ok(x) => x,
112        Err(e) => return handle_error(state, e.into()),
113    };
114
115    if let Err(e) = send_request(&state, request).await {
116        return handle_error(state, e.into());
117    }
118
119    process_non_streaming_response(&mut rx, state).await
120}
121
122/// Helper function to handle image generation errors and logging them.
123pub fn handle_error(
124    state: SharedMistralRsState,
125    e: Box<dyn std::error::Error + Send + Sync + 'static>,
126) -> ImageGenerationResponder {
127    let sanitized_msg = sanitize_error_message(&*e);
128    let e = anyhow::Error::msg(sanitized_msg);
129    MistralRs::maybe_log_error(state, &*e);
130    ImageGenerationResponder::InternalError(e.into())
131}
132
133/// Process non-streaming image generation responses.
134pub async fn process_non_streaming_response(
135    rx: &mut Receiver<Response>,
136    state: SharedMistralRsState,
137) -> ImageGenerationResponder {
138    base_process_non_streaming_response(rx, state, match_responses, handle_error).await
139}
140
141/// Matches and processes different types of model responses into appropriate image generation responses.
142pub fn match_responses(
143    state: SharedMistralRsState,
144    response: Response,
145) -> ImageGenerationResponder {
146    match response {
147        Response::InternalError(e) => {
148            MistralRs::maybe_log_error(state, &*e);
149            ImageGenerationResponder::InternalError(e)
150        }
151        Response::ValidationError(e) => ImageGenerationResponder::ValidationError(e),
152        Response::ImageGeneration(response) => {
153            MistralRs::maybe_log_response(state, &response);
154            ImageGenerationResponder::Json(response)
155        }
156        Response::CompletionModelError(m, _) => {
157            let e = anyhow::Error::msg(m.to_string());
158            MistralRs::maybe_log_error(state, &*e);
159            ImageGenerationResponder::InternalError(e.into())
160        }
161        Response::CompletionDone(_) => unreachable!(),
162        Response::CompletionChunk(_) => unreachable!(),
163        Response::Chunk(_) => unreachable!(),
164        Response::Done(_) => unreachable!(),
165        Response::ModelError(_, _) => unreachable!(),
166        Response::Speech { .. } => unreachable!(),
167        Response::Raw { .. } => unreachable!(),
168        Response::Embeddings { .. } => unreachable!(),
169    }
170}