From b4b35590c46dc6ea753a800e9bc7894c60f4c65e Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 3 Aug 2017 11:15:20 +0800 Subject: [PATCH] Mat reshape across dims --- src/layer/reshape.cpp | 3 ++ src/mat.h | 64 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/src/layer/reshape.cpp b/src/layer/reshape.cpp index b4cbfe7c0ba..05de0efb07a 100644 --- a/src/layer/reshape.cpp +++ b/src/layer/reshape.cpp @@ -145,6 +145,9 @@ int Reshape::forward(const Mat& bottom_blob, Mat& top_blob) const top_blob = bottom_blob.reshape(_w, _h, _c); } + if (top_blob.empty()) + return -100; + return 0; } diff --git a/src/mat.h b/src/mat.h index e07723adc55..cb077823f62 100644 --- a/src/mat.h +++ b/src/mat.h @@ -338,6 +338,25 @@ inline Mat Mat::clone() const inline Mat Mat::reshape(int _w) const { + if (w * h * c != _w) + return Mat(); + + if (dims == 3 && cstep != (size_t)w * h) + { + Mat m; + m.create(_w); + + // flatten + for (int i=0; i> 2) + { + Mat m; + m.create(_w, _h, _c); + + // align channel + for (int i=0; i<_c; i++) + { + const float* ptr = data + i * _w * _h; + float* mptr = m.data + i * m.cstep; + memcpy(mptr, ptr, _w * _h * sizeof(float)); + } + + return m; + } + + if (c != _c) + { + // flatten and then align + Mat tmp = reshape(_w * _h * _c); + return tmp.reshape(_w, _h, _c); + } + Mat m = *this; m.dims = 3;