Skip to content

Commit

Permalink
Adds dim argument to torch.unique (pytorch#10423)
Browse files Browse the repository at this point in the history
Summary:
Initial version of `unique` supporting a `dim` argument.

As discussed in [this issue](pytorch#9997) I added the `dim` argument to `torch.unique` with the same behavior like [numpy](https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.unique.html).

Since the implementation is based on `std/thrust::unique`, the `tensor` always needs to be sorted. The `sorted` argument in `torch.unique` does not have any function, as in the CUDA version of the plain `torch.unique`.

To check the performance and equal behavior between `torch.unique` and `np.unique`, I've used [this gist](https://gist.github.com/ptrblck/ac0dc862f4e1766f0e1036c252cdb105).

Currently we achieve the following timings for an input of `x = torch.randint(2, (1000, 1000))`:
(The values are calculated by taking the average of the times for both dimension)

| Device | PyTorch (return_inverse=False) | Numpy (return_inverse=False) | PyTorch (return_inverse=True) | Numpy (return_inverse=True) |
| --- | --- | --- | --- | --- |
| CPU | ~0.007331s | ~0.022452s | ~0.011139s | ~0.044800s |
| GPU | ~0.006154s | - | ~0.105373s | - |

Many thanks to colesbury for the awesome mentoring and the valuable advices on the general implementation and performance issues!
Pull Request resolved: pytorch#10423

Differential Revision: D9517289

Pulled By: soumith

fbshipit-source-id: a4754f805223589c2847c98b8e4e39d8c3ddb7b5
  • Loading branch information
pbialecki authored and facebook-github-bot committed Aug 29, 2018
1 parent 98d85b1 commit 2cc98d8
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 9 deletions.
84 changes: 84 additions & 0 deletions aten/src/ATen/native/Unique.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,82 @@ std::tuple<Tensor, Tensor> _unique_cpu_template(
}
return std::make_tuple(output, inverse_indices);
}

template<class ForwardIt>
ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
std::vector<int64_t>& indices, Tensor inverse_indices_vec) {
if (first == last) {
return last;
}
// save to calculate distance to iterators
ForwardIt begin = first;

// set first inverse index
inverse_indices_vec[indices[0]] = 0;

ForwardIt result = first;
while (++first != last) {
if (!at::equal(*result, *first) && ++result != first) {
*result = std::move(*first);
}
int64_t idx_result = std::distance(begin, result);
int64_t idx_first = std::distance(begin, first);
inverse_indices_vec[indices[idx_first]] = idx_result;
}

return ++result;
}

template <typename scalar_t>
std::tuple<Tensor, Tensor> _unique_dim_cpu_template(
const Tensor& self,
const int64_t dim,
const bool return_inverse) {
// reshape tensor as [dim, -1]
Tensor input_flat = self.transpose(dim, 0);
auto orig_sizes = input_flat.sizes().vec();
input_flat = input_flat.contiguous().view({input_flat.size(0), -1});

std::vector<int64_t> indices(input_flat.size(0));
std::iota(indices.begin(), indices.end(), 0);
int64_t numel = input_flat.size(1);
scalar_t* input_flat_ptr = ((scalar_t*)input_flat.data_ptr());

// sort indices using data
std::sort(indices.begin(), indices.end(),
[&](int64_t a, int64_t b) -> bool {
for (int64_t i = 0; i < numel; ++i) {
scalar_t lhs = input_flat_ptr[i + a * numel];
scalar_t rhs = input_flat_ptr[i + b * numel];
if (lhs < rhs) {
return true;
} else if (lhs > rhs) {
return false;
}
}
return false;
});

Tensor input_sorted = at::empty(input_flat.sizes(), input_flat.type());
for (int i = 0; i < indices.size(); ++i) {
input_sorted[i] = input_flat[indices[i]];
}

Tensor inverse_indices = at::empty(indices.size(), self.type().toScalarType(kLong));
std::vector<Tensor> input_unbind = at::unbind(input_sorted, 0);
auto last = _unique_dim_cpu_impl(
input_unbind.begin(), input_unbind.end(), indices, inverse_indices);
input_unbind.erase(last, input_unbind.end());

// reshape back
auto output = at::stack(input_unbind, 0);
auto new_sizes = std::vector<int64_t>(orig_sizes);
new_sizes[0] = -1;
output = output.view(new_sizes);
output = output.transpose(0, dim);

return std::make_tuple(output, inverse_indices);
}
} // namespace

