diff --git a/examples/image.rs b/examples/image.rs index db029bf..e54e2f4 100755 --- a/examples/image.rs +++ b/examples/image.rs @@ -127,6 +127,15 @@ fn main() { ); } + let dataset = FaceLandmarksDataset::new( + "examples/image/dataset/face_landmarks.csv", + env::current_dir().unwrap().join("examples/image/dataset/"), + ); + let loader = DataLoader::builder(dataset) + .batch_size(4) + .collate_fn(TorchCollate) + .build(); + loader .into_iter() .enumerate() diff --git a/src/indexable/dataloader/builder.rs b/src/indexable/dataloader/builder.rs index eac93b6..b7a04e8 100644 --- a/src/indexable/dataloader/builder.rs +++ b/src/indexable/dataloader/builder.rs @@ -4,7 +4,6 @@ use crate::{ Dataset, }; use std::cmp::max; -use std::sync::Arc; #[cfg(feature = "rayon")] use crate::THREAD_POOL; @@ -166,9 +165,9 @@ where } DataLoader { - dataset: Arc::new(self.dataset), + dataset: self.dataset, batch_sampler: self.batch_sampler, - collate_fn: Arc::new(self.collate_fn), + collate_fn: self.collate_fn, prefetch_size: self.prefetch_size, } } diff --git a/src/indexable/dataloader/mod.rs b/src/indexable/dataloader/mod.rs index ac5cbe3..419b2ba 100644 --- a/src/indexable/dataloader/mod.rs +++ b/src/indexable/dataloader/mod.rs @@ -6,8 +6,8 @@ use crate::{ sampler::{BatchSampler, Sampler, SequentialSampler}, Dataset, Len, }; +use std::sync::mpsc; use std::sync::mpsc::sync_channel; -use std::sync::{mpsc, Arc}; use std::thread::JoinHandle; mod builder; @@ -35,11 +35,11 @@ use builder::Builder; #[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq, Ord)] pub struct DataLoader { /// Dataset from which to load the data. - dataset: Arc, + dataset: D, /// Return a batch of indices at a time. batch_sampler: BatchSampler, /// Collate function. - collate_fn: Arc, + collate_fn: C, /// Prefetch buffer size. prefetch_size: usize, } @@ -58,13 +58,13 @@ where impl DataLoader where D: Dataset + Sync + Send + 'static, - S: Sampler + Send + Sync + 'static, - C: Collate + Send + Sync + 'static, + S: Sampler + Send + 'static, + C: Collate + Send + 'static, D::Sample: Send, C::Output: Send, { /// Return not owning iterator over the dataloader. - pub fn iter(&self) -> SingleProcessDataLoaderIter { + pub fn iter(self) -> SingleProcessDataLoaderIter { SingleProcessDataLoaderIter::::new(self) } } @@ -90,30 +90,29 @@ where /// Number of sample yielded. num_yielded: u64, rx: mpsc::Receiver, - _thread_handle: JoinHandle<()>, + thread_handle: Option>, } impl SingleProcessDataLoaderIter where CO: Send, { - fn new(loader: &DataLoader) -> SingleProcessDataLoaderIter + fn new(loader: DataLoader) -> SingleProcessDataLoaderIter where D: Dataset + Sync + Send + 'static, - S: Sampler + Send + Sync + 'static, - C: Collate + Send + Sync + 'static, + S: Sampler + Send + 'static, + C: Collate + Send + 'static, C::Output: Send, D::Sample: Send, { let (tx, rx) = sync_channel(loader.prefetch_size); let mut data_fetcher = MapDatasetFetcher { - dataset: loader.dataset.clone(), - collate_fn: loader.collate_fn.clone(), + dataset: loader.dataset, + collate_fn: loader.collate_fn, }; - let batch_sampler = loader.batch_sampler.clone(); - let _thread_handle = std::thread::spawn(move || { - let mut sampler_iter = batch_sampler.iter(); + let thread_handle = std::thread::spawn(move || { + let mut sampler_iter = loader.batch_sampler.iter(); // In this dedicated thread : // We fetch the data and push it into the channel TX // the loop will pause if there is no more space in the channel/buffer @@ -131,7 +130,7 @@ where SingleProcessDataLoaderIter { num_yielded: 0, rx, - _thread_handle, + thread_handle: Some(thread_handle), } } @@ -147,6 +146,19 @@ where } } +impl Drop for SingleProcessDataLoaderIter +where + CO: Send, +{ + fn drop(&mut self) { + if let Some(thread_handle) = self.thread_handle.take() { + // This call may deadlock + // TODO: Find a way to stop the background thread before joining + let _ = thread_handle.join(); + }; + } +} + impl Iterator for SingleProcessDataLoaderIter where CO: Send, @@ -166,8 +178,8 @@ where impl IntoIterator for DataLoader where D: Dataset + Send + Sync + 'static, - S: Sampler + Send + Sync + 'static, - C: Collate + Send + Sync + 'static, + S: Sampler + Send + 'static, + C: Collate + Send + 'static, D::Sample: Send, C::Output: Send, { @@ -243,7 +255,7 @@ mod tests { let dataset = vec![1, 2, 3, 4]; let dataloader = DataLoader::builder(dataset).batch_size(2).build(); - let mut iter = dataloader.iter(); + let mut iter = dataloader.clone().iter(); assert_eq!(iter.next(), Some(array![1, 2])); assert_eq!(iter.next(), Some(array![3, 4])); assert_eq!(iter.next(), None); diff --git a/src/indexable/fetch/mod.rs b/src/indexable/fetch/mod.rs index 2ec6a13..0cae373 100644 --- a/src/indexable/fetch/mod.rs +++ b/src/indexable/fetch/mod.rs @@ -2,7 +2,6 @@ use crate::{ collate::{Collate, DefaultCollate}, Dataset, }; -use std::sync::Arc; #[cfg(feature = "rayon")] use crate::THREAD_POOL; @@ -35,15 +34,15 @@ where C: Collate, { /// The dataset data will be fetch from. - pub(crate) dataset: Arc, + pub(crate) dataset: D, /// The function (generic struct) used to collate data together. - pub(crate) collate_fn: Arc, + pub(crate) collate_fn: C, } impl Fetcher for MapDatasetFetcher where D: Dataset + Sync + Send, - C: Collate + Sync + Send, + C: Collate + Send, D::Sample: Send, { fn fetch(&mut self, possibly_batched_index: Vec) -> C::Output {