Skip to content

Commit

Permalink
UFOAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
xmu-xiaoma666 committed Oct 29, 2021
1 parent 734cd91 commit 76b9aaf
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 5 deletions.
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ Hello,大家好,我是小马🚀🚀🚀

- [29. ParNet Attention Usage](#29-ParNet-Attention-Usage)

- [30. UFO Attention Usage](#30-UFO-Attention-Usage)


- [Backbone Series](#Backbone-series)
Expand Down Expand Up @@ -236,6 +237,8 @@ Hello,大家好,我是小马🚀🚀🚀

- Pytorch implementation of [Non-deep Networks---ArXiv 2021.10.20](https://arxiv.org/abs/2110.07641)

- Pytorch implementation of [UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29](https://arxiv.org/abs/2110.07641)

***

### 1. External Attention Usage
Expand Down Expand Up @@ -1011,6 +1014,35 @@ if __name__ == '__main__':
***


### 30. UFO Attention Usage

#### 30.1. Paper

[UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29](https://arxiv.org/abs/2110.07641)


#### 30.2. Overview

![](./model/img/UFO.png)

#### 30.3. Usage Code

```python
from model.attention.UFOAttention import *
import torch
from torch import nn
from torch.nn import functional as F

if __name__ == '__main__':
input=torch.randn(50,49,512)
ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8)
output=ufo(input,input,input)
print(output.shape) #[50, 49, 512]

```

***


# Backbone Series

Expand Down
10 changes: 5 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from model.attention.ParNetAttention import *
from model.attention.UFOAttention import *
import torch
from torch import nn
from torch.nn import functional as F

if __name__ == '__main__':
input=torch.randn(50,512,7,7)
pna = ParNetAttention(channel=512)
output=pna(input)
print(output.shape) #50,512,7,7
input=torch.randn(50,49,512)
ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8)
output=ufo(input,input,input)
print(output.shape) #[50, 49, 512]
79 changes: 79 additions & 0 deletions model/attention/UFOAttention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
import torch
from torch import nn
from torch.functional import norm
from torch.nn import init


def XNorm(x,gamma):
norm_tensor=torch.norm(x,2,-1,True)
return x*gamma/norm_tensor


class UFOAttention(nn.Module):
'''
Scaled dot-product attention
'''

def __init__(self, d_model, d_k, d_v, h,dropout=.1):
'''
:param d_model: Output dimensionality of the model
:param d_k: Dimensionality of queries and keys
:param d_v: Dimensionality of values
:param h: Number of heads
'''
super(UFOAttention, self).__init__()
self.fc_q = nn.Linear(d_model, h * d_k)
self.fc_k = nn.Linear(d_model, h * d_k)
self.fc_v = nn.Linear(d_model, h * d_v)
self.fc_o = nn.Linear(h * d_v, d_model)
self.dropout=nn.Dropout(dropout)
self.gamma=nn.Parameter(torch.randn((1,h,1,1)))

self.d_model = d_model
self.d_k = d_k
self.d_v = d_v
self.h = h

self.init_weights()


def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)

def forward(self, queries, keys, values):
b_s, nq = queries.shape[:2]
nk = keys.shape[1]

q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)

kv=torch.matmul(k, v) #bs,h,c,c
kv_norm=XNorm(kv,self.gamma) #bs,h,c,c
q_norm=XNorm(q,self.gamma) #bs,h,n,c
out=torch.matmul(q_norm,kv_norm).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)
out = self.fc_o(out) # (b_s, nq, d_model)


return out


if __name__ == '__main__':
input=torch.randn(50,49,512)
ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8)
output=ufo(input,input,input)
print(output.shape)


Binary file not shown.
Binary file added model/img/UFO.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 76b9aaf

Please sign in to comment.