Skip to content

Commit a371646

Browse files
committed
Merge branch 'master' of github.com:Thinklab-SJTU/ThinkMatch
2 parents 0f4f40b + 336baac commit a371646

File tree

3 files changed

+123
-115
lines changed

3 files changed

+123
-115
lines changed

eval.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def eval_model(model, classes, bm, last_epoch=True, verbose=False, xls_sheet=Non
3232
length=cfg.EVAL.SAMPLES,
3333
cls=cls,
3434
using_all_graphs=cfg.PROBLEM.TEST_ALL_GRAPHS)
35-
torch.manual_seed(cfg.RANDOM_SEED)
35+
36+
torch.manual_seed(cfg.RANDOM_SEED) # Fix fetched data in test-set to prevent variance
37+
3638
dataloader = get_dataloader(image_dataset, shuffle=True)
3739
dataloaders.append(dataloader)
3840

src/dataset/data_loader.py

+2
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,11 @@ def to_pyg_graph(A, P):
8585
def get_pair(self, idx, cls):
8686
#anno_pair, perm_mat = self.bm.get_pair(self.cls if self.cls is not None else
8787
# (idx % (cfg.BATCH_SIZE * len(self.classes))) // cfg.BATCH_SIZE)
88+
8889
cls_num = random.randrange(0, len(self.classes))
8990
ids = list(self.id_combination[cls_num][idx % self.length_list[cls_num]])
9091
anno_pair, perm_mat_, id_list = self.bm.get_data(ids)
92+
9193
perm_mat = perm_mat_[(0, 1)].toarray()
9294
while min(perm_mat.shape[0], perm_mat.shape[1]) <= 2 or perm_mat.size >= cfg.PROBLEM.MAX_PROB_SIZE > 0:
9395
anno_pair, perm_mat_, id_list = self.bm.rand_get_data(cls)

train_eval.py

