22
22
def train_eval_model (model ,
23
23
criterion ,
24
24
optimizer ,
25
+ image_dataset ,
25
26
dataloader ,
26
27
tfboard_writer ,
27
28
benchmark ,
@@ -60,6 +61,9 @@ def train_eval_model(model,
60
61
last_epoch = cfg .TRAIN .START_EPOCH - 1 )
61
62
62
63
for epoch in range (start_epoch , num_epochs ):
64
+ # Reset seed after evaluation per epoch
65
+ torch .manual_seed (cfg .RANDOM_SEED + epoch + 1 )
66
+ dataloader ['train' ] = get_dataloader (image_dataset ['train' ], shuffle = True , fix_seed = False )
63
67
print ('Epoch {}/{}' .format (epoch , num_epochs - 1 ))
64
68
print ('-' * 10 )
65
69
@@ -73,127 +77,126 @@ def train_eval_model(model,
73
77
iter_num = 0
74
78
75
79
# Iterate over data.
76
- while iter_num < cfg .TRAIN .EPOCH_ITERS :
77
- for inputs in dataloader ['train' ]:
78
- if iter_num >= cfg .TRAIN .EPOCH_ITERS :
79
- break
80
- if model .module .device != torch .device ('cpu' ):
81
- inputs = data_to_cuda (inputs )
82
-
83
- iter_num = iter_num + 1
84
-
85
- # zero the parameter gradients
86
- optimizer .zero_grad ()
87
-
88
- with torch .set_grad_enabled (True ):
89
- # forward
90
- outputs = model (inputs )
91
-
92
- if cfg .PROBLEM .TYPE == '2GM' :
93
- assert 'ds_mat' in outputs
94
- assert 'perm_mat' in outputs
80
+ for inputs in dataloader ['train' ]:
81
+ if iter_num >= cfg .TRAIN .EPOCH_ITERS :
82
+ break
83
+ if model .module .device != torch .device ('cpu' ):
84
+ inputs = data_to_cuda (inputs )
85
+
86
+ iter_num = iter_num + 1
87
+
88
+ # zero the parameter gradients
89
+ optimizer .zero_grad ()
90
+
91
+ with torch .set_grad_enabled (True ):
92
+ # forward
93
+ outputs = model (inputs )
94
+
95
+ if cfg .PROBLEM .TYPE == '2GM' :
96
+ assert 'ds_mat' in outputs
97
+ assert 'perm_mat' in outputs
98
+ assert 'gt_perm_mat' in outputs
99
+
100
+ # compute loss
101
+ if cfg .TRAIN .LOSS_FUNC == 'offset' :
102
+ d_gt , grad_mask = displacement (outputs ['gt_perm_mat' ], * outputs ['Ps' ], outputs ['ns' ][0 ])
103
+ d_pred , _ = displacement (outputs ['ds_mat' ], * outputs ['Ps' ], outputs ['ns' ][0 ])
104
+ loss = criterion (d_pred , d_gt , grad_mask )
105
+ elif cfg .TRAIN .LOSS_FUNC in ['perm' , 'ce' , 'hung' ]:
106
+ loss = criterion (outputs ['ds_mat' ], outputs ['gt_perm_mat' ], * outputs ['ns' ])
107
+ elif cfg .TRAIN .LOSS_FUNC == 'hamming' :
108
+ loss = criterion (outputs ['perm_mat' ], outputs ['gt_perm_mat' ])
109
+ elif cfg .TRAIN .LOSS_FUNC == 'custom' :
110
+ loss = torch .sum (outputs ['loss' ])
111
+ else :
112
+ raise ValueError (
113
+ 'Unsupported loss function {} for problem type {}' .format (cfg .TRAIN .LOSS_FUNC ,
114
+ cfg .PROBLEM .TYPE ))
115
+
116
+ # compute accuracy
117
+ acc = matching_accuracy (outputs ['perm_mat' ], outputs ['gt_perm_mat' ], outputs ['ns' ][0 ])
118
+
119
+ elif cfg .PROBLEM .TYPE in ['MGM' , 'MGM3' ]:
120
+ assert 'ds_mat_list' in outputs
121
+ assert 'graph_indices' in outputs
122
+ assert 'perm_mat_list' in outputs
123
+ if not 'gt_perm_mat_list' in outputs :
95
124
assert 'gt_perm_mat' in outputs
96
-
97
- # compute loss
98
- if cfg .TRAIN .LOSS_FUNC == 'offset' :
99
- d_gt , grad_mask = displacement (outputs ['gt_perm_mat' ], * outputs ['Ps' ], outputs ['ns' ][0 ])
100
- d_pred , _ = displacement (outputs ['ds_mat' ], * outputs ['Ps' ], outputs ['ns' ][0 ])
101
- loss = criterion (d_pred , d_gt , grad_mask )
102
- elif cfg .TRAIN .LOSS_FUNC in ['perm' , 'ce' , 'hung' ]:
103
- loss = criterion (outputs ['ds_mat' ], outputs ['gt_perm_mat' ], * outputs ['ns' ])
104
- elif cfg .TRAIN .LOSS_FUNC == 'hamming' :
105
- loss = criterion (outputs ['perm_mat' ], outputs ['gt_perm_mat' ])
106
- elif cfg .TRAIN .LOSS_FUNC == 'custom' :
107
- loss = torch .sum (outputs ['loss' ])
108
- else :
109
- raise ValueError (
110
- 'Unsupported loss function {} for problem type {}' .format (cfg .TRAIN .LOSS_FUNC ,
111
- cfg .PROBLEM .TYPE ))
112
-
113
- # compute accuracy
114
- acc = matching_accuracy (outputs ['perm_mat' ], outputs ['gt_perm_mat' ], outputs ['ns' ][0 ])
115
-
116
- elif cfg .PROBLEM .TYPE in ['MGM' , 'MGM3' ]:
117
- assert 'ds_mat_list' in outputs
118
- assert 'graph_indices' in outputs
119
- assert 'perm_mat_list' in outputs
120
- if not 'gt_perm_mat_list' in outputs :
121
- assert 'gt_perm_mat' in outputs
122
- gt_perm_mat_list = [outputs ['gt_perm_mat' ][idx ] for idx in outputs ['graph_indices' ]]
123
- else :
124
- gt_perm_mat_list = outputs ['gt_perm_mat_list' ]
125
-
126
- # compute loss & accuracy
127
- if cfg .TRAIN .LOSS_FUNC in ['perm' , 'ce' 'hung' ]:
128
- loss = torch .zeros (1 , device = model .module .device )
129
- ns = outputs ['ns' ]
130
- for s_pred , x_gt , (idx_src , idx_tgt ) in \
131
- zip (outputs ['ds_mat_list' ], gt_perm_mat_list , outputs ['graph_indices' ]):
132
- l = criterion (s_pred , x_gt , ns [idx_src ], ns [idx_tgt ])
133
- loss += l
134
- loss /= len (outputs ['ds_mat_list' ])
135
- elif cfg .TRAIN .LOSS_FUNC == 'plain' :
136
- loss = torch .sum (outputs ['loss' ])
137
- else :
138
- raise ValueError (
139
- 'Unsupported loss function {} for problem type {}' .format (cfg .TRAIN .LOSS_FUNC ,
140
- cfg .PROBLEM .TYPE ))
141
-
142
- # compute accuracy
143
- acc = torch .zeros (1 , device = model .module .device )
144
- for x_pred , x_gt , (idx_src , idx_tgt ) in \
145
- zip (outputs ['perm_mat_list' ], gt_perm_mat_list , outputs ['graph_indices' ]):
146
- a = matching_accuracy (x_pred , x_gt , ns [idx_src ])
147
- acc += torch .sum (a )
148
- acc /= len (outputs ['perm_mat_list' ])
125
+ gt_perm_mat_list = [outputs ['gt_perm_mat' ][idx ] for idx in outputs ['graph_indices' ]]
149
126
else :
150
- raise ValueError ('Unknown problem type {}' .format (cfg .PROBLEM .TYPE ))
151
-
152
- # backward + optimize
153
- if cfg .FP16 :
154
- with amp .scale_loss (loss , optimizer ) as scaled_loss :
155
- scaled_loss .backward ()
127
+ gt_perm_mat_list = outputs ['gt_perm_mat_list' ]
128
+
129
+ # compute loss & accuracy
130
+ if cfg .TRAIN .LOSS_FUNC in ['perm' , 'ce' 'hung' ]:
131
+ loss = torch .zeros (1 , device = model .module .device )
132
+ ns = outputs ['ns' ]
133
+ for s_pred , x_gt , (idx_src , idx_tgt ) in \
134
+ zip (outputs ['ds_mat_list' ], gt_perm_mat_list , outputs ['graph_indices' ]):
135
+ l = criterion (s_pred , x_gt , ns [idx_src ], ns [idx_tgt ])
136
+ loss += l
137
+ loss /= len (outputs ['ds_mat_list' ])
138
+ elif cfg .TRAIN .LOSS_FUNC == 'plain' :
139
+ loss = torch .sum (outputs ['loss' ])
156
140
else :
157
- loss .backward ()
158
- optimizer .step ()
159
-
160
- batch_num = inputs ['batch_size' ]
161
-
162
- # tfboard writer
163
- loss_dict = dict ()
164
- loss_dict ['loss' ] = loss .item ()
165
- tfboard_writer .add_scalars ('loss' , loss_dict , epoch * cfg .TRAIN .EPOCH_ITERS + iter_num )
141
+ raise ValueError (
142
+ 'Unsupported loss function {} for problem type {}' .format (cfg .TRAIN .LOSS_FUNC ,
143
+ cfg .PROBLEM .TYPE ))
144
+
145
+ # compute accuracy
146
+ acc = torch .zeros (1 , device = model .module .device )
147
+ for x_pred , x_gt , (idx_src , idx_tgt ) in \
148
+ zip (outputs ['perm_mat_list' ], gt_perm_mat_list , outputs ['graph_indices' ]):
149
+ a = matching_accuracy (x_pred , x_gt , ns [idx_src ])
150
+ acc += torch .sum (a )
151
+ acc /= len (outputs ['perm_mat_list' ])
152
+ else :
153
+ raise ValueError ('Unknown problem type {}' .format (cfg .PROBLEM .TYPE ))
154
+
155
+ # backward + optimize
156
+ if cfg .FP16 :
157
+ with amp .scale_loss (loss , optimizer ) as scaled_loss :
158
+ scaled_loss .backward ()
159
+ else :
160
+ loss .backward ()
161
+ optimizer .step ()
162
+
163
+ batch_num = inputs ['batch_size' ]
164
+
165
+ # tfboard writer
166
+ loss_dict = dict ()
167
+ loss_dict ['loss' ] = loss .item ()
168
+ tfboard_writer .add_scalars ('loss' , loss_dict , epoch * cfg .TRAIN .EPOCH_ITERS + iter_num )
169
+
170
+ accdict = dict ()
171
+ accdict ['matching accuracy' ] = torch .mean (acc )
172
+ tfboard_writer .add_scalars (
173
+ 'training accuracy' ,
174
+ accdict ,
175
+ epoch * cfg .TRAIN .EPOCH_ITERS + iter_num
176
+ )
177
+
178
+ # statistics
179
+ running_loss += loss .item () * batch_num
180
+ epoch_loss += loss .item () * batch_num
181
+
182
+ if iter_num % cfg .STATISTIC_STEP == 0 :
183
+ running_speed = cfg .STATISTIC_STEP * batch_num / (time .time () - running_since )
184
+ print ('Epoch {:<4} Iteration {:<4} {:>4.2f}sample/s Loss={:<8.4f}'
185
+ .format (epoch , iter_num , running_speed , running_loss / cfg .STATISTIC_STEP / batch_num ))
186
+ tfboard_writer .add_scalars (
187
+ 'speed' ,
188
+ {'speed' : running_speed },
189
+ epoch * cfg .TRAIN .EPOCH_ITERS + iter_num
190
+ )
166
191
167
- accdict = dict ()
168
- accdict ['matching accuracy' ] = torch .mean (acc )
169
192
tfboard_writer .add_scalars (
170
- 'training accuracy ' ,
171
- accdict ,
193
+ 'learning rate ' ,
194
+ { 'lr_{}' . format ( i ): x [ 'lr' ] for i , x in enumerate ( optimizer . param_groups )} ,
172
195
epoch * cfg .TRAIN .EPOCH_ITERS + iter_num
173
196
)
174
197
175
- # statistics
176
- running_loss += loss .item () * batch_num
177
- epoch_loss += loss .item () * batch_num
178
-
179
- if iter_num % cfg .STATISTIC_STEP == 0 :
180
- running_speed = cfg .STATISTIC_STEP * batch_num / (time .time () - running_since )
181
- print ('Epoch {:<4} Iteration {:<4} {:>4.2f}sample/s Loss={:<8.4f}'
182
- .format (epoch , iter_num , running_speed , running_loss / cfg .STATISTIC_STEP / batch_num ))
183
- tfboard_writer .add_scalars (
184
- 'speed' ,
185
- {'speed' : running_speed },
186
- epoch * cfg .TRAIN .EPOCH_ITERS + iter_num
187
- )
188
-
189
- tfboard_writer .add_scalars (
190
- 'learning rate' ,
191
- {'lr_{}' .format (i ): x ['lr' ] for i , x in enumerate (optimizer .param_groups )},
192
- epoch * cfg .TRAIN .EPOCH_ITERS + iter_num
193
- )
194
-
195
- running_loss = 0.0
196
- running_since = time .time ()
198
+ running_loss = 0.0
199
+ running_since = time .time ()
197
200
198
201
epoch_loss = epoch_loss / cfg .TRAIN .EPOCH_ITERS / batch_num
199
202
@@ -253,6 +256,7 @@ def train_eval_model(model,
253
256
filter = cfg .PROBLEM .FILTER ,
254
257
** ds_dict )
255
258
for x in ('train' , 'test' )}
259
+
256
260
image_dataset = {
257
261
x : GMDataset (name = cfg .DATASET_FULL_NAME ,
258
262
bm = benchmark [x ],
@@ -325,7 +329,7 @@ def train_eval_model(model,
325
329
326
330
with DupStdoutFileManager (str (Path (cfg .OUTPUT_PATH ) / ('train_log_' + now_time + '.log' ))) as _ :
327
331
print_easydict (cfg )
328
- model = train_eval_model (model , criterion , optimizer , dataloader , tfboardwriter , benchmark ,
332
+ model = train_eval_model (model , criterion , optimizer , image_dataset , dataloader , tfboardwriter , benchmark ,
329
333
num_epochs = cfg .TRAIN .NUM_EPOCHS ,
330
334
start_epoch = cfg .TRAIN .START_EPOCH ,
331
335
xls_wb = wb )
0 commit comments