Skip to content

Commit

Permalink
Merge pull request opencv#21522 from rogday:lstm
Browse files Browse the repository at this point in the history
Fix LSTM support in ONNX

* fix LSTM and add peephole support

* disable old tests

* turn lambdas into functions

* more hacks for  c++98

* add assertions

* slice fixes

* backport of cuda-related fixes

* address review comments
  • Loading branch information
rogday authored Mar 15, 2022
1 parent 5d8134e commit 93353ae
Show file tree
Hide file tree
Showing 3 changed files with 404 additions and 60 deletions.
151 changes: 139 additions & 12 deletions modules/dnn/src/layers/recurrent_layers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ static ActivationFunction get_activation_function(const String& activation) {

class LSTMLayerImpl CV_FINAL : public LSTMLayer
{
int numTimeStamps, numSamples;
int numTimeStamps, numSamples, numHidden;
bool allocated;

MatShape outTailShape; //shape of single output sample
Expand All @@ -127,6 +127,10 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
bool useAVX2;
#endif

// CUDA needs input blobs to be rearranged in a specific way, but some transformations
// in ONNXImporter are destructive, so we keep a copy.
std::vector<Mat> originalBlobs;

public:

LSTMLayerImpl(const LayerParams& params)
Expand All @@ -140,6 +144,13 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
{
setParamsFrom(params);

if (params.get<bool>("is_onnx", false))
{
// collect copies of onnx blobs
originalBlobs.insert(originalBlobs.begin(), blobs.begin(), blobs.begin() + 3);
blobs.erase(blobs.begin(), blobs.begin() + 3);
}

bidirectional = params.get<bool>("bidirectional", false);
if (!blobs.empty())
{
Expand Down Expand Up @@ -181,6 +192,7 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
useCellClip = params.get<bool>("use_cell_clip", false);
usePeephole = params.get<bool>("use_peephole", false);
reverse = params.get<bool>("reverse", false);
numHidden = params.get<int>("hidden_size", 1);
CV_Assert(!reverse || !bidirectional);

// read activations
Expand Down Expand Up @@ -269,8 +281,21 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end());
outResShape.back() *= (1 + static_cast<int>(bidirectional));

size_t noutputs = produceCellOutput ? 2 : 1;
outputs.assign(noutputs, outResShape);
outputs.assign(1, outResShape);
if (produceCellOutput)
{
// the producer is ONNX, so CellState is different
if (!originalBlobs.empty())
{
int shp[] = {(1 + static_cast<int>(bidirectional)), _numSamples, numHidden};
MatShape newShape(shp, shp + sizeof(shp)/sizeof(shp[0]));
outputs.push_back(newShape);
}
else
{
outputs.push_back(outResShape);
}
}

internals.assign(1, shape(_numSamples, _numOut)); // hInternal
internals.push_back(shape(_numSamples, _numOut)); // cInternal
Expand Down Expand Up @@ -335,14 +360,39 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
outputs_arr.getMatVector(output);
internals_arr.getMatVector(internals);

Mat cOut = produceCellOutput ? output[0].clone() : Mat();
const bool needYcTransform = !originalBlobs.empty(); // if the producer is onnx
const int numDirs = 1 + static_cast<int>(bidirectional);
for (int i = 0; i < numDirs; ++i)
{
const Mat &Wh = blobs[0].rowRange(i * blobs[0].rows / numDirs, (i + 1) * blobs[0].rows / numDirs);
const Mat &Wx = blobs[1].rowRange(i * blobs[1].rows / numDirs, (i + 1) * blobs[1].rows / numDirs);
const Mat &bias = blobs[2].colRange(i * blobs[2].cols / numDirs, (i + 1) * blobs[2].cols / numDirs);
const Mat &h_0 = blobs[3].rowRange(i * blobs[3].rows / numDirs, (i + 1) * blobs[3].rows / numDirs);
const Mat &c_0 = blobs[4].rowRange(i * blobs[4].rows / numDirs, (i + 1) * blobs[4].rows / numDirs);
Mat Wh = blobs[0];
Mat Wx = blobs[1];
Mat bias = blobs[2];
Mat h_0 = blobs[3];
Mat c_0 = blobs[4];
Mat pI, pF, pO;

Wh = Wh.rowRange(i * Wh.rows / numDirs, (i + 1) * Wh.rows / numDirs);
Wx = Wx.rowRange(i * Wx.rows / numDirs, (i + 1) * Wx.rows / numDirs);
bias = bias.colRange(i * bias.cols / numDirs, (i + 1) * bias.cols / numDirs);
h_0 = h_0.rowRange(i * h_0.rows / numDirs, (i + 1) * h_0.rows / numDirs);
c_0 = c_0.rowRange(i * c_0.rows / numDirs, (i + 1) * c_0.rows / numDirs);

if (usePeephole)
{
pI = blobs[5];
pF = blobs[6];
pO = blobs[7];

pI = pI.rowRange(i * pI.rows / numDirs, (i + 1) * pI.rows / numDirs);
pI = pI.colRange(i * pI.cols / numDirs, (i + 1) * pI.cols / numDirs);

pF = pF.rowRange(i * pF.rows / numDirs, (i + 1) * pF.rows / numDirs);
pF = pF.colRange(i * pF.cols / numDirs, (i + 1) * pF.cols / numDirs);

pO = pO.rowRange(i * pO.rows / numDirs, (i + 1) * pO.rows / numDirs);
pO = pO.colRange(i * pO.cols / numDirs, (i + 1) * pO.cols / numDirs);
}

int numOut = Wh.size[1];
Mat hInternal = internals[0], cInternal = internals[1],
Expand All @@ -356,7 +406,12 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer

Mat hOutTs = output[0].reshape(1, numSamplesTotal);
hOutTs = hOutTs.colRange(i * hOutTs.cols / numDirs, (i + 1) * hOutTs.cols / numDirs);
Mat cOutTs = produceCellOutput ? output[1].reshape(1, numSamplesTotal) : Mat();
Mat cOutTs;
if (produceCellOutput)
{
cOutTs = cOut.reshape(1, numSamplesTotal);
cOutTs = cOutTs.colRange(i * cOutTs.cols / numDirs, (i + 1) * cOutTs.cols / numDirs);
}

#if CV_TRY_AVX2 || CV_TRY_AVX
bool canUseAvx = gates.isContinuous() && bias.isContinuous()
Expand Down Expand Up @@ -471,8 +526,8 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
if (usePeephole)
{
Mat gatesIF = gates.colRange(0, 2*numOut);
gemm(cInternal, blobs[5], 1, gateI, 1, gateI);
gemm(cInternal, blobs[6], 1, gateF, 1, gateF);
gemm(cInternal, pI, 1, gateI, 1, gateI);
gemm(cInternal, pF, 1, gateF, 1, gateF);
f_activation(gatesIF, gatesIF);
}
else
Expand All @@ -495,7 +550,7 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
}
if (usePeephole)
{
gemm(cInternal, blobs[7], 1, gateO, 1, gateO);
gemm(cInternal, pO, 1, gateO, 1, gateO);
f_activation(gateO, gateO);
}

Expand All @@ -509,6 +564,78 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
cInternal.copyTo(cOutTs.rowRange(curRowRange));
}
}

if (needYcTransform && produceCellOutput)
{
fixCellState(cOut, numDirs);
}
if (produceCellOutput)
{
cOut.copyTo(output[1]);
}
}

