Skip to content

Commit

Permalink
mlpmixer
Browse files Browse the repository at this point in the history
  • Loading branch information
xmu-xiaoma666 committed May 18, 2021
1 parent a851e48 commit 00f4ed0
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 34 deletions.
28 changes: 26 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,17 @@
- [7. BAM Attention Usage](#7-bam-attention-usage)

- [MLP Series](#mlp-series)

- [1. RepMLP Usage](#1-RepMLP-Usage)

- [2. MLP-Mixer Usage](#2-MLP-Mixer-Usage)




***


# Attention Series

- Pytorch implementation of ["Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks"](https://arxiv.org/abs/2105.02358)
Expand Down Expand Up @@ -225,11 +234,26 @@ for module in repmlp.modules():

#training result
out=repmlp(x)


#inference result
repmlp.switch_to_deploy()
deployout = repmlp(x)

print(((deployout-out)**2).sum())
```

### 2. MLP-Mixer Usage
#### 2.1. Paper
["MLP-Mixer: An all-MLP Architecture for Vision"](https://arxiv.org/pdf/2105.01601.pdf)

#### 2.2. Overview
![](./img/mlpmixer.png)

#### 2.3. Code
```python
from mlp.mlp_mixer import MlpMixer
import torch
mlp_mixer=MlpMixer(num_classes=1000,num_blocks=10,patch_size=10,tokens_hidden_dim=32,channels_hidden_dim=1024,tokens_mlp_dim=16,channels_mlp_dim=1024)
input=torch.randn(50,3,40,40)
output=mlp_mixer(input)
print(output.shape)
```
Binary file added img/mlpmixer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 5 additions & 32 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,6 @@
from mlp.repmlp import RepMLP
from mlp.mlp_mixer import MlpMixer
import torch
from torch import nn

N=4 #batch size
C=512 #input dim
O=1024 #output dim
H=14 #image height
W=14 #image width
h=7 #patch height
w=7 #patch width
fc1_fc2_reduction=1 #reduction ratio
fc3_groups=8 # groups
repconv_kernels=[1,3,5,7] #kernel list
repmlp=RepMLP(C,O,H,W,h,w,fc1_fc2_reduction,fc3_groups,repconv_kernels=repconv_kernels)
x=torch.randn(N,C,H,W)
repmlp.eval()
for module in repmlp.modules():
if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
nn.init.uniform_(module.running_mean, 0, 0.1)
nn.init.uniform_(module.running_var, 0, 0.1)
nn.init.uniform_(module.weight, 0, 0.1)
nn.init.uniform_(module.bias, 0, 0.1)

#training result
out=repmlp(x)


#inference result
repmlp.switch_to_deploy()
deployout = repmlp(x)

print(((deployout-out)**2).sum())
mlp_mixer=MlpMixer(num_classes=1000,num_blocks=10,patch_size=10,tokens_hidden_dim=32,channels_hidden_dim=1024,tokens_mlp_dim=16,channels_mlp_dim=1024)
input=torch.randn(50,3,40,40)
output=mlp_mixer(input)
print(output.shape)
Binary file added mlp/__pycache__/mlp_mixer.cpython-38.pyc
Binary file not shown.
76 changes: 76 additions & 0 deletions mlp/mlp_mixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
from torch import nn

class MlpBlock(nn.Module):
def __init__(self,input_dim,mlp_dim=512) :
super().__init__()
self.fc1=nn.Linear(input_dim,mlp_dim)
self.gelu=nn.GELU()
self.fc2=nn.Linear(mlp_dim,input_dim)

def forward(self,x):
#x: (bs,tokens,channels) or (bs,channels,tokens)
return self.fc2(self.gelu(self.fc1(x)))



class MixerBlock(nn.Module):
def __init__(self,tokens_mlp_dim=16,channels_mlp_dim=1024,tokens_hidden_dim=32,channels_hidden_dim=1024):
super().__init__()
self.ln=nn.LayerNorm(channels_mlp_dim)
self.tokens_mlp_block=MlpBlock(tokens_mlp_dim,mlp_dim=tokens_hidden_dim)
self.channels_mlp_block=MlpBlock(channels_mlp_dim,mlp_dim=channels_hidden_dim)

def forward(self,x):
"""
x: (bs,tokens,channels)
"""
### tokens mixing
y=self.ln(x)
y=y.transpose(1,2) #(bs,channels,tokens)
y=self.tokens_mlp_block(y) #(bs,channels,tokens)
### channels mixing
y=y.transpose(1,2) #(bs,tokens,channels)
y=x+y #(bs,tokens,channels)
y=self.ln(y) #(bs,tokens,channels)
y=x+self.channels_mlp_block(y) #(bs,tokens,channels)
return y

class MlpMixer(nn.Module):
def __init__(self,num_classes,num_blocks,patch_size,tokens_hidden_dim,channels_hidden_dim,tokens_mlp_dim,channels_mlp_dim):
super().__init__()
self.num_classes=num_classes
self.num_blocks=num_blocks #num of mlp layers
self.patch_size=patch_size
self.tokens_mlp_dim=tokens_mlp_dim
self.channels_mlp_dim=channels_mlp_dim
self.embd=nn.Conv2d(3,channels_mlp_dim,kernel_size=patch_size,stride=patch_size)
self.ln=nn.LayerNorm(channels_mlp_dim)
self.mlp_blocks=[]
for _ in range(num_blocks):
self.mlp_blocks.append(MixerBlock(tokens_mlp_dim,channels_mlp_dim,tokens_hidden_dim,channels_hidden_dim))
self.fc=nn.Linear(channels_mlp_dim,num_classes)

def forward(self,x):
y=self.embd(x) # bs,channels,h,w
bs,c,h,w=y.shape
y=y.view(bs,c,-1).transpose(1,2) # bs,tokens,channels

if(self.tokens_mlp_dim!=y.shape[1]):
raise ValueError('Tokens_mlp_dim is not correct.')

for i in range(self.num_blocks):
y=self.mlp_blocks[i](y) # bs,tokens,channels
y=self.ln(y) # bs,tokens,channels
y=torch.mean(y,dim=1,keepdim=False) # bs,channels
probs=self.fc(y) # bs,num_classes
return probs



if __name__ == '__main__':
mlp_mixer=MlpMixer(num_classes=1000,num_blocks=10,patch_size=10,tokens_hidden_dim=32,channels_hidden_dim=1024,tokens_mlp_dim=16,channels_mlp_dim=1024)
input=torch.randn(50,3,40,40)
output=mlp_mixer(input)
print(output.shape)

0 comments on commit 00f4ed0

Please sign in to comment.