diff --git a/ptsemseg/models/linknet.py b/ptsemseg/models/linknet.py index 26faa6c4..9f647e69 100644 --- a/ptsemseg/models/linknet.py +++ b/ptsemseg/models/linknet.py @@ -1,19 +1,40 @@ import torch.nn as nn +import torchvision.models as models from utils import * +Resnets = {'resnet18' :{'layers':[2, 2, 2, 2],'filters':[64, 128, 256, 512], 'block':residualBlock,'expansion':1}, + 'resnet34' :{'layers':[3, 4, 6, 3],'filters':[64, 128, 256, 512], 'block':residualBlock,'expansion':1}, + 'resnet50' :{'layers':[3, 4, 6, 3],'filters':[64, 128, 256, 512], 'block':residualBottleneck,'expansion':4}, + 'resnet101' :{'layers':[3, 4, 23, 3],'filters':[64, 128, 256, 512], 'block':residualBottleneck,'expansion':4}, + 'resnet152':{'layers':[3, 8, 36, 3],'filters':[64, 128, 256, 512], 'block':residualBottleneck,'expansion':4} + } + +pretrained_models = {'resnet18': models.resnet18(pretrained=True), + 'resnet34': models.resnet34(pretrained=True), + 'resnet50': models.resnet50(pretrained=True), + 'resnet101': models.resnet101(pretrained=True), + 'resnet152': models.resnet152(pretrained=True) + } + + class linknet(nn.Module): - def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): + def __init__(self, resnet='resnet18', feature_scale=4, n_classes=21, pretrained=True, is_deconv=True, in_channels=3, is_batchnorm=True): super(linknet, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale - self.layers = [2, 2, 2, 2] # Currently hardcoded for ResNet-18 + self.pretrained = pretrained - filters = [64, 128, 256, 512] - filters = [x / self.feature_scale for x in filters] + assert resnet in Resnets.keys(), 'Not a valid resnet, currently supported resnets are 18, 34, 50, 101 and 152' + layers = Resnets[resnet]['layers'] + filters = Resnets[resnet]['filters'] + weights = pretrained_models[resnet] + + # filters = [x / self.feature_scale for x in filters] + expansion =Resnets[resnet]['expansion'] self.inplanes = filters[0] @@ -23,33 +44,35 @@ def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, padding=3, stride=2, bias=False) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - block = residualBlock - self.encoder1 = self._make_layer(block, filters[0], self.layers[0]) - self.encoder2 = self._make_layer(block, filters[1], self.layers[1], stride=2) - self.encoder3 = self._make_layer(block, filters[2], self.layers[2], stride=2) - self.encoder4 = self._make_layer(block, filters[3], self.layers[3], stride=2) + block = Resnets[resnet]['block'] + if self.pretrained: + self.encoder1 = weights.layer1 + self.encoder2 = weights.layer2 + self.encoder3 = weights.layer3 + self.encoder4 = weights.layer4 + else: + self.encoder1 = self._make_layer(block, filters[0], layers[0]) + self.encoder2 = self._make_layer(block, filters[1], layers[1], stride=2) + self.encoder3 = self._make_layer(block, filters[2], layers[2], stride=2) + self.encoder4 = self._make_layer(block, filters[3], layers[3], stride=2) self.avgpool = nn.AvgPool2d(7) # Decoder - self.decoder4 = linknetUp(filters[3], filters[2]) - self.decoder4 = linknetUp(filters[2], filters[1]) - self.decoder4 = linknetUp(filters[1], filters[0]) - self.decoder4 = linknetUp(filters[0], filters[0]) + self.decoder4 = linknetUp(filters[3]*expansion, filters[2]*expansion) + self.decoder3 = linknetUp(filters[2]*expansion, filters[1]*expansion) + self.decoder2 = linknetUp(filters[1]*expansion, filters[0]*expansion) + self.decoder1 = linknetUp(filters[0]*expansion, filters[0]) # Final Classifier - self.finaldeconvbnrelu1 = nn.Sequential(nn.ConvTranspose2d(filters[0], 32/feature_scale, 3, 2, 1), - nn.BatchNorm2d(32/feature_scale), - nn.ReLU(inplace=True),) + self.finaldeconvbnrelu1 = deconv2DBatchNormRelu(filters[0], 32/feature_scale, 2, 2, 0) self.finalconvbnrelu2 = conv2DBatchNormRelu(in_channels=32/feature_scale, k_size=3, n_filters=32/feature_scale, padding=1, stride=1) - self.finalconv3 = nn.Conv2d(32/feature_scale, n_classes, 2, 2, 0) + self.finalconv3 = nn.Conv2d(int(32/feature_scale), int(n_classes), 3, 1, 1) def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block.expansion),) + downsample = conv2DBatchNorm(self.inplanes, planes*block.expansion, k_size=1, stride=stride, padding=0, bias=False) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion @@ -57,7 +80,6 @@ def _make_layer(self, block, planes, blocks, stride=1): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) - def forward(self, x): # Encoder x = self.convbnrelu1(x) @@ -70,11 +92,11 @@ def forward(self, x): # Decoder with Skip Connections d4 = self.decoder4(e4) - d4 += e3 + d4 = d4 + e3 d3 = self.decoder3(d4) - d3 += e2 + d3 = d3 + e2 d2 = self.decoder2(d3) - d2 += e1 + d2 = d2 + e1 d1 = self.decoder1(d2) # Final Classification @@ -82,5 +104,4 @@ def forward(self, x): f2 = self.finalconvbnrelu2(f1) f3 = self.finalconv3(f2) - return f3 - + return f3 \ No newline at end of file diff --git a/ptsemseg/models/pspnet.py b/ptsemseg/models/pspnet.py index f0294a08..63b38b6c 100644 --- a/ptsemseg/models/pspnet.py +++ b/ptsemseg/models/pspnet.py @@ -2,18 +2,28 @@ from utils import * +Resnets = {'resnet18' :{'layers':[2, 2, 2, 2],'filters':[64, 128, 256, 512], 'block':residualBlock,'expansion':1}, + 'resnet34' :{'layers':[3, 4, 6, 3],'filters':[64, 128, 256, 512], 'block':residualBlock,'expansion':1}, + 'resnet50' :{'layers':[3, 4, 6, 3],'filters':[64, 128, 256, 512], 'block':residualBottleneck,'expansion':4}, + 'resnet101' :{'layers':[3, 4, 23, 3],'filters':[64, 128, 256, 512], 'block':residualBottleneck,'expansion':4}, + 'resnet152':{'layers':[3, 8, 36, 3],'filters':[64, 128, 256, 512], 'block':residualBottleneck,'expansion':4} + } + + class pspnet(nn.Module): - def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): + def __init__(self, resnet='resnet18', feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): super(pspnet, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale - self.layers = [2, 2, 2, 2] # Currently hardcoded for ResNet-18 - filters = [64, 128, 256, 512] - filters = [x / self.feature_scale for x in filters] + assert resnet in Resnets.keys(), 'Not a valid resnet, currently supported resnets are 18, 34, 50, 101 and 152' + layers = Resnets[resnet]['layers'] + filters = Resnets[resnet]['filters'] + # filters = [x / self.feature_scale for x in filters] + expansion =Resnets[resnet]['expansion'] self.inplanes = filters[0] @@ -23,33 +33,29 @@ def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, padding=3, stride=2, bias=False) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - block = residualBlock - self.encoder1 = self._make_layer(block, filters[0], self.layers[0]) - self.encoder2 = self._make_layer(block, filters[1], self.layers[1], stride=2) - self.encoder3 = self._make_layer(block, filters[2], self.layers[2], stride=2) - self.encoder4 = self._make_layer(block, filters[3], self.layers[3], stride=2) + block = Resnets[resnet]['block'] + self.encoder1 = self._make_layer(block, filters[0], layers[0]) + self.encoder2 = self._make_layer(block, filters[1], layers[1], stride=2) + self.encoder3 = self._make_layer(block, filters[2], layers[2], stride=2) + self.encoder4 = self._make_layer(block, filters[3], layers[3], stride=2) self.avgpool = nn.AvgPool2d(7) # Decoder - self.decoder4 = linknetUp(filters[3], filters[2]) - self.decoder4 = linknetUp(filters[2], filters[1]) - self.decoder4 = linknetUp(filters[1], filters[0]) - self.decoder4 = linknetUp(filters[0], filters[0]) + self.decoder4 = linknetUp(filters[3]*expansion, filters[2]*expansion) + self.decoder3 = linknetUp(filters[2]*expansion, filters[1]*expansion) + self.decoder2 = linknetUp(filters[1]*expansion, filters[0]*expansion) + self.decoder1 = linknetUp(filters[0]*expansion, filters[0]) # Final Classifier - self.finaldeconvbnrelu1 = nn.Sequential(nn.ConvTranspose2d(filters[0], 32/feature_scale, 3, 2, 1), - nn.BatchNorm2d(32/feature_scale), - nn.ReLU(inplace=True),) + self.finaldeconvbnrelu1 = deconv2DBatchNormRelu(filters[0], 32/feature_scale, 2, 2, 0) self.finalconvbnrelu2 = conv2DBatchNormRelu(in_channels=32/feature_scale, k_size=3, n_filters=32/feature_scale, padding=1, stride=1) - self.finalconv3 = nn.Conv2d(32/feature_scale, n_classes, 2, 2, 0) + self.finalconv3 = nn.Conv2d(int(32/feature_scale), int(n_classes), 3, 1, 1) def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block.expansion),) + downsample = conv2DBatchNorm(self.inplanes, planes*block.expansion, k_size=1, stride=stride, padding=0, bias=False) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion @@ -57,7 +63,6 @@ def _make_layer(self, block, planes, blocks, stride=1): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) - def forward(self, x): # Encoder x = self.convbnrelu1(x) @@ -70,11 +75,11 @@ def forward(self, x): # Decoder with Skip Connections d4 = self.decoder4(e4) - d4 += e3 + d4 = d4 + e3 d3 = self.decoder3(d4) - d3 += e2 + d3 = d3 + e2 d2 = self.decoder2(d3) - d2 += e1 + d2 = d2 + e1 d1 = self.decoder1(d2) # Final Classification @@ -83,4 +88,3 @@ def forward(self, x): f3 = self.finalconv3(f2) return f3 - diff --git a/ptsemseg/models/utils.py b/ptsemseg/models/utils.py index 13ff26cc..5025855a 100644 --- a/ptsemseg/models/utils.py +++ b/ptsemseg/models/utils.py @@ -8,8 +8,8 @@ def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): super(conv2DBatchNorm, self).__init__() self.cb_unit = nn.Sequential(nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, - padding=padding, stride=stride, bias=bias), - nn.BatchNorm2d(int(n_filters)),) + padding=padding, stride=stride, bias=bias), nn.BatchNorm2d(int(n_filters)),) + def forward(self, inputs): outputs = self.cb_unit(inputs) @@ -21,8 +21,8 @@ def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): super(deconv2DBatchNorm, self).__init__() self.dcb_unit = nn.Sequential(nn.ConvTranspose2d(int(in_channels), int(n_filters), kernel_size=k_size, - padding=padding, stride=stride, bias=bias), - nn.BatchNorm2d(int(n_filters)),) + padding=padding, stride=stride, bias=bias), + nn.BatchNorm2d(int(n_filters)),) def forward(self, inputs): outputs = self.dcb_unit(inputs) @@ -33,10 +33,9 @@ class conv2DBatchNormRelu(nn.Module): def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): super(conv2DBatchNormRelu, self).__init__() - self.cbr_unit = nn.Sequential(nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, - padding=padding, stride=stride, bias=bias), - nn.BatchNorm2d(int(n_filters)), - nn.ReLU(inplace=True),) + self.cbr_unit = nn.Sequential(nn.Conv2d(int(in_channels), int(n_filters),kernel_size=k_size, padding=padding, + stride=stride, bias=bias), nn.BatchNorm2d(int(n_filters)), + nn.ReLU(inplace=True),) def forward(self, inputs): outputs = self.cbr_unit(inputs) @@ -201,9 +200,9 @@ class residualBottleneck(nn.Module): def __init__(self, in_channels, n_filters, stride=1, downsample=None): super(residualBottleneck, self).__init__() - self.convbn1 = nn.Conv2DBatchNorm(in_channels, n_filters, k_size=1, bias=False) - self.convbn2 = nn.Conv2DBatchNorm(n_filters, n_filters, k_size=3, padding=1, stride=stride, bias=False) - self.convbn3 = nn.Conv2DBatchNorm(n_filters, n_filters * 4, k_size=1, bias=False) + self.convbn1 = conv2DBatchNorm(in_channels, n_filters, k_size=1, stride=1, padding=0, bias=False) + self.convbn2 = conv2DBatchNorm(n_filters, n_filters, k_size=3, stride=stride, padding=1, bias=False) + self.convbn3 = conv2DBatchNorm(n_filters, n_filters * self.expansion, k_size=1, stride=1, padding=0, bias=False) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride @@ -229,13 +228,14 @@ def __init__(self, in_channels, n_filters): super(linknetUp, self).__init__() # B, 2C, H, W -> B, C/2, H, W - self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters/2, k_size=1, stride=1, padding=1) - # B, C/2, H, W -> B, C/2, H, W - self.deconvbnrelu2 = nn.deconv2DBatchNormRelu(n_filters/2, n_filters/2, k_size=3, stride=2, padding=0,) + self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters/2, k_size=1, stride=1, padding=0) + + # B, C/2, H, W -> B, C/2, 2H, 2W + self.deconvbnrelu2 = deconv2DBatchNormRelu(n_filters/2, n_filters/2, k_size=2, stride=2, padding=0) - # B, C/2, H, W -> B, C, H, W - self.convbnrelu3 = conv2DBatchNormRelu(n_filters/2, n_filters, k_size=1, stride=1, padding=1) + # B, C/2, 2H, 2W -> B, C, 2H, 2W + self.convbnrelu3 = conv2DBatchNormRelu(n_filters/2, n_filters, k_size=1, stride=1, padding=0) def forward(self, x): x = self.convbnrelu1(x)