diff --git a/benches/pipeline.rs b/benches/pipeline.rs index e4a0d8a9..a28be121 100644 --- a/benches/pipeline.rs +++ b/benches/pipeline.rs @@ -1,6 +1,8 @@ +use std::num::NonZero; use std::time::Duration; use divan::Bencher; +use rodio::ChannelCount; use rodio::{source::UniformSourceIterator, Source}; mod shared; @@ -31,7 +33,11 @@ fn long(bencher: Bencher) { .buffered() .reverb(Duration::from_secs_f32(0.05), 0.3) .skippable(); - let resampled = UniformSourceIterator::new(effects_applied, 2, 40_000); + let resampled = UniformSourceIterator::new( + effects_applied, + ChannelCount::new(2).unwrap(), + NonZero::new(40_000).unwrap(), + ); resampled.for_each(divan::black_box_drop) }) } diff --git a/benches/resampler.rs b/benches/resampler.rs index 6c5c0683..deb10f08 100644 --- a/benches/resampler.rs +++ b/benches/resampler.rs @@ -4,7 +4,7 @@ use rodio::source::UniformSourceIterator; mod shared; use shared::music_wav; -use rodio::Source; +use rodio::{SampleRate, Source}; fn main() { divan::main(); @@ -31,6 +31,7 @@ const COMMON_SAMPLE_RATES: [u32; 12] = [ #[divan::bench(args = COMMON_SAMPLE_RATES)] fn resample_to(bencher: Bencher, target_sample_rate: u32) { + let target_sample_rate = SampleRate::new(target_sample_rate).unwrap(); bencher .with_inputs(|| { let source = music_wav(); diff --git a/benches/shared.rs b/benches/shared.rs index 442621f4..dbe23395 100644 --- a/benches/shared.rs +++ b/benches/shared.rs @@ -6,8 +6,8 @@ use rodio::{ChannelCount, Sample, SampleRate, Source}; pub struct TestSource { samples: vec::IntoIter, - channels: u16, - sample_rate: u32, + channels: ChannelCount, + sample_rate: SampleRate, total_duration: Duration, } diff --git a/examples/custom_config.rs b/examples/custom_config.rs index 4e1ff76a..aa83a7ee 100644 --- a/examples/custom_config.rs +++ b/examples/custom_config.rs @@ -3,6 +3,7 @@ use cpal::{BufferSize, SampleFormat}; use rodio::source::SineWave; use rodio::Source; use std::error::Error; +use std::num::NonZero; use std::thread; use std::time::Duration; @@ -15,7 +16,7 @@ fn main() -> Result<(), Box> { // No need to set all parameters explicitly here, // the defaults were set from the device's description. .with_buffer_size(BufferSize::Fixed(256)) - .with_sample_rate(48_000) + .with_sample_rate(NonZero::new(48_000).unwrap()) .with_sample_format(SampleFormat::F32) // Note that the function below still tries alternative configs if the specified one fails. // If you need to only use the exact specified configuration, diff --git a/examples/mix_multiple_sources.rs b/examples/mix_multiple_sources.rs index 065b6585..69fc53d3 100644 --- a/examples/mix_multiple_sources.rs +++ b/examples/mix_multiple_sources.rs @@ -1,11 +1,12 @@ use rodio::mixer; use rodio::source::{SineWave, Source}; use std::error::Error; +use std::num::NonZero; use std::time::Duration; fn main() -> Result<(), Box> { // Construct a dynamic controller and mixer, stream_handle, and sink. - let (controller, mixer) = mixer::mixer(2, 44_100); + let (controller, mixer) = mixer::mixer(NonZero::new(2).unwrap(), NonZero::new(44_100).unwrap()); let stream_handle = rodio::OutputStreamBuilder::open_default_stream()?; let sink = rodio::Sink::connect_new(&stream_handle.mixer()); diff --git a/examples/signal_generator.rs b/examples/signal_generator.rs index 1ae4048e..2257f30d 100644 --- a/examples/signal_generator.rs +++ b/examples/signal_generator.rs @@ -1,6 +1,7 @@ //! Test signal generator example. use std::error::Error; +use std::num::NonZero; fn main() -> Result<(), Box> { use rodio::source::{chirp, Function, SignalGenerator, Source}; @@ -11,7 +12,7 @@ fn main() -> Result<(), Box> { let test_signal_duration = Duration::from_millis(1000); let interval_duration = Duration::from_millis(1500); - let sample_rate = 48000; + let sample_rate = NonZero::new(48000).unwrap(); println!("Playing 1000 Hz tone"); stream_handle.mixer().add( diff --git a/src/buffer.rs b/src/buffer.rs index 192245e6..8cedf727 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -6,7 +6,8 @@ //! //! ``` //! use rodio::buffer::SamplesBuffer; -//! let _ = SamplesBuffer::new(1, 44100, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); +//! use core::num::NonZero; +//! let _ = SamplesBuffer::new(NonZero::new(1).unwrap(), NonZero::new(44100).unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); //! ``` //! @@ -30,7 +31,6 @@ impl SamplesBuffer { /// /// # Panic /// - /// - Panics if the number of channels is zero. /// - Panics if the samples rate is zero. /// - Panics if the length of the buffer is larger than approximately 16 billion elements. /// This is because the calculation of the duration would overflow. @@ -39,13 +39,10 @@ impl SamplesBuffer { where D: Into>, { - assert!(channels >= 1); - assert!(sample_rate >= 1); - let data = data.into(); let duration_ns = 1_000_000_000u64.checked_mul(data.len() as u64).unwrap() - / sample_rate as u64 - / channels as u64; + / sample_rate.get() as u64 + / channels.get() as u64; let duration = Duration::new( duration_ns / 1_000_000_000, (duration_ns % 1_000_000_000) as u32, @@ -89,14 +86,15 @@ impl Source for SamplesBuffer { // and due to the constant sample_rate we can jump to the right // sample directly. - let curr_channel = self.pos % self.channels() as usize; - let new_pos = pos.as_secs_f32() * self.sample_rate() as f32 * self.channels() as f32; + let curr_channel = self.pos % self.channels().get() as usize; + let new_pos = + pos.as_secs_f32() * self.sample_rate().get() as f32 * self.channels().get() as f32; // saturate pos at the end of the source let new_pos = new_pos as usize; let new_pos = new_pos.min(self.data.len()); // make sure the next sample is for the right channel - let new_pos = new_pos.next_multiple_of(self.channels() as usize); + let new_pos = new_pos.next_multiple_of(self.channels().get() as usize); let new_pos = new_pos - curr_channel; self.pos = new_pos; @@ -123,28 +121,17 @@ impl Iterator for SamplesBuffer { #[cfg(test)] mod tests { use crate::buffer::SamplesBuffer; + use crate::math::nz; use crate::source::Source; #[test] fn basic() { - let _ = SamplesBuffer::new(1, 44100, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); - } - - #[test] - #[should_panic] - fn panic_if_zero_channels() { - SamplesBuffer::new(0, 44100, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); - } - - #[test] - #[should_panic] - fn panic_if_zero_sample_rate() { - SamplesBuffer::new(1, 0, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); + let _ = SamplesBuffer::new(nz!(1), nz!(44100), vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); } #[test] fn duration_basic() { - let buf = SamplesBuffer::new(2, 2, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); + let buf = SamplesBuffer::new(nz!(2), nz!(2), vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); let dur = buf.total_duration().unwrap(); assert_eq!(dur.as_secs(), 1); assert_eq!(dur.subsec_nanos(), 500_000_000); @@ -152,7 +139,7 @@ mod tests { #[test] fn iteration() { - let mut buf = SamplesBuffer::new(1, 44100, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + let mut buf = SamplesBuffer::new(nz!(1), nz!(44100), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); assert_eq!(buf.next(), Some(1.0)); assert_eq!(buf.next(), Some(2.0)); assert_eq!(buf.next(), Some(3.0)); @@ -171,8 +158,8 @@ mod tests { #[test] fn channel_order_stays_correct() { - const SAMPLE_RATE: SampleRate = 100; - const CHANNELS: ChannelCount = 2; + const SAMPLE_RATE: SampleRate = nz!(100); + const CHANNELS: ChannelCount = nz!(2); let mut buf = SamplesBuffer::new( CHANNELS, SAMPLE_RATE, @@ -182,7 +169,10 @@ mod tests { .collect::>(), ); buf.try_seek(Duration::from_secs(5)).unwrap(); - assert_eq!(buf.next(), Some(5.0 * SAMPLE_RATE as f32 * CHANNELS as f32)); + assert_eq!( + buf.next(), + Some(5.0 * SAMPLE_RATE.get() as f32 * CHANNELS.get() as f32) + ); assert!(buf.next().is_some_and(|s| s.trunc() as i32 % 2 == 1)); assert!(buf.next().is_some_and(|s| s.trunc() as i32 % 2 == 0)); diff --git a/src/common.rs b/src/common.rs index 12b3a94d..17ff3b98 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,8 +1,10 @@ +use std::num::NonZero; + /// Stream sample rate (a frame rate or samples per second per channel). -pub type SampleRate = u32; +pub type SampleRate = NonZero; -/// Number of channels in a stream. -pub type ChannelCount = u16; +/// Number of channels in a stream. Can never be Zero +pub type ChannelCount = NonZero; /// Represents value of a single sample. /// Silence corresponds to the value `0.0`. The expected amplitude range is -1.0...1.0. diff --git a/src/conversions/channels.rs b/src/conversions/channels.rs index 17887673..c1401357 100644 --- a/src/conversions/channels.rs +++ b/src/conversions/channels.rs @@ -11,7 +11,7 @@ where from: ChannelCount, to: ChannelCount, sample_repeat: Option, - next_output_sample_pos: ChannelCount, + next_output_sample_pos: u16, } impl ChannelCountConverter @@ -26,9 +26,6 @@ where /// #[inline] pub fn new(input: I, from: ChannelCount, to: ChannelCount) -> ChannelCountConverter { - assert!(from >= 1); - assert!(to >= 1); - ChannelCountConverter { input, from, @@ -65,7 +62,7 @@ where self.sample_repeat = value; value } - x if x < self.from => self.input.next(), + x if x < self.from.get() => self.input.next(), 1 => self.sample_repeat, _ => Some(0.0), }; @@ -74,11 +71,11 @@ where self.next_output_sample_pos += 1; } - if self.next_output_sample_pos == self.to { + if self.next_output_sample_pos == self.to.get() { self.next_output_sample_pos = 0; if self.from > self.to { - for _ in self.to..self.from { + for _ in self.to.get()..self.from.get() { self.input.next(); // discarding extra input } } @@ -91,13 +88,13 @@ where fn size_hint(&self) -> (usize, Option) { let (min, max) = self.input.size_hint(); - let consumed = std::cmp::min(self.from, self.next_output_sample_pos) as usize; + let consumed = std::cmp::min(self.from.get(), self.next_output_sample_pos) as usize; - let min = ((min + consumed) / self.from as usize * self.to as usize) + let min = ((min + consumed) / self.from.get() as usize * self.to.get() as usize) .saturating_sub(self.next_output_sample_pos as usize); let max = max.map(|max| { - ((max + consumed) / self.from as usize * self.to as usize) + ((max + consumed) / self.from.get() as usize * self.to.get() as usize) .saturating_sub(self.next_output_sample_pos as usize) }); @@ -111,31 +108,37 @@ impl ExactSizeIterator for ChannelCountConverter where I: ExactSizeIterato mod test { use super::ChannelCountConverter; use crate::common::ChannelCount; + use crate::math::nz; use crate::Sample; #[test] fn remove_channels() { let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; - let output = ChannelCountConverter::new(input.into_iter(), 3, 2).collect::>(); + let output = + ChannelCountConverter::new(input.into_iter(), nz!(3), nz!(2)).collect::>(); assert_eq!(output, [1.0, 2.0, 4.0, 5.0]); let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; - let output = ChannelCountConverter::new(input.into_iter(), 4, 1).collect::>(); + let output = + ChannelCountConverter::new(input.into_iter(), nz!(4), nz!(1)).collect::>(); assert_eq!(output, [1.0, 5.0]); } #[test] fn add_channels() { let input = vec![1.0, 2.0, 3.0, 4.0]; - let output = ChannelCountConverter::new(input.into_iter(), 1, 2).collect::>(); + let output = + ChannelCountConverter::new(input.into_iter(), nz!(1), nz!(2)).collect::>(); assert_eq!(output, [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]); let input = vec![1.0, 2.0]; - let output = ChannelCountConverter::new(input.into_iter(), 1, 4).collect::>(); + let output = + ChannelCountConverter::new(input.into_iter(), nz!(1), nz!(4)).collect::>(); assert_eq!(output, [1.0, 1.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0]); let input = vec![1.0, 2.0, 3.0, 4.0]; - let output = ChannelCountConverter::new(input.into_iter(), 2, 4).collect::>(); + let output = + ChannelCountConverter::new(input.into_iter(), nz!(2), nz!(4)).collect::>(); assert_eq!(output, [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0]); } @@ -152,24 +155,24 @@ mod test { assert_eq!(converter.size_hint(), (0, Some(0))); } - test(&[1.0, 2.0, 3.0], 1, 2); - test(&[1.0, 2.0, 3.0, 4.0], 2, 4); - test(&[1.0, 2.0, 3.0, 4.0], 4, 2); - test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 8); - test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], 4, 1); + test(&[1.0, 2.0, 3.0], nz!(1), nz!(2)); + test(&[1.0, 2.0, 3.0, 4.0], nz!(2), nz!(4)); + test(&[1.0, 2.0, 3.0, 4.0], nz!(4), nz!(2)); + test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], nz!(3), nz!(8)); + test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], nz!(4), nz!(1)); } #[test] fn len_more() { let input = vec![1.0, 2.0, 3.0, 4.0]; - let output = ChannelCountConverter::new(input.into_iter(), 2, 3); + let output = ChannelCountConverter::new(input.into_iter(), nz!(2), nz!(3)); assert_eq!(output.len(), 6); } #[test] fn len_less() { let input = vec![1.0, 2.0, 3.0, 4.0]; - let output = ChannelCountConverter::new(input.into_iter(), 2, 1); + let output = ChannelCountConverter::new(input.into_iter(), nz!(2), nz!(1)); assert_eq!(output.len(), 2); } } diff --git a/src/conversions/sample_rate.rs b/src/conversions/sample_rate.rs index da85f52e..5d457aaa 100644 --- a/src/conversions/sample_rate.rs +++ b/src/conversions/sample_rate.rs @@ -54,27 +54,23 @@ where to: SampleRate, num_channels: ChannelCount, ) -> SampleRateConverter { - assert!(num_channels >= 1); - assert!(from >= 1); - assert!(to >= 1); - let (first_samples, next_samples) = if from == to { // if `from` == `to` == 1, then we just pass through (Vec::new(), Vec::new()) } else { let first = input .by_ref() - .take(num_channels as usize) + .take(num_channels.get() as usize) .collect::>(); let next = input .by_ref() - .take(num_channels as usize) + .take(num_channels.get() as usize) .collect::>(); (first, next) }; // Reducing numerator to avoid numeric overflows during interpolation. - let (to, from) = Ratio::new(to, from).into_raw(); + let (to, from) = Ratio::new(to.get(), from.get()).into_raw(); SampleRateConverter { input, @@ -85,7 +81,7 @@ where next_output_span_pos_in_chunk: 0, current_span: first_samples, next_frame: next_samples, - output_buffer: Vec::with_capacity(num_channels as usize - 1), + output_buffer: Vec::with_capacity(num_channels.get() as usize - 1), } } @@ -106,7 +102,7 @@ where mem::swap(&mut self.current_span, &mut self.next_frame); self.next_frame.clear(); - for _ in 0..self.channels { + for _ in 0..self.channels.get() { if let Some(i) = self.input.next() { self.next_frame.push(i); } else { @@ -213,7 +209,7 @@ where // removing the samples of the current chunk that have not yet been read let samples_after_chunk = samples_after_chunk.saturating_sub( self.from.saturating_sub(self.current_span_pos_in_chunk + 2) as usize - * usize::from(self.channels), + * usize::from(self.channels.get()), ); // calculating the number of samples after the transformation // TODO: this is wrong here \|/ @@ -222,7 +218,7 @@ where // `samples_current_chunk` will contain the number of samples remaining to be output // for the chunk currently being processed let samples_current_chunk = (self.to - self.next_output_span_pos_in_chunk) as usize - * usize::from(self.channels); + * usize::from(self.channels.get()); samples_current_chunk + samples_after_chunk + self.output_buffer.len() }; @@ -242,25 +238,22 @@ impl ExactSizeIterator for SampleRateConverter where I: ExactSizeIterator< mod test { use super::SampleRateConverter; use crate::common::{ChannelCount, SampleRate}; + use crate::math::nz; use crate::Sample; use core::time::Duration; use quickcheck::{quickcheck, TestResult}; quickcheck! { /// Check that resampling an empty input produces no output. - fn empty(from: u16, to: u16, channels: u8) -> TestResult { - if channels == 0 || channels > 128 - || from == 0 - || to == 0 + fn empty(from: SampleRate, to: SampleRate, channels: ChannelCount) -> TestResult { + if channels.get() > 128 { return TestResult::discard(); } - let from = from as SampleRate; - let to = to as SampleRate; let input: Vec = Vec::new(); let output = - SampleRateConverter::new(input.into_iter(), from, to, channels as ChannelCount) + SampleRateConverter::new(input.into_iter(), from, to, channels) .collect::>(); assert_eq!(output, []); @@ -268,13 +261,12 @@ mod test { } /// Check that resampling to the same rate does not change the signal. - fn identity(from: u16, channels: u8, input: Vec) -> TestResult { - if channels == 0 || channels > 128 || from == 0 { return TestResult::discard(); } - let from = from as SampleRate; + fn identity(from: SampleRate, channels: ChannelCount, input: Vec) -> TestResult { + if channels.get() > 128 { return TestResult::discard(); } let input = Vec::from_iter(input.iter().map(|x| *x as Sample)); let output = - SampleRateConverter::new(input.clone().into_iter(), from, from, channels as ChannelCount) + SampleRateConverter::new(input.clone().into_iter(), from, from, channels) .collect::>(); TestResult::from_bool(input == output) @@ -282,75 +274,74 @@ mod test { /// Check that dividing the sample rate by k (integer) is the same as /// dropping a sample from each channel. - fn divide_sample_rate(to: u16, k: u16, input: Vec, channels: u8) -> TestResult { - if k == 0 || channels == 0 || channels > 128 || to == 0 || to > 48000 { + fn divide_sample_rate(to: SampleRate, k: u16, input: Vec, channels: ChannelCount) -> TestResult { + if k == 0 || channels.get() > 128 || to.get() > 48000 { return TestResult::discard(); } let input = Vec::from_iter(input.iter().map(|x| *x as Sample)); let to = to as SampleRate; - let from = to * k as u32; + let from = to.get() * k as u32; // Truncate the input, so it contains an integer number of spans. let input = { - let ns = channels as usize; + let ns = channels.get() as usize; let mut i = input; i.truncate(ns * (i.len() / ns)); i }; let output = - SampleRateConverter::new(input.clone().into_iter(), from, to, channels as ChannelCount) + SampleRateConverter::new(input.clone().into_iter(), SampleRate::new(from).unwrap(), to, channels) .collect::>(); - TestResult::from_bool(input.chunks_exact(channels.into()) + TestResult::from_bool(input.chunks_exact(channels.get().into()) .step_by(k as usize).collect::>().concat() == output) } /// Check that, after multiplying the sample rate by k, every k-th - /// sample in the output matches exactly with the input. - fn multiply_sample_rate(from: u16, k: u8, input: Vec, channels: u8) -> TestResult { - if k == 0 || channels == 0 || channels > 128 || from == 0 { + /// sample in the output matches exactly with the input. + fn multiply_sample_rate(from: SampleRate, k: u8, input: Vec, channels: ChannelCount) -> TestResult { + if k == 0 || from.get() > u16::MAX as u32 || channels.get() > 128 { return TestResult::discard(); } let input = Vec::from_iter(input.iter().map(|x| *x as Sample)); let from = from as SampleRate; - let to = from * k as u32; + dbg!(from, k); + let to = from.get() * k as u32; // Truncate the input, so it contains an integer number of spans. let input = { - let ns = channels as usize; + let ns = channels.get() as usize; let mut i = input; i.truncate(ns * (i.len() / ns)); i }; let output = - SampleRateConverter::new(input.clone().into_iter(), from, to, channels as ChannelCount) + SampleRateConverter::new(input.clone().into_iter(), from, SampleRate::new(to).unwrap(), channels) .collect::>(); TestResult::from_bool(input == - output.chunks_exact(channels.into()) + output.chunks_exact(channels.get().into()) .step_by(k as usize).collect::>().concat()) } #[ignore] /// Check that resampling does not change the audio duration, - /// except by a negligible amount (± 1ms). Reproduces #316. + /// except by a negligible amount (± 1ms). Reproduces #316. /// Ignored, pending a bug fix. fn preserve_durations(d: Duration, freq: f32, to: SampleRate) -> TestResult { - if to == 0 { return TestResult::discard(); } - use crate::source::{SineWave, Source}; let source = SineWave::new(freq).take_duration(d); let from = source.sample_rate(); let resampled = - SampleRateConverter::new(source, from, to, 1); + SampleRateConverter::new(source, from, to, nz!(1)); let duration = - Duration::from_secs_f32(resampled.count() as f32 / to as f32); + Duration::from_secs_f32(resampled.count() as f32 / to.get() as f32); let delta = if d < duration { duration - d } else { d - duration }; TestResult::from_bool(delta < Duration::from_millis(1)) @@ -360,7 +351,7 @@ mod test { #[test] fn upsample() { let input = vec![2.0, 16.0, 4.0, 18.0, 6.0, 20.0, 8.0, 22.0]; - let output = SampleRateConverter::new(input.into_iter(), 2000, 3000, 2); + let output = SampleRateConverter::new(input.into_iter(), nz!(2000), nz!(3000), nz!(2)); assert_eq!(output.len(), 12); // Test the source's Iterator::size_hint() let output = output.map(|x| x.trunc()).collect::>(); @@ -373,7 +364,7 @@ mod test { #[test] fn upsample2() { let input = vec![1.0, 14.0]; - let output = SampleRateConverter::new(input.into_iter(), 1000, 7000, 1); + let output = SampleRateConverter::new(input.into_iter(), nz!(1000), nz!(7000), nz!(1)); let size_estimation = output.len(); let output = output.map(|x| x.trunc()).collect::>(); assert_eq!(output, [1.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0]); @@ -383,7 +374,7 @@ mod test { #[test] fn downsample() { let input = Vec::from_iter((0..17).map(|x| x as Sample)); - let output = SampleRateConverter::new(input.into_iter(), 12000, 2400, 1); + let output = SampleRateConverter::new(input.into_iter(), nz!(12000), nz!(2400), nz!(1)); let size_estimation = output.len(); let output = output.collect::>(); assert_eq!(output, [0.0, 5.0, 10.0, 15.0]); diff --git a/src/decoder/flac.rs b/src/decoder/flac.rs index dd21d871..be93d6bc 100644 --- a/src/decoder/flac.rs +++ b/src/decoder/flac.rs @@ -60,8 +60,14 @@ where current_block_channel_len: 1, current_block_off: 0, bits_per_sample: spec.bits_per_sample, - sample_rate, - channels: spec.channels as ChannelCount, + sample_rate: SampleRate::new(sample_rate) + .expect("flac data should never have a zero sample rate"), + channels: ChannelCount::new( + spec.channels + .try_into() + .expect("rodio supports only up to u16::MAX (65_535) channels"), + ) + .expect("flac should never have zero channels"), total_duration, }) } @@ -115,9 +121,9 @@ where loop { if self.current_block_off < self.current_block.len() { // Read from current block. - let real_offset = (self.current_block_off % self.channels as usize) + let real_offset = (self.current_block_off % self.channels.get() as usize) * self.current_block_channel_len - + self.current_block_off / self.channels as usize; + + self.current_block_off / self.channels.get() as usize; let raw_val = self.current_block[real_offset]; self.current_block_off += 1; let bits = self.bits_per_sample; diff --git a/src/decoder/mod.rs b/src/decoder/mod.rs index 84aa469e..d54de527 100644 --- a/src/decoder/mod.rs +++ b/src/decoder/mod.rs @@ -8,6 +8,7 @@ use std::mem; use std::str::FromStr; use std::time::Duration; +use crate::math::nz; use crate::source::SeekError; use crate::{Sample, Source}; @@ -131,7 +132,7 @@ impl DecoderImpl { DecoderImpl::Mp3(source) => source.channels(), #[cfg(feature = "symphonia")] DecoderImpl::Symphonia(source) => source.channels(), - DecoderImpl::None(_) => 0, + DecoderImpl::None(_) => nz!(1), } } @@ -148,7 +149,7 @@ impl DecoderImpl { DecoderImpl::Mp3(source) => source.sample_rate(), #[cfg(feature = "symphonia")] DecoderImpl::Symphonia(source) => source.sample_rate(), - DecoderImpl::None(_) => 1, + DecoderImpl::None(_) => nz!(1), } } diff --git a/src/decoder/mp3.rs b/src/decoder/mp3.rs index e0680cbf..e1c38fc0 100644 --- a/src/decoder/mp3.rs +++ b/src/decoder/mp3.rs @@ -66,7 +66,7 @@ where #[inline] fn sample_rate(&self) -> SampleRate { - self.current_span.sample_rate as _ + self.current_span.sample_rate } #[inline] @@ -77,7 +77,7 @@ where fn try_seek(&mut self, _pos: Duration) -> Result<(), SeekError> { // TODO waiting for PR in minimp3_fixed or minimp3 - // let pos = (pos.as_secs_f32() * self.sample_rate() as f32) as u64; + // let pos = (pos.as_secs_f32() * self.sample_rate().get() as f32) as u64; // // do not trigger a sample_rate, channels and frame/span len update // // as the seek only takes effect after the current frame/span is done // self.decoder.seek_samples(pos)?; diff --git a/src/decoder/symphonia.rs b/src/decoder/symphonia.rs index 42d0d834..d0b4d910 100644 --- a/src/decoder/symphonia.rs +++ b/src/decoder/symphonia.rs @@ -164,12 +164,19 @@ impl Source for SymphoniaDecoder { #[inline] fn channels(&self) -> ChannelCount { - self.spec.channels.count() as ChannelCount + ChannelCount::new( + self.spec + .channels + .count() + .try_into() + .expect("rodio only support up to u16::MAX channels (65_535)"), + ) + .expect("audio should always have at least one channel") } #[inline] fn sample_rate(&self) -> SampleRate { - self.spec.rate + SampleRate::new(self.spec.rate).expect("audio should always have a non zero SampleRate") } #[inline] @@ -192,7 +199,7 @@ impl Source for SymphoniaDecoder { }; // make sure the next sample is for the right channel - let to_skip = self.current_span_offset % self.channels() as usize; + let to_skip = self.current_span_offset % self.channels().get() as usize; let seek_res = self .format @@ -289,7 +296,7 @@ impl SymphoniaDecoder { let decoded = decoded.map_err(SeekError::Decoding)?; decoded.spec().clone_into(&mut self.spec); self.buffer = SymphoniaDecoder::get_buffer(decoded, &self.spec); - self.current_span_offset = samples_to_pass as usize * self.channels() as usize; + self.current_span_offset = samples_to_pass as usize * self.channels().get() as usize; Ok(()) } } diff --git a/src/decoder/vorbis.rs b/src/decoder/vorbis.rs index 656d96ca..8d81fbf5 100644 --- a/src/decoder/vorbis.rs +++ b/src/decoder/vorbis.rs @@ -69,12 +69,14 @@ where #[inline] fn channels(&self) -> ChannelCount { - self.stream_reader.ident_hdr.audio_channels as ChannelCount + ChannelCount::new(self.stream_reader.ident_hdr.audio_channels.into()) + .expect("audio should have at least one channel") } #[inline] fn sample_rate(&self) -> SampleRate { - self.stream_reader.ident_hdr.audio_sample_rate + SampleRate::new(self.stream_reader.ident_hdr.audio_sample_rate) + .expect("audio should always have a non zero SampleRate") } #[inline] diff --git a/src/decoder/wav.rs b/src/decoder/wav.rs index 421a75e0..db1a39d0 100644 --- a/src/decoder/wav.rs +++ b/src/decoder/wav.rs @@ -41,6 +41,7 @@ where let sample_rate = spec.sample_rate; let channels = spec.channels; + assert!(channels > 0); let total_duration = { let data_rate = sample_rate as u64 * channels as u64; @@ -52,8 +53,9 @@ where Ok(WavDecoder { reader, total_duration, - sample_rate: sample_rate as SampleRate, - channels: channels as ChannelCount, + sample_rate: SampleRate::new(sample_rate) + .expect("wav should have a sample rate higher then zero"), + channels: ChannelCount::new(channels).expect("wav should have a least one channel"), }) } @@ -170,18 +172,18 @@ where fn try_seek(&mut self, pos: Duration) -> Result<(), SeekError> { let file_len = self.reader.reader.duration(); - let new_pos = pos.as_secs_f32() * self.sample_rate() as f32; + let new_pos = pos.as_secs_f32() * self.sample_rate().get() as f32; let new_pos = new_pos as u32; let new_pos = new_pos.min(file_len); // saturate pos at the end of the source // make sure the next sample is for the right channel - let to_skip = self.reader.samples_read % self.channels() as u32; + let to_skip = self.reader.samples_read % self.channels().get() as u32; self.reader .reader .seek(new_pos) .map_err(SeekError::HoundDecoder)?; - self.reader.samples_read = new_pos * self.channels() as u32; + self.reader.samples_read = new_pos * self.channels().get() as u32; for _ in 0..to_skip { self.next(); diff --git a/src/math.rs b/src/math.rs index 2f020748..9f8b8444 100644 --- a/src/math.rs +++ b/src/math.rs @@ -10,6 +10,17 @@ pub fn lerp(first: &f32, second: &f32, numerator: u32, denominator: u32) -> f32 first + (second - first) * numerator as f32 / denominator as f32 } +/// short macro to generate a `NonZero`. It panics during compile if the +/// passed in literal is zero. Used for `ChannelCount` and `Samplerate` +/// constants +macro_rules! nz { + ($n:literal) => { + const { core::num::NonZero::new($n).unwrap() } + }; +} + +pub(crate) use nz; + #[cfg(test)] mod test { use super::*; diff --git a/src/mixer.rs b/src/mixer.rs index 6c89fc9d..cd88f1a0 100644 --- a/src/mixer.rs +++ b/src/mixer.rs @@ -167,7 +167,7 @@ impl MixerSource { let mut pending = self.input.pending_sources.lock().unwrap(); // TODO: relax ordering? for source in pending.drain(..) { - let in_step = self.sample_count % source.channels() as usize == 0; + let in_step = self.sample_count % source.channels().get() as usize == 0; if in_step { self.current_sources.push(source); @@ -198,18 +198,27 @@ impl MixerSource { #[cfg(test)] mod tests { use crate::buffer::SamplesBuffer; + use crate::math::nz; use crate::mixer; use crate::source::Source; #[test] fn basic() { - let (tx, mut rx) = mixer::mixer(1, 48000); + let (tx, mut rx) = mixer::mixer(nz!(1), nz!(48000)); - tx.add(SamplesBuffer::new(1, 48000, vec![10.0, -10.0, 10.0, -10.0])); - tx.add(SamplesBuffer::new(1, 48000, vec![5.0, 5.0, 5.0, 5.0])); + tx.add(SamplesBuffer::new( + nz!(1), + nz!(48000), + vec![10.0, -10.0, 10.0, -10.0], + )); + tx.add(SamplesBuffer::new( + nz!(1), + nz!(48000), + vec![5.0, 5.0, 5.0, 5.0], + )); - assert_eq!(rx.channels(), 1); - assert_eq!(rx.sample_rate(), 48000); + assert_eq!(rx.channels(), nz!(1)); + assert_eq!(rx.sample_rate().get(), 48000); assert_eq!(rx.next(), Some(15.0)); assert_eq!(rx.next(), Some(-5.0)); assert_eq!(rx.next(), Some(15.0)); @@ -219,13 +228,21 @@ mod tests { #[test] fn channels_conv() { - let (tx, mut rx) = mixer::mixer(2, 48000); + let (tx, mut rx) = mixer::mixer(nz!(2), nz!(48000)); - tx.add(SamplesBuffer::new(1, 48000, vec![10.0, -10.0, 10.0, -10.0])); - tx.add(SamplesBuffer::new(1, 48000, vec![5.0, 5.0, 5.0, 5.0])); + tx.add(SamplesBuffer::new( + nz!(1), + nz!(48000), + vec![10.0, -10.0, 10.0, -10.0], + )); + tx.add(SamplesBuffer::new( + nz!(1), + nz!(48000), + vec![5.0, 5.0, 5.0, 5.0], + )); - assert_eq!(rx.channels(), 2); - assert_eq!(rx.sample_rate(), 48000); + assert_eq!(rx.channels(), nz!(2)); + assert_eq!(rx.sample_rate().get(), 48000); assert_eq!(rx.next(), Some(15.0)); assert_eq!(rx.next(), Some(15.0)); assert_eq!(rx.next(), Some(-5.0)); @@ -239,13 +256,21 @@ mod tests { #[test] fn rate_conv() { - let (tx, mut rx) = mixer::mixer(1, 96000); + let (tx, mut rx) = mixer::mixer(nz!(1), nz!(96000)); - tx.add(SamplesBuffer::new(1, 48000, vec![10.0, -10.0, 10.0, -10.0])); - tx.add(SamplesBuffer::new(1, 48000, vec![5.0, 5.0, 5.0, 5.0])); + tx.add(SamplesBuffer::new( + nz!(1), + nz!(48000), + vec![10.0, -10.0, 10.0, -10.0], + )); + tx.add(SamplesBuffer::new( + nz!(1), + nz!(48000), + vec![5.0, 5.0, 5.0, 5.0], + )); - assert_eq!(rx.channels(), 1); - assert_eq!(rx.sample_rate(), 96000); + assert_eq!(rx.channels(), nz!(1)); + assert_eq!(rx.sample_rate().get(), 96000); assert_eq!(rx.next(), Some(15.0)); assert_eq!(rx.next(), Some(5.0)); assert_eq!(rx.next(), Some(-5.0)); @@ -258,16 +283,20 @@ mod tests { #[test] fn start_afterwards() { - let (tx, mut rx) = mixer::mixer(1, 48000); + let (tx, mut rx) = mixer::mixer(nz!(1), nz!(48000)); - tx.add(SamplesBuffer::new(1, 48000, vec![10.0, -10.0, 10.0, -10.0])); + tx.add(SamplesBuffer::new( + nz!(1), + nz!(48000), + vec![10.0, -10.0, 10.0, -10.0], + )); assert_eq!(rx.next(), Some(10.0)); assert_eq!(rx.next(), Some(-10.0)); tx.add(SamplesBuffer::new( - 1, - 48000, + nz!(1), + nz!(48000), vec![5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 7.0], )); @@ -277,7 +306,7 @@ mod tests { assert_eq!(rx.next(), Some(6.0)); assert_eq!(rx.next(), Some(6.0)); - tx.add(SamplesBuffer::new(1, 48000, vec![2.0])); + tx.add(SamplesBuffer::new(nz!(1), nz!(48000), vec![2.0])); assert_eq!(rx.next(), Some(9.0)); assert_eq!(rx.next(), Some(7.0)); diff --git a/src/queue.rs b/src/queue.rs index ccae005f..035a6759 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -4,6 +4,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use std::time::Duration; +use crate::math::nz; use crate::source::{Empty, SeekError, Source, Zero}; use crate::Sample; @@ -220,7 +221,7 @@ impl SourcesQueueOutput { let mut next = self.input.next_sounds.lock().unwrap(); if next.len() == 0 { - let silence = Box::new(Zero::new_samples(1, 44100, THRESHOLD)) as Box<_>; + let silence = Box::new(Zero::new_samples(nz!(1), nz!(44100), THRESHOLD)) as Box<_>; if self.input.keep_alive_if_empty.load(Ordering::Acquire) { // Play a short silence in order to avoid spinlocking. (silence, None) @@ -241,6 +242,7 @@ impl SourcesQueueOutput { #[cfg(test)] mod tests { use crate::buffer::SamplesBuffer; + use crate::math::nz; use crate::queue; use crate::source::Source; @@ -249,17 +251,25 @@ mod tests { fn basic() { let (tx, mut rx) = queue::queue(false); - tx.append(SamplesBuffer::new(1, 48000, vec![10.0, -10.0, 10.0, -10.0])); - tx.append(SamplesBuffer::new(2, 96000, vec![5.0, 5.0, 5.0, 5.0])); - - assert_eq!(rx.channels(), 1); - assert_eq!(rx.sample_rate(), 48000); + tx.append(SamplesBuffer::new( + nz!(1), + nz!(48000), + vec![10.0, -10.0, 10.0, -10.0], + )); + tx.append(SamplesBuffer::new( + nz!(2), + nz!(96000), + vec![5.0, 5.0, 5.0, 5.0], + )); + + assert_eq!(rx.channels(), nz!(1)); + assert_eq!(rx.sample_rate().get(), 48000); assert_eq!(rx.next(), Some(10.0)); assert_eq!(rx.next(), Some(-10.0)); assert_eq!(rx.next(), Some(10.0)); assert_eq!(rx.next(), Some(-10.0)); - assert_eq!(rx.channels(), 2); - assert_eq!(rx.sample_rate(), 96000); + assert_eq!(rx.channels(), nz!(2)); + assert_eq!(rx.sample_rate().get(), 96000); assert_eq!(rx.next(), Some(5.0)); assert_eq!(rx.next(), Some(5.0)); assert_eq!(rx.next(), Some(5.0)); @@ -276,7 +286,11 @@ mod tests { #[test] fn keep_alive() { let (tx, mut rx) = queue::queue(true); - tx.append(SamplesBuffer::new(1, 48000, vec![10.0, -10.0, 10.0, -10.0])); + tx.append(SamplesBuffer::new( + nz!(1), + nz!(48000), + vec![10.0, -10.0, 10.0, -10.0], + )); assert_eq!(rx.next(), Some(10.0)); assert_eq!(rx.next(), Some(-10.0)); @@ -297,7 +311,11 @@ mod tests { assert_eq!(rx.next(), Some(0.0)); } - tx.append(SamplesBuffer::new(1, 48000, vec![10.0, -10.0, 10.0, -10.0])); + tx.append(SamplesBuffer::new( + nz!(1), + nz!(48000), + vec![10.0, -10.0, 10.0, -10.0], + )); assert_eq!(rx.next(), Some(10.0)); assert_eq!(rx.next(), Some(-10.0)); assert_eq!(rx.next(), Some(10.0)); diff --git a/src/sink.rs b/src/sink.rs index 0ba019af..dd04037b 100644 --- a/src/sink.rs +++ b/src/sink.rs @@ -366,6 +366,7 @@ mod tests { use std::sync::atomic::Ordering; use crate::buffer::SamplesBuffer; + use crate::math::nz; use crate::{Sink, Source}; #[test] @@ -382,8 +383,8 @@ mod tests { let v = vec![10.0, -10.0, 20.0, -20.0, 30.0, -30.0]; // Low rate to ensure immediate control. - sink.append(SamplesBuffer::new(1, 1, v.clone())); - let mut reference_src = SamplesBuffer::new(1, 1, v); + sink.append(SamplesBuffer::new(nz!(1), nz!(1), v.clone())); + let mut reference_src = SamplesBuffer::new(nz!(1), nz!(1), v); assert_eq!(source.next(), reference_src.next()); assert_eq!(source.next(), reference_src.next()); @@ -410,8 +411,8 @@ mod tests { let v = vec![10.0, -10.0, 20.0, -20.0, 30.0, -30.0]; - sink.append(SamplesBuffer::new(1, 1, v.clone())); - let mut src = SamplesBuffer::new(1, 1, v.clone()); + sink.append(SamplesBuffer::new(nz!(1), nz!(1), v.clone())); + let mut src = SamplesBuffer::new(nz!(1), nz!(1), v.clone()); assert_eq!(queue_rx.next(), src.next()); assert_eq!(queue_rx.next(), src.next()); @@ -421,8 +422,8 @@ mod tests { assert!(sink.controls.stopped.load(Ordering::SeqCst)); assert_eq!(queue_rx.next(), Some(0.0)); - src = SamplesBuffer::new(1, 1, v.clone()); - sink.append(SamplesBuffer::new(1, 1, v)); + src = SamplesBuffer::new(nz!(1), nz!(1), v.clone()); + sink.append(SamplesBuffer::new(nz!(1), nz!(1), v)); assert!(!sink.controls.stopped.load(Ordering::SeqCst)); // Flush silence @@ -439,8 +440,8 @@ mod tests { let v = vec![10.0, -10.0, 20.0, -20.0, 30.0, -30.0]; // High rate to avoid immediate control. - sink.append(SamplesBuffer::new(2, 44100, v.clone())); - let src = SamplesBuffer::new(2, 44100, v.clone()); + sink.append(SamplesBuffer::new(nz!(2), nz!(44100), v.clone())); + let src = SamplesBuffer::new(nz!(2), nz!(44100), v.clone()); let mut src = src.amplify(0.5); sink.set_volume(0.5); diff --git a/src/source/agc.rs b/src/source/agc.rs index 39a7019b..a7095a15 100644 --- a/src/source/agc.rs +++ b/src/source/agc.rs @@ -143,7 +143,7 @@ pub(crate) fn automatic_gain_control( where I: Source, { - let sample_rate = input.sample_rate(); + let sample_rate = input.sample_rate().get(); let attack_coeff = (-1.0 / (attack_time * sample_rate as f32)).exp(); let release_coeff = (-1.0 / (release_time * sample_rate as f32)).exp(); diff --git a/src/source/blt.rs b/src/source/blt.rs index 2aad5b76..151b0bee 100644 --- a/src/source/blt.rs +++ b/src/source/blt.rs @@ -119,7 +119,7 @@ where let last_in_span = self.input.current_span_len() == Some(1); if self.applier.is_none() { - self.applier = Some(self.formula.to_applier(self.input.sample_rate())); + self.applier = Some(self.formula.to_applier(self.input.sample_rate().get())); } let sample = self.input.next()?; diff --git a/src/source/buffered.rs b/src/source/buffered.rs index 402e63bb..974389dc 100644 --- a/src/source/buffered.rs +++ b/src/source/buffered.rs @@ -5,6 +5,7 @@ use std::time::Duration; use super::SeekError; use crate::common::{ChannelCount, SampleRate}; +use crate::math::nz; use crate::Source; /// Internal function that builds a `Buffered` object. @@ -214,7 +215,7 @@ where fn channels(&self) -> ChannelCount { match *self.current_span { Span::Data(SpanData { channels, .. }) => channels, - Span::End => 1, + Span::End => nz!(1), Span::Input(_) => unreachable!(), } } @@ -223,7 +224,7 @@ where fn sample_rate(&self) -> SampleRate { match *self.current_span { Span::Data(SpanData { rate, .. }) => rate, - Span::End => 44100, + Span::End => nz!(44100), Span::Input(_) => unreachable!(), } } diff --git a/src/source/channel_volume.rs b/src/source/channel_volume.rs index 89b2280d..6dbe7cce 100644 --- a/src/source/channel_volume.rs +++ b/src/source/channel_volume.rs @@ -24,10 +24,10 @@ where /// Wrap the input source and make it mono. Play that mono sound to each /// channel at the volume set by the user. The volume can be changed using /// [`ChannelVolume::set_volume`]. - pub fn new(input: I, channel_volumes: Vec) -> ChannelVolume - where - I: Source, - { + /// + /// # Panics if channel_volumes is empty + pub fn new(input: I, channel_volumes: Vec) -> ChannelVolume { + assert!(!channel_volumes.is_empty()); let channel_count = channel_volumes.len(); // See next() implementation. ChannelVolume { input, @@ -75,12 +75,12 @@ where self.current_channel = 0; self.current_sample = None; let num_channels = self.input.channels(); - for _ in 0..num_channels { + for _ in 0..num_channels.get() { if let Some(s) = self.input.next() { self.current_sample = Some(self.current_sample.unwrap_or(0.0) + s); } } - self.current_sample.map(|s| s / num_channels as f32); + self.current_sample.map(|s| s / num_channels.get() as f32); } let result = self .current_sample @@ -108,7 +108,8 @@ where #[inline] fn channels(&self) -> ChannelCount { - self.channel_volumes.len() as ChannelCount + ChannelCount::new(self.channel_volumes.len() as u16) + .expect("checked to be non-empty in new implementation") } #[inline] diff --git a/src/source/chirp.rs b/src/source/chirp.rs index 4b99eb85..a17b3ad2 100644 --- a/src/source/chirp.rs +++ b/src/source/chirp.rs @@ -1,6 +1,7 @@ //! Chirp/sweep source. use crate::common::{ChannelCount, SampleRate}; +use crate::math::nz; use crate::Source; use std::{f32::consts::TAU, time::Duration}; @@ -37,7 +38,7 @@ impl Chirp { sample_rate, start_frequency, end_frequency, - total_samples: (duration.as_secs_f64() * (sample_rate as f64)) as u64, + total_samples: (duration.as_secs_f64() * (sample_rate.get() as f64)) as u64, elapsed_samples: 0, } } @@ -51,7 +52,7 @@ impl Iterator for Chirp { let ratio = self.elapsed_samples as f32 / self.total_samples as f32; self.elapsed_samples += 1; let freq = self.start_frequency * (1.0 - ratio) + self.end_frequency * ratio; - let t = (i as f32 / self.sample_rate() as f32) * TAU * freq; + let t = (i as f32 / self.sample_rate().get() as f32) * TAU * freq; Some(t.sin()) } } @@ -62,7 +63,7 @@ impl Source for Chirp { } fn channels(&self) -> ChannelCount { - 1 + nz!(1) } fn sample_rate(&self) -> SampleRate { @@ -70,7 +71,7 @@ impl Source for Chirp { } fn total_duration(&self) -> Option { - let secs: f64 = self.total_samples as f64 / self.sample_rate as f64; + let secs: f64 = self.total_samples as f64 / self.sample_rate.get() as f64; Some(Duration::new(1, 0).mul_f64(secs)) } } diff --git a/src/source/crossfade.rs b/src/source/crossfade.rs index 0109c5e4..2e6efc8b 100644 --- a/src/source/crossfade.rs +++ b/src/source/crossfade.rs @@ -33,11 +33,12 @@ pub type Crossfade = Mix, FadeIn>>; mod tests { use super::*; use crate::buffer::SamplesBuffer; + use crate::math::nz; use crate::source::Zero; fn dummy_source(length: u8) -> SamplesBuffer { let data: Vec = (1..=length).map(f32::from).collect(); - SamplesBuffer::new(1, 1, data) + SamplesBuffer::new(nz!(1), nz!(1), data) } #[test] @@ -60,7 +61,7 @@ mod tests { #[test] fn test_crossfade() { let source1 = dummy_source(10); - let source2 = Zero::new(1, 1); + let source2 = Zero::new(nz!(1), nz!(1)); let mixed = crossfade( source1, source2, diff --git a/src/source/delay.rs b/src/source/delay.rs index f2476ddd..1f2d3fb4 100644 --- a/src/source/delay.rs +++ b/src/source/delay.rs @@ -10,7 +10,7 @@ fn remaining_samples( channels: ChannelCount, ) -> usize { let ns = until_playback.as_secs() * 1_000_000_000 + until_playback.subsec_nanos() as u64; - let samples = ns * channels as u64 * sample_rate as u64 / 1_000_000_000; + let samples = ns * channels.get() as u64 * sample_rate.get() as u64 / 1_000_000_000; samples as usize } diff --git a/src/source/empty.rs b/src/source/empty.rs index 8c8b4853..bfdf422e 100644 --- a/src/source/empty.rs +++ b/src/source/empty.rs @@ -2,6 +2,7 @@ use std::time::Duration; use super::SeekError; use crate::common::{ChannelCount, SampleRate}; +use crate::math::nz; use crate::{Sample, Source}; /// An empty source. @@ -41,12 +42,12 @@ impl Source for Empty { #[inline] fn channels(&self) -> ChannelCount { - 1 + nz!(1) } #[inline] fn sample_rate(&self) -> SampleRate { - 48000 + nz!(48000) } #[inline] diff --git a/src/source/empty_callback.rs b/src/source/empty_callback.rs index 08d8b0fb..4ae62437 100644 --- a/src/source/empty_callback.rs +++ b/src/source/empty_callback.rs @@ -2,6 +2,7 @@ use std::time::Duration; use super::SeekError; use crate::common::{ChannelCount, SampleRate}; +use crate::math::nz; use crate::{Sample, Source}; /// An empty source that executes a callback function @@ -38,12 +39,12 @@ impl Source for EmptyCallback { #[inline] fn channels(&self) -> ChannelCount { - 1 + nz!(1) } #[inline] fn sample_rate(&self) -> SampleRate { - 48000 + nz!(48000) } #[inline] diff --git a/src/source/from_iter.rs b/src/source/from_iter.rs index 48251763..59845758 100644 --- a/src/source/from_iter.rs +++ b/src/source/from_iter.rs @@ -2,6 +2,7 @@ use std::time::Duration; use super::SeekError; use crate::common::{ChannelCount, SampleRate}; +use crate::math::nz; use crate::Source; /// Builds a source that chains sources provided by an iterator. @@ -117,7 +118,7 @@ where src.channels() } else { // Dummy value that only happens if the iterator was empty. - 2 + nz!(2) } } @@ -127,7 +128,7 @@ where src.sample_rate() } else { // Dummy value that only happens if the iterator was empty. - 44100 + nz!(44100) } } @@ -149,28 +150,29 @@ where #[cfg(test)] mod tests { use crate::buffer::SamplesBuffer; + use crate::math::nz; use crate::source::{from_iter, Source}; #[test] fn basic() { let mut rx = from_iter((0..2).map(|n| { if n == 0 { - SamplesBuffer::new(1, 48000, vec![10.0, -10.0, 10.0, -10.0]) + SamplesBuffer::new(nz!(1), nz!(48000), vec![10.0, -10.0, 10.0, -10.0]) } else if n == 1 { - SamplesBuffer::new(2, 96000, vec![5.0, 5.0, 5.0, 5.0]) + SamplesBuffer::new(nz!(2), nz!(96000), vec![5.0, 5.0, 5.0, 5.0]) } else { unreachable!() } })); - assert_eq!(rx.channels(), 1); - assert_eq!(rx.sample_rate(), 48000); + assert_eq!(rx.channels(), nz!(1)); + assert_eq!(rx.sample_rate().get(), 48000); assert_eq!(rx.next(), Some(10.0)); assert_eq!(rx.next(), Some(-10.0)); assert_eq!(rx.next(), Some(10.0)); assert_eq!(rx.next(), Some(-10.0)); /*assert_eq!(rx.channels(), 2); - assert_eq!(rx.sample_rate(), 96000);*/ + assert_eq!(rx.sample_rate().get(), 96000);*/ // FIXME: not working assert_eq!(rx.next(), Some(5.0)); assert_eq!(rx.next(), Some(5.0)); diff --git a/src/source/linear_ramp.rs b/src/source/linear_ramp.rs index af830a5e..eb0aad8a 100644 --- a/src/source/linear_ramp.rs +++ b/src/source/linear_ramp.rs @@ -45,7 +45,7 @@ impl LinearGainRamp where I: Source, { - /// Returns a reference to the innner source. + /// Returns a reference to the inner source. #[inline] pub fn inner(&self) -> &I { &self.input @@ -88,8 +88,8 @@ where factor = self.start_gain * (1.0f32 - p) + self.end_gain * p; } - if self.sample_idx % (self.channels() as u64) == 0 { - self.elapsed_ns += 1000000000.0 / (self.input.sample_rate() as f32); + if self.sample_idx % (self.channels().get() as u64) == 0 { + self.elapsed_ns += 1000000000.0 / (self.input.sample_rate().get() as f32); } self.input.next().map(|value| value * factor) @@ -140,13 +140,14 @@ mod tests { use super::*; use crate::buffer::SamplesBuffer; + use crate::math::nz; use crate::Sample; /// Create a SamplesBuffer of identical samples with value `value`. /// Returned buffer is one channel and has a sample rate of 1 hz. fn const_source(length: u8, value: Sample) -> SamplesBuffer { let data: Vec = (1..=length).map(|_| value).collect(); - SamplesBuffer::new(1, 1, data) + SamplesBuffer::new(nz!(1), nz!(1), data) } /// Create a SamplesBuffer of repeating sample values from `values`. @@ -156,7 +157,7 @@ mod tests { .map(|(i, _)| values[i % values.len()]) .collect(); - SamplesBuffer::new(1, 1, data) + SamplesBuffer::new(nz!(1), nz!(1), data) } #[test] diff --git a/src/source/mod.rs b/src/source/mod.rs index aefbd1b2..f0d67d5b 100644 --- a/src/source/mod.rs +++ b/src/source/mod.rs @@ -160,6 +160,7 @@ pub trait Source: Iterator { fn current_span_len(&self) -> Option; /// Returns the number of channels. Channels are always interleaved. + /// Should never be Zero fn channels(&self) -> ChannelCount; /// Returns the rate at which the source should be played. In number of samples per second. diff --git a/src/source/noise.rs b/src/source/noise.rs index 03a2e2dd..85323946 100644 --- a/src/source/noise.rs +++ b/src/source/noise.rs @@ -145,7 +145,7 @@ impl Source for PinkNoise { } fn sample_rate(&self) -> SampleRate { - self.white_noise.sample_rate() + self.white_noise.sample_rate().get() } fn total_duration(&self) -> Option { diff --git a/src/source/pausable.rs b/src/source/pausable.rs index f9fddadc..0bcca512 100644 --- a/src/source/pausable.rs +++ b/src/source/pausable.rs @@ -31,7 +31,7 @@ where pub struct Pausable { input: I, paused_channels: Option, - remaining_paused_samples: ChannelCount, + remaining_paused_samples: u16, } impl Pausable @@ -83,7 +83,7 @@ where } if let Some(paused_channels) = self.paused_channels { - self.remaining_paused_samples = paused_channels - 1; + self.remaining_paused_samples = paused_channels.get() - 1; return Some(0.0); } diff --git a/src/source/periodic.rs b/src/source/periodic.rs index 31260fd1..fb55e673 100644 --- a/src/source/periodic.rs +++ b/src/source/periodic.rs @@ -12,7 +12,8 @@ where // TODO: handle the fact that the samples rate can change // TODO: generally, just wrong let update_ms = period.as_secs() as u32 * 1_000 + period.subsec_millis(); - let update_frequency = (update_ms * source.sample_rate()) / 1000 * source.channels() as u32; + let update_frequency = + (update_ms * source.sample_rate().get()) / 1000 * source.channels().get() as u32; PeriodicAccess { input: source, @@ -131,12 +132,13 @@ mod tests { use std::time::Duration; use crate::buffer::SamplesBuffer; + use crate::math::nz; use crate::source::Source; #[test] fn stereo_access() { // Stereo, 1Hz audio buffer - let inner = SamplesBuffer::new(2, 1, vec![10.0, -10.0, 10.0, -10.0, 20.0, -20.0]); + let inner = SamplesBuffer::new(nz!(2), nz!(1), vec![10.0, -10.0, 10.0, -10.0, 20.0, -20.0]); let cnt = RefCell::new(0); @@ -164,7 +166,7 @@ mod tests { #[test] fn fast_access_overflow() { // 1hz is lower than 0.5 samples per 5ms - let inner = SamplesBuffer::new(1, 1, vec![10.0, -10.0, 10.0, -10.0, 20.0, -20.0]); + let inner = SamplesBuffer::new(nz!(1), nz!(1), vec![10.0, -10.0, 10.0, -10.0, 20.0, -20.0]); let mut source = inner.periodic_access(Duration::from_millis(5), |_src| {}); source.next(); diff --git a/src/source/position.rs b/src/source/position.rs index aa6b63f8..decb0810 100644 --- a/src/source/position.rs +++ b/src/source/position.rs @@ -2,6 +2,7 @@ use std::time::Duration; use super::SeekError; use crate::common::{ChannelCount, SampleRate}; +use crate::math::nz; use crate::Source; /// Internal function that builds a `TrackPosition` object. See trait docs for @@ -11,8 +12,8 @@ pub fn track_position(source: I) -> TrackPosition { input: source, samples_counted: 0, offset_duration: 0.0, - current_span_sample_rate: 0, - current_span_channels: 0, + current_span_sample_rate: nz!(1), + current_span_channels: nz!(1), current_span_len: None, } } @@ -65,8 +66,8 @@ where #[inline] pub fn get_pos(&self) -> Duration { let seconds = self.samples_counted as f64 - / self.input.sample_rate() as f64 - / self.input.channels() as f64 + / self.input.sample_rate().get() as f64 + / self.input.channels().get() as f64 + self.offset_duration; Duration::from_secs_f64(seconds) } @@ -100,8 +101,8 @@ where // offset_duration and start collecting samples again. if Some(self.samples_counted) == self.current_span_len() { self.offset_duration += self.samples_counted as f64 - / self.current_span_sample_rate as f64 - / self.current_span_channels as f64; + / self.current_span_sample_rate.get() as f64 + / self.current_span_channels.get() as f64; // Reset. self.samples_counted = 0; @@ -160,11 +161,12 @@ mod tests { use std::time::Duration; use crate::buffer::SamplesBuffer; + use crate::math::nz; use crate::source::Source; #[test] fn test_position() { - let inner = SamplesBuffer::new(1, 1, vec![10.0, -10.0, 10.0, -10.0, 20.0, -20.0]); + let inner = SamplesBuffer::new(nz!(1), nz!(1), vec![10.0, -10.0, 10.0, -10.0, 20.0, -20.0]); let mut source = inner.track_position(); assert_eq!(source.get_pos().as_secs_f32(), 0.0); @@ -180,7 +182,7 @@ mod tests { #[test] fn test_position_in_presence_of_speedup() { - let inner = SamplesBuffer::new(1, 1, vec![10.0, -10.0, 10.0, -10.0, 20.0, -20.0]); + let inner = SamplesBuffer::new(nz!(1), nz!(1), vec![10.0, -10.0, 10.0, -10.0, 20.0, -20.0]); let mut source = inner.speed(2.0).track_position(); assert_eq!(source.get_pos().as_secs_f32(), 0.0); diff --git a/src/source/sawtooth.rs b/src/source/sawtooth.rs index c6ae01f3..a0e812e4 100644 --- a/src/source/sawtooth.rs +++ b/src/source/sawtooth.rs @@ -1,4 +1,5 @@ use crate::common::{ChannelCount, SampleRate}; +use crate::math::nz; use crate::source::{Function, SignalGenerator}; use crate::Source; use std::time::Duration; @@ -17,7 +18,7 @@ pub struct SawtoothWave { } impl SawtoothWave { - const SAMPLE_RATE: SampleRate = 48000; + const SAMPLE_RATE: SampleRate = nz!(48000); /// The frequency of the sine. #[inline] @@ -45,7 +46,7 @@ impl Source for SawtoothWave { #[inline] fn channels(&self) -> ChannelCount { - 1 + nz!(1) } #[inline] diff --git a/src/source/signal_generator.rs b/src/source/signal_generator.rs index e0c33bb8..3b914ff7 100644 --- a/src/source/signal_generator.rs +++ b/src/source/signal_generator.rs @@ -8,11 +8,13 @@ //! //! ``` //! use rodio::source::{SignalGenerator,Function}; +//! use core::num::NonZero; //! -//! let tone = SignalGenerator::new(48000, 440.0, Function::Sine); +//! let tone = SignalGenerator::new(NonZero::new(48000).unwrap(), 440.0, Function::Sine); //! ``` use super::SeekError; use crate::common::{ChannelCount, SampleRate}; +use crate::math::nz; use crate::Source; use std::f32::consts::TAU; use std::time::Duration; @@ -110,7 +112,7 @@ impl SignalGenerator { generator_function: GeneratorFunction, ) -> Self { assert!(frequency != 0.0, "frequency must be greater than zero"); - let period = sample_rate as f32 / frequency; + let period = sample_rate.get() as f32 / frequency; let phase_step = 1.0f32 / period; SignalGenerator { @@ -143,7 +145,7 @@ impl Source for SignalGenerator { #[inline] fn channels(&self) -> ChannelCount { - 1 + nz!(1) } #[inline] @@ -158,7 +160,7 @@ impl Source for SignalGenerator { #[inline] fn try_seek(&mut self, duration: Duration) -> Result<(), SeekError> { - let seek = duration.as_secs_f32() * (self.sample_rate as f32) / self.period; + let seek = duration.as_secs_f32() * (self.sample_rate.get() as f32) / self.period; self.phase = seek.rem_euclid(1.0f32); Ok(()) } @@ -166,12 +168,13 @@ impl Source for SignalGenerator { #[cfg(test)] mod tests { + use crate::math::nz; use crate::source::{Function, SignalGenerator}; use approx::assert_abs_diff_eq; #[test] fn square() { - let mut wf = SignalGenerator::new(2000, 500.0f32, Function::Square); + let mut wf = SignalGenerator::new(nz!(2000), 500.0f32, Function::Square); assert_eq!(wf.next(), Some(1.0f32)); assert_eq!(wf.next(), Some(1.0f32)); assert_eq!(wf.next(), Some(-1.0f32)); @@ -184,7 +187,7 @@ mod tests { #[test] fn triangle() { - let mut wf = SignalGenerator::new(8000, 1000.0f32, Function::Triangle); + let mut wf = SignalGenerator::new(nz!(8000), 1000.0f32, Function::Triangle); assert_eq!(wf.next(), Some(-1.0f32)); assert_eq!(wf.next(), Some(-0.5f32)); assert_eq!(wf.next(), Some(0.0f32)); @@ -205,7 +208,7 @@ mod tests { #[test] fn saw() { - let mut wf = SignalGenerator::new(200, 50.0f32, Function::Sawtooth); + let mut wf = SignalGenerator::new(nz!(200), 50.0f32, Function::Sawtooth); assert_eq!(wf.next(), Some(0.0f32)); assert_eq!(wf.next(), Some(0.5f32)); assert_eq!(wf.next(), Some(-1.0f32)); @@ -217,7 +220,7 @@ mod tests { #[test] fn sine() { - let mut wf = SignalGenerator::new(1000, 100f32, Function::Sine); + let mut wf = SignalGenerator::new(nz!(1000), 100f32, Function::Sine); assert_abs_diff_eq!(wf.next().unwrap(), 0.0f32); assert_abs_diff_eq!(wf.next().unwrap(), 0.58778525f32); diff --git a/src/source/sine.rs b/src/source/sine.rs index e3814435..a85fcb63 100644 --- a/src/source/sine.rs +++ b/src/source/sine.rs @@ -1,4 +1,5 @@ use crate::common::{ChannelCount, SampleRate}; +use crate::math::nz; use crate::source::{Function, SignalGenerator}; use crate::Source; use std::time::Duration; @@ -17,7 +18,7 @@ pub struct SineWave { } impl SineWave { - const SAMPLE_RATE: u32 = 48000; + const SAMPLE_RATE: SampleRate = nz!(48000); /// The frequency of the sine. #[inline] @@ -45,7 +46,7 @@ impl Source for SineWave { #[inline] fn channels(&self) -> ChannelCount { - 1 + nz!(1) } #[inline] diff --git a/src/source/skip.rs b/src/source/skip.rs index ee3c7365..36f8d210 100644 --- a/src/source/skip.rs +++ b/src/source/skip.rs @@ -40,7 +40,7 @@ where } let ns_per_sample: u128 = - NS_PER_SECOND / input.sample_rate() as u128 / input.channels() as u128; + NS_PER_SECOND / input.sample_rate().get() as u128 / input.channels().get() as u128; // Check if we need to skip only part of the current span. if span_len as u128 * ns_per_sample > duration.as_nanos() { @@ -61,8 +61,8 @@ where I: Source, { let samples_per_channel: u128 = - duration.as_nanos() * input.sample_rate() as u128 / NS_PER_SECOND; - let samples_to_skip: u128 = samples_per_channel * input.channels() as u128; + duration.as_nanos() * input.sample_rate().get() as u128 / NS_PER_SECOND; + let samples_to_skip: u128 = samples_per_channel * input.channels().get() as u128; skip_samples(input, samples_to_skip as usize); } @@ -165,6 +165,7 @@ mod tests { use crate::buffer::SamplesBuffer; use crate::common::{ChannelCount, SampleRate}; + use crate::math::nz; use crate::source::Source; fn test_skip_duration_samples_left( @@ -173,13 +174,14 @@ mod tests { seconds: u32, seconds_to_skip: u32, ) { - let buf_len = (sample_rate * channels as u32 * seconds) as usize; + let buf_len = (sample_rate.get() * channels.get() as u32 * seconds) as usize; assert!(buf_len < 10 * 1024 * 1024); let data: Vec = vec![0f32; buf_len]; let test_buffer = SamplesBuffer::new(channels, sample_rate, data); let seconds_left = seconds.saturating_sub(seconds_to_skip); - let samples_left_expected = (sample_rate * channels as u32 * seconds_left) as usize; + let samples_left_expected = + (sample_rate.get() * channels.get() as u32 * seconds_left) as usize; let samples_left = test_buffer .skip_duration(Duration::from_secs(seconds_to_skip as u64)) .count(); @@ -190,7 +192,7 @@ mod tests { macro_rules! skip_duration_test_block { ($(channels: $ch:expr, sample rate: $sr:expr, seconds: $sec:expr, seconds to skip: $sec_to_skip:expr;)+) => { $( - test_skip_duration_samples_left($ch, $sr, $sec, $sec_to_skip); + test_skip_duration_samples_left(nz!($ch), nz!($sr), $sec, $sec_to_skip); )+ } } diff --git a/src/source/speed.rs b/src/source/speed.rs index 25e25efc..0f3581f0 100644 --- a/src/source/speed.rs +++ b/src/source/speed.rs @@ -127,7 +127,8 @@ where #[inline] fn sample_rate(&self) -> SampleRate { - (self.input.sample_rate() as f32 * self.factor) as u32 + SampleRate::new((self.input.sample_rate().get() as f32 * self.factor).max(1.0) as u32) + .expect("minimum is 1.0 > 0") } #[inline] diff --git a/src/source/square.rs b/src/source/square.rs index ac6bd678..4beaf33d 100644 --- a/src/source/square.rs +++ b/src/source/square.rs @@ -1,4 +1,5 @@ use crate::common::{ChannelCount, SampleRate}; +use crate::math::nz; use crate::source::{Function, SignalGenerator}; use crate::Source; use std::time::Duration; @@ -17,7 +18,7 @@ pub struct SquareWave { } impl SquareWave { - const SAMPLE_RATE: u32 = 48000; + const SAMPLE_RATE: SampleRate = nz!(48000); /// The frequency of the sine. #[inline] @@ -45,7 +46,7 @@ impl Source for SquareWave { #[inline] fn channels(&self) -> ChannelCount { - 1 + nz!(1) } #[inline] diff --git a/src/source/take.rs b/src/source/take.rs index 9abcf7f6..60ee9c6d 100644 --- a/src/source/take.rs +++ b/src/source/take.rs @@ -58,7 +58,7 @@ where /// Returns the duration elapsed for each sample extracted. #[inline] fn get_duration_per_sample(input: &I) -> Duration { - let ns = NANOS_PER_SEC / (input.sample_rate() as u64 * input.channels() as u64); + let ns = NANOS_PER_SEC / (input.sample_rate().get() as u64 * input.channels().get() as u64); // \|/ the maximum value of `ns` is one billion, so this can't fail Duration::new(0, ns as u32) } diff --git a/src/source/triangle.rs b/src/source/triangle.rs index eb73801d..a35df3af 100644 --- a/src/source/triangle.rs +++ b/src/source/triangle.rs @@ -1,4 +1,5 @@ use crate::common::{ChannelCount, SampleRate}; +use crate::math::nz; use crate::source::{Function, SignalGenerator}; use crate::Source; use std::time::Duration; @@ -17,7 +18,7 @@ pub struct TriangleWave { } impl TriangleWave { - const SAMPLE_RATE: SampleRate = 48000; + const SAMPLE_RATE: SampleRate = nz!(48000); /// The frequency of the sine. #[inline] @@ -45,7 +46,7 @@ impl Source for TriangleWave { #[inline] fn channels(&self) -> ChannelCount { - 1 + nz!(1) } #[inline] diff --git a/src/static_buffer.rs b/src/static_buffer.rs index 014d824c..2a3c54a2 100644 --- a/src/static_buffer.rs +++ b/src/static_buffer.rs @@ -6,7 +6,8 @@ //! //! ``` //! use rodio::static_buffer::StaticSamplesBuffer; -//! let _ = StaticSamplesBuffer::new(1, 44100, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); +//! use core::num::NonZero; +//! let _ = StaticSamplesBuffer::new(NonZero::new(1).unwrap(), NonZero::new(44100).unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); //! ``` //! @@ -41,12 +42,9 @@ impl StaticSamplesBuffer { sample_rate: SampleRate, data: &'static [Sample], ) -> StaticSamplesBuffer { - assert!(channels != 0); - assert!(sample_rate != 0); - let duration_ns = 1_000_000_000u64.checked_mul(data.len() as u64).unwrap() - / sample_rate as u64 - / channels as u64; + / sample_rate.get() as u64 + / channels.get() as u64; let duration = Duration::new( duration_ns / 1_000_000_000, (duration_ns % 1_000_000_000) as u32, @@ -106,29 +104,18 @@ impl Iterator for StaticSamplesBuffer { #[cfg(test)] mod tests { + use crate::math::nz; use crate::source::Source; use crate::static_buffer::StaticSamplesBuffer; #[test] fn basic() { - let _ = StaticSamplesBuffer::new(1, 44100, &[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); - } - - #[test] - #[should_panic] - fn panic_if_zero_channels() { - StaticSamplesBuffer::new(0, 44100, &[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); - } - - #[test] - #[should_panic] - fn panic_if_zero_sample_rate() { - StaticSamplesBuffer::new(1, 0, &[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); + let _ = StaticSamplesBuffer::new(nz!(1), nz!(44100), &[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); } #[test] fn duration_basic() { - let buf = StaticSamplesBuffer::new(2, 2, &[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); + let buf = StaticSamplesBuffer::new(nz!(2), nz!(2), &[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); let dur = buf.total_duration().unwrap(); assert_eq!(dur.as_secs(), 1); assert_eq!(dur.subsec_nanos(), 500_000_000); @@ -136,7 +123,7 @@ mod tests { #[test] fn iteration() { - let mut buf = StaticSamplesBuffer::new(1, 44100, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + let mut buf = StaticSamplesBuffer::new(nz!(1), nz!(44100), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); assert_eq!(buf.next(), Some(1.0)); assert_eq!(buf.next(), Some(2.0)); assert_eq!(buf.next(), Some(3.0)); diff --git a/src/stream.rs b/src/stream.rs index 40eb8d8c..31c54791 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,5 +1,6 @@ use crate::common::{ChannelCount, SampleRate}; use crate::decoder; +use crate::math::nz; use crate::mixer::{mixer, Mixer, MixerSource}; use crate::sink::Sink; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; @@ -9,7 +10,7 @@ use std::marker::Sync; use std::sync::Arc; use std::{error, fmt}; -const HZ_44100: SampleRate = 44_100; +const HZ_44100: SampleRate = nz!(44_100); /// `cpal::Stream` container. /// Use `mixer()` method to control output. @@ -37,7 +38,7 @@ struct OutputStreamConfig { impl Default for OutputStreamConfig { fn default() -> Self { Self { - channel_count: 2, + channel_count: nz!(2), sample_rate: HZ_44100, buffer_size: BufferSize::Default, sample_format: SampleFormat::F32, @@ -98,7 +99,6 @@ impl OutputStreamBuilder { /// Sets number of output stream's channels. pub fn with_channels(mut self, channel_count: ChannelCount) -> OutputStreamBuilder { - assert!(channel_count > 0); self.config.channel_count = channel_count; self } @@ -123,15 +123,17 @@ impl OutputStreamBuilder { self } - /// Set available parameters from a CPAL supported config. You can ge list of + /// Set available parameters from a CPAL supported config. You can get list of /// such configurations for an output device using [crate::stream::supported_output_configs()] pub fn with_supported_config( mut self, config: &cpal::SupportedStreamConfig, ) -> OutputStreamBuilder { self.config = OutputStreamConfig { - channel_count: config.channels() as ChannelCount, - sample_rate: config.sample_rate().0 as SampleRate, + channel_count: ChannelCount::new(config.channels()) + .expect("cpal should never return a zero channel output"), + sample_rate: SampleRate::new(config.sample_rate().0) + .expect("cpal should never return a zero sample rate output"), // In case of supported range limit buffer size to avoid unexpectedly long playback delays. buffer_size: clamp_supported_buffer_size(config.buffer_size(), 1024), sample_format: config.sample_format(), @@ -142,8 +144,10 @@ impl OutputStreamBuilder { /// Set all output stream parameters at once from CPAL stream config. pub fn with_config(mut self, config: &cpal::StreamConfig) -> OutputStreamBuilder { self.config = OutputStreamConfig { - channel_count: config.channels as ChannelCount, - sample_rate: config.sample_rate.0 as SampleRate, + channel_count: ChannelCount::new(config.channels) + .expect("cpal should never return a zero channel output"), + sample_rate: SampleRate::new(config.sample_rate.0) + .expect("cpal should never return a zero sample rate output"), buffer_size: config.buffer_size, ..self.config }; @@ -234,8 +238,8 @@ where impl From<&OutputStreamConfig> for StreamConfig { fn from(config: &OutputStreamConfig) -> Self { cpal::StreamConfig { - channels: config.channel_count as cpal::ChannelCount, - sample_rate: cpal::SampleRate(config.sample_rate), + channels: config.channel_count.get() as cpal::ChannelCount, + sample_rate: cpal::SampleRate(config.sample_rate.get()), buffer_size: config.buffer_size, } } @@ -326,11 +330,6 @@ impl OutputStream { if let BufferSize::Fixed(sz) = config.buffer_size { assert!(sz > 0, "fixed buffer size is greater than zero"); } - assert!(config.sample_rate > 0, "sample rate is greater than zero"); - assert!( - config.channel_count > 0, - "channel number is greater than zero" - ); } fn open( @@ -488,7 +487,7 @@ fn supported_output_configs( let max_rate = sf.max_sample_rate(); let min_rate = sf.min_sample_rate(); let mut formats = vec![sf.with_max_sample_rate()]; - let preferred_rate = cpal::SampleRate(HZ_44100); + let preferred_rate = cpal::SampleRate(HZ_44100.get()); if preferred_rate < max_rate && preferred_rate > min_rate { formats.push(sf.with_sample_rate(preferred_rate)) } diff --git a/src/wav_output.rs b/src/wav_output.rs index 3b91d8a4..43221fa6 100644 --- a/src/wav_output.rs +++ b/src/wav_output.rs @@ -1,4 +1,4 @@ -use crate::{ChannelCount, Source}; +use crate::Source; use hound::{SampleFormat, WavSpec}; use std::path; @@ -10,8 +10,8 @@ pub fn output_to_wav( wav_file: impl AsRef, ) -> Result<(), Box> { let format = WavSpec { - channels: source.channels() as ChannelCount, - sample_rate: source.sample_rate(), + channels: source.channels().get(), + sample_rate: source.sample_rate().get(), bits_per_sample: 32, sample_format: SampleFormat::Float, }; @@ -26,7 +26,6 @@ pub fn output_to_wav( #[cfg(test)] mod test { use super::output_to_wav; - use crate::common::ChannelCount; use crate::Source; use std::io::BufReader; use std::time::Duration; @@ -46,8 +45,8 @@ mod test { let mut reader = hound::WavReader::new(BufReader::new(file)).expect("wav file can be read back"); let reference = make_source(); - assert_eq!(reference.sample_rate(), reader.spec().sample_rate); - assert_eq!(reference.channels(), reader.spec().channels as ChannelCount); + assert_eq!(reference.sample_rate().get(), reader.spec().sample_rate); + assert_eq!(reference.channels().get(), reader.spec().channels); let actual_samples: Vec = reader.samples::().map(|x| x.unwrap()).collect(); let expected_samples: Vec = reference.collect(); diff --git a/tests/seek.rs b/tests/seek.rs index 38691144..d85759d1 100644 --- a/tests/seek.rs +++ b/tests/seek.rs @@ -1,4 +1,4 @@ -use rodio::{ChannelCount, Decoder, Source}; +use rodio::{Decoder, Source}; use rstest::rstest; use rstest_reuse::{self, *}; use std::io::{BufReader, Read, Seek}; @@ -121,18 +121,20 @@ fn seek_does_not_break_channel_order( ) { let mut source = get_rl(format); let channels = source.channels(); - assert_eq!(channels, 2, "test needs a stereo beep file"); + assert_eq!(channels.get(), 2, "test needs a stereo beep file"); let beep_range = second_channel_beep_range(&mut source); let beep_start = Duration::from_secs_f32( - beep_range.start as f32 / source.channels() as f32 / source.sample_rate() as f32, + beep_range.start as f32 + / source.channels().get() as f32 + / source.sample_rate().get() as f32, ); let mut source = get_rl(format); let mut channel_offset = 0; for offset in [1, 4, 7, 40, 41, 120, 179] - .map(|offset| offset as f32 / (source.sample_rate() as f32)) + .map(|offset| offset as f32 / (source.sample_rate().get() as f32)) .map(Duration::from_secs_f32) { source.next(); // WINDOW is even, make the amount of calls to next @@ -144,7 +146,7 @@ fn seek_does_not_break_channel_order( let samples: Vec<_> = source.by_ref().take(100).collect(); let channel0 = 0 + channel_offset; assert!( - is_silent(&samples, source.channels(), channel0), + is_silent(&samples, source.channels().get() as usize, channel0), "channel0 should be silent, channel0 starts at idx: {channel0} seek: {beep_start:?} + {offset:?} @@ -152,7 +154,7 @@ fn seek_does_not_break_channel_order( ); let channel1 = (1 + channel_offset) % 2; assert!( - !is_silent(&samples, source.channels(), channel1), + !is_silent(&samples, source.channels().get() as usize, channel1), "channel1 should not be silent, channel1; starts at idx: {channel1} seek: {beep_start:?} + {offset:?} @@ -165,7 +167,7 @@ fn second_channel_beep_range(source: &mut R) -> std::ops::Rang where R: Iterator, { - let channels = source.channels() as usize; + let channels = source.channels().get() as usize; let samples: Vec = source.by_ref().collect(); const WINDOW: usize = 50; @@ -202,21 +204,15 @@ where .next_multiple_of(channels); let samples = &samples[beep_starts..beep_starts + 100]; - assert!( - is_silent(samples, channels as ChannelCount, 0), - "{samples:?}" - ); - assert!( - !is_silent(samples, channels as ChannelCount, 1), - "{samples:?}" - ); + assert!(is_silent(samples, channels, 0), "{samples:?}"); + assert!(!is_silent(samples, channels, 1), "{samples:?}"); beep_starts..beep_ends } -fn is_silent(samples: &[f32], channels: ChannelCount, channel: usize) -> bool { +fn is_silent(samples: &[f32], channels: usize, channel: usize) -> bool { assert_eq!(samples.len(), 100); - let channel = samples.iter().skip(channel).step_by(channels as usize); + let channel = samples.iter().skip(channel).step_by(channels); let volume = channel.map(|s| s.abs()).sum::() / samples.len() as f32 * channels as f32; const BASICALLY_ZERO: f32 = 0.0001; @@ -224,8 +220,8 @@ fn is_silent(samples: &[f32], channels: ChannelCount, channel: usize) -> bool { } fn time_remaining(decoder: Decoder) -> Duration { - let rate = decoder.sample_rate() as f64; - let n_channels = decoder.channels() as f64; + let rate = decoder.sample_rate().get() as f64; + let n_channels = decoder.channels().get() as f64; let n_samples = decoder.into_iter().count() as f64; Duration::from_secs_f64(n_samples / rate / n_channels) }