Skip to content

Commit

Permalink
Allow Variables in calls to type2backend (#4724)
Browse files Browse the repository at this point in the history
Use x.type() instead of type(x) when accessing type2backend to support
Variables as well as Tensors.
  • Loading branch information
colesbury authored Jan 18, 2018
1 parent 23fc2b7 commit 3249d8b
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion torch/legacy/nn/Criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Criterion(object):
def __init__(self):
self.gradInput = torch.Tensor()
self.output = 0
self._backend = torch._thnn.type2backend[type(self.gradInput)]
self._backend = torch._thnn.type2backend[self.gradInput.type()]

def updateOutput(self, input, target):
raise NotImplementedError
Expand Down
2 changes: 1 addition & 1 deletion torch/legacy/nn/Module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self):
self.gradInput = torch.Tensor()
self.output = torch.Tensor()
self._type = self.output.type()
self._backend = torch._thnn.type2backend[type(self.output)]
self._backend = torch._thnn.type2backend[self.output.type()]

def __repr__(self):
return 'nn.' + self.__class__.__name__
Expand Down
4 changes: 2 additions & 2 deletions torch/nn/_functions/thnn/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def symbolic(*args, **kwargs):

@staticmethod
def forward(ctx, input, target, *args):
ctx._backend = type2backend[type(input)]
ctx._backend = type2backend[input.type()]
ctx.save_for_backward(input, target)
if weight_arg_idx >= 0:
ctx.weight = args[0]
Expand Down Expand Up @@ -146,7 +146,7 @@ def symbolic(*args, **kwargs):

@staticmethod
def forward(ctx, input, *params):
ctx._backend = type2backend[type(input)]
ctx._backend = type2backend[input.type()]

ctx.additional_args = []
tensor_param_list = []
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/_functions/thnn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def forward(self, input):
self.scale = self.scale or input.new()
output = input.new()

backend = type2backend[type(input)]
backend = type2backend[input.type()]
if backend is not None:
try:
backend.SpatialCrossMapLRN_updateOutput
Expand Down
8 changes: 4 additions & 4 deletions torch/nn/_functions/thnn/rnnFusedPointwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class GRUFused(Function):
@staticmethod
def forward(ctx, input_gate, hidden_gate, hx, ibias=None, hbias=None):
ctx.backend = type2backend[type(input_gate)]
ctx.backend = type2backend[input_gate.type()]

hy = input_gate.new()
workspace = input_gate.new(hx.numel() * 5)
Expand All @@ -32,7 +32,7 @@ def forward(ctx, input_gate, hidden_gate, hx, ibias=None, hbias=None):
@staticmethod
@once_differentiable
def backward(ctx, gradOutput):
ctx.backend = type2backend[type(gradOutput)]
ctx.backend = type2backend[gradOutput.type()]

gradInputHx = gradOutput.new()
gradInInput = gradOutput.new(*ctx.igate_size)
Expand All @@ -52,7 +52,7 @@ def backward(ctx, gradOutput):
class LSTMFused(Function):
@staticmethod
def forward(ctx, input_gate, hidden_gate, cx, ibias=None, hbias=None):
ctx.backend = type2backend[type(input_gate)]
ctx.backend = type2backend[input_gate.type()]
hy = input_gate.new()
cy = input_gate.new()

Expand All @@ -79,7 +79,7 @@ def forward(ctx, input_gate, hidden_gate, cx, ibias=None, hbias=None):
@staticmethod
@once_differentiable
def backward(ctx, *gradOutput):
ctx.backend = type2backend[type(gradOutput[0])]
ctx.backend = type2backend[gradOutput[0].type()]
gradInputCx = gradOutput[0].new()
gradInGates = gradOutput[0].new(*ctx.hgate_size)

Expand Down
2 changes: 1 addition & 1 deletion torch/nn/_functions/thnn/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def forward(cls, ctx, weight, indices, offsets,
" ({}), but got offsets[-1] of {}"
.format(indices.size(0), offsets[-1]))

ctx._backend = type2backend[type(weight)]
ctx._backend = type2backend[weight.type()]
ctx._weight_size = weight.size()
ctx._offset2bag = offsets.new()

Expand Down
4 changes: 2 additions & 2 deletions torch/nn/_functions/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def forward(ctx, input, grid, padding_mode='zeros'):
.format(padding_mode))

grid_sz = grid.size()
backend = type2backend[type(input)]
backend = type2backend[input.type()]
output = input.new(grid_sz[0], input.size(1), grid_sz[1], grid_sz[2])
backend.SpatialGridSamplerBilinear_updateOutput(backend.library_state, input, grid, output, ctx.padding_mode)
return output
Expand All @@ -58,7 +58,7 @@ def backward(ctx, grad_output):
input, grid = ctx.saved_tensors
padding_mode = ctx.padding_mode

backend = type2backend[type(input)]
backend = type2backend[input.type()]
grad_input = input.new(input.size())
grad_grid = grid.new(grid.size())
backend.SpatialGridSamplerBilinear_updateGradInput(
Expand Down
4 changes: 2 additions & 2 deletions torch/utils/serialization/read_lua_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _load_backend(obj):
attr = getattr(obj, key)
if torch.is_tensor(attr):
try:
obj._backend = type2backend[type(attr)]
obj._backend = type2backend[attr.type()]
except KeyError:
pass
# Monkey patch the forward to capture the type of input
Expand All @@ -231,7 +231,7 @@ def updateOutput_patch(*args):
input = args[0]
while not torch.is_tensor(input):
input = input[0]
obj._backend = type2backend[type(input)]
obj._backend = type2backend[input.type()]
obj.updateOutput = updateOutput_orig
return obj.updateOutput(*args)
obj.updateOutput = updateOutput_patch
Expand Down

0 comments on commit 3249d8b

Please sign in to comment.