Skip to content

Commit

Permalink
sMLP
Browse files Browse the repository at this point in the history
  • Loading branch information
xmu-xiaoma666 committed Sep 14, 2021
1 parent 3ea4c82 commit 9cb2aea
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 17 deletions.
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ $ pip install fightingcv

- [4. gMLP Usage](#4-gMLP-Usage)

- [5. sMLP Usage](#5-sMLP-Usage)

- [Re-Parameter(ReP) Series](#Re-Parameter-series)

Expand Down Expand Up @@ -984,6 +985,9 @@ if __name__ == '__main__':

- Pytorch implementation of ["Pay Attention to MLPs---arXiv 2021.05.17"](https://arxiv.org/abs/2105.08050)


- Pytorch implementation of ["Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?---arXiv 2021.09.12"](https://arxiv.org/abs/2109.05422)

### 1. RepMLP Usage
#### 1.1. Paper
["RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition"](https://arxiv.org/pdf/2105.01883v1.pdf)
Expand Down Expand Up @@ -1087,6 +1091,28 @@ output=gmlp(input)
print(output.shape)
```

***

### 5. sMLP Usage
#### 5.1. Paper
["Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?"](https://arxiv.org/abs/2109.05422)

#### 5.2. Overview
![](./fightingcv/img/sMLP.jpg)

#### 5.3. Code
```python
from fightingcv.mlp.sMLP_block import sMLPBlock
import torch
from torch import nn
from torch.nn import functional as F

if __name__ == '__main__':
input=torch.randn(50,3,224,224)
smlp=sMLPBlock(h=224,w=224)
out=smlp(input)
print(out.shape)
```


# Re-Parameter Series
Expand Down
Binary file modified fightingcv/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file added fightingcv/img/sMLP.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
30 changes: 30 additions & 0 deletions fightingcv/mlp/sMLP_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
from torch import nn







class sMLPBlock(nn.Module):
def __init__(self,h=224,w=224,c=3):
super().__init__()
self.proj_h=nn.Linear(h,h)
self.proj_w=nn.Linear(w,w)
self.fuse=nn.Linear(3*c,c)

def forward(self,x):
x_h=self.proj_h(x.permute(0,1,3,2)).permute(0,1,3,2)
x_w=self.proj_w(x)
x_id=x
x_fuse=torch.cat([x_h,x_w,x_id],dim=1)
out=self.fuse(x_fuse.permute(0,2,3,1)).permute(0,3,1,2)
return out


if __name__ == '__main__':
input=torch.randn(50,3,224,224)
smlp=sMLPBlock(h=224,w=224)
out=smlp(input)
print(out.shape)
26 changes: 9 additions & 17 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
# from attention.gfnet import GFNet
# import torch
# from torch import nn
# from torch.nn import functional as F
from fightingcv.mlp.sMLP_block import sMLPBlock
import torch
from torch import nn
from torch.nn import functional as F

# x = torch.randn(1, 3, 224, 224)
# gfnet = GFNet(embed_dim=384, img_size=224, patch_size=16, num_classes=1000)
# out = gfnet(x)
# print(out.shape)
if __name__ == '__main__':
input=torch.randn(50,3,224,224)
smlp=sMLPBlock(h=224,w=224)
out=smlp(input)
print(out.shape)


# from backbone_cnn.resnext import ResNeXt50,ResNeXt101,ResNeXt152
# import torch

# if __name__ == '__main__':
# input=torch.randn(50,3,224,224)
# resnext50=ResNeXt50(1000)
# # resnext101=ResNeXt101(1000)
# # resnext152=ResNeXt152(1000)
# out=resnext50(input)
# print(out.shape)


0 comments on commit 9cb2aea

Please sign in to comment.