Skip to content

Commit 29d0822

Browse files
committedDec 30, 2020
初步完成多进程版本
1 parent f1fbebd commit 29d0822

10 files changed

+728
-0
lines changed
 

‎libv2/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .dataset import Dataset
2+
from .sampler import Sampler, BatchSampler, SequentialSampler, RandomSampler
3+
from .dataloader import DataLoader
4+
from .collate import default_collate
5+
6+
__all__ = ['Dataset', 'Sampler', 'BatchSampler', 'SequentialSampler', 'RandomSampler', 'DataLoader', 'default_collate']

‎libv2/collate.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# -*- coding: utf-8 -*-
2+
# ======================================================
3+
# @Time : 20-12-26 下午4:42
4+
# @Author : huang ha
5+
# @Email :
6+
# @File : collate.py
7+
# @Comment:
8+
# ======================================================
9+
import torch
10+
11+
12+
def default_collate(batch):
13+
elem = batch[0]
14+
elem_type = type(elem)
15+
if isinstance(elem, torch.Tensor):
16+
return torch.stack(batch, 0)
17+
elif elem_type.__module__ == 'numpy':
18+
return default_collate([torch.as_tensor(b) for b in batch])
19+
else:
20+
raise NotImplementedError

‎libv2/dataloader.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from torch.utils.data._utils.collate import default_collate
2+
from .sampler import BatchSampler, SequentialSampler, RandomSampler
3+
from .dataloaderiter import _SingleProcessDataLoaderIter, _MultiProcessingDataLoaderIter
4+
5+
6+
class DataLoader(object):
7+
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
8+
batch_sampler=None, collate_fn=None, drop_last=False,
9+
num_workers=0, timeout=0, worker_init_fn=None, prefetch_factor=2, persistent_workers=False):
10+
self.dataset = dataset
11+
12+
# 因为这两个功能是冲突的,假设shuffle=True,但是sampler里面是SequentialSampler,那么就违背设计思想了
13+
if sampler is not None and shuffle:
14+
raise ValueError('sampler option is mutually exclusive with '
15+
'shuffle')
16+
17+
if batch_sampler is not None:
18+
# 一旦设置了batch_sampler,那么batch_size、shuffle、sampler和drop_last四个参数就不能传入
19+
# 因为这4个参数功能和batch_sampler功能冲突了
20+
if batch_size != 1 or shuffle or sampler is not None or drop_last:
21+
raise ValueError('batch_sampler option is mutually exclusive '
22+
'with batch_size, shuffle, sampler, and '
23+
'drop_last')
24+
batch_size = None
25+
drop_last = False
26+
27+
if sampler is None:
28+
if shuffle:
29+
sampler = RandomSampler(dataset)
30+
else:
31+
sampler = SequentialSampler(dataset)
32+
33+
# 也就是说batch_sampler必须要存在,你如果没有设置,那么采用默认类
34+
if batch_sampler is None:
35+
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
36+
37+
self.batch_size = batch_size
38+
self.drop_last = drop_last
39+
self.sampler = sampler
40+
self.batch_sampler = iter(batch_sampler)
41+
42+
if collate_fn is None:
43+
collate_fn = default_collate
44+
self.collate_fn = collate_fn
45+
46+
self.num_workers = num_workers
47+
self.prefetch_factor = prefetch_factor
48+
self.timeout = timeout
49+
self.worker_init_fn = worker_init_fn
50+
51+
# 换一种迭代器写法
52+
def _get_iterator(self):
53+
if self.num_workers == 0:
54+
return _SingleProcessDataLoaderIter(self)
55+
else:
56+
return _MultiProcessingDataLoaderIter(self)
57+
58+
# 返回迭代器对象
59+
def __iter__(self):
60+
return self._get_iterator()
61+

0 commit comments

Comments
 (0)