Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
pprp committed Mar 29, 2020
1 parent 6c19620 commit 8108f41
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def smooth_BCE(eps=0.1):
return 1.0 - 0.5 * eps, 0.5 * eps

def compute_loss(p, targets, model):
# p: (bs, anchors, grid, grid, classes + xywh)
# predictions, targets, model
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
Expand Down Expand Up @@ -511,14 +512,14 @@ def compute_loss(p, targets, model):
elif 'BCE' in arc: # unified BCE (80 classes)
t = torch.zeros_like(pi[..., 5:]) # targets
if nb:
t[b, a, gj, gi, tcls[i]] = 1.0 # 对应class置信度设置为1
t[b, a, gj, gi, tcls[i]] = 1.0 # 对应正样本class置信度设置为1
lobj += BCE(pi[..., 5:], t)
#pi[...,5:]对应的是所有的class

elif 'CE' in arc: # unified CE (1 background + 80 classes)
t = torch.zeros_like(pi[..., 0], dtype=torch.long) # targets
if nb:
t[b, a, gj, gi] = tcls[i] + 1
t[b, a, gj, gi] = tcls[i] + 1 # 由于cls是从零开始计数的,所以+1
lcls += CE(pi[..., 4:].view(-1, model.nc + 1), t.view(-1))
# 这里将obj loss和cls loss一起计算,使用CrossEntropy Loss

Expand All @@ -529,9 +530,11 @@ def compute_loss(p, targets, model):

if red == 'sum':
bs = tobj.shape[0] # batch size
lobj *= 2 / (2535 * bs) * 2 # 3 / np * 2
# from 6300 to (13**2+26**2)*3 = 2535
# 3代表3个yolo层
lobj *= 3 / (6300 * bs) * 2
# 6300 = (10 ** 2 + 20 ** 2 + 40 ** 2) * 3
# 输入为320x320的图片,则存在6300个anchor
# 3代表3个yolo层, 2是一个超参数,通过实验获取
# 如果不想计算的话,可以修改red='mean'
if ng:
lcls *= 3 / ng / model.nc
lbox *= 3 / ng
Expand Down Expand Up @@ -633,7 +636,6 @@ def build_targets(model, targets):
assert c.max() < model.nc, 'Model accepts %g classes labeled from 0-%g, however you labelled a class %g. ' \
'See https://github.com/ultralytics/yolov3/wiki/Train-Custom-Data' % (
model.nc, model.nc - 1, c.max())
# tcls, tbox, indices, anchor_vec = build_targets(model, targets)
return tcls, tbox, indices, av


Expand Down

0 comments on commit 8108f41

Please sign in to comment.