Skip to content

Commit

Permalink
Support Gather for variable inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
l-bat committed Jul 20, 2020
1 parent 284d26d commit a35d4f9
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 13 deletions.
58 changes: 45 additions & 13 deletions modules/dnn/src/onnx/onnx_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1342,32 +1342,64 @@ void ONNXImporter::populateNet(Net dstNet)
else if (layer_type == "Gather")
{
CV_Assert(node_proto.input_size() == 2);
Mat input = getBlob(node_proto, constBlobs, 0);
Mat indexMat = getBlob(node_proto, constBlobs, 1);
CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
int index = indexMat.at<int>(0);
int axis = layerParams.get<int>("axis", 0);

Mat out;
if (layerParams.has("axis"))
if ((constBlobs.find(node_proto.input(0)) != constBlobs.end()))
{
int axis = layerParams.get<int>("axis");

Mat input = getBlob(node_proto, constBlobs, 0);
Mat out;
std::vector<cv::Range> ranges(input.dims, Range::all());
ranges[axis] = Range(index, index + 1);

out = input(ranges);
MatShape outShape = shape(out);
if (outShape.size() > 1)
{
outShape.erase(outShape.begin() + axis);
out.reshape(0, outShape);
}
addConstant(layerParams.name, out, constBlobs, outShapes);
continue;
}
else
{
CV_Assert(index < input.total());
const int dims = input.dims;
input = input.reshape(1, 1);
input.dims = 2;
out = input.reshape(1, 1).colRange(index, index + 1);
out.dims = dims;
shapeIt = outShapes.find(node_proto.input(0));
CV_Assert(shapeIt != outShapes.end());
MatShape inpShape = shapeIt->second;

LayerParams sliceLp;
sliceLp.type = "Slice";
sliceLp.name = inpShape.size() > 1 ? layerParams.name + "/slice" : layerParams.name;
std::vector<int> begin(inpShape.size(), 0);
std::vector<int> end(inpShape.size(), -1);
begin[axis] = index;
end[axis] = index + 1;

cv::dnn::DictValue paramBegin = cv::dnn::DictValue::arrayInt(begin.data(), begin.size());
cv::dnn::DictValue paramEnd = cv::dnn::DictValue::arrayInt(end.data(), end.size());
sliceLp.set("begin", paramBegin);
sliceLp.set("end", paramEnd);

if (inpShape.size() > 1)
{
opencv_onnx::NodeProto proto;
proto.add_input(node_proto.input(0));
proto.add_output(sliceLp.name);
addLayer(dstNet, sliceLp, proto, layer_id, outShapes);

inpShape.erase(inpShape.begin() + axis);
layerParams.type = "Reshape";
layerParams.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
node_proto.set_input(0, sliceLp.name);
}
else
{
layerParams = sliceLp;
}
}
addConstant(layerParams.name, out, constBlobs, outShapes);
continue;
}
else if (layer_type == "Concat")
{
Expand Down
11 changes: 11 additions & 0 deletions modules/dnn/test/test_onnx_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ TEST_P(Test_ONNX_layers, Convolution)
testONNXModels("convolution");
}

TEST_P(Test_ONNX_layers, Gather)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target == DNN_TARGET_MYRIAD)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
testONNXModels("gather");
// GPU plugin unsupported slice for constant
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && (target == DNN_TARGET_OPENCL || target == DNN_TARGET_OPENCL_FP16))
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_OPENCL, CV_TEST_TAG_DNN_SKIP_IE_OPENCL_FP16, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
testONNXModels("gather_scalar", npy, 0, 0, false, false);
}

TEST_P(Test_ONNX_layers, Convolution3D)
{
#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_LT(2019010000)
Expand Down

0 comments on commit a35d4f9

Please sign in to comment.