Skip to content

Commit

Permalink
GFNet
Browse files Browse the repository at this point in the history
  • Loading branch information
xmu-xiaoma666 committed Aug 18, 2021
1 parent 7c3e363 commit 226de99
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 17 deletions.
39 changes: 39 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ $ pip install dlutils_add

- [24. S2 Attention Usage](#24-S2-Attention-Usage)

- [25. GFNet Attention Usage](#25-GFNet-Attention-Usage)


- [Backbone CNN Series](#Backbone-cnn-series)
Expand Down Expand Up @@ -210,6 +211,7 @@ $ pip install dlutils_add

- Pytorch implementation of [S²-MLPv2: Improved Spatial-Shift MLP Architecture for Vision---arXiv 2021.08.02](https://arxiv.org/abs/2108.01072) [【论文解析】](https://zhuanlan.zhihu.com/p/397003638)

- Pytorch implementation of [Global Filter Networks for Image Classification---arXiv 2027.01.01](https://arxiv.org/abs/2108.01072)

***

Expand Down Expand Up @@ -837,6 +839,41 @@ print(output.shape)

***

- Pytorch implementation of





### 25. GFNet Attention Usage

#### 25.1. Paper

[Global Filter Networks for Image Classification---arXiv 2027.01.01](https://arxiv.org/abs/2108.01072)


#### 25.2. Overview

![](./img/GFNet.jpg)

#### 25.3. Code - Implemented by 原作者(赵文亮)
```python
from attention.gfnet import GFNet
import torch
from torch import nn
from torch.nn import functional as F

x = torch.randn(1, 3, 224, 224)
gfnet = GFNet(embed_dim=384, img_size=224, patch_size=16, num_classes=1000)
out = gfnet(x)
print(out.shape)

```

***




# Backbone CNN Series

Expand Down Expand Up @@ -901,6 +938,8 @@ if __name__ == '__main__':





# MLP Series

- Pytorch implementation of ["RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition---arXiv 2021.05.05"](https://arxiv.org/pdf/2105.01883v1.pdf)
Expand Down
Binary file added attention/__pycache__/gfnet.cpython-36.pyc
Binary file not shown.
118 changes: 118 additions & 0 deletions attention/gfnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import torch
from torch import nn
import math
from timm.models.layers import DropPath, to_2tuple

class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x

class GlobalFilter(nn.Module):
def __init__(self, dim, h=14, w=8):
super().__init__()
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
self.w = w
self.h = h

def forward(self, x, spatial_size=None):
B, N, C = x.shape
if spatial_size is None:
a = b = int(math.sqrt(N))
else:
a, b = spatial_size

x = x.view(B, a, b, C)

x = x.to(torch.float32)

x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
weight = torch.view_as_complex(self.complex_weight)
x = x * weight
x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho')

x = x.reshape(B, N, C)
return x

class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x

class Block(nn.Module):
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8):
super().__init__()
self.norm1 = norm_layer(dim)
self.filter = GlobalFilter(dim, h=h, w=w)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

def forward(self, x):
x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))
return x


class GFNet(nn.Module):
def __init__(self, embed_dim=384, img_size=224, patch_size=16, mlp_ratio=4, depth=4, num_classes=1000):
super().__init__()
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
self.embedding = nn.Linear((patch_size ** 2) * 3, embed_dim)

h = img_size // patch_size
w = h // 2 + 1


self.blocks = nn.ModuleList([
Block(dim=embed_dim, mlp_ratio=mlp_ratio, h=h, w=w)
for i in range(depth)
])

self.head = nn.Linear(embed_dim, num_classes)
self.softmax = nn.Softmax(1)

def forward(self, x):
x = self.patch_embed(x)
for blk in self.blocks:
x = blk(x)
x = x.mean(dim=1)
x = self.softmax(self.head(x))
return x

if __name__ == '__main__':
x = torch.randn(1, 3, 224, 224)
gfnet = GFNet(embed_dim=384, img_size=224, patch_size=16, num_classes=1000)
out = gfnet(x)
print(out.shape)


Binary file added img/GFNet.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 17 additions & 17 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
# from attention.S2Attention import S2Attention
# import torch
# from torch import nn
# from torch.nn import functional as F
from attention.gfnet import GFNet
import torch
from torch import nn
from torch.nn import functional as F

# input=torch.randn(50,512,7,7)
# s2att = S2Attention(channels=512)
# output=s2att(input)
# print(output.shape)
x = torch.randn(1, 3, 224, 224)
gfnet = GFNet(embed_dim=384, img_size=224, patch_size=16, num_classes=1000)
out = gfnet(x)
print(out.shape)


from backbone_cnn.resnext import ResNeXt50,ResNeXt101,ResNeXt152
import torch
# from backbone_cnn.resnext import ResNeXt50,ResNeXt101,ResNeXt152
# import torch

if __name__ == '__main__':
input=torch.randn(50,3,224,224)
resnext50=ResNeXt50(1000)
# resnext101=ResNeXt101(1000)
# resnext152=ResNeXt152(1000)
out=resnext50(input)
print(out.shape)
# if __name__ == '__main__':
# input=torch.randn(50,3,224,224)
# resnext50=ResNeXt50(1000)
# # resnext101=ResNeXt101(1000)
# # resnext152=ResNeXt152(1000)
# out=resnext50(input)
# print(out.shape)

0 comments on commit 226de99

Please sign in to comment.