Skip to content

Commit

Permalink
[WIP] Prefetching feature
Browse files Browse the repository at this point in the history
  • Loading branch information
AzHicham committed Sep 1, 2023
1 parent b9b5025 commit fb3d302
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 40 deletions.
10 changes: 10 additions & 0 deletions examples/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ fn main() {
);
}

// Find a way to not clone the dataloader
let dataset = FaceLandmarksDataset::new(
"examples/image/dataset/face_landmarks.csv",
env::current_dir().unwrap().join("examples/image/dataset/"),
);
let loader = DataLoader::builder(dataset)
.batch_size(4)
.collate_fn(TorchCollate)
.build();

loader
.into_iter()
.enumerate()
Expand Down
12 changes: 12 additions & 0 deletions src/indexable/dataloader/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ where
#[cfg(feature = "rayon")]
/// Number of threads to use.
num_threads: usize,
/// Prefetch buffer size.
prefetch_size: usize,
}

// FIXME: kind of strange that we require DefaultCollatte even if in the end we may won't use it
Expand Down Expand Up @@ -56,6 +58,7 @@ where
collate_fn: DefaultCollate,
#[cfg(feature = "rayon")]
num_threads,
prefetch_size: 0,
}
}
}
Expand Down Expand Up @@ -83,6 +86,12 @@ where
self
}

/// Set the size of the prefetch buffer.
pub fn prefetch_size(mut self, prefetch_size: usize) -> Self {
self.prefetch_size = prefetch_size;
self
}

/// Drop the lasts element if they don't feat into a batch. For instance if a dataset have 13
/// samples and a `batch_size` of 5, the last 3 samples will be dropped.
pub fn drop_last(mut self) -> Self {
Expand All @@ -102,6 +111,7 @@ where
collate_fn,
#[cfg(feature = "rayon")]
num_threads: self.num_threads,
prefetch_size: 0,
}
}

Expand All @@ -122,6 +132,7 @@ where
collate_fn: self.collate_fn,
#[cfg(feature = "rayon")]
num_threads: self.num_threads,
prefetch_size: self.prefetch_size,
}
}
/// Create a `Dataloader` from a [`Builder`].
Expand Down Expand Up @@ -156,6 +167,7 @@ where
dataset: self.dataset,
batch_sampler: self.batch_sampler,
collate_fn: self.collate_fn,
prefetch_size: self.prefetch_size,
}
}
}
Expand Down
109 changes: 69 additions & 40 deletions src/indexable/dataloader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
use super::fetch::{Fetcher, MapDatasetFetcher};
use crate::{
collate::{Collate, DefaultCollate},
sampler::{BatchIterator, BatchSampler, Sampler, SequentialSampler},
sampler::{BatchSampler, Sampler, SequentialSampler},
Dataset, Len,
};
use std::marker::PhantomData;
use std::sync::mpsc;
use std::sync::mpsc::sync_channel;
use std::thread::JoinHandle;

mod builder;
use builder::Builder;
Expand Down Expand Up @@ -37,6 +41,8 @@ pub struct DataLoader<D, S = SequentialSampler, C = DefaultCollate> {
batch_sampler: BatchSampler<S>,
/// Collate function.
collate_fn: C,
/// Prefetch buffer size.
prefetch_size: usize,
}

impl<D> DataLoader<D, SequentialSampler, DefaultCollate>
Expand All @@ -52,13 +58,15 @@ where

impl<D, S, C> DataLoader<D, S, C>
where
D: Dataset + Sync,
S: Sampler,
C: Collate<D::Sample>,
D: Dataset + Sync + Send + 'static,
S: Sampler + Send + 'static,
C: Collate<D::Sample> + Send + 'static,
D::Sample: Send,
C::Output: Send,
{
/// Return not owning iterator over the dataloader.
pub fn iter(&self) -> SingleProcessDataLoaderIter<'_, D, S, C> {
/// Return owning iterator over the dataloader.
/// TODO: Find a way to not consume the Dataloader
pub fn iter(self) -> SingleProcessDataLoaderIter<D, S, C> {
SingleProcessDataLoaderIter::new(self)
}
}
Expand All @@ -77,55 +85,62 @@ where

/// Iterate over the dataloader with a single thread.
#[derive(Debug)]
pub struct SingleProcessDataLoaderIter<'dataset, D, S = SequentialSampler, C = DefaultCollate>
pub struct SingleProcessDataLoaderIter<D, S = SequentialSampler, C = DefaultCollate>
where
D: Dataset + Sync,
S: Sampler,
C: Collate<D::Sample>,
{
/// The batch iterator of this iterator.
sampler_iter: BatchIterator<S::IntoIter>,
/// Number of sample yielded.
num_yielded: u64,
/// Used to fetch the data from the dataset.
data_fetcher: MapDatasetFetcher<'dataset, D, C>,
rx: mpsc::Receiver<C::Output>,
_thread_handle: JoinHandle<()>,
sampler: PhantomData<S>,
}

