Skip to content

Commit

Permalink
change logic to avoid loss divide by zero
Browse files Browse the repository at this point in the history
  • Loading branch information
mtjhl authored and Chilicyy committed Nov 4, 2022
1 parent 4ee4300 commit 2f83a8e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
17 changes: 11 additions & 6 deletions yolov6/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,8 @@ def __call__(

target_scores_sum = target_scores.sum()
# avoid devide zero error, devide by zero will cause loss to be inf or nan.
if target_scores_sum == 0:
loss_cls *= 0
else:
# if target_scores_sum is 0, loss_cls equals to 0 alson
if target_scores_sum > 0:
loss_cls /= target_scores_sum

# bbox loss
Expand Down Expand Up @@ -228,8 +227,11 @@ def forward(self, pred_dist, pred_bboxes, anchor_points,
target_scores.sum(-1), fg_mask).unsqueeze(-1)
loss_iou = self.iou_loss(pred_bboxes_pos,
target_bboxes_pos) * bbox_weight
loss_iou = loss_iou.sum() / target_scores_sum

if target_scores_sum == 0:
loss_iou = loss_iou.sum()
else:
loss_iou = loss_iou.sum() / target_scores_sum

# dfl loss
if self.use_dfl:
dist_mask = fg_mask.unsqueeze(-1).repeat(
Expand All @@ -241,7 +243,10 @@ def forward(self, pred_dist, pred_bboxes, anchor_points,
target_ltrb, bbox_mask).reshape([-1, 4])
loss_dfl = self._df_loss(pred_dist_pos,
target_ltrb_pos) * bbox_weight
loss_dfl = loss_dfl.sum() / target_scores_sum
if target_scores_sum == 0:
loss_dfl = loss_dfl.sum()
else:
loss_dfl = loss_dfl.sum() / target_scores_sum
else:
loss_dfl = torch.tensor(0.).to(pred_dist.device)

Expand Down
19 changes: 12 additions & 7 deletions yolov6/models/loss_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,8 @@ def __call__(

target_scores_sum = target_scores.sum()
# avoid devide zero error, devide by zero will cause loss to be inf or nan.
if target_scores_sum == 0:
loss_cls *= 0
else:
loss_cls /= target_scores_sum
if target_scores_sum > 0:
loss_cls /= target_scores_sum

# bbox loss
loss_iou, loss_dfl, d_loss_dfl = self.bbox_loss(pred_distri, pred_bboxes, t_pred_distri, t_pred_bboxes, temperature, anchor_points_s,
Expand Down Expand Up @@ -300,7 +298,10 @@ def forward(self, pred_dist, pred_bboxes, t_pred_dist, t_pred_bboxes, temperatur
target_scores.sum(-1), fg_mask).unsqueeze(-1)
loss_iou = self.iou_loss(pred_bboxes_pos,
target_bboxes_pos) * bbox_weight
loss_iou = loss_iou.sum() / target_scores_sum
if target_scores_sum == 0:
loss_iou = loss_iou.sum()
else:
loss_iou = loss_iou.sum() / target_scores_sum

# dfl loss
if self.use_dfl:
Expand All @@ -316,8 +317,12 @@ def forward(self, pred_dist, pred_bboxes, t_pred_dist, t_pred_bboxes, temperatur
loss_dfl = self._df_loss(pred_dist_pos,
target_ltrb_pos) * bbox_weight
d_loss_dfl = self.distill_loss_dfl(pred_dist_pos, t_pred_dist_pos, temperature) * bbox_weight
loss_dfl = loss_dfl.sum() / target_scores_sum
d_loss_dfl = d_loss_dfl.sum() / target_scores_sum
if target_scores_sum == 0:
loss_dfl = loss_dfl.sum()
d_loss_dfl = d_loss_dfl.sum()
else:
loss_dfl = loss_dfl.sum() / target_scores_sum
d_loss_dfl = d_loss_dfl.sum() / target_scores_sum
else:
loss_dfl = torch.tensor(0.).to(pred_dist.device)
d_loss_dfl = torch.tensor(0.).to(pred_dist.device)
Expand Down

0 comments on commit 2f83a8e

Please sign in to comment.