Skip to content

Commit

Permalink
DANet
Browse files Browse the repository at this point in the history
  • Loading branch information
xmu-xiaoma666 committed May 23, 2021
1 parent 7c083fc commit 80c92b3
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 11 deletions.
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

- [8. ECA Attention Usage](#8-eca-attention-usage)

- [9. DANet Attention Usage](#9-danet-attention-usage)

- [MLP Series](#mlp-series)

- [1. RepMLP Usage](#1-RepMLP-Usage)
Expand Down Expand Up @@ -49,6 +51,8 @@

- Pytorch implementation of ["ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks---CVPR2020"](https://arxiv.org/pdf/1910.03151.pdf)

- Pytorch implementation of ["Dual Attention Network for Scene Segmentation---CVPR2019"](https://arxiv.org/pdf/1809.02983.pdf)

***


Expand Down Expand Up @@ -198,6 +202,7 @@ print(output.shape)

```

***

### 8. ECA Attention Usage
#### 8.1. Paper
Expand All @@ -218,7 +223,26 @@ print(output.shape)

```

***

### 9. DANet Attention Usage
#### 9.1. Paper
["Dual Attention Network for Scene Segmentation---CVPR2019"](https://arxiv.org/pdf/1809.02983.pdf)

#### 9.2. Overview
![](./img/danet.png)

#### 9.3. Code
```python
from attention.DANet import DAModule
import torch

if __name__ == '__main__':
input=torch.randn(50,512,7,7)
danet=DAModule(d_model=512,kernel_size=3,H=7,W=7)
print(danet(input).shape)

```

***

Expand All @@ -230,7 +254,6 @@ print(output.shape)

- Pytorch implementation of ["ResMLP: Feedforward networks for image classification with data-efficient training"](https://arxiv.org/pdf/2105.03404.pdf)


- Pytorch implementation of ["Pay Attention to MLPs"](https://arxiv.org/abs/2105.08050)

### 1. RepMLP Usage
Expand Down
59 changes: 59 additions & 0 deletions attention/DANet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy as np
import torch
from torch import nn
from torch.nn import init
from .SelfAttention import ScaledDotProductAttention
from .SimplifiedSelfAttention import SimplifiedScaledDotProductAttention

class PositionAttentionModule(nn.Module):

def __init__(self,d_model=512,kernel_size=3,H=7,W=7):
super().__init__()
self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)
self.pa=ScaledDotProductAttention(d_model,d_k=d_model,d_v=d_model,h=1)

def forward(self,x):
bs,c,h,w=x.shape
y=self.cnn(x)
y=y.view(bs,c,-1).permute(0,2,1) #bs,h*w,c
y=self.pa(y,y,y) #bs,h*w,c
return y


class ChannelAttentionModule(nn.Module):

def __init__(self,d_model=512,kernel_size=3,H=7,W=7):
super().__init__()
self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)
self.pa=SimplifiedScaledDotProductAttention(H*W,h=1)

def forward(self,x):
bs,c,h,w=x.shape
y=self.cnn(x)
y=y.view(bs,c,-1) #bs,c,h*w
y=self.pa(y,y,y) #bs,c,h*w
return y




class DAModule(nn.Module):

def __init__(self,d_model=512,kernel_size=3,H=7,W=7):
super().__init__()
self.position_attention_module=PositionAttentionModule(d_model=512,kernel_size=3,H=7,W=7)
self.channel_attention_module=ChannelAttentionModule(d_model=512,kernel_size=3,H=7,W=7)

def forward(self,input):
bs,c,h,w=input.shape
p_out=self.position_attention_module(input)
c_out=self.channel_attention_module(input)
p_out=p_out.permute(0,2,1).view(bs,c,h,w)
c_out=c_out.view(bs,c,h,w)
return p_out+c_out


if __name__ == '__main__':
input=torch.randn(50,512,7,7)
danet=DAModule(d_model=512,kernel_size=3,H=7,W=7)
print(danet(input).shape)
Binary file added attention/__pycache__/DANet.cpython-38.pyc
Binary file not shown.
Binary file added img/danet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 4 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
from mlp.g_mlp import gMLP
from attention.DANet import DAModule
import torch

if __name__ == '__main__':

num_tokens=10000
bs=50
len_sen=49
num_layers=6
input=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen
gmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)
output=gmlp(input)
print(output.shape)
input=torch.randn(50,512,7,7)
danet=DAModule(d_model=512,kernel_size=3,H=7,W=7)
print(danet(input).shape)

0 comments on commit 80c92b3

Please sign in to comment.