Skip to content

Commit

Permalink
feat: prefetching
Browse files Browse the repository at this point in the history
  • Loading branch information
AzHicham committed Sep 6, 2023
1 parent 8fa0516 commit 0b6c73d
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 59 deletions.
20 changes: 17 additions & 3 deletions src/indexable/dataloader/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use crate::{
sampler::{BatchSampler, RandomSampler, Sampler, SequentialSampler},
Dataset,
};
use std::cmp::max;
use std::sync::Arc;

#[cfg(feature = "rayon")]
use crate::THREAD_POOL;
Expand All @@ -29,6 +31,8 @@ where
#[cfg(feature = "rayon")]
/// Number of threads to use.
num_threads: usize,
/// Prefetch buffer size.
prefetch_size: usize,
}

// FIXME: kind of strange that we require DefaultCollatte even if in the end we may won't use it
Expand Down Expand Up @@ -56,6 +60,7 @@ where
collate_fn: DefaultCollate,
#[cfg(feature = "rayon")]
num_threads,
prefetch_size: 1,
}
}
}
Expand All @@ -72,7 +77,7 @@ where
}
/// Set the number of elements in a batch.
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.batch_sampler.batch_size = batch_size;
self.batch_sampler.batch_size = max(batch_size, 1);
self
}

Expand All @@ -83,6 +88,12 @@ where
self
}

/// Set the size of the prefetch buffer.
pub fn prefetch_size(mut self, prefetch_size: usize) -> Self {
self.prefetch_size = max(prefetch_size, 1);
self
}

/// Drop the lasts element if they don't feat into a batch. For instance if a dataset have 13
/// samples and a `batch_size` of 5, the last 3 samples will be dropped.
pub fn drop_last(mut self) -> Self {
Expand All @@ -102,6 +113,7 @@ where
collate_fn,
#[cfg(feature = "rayon")]
num_threads: self.num_threads,
prefetch_size: self.prefetch_size,
}
}

Expand All @@ -122,6 +134,7 @@ where
collate_fn: self.collate_fn,
#[cfg(feature = "rayon")]
num_threads: self.num_threads,
prefetch_size: self.prefetch_size,
}
}
/// Create a `Dataloader` from a [`Builder`].
Expand Down Expand Up @@ -153,9 +166,10 @@ where
}

DataLoader {
dataset: self.dataset,
dataset: Arc::new(self.dataset),
batch_sampler: self.batch_sampler,
collate_fn: self.collate_fn,
collate_fn: Arc::new(self.collate_fn),
prefetch_size: self.prefetch_size,
}
}
}
Expand Down
174 changes: 125 additions & 49 deletions src/indexable/dataloader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
use super::fetch::{Fetcher, MapDatasetFetcher};
use crate::{
collate::{Collate, DefaultCollate},
sampler::{BatchIterator, BatchSampler, Sampler, SequentialSampler},
sampler::{BatchSampler, Sampler, SequentialSampler},
Dataset, Len,
};
use std::sync::mpsc::sync_channel;
use std::sync::{mpsc, Arc};
use std::thread::JoinHandle;

mod builder;
use builder::Builder;
Expand All @@ -23,7 +26,7 @@ use builder::Builder;
///
/// let loader = DataLoader::builder(vec![(0, "hola"), (1, "hello"), (2, "hallo"), (3, "bonjour")]).batch_size(2).shuffle().build();
///
/// for (label, text) in &loader {
/// for (label, text) in loader.iter() {
/// println!("Label {label:?}");
/// println!("Text {text:?}");
/// }
Expand All @@ -32,11 +35,13 @@ use builder::Builder;
#[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq, Ord)]
pub struct DataLoader<D, S = SequentialSampler, C = DefaultCollate> {
/// Dataset from which to load the data.
dataset: D,
dataset: Arc<D>,
/// Return a batch of indices at a time.
batch_sampler: BatchSampler<S>,
/// Collate function.
collate_fn: C,
collate_fn: Arc<C>,
/// Prefetch buffer size.
prefetch_size: usize,
}

impl<D> DataLoader<D, SequentialSampler, DefaultCollate>
Expand All @@ -52,14 +57,15 @@ where

