Skip to content

Commit

Permalink
[Sparse] Add SpSpMM for sparse-sparse and sparse-diag matrix multipli…
Browse files Browse the repository at this point in the history
…cation. (dmlc#5050)

* [Sparse] Add SpSpMM

* Update matmul interface

* address comments

* fix test utils to generate only coalesced matrices

* fix linter

* fix ut

* fix

* rm print

Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
czkkkkkk and jermainewang authored Dec 24, 2022
1 parent f0ce2be commit 0159c3c
Show file tree
Hide file tree
Showing 12 changed files with 444 additions and 43 deletions.
2 changes: 2 additions & 0 deletions dgl_sparse/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
project(dgl_sparse C CXX)

# Find PyTorch cmake files and PyTorch versions with the python interpreter $TORCH_PYTHON_INTERPS
# ("python3" or "python" if empty)
if(NOT TORCH_PYTHON_INTERPS)
Expand Down
37 changes: 37 additions & 0 deletions dgl_sparse/include/sparse/spspmm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/**
* Copyright (c) 2022 by Contributors
* @file sparse/spspmm.h
* @brief DGL C++ SpSpMM operator.
*/
#ifndef SPARSE_SPSPMM_H_
#define SPARSE_SPSPMM_H_

#include <sparse/sparse_matrix.h>
#include <torch/script.h>

namespace dgl {
namespace sparse {

/**
* @brief Perform a sparse-sparse matrix multiplication on matrices with
* possibly different sparsities. The two sparse matrices must have
* 1-D values. If the first sparse matrix has shape (n, m), the second
* sparse matrix must have shape (m, k), and the returned sparse matrix has
* shape (n, k).
*
* This function supports autograd for both sparse matrices but does
* not support higher order gradient.
*
* @param lhs_mat The first sparse matrix of shape (n, m).
* @param rhs_mat The second sparse matrix of shape (m, k).
*
* @return Sparse matrix of shape (n, k).
*/
c10::intrusive_ptr<SparseMatrix> SpSpMM(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat);

} // namespace sparse
} // namespace dgl

#endif // SPARSE_SPSPMM_H_
30 changes: 30 additions & 0 deletions dgl_sparse/src/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,35 @@ torch::Tensor SDDMMNoAutoGrad(
return ret;
}

c10::intrusive_ptr<SparseMatrix> SpSpMMNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat, torch::Tensor lhs_val,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat, torch::Tensor rhs_val,
bool lhs_transpose, bool rhs_transpose) {
aten::CSRMatrix lhs_dgl_csr, rhs_dgl_csr;
if (!lhs_transpose) {
lhs_dgl_csr = CSRToOldDGLCSR(lhs_mat->CSRPtr());
} else {
lhs_dgl_csr = CSRToOldDGLCSR(lhs_mat->CSCPtr());
}
if (!rhs_transpose) {
rhs_dgl_csr = CSRToOldDGLCSR(rhs_mat->CSRPtr());
} else {
rhs_dgl_csr = CSRToOldDGLCSR(rhs_mat->CSCPtr());
}
auto lhs_dgl_val = TorchTensorToDGLArray(lhs_val);
auto rhs_dgl_val = TorchTensorToDGLArray(rhs_val);
const int64_t ret_row =
lhs_transpose ? lhs_mat->shape()[1] : lhs_mat->shape()[0];
const int64_t ret_col =
rhs_transpose ? rhs_mat->shape()[0] : rhs_mat->shape()[1];
std::vector<int64_t> ret_shape({ret_row, ret_col});
aten::CSRMatrix ret_dgl_csr;
runtime::NDArray ret_val;
std::tie(ret_dgl_csr, ret_val) =
aten::CSRMM(lhs_dgl_csr, lhs_dgl_val, rhs_dgl_csr, rhs_dgl_val);
return SparseMatrix::FromCSR(
CSRFromOldDGLCSR(ret_dgl_csr), DGLArrayToTorchTensor(ret_val), ret_shape);
}

} // namespace sparse
} // namespace dgl
22 changes: 22 additions & 0 deletions dgl_sparse/src/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,28 @@ torch::Tensor SDDMMNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1,
torch::Tensor mat2_tr);

/**
* @brief Perform a sparse-sparse matrix multiplication with possibly different
* sparsities. The two sparse values must have 1-dimensional values. If the
* first sparse matrix has shape (n, m), the second sparse matrix must have
* shape (m, k), and the returned sparse matrix has shape (n, k).
*
* This function does not take care of autograd.
*
* @param lhs_mat The first sparse matrix of shape (n, m).
* @param lhs_val Sparse value for the first sparse matrix.
* @param rhs_mat The second sparse matrix of shape (m, k).
* @param rhs_val Sparse value for the second sparse matrix.
* @param lhs_transpose Whether the first matrix is transposed.
* @param rhs_transpose Whether the second matrix is transposed.
*
* @return Sparse matrix of shape (n, k).
*/
c10::intrusive_ptr<SparseMatrix> SpSpMMNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat, torch::Tensor lhs_val,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat, torch::Tensor rhs_val,
bool lhs_transpose, bool rhs_transpose);

} // namespace sparse
} // namespace dgl

Expand Down
4 changes: 3 additions & 1 deletion dgl_sparse/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <sparse/sddmm.h>
#include <sparse/sparse_matrix.h>
#include <sparse/spmm.h>
#include <sparse/spspmm.h>
#include <torch/custom_class.h>
#include <torch/script.h>

