Skip to content

Commit

Permalink
fix the bug that will cause classification loss to be inf.
Browse files Browse the repository at this point in the history
  • Loading branch information
mtjhl authored and Chilicyy committed Nov 4, 2022
1 parent fdbdd95 commit 12d302a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
6 changes: 5 additions & 1 deletion yolov6/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ def __call__(
loss_cls = self.varifocal_loss(pred_scores, target_scores, one_hot_label)

target_scores_sum = target_scores.sum()
loss_cls /= target_scores_sum
# avoid devide zero error, devide by zero will cause loss to be inf or nan.
if target_scores_sum == 0:
loss_cls *= 0
else:
loss_cls /= target_scores_sum

# bbox loss
loss_iou, loss_dfl = self.bbox_loss(pred_distri, pred_bboxes, anchor_points_s, target_bboxes,
Expand Down
6 changes: 5 additions & 1 deletion yolov6/models/loss_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ def __call__(
loss_cls = self.varifocal_loss(pred_scores, target_scores, one_hot_label)

target_scores_sum = target_scores.sum()
loss_cls /= target_scores_sum
# avoid devide zero error, devide by zero will cause loss to be inf or nan.
if target_scores_sum == 0:
loss_cls *= 0
else:
loss_cls /= target_scores_sum

# bbox loss
loss_iou, loss_dfl, d_loss_dfl = self.bbox_loss(pred_distri, pred_bboxes, t_pred_distri, t_pred_bboxes, temperature, anchor_points_s,
Expand Down

0 comments on commit 12d302a

Please sign in to comment.