13
13
# limitations under the License.
14
14
import torch
15
15
import os
16
- import torchvision . transforms as transforms
17
- import torchvision . datasets as datasets
16
+ from torchvision import transforms
17
+ from torchvision import datasets
18
18
19
19
def get_dataset (dataset_name , data_dir , split , rand_fraction = None ,clean = False , transform = None , imsize = None , bucket = 'pytorch-data' , ** kwargs ):
20
20
21
21
if dataset_name in [ 'cifar10' , 'cifar100' ]:
22
- dataset = globals ()[f'get_{ dataset_name } ' ](dataset_name , data_dir , split , imsize = imsize , bucket = bucket , ** kwargs )
22
+ dataset = globals ()[f'get_{ dataset_name } ' ](dataset_name , data_dir , split , imsize = imsize , bucket = bucket , ** kwargs )
23
23
elif dataset_name in [ 'cifar10vit224' , 'cifar100vit224' ,'cifar10vit384' , 'cifar100vit384' ,]:
24
24
imsize = int (dataset_name .split ('vit' )[- 1 ])
25
25
dataset_name = dataset_name .split ('vit' )[0 ]
26
26
#print ('here')
27
- dataset = globals ()['get_cifar_vit' ](dataset_name , data_dir , split , imsize = imsize , bucket = bucket , ** kwargs )
27
+ dataset = globals ()['get_cifar_vit' ](dataset_name , data_dir , split , imsize = imsize , bucket = bucket , ** kwargs )
28
28
else :
29
29
assert 'cifar' in dataset_name
30
30
print (dataset_name )
@@ -59,10 +59,10 @@ def get_transform(split, normalize=None, transform=None, imsize=None, aug='large
59
59
if transform is None :
60
60
if normalize is None :
61
61
if aug == 'large' :
62
-
62
+
63
63
normalize = transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ])
64
64
else :
65
- normalize = transforms .Normalize (mean = [0.4914 , 0.4822 , 0.4465 ], std = [0.2023 , 0.1994 , 0.2010 ])
65
+ normalize = transforms .Normalize (mean = [0.4914 , 0.4822 , 0.4465 ], std = [0.2023 , 0.1994 , 0.2010 ])
66
66
transform = transforms .Compose (get_aug (split , imsize = imsize , aug = aug )
67
67
+ [transforms .ToTensor (), normalize ])
68
68
return transform
@@ -71,7 +71,7 @@ def get_transform(split, normalize=None, transform=None, imsize=None, aug='large
71
71
def get_cifar10 (dataset_name , data_dir , split , transform = None , imsize = None , bucket = 'pytorch-data' , ** kwargs ):
72
72
if imsize == 224 :
73
73
transform = get_transform (split , transform = transform , imsize = imsize , aug = 'large' )
74
- else :
74
+ else :
75
75
transform = get_transform (split , transform = transform , imsize = imsize , aug = 'small' )
76
76
return datasets .CIFAR10 (data_dir , train = (split == 'train' ), transform = transform , download = True , ** kwargs )
77
77
@@ -88,7 +88,7 @@ def get_cifar100N(dataset_name, data_dir, split, rand_fraction=None,transform=No
88
88
if split == 'train' :
89
89
return CIFAR100N (root = data_dir , train = (split == 'train' ), transform = transform , download = True , rand_fraction = rand_fraction )
90
90
else :
91
- return datasets .CIFAR100 (data_dir , train = (split == 'train' ), transform = transform , download = True , ** kwargs )
91
+ return datasets .CIFAR100 (data_dir , train = (split == 'train' ), transform = transform , download = True , ** kwargs )
92
92
93
93
def get_cifar_vit (dataset_name , data_dir , split , transform = None , imsize = None , bucket = 'pytorch-data' , ** kwargs ):
94
94
if imsize == 224 :
@@ -111,12 +111,12 @@ def get_cifar_vit(dataset_name, data_dir, split, transform=None, imsize=None, bu
111
111
if dataset_name == 'cifar10' :
112
112
return datasets .CIFAR10 (data_dir , train = (split == 'train' ), transform = transform_data , download = True , ** kwargs )
113
113
elif dataset_name == 'cifar100' :
114
-
114
+
115
115
return datasets .CIFAR100 (data_dir , train = (split == 'train' ), transform = transform_data , download = True , ** kwargs )
116
116
else :
117
117
assert dataset_name in ['cifar10' , 'cifar100' ]
118
118
else :
119
-
119
+
120
120
if split == 'train' :
121
121
transform_data = transforms .Compose ([# transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
122
122
transforms .Resize (imsize ),
@@ -164,4 +164,4 @@ def get_imagenet_vit(dataset_name, data_dir, split, transform=None, imsize=None,
164
164
#return torch.utils.data.distributed.DistributedSampler(train_dataset)
165
165
else :
166
166
return datasets .ImageFolder (valdir , transform_data )
167
- #Ereturn torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
167
+ #Ereturn torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
0 commit comments