From eb8dbf71c7e386b304735a7f4892f200b797516a Mon Sep 17 00:00:00 2001 From: Mohammad Fattouh Date: Fri, 20 Jul 2018 11:32:00 +0200 Subject: [PATCH] avoid zero divisions and negative values in log --- lib/utils/boxes.py | 7 +++++-- tools/train_net_step.py | 3 +++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/lib/utils/boxes.py b/lib/utils/boxes.py index 1b776a7f..0ad7267d 100644 --- a/lib/utils/boxes.py +++ b/lib/utils/boxes.py @@ -211,6 +211,9 @@ def bbox_transform_inv(boxes, gt_boxes, weights=(1.0, 1.0, 1.0, 1.0)): """ ex_widths = boxes[:, 2] - boxes[:, 0] + 1.0 ex_heights = boxes[:, 3] - boxes[:, 1] + 1.0 + # replace zeros with very small values + ex_widths[ex_widths == 0] = cfg.EPS + ex_heights[ex_heights == 0] = cfg.EPS ex_ctr_x = boxes[:, 0] + 0.5 * ex_widths ex_ctr_y = boxes[:, 1] + 0.5 * ex_heights @@ -222,8 +225,8 @@ def bbox_transform_inv(boxes, gt_boxes, weights=(1.0, 1.0, 1.0, 1.0)): wx, wy, ww, wh = weights targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights - targets_dw = ww * np.log(gt_widths / ex_widths) - targets_dh = wh * np.log(gt_heights / ex_heights) + targets_dw = ww * np.log(np.maximum(gt_widths / ex_widths, cfg.EPS)) + targets_dh = wh * np.log(np.maximum(gt_heights / ex_heights, cfg.EPS)) targets = np.vstack((targets_dx, targets_dy, targets_dw, targets_dh)).transpose() diff --git a/tools/train_net_step.py b/tools/train_net_step.py index 679076bc..406e8283 100644 --- a/tools/train_net_step.py +++ b/tools/train_net_step.py @@ -424,6 +424,9 @@ def main(): net_outputs = maskRCNN(**input_data) training_stats.UpdateIterStats(net_outputs, inner_iter) loss = net_outputs['total_loss'] + if torch.isnan(loss): + logger.info('NaN loss found! Skipping the current step ...') + break loss.backward() optimizer.step() training_stats.IterToc()