mistralrs_core/speech_models/
mod.rs1mod bs1770;
2mod dia;
3pub mod utils;
4
5use std::{str::FromStr, sync::Arc};
6
7pub use dia::{DiaConfig, DiaPipeline};
8use serde::Deserialize;
9
10#[derive(Clone, Copy, Debug, Deserialize, PartialEq)]
11pub enum SpeechLoaderType {
12 #[serde(rename = "dia")]
13 Dia,
14}
15
16impl FromStr for SpeechLoaderType {
17 type Err = String;
18 fn from_str(s: &str) -> Result<Self, Self::Err> {
19 match s {
20 "dia" => Ok(Self::Dia),
21 a => Err(format!(
22 "Unknown architecture `{a}`. Possible architectures: `dia`."
23 )),
24 }
25 }
26}
27
28#[derive(Clone, Copy, Debug)]
29pub enum SpeechGenerationConfig {
30 Dia {
31 max_tokens: Option<usize>,
32 cfg_scale: f32,
33 temperature: f32,
34 top_p: f32,
35 top_k: Option<usize>,
36 },
37}
38
39impl SpeechGenerationConfig {
40 pub fn default(ty: SpeechLoaderType) -> Self {
41 match ty {
42 SpeechLoaderType::Dia => Self::Dia {
43 max_tokens: None,
44 cfg_scale: 3.,
45 temperature: 1.3,
46 top_p: 0.95,
47 top_k: Some(35),
48 },
49 }
50 }
51}
52
53#[derive(Clone, Debug)]
54pub struct SpeechGenerationOutput {
55 pub pcm: Arc<Vec<f32>>,
56 pub rate: usize,
57 pub channels: usize,
58}