Skip to content
This repository was archived by the owner on Jan 26, 2022. It is now read-only.

Handle NaN loss #110

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions lib/utils/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ def bbox_transform_inv(boxes, gt_boxes, weights=(1.0, 1.0, 1.0, 1.0)):
"""
ex_widths = boxes[:, 2] - boxes[:, 0] + 1.0
ex_heights = boxes[:, 3] - boxes[:, 1] + 1.0
# replace zeros with very small values
ex_widths[ex_widths == 0] = cfg.EPS
ex_heights[ex_heights == 0] = cfg.EPS
ex_ctr_x = boxes[:, 0] + 0.5 * ex_widths
ex_ctr_y = boxes[:, 1] + 0.5 * ex_heights

Expand All @@ -222,8 +225,8 @@ def bbox_transform_inv(boxes, gt_boxes, weights=(1.0, 1.0, 1.0, 1.0)):
wx, wy, ww, wh = weights
targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
targets_dw = ww * np.log(gt_widths / ex_widths)
targets_dh = wh * np.log(gt_heights / ex_heights)
targets_dw = ww * np.log(np.maximum(gt_widths / ex_widths, cfg.EPS))
targets_dh = wh * np.log(np.maximum(gt_heights / ex_heights, cfg.EPS))

targets = np.vstack((targets_dx, targets_dy, targets_dw,
targets_dh)).transpose()
Expand Down
3 changes: 3 additions & 0 deletions tools/train_net_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,9 @@ def main():
net_outputs = maskRCNN(**input_data)
training_stats.UpdateIterStats(net_outputs, inner_iter)
loss = net_outputs['total_loss']
if torch.isnan(loss):
logger.info('NaN loss found! Skipping the current step ...')
break
loss.backward()
optimizer.step()
training_stats.IterToc()
Expand Down