Skip to content

Commit 63657a5

Browse files
committed
modules: Refactor code
1 parent 44abec5 commit 63657a5

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

models/modules/aspp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
class ASPPConv(nn.Sequential):
99
def __init__(self, in_channels, out_channels, dilation):
1010
modules = [
11-
models.modules.conv.SeparableConv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation,
12-
activation=nn.ReLU(),
13-
channel_attention=models.modules.attention.ChannelAttention(in_channels)),
11+
models.modules.conv.SeparableConv2d(
12+
in_channels, out_channels, 3, padding=dilation, dilation=dilation,
13+
channel_attention=models.modules.attention.ChannelAttention(in_channels)
14+
),
1415
nn.ReLU()
1516
]
1617
super(ASPPConv, self).__init__(*modules)
@@ -53,7 +54,6 @@ def __init__(self, in_channels, atrous_rates, out_channels=256):
5354
nn.BatchNorm2d(out_channels),
5455
nn.ReLU()
5556
)
56-
5757
self.shortcut = nn.Sequential(
5858
nn.Conv2d(in_channels, out_channels, 1, bias=False),
5959
nn.BatchNorm2d(out_channels),

models/modules/attention.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,11 @@
44

55
class ChannelAttention(nn.Module):
66
""" Squeeze and Excitation"""
7-
def __init__(self, in_channels: int, reduction_ratio=4, activation: nn.Module = None, multiplication=True):
7+
def __init__(self, in_channels: int, reduction_ratio=4, activation=nn.ReLU(), multiplication=True):
88
super(ChannelAttention, self).__init__()
99
self.gap = nn.AdaptiveAvgPool2d(1)
1010
self.conv1 = nn.Conv2d(in_channels, in_channels // reduction_ratio, 1)
11-
if activation is not None:
12-
self.activation = activation
13-
else:
14-
self.activation = nn.ReLU()
11+
self.activation = activation
1512
self.conv2 = nn.Conv2d(in_channels // reduction_ratio, in_channels, 1)
1613
self.sigmoid = nn.Sigmoid()
1714
self.multiplication = multiplication

0 commit comments

Comments
 (0)