Skip to content

Commit

Permalink
[fix] torch.multinomial : fix for 0 size dim (#43775)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch/pytorch#43768

TO-DO:
* [x] Add test

Pull Request resolved: pytorch/pytorch#43775

Reviewed By: ZolotukhinM

Differential Revision: D23421979

Pulled By: ngimel

fbshipit-source-id: 949fcdd30f18d17ae1c372fa6ca6a0b8d0d538ce
  • Loading branch information
kshitij12345 authored and facebook-github-bot committed Aug 31, 2020
1 parent 3c8b1d7 commit 0394c5a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/native/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ Tensor& multinomial_out(Tensor& result, const Tensor& self, int64_t n_sample, bo
if (self.dim() > 1) {
int64_t n_dist = self.size(-2);
result.resize_({n_dist, n_sample});
if (n_dist == 0) { return result; };
} else {
result.resize_({n_sample});
}
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/cuda/MultinomialKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n
}
});
AT_CUDA_CHECK(cudaGetLastError());
if (inputSize == 1) {
result.resize_({n_sample});
}
Expand Down
18 changes: 12 additions & 6 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17942,13 +17942,19 @@ def test(probs, replacement):
test(y, True)
test(z, True)

def test_multinomial_empty(self, device):
probs = torch.ones(0, 3)
num_samples = 1
def _test_multinomial_empty(self, device, replacement, num_samples):
probs = torch.ones(0, 3, device=device)
expected = torch.empty(0, num_samples, dtype=torch.int64)
for replacement in (True, False):
out = torch.multinomial(probs, num_samples=num_samples, replacement=replacement)
self.assertEqual(out, expected)
out = torch.multinomial(probs, num_samples=num_samples, replacement=replacement)
self.assertEqual(out, expected)

def test_multinomial_empty_w_replacement(self, device):
self._test_multinomial_empty(device, True, 1)
self._test_multinomial_empty(device, True, 2)

def test_multinomial_empty_wo_replacement(self, device):
self._test_multinomial_empty(device, False, 1)
self._test_multinomial_empty(device, False, 2)

def _generate_input(self, shape, dtype, device, with_extremal):
if shape == ():
Expand Down

0 comments on commit 0394c5a

Please sign in to comment.