Skip to content

Commit

Permalink
add simple upsampling
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Nov 28, 2015
1 parent cce0b32 commit bc68bfd
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 0 deletions.
15 changes: 15 additions & 0 deletions guide/basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,21 @@ int main(void) {
printf("\n");
}

printf("upsampling\n");
TensorContainer<cpu, 2> small(Shape2(2, 2));
small[0][0] = 1.0f;
small[0][1] = 2.0f;
small[1][0] = 3.0f;
small[1][1] = 4.0f;
TensorContainer<cpu, 2> large(Shape2(6, 6));
large = upsampling(small, 3);
for (index_t i = 0; i < large.size(0); ++i) {
for (index_t j = 0; j < large.size(1); ++j) {
printf("%.2f ", large[i][j]);
}
printf("\n");
}

// shutdown tensor enigne after usage
ShutdownTensorEngine<cpu>();
return 0;
Expand Down
1 change: 1 addition & 0 deletions mshadow/extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@
#include "./extension/slice.h"
#include "./extension/take.h"
#include "./extension/take_grad.h"
#include "./extension/spatial_upsampling.h"
#endif // MSHADOW_EXTENSION_H_
71 changes: 71 additions & 0 deletions mshadow/extension/spatial_upsampling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*!
* Copyright (c) 2015 by Contributors
* \file spatial_upsampling.h
* \brief
* \author Bing Xu
*/
#ifndef MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_H_
#define MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_H_
#include "../extension.h"

namespace mshadow {
namespace expr {

/*! \brief nearest neighboor upsampling
* out(x, y) = in(int(x / scale_x), int(y / scale_y))
* \tparam SrcExp source expression
* \tparam DType data type
* \tparam srcdim source dimension
*/
template<typename SrcExp, typename DType, int srcdim>
struct UpSamplingExp :
public MakeTensorExp<UpSamplingExp<SrcExp, DType, srcdim>,
SrcExp, srcdim, DType> {
/*! \brief source oprand */
const SrcExp &src_;
/*! \brief up sampling scale */
index_t scale_;
/*! \brief constructor */
UpSamplingExp(const SrcExp &src, index_t scale)
: src_(src), scale_(scale) {
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
this->shape_[srcdim - 2] *= scale_;
this->shape_[srcdim - 1] *= scale_;
}
};


template<typename SrcExp, typename DType, int etype>
inline UpSamplingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
upsampling(const Exp<SrcExp, DType, etype> &src, index_t scale) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return UpSamplingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), scale);
}

template<typename SrcExp, typename DType, int srcdim>
struct Plan<UpSamplingExp<SrcExp, DType, srcdim>, DType> {
public:
explicit Plan(const UpSamplingExp<SrcExp, DType, srcdim> &e)
: src_(MakePlan(e.src_)),
scale_(e.scale_),
new_height_(e.shape_[srcdim - 2]),
src_height_(static_cast<index_t>(e.shape_[srcdim - 2] / 2)) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
const index_t x = j;
const index_t y = i % new_height_;
const index_t c = i / new_height_;
const index_t h = static_cast<index_t>(y / scale_);
const index_t w = static_cast<index_t>(x / scale_);
return src_.Eval(c * src_height_ + h, w);
}

private:
Plan<SrcExp, DType> src_;
const index_t scale_;
const index_t new_height_;
const index_t src_height_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_H_

0 comments on commit bc68bfd

Please sign in to comment.