forked from karpathy/char-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGRU.lua
63 lines (55 loc) · 2.01 KB
/
GRU.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
local GRU = {}
--[[
Creates one timestep of one GRU
Paper reference: http://arxiv.org/pdf/1412.3555v1.pdf
]]--
function GRU.gru(input_size, rnn_size, n, dropout)
dropout = dropout or 0
-- there are n+1 inputs (hiddens on each layer and x)
local inputs = {}
table.insert(inputs, nn.Identity()()) -- x
for L = 1,n do
table.insert(inputs, nn.Identity()()) -- prev_h[L]
end
function new_input_sum(insize, xv, hv)
local i2h = nn.Linear(insize, rnn_size)(xv)
local h2h = nn.Linear(rnn_size, rnn_size)(hv)
return nn.CAddTable()({i2h, h2h})
end
local x, input_size_L
local outputs = {}
for L = 1,n do
local prev_h = inputs[L+1]
-- the input to this layer
if L == 1 then
x = OneHot(input_size)(inputs[1])
input_size_L = input_size
else
x = outputs[(L-1)]
if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any
input_size_L = rnn_size
end
-- GRU tick
-- forward the update and reset gates
local update_gate = nn.Sigmoid()(new_input_sum(input_size_L, x, prev_h))
local reset_gate = nn.Sigmoid()(new_input_sum(input_size_L, x, prev_h))
-- compute candidate hidden state
local gated_hidden = nn.CMulTable()({reset_gate, prev_h})
local p2 = nn.Linear(rnn_size, rnn_size)(gated_hidden)
local p1 = nn.Linear(input_size_L, rnn_size)(x)
local hidden_candidate = nn.Tanh()(nn.CAddTable()({p1,p2}))
-- compute new interpolated hidden state, based on the update gate
local zh = nn.CMulTable()({update_gate, hidden_candidate})
local zhm1 = nn.CMulTable()({nn.AddConstant(1,false)(nn.MulConstant(-1,false)(update_gate)), prev_h})
local next_h = nn.CAddTable()({zh, zhm1})
table.insert(outputs, next_h)
end
-- set up the decoder
local top_h = outputs[#outputs]
if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end
local proj = nn.Linear(rnn_size, input_size)(top_h)
local logsoft = nn.LogSoftMax()(proj)
table.insert(outputs, logsoft)
return nn.gModule(inputs, outputs)
end
return GRU