Skip to content

Commit b965b9c

Browse files
authored
Update references to torchvision (#949)
1 parent 1842b4f commit b965b9c

File tree

8 files changed

+47
-47
lines changed

8 files changed

+47
-47
lines changed

training/cifar/cifar10_deepspeed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn as nn
77
import torch.nn.functional as F
88
import torchvision
9-
import torchvision.transforms as transforms
9+
from torchvision import transforms
1010
from deepspeed.accelerator import get_accelerator
1111
from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer
1212

training/cifar/cifar10_tutorial.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
"""
5858
import torch
5959
import torchvision
60-
import torchvision.transforms as transforms
60+
from torchvision import transforms
6161

6262
########################################################################
6363
# The output of torchvision datasets are PILImage images of range [0, 1].

training/data_efficiency/vit_finetuning/main_imagenet.py

+23-23
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import torch.multiprocessing as mp
2020
import torch.utils.data
2121
import torch.utils.data.distributed
22-
import torchvision.transforms as transforms
23-
import torchvision.datasets as datasets
22+
from torchvision import transforms
23+
from torchvision import datasets
2424
import torchvision.models as models
2525
from torch.utils.data import Subset
2626
import models
@@ -105,7 +105,7 @@ def _get_model(args):
105105
nchannels = 3
106106
model = models.__dict__[args.arch](num_classes=nclasses, nchannels=nchannels)
107107
return model
108-
108+
109109
def _get_dist_model(gpu, args):
110110
ngpus_per_node = torch.cuida.device_count()
111111
if args.distributed:
@@ -149,9 +149,9 @@ def _get_dist_model(gpu, args):
149149
else:
150150
model = torch.nn.DataParallel(model).cuda()
151151
return model
152-
152+
153153
def main():
154-
154+
155155
args = parser.parse_args()
156156

157157
if args.seed is not None:
@@ -190,7 +190,7 @@ def main():
190190
def main_worker(gpu, ngpus_per_node, args):
191191
global best_acc1
192192
global history
193-
193+
194194
if args.deepspeed:
195195
gpu = args.local_rank
196196
args.gpu = gpu
@@ -205,7 +205,7 @@ def main_worker(gpu, ngpus_per_node, args):
205205
deepspeed.init_distributed()
206206
print(f'created model on gpu {gpu}')
207207
# exit ()
208-
208+
209209
# define loss function (criterion), optimizer, and learning rate scheduler
210210
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
211211

@@ -284,14 +284,14 @@ def main_worker(gpu, ngpus_per_node, args):
284284
validate(val_loader, model, criterion, args)
285285
# return
286286
args.completed_step = 0
287-
287+
288288
optimizer = torch.optim.SGD(model.parameters(), args.lr,
289289
momentum=args.momentum,
290290
weight_decay=args.weight_decay)
291-
291+
292292
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
293293
scheduler = StepLR(optimizer, step_size=int(len(train_loader)*args.epochs//3), gamma=0.1)# None #
294-
294+
295295

296296
model, optimizer, _, scheduler = deepspeed.initialize(
297297
model=model,
@@ -311,17 +311,17 @@ def main_worker(gpu, ngpus_per_node, args):
311311
time_epoch = time.time() - start_time
312312
# evaluate on validation set
313313
top5_val, top1_val, losses_val = validate(val_loader, model, criterion, args)
314-
if args.gpu==0:
314+
if args.gpu==0:
315315
history["epoch"].append(epoch)
316316
history["val_loss"].append(losses_val)
317-
history["val_acc1"].append(top1_val)
318-
history["val_acc5"].append(top5_val)
317+
history["val_acc1"].append(top1_val)
318+
history["val_acc5"].append(top5_val)
319319
history["train_loss"].append(losses_train)
320-
history["train_acc1"].append(top1_train)
320+
history["train_acc1"].append(top1_train)
321321
history["train_acc5"].append(top5_train)
322-
torch.save(history,f"{args.out_dir}/stat.pt")
322+
torch.save(history,f"{args.out_dir}/stat.pt")
323323
try:
324-
print (f'{epoch} epoch at time {time_epoch}s and learning rate {scheduler.get_last_lr()}')
324+
print (f'{epoch} epoch at time {time_epoch}s and learning rate {scheduler.get_last_lr()}')
325325
except:
326326
print (f'{epoch} epoch at time {time_epoch}s and learning rate {args.lr}')
327327
print (f"finish epoch {epoch} or iteration {args.completed_step}, train_accuracy is {top1_train}, val_accuracy {top1_val}")
@@ -393,14 +393,14 @@ def train(scheduler, train_loader, model, criterion, optimizer, epoch, args):
393393
loss.backward()
394394
optimizer.step()
395395
scheduler.step()
396-
396+
397397
# measure elapsed time
398398
batch_time.update(time.time() - end)
399399
end = time.time()
400400

401-
if i % args.print_freq == 0 and args.gpu==0:
401+
if i % args.print_freq == 0 and args.gpu==0:
402402
progress.display(i + 1)
403-
403+
404404
if args.distributed:
405405
losses.all_reduce()
406406
top1.all_reduce()
@@ -432,7 +432,7 @@ def run_validate(loader, base_progress=0):
432432
batch_time.update(time.time() - end)
433433
end = time.time()
434434

435-
if i % args.print_freq == 0 and args.gpu==0:
435+
if i % args.print_freq == 0 and args.gpu==0:
436436
progress.display(i + 1)
437437

438438
batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
@@ -509,7 +509,7 @@ def all_reduce(self):
509509
def __str__(self):
510510
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
511511
return fmtstr.format(**self.__dict__)
512-
512+
513513
def summary(self):
514514
fmtstr = ''
515515
if self.summary_type is Summary.NONE:
@@ -522,7 +522,7 @@ def summary(self):
522522
fmtstr = '{name} {count:.3f}'
523523
else:
524524
raise ValueError('invalid summary type %r' % self.summary_type)
525-
525+
526526
return fmtstr.format(**self.__dict__)
527527

528528

@@ -536,7 +536,7 @@ def display(self, batch):
536536
entries = [self.prefix + self.batch_fmtstr.format(batch)]
537537
entries += [str(meter) for meter in self.meters]
538538
print ('\t'.join(entries))
539-
539+
540540
def display_summary(self):
541541
entries = [" *"]
542542
entries += [meter.summary() for meter in self.meters]

training/data_efficiency/vit_finetuning/utils/get_data.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@
1313
# limitations under the License.
1414
import torch
1515
import os
16-
import torchvision.transforms as transforms
17-
import torchvision.datasets as datasets
16+
from torchvision import transforms
17+
from torchvision import datasets
1818

1919
def get_dataset(dataset_name, data_dir, split, rand_fraction=None,clean=False, transform=None, imsize=None, bucket='pytorch-data', **kwargs):
2020

2121
if dataset_name in [ 'cifar10', 'cifar100']:
22-
dataset = globals()[f'get_{dataset_name}'](dataset_name, data_dir, split, imsize=imsize, bucket=bucket, **kwargs)
22+
dataset = globals()[f'get_{dataset_name}'](dataset_name, data_dir, split, imsize=imsize, bucket=bucket, **kwargs)
2323
elif dataset_name in [ 'cifar10vit224', 'cifar100vit224','cifar10vit384', 'cifar100vit384',]:
2424
imsize = int(dataset_name.split('vit')[-1])
2525
dataset_name = dataset_name.split('vit')[0]
2626
#print ('here')
27-
dataset = globals()['get_cifar_vit'](dataset_name, data_dir, split, imsize=imsize, bucket=bucket, **kwargs)
27+
dataset = globals()['get_cifar_vit'](dataset_name, data_dir, split, imsize=imsize, bucket=bucket, **kwargs)
2828
else:
2929
assert 'cifar' in dataset_name
3030
print (dataset_name)
@@ -59,10 +59,10 @@ def get_transform(split, normalize=None, transform=None, imsize=None, aug='large
5959
if transform is None:
6060
if normalize is None:
6161
if aug == 'large':
62-
62+
6363
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
6464
else:
65-
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
65+
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
6666
transform = transforms.Compose(get_aug(split, imsize=imsize, aug=aug)
6767
+ [transforms.ToTensor(), normalize])
6868
return transform
@@ -71,7 +71,7 @@ def get_transform(split, normalize=None, transform=None, imsize=None, aug='large
7171
def get_cifar10(dataset_name, data_dir, split, transform=None, imsize=None, bucket='pytorch-data', **kwargs):
7272
if imsize==224:
7373
transform = get_transform(split, transform=transform, imsize=imsize, aug='large')
74-
else:
74+
else:
7575
transform = get_transform(split, transform=transform, imsize=imsize, aug='small')
7676
return datasets.CIFAR10(data_dir, train=(split=='train'), transform=transform, download=True, **kwargs)
7777

@@ -88,7 +88,7 @@ def get_cifar100N(dataset_name, data_dir, split, rand_fraction=None,transform=No
8888
if split=='train':
8989
return CIFAR100N(root=data_dir, train=(split=='train'), transform=transform, download=True, rand_fraction=rand_fraction)
9090
else:
91-
return datasets.CIFAR100(data_dir, train=(split=='train'), transform=transform, download=True, **kwargs)
91+
return datasets.CIFAR100(data_dir, train=(split=='train'), transform=transform, download=True, **kwargs)
9292

9393
def get_cifar_vit(dataset_name, data_dir, split, transform=None, imsize=None, bucket='pytorch-data', **kwargs):
9494
if imsize==224:
@@ -111,12 +111,12 @@ def get_cifar_vit(dataset_name, data_dir, split, transform=None, imsize=None, bu
111111
if dataset_name =='cifar10':
112112
return datasets.CIFAR10(data_dir, train=(split=='train'), transform=transform_data, download=True, **kwargs)
113113
elif dataset_name =='cifar100':
114-
114+
115115
return datasets.CIFAR100(data_dir, train=(split=='train'), transform=transform_data, download=True, **kwargs)
116116
else:
117117
assert dataset_name in ['cifar10', 'cifar100']
118118
else:
119-
119+
120120
if split=='train':
121121
transform_data = transforms.Compose([# transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
122122
transforms.Resize(imsize),
@@ -164,4 +164,4 @@ def get_imagenet_vit(dataset_name, data_dir, split, transform=None, imsize=None,
164164
#return torch.utils.data.distributed.DistributedSampler(train_dataset)
165165
else:
166166
return datasets.ImageFolder(valdir, transform_data)
167-
#Ereturn torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
167+
#Ereturn torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)

training/gan/gan_baseline_train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn as nn
44
import torch.utils.data
55
import torchvision.datasets as dset
6-
import torchvision.transforms as transforms
6+
from torchvision import transforms
77
import torchvision.utils as vutils
88
from torch.utils.tensorboard import SummaryWriter
99
from time import time

training/gan/gan_deepspeed_train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn as nn
44
import torch.utils.data
55
import torchvision.datasets as dset
6-
import torchvision.transforms as transforms
6+
from torchvision import transforms
77
import torchvision.utils as vutils
88
from torch.utils.tensorboard import SummaryWriter
99
from time import time

training/imagenet/main.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
import torch.optim
1919
import torch.utils.data
2020
import torch.utils.data.distributed
21-
import torchvision.datasets as datasets
2221
import torchvision.models as models
23-
import torchvision.transforms as transforms
22+
from torchvision import transforms
23+
from torchvision import datasets
2424
from torch.optim.lr_scheduler import StepLR
2525
from torch.utils.data import Subset
2626

@@ -94,7 +94,7 @@ def main():
9494
'which can slow down your training considerably! '
9595
'You may see unexpected behavior when restarting '
9696
'from checkpoints.')
97-
97+
9898
if args.gpu is not None:
9999
warnings.warn('You have chosen a specific GPU. This will completely '
100100
'disable data parallelism.')
@@ -112,7 +112,7 @@ def main():
112112
args.world_size = ngpus_per_node * args.world_size
113113
t_losses, t_acc1s = main_worker(args.gpu, ngpus_per_node, args)
114114
#dist.barrier()
115-
115+
116116
# Write the losses to an excel file
117117
if dist.get_rank() ==0:
118118
all_losses = [torch.empty_like(t_losses) for _ in range(ngpus_per_node)]
@@ -278,7 +278,7 @@ def print_rank_0(msg):
278278
acc1s[epoch] = acc1
279279

280280
scheduler.step()
281-
281+
282282
# remember best acc@1 and save checkpoint
283283
is_best = acc1 > best_acc1
284284
best_acc1 = max(acc1, best_acc1)
@@ -449,7 +449,7 @@ def all_reduce(self):
449449
def __str__(self):
450450
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
451451
return fmtstr.format(**self.__dict__)
452-
452+
453453
def summary(self):
454454
fmtstr = ''
455455
if self.summary_type is Summary.NONE:
@@ -462,7 +462,7 @@ def summary(self):
462462
fmtstr = '{name} {count:.3f}'
463463
else:
464464
raise ValueError('invalid summary type %r' % self.summary_type)
465-
465+
466466
return fmtstr.format(**self.__dict__)
467467

468468

@@ -476,7 +476,7 @@ def display(self, batch):
476476
entries = [self.prefix + self.batch_fmtstr.format(batch)]
477477
entries += [str(meter) for meter in self.meters]
478478
print('\t'.join(entries))
479-
479+
480480
def display_summary(self):
481481
entries = [" *"]
482482
entries += [meter.summary() for meter in self.meters]

training/pipeline_parallelism/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.distributed as dist
88

99
import torchvision
10-
import torchvision.transforms as transforms
10+
from torchvision import transforms
1111
from torchvision.models import AlexNet
1212
from torchvision.models import vgg19
1313

0 commit comments

Comments
 (0)