impl<'dataset, D, S, C> SingleProcessDataLoaderIter<'dataset, D, S, C>
impl<D, S, C> SingleProcessDataLoaderIter<D, S, C>
where
D: Dataset + Sync,
S: Sampler,
C: Collate<D::Sample>,
D: Dataset + Sync + Send + 'static,
S: Sampler + Send + 'static,
C: Collate<D::Sample> + Send + 'static,
D::Sample: Send,
C::Output: Send,
{
fn new(loader: &DataLoader<D, S, C>) -> SingleProcessDataLoaderIter<'_, D, S, C> {
SingleProcessDataLoaderIter {
sampler_iter: loader.batch_sampler.iter(),
num_yielded: 0,
data_fetcher: MapDatasetFetcher {
fn new(loader: DataLoader<D, S, C>) -> SingleProcessDataLoaderIter<D, S, C> {
let (tx, rx) = sync_channel(loader.prefetch_size);
let _thread_handle = std::thread::spawn(move || {
let loader = loader;
let mut data_fetcher = MapDatasetFetcher {
dataset: &loader.dataset,
collate_fn: &loader.collate_fn,
},
};
let mut sampler_iter = loader.batch_sampler.iter();
while let Some(index) = sampler_iter.next() {
let data = data_fetcher.fetch(index);
tx.send(data).expect("Cannot send data to channel");
}
});

SingleProcessDataLoaderIter {
num_yielded: 0,
rx,
_thread_handle,
sampler: PhantomData::default(),
}
}
fn next_index(&mut self) -> Option<Vec<usize>> {
self.sampler_iter.next()
}

fn next_data(&mut self) -> Option<C::Output> {
let index = self.next_index();
if let Some(index) = index {
let data = self.data_fetcher.fetch(index);
return Some(data);
}
None
self.rx.recv().ok()
}
}
impl<'dataset, D, S, C> Iterator for SingleProcessDataLoaderIter<'dataset, D, S, C>

impl<D, S, C> Iterator for SingleProcessDataLoaderIter<D, S, C>
where
D: Dataset + Sync,
S: Sampler,
C: Collate<D::Sample>,
D: Dataset + Sync + Send + 'static,
S: Sampler + Send + 'static,
C: Collate<D::Sample> + Send + 'static,
D::Sample: Send,
C::Output: Send,
{
type Item = C::Output;
fn next(&mut self) -> Option<Self::Item> {
Expand All @@ -138,15 +153,17 @@ where
None
}
}
impl<'dataset, D, S, C> IntoIterator for &'dataset DataLoader<D, S, C>

impl<D, S, C> IntoIterator for DataLoader<D, S, C>
where
D: Dataset + Sync,
S: Sampler,
C: Collate<D::Sample>,
D: Dataset + Sync + Send + 'static,
S: Sampler + Send + 'static,
C: Collate<D::Sample> + Send + 'static,
D::Sample: Send,
C::Output: Send,
{
type Item = C::Output;
type IntoIter = SingleProcessDataLoaderIter<'dataset, D, S, C>;
type IntoIter = SingleProcessDataLoaderIter<D, S, C>;

fn into_iter(self) -> Self::IntoIter {
self.iter()
Expand Down Expand Up @@ -196,7 +213,7 @@ mod tests {
let dataset = vec![1, 2, 3, 4];
let dataloader = DataLoader::builder(dataset).batch_size(2).build();

let mut iter = dataloader.iter();
let mut iter = dataloader.clone().iter();
assert_eq!(iter.next(), Some(array![1, 2]));
assert_eq!(iter.next(), Some(array![3, 4]));
assert_eq!(iter.next(), None);
Expand All @@ -216,6 +233,18 @@ mod tests {
assert_eq!(iter.next(), Some(vec![String::from("b")]));
assert_eq!(iter.next(), None);
}

#[test]
fn one_dimension_basic_string_with_prefetching() {
let dataset = vec![String::from("a"), String::from("b")];
let dataloader = DataLoader::builder(dataset).prefetch_size(10).build();

let mut iter = dataloader.iter();
assert_eq!(iter.next(), Some(vec![String::from("a")]));
assert_eq!(iter.next(), Some(vec![String::from("b")]));
assert_eq!(iter.next(), None);
}

#[test]
fn collate() {
let dataset = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
Expand Down

0 comments on commit fb3d302

Please sign in to comment.