forked from YunseokJANG/l2l-da
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbin_codes.py
73 lines (54 loc) · 2.26 KB
/
bin_codes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
""" Holds activation layers """
from __future__ import print_function
import torch.nn as nn
class SequentialSelector(nn.Module):
""" NN that keeps track of ReluActivation statuses """
def __init__(self, original_net, selector_class=nn.modules.activation.ReLU):
""" Takes in an input neural net and will keep track (on forward) of
the status of all the ReLUs.
Needs to only be composed of linear layers and ReLus
"""
super(SequentialSelector, self).__init__()
self._safety_dance(original_net, selector_class)
self.original_net = original_net
self.selector_class = selector_class
self.new_seq = self._flatten(original_net)
def _safety_dance(self, original_net, selector_class):
""" Asserts that all layers are linear or ReLU's """
assert selector_class == nn.modules.activation.ReLU
assert isinstance(original_net, nn.Sequential)
for el in original_net:
if isinstance(el, nn.Sequential):
self._safety_dance(el, selector_class)
else:
assert isinstance(el, (nn.Linear, nn.ReLU))
def _flatten(self, original_net):
def _inner_loop(net, output_list):
for el in net:
if isinstance(el, nn.Sequential):
output_list = _inner_loop(el, output_list)
else:
output_list.append(el)
return output_list
flat_list = _inner_loop(original_net, [])
new_seq = nn.Sequential()
for i, el in enumerate(flat_list):
new_seq.add_module(str(i), el)
return new_seq
def forward_activations(self, x):
intermed = x
bincodes = []
for layer in self.new_seq:
intermed = layer(intermed)
if isinstance(layer, self.selector_class):
bincodes.append(intermed.clone())
return bincodes, intermed
def forward(self, x):
return self.forward_activations(x)[1]
def bincodes(self, x, binary=True):
activations = torch.stack([_.view(-1) for _ in
self.forward_activations(x)[0]])
if binary:
return activations > 0
else:
return activations