mistralrs_server/
image_generation.rs

1use anyhow::Result;
2use std::{error::Error, sync::Arc};
3use tokio::sync::mpsc::{channel, Sender};
4
5use crate::openai::ImageGenerationRequest;
6use axum::{
7    extract::{Json, State},
8    http::{self, StatusCode},
9    response::IntoResponse,
10};
11use mistralrs_core::{
12    Constraint, DiffusionGenerationParams, ImageGenerationResponse, MistralRs, NormalRequest,
13    Request, RequestMessage, Response, SamplingParams,
14};
15use serde::Serialize;
16
17pub enum ImageGenerationResponder {
18    Json(ImageGenerationResponse),
19    InternalError(Box<dyn Error>),
20    ValidationError(Box<dyn Error>),
21}
22
23trait ErrorToResponse: Serialize {
24    fn to_response(&self, code: StatusCode) -> axum::response::Response {
25        let mut r = Json(self).into_response();
26        *r.status_mut() = code;
27        r
28    }
29}
30
31#[derive(Serialize)]
32struct JsonError {
33    message: String,
34}
35
36impl JsonError {
37    fn new(message: String) -> Self {
38        Self { message }
39    }
40}
41impl ErrorToResponse for JsonError {}
42
43impl IntoResponse for ImageGenerationResponder {
44    fn into_response(self) -> axum::response::Response {
45        match self {
46            ImageGenerationResponder::Json(s) => Json(s).into_response(),
47            ImageGenerationResponder::InternalError(e) => {
48                JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
49            }
50            ImageGenerationResponder::ValidationError(e) => {
51                JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
52            }
53        }
54    }
55}
56
57fn parse_request(
58    oairequest: ImageGenerationRequest,
59    state: Arc<MistralRs>,
60    tx: Sender<Response>,
61) -> Result<Request> {
62    let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
63    MistralRs::maybe_log_request(state.clone(), repr);
64
65    Ok(Request::Normal(NormalRequest {
66        id: state.next_request_id(),
67        messages: RequestMessage::ImageGeneration {
68            prompt: oairequest.prompt,
69            format: oairequest.response_format,
70            generation_params: DiffusionGenerationParams {
71                height: oairequest.height,
72                width: oairequest.width,
73            },
74        },
75        sampling_params: SamplingParams::deterministic(),
76        response: tx,
77        return_logprobs: false,
78        is_streaming: false,
79        suffix: None,
80        constraint: Constraint::None,
81        tool_choice: None,
82        tools: None,
83        logits_processors: None,
84        return_raw_logits: false,
85        web_search_options: None,
86    }))
87}
88
89#[utoipa::path(
90    post,
91    tag = "Mistral.rs",
92    path = "/v1/images/generations",
93    request_body = ImageGenerationRequest,
94    responses((status = 200, description = "Image generation"))
95)]
96
97pub async fn image_generation(
98    State(state): State<Arc<MistralRs>>,
99    Json(oairequest): Json<ImageGenerationRequest>,
100) -> ImageGenerationResponder {
101    let (tx, mut rx) = channel(10_000);
102
103    let request = match parse_request(oairequest, state.clone(), tx) {
104        Ok(x) => x,
105        Err(e) => {
106            let e = anyhow::Error::msg(e.to_string());
107            MistralRs::maybe_log_error(state, &*e);
108            return ImageGenerationResponder::InternalError(e.into());
109        }
110    };
111    let sender = state.get_sender().unwrap();
112
113    if let Err(e) = sender.send(request).await {
114        let e = anyhow::Error::msg(e.to_string());
115        MistralRs::maybe_log_error(state, &*e);
116        return ImageGenerationResponder::InternalError(e.into());
117    }
118
119    let response = match rx.recv().await {
120        Some(response) => response,
121        None => {
122            let e = anyhow::Error::msg("No response received from the model.");
123            MistralRs::maybe_log_error(state, &*e);
124            return ImageGenerationResponder::InternalError(e.into());
125        }
126    };
127
128    match response {
129        Response::InternalError(e) => {
130            MistralRs::maybe_log_error(state, &*e);
131            ImageGenerationResponder::InternalError(e)
132        }
133        Response::ValidationError(e) => ImageGenerationResponder::ValidationError(e),
134        Response::ImageGeneration(response) => {
135            MistralRs::maybe_log_response(state, &response);
136            ImageGenerationResponder::Json(response)
137        }
138        Response::CompletionModelError(m, _) => {
139            let e = anyhow::Error::msg(m.to_string());
140            MistralRs::maybe_log_error(state, &*e);
141            ImageGenerationResponder::InternalError(e.into())
142        }
143        Response::CompletionDone(_) => unreachable!(),
144        Response::CompletionChunk(_) => unreachable!(),
145        Response::Chunk(_) => unreachable!(),
146        Response::Done(_) => unreachable!(),
147        Response::ModelError(_, _) => unreachable!(),
148        Response::Raw { .. } => unreachable!(),
149    }
150}