1use anyhow::Result;
8use symphonia::core::{
9 audio::SampleBuffer, codecs::DecoderOptions, formats::FormatOptions, io::MediaSourceStream,
10 meta::MetadataOptions, probe::Hint,
11};
12
13#[derive(Clone, Debug, PartialEq)]
15pub struct AudioInput {
16 pub samples: Vec<f32>,
17 pub sample_rate: u32,
18 pub channels: u16,
19}
20
21impl AudioInput {
22 pub fn read_wav(wav_path: &str) -> Result<Self> {
24 let mut reader = hound::WavReader::open(wav_path)?;
25 let spec = reader.spec();
26 let samples: Vec<f32> = match spec.sample_format {
27 hound::SampleFormat::Float => reader
28 .samples::<f32>()
29 .collect::<std::result::Result<_, _>>()?,
30 hound::SampleFormat::Int => reader
31 .samples::<i16>()
32 .map(|s| s.map(|v| v as f32 / i16::MAX as f32))
33 .collect::<std::result::Result<_, _>>()?,
34 };
35 Ok(Self {
36 samples,
37 sample_rate: spec.sample_rate,
38 channels: spec.channels,
39 })
40 }
41
42 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
44 let cursor = std::io::Cursor::new(bytes.to_vec());
45 let mss = MediaSourceStream::new(Box::new(cursor), Default::default());
46 let hint = Hint::new();
47 let probed = symphonia::default::get_probe().format(
48 &hint,
49 mss,
50 &FormatOptions::default(),
51 &MetadataOptions::default(),
52 )?;
53 let mut format = probed.format;
54 let track = format
55 .default_track()
56 .ok_or_else(|| anyhow::anyhow!("no supported audio tracks"))?;
57 let codec_params = &track.codec_params;
58 let sample_rate = codec_params
59 .sample_rate
60 .ok_or_else(|| anyhow::anyhow!("unknown sample rate"))?;
61 #[allow(clippy::cast_possible_truncation)]
62 let channels = codec_params.channels.map(|c| c.count() as u16).unwrap_or(1);
63 let mut decoder =
64 symphonia::default::get_codecs().make(codec_params, &DecoderOptions::default())?;
65 let mut samples = Vec::new();
66 loop {
67 match format.next_packet() {
68 Ok(packet) => {
69 let decoded = decoder.decode(&packet)?;
70 let mut buf =
71 SampleBuffer::<f32>::new(decoded.capacity() as u64, *decoded.spec());
72 buf.copy_interleaved_ref(decoded);
73 samples.extend_from_slice(buf.samples());
74 }
75 Err(symphonia::core::errors::Error::IoError(e))
76 if e.kind() == std::io::ErrorKind::UnexpectedEof =>
77 {
78 break;
79 }
80 Err(e) => return Err(e.into()),
81 }
82 }
83 Ok(Self {
84 samples,
85 sample_rate,
86 channels,
87 })
88 }
89
90 pub fn to_mono(&self) -> Vec<f32> {
92 if self.channels <= 1 {
93 return self.samples.clone();
94 }
95 let mut mono = vec![0.0; self.samples.len() / self.channels as usize];
96 for (i, sample) in self.samples.iter().enumerate() {
97 mono[i / self.channels as usize] += *sample;
98 }
99 for s in &mut mono {
100 *s /= self.channels as f32;
101 }
102 mono
103 }
104
105 pub fn normalize(&mut self) -> &mut Self {
107 let max_amplitude = self.samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
108 if max_amplitude > 0.0 && max_amplitude != 1.0 {
109 let scale = 1.0 / max_amplitude;
110 for sample in &mut self.samples {
111 *sample *= scale;
112 }
113 }
114 self
115 }
116
117 pub fn apply_fade(&mut self, fade_in_samples: usize, fade_out_samples: usize) -> &mut Self {
119 let len = self.samples.len();
120 for i in 0..fade_in_samples.min(len) {
122 let factor = i as f32 / fade_in_samples as f32;
123 self.samples[i] *= factor;
124 }
125 for i in 0..fade_out_samples.min(len) {
127 let factor = (fade_out_samples - i) as f32 / fade_out_samples as f32;
128 self.samples[len - 1 - i] *= factor;
129 }
130 self
131 }
132
133 pub fn remove_dc_offset(&mut self) -> &mut Self {
135 if self.samples.is_empty() {
136 return self;
137 }
138 let mean = self.samples.iter().sum::<f32>() / self.samples.len() as f32;
139 for sample in &mut self.samples {
140 *sample -= mean;
141 }
142 self
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::AudioInput;
149 use hound::{SampleFormat, WavSpec, WavWriter};
150 use std::io::Cursor;
151
152 #[test]
153 fn read_wav_roundtrip() {
154 let spec = WavSpec {
155 channels: 1,
156 sample_rate: 16000,
157 bits_per_sample: 16,
158 sample_format: SampleFormat::Int,
159 };
160 let mut writer = WavWriter::create("/tmp/test.wav", spec).unwrap();
161 for _ in 0..160 {
162 writer.write_sample::<i16>(0).unwrap();
163 }
164 writer.finalize().unwrap();
165 let input = AudioInput::read_wav("/tmp/test.wav").unwrap();
166 assert_eq!(input.samples.len(), 160);
167 assert_eq!(input.sample_rate, 16000);
168 std::fs::remove_file("/tmp/test.wav").unwrap();
169 }
170
171 #[test]
172 fn from_bytes() {
173 let spec = WavSpec {
174 channels: 1,
175 sample_rate: 8000,
176 bits_per_sample: 16,
177 sample_format: SampleFormat::Int,
178 };
179 let mut buffer: Vec<u8> = Vec::new();
180 {
181 let mut writer = WavWriter::new(Cursor::new(&mut buffer), spec).unwrap();
182 for _ in 0..80 {
183 writer.write_sample::<i16>(0).unwrap();
184 }
185 writer.finalize().unwrap();
186 }
187 let input = AudioInput::from_bytes(&buffer).unwrap();
188 assert_eq!(input.samples.len(), 80);
189 assert_eq!(input.sample_rate, 8000);
190 }
191
192 #[test]
193 fn test_normalize() {
194 let mut input = AudioInput {
195 samples: vec![0.2, -0.5, 0.8, -1.0],
196 sample_rate: 16000,
197 channels: 1,
198 };
199 input.normalize();
200 let max = input.samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
201 assert!((max - 1.0).abs() < 1e-6);
202 }
203
204 #[test]
205 fn test_apply_fade() {
206 let mut input = AudioInput {
207 samples: vec![1.0; 10],
208 sample_rate: 16000,
209 channels: 1,
210 };
211 input.apply_fade(3, 3);
212 assert!((input.samples[0] - 0.0).abs() < 1e-6);
213 assert!(input.samples[1] > 0.0 && input.samples[1] < 1.0);
214 assert!(input.samples[2] > 0.0 && input.samples[2] < 1.0);
215 assert!(input.samples[3] == 1.0);
216 assert!((input.samples[9] - 0.0).abs() < 1e-6);
217 }
218
219 #[test]
220 fn test_remove_dc_offset() {
221 let mut input = AudioInput {
222 samples: vec![1.0, 1.0, 1.0, 1.0],
223 sample_rate: 16000,
224 channels: 1,
225 };
226 input.remove_dc_offset();
227 for s in input.samples {
228 assert!((s - 0.0).abs() < 1e-6);
229 }
230 }
231}