Skip to content

Commit

Permalink
Re-Parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
xmu-xiaoma666 committed Jun 6, 2021
1 parent f9eae2f commit 18c5899
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 8 deletions.
45 changes: 43 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
- [4. gMLP Usage](#4-gMLP-Usage)


- [Re-Parameter(ReP) Series](#Re-Parameter(ReP)-series)

- [1. RepVGG Usage](#1-RepVGG-Usage)






***
Expand Down Expand Up @@ -345,7 +352,7 @@ output=mlp_mixer(input)
print(output.shape)
```


***

### 3. ResMLP Usage
#### 3.1. Paper
Expand All @@ -365,6 +372,7 @@ out=resmlp(input)
print(out.shape) #the last dimention is class_num
```

***

### 4. gMLP Usage
#### 4.1. Paper
Expand All @@ -386,4 +394,37 @@ input=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen
gmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)
output=gmlp(input)
print(output.shape)
```
```



# Re-Parameter(ReP) Series

- Pytorch implementation of ["RepVGG: Making VGG-style ConvNets Great Again---CVPR2021"](https://arxiv.org/abs/2101.03697)


***

### 1. RepVGG Usage
#### 1.1. Paper
["RepVGG: Making VGG-style ConvNets Great Again"](https://arxiv.org/abs/2101.03697)

#### 1.2. Overview
![](./img/repvgg.png)

#### 1.3. Code
```python

from rep.repvgg import RepBlock
import torch


input=torch.randn(50,512,49,49)
repblock=RepBlock(512,512)
repblock.eval()
out=repblock(input)
repblock._switch_to_deploy()
out2=repblock(input)
print('difference between vgg and repvgg')
print(((out2-out)**2).sum())
```
Binary file added img/repvgg.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 11 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from attention.PSA import PSA
from rep.repvgg import RepBlock
import torch

if __name__ == '__main__':
input=torch.randn(50,512,7,7)
psa = PSA(channel=512,reduction=8)
output=psa(input)
print(output.shape)

if __name__ == '__main__':
input=torch.randn(50,512,49,49)
repblock=RepBlock(512,512)
repblock.eval()
out=repblock(input)
repblock._switch_to_deploy()
out2=repblock(input)
print('difference between vgg and repvgg')
print(((out2-out)**2).sum())

Binary file added rep/__pycache__/repvgg.cpython-38.pyc
Binary file not shown.
137 changes: 137 additions & 0 deletions rep/repvgg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch
from torch import mean, nn
from collections import OrderedDict
from torch.nn import functional as F
import numpy as np
from numpy import random

def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True

def _conv_bn(input_channel,output_channel,kernel_size=3,padding=1,stride=1,groups=1):
res=nn.Sequential()
res.add_module('conv',nn.Conv2d(in_channels=input_channel,out_channels=output_channel,kernel_size=kernel_size,padding=padding,padding_mode='zeros',stride=stride,groups=groups,bias=False))
res.add_module('bn',nn.BatchNorm2d(output_channel))
return res

class RepBlock(nn.Module):
def __init__(self,input_channel,output_channel,kernel_size=3,groups=1,stride=1,deploy=False,use_se=False):
super().__init__()
self.use_se=use_se
self.input_channel=input_channel
self.output_channel=output_channel
self.deploy=deploy
self.kernel_size=kernel_size
self.padding=kernel_size//2
self.groups=groups
self.activation=nn.ReLU()



#make sure kernel_size=3 padding=1
assert self.kernel_size==3
assert self.padding==1
if(not self.deploy):
self.brb_3x3=_conv_bn(input_channel,output_channel,kernel_size=self.kernel_size,padding=self.padding,groups=groups)
self.brb_1x1=_conv_bn(input_channel,output_channel,kernel_size=1,padding=0,groups=groups)
self.brb_identity=nn.BatchNorm2d(self.input_channel) if self.input_channel == self.output_channel else None
else:
self.brb_rep=nn.Conv2d(in_channels=input_channel,out_channels=output_channel,kernel_size=self.kernel_size,padding=self.padding,padding_mode='zeros',stride=stride,bias=True)



def forward(self, inputs):
if(self.deploy):
return self.activation(self.brb_rep(inputs))

if(self.brb_identity==None):
identity_out=0
else:
identity_out=self.brb_identity(inputs)

return self.activation(self.brb_1x1(inputs)+self.brb_3x3(inputs)+identity_out)




def _switch_to_deploy(self):
self.deploy=True
kernel,bias=self._get_equivalent_kernel_bias()
self.brb_rep=nn.Conv2d(in_channels=self.brb_3x3.conv.in_channels,out_channels=self.brb_3x3.conv.out_channels,
kernel_size=self.brb_3x3.conv.kernel_size,padding=self.brb_3x3.conv.padding,
padding_mode=self.brb_3x3.conv.padding_mode,stride=self.brb_3x3.conv.stride,
groups=self.brb_3x3.conv.groups,bias=True)
self.brb_rep.weight.data=kernel
self.brb_rep.bias.data=bias
#消除梯度更新
for para in self.parameters():
para.detach_()
#删除没用的分支
self.__delattr__('brb_3x3')
self.__delattr__('brb_1x1')
self.__delattr__('brb_identity')


#将1x1的卷积变成3x3的卷积参数
def _pad_1x1_kernel(self,kernel):
if(kernel is None):
return 0
else:
return F.pad(kernel,[1]*4)


#将identity,1x1,3x3的卷积融合到一起,变成一个3x3卷积的参数
def _get_equivalent_kernel_bias(self):
brb_3x3_weight,brb_3x3_bias=self._fuse_conv_bn(self.brb_3x3)
brb_1x1_weight,brb_1x1_bias=self._fuse_conv_bn(self.brb_1x1)
brb_id_weight,brb_id_bias=self._fuse_conv_bn(self.brb_identity)
return brb_3x3_weight+self._pad_1x1_kernel(brb_1x1_weight)+brb_id_weight,brb_3x3_bias+brb_1x1_bias+brb_id_bias


### 将卷积和BN的参数融合到一起
def _fuse_conv_bn(self,branch):
if(branch is None):
return 0,0
elif(isinstance(branch,nn.Sequential)):
kernel=branch.conv.weight
running_mean=branch.bn.running_mean
running_var=branch.bn.running_var
gamma=branch.bn.weight
beta=branch.bn.bias
eps=branch.bn.eps
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, 'id_tensor'):
input_dim = self.input_channel // self.groups
kernel_value = np.zeros((self.input_channel, input_dim, 3, 3), dtype=np.float32)
for i in range(self.input_channel):
kernel_value[i, i % input_dim, 1, 1] = 1
self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps

std=(running_var+eps).sqrt()
t=gamma/std
t=t.view(-1,1,1,1)
return kernel*t,beta-running_mean*gamma/std



if __name__ == '__main__':
input=torch.randn(50,512,49,49)
repblock=RepBlock(512,512)
repblock.eval()
out=repblock(input)
repblock._switch_to_deploy()
out2=repblock(input)
print('difference between vgg and repvgg')
print(((out2-out)**2).sum())

0 comments on commit 18c5899

Please sign in to comment.