mistralrs_audio/
lib.rs

1//! Audio utilities for `mistral.rs`.
2//!
3//! This crate mirrors `mistralrs-vision` and focuses on audio specific
4//! functionality such as reading audio data, resampling and computing
5//! mel spectrogram features.
6
7use anyhow::Result;
8use symphonia::core::{
9    audio::SampleBuffer, codecs::DecoderOptions, formats::FormatOptions, io::MediaSourceStream,
10    meta::MetadataOptions, probe::Hint,
11};
12
13/// Raw audio input consisting of PCM samples and a sample rate.
14#[derive(Clone, Debug, PartialEq)]
15pub struct AudioInput {
16    pub samples: Vec<f32>,
17    pub sample_rate: u32,
18    pub channels: u16,
19}
20
21impl AudioInput {
22    /// Read a wav file from disk.
23    pub fn read_wav(wav_path: &str) -> Result<Self> {
24        let mut reader = hound::WavReader::open(wav_path)?;
25        let spec = reader.spec();
26        let samples: Vec<f32> = match spec.sample_format {
27            hound::SampleFormat::Float => reader
28                .samples::<f32>()
29                .collect::<std::result::Result<_, _>>()?,
30            hound::SampleFormat::Int => reader
31                .samples::<i16>()
32                .map(|s| s.map(|v| v as f32 / i16::MAX as f32))
33                .collect::<std::result::Result<_, _>>()?,
34        };
35        Ok(Self {
36            samples,
37            sample_rate: spec.sample_rate,
38            channels: spec.channels,
39        })
40    }
41
42    /// Decode audio bytes using `symphonia`.
43    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
44        let cursor = std::io::Cursor::new(bytes.to_vec());
45        let mss = MediaSourceStream::new(Box::new(cursor), Default::default());
46        let hint = Hint::new();
47        let probed = symphonia::default::get_probe().format(
48            &hint,
49            mss,
50            &FormatOptions::default(),
51            &MetadataOptions::default(),
52        )?;
53        let mut format = probed.format;
54        let track = format
55            .default_track()
56            .ok_or_else(|| anyhow::anyhow!("no supported audio tracks"))?;
57        let codec_params = &track.codec_params;
58        let sample_rate = codec_params
59            .sample_rate
60            .ok_or_else(|| anyhow::anyhow!("unknown sample rate"))?;
61        #[allow(clippy::cast_possible_truncation)]
62        let channels = codec_params.channels.map(|c| c.count() as u16).unwrap_or(1);
63        let mut decoder =
64            symphonia::default::get_codecs().make(codec_params, &DecoderOptions::default())?;
65        let mut samples = Vec::new();
66        loop {
67            match format.next_packet() {
68                Ok(packet) => {
69                    let decoded = decoder.decode(&packet)?;
70                    let mut buf =
71                        SampleBuffer::<f32>::new(decoded.capacity() as u64, *decoded.spec());
72                    buf.copy_interleaved_ref(decoded);
73                    samples.extend_from_slice(buf.samples());
74                }
75                Err(symphonia::core::errors::Error::IoError(e))
76                    if e.kind() == std::io::ErrorKind::UnexpectedEof =>
77                {
78                    break;
79                }
80                Err(e) => return Err(e.into()),
81            }
82        }
83        Ok(Self {
84            samples,
85            sample_rate,
86            channels,
87        })
88    }
89
90    /// Convert multi channel audio to mono by averaging channels.
91    pub fn to_mono(&self) -> Vec<f32> {
92        if self.channels <= 1 {
93            return self.samples.clone();
94        }
95        let mut mono = vec![0.0; self.samples.len() / self.channels as usize];
96        for (i, sample) in self.samples.iter().enumerate() {
97            mono[i / self.channels as usize] += *sample;
98        }
99        for s in &mut mono {
100            *s /= self.channels as f32;
101        }
102        mono
103    }
104
105    /// Normalize audio to prevent clipping
106    pub fn normalize(&mut self) -> &mut Self {
107        let max_amplitude = self.samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
108        if max_amplitude > 0.0 && max_amplitude != 1.0 {
109            let scale = 1.0 / max_amplitude;
110            for sample in &mut self.samples {
111                *sample *= scale;
112            }
113        }
114        self
115    }
116
117    /// Apply fade in/out to reduce audio artifacts
118    pub fn apply_fade(&mut self, fade_in_samples: usize, fade_out_samples: usize) -> &mut Self {
119        let len = self.samples.len();
120        // Fade in
121        for i in 0..fade_in_samples.min(len) {
122            let factor = i as f32 / fade_in_samples as f32;
123            self.samples[i] *= factor;
124        }
125        // Fade out
126        for i in 0..fade_out_samples.min(len) {
127            let factor = (fade_out_samples - i) as f32 / fade_out_samples as f32;
128            self.samples[len - 1 - i] *= factor;
129        }
130        self
131    }
132
133    /// Remove DC offset (audio centered around 0)
134    pub fn remove_dc_offset(&mut self) -> &mut Self {
135        if self.samples.is_empty() {
136            return self;
137        }
138        let mean = self.samples.iter().sum::<f32>() / self.samples.len() as f32;
139        for sample in &mut self.samples {
140            *sample -= mean;
141        }
142        self
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::AudioInput;
149    use hound::{SampleFormat, WavSpec, WavWriter};
150    use std::io::Cursor;
151
152    #[test]
153    fn read_wav_roundtrip() {
154        let spec = WavSpec {
155            channels: 1,
156            sample_rate: 16000,
157            bits_per_sample: 16,
158            sample_format: SampleFormat::Int,
159        };
160        let mut writer = WavWriter::create("/tmp/test.wav", spec).unwrap();
161        for _ in 0..160 {
162            writer.write_sample::<i16>(0).unwrap();
163        }
164        writer.finalize().unwrap();
165        let input = AudioInput::read_wav("/tmp/test.wav").unwrap();
166        assert_eq!(input.samples.len(), 160);
167        assert_eq!(input.sample_rate, 16000);
168        std::fs::remove_file("/tmp/test.wav").unwrap();
169    }
170
171    #[test]
172    fn from_bytes() {
173        let spec = WavSpec {
174            channels: 1,
175            sample_rate: 8000,
176            bits_per_sample: 16,
177            sample_format: SampleFormat::Int,
178        };
179        let mut buffer: Vec<u8> = Vec::new();
180        {
181            let mut writer = WavWriter::new(Cursor::new(&mut buffer), spec).unwrap();
182            for _ in 0..80 {
183                writer.write_sample::<i16>(0).unwrap();
184            }
185            writer.finalize().unwrap();
186        }
187        let input = AudioInput::from_bytes(&buffer).unwrap();
188        assert_eq!(input.samples.len(), 80);
189        assert_eq!(input.sample_rate, 8000);
190    }
191
192    #[test]
193    fn test_normalize() {
194        let mut input = AudioInput {
195            samples: vec![0.2, -0.5, 0.8, -1.0],
196            sample_rate: 16000,
197            channels: 1,
198        };
199        input.normalize();
200        let max = input.samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
201        assert!((max - 1.0).abs() < 1e-6);
202    }
203
204    #[test]
205    fn test_apply_fade() {
206        let mut input = AudioInput {
207            samples: vec![1.0; 10],
208            sample_rate: 16000,
209            channels: 1,
210        };
211        input.apply_fade(3, 3);
212        assert!((input.samples[0] - 0.0).abs() < 1e-6);
213        assert!(input.samples[1] > 0.0 && input.samples[1] < 1.0);
214        assert!(input.samples[2] > 0.0 && input.samples[2] < 1.0);
215        assert!(input.samples[3] == 1.0);
216        assert!((input.samples[9] - 0.0).abs() < 1e-6);
217    }
218
219    #[test]
220    fn test_remove_dc_offset() {
221        let mut input = AudioInput {
222            samples: vec![1.0, 1.0, 1.0, 1.0],
223            sample_rate: 16000,
224            channels: 1,
225        };
226        input.remove_dc_offset();
227        for s in input.samples {
228            assert!((s - 0.0).abs() < 1e-6);
229        }
230    }
231}