1use std::{error::Error, sync::Arc};
4
5use anyhow::Result;
6use axum::{
7 body::Bytes,
8 extract::{Json, State},
9 http::{self, HeaderMap, HeaderValue, StatusCode},
10 response::IntoResponse,
11};
12use mistralrs_core::{
13 speech_utils::{self, Sample},
14 Constraint, MistralRs, NormalRequest, Request, RequestMessage, Response, SamplingParams,
15};
16use tokio::sync::mpsc::{Receiver, Sender};
17
18use crate::{
19 handler_core::{create_response_channel, send_request, ErrorToResponse, JsonError},
20 openai::{AudioResponseFormat, SpeechGenerationRequest},
21 types::SharedMistralRsState,
22 util::{sanitize_error_message, validate_model_name},
23};
24
25pub enum SpeechGenerationResponder {
27 InternalError(Box<dyn Error>),
28 ValidationError(Box<dyn Error>),
29 RawResponse(axum::response::Response),
30}
31
32impl IntoResponse for SpeechGenerationResponder {
33 fn into_response(self) -> axum::response::Response {
35 match self {
36 SpeechGenerationResponder::InternalError(e) => {
37 JsonError::new(sanitize_error_message(e.as_ref()))
38 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
39 }
40 SpeechGenerationResponder::ValidationError(e) => {
41 JsonError::new(sanitize_error_message(e.as_ref()))
42 .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
43 }
44 SpeechGenerationResponder::RawResponse(resp) => resp,
45 }
46 }
47}
48
49pub fn parse_request(
54 oairequest: SpeechGenerationRequest,
55 state: Arc<MistralRs>,
56 tx: Sender<Response>,
57) -> Result<(Request, AudioResponseFormat)> {
58 let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
59 MistralRs::maybe_log_request(state.clone(), repr);
60
61 validate_model_name(&oairequest.model, state.clone())?;
63
64 let request = Request::Normal(Box::new(NormalRequest {
65 id: state.next_request_id(),
66 messages: RequestMessage::SpeechGeneration {
67 prompt: oairequest.input,
68 },
69 sampling_params: SamplingParams::deterministic(),
70 response: tx,
71 return_logprobs: false,
72 is_streaming: false,
73 suffix: None,
74 constraint: Constraint::None,
75 tool_choice: None,
76 tools: None,
77 logits_processors: None,
78 return_raw_logits: false,
79 web_search_options: None,
80 model_id: if oairequest.model == "default" {
81 None
82 } else {
83 Some(oairequest.model.clone())
84 },
85 truncate_sequence: false,
86 }));
87
88 Ok((request, oairequest.response_format))
89}
90
91#[utoipa::path(
93 post,
94 tag = "Mistral.rs",
95 path = "/v1/audio/speech",
96 request_body = SpeechGenerationRequest,
97 responses((status = 200, description = "Speech generation"))
98)]
99pub async fn speech_generation(
100 State(state): State<Arc<MistralRs>>,
101 Json(oairequest): Json<SpeechGenerationRequest>,
102) -> SpeechGenerationResponder {
103 let (tx, mut rx) = create_response_channel(None);
104
105 let (request, response_format) = match parse_request(oairequest, state.clone(), tx) {
106 Ok(x) => x,
107 Err(e) => return handle_error(state, e.into()),
108 };
109
110 if !matches!(
112 response_format,
113 AudioResponseFormat::Wav | AudioResponseFormat::Pcm
114 ) {
115 return SpeechGenerationResponder::ValidationError(Box::new(JsonError::new(
116 "Only support wav/pcm response format.".to_string(),
117 )));
118 }
119
120 if let Err(e) = send_request(&state, request).await {
121 return handle_error(state, e.into());
122 }
123
124 process_non_streaming_response(&mut rx, state, response_format).await
125}
126
127pub fn handle_error(
129 state: SharedMistralRsState,
130 e: Box<dyn std::error::Error + Send + Sync + 'static>,
131) -> SpeechGenerationResponder {
132 let sanitized_msg = sanitize_error_message(&*e);
133 let e = anyhow::Error::msg(sanitized_msg);
134 MistralRs::maybe_log_error(state, &*e);
135 SpeechGenerationResponder::InternalError(e.into())
136}
137
138pub async fn process_non_streaming_response(
140 rx: &mut Receiver<Response>,
141 state: SharedMistralRsState,
142 response_format: AudioResponseFormat,
143) -> SpeechGenerationResponder {
144 let response = match rx.recv().await {
145 Some(response) => response,
146 None => {
147 let e = anyhow::Error::msg("No response received from the model.");
148 return handle_error(state, e.into());
149 }
150 };
151
152 match_responses(state, response, response_format)
153}
154
155pub fn match_responses(
157 state: SharedMistralRsState,
158 response: Response,
159 response_format: AudioResponseFormat,
160) -> SpeechGenerationResponder {
161 match response {
162 Response::InternalError(e) => {
163 MistralRs::maybe_log_error(state, &*e);
164 SpeechGenerationResponder::InternalError(e)
165 }
166 Response::ValidationError(e) => SpeechGenerationResponder::ValidationError(e),
167 Response::ImageGeneration(_) => unreachable!(),
168 Response::CompletionModelError(m, _) => {
169 let e = anyhow::Error::msg(m.to_string());
170 MistralRs::maybe_log_error(state, &*e);
171 SpeechGenerationResponder::InternalError(e.into())
172 }
173 Response::CompletionDone(_) => unreachable!(),
174 Response::CompletionChunk(_) => unreachable!(),
175 Response::Chunk(_) => unreachable!(),
176 Response::Done(_) => unreachable!(),
177 Response::ModelError(_, _) => unreachable!(),
178 Response::Speech {
179 pcm,
180 rate,
181 channels,
182 } => {
183 let pcm_endianness = "s16le";
184
185 let content_type = response_format.audio_content_type(rate, channels, pcm_endianness);
186 let mut headers = HeaderMap::new();
187 headers.insert(
188 http::header::CONTENT_TYPE,
189 HeaderValue::from_str(&content_type).unwrap(),
190 );
191
192 let encoded = match response_format {
193 AudioResponseFormat::Pcm => {
194 let samples: &[f32] = &pcm;
195 let mut buf = Vec::with_capacity(samples.len() * std::mem::size_of::<i64>());
196 for &sample in samples {
197 buf.extend_from_slice(&sample.to_i16().to_le_bytes());
198 }
199 buf
200 }
201 AudioResponseFormat::Wav => {
202 let mut buf = Vec::new();
204 speech_utils::write_pcm_as_wav(&mut buf, &pcm, rate as u32, channels as u16)
205 .unwrap();
206 buf
207 }
208 _ => unreachable!("Should be validated above."),
209 };
210
211 let bytes = Bytes::from(encoded);
212
213 SpeechGenerationResponder::RawResponse((StatusCode::OK, headers, bytes).into_response())
214 }
215 Response::Raw { .. } => unreachable!(),
216 Response::Embeddings { .. } => unreachable!(),
217 }
218}