diff --git a/src/indexable/dataloader/builder.rs b/src/indexable/dataloader/builder.rs index 0b46bd7..eac93b6 100644 --- a/src/indexable/dataloader/builder.rs +++ b/src/indexable/dataloader/builder.rs @@ -3,6 +3,8 @@ use crate::{ sampler::{BatchSampler, RandomSampler, Sampler, SequentialSampler}, Dataset, }; +use std::cmp::max; +use std::sync::Arc; #[cfg(feature = "rayon")] use crate::THREAD_POOL; @@ -29,6 +31,8 @@ where #[cfg(feature = "rayon")] /// Number of threads to use. num_threads: usize, + /// Prefetch buffer size. + prefetch_size: usize, } // FIXME: kind of strange that we require DefaultCollatte even if in the end we may won't use it @@ -56,6 +60,7 @@ where collate_fn: DefaultCollate, #[cfg(feature = "rayon")] num_threads, + prefetch_size: 1, } } } @@ -72,7 +77,7 @@ where } /// Set the number of elements in a batch. pub fn batch_size(mut self, batch_size: usize) -> Self { - self.batch_sampler.batch_size = batch_size; + self.batch_sampler.batch_size = max(batch_size, 1); self } @@ -83,6 +88,12 @@ where self } + /// Set the size of the prefetch buffer. + pub fn prefetch_size(mut self, prefetch_size: usize) -> Self { + self.prefetch_size = max(prefetch_size, 1); + self + } + /// Drop the lasts element if they don't feat into a batch. For instance if a dataset have 13 /// samples and a `batch_size` of 5, the last 3 samples will be dropped. pub fn drop_last(mut self) -> Self { @@ -102,6 +113,7 @@ where collate_fn, #[cfg(feature = "rayon")] num_threads: self.num_threads, + prefetch_size: self.prefetch_size, } } @@ -122,6 +134,7 @@ where collate_fn: self.collate_fn, #[cfg(feature = "rayon")] num_threads: self.num_threads, + prefetch_size: self.prefetch_size, } } /// Create a `Dataloader` from a [`Builder`]. @@ -153,9 +166,10 @@ where } DataLoader { - dataset: self.dataset, + dataset: Arc::new(self.dataset), batch_sampler: self.batch_sampler, - collate_fn: self.collate_fn, + collate_fn: Arc::new(self.collate_fn), + prefetch_size: self.prefetch_size, } } } diff --git a/src/indexable/dataloader/mod.rs b/src/indexable/dataloader/mod.rs index 0089fed..ac5cbe3 100644 --- a/src/indexable/dataloader/mod.rs +++ b/src/indexable/dataloader/mod.rs @@ -3,9 +3,12 @@ use super::fetch::{Fetcher, MapDatasetFetcher}; use crate::{ collate::{Collate, DefaultCollate}, - sampler::{BatchIterator, BatchSampler, Sampler, SequentialSampler}, + sampler::{BatchSampler, Sampler, SequentialSampler}, Dataset, Len, }; +use std::sync::mpsc::sync_channel; +use std::sync::{mpsc, Arc}; +use std::thread::JoinHandle; mod builder; use builder::Builder; @@ -23,7 +26,7 @@ use builder::Builder; /// /// let loader = DataLoader::builder(vec![(0, "hola"), (1, "hello"), (2, "hallo"), (3, "bonjour")]).batch_size(2).shuffle().build(); /// -/// for (label, text) in &loader { +/// for (label, text) in loader.iter() { /// println!("Label {label:?}"); /// println!("Text {text:?}"); /// } @@ -32,11 +35,13 @@ use builder::Builder; #[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq, Ord)] pub struct DataLoader { /// Dataset from which to load the data. - dataset: D, + dataset: Arc, /// Return a batch of indices at a time. batch_sampler: BatchSampler, /// Collate function. - collate_fn: C, + collate_fn: Arc, + /// Prefetch buffer size. + prefetch_size: usize, } impl DataLoader @@ -52,14 +57,15 @@ where impl DataLoader where - D: Dataset + Sync, - S: Sampler, - C: Collate, + D: Dataset + Sync + Send + 'static, + S: Sampler + Send + Sync + 'static, + C: Collate + Send + Sync + 'static, D::Sample: Send, + C::Output: Send, { /// Return not owning iterator over the dataloader. - pub fn iter(&self) -> SingleProcessDataLoaderIter<'_, D, S, C> { - SingleProcessDataLoaderIter::new(self) + pub fn iter(&self) -> SingleProcessDataLoaderIter { + SingleProcessDataLoaderIter::::new(self) } } @@ -77,57 +83,75 @@ where /// Iterate over the dataloader with a single thread. #[derive(Debug)] -pub struct SingleProcessDataLoaderIter<'dataset, D, S = SequentialSampler, C = DefaultCollate> +pub struct SingleProcessDataLoaderIter where - D: Dataset + Sync, - S: Sampler, - C: Collate, + CO: Send, { - /// The batch iterator of this iterator. - sampler_iter: BatchIterator, /// Number of sample yielded. num_yielded: u64, - /// Used to fetch the data from the dataset. - data_fetcher: MapDatasetFetcher<'dataset, D, C>, + rx: mpsc::Receiver, + _thread_handle: JoinHandle<()>, } -impl<'dataset, D, S, C> SingleProcessDataLoaderIter<'dataset, D, S, C> +impl SingleProcessDataLoaderIter where - D: Dataset + Sync, - S: Sampler, - C: Collate, - D::Sample: Send, + CO: Send, { - fn new(loader: &DataLoader) -> SingleProcessDataLoaderIter<'_, D, S, C> { + fn new(loader: &DataLoader) -> SingleProcessDataLoaderIter + where + D: Dataset + Sync + Send + 'static, + S: Sampler + Send + Sync + 'static, + C: Collate + Send + Sync + 'static, + C::Output: Send, + D::Sample: Send, + { + let (tx, rx) = sync_channel(loader.prefetch_size); + + let mut data_fetcher = MapDatasetFetcher { + dataset: loader.dataset.clone(), + collate_fn: loader.collate_fn.clone(), + }; + let batch_sampler = loader.batch_sampler.clone(); + let _thread_handle = std::thread::spawn(move || { + let mut sampler_iter = batch_sampler.iter(); + // In this dedicated thread : + // We fetch the data and push it into the channel TX + // the loop will pause if there is no more space in the channel/buffer + while let Some(index) = sampler_iter.next() { + let data = data_fetcher.fetch(index); + if let Err(_err) = tx.send(data) { + // An error occurred + // rx has been dropped + drop(tx); + return; + } + } + }); + SingleProcessDataLoaderIter { - sampler_iter: loader.batch_sampler.iter(), num_yielded: 0, - data_fetcher: MapDatasetFetcher { - dataset: &loader.dataset, - collate_fn: &loader.collate_fn, - }, + rx, + _thread_handle, } } - fn next_index(&mut self) -> Option> { - self.sampler_iter.next() - } - fn next_data(&mut self) -> Option { - let index = self.next_index(); - if let Some(index) = index { - let data = self.data_fetcher.fetch(index); - return Some(data); + + fn next_data(&mut self) -> Option { + match self.rx.recv() { + Ok(data) => Some(data), + Err(_) => { + // An error occurred with the channel, + // it is probably closed or has been cancelled + None + } } - None } } -impl<'dataset, D, S, C> Iterator for SingleProcessDataLoaderIter<'dataset, D, S, C> + +impl Iterator for SingleProcessDataLoaderIter where - D: Dataset + Sync, - S: Sampler, - C: Collate, - D::Sample: Send, + CO: Send, { - type Item = C::Output; + type Item = CO; fn next(&mut self) -> Option { let data = self.next_data(); @@ -138,15 +162,17 @@ where None } } -impl<'dataset, D, S, C> IntoIterator for &'dataset DataLoader + +impl IntoIterator for DataLoader where - D: Dataset + Sync, - S: Sampler, - C: Collate, + D: Dataset + Send + Sync + 'static, + S: Sampler + Send + Sync + 'static, + C: Collate + Send + Sync + 'static, D::Sample: Send, + C::Output: Send, { type Item = C::Output; - type IntoIter = SingleProcessDataLoaderIter<'dataset, D, S, C>; + type IntoIter = SingleProcessDataLoaderIter; fn into_iter(self) -> Self::IntoIter { self.iter() @@ -159,12 +185,33 @@ mod tests { use crate::collate::NoOpCollate; use crate::sampler::RandomSampler; use crate::sampler::SequentialSampler; - use crate::Len; use crate::NdarrayDataset; + use crate::{GetSample, Len}; use ndarray::{arr0, array, Array, Array1, Array4, Axis, Ix1, Ix4, Slice}; use ndarray_rand::rand_distr::{Normal, Uniform}; use ndarray_rand::RandomExt; use std::collections::HashMap; + use std::thread::sleep; + use std::time::{Duration, Instant}; + + struct FakeDataset; + + impl Len for FakeDataset { + fn len(&self) -> usize { + 8 + } + } + + impl GetSample for FakeDataset { + type Sample = usize; + + fn get_sample(&self, index: usize) -> Self::Sample { + sleep(Duration::from_millis(100)); + index + } + } + + impl Dataset for FakeDataset {} #[test] fn len() { @@ -216,6 +263,35 @@ mod tests { assert_eq!(iter.next(), Some(vec![String::from("b")])); assert_eq!(iter.next(), None); } + + #[test] + fn prefetching() { + let dataset = FakeDataset; + let dataloader = DataLoader::builder(dataset) + .collate_fn(NoOpCollate) + .batch_size(2) + .prefetch_size(4) + .build(); + + let mut iter = dataloader.iter(); + let start = Instant::now(); + // This sleep execute in parallel with FakeDataset sleep + // Then this sleep does not affect the whole execution time of this test because : + // - Duration <= 400ms (FakeDataset sleep for 100ms sec on each get_sample) + // - prefetch_size >= nb_batch + sleep(Duration::from_millis(400)); + assert_eq!(iter.next(), Some(vec![0, 1])); + assert_eq!(iter.next(), Some(vec![2, 3])); + assert_eq!(iter.next(), Some(vec![4, 5])); + assert_eq!(iter.next(), Some(vec![6, 7])); + assert_eq!(iter.next(), None); + assert_eq!(iter.next(), None); + + let duration = start.elapsed(); + println!("Time elapsed in data loading is: {:?}", duration); + //assert!(duration < Duration::from_millis(850)); + } + #[test] fn collate() { let dataset = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; diff --git a/src/indexable/fetch/mod.rs b/src/indexable/fetch/mod.rs index d480075..2ec6a13 100644 --- a/src/indexable/fetch/mod.rs +++ b/src/indexable/fetch/mod.rs @@ -2,6 +2,7 @@ use crate::{ collate::{Collate, DefaultCollate}, Dataset, }; +use std::sync::Arc; #[cfg(feature = "rayon")] use crate::THREAD_POOL; @@ -28,21 +29,21 @@ where /// Fetcher for map-style dataset. Simply call the collate function on all the batch of elements. #[derive(Debug)] -pub(crate) struct MapDatasetFetcher<'dataset, D, C = DefaultCollate> +pub(crate) struct MapDatasetFetcher where - D: Dataset + Sync, + D: Dataset, C: Collate, { /// The dataset data will be fetch from. - pub(crate) dataset: &'dataset D, + pub(crate) dataset: Arc, /// The function (generic struct) used to collate data together. - pub(crate) collate_fn: &'dataset C, + pub(crate) collate_fn: Arc, } -impl<'dataset, D, C> Fetcher for MapDatasetFetcher<'dataset, D, C> +impl Fetcher for MapDatasetFetcher where - D: Dataset + Sync, - C: Collate, + D: Dataset + Sync + Send, + C: Collate + Sync + Send, D::Sample: Send, { fn fetch(&mut self, possibly_batched_index: Vec) -> C::Output {