-
Notifications
You must be signed in to change notification settings - Fork 101
/
Copy pathON_LSTM.py
184 lines (142 loc) · 6.12 KB
/
ON_LSTM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import torch.nn.functional as F
import torch.nn as nn
import torch
from locked_dropout import LockedDropout
class LayerNorm(nn.Module):
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.ones(features))
self.beta = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
class LinearDropConnect(nn.Linear):
def __init__(self, in_features, out_features, bias=True, dropout=0.):
super(LinearDropConnect, self).__init__(
in_features=in_features,
out_features=out_features,
bias=bias
)
self.dropout = dropout
def sample_mask(self):
if self.dropout == 0.:
self._weight = self.weight
else:
mask = self.weight.new_empty(
self.weight.size(),
dtype=torch.uint8
)
mask.bernoulli_(self.dropout)
self._weight = self.weight.masked_fill(mask, 0.)
def forward(self, input, sample_mask=False):
if self.training:
if sample_mask:
self.sample_mask()
return F.linear(input, self._weight, self.bias)
else:
return F.linear(input, self.weight * (1 - self.dropout),
self.bias)
def cumsoftmax(x, dim=-1):
return torch.cumsum(F.softmax(x, dim=dim), dim=dim)
class ONLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, chunk_size, dropconnect=0.):
super(ONLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.chunk_size = chunk_size
self.n_chunk = int(hidden_size / chunk_size)
self.ih = nn.Sequential(
nn.Linear(input_size, 4 * hidden_size + self.n_chunk * 2, bias=True),
# LayerNorm(3 * hidden_size)
)
self.hh = LinearDropConnect(hidden_size, hidden_size*4+self.n_chunk*2, bias=True, dropout=dropconnect)
# self.c_norm = LayerNorm(hidden_size)
self.drop_weight_modules = [self.hh]
def forward(self, input, hidden,
transformed_input=None):
hx, cx = hidden
if transformed_input is None:
transformed_input = self.ih(input)
gates = transformed_input + self.hh(hx)
cingate, cforgetgate = gates[:, :self.n_chunk*2].chunk(2, 1)
outgate, cell, ingate, forgetgate = gates[:,self.n_chunk*2:].view(-1, self.n_chunk*4, self.chunk_size).chunk(4,1)
cingate = 1. - cumsoftmax(cingate)
cforgetgate = cumsoftmax(cforgetgate)
distance_cforget = 1. - cforgetgate.sum(dim=-1) / self.n_chunk
distance_cin = cingate.sum(dim=-1) / self.n_chunk
cingate = cingate[:, :, None]
cforgetgate = cforgetgate[:, :, None]
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cell = F.tanh(cell)
outgate = F.sigmoid(outgate)
# cy = cforgetgate * forgetgate * cx + cingate * ingate * cell
overlap = cforgetgate * cingate
forgetgate = forgetgate * overlap + (cforgetgate - overlap)
ingate = ingate * overlap + (cingate - overlap)
cy = forgetgate * cx + ingate * cell
# hy = outgate * F.tanh(self.c_norm(cy))
hy = outgate * F.tanh(cy)
return hy.view(-1, self.hidden_size), cy, (distance_cforget, distance_cin)
def init_hidden(self, bsz):
weight = next(self.parameters()).data
return (weight.new(bsz, self.hidden_size).zero_(),
weight.new(bsz, self.n_chunk, self.chunk_size).zero_())
def sample_masks(self):
for m in self.drop_weight_modules:
m.sample_mask()
class ONLSTMStack(nn.Module):
def __init__(self, layer_sizes, chunk_size, dropout=0., dropconnect=0.):
super(ONLSTMStack, self).__init__()
self.cells = nn.ModuleList([ONLSTMCell(layer_sizes[i],
layer_sizes[i+1],
chunk_size,
dropconnect=dropconnect)
for i in range(len(layer_sizes) - 1)])
self.lockdrop = LockedDropout()
self.dropout = dropout
self.sizes = layer_sizes
def init_hidden(self, bsz):
return [c.init_hidden(bsz) for c in self.cells]
def forward(self, input, hidden):
length, batch_size, _ = input.size()
if self.training:
for c in self.cells:
c.sample_masks()
prev_state = list(hidden)
prev_layer = input
raw_outputs = []
outputs = []
distances_forget = []
distances_in = []
for l in range(len(self.cells)):
curr_layer = [None] * length
dist = [None] * length
t_input = self.cells[l].ih(prev_layer)
for t in range(length):
hidden, cell, d = self.cells[l](
None, prev_state[l],
transformed_input=t_input[t]
)
prev_state[l] = hidden, cell # overwritten every timestep
curr_layer[t] = hidden
dist[t] = d
prev_layer = torch.stack(curr_layer)
dist_cforget, dist_cin = zip(*dist)
dist_layer_cforget = torch.stack(dist_cforget)
dist_layer_cin = torch.stack(dist_cin)
raw_outputs.append(prev_layer)
if l < len(self.cells) - 1:
prev_layer = self.lockdrop(prev_layer, self.dropout)
outputs.append(prev_layer)
distances_forget.append(dist_layer_cforget)
distances_in.append(dist_layer_cin)
output = prev_layer
return output, prev_state, raw_outputs, outputs, (torch.stack(distances_forget), torch.stack(distances_in))
if __name__ == "__main__":
x = torch.Tensor(10, 10, 10)
x.data.normal_()
lstm = ONLSTMStack([10, 10, 10], chunk_size=10)
print(lstm(x, lstm.init_hidden(10))[1])