1use std::{error::Error, sync::Arc};
4
5use anyhow::Result;
6use axum::{
7 body::Bytes,
8 extract::{Json, State},
9 http::{self, HeaderMap, HeaderValue, StatusCode},
10 response::IntoResponse,
11};
12use mistralrs_core::{
13 speech_utils::{self, Sample},
14 Constraint, MistralRs, NormalRequest, Request, RequestMessage, Response, SamplingParams,
15};
16use tokio::sync::mpsc::{Receiver, Sender};
17
18use crate::{
19 handler_core::{create_response_channel, send_request, ErrorToResponse, JsonError},
20 openai::{AudioResponseFormat, SpeechGenerationRequest},
21 types::SharedMistralRsState,
22 util::validate_model_name,
23};
24
25pub enum SpeechGenerationResponder {
27 InternalError(Box<dyn Error>),
28 ValidationError(Box<dyn Error>),
29 RawResponse(axum::response::Response),
30}
31
32impl IntoResponse for SpeechGenerationResponder {
33 fn into_response(self) -> axum::response::Response {
35 match self {
36 SpeechGenerationResponder::InternalError(e) => {
37 JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
38 }
39 SpeechGenerationResponder::ValidationError(e) => {
40 JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
41 }
42 SpeechGenerationResponder::RawResponse(resp) => resp,
43 }
44 }
45}
46
47pub fn parse_request(
52 oairequest: SpeechGenerationRequest,
53 state: Arc<MistralRs>,
54 tx: Sender<Response>,
55) -> Result<(Request, AudioResponseFormat)> {
56 let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
57 MistralRs::maybe_log_request(state.clone(), repr);
58
59 validate_model_name(&oairequest.model, state.clone())?;
61
62 let request = Request::Normal(Box::new(NormalRequest {
63 id: state.next_request_id(),
64 messages: RequestMessage::SpeechGeneration {
65 prompt: oairequest.input,
66 },
67 sampling_params: SamplingParams::deterministic(),
68 response: tx,
69 return_logprobs: false,
70 is_streaming: false,
71 suffix: None,
72 constraint: Constraint::None,
73 tool_choice: None,
74 tools: None,
75 logits_processors: None,
76 return_raw_logits: false,
77 web_search_options: None,
78 model_id: if oairequest.model == "default" {
79 None
80 } else {
81 Some(oairequest.model.clone())
82 },
83 }));
84
85 Ok((request, oairequest.response_format))
86}
87
88#[utoipa::path(
90 post,
91 tag = "Mistral.rs",
92 path = "/v1/audio/speech",
93 request_body = SpeechGenerationRequest,
94 responses((status = 200, description = "Speech generation"))
95)]
96pub async fn speech_generation(
97 State(state): State<Arc<MistralRs>>,
98 Json(oairequest): Json<SpeechGenerationRequest>,
99) -> SpeechGenerationResponder {
100 let (tx, mut rx) = create_response_channel(None);
101
102 let (request, response_format) = match parse_request(oairequest, state.clone(), tx) {
103 Ok(x) => x,
104 Err(e) => return handle_error(state, e.into()),
105 };
106
107 if !matches!(
109 response_format,
110 AudioResponseFormat::Wav | AudioResponseFormat::Pcm
111 ) {
112 return SpeechGenerationResponder::ValidationError(Box::new(JsonError::new(
113 "Only support wav/pcm response format.".to_string(),
114 )));
115 }
116
117 if let Err(e) = send_request(&state, request).await {
118 return handle_error(state, e.into());
119 }
120
121 process_non_streaming_response(&mut rx, state, response_format).await
122}
123
124pub fn handle_error(
126 state: SharedMistralRsState,
127 e: Box<dyn std::error::Error + Send + Sync + 'static>,
128) -> SpeechGenerationResponder {
129 let e = anyhow::Error::msg(e.to_string());
130 MistralRs::maybe_log_error(state, &*e);
131 SpeechGenerationResponder::InternalError(e.into())
132}
133
134pub async fn process_non_streaming_response(
136 rx: &mut Receiver<Response>,
137 state: SharedMistralRsState,
138 response_format: AudioResponseFormat,
139) -> SpeechGenerationResponder {
140 let response = match rx.recv().await {
141 Some(response) => response,
142 None => {
143 let e = anyhow::Error::msg("No response received from the model.");
144 return handle_error(state, e.into());
145 }
146 };
147
148 match_responses(state, response, response_format)
149}
150
151pub fn match_responses(
153 state: SharedMistralRsState,
154 response: Response,
155 response_format: AudioResponseFormat,
156) -> SpeechGenerationResponder {
157 match response {
158 Response::InternalError(e) => {
159 MistralRs::maybe_log_error(state, &*e);
160 SpeechGenerationResponder::InternalError(e)
161 }
162 Response::ValidationError(e) => SpeechGenerationResponder::ValidationError(e),
163 Response::ImageGeneration(_) => unreachable!(),
164 Response::CompletionModelError(m, _) => {
165 let e = anyhow::Error::msg(m.to_string());
166 MistralRs::maybe_log_error(state, &*e);
167 SpeechGenerationResponder::InternalError(e.into())
168 }
169 Response::CompletionDone(_) => unreachable!(),
170 Response::CompletionChunk(_) => unreachable!(),
171 Response::Chunk(_) => unreachable!(),
172 Response::Done(_) => unreachable!(),
173 Response::ModelError(_, _) => unreachable!(),
174 Response::Speech {
175 pcm,
176 rate,
177 channels,
178 } => {
179 let pcm_endianness = "s16le";
180
181 let content_type = response_format.audio_content_type(rate, channels, pcm_endianness);
182 let mut headers = HeaderMap::new();
183 headers.insert(
184 http::header::CONTENT_TYPE,
185 HeaderValue::from_str(&content_type).unwrap(),
186 );
187
188 let encoded = match response_format {
189 AudioResponseFormat::Pcm => {
190 let samples: &[f32] = &pcm;
191 let mut buf = Vec::with_capacity(samples.len() * std::mem::size_of::<i64>());
192 for &sample in samples {
193 buf.extend_from_slice(&sample.to_i16().to_le_bytes());
194 }
195 buf
196 }
197 AudioResponseFormat::Wav => {
198 let mut buf = Vec::new();
200 speech_utils::write_pcm_as_wav(&mut buf, &pcm, rate as u32, channels as u16)
201 .unwrap();
202 buf
203 }
204 _ => unreachable!("Should be validated above."),
205 };
206
207 let bytes = Bytes::from(encoded);
208
209 SpeechGenerationResponder::RawResponse((StatusCode::OK, headers, bytes).into_response())
210 }
211 Response::Raw { .. } => unreachable!(),
212 }
213}