mistralrs_server_core/
speech_generation.rs

1//! ## Speech generation functionality and route handler.
2
3use std::{error::Error, sync::Arc};
4
5use anyhow::Result;
6use axum::{
7    body::Bytes,
8    extract::{Json, State},
9    http::{self, HeaderMap, HeaderValue, StatusCode},
10    response::IntoResponse,
11};
12use mistralrs_core::{
13    speech_utils::{self, Sample},
14    Constraint, MistralRs, NormalRequest, Request, RequestMessage, Response, SamplingParams,
15};
16use tokio::sync::mpsc::{Receiver, Sender};
17
18use crate::{
19    handler_core::{create_response_channel, send_request, ErrorToResponse, JsonError},
20    openai::{AudioResponseFormat, SpeechGenerationRequest},
21    types::SharedMistralRsState,
22    util::validate_model_name,
23};
24
25/// Represents different types of speech generation responses.
26pub enum SpeechGenerationResponder {
27    InternalError(Box<dyn Error>),
28    ValidationError(Box<dyn Error>),
29    RawResponse(axum::response::Response),
30}
31
32impl IntoResponse for SpeechGenerationResponder {
33    /// Converts the speech generation responder into an HTTP response.
34    fn into_response(self) -> axum::response::Response {
35        match self {
36            SpeechGenerationResponder::InternalError(e) => {
37                JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
38            }
39            SpeechGenerationResponder::ValidationError(e) => {
40                JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
41            }
42            SpeechGenerationResponder::RawResponse(resp) => resp,
43        }
44    }
45}
46
47/// Parses and validates a speech generation request.
48///
49/// This function transforms a speech generation request into the
50/// request format used by mistral.rs.
51pub fn parse_request(
52    oairequest: SpeechGenerationRequest,
53    state: Arc<MistralRs>,
54    tx: Sender<Response>,
55) -> Result<(Request, AudioResponseFormat)> {
56    let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
57    MistralRs::maybe_log_request(state.clone(), repr);
58
59    // Validate that the requested model matches the loaded model
60    validate_model_name(&oairequest.model, state.clone())?;
61
62    let request = Request::Normal(Box::new(NormalRequest {
63        id: state.next_request_id(),
64        messages: RequestMessage::SpeechGeneration {
65            prompt: oairequest.input,
66        },
67        sampling_params: SamplingParams::deterministic(),
68        response: tx,
69        return_logprobs: false,
70        is_streaming: false,
71        suffix: None,
72        constraint: Constraint::None,
73        tool_choice: None,
74        tools: None,
75        logits_processors: None,
76        return_raw_logits: false,
77        web_search_options: None,
78        model_id: if oairequest.model == "default" {
79            None
80        } else {
81            Some(oairequest.model.clone())
82        },
83    }));
84
85    Ok((request, oairequest.response_format))
86}
87
88/// Speech generation endpoint handler.
89#[utoipa::path(
90    post,
91    tag = "Mistral.rs",
92    path = "/v1/audio/speech",
93    request_body = SpeechGenerationRequest,
94    responses((status = 200, description = "Speech generation"))
95)]
96pub async fn speech_generation(
97    State(state): State<Arc<MistralRs>>,
98    Json(oairequest): Json<SpeechGenerationRequest>,
99) -> SpeechGenerationResponder {
100    let (tx, mut rx) = create_response_channel(None);
101
102    let (request, response_format) = match parse_request(oairequest, state.clone(), tx) {
103        Ok(x) => x,
104        Err(e) => return handle_error(state, e.into()),
105    };
106
107    // Validate response format here
108    if !matches!(
109        response_format,
110        AudioResponseFormat::Wav | AudioResponseFormat::Pcm
111    ) {
112        return SpeechGenerationResponder::ValidationError(Box::new(JsonError::new(
113            "Only support wav/pcm response format.".to_string(),
114        )));
115    }
116
117    if let Err(e) = send_request(&state, request).await {
118        return handle_error(state, e.into());
119    }
120
121    process_non_streaming_response(&mut rx, state, response_format).await
122}
123
124/// Helper function to handle speech generation errors and logging them.
125pub fn handle_error(
126    state: SharedMistralRsState,
127    e: Box<dyn std::error::Error + Send + Sync + 'static>,
128) -> SpeechGenerationResponder {
129    let e = anyhow::Error::msg(e.to_string());
130    MistralRs::maybe_log_error(state, &*e);
131    SpeechGenerationResponder::InternalError(e.into())
132}
133
134/// Process non-streaming speech generation responses.
135pub async fn process_non_streaming_response(
136    rx: &mut Receiver<Response>,
137    state: SharedMistralRsState,
138    response_format: AudioResponseFormat,
139) -> SpeechGenerationResponder {
140    let response = match rx.recv().await {
141        Some(response) => response,
142        None => {
143            let e = anyhow::Error::msg("No response received from the model.");
144            return handle_error(state, e.into());
145        }
146    };
147
148    match_responses(state, response, response_format)
149}
150
151/// Matches and processes different types of model responses into appropriate speech generation responses.
152pub fn match_responses(
153    state: SharedMistralRsState,
154    response: Response,
155    response_format: AudioResponseFormat,
156) -> SpeechGenerationResponder {
157    match response {
158        Response::InternalError(e) => {
159            MistralRs::maybe_log_error(state, &*e);
160            SpeechGenerationResponder::InternalError(e)
161        }
162        Response::ValidationError(e) => SpeechGenerationResponder::ValidationError(e),
163        Response::ImageGeneration(_) => unreachable!(),
164        Response::CompletionModelError(m, _) => {
165            let e = anyhow::Error::msg(m.to_string());
166            MistralRs::maybe_log_error(state, &*e);
167            SpeechGenerationResponder::InternalError(e.into())
168        }
169        Response::CompletionDone(_) => unreachable!(),
170        Response::CompletionChunk(_) => unreachable!(),
171        Response::Chunk(_) => unreachable!(),
172        Response::Done(_) => unreachable!(),
173        Response::ModelError(_, _) => unreachable!(),
174        Response::Speech {
175            pcm,
176            rate,
177            channels,
178        } => {
179            let pcm_endianness = "s16le";
180
181            let content_type = response_format.audio_content_type(rate, channels, pcm_endianness);
182            let mut headers = HeaderMap::new();
183            headers.insert(
184                http::header::CONTENT_TYPE,
185                HeaderValue::from_str(&content_type).unwrap(),
186            );
187
188            let encoded = match response_format {
189                AudioResponseFormat::Pcm => {
190                    let samples: &[f32] = &pcm;
191                    let mut buf = Vec::with_capacity(samples.len() * std::mem::size_of::<i64>());
192                    for &sample in samples {
193                        buf.extend_from_slice(&sample.to_i16().to_le_bytes());
194                    }
195                    buf
196                }
197                AudioResponseFormat::Wav => {
198                    // Write WAV data into an in-memory buffer
199                    let mut buf = Vec::new();
200                    speech_utils::write_pcm_as_wav(&mut buf, &pcm, rate as u32, channels as u16)
201                        .unwrap();
202                    buf
203                }
204                _ => unreachable!("Should be validated above."),
205            };
206
207            let bytes = Bytes::from(encoded);
208
209            SpeechGenerationResponder::RawResponse((StatusCode::OK, headers, bytes).into_response())
210        }
211        Response::Raw { .. } => unreachable!(),
212    }
213}