Skip to content

Commit

Permalink
Coordinate Attention - CVPR21
Browse files Browse the repository at this point in the history
  • Loading branch information
rushi-the-neural-arch committed Sep 25, 2021
1 parent 7ef9bb8 commit 73de5ef
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions fightingcv/attention/CoordAttention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)

def forward(self, x):
return self.relu(x + 3) / 6

class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
self.sigmoid = h_sigmoid(inplace=inplace)

def forward(self, x):
return x * self.sigmoid(x)

class CoordAtt(nn.Module):
def __init__(self, inp, oup, reduction=32):
super(CoordAtt, self).__init__()
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))

mip = max(8, inp // reduction)

self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(mip)
self.act = h_swish()

self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)


def forward(self, x):
identity = x

n,c,h,w = x.size()
x_h = self.pool_h(x)
x_w = self.pool_w(x).permute(0, 1, 3, 2)

y = torch.cat([x_h, x_w], dim=2)
y = self.conv1(y)
y = self.bn1(y)
y = self.act(y)

x_h, x_w = torch.split(y, [h, w], dim=2)
x_w = x_w.permute(0, 1, 3, 2)

a_h = self.conv_h(x_h).sigmoid()
a_w = self.conv_w(x_w).sigmoid()

out = identity * a_w * a_h

return out

0 comments on commit 73de5ef

Please sign in to comment.