Skip to content

Commit

Permalink
fix Flatten layer
Browse files Browse the repository at this point in the history
  • Loading branch information
rogday committed Dec 17, 2021
1 parent f071207 commit fec2c7e
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 36 deletions.
1 change: 0 additions & 1 deletion modules/dnn/src/layers/flatten_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ class FlattenLayerImpl CV_FINAL : public FlattenLayer
{
outputShapeVec.push_back(inputs[0][i]);
}
CV_Assert(outputShapeVec.size() <= 4);

outputs.resize(inputs.size(), outputShapeVec);

Expand Down
57 changes: 52 additions & 5 deletions modules/dnn/src/onnx/onnx_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1781,20 +1781,67 @@ void ONNXImporter::parseSqueeze(LayerParams& layerParams, const opencv_onnx::Nod
addLayer(layerParams, node_proto);
}

void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
CV_CheckEQ(node_proto.input_size(), 1, "");
int axis_ = layerParams.get<int>("axis", 1);
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{
Mat input = getBlob(node_proto, 0);
int axis = normalize_axis(layerParams.get<int>("axis", 1), input.dims);
int axis = normalize_axis(axis_, input.dims);

std::vector<int> out_size(&input.size[0], &input.size[0] + axis);
out_size.push_back(input.total(axis));
Mat output = input.reshape(1, out_size);
int out_size[2] = {1, 1};
for (int i = 0; i < axis; ++i)
{
out_size[0] *= input.size[i];
}
for (int i = axis; i < input.dims; ++i)
{
out_size[1] *= input.size[i];
}

Mat output = input.reshape(1, 2, out_size);
addConstant(layerParams.name, output);
return;
}
IterShape_t shapeIt = outShapes.find(node_proto.input(0));
CV_Assert(shapeIt != outShapes.end());
MatShape inpShape = shapeIt->second;
int axis = normalize_axis(axis_, inpShape.size());

if (axis == 0 || axis == inpShape.size())
{
LayerParams reshapeLp;
reshapeLp.name = layerParams.name + "/reshape";
reshapeLp.type = "Reshape";
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());

inpShape.insert(axis == 0 ? inpShape.begin() : inpShape.end(), 1);
reshapeLp.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));

opencv_onnx::NodeProto proto;
proto.add_input(node_proto.input(0));
proto.add_output(reshapeLp.name);
addLayer(reshapeLp, proto);
node_proto.set_input(0, reshapeLp.name);
axis += 1;
}

LayerParams first_pass;
first_pass.name = layerParams.name + "/flatten";
CV_Assert(layer_id.find(first_pass.name) == layer_id.end());
first_pass.type = "Flatten";
first_pass.set("axis", 0);
first_pass.set("end_axis", axis - 1);

opencv_onnx::NodeProto proto;
proto.add_input(node_proto.input(0));
proto.add_output(first_pass.name);
addLayer(first_pass, proto);

layerParams.set("axis", 1);
node_proto.set_input(0, first_pass.name);
addLayer(layerParams, node_proto);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,6 @@
"test_elu",
"test_elu_default",
"test_exp",
"test_flatten_axis0",
"test_flatten_axis2",
"test_flatten_axis3",
"test_flatten_negative_axis1",
"test_flatten_negative_axis2",
"test_flatten_negative_axis4",
"test_leakyrelu",
"test_leakyrelu_default",
"test_logsoftmax_axis_1",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,35 +561,23 @@ CASE(test_eyelike_with_dtype)
CASE(test_eyelike_without_dtype)
// no filter
CASE(test_flatten_axis0)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
SKIP;
#endif
// no filter
CASE(test_flatten_axis1)
// no filter
CASE(test_flatten_axis2)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
SKIP;
#endif
// no filter
CASE(test_flatten_axis3)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
SKIP;
#endif
// no filter
CASE(test_flatten_default_axis)
// no filter
CASE(test_flatten_negative_axis1)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
SKIP;
#endif
// no filter
CASE(test_flatten_negative_axis2)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
SKIP;
#endif
// no filter
CASE(test_flatten_negative_axis3)
// no filter
CASE(test_flatten_negative_axis4)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
SKIP;
#endif
// no filter
CASE(test_floor)
// no filter
CASE(test_floor_example)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,6 @@
"test_castlike_FLOAT_to_STRING_expanded",
"test_castlike_STRING_to_FLOAT_expanded",
"test_concat_1d_axis_negative_1",
"test_flatten_axis0",
"test_flatten_axis2",
"test_flatten_axis3",
"test_flatten_negative_axis1",
"test_flatten_negative_axis2",
"test_flatten_negative_axis4",
"test_logsoftmax_default_axis",
"test_maxpool_2d_dilations",
"test_maxpool_2d_same_lower",
Expand Down

0 comments on commit fec2c7e

Please sign in to comment.