Skip to content

Commit b6c1ba3

Browse files
committed
Fixes run on TPU
- applied black - refactored trainers
1 parent 15fba79 commit b6c1ba3

File tree

12 files changed

+714
-371
lines changed

12 files changed

+714
-371
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Experiments with "FixMatch" on Cifar10 dataset.
44

55
Based on ["FixMatch: Simplifying Semi-Supervised Learning withConsistency and Confidence"](https://arxiv.org/abs/2001.07685)
6+
and its official [code](https://github.com/google-research/fixmatch).
67

78
## Requirements
89

ctaugment/__init__.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66

77
class StorableCTAugment(CTAugment):
8-
98
def load_state_dict(self, state):
109
for k in ["decay", "depth", "th", "rates"]:
1110
assert k in state, "{} not in {}".format(k, state.keys())
1211
setattr(self, k, state[k])
1312

1413
def state_dict(self):
15-
return OrderedDict([(k, getattr(self, k)) for k in ["decay", "depth", "th", "rates"]])
14+
return OrderedDict(
15+
[(k, getattr(self, k)) for k in ["decay", "depth", "th", "rates"]]
16+
)
1617

1718

1819
def get_default_cta():
@@ -32,9 +33,17 @@ def deserialize(policy_str):
3233

3334

3435
def stats(cta):
35-
return '\n'.join('%-16s %s' % (k, ' / '.join(' '.join('%.2f' % x for x in cta.rate_to_p(rate))
36-
for rate in cta.rates[k]))
37-
for k in sorted(OPS.keys()))
36+
return "\n".join(
37+
"%-16s %s"
38+
% (
39+
k,
40+
" / ".join(
41+
" ".join("%.2f" % x for x in cta.rate_to_p(rate))
42+
for rate in cta.rates[k]
43+
),
44+
)
45+
for k in sorted(OPS.keys())
46+
)
3847

3948

4049
def interleave(x, batch, inverse=False):

ctaugment/ctaugment.py

+26-8
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222

2323

2424
OPS = {}
25-
OP = namedtuple('OP', ('f', 'bins'))
26-
Sample = namedtuple('Sample', ('train', 'probe'))
25+
OP = namedtuple("OP", ("f", "bins"))
26+
Sample = namedtuple("Sample", ("train", "probe"))
2727

2828

2929
def register(*bins):
@@ -41,7 +41,7 @@ def __init__(self, depth=2, th=0.85, decay=0.99):
4141
self.th = th
4242
self.rates = {}
4343
for k, op in OPS.items():
44-
self.rates[k] = tuple([np.ones(x, 'f') for x in op.bins])
44+
self.rates[k] = tuple([np.ones(x, "f") for x in op.bins])
4545

4646
def rate_to_p(self, rate):
4747
p = rate + (1 - self.decay) # Avoid to have all zero.
@@ -78,9 +78,17 @@ def update_rates(self, policy, proximity):
7878
rate[p] = rate[p] * self.decay + proximity * (1 - self.decay)
7979

8080
def stats(self):
81-
return '\n'.join('%-16s %s' % (k, ' / '.join(' '.join('%.2f' % x for x in self.rate_to_p(rate))
82-
for rate in self.rates[k]))
83-
for k in sorted(OPS.keys()))
81+
return "\n".join(
82+
"%-16s %s"
83+
% (
84+
k,
85+
" / ".join(
86+
" ".join("%.2f" % x for x in self.rate_to_p(rate))
87+
for rate in self.rates[k]
88+
),
89+
)
90+
for k in sorted(OPS.keys())
91+
)
8492

8593

8694
def _enhance(x, op, level):
@@ -128,7 +136,10 @@ def cutout(x, level):
128136
height_loc = np.random.randint(low=0, high=img_height)
129137
width_loc = np.random.randint(low=0, high=img_width)
130138
upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2))
131-
lower_coord = (min(img_height, height_loc + size // 2), min(img_width, width_loc + size // 2))
139+
lower_coord = (
140+
min(img_height, height_loc + size // 2),
141+
min(img_width, width_loc + size // 2),
142+
)
132143
pixels = x.load() # create the pixel map
133144
for i in range(upper_coord[0], lower_coord[0]): # for every col:
134145
for j in range(upper_coord[1], lower_coord[1]): # For every row
@@ -162,7 +173,14 @@ def rescale(x, scale, method):
162173
s = x.size
163174
scale *= 0.25
164175
crop = (scale * s[0], scale * s[1], s[0] * (1 - scale), s[1] * (1 - scale))
165-
methods = (Image.ANTIALIAS, Image.BICUBIC, Image.BILINEAR, Image.BOX, Image.HAMMING, Image.NEAREST)
176+
methods = (
177+
Image.ANTIALIAS,
178+
Image.BICUBIC,
179+
Image.BILINEAR,
180+
Image.BOX,
181+
Image.HAMMING,
182+
Image.NEAREST,
183+
)
166184
method = methods[int(method * 5.99)]
167185
return x.crop(crop).resize(x.size, method)
168186

dataflow/__init__.py

+27-16
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77

88
class TransformedDataset(Dataset):
9-
109
def __init__(self, dataset, transforms):
1110
self.dataset = dataset
1211
self.transforms = transforms
@@ -31,17 +30,23 @@ def cycle(dataloader):
3130
yield b
3231

3332

34-
def get_supervised_train_loader(dataset_name, root, num_train_samples_per_class, download=True, **dataloader_kwargs):
33+
def get_supervised_train_loader(
34+
dataset_name, root, num_train_samples_per_class, download=True, **dataloader_kwargs
35+
):
3536
if dataset_name == "cifar10":
36-
from dataflow.cifar10 import get_supervised_trainset, get_supervised_train_loader, weak_transforms
37+
from dataflow.cifar10 import (
38+
get_supervised_trainset,
39+
get_supervised_train_loader,
40+
weak_transforms,
41+
)
3742

3843
train_dataset = get_supervised_trainset(
39-
root, num_train_samples_per_class=num_train_samples_per_class, download=download
44+
root,
45+
num_train_samples_per_class=num_train_samples_per_class,
46+
download=download,
4047
)
4148

42-
return get_supervised_train_loader(
43-
train_dataset, **dataloader_kwargs
44-
)
49+
return get_supervised_train_loader(train_dataset, **dataloader_kwargs)
4550

4651
else:
4752
raise ValueError("Unhandled dataset: {}".format(dataset_name))
@@ -57,7 +62,9 @@ def get_test_loader(dataset_name, root, download=True, **dataloader_kwargs):
5762
raise ValueError("Unhandled dataset: {}".format(dataset_name))
5863

5964

60-
def get_unsupervised_train_loader(dataset_name, root, cta, download=True, **dataloader_kwargs):
65+
def get_unsupervised_train_loader(
66+
dataset_name, root, cta, download=True, **dataloader_kwargs
67+
):
6168
if dataset_name == "cifar10":
6269
from dataflow import cifar10
6370

@@ -78,20 +85,24 @@ def get_unsupervised_train_loader(dataset_name, root, cta, download=True, **data
7885
raise ValueError("Unhandled dataset: {}".format(dataset_name))
7986

8087

81-
def get_cta_probe_loader(dataset_name, root, num_train_samples_per_class, cta, download=True, **dataloader_kwargs):
88+
def get_cta_probe_loader(
89+
dataset_name,
90+
root,
91+
num_train_samples_per_class,
92+
cta,
93+
download=True,
94+
**dataloader_kwargs
95+
):
8296
if dataset_name == "cifar10":
8397
from dataflow.cifar10 import get_supervised_trainset, get_cta_probe_loader
8498

8599
train_dataset = get_supervised_trainset(
86-
root, num_train_samples_per_class=num_train_samples_per_class, download=download
100+
root,
101+
num_train_samples_per_class=num_train_samples_per_class,
102+
download=download,
87103
)
88104

89-
return get_cta_probe_loader(
90-
train_dataset,
91-
cta=cta,
92-
**dataloader_kwargs
93-
)
105+
return get_cta_probe_loader(train_dataset, cta=cta, **dataloader_kwargs)
94106

95107
else:
96108
raise ValueError("Unhandled dataset: {}".format(dataset_name))
97-

0 commit comments

Comments
 (0)