Skip to content

Commit 30a9dd0

Browse files
committed
This makes ChannelCount NonZero<u16> and channels not zero asserts
I ran into a lot of bugs while adding tests that had to do with channel being set to zero somewhere. While this change makes the API slightly less easy to use it prevents very hard to debug crashes/underflows etc. Performance might drop in decoders, the current implementation makes the bound check every time `channels` is called which is once per span. This could be cached to alleviate that.
1 parent 1671d36 commit 30a9dd0

40 files changed

+245
-204
lines changed

benches/pipeline.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::time::Duration;
22

33
use divan::Bencher;
4+
use rodio::ChannelCount;
45
use rodio::{source::UniformSourceIterator, Source};
56

67
mod shared;
@@ -31,7 +32,8 @@ fn long(bencher: Bencher) {
3132
.buffered()
3233
.reverb(Duration::from_secs_f32(0.05), 0.3)
3334
.skippable();
34-
let resampled = UniformSourceIterator::new(effects_applied, 2, 40_000);
35+
let resampled =
36+
UniformSourceIterator::new(effects_applied, ChannelCount::new(2).unwrap(), 40_000);
3537
resampled.for_each(divan::black_box_drop)
3638
})
3739
}

benches/shared.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rodio::{ChannelCount, Sample, SampleRate, Source};
66

77
pub struct TestSource {
88
samples: vec::IntoIter<Sample>,
9-
channels: u16,
9+
channels: ChannelCount,
1010
sample_rate: u32,
1111
total_duration: Duration,
1212
}

examples/mix_multiple_sources.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use rodio::mixer;
22
use rodio::source::{SineWave, Source};
33
use std::error::Error;
4+
use std::num::NonZero;
45
use std::time::Duration;
56

