Skip to content

Commit 80ef70e

Browse files
committed
Added DistributedSampler to the train_dataloader
1 parent 78c4001 commit 80ef70e

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

demo-notebooks/guided-demos/pytorch_lightning.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import tempfile
33

44
import torch
5-
from torch.utils.data import DataLoader
5+
from torch.utils.data import DataLoader, DistributedSampler
66
from torchvision.models import resnet18
77
from torchvision.datasets import FashionMNIST
88
from torchvision.transforms import ToTensor, Normalize, Compose
@@ -74,10 +74,19 @@ def train_func():
7474
train_data = FashionMNIST(
7575
root=data_dir, train=True, download=True, transform=transform
7676
)
77-
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)
7877

7978
# Training
8079
model = ImageClassifier()
80+
81+
sampler = DistributedSampler(
82+
train_data,
83+
num_replicas=ray.train.get_context().get_world_size(),
84+
rank=ray.train.get_context().get_world_rank(),
85+
)
86+
87+
train_dataloader = DataLoader(
88+
train_data, batch_size=128, shuffle=False, sampler=sampler
89+
)
8190
# [1] Configure PyTorch Lightning Trainer.
8291
trainer = pl.Trainer(
8392
max_epochs=10,

0 commit comments

Comments
 (0)