Expand Down Expand Up @@ -40,7 +41,8 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("sprod", &ReduceProd)
.def("val_like", &CreateValLike)
.def("spmm", &SpMM)
.def("sddmm", &SDDMM);
.def("sddmm", &SDDMM)
.def("spspmm", &SpSpMM);
}

} // namespace sparse
Expand Down
2 changes: 2 additions & 0 deletions dgl_sparse/src/sddmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ c10::intrusive_ptr<SparseMatrix> SDDMM(
torch::Tensor mat2) {
if (mat1.dim() == 1) {
mat1 = mat1.view({mat1.size(0), 1});
}
if (mat2.dim() == 1) {
mat2 = mat2.view({1, mat2.size(0)});
}
_SDDMMSanityCheck(sparse_mat, mat1, mat2);
Expand Down
123 changes: 123 additions & 0 deletions dgl_sparse/src/spspmm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/**
* Copyright (c) 2022 by Contributors
* @file spspmm.cc
* @brief DGL C++ sparse SpSpMM operator implementation.
*/

#include <sparse/sddmm.h>
#include <sparse/sparse_matrix.h>
#include <sparse/spspmm.h>
#include <torch/script.h>

#include "./matmul.h"
#include "./utils.h"

namespace dgl {
namespace sparse {

using namespace torch::autograd;

class SpSpMMAutoGrad : public Function<SpSpMMAutoGrad> {
public:
static variable_list forward(
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,
torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,
torch::Tensor rhs_val);

static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
};

void _SpSpMMSanityCheck(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
const auto& lhs_shape = lhs_mat->shape();
const auto& rhs_shape = rhs_mat->shape();
CHECK_EQ(lhs_shape[1], rhs_shape[0])
<< "SpSpMM: the second dim of lhs_mat should be equal to the first dim "
"of the second matrix";
CHECK_EQ(lhs_mat->value().dim(), 1)
<< "SpSpMM: the value shape of lhs_mat should be 1-D";
CHECK_EQ(rhs_mat->value().dim(), 1)
<< "SpSpMM: the value shape of rhs_mat should be 1-D";
CHECK_EQ(lhs_mat->device(), rhs_mat->device())
<< "SpSpMM: lhs_mat and rhs_mat should on the same device";
CHECK_EQ(lhs_mat->dtype(), rhs_mat->dtype())
<< "SpSpMM: lhs_mat and rhs_mat should have the same dtype";
}

// Mask select value of `mat` by `sub_mat`.
torch::Tensor _CSRMask(
const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value,
const c10::intrusive_ptr<SparseMatrix>& sub_mat) {
auto csr = CSRToOldDGLCSR(mat->CSRPtr());
auto val = TorchTensorToDGLArray(value);
auto row = TorchTensorToDGLArray(sub_mat->COOPtr()->row);
auto col = TorchTensorToDGLArray(sub_mat->COOPtr()->col);
runtime::NDArray ret;
ATEN_FLOAT_TYPE_SWITCH(val->dtype, DType, "Value Type", {
ret = aten::CSRGetData<DType>(csr, row, col, val, 0.);
});
return DGLArrayToTorchTensor(ret);
}

variable_list SpSpMMAutoGrad::forward(
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,
torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,
torch::Tensor rhs_val) {
auto ret_mat =
SpSpMMNoAutoGrad(lhs_mat, lhs_val, rhs_mat, rhs_val, false, false);

ctx->saved_data["lhs_mat"] = lhs_mat;
ctx->saved_data["rhs_mat"] = rhs_mat;
ctx->saved_data["ret_mat"] = ret_mat;
ctx->saved_data["lhs_require_grad"] = lhs_val.requires_grad();
ctx->saved_data["rhs_require_grad"] = rhs_val.requires_grad();
ctx->save_for_backward({lhs_val, rhs_val});

auto csr = ret_mat->CSRPtr();
auto val = ret_mat->value();
CHECK(!csr->value_indices.has_value());
return {csr->indptr, csr->indices, val};
}

tensor_list SpSpMMAutoGrad::backward(
AutogradContext* ctx, tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto lhs_val = saved[0];
auto rhs_val = saved[1];
auto output_grad = grad_outputs[2];
auto lhs_mat = ctx->saved_data["lhs_mat"].toCustomClass<SparseMatrix>();
auto rhs_mat = ctx->saved_data["rhs_mat"].toCustomClass<SparseMatrix>();
auto ret_mat = ctx->saved_data["ret_mat"].toCustomClass<SparseMatrix>();
torch::Tensor lhs_val_grad, rhs_val_grad;

if (ctx->saved_data["lhs_require_grad"].toBool()) {
// A @ B = C -> dA = dC @ (B^T)
auto lhs_mat_grad =
SpSpMMNoAutoGrad(ret_mat, output_grad, rhs_mat, rhs_val, false, true);
lhs_val_grad = _CSRMask(lhs_mat_grad, lhs_mat_grad->value(), lhs_mat);
}
if (ctx->saved_data["rhs_require_grad"].toBool()) {
// A @ B = C -> dB = (A^T) @ dC
auto rhs_mat_grad =
SpSpMMNoAutoGrad(lhs_mat, lhs_val, ret_mat, output_grad, true, false);
rhs_val_grad = _CSRMask(rhs_mat_grad, rhs_mat_grad->value(), rhs_mat);
}
return {torch::Tensor(), lhs_val_grad, torch::Tensor(), rhs_val_grad};
}

c10::intrusive_ptr<SparseMatrix> SpSpMM(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
_SpSpMMSanityCheck(lhs_mat, rhs_mat);
auto results = SpSpMMAutoGrad::apply(
lhs_mat, lhs_mat->value(), rhs_mat, rhs_mat->value());
std::vector<int64_t> ret_shape({lhs_mat->shape()[0], rhs_mat->shape()[1]});
auto indptr = results[0];
auto indices = results[1];
auto value = results[2];
return CreateFromCSR(indptr, indices, value, ret_shape);
}

} // namespace sparse
} // namespace dgl
Loading

0 comments on commit 0159c3c

Please sign in to comment.