Skip to content

Commit

Permalink
Merge pull request opencv#18841 from JulienMaille:patch-2
Browse files Browse the repository at this point in the history
Fixing dnn Resize layer for variable input size

* Fix onnx loading of resize/upsample layers for different opset

* group all DynamicResize tests

* cleaned up scales checks

* Simplify branching
  • Loading branch information
JulienMaille authored Nov 20, 2020
1 parent bb067c7 commit ac24a72
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 36 deletions.
76 changes: 41 additions & 35 deletions modules/dnn/src/onnx/onnx_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1746,43 +1746,45 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
for (int i = 1; i < node_proto.input_size(); i++)
CV_Assert(layer_id.find(node_proto.input(i)) == layer_id.end());

String interp_mode;
if (layerParams.has("coordinate_transformation_mode"))
interp_mode = layerParams.get<String>("coordinate_transformation_mode");
else
interp_mode = layerParams.get<String>("mode");
CV_Assert_N(interp_mode != "tf_crop_and_resize", interp_mode != "tf_half_pixel_for_nn");

layerParams.set("align_corners", interp_mode == "align_corners");
Mat shapes = getBlob(node_proto, node_proto.input_size() - 1);
CV_CheckEQ(shapes.size[0], 4, "");
CV_CheckEQ(shapes.size[1], 1, "");
CV_CheckDepth(shapes.depth(), shapes.depth() == CV_32S || shapes.depth() == CV_32F, "");
if (shapes.depth() == CV_32F)
shapes.convertTo(shapes, CV_32S);
int height = shapes.at<int>(2);
int width = shapes.at<int>(3);
if (hasDynamicShapes)
{
layerParams.set("zoom_factor_x", width);
layerParams.set("zoom_factor_y", height);
String interp_mode = layerParams.get<String>("coordinate_transformation_mode");
CV_Assert_N(interp_mode != "tf_crop_and_resize", interp_mode != "tf_half_pixel_for_nn");

layerParams.set("align_corners", interp_mode == "align_corners");
if (layerParams.get<String>("mode") == "linear")
{
layerParams.set("mode", interp_mode == "pytorch_half_pixel" ?
"opencv_linear" : "bilinear");
}
}
if (layerParams.get<String>("mode") == "linear" && framework_name == "pytorch")
layerParams.set("mode", "opencv_linear");

// input = [X, scales], [X, roi, scales] or [x, roi, scales, sizes]
int foundScaleId = hasDynamicShapes ? node_proto.input_size() - 1
: node_proto.input_size() > 2 ? 2 : 1;

Mat scales = getBlob(node_proto, foundScaleId);
if (scales.total() == 4)
{
layerParams.set("zoom_factor_y", scales.at<float>(2));
layerParams.set("zoom_factor_x", scales.at<float>(3));
}
else
{
if (node_proto.input_size() == 3) {
IterShape_t shapeIt = outShapes.find(node_proto.input(0));
CV_Assert(shapeIt != outShapes.end());
MatShape scales = shapeIt->second;
height *= scales[2];
width *= scales[3];
const std::string& inputLast = node_proto.input(node_proto.input_size() - 1);
if (constBlobs.find(inputLast) != constBlobs.end())
{
Mat shapes = getBlob(inputLast);
CV_CheckEQ(shapes.size[0], 4, "");
CV_CheckEQ(shapes.size[1], 1, "");
CV_CheckDepth(shapes.depth(), shapes.depth() == CV_32S || shapes.depth() == CV_32F, "");
if (shapes.depth() == CV_32F)
shapes.convertTo(shapes, CV_32S);
layerParams.set("width", shapes.at<int>(3));
layerParams.set("height", shapes.at<int>(2));
}
layerParams.set("width", width);
layerParams.set("height", height);
}

if (layerParams.get<String>("mode") == "linear") {
layerParams.set("mode", interp_mode == "pytorch_half_pixel" ?
"opencv_linear" : "bilinear");
}
replaceLayerParam(layerParams, "mode", "interpolation");
}
Expand Down Expand Up @@ -1822,10 +1824,14 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
else
{
// scales as input
Mat scales = getBlob(node_proto, 1);
CV_Assert(scales.total() == 4);
layerParams.set("zoom_factor_y", scales.at<float>(2));
layerParams.set("zoom_factor_x", scales.at<float>(3));
const std::string& input1 = node_proto.input(1);
if (constBlobs.find(input1) != constBlobs.end())
{
Mat scales = getBlob(input1);
CV_Assert(scales.total() == 4);
layerParams.set("zoom_factor_y", scales.at<float>(2));
layerParams.set("zoom_factor_x", scales.at<float>(3));
}
}
replaceLayerParam(layerParams, "mode", "interpolation");
}
Expand Down
7 changes: 6 additions & 1 deletion modules/dnn/test/test_onnx_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,12 @@ TEST_P(Test_ONNX_layers, Broadcast)

TEST_P(Test_ONNX_layers, DynamicResize)
{
testONNXModels("dynamic_resize", npy, 0, 0, false, true, 2);
testONNXModels("dynamic_resize_9", npy, 0, 0, false, true, 2);
testONNXModels("dynamic_resize_10", npy, 0, 0, false, true, 2);
testONNXModels("dynamic_resize_11", npy, 0, 0, false, true, 2);
testONNXModels("dynamic_resize_scale_9", npy, 0, 0, false, true, 2);
testONNXModels("dynamic_resize_scale_10", npy, 0, 0, false, true, 2);
testONNXModels("dynamic_resize_scale_11", npy, 0, 0, false, true, 2);
}

TEST_P(Test_ONNX_layers, Div)
Expand Down

0 comments on commit ac24a72

Please sign in to comment.