diff --git a/dgl_sparse/CMakeLists.txt b/dgl_sparse/CMakeLists.txt index 49a29ac675d2..b4d9e03a8368 100644 --- a/dgl_sparse/CMakeLists.txt +++ b/dgl_sparse/CMakeLists.txt @@ -1,3 +1,5 @@ +project(dgl_sparse C CXX) + # Find PyTorch cmake files and PyTorch versions with the python interpreter $TORCH_PYTHON_INTERPS # ("python3" or "python" if empty) if(NOT TORCH_PYTHON_INTERPS) diff --git a/dgl_sparse/include/sparse/spspmm.h b/dgl_sparse/include/sparse/spspmm.h new file mode 100644 index 000000000000..b4e856bf50fb --- /dev/null +++ b/dgl_sparse/include/sparse/spspmm.h @@ -0,0 +1,37 @@ +/** + * Copyright (c) 2022 by Contributors + * @file sparse/spspmm.h + * @brief DGL C++ SpSpMM operator. + */ +#ifndef SPARSE_SPSPMM_H_ +#define SPARSE_SPSPMM_H_ + +#include +#include + +namespace dgl { +namespace sparse { + +/** + * @brief Perform a sparse-sparse matrix multiplication on matrices with + * possibly different sparsities. The two sparse matrices must have + * 1-D values. If the first sparse matrix has shape (n, m), the second + * sparse matrix must have shape (m, k), and the returned sparse matrix has + * shape (n, k). + * + * This function supports autograd for both sparse matrices but does + * not support higher order gradient. + * + * @param lhs_mat The first sparse matrix of shape (n, m). + * @param rhs_mat The second sparse matrix of shape (m, k). + * + * @return Sparse matrix of shape (n, k). + */ +c10::intrusive_ptr SpSpMM( + const c10::intrusive_ptr& lhs_mat, + const c10::intrusive_ptr& rhs_mat); + +} // namespace sparse +} // namespace dgl + +#endif // SPARSE_SPSPMM_H_ diff --git a/dgl_sparse/src/matmul.cc b/dgl_sparse/src/matmul.cc index 17569fc07b48..703bf2a629c7 100644 --- a/dgl_sparse/src/matmul.cc +++ b/dgl_sparse/src/matmul.cc @@ -98,5 +98,35 @@ torch::Tensor SDDMMNoAutoGrad( return ret; } +c10::intrusive_ptr SpSpMMNoAutoGrad( + const c10::intrusive_ptr& lhs_mat, torch::Tensor lhs_val, + const c10::intrusive_ptr& rhs_mat, torch::Tensor rhs_val, + bool lhs_transpose, bool rhs_transpose) { + aten::CSRMatrix lhs_dgl_csr, rhs_dgl_csr; + if (!lhs_transpose) { + lhs_dgl_csr = CSRToOldDGLCSR(lhs_mat->CSRPtr()); + } else { + lhs_dgl_csr = CSRToOldDGLCSR(lhs_mat->CSCPtr()); + } + if (!rhs_transpose) { + rhs_dgl_csr = CSRToOldDGLCSR(rhs_mat->CSRPtr()); + } else { + rhs_dgl_csr = CSRToOldDGLCSR(rhs_mat->CSCPtr()); + } + auto lhs_dgl_val = TorchTensorToDGLArray(lhs_val); + auto rhs_dgl_val = TorchTensorToDGLArray(rhs_val); + const int64_t ret_row = + lhs_transpose ? lhs_mat->shape()[1] : lhs_mat->shape()[0]; + const int64_t ret_col = + rhs_transpose ? rhs_mat->shape()[0] : rhs_mat->shape()[1]; + std::vector ret_shape({ret_row, ret_col}); + aten::CSRMatrix ret_dgl_csr; + runtime::NDArray ret_val; + std::tie(ret_dgl_csr, ret_val) = + aten::CSRMM(lhs_dgl_csr, lhs_dgl_val, rhs_dgl_csr, rhs_dgl_val); + return SparseMatrix::FromCSR( + CSRFromOldDGLCSR(ret_dgl_csr), DGLArrayToTorchTensor(ret_val), ret_shape); +} + } // namespace sparse } // namespace dgl diff --git a/dgl_sparse/src/matmul.h b/dgl_sparse/src/matmul.h index 58a1270dcfb1..515376526ff4 100644 --- a/dgl_sparse/src/matmul.h +++ b/dgl_sparse/src/matmul.h @@ -53,6 +53,28 @@ torch::Tensor SDDMMNoAutoGrad( const c10::intrusive_ptr& sparse_mat, torch::Tensor mat1, torch::Tensor mat2_tr); +/** + * @brief Perform a sparse-sparse matrix multiplication with possibly different + * sparsities. The two sparse values must have 1-dimensional values. If the + * first sparse matrix has shape (n, m), the second sparse matrix must have + * shape (m, k), and the returned sparse matrix has shape (n, k). + * + * This function does not take care of autograd. + * + * @param lhs_mat The first sparse matrix of shape (n, m). + * @param lhs_val Sparse value for the first sparse matrix. + * @param rhs_mat The second sparse matrix of shape (m, k). + * @param rhs_val Sparse value for the second sparse matrix. + * @param lhs_transpose Whether the first matrix is transposed. + * @param rhs_transpose Whether the second matrix is transposed. + * + * @return Sparse matrix of shape (n, k). + */ +c10::intrusive_ptr SpSpMMNoAutoGrad( + const c10::intrusive_ptr& lhs_mat, torch::Tensor lhs_val, + const c10::intrusive_ptr& rhs_mat, torch::Tensor rhs_val, + bool lhs_transpose, bool rhs_transpose); + } // namespace sparse } // namespace dgl diff --git a/dgl_sparse/src/python_binding.cc b/dgl_sparse/src/python_binding.cc index dc65ed8df59f..4079b8c273b4 100644 --- a/dgl_sparse/src/python_binding.cc +++ b/dgl_sparse/src/python_binding.cc @@ -12,6 +12,7 @@ #include #include #include +#include #include #include @@ -40,7 +41,8 @@ TORCH_LIBRARY(dgl_sparse, m) { .def("sprod", &ReduceProd) .def("val_like", &CreateValLike) .def("spmm", &SpMM) - .def("sddmm", &SDDMM); + .def("sddmm", &SDDMM) + .def("spspmm", &SpSpMM); } } // namespace sparse diff --git a/dgl_sparse/src/sddmm.cc b/dgl_sparse/src/sddmm.cc index 104416fd3c11..64ce43641cfe 100644 --- a/dgl_sparse/src/sddmm.cc +++ b/dgl_sparse/src/sddmm.cc @@ -95,6 +95,8 @@ c10::intrusive_ptr SDDMM( torch::Tensor mat2) { if (mat1.dim() == 1) { mat1 = mat1.view({mat1.size(0), 1}); + } + if (mat2.dim() == 1) { mat2 = mat2.view({1, mat2.size(0)}); } _SDDMMSanityCheck(sparse_mat, mat1, mat2); diff --git a/dgl_sparse/src/spspmm.cc b/dgl_sparse/src/spspmm.cc new file mode 100644 index 000000000000..36cfed934301 --- /dev/null +++ b/dgl_sparse/src/spspmm.cc @@ -0,0 +1,123 @@ +/** + * Copyright (c) 2022 by Contributors + * @file spspmm.cc + * @brief DGL C++ sparse SpSpMM operator implementation. + */ + +#include +#include +#include +#include + +#include "./matmul.h" +#include "./utils.h" + +namespace dgl { +namespace sparse { + +using namespace torch::autograd; + +class SpSpMMAutoGrad : public Function { + public: + static variable_list forward( + AutogradContext* ctx, c10::intrusive_ptr lhs_mat, + torch::Tensor lhs_val, c10::intrusive_ptr rhs_mat, + torch::Tensor rhs_val); + + static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs); +}; + +void _SpSpMMSanityCheck( + const c10::intrusive_ptr& lhs_mat, + const c10::intrusive_ptr& rhs_mat) { + const auto& lhs_shape = lhs_mat->shape(); + const auto& rhs_shape = rhs_mat->shape(); + CHECK_EQ(lhs_shape[1], rhs_shape[0]) + << "SpSpMM: the second dim of lhs_mat should be equal to the first dim " + "of the second matrix"; + CHECK_EQ(lhs_mat->value().dim(), 1) + << "SpSpMM: the value shape of lhs_mat should be 1-D"; + CHECK_EQ(rhs_mat->value().dim(), 1) + << "SpSpMM: the value shape of rhs_mat should be 1-D"; + CHECK_EQ(lhs_mat->device(), rhs_mat->device()) + << "SpSpMM: lhs_mat and rhs_mat should on the same device"; + CHECK_EQ(lhs_mat->dtype(), rhs_mat->dtype()) + << "SpSpMM: lhs_mat and rhs_mat should have the same dtype"; +} + +// Mask select value of `mat` by `sub_mat`. +torch::Tensor _CSRMask( + const c10::intrusive_ptr& mat, torch::Tensor value, + const c10::intrusive_ptr& sub_mat) { + auto csr = CSRToOldDGLCSR(mat->CSRPtr()); + auto val = TorchTensorToDGLArray(value); + auto row = TorchTensorToDGLArray(sub_mat->COOPtr()->row); + auto col = TorchTensorToDGLArray(sub_mat->COOPtr()->col); + runtime::NDArray ret; + ATEN_FLOAT_TYPE_SWITCH(val->dtype, DType, "Value Type", { + ret = aten::CSRGetData(csr, row, col, val, 0.); + }); + return DGLArrayToTorchTensor(ret); +} + +variable_list SpSpMMAutoGrad::forward( + AutogradContext* ctx, c10::intrusive_ptr lhs_mat, + torch::Tensor lhs_val, c10::intrusive_ptr rhs_mat, + torch::Tensor rhs_val) { + auto ret_mat = + SpSpMMNoAutoGrad(lhs_mat, lhs_val, rhs_mat, rhs_val, false, false); + + ctx->saved_data["lhs_mat"] = lhs_mat; + ctx->saved_data["rhs_mat"] = rhs_mat; + ctx->saved_data["ret_mat"] = ret_mat; + ctx->saved_data["lhs_require_grad"] = lhs_val.requires_grad(); + ctx->saved_data["rhs_require_grad"] = rhs_val.requires_grad(); + ctx->save_for_backward({lhs_val, rhs_val}); + + auto csr = ret_mat->CSRPtr(); + auto val = ret_mat->value(); + CHECK(!csr->value_indices.has_value()); + return {csr->indptr, csr->indices, val}; +} + +tensor_list SpSpMMAutoGrad::backward( + AutogradContext* ctx, tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto lhs_val = saved[0]; + auto rhs_val = saved[1]; + auto output_grad = grad_outputs[2]; + auto lhs_mat = ctx->saved_data["lhs_mat"].toCustomClass(); + auto rhs_mat = ctx->saved_data["rhs_mat"].toCustomClass(); + auto ret_mat = ctx->saved_data["ret_mat"].toCustomClass(); + torch::Tensor lhs_val_grad, rhs_val_grad; + + if (ctx->saved_data["lhs_require_grad"].toBool()) { + // A @ B = C -> dA = dC @ (B^T) + auto lhs_mat_grad = + SpSpMMNoAutoGrad(ret_mat, output_grad, rhs_mat, rhs_val, false, true); + lhs_val_grad = _CSRMask(lhs_mat_grad, lhs_mat_grad->value(), lhs_mat); + } + if (ctx->saved_data["rhs_require_grad"].toBool()) { + // A @ B = C -> dB = (A^T) @ dC + auto rhs_mat_grad = + SpSpMMNoAutoGrad(lhs_mat, lhs_val, ret_mat, output_grad, true, false); + rhs_val_grad = _CSRMask(rhs_mat_grad, rhs_mat_grad->value(), rhs_mat); + } + return {torch::Tensor(), lhs_val_grad, torch::Tensor(), rhs_val_grad}; +} + +c10::intrusive_ptr SpSpMM( + const c10::intrusive_ptr& lhs_mat, + const c10::intrusive_ptr& rhs_mat) { + _SpSpMMSanityCheck(lhs_mat, rhs_mat); + auto results = SpSpMMAutoGrad::apply( + lhs_mat, lhs_mat->value(), rhs_mat, rhs_mat->value()); + std::vector ret_shape({lhs_mat->shape()[0], rhs_mat->shape()[1]}); + auto indptr = results[0]; + auto indices = results[1]; + auto value = results[2]; + return CreateFromCSR(indptr, indices, value, ret_shape); +} + +} // namespace sparse +} // namespace dgl diff --git a/python/dgl/mock_sparse2/matmul.py b/python/dgl/mock_sparse2/matmul.py index b54129151ec0..16235c4ee1bb 100644 --- a/python/dgl/mock_sparse2/matmul.py +++ b/python/dgl/mock_sparse2/matmul.py @@ -4,11 +4,11 @@ import torch -from .diag_matrix import DiagMatrix +from .diag_matrix import diag, DiagMatrix from .sparse_matrix import SparseMatrix -__all__ = ["spmm"] +__all__ = ["spmm", "spspmm", "mm"] def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: @@ -53,51 +53,138 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: return torch.ops.dgl_sparse.spmm(A.c_sparse_matrix, X) -def mm_sp( - A1: SparseMatrix, A2: Union[torch.Tensor, SparseMatrix, DiagMatrix] -) -> Union[torch.Tensor, SparseMatrix]: - """Internal function for multiplying a sparse matrix by - a dense/sparse/diagonal matrix. +def _diag_diag_mm(A1: DiagMatrix, A2: DiagMatrix) -> DiagMatrix: + """Internal function for multiplying a diagonal matrix by a diagonal matrix Parameters ---------- - A1 : SparseMatrix + A1 : DiagMatrix Matrix of shape (N, M), with values of shape (nnz1) - A2 : torch.Tensor, SparseMatrix, or DiagMatrix - If A2 is a dense tensor, it can have shapes of (M, P) or (M, ). - Otherwise it must have a shape of (M, P). + A2 : DiagMatrix + Matrix of shape (M, P), with values of shape (nnz2) Returns ------- - torch.Tensor or SparseMatrix + DiagMatrix The result of multiplication. + """ + M, N = A1.shape + N, P = A2.shape + common_diag_len = min(M, N, P) + new_diag_len = min(M, P) + diag_val = torch.zeros(new_diag_len) + diag_val[:common_diag_len] = ( + A1.val[:common_diag_len] * A2.val[:common_diag_len] + ) + return diag(diag_val.to(A1.device), (M, P)) + + +def spspmm( + A1: Union[SparseMatrix, DiagMatrix], A2: Union[SparseMatrix, DiagMatrix] +) -> Union[SparseMatrix, DiagMatrix]: + """Multiply a sparse matrix by a sparse matrix. The non-zero values of the + two sparse matrices must be 1D. + + Parameters + ---------- + A1 : SparseMatrix or DiagMatrix + Sparse matrix of shape (N, M) with values of shape (nnz) + A2 : SparseMatrix or DiagMatrix + Sparse matrix of shape (M, P) with values of shape (nnz) + + Returns + ------- + SparseMatrix or DiagMatrix + The result of multiplication. It is a DiagMatrix object if both matrices + are DiagMatrix objects. It is a SparseMatrix object otherwise. + + Examples + -------- + + >>> row1 = torch.tensor([0, 1, 1]) + >>> col1 = torch.tensor([1, 0, 1]) + >>> val1 = torch.ones(len(row1)) + >>> A1 = create_from_coo(row1, col1, val1) + + >>> row2 = torch.tensor([0, 1, 1]) + >>> col2 = torch.tensor([0, 2, 1]) + >>> val2 = torch.ones(len(row2)) + >>> A2 = create_from_coo(row2, col2, val2) + >>> result = dgl.sparse.spspmm(A1, A2) + >>> print(result) + SparseMatrix(indices=tensor([[0, 0, 1, 1, 1], + [1, 2, 0, 1, 2]]), + values=tensor([1., 1., 1., 1., 1.]), + shape=(2, 3), nnz=5) + """ + assert isinstance( + A1, (SparseMatrix, DiagMatrix) + ), f"Expect A1 to be a SparseMatrix or DiagMatrix object, got {type(A1)}" + assert isinstance( + A2, (SparseMatrix, DiagMatrix) + ), f"Expect A2 to be a SparseMatrix or DiagMatrix object, got {type(A2)}" + + if isinstance(A1, DiagMatrix) and isinstance(A2, DiagMatrix): + return _diag_diag_mm(A1, A2) + if isinstance(A1, DiagMatrix): + A1 = A1.as_sparse() + if isinstance(A2, DiagMatrix): + A2 = A2.as_sparse() + return SparseMatrix( + torch.ops.dgl_sparse.spspmm(A1.c_sparse_matrix, A2.c_sparse_matrix) + ) + + +def mm( + A1: Union[SparseMatrix, DiagMatrix], + A2: Union[torch.Tensor, SparseMatrix, DiagMatrix], +) -> Union[torch.Tensor, SparseMatrix, DiagMatrix]: + """Multiply a sparse/diagonal matrix by a dense/sparse/diagonal matrix. + If an input is a SparseMatrix or DiagMatrix, its non-zero values should + be 1-D. + + Parameters + ---------- + A1 : SparseMatrix or DiagMatrix + Matrix of shape (N, M), with values of shape (nnz1) + A2 : torch.Tensor, SparseMatrix, or DiagMatrix + Matrix of shape (M, P). If it is a SparseMatrix or DiagMatrix, + it should have values of shape (nnz2). + + Returns + ------- + torch.Tensor or DiagMatrix or SparseMatrix + The result of multiplication of shape (N, P) * It is a dense torch tensor if :attr:`A2` is so. + * It is a DiagMatrix object if both :attr:`A1` and :attr:`A2` are so. * It is a SparseMatrix object otherwise. Examples -------- - >>> row = torch.tensor([0, 1, 1]) - >>> col = torch.tensor([1, 0, 1]) - >>> val = torch.randn(len(row)) - >>> A1 = create_from_coo(row, col, val) - >>> A2 = torch.randn(2, 3) - >>> result = A1 @ A2 + >>> val = torch.randn(3) + >>> A1 = diag(val) + >>> A2 = torch.randn(3, 2) + >>> result = dgl.sparse.mm(A1, A2) >>> print(type(result)) >>> print(result.shape) - torch.Size([2, 3]) + torch.Size([3, 2]) """ + assert isinstance( + A1, (SparseMatrix, DiagMatrix) + ), f"Expect arg1 to be a SparseMatrix, or DiagMatrix object, got {type(A1)}." assert isinstance(A2, (torch.Tensor, SparseMatrix, DiagMatrix)), ( - f"Expect arg2 to be a torch Tensor, SparseMatrix, or DiagMatrix object," - f"got {type(A2)}" + f"Expect arg2 to be a torch Tensor, SparseMatrix, or DiagMatrix" + f"object, got {type(A2)}." ) - if isinstance(A2, torch.Tensor): return spmm(A1, A2) - else: - raise NotImplementedError + if isinstance(A1, DiagMatrix) and isinstance(A2, DiagMatrix): + return _diag_diag_mm(A1, A2) + return spspmm(A1, A2) -SparseMatrix.__matmul__ = mm_sp +SparseMatrix.__matmul__ = mm +DiagMatrix.__matmul__ = mm diff --git a/python/dgl/mock_sparse2/sddmm.py b/python/dgl/mock_sparse2/sddmm.py index 34b86c5a1109..1ca6983709af 100644 --- a/python/dgl/mock_sparse2/sddmm.py +++ b/python/dgl/mock_sparse2/sddmm.py @@ -11,8 +11,8 @@ def sddmm( ) -> SparseMatrix: r"""Sampled-Dense-Dense Matrix Multiplication (SDDMM). - ``sddmm`` multiplies two dense matrices :attr:``mat1`` and :attr:``mat2`` - at the nonzero locations of sparse matrix :attr:``A``. Values of :attr:``A`` + ``sddmm`` multiplies two dense matrices :attr:`mat1` and :attr:`mat2` + at the nonzero locations of sparse matrix :attr:`A`. Values of :attr:`A` is not considered during the computation. Mathematically ``sddmm`` is formulated as: @@ -20,19 +20,23 @@ def sddmm( .. math:: out = (mat1 @ mat2) * A + In particular, :attr:`mat1` and :attr:`mat2` can be 1-D, then ``mat1 @ + mat2`` becomes the out-product of the two vector (which results in a + matrix). + Parameters ---------- A : SparseMatrix - Sparse matrix of shape `(M, N)`. + Sparse matrix of shape ``(M, N)``. mat1 : Tensor - Dense matrix of shape `(M, K)` + Dense matrix of shape ``(M, K)`` or ``(M,)`` mat2 : Tensor - Dense matrix of shape `(K, N)` + Dense matrix of shape ``(K, N)`` or ``(N,)`` Returns ------- SparseMatrix - Sparse matrix of shape `(M, N)`. + Sparse matrix of shape ``(M, N)``. Examples -------- diff --git a/tests/pytorch/mock_sparse2/test_matmul.py b/tests/pytorch/mock_sparse2/test_matmul.py index e8372ec3166c..10c551203b22 100644 --- a/tests/pytorch/mock_sparse2/test_matmul.py +++ b/tests/pytorch/mock_sparse2/test_matmul.py @@ -51,3 +51,41 @@ def test_spmm(create_func, shape, nnz, out_dim): sparse_matrix_to_dense(val_like(A, A.val.grad)), atol=1e-05, ) + + +@pytest.mark.parametrize("create_func1", [rand_coo, rand_csr, rand_csc]) +@pytest.mark.parametrize("create_func2", [rand_coo, rand_csr, rand_csc]) +@pytest.mark.parametrize("shape_n_m", [(5, 5), (5, 6)]) +@pytest.mark.parametrize("shape_k", [3, 4]) +@pytest.mark.parametrize("nnz1", [1, 10]) +@pytest.mark.parametrize("nnz2", [1, 10]) +def test_sparse_sparse_mm( + create_func1, create_func2, shape_n_m, shape_k, nnz1, nnz2 +): + dev = F.ctx() + shape1 = shape_n_m + shape2 = (shape_n_m[1], shape_k) + A1 = create_func1(shape1, nnz1, dev) + A2 = create_func2(shape2, nnz2, dev) + A3 = A1 @ A2 + grad = torch.randn_like(A3.val) + A3.val.backward(grad) + + torch_A1 = sparse_matrix_to_torch_sparse(A1) + torch_A2 = sparse_matrix_to_torch_sparse(A2) + torch_A3 = torch.sparse.mm(torch_A1, torch_A2) + torch_A3_grad = sparse_matrix_to_torch_sparse(A3, grad) + torch_A3.backward(torch_A3_grad) + + with torch.no_grad(): + assert torch.allclose(A3.dense(), torch_A3.to_dense(), atol=1e-05) + assert torch.allclose( + val_like(A1, A1.val.grad).dense(), + torch_A1.grad.to_dense(), + atol=1e-05, + ) + assert torch.allclose( + val_like(A2, A2.val.grad).dense(), + torch_A2.grad.to_dense(), + atol=1e-05, + ) diff --git a/tests/pytorch/mock_sparse2/test_sddmm.py b/tests/pytorch/mock_sparse2/test_sddmm.py index f9e4bf90025d..f985cb3f7abc 100644 --- a/tests/pytorch/mock_sparse2/test_sddmm.py +++ b/tests/pytorch/mock_sparse2/test_sddmm.py @@ -13,7 +13,7 @@ @pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc]) -@pytest.mark.parametrize("shape", [(2, 3), (5, 2)]) +@pytest.mark.parametrize("shape", [(5, 5), (5, 4)]) @pytest.mark.parametrize("nnz", [2, 10]) @pytest.mark.parametrize("hidden", [1, 5]) def test_sddmm(create_func, shape, nnz, hidden): diff --git a/tests/pytorch/mock_sparse2/utils.py b/tests/pytorch/mock_sparse2/utils.py index 2a16d36ea7e8..a38ffb787765 100644 --- a/tests/pytorch/mock_sparse2/utils.py +++ b/tests/pytorch/mock_sparse2/utils.py @@ -1,4 +1,6 @@ +import numpy as np import torch + from dgl.mock_sparse2 import ( create_from_coo, create_from_csc, @@ -6,6 +8,9 @@ SparseMatrix, ) +np.random.seed(42) +torch.random.manual_seed(42) + def clone_detach_and_grad(t): t = t.clone().detach() @@ -14,33 +19,80 @@ def clone_detach_and_grad(t): def rand_coo(shape, nnz, dev): - row = torch.randint(0, shape[0], (nnz,), device=dev) - col = torch.randint(0, shape[1], (nnz,), device=dev) + # Create a sparse matrix without duplicate entries. + nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False) + nnzid = torch.tensor(nnzid, device=dev).long() + row = torch.div(nnzid, shape[1], rounding_mode="floor") + col = nnzid % shape[1] val = torch.randn(nnz, device=dev, requires_grad=True) return create_from_coo(row, col, val, shape) def rand_csr(shape, nnz, dev): - row = torch.randint(0, shape[0], (nnz,), device=dev) - col = torch.randint(0, shape[1], (nnz,), device=dev) + # Create a sparse matrix without duplicate entries. + nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False) + nnzid = torch.tensor(nnzid, device=dev).long() + row = torch.div(nnzid, shape[1], rounding_mode="floor") + col = nnzid % shape[1] val = torch.randn(nnz, device=dev, requires_grad=True) indptr = torch.zeros(shape[0] + 1, device=dev, dtype=torch.int64) for r in row.tolist(): indptr[r + 1] += 1 indptr = torch.cumsum(indptr, 0) - indices = col + row_sorted, row_sorted_idx = torch.sort(row) + indices = col[row_sorted_idx] return create_from_csr(indptr, indices, val, shape=shape) def rand_csc(shape, nnz, dev): - row = torch.randint(0, shape[0], (nnz,), device=dev) - col = torch.randint(0, shape[1], (nnz,), device=dev) + # Create a sparse matrix without duplicate entries. + nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False) + nnzid = torch.tensor(nnzid, device=dev).long() + row = torch.div(nnzid, shape[1], rounding_mode="floor") + col = nnzid % shape[1] + val = torch.randn(nnz, device=dev, requires_grad=True) + indptr = torch.zeros(shape[1] + 1, device=dev, dtype=torch.int64) + for c in col.tolist(): + indptr[c + 1] += 1 + indptr = torch.cumsum(indptr, 0) + col_sorted, col_sorted_idx = torch.sort(col) + indices = row[col_sorted_idx] + return create_from_csc(indptr, indices, val, shape=shape) + + +def rand_coo_uncoalesced(shape, nnz, dev): + # Create a sparse matrix with possible duplicate entries. + row = torch.randint(shape[0], (nnz,), device=dev) + col = torch.randint(shape[1], (nnz,), device=dev) + val = torch.randn(nnz, device=dev, requires_grad=True) + return create_from_coo(row, col, val, shape) + + +def rand_csr_uncoalesced(shape, nnz, dev): + # Create a sparse matrix with possible duplicate entries. + row = torch.randint(shape[0], (nnz,), device=dev) + col = torch.randint(shape[1], (nnz,), device=dev) + val = torch.randn(nnz, device=dev, requires_grad=True) + indptr = torch.zeros(shape[0] + 1, device=dev, dtype=torch.int64) + for r in row.tolist(): + indptr[r + 1] += 1 + indptr = torch.cumsum(indptr, 0) + row_sorted, row_sorted_idx = torch.sort(row) + indices = col[row_sorted_idx] + return create_from_csr(indptr, indices, val, shape=shape) + + +def rand_csc_uncoalesced(shape, nnz, dev): + # Create a sparse matrix with possible duplicate entries. + row = torch.randint(shape[0], (nnz,), device=dev) + col = torch.randint(shape[1], (nnz,), device=dev) val = torch.randn(nnz, device=dev, requires_grad=True) indptr = torch.zeros(shape[1] + 1, device=dev, dtype=torch.int64) for c in col.tolist(): indptr[c + 1] += 1 indptr = torch.cumsum(indptr, 0) - indices = row + col_sorted, col_sorted_idx = torch.sort(col) + indices = row[col_sorted_idx] return create_from_csc(indptr, indices, val, shape=shape) @@ -50,11 +102,13 @@ def sparse_matrix_to_dense(A: SparseMatrix): return dense -def sparse_matrix_to_torch_sparse(A: SparseMatrix): +def sparse_matrix_to_torch_sparse(A: SparseMatrix, val=None): row, col = A.coo() edge_index = torch.cat((row.unsqueeze(0), col.unsqueeze(0)), 0) shape = A.shape - val = A.val.clone().detach() + if val is None: + val = A.val + val = val.clone().detach() if len(A.val.shape) > 1: shape += (A.val.shape[-1],) ret = torch.sparse_coo_tensor(edge_index, val, shape).coalesce()