Skip to content

Commit

Permalink
Recurrent:sharedType
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Mar 19, 2015
1 parent 6050849 commit 45e5bf6
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Module.lua
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ end

-- for preserving shared params created with sharedClones
function Module:sharedType(type, castmap)
assert(type, 'Module: must provide a type to convert to')
assert(type, 'Module:sharedType must provide a type to convert to')
-- key: pointer to old storage
-- value : new storage
castmap = castmap or {} --contains torch.Storage instances
Expand Down
13 changes: 13 additions & 0 deletions Recurrent.lua
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,19 @@ function Recurrent:forget()
parent.forget(self, 1)
end

function Recurrent:sharedType(type, castmap)
local modules = self.modules
self.modules = {}
for i,modules in ipairs{modules, self.sharedClones, {self.initialModule}} do
for j, module in pairs(modules) do
table.insert(self.modules, module)
end
end
parent.sharedType(self, type, castmap)
self.modules = modules
return self
end

function Recurrent:__tostring__()
local tab = ' '
local line = '\n'
Expand Down
9 changes: 9 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,15 @@ function rnntest.Recurrent()
for i=1,#params do
mytester:assertTensorEq(params[i], params5[i], 0.000001, 'backwardUpdateThroughTime error ' .. i)
end

mlp:forget()
local rnn = mlp:float(true)
local outputs2 = {}
for step=1,nSteps do
rnn:forward(inputSequence[step]:float())
rnn:backward(inputSequence[step]:float(), gradOutputs[step]:float())
end
local gradInput2 = rnn:backwardThroughTime()
end

function rnntest.Recurrent_TestTable()
Expand Down

0 comments on commit 45e5bf6

Please sign in to comment.