Skip to content

Commit

Permalink
Fuse IAuxDetect
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexeyAB committed Jul 28, 2022
1 parent 4c207e1 commit 954cde6
Showing 1 changed file with 71 additions and 3 deletions.
74 changes: 71 additions & 3 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def _make_grid(nx=20, ny=20):
class IAuxDetect(nn.Module):
stride = None # strides computed during build
export = False # onnx export
end2end = False
include_nms = False

def __init__(self, nc=80, anchors=(), ch=()): # detection layer
super(IAuxDetect, self).__init__()
Expand Down Expand Up @@ -338,17 +340,83 @@ def forward(self, x):
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)

y = x[i].sigmoid()
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
if not torch.onnx.is_in_onnx_export():
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else:
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].data # wh
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no))

return x if self.training else (torch.cat(z, 1), x[:self.nl])

def fuseforward(self, x):
# x = x.copy() # for profiling
z = [] # inference output
self.training |= self.export
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)

y = x[i].sigmoid()
if not torch.onnx.is_in_onnx_export():
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else:
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].data # wh
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no))

if self.training:
out = x
elif self.end2end:
out = torch.cat(z, 1)
elif self.include_nms:
z = self.convert(z)
out = (z, )
else:
out = (torch.cat(z, 1), x)

return out

def fuse(self):
print("IAuxDetect.fuse")
# fuse ImplicitA and Convolution
for i in range(len(self.m)):
c1,c2,_,_ = self.m[i].weight.shape
c1_,c2_, _,_ = self.ia[i].implicit.shape
self.m[i].bias += torch.matmul(self.m[i].weight.reshape(c1,c2),self.ia[i].implicit.reshape(c2_,c1_)).squeeze(1)

# fuse ImplicitM and Convolution
for i in range(len(self.m)):
c1,c2, _,_ = self.im[i].implicit.shape
self.m[i].bias *= self.im[i].implicit.reshape(c2)
self.m[i].weight *= self.im[i].implicit.transpose(0,1)

@staticmethod
def _make_grid(nx=20, ny=20):
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()

def convert(self, z):
z = torch.cat(z, 1)
box = z[:, :, :4]
conf = z[:, :, 4:5]
score = z[:, :, 5:]
score *= conf
convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
dtype=torch.float32,
device=z.device)
box @= convert_matrix
return (box, score)


class IBin(nn.Module):
stride = None # strides computed during build
Expand Down Expand Up @@ -623,7 +691,7 @@ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm
m.forward = m.fuseforward # update forward
elif isinstance(m, IDetect):
elif isinstance(m, (IDetect, IAuxDetect)):
m.fuse()
m.forward = m.fuseforward
self.info()
Expand Down

0 comments on commit 954cde6

Please sign in to comment.