Skip to content

Commit

Permalink
update model
Browse files Browse the repository at this point in the history
  • Loading branch information
wz committed Mar 2, 2021
1 parent 8518df8 commit e5ffdc7
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions pytorch_classification/Test10_regnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ def __init__(self,
bias=False)

self.bn = nn.BatchNorm2d(out_c)
self.act = act
self.act = act if act is not None else nn.Identity()

def forward(self, x: Tensor) -> Tensor:
x = self.conv(x)
x = self.bn(x)
if self.act is not None:
x = self.act(x)
# if self.act is not None:
x = self.act(x)
return x


Expand All @@ -120,15 +120,14 @@ def __init__(self,
if drop_ratio > 0:
self.dropout = nn.Dropout(p=drop_ratio)
else:
self.dropout = None
self.dropout = nn.Identity()

self.fc = nn.Linear(in_features=in_unit, out_features=out_unit)

def forward(self, x: Tensor) -> Tensor:
x = self.pool(x)
x = torch.flatten(x, start_dim=1)
if self.dropout is not None:
x = self.dropout(x)
x = self.dropout(x)
x = self.fc(x)
return x

Expand Down Expand Up @@ -172,20 +171,20 @@ def __init__(self,
if se_ratio > 0:
self.se = SqueezeExcitation(in_c, out_c, se_ratio)
else:
self.se = None
self.se = nn.Identity()

self.conv3 = ConvBNAct(in_c=out_c, out_c=out_c, kernel_s=1, act=None)
self.ac3 = nn.ReLU(inplace=True)

if drop_ratio > 0:
self.dropout = nn.Dropout(p=drop_ratio)
else:
self.dropout = None
self.dropout = nn.Identity()

if (in_c != out_c) or (stride != 1):
self.downsample = ConvBNAct(in_c=in_c, out_c=out_c, kernel_s=1, stride=stride, act=None)
else:
self.downsample = None
self.downsample = nn.Identity()

def zero_init_last_bn(self):
nn.init.zeros_(self.conv3.bn.weight)
Expand All @@ -195,15 +194,12 @@ def forward(self, x: Tensor) -> Tensor:
x = self.conv1(x)
x = self.conv2(x)

if self.se is not None:
x = self.se(x)
x = self.se(x)
x = self.conv3(x)

if self.dropout is not None:
x = self.dropout(x)
x = self.dropout(x)

if self.downsample is not None:
shortcut = self.downsample(shortcut)
shortcut = self.downsample(shortcut)

x += shortcut
x = self.ac3(x)
Expand Down

0 comments on commit e5ffdc7

Please sign in to comment.