std::tuple<Tensor, Tensor>
Expand All @@ -56,5 +132,13 @@ _unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) {
});
}

std::tuple<Tensor, Tensor>
_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] {
// The current implementation using `dim` always sorts due to unhashable tensors
return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse);
});
}

} // namespace native
} // namespace at
97 changes: 97 additions & 0 deletions aten/src/ATen/native/cuda/Unique.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,92 @@ template <typename scalar_t>
return std::tuple<Tensor, Tensor>(output, inverse_indices);

}

template <typename scalar_t>
std::tuple<Tensor, Tensor> _unique_dim_cuda_template(
const Tensor& self,
const int64_t dim,
const bool return_inverse) {

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);

Tensor input_flat = self.transpose(dim, 0);
auto orig_sizes = input_flat.sizes().vec();
input_flat = input_flat.contiguous().view({input_flat.size(0), -1});

scalar_t* input_flat_ptr = input_flat.data<scalar_t>();

Tensor indices = at::arange(0, input_flat.size(0), self.type().toScalarType(kLong));
int64_t* indices_ptr = indices.data<int64_t>();
int64_t numel = input_flat.size(1);

// sort indices using data
thrust::sort(policy, indices_ptr, indices_ptr + indices.numel(),
[=] __device__ (int64_t a, int64_t b) -> bool {
for (int64_t i = 0; i < numel; ++i) {
scalar_t lhs = input_flat_ptr[i + a * numel];
scalar_t rhs = input_flat_ptr[i + b * numel];
if (lhs < rhs) {
return true;
} else if (lhs > rhs) {
return false;
}
}
return false;
});

Tensor input_sorted = input_flat.index_select(0, indices);

// get unique tensors
scalar_t* input_sorted_ptr = input_sorted.data<scalar_t>();
Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.type().toScalarType(kLong));
int64_t* input_sorted_indices_ptr = input_sorted_indices.data<int64_t>();
auto last = thrust::unique(policy, input_sorted_indices_ptr, input_sorted_indices_ptr + input_sorted_indices.numel(),
[=] __device__ (int64_t a, int64_t b) -> bool {
for (int64_t i = 0; i < numel; ++i) {
scalar_t lhs = input_sorted_ptr[i + a * numel];
scalar_t rhs = input_sorted_ptr[i + b * numel];
if (lhs != rhs) {
return false;
}
}
return true;
});
input_sorted_indices.resize_(last - input_sorted_indices_ptr);
Tensor output = input_sorted.index_select(0, input_sorted_indices);

// reshape back
auto new_sizes = std::vector<int64_t>(orig_sizes);
new_sizes[0] = -1;
output = output.view(new_sizes);
output = output.transpose(0, dim);

// calculate inverse indices
Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong));
if (return_inverse) {
int64_t size = self.size(dim);
inverse_indices.resize_(size);
Tensor mask = at::empty(input_sorted.size(0), self.type().toScalarType(kLong));
mask[0] = 1;
for (int i = 0; i < input_sorted.size(0) - 1; ++i) {
if (!at::equal(input_sorted[i], input_sorted[i+1])) {
mask[i+1] = 1;
} else {
mask[i+1] = 0;
}
}

Tensor imask = at::cumsum(mask, 0) - 1;
for (int i = 0; i < indices.size(0); ++i) {
inverse_indices[indices[i]] = imask[i];
}
}

THCudaCheck(cudaGetLastError());
return std::tuple<Tensor, Tensor>(output, inverse_indices);
}
} // namespace

#endif
Expand All @@ -86,5 +172,16 @@ _unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) {
#endif
}

std::tuple<Tensor, Tensor>
_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
#ifndef __HIP_PLATFORM_HCC__
return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] {
return _unique_dim_cuda_template<scalar_t>(self, dim, return_inverse);
});
#else
AT_ERROR("unique_dim_cuda: HIP not supported");
#endif
}

} // namespace native
} // namespace at
5 changes: 5 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1748,6 +1748,11 @@
CPU: _unique_cpu
CUDA: _unique_cuda

