Skip to content

Commit

Permalink
modules: Refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
synml committed Nov 15, 2021
1 parent 44abec5 commit 63657a5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
8 changes: 4 additions & 4 deletions models/modules/aspp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
modules = [
models.modules.conv.SeparableConv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation,
activation=nn.ReLU(),
channel_attention=models.modules.attention.ChannelAttention(in_channels)),
models.modules.conv.SeparableConv2d(
in_channels, out_channels, 3, padding=dilation, dilation=dilation,
channel_attention=models.modules.attention.ChannelAttention(in_channels)
),
nn.ReLU()
]
super(ASPPConv, self).__init__(*modules)
Expand Down Expand Up @@ -53,7 +54,6 @@ def __init__(self, in_channels, atrous_rates, out_channels=256):
nn.BatchNorm2d(out_channels),
nn.ReLU()
)

self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
Expand Down
7 changes: 2 additions & 5 deletions models/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@

class ChannelAttention(nn.Module):
""" Squeeze and Excitation"""
def __init__(self, in_channels: int, reduction_ratio=4, activation: nn.Module = None, multiplication=True):
def __init__(self, in_channels: int, reduction_ratio=4, activation=nn.ReLU(), multiplication=True):
super(ChannelAttention, self).__init__()
self.gap = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, in_channels // reduction_ratio, 1)
if activation is not None:
self.activation = activation
else:
self.activation = nn.ReLU()
self.activation = activation
self.conv2 = nn.Conv2d(in_channels // reduction_ratio, in_channels, 1)
self.sigmoid = nn.Sigmoid()
self.multiplication = multiplication
Expand Down

0 comments on commit 63657a5

Please sign in to comment.