forked from Element-Research/rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSeqLSTMP.lua
57 lines (45 loc) · 1.89 KB
/
SeqLSTMP.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
local SeqLSTMP, parent = torch.class('nn.SeqLSTMP', 'nn.SeqLSTM')
SeqLSTMP.dpnn_parameters = {'weight', 'bias', 'weightO'}
SeqLSTMP.dpnn_gradParameters = {'gradWeight', 'gradBias', 'gradWeightO'}
function SeqLSTMP:__init(inputsize, hiddensize, outputsize)
assert(inputsize and hiddensize and outputsize, "Expecting input, hidden and output size")
local D, H, R = inputsize, hiddensize, outputsize
self.weightO = torch.Tensor(H, R)
self.gradWeightO = torch.Tensor(H, R)
parent.__init(self, inputsize, hiddensize, outputsize)
end
function SeqLSTMP:reset(std)
self.bias:zero()
self.bias[{{self.outputsize + 1, 2 * self.outputsize}}]:fill(1)
if not std then
self.weight:normal(0, 1.0 / math.sqrt(self.hiddensize + self.inputsize))
self.weightO:normal(0, 1.0 / math.sqrt(self.outputsize + self.hiddensize))
else
self.weight:normal(0, std)
self.weightO:normal(0, std)
end
return self
end
function SeqLSTMP:adapter(t)
local T, N = self._output:size(1), self._output:size(2)
self._hidden = self._hidden or self.next_h.new()
self._hidden:resize(T, N, self.hiddensize)
self._hidden[t]:copy(self.next_h)
self.next_h:resize(N,self.outputsize)
self.next_h:mm(self._hidden[t], self.weightO)
end
function SeqLSTMP:gradAdapter(scale, t)
self.buffer3:resizeAs(self.grad_next_h):copy(self.grad_next_h)
self.gradWeightO:addmm(scale, self._hidden[t]:t(), self.grad_next_h)
self.grad_next_h:resize(self._output:size(2), self.hiddensize)
self.grad_next_h:mm(self.buffer3, self.weightO:t())
end
function SeqLSTMP:parameters()
return {self.weight, self.bias, self.weightO}, {self.gradWeight, self.gradBias, self.gradWeightO}
end
function SeqLSTMP:accUpdateGradParameters(input, gradOutput, lr)
error"accUpdateGradParameters not implemented for SeqLSTMP"
end
function SeqLSTMP:toFastLSTM()
error"toFastLSTM not supported for SeqLSTMP"
end