Skip to content

Commit

Permalink
refacto: impl with reference instead of Arc
Browse files Browse the repository at this point in the history
  • Loading branch information
AzHicham committed Sep 7, 2023
1 parent 0b6c73d commit f74b62f
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 26 deletions.
9 changes: 9 additions & 0 deletions examples/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ fn main() {
);
}

let dataset = FaceLandmarksDataset::new(
"examples/image/dataset/face_landmarks.csv",
env::current_dir().unwrap().join("examples/image/dataset/"),
);
let loader = DataLoader::builder(dataset)
.batch_size(4)
.collate_fn(TorchCollate)
.build();

loader
.into_iter()
.enumerate()
Expand Down
5 changes: 2 additions & 3 deletions src/indexable/dataloader/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::{
Dataset,
};
use std::cmp::max;
use std::sync::Arc;

#[cfg(feature = "rayon")]
use crate::THREAD_POOL;
Expand Down Expand Up @@ -166,9 +165,9 @@ where
}

DataLoader {
dataset: Arc::new(self.dataset),
dataset: self.dataset,
batch_sampler: self.batch_sampler,
collate_fn: Arc::new(self.collate_fn),
collate_fn: self.collate_fn,
prefetch_size: self.prefetch_size,
}
}
Expand Down
50 changes: 31 additions & 19 deletions src/indexable/dataloader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::{
sampler::{BatchSampler, Sampler, SequentialSampler},
Dataset, Len,
};
use std::sync::mpsc;
use std::sync::mpsc::sync_channel;
use std::sync::{mpsc, Arc};
use std::thread::JoinHandle;

mod builder;
Expand Down Expand Up @@ -35,11 +35,11 @@ use builder::Builder;
#[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq, Ord)]
pub struct DataLoader<D, S = SequentialSampler, C = DefaultCollate> {
/// Dataset from which to load the data.
dataset: Arc<D>,
dataset: D,
/// Return a batch of indices at a time.
batch_sampler: BatchSampler<S>,
/// Collate function.
collate_fn: Arc<C>,
collate_fn: C,
/// Prefetch buffer size.
prefetch_size: usize,
}
Expand All @@ -58,13 +58,13 @@ where
impl<D, S, C> DataLoader<D, S, C>
where
D: Dataset + Sync + Send + 'static,
S: Sampler + Send + Sync + 'static,
C: Collate<D::Sample> + Send + Sync + 'static,
S: Sampler + Send + 'static,
C: Collate<D::Sample> + Send + 'static,
D::Sample: Send,
C::Output: Send,
{
/// Return not owning iterator over the dataloader.
pub fn iter(&self) -> SingleProcessDataLoaderIter<C::Output> {
pub fn iter(self) -> SingleProcessDataLoaderIter<C::Output> {
SingleProcessDataLoaderIter::<C::Output>::new(self)
}
}
Expand All @@ -90,30 +90,29 @@ where
/// Number of sample yielded.
num_yielded: u64,
rx: mpsc::Receiver<CO>,
_thread_handle: JoinHandle<()>,
thread_handle: Option<JoinHandle<()>>,
}

impl<CO> SingleProcessDataLoaderIter<CO>
where
CO: Send,
{
fn new<D, S, C>(loader: &DataLoader<D, S, C>) -> SingleProcessDataLoaderIter<C::Output>
fn new<D, S, C>(loader: DataLoader<D, S, C>) -> SingleProcessDataLoaderIter<C::Output>
where
D: Dataset + Sync + Send + 'static,
S: Sampler + Send + Sync + 'static,
C: Collate<D::Sample> + Send + Sync + 'static,
S: Sampler + Send + 'static,
C: Collate<D::Sample> + Send + 'static,
C::Output: Send,
D::Sample: Send,
{
let (tx, rx) = sync_channel(loader.prefetch_size);

let mut data_fetcher = MapDatasetFetcher {
dataset: loader.dataset.clone(),
collate_fn: loader.collate_fn.clone(),
dataset: loader.dataset,
collate_fn: loader.collate_fn,
};
let batch_sampler = loader.batch_sampler.clone();
let _thread_handle = std::thread::spawn(move || {
let mut sampler_iter = batch_sampler.iter();
let thread_handle = std::thread::spawn(move || {
let mut sampler_iter = loader.batch_sampler.iter();
// In this dedicated thread :
// We fetch the data and push it into the channel TX
// the loop will pause if there is no more space in the channel/buffer
Expand All @@ -131,7 +130,7 @@ where
SingleProcessDataLoaderIter {
num_yielded: 0,
rx,
_thread_handle,
thread_handle: Some(thread_handle),
}
}

Expand All @@ -147,6 +146,19 @@ where
}
}

impl<CO> Drop for SingleProcessDataLoaderIter<CO>
where
CO: Send,
{
fn drop(&mut self) {
if let Some(thread_handle) = self.thread_handle.take() {
// This call may deadlock
// TODO: Find a way to stop the background thread before joining
let _ = thread_handle.join();
};
}
}

impl<CO> Iterator for SingleProcessDataLoaderIter<CO>
where
CO: Send,
Expand All @@ -166,8 +178,8 @@ where
impl<D, S, C> IntoIterator for DataLoader<D, S, C>
where
D: Dataset + Send + Sync + 'static,
S: Sampler + Send + Sync + 'static,
C: Collate<D::Sample> + Send + Sync + 'static,
S: Sampler + Send + 'static,
C: Collate<D::Sample> + Send + 'static,
D::Sample: Send,
C::Output: Send,
{
Expand Down Expand Up @@ -243,7 +255,7 @@ mod tests {
let dataset = vec![1, 2, 3, 4];
let dataloader = DataLoader::builder(dataset).batch_size(2).build();

let mut iter = dataloader.iter();
let mut iter = dataloader.clone().iter();
assert_eq!(iter.next(), Some(array![1, 2]));
assert_eq!(iter.next(), Some(array![3, 4]));
assert_eq!(iter.next(), None);
Expand Down
7 changes: 3 additions & 4 deletions src/indexable/fetch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::{
collate::{Collate, DefaultCollate},
Dataset,
};
use std::sync::Arc;

#[cfg(feature = "rayon")]
use crate::THREAD_POOL;
Expand Down Expand Up @@ -35,15 +34,15 @@ where
C: Collate<D::Sample>,
{
/// The dataset data will be fetch from.
pub(crate) dataset: Arc<D>,
pub(crate) dataset: D,
/// The function (generic struct) used to collate data together.
pub(crate) collate_fn: Arc<C>,
pub(crate) collate_fn: C,
}

impl<D, C> Fetcher<D, C> for MapDatasetFetcher<D, C>
where
D: Dataset + Sync + Send,
C: Collate<D::Sample> + Sync + Send,
C: Collate<D::Sample> + Send,
D::Sample: Send,
{
fn fetch(&mut self, possibly_batched_index: Vec<usize>) -> C::Output {
Expand Down

0 comments on commit f74b62f

Please sign in to comment.