impl<D, S, C> DataLoader<D, S, C>
where
D: Dataset + Sync,
S: Sampler,
C: Collate<D::Sample>,
D: Dataset + Sync + Send + 'static,
S: Sampler + Send + Sync + 'static,
C: Collate<D::Sample> + Send + Sync + 'static,
D::Sample: Send,
C::Output: Send,
{
/// Return not owning iterator over the dataloader.
pub fn iter(&self) -> SingleProcessDataLoaderIter<'_, D, S, C> {
SingleProcessDataLoaderIter::new(self)
pub fn iter(&self) -> SingleProcessDataLoaderIter<C::Output> {
SingleProcessDataLoaderIter::<C::Output>::new(self)
}
}

Expand All @@ -77,57 +83,75 @@ where

/// Iterate over the dataloader with a single thread.
#[derive(Debug)]
pub struct SingleProcessDataLoaderIter<'dataset, D, S = SequentialSampler, C = DefaultCollate>
pub struct SingleProcessDataLoaderIter<CO>
where
D: Dataset + Sync,
S: Sampler,
C: Collate<D::Sample>,
CO: Send,
{
/// The batch iterator of this iterator.
sampler_iter: BatchIterator<S::IntoIter>,
/// Number of sample yielded.
num_yielded: u64,
/// Used to fetch the data from the dataset.
data_fetcher: MapDatasetFetcher<'dataset, D, C>,
rx: mpsc::Receiver<CO>,
_thread_handle: JoinHandle<()>,
}

