diff --git a/Cargo.toml b/Cargo.toml index d050153..cbcce31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,10 +36,10 @@ once_cell = { version = "1.17.1", optional = true } [dev-dependencies] -criterion = { version = "^0.5", features = ["html_reports"] } -csv = "^1.1" -image = "^0.24" -nshare = { version = "^0.9", features = ["ndarray", "image"] } +criterion = { version = "0.5.1", features = ["html_reports"] } +csv = "1.1.6" +image = "0.24.3" +nshare = { version = "0.9.0", features = ["ndarray", "image"] } [[example]] name = "image" diff --git a/src/indexable/dataloader/builder.rs b/src/indexable/dataloader/builder.rs index eba9471..eac93b6 100644 --- a/src/indexable/dataloader/builder.rs +++ b/src/indexable/dataloader/builder.rs @@ -90,7 +90,7 @@ where /// Set the size of the prefetch buffer. pub fn prefetch_size(mut self, prefetch_size: usize) -> Self { - self.prefetch_size = prefetch_size; + self.prefetch_size = max(prefetch_size, 1); self } diff --git a/src/indexable/dataloader/mod.rs b/src/indexable/dataloader/mod.rs index 3897385..e783a16 100644 --- a/src/indexable/dataloader/mod.rs +++ b/src/indexable/dataloader/mod.rs @@ -6,7 +6,6 @@ use crate::{ sampler::{BatchSampler, Sampler, SequentialSampler}, Dataset, Len, }; -use std::marker::PhantomData; use std::sync::mpsc::sync_channel; use std::sync::{mpsc, Arc}; use std::thread::JoinHandle; @@ -65,8 +64,8 @@ where C::Output: Send, { /// Return not owning iterator over the dataloader. - pub fn iter(&self) -> SingleProcessDataLoaderIter { - SingleProcessDataLoaderIter::new(self) + pub fn iter(&self) -> SingleProcessDataLoaderIter { + SingleProcessDataLoaderIter::::new(self) } } @@ -84,28 +83,28 @@ where /// Iterate over the dataloader with a single thread. #[derive(Debug)] -pub struct SingleProcessDataLoaderIter +pub struct SingleProcessDataLoaderIter where - D: Dataset, - S: Sampler, - C: Collate, + CO: Send, { /// Number of sample yielded. num_yielded: u64, - rx: mpsc::Receiver, + rx: mpsc::Receiver, _thread_handle: JoinHandle<()>, - sampler: PhantomData, } -impl SingleProcessDataLoaderIter +impl SingleProcessDataLoaderIter where - D: Dataset + Send + Sync + 'static, - S: Sampler + Send + Sync + 'static, - C: Collate + Send + Sync + 'static, - D::Sample: Send, - C::Output: Send, + CO: Send, { - fn new(loader: &DataLoader) -> SingleProcessDataLoaderIter { + fn new(loader: &DataLoader) -> SingleProcessDataLoaderIter + where + D: Dataset + Sync + Send + 'static, + S: Sampler + Send + Sync + 'static, + C: Collate + Send + Sync + 'static, + C::Output: Send, + D::Sample: Send, + { let (tx, rx) = sync_channel(loader.prefetch_size); let data_fetcher = MapDatasetFetcher { @@ -133,11 +132,10 @@ where num_yielded: 0, rx, _thread_handle, - sampler: PhantomData, } } - fn next_data(&mut self) -> Option { + fn next_data(&mut self) -> Option { match self.rx.recv() { Ok(data) => Some(data), Err(_) => { @@ -149,15 +147,11 @@ where } } -impl Iterator for SingleProcessDataLoaderIter +impl Iterator for SingleProcessDataLoaderIter where - D: Dataset + Send + Sync + 'static, - S: Sampler + Send + Sync + 'static, - C: Collate + Send + Sync + 'static, - D::Sample: Send, - C::Output: Send, + CO: Send, { - type Item = C::Output; + type Item = CO; fn next(&mut self) -> Option { let data = self.next_data(); @@ -178,7 +172,7 @@ where C::Output: Send, { type Item = C::Output; - type IntoIter = SingleProcessDataLoaderIter; + type IntoIter = SingleProcessDataLoaderIter; fn into_iter(self) -> Self::IntoIter { self.iter()