forked from mapillary/inplace_abn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
resnext.py
132 lines (114 loc) · 4.62 KB
/
resnext.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import sys
from collections import OrderedDict
from functools import partial
import torch.nn as nn
from ._util import try_index
from modules import IdentityResidualBlock, ABN, GlobalAvgPool2d
class ResNeXt(nn.Module):
def __init__(self,
structure,
groups=64,
norm_act=ABN,
input_3x3=False,
classes=0,
dilation=1,
base_channels=(128, 128, 256)):
"""Pre-activation (identity mapping) ResNeXt model
Parameters
----------
structure : list of int
Number of residual blocks in each of the four modules of the network.
groups : int
Number of groups in each ResNeXt block
norm_act : callable
Function to create normalization / activation Module.
input_3x3 : bool
If `True` use three `3x3` convolutions in the input module instead of a single `7x7` one.
classes : int
If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end
of the network.
dilation : list of list of int or list of int or int
List of dilation factors, or `1` to ignore dilation. For each module, if a single value is given it is
used for all its blocks, otherwise this expects a value for each block.
base_channels : list of int
Channels in the blocks of the first residual module. Each following module will multiply these values by 2.
"""
super(ResNeXt, self).__init__()
self.structure = structure
if len(structure) != 4:
raise ValueError("Expected a structure with four values")
if dilation != 1 and len(dilation) != 4:
raise ValueError("If dilation is not 1 it must contain four values")
# Initial layers
if input_3x3:
layers = [
("conv1", nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False)),
("bn1", norm_act(64)),
("conv2", nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)),
("bn2", norm_act(64)),
("conv3", nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)),
("pool", nn.MaxPool2d(3, stride=2, padding=1))
]
else:
layers = [
("conv1", nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False)),
("pool", nn.MaxPool2d(3, stride=2, padding=1))
]
self.mod1 = nn.Sequential(OrderedDict(layers))
# Groups of residual blocks
in_channels = 64
channels = base_channels
for mod_id, num in enumerate(structure):
# Create blocks for module
blocks = []
for block_id in range(num):
s, d = self._stride_dilation(mod_id, block_id, dilation)
blocks.append((
"block%d" % (block_id + 1),
IdentityResidualBlock(in_channels, channels, stride=s, norm_act=norm_act, groups=groups, dilation=d)
))
# Update channels
in_channels = channels[-1]
# Create and add module
self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks)))
channels = [c * 2 for c in channels]
# Pooling and predictor
self.bn_out = norm_act(in_channels)
if classes != 0:
self.classifier = nn.Sequential(OrderedDict([
("avg_pool", GlobalAvgPool2d()),
("fc", nn.Linear(in_channels, classes))
]))
def forward(self, img):
out = self.mod1(img)
out = self.mod2(out)
out = self.mod3(out)
out = self.mod4(out)
out = self.mod5(out)
out = self.bn_out(out)
if hasattr(self, "classifier"):
out = self.classifier(out)
return out
@staticmethod
def _stride_dilation(mod_id, block_id, dilation):
if dilation == 1:
s = 2 if mod_id > 0 and block_id == 0 else 1
d = 1
else:
if dilation[mod_id] == 1:
s = 2 if mod_id > 0 and block_id == 0 else 1
d = 1
else:
s = 1
d = try_index(dilation[mod_id], block_id)
return s, d
_NETS = {
"50": {"structure": [3, 4, 6, 3]},
"101": {"structure": [3, 4, 23, 3]},
"152": {"structure": [3, 8, 36, 3]},
}
__all__ = []
for name, params in _NETS.items():
net_name = "net_resnext" + name
setattr(sys.modules[__name__], net_name, partial(ResNeXt, **params))
__all__.append(net_name)