Skip to content

Commit

Permalink
convert datatype of target labels during one-hot encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
Chilicyy committed Sep 8, 2022
1 parent 607727e commit 5b04545
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion yolov6/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __call__(

# cls loss
target_labels = torch.where(fg_mask > 0, target_labels, torch.full_like(target_labels, self.num_classes))
one_hot_label = F.one_hot(target_labels, self.num_classes + 1)[..., :-1]
one_hot_label = F.one_hot(target_labels.long(), self.num_classes + 1)[..., :-1]
loss_cls = self.varifocal_loss(pred_scores, target_scores, one_hot_label)

target_scores_sum = target_scores.sum()
Expand Down

0 comments on commit 5b04545

Please sign in to comment.