Skip to content

Commit

Permalink
Nicer error message when flattening parameters with inconsistent types.
Browse files Browse the repository at this point in the history
  • Loading branch information
timharley committed Apr 24, 2015
1 parent 394554a commit 51b903c
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions Module.lua
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ function Module:share(mlp, ...)
mlp.accUpdateGradParameters = mlp.sharedAccUpdateGradParameters
end
end
return self
return self
end

function Module:clone(...)
Expand All @@ -118,11 +118,8 @@ local function recursiveType(param, type_str)
for i = 1, #param do
param[i] = recursiveType(param[i], type_str)
end
else
if torch.typename(param) and
torch.typename(param):find('torch%..+Tensor') then
param = param:type(type_str)
end
elseif torch.isTensor(param) then
param = param:type(type_str)
end
return param
end
Expand Down Expand Up @@ -182,17 +179,22 @@ function Module:getParameters()
return torch.Tensor()
end
local Tensor = parameters[1].new
local dtype = parameters[1]:type()

local storages = {}
local nParameters = 0
for k = 1,#parameters do
if parameters[k]:type() ~= dtype then
error("Inconsistent parameter types. " .. parameters[k]:type() ..
" ~= " .. dtype)
end
local storage = parameters[k]:storage()
if not storageInSet(storages, storage) then
storages[torch.pointer(storage)] = {storage, nParameters}
nParameters = nParameters + storage:size()
end
end

local flatParameters = Tensor(nParameters):fill(1)
local flatStorage = flatParameters:storage()

Expand All @@ -205,7 +207,7 @@ function Module:getParameters()
parameters[k]:zero()
end

local maskParameters= flatParameters:float():clone()
local maskParameters = flatParameters:float():clone()
local cumSumOfHoles = flatParameters:float():cumsum(1)
local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles]
local flatUsedParameters = Tensor(nUsedParameters)
Expand All @@ -214,9 +216,9 @@ function Module:getParameters()
for k = 1,#parameters do
local offset = cumSumOfHoles[parameters[k]:storageOffset()]
parameters[k]:set(flatUsedStorage,
parameters[k]:storageOffset() - offset,
parameters[k]:size(),
parameters[k]:stride())
parameters[k]:storageOffset() - offset,
parameters[k]:size(),
parameters[k]:stride())
end

for _, storageAndOffset in pairs(storages) do
Expand Down Expand Up @@ -273,11 +275,11 @@ function Module:findModules(typename, container)
if (torch.type(self.modules) == 'table') then
for i = 1, #self.modules do
local child = self.modules[i]
local cur_nodes, cur_containers =
local cur_nodes, cur_containers =
child:findModules(typename, self)
assert(#cur_nodes == #cur_containers,
assert(#cur_nodes == #cur_containers,
'Internal error: incorrect return length') -- This shouldn't happen
-- add the list items from our child to our list (ie return a
-- add the list items from our child to our list (ie return a
-- flattened table of the return nodes).
for j = 1, #cur_nodes do
nodes[#nodes+1] = cur_nodes[j]
Expand Down Expand Up @@ -312,4 +314,3 @@ function Module:listModules()
end
return modules
end

0 comments on commit 51b903c

Please sign in to comment.