- func: _unique_dim(Tensor self, int64_t dim, bool sorted=false, bool return_inverse=false) -> (Tensor, Tensor)
dispatch:
CPU: _unique_dim_cpu
CUDA: _unique_dim_cuda

- func: _unsafe_view(Tensor self, IntList size) -> Tensor
variants: function

Expand Down
61 changes: 61 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8485,6 +8485,67 @@ def test_unique(self):
self.assertEqual(torch.ByteTensor([7, 42, 128, 133]), byte_unique)
self.assertEqual(torch.LongTensor([3, 0, 0, 0, 1, 2]), byte_inverse)

def test_unique_dim(self):
def run_test(dtype=torch.float):
x = torch.tensor([[[1., 1.],
[0., 1.],
[2., 1.],
[0., 1.]],
[[1., 1.],
[0., 1.],
[2., 1.],
[0., 1.]]], dtype=dtype)
expected_unique_dim0 = torch.tensor([[[1., 1.],
[0., 1.],
[2., 1.],
[0., 1.]]], dtype=dtype)
expected_inverse_dim0 = torch.tensor([0, 0])
expected_unique_dim1 = torch.tensor([[[0., 1.],
[1., 1.],
[2., 1.]],
[[0., 1.],
[1., 1.],
[2., 1.]]], dtype=dtype)
expected_inverse_dim1 = torch.tensor([1, 0, 2, 0])
expected_unique_dim2 = torch.tensor([[[1., 1.],
[0., 1.],
[2., 1.],
[0., 1.]],
[[1., 1.],
[0., 1.],
[2., 1.],
[0., 1.]]], dtype=dtype)
expected_inverse_dim2 = torch.tensor([0, 1])

# dim0
x_unique = torch.unique(x, dim=0)
self.assertEqual(expected_unique_dim0, x_unique)

x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0)
self.assertEqual(expected_unique_dim0, x_unique)
self.assertEqual(expected_inverse_dim0, x_inverse)

# dim1
x_unique = torch.unique(x, dim=1)
self.assertEqual(expected_unique_dim1, x_unique)

x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1)
self.assertEqual(expected_unique_dim1, x_unique)
self.assertEqual(expected_inverse_dim1, x_inverse)

# dim2
x_unique = torch.unique(x, dim=2)
self.assertEqual(expected_unique_dim2, x_unique)

x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2)
self.assertEqual(expected_unique_dim2, x_unique)
self.assertEqual(expected_inverse_dim2, x_inverse)

run_test(torch.float)
run_test(torch.double)
run_test(torch.long)
run_test(torch.uint8)

@staticmethod
def _test_bincount(self, device):
# negative input throws
Expand Down
20 changes: 14 additions & 6 deletions torch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def isnan(tensor):
return tensor != tensor


def unique(input, sorted=False, return_inverse=False):
def unique(input, sorted=False, return_inverse=False, dim=None):
r"""Returns the unique scalar elements of the input tensor as a 1-D tensor.
Arguments:
Expand Down Expand Up @@ -431,11 +431,19 @@ def unique(input, sorted=False, return_inverse=False):
[ 1, 2]])
"""
output, inverse_indices = torch._unique(
input,
sorted=sorted,
return_inverse=return_inverse,
)
if dim is not None:
output, inverse_indices = torch._unique_dim(
input,
dim,
sorted=sorted,
return_inverse=return_inverse
)
else:
output, inverse_indices = torch._unique(
input,
sorted=sorted,
return_inverse=return_inverse,
)
if return_inverse:
return output, inverse_indices
else:
Expand Down
15 changes: 12 additions & 3 deletions torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,22 @@ def masked_fill(self, mask, value):
"""
return self.clone().masked_fill_(mask, value)

def unique(self, sorted=False, return_inverse=False):
def unique(self, sorted=False, return_inverse=False, dim=None):
r"""Returns the unique scalar elements of the tensor as a 1-D tensor.
See :func:`torch.unique`
"""
output, inverse_indices = self._unique(
sorted=sorted, return_inverse=return_inverse)
if dim is not None:
output, inverse_indices = self._unique_dim(
sorted=sorted,
return_inverse=return_inverse,
dim=dim
)
else:
output, inverse_indices = self._unique(
sorted=sorted,
return_inverse=return_inverse
)
if return_inverse:
return output, inverse_indices
else:
Expand Down

0 comments on commit 2cc98d8

Please sign in to comment.