-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
164 lines (146 loc) · 5.73 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import torch
import torch.nn as nn
import torch.nn.functional as F
class DownBlock(nn.Module):
def __init__(self, c_in, c_out):
"""
2D ResBlock with instance norm that downsamples to half of
the input dimensions
:param int c_in: Number of input channels
:param int c_out: Number of output channels
"""
super(DownBlock, self).__init__()
self.skip = nn.Sequential()
# 1x1 conv to adapt the number of channels
if c_in != c_out:
self.skip = nn.Sequential(
nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=1, stride=2),
nn.InstanceNorm2d(num_features=c_out)
)
else:
self.skip = None
self.main_path = nn.Sequential(
nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=3, stride=2, padding=1),
nn.InstanceNorm2d(num_features=c_out),
nn.LeakyReLU(),
nn.Conv2d(in_channels=c_out, out_channels=c_out, kernel_size=3, padding=1),
nn.InstanceNorm2d(num_features=c_out)
)
def forward(self, x):
# init skip path
if(self.skip is not None):
skip_path = self.skip(x)
else:
skip_path = x # identity
# main path
x = self.main_path(x)
# add skip connection
x += skip_path
x = F.leaky_relu(x)
return x
class UpBlock(nn.Module):
def __init__(self, c_in, c_out):
"""
2D Upsampling convolution
:param int c_in: Number of input channels
:param int c_out: Number of output channels
"""
super(UpBlock, self).__init__()
self.model = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=3, padding=1),
nn.LeakyReLU(),
#nn.InstanceNorm2d(c_out)
)
def forward(self, input, skip, palette):
N = input.shape[0]
x = input
if(skip is not None):
x = torch.hstack((skip, x))
if(palette is not None):
x = torch.hstack((palette, x))
x = self.model(x)
return x
class FeatureEncoder(nn.Module):
def __init__(self):
"""
Downsampling path of the netword that encodes the features of
the inputs to a compressed form and provides intermediate outputs
as skip connections to the decoder / upsampling path
"""
super(FeatureEncoder, self).__init__()
self.input = None
self.conv_out = None
self.res_block_1_out = None
self.res_block_2_out = None
self.res_block_3_out = None
self.conv = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1, stride=2),
nn.InstanceNorm2d(num_features=64),
nn.LeakyReLU()
)
self.res_block_1 = DownBlock(64, 128)
self.res_block_2 = DownBlock(128, 256)
self.res_block_3 = DownBlock(256, 512)
def forward(self, x):
self.input = x
x = self.conv(x)
self.conv_out = x
x = self.res_block_1(x)
self.res_block_1_out = x
x = self.res_block_2(x)
self.res_block_2_out = x
x = self.res_block_3(x)
self.res_block_3_out = x
return x
class RecoloringDecoder(nn.Module):
def __init__(self, encoder:FeatureEncoder, num_colors=6):
"""
Upsampling path of the network that takes as inputs
the target palette and the encoded features and produces
a recolored output matching to the target palette.
"""
super(RecoloringDecoder, self).__init__()
self.encoder = encoder
self.num_color_chs = num_colors * 3
self.up_block_1 = UpBlock(512 + self.num_color_chs, 256)
self.up_block_2 = UpBlock(512, 128)
self.up_block_3 = UpBlock(256 + self.num_color_chs, 64)
self.up_block_4 = UpBlock(128 + self.num_color_chs, 32)
self.conv = nn.Sequential(
nn.Conv2d(in_channels=32 + 1, out_channels=2, kernel_size=3, padding=1),
nn.Tanh()
)
def forward(self, input, palette):
x = input
h = x.shape[2]
w = x.shape[3]
# reshape palette:
palette_pixel = palette.reshape(-1, palette.shape[1] * palette.shape[2] * palette.shape[3], 1, 1)
# repeat palette pixel for height and width of the input:
palette = palette_pixel.repeat((1, 1, h, w))
x = self.up_block_1.forward(x, None, palette)
x = self.up_block_2.forward(x, self.encoder.res_block_2_out, None)
# update palette dims to match input
palette = palette_pixel.repeat((1, 1, x.shape[2], x.shape[3]))
x = self.up_block_3.forward(x, self.encoder.res_block_1_out, palette)
palette = palette_pixel.repeat((1, 1, x.shape[2], x.shape[3]))
x = self.up_block_4.forward(x, self.encoder.conv_out, palette)
# append LAB lightness from input before the final convolution
ll = self.encoder.input[:, 0, :, :].reshape(-1, 1, self.encoder.input.shape[2], self.encoder.input.shape[3])
x = torch.hstack((ll, x))
x = self.conv(x)
return x
class PaletteNet(nn.Module):
def __init__(self):
"""
The main model of the PaletteNet with a feature
encoder and a recoloring decoder
"""
super(PaletteNet, self).__init__()
self.encoder = FeatureEncoder()
self.decoder = RecoloringDecoder(self.encoder)
def forward(self, x, palette):
x = self.encoder.forward(x)
x = self.decoder.forward(x, palette)
return x