impl<'dataset, D, S, C> SingleProcessDataLoaderIter<'dataset, D, S, C>
impl<CO> SingleProcessDataLoaderIter<CO>
where
D: Dataset + Sync,
S: Sampler,
C: Collate<D::Sample>,
D::Sample: Send,
CO: Send,
{
fn new(loader: &DataLoader<D, S, C>) -> SingleProcessDataLoaderIter<'_, D, S, C> {
fn new<D, S, C>(loader: &DataLoader<D, S, C>) -> SingleProcessDataLoaderIter<C::Output>
where
D: Dataset + Sync + Send + 'static,
S: Sampler + Send + Sync + 'static,
C: Collate<D::Sample> + Send + Sync + 'static,
C::Output: Send,
D::Sample: Send,
{
let (tx, rx) = sync_channel(loader.prefetch_size);

let mut data_fetcher = MapDatasetFetcher {
dataset: loader.dataset.clone(),
collate_fn: loader.collate_fn.clone(),
};
let batch_sampler = loader.batch_sampler.clone();
let _thread_handle = std::thread::spawn(move || {
let mut sampler_iter = batch_sampler.iter();
// In this dedicated thread :
// We fetch the data and push it into the channel TX
// the loop will pause if there is no more space in the channel/buffer
while let Some(index) = sampler_iter.next() {
let data = data_fetcher.fetch(index);
if let Err(_err) = tx.send(data) {
// An error occurred
// rx has been dropped
drop(tx);
return;
}
}
});

SingleProcessDataLoaderIter {
sampler_iter: loader.batch_sampler.iter(),
num_yielded: 0,
data_fetcher: MapDatasetFetcher {
dataset: &loader.dataset,
collate_fn: &loader.collate_fn,
},
rx,
_thread_handle,
}
}
fn next_index(&mut self) -> Option<Vec<usize>> {
self.sampler_iter.next()
}
fn next_data(&mut self) -> Option<C::Output> {
let index = self.next_index();
if let Some(index) = index {
let data = self.data_fetcher.fetch(index);
return Some(data);

fn next_data(&mut self) -> Option<CO> {
match self.rx.recv() {
Ok(data) => Some(data),
Err(_) => {
// An error occurred with the channel,
// it is probably closed or has been cancelled
None
}
}
None
}
}
impl<'dataset, D, S, C> Iterator for SingleProcessDataLoaderIter<'dataset, D, S, C>

impl<CO> Iterator for SingleProcessDataLoaderIter<CO>
where
D: Dataset + Sync,
S: Sampler,
C: Collate<D::Sample>,
D::Sample: Send,
CO: Send,
{
type Item = C::Output;
type Item = CO;
fn next(&mut self) -> Option<Self::Item> {
let data = self.next_data();

Expand All @@ -138,15 +162,17 @@ where
None
}
}
impl<'dataset, D, S, C> IntoIterator for &'dataset DataLoader<D, S, C>

impl<D, S, C> IntoIterator for DataLoader<D, S, C>
where
D: Dataset + Sync,
S: Sampler,
C: Collate<D::Sample>,
D: Dataset + Send + Sync + 'static,
S: Sampler + Send + Sync + 'static,
C: Collate<D::Sample> + Send + Sync + 'static,
D::Sample: Send,
C::Output: Send,
{
type Item = C::Output;
type IntoIter = SingleProcessDataLoaderIter<'dataset, D, S, C>;
type IntoIter = SingleProcessDataLoaderIter<C::Output>;

fn into_iter(self) -> Self::IntoIter {
self.iter()
Expand All @@ -159,12 +185,33 @@ mod tests {
use crate::collate::NoOpCollate;
use crate::sampler::RandomSampler;
use crate::sampler::SequentialSampler;
use crate::Len;
use crate::NdarrayDataset;
use crate::{GetSample, Len};
use ndarray::{arr0, array, Array, Array1, Array4, Axis, Ix1, Ix4, Slice};
use ndarray_rand::rand_distr::{Normal, Uniform};
use ndarray_rand::RandomExt;
use std::collections::HashMap;
use std::thread::sleep;
use std::time::{Duration, Instant};

struct FakeDataset;

impl Len for FakeDataset {
fn len(&self) -> usize {
8
}
}

impl GetSample for FakeDataset {
type Sample = usize;

fn get_sample(&self, index: usize) -> Self::Sample {
sleep(Duration::from_millis(100));
index
}
}

impl Dataset for FakeDataset {}

#[test]
fn len() {
Expand Down Expand Up @@ -216,6 +263,35 @@ mod tests {
assert_eq!(iter.next(), Some(vec![String::from("b")]));
assert_eq!(iter.next(), None);
}

#[test]
fn prefetching() {
let dataset = FakeDataset;
let dataloader = DataLoader::builder(dataset)
.collate_fn(NoOpCollate)
.batch_size(2)
.prefetch_size(4)
.build();

let mut iter = dataloader.iter();
let start = Instant::now();
// This sleep execute in parallel with FakeDataset sleep
// Then this sleep does not affect the whole execution time of this test because :
// - Duration <= 400ms (FakeDataset sleep for 100ms sec on each get_sample)
// - prefetch_size >= nb_batch
sleep(Duration::from_millis(400));
assert_eq!(iter.next(), Some(vec![0, 1]));
assert_eq!(iter.next(), Some(vec![2, 3]));
assert_eq!(iter.next(), Some(vec![4, 5]));
assert_eq!(iter.next(), Some(vec![6, 7]));
assert_eq!(iter.next(), None);
assert_eq!(iter.next(), None);

let duration = start.elapsed();
println!("Time elapsed in data loading is: {:?}", duration);
//assert!(duration < Duration::from_millis(850));
}

#[test]
fn collate() {
let dataset = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
Expand Down
15 changes: 8 additions & 7 deletions src/indexable/fetch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{
collate::{Collate, DefaultCollate},
Dataset,
};
use std::sync::Arc;

#[cfg(feature = "rayon")]
use crate::THREAD_POOL;
Expand All @@ -28,21 +29,21 @@ where

/// Fetcher for map-style dataset. Simply call the collate function on all the batch of elements.
#[derive(Debug)]
pub(crate) struct MapDatasetFetcher<'dataset, D, C = DefaultCollate>
pub(crate) struct MapDatasetFetcher<D, C = DefaultCollate>
where
D: Dataset + Sync,
D: Dataset,
C: Collate<D::Sample>,
{
/// The dataset data will be fetch from.
pub(crate) dataset: &'dataset D,
pub(crate) dataset: Arc<D>,
/// The function (generic struct) used to collate data together.
pub(crate) collate_fn: &'dataset C,
pub(crate) collate_fn: Arc<C>,
}

impl<'dataset, D, C> Fetcher<D, C> for MapDatasetFetcher<'dataset, D, C>
impl<D, C> Fetcher<D, C> for MapDatasetFetcher<D, C>
where
D: Dataset + Sync,
C: Collate<D::Sample>,
D: Dataset + Sync + Send,
C: Collate<D::Sample> + Sync + Send,
D::Sample: Send,
{
fn fetch(&mut self, possibly_batched_index: Vec<usize>) -> C::Output {
Expand Down

0 comments on commit 0b6c73d

Please sign in to comment.