Skip to content

Commit

Permalink
gMLP
Browse files Browse the repository at this point in the history
  • Loading branch information
xmu-xiaoma666 committed May 19, 2021
1 parent f0b0312 commit 08c82f7
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 5 deletions.
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

- [3. ResMLP Usage](#3-ResMLP-Usage)

- [4. ResMLP Usage](#4-gMLP-Usage)




Expand Down Expand Up @@ -226,6 +228,9 @@ print(output.shape)

- Pytorch implementation of ["ResMLP: Feedforward networks for image classification with data-efficient training"](https://arxiv.org/pdf/2105.03404.pdf)


- Pytorch implementation of ["Pay Attention to MLPs"](https://arxiv.org/abs/2105.08050)

### 1. RepMLP Usage
#### 1.1. Paper
["RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition"](https://arxiv.org/pdf/2105.01883v1.pdf)
Expand Down Expand Up @@ -303,4 +308,27 @@ input=torch.randn(50,3,14,14)
resmlp=ResMLP(dim=128,image_size=14,patch_size=7,class_num=1000)
out=resmlp(input)
print(out.shape) #the last dimention is class_num
```


### 4. gMLP Usage
#### 4.1. Paper
["Pay Attention to MLPs"](https://arxiv.org/abs/2105.08050)

#### 4.2. Overview
![](./img/gMLP.jpg)

#### 4.3. Code
```python
from mlp.g_mlp import gMLP
import torch

num_tokens=10000
bs=50
len_sen=49
num_layers=6
input=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen
gmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)
output=gmlp(input)
print(output.shape)
```
Binary file added img/gMLP.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 11 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from attention.ECAAttention import ECAAttention
from mlp.g_mlp import gMLP
import torch

input=torch.randn(50,512,7,7)
eca = ECAAttention(kernel_size=3)
output=eca(input)
print(output.shape)
if __name__ == '__main__':

num_tokens=10000
bs=50
len_sen=49
num_layers=6
input=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen
gmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)
output=gmlp(input)
print(output.shape)
Binary file added mlp/__pycache__/g_mlp.cpython-38.pyc
Binary file not shown.
86 changes: 86 additions & 0 deletions mlp/g_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from collections import OrderedDict
import torch
from torch import nn


def exist(x):
return x is not None

class Residual(nn.Module):
def __init__(self,fn):
super().__init__()
self.fn=fn

def forward(self,x):
return self.fn(x)+x

class SpatialGatingUnit(nn.Module):
def __init__(self,dim,len_sen):
super().__init__()
self.ln=nn.LayerNorm(dim)
self.proj=nn.Conv1d(len_sen,len_sen,1)

nn.init.zeros_(self.proj.weight)
nn.init.ones_(self.proj.bias)

def forward(self,x):
res,gate=torch.chunk(x,2,-1) #bs,n,d_ff
###Norm
gate=self.ln(gate) #bs,n,d_ff
###Spatial Proj
gate=self.proj(gate) #bs,n,d_ff

return res*gate

class gMLP(nn.Module):
def __init__(self,num_tokens=None,len_sen=49,dim=512,d_ff=1024,num_layers=6):
super().__init__()
self.num_layers=num_layers
self.embedding=nn.Embedding(num_tokens,dim) if exist(num_tokens) else nn.Identity()

self.gmlp=nn.ModuleList([Residual(nn.Sequential(OrderedDict([
('ln1_%d'%i,nn.LayerNorm(dim)),
('fc1_%d'%i,nn.Linear(dim,d_ff*2)),
('gelu_%d'%i,nn.GELU()),
('sgu_%d'%i,SpatialGatingUnit(d_ff,len_sen)),
('fc2_%d'%i,nn.Linear(d_ff,dim)),
]))) for i in range(num_layers)])



self.to_logits=nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim,num_tokens),
nn.Softmax(-1)
)


def forward(self,x):
#embedding
embeded=self.embedding(x)

#gMLP
y=nn.Sequential(*self.gmlp)(embeded)


#to logits
logits=self.to_logits(y)


return logits





if __name__ == '__main__':

num_tokens=10000
bs=50
len_sen=49
num_layers=6
input=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen
gmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)
output=gmlp(input)
print(output.shape)

0 comments on commit 08c82f7

Please sign in to comment.