-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathmain.py
79 lines (59 loc) · 3.2 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import logging
from parse_args import parse_arguments
from load_data import build_splits_baseline, build_splits_domain_disentangle, build_splits_clip_disentangle
from experiments.baseline import BaselineExperiment
from experiments.domain_disentangle import DomainDisentangleExperiment
from experiments.clip_disentangle import CLIPDisentangleExperiment
def setup_experiment(opt):
if opt['experiment'] == 'baseline':
experiment = BaselineExperiment(opt)
train_loader, validation_loader, test_loader = build_splits_baseline(opt)
elif opt['experiment'] == 'domain_disentangle':
experiment = DomainDisentangleExperiment(opt)
train_loader, validation_loader, test_loader = build_splits_domain_disentangle(opt)
elif opt['experiment'] == 'clip_disentangle':
experiment = CLIPDisentangleExperiment(opt)
train_loader, validation_loader, test_loader = build_splits_clip_disentangle(opt)
else:
raise ValueError('Experiment not yet supported.')
return experiment, train_loader, validation_loader, test_loader
def main(opt):
experiment, train_loader, validation_loader, test_loader = setup_experiment(opt)
if not opt['test']: # Skip training if '--test' flag is set
iteration = 0
best_accuracy = 0
total_train_loss = 0
# Restore last checkpoint
if os.path.exists(f'{opt["output_path"]}/last_checkpoint.pth'):
iteration, best_accuracy, total_train_loss = experiment.load_checkpoint(f'{opt["output_path"]}/last_checkpoint.pth')
else:
logging.info(opt)
# Train loop
while iteration < opt['max_iterations']:
for data in train_loader:
total_train_loss += experiment.train_iteration(data)
if iteration % opt['print_every'] == 0:
logging.info(f'[TRAIN - {iteration}] Loss: {total_train_loss / (iteration + 1)}')
if iteration % opt['validate_every'] == 0:
# Run validation
val_accuracy, val_loss = experiment.validate(validation_loader)
logging.info(f'[VAL - {iteration}] Loss: {val_loss} | Accuracy: {(100 * val_accuracy):.2f}')
if val_accuracy > best_accuracy:
best_accuracy = val_accuracy
experiment.save_checkpoint(f'{opt["output_path"]}/best_checkpoint.pth', iteration, best_accuracy, total_train_loss)
experiment.save_checkpoint(f'{opt["output_path"]}/last_checkpoint.pth', iteration, best_accuracy, total_train_loss)
iteration += 1
if iteration > opt['max_iterations']:
break
# Test
experiment.load_checkpoint(f'{opt["output_path"]}/best_checkpoint.pth')
test_accuracy, _ = experiment.validate(test_loader)
logging.info(f'[TEST] Accuracy: {(100 * test_accuracy):.2f}')
if __name__ == '__main__':
opt = parse_arguments()
# Setup output directories
os.makedirs(opt['output_path'], exist_ok=True)
# Setup logger
logging.basicConfig(filename=f'{opt["output_path"]}/log.txt', format='%(message)s', level=logging.INFO, filemode='a')
main(opt)