Skip to content

Commit

Permalink
parnet
Browse files Browse the repository at this point in the history
  • Loading branch information
xmu-xiaoma666 committed Oct 20, 2021
1 parent 6e073be commit 920734e
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 7 deletions.
39 changes: 37 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ Hello,大家好,我是小马🚀🚀🚀
- [27. Coordinate Attention Usage](#27-Coordinate-Attention-Usage)

- [28. MobileViT Attention Usage](#28-MobileViT-Attention-Usage)

- [29. ParNet Attention Usage](#29-ParNet-Attention-Usage)



- [Backbone Series](#Backbone-series)
Expand Down Expand Up @@ -227,7 +230,9 @@ Hello,大家好,我是小马🚀🚀🚀

- Pytorch implementation of [Coordinate Attention for Efficient Mobile Network Design ---CVPR 2021](https://arxiv.org/abs/2103.02907)

- Pytorch implementation of [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05](https://arxiv.org/abs/2103.02907)
- Pytorch implementation of [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2021.10.05](https://arxiv.org/abs/2103.02907)

- Pytorch implementation of [Non-deep Networks---ArXiv 2021.10.20](https://arxiv.org/abs/2110.07641)

***

Expand Down Expand Up @@ -948,7 +953,7 @@ print(output.shape)

#### 28.1. Paper

[MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05](https://arxiv.org/abs/2103.02907)
[MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2021.10.05](https://arxiv.org/abs/2103.02907)


#### 28.2. Overview
Expand All @@ -974,6 +979,36 @@ if __name__ == '__main__':
***


### 29. ParNet Attention Usage

#### 29.1. Paper

[Non-deep Networks---ArXiv 2021.10.20](https://arxiv.org/abs/2110.07641)


#### 29.2. Overview

![](./model/img/ParNet.png.png)

#### 29.3. Usage Code

```python
from model.attention.ParNetAttention import *
import torch
from torch import nn
from torch.nn import functional as F

if __name__ == '__main__':
input=torch.randn(50,512,7,7)
pna = ParNetAttention(channel=512)
output=pna(input)
print(output.shape) #50,512,7,7

```

***



# Backbone Series

Expand Down
10 changes: 5 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from model.conv.CondConv import *
from model.attention.ParNetAttention import *
import torch
from torch import nn
from torch.nn import functional as F

if __name__ == '__main__':
input=torch.randn(2,32,64,64)
m=CondConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)
out=m(input)
print(out.shape)
input=torch.randn(50,512,7,7)
pna = ParNetAttention(channel=512)
output=pna(input)
print(output.shape) #50,512,7,7
44 changes: 44 additions & 0 deletions model/attention/ParNetAttention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import torch
from torch import nn
from torch.nn import init



class ParNetAttention(nn.Module):

def __init__(self, channel=512):
super().__init__()
self.sse = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channel,channel,kernel_size=1),
nn.Sigmoid()
)

self.conv1x1=nn.Sequential(
nn.Conv2d(channel,channel,kernel_size=1),
nn.BatchNorm2d(channel)
)
self.conv3x3=nn.Sequential(
nn.Conv2d(channel,channel,kernel_size=3,padding=1),
nn.BatchNorm2d(channel)
)
self.silu=nn.SiLU()


def forward(self, x):
b, c, _, _ = x.size()
x1=self.conv1x1(x)
x2=self.conv3x3(x)
x3=self.sse(x)*x
y=self.silu(x1+x2+x3)
return y


if __name__ == '__main__':
input=torch.randn(50,512,7,7)
pna = ParNetAttention(channel=512)
output=pna(input)
print(output.shape)


Binary file not shown.
Binary file added model/img/ParNet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 920734e

Please sign in to comment.