1use std::{error::Error, sync::Arc};
4
5use anyhow::Result;
6use axum::{
7 extract::{Json, State},
8 http::{self},
9 response::IntoResponse,
10};
11use mistralrs_core::{
12 Constraint, DiffusionGenerationParams, ImageGenerationResponse, MistralRs, NormalRequest,
13 Request, RequestMessage, Response, SamplingParams,
14};
15use tokio::sync::mpsc::{Receiver, Sender};
16
17use crate::{
18 handler_core::{
19 base_process_non_streaming_response, create_response_channel, send_request,
20 ErrorToResponse, JsonError,
21 },
22 openai::ImageGenerationRequest,
23 types::{ExtractedMistralRsState, SharedMistralRsState},
24 util::{sanitize_error_message, validate_model_name},
25};
26
27pub enum ImageGenerationResponder {
29 Json(ImageGenerationResponse),
30 InternalError(Box<dyn Error>),
31 ValidationError(Box<dyn Error>),
32}
33
34impl IntoResponse for ImageGenerationResponder {
35 fn into_response(self) -> axum::response::Response {
37 match self {
38 ImageGenerationResponder::Json(s) => Json(s).into_response(),
39 ImageGenerationResponder::InternalError(e) => {
40 JsonError::new(sanitize_error_message(e.as_ref()))
41 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
42 }
43 ImageGenerationResponder::ValidationError(e) => {
44 JsonError::new(sanitize_error_message(e.as_ref()))
45 .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
46 }
47 }
48 }
49}
50
51pub fn parse_request(
56 oairequest: ImageGenerationRequest,
57 state: Arc<MistralRs>,
58 tx: Sender<Response>,
59) -> Result<Request> {
60 let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
61 MistralRs::maybe_log_request(state.clone(), repr);
62
63 validate_model_name(&oairequest.model, state.clone())?;
65
66 Ok(Request::Normal(Box::new(NormalRequest {
67 id: state.next_request_id(),
68 messages: RequestMessage::ImageGeneration {
69 prompt: oairequest.prompt,
70 format: oairequest.response_format,
71 generation_params: DiffusionGenerationParams {
72 height: oairequest.height,
73 width: oairequest.width,
74 },
75 },
76 sampling_params: SamplingParams::deterministic(),
77 response: tx,
78 return_logprobs: false,
79 is_streaming: false,
80 suffix: None,
81 constraint: Constraint::None,
82 tool_choice: None,
83 tools: None,
84 logits_processors: None,
85 return_raw_logits: false,
86 web_search_options: None,
87 model_id: if oairequest.model == "default" {
88 None
89 } else {
90 Some(oairequest.model.clone())
91 },
92 truncate_sequence: false,
93 })))
94}
95
96#[utoipa::path(
98 post,
99 tag = "Mistral.rs",
100 path = "/v1/images/generations",
101 request_body = ImageGenerationRequest,
102 responses((status = 200, description = "Image generation"))
103)]
104pub async fn image_generation(
105 State(state): ExtractedMistralRsState,
106 Json(oairequest): Json<ImageGenerationRequest>,
107) -> ImageGenerationResponder {
108 let (tx, mut rx) = create_response_channel(None);
109
110 let request = match parse_request(oairequest, state.clone(), tx) {
111 Ok(x) => x,
112 Err(e) => return handle_error(state, e.into()),
113 };
114
115 if let Err(e) = send_request(&state, request).await {
116 return handle_error(state, e.into());
117 }
118
119 process_non_streaming_response(&mut rx, state).await
120}
121
122pub fn handle_error(
124 state: SharedMistralRsState,
125 e: Box<dyn std::error::Error + Send + Sync + 'static>,
126) -> ImageGenerationResponder {
127 let sanitized_msg = sanitize_error_message(&*e);
128 let e = anyhow::Error::msg(sanitized_msg);
129 MistralRs::maybe_log_error(state, &*e);
130 ImageGenerationResponder::InternalError(e.into())
131}
132
133pub async fn process_non_streaming_response(
135 rx: &mut Receiver<Response>,
136 state: SharedMistralRsState,
137) -> ImageGenerationResponder {
138 base_process_non_streaming_response(rx, state, match_responses, handle_error).await
139}
140
141pub fn match_responses(
143 state: SharedMistralRsState,
144 response: Response,
145) -> ImageGenerationResponder {
146 match response {
147 Response::InternalError(e) => {
148 MistralRs::maybe_log_error(state, &*e);
149 ImageGenerationResponder::InternalError(e)
150 }
151 Response::ValidationError(e) => ImageGenerationResponder::ValidationError(e),
152 Response::ImageGeneration(response) => {
153 MistralRs::maybe_log_response(state, &response);
154 ImageGenerationResponder::Json(response)
155 }
156 Response::CompletionModelError(m, _) => {
157 let e = anyhow::Error::msg(m.to_string());
158 MistralRs::maybe_log_error(state, &*e);
159 ImageGenerationResponder::InternalError(e.into())
160 }
161 Response::CompletionDone(_) => unreachable!(),
162 Response::CompletionChunk(_) => unreachable!(),
163 Response::Chunk(_) => unreachable!(),
164 Response::Done(_) => unreachable!(),
165 Response::ModelError(_, _) => unreachable!(),
166 Response::Speech { .. } => unreachable!(),
167 Response::Raw { .. } => unreachable!(),
168 Response::Embeddings { .. } => unreachable!(),
169 }
170}