This repository has been archived by the owner on Mar 10, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathl0_layers.py
292 lines (247 loc) · 13 KB
/
l0_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
import torch
import math
import torch.nn.functional as F
from torch.nn.modules import Module
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair as pair
from torch.autograd import Variable
from torch.nn import init
limit_a, limit_b, epsilon = -.1, 1.1, 1e-6
class L0Dense(Module):
"""Implementation of L0 regularization for the input units of a fully connected layer"""
def __init__(self, in_features, out_features, bias=True, weight_decay=1., droprate_init=0.5, temperature=2./3.,
lamba=1., local_rep=False, **kwargs):
"""
:param in_features: Input dimensionality
:param out_features: Output dimensionality
:param bias: Whether we use a bias
:param weight_decay: Strength of the L2 penalty
:param droprate_init: Dropout rate that the L0 gates will be initialized to
:param temperature: Temperature of the concrete distribution
:param lamba: Strength of the L0 penalty
:param local_rep: Whether we will use a separate gate sample per element in the minibatch
"""
super(L0Dense, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.prior_prec = weight_decay
self.weights = Parameter(torch.Tensor(in_features, out_features))
self.qz_loga = Parameter(torch.Tensor(in_features))
self.temperature = temperature
self.droprate_init = droprate_init if droprate_init != 0. else 0.5
self.lamba = lamba
self.use_bias = False
self.local_rep = local_rep
if bias:
self.bias = Parameter(torch.Tensor(out_features))
self.use_bias = True
self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor
self.reset_parameters()
print(self)
def reset_parameters(self):
init.kaiming_normal(self.weights, mode='fan_out')
self.qz_loga.data.normal_(math.log(1 - self.droprate_init) - math.log(self.droprate_init), 1e-2)
if self.use_bias:
self.bias.data.fill_(0)
def constrain_parameters(self, **kwargs):
self.qz_loga.data.clamp_(min=math.log(1e-2), max=math.log(1e2))
def cdf_qz(self, x):
"""Implements the CDF of the 'stretched' concrete distribution"""
xn = (x - limit_a) / (limit_b - limit_a)
logits = math.log(xn) - math.log(1 - xn)
return F.sigmoid(logits * self.temperature - self.qz_loga).clamp(min=epsilon, max=1 - epsilon)
def quantile_concrete(self, x):
"""Implements the quantile, aka inverse CDF, of the 'stretched' concrete distribution"""
y = F.sigmoid((torch.log(x) - torch.log(1 - x) + self.qz_loga) / self.temperature)
return y * (limit_b - limit_a) + limit_a
def _reg_w(self):
"""Expected L0 norm under the stochastic gates, takes into account and re-weights also a potential L2 penalty"""
logpw_col = torch.sum(- (.5 * self.prior_prec * self.weights.pow(2)) - self.lamba, 1)
logpw = torch.sum((1 - self.cdf_qz(0)) * logpw_col)
logpb = 0 if not self.use_bias else - torch.sum(.5 * self.prior_prec * self.bias.pow(2))
return logpw + logpb
def regularization(self):
return self._reg_w()
def count_expected_flops_and_l0(self):
"""Measures the expected floating point operations (FLOPs) and the expected L0 norm"""
# dim_in multiplications and dim_in - 1 additions for each output neuron for the weights
# + the bias addition for each neuron
# total_flops = (2 * in_features - 1) * out_features + out_features
ppos = torch.sum(1 - self.cdf_qz(0))
expected_flops = (2 * ppos - 1) * self.out_features
expected_l0 = ppos * self.out_features
if self.use_bias:
expected_flops += self.out_features
expected_l0 += self.out_features
return expected_flops.data[0], expected_l0.data[0]
def get_eps(self, size):
"""Uniform random numbers for the concrete distribution"""
eps = self.floatTensor(size).uniform_(epsilon, 1-epsilon)
eps = Variable(eps)
return eps
def sample_z(self, batch_size, sample=True):
"""Sample the hard-concrete gates for training and use a deterministic value for testing"""
if sample:
eps = self.get_eps(self.floatTensor(batch_size, self.in_features))
z = self.quantile_concrete(eps)
return F.hardtanh(z, min_val=0, max_val=1)
else: # mode
pi = F.sigmoid(self.qz_loga).view(1, self.in_features).expand(batch_size, self.in_features)
return F.hardtanh(pi * (limit_b - limit_a) + limit_a, min_val=0, max_val=1)
def sample_weights(self):
z = self.quantile_concrete(self.get_eps(self.floatTensor(self.in_features)))
mask = F.hardtanh(z, min_val=0, max_val=1)
return mask.view(self.in_features, 1) * self.weights
def forward(self, input):
if self.local_rep or not self.training:
z = self.sample_z(input.size(0), sample=self.training)
xin = input.mul(z)
output = xin.mm(self.weights)
else:
weights = self.sample_weights()
output = input.mm(weights)
if self.use_bias:
output.add_(self.bias)
return output
def __repr__(self):
s = ('{name}({in_features} -> {out_features}, droprate_init={droprate_init}, '
'lamba={lamba}, temperature={temperature}, weight_decay={prior_prec}, '
'local_rep={local_rep}')
if not self.use_bias:
s += ', bias=False'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
class L0Conv2d(Module):
"""Implementation of L0 regularization for the feature maps of a convolutional layer"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
droprate_init=0.5, temperature=2./3., weight_decay=1., lamba=1., local_rep=False, **kwargs):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param kernel_size: Size of the kernel
:param stride: Stride for the convolution
:param padding: Padding for the convolution
:param dilation: Dilation factor for the convolution
:param groups: How many groups we will assume in the convolution
:param bias: Whether we will use a bias
:param droprate_init: Dropout rate that the L0 gates will be initialized to
:param temperature: Temperature of the concrete distribution
:param weight_decay: Strength of the L2 penalty
:param lamba: Strength of the L0 penalty
:param local_rep: Whether we will use a separate gate sample per element in the minibatch
"""
super(L0Conv2d, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = pair(kernel_size)
self.stride = pair(stride)
self.padding = pair(padding)
self.dilation = pair(dilation)
self.output_padding = pair(0)
self.groups = groups
self.prior_prec = weight_decay
self.lamba = lamba
self.droprate_init = droprate_init if droprate_init != 0. else 0.5
self.temperature = temperature
self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor
self.use_bias = False
self.weights = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
self.qz_loga = Parameter(torch.Tensor(out_channels))
self.dim_z = out_channels
self.input_shape = None
self.local_rep = local_rep
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
self.use_bias = True
self.reset_parameters()
print(self)
def reset_parameters(self):
init.kaiming_normal(self.weights, mode='fan_in')
self.qz_loga.data.normal_(math.log(1 - self.droprate_init) - math.log(self.droprate_init), 1e-2)
if self.use_bias:
self.bias.data.fill_(0)
def constrain_parameters(self, **kwargs):
self.qz_loga.data.clamp_(min=math.log(1e-2), max=math.log(1e2))
def cdf_qz(self, x):
"""Implements the CDF of the 'stretched' concrete distribution"""
xn = (x - limit_a) / (limit_b - limit_a)
logits = math.log(xn) - math.log(1 - xn)
return F.sigmoid(logits * self.temperature - self.qz_loga).clamp(min=epsilon, max=1 - epsilon)
def quantile_concrete(self, x):
"""Implements the quantile, aka inverse CDF, of the 'stretched' concrete distribution"""
y = F.sigmoid((torch.log(x) - torch.log(1 - x) + self.qz_loga) / self.temperature)
return y * (limit_b - limit_a) + limit_a
def _reg_w(self):
"""Expected L0 norm under the stochastic gates, takes into account and re-weights also a potential L2 penalty"""
q0 = self.cdf_qz(0)
logpw_col = torch.sum(- (.5 * self.prior_prec * self.weights.pow(2)) - self.lamba, 3).sum(2).sum(1)
logpw = torch.sum((1 - q0) * logpw_col)
logpb = 0 if not self.use_bias else - torch.sum((1 - q0) * (.5 * self.prior_prec * self.bias.pow(2) -
self.lamba))
return logpw + logpb
def regularization(self):
return self._reg_w()
def count_expected_flops_and_l0(self):
"""Measures the expected floating point operations (FLOPs) and the expected L0 norm"""
ppos = torch.sum(1 - self.cdf_qz(0))
n = self.kernel_size[0] * self.kernel_size[1] * self.in_channels # vector_length
flops_per_instance = n + (n - 1) # (n: multiplications and n-1: additions)
num_instances_per_filter = ((self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1 # for rows
num_instances_per_filter *= ((self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]) + 1 # multiplying with cols
flops_per_filter = num_instances_per_filter * flops_per_instance
expected_flops = flops_per_filter * ppos # multiply with number of filters
expected_l0 = n * ppos
if self.use_bias:
# since the gate is applied to the output we also reduce the bias computation
expected_flops += num_instances_per_filter * ppos
expected_l0 += ppos
return expected_flops.data[0], expected_l0.data[0]
def get_eps(self, size):
"""Uniform random numbers for the concrete distribution"""
eps = self.floatTensor(size).uniform_(epsilon, 1-epsilon)
eps = Variable(eps)
return eps
def sample_z(self, batch_size, sample=True):
"""Sample the hard-concrete gates for training and use a deterministic value for testing"""
if sample:
eps = self.get_eps(self.floatTensor(batch_size, self.dim_z))
z = self.quantile_concrete(eps).view(batch_size, self.dim_z, 1, 1)
return F.hardtanh(z, min_val=0, max_val=1)
else: # mode
pi = F.sigmoid(self.qz_loga).view(1, self.dim_z, 1, 1)
return F.hardtanh(pi * (limit_b - limit_a) + limit_a, min_val=0, max_val=1)
def sample_weights(self):
z = self.quantile_concrete(self.get_eps(self.floatTensor(self.dim_z))).view(self.dim_z, 1, 1, 1)
return F.hardtanh(z, min_val=0, max_val=1) * self.weights
def forward(self, input_):
if self.input_shape is None:
self.input_shape = input_.size()
b = None if not self.use_bias else self.bias
if self.local_rep or not self.training:
output = F.conv2d(input_, self.weights, b, self.stride, self.padding, self.dilation, self.groups)
z = self.sample_z(output.size(0), sample=self.training)
return output.mul(z)
else:
weights = self.sample_weights()
output = F.conv2d(input_, weights, None, self.stride, self.padding, self.dilation, self.groups)
return output
def __repr__(self):
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}, '
'droprate_init={droprate_init}, temperature={temperature}, prior_prec={prior_prec}, '
'lamba={lamba}, local_rep={local_rep}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0,) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if not self.use_bias:
s += ', bias=False'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)