void fixCellState(Mat& cOut, int numDirs)
{
// seq, batch, dirs, hidden
int shp[] = {0, numSamples, numDirs, numHidden};
cOut = cOut.reshape(1, sizeof(shp)/sizeof(shp[0]), shp);

// permute to {0, 2, 1, 3};
std::vector<int> newShape = shape(cOut);
std::swap(newShape[1], newShape[2]);
cv::Mat newCellState(newShape, CV_32FC1);
const float* src = cOut.ptr<const float>();
float* dst = newCellState.ptr<float>();
size_t sj = newCellState.size[3];
size_t sk = newCellState.size[2] * sj;
size_t si = newCellState.size[1] * sk;
for (size_t i = 0; i < newCellState.size[0]; i++)
{
for (size_t j = 0; j < newCellState.size[2]; j++)
{
for (size_t k = 0; k < newCellState.size[1]; k++)
{
std::memcpy(dst, src, sizeof(float) * newCellState.size[3]);
src += cOut.size[3];
dst += sk;
}
dst = dst + sj - si;
}
dst = dst + si - sk;
}

cOut = newCellState;

if (numDirs == 1)
{
// Slice: Yh = Y[-1, :, :, :]
Range ranges[] = {cv::Range(cOut.size[0] - 1, cOut.size[0]), cv::Range::all(), cv::Range::all(), cv::Range::all()};
cOut = cOut(ranges);
// Reshape: 1x1xBxH -> 1xBxH
int shp[] = {1, numSamples, numHidden};
cOut = cOut.reshape(1, sizeof(shp)/sizeof(shp[0]), shp);
}
else
{
// Slice: SxDxBxH -> last sequence, first direction
Range ranges1[] = {cv::Range(cOut.size[0] - 1, cOut.size[0]), cv::Range(0, 1), cv::Range::all(), cv::Range::all()};
Mat part1 = cOut(ranges1);

// Slice: SxDxBxH -> first sequence, last direction
Range ranges2[] = {cv::Range(0, 1), cv::Range(cOut.size[1] - 1, cOut.size[1]), cv::Range::all(), cv::Range::all()};
Mat part2 = cOut(ranges2);

int shp[] = {1, part1.size[2] * part1.size[3]};
part1 = part1.reshape(1, sizeof(shp)/sizeof(shp[0]), shp);
part2 = part2.reshape(1, sizeof(shp)/sizeof(shp[0]), shp);

vconcat(part1, part2, cOut);

// Reshape: 1x2xBxH -> 2xBxH
int finalShape[] = {2, numSamples, numHidden};
cOut = cOut.reshape(1, sizeof(finalShape)/sizeof(finalShape[0]), finalShape);
}
}
};

Expand Down
Loading

0 comments on commit 93353ae

Please sign in to comment.