Skip to content

Commit

Permalink
Add optional is_coalesced argument to sparse coo tensor factory funct…
Browse files Browse the repository at this point in the history
…ion. (pytorch#107638)

Resolves pytorch#107097

After this PR, instead of
```python
torch.sparse_coo_tensor(indices, values, size)._coalesced_(is_coalesced)
```
(that does not work in the autograd context, see pytorch#107097), use
```python
torch.sparse_coo_tensor(indices, values, size, is_coalesced=is_coalesced)
```

All sparse coo factory functions that take indices as input support the `is_coalesced` argument.

Pull Request resolved: pytorch#107638
Approved by: https://github.com/cpuhrsch
  • Loading branch information
pearu authored and pytorchmergebot committed Aug 26, 2023
1 parent 781b7eb commit fe3309b
Show file tree
Hide file tree
Showing 14 changed files with 110 additions and 41 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/NonSymbolicBC.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace at::native {
// In those cases, we will duplicate the signature here with non-symbolic ints, and also duplicate the C++ implementation.
TORCH_API at::Tensor reshape(const at::Tensor& self, at::IntArrayRef proposed_shape);
TORCH_API at::Tensor narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length);
TORCH_API at::Tensor _sparse_coo_tensor_unsafe(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, c10::optional<at::ScalarType> dtype=c10::nullopt, c10::optional<at::Layout> layout=c10::nullopt, c10::optional<at::Device> device=c10::nullopt, c10::optional<bool> pin_memory=c10::nullopt);
TORCH_API at::Tensor _sparse_coo_tensor_unsafe(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, c10::optional<at::ScalarType> dtype=c10::nullopt, c10::optional<at::Layout> layout=c10::nullopt, c10::optional<at::Device> device=c10::nullopt, c10::optional<bool> pin_memory=c10::nullopt, c10::optional<bool> is_coalesced=c10::nullopt);
TORCH_API at::Tensor nll_loss(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor>& weight_opt, int64_t reduction, int64_t ignore_index);
TORCH_API at::Tensor nll_loss2d(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor>& weight_opt, int64_t reduction, int64_t ignore_index);
// The below ops don't get a duplicated C++ implementation.
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/SparseTensorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ Tensor zeros_like_with_indices(const Tensor& t) {
t.sizes(),
t._indices().clone(),
at::zeros({1}, t._values().options()).expand_as(t._values()),
t.options())._coalesced_(t.is_coalesced());
t.options(),
t.is_coalesced());
}

}} // namespace at::sparse
9 changes: 4 additions & 5 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ Tensor sparse_broadcast_to(const Tensor& self, IntArrayRef size) {
for (int64_t j:unchanged_dims) {
new_indices.select(0, sparse_extra_ndim + j).copy_(indices.select(0, j).repeat_interleave(nnz_factor));
}
return at::sparse_coo_tensor(new_indices, new_values, size)._coalesced_(is_coalesced);
return at::sparse_coo_tensor(new_indices, new_values, size, self.options(), is_coalesced);
}

Tensor broadcast_to_symint(const Tensor& self, SymIntArrayRef size) {
Expand Down Expand Up @@ -1283,8 +1283,7 @@ Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_
new_values = self._values().narrow_copy(dense_dim, start, length);
}

auto newTensor = at::sparse_coo_tensor(new_indices, new_values, new_sizes);
return newTensor._coalesced_(self.is_coalesced());
return at::sparse_coo_tensor(new_indices, new_values, new_sizes, self.options(), self.is_coalesced());
}

// Should just use narrow_copy_out, but this API is used internally at Meta:
Expand Down Expand Up @@ -1506,9 +1505,9 @@ Tensor permute_sparse_coo(const Tensor& self, IntArrayRef dims) {
}();

const auto is_coalesced = self.is_coalesced() && (dims[0] == 0);
// TODO: apply `is_coalesced ||= new_values.size(0) < 2`.
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_ndim, dense_ndim, new_sizes, new_indices, new_values, self.options())
._coalesced_(is_coalesced);
sparse_ndim, dense_ndim, new_sizes, new_indices, new_values, self.options(), is_coalesced);
}

