Skip to content

Commit

Permalink
Merge pull request torch#240 from torch/fixview
Browse files Browse the repository at this point in the history
fix View (hopefuly)
  • Loading branch information
soumith committed Apr 22, 2015
2 parents dbfeab9 + 0063dbd commit 394554a
Showing 2 changed files with 48 additions and 32 deletions.
69 changes: 37 additions & 32 deletions View.lua
Original file line number Diff line number Diff line change
@@ -2,14 +2,23 @@ local View, parent = torch.class('nn.View', 'nn.Module')

function View:__init(...)
parent.__init(self)
self.size = ...
if select('#', ...) > 1 or type(self.size) == "number" then
if select('#', ...) == 1 and torch.typename(select(1, ...)) == 'torch.LongStorage' then
self.size = select(1, ...)
else
self.size = torch.LongStorage({...})
end
assert(torch.typename(self.size)=="torch.LongStorage", "expecting a LongStorage")

self.numElements = 1
local inferdim = false
for i = 1,#self.size do
self.numElements = self.numElements * self.size[i]
local szi = self.size[i]
if szi >= 0 then
self.numElements = self.numElements * self.size[i]
else
assert(szi == -1, 'size should be positive or -1')
assert(not inferdim, 'only one dimension can be at -1')
inferdim = true
end
end

self.output = nil
@@ -23,42 +32,38 @@ function View:setNumInputDims(numInputDims)
end

local function batchsize(input, size, numInputDims, numElements)

-- handle special vector case
if size:size() == 1 and size[1] == -1 then
if numInputDims then
numElements = 1
local dim = input:nDimension()
for i=1,numInputDims do
numElements = numElements * input:size(dim-i+1)
end
else
numElements = input:nElement()
end
size = torch.LongStorage{numElements}
end

-- find if number of elements is divisible with desired number
local ine = input:nElement()
local dim = 0
local bsz = 1
while ine > numElements do
dim = dim + 1
local dimsz = input:size(dim)
if ine % numElements == 0 then
dimsz = math.min(ine/numElements, dimsz)
end
ine = ine / dimsz
bsz = bsz * dimsz
local ind = input:nDimension()
local isz = input:size()
local maxdim = numInputDims and numInputDims or ind
local ine = 1
for i=ind,ind-maxdim+1,-1 do
ine = ine * isz[i]
end

if ine ~= numElements then
if ine % numElements ~= 0 then
error(string.format(
'input view (%s) and desired view (%s) do not match',
table.concat(input:size():totable(), 'x'),
table.concat(size:totable(), 'x')))
end

-- the remainder is either the batch...
local bsz = ine / numElements

-- ... or the missing size dim
for i=1,size:size() do
if size[i] == -1 then
bsz = 1
break
end
end

-- for dim over maxdim, it is definitively the batch
for i=ind-maxdim,1,-1 do
bsz = bsz * isz[i]
end

-- special card
if bsz == 1 and (not numInputDims or input:nDimension() <= numInputDims) then
return
end
11 changes: 11 additions & 0 deletions test.lua
Original file line number Diff line number Diff line change
@@ -2872,6 +2872,17 @@ function nntest.View()
minibatch:size(1),
"Error in minibatch dimension with size -1")

-- another setNumInputDims case
local minibatch = torch.rand(2,5,4,10)
local module = nn.View(4,-1):setNumInputDims(2)
local out = module:forward(minibatch)
mytester:assertTableEq(out:size(1), minibatch:size(1)*minibatch:size(2),
"Error in minibatch dimension with size -1")
mytester:assertTableEq(out:size(2), minibatch:size(3),
"Error in minibatch dimension with size -1")
mytester:assertTableEq(out:size(3), minibatch:size(4),
"Error in minibatch dimension with size -1")

-- Minibatch Generalization
local minibatch = torch.rand(5,2,6)
local module = nn.View(6)

0 comments on commit 394554a

Please sign in to comment.