Skip to content

Commit

Permalink
main code
Browse files Browse the repository at this point in the history
fix nan of aux training WongKinYiu#250 (comment) @hudingding
  • Loading branch information
WongKinYiu authored Jul 21, 2022
1 parent de6a5e7 commit 4f6e390
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,46 +1218,55 @@ def __call__(self, p, targets, imgs): # predictions, targets, model
tobj_aux = torch.zeros_like(pi_aux[..., 0], device=device) # target obj

n = b.shape[0] # number of targets
n_aux = b_aux.shape[0] # number of targets
if n:
ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
ps_aux = pi_aux[b_aux, a_aux, gj_aux, gi_aux] # prediction subset corresponding to targets

# Regression
grid = torch.stack([gi, gj], dim=1)
grid_aux = torch.stack([gi_aux, gj_aux], dim=1)
pxy = ps[:, :2].sigmoid() * 2. - 0.5
pxy_aux = ps_aux[:, :2].sigmoid() * 2. - 0.5
#pxy = ps[:, :2].sigmoid() * 3. - 1.
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
pwh_aux = (ps_aux[:, 2:4].sigmoid() * 2) ** 2 * anchors_aux[i]
pbox = torch.cat((pxy, pwh), 1) # predicted box
pbox_aux = torch.cat((pxy_aux, pwh_aux), 1) # predicted box
selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i]
selected_tbox_aux = targets_aux[i][:, 2:6] * pre_gen_gains_aux[i]
selected_tbox[:, :2] -= grid
selected_tbox_aux[:, :2] -= grid_aux
iou = bbox_iou(pbox.T, selected_tbox, x1y1x2y2=False, CIoU=True) # iou(prediction, target)
iou_aux = bbox_iou(pbox_aux.T, selected_tbox_aux, x1y1x2y2=False, CIoU=True) # iou(prediction, target)
lbox += (1.0 - iou).mean() + 0.25 * (1.0 - iou_aux).mean() # iou loss
lbox += (1.0 - iou).mean() # iou loss

# Objectness
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio
tobj_aux[b_aux, a_aux, gj_aux, gi_aux] = (1.0 - self.gr) + self.gr * iou_aux.detach().clamp(0).type(tobj_aux.dtype) # iou ratio

# Classification
selected_tcls = targets[i][:, 1].long()
selected_tcls_aux = targets_aux[i][:, 1].long()
if self.nc > 1: # cls loss (only if multiple classes)
t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
t_aux = torch.full_like(ps_aux[:, 5:], self.cn, device=device) # targets
t[range(n), selected_tcls] = self.cp
t_aux[range(n_aux), selected_tcls_aux] = self.cp
lcls += self.BCEcls(ps[:, 5:], t) + 0.25 * self.BCEcls(ps_aux[:, 5:], t_aux) # BCE
lcls += self.BCEcls(ps[:, 5:], t) # BCE

# Append targets to text file
# with open('targets.txt', 'a') as file:
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]

n_aux = b_aux.shape[0] # number of targets
if n_aux:
ps_aux = pi_aux[b_aux, a_aux, gj_aux, gi_aux] # prediction subset corresponding to targets
grid_aux = torch.stack([gi_aux, gj_aux], dim=1)
pxy_aux = ps_aux[:, :2].sigmoid() * 2. - 0.5
#pxy_aux = ps_aux[:, :2].sigmoid() * 3. - 1.
pwh_aux = (ps_aux[:, 2:4].sigmoid() * 2) ** 2 * anchors_aux[i]
pbox_aux = torch.cat((pxy_aux, pwh_aux), 1) # predicted box
selected_tbox_aux = targets_aux[i][:, 2:6] * pre_gen_gains_aux[i]
selected_tbox_aux[:, :2] -= grid_aux
iou_aux = bbox_iou(pbox_aux.T, selected_tbox_aux, x1y1x2y2=False, CIoU=True) # iou(prediction, target)
lbox += 0.25 * (1.0 - iou_aux).mean() # iou loss

# Objectness
tobj_aux[b_aux, a_aux, gj_aux, gi_aux] = (1.0 - self.gr) + self.gr * iou_aux.detach().clamp(0).type(tobj_aux.dtype) # iou ratio

# Classification
selected_tcls_aux = targets_aux[i][:, 1].long()
if self.nc > 1: # cls loss (only if multiple classes)
t_aux = torch.full_like(ps_aux[:, 5:], self.cn, device=device) # targets
t_aux[range(n_aux), selected_tcls_aux] = self.cp
lcls += 0.25 * self.BCEcls(ps_aux[:, 5:], t_aux) # BCE

obji = self.BCEobj(pi[..., 4], tobj)
obji_aux = self.BCEobj(pi_aux[..., 4], tobj_aux)
Expand Down

0 comments on commit 4f6e390

Please sign in to comment.