+118-114
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
def train_eval_model(model,
2323
criterion,
2424
optimizer,
25+
image_dataset,
2526
dataloader,
2627
tfboard_writer,
2728
benchmark,
@@ -60,6 +61,9 @@ def train_eval_model(model,
6061
last_epoch=cfg.TRAIN.START_EPOCH - 1)
6162

6263
for epoch in range(start_epoch, num_epochs):
64+
# Reset seed after evaluation per epoch
65+
torch.manual_seed(cfg.RANDOM_SEED + epoch + 1)
66+
dataloader['train'] = get_dataloader(image_dataset['train'], shuffle=True, fix_seed=False)
6367
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
6468
print('-' * 10)
6569

@@ -73,127 +77,126 @@ def train_eval_model(model,
7377
iter_num = 0
7478

7579
# Iterate over data.
76-
while iter_num < cfg.TRAIN.EPOCH_ITERS:
77-
for inputs in dataloader['train']:
78-
if iter_num >= cfg.TRAIN.EPOCH_ITERS:
79-
break
80-
if model.module.device != torch.device('cpu'):
81-
inputs = data_to_cuda(inputs)
82-
83-
iter_num = iter_num + 1
84-
85-
# zero the parameter gradients
86-
optimizer.zero_grad()
87-
88-
with torch.set_grad_enabled(True):
89-
# forward
90-
outputs = model(inputs)
91-
92-
if cfg.PROBLEM.TYPE == '2GM':
93-
assert 'ds_mat' in outputs
94-
assert 'perm_mat' in outputs
80+
for inputs in dataloader['train']:
81+
if iter_num >= cfg.TRAIN.EPOCH_ITERS:
82+
break
83+
if model.module.device != torch.device('cpu'):
84+
inputs = data_to_cuda(inputs)
85+
86+
iter_num = iter_num + 1
87+
88+
# zero the parameter gradients
89+
optimizer.zero_grad()
90+
91+
with torch.set_grad_enabled(True):
92+
# forward
93+
outputs = model(inputs)
94+
95+
if cfg.PROBLEM.TYPE == '2GM':
96+
assert 'ds_mat' in outputs
97+
assert 'perm_mat' in outputs
98+
assert 'gt_perm_mat' in outputs
99+
100+
# compute loss
101+
if cfg.TRAIN.LOSS_FUNC == 'offset':
102+
d_gt, grad_mask = displacement(outputs['gt_perm_mat'], *outputs['Ps'], outputs['ns'][0])
103+
d_pred, _ = displacement(outputs['ds_mat'], *outputs['Ps'], outputs['ns'][0])
104+
loss = criterion(d_pred, d_gt, grad_mask)
105+
elif cfg.TRAIN.LOSS_FUNC in ['perm', 'ce', 'hung']:
106+
loss = criterion(outputs['ds_mat'], outputs['gt_perm_mat'], *outputs['ns'])
107+
elif cfg.TRAIN.LOSS_FUNC == 'hamming':
108+
loss = criterion(outputs['perm_mat'], outputs['gt_perm_mat'])
109+
elif cfg.TRAIN.LOSS_FUNC == 'custom':
110+
loss = torch.sum(outputs['loss'])
111+
else:
112+
raise ValueError(
113+
'Unsupported loss function {} for problem type {}'.format(cfg.TRAIN.LOSS_FUNC,
114+
cfg.PROBLEM.TYPE))
115+
116+
# compute accuracy
117+
acc = matching_accuracy(outputs['perm_mat'], outputs['gt_perm_mat'], outputs['ns'][0])
118+
119+
elif cfg.PROBLEM.TYPE in ['MGM', 'MGM3']:
120+
assert 'ds_mat_list' in outputs
121+
assert 'graph_indices' in outputs
122+
assert 'perm_mat_list' in outputs
123+
if not 'gt_perm_mat_list' in outputs:
95124
assert 'gt_perm_mat' in outputs
96-
97-
# compute loss
98-
if cfg.TRAIN.LOSS_FUNC == 'offset':
99-
d_gt, grad_mask = displacement(outputs['gt_perm_mat'], *outputs['Ps'], outputs['ns'][0])
100-
d_pred, _ = displacement(outputs['ds_mat'], *outputs['Ps'], outputs['ns'][0])
101-
loss = criterion(d_pred, d_gt, grad_mask)
102-
elif cfg.TRAIN.LOSS_FUNC in ['perm', 'ce', 'hung']:
103-
loss = criterion(outputs['ds_mat'], outputs['gt_perm_mat'], *outputs['ns'])
104-
elif cfg.TRAIN.LOSS_FUNC == 'hamming':
105-
loss = criterion(outputs['perm_mat'], outputs['gt_perm_mat'])
106-
elif cfg.TRAIN.LOSS_FUNC == 'custom':
107-
loss = torch.sum(outputs['loss'])
108-
else:
109-
raise ValueError(
110-
'Unsupported loss function {} for problem type {}'.format(cfg.TRAIN.LOSS_FUNC,
111-
cfg.PROBLEM.TYPE))
112-
113-
# compute accuracy
114-
acc = matching_accuracy(outputs['perm_mat'], outputs['gt_perm_mat'], outputs['ns'][0])
115-
116-
elif cfg.PROBLEM.TYPE in ['MGM', 'MGM3']:
117-
assert 'ds_mat_list' in outputs
118-
assert 'graph_indices' in outputs
119-
assert 'perm_mat_list' in outputs
120-
if not 'gt_perm_mat_list' in outputs:
121-
assert 'gt_perm_mat' in outputs
122-
gt_perm_mat_list = [outputs['gt_perm_mat'][idx] for idx in outputs['graph_indices']]
123-
else:
124-
gt_perm_mat_list = outputs['gt_perm_mat_list']
125-
126-
# compute loss & accuracy
127-
if cfg.TRAIN.LOSS_FUNC in ['perm', 'ce' 'hung']:
128-
loss = torch.zeros(1, device=model.module.device)
129-
ns = outputs['ns']
130-
for s_pred, x_gt, (idx_src, idx_tgt) in \
131-
zip(outputs['ds_mat_list'], gt_perm_mat_list, outputs['graph_indices']):
132-
l = criterion(s_pred, x_gt, ns[idx_src], ns[idx_tgt])
133-
loss += l
134-
loss /= len(outputs['ds_mat_list'])
135-
elif cfg.TRAIN.LOSS_FUNC == 'plain':
136-
loss = torch.sum(outputs['loss'])
137-
else:
138-
raise ValueError(
139-
'Unsupported loss function {} for problem type {}'.format(cfg.TRAIN.LOSS_FUNC,
140-
cfg.PROBLEM.TYPE))
141-
142-
# compute accuracy
143-
acc = torch.zeros(1, device=model.module.device)
144-
for x_pred, x_gt, (idx_src, idx_tgt) in \
145-
zip(outputs['perm_mat_list'], gt_perm_mat_list, outputs['graph_indices']):
146-
a = matching_accuracy(x_pred, x_gt, ns[idx_src])
147-
acc += torch.sum(a)
148-
acc /= len(outputs['perm_mat_list'])
125+
gt_perm_mat_list = [outputs['gt_perm_mat'][idx] for idx in outputs['graph_indices']]
149126
else:
150-
raise ValueError('Unknown problem type {}'.format(cfg.PROBLEM.TYPE))
151-
152-
# backward + optimize
153-
if cfg.FP16:
154-
with amp.scale_loss(loss, optimizer) as scaled_loss:
155-
scaled_loss.backward()
127+
gt_perm_mat_list = outputs['gt_perm_mat_list']
128+
129+
# compute loss & accuracy
130+
if cfg.TRAIN.LOSS_FUNC in ['perm', 'ce' 'hung']:
131+
loss = torch.zeros(1, device=model.module.device)
132+
ns = outputs['ns']
133+
for s_pred, x_gt, (idx_src, idx_tgt) in \
134+
zip(outputs['ds_mat_list'], gt_perm_mat_list, outputs['graph_indices']):
135+
l = criterion(s_pred, x_gt, ns[idx_src], ns[idx_tgt])
136+
loss += l
137+
loss /= len(outputs['ds_mat_list'])
138+
elif cfg.TRAIN.LOSS_FUNC == 'plain':
139+
loss = torch.sum(outputs['loss'])
156140
else:
157-
loss.backward()
158-
optimizer.step()
159-
160-
batch_num = inputs['batch_size']
161-
162-
# tfboard writer
163-
loss_dict = dict()
164-
loss_dict['loss'] = loss.item()
165-
tfboard_writer.add_scalars('loss', loss_dict, epoch * cfg.TRAIN.EPOCH_ITERS + iter_num)
141+
raise ValueError(
142+
'Unsupported loss function {} for problem type {}'.format(cfg.TRAIN.LOSS_FUNC,
143+
cfg.PROBLEM.TYPE))
144+
145+
# compute accuracy
146+
acc = torch.zeros(1, device=model.module.device)
147+
for x_pred, x_gt, (idx_src, idx_tgt) in \
148+
zip(outputs['perm_mat_list'], gt_perm_mat_list, outputs['graph_indices']):
149+
a = matching_accuracy(x_pred, x_gt, ns[idx_src])
150+
acc += torch.sum(a)
151+
acc /= len(outputs['perm_mat_list'])
152+
else:
153+
raise ValueError('Unknown problem type {}'.format(cfg.PROBLEM.TYPE))
154+
155+
# backward + optimize
156+
if cfg.FP16:
157+
with amp.scale_loss(loss, optimizer) as scaled_loss:
158+
scaled_loss.backward()
159+
else:
160+
loss.backward()
161+
optimizer.step()
162+
163+
batch_num = inputs['batch_size']
164+
165+
# tfboard writer
166+
loss_dict = dict()
167+
loss_dict['loss'] = loss.item()
168+
tfboard_writer.add_scalars('loss', loss_dict, epoch * cfg.TRAIN.EPOCH_ITERS + iter_num)
169+
170+
accdict = dict()
171+
accdict['matching accuracy'] = torch.mean(acc)
172+
tfboard_writer.add_scalars(
173+
'training accuracy',
174+
accdict,
175+
epoch * cfg.TRAIN.EPOCH_ITERS + iter_num
176+
)
177+
178+
# statistics
179+
running_loss += loss.item() * batch_num
180+
epoch_loss += loss.item() * batch_num
181+
182+
if iter_num % cfg.STATISTIC_STEP == 0:
183+
running_speed = cfg.STATISTIC_STEP * batch_num / (time.time() - running_since)
184+
print('Epoch {:<4} Iteration {:<4} {:>4.2f}sample/s Loss={:<8.4f}'
185+
.format(epoch, iter_num, running_speed, running_loss / cfg.STATISTIC_STEP / batch_num))
186+
tfboard_writer.add_scalars(
187+
'speed',
188+
{'speed': running_speed},
189+
epoch * cfg.TRAIN.EPOCH_ITERS + iter_num
190+
)
166191

167-
accdict = dict()
168-
accdict['matching accuracy'] = torch.mean(acc)
169192
tfboard_writer.add_scalars(
170-
'training accuracy',
171-
accdict,
193+
'learning rate',
194+
{'lr_{}'.format(i): x['lr'] for i, x in enumerate(optimizer.param_groups)},
172195
epoch * cfg.TRAIN.EPOCH_ITERS + iter_num
173196
)
174197

175-
# statistics
176-
running_loss += loss.item() * batch_num
177-
epoch_loss += loss.item() * batch_num
178-
179-
if iter_num % cfg.STATISTIC_STEP == 0:
180-
running_speed = cfg.STATISTIC_STEP * batch_num / (time.time() - running_since)
181-
print('Epoch {:<4} Iteration {:<4} {:>4.2f}sample/s Loss={:<8.4f}'
182-
.format(epoch, iter_num, running_speed, running_loss / cfg.STATISTIC_STEP / batch_num))
183-
tfboard_writer.add_scalars(
184-
'speed',
185-
{'speed': running_speed},
186-
epoch * cfg.TRAIN.EPOCH_ITERS + iter_num
187-
)
188-
189-
tfboard_writer.add_scalars(
190-
'learning rate',
191-
{'lr_{}'.format(i): x['lr'] for i, x in enumerate(optimizer.param_groups)},
192-
epoch * cfg.TRAIN.EPOCH_ITERS + iter_num
193-
)
194-
195-
running_loss = 0.0
196-
running_since = time.time()
198+
running_loss = 0.0
199+
running_since = time.time()
197200

198201
epoch_loss = epoch_loss / cfg.TRAIN.EPOCH_ITERS / batch_num
199202

@@ -253,6 +256,7 @@ def train_eval_model(model,
253256
filter=cfg.PROBLEM.FILTER,
254257
**ds_dict)
255258
for x in ('train', 'test')}
259+
256260
image_dataset = {
257261
x: GMDataset(name=cfg.DATASET_FULL_NAME,
258262
bm=benchmark[x],
@@ -325,7 +329,7 @@ def train_eval_model(model,
325329

326330
with DupStdoutFileManager(str(Path(cfg.OUTPUT_PATH) / ('train_log_' + now_time + '.log'))) as _:
327331
print_easydict(cfg)
328-
model = train_eval_model(model, criterion, optimizer, dataloader, tfboardwriter, benchmark,
332+
model = train_eval_model(model, criterion, optimizer, image_dataset, dataloader, tfboardwriter, benchmark,
329333
num_epochs=cfg.TRAIN.NUM_EPOCHS,
330334
start_epoch=cfg.TRAIN.START_EPOCH,
331335
xls_wb=wb)

0 commit comments

Comments
 (0)