Skip to content

Commit

Permalink
pack_padded_sequence: Check for empty (zero-element) tensors (#21461)
Browse files Browse the repository at this point in the history
Summary:
Fixes: #20529

Thank you, JamieCT for the bug report with reproducing script.
Pull Request resolved: pytorch/pytorch#21461

Differential Revision: D15696183

Pulled By: ezyang

fbshipit-source-id: a93cde2c924f8447563c64ce8a1cf75fcee60a01
  • Loading branch information
t-vi authored and facebook-github-bot committed Jun 6, 2019
1 parent 3b6362d commit 3feb40d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
1 change: 1 addition & 0 deletions aten/src/ATen/native/PackedSequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ std::tuple<Tensor, Tensor> _pack_padded_sequence(const Tensor& _input, const Ten

int64_t batch_size = input.size(1);
int64_t * lengths = lengths_t.data<int64_t>();
TORCH_CHECK(input.numel() > 0, "Cannot pack empty tensors.");
TORCH_CHECK(lengths_t.size(0) == batch_size,
"Expected `len(lengths)` to be equal to batch_size, but got ", lengths_t.size(0),
" (batch_size=", batch_size, ")");
Expand Down
4 changes: 3 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5329,9 +5329,11 @@ def pad(tensor, length):
if l < 10:
self.assertEqual(padded.grad.data[l:, i].abs().sum(), 0)

# test error message
# test error messages
with self.assertRaisesRegex(RuntimeError, 'You can pass `enforce_sorted=False`'):
packed = rnn_utils.pack_padded_sequence(torch.randn(3, 3), [1, 3, 2])
with self.assertRaisesRegex(RuntimeError, 'empty tensor'):
packed = rnn_utils.pack_padded_sequence(torch.randn(0, 0), [])

def _test_variable_sequence(self, device="cpu", dtype=torch.float):
def pad(var, length):
Expand Down

0 comments on commit 3feb40d

Please sign in to comment.