Skip to content

Commit

Permalink
add QARepVGG V2
Browse files Browse the repository at this point in the history
  • Loading branch information
mtjhl committed Apr 7, 2023
1 parent 1b9e04d commit 925329d
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 4 deletions.
2 changes: 1 addition & 1 deletion configs/qarepvgg/yolov6m_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,4 @@
mixup=0.1,
)

training_mode='qarepvgg'
training_mode='qarepvggv2'
2 changes: 1 addition & 1 deletion configs/qarepvgg/yolov6n_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@
mosaic=1.0,
mixup=0.0,
)
training_mode='qarepvgg'
training_mode='qarepvggv2'
2 changes: 1 addition & 1 deletion configs/qarepvgg/yolov6s_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@
mixup=0.0,
)

training_mode='qarepvgg'
training_mode='qarepvggv2'
77 changes: 76 additions & 1 deletion yolov6/layers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,80 @@ def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False):
super(QARepVGGBlock, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
padding_mode, deploy, use_se)
if not deploy:
self.bn = nn.BatchNorm2d(out_channels)
self.rbr_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, groups=groups, bias=False)
self.rbr_identity = nn.Identity() if out_channels == in_channels and stride == 1 else None
self._id_tensor = None

def forward(self, inputs):
if hasattr(self, 'rbr_reparam'):
return self.nonlinearity(self.bn(self.se(self.rbr_reparam(inputs))))

if self.rbr_identity is None:
id_out = 0
else:
id_out = self.rbr_identity(inputs)

return self.nonlinearity(self.bn(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)))

def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
kernel = kernel3x3 + self._pad_1x1_to_3x3_tensor(self.rbr_1x1.weight)
bias = bias3x3

if self.rbr_identity is not None:
input_dim = self.in_channels // self.groups
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
for i in range(self.in_channels):
kernel_value[i, i % input_dim, 1, 1] = 1
id_tensor = torch.from_numpy(kernel_value).to(self.rbr_1x1.weight.device)
kernel = kernel + id_tensor
return kernel, bias

def _fuse_extra_bn_tensor(self, kernel, bias, branch):
assert isinstance(branch, nn.BatchNorm2d)
running_mean = branch.running_mean - bias # remove bias
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std

def switch_to_deploy(self):
if hasattr(self, 'rbr_reparam'):
return
kernel, bias = self.get_equivalent_kernel_bias()
self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,
kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)
self.rbr_reparam.weight.data = kernel
self.rbr_reparam.bias.data = bias
for para in self.parameters():
para.detach_()
self.__delattr__('rbr_dense')
self.__delattr__('rbr_1x1')
if hasattr(self, 'rbr_identity'):
self.__delattr__('rbr_identity')
if hasattr(self, 'id_tensor'):
self.__delattr__('id_tensor')
# keep post bn for QAT
# if hasattr(self, 'bn'):
# self.__delattr__('bn')
self.deploy = True


class QARepVGGBlockV2(RepVGGBlock):
"""
RepVGGBlock is a basic rep-style block, including training and deploy status
This code is based on https://arxiv.org/abs/2212.01593
"""
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False):
super(QARepVGGBlockV2, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
padding_mode, deploy, use_se)
if not deploy:
self.bn = nn.BatchNorm2d(out_channels)
self.rbr_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, groups=groups, bias=False)
Expand All @@ -365,7 +439,6 @@ def forward(self, inputs):

return self.nonlinearity(self.bn(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out + avg_out)))


def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
kernel = kernel3x3 + self._pad_1x1_to_3x3_tensor(self.rbr_1x1.weight)
Expand Down Expand Up @@ -627,6 +700,8 @@ def get_block(mode):
return RepVGGBlock
elif mode == 'qarepvgg':
return QARepVGGBlock
elif mode == 'qarepvggv2':
return QARepVGGBlockV2
elif mode == 'hyper_search':
return LinearAddBlock
elif mode == 'repopt':
Expand Down

0 comments on commit 925329d

Please sign in to comment.