mistralrs_server/
image_generation.rs1use anyhow::Result;
2use std::{error::Error, sync::Arc};
3use tokio::sync::mpsc::{channel, Sender};
4
5use crate::openai::ImageGenerationRequest;
6use axum::{
7 extract::{Json, State},
8 http::{self, StatusCode},
9 response::IntoResponse,
10};
11use mistralrs_core::{
12 Constraint, DiffusionGenerationParams, ImageGenerationResponse, MistralRs, NormalRequest,
13 Request, RequestMessage, Response, SamplingParams,
14};
15use serde::Serialize;
16
17pub enum ImageGenerationResponder {
18 Json(ImageGenerationResponse),
19 InternalError(Box<dyn Error>),
20 ValidationError(Box<dyn Error>),
21}
22
23trait ErrorToResponse: Serialize {
24 fn to_response(&self, code: StatusCode) -> axum::response::Response {
25 let mut r = Json(self).into_response();
26 *r.status_mut() = code;
27 r
28 }
29}
30
31#[derive(Serialize)]
32struct JsonError {
33 message: String,
34}
35
36impl JsonError {
37 fn new(message: String) -> Self {
38 Self { message }
39 }
40}
41impl ErrorToResponse for JsonError {}
42
43impl IntoResponse for ImageGenerationResponder {
44 fn into_response(self) -> axum::response::Response {
45 match self {
46 ImageGenerationResponder::Json(s) => Json(s).into_response(),
47 ImageGenerationResponder::InternalError(e) => {
48 JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
49 }
50 ImageGenerationResponder::ValidationError(e) => {
51 JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
52 }
53 }
54 }
55}
56
57fn parse_request(
58 oairequest: ImageGenerationRequest,
59 state: Arc<MistralRs>,
60 tx: Sender<Response>,
61) -> Result<Request> {
62 let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
63 MistralRs::maybe_log_request(state.clone(), repr);
64
65 Ok(Request::Normal(NormalRequest {
66 id: state.next_request_id(),
67 messages: RequestMessage::ImageGeneration {
68 prompt: oairequest.prompt,
69 format: oairequest.response_format,
70 generation_params: DiffusionGenerationParams {
71 height: oairequest.height,
72 width: oairequest.width,
73 },
74 },
75 sampling_params: SamplingParams::deterministic(),
76 response: tx,
77 return_logprobs: false,
78 is_streaming: false,
79 suffix: None,
80 constraint: Constraint::None,
81 adapters: None,
82 tool_choice: None,
83 tools: None,
84 logits_processors: None,
85 return_raw_logits: false,
86 web_search_options: None,
87 }))
88}
89
90#[utoipa::path(
91 post,
92 tag = "Mistral.rs",
93 path = "/v1/images/generations",
94 request_body = ImageGenerationRequest,
95 responses((status = 200, description = "Image generation"))
96)]
97
98pub async fn image_generation(
99 State(state): State<Arc<MistralRs>>,
100 Json(oairequest): Json<ImageGenerationRequest>,
101) -> ImageGenerationResponder {
102 let (tx, mut rx) = channel(10_000);
103
104 let request = match parse_request(oairequest, state.clone(), tx) {
105 Ok(x) => x,
106 Err(e) => {
107 let e = anyhow::Error::msg(e.to_string());
108 MistralRs::maybe_log_error(state, &*e);
109 return ImageGenerationResponder::InternalError(e.into());
110 }
111 };
112 let sender = state.get_sender().unwrap();
113
114 if let Err(e) = sender.send(request).await {
115 let e = anyhow::Error::msg(e.to_string());
116 MistralRs::maybe_log_error(state, &*e);
117 return ImageGenerationResponder::InternalError(e.into());
118 }
119
120 let response = match rx.recv().await {
121 Some(response) => response,
122 None => {
123 let e = anyhow::Error::msg("No response received from the model.");
124 MistralRs::maybe_log_error(state, &*e);
125 return ImageGenerationResponder::InternalError(e.into());
126 }
127 };
128
129 match response {
130 Response::InternalError(e) => {
131 MistralRs::maybe_log_error(state, &*e);
132 ImageGenerationResponder::InternalError(e)
133 }
134 Response::ValidationError(e) => ImageGenerationResponder::ValidationError(e),
135 Response::ImageGeneration(response) => {
136 MistralRs::maybe_log_response(state, &response);
137 ImageGenerationResponder::Json(response)
138 }
139 Response::CompletionModelError(m, _) => {
140 let e = anyhow::Error::msg(m.to_string());
141 MistralRs::maybe_log_error(state, &*e);
142 ImageGenerationResponder::InternalError(e.into())
143 }
144 Response::CompletionDone(_) => unreachable!(),
145 Response::CompletionChunk(_) => unreachable!(),
146 Response::Chunk(_) => unreachable!(),
147 Response::Done(_) => unreachable!(),
148 Response::ModelError(_, _) => unreachable!(),
149 Response::Raw { .. } => unreachable!(),
150 }
151}