Skip to content

Commit

Permalink
Mat reshape across dims
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Aug 3, 2017
1 parent 23630b1 commit b4b3559
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/layer/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ int Reshape::forward(const Mat& bottom_blob, Mat& top_blob) const
top_blob = bottom_blob.reshape(_w, _h, _c);
}

if (top_blob.empty())
return -100;

return 0;
}

Expand Down
64 changes: 64 additions & 0 deletions src/mat.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,25 @@ inline Mat Mat::clone() const

inline Mat Mat::reshape(int _w) const
{
if (w * h * c != _w)
return Mat();

if (dims == 3 && cstep != (size_t)w * h)
{
Mat m;
m.create(_w);

// flatten
for (int i=0; i<c; i++)
{
const float* ptr = data + i * cstep;
float* mptr = m.data + i * w * h;
memcpy(mptr, ptr, w * h * sizeof(float));
}

return m;
}

Mat m = *this;

m.dims = 1;
Expand All @@ -353,6 +372,25 @@ inline Mat Mat::reshape(int _w) const

inline Mat Mat::reshape(int _w, int _h) const
{
if (w * h * c != _w * _h)
return Mat();

if (dims == 3 && cstep != (size_t)w * h)
{
Mat m;
m.create(_w, _h);

// flatten
for (int i=0; i<c; i++)
{
const float* ptr = data + i * cstep;
float* mptr = m.data + i * w * h;
memcpy(mptr, ptr, w * h * sizeof(float));
}

return m;
}

Mat m = *this;

m.dims = 2;
Expand All @@ -368,6 +406,32 @@ inline Mat Mat::reshape(int _w, int _h) const

inline Mat Mat::reshape(int _w, int _h, int _c) const
{
if (w * h * c != _w * _h * _c)
return Mat();

if (dims < 3 && (size_t)_w * _h != alignSize(_w * _h * sizeof(float), 16) >> 2)
{
Mat m;
m.create(_w, _h, _c);

// align channel
for (int i=0; i<_c; i++)
{
const float* ptr = data + i * _w * _h;
float* mptr = m.data + i * m.cstep;
memcpy(mptr, ptr, _w * _h * sizeof(float));
}

return m;
}

if (c != _c)
{
// flatten and then align
Mat tmp = reshape(_w * _h * _c);
return tmp.reshape(_w, _h, _c);
}

Mat m = *this;

m.dims = 3;
Expand Down

0 comments on commit b4b3559

Please sign in to comment.