Skip to content

Commit

Permalink
simplify traits ??
Browse files Browse the repository at this point in the history
  • Loading branch information
AzHicham committed Sep 5, 2023
1 parent b6ec1e9 commit e68e8ac
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 31 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ once_cell = { version = "1.17.1", optional = true }


[dev-dependencies]
criterion = { version = "^0.5", features = ["html_reports"] }
csv = "^1.1"
image = "^0.24"
nshare = { version = "^0.9", features = ["ndarray", "image"] }
criterion = { version = "0.5.1", features = ["html_reports"] }
csv = "1.1.6"
image = "0.24.3"
nshare = { version = "0.9.0", features = ["ndarray", "image"] }

[[example]]
name = "image"
Expand Down
2 changes: 1 addition & 1 deletion src/indexable/dataloader/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ where

/// Set the size of the prefetch buffer.
pub fn prefetch_size(mut self, prefetch_size: usize) -> Self {
self.prefetch_size = prefetch_size;
self.prefetch_size = max(prefetch_size, 1);
self
}

Expand Down
46 changes: 20 additions & 26 deletions src/indexable/dataloader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use crate::{
sampler::{BatchSampler, Sampler, SequentialSampler},
Dataset, Len,
};
use std::marker::PhantomData;
use std::sync::mpsc::sync_channel;
use std::sync::{mpsc, Arc};
use std::thread::JoinHandle;
Expand Down Expand Up @@ -65,8 +64,8 @@ where
C::Output: Send,
{
/// Return not owning iterator over the dataloader.
pub fn iter(&self) -> SingleProcessDataLoaderIter<D, S, C> {
SingleProcessDataLoaderIter::new(self)
pub fn iter(&self) -> SingleProcessDataLoaderIter<C::Output> {
SingleProcessDataLoaderIter::<C::Output>::new(self)
}
}

Expand All @@ -84,28 +83,28 @@ where

/// Iterate over the dataloader with a single thread.
#[derive(Debug)]
pub struct SingleProcessDataLoaderIter<D, S = SequentialSampler, C = DefaultCollate>
pub struct SingleProcessDataLoaderIter<CO>
where
D: Dataset,
S: Sampler,
C: Collate<D::Sample>,
CO: Send,
{
/// Number of sample yielded.
num_yielded: u64,
rx: mpsc::Receiver<C::Output>,
rx: mpsc::Receiver<CO>,
_thread_handle: JoinHandle<()>,
sampler: PhantomData<S>,
}

impl<D, S, C> SingleProcessDataLoaderIter<D, S, C>
impl<CO> SingleProcessDataLoaderIter<CO>
where
D: Dataset + Send + Sync + 'static,
S: Sampler + Send + Sync + 'static,
C: Collate<D::Sample> + Send + Sync + 'static,
D::Sample: Send,
C::Output: Send,
CO: Send,
{
fn new(loader: &DataLoader<D, S, C>) -> SingleProcessDataLoaderIter<D, S, C> {
fn new<D, S, C>(loader: &DataLoader<D, S, C>) -> SingleProcessDataLoaderIter<C::Output>
where
D: Dataset + Sync + Send + 'static,
S: Sampler + Send + Sync + 'static,
C: Collate<D::Sample> + Send + Sync + 'static,
C::Output: Send,
D::Sample: Send,
{
let (tx, rx) = sync_channel(loader.prefetch_size);

let data_fetcher = MapDatasetFetcher {
Expand Down Expand Up @@ -133,11 +132,10 @@ where
num_yielded: 0,
rx,
_thread_handle,
sampler: PhantomData,
}
}

fn next_data(&mut self) -> Option<C::Output> {
fn next_data(&mut self) -> Option<CO> {
match self.rx.recv() {
Ok(data) => Some(data),
Err(_) => {
Expand All @@ -149,15 +147,11 @@ where
}
}

impl<D, S, C> Iterator for SingleProcessDataLoaderIter<D, S, C>
impl<CO> Iterator for SingleProcessDataLoaderIter<CO>
where
D: Dataset + Send + Sync + 'static,
S: Sampler + Send + Sync + 'static,
C: Collate<D::Sample> + Send + Sync + 'static,
D::Sample: Send,
C::Output: Send,
CO: Send,
{
type Item = C::Output;
type Item = CO;
fn next(&mut self) -> Option<Self::Item> {
let data = self.next_data();

Expand All @@ -178,7 +172,7 @@ where
C::Output: Send,
{
type Item = C::Output;
type IntoIter = SingleProcessDataLoaderIter<D, S, C>;
type IntoIter = SingleProcessDataLoaderIter<C::Output>;

fn into_iter(self) -> Self::IntoIter {
self.iter()
Expand Down

0 comments on commit e68e8ac

Please sign in to comment.