Skip to content

Commit

Permalink
Merge pull request opencv#22148 from zihaomu:gemm_onnx_bug_fix_branch34
Browse files Browse the repository at this point in the history
  • Loading branch information
alalek authored Jun 23, 2022
2 parents dd7b900 + ef94275 commit 6234f01
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
8 changes: 4 additions & 4 deletions modules/dnn/src/onnx/onnx_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1759,15 +1759,15 @@ void ONNXImporter::parseBatchNormalization(LayerParams& layerParams, const openc
addLayer(layerParams, node_proto);
}

// A * B + C = Y, we require that the dimension of A is [m, k], and the dimension of B is [n, k].
// And the dim of output Y is [m, n]
void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.input_size() >= 2);
layerParams.type = "InnerProduct";
Mat weights = getBlob(node_proto, 1);
int ind_num_out = 0;
if (layerParams.has("transB") && !layerParams.get<int>("transB")) {
if (!layerParams.get<int>("transB", 0)) {
transpose(weights, weights);
ind_num_out = 1;
}
layerParams.blobs.push_back(weights);

Expand All @@ -1789,7 +1789,7 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr
addLayer(constParams, proto);
}

layerParams.set("num_output", layerParams.blobs[0].size[ind_num_out]);
layerParams.set("num_output", layerParams.blobs[0].size[0]);
layerParams.set("bias_term", node_proto.input_size() == 3);
addLayer(layerParams, node_proto);
}
Expand Down
6 changes: 6 additions & 0 deletions modules/dnn/test/test_onnx_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,12 @@ TEST_P(Test_ONNX_layers, DivConst)
testONNXModels("div_const");
}

TEST_P(Test_ONNX_layers, Gemm)
{
testONNXModels("gemm_no_transB");
testONNXModels("gemm_transB_0");
}

TEST_P(Test_ONNX_layers, OutputRegistration)
{
testONNXModels("output_registration", npy, 0, 0, false, true, 2);
Expand Down

0 comments on commit 6234f01

Please sign in to comment.