Tensor repeat(const Tensor& self, IntArrayRef repeats) {
Expand Down
10 changes: 5 additions & 5 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6824,15 +6824,15 @@
CompositeExplicitAutograd: sparse_coo_tensor
autogen: sparse_coo_tensor.size_out

- func: sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
- func: sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor

- func: sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
- func: sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor

- func: _sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
- func: _sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor
dispatch:
CompositeImplicitAutograd: _sparse_coo_tensor_unsafe_symint

- func: _validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size) -> ()
- func: _validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size, bool? is_coalesced=None) -> ()

- func: _validate_sparse_compressed_tensor_args(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, Layout layout) -> ()
- func: _validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> ()
Expand All @@ -6845,7 +6845,7 @@
SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_sparse
autogen: _sparse_coo_tensor_with_dims.out

- func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
- func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor
dispatch:
SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_and_tensor_sparse_symint
autogen: _sparse_coo_tensor_with_dims_and_tensors.out
Expand Down
55 changes: 42 additions & 13 deletions aten/src/ATen/native/sparse/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,11 @@ SparseTensor new_with_dims_and_tensor_sparse_symint(
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
c10::optional<bool> pin_memory,
c10::optional<bool> is_coalesced) {
SparseTensor self = new_sparse(dtype, layout, device, pin_memory);
get_sparse_impl(self)->resize_(sparse_dim, dense_dim, size);
auto impl = get_sparse_impl(self);
impl->resize_(sparse_dim, dense_dim, size);
// NOTE: There is no guarantee that `indices` and `values` don't contain
// AutogradMeta. However, we want to maintain the invariant that `indices_`
// and `values_` of a sparse tensor don't contain AutogradMeta, and to achieve
Expand All @@ -204,6 +206,20 @@ SparseTensor new_with_dims_and_tensor_sparse_symint(
/*version_counter=*/values.unsafeGetTensorImpl()->version_counter(),
/*allow_tensor_metadata_change=*/true));
alias_into_sparse(self, indices_shallow_copy, values_shallow_copy);
// alias_into_sparse overrides coalesced flag, so resetting the flag to
// the desired state here:
if (is_coalesced.has_value()) {
impl->set_coalesced(*is_coalesced);
}
// TODO: alias_into_sparse sets the coalesce flag to
// `self._values().shape[0] < 2`. There exist methods (e.g. permute
// on COO tensors when `dims[0] != 0` holds) that force coalesced
// flag to false even when nnz is less that 2. Here we cannot
// determine if this is the intention of such methods but it is
// likely that these methods are overly restrictive when estimating
// is_coalesced state. The condition `!is_coalesced && self._nnz() <
// 2` provides a way to detect and optimize such methods with
// respect to estimating the is_coalesced state.
return self;
}

Expand Down Expand Up @@ -255,7 +271,8 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values_,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
c10::optional<bool> pin_memory,
c10::optional<bool> is_coalesced) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);

Expand Down Expand Up @@ -327,14 +344,17 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values_,
computed_sizes,
indices,
values,
values.options().layout(kSparse));
values.options().layout(kSparse),
is_coalesced);
}

void _validate_sparse_coo_tensor_args(
const Tensor& indices,
const Tensor& values_,
ArrayRef<int64_t> size) {
ArrayRef<int64_t> size,
c10::optional<bool> is_coalesced_) {
Tensor values = expand_values_if_needed(values_);
bool is_coalesced = is_coalesced_.value_or(false);

// the following checks are redundant because they are also checked in
// SparseTensorImpl::set_indices_and_values_unsafe but we need to ensure them
Expand Down Expand Up @@ -395,16 +415,21 @@ void _validate_sparse_coo_tensor_args(
" but found index ",
max_index_in_dim);
}
if (is_coalesced && values.size(0) > 1) {
Tensor indices_scalar = flatten_indices(indices, size);
Tensor diff = indices_scalar.diff();
TORCH_CHECK(diff.min().item().toLong() > 0, "cannot set is_coalesced to true if indices correspond to uncoalesced COO tensor");
}
}
}

