-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain.py
124 lines (96 loc) · 3.49 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import json
import random
import argparse
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
import numpy as np
import pandas as pd
# from eval_tools.icdar2015 import eval as icdar_eval
from trainer import Train
from data_helpers.datasets import ICDARDataset, Synth800kPreprocessedDataset
from data_helpers.data_utils import icdar_collate
from components.loss import FOTSLoss
from model import FOTSModel
from trainer import Train
def fots_metric(pred, gt):
config = icdar_eval.default_evaluation_params()
output = icdar_eval.eval(pred, gt, config)
return output['method']['precision'], output['method']['recall'], output['method']['hmean']
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def main(config):
"""Main entry point of train module."""
# Initialize the dataset
# Full dataset
# dataset = ICDARDataset('/content/ch4_training_images', '/content/ch4_training_localization_transcription_gt')
data_df = pd.read_csv(f"{config['data_base_dir']}/train.csv")
dataset = Synth800kPreprocessedDataset(config["data_base_dir"], data_df)
# Train test split
val_size = config["val_fraction"]
val_len = int(val_size * len(dataset))
train_len = len(dataset) - val_len
icdar_train_dataset, icdar_val_dataset = torch.utils.data.random_split(
dataset, [train_len, val_len]
)
icdar_train_data_loader = DataLoader(
icdar_train_dataset,
pin_memory=True,
**config["dataset_config"],
worker_init_fn=seed_worker
# collate_fn=icdar_collate
)
icdar_val_data_loader = DataLoader(
icdar_val_dataset,
**config["dataset_config"],
pin_memory=True,
worker_init_fn=seed_worker
# collate_fn=icdar_collate
)
# Initialize the model
model = FOTSModel()
# Count trainable parameters
print(f'The model has {count_parameters(model):,} trainable parameters.')
loss = FOTSLoss(config)
optimizer = model.get_optimizer(config["optimizer"], config["optimizer_config"])
lr_schedular = getattr(
optim.lr_scheduler, config["lr_schedular"], "ReduceLROnPlateau"
)(optimizer, **config["lr_scheduler_config"])
trainer = Train(
model, icdar_train_data_loader, icdar_val_data_loader, loss,
fots_metric, optimizer, lr_schedular, config
)
trainer.train()
def seed_all(seed=28):
"""Seed everything for result reproducibility."""
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False
def seed_worker(_worker_id):
"""Seed data loader workers."""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
if __name__ == '__main__':
# First seed everything
seed_all()
# Parse command line args to get the config file
parser = argparse.ArgumentParser()
parser.add_argument(
'-c', '--config', default="../config/train_config.json",
type=str, help='Training config file path.'
)
args = parser.parse_args()
if args.config is not None:
with open(args.config, "r") as f:
config = json.load(f)
main(config)
else:
print("Invalid training configuration file provided.")