67
fn main() -> Result<(), Box<dyn Error>> {
78
// Construct a dynamic controller and mixer, stream_handle, and sink.
8-
let (controller, mixer) = mixer::mixer(2, 44_100);
9+
let (controller, mixer) = mixer::mixer(NonZero::new(2).unwrap(), 44_100);
910
let stream_handle = rodio::OutputStreamBuilder::open_default_stream()?;
1011
let sink = rodio::Sink::connect_new(&stream_handle.mixer());
1112

src/buffer.rs

+16-19
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
//!
77
//! ```
88
//! use rodio::buffer::SamplesBuffer;
9-
//! let _ = SamplesBuffer::new(1, 44100, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
9+
//! use rodio::ChannelCount;
10+
//! let _ = SamplesBuffer::new(ChannelCount::new(1).unwrap(), 44100, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1011
//! ```
1112
//!
1213
@@ -30,7 +31,6 @@ impl SamplesBuffer {
3031
///
3132
/// # Panic
3233
///
33-
/// - Panics if the number of channels is zero.
3434
/// - Panics if the samples rate is zero.
3535
/// - Panics if the length of the buffer is larger than approximately 16 billion elements.
3636
/// This is because the calculation of the duration would overflow.
@@ -39,13 +39,12 @@ impl SamplesBuffer {
3939
where
4040
D: Into<Vec<Sample>>,
4141
{
42-
assert!(channels >= 1);
4342
assert!(sample_rate >= 1);
4443

4544
let data = data.into();
4645
let duration_ns = 1_000_000_000u64.checked_mul(data.len() as u64).unwrap()
4746
/ sample_rate as u64
48-
/ channels as u64;
47+
/ channels.get() as u64;
4948
let duration = Duration::new(
5049
duration_ns / 1_000_000_000,
5150
(duration_ns % 1_000_000_000) as u32,
@@ -89,14 +88,14 @@ impl Source for SamplesBuffer {
8988
// and due to the constant sample_rate we can jump to the right
9089
// sample directly.
9190

92-
let curr_channel = self.pos % self.channels() as usize;
93-
let new_pos = pos.as_secs_f32() * self.sample_rate() as f32 * self.channels() as f32;
91+
let curr_channel = self.pos % self.channels().get() as usize;
92+
let new_pos = pos.as_secs_f32() * self.sample_rate() as f32 * self.channels().get() as f32;
9493
// saturate pos at the end of the source
9594
let new_pos = new_pos as usize;
9695
let new_pos = new_pos.min(self.data.len());
9796

9897
// make sure the next sample is for the right channel
99-
let new_pos = new_pos.next_multiple_of(self.channels() as usize);
98+
let new_pos = new_pos.next_multiple_of(self.channels().get() as usize);
10099
let new_pos = new_pos - curr_channel;
101100

102101
self.pos = new_pos;
@@ -123,36 +122,31 @@ impl Iterator for SamplesBuffer {
123122
#[cfg(test)]
124123
mod tests {
125124
use crate::buffer::SamplesBuffer;
125+
use crate::math::ch;
126126
use crate::source::Source;
127127

128128
#[test]
129129
fn basic() {
130-
let _ = SamplesBuffer::new(1, 44100, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
131-
}
132-
133-
#[test]
134-
#[should_panic]
135-
fn panic_if_zero_channels() {
136-
SamplesBuffer::new(0, 44100, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
130+
let _ = SamplesBuffer::new(ch!(1), 44100, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
137131
}
138132

139133
#[test]
140134
#[should_panic]
141135
fn panic_if_zero_sample_rate() {
142-
SamplesBuffer::new(1, 0, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
136+
SamplesBuffer::new(ch!(1), 0, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
143137
}
144138

145139
#[test]
146140
fn duration_basic() {
147-
let buf = SamplesBuffer::new(2, 2, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
141+
let buf = SamplesBuffer::new(ch!(2), 2, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
148142
let dur = buf.total_duration().unwrap();
149143
assert_eq!(dur.as_secs(), 1);
150144
assert_eq!(dur.subsec_nanos(), 500_000_000);
151145
}
152146

153147
#[test]
154148
fn iteration() {
155-
let mut buf = SamplesBuffer::new(1, 44100, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
149+
let mut buf = SamplesBuffer::new(ch!(1), 44100, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
156150
assert_eq!(buf.next(), Some(1.0));
157151
assert_eq!(buf.next(), Some(2.0));
158152
assert_eq!(buf.next(), Some(3.0));
@@ -172,7 +166,7 @@ mod tests {
172166
#[test]
173167
fn channel_order_stays_correct() {
174168
const SAMPLE_RATE: SampleRate = 100;
175-
const CHANNELS: ChannelCount = 2;
169+
const CHANNELS: ChannelCount = ch!(2);
176170
let mut buf = SamplesBuffer::new(
177171
CHANNELS,
178172
SAMPLE_RATE,
@@ -182,7 +176,10 @@ mod tests {
182176
.collect::<Vec<_>>(),
183177
);
184178
buf.try_seek(Duration::from_secs(5)).unwrap();
185-
assert_eq!(buf.next(), Some(5.0 * SAMPLE_RATE as f32 * CHANNELS as f32));
179+
assert_eq!(
180+
buf.next(),
181+
Some(5.0 * SAMPLE_RATE as f32 * CHANNELS.get() as f32)
182+
);
186183

187184
assert!(buf.next().is_some_and(|s| s.trunc() as i32 % 2 == 1));
188185
assert!(buf.next().is_some_and(|s| s.trunc() as i32 % 2 == 0));

src/common.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
use std::num::NonZero;
2+
13
/// Stream sample rate (a frame rate or samples per second per channel).
24
pub type SampleRate = u32;
35

4-
/// Number of channels in a stream.
5-
pub type ChannelCount = u16;
6+
/// Number of channels in a stream. Can never be Zero
7+
pub type ChannelCount = NonZero<u16>;
68

79
/// Represents value of a single sample.
810
/// Silence corresponds to the value `0.0`. The expected amplitude range is -1.0...1.0.

src/conversions/channels.rs

+26-22
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ where
1111
from: ChannelCount,
1212
to: ChannelCount,
1313
sample_repeat: Option<Sample>,
14-
next_output_sample_pos: ChannelCount,
14+
next_output_sample_pos: u16,
1515
}
1616

1717
impl<I> ChannelCountConverter<I>
@@ -26,9 +26,6 @@ where
2626
///
2727
#[inline]
2828
pub fn new(input: I, from: ChannelCount, to: ChannelCount) -> ChannelCountConverter<I> {
29-
assert!(from >= 1);
30-
assert!(to >= 1);
31-
3229
ChannelCountConverter {
3330
input,
3431
from,
@@ -65,7 +62,7 @@ where
6562
self.sample_repeat = value;
6663
value
6764
}
68-
x if x < self.from => self.input.next(),
65+
x if x < self.from.get() => self.input.next(),
6966
1 => self.sample_repeat,
7067
_ => Some(0.0),
7168
};
@@ -74,11 +71,11 @@ where
7471
self.next_output_sample_pos += 1;
7572
}
7673

77-
if self.next_output_sample_pos == self.to {
74+
if self.next_output_sample_pos == self.to.get() {
7875
self.next_output_sample_pos = 0;
7976

8077
if self.from > self.to {
81-
for _ in self.to..self.from {
78+
for _ in self.to.get()..self.from.get() {
8279
self.input.next(); // discarding extra input
8380
}
8481
}
@@ -91,9 +88,9 @@ where
9188
fn size_hint(&self) -> (usize, Option<usize>) {
9289
let (min, max) = self.input.size_hint();
9390

94-
let consumed = std::cmp::min(self.from, self.next_output_sample_pos) as usize;
91+
let consumed = std::cmp::min(self.from.get(), self.next_output_sample_pos) as usize;
9592
let calculate = |size| {
96-
(size + consumed) / self.from as usize * self.to as usize
93+
(size + consumed) / self.from.get() as usize * self.to.get() as usize
9794
- self.next_output_sample_pos as usize
9895
};
9996

@@ -110,38 +107,45 @@ impl<I> ExactSizeIterator for ChannelCountConverter<I> where I: ExactSizeIterato
110107
mod test {
111108
use super::ChannelCountConverter;
112109
use crate::common::ChannelCount;
110+
use crate::math::ch;
113111
use crate::Sample;
114112

115113
#[test]
116114
fn remove_channels() {
117115
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
118-
let output = ChannelCountConverter::new(input.into_iter(), 3, 2).collect::<Vec<_>>();
116+
let output =
117+
ChannelCountConverter::new(input.into_iter(), ch!(3), ch!(2)).collect::<Vec<_>>();
119118
assert_eq!(output, [1.0, 2.0, 4.0, 5.0]);
120119

121120
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
122-
let output = ChannelCountConverter::new(input.into_iter(), 4, 1).collect::<Vec<_>>();
121+
let output =
122+
ChannelCountConverter::new(input.into_iter(), ch!(4), ch!(1)).collect::<Vec<_>>();
123123
assert_eq!(output, [1.0, 5.0]);
124124
}
125125

126126
#[test]
127127
fn add_channels() {
128128
let input = vec![1.0, 2.0, 3.0, 4.0];
129-
let output = ChannelCountConverter::new(input.into_iter(), 1, 2).collect::<Vec<_>>();
129+
let output =
130+
ChannelCountConverter::new(input.into_iter(), ch!(1), ch!(2)).collect::<Vec<_>>();
130131
assert_eq!(output, [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]);
131132

132133
let input = vec![1.0, 2.0];
133-
let output = ChannelCountConverter::new(input.into_iter(), 1, 4).collect::<Vec<_>>();
134+
let output =
135+
ChannelCountConverter::new(input.into_iter(), ch!(1), ch!(4)).collect::<Vec<_>>();
134136
assert_eq!(output, [1.0, 1.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0]);
135137

136138
let input = vec![1.0, 2.0, 3.0, 4.0];
137-
let output = ChannelCountConverter::new(input.into_iter(), 2, 4).collect::<Vec<_>>();
139+
let output =
140+
ChannelCountConverter::new(input.into_iter(), ch!(2), ch!(4)).collect::<Vec<_>>();
138141
assert_eq!(output, [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0]);
139142
}
140143

141144
#[test]
142145
fn size_hint() {
143146
fn test(input: &[Sample], from: ChannelCount, to: ChannelCount) {
144-
let mut converter = ChannelCountConverter::new(input.iter().copied(), from, to);
147+
let mut converter =
148+
ChannelCountConverter::new(input.iter().copied(), from, to);
145149
let count = converter.clone().count();
146150
for left_in_iter in (0..=count).rev() {
147151
println!("left_in_iter = {left_in_iter}");
@@ -151,24 +155,24 @@ mod test {
151155
assert_eq!(converter.size_hint(), (0, Some(0)));
152156
}
153157

154-
test(&[1.0, 2.0, 3.0], 1, 2);
155-
test(&[1.0, 2.0, 3.0, 4.0], 2, 4);
156-
test(&[1.0, 2.0, 3.0, 4.0], 4, 2);
157-
test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 8);
158-
test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], 4, 1);
158+
test(&[1.0, 2.0, 3.0], ch!(1), ch!(2));
159+
test(&[1.0, 2.0, 3.0, 4.0], ch!(2), ch!(4));
160+
test(&[1.0, 2.0, 3.0, 4.0], ch!(4), ch!(2));
161+
test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], ch!(3), ch!(8));
162+
test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], ch!(4), ch!(1));
159163
}
160164

161165
#[test]
162166
fn len_more() {
163167
let input = vec![1.0, 2.0, 3.0, 4.0];
164-
let output = ChannelCountConverter::new(input.into_iter(), 2, 3);
168+
let output = ChannelCountConverter::new(input.into_iter(), ch!(2), ch!(3));
165169
assert_eq!(output.len(), 6);
166170
}
167171

168172
#[test]
169173
fn len_less() {
170174
let input = vec![1.0, 2.0, 3.0, 4.0];
171-
let output = ChannelCountConverter::new(input.into_iter(), 2, 1);
175+
let output = ChannelCountConverter::new(input.into_iter(), ch!(2), ch!(1));
172176
assert_eq!(output.len(), 2);
173177
}
174178
}

0 commit comments

Comments
 (0)