Skip to content

Commit

Permalink
main code
Browse files Browse the repository at this point in the history
fix gain for train_aux
  • Loading branch information
WongKinYiu authored Jul 14, 2022
1 parent 4cebf40 commit ef4dde4
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,7 +1115,7 @@ def find_3_positive(self, p, targets):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
na, nt = self.na, targets.shape[0] # number of anchors, targets
indices, anch = [], []
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices

Expand Down Expand Up @@ -1561,7 +1561,7 @@ def find_5_positive(self, p, targets):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
na, nt = self.na, targets.shape[0] # number of anchors, targets
indices, anch = [], []
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices

Expand Down Expand Up @@ -1614,7 +1614,7 @@ def find_3_positive(self, p, targets):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
na, nt = self.na, targets.shape[0] # number of anchors, targets
indices, anch = [], []
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices

Expand Down

0 comments on commit ef4dde4

Please sign in to comment.