Skip to content

Commit

Permalink
Check argument types in 'checkTypes' (pytorch#1363)
Browse files Browse the repository at this point in the history
  • Loading branch information
colesbury authored Apr 26, 2017
1 parent 41705ce commit 8ca7bf2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
8 changes: 8 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,14 @@ def test_assignments(get_list, a, b, c):
self.assertIn('buf', l.state_dict())
self.assertIs(l.state_dict()['buf'], buf)

def test_Conv2d_inconsistent_types(self):
inputs = Variable(torch.randn(4, 1, 7, 7).float())
weights = Variable(torch.randn(1, 1, 3, 3).double())
# inconsistent types should raise an exception
self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights))
# but it should work with the same type
nn.functional.conv2d(inputs.float(), weights.float())

@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
def test_Conv2d_large_workspace(self):
# These sizes require huge cuDNN workspaces. Make sure we choose a
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/nn/THNN_generic.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ void checkTypes(bool isCuda, thpp::Type type, ...) {
if (tensor->isCuda() != isCuda) {
throw invalid_tensor(isCuda ? "CUDA" : "CPU", tensor->isCuda() ? "CUDA" : "CPU");
}
if (tensor->type() != type) {
throw invalid_tensor(thpp::toString(type), thpp::toString(tensor->type()));
}
}
}

Expand Down

0 comments on commit 8ca7bf2

Please sign in to comment.