forked from kuangliu/pytorch-cifar
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
kuangliu
committed
Sep 19, 2017
1 parent
2ded131
commit 0ed52b1
Showing
5 changed files
with
132 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,4 @@ | |
from .googlenet import * | ||
from .mobilenet import * | ||
from .shufflenet import * | ||
from .preact_resnet import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
'''Pre-activation ResNet in PyTorch. | ||
Reference: | ||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun | ||
Identity Mappings in Deep Residual Networks. arXiv:1603.05027 | ||
''' | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from torch.autograd import Variable | ||
|
||
|
||
class PreActBlock(nn.Module): | ||
'''Pre-activation version of the BasicBlock.''' | ||
expansion = 1 | ||
|
||
def __init__(self, in_planes, planes, stride=1): | ||
super(PreActBlock, self).__init__() | ||
self.bn1 = nn.BatchNorm2d(in_planes) | ||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) | ||
|
||
if stride != 1 or in_planes != self.expansion*planes: | ||
self.shortcut = nn.Sequential( | ||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) | ||
) | ||
|
||
def forward(self, x): | ||
out = F.relu(self.bn1(x)) | ||
shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x | ||
out = self.conv1(out) | ||
out = self.conv2(F.relu(self.bn2(out))) | ||
out += shortcut | ||
return out | ||
|
||
|
||
class PreActBottleneck(nn.Module): | ||
'''Pre-activation version of the original Bottleneck module.''' | ||
expansion = 4 | ||
|
||
def __init__(self, in_planes, planes, stride=1): | ||
super(PreActBottleneck, self).__init__() | ||
self.bn1 = nn.BatchNorm2d(in_planes) | ||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | ||
self.bn3 = nn.BatchNorm2d(planes) | ||
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) | ||
|
||
if stride != 1 or in_planes != self.expansion*planes: | ||
self.shortcut = nn.Sequential( | ||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) | ||
) | ||
|
||
def forward(self, x): | ||
out = F.relu(self.bn1(x)) | ||
shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x | ||
out = self.conv1(out) | ||
out = self.conv2(F.relu(self.bn2(out))) | ||
out = self.conv3(F.relu(self.bn3(out))) | ||
out += shortcut | ||
return out | ||
|
||
|
||
class PreActResNet(nn.Module): | ||
def __init__(self, block, num_blocks, num_classes=10): | ||
super(PreActResNet, self).__init__() | ||
self.in_planes = 64 | ||
|
||
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) | ||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) | ||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) | ||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) | ||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) | ||
self.linear = nn.Linear(512*block.expansion, num_classes) | ||
|
||
def _make_layer(self, block, planes, num_blocks, stride): | ||
strides = [stride] + [1]*(num_blocks-1) | ||
layers = [] | ||
for stride in strides: | ||
layers.append(block(self.in_planes, planes, stride)) | ||
self.in_planes = planes * block.expansion | ||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
out = self.conv1(x) | ||
out = self.layer1(out) | ||
out = self.layer2(out) | ||
out = self.layer3(out) | ||
out = self.layer4(out) | ||
out = F.avg_pool2d(out, 4) | ||
out = out.view(out.size(0), -1) | ||
out = self.linear(out) | ||
return out | ||
|
||
|
||
def PreActResNet18(): | ||
return PreActResNet(PreActBlock, [2,2,2,2]) | ||
|
||
def PreActResNet34(): | ||
return PreActResNet(PreActBlock, [3,4,6,3]) | ||
|
||
def PreActResNet50(): | ||
return PreActResNet(PreActBottleneck, [3,4,6,3]) | ||
|
||
def PreActResNet101(): | ||
return PreActResNet(PreActBottleneck, [3,4,23,3]) | ||
|
||
def PreActResNet152(): | ||
return PreActResNet(PreActBottleneck, [3,8,36,3]) | ||
|
||
|
||
def test(): | ||
net = PreActResNet18() | ||
y = net(Variable(torch.randn(1,3,32,32))) | ||
print(y.size()) | ||
|
||
# test() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters