Skip to content

Commit

Permalink
main code
Browse files Browse the repository at this point in the history
  • Loading branch information
WongKinYiu authored Jul 10, 2022
1 parent eef4f2c commit c587e8f
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,12 +775,20 @@ def build_targets(self, p, targets, imgs):
matching_anchs[i].append(all_anch[layer_idx])

for i in range(nl):
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
matching_as[i] = torch.cat(matching_as[i], dim=0)
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
if matching_gis[i] != []:
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
matching_as[i] = torch.cat(matching_as[i], dim=0)
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
else:
matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)

return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs

Expand Down

0 comments on commit c587e8f

Please sign in to comment.