Skip to content

Commit

Permalink
ECA-Attention
Browse files Browse the repository at this point in the history
  • Loading branch information
xmu-xiaoma666 committed May 19, 2021
1 parent 361a74b commit f0b0312
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 7 deletions.
24 changes: 23 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
- [6. CBAM Attention Usage](#6-cbam-attention-usage)

- [7. BAM Attention Usage](#7-bam-attention-usage)

- [8. ECA Attention Usage](#8-eca-attention-usage)

- [MLP Series](#mlp-series)

Expand Down Expand Up @@ -41,6 +43,8 @@

- Pytorch implementation of ["BAM: Bottleneck Attention Module---BMCV2018"](https://arxiv.org/pdf/1807.06514.pdf)

- Pytorch implementation of ["ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks---CVPR2020"](https://arxiv.org/pdf/1910.03151.pdf)

***


Expand Down Expand Up @@ -173,7 +177,7 @@ print(output.shape)

### 7. BAM Attention Usage
#### 7.1. Paper
["BAM: Bottleneck Attention Module---BMCV2018"](https://arxiv.org/pdf/1807.06514.pdf)
["BAM: Bottleneck Attention Module"](https://arxiv.org/pdf/1807.06514.pdf)

#### 7.2. Overview
![](./img/BAM.png)
Expand All @@ -191,6 +195,24 @@ print(output.shape)
```


### 8. ECA Attention Usage
#### 8.1. Paper
["ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks"](https://arxiv.org/pdf/1910.03151.pdf)

#### 8.2. Overview
![](./img/ECA.png)

#### 8.3. Code
```python
from attention.ECAAttention import ECAAttention
import torch

input=torch.randn(50,512,7,7)
eca = ECAAttention(kernel_size=3)
output=eca(input)
print(output.shape)

```



Expand Down
50 changes: 50 additions & 0 deletions attention/ECAAttention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDict



class ECAAttention(nn.Module):

def __init__(self, kernel_size=3):
super().__init__()
self.gap=nn.AdaptiveAvgPool2d(1)
self.conv=nn.Conv1d(1,1,kernel_size=kernel_size,padding=(kernel_size-1)//2)
self.sigmoid=nn.Sigmoid()

def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)

def forward(self, x):
y=self.gap(x) #bs,c,1,1
y=y.squeeze(-1).permute(0,2,1) #bs,1,c
y=self.conv(y) #bs,1,c
y=self.sigmoid(y) #bs,1,c
y=y.permute(0,2,1).unsqueeze(-1) #bs,c,1,1
return x*y.expand_as(x)






if __name__ == '__main__':
input=torch.randn(50,512,7,7)
eca = ECAAttention(kernel_size=3)
output=eca(input)
print(output.shape)


Binary file added attention/__pycache__/ECAAttention.cpython-38.pyc
Binary file not shown.
Binary file added img/ECA.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 5 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from mlp.resmlp import ResMLP
from attention.ECAAttention import ECAAttention
import torch


input=torch.randn(50,3,14,14)
resmlp=ResMLP(dim=128,image_size=14,patch_size=7,class_num=1000)
out=resmlp(input)
print(out.shape)
input=torch.randn(50,512,7,7)
eca = ECAAttention(kernel_size=3)
output=eca(input)
print(output.shape)

0 comments on commit f0b0312

Please sign in to comment.