Skip to content

Commit

Permalink
Merge pull request dmlc#116 from dali-ml/reduce_with_axis_fix
Browse files Browse the repository at this point in the history
fix a bug which causes reduce_with_axis to fail when reducing over the last axis
  • Loading branch information
piiswrong committed May 16, 2016
2 parents 79283e0 + 378f8dd commit cc9b210
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions mshadow/extension/reduce_with_axis.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ struct ReduceWithAxisExp:
SrcExp, srcdim-1, DType> {
/*! \brief source oprand */
const SrcExp &src_;
/*! \brief size of leading dimensions */
index_t leading_;
/*! \brief size of last destination dimension */
index_t last_dst_dim_;
/*! \brief size of trailing dimensions */
index_t trailing_;
/*! \brief size of axis dimension */
Expand All @@ -36,9 +36,7 @@ struct ReduceWithAxisExp:
: src_(src) {
CHECK(srcdim > axis) << "reduce axis out of bound";
Shape<srcdim> src_shape = ShapeCheck<srcdim, SrcExp>::Check(src_);
this->leading_ = 1;
for (index_t i = 0; i < axis; ++i) {
this->leading_ *= src_shape[i];
this->shape_[i] = src_shape[i];
}
this->size_ = src_shape[axis];
Expand All @@ -48,6 +46,11 @@ struct ReduceWithAxisExp:
this->shape_[i-1] = src_shape[i];
}
this->last_ = src_shape[srcdim-1];
if (axis == srcdim -1) {
this->last_dst_dim_ = src_shape[srcdim-2];
} else {
this->last_dst_dim_ = src_shape[srcdim-1];
}
}
}; // struct ReduceWithAxisExp

Expand All @@ -71,11 +74,11 @@ template<typename Reducer, typename SrcExp, typename DType, int srcdim, bool mas
struct Plan<ReduceWithAxisExp<Reducer, SrcExp, DType, srcdim, mask>, DType> {
public:
explicit Plan(const ReduceWithAxisExp<Reducer, SrcExp, DType, srcdim, mask> &e)
: src_(MakePlan(e.src_)), leading_(e.leading_), trailing_(e.trailing_),
: src_(MakePlan(e.src_)), last_dst_dim_(e.last_dst_dim_), trailing_(e.trailing_),
size_(e.size_), last_(e.last_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
index_t x = (i*last_ + j)/trailing_;
index_t y = (i*last_ + j)%trailing_;
index_t x = (i*last_dst_dim_ + j)/trailing_;
index_t y = (i*last_dst_dim_ + j)%trailing_;

if (mask) {
index_t idx = 0;
Expand All @@ -101,7 +104,7 @@ struct Plan<ReduceWithAxisExp<Reducer, SrcExp, DType, srcdim, mask>, DType> {

private:
Plan<SrcExp, DType> src_;
const index_t leading_, trailing_, size_, last_;
const index_t last_dst_dim_, trailing_, size_, last_;
};
} // namespace expr
} // namespace mshadow
Expand Down

0 comments on commit cc9b210

Please sign in to comment.