Skip to content

Commit

Permalink
Keep the original function name and add broadcast_keepdim & reduce_ke…
Browse files Browse the repository at this point in the history
…epdim
  • Loading branch information
sxjscience committed May 28, 2016
1 parent a90696e commit 997d042
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 15 deletions.
27 changes: 21 additions & 6 deletions mshadow/extension/broadcast_with_axis.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ struct BroadcastWithAxisExp:
/*! \brief size of the last dimension of src*/
index_t last_;
/*! constructor */
BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size, int keepdim)
BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size)
: src_(src), size_(size) {
bool keepdim = (dimsrc == dimdst);
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src_);
this->trailing_ = 1;

Expand Down Expand Up @@ -71,19 +72,33 @@ struct BroadcastWithAxisExp:
}; // struct BroadcastWithAxisExp

/*!
* \brief Broadcasting the tensor in the given axis. If keepdim is off, insert the broadcasting dim after axis. Otherwise broadcasting axis.
* \param keepdim whether to keepdim
* \brief Broadcasting the tensor after given axis.
* \param SrcExp source expression
* \tparam DType data type
* \tparam etype type of the expression
*/
template<int keepdim, typename SrcExp, typename DType, int etype>
template<typename SrcExp, typename DType, int etype>
inline BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
ExpInfo<SrcExp>::kDim + 1 - keepdim>
ExpInfo<SrcExp>::kDim + 1>
broadcast_with_axis(const Exp<SrcExp, DType, etype> &src, const int axis, const index_t size) {
return BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
ExpInfo<SrcExp>::kDim + 1 - keepdim>(src.self(), axis, size, keepdim);
ExpInfo<SrcExp>::kDim + 1>(src.self(), axis, size);
}

/*!
* \brief Broadcasting the tensor in the given axis (keepdim turned on)
* \param SrcExp source expression
* \tparam DType data type
* \tparam etype type of the expression
*/
template<typename SrcExp, typename DType, int etype>
inline BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
ExpInfo<SrcExp>::kDim>
broadcast_keepdim(const Exp<SrcExp, DType, etype> &src, const int axis, const index_t size) {
return BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
ExpInfo<SrcExp>::kDim>(src.self(), axis, size);
}

//----------------------
// Execution plan
//----------------------
Expand Down
27 changes: 22 additions & 5 deletions mshadow/extension/reduce_with_axis.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ struct ReduceWithAxisExp:
/*! \brief size of last src dimension */
index_t last_;
/*! constructor */
explicit ReduceWithAxisExp(const SrcExp &src, int axis, int keepdim)
explicit ReduceWithAxisExp(const SrcExp &src, int axis)
: src_(src) {
bool keepdim = (dimsrc == dimdst);
CHECK(dimsrc > axis) << "reduce axis out of bound";
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src_);
for (index_t i = 0; i < axis; ++i) {
Expand Down Expand Up @@ -63,18 +64,34 @@ struct ReduceWithAxisExp:
* \brief reduce out the dimension of src labeled by axis.
* \param Reducer type of the reducing operation
* \param mask whether to output the unmask indices
* \param keepdim the keepdim flag
* \tparam SrcExp source expression
* \tparam DType data type
* \tparam etype type of the expression
*/
template<typename Reducer, bool mask, int keepdim, typename SrcExp, typename DType, int etype>
template<typename Reducer, bool mask, typename SrcExp, typename DType, int etype>
inline ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
ExpInfo<SrcExp>::kDim + keepdim - 1>
ExpInfo<SrcExp>::kDim - 1>
reduce_with_axis(const Exp<SrcExp, DType, etype> &src, int axis) {
return ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
ExpInfo<SrcExp>::kDim + keepdim - 1>(src.self(), axis, keepdim);
ExpInfo<SrcExp>::kDim- 1>(src.self(), axis);
}

/*!
* \brief reduce out the dimension of src labeled by axis, keepdim turned on.
* \param Reducer type of the reducing operation
* \param mask whether to output the unmask indices
* \tparam SrcExp source expression
* \tparam DType data type
* \tparam etype type of the expression
*/
template<typename Reducer, bool mask, typename SrcExp, typename DType, int etype>
inline ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
ExpInfo<SrcExp>::kDim>
reduce_keepdim(const Exp<SrcExp, DType, etype> &src, int axis) {
return ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
ExpInfo<SrcExp>::kDim>(src.self(), axis);
}

//----------------------
// Execution plan
//----------------------
Expand Down
8 changes: 4 additions & 4 deletions test/test_tblob.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void test_broadcast_with_axis() {
input_tensor = 11;
mshadow::Tensor<mshadow::cpu, 4> n_tensor(NULL, test_shapes[dim + 1]);
mshadow::AllocSpace(&n_tensor);
n_tensor = broadcast_with_axis<0>(input_tensor, dim, 5);
n_tensor = broadcast_with_axis(input_tensor, dim, 5);
printf("Test for keepdim = 0, dim = %d", dim);
for (index_t i = 0; i < n_tensor.shape_[0]; i++) {
for (index_t j = 0; j < n_tensor.shape_[1]; j++) {
Expand All @@ -100,7 +100,7 @@ void test_broadcast_with_axis() {
input_tensor = 11;
mshadow::Tensor<mshadow::cpu, 4> n_tensor(NULL, test_shapes[dim]);
mshadow::AllocSpace(&n_tensor);
n_tensor = broadcast_with_axis<1>(input_tensor, dim, 5);
n_tensor = broadcast_keepdim(input_tensor, dim, 5);
printf("Test for keepdim = 1, dim = %d", dim);
for (index_t i = 0; i < n_tensor.shape_[0]; i++) {
for (index_t j = 0; j < n_tensor.shape_[1]; j++) {
Expand Down Expand Up @@ -136,7 +136,7 @@ void test_reduce_with_axis() {
input_tensor = 1;
mshadow::Tensor<mshadow::cpu, 3> n_tensor(NULL, mshadow::Shape3(2, 3, 4));
mshadow::AllocSpace(&n_tensor);
n_tensor = reduce_with_axis<mshadow::red::sum, false, 0>(input_tensor, dim);
n_tensor = reduce_with_axis<mshadow::red::sum, false>(input_tensor, dim);
printf("Test for keepdim = 0, dim = %d", dim);
for (index_t i = 0; i < n_tensor.shape_[0]; i++) {
for (index_t j = 0; j < n_tensor.shape_[1]; j++) {
Expand All @@ -154,7 +154,7 @@ void test_reduce_with_axis() {
input_tensor = 1;
mshadow::Tensor<mshadow::cpu, 4> n_tensor(NULL, keepdim_output_shapes[dim]);
mshadow::AllocSpace(&n_tensor);
n_tensor = reduce_with_axis<mshadow::red::sum, false, 1>(input_tensor, dim);
n_tensor = reduce_keepdim<mshadow::red::sum, false>(input_tensor, dim);
printf("Test for keepdim = 1, dim = %d", dim);
for (index_t i = 0; i < n_tensor.shape_[0]; i++) {
for (index_t j = 0; j < n_tensor.shape_[1]; j++) {
Expand Down

0 comments on commit 997d042

Please sign in to comment.