mistralrs_server_core/
speech_generation.rs

1//! ## Speech generation functionality and route handler.
2
3use std::{error::Error, sync::Arc};
4
5use anyhow::Result;
6use axum::{
7    body::Bytes,
8    extract::{Json, State},
9    http::{self, HeaderMap, HeaderValue, StatusCode},
10    response::IntoResponse,
11};
12use mistralrs_core::{
13    speech_utils::{self, Sample},
14    Constraint, MistralRs, NormalRequest, Request, RequestMessage, Response, SamplingParams,
15};
16use tokio::sync::mpsc::{Receiver, Sender};
17
18use crate::{
19    handler_core::{create_response_channel, send_request, ErrorToResponse, JsonError},
20    openai::{AudioResponseFormat, SpeechGenerationRequest},
21    types::SharedMistralRsState,
22    util::{sanitize_error_message, validate_model_name},
23};
24
25/// Represents different types of speech generation responses.
26pub enum SpeechGenerationResponder {
27    InternalError(Box<dyn Error>),
28    ValidationError(Box<dyn Error>),
29    RawResponse(axum::response::Response),
30}
31
32impl IntoResponse for SpeechGenerationResponder {
33    /// Converts the speech generation responder into an HTTP response.
34    fn into_response(self) -> axum::response::Response {
35        match self {
36            SpeechGenerationResponder::InternalError(e) => {
37                JsonError::new(sanitize_error_message(e.as_ref()))
38                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
39            }
40            SpeechGenerationResponder::ValidationError(e) => {
41                JsonError::new(sanitize_error_message(e.as_ref()))
42                    .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
43            }
44            SpeechGenerationResponder::RawResponse(resp) => resp,
45        }
46    }
47}
48
49/// Parses and validates a speech generation request.
50///
51/// This function transforms a speech generation request into the
52/// request format used by mistral.rs.
53pub fn parse_request(
54    oairequest: SpeechGenerationRequest,
55    state: Arc<MistralRs>,
56    tx: Sender<Response>,
57) -> Result<(Request, AudioResponseFormat)> {
58    let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
59    MistralRs::maybe_log_request(state.clone(), repr);
60
61    // Validate that the requested model matches the loaded model
62    validate_model_name(&oairequest.model, state.clone())?;
63
64    let request = Request::Normal(Box::new(NormalRequest {
65        id: state.next_request_id(),
66        messages: RequestMessage::SpeechGeneration {
67            prompt: oairequest.input,
68        },
69        sampling_params: SamplingParams::deterministic(),
70        response: tx,
71        return_logprobs: false,
72        is_streaming: false,
73        suffix: None,
74        constraint: Constraint::None,
75        tool_choice: None,
76        tools: None,
77        logits_processors: None,
78        return_raw_logits: false,
79        web_search_options: None,
80        model_id: if oairequest.model == "default" {
81            None
82        } else {
83            Some(oairequest.model.clone())
84        },
85        truncate_sequence: false,
86    }));
87
88    Ok((request, oairequest.response_format))
89}
90
91/// Speech generation endpoint handler.
92#[utoipa::path(
93    post,
94    tag = "Mistral.rs",
95    path = "/v1/audio/speech",
96    request_body = SpeechGenerationRequest,
97    responses((status = 200, description = "Speech generation"))
98)]
99pub async fn speech_generation(
100    State(state): State<Arc<MistralRs>>,
101    Json(oairequest): Json<SpeechGenerationRequest>,
102) -> SpeechGenerationResponder {
103    let (tx, mut rx) = create_response_channel(None);
104
105    let (request, response_format) = match parse_request(oairequest, state.clone(), tx) {
106        Ok(x) => x,
107        Err(e) => return handle_error(state, e.into()),
108    };
109
110    // Validate response format here
111    if !matches!(
112        response_format,
113        AudioResponseFormat::Wav | AudioResponseFormat::Pcm
114    ) {
115        return SpeechGenerationResponder::ValidationError(Box::new(JsonError::new(
116            "Only support wav/pcm response format.".to_string(),
117        )));
118    }
119
120    if let Err(e) = send_request(&state, request).await {
121        return handle_error(state, e.into());
122    }
123
124    process_non_streaming_response(&mut rx, state, response_format).await
125}
126
127/// Helper function to handle speech generation errors and logging them.
128pub fn handle_error(
129    state: SharedMistralRsState,
130    e: Box<dyn std::error::Error + Send + Sync + 'static>,
131) -> SpeechGenerationResponder {
132    let sanitized_msg = sanitize_error_message(&*e);
133    let e = anyhow::Error::msg(sanitized_msg);
134    MistralRs::maybe_log_error(state, &*e);
135    SpeechGenerationResponder::InternalError(e.into())
136}
137
138/// Process non-streaming speech generation responses.
139pub async fn process_non_streaming_response(
140    rx: &mut Receiver<Response>,
141    state: SharedMistralRsState,
142    response_format: AudioResponseFormat,
143) -> SpeechGenerationResponder {
144    let response = match rx.recv().await {
145        Some(response) => response,
146        None => {
147            let e = anyhow::Error::msg("No response received from the model.");
148            return handle_error(state, e.into());
149        }
150    };
151
152    match_responses(state, response, response_format)
153}
154
155/// Matches and processes different types of model responses into appropriate speech generation responses.
156pub fn match_responses(
157    state: SharedMistralRsState,
158    response: Response,
159    response_format: AudioResponseFormat,
160) -> SpeechGenerationResponder {
161    match response {
162        Response::InternalError(e) => {
163            MistralRs::maybe_log_error(state, &*e);
164            SpeechGenerationResponder::InternalError(e)
165        }
166        Response::ValidationError(e) => SpeechGenerationResponder::ValidationError(e),
167        Response::ImageGeneration(_) => unreachable!(),
168        Response::CompletionModelError(m, _) => {
169            let e = anyhow::Error::msg(m.to_string());
170            MistralRs::maybe_log_error(state, &*e);
171            SpeechGenerationResponder::InternalError(e.into())
172        }
173        Response::CompletionDone(_) => unreachable!(),
174        Response::CompletionChunk(_) => unreachable!(),
175        Response::Chunk(_) => unreachable!(),
176        Response::Done(_) => unreachable!(),
177        Response::ModelError(_, _) => unreachable!(),
178        Response::Speech {
179            pcm,
180            rate,
181            channels,
182        } => {
183            let pcm_endianness = "s16le";
184
185            let content_type = response_format.audio_content_type(rate, channels, pcm_endianness);
186            let mut headers = HeaderMap::new();
187            headers.insert(
188                http::header::CONTENT_TYPE,
189                HeaderValue::from_str(&content_type).unwrap(),
190            );
191
192            let encoded = match response_format {
193                AudioResponseFormat::Pcm => {
194                    let samples: &[f32] = &pcm;
195                    let mut buf = Vec::with_capacity(samples.len() * std::mem::size_of::<i64>());
196                    for &sample in samples {
197                        buf.extend_from_slice(&sample.to_i16().to_le_bytes());
198                    }
199                    buf
200                }
201                AudioResponseFormat::Wav => {
202                    // Write WAV data into an in-memory buffer
203                    let mut buf = Vec::new();
204                    speech_utils::write_pcm_as_wav(&mut buf, &pcm, rate as u32, channels as u16)
205                        .unwrap();
206                    buf
207                }
208                _ => unreachable!("Should be validated above."),
209            };
210
211            let bytes = Bytes::from(encoded);
212
213            SpeechGenerationResponder::RawResponse((StatusCode::OK, headers, bytes).into_response())
214        }
215        Response::Raw { .. } => unreachable!(),
216        Response::Embeddings { .. } => unreachable!(),
217    }
218}