mistralrs_core/speech_models/
utils.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::io::Write;
4
5use candle_core::{Result, Tensor};
6
7use super::bs1770;
8
9pub(crate) fn normalize_loudness(
10    wav: &Tensor,
11    sample_rate: u32,
12    loudness_compressor: bool,
13) -> Result<Tensor> {
14    let energy = wav.sqr()?.mean_all()?.sqrt()?.to_vec0::<f32>()?;
15    if energy < 2e-3 {
16        return Ok(wav.clone());
17    }
18    let wav_array = wav.to_vec1::<f32>()?;
19    let mut meter = bs1770::ChannelLoudnessMeter::new(sample_rate);
20    meter.push(wav_array.into_iter());
21    let power = meter.as_100ms_windows();
22    let loudness = match bs1770::gated_mean(power) {
23        None => return Ok(wav.clone()),
24        Some(gp) => gp.loudness_lkfs() as f64,
25    };
26    let delta_loudness = -14. - loudness;
27    let gain = 10f64.powf(delta_loudness / 20.);
28    let wav = (wav * gain)?;
29    if loudness_compressor {
30        wav.tanh()
31    } else {
32        Ok(wav)
33    }
34}
35
36pub trait Sample {
37    fn to_i16(&self) -> i16;
38}
39
40impl Sample for f32 {
41    fn to_i16(&self) -> i16 {
42        (self.clamp(-1.0, 1.0) * 32767.0) as i16
43    }
44}
45
46impl Sample for f64 {
47    fn to_i16(&self) -> i16 {
48        (self.clamp(-1.0, 1.0) * 32767.0) as i16
49    }
50}
51
52impl Sample for i16 {
53    fn to_i16(&self) -> i16 {
54        *self
55    }
56}
57
58pub fn write_pcm_as_wav<W: Write, S: Sample>(
59    w: &mut W,
60    samples: &[S],
61    sample_rate: u32,
62    n_channels: u16,
63) -> std::io::Result<()> {
64    let len = 12u32; // header
65    let len = len + 24u32; // fmt
66    let len = len + samples.len() as u32 * 2 + 8; // data
67    let bytes_per_second = sample_rate * 2 * n_channels as u32;
68    w.write_all(b"RIFF")?;
69    w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes
70    w.write_all(b"WAVE")?;
71
72    // Format block
73    w.write_all(b"fmt ")?;
74    w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes
75    w.write_all(&1u16.to_le_bytes())?; // PCM
76    w.write_all(&n_channels.to_le_bytes())?; // one channel
77    w.write_all(&sample_rate.to_le_bytes())?;
78    w.write_all(&bytes_per_second.to_le_bytes())?;
79    let block_align = 2 * n_channels;
80    w.write_all(&block_align.to_le_bytes())?; // 2 bytes of data per sample
81    w.write_all(&16u16.to_le_bytes())?; // bits per sample
82
83    // Data block
84    w.write_all(b"data")?;
85    w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?;
86    for sample in samples.iter() {
87        w.write_all(&sample.to_i16().to_le_bytes())?
88    }
89    Ok(())
90}