Skip to content

Commit

Permalink
mobileViTAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
xmu-xiaoma666 committed Oct 10, 2021
1 parent 56762d8 commit 6d86aab
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 14 deletions.
41 changes: 37 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ $ pip install fightingcv

- [27. Coordinate Attention Usage](#27-Coordinate-Attention-Usage)

- [28. MobileViT Attention Usage](#28-MobileViT-Attention-Usage)


- [Backbone CNN Series](#Backbone-cnn-series)

- [1. ResNet Usage](#1-ResNet-Usage)
Expand Down Expand Up @@ -227,6 +230,8 @@ $ pip install fightingcv

- Pytorch implementation of [Coordinate Attention for Efficient Mobile Network Design ---CVPR 2021](https://arxiv.org/abs/2103.02907)

- Pytorch implementation of [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05](https://arxiv.org/abs/2103.02907)

***

### 1. External Attention Usage
Expand Down Expand Up @@ -907,8 +912,6 @@ output=triplet(input)
print(output.shape)
```

***


***

Expand All @@ -917,8 +920,7 @@ print(output.shape)

#### 27.1. Paper

[Coordinate Attention for Efficient Mobile Network Design
---CVPR 2021](https://arxiv.org/abs/2103.02907)
[Coordinate Attention for Efficient Mobile Network Design---CVPR 2021](https://arxiv.org/abs/2103.02907)


#### 27.2. Overview
Expand All @@ -945,6 +947,37 @@ print(output.shape)
***


### 28. MobileViT Attention Usage

#### 28.1. Paper

[MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05](https://arxiv.org/abs/2103.02907)


#### 28.2. Overview

![](./fightingcv/img/MobileViTAttention.png)

#### 28.3. Usage Code

```python
from fightingcv.attention.MobileViTAttention import MobileViTAttention
import torch
from torch import nn
from torch.nn import functional as F

if __name__ == '__main__':
m=MobileViTAttention()
input=torch.randn(1,3,49,49)
output=m(input)
print(output.shape) #output:(1,3,49,49)

```

***



# Backbone CNN Series

- Pytorch implementation of ["Deep Residual Learning for Image Recognition---CVPR2016 Best Paper"](https://arxiv.org/pdf/1512.03385.pdf)
Expand Down
Binary file added fightingcv/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
10 changes: 5 additions & 5 deletions fightingcv/attention/MobileViTAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def forward(self,x,**kwargs):
class FeedForward(nn.Module):
def __init__(self,dim,mlp_dim,dropout) :
super().__init__()
self.net=nn.Sequetial(
nn.Lineqar(dim,mlp_dim),
self.net=nn.Sequential(
nn.Linear(dim,mlp_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim,dim),
Expand All @@ -36,7 +36,7 @@ def __init__(self,dim,heads,head_dim,dropout):
self.attend=nn.Softmax(dim=-1)
self.to_qkv=nn.Linear(dim,inner_dim*3,bias=False)

self.to_out=nn.Sequetial(
self.to_out=nn.Sequential(
nn.Linear(inner_dim,dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
Expand All @@ -58,7 +58,7 @@ class Transformer(nn.Module):
def __init__(self,dim,depth,heads,head_dim,mlp_dim,dropout=0.):
super().__init__()
self.layers=nn.ModuleList([])
for _ in nn.range(depth):
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim,Attention(dim,heads,head_dim,dropout)),
PreNorm(dim,FeedForward(dim,mlp_dim,dropout))
Expand All @@ -79,7 +79,7 @@ def __init__(self,in_channel=3,dim=512,kernel_size=3,patch_size=7):
self.conv1=nn.Conv2d(in_channel,in_channel,kernel_size=kernel_size,padding=kernel_size//2)
self.conv2=nn.Conv2d(in_channel,dim,kernel_size=1)

self.trans=Transformer()
self.trans=Transformer(dim=dim,depth=3,heads=8,head_dim=64,mlp_dim=1024)

self.conv3=nn.Conv2d(dim,in_channel,kernel_size=1)
self.conv4=nn.Conv2d(2*in_channel,in_channel,kernel_size=kernel_size,padding=kernel_size//2)
Expand Down
Binary file not shown.
Binary file not shown.
Binary file added fightingcv/img/MobileViTAttention.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 8 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from fightingcv.mlp.sMLP_block import sMLPBlock
from fightingcv.attention.MobileViTAttention import MobileViTAttention
import torch
from torch import nn
from torch.nn import functional as F

if __name__ == '__main__':
input=torch.randn(50,3,224,224)
smlp=sMLPBlock(h=224,w=224)
out=smlp(input)
print(out.shape)
m=MobileViTAttention()
input=torch.randn(1,3,49,49)
output=m(input)
print(output.shape)






Expand Down

0 comments on commit 6d86aab

Please sign in to comment.