-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathAGCA.py
39 lines (34 loc) · 1.47 KB
/
AGCA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.nn as nn
from torch.nn import init
class AGCA(nn.Module):
def __init__(self, in_channel, ratio):
super(AGCA, self).__init__()
hide_channel = in_channel // ratio
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channel, hide_channel, kernel_size=1, bias=False)
self.softmax = nn.Softmax(2)
# Choose to deploy A0 on GPU or CPU according to your needs
self.A0 = torch.eye(hide_channel).to('cuda')
# self.A0 = torch.eye(hide_channel)
# A2 is initialized to 1e-6
self.A2 = nn.Parameter(torch.FloatTensor(torch.zeros((hide_channel, hide_channel))), requires_grad=True)
init.constant_(self.A2, 1e-6)
self.conv2 = nn.Conv1d(1, 1, kernel_size=1, bias=False)
self.conv3 = nn.Conv1d(1, 1, kernel_size=1, bias=False)
self.relu = nn.ReLU(inplace=True)
self.conv4 = nn.Conv2d(hide_channel, in_channel, kernel_size=1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
y = self.avg_pool(x)
y = self.conv1(y)
B, C, _, _ = y.size()
y = y.flatten(2).transpose(1, 2)
A1 = self.softmax(self.conv2(y))
A1 = A1.expand(B, C, C)
A = (self.A0 * A1) + self.A2
y = torch.matmul(y, A)
y = self.relu(self.conv3(y))
y = y.transpose(1, 2).view(-1, C, 1, 1)
y = self.sigmoid(self.conv4(y))
return x * y