mistralrs_server/
image_generation.rs

1use anyhow::Result;
2use std::{error::Error, sync::Arc};
3use tokio::sync::mpsc::{channel, Sender};
4
5use crate::openai::ImageGenerationRequest;
6use axum::{
7    extract::{Json, State},
8    http::{self, StatusCode},
9    response::IntoResponse,
10};
11use mistralrs_core::{
12    Constraint, DiffusionGenerationParams, ImageGenerationResponse, MistralRs, NormalRequest,
13    Request, RequestMessage, Response, SamplingParams,
14};
15use serde::Serialize;
16
17pub enum ImageGenerationResponder {
18    Json(ImageGenerationResponse),
19    InternalError(Box<dyn Error>),
20    ValidationError(Box<dyn Error>),
21}
22
23trait ErrorToResponse: Serialize {
24    fn to_response(&self, code: StatusCode) -> axum::response::Response {
25        let mut r = Json(self).into_response();
26        *r.status_mut() = code;
27        r
28    }
29}
30
31#[derive(Serialize)]
32struct JsonError {
33    message: String,
34}
35
36impl JsonError {
37    fn new(message: String) -> Self {
38        Self { message }
39    }
40}
41impl ErrorToResponse for JsonError {}
42
43impl IntoResponse for ImageGenerationResponder {
44    fn into_response(self) -> axum::response::Response {
45        match self {
46            ImageGenerationResponder::Json(s) => Json(s).into_response(),
47            ImageGenerationResponder::InternalError(e) => {
48                JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
49            }
50            ImageGenerationResponder::ValidationError(e) => {
51                JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
52            }
53        }
54    }
55}
56
57fn parse_request(
58    oairequest: ImageGenerationRequest,
59    state: Arc<MistralRs>,
60    tx: Sender<Response>,
61) -> Result<Request> {
62    let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
63    MistralRs::maybe_log_request(state.clone(), repr);
64
65    Ok(Request::Normal(NormalRequest {
66        id: state.next_request_id(),
67        messages: RequestMessage::ImageGeneration {
68            prompt: oairequest.prompt,
69            format: oairequest.response_format,
70            generation_params: DiffusionGenerationParams {
71                height: oairequest.height,
72                width: oairequest.width,
73            },
74        },
75        sampling_params: SamplingParams::deterministic(),
76        response: tx,
77        return_logprobs: false,
78        is_streaming: false,
79        suffix: None,
80        constraint: Constraint::None,
81        adapters: None,
82        tool_choice: None,
83        tools: None,
84        logits_processors: None,
85        return_raw_logits: false,
86        web_search_options: None,
87    }))
88}
89
90#[utoipa::path(
91    post,
92    tag = "Mistral.rs",
93    path = "/v1/images/generations",
94    request_body = ImageGenerationRequest,
95    responses((status = 200, description = "Image generation"))
96)]
97
98pub async fn image_generation(
99    State(state): State<Arc<MistralRs>>,
100    Json(oairequest): Json<ImageGenerationRequest>,
101) -> ImageGenerationResponder {
102    let (tx, mut rx) = channel(10_000);
103
104    let request = match parse_request(oairequest, state.clone(), tx) {
105        Ok(x) => x,
106        Err(e) => {
107            let e = anyhow::Error::msg(e.to_string());
108            MistralRs::maybe_log_error(state, &*e);
109            return ImageGenerationResponder::InternalError(e.into());
110        }
111    };
112    let sender = state.get_sender().unwrap();
113
114    if let Err(e) = sender.send(request).await {
115        let e = anyhow::Error::msg(e.to_string());
116        MistralRs::maybe_log_error(state, &*e);
117        return ImageGenerationResponder::InternalError(e.into());
118    }
119
120    let response = match rx.recv().await {
121        Some(response) => response,
122        None => {
123            let e = anyhow::Error::msg("No response received from the model.");
124            MistralRs::maybe_log_error(state, &*e);
125            return ImageGenerationResponder::InternalError(e.into());
126        }
127    };
128
129    match response {
130        Response::InternalError(e) => {
131            MistralRs::maybe_log_error(state, &*e);
132            ImageGenerationResponder::InternalError(e)
133        }
134        Response::ValidationError(e) => ImageGenerationResponder::ValidationError(e),
135        Response::ImageGeneration(response) => {
136            MistralRs::maybe_log_response(state, &response);
137            ImageGenerationResponder::Json(response)
138        }
139        Response::CompletionModelError(m, _) => {
140            let e = anyhow::Error::msg(m.to_string());
141            MistralRs::maybe_log_error(state, &*e);
142            ImageGenerationResponder::InternalError(e.into())
143        }
144        Response::CompletionDone(_) => unreachable!(),
145        Response::CompletionChunk(_) => unreachable!(),
146        Response::Chunk(_) => unreachable!(),
147        Response::Done(_) => unreachable!(),
148        Response::ModelError(_, _) => unreachable!(),
149        Response::Raw { .. } => unreachable!(),
150    }
151}