// NB: Got rid of the sizes == NULL case

Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
c10::optional<bool> pin_memory,
c10::optional<bool> is_coalesced) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
// arg checking
Expand All @@ -419,18 +444,20 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRe
optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(),
options.device_opt(),
options.pinned_memory_opt());
options.pinned_memory_opt(),
is_coalesced);
}

Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, at::IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
c10::optional<bool> pin_memory,
c10::optional<bool> is_coalesced) {
if (at::globalContext().checkSparseTensorInvariants()) {
at::native::_validate_sparse_coo_tensor_args(indices, values_, size);
at::native::_validate_sparse_coo_tensor_args(indices, values_, size, is_coalesced);
}
return at::native::_sparse_coo_tensor_unsafe_symint(indices, values_, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
return at::native::_sparse_coo_tensor_unsafe_symint(indices, values_, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, is_coalesced);
}

// NOTE: _sparse_coo_tensor_unsafe() differs from sparse_coo_tensor()
Expand All @@ -443,7 +470,8 @@ Tensor _sparse_coo_tensor_unsafe_symint(const Tensor& indices, const Tensor& val
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
c10::optional<bool> pin_memory,
c10::optional<bool> is_coalesced) {
// See [Note: hacky wrapper removal for TensorOptions]

Tensor values = expand_values_if_needed(values_);
Expand All @@ -459,7 +487,8 @@ Tensor _sparse_coo_tensor_unsafe_symint(const Tensor& indices, const Tensor& val
size,
indices,
values,
values.options().layout(kSparse));
values.options().layout(kSparse),
is_coalesced);
}

// NB: Deleted newWithSizeNd variants
Expand Down
9 changes: 6 additions & 3 deletions aten/src/ATen/native/sparse/SparseTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1712,7 +1712,8 @@ Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum) {
if (sum_all_sparse_dim) new_sizes.emplace(new_sizes.begin(), 1);

// use coalesce() to do sum reduction
SparseTensor new_sparse = at::_sparse_coo_tensor_with_dims_and_tensors(new_sparse_dim, new_dense_dim, new_sizes, new_indices, new_values, input.options());
bool is_coalesced = false; // TODO: can we use input.is_coalesced()?
SparseTensor new_sparse = at::_sparse_coo_tensor_with_dims_and_tensors(new_sparse_dim, new_dense_dim, new_sizes, new_indices, new_values, input.options(), is_coalesced);
new_sparse = new_sparse.coalesce();
return new_sparse;
}
Expand Down Expand Up @@ -1814,7 +1815,8 @@ Tensor _sparse_sum_backward_cpu(const Tensor& grad_, const SparseTensor& input_,
grad_input_values = grad_input_values.expand(dense_expand_size);
}
grad_input_values = grad_input_values.expand(expand_size).clone(at::MemoryFormat::Contiguous);
return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(at::MemoryFormat::Contiguous), grad_input_values, input.options().dtype(grad_.dtype())); // convert to grad dtype
bool grad_is_coalesced = input.is_coalesced();
return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(at::MemoryFormat::Contiguous), grad_input_values, input.options().dtype(grad_.dtype()), grad_is_coalesced); // convert to grad dtype
}
else {
TORCH_CHECK(grad_.is_sparse(), "_sparse_sum_backward_cpu: expected grad_ Tensor to be sparse, but got dense");
Expand Down Expand Up @@ -1890,7 +1892,8 @@ Tensor _sparse_sum_backward_cpu(const Tensor& grad_, const SparseTensor& input_,
else {
grad_input_values = grad_values_expand;
}
return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(at::MemoryFormat::Contiguous), grad_input_values, grad.options());
bool grad_is_coalesced = input.is_coalesced();
return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(at::MemoryFormat::Contiguous), grad_input_values, grad.options(), grad_is_coalesced);
}
}

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/sparse/SparseUnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ Tensor coalesced_unary_ufunc(const Tensor &self, const Ufunc &ufunc) {
input.sizes(),
input.indices().clone(),
out_values,
input.options().dtype(out_values.scalar_type()));
result._coalesced_(true);
input.options().dtype(out_values.scalar_type()),
/*is_coalesced=*/ true);
return result;
}

Expand Down
21 changes: 21 additions & 0 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,27 @@ def test_ctor_size_checks(self, device, dtype):
RuntimeError,
lambda: self.sparse_tensor(indices, values, torch.Size([2, 4, 2, 1])))

