Skip to content

Commit

Permalink
DynamicConv
Browse files Browse the repository at this point in the history
  • Loading branch information
xmu-xiaoma666 committed Oct 12, 2021
1 parent fc5d41c commit fda1217
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 5 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ Hello,大家好,我是小马🚀🚀🚀
- [2. MBConv Usage](#2-MBConv-Usage)

- [3. Involution Usage](#3-Involution-Usage)

- [4. DynamicConv Usage](#4-DynamicConv-Usage)

***


Expand Down Expand Up @@ -1458,6 +1461,9 @@ print("difference:",((out2-out1)**2).sum().item())

- Pytorch implementation of ["Involution: Inverting the Inherence of Convolution for Visual Recognition---CVPR2021"](https://arxiv.org/abs/2103.06255)

- Pytorch implementation of ["Dynamic Convolution: Attention over Convolution Kernels---CVPR2020 Oral"](https://arxiv.org/abs/1912.03458)


***

### 1. Depthwise Separable Convolution Usage
Expand Down Expand Up @@ -1530,3 +1536,27 @@ print(out.shape)

***


### 4. DynamicConv Usage
#### 4.1. Paper
["Dynamic Convolution: Attention over Convolution Kernels"](https://arxiv.org/abs/1912.03458)

#### 4.2. Overview
![](./model/img/DynamicConv.png)

#### 4.3. Usage Code
```python
from model.conv.DynamicConv import *
import torch
from torch import nn
from torch.nn import functional as F

if __name__ == '__main__':
input=torch.randn(2,32,64,64)
m=DynamicConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)
out=m(input)
print(out.shape) # 2,32,64,64

```

***
10 changes: 5 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from model.backbone.ConvMixer import *
from model.conv.DynamicConv import *
import torch
from torch import nn
from torch.nn import functional as F

if __name__ == '__main__':
x=torch.randn(1,3,224,224)
convmixer=ConvMixer(dim=512,depth=12)
out=convmixer(x)
print(out.shape) #[1, 1000]
input=torch.randn(2,32,64,64)
m=DynamicConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)
out=m(input)
print(out.shape) # 2,32,64,64
90 changes: 90 additions & 0 deletions model/conv/DynamicConv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
from torch import nn
from torch.nn import functional as F

class Attention(nn.Module):
def __init__(self,in_planes,ratio,K,temprature=30,init_weight=True):
super().__init__()
self.avgpool=nn.AdaptiveAvgPool2d(1)
self.temprature=temprature
assert in_planes>ratio
hidden_planes=in_planes//ratio
self.net=nn.Sequential(
nn.Conv2d(in_planes,hidden_planes,kernel_size=1,bias=False),
nn.ReLU(),
nn.Conv2d(hidden_planes,K,kernel_size=1,bias=False)
)

if(init_weight):
self._initialize_weights()

def update_temprature(self):
if(self.temprature>1):
self.temprature-=1

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m ,nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def forward(self,x):
att=self.avgpool(x) #bs,dim,1,1
att=self.net(att).view(x.shape[0],-1) #bs,K
return F.softmax(att/self.temprature,-1)

class DynamicConv(nn.Module):
def __init__(self,in_planes,out_planes,kernel_size,stride,padding=0,dilation=1,grounps=1,bias=True,K=4,temprature=30,ratio=4,init_weight=True):
super().__init__()
self.in_planes=in_planes
self.out_planes=out_planes
self.kernel_size=kernel_size
self.stride=stride
self.padding=padding
self.dilation=dilation
self.groups=grounps
self.bias=bias
self.K=K
self.init_weight=init_weight
self.attention=Attention(in_planes=in_planes,ratio=ratio,K=K,temprature=temprature,init_weight=init_weight)

self.weight=nn.Parameter(torch.randn(K,out_planes,in_planes//grounps,kernel_size,kernel_size),requires_grad=True)
if(bias):
self.bias=nn.Parameter(torch.randn(K,out_planes),requires_grad=True)
else:
self.bias=None

if(self.init_weight):
self._initialize_weights()

#TODO 初始化
def _initialize_weights(self):
for i in range(self.K):
nn.init.kaiming_uniform_(self.weight[i])

def forward(self,x):
bs,in_planels,h,w=x.shape
softmax_att=self.attention(x) #bs,K
x=x.view(1,-1,h,w)
weight=self.weight.view(self.K,-1) #K,-1
aggregate_weight=torch.mm(softmax_att,weight).view(bs*self.out_planes,self.in_planes//self.groups,self.kernel_size,self.kernel_size) #bs*out_p,in_p,k,k

if(self.bias is not None):
bias=self.bias.view(self.K,-1) #K,out_p
aggregate_bias=torch.mm(softmax_att,bias).view(-1) #bs,out_p
output=F.conv2d(x,weight=aggregate_weight,bias=aggregate_bias,stride=self.stride,padding=self.padding,groups=self.groups*bs,dilation=self.dilation)
else:
output=F.conv2d(x,weight=aggregate_weight,bias=None,stride=self.stride,padding=self.padding,groups=self.groups*bs,dilation=self.dilation)

output=output.view(bs,self.out_planes,h,w)
return output

if __name__ == '__main__':
input=torch.randn(2,32,64,64)
m=DynamicConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)
out=m(input)
print(out.shape)
Binary file added model/img/DynamicConv.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit fda1217

Please sign in to comment.