-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathalbunet_v2.py
executable file
·175 lines (116 loc) · 5.03 KB
/
albunet_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from torch import nn
from torch.nn import functional as F
import torch
# from torchvision import models
import torchvision
import resnet_v2
# def conv3x3(in_, out):
# return nn.Conv2d(in_, out, 3, padding=1)
class conv3x3(nn.Module):
"""3x3 convolution with padding
https://github.com/jeffwen/road_building_extraction
reflection padding for same size output as input (reflection padding has shown better results than zero padding)
"""
def __init__(self, in_planes, out_planes, stride=1):
super().__init__()
self.conv = nn.Sequential(nn.ReflectionPad2d(padding=(3 -1)//2),
nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride))
def forward(self, input):
output = self.conv(input)
return output
class ConvRelu(nn.Module):
def __init__(self, in_, out):
super().__init__()
self.conv = conv3x3(in_, out)
# self.activation = nn.ReLU(inplace=True)
self.activation = nn.PReLU()
def forward(self, x):
x = self.conv(x)
x = self.activation(x)
return x
class DecoderBlockV2(nn.Module):
"""
change relu to Prelu
"""
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
super(DecoderBlockV2, self).__init__()
self.in_channels = in_channels
if is_deconv:
"""
Paramaters for Deconvolution were chosen to avoid artifacts, following
link https://distill.pub/2016/deconv-checkerboard/
"""
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
padding=1),
# nn.ReLU(inplace=True)
nn.PReLU()
)
else:
self.block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
ConvRelu(in_channels, middle_channels),
ConvRelu(middle_channels, out_channels),
)
def forward(self, x):
return self.block(x)
class AlbuNet(nn.Module):
"""
UNet (https://arxiv.org/abs/1505.04597) with Resnet34(https://arxiv.org/abs/1512.03385) encoder
Proposed by Alexander Buslaev: https://www.linkedin.com/in/al-buslaev/
slightly change according to http://jeffwen.com/2018/02/23/road_extraction
https://github.com/jeffwen/road_building_extraction
Pretrained model cannot be used here after changing
"""
def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False):
"""
:param num_classes:
:param num_filters:
:param pretrained:
False - no pre-trained network is used
True - encoder is pre-trained with resnet34
:is_deconv:
False: bilinear interpolation is used in decoder
True: deconvolution is used in decoder
"""
super().__init__()
self.num_classes = num_classes
self.pool = nn.MaxPool2d(2, 2)
self.encoder = resnet_v2.resnet34(pretrained=pretrained)
# self.relu = nn.ReLU(inplace=True)
self.relu = nn.PReLU()
self.conv1 = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu,
self.pool)
self.conv2 = self.encoder.layer1
self.conv3 = self.encoder.layer2
self.conv4 = self.encoder.layer3
self.conv5 = self.encoder.layer4
self.center = DecoderBlockV2(512, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec5 = DecoderBlockV2(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlockV2(256 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlockV2(128 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec2 = DecoderBlockV2(64 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2, is_deconv)
self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec0 = ConvRelu(num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
conv5 = self.conv5(conv4)
center = self.center(self.pool(conv5))
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(dec2)
dec0 = self.dec0(dec1)
if self.num_classes > 1:
x_out = F.log_softmax(self.final(dec0), dim=1)
else:
x_out = self.final(dec0)
return x_out