-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
executable file
·100 lines (78 loc) · 2.43 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python3
import torch
import pytorch_lightning as pl
import data_utils
from main import parse_arguments
TEXT_DEFAULT = '\033[0m'
TEXT_BOLD = '\033[1m'
TEXT_DIM = '\033[2m'
TEXT_ITALICS = '\033[3m'
TEXT_UNDERLINED = '\033[4m'
def main():
try:
test_datasets()
test_aquisition_methods(True)
test_aquisition_methods(False)
except KeyboardInterrupt:
pass
def check_labeled_data(datamodule: pl.LightningDataModule, labeled_count: int):
assert len(datamodule.data_train) == labeled_count
assert len(datamodule.data_train) - len(set(datamodule.data_train)) == 0
assert len(datamodule.data_unlabeled) - len(set(datamodule.data_unlabeled)) == 0
assert len(set(datamodule.data_train).intersection(set(datamodule.data_unlabeled))) == 0
def test_datasets():
for dataset in ['mnist-binary', 'mnist', 'cifar10', 'svhn']:
print(f"\n{TEXT_BOLD}Testing dataset {dataset}{TEXT_DEFAULT}")
args = parse_arguments([
dataset,
'random'
])
use_gpu = torch.cuda.is_available()
trainer = pl.Trainer(
gpus=int(use_gpu),
auto_select_gpus=use_gpu,
logger=None,
enable_checkpointing=False,
max_epochs=1
)
model, datamodule = data_utils.get_modules(args)
trainer.fit(model, datamodule)
if trainer.interrupted:
raise KeyboardInterrupt
trainer.test(model, datamodule)
if trainer.interrupted:
raise KeyboardInterrupt
check_labeled_data(datamodule, args.initial_labels)
def test_aquisition_methods(binary: bool = False):
for aquisition_method in [
'random',
'least-confident', 'margin', 'entropy',
'learning-loss',
'k-center-greedy',
'class-balanced-greedy',
'hal-r', 'hal-g',
'influence', 'influence-abs', 'influence-neg',
'influence-real', 'influence-abs-real', 'influence-neg-real',
]:
print(f"\n{TEXT_BOLD}Testing aquisition method {aquisition_method}{TEXT_DEFAULT}")
args = parse_arguments([
'mnist-binary' if binary else 'mnist',
aquisition_method,
'--initial-labels=1000',
'--labeling-budget=10',
'--class-balance=1'
])
# trainer = pl.Trainer(
# gpus=0,
# logger=None,
# enable_checkpointing=False,
# max_epochs=60
# )
model, datamodule = data_utils.get_modules(args)
datamodule.setup('fit')
# trainer.fit(model, datamodule)
check_labeled_data(datamodule, args.initial_labels)
datamodule.label_data(model)
check_labeled_data(datamodule, args.initial_labels + args.labeling_budget)
if __name__ == "__main__":
main()