Skip to content

Commit

Permalink
Update loss_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
liruihui authored Nov 24, 2020
1 parent 3994722 commit bfe4a51
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions Common/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ def cls_loss(pred, pred_aug, gold, pc_tran, aug_tran, pc_feat, aug_feat, ispn =
cls_aug = cls_aug + 0.001*mat_loss(aug_tran)

feat_diff = 10.0*mse_fn(pc_feat,aug_feat)

cls_loss = cls_pc + cls_aug + feat_diff #+ cls_diff
parameters = torch.max(torch.tensor(NUM).cuda(), torch.exp(1.0-cls_pc_raw)**2).cuda()
cls_diff = (torch.abs(cls_pc_raw - cls_aug_raw) * (parameters*2)).mean()
cls_loss = cls_pc + cls_aug + feat_diff + cls_diff

return cls_loss

Expand Down

0 comments on commit bfe4a51

Please sign in to comment.