mistralrs_server/
image_generation.rs1use anyhow::Result;
2use std::{error::Error, sync::Arc};
3use tokio::sync::mpsc::{channel, Sender};
4
5use crate::openai::ImageGenerationRequest;
6use axum::{
7 extract::{Json, State},
8 http::{self, StatusCode},
9 response::IntoResponse,
10};
11use mistralrs_core::{
12 Constraint, DiffusionGenerationParams, ImageGenerationResponse, MistralRs, NormalRequest,
13 Request, RequestMessage, Response, SamplingParams,
14};
15use serde::Serialize;
16
17pub enum ImageGenerationResponder {
18 Json(ImageGenerationResponse),
19 InternalError(Box<dyn Error>),
20 ValidationError(Box<dyn Error>),
21}
22
23trait ErrorToResponse: Serialize {
24 fn to_response(&self, code: StatusCode) -> axum::response::Response {
25 let mut r = Json(self).into_response();
26 *r.status_mut() = code;
27 r
28 }
29}
30
31#[derive(Serialize)]
32struct JsonError {
33 message: String,
34}
35
36impl JsonError {
37 fn new(message: String) -> Self {
38 Self { message }
39 }
40}
41impl ErrorToResponse for JsonError {}
42
43impl IntoResponse for ImageGenerationResponder {
44 fn into_response(self) -> axum::response::Response {
45 match self {
46 ImageGenerationResponder::Json(s) => Json(s).into_response(),
47 ImageGenerationResponder::InternalError(e) => {
48 JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
49 }
50 ImageGenerationResponder::ValidationError(e) => {
51 JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
52 }
53 }
54 }
55}
56
57fn parse_request(
58 oairequest: ImageGenerationRequest,
59 state: Arc<MistralRs>,
60 tx: Sender<Response>,
61) -> Result<Request> {
62 let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
63 MistralRs::maybe_log_request(state.clone(), repr);
64
65 Ok(Request::Normal(NormalRequest {
66 id: state.next_request_id(),
67 messages: RequestMessage::ImageGeneration {
68 prompt: oairequest.prompt,
69 format: oairequest.response_format,
70 generation_params: DiffusionGenerationParams {
71 height: oairequest.height,
72 width: oairequest.width,
73 },
74 },
75 sampling_params: SamplingParams::deterministic(),
76 response: tx,
77 return_logprobs: false,
78 is_streaming: false,
79 suffix: None,
80 constraint: Constraint::None,
81 tool_choice: None,
82 tools: None,
83 logits_processors: None,
84 return_raw_logits: false,
85 web_search_options: None,
86 }))
87}
88
89#[utoipa::path(
90 post,
91 tag = "Mistral.rs",
92 path = "/v1/images/generations",
93 request_body = ImageGenerationRequest,
94 responses((status = 200, description = "Image generation"))
95)]
96
97pub async fn image_generation(
98 State(state): State<Arc<MistralRs>>,
99 Json(oairequest): Json<ImageGenerationRequest>,
100) -> ImageGenerationResponder {
101 let (tx, mut rx) = channel(10_000);
102
103 let request = match parse_request(oairequest, state.clone(), tx) {
104 Ok(x) => x,
105 Err(e) => {
106 let e = anyhow::Error::msg(e.to_string());
107 MistralRs::maybe_log_error(state, &*e);
108 return ImageGenerationResponder::InternalError(e.into());
109 }
110 };
111 let sender = state.get_sender().unwrap();
112
113 if let Err(e) = sender.send(request).await {
114 let e = anyhow::Error::msg(e.to_string());
115 MistralRs::maybe_log_error(state, &*e);
116 return ImageGenerationResponder::InternalError(e.into());
117 }
118
119 let response = match rx.recv().await {
120 Some(response) => response,
121 None => {
122 let e = anyhow::Error::msg("No response received from the model.");
123 MistralRs::maybe_log_error(state, &*e);
124 return ImageGenerationResponder::InternalError(e.into());
125 }
126 };
127
128 match response {
129 Response::InternalError(e) => {
130 MistralRs::maybe_log_error(state, &*e);
131 ImageGenerationResponder::InternalError(e)
132 }
133 Response::ValidationError(e) => ImageGenerationResponder::ValidationError(e),
134 Response::ImageGeneration(response) => {
135 MistralRs::maybe_log_response(state, &response);
136 ImageGenerationResponder::Json(response)
137 }
138 Response::CompletionModelError(m, _) => {
139 let e = anyhow::Error::msg(m.to_string());
140 MistralRs::maybe_log_error(state, &*e);
141 ImageGenerationResponder::InternalError(e.into())
142 }
143 Response::CompletionDone(_) => unreachable!(),
144 Response::CompletionChunk(_) => unreachable!(),
145 Response::Chunk(_) => unreachable!(),
146 Response::Done(_) => unreachable!(),
147 Response::ModelError(_, _) => unreachable!(),
148 Response::Raw { .. } => unreachable!(),
149 }
150}