diff --git a/README.md b/README.md index 0c86ccd..4db5bdd 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,7 @@ Hello,大家好,我是小马🚀🚀🚀 - [29. ParNet Attention Usage](#29-ParNet-Attention-Usage) + - [30. UFO Attention Usage](#30-UFO-Attention-Usage) - [Backbone Series](#Backbone-series) @@ -236,6 +237,8 @@ Hello,大家好,我是小马🚀🚀🚀 - Pytorch implementation of [Non-deep Networks---ArXiv 2021.10.20](https://arxiv.org/abs/2110.07641) +- Pytorch implementation of [UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29](https://arxiv.org/abs/2110.07641) + *** ### 1. External Attention Usage @@ -1011,6 +1014,35 @@ if __name__ == '__main__': *** +### 30. UFO Attention Usage + +#### 30.1. Paper + +[UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29](https://arxiv.org/abs/2110.07641) + + +#### 30.2. Overview + +![](./model/img/UFO.png) + +#### 30.3. Usage Code + +```python +from model.attention.UFOAttention import * +import torch +from torch import nn +from torch.nn import functional as F + +if __name__ == '__main__': + input=torch.randn(50,49,512) + ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8) + output=ufo(input,input,input) + print(output.shape) #[50, 49, 512] + +``` + +*** + # Backbone Series diff --git a/main.py b/main.py index af98316..127bcb2 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,10 @@ -from model.attention.ParNetAttention import * +from model.attention.UFOAttention import * import torch from torch import nn from torch.nn import functional as F if __name__ == '__main__': - input=torch.randn(50,512,7,7) - pna = ParNetAttention(channel=512) - output=pna(input) - print(output.shape) #50,512,7,7 \ No newline at end of file + input=torch.randn(50,49,512) + ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8) + output=ufo(input,input,input) + print(output.shape) #[50, 49, 512] \ No newline at end of file diff --git a/model/attention/UFOAttention.py b/model/attention/UFOAttention.py new file mode 100644 index 0000000..7f46aeb --- /dev/null +++ b/model/attention/UFOAttention.py @@ -0,0 +1,79 @@ +import numpy as np +import torch +from torch import nn +from torch.functional import norm +from torch.nn import init + + +def XNorm(x,gamma): + norm_tensor=torch.norm(x,2,-1,True) + return x*gamma/norm_tensor + + +class UFOAttention(nn.Module): + ''' + Scaled dot-product attention + ''' + + def __init__(self, d_model, d_k, d_v, h,dropout=.1): + ''' + :param d_model: Output dimensionality of the model + :param d_k: Dimensionality of queries and keys + :param d_v: Dimensionality of values + :param h: Number of heads + ''' + super(UFOAttention, self).__init__() + self.fc_q = nn.Linear(d_model, h * d_k) + self.fc_k = nn.Linear(d_model, h * d_k) + self.fc_v = nn.Linear(d_model, h * d_v) + self.fc_o = nn.Linear(h * d_v, d_model) + self.dropout=nn.Dropout(dropout) + self.gamma=nn.Parameter(torch.randn((1,h,1,1))) + + self.d_model = d_model + self.d_k = d_k + self.d_v = d_v + self.h = h + + self.init_weights() + + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.001) + if m.bias is not None: + init.constant_(m.bias, 0) + + def forward(self, queries, keys, values): + b_s, nq = queries.shape[:2] + nk = keys.shape[1] + + q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) + k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) + v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) + + kv=torch.matmul(k, v) #bs,h,c,c + kv_norm=XNorm(kv,self.gamma) #bs,h,c,c + q_norm=XNorm(q,self.gamma) #bs,h,n,c + out=torch.matmul(q_norm,kv_norm).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) + out = self.fc_o(out) # (b_s, nq, d_model) + + + return out + + +if __name__ == '__main__': + input=torch.randn(50,49,512) + ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8) + output=ufo(input,input,input) + print(output.shape) + + \ No newline at end of file diff --git a/model/attention/__pycache__/UFOAttention.cpython-36.pyc b/model/attention/__pycache__/UFOAttention.cpython-36.pyc new file mode 100644 index 0000000..3a3231d Binary files /dev/null and b/model/attention/__pycache__/UFOAttention.cpython-36.pyc differ diff --git a/model/img/UFO.png b/model/img/UFO.png new file mode 100644 index 0000000..3a7d11b Binary files /dev/null and b/model/img/UFO.png differ