@coalescedonoff
@dtypes(torch.double)
def test_ctor_is_coalesced_with_gradcheck(self, device, dtype, coalesced):
for sparse_size, nnz in (((3, 3), 5), ((2, 3, 1, 5), 11)):
t, _, _ = self._gen_sparse(len(sparse_size), nnz, sparse_size, dtype, device, coalesced)
self.assertEqual(t.is_coalesced(), coalesced)

def func(indices, values, shape, is_coalesced):
s = torch.sparse_coo_tensor(indices, values, shape, check_invariants=True, is_coalesced=is_coalesced)
self.assertEqual(s.is_coalesced(), is_coalesced)
return s.to_dense(masked_grad=False)

if coalesced:
torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, False))
torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, True))
else:
torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, False))
with self.assertRaisesRegex(RuntimeError,
"cannot set is_coalesced to true if indices correspond to uncoalesced COO tensor"):
torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, True))

@dtypes(*floating_and_complex_types_and(torch.float16, torch.bfloat16))
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
@gradcheck_semantics()
Expand Down
3 changes: 2 additions & 1 deletion tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1873,7 +1873,8 @@
self: sparse_mask_backward(grad, mask, self.layout())
mask: non_differentiable

- name: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
- name: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor
indices: non_differentiable
values: grad.sparse_mask(result)._values()

- name: sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
Expand Down
1 change: 1 addition & 0 deletions tools/pyi/gen_pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ def gen_pyi(
"device: Union[_device, str, None] = None",
"requires_grad: _bool = False",
"check_invariants: Optional[_bool] = None",
"is_coalesced: Optional[_bool] = None",
]
)
)
Expand Down
17 changes: 14 additions & 3 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10515,16 +10515,18 @@ def merge_dicts(*dicts):

add_docstr(
torch.sparse_coo_tensor,
r"""
sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor
r"""sparse_coo_tensor(indices, values, size=None, """
r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None, is_coalesced=None) -> Tensor

Constructs a :ref:`sparse tensor in COO(rdinate) format
<sparse-coo-docs>` with specified values at the given
:attr:`indices`.

.. note::

This function returns an :ref:`uncoalesced tensor <sparse-uncoalesced-coo-docs>`.
This function returns an :ref:`uncoalesced tensor
<sparse-uncoalesced-coo-docs>` when :attr:`is_coalesced` is
unspecified or ``None``.

{sparse_factory_device_note}

Expand All @@ -10549,6 +10551,15 @@ def merge_dicts(*dicts):
for CPU tensor types and the current CUDA device for CUDA tensor types.
{requires_grad}
{check_invariants}
is_coalesced (bool, optional): When``True``, the caller is
responsible for providing tensor indices that correspond to a
coalesced tensor. If the :attr:`check_invariants` flag is
False, no error will be raised if the prerequisites are not
met and this will lead to silently incorrect results. To force
coalescion please use :meth:`coalesce` on the resulting
Tensor.
Default: None: except for trivial cases (e.g. nnz < 2) the
resulting Tensor has is_coalesced set to ``False```.

Example::

Expand Down
8 changes: 4 additions & 4 deletions torch/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _validate_loaded_sparse_tensors():
for t in _sparse_tensors_to_validate:
if t.layout is torch.sparse_coo:
torch._validate_sparse_coo_tensor_args(
t._indices(), t._values(), t.size()
t._indices(), t._values(), t.size(), t.is_coalesced()
)
elif t.layout in {
torch.sparse_csr,
Expand Down Expand Up @@ -276,9 +276,9 @@ def _rebuild_sparse_tensor(layout, data):
is_coalesced = None
else:
indices, values, size, is_coalesced = data
result = torch.sparse_coo_tensor(indices, values, size, check_invariants=False)
if is_coalesced is not None:
result._coalesced_(is_coalesced)
result = torch.sparse_coo_tensor(
indices, values, size, check_invariants=False, is_coalesced=is_coalesced
)
_sparse_tensors_to_validate.append(result)
return result

Expand Down
Loading

0 comments on commit fe3309b

Please sign in to comment.