diff --git a/examples/image.rs b/examples/image.rs index db029bf..04693df 100755 --- a/examples/image.rs +++ b/examples/image.rs @@ -127,6 +127,16 @@ fn main() { ); } + // Find a way to not clone the dataloader + let dataset = FaceLandmarksDataset::new( + "examples/image/dataset/face_landmarks.csv", + env::current_dir().unwrap().join("examples/image/dataset/"), + ); + let loader = DataLoader::builder(dataset) + .batch_size(4) + .collate_fn(TorchCollate) + .build(); + loader .into_iter() .enumerate() diff --git a/src/indexable/dataloader/builder.rs b/src/indexable/dataloader/builder.rs index 4fa9d51..397793b 100644 --- a/src/indexable/dataloader/builder.rs +++ b/src/indexable/dataloader/builder.rs @@ -29,6 +29,8 @@ where #[cfg(feature = "rayon")] /// Number of threads to use. num_threads: usize, + /// Prefetch buffer size. + prefetch_size: usize, } // FIXME: kind of strange that we require DefaultCollatte even if in the end we may won't use it @@ -56,6 +58,7 @@ where collate_fn: DefaultCollate, #[cfg(feature = "rayon")] num_threads, + prefetch_size: 0, } } } @@ -83,6 +86,12 @@ where self } + /// Set the size of the prefetch buffer. + pub fn prefetch_size(mut self, prefetch_size: usize) -> Self { + self.prefetch_size = prefetch_size; + self + } + /// Drop the lasts element if they don't feat into a batch. For instance if a dataset have 13 /// samples and a `batch_size` of 5, the last 3 samples will be dropped. pub fn drop_last(mut self) -> Self { @@ -102,6 +111,7 @@ where collate_fn, #[cfg(feature = "rayon")] num_threads: self.num_threads, + prefetch_size: 0, } } @@ -122,6 +132,7 @@ where collate_fn: self.collate_fn, #[cfg(feature = "rayon")] num_threads: self.num_threads, + prefetch_size: self.prefetch_size, } } /// Create a `Dataloader` from a [`Builder`]. @@ -156,6 +167,7 @@ where dataset: self.dataset, batch_sampler: self.batch_sampler, collate_fn: self.collate_fn, + prefetch_size: self.prefetch_size, } } } diff --git a/src/indexable/dataloader/mod.rs b/src/indexable/dataloader/mod.rs index 0089fed..40bfbf7 100644 --- a/src/indexable/dataloader/mod.rs +++ b/src/indexable/dataloader/mod.rs @@ -3,9 +3,13 @@ use super::fetch::{Fetcher, MapDatasetFetcher}; use crate::{ collate::{Collate, DefaultCollate}, - sampler::{BatchIterator, BatchSampler, Sampler, SequentialSampler}, + sampler::{BatchSampler, Sampler, SequentialSampler}, Dataset, Len, }; +use std::marker::PhantomData; +use std::sync::mpsc; +use std::sync::mpsc::sync_channel; +use std::thread::JoinHandle; mod builder; use builder::Builder; @@ -37,6 +41,8 @@ pub struct DataLoader { batch_sampler: BatchSampler, /// Collate function. collate_fn: C, + /// Prefetch buffer size. + prefetch_size: usize, } impl DataLoader @@ -52,13 +58,15 @@ where impl DataLoader where - D: Dataset + Sync, - S: Sampler, - C: Collate, + D: Dataset + Sync + Send + 'static, + S: Sampler + Send + 'static, + C: Collate + Send + 'static, D::Sample: Send, + C::Output: Send, { - /// Return not owning iterator over the dataloader. - pub fn iter(&self) -> SingleProcessDataLoaderIter<'_, D, S, C> { + /// Return owning iterator over the dataloader. + /// TODO: Find a way to not consume the Dataloader + pub fn iter(self) -> SingleProcessDataLoaderIter { SingleProcessDataLoaderIter::new(self) } } @@ -77,55 +85,62 @@ where /// Iterate over the dataloader with a single thread. #[derive(Debug)] -pub struct SingleProcessDataLoaderIter<'dataset, D, S = SequentialSampler, C = DefaultCollate> +pub struct SingleProcessDataLoaderIter where D: Dataset + Sync, S: Sampler, C: Collate, { - /// The batch iterator of this iterator. - sampler_iter: BatchIterator, /// Number of sample yielded. num_yielded: u64, - /// Used to fetch the data from the dataset. - data_fetcher: MapDatasetFetcher<'dataset, D, C>, + rx: mpsc::Receiver, + _thread_handle: JoinHandle<()>, + sampler: PhantomData, } -impl<'dataset, D, S, C> SingleProcessDataLoaderIter<'dataset, D, S, C> +impl SingleProcessDataLoaderIter where - D: Dataset + Sync, - S: Sampler, - C: Collate, + D: Dataset + Sync + Send + 'static, + S: Sampler + Send + 'static, + C: Collate + Send + 'static, D::Sample: Send, + C::Output: Send, { - fn new(loader: &DataLoader) -> SingleProcessDataLoaderIter<'_, D, S, C> { - SingleProcessDataLoaderIter { - sampler_iter: loader.batch_sampler.iter(), - num_yielded: 0, - data_fetcher: MapDatasetFetcher { + fn new(loader: DataLoader) -> SingleProcessDataLoaderIter { + let (tx, rx) = sync_channel(loader.prefetch_size); + let _thread_handle = std::thread::spawn(move || { + let loader = loader; + let mut data_fetcher = MapDatasetFetcher { dataset: &loader.dataset, collate_fn: &loader.collate_fn, - }, + }; + let mut sampler_iter = loader.batch_sampler.iter(); + while let Some(index) = sampler_iter.next() { + let data = data_fetcher.fetch(index); + tx.send(data).expect("Cannot send data to channel"); + } + }); + + SingleProcessDataLoaderIter { + num_yielded: 0, + rx, + _thread_handle, + sampler: PhantomData::default(), } } - fn next_index(&mut self) -> Option> { - self.sampler_iter.next() - } + fn next_data(&mut self) -> Option { - let index = self.next_index(); - if let Some(index) = index { - let data = self.data_fetcher.fetch(index); - return Some(data); - } - None + self.rx.recv().ok() } } -impl<'dataset, D, S, C> Iterator for SingleProcessDataLoaderIter<'dataset, D, S, C> + +impl Iterator for SingleProcessDataLoaderIter where - D: Dataset + Sync, - S: Sampler, - C: Collate, + D: Dataset + Sync + Send + 'static, + S: Sampler + Send + 'static, + C: Collate + Send + 'static, D::Sample: Send, + C::Output: Send, { type Item = C::Output; fn next(&mut self) -> Option { @@ -138,15 +153,17 @@ where None } } -impl<'dataset, D, S, C> IntoIterator for &'dataset DataLoader + +impl IntoIterator for DataLoader where - D: Dataset + Sync, - S: Sampler, - C: Collate, + D: Dataset + Sync + Send + 'static, + S: Sampler + Send + 'static, + C: Collate + Send + 'static, D::Sample: Send, + C::Output: Send, { type Item = C::Output; - type IntoIter = SingleProcessDataLoaderIter<'dataset, D, S, C>; + type IntoIter = SingleProcessDataLoaderIter; fn into_iter(self) -> Self::IntoIter { self.iter() @@ -196,7 +213,7 @@ mod tests { let dataset = vec![1, 2, 3, 4]; let dataloader = DataLoader::builder(dataset).batch_size(2).build(); - let mut iter = dataloader.iter(); + let mut iter = dataloader.clone().iter(); assert_eq!(iter.next(), Some(array![1, 2])); assert_eq!(iter.next(), Some(array![3, 4])); assert_eq!(iter.next(), None); @@ -216,6 +233,18 @@ mod tests { assert_eq!(iter.next(), Some(vec![String::from("b")])); assert_eq!(iter.next(), None); } + + #[test] + fn one_dimension_basic_string_with_prefetching() { + let dataset = vec![String::from("a"), String::from("b")]; + let dataloader = DataLoader::builder(dataset).prefetch_size(10).build(); + + let mut iter = dataloader.iter(); + assert_eq!(iter.next(), Some(vec![String::from("a")])); + assert_eq!(iter.next(), Some(vec![String::from("b")])); + assert_eq!(iter.next(), None); + } + #[test] fn collate() { let dataset = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];