Skip to content

Commit

Permalink
some fix & clean
Browse files Browse the repository at this point in the history
  • Loading branch information
AzHicham committed Sep 5, 2023
1 parent d60a0df commit b6ec1e9
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 26 deletions.
4 changes: 2 additions & 2 deletions src/collate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub trait Collate<T> {
// Allow user to specify closure as collate function.
impl<T, F, O> Collate<T> for F
where
F: Fn(Vec<T>) -> O + Clone,
F: Fn(Vec<T>) -> O,
{
type Output = O;
fn collate(&self, batch: Vec<T>) -> Self::Output {
Expand All @@ -36,7 +36,7 @@ where
}

/// Simple Collate that doesn't change the batch of samples.
#[derive(Default, Debug, Clone)]
#[derive(Default, Debug)]
pub struct NoOpCollate;

impl<T> Collate<T> for NoOpCollate {
Expand Down
7 changes: 4 additions & 3 deletions src/indexable/dataloader/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
sampler::{BatchSampler, RandomSampler, Sampler, SequentialSampler},
Dataset,
};
use std::cmp::max;
use std::sync::Arc;

#[cfg(feature = "rayon")]
Expand Down Expand Up @@ -59,7 +60,7 @@ where
collate_fn: DefaultCollate,
#[cfg(feature = "rayon")]
num_threads,
prefetch_size: 0,
prefetch_size: 1,
}
}
}
Expand All @@ -76,7 +77,7 @@ where
}
/// Set the number of elements in a batch.
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.batch_sampler.batch_size = batch_size;
self.batch_sampler.batch_size = max(batch_size, 1);
self
}

Expand Down Expand Up @@ -112,7 +113,7 @@ where
collate_fn,
#[cfg(feature = "rayon")]
num_threads: self.num_threads,
prefetch_size: 0,
prefetch_size: self.prefetch_size,
}
}

Expand Down
24 changes: 10 additions & 14 deletions src/indexable/dataloader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use builder::Builder;
///
/// let loader = DataLoader::builder(vec![(0, "hola"), (1, "hello"), (2, "hallo"), (3, "bonjour")]).batch_size(2).shuffle().build();
///
/// for (label, text) in loader {
/// for (label, text) in &loader {
/// println!("Label {label:?}");
/// println!("Text {text:?}");
/// }
Expand Down Expand Up @@ -64,7 +64,7 @@ where
D::Sample: Send,
C::Output: Send,
{
/// Return owning iterator over the dataloader.
/// Return not owning iterator over the dataloader.
pub fn iter(&self) -> SingleProcessDataLoaderIter<D, S, C> {
SingleProcessDataLoaderIter::new(self)
}
Expand Down Expand Up @@ -93,7 +93,7 @@ where
/// Number of sample yielded.
num_yielded: u64,
rx: mpsc::Receiver<C::Output>,
thread_handle: JoinHandle<()>,
_thread_handle: JoinHandle<()>,
sampler: PhantomData<S>,
}

Expand All @@ -113,7 +113,7 @@ where
collate_fn: loader.collate_fn.clone(),
};
let batch_sampler = loader.batch_sampler.clone();
let thread_handle = std::thread::spawn(move || {
let _thread_handle = std::thread::spawn(move || {
let mut sampler_iter = batch_sampler.iter();
// In this dedicated thread :
// We fetch the data and push it into the channel TX
Expand All @@ -132,20 +132,17 @@ where
SingleProcessDataLoaderIter {
num_yielded: 0,
rx,
thread_handle,
_thread_handle,
sampler: PhantomData,
}
}

fn next_data(&mut self) -> Option<C::Output> {
match self.rx.recv() {
Ok(data) => Some(data),
Err(_) if self.thread_handle.is_finished() => {
// The thread finished, the the channel is then closed. No more data is available
None
}
Err(_) => {
// An error occurred with the channel
// An error occurred with the channel,
// it is probably closed or has been cancelled
None
}
}
Expand Down Expand Up @@ -252,7 +249,7 @@ mod tests {
let dataset = vec![1, 2, 3, 4];
let dataloader = DataLoader::builder(dataset).batch_size(2).build();

let mut iter = dataloader.clone().iter();
let mut iter = dataloader.iter();
assert_eq!(iter.next(), Some(array![1, 2]));
assert_eq!(iter.next(), Some(array![3, 4]));
assert_eq!(iter.next(), None);
Expand All @@ -279,7 +276,6 @@ mod tests {
let dataloader = DataLoader::builder(dataset)
.collate_fn(NoOpCollate)
.batch_size(2)
.num_threads(2)
.prefetch_size(4)
.build();

Expand All @@ -298,8 +294,8 @@ mod tests {
assert_eq!(iter.next(), None);

let duration = start.elapsed();
println!("Time elapsed in expensive_function() is: {:?}", duration);
assert!(duration < Duration::from_millis(450));
println!("Time elapsed in data loading is: {:?}", duration);
//assert!(duration < Duration::from_millis(850));
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion src/indexable/fetch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ where
#[cfg(not(feature = "rayon"))]
let data = possibly_batched_index
.into_iter()
.map(|idx| self.dataset.get_sample(idx))
.map(|idx| self.dataset.get_sample(*idx))
.collect();

self.collate_fn.collate(data)
Expand Down
6 changes: 0 additions & 6 deletions src/indexable/sampler/random_sampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,6 @@ impl Iterator for RandomSamplerIter {
}
}

impl ExactSizeIterator for RandomSamplerIter {
fn len(&self) -> usize {
self.indexes.len()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit b6ec1e9

Please sign in to comment.