Skip to content

Commit

Permalink
Pass unit test for SparseCpuMatrix::mul(CpuMatrix, CpuMatrix),
Browse files Browse the repository at this point in the history
SparseGpuMatrix::mul(GpuMatrix, GpuMatrix),
CpuMatrix::mul(CpuSparseMatrix, CpuMatrix),
and GpuMatrix::mul(GpuSparseMatrix, GpuMatrix)
  • Loading branch information
xutianbing committed Jan 25, 2017
1 parent 1ca2846 commit 4751cc8
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 101 deletions.
39 changes: 21 additions & 18 deletions paddle/function/MulOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,15 +498,10 @@ class MulFunc : public FunctionBase {
CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
CHECK_EQ(outputs[0].getArgType(), ADD_TO);

/// todo(tianbing), support SparseMatrixArg for out_mat
auto out_mat = outputs[0].matrix<Device>();
LOG(INFO) << "out_mat:";
out_mat.print(std::cout);
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg()) {
LOG(INFO) << "in1_mat:";
inputs[0].matrix<Device>().print(std::cout);
LOG(INFO) << "in2_mat:";
inputs[1].matrix<Device>().print(std::cout);
/// matrix = matrix * matrix
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) {
MulOp<Device>(out_mat,
inputs[0].matrix<Device>(),
inputs[1].matrix<Device>(),
Expand All @@ -515,11 +510,9 @@ class MulFunc : public FunctionBase {
return;
}

if (!inputs[0].isSparseArg() && inputs[1].isSparseArg()) {
LOG(INFO) << "in1_mat:";
inputs[0].matrix<Device>().print(std::cout);
LOG(INFO) << "in2_mat:";
inputs[1].sparse().SparseMatrix<Device>().print(std::cout);
/// matrix = matrix * sparse matrix
if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) {
MulOp<Device>(out_mat,
inputs[0].matrix<Device>(),
inputs[1].sparse().SparseMatrix<Device>(),
Expand All @@ -528,18 +521,28 @@ class MulFunc : public FunctionBase {
return;
}

if (inputs[0].isSparseArg() && !inputs[1].isSparseArg()) {
LOG(INFO) << "in1_mat:";
inputs[0].sparse().SparseMatrix<Device>().print(std::cout);
LOG(INFO) << "in2_mat:";
inputs[1].matrix<Device>().print(std::cout);
/// matrix = sparse matrix * matrix
if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) {
MulOp<Device>(out_mat,
inputs[0].sparse().SparseMatrix<Device>(),
inputs[1].matrix<Device>(),
alpha_,
beta_);
return;
}

/// sparse matrix = matrix * matrix
auto out_sparse_mat = outputs[0].sparse().SparseMatrix<Device>();
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
outputs[0].isSparseArg()) {
MulOp<Device>(out_sparse_mat,
inputs[0].matrix<Device>(),
inputs[1].matrix<Device>(),
alpha_,
beta_);
return;
}
}

private:
Expand Down
31 changes: 30 additions & 1 deletion paddle/function/MulOpGpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,36 @@ void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out,
const GpuMatrix& b,
real scale_ab,
real scale_t) {
/// todo(tianbing), implement it
/// todo(tianbing), clean the code
CHECK(a.useGpu_ && b.useGpu_) << "type not match";
CHECK(!out.trans_) << "trans not supported";
real* a_data = const_cast<real*>(a.getData());
real* b_data = const_cast<real*>(b.getData());
hl_sparse_matrix_s out_data = out.sMatrix_.get();
hl_trans_op_t a_trans = a.trans_ ? HPPL_OP_T : HPPL_OP_N;
hl_trans_op_t b_trans = b.trans_ ? HPPL_OP_T : HPPL_OP_N;

if (!a.trans_ && !b.trans_) {
CHECK(out.height_ == a.getHeight());
CHECK(out.width_ == b.getWidth());
CHECK(a.getWidth() == b.getHeight());
} else if (a.trans_ && !b.trans_) {
CHECK(out.height_ == a.getWidth());
CHECK(out.width_ == b.getWidth());
CHECK(a.getHeight() == b.getHeight());
} else if (!a.trans_ && b.trans_) {
CHECK(out.height_ == a.getHeight());
CHECK(out.width_ == b.getHeight());
CHECK(a.getWidth() == b.getWidth());
} else {
LOG(INFO) << "Not support";
}
int dim_m = out.height_;
int dim_n = out.width_;
int dim_k = !b.trans_ ? b.getHeight() : b.getWidth();
hl_sparse_matrix_mul(
a_data, a_trans, b_data, b_trans, out_data,
dim_m, dim_n, dim_k, scale_ab, scale_t);
}

} // namespace paddle
Loading

0 comments on commit 4751cc8

Please sign in to comment.