Skip to content

Commit

Permalink
SA-NET
Browse files Browse the repository at this point in the history
  • Loading branch information
xmu-xiaoma666 committed Jun 10, 2021
1 parent 2b4343c commit b9485e7
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 6 deletions.
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

- [11. Efficient Multi-Head Self-Attention(EMSA) Usage](#11-Efficient-Multi-Head-Self-Attention-Usage)

- [12. Shuffle Attention Usage](#12-Shuffle-Attention-Usage)



- [MLP Series](#mlp-series)

- [1. RepMLP Usage](#1-RepMLP-Usage)
Expand Down Expand Up @@ -68,6 +72,8 @@

- Pytorch implementation of ["ResT: An Efficient Transformer for Visual Recognition---arXiv 2020.05.28"](https://arxiv.org/abs/2105.13677)

- Pytorch implementation of ["SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS---ICASSP 2021"](https://arxiv.org/pdf/2102.00240.pdf)



***
Expand Down Expand Up @@ -308,6 +314,37 @@ print(output.shape)

```

***


### 12. Shuffle Attention Usage

#### 12.1. Paper
["SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS"](https://arxiv.org/pdf/2102.00240.pdf)

#### 12.2. Overview
![](./img/ShuffleAttention.jpg)

#### 12.3. Code
```python

from attention.ShuffleAttention import ShuffleAttention
import torch
from torch import nn
from torch.nn import functional as F


input=torch.randn(50,512,7,7)
se = ShuffleAttention(channel=512,G=8)
output=se(input)
print(output.shape)


```




***

# MLP Series
Expand Down
82 changes: 82 additions & 0 deletions attention/ShuffleAttention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter


class ShuffleAttention(nn.Module):

def __init__(self, channel=512,reduction=16,G=8):
super().__init__()
self.G=G
self.channel=channel
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
self.sigmoid=nn.Sigmoid()


def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)


@staticmethod
def channel_shuffle(x, groups):
b, c, h, w = x.shape
x = x.reshape(b, groups, -1, h, w)
x = x.permute(0, 2, 1, 3, 4)

# flatten
x = x.reshape(b, -1, h, w)

return x

def forward(self, x):
b, c, h, w = x.size()
#group into subfeatures
x=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w

#channel_split
x_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w

#channel attention
x_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1
x_channel=self.cweight*x_channel+self.cweight #bs*G,c//(2*G),1,1
x_channel=x_0*self.sigmoid(x_channel)

#spatial attention
x_spatial=self.gn(x_1) #bs*G,c//(2*G),h,w
x_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,w
x_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w

# concatenate along channel axis
out=torch.cat([x_channel,x_spatial],dim=1) #bs*G,c//G,h,w
out=out.contiguous().view(b,-1,h,w)

# channel shuffle
out = self.channel_shuffle(out, 2)
return out


if __name__ == '__main__':
input=torch.randn(50,512,7,7)
se = ShuffleAttention(channel=512,G=8)
output=se(input)
print(output.shape)


Binary file not shown.
Binary file added img/ShuffleAttention.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 7 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from attention.EMSA import EMSA
from attention.ShuffleAttention import ShuffleAttention
import torch
from torch import nn
from torch.nn import functional as F

if __name__ == '__main__':
input=torch.randn(50,64,512)
emsa = EMSA(d_model=512, d_k=512, d_v=512, h=8,H=8,W=8,ratio=2,apply_transform=True)
output=emsa(input,input,input)
print(output.shape)

input=torch.randn(50,512,7,7)
se = ShuffleAttention(channel=512,G=8)
output=se(input)
print(output.shape)



0 comments on commit b9485e7

Please sign in to comment.