Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ChannelCount and SampleRate NonZero #709

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion benches/pipeline.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::num::NonZero;
use std::time::Duration;

use divan::Bencher;
use rodio::ChannelCount;
use rodio::{source::UniformSourceIterator, Source};

mod shared;
Expand Down Expand Up @@ -31,7 +33,11 @@ fn long(bencher: Bencher) {
.buffered()
.reverb(Duration::from_secs_f32(0.05), 0.3)
.skippable();
let resampled = UniformSourceIterator::new(effects_applied, 2, 40_000);
let resampled = UniformSourceIterator::new(
effects_applied,
ChannelCount::new(2).unwrap(),
NonZero::new(40_000).unwrap(),
);
resampled.for_each(divan::black_box_drop)
})
}
Expand Down
3 changes: 2 additions & 1 deletion benches/resampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rodio::source::UniformSourceIterator;
mod shared;
use shared::music_wav;

use rodio::Source;
use rodio::{SampleRate, Source};

fn main() {
divan::main();
Expand All @@ -31,6 +31,7 @@ const COMMON_SAMPLE_RATES: [u32; 12] = [

#[divan::bench(args = COMMON_SAMPLE_RATES)]
fn resample_to(bencher: Bencher, target_sample_rate: u32) {
let target_sample_rate = SampleRate::new(target_sample_rate).unwrap();
bencher
.with_inputs(|| {
let source = music_wav();
Expand Down
4 changes: 2 additions & 2 deletions benches/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use rodio::{ChannelCount, Sample, SampleRate, Source};

pub struct TestSource {
samples: vec::IntoIter<Sample>,
channels: u16,
sample_rate: u32,
channels: ChannelCount,
sample_rate: SampleRate,
total_duration: Duration,
}

Expand Down
3 changes: 2 additions & 1 deletion examples/custom_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use cpal::{BufferSize, SampleFormat};
use rodio::source::SineWave;
use rodio::Source;
use std::error::Error;
use std::num::NonZero;
use std::thread;
use std::time::Duration;

Expand All @@ -15,7 +16,7 @@ fn main() -> Result<(), Box<dyn Error>> {
// No need to set all parameters explicitly here,
// the defaults were set from the device's description.
.with_buffer_size(BufferSize::Fixed(256))
.with_sample_rate(48_000)
.with_sample_rate(NonZero::new(48_000).unwrap())
.with_sample_format(SampleFormat::F32)
// Note that the function below still tries alternative configs if the specified one fails.
// If you need to only use the exact specified configuration,
Expand Down
3 changes: 2 additions & 1 deletion examples/mix_multiple_sources.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use rodio::mixer;
use rodio::source::{SineWave, Source};
use std::error::Error;
use std::num::NonZero;
use std::time::Duration;

fn main() -> Result<(), Box<dyn Error>> {
// Construct a dynamic controller and mixer, stream_handle, and sink.
let (controller, mixer) = mixer::mixer(2, 44_100);
let (controller, mixer) = mixer::mixer(NonZero::new(2).unwrap(), NonZero::new(44_100).unwrap());
let stream_handle = rodio::OutputStreamBuilder::open_default_stream()?;
let sink = rodio::Sink::connect_new(&stream_handle.mixer());

Expand Down
3 changes: 2 additions & 1 deletion examples/signal_generator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Test signal generator example.

use std::error::Error;
use std::num::NonZero;

fn main() -> Result<(), Box<dyn Error>> {
use rodio::source::{chirp, Function, SignalGenerator, Source};
Expand All @@ -11,7 +12,7 @@ fn main() -> Result<(), Box<dyn Error>> {

let test_signal_duration = Duration::from_millis(1000);
let interval_duration = Duration::from_millis(1500);
let sample_rate = 48000;
let sample_rate = NonZero::new(48000).unwrap();

println!("Playing 1000 Hz tone");
stream_handle.mixer().add(
Expand Down
46 changes: 18 additions & 28 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
//!
//! ```
//! use rodio::buffer::SamplesBuffer;
//! let _ = SamplesBuffer::new(1, 44100, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
//! use core::num::NonZero;
//! let _ = SamplesBuffer::new(NonZero::new(1).unwrap(), NonZero::new(44100).unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
//! ```
//!

Expand All @@ -30,7 +31,6 @@ impl SamplesBuffer {
///
/// # Panic
///
/// - Panics if the number of channels is zero.
/// - Panics if the samples rate is zero.
/// - Panics if the length of the buffer is larger than approximately 16 billion elements.
/// This is because the calculation of the duration would overflow.
Expand All @@ -39,13 +39,10 @@ impl SamplesBuffer {
where
D: Into<Vec<Sample>>,
{
assert!(channels >= 1);
assert!(sample_rate >= 1);

let data = data.into();
let duration_ns = 1_000_000_000u64.checked_mul(data.len() as u64).unwrap()
/ sample_rate as u64
/ channels as u64;
/ sample_rate.get() as u64
/ channels.get() as u64;
let duration = Duration::new(
duration_ns / 1_000_000_000,
(duration_ns % 1_000_000_000) as u32,
Expand Down Expand Up @@ -89,14 +86,15 @@ impl Source for SamplesBuffer {
// and due to the constant sample_rate we can jump to the right
// sample directly.

let curr_channel = self.pos % self.channels() as usize;
let new_pos = pos.as_secs_f32() * self.sample_rate() as f32 * self.channels() as f32;
let curr_channel = self.pos % self.channels().get() as usize;
let new_pos =
pos.as_secs_f32() * self.sample_rate().get() as f32 * self.channels().get() as f32;
// saturate pos at the end of the source
let new_pos = new_pos as usize;
let new_pos = new_pos.min(self.data.len());

// make sure the next sample is for the right channel
let new_pos = new_pos.next_multiple_of(self.channels() as usize);
let new_pos = new_pos.next_multiple_of(self.channels().get() as usize);
let new_pos = new_pos - curr_channel;

self.pos = new_pos;
Expand All @@ -123,36 +121,25 @@ impl Iterator for SamplesBuffer {
#[cfg(test)]
mod tests {
use crate::buffer::SamplesBuffer;
use crate::math::nz;
use crate::source::Source;

#[test]
fn basic() {
let _ = SamplesBuffer::new(1, 44100, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
}

#[test]
#[should_panic]
fn panic_if_zero_channels() {
SamplesBuffer::new(0, 44100, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
}

#[test]
#[should_panic]
fn panic_if_zero_sample_rate() {
SamplesBuffer::new(1, 0, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
let _ = SamplesBuffer::new(nz!(1), nz!(44100), vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
}

#[test]
fn duration_basic() {
let buf = SamplesBuffer::new(2, 2, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
let buf = SamplesBuffer::new(nz!(2), nz!(2), vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
let dur = buf.total_duration().unwrap();
assert_eq!(dur.as_secs(), 1);
assert_eq!(dur.subsec_nanos(), 500_000_000);
}

#[test]
fn iteration() {
let mut buf = SamplesBuffer::new(1, 44100, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let mut buf = SamplesBuffer::new(nz!(1), nz!(44100), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert_eq!(buf.next(), Some(1.0));
assert_eq!(buf.next(), Some(2.0));
assert_eq!(buf.next(), Some(3.0));
Expand All @@ -171,8 +158,8 @@ mod tests {

#[test]
fn channel_order_stays_correct() {
const SAMPLE_RATE: SampleRate = 100;
const CHANNELS: ChannelCount = 2;
const SAMPLE_RATE: SampleRate = nz!(100);
const CHANNELS: ChannelCount = nz!(2);
let mut buf = SamplesBuffer::new(
CHANNELS,
SAMPLE_RATE,
Expand All @@ -182,7 +169,10 @@ mod tests {
.collect::<Vec<_>>(),
);
buf.try_seek(Duration::from_secs(5)).unwrap();
assert_eq!(buf.next(), Some(5.0 * SAMPLE_RATE as f32 * CHANNELS as f32));
assert_eq!(
buf.next(),
Some(5.0 * SAMPLE_RATE.get() as f32 * CHANNELS.get() as f32)
);

assert!(buf.next().is_some_and(|s| s.trunc() as i32 % 2 == 1));
assert!(buf.next().is_some_and(|s| s.trunc() as i32 % 2 == 0));
Expand Down
8 changes: 5 additions & 3 deletions src/common.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::num::NonZero;

/// Stream sample rate (a frame rate or samples per second per channel).
pub type SampleRate = u32;
pub type SampleRate = NonZero<u32>;

/// Number of channels in a stream.
pub type ChannelCount = u16;
/// Number of channels in a stream. Can never be Zero
pub type ChannelCount = NonZero<u16>;

/// Represents value of a single sample.
/// Silence corresponds to the value `0.0`. The expected amplitude range is -1.0...1.0.
Expand Down
47 changes: 25 additions & 22 deletions src/conversions/channels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ where
from: ChannelCount,
to: ChannelCount,
sample_repeat: Option<Sample>,
next_output_sample_pos: ChannelCount,
next_output_sample_pos: u16,
}

impl<I> ChannelCountConverter<I>
Expand All @@ -26,9 +26,6 @@ where
///
#[inline]
pub fn new(input: I, from: ChannelCount, to: ChannelCount) -> ChannelCountConverter<I> {
assert!(from >= 1);
assert!(to >= 1);

ChannelCountConverter {
input,
from,
Expand Down Expand Up @@ -65,7 +62,7 @@ where
self.sample_repeat = value;
value
}
x if x < self.from => self.input.next(),
x if x < self.from.get() => self.input.next(),
1 => self.sample_repeat,
_ => Some(0.0),
};
Expand All @@ -74,11 +71,11 @@ where
self.next_output_sample_pos += 1;
}

if self.next_output_sample_pos == self.to {
if self.next_output_sample_pos == self.to.get() {
self.next_output_sample_pos = 0;

if self.from > self.to {
for _ in self.to..self.from {
for _ in self.to.get()..self.from.get() {
self.input.next(); // discarding extra input
}
}
Expand All @@ -91,13 +88,13 @@ where
fn size_hint(&self) -> (usize, Option<usize>) {
let (min, max) = self.input.size_hint();

let consumed = std::cmp::min(self.from, self.next_output_sample_pos) as usize;
let consumed = std::cmp::min(self.from.get(), self.next_output_sample_pos) as usize;

let min = ((min + consumed) / self.from as usize * self.to as usize)
let min = ((min + consumed) / self.from.get() as usize * self.to.get() as usize)
.saturating_sub(self.next_output_sample_pos as usize);

let max = max.map(|max| {
((max + consumed) / self.from as usize * self.to as usize)
((max + consumed) / self.from.get() as usize * self.to.get() as usize)
.saturating_sub(self.next_output_sample_pos as usize)
});

Expand All @@ -111,31 +108,37 @@ impl<I> ExactSizeIterator for ChannelCountConverter<I> where I: ExactSizeIterato
mod test {
use super::ChannelCountConverter;
use crate::common::ChannelCount;
use crate::math::nz;
use crate::Sample;

#[test]
fn remove_channels() {
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let output = ChannelCountConverter::new(input.into_iter(), 3, 2).collect::<Vec<_>>();
let output =
ChannelCountConverter::new(input.into_iter(), nz!(3), nz!(2)).collect::<Vec<_>>();
assert_eq!(output, [1.0, 2.0, 4.0, 5.0]);

let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let output = ChannelCountConverter::new(input.into_iter(), 4, 1).collect::<Vec<_>>();
let output =
ChannelCountConverter::new(input.into_iter(), nz!(4), nz!(1)).collect::<Vec<_>>();
assert_eq!(output, [1.0, 5.0]);
}

#[test]
fn add_channels() {
let input = vec![1.0, 2.0, 3.0, 4.0];
let output = ChannelCountConverter::new(input.into_iter(), 1, 2).collect::<Vec<_>>();
let output =
ChannelCountConverter::new(input.into_iter(), nz!(1), nz!(2)).collect::<Vec<_>>();
assert_eq!(output, [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]);

let input = vec![1.0, 2.0];
let output = ChannelCountConverter::new(input.into_iter(), 1, 4).collect::<Vec<_>>();
let output =
ChannelCountConverter::new(input.into_iter(), nz!(1), nz!(4)).collect::<Vec<_>>();
assert_eq!(output, [1.0, 1.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0]);

let input = vec![1.0, 2.0, 3.0, 4.0];
let output = ChannelCountConverter::new(input.into_iter(), 2, 4).collect::<Vec<_>>();
let output =
ChannelCountConverter::new(input.into_iter(), nz!(2), nz!(4)).collect::<Vec<_>>();
assert_eq!(output, [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0]);
}

Expand All @@ -152,24 +155,24 @@ mod test {
assert_eq!(converter.size_hint(), (0, Some(0)));
}

test(&[1.0, 2.0, 3.0], 1, 2);
test(&[1.0, 2.0, 3.0, 4.0], 2, 4);
test(&[1.0, 2.0, 3.0, 4.0], 4, 2);
test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 8);
test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], 4, 1);
test(&[1.0, 2.0, 3.0], nz!(1), nz!(2));
test(&[1.0, 2.0, 3.0, 4.0], nz!(2), nz!(4));
test(&[1.0, 2.0, 3.0, 4.0], nz!(4), nz!(2));
test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], nz!(3), nz!(8));
test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], nz!(4), nz!(1));
}

#[test]
fn len_more() {
let input = vec![1.0, 2.0, 3.0, 4.0];
let output = ChannelCountConverter::new(input.into_iter(), 2, 3);
let output = ChannelCountConverter::new(input.into_iter(), nz!(2), nz!(3));
assert_eq!(output.len(), 6);
}

#[test]
fn len_less() {
let input = vec![1.0, 2.0, 3.0, 4.0];
let output = ChannelCountConverter::new(input.into_iter(), 2, 1);
let output = ChannelCountConverter::new(input.into_iter(), nz!(2), nz!(1));
assert_eq!(output.len(), 2);
}
}
Loading