mistralrs_core/speech_models/
mod.rs

1mod bs1770;
2mod dia;
3pub mod utils;
4
5use std::{str::FromStr, sync::Arc};
6
7pub use dia::{DiaConfig, DiaPipeline};
8use serde::Deserialize;
9
10#[derive(Clone, Copy, Debug, Deserialize, PartialEq)]
11pub enum SpeechLoaderType {
12    #[serde(rename = "dia")]
13    Dia,
14}
15
16impl FromStr for SpeechLoaderType {
17    type Err = String;
18    fn from_str(s: &str) -> Result<Self, Self::Err> {
19        match s {
20            "dia" => Ok(Self::Dia),
21            a => Err(format!(
22                "Unknown architecture `{a}`. Possible architectures: `dia`."
23            )),
24        }
25    }
26}
27
28#[derive(Clone, Copy, Debug)]
29pub enum SpeechGenerationConfig {
30    Dia {
31        max_tokens: Option<usize>,
32        cfg_scale: f32,
33        temperature: f32,
34        top_p: f32,
35        top_k: Option<usize>,
36    },
37}
38
39impl SpeechGenerationConfig {
40    pub fn default(ty: SpeechLoaderType) -> Self {
41        match ty {
42            SpeechLoaderType::Dia => Self::Dia {
43                max_tokens: None,
44                cfg_scale: 3.,
45                temperature: 1.3,
46                top_p: 0.95,
47                top_k: Some(35),
48            },
49        }
50    }
51}
52
53#[derive(Clone, Debug)]
54pub struct SpeechGenerationOutput {
55    pub pcm: Arc<Vec<f32>>,
56    pub rate: usize,
57    pub channels: usize,
58}