-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodules_general.py
684 lines (526 loc) · 22.7 KB
/
modules_general.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
import random
import numpy
import tqdm
import torch
import torchmetrics
import torchvision
import pytorch_lightning as pl
import data_utils
class IALDataModule(pl.LightningDataModule):
def __init__(self, **kwargs):
super().__init__()
self.save_hyperparameters()
self.transform = torchvision.transforms.ToTensor()
self.data_train = None
self.data_val = None
self.data_test = None
self.data_unlabeled = None
self.setup_fit_done = False
self.setup_test_done = False
def setup(self, stage:str=None):
if stage in ["fit", "validate", None] and not self.setup_fit_done:
self.setup_fit_done = True
data_full = self.get_data_train()
# Split dataset in unlabeled and validation sets randomly
size_train = round(len(data_full) * self.hparams.train_split)
size_val = len(data_full) - size_train
self.data_unlabeled, self.data_val = torch.utils.data.random_split(
data_full,
[size_train, size_val]
)
self.data_train = torch.utils.data.Subset(data_full, [])
data_utils.balance_classes(self.data_unlabeled, self.hparams.class_balance)
self.label_static_distribution(self.hparams.initial_labels)
if stage in ["test", None] and not self.setup_test_done:
self.setup_test_done = True
self.data_test = self.get_data_test()
def train_dataloader(self):
return torch.utils.data.DataLoader(
self.data_train,
batch_size=self.hparams.train_batch_size,
shuffle=True,
num_workers=self.hparams.dataloader_workers
)
def val_dataloader(self):
return torch.utils.data.DataLoader(
self.data_val,
batch_size=self.hparams.eval_batch_size,
num_workers=self.hparams.dataloader_workers
)
def test_dataloader(self):
return torch.utils.data.DataLoader(
self.data_test,
batch_size=self.hparams.eval_batch_size,
num_workers=self.hparams.dataloader_workers
)
def predict_dataloader(self):
return torch.utils.data.DataLoader(
self.data_test,
batch_size=self.hparams.eval_batch_size,
num_workers=self.hparams.dataloader_workers
)
def labeled_dataloader(self):
return torch.utils.data.DataLoader(
self.data_train,
batch_size=self.hparams.eval_batch_size,
num_workers=self.hparams.dataloader_workers
)
def labeled_dataloader_single(self):
return torch.utils.data.DataLoader(
self.data_train,
batch_size=1,
shuffle=True,
num_workers=self.hparams.dataloader_workers
)
def unlabeled_dataloader(self):
return torch.utils.data.DataLoader(
self.data_unlabeled,
batch_size=self.hparams.eval_batch_size,
num_workers=self.hparams.dataloader_workers
)
def unlabeled_dataloader_single(self):
return torch.utils.data.DataLoader(
self.data_unlabeled,
batch_size=1,
num_workers=self.hparams.dataloader_workers
)
@property
def class_balance(self):
return torch.tensor([len([index
for index in self.data_train.indices
if self.data_train.dataset.targets[index] == c
]) for c, _ in enumerate(self.data_train.dataset.classes)
])
def label_indices(self, indices: list):
self.data_train.indices = sorted(list(set(self.data_train.indices + indices)))
self.data_unlabeled.indices = sorted(list(set([index
for index in self.data_unlabeled.indices
if index not in indices
])))
def label_each_class(self, amount: int = 1):
self.label_indices(sum([
random.sample([index
for index in self.data_unlabeled.indices
if self.data_unlabeled.dataset.targets[index] == class_num
], amount)
for class_num, _ in enumerate(self.data_unlabeled.dataset.classes)
], []))
def label_static_distribution(self, amount: int):
'''
Label images randomly following the distribution of unlabeled data
'''
unlabeled_data_count = len(self.data_unlabeled)
unlabeled_class_indices = [
[index
for index in self.data_unlabeled.indices
if self.data_unlabeled.dataset.targets[index] == class_num
] for class_num, _ in enumerate(self.data_unlabeled.dataset.classes)
]
# Label amount of images from each class relative to its size (rounded down)
labels_remaining = amount
for class_indices in unlabeled_class_indices:
class_labeling_count = int(amount * len(class_indices) / unlabeled_data_count)
self.label_indices(random.sample(class_indices, class_labeling_count))
labels_remaining -= class_labeling_count
# Choose random classes to label the remaining images
self.label_indices(sum([
random.sample([index
for index in self.data_unlabeled.indices
if self.data_unlabeled.dataset.targets[index] == class_num
], 1)
for class_num in random.sample(range(len(self.data_unlabeled.dataset.classes)), labels_remaining)
], []))
def label_randomly(self, amount: int, model: pl.LightningModule = None):
chosen_indices = random.sample(self.data_unlabeled.indices, amount)
self.label_indices(chosen_indices)
def label_uncertain(self, amount: int, model: pl.LightningModule, uncertainty_method: str):
if uncertainty_method == 'entropy':
def uncertainty_method_fn(preds):
return -(preds*preds.log()).sum(1).nan_to_num(0)
elif uncertainty_method == 'margin':
def uncertainty_method_fn(preds):
return 1 - preds.topk(2, dim=1)[0].diff(dim=1).abs().squeeze(1)
elif uncertainty_method == 'least-confident':
def uncertainty_method_fn(preds):
return 1 - preds.max(1)[0]
else:
raise ValueError(f"{uncertainty_method} is no valid uncertainty method")
uncertainty_list = []
with torch.no_grad():
for batch in tqdm.tqdm(self.unlabeled_dataloader(), desc='Labeling'):
images, _ = batch
output, _ = model(images)
try:
# Multiclass, softmax
preds = torch.nn.functional.softmax(output, 1)
except IndexError:
# Binary, sigmoid
preds_binary = torch.sigmoid(output)
preds = torch.stack([preds_binary, 1 - preds_binary], 1)
uncertainty_list.append(uncertainty_method_fn(preds))
uncertainty = torch.cat(uncertainty_list)
_, top_indices = uncertainty.topk(amount)
chosen_indices = [self.data_unlabeled.indices[i] for i in top_indices]
self.label_indices(chosen_indices)
def label_entropy(self, amount: int, model: pl.LightningModule):
self.label_uncertain(amount, model, 'entropy')
def label_margin(self, amount: int, model: pl.LightningModule):
self.label_uncertain(amount, model, 'margin')
def label_least_confident(self, amount: int, model: pl.LightningModule):
self.label_uncertain(amount, model, 'least-confident')
# Active Learning for Skewed Data Sets
# Abbas Kazerouni et al
def label_hal_r(self, amount: int, model: pl.LightningModule):
dist = numpy.random.binomial(amount, self.hparams.hal_exploit_probability)
self.label_margin(dist, model)
self.label_randomly(amount - dist, model)
def label_hal_g(self, amount: int, model: pl.LightningModule):
dist = numpy.random.binomial(amount, self.hparams.hal_exploit_probability)
self.label_margin(dist, model)
batch_size = self.hparams.eval_batch_size
with torch.no_grad():
for _ in tqdm.trange(amount - dist, desc='Labeling'):
min_sum_dist = 0
cache_labeled = []
for batch_labeled in self.labeled_dataloader():
x_labeled, _ = batch_labeled
h_labeled = model.convolutional(x_labeled)
for layer in model.fully_connected:
h_labeled = layer(h_labeled)
cache_labeled.append(h_labeled)
for batch_index, batch_unlabeled in enumerate(self.unlabeled_dataloader()):
x_unlabeled, _ = batch_unlabeled
h_unlabeled = model.convolutional(x_unlabeled)
for layer in model.fully_connected:
h_unlabeled = layer(h_unlabeled)
sum_dist = torch.full([len(h_unlabeled)], numpy.Inf)
for h_labeled in cache_labeled:
uu = h_unlabeled.pow(2).sum(1, keepdim=True).T
ll = h_labeled.pow(2).sum(1, keepdim=True)
lu = h_labeled @ h_unlabeled.T
dist = (uu + ll - 2*lu).sqrt()
dist_gauss = torch.exp(- dist / self.hparams.hal_gaussian_variance)
sum_dist = sum_dist + dist_gauss.sum(0)
cur_index = (sum_dist - min_sum_dist).argmin()
min_sum_dist = sum_dist[cur_index]
chosen_index = self.data_unlabeled.indices[batch_size * batch_index + cur_index]
self.label_indices([chosen_index])
# Class-Balanced Active Learning for Image Classification
# Javad Zolfaghari Bengar, Joost van de Weijer, Laura Lopez Fuentes, Bogdan Raducanu
def label_class_balanced(self, amount: int, model: pl.LightningModule):
raise NotImplementedError
def label_class_balanced_greedy(self, amount: int, model: pl.LightningModule):
# Greedy: Essentially same as uncertain except:
# - Add to uncertainty values: lambda * (max(0, labeled/classes - labeled_class) - expected_classes)
# - Label points one at a time
batch_size = self.hparams.eval_batch_size
with torch.no_grad():
for _ in tqdm.trange(amount, desc='Labeling'):
max_uncertainty = float('-inf')
for batch_index, batch in enumerate(self.unlabeled_dataloader()):
images, _ = batch
output, _ = model(images)
try:
# Multiclass, softmax
preds = torch.nn.functional.softmax(output, 1)
except IndexError:
# Binary, sigmoid
preds_binary = torch.sigmoid(output)
preds = torch.stack([preds_binary, 1 - preds_binary], 1)
uncertainty_score = -(preds*preds.log()).sum(1).nan_to_num(0)
balance_omega = torch.clamp(len(self.data_train) / len(self.data_train.dataset.classes) - self.class_balance, min=0)
balance_penalty = self.hparams.class_balancing_factor * torch.norm(balance_omega.unsqueeze(0) - preds, p=1, dim=1)
cur_uncertainty, cur_index = torch.max(uncertainty_score - balance_penalty, axis=0)
if cur_uncertainty > max_uncertainty:
max_index = batch_size * batch_index + cur_index
max_uncertainty = cur_uncertainty
chosen_index = self.data_unlabeled.indices[max_index]
self.label_indices([chosen_index])
# Learning Loss for Active Learning
# Donggeun Yoo, In So Kweon
def label_highest_loss(self, amount: int, model: pl.LightningModule):
with torch.no_grad():
uncertainty_list = []
for batch in tqdm.tqdm(self.unlabeled_dataloader(), desc='Labeling'):
x, _ = batch
_, losses_hat = model(x)
uncertainty_list.append(losses_hat)
uncertainty = torch.cat(uncertainty_list)
_, top_indices = uncertainty.topk(amount)
chosen_indices = [self.data_unlabeled.indices[i] for i in top_indices]
self.label_indices(chosen_indices)
# Active Learning for Convolutional Neural Networks: A Core-Set Approach
# Ozan Sener, Silvio Savarese
def label_k_center(self, amount: int, model: pl.LightningModule):
with torch.no_grad():
raise NotImplementedError
def label_k_center_greedy(self, amount: int, model: pl.LightningModule):
# Each time, get the unlabeled data point with the largest minimum distance to a labeled data point
batch_size = self.hparams.eval_batch_size
with torch.no_grad():
for _ in tqdm.trange(amount, desc='Labeling'):
max_min_dist = 0
cache_labeled = []
for batch_labeled in self.labeled_dataloader():
x_labeled, _ = batch_labeled
h_labeled = model.convolutional(x_labeled)
for layer in model.fully_connected:
h_labeled = layer(h_labeled)
cache_labeled.append(h_labeled)
for batch_index, batch_unlabeled in enumerate(self.unlabeled_dataloader()):
x_unlabeled, _ = batch_unlabeled
h_unlabeled = model.convolutional(x_unlabeled)
for layer in model.fully_connected:
h_unlabeled = layer(h_unlabeled)
min_dist = torch.full([len(h_unlabeled)], numpy.Inf)
for h_labeled in cache_labeled:
uu = h_unlabeled.pow(2).sum(1, keepdim=True).T
ll = h_labeled.pow(2).sum(1, keepdim=True)
lu = h_labeled @ h_unlabeled.T
dist = (uu + ll - 2*lu)#.sqrt()
min_dist = torch.min(min_dist, dist.min(0)[0])
cur_index = (min_dist - max_min_dist).argmax()
max_min_dist = min_dist[cur_index]
chosen_index = self.data_unlabeled.indices[batch_size * batch_index + cur_index]
self.label_indices([chosen_index])
# https://github.com/nimarb/pytorch_influence_functions/blob/master/pytorch_influence_functions/influence_function.py
def rank_influence(self, model: pl.LightningModule, real: bool = False):
params = [p for p in model.parameters() if p.requires_grad]
def calc_hvp(loss, s_test):
first_grads = torch.autograd.grad(loss, params, create_graph=True, retain_graph=True)
gradients = torch.autograd.grad(first_grads, params, s_test, create_graph=True, retain_graph=False)
return [gradient.detach() for gradient in gradients]
def calc_v():
loss = 0
for images, targets in self.val_dataloader():
predictions, _ = model(images)
loss += model.loss(predictions, targets, reduction='sum')
loss /= len(self.data_val)
gradients = torch.autograd.grad(loss, params, create_graph=True, retain_graph=False)
return [gradient.detach() for gradient in gradients]
def calc_s_test():
v = calc_v()
s_test = v.copy()
current_iteration = 0
cur_diff = numpy.inf
diff_diff = [numpy.inf]*5
go = True
while go:
for images, targets in self.labeled_dataloader_single():
predictions, _ = model(images)
loss = model.loss(predictions, targets)
hvp = calc_hvp(loss, s_test)
s_test_old = s_test
damp = 1 - self.hparams.influence_damp
scale = 1 / self.hparams.influence_scale
s_test = [(v_i + damp*s_test_i - scale*hvp_i).detach()
for v_i, s_test_i, hvp_i in zip(v, s_test, hvp)
]
for unit in s_test:
if unit.isnan().any():
raise ValueError("One or more values of s_test bacame NaN")
# Check euclidian distance of 5 neighbouring s_test vector pairs,
# if the average difference between those is 0 or lower, assume stable.
prev_diff = cur_diff
cur_diff = torch.sqrt(sum((s_test_old_i - s_test_i).pow(2).sum()
for s_test_old_i, s_test_i in zip(s_test_old, s_test)
))
diff_diff = diff_diff[-4:] + [prev_diff - cur_diff]
not_stabilized = sum(diff_diff) > 0
current_iteration += 1
not_at_max = current_iteration < self.hparams.influence_max_iterations
if not (go := not_stabilized and not_at_max):
break
return s_test
def calc_influences():
# TODO Is this batchable?
influences = []
for images, _ in tqdm.tqdm(self.unlabeled_dataloader_single(), desc='Labeling'):
predictions, _ = model(images)
certainties, targets = model.guess(predictions)
loss = model.loss(predictions, targets)
g_z = [gradients.detach() * certainties for gradients in torch.autograd.grad(loss, params, create_graph=True, retain_graph=False)]
influence = sum(float(torch.sum(s_test_i * g_z_i)) for s_test_i, g_z_i in zip(s_test, g_z))
influences.append(influence)
return torch.tensor(influences)
def calc_influences_real():
influences = []
for images, targets in tqdm.tqdm(self.unlabeled_dataloader_single(), desc='Labeling'):
predictions, _ = model(images)
loss = model.loss(predictions, targets)
g_z = [gradients.detach() for gradients in torch.autograd.grad(loss, params, create_graph=True, retain_graph=False)]
influence = sum(float(torch.sum(s_test_i * g_z_i)) for s_test_i, g_z_i in zip(s_test, g_z))
influences.append(influence)
return torch.tensor(influences)
s_test = calc_s_test()
if real:
return calc_influences_real()
else:
return calc_influences()
def label_influence(self, amount: int, model: pl.LightningModule):
influences = self.rank_influence(model)
_, top_indices = influences.topk(amount)
chosen_indices = [self.data_unlabeled.indices[i] for i in top_indices]
self.label_indices(chosen_indices)
def label_influence_abs(self, amount: int, model: pl.LightningModule):
influences = self.rank_influence(model).abs()
_, top_indices = influences.topk(amount)
chosen_indices = [self.data_unlabeled.indices[i] for i in top_indices]
self.label_indices(chosen_indices)
def label_influence_neg(self, amount: int, model: pl.LightningModule):
influences = -self.rank_influence(model)
_, top_indices = influences.topk(amount)
chosen_indices = [self.data_unlabeled.indices[i] for i in top_indices]
self.label_indices(chosen_indices)
def label_influence_real(self, amount: int, model: pl.LightningModule):
influences = self.rank_influence(model, real=True)
_, top_indices = influences.topk(amount)
chosen_indices = [self.data_unlabeled.indices[i] for i in top_indices]
self.label_indices(chosen_indices)
def label_influence_abs_real(self, amount: int, model: pl.LightningModule):
influences = self.rank_influence(model, real=True).abs()
_, top_indices = influences.topk(amount)
chosen_indices = [self.data_unlabeled.indices[i] for i in top_indices]
self.label_indices(chosen_indices)
def label_influence_neg_real(self, amount: int, model: pl.LightningModule):
influences = -self.rank_influence(model, real=True)
_, top_indices = influences.topk(amount)
chosen_indices = [self.data_unlabeled.indices[i] for i in top_indices]
self.label_indices(chosen_indices)
def label_data(self, model):
aquisition_methods = {
'random': self.label_randomly,
'least-confident': self.label_least_confident,
'margin': self.label_margin,
'entropy': self.label_entropy,
'learning-loss': self.label_highest_loss,
'k-center': self.label_k_center,
'k-center-greedy': self.label_k_center_greedy,
'class-balanced': self.label_class_balanced,
'class-balanced-greedy': self.label_class_balanced_greedy,
'hal-r': self.label_hal_r,
'hal-g': self.label_hal_g,
'influence': self.label_influence,
'influence-abs': self.label_influence_abs,
'influence-neg': self.label_influence_neg,
'influence-real': self.label_influence_real,
'influence-abs-real': self.label_influence_abs_real,
'influence-neg-real': self.label_influence_neg_real,
}
cb_before = self.class_balance / len(self.data_train) * 100
aquisition_methods[self.hparams.aquisition_method](self.hparams.labeling_budget, model)
cb_after = self.class_balance / len(self.data_train) * 100
print('Data labeled, class balance:')
max_class_len = max(len(cls) for cls in self.data_train.dataset.classes)
for num, cls in enumerate(self.data_train.dataset.classes):
print(f"{cls:{max_class_len}} {cb_before[num]:2.0f}% -> {cb_after[num]:2.0f}% ({cb_after[num] - cb_before[num]:+2.0f}%)")
class IALModel(pl.LightningModule):
def __init__(self,
image_size: int,
image_depth: int,
layers_conv: list,
layers_fc: list,
classes: int,
**kwargs
):
super().__init__()
self.save_hyperparameters()
self.example_input_array = torch.zeros([self.hparams.train_batch_size, image_depth, image_size, image_size])
convolutional = []
size_prev = image_depth
final_size = image_size
for size in layers_conv:
convolutional += [
torch.nn.Conv2d(size_prev, size, self.hparams.convolutional_size),
torch.nn.ReLU(),
torch.nn.MaxPool2d(self.hparams.convolutional_pool, self.hparams.convolutional_pool)
]
final_size = (final_size - self.hparams.convolutional_size + 1) // self.hparams.convolutional_pool
size_prev = size
convolutional.append(torch.nn.Flatten(1))
self.convolutional = torch.nn.Sequential(*convolutional)
self.fully_connected = torch.nn.ModuleList()
size_prev = layers_conv[-1] * final_size**2
for size in layers_fc:
self.fully_connected.append(torch.nn.Sequential(torch.nn.Linear(size_prev, size), torch.nn.ReLU()))
size_prev = size
self.classifier = torch.nn.Linear(layers_fc[-1], 1 if classes == 2 else classes)
if classes <= 2:
self.binary = True
self.loss = data_utils.bce_tofloat_loss
else:
self.binary = False
self.loss = torch.nn.functional.cross_entropy
self.accuracy = torchmetrics.Accuracy()
if self.hparams.aquisition_method == 'learning-loss':
self.loss_layers = torch.nn.ModuleList([
torch.nn.Sequential(torch.nn.Linear(size, self.hparams.learning_loss_layer_size), torch.nn.ReLU())
for size in layers_fc
])
self.loss_regressor = torch.nn.Linear(self.hparams.learning_loss_layer_size * len(layers_fc), 1)
def forward(self, images):
pred_loss = torch.empty(0)
hidden = self.convolutional(images)
if self.hparams.aquisition_method == 'learning-loss':
losses = []
for step, layer in enumerate(self.fully_connected):
hidden = layer(hidden)
losses.append(self.loss_layers[step](hidden.detach()))
pred_loss = self.loss_regressor(torch.cat(losses, 1)).squeeze()
else:
for layer in self.fully_connected:
hidden = layer(hidden)
preds = self.classifier(hidden).squeeze(1)
return preds, pred_loss
def training_step(self, batch, batch_idx):
images, labels = batch
labels_hat, losses_hat = self(images)
loss = self.loss(labels_hat, labels)
self.log("running/classification/training/loss", loss, on_step=False, on_epoch=True)
if self.hparams.aquisition_method == 'learning-loss':
losses = self.loss(labels_hat, labels, reduction='none')
loss_loss = data_utils.loss_loss(losses_hat, losses)
self.log("running/learning-loss/training/loss", loss_loss, on_step=False, on_epoch=True)
loss += self.hparams.learning_loss_factor * loss_loss
return loss
def on_train_end(self):
# To force skip early stopping the next epoch
self.trainer.fit_loop.min_epochs = self.trainer.fit_loop.epoch_progress.current.processed + self.hparams.min_epochs
def validation_step(self, batch, batch_idx):
images, labels = batch
labels_hat, losses_hat = self(images)
loss = self.loss(labels_hat, labels)
self.log("running/classification/validation/loss", loss)
accuracy = self.accuracy(labels_hat, labels)
self.log("running/classification/validation/accuracy", accuracy)
num_labeled = float(len(self.trainer.datamodule.data_train.indices))
self.log("running/labeled-data/count", num_labeled)
class_balance = self.trainer.datamodule.class_balance / len(self.trainer.datamodule.data_train)
entropy_labeled = -(class_balance * class_balance.log()).sum()
self.log("running/labeled-data/entropy", entropy_labeled)
if self.hparams.aquisition_method == 'learning-loss':
losses = self.loss(labels_hat, labels, reduction='none')
loss_loss = data_utils.loss_loss(losses_hat, losses)
self.log("running/learning-loss/validation/loss", loss_loss)
return loss
def test_step(self, batch, batch_idx):
images, labels = batch
labels_hat, losses_hat = self(images)
loss = self.loss(labels_hat, labels)
self.log("running/classification/test/loss", loss)
accuracy = self.accuracy(labels_hat, labels)
self.log("running/classification/test/accuracy", accuracy)
if self.hparams.aquisition_method == 'learning-loss':
losses = self.loss(labels_hat, labels, reduction='none')
loss_loss = data_utils.loss_loss(losses_hat, losses)
self.log("running/learning-loss/test/loss", loss_loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
def guess(self, predictions):
if self.binary:
certainties = predictions
targets = (predictions > 0).int()
else:
certainties, targets = predictions.max(1)
return certainties, targets