From e5ffdc7589d5623af5d0daaa91dee073400b430d Mon Sep 17 00:00:00 2001 From: wz <180662@gree.com.cn> Date: Tue, 2 Mar 2021 17:50:04 +0800 Subject: [PATCH] update model --- pytorch_classification/Test10_regnet/model.py | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/pytorch_classification/Test10_regnet/model.py b/pytorch_classification/Test10_regnet/model.py index f1e32ab4c..4cab1aaf5 100644 --- a/pytorch_classification/Test10_regnet/model.py +++ b/pytorch_classification/Test10_regnet/model.py @@ -98,13 +98,13 @@ def __init__(self, bias=False) self.bn = nn.BatchNorm2d(out_c) - self.act = act + self.act = act if act is not None else nn.Identity() def forward(self, x: Tensor) -> Tensor: x = self.conv(x) x = self.bn(x) - if self.act is not None: - x = self.act(x) + # if self.act is not None: + x = self.act(x) return x @@ -120,15 +120,14 @@ def __init__(self, if drop_ratio > 0: self.dropout = nn.Dropout(p=drop_ratio) else: - self.dropout = None + self.dropout = nn.Identity() self.fc = nn.Linear(in_features=in_unit, out_features=out_unit) def forward(self, x: Tensor) -> Tensor: x = self.pool(x) x = torch.flatten(x, start_dim=1) - if self.dropout is not None: - x = self.dropout(x) + x = self.dropout(x) x = self.fc(x) return x @@ -172,7 +171,7 @@ def __init__(self, if se_ratio > 0: self.se = SqueezeExcitation(in_c, out_c, se_ratio) else: - self.se = None + self.se = nn.Identity() self.conv3 = ConvBNAct(in_c=out_c, out_c=out_c, kernel_s=1, act=None) self.ac3 = nn.ReLU(inplace=True) @@ -180,12 +179,12 @@ def __init__(self, if drop_ratio > 0: self.dropout = nn.Dropout(p=drop_ratio) else: - self.dropout = None + self.dropout = nn.Identity() if (in_c != out_c) or (stride != 1): self.downsample = ConvBNAct(in_c=in_c, out_c=out_c, kernel_s=1, stride=stride, act=None) else: - self.downsample = None + self.downsample = nn.Identity() def zero_init_last_bn(self): nn.init.zeros_(self.conv3.bn.weight) @@ -195,15 +194,12 @@ def forward(self, x: Tensor) -> Tensor: x = self.conv1(x) x = self.conv2(x) - if self.se is not None: - x = self.se(x) + x = self.se(x) x = self.conv3(x) - if self.dropout is not None: - x = self.dropout(x) + x = self.dropout(x) - if self.downsample is not None: - shortcut = self.downsample(shortcut) + shortcut = self.downsample(shortcut) x += shortcut x = self.ac3(x)