Skip to content

Commit

Permalink
Several fixes for ONNX importer: Expand, Gather
Browse files Browse the repository at this point in the history
  • Loading branch information
dkurt committed Mar 27, 2023
1 parent 352f92e commit 5e1d333
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
15 changes: 11 additions & 4 deletions modules/dnn/src/onnx/onnx_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2435,12 +2435,18 @@ void ONNXImporter::parseExpand(LayerParams& layerParams, const opencv_onnx::Node
}
else
{
inpShape = shape(getBlob(input0));
Mat blob = getBlob(input0);
if (constBlobsExtraInfo.find(node_proto.input(0)) != constBlobsExtraInfo.end() &&
getBlobExtraInfo(node_proto, 0).real_ndims == 1) {
inpShape = {(int)blob.total()};
} else {
inpShape = shape(blob);
}
}

String srcName = input0;
// Unsqueeze and repeat along new axis
if (targetShape.size() == inpShape.size() + 1)
if (targetShape.size() > inpShape.size())
{
inpShape.insert(inpShape.begin(), targetShape.size() - inpShape.size(), 1);
for (int i = 0; i < targetShape.size(); i++)
Expand Down Expand Up @@ -2486,7 +2492,7 @@ void ONNXImporter::parseExpand(LayerParams& layerParams, const opencv_onnx::Node
{
if (broadcast_axes.empty())
{
addConstant(output_name, getBlob(node_proto, 0));
addConstant(output_name, getBlob(node_proto, 0).reshape(1, targetShape));
return;
}

Expand Down Expand Up @@ -2719,7 +2725,8 @@ void ONNXImporter::parseGather(LayerParams& layerParams, const opencv_onnx::Node

runLayer(layerParams, inputs, output);
output.back().convertTo(output.back(), type);
output.back().dims = std::max(input_real_ndims - real_ndims, 1);
if (real_ndims < 2) // In case of scalars or 1D vectors, OpenCV initializes 2D cv::Mat
output.back().dims = std::max(input_real_ndims - real_ndims, 1);
addConstant(node_proto.output(0), output.back());
return;
}
Expand Down
5 changes: 5 additions & 0 deletions modules/dnn/test/test_onnx_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2487,6 +2487,11 @@ TEST_P(Test_ONNX_layers, Gelu)
testONNXModels("gelu_approximation");
}

TEST_P(Test_ONNX_layers, OpenAI_CLIP_head)
{
testONNXModels("clip-vit-base-head");
}

INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets());

}} // namespace

0 comments on commit 5e1d333

Please sign in to comment.