Skip to content

Commit

Permalink
[OPENVINO] fix tflite innerproduct error
Browse files Browse the repository at this point in the history
  • Loading branch information
seanxcwang committed Apr 8, 2021
1 parent c0a32c5 commit 842cbd0
Showing 1 changed file with 27 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
#include <cmath>
#include <memory>

#include <ngraph/node.hpp>
#include <inference_engine.hpp>
#include <ngraph/ngraph.hpp>
#include <ngraph/node.hpp>
#include <ngraph/op/op.hpp>
#include <ngraph/opsets/opset.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <inference_engine.hpp>

#include "tnn/layer/base_layer.h"
#include "tnn/network/openvino/layer_builder/openvino_layer_builder.h"
#include "tnn/extern_wrapper/foreign_blob.h"
#include "tnn/extern_wrapper/foreign_tensor.h"
#include "tnn/layer/base_layer.h"
#include "tnn/network/openvino/layer_builder/openvino_layer_builder.h"
#include "tnn/network/openvino/openvino_types.h"
#include "tnn/utils/dims_utils.h"

Expand All @@ -34,19 +34,18 @@ namespace TNN_NS {
DECLARE_OPENVINO_LAYER_BUILDER(InnerProduct, LAYER_INNER_PRODUCT);

Status InnerProductOVLayerBuilder::Build() {

auto paramlist = dynamic_cast<InnerProductLayerParam *>(param_);
auto resource = dynamic_cast<InnerProductLayerResource *>(resource_);
if (GetInputNodes().size() <=0) {

if (GetInputNodes().size() <= 0) {
LOGE("Error: 0 input nodes\n");
return TNNERR_INIT_LAYER;
}
auto input_node = GetInputNodes()[0];

auto get_shape_count = [&](const ngraph::Shape &shape, int axis) -> size_t {
size_t res = 1;
for (int i=axis; i < shape.size();i++)
for (int i = axis; i < shape.size(); i++)
res *= shape[i];
return res;
};
Expand All @@ -58,35 +57,19 @@ Status InnerProductOVLayerBuilder::Build() {
matShape.push_back(m);
matShape.push_back(n);

auto reshapeConstNode = std::make_shared<ngraph::op::Constant>(
ngraph::element::Type_t::i32, ngraph::Shape({2}), matShape);

auto reshapeNode = std::make_shared<ngraph::op::v1::Reshape>(
input_node->output(0), reshapeConstNode, true);

ngraph::Shape weightsShape;

auto weightsNode = std::make_shared<ngraph::op::Constant>(
ngraph::element::Type_t::f32, ngraph::Shape({k, n}), resource->weight_handle.force_to<float *>());

auto matMulNode = std::make_shared<ngraph::op::MatMul>(
reshapeNode->output(0), weightsNode->output(0), false, !paramlist->transpose);

std::vector<int> matReshape;
matReshape.push_back(m);
matReshape.push_back(k);
matReshape.push_back(1);
matReshape.push_back(1);

auto reverseReshapeConstNode = std::make_shared<ngraph::op::Constant>(
ngraph::element::Type_t::i32, ngraph::Shape({4}), matReshape);

auto reverseReshapeNode = std::make_shared<ngraph::op::v1::Reshape>(
matMulNode->output(0), reverseReshapeConstNode, true);

auto reshapeConstNode =
std::make_shared<ngraph::op::Constant>(ngraph::element::Type_t::i32, ngraph::Shape({2}), matShape);

auto reshapeNode = std::make_shared<ngraph::op::v1::Reshape>(input_node->output(0), reshapeConstNode, true);

auto weightsNode = std::make_shared<ngraph::op::Constant>(ngraph::element::Type_t::f32, ngraph::Shape({k, n}),
resource->weight_handle.force_to<float *>());

auto matMulNode = std::make_shared<ngraph::op::MatMul>(reshapeNode->output(0), weightsNode->output(0), false, true);

if (paramlist->has_bias) {
ngraph::Shape biasShape;
auto output_shape = reverseReshapeNode->get_output_shape(0);
auto output_shape = matMulNode->get_output_shape(0);
for (int i = 0; i < output_shape.size(); i++) {
if (i == paramlist->axis) {
biasShape.push_back(output_shape.at(i));
Expand All @@ -95,12 +78,11 @@ Status InnerProductOVLayerBuilder::Build() {
}
}

auto biasNode = std::make_shared<ngraph::op::Constant>(
ngraph::element::Type_t::f32, biasShape, resource->bias_handle.force_to<float*>());

auto addNode = std::make_shared<ngraph::op::v1::Add>(
reverseReshapeNode->output(0), biasNode->output(0));

auto biasNode = std::make_shared<ngraph::op::Constant>(ngraph::element::Type_t::f32, biasShape,
resource->bias_handle.force_to<float *>());

auto addNode = std::make_shared<ngraph::op::v1::Add>(matMulNode->output(0), biasNode->output(0));

addNode->set_friendly_name(paramlist->name);
addNode->validate_and_infer_types();

Expand All @@ -109,15 +91,15 @@ Status InnerProductOVLayerBuilder::Build() {
SetOutputTensors(outputNodes);

} else {
reverseReshapeNode->set_friendly_name(paramlist->name);
reverseReshapeNode->validate_and_infer_types();
matMulNode->set_friendly_name(paramlist->name);
matMulNode->validate_and_infer_types();

ngraph::NodeVector outputNodes;
outputNodes.push_back(reverseReshapeNode);
outputNodes.push_back(matMulNode);
SetOutputTensors(outputNodes);
}
return TNN_OK;
}

REGISTER_OPENVINO_LAYER_BUILDER(InnerProduct, LAYER_INNER_PRODUCT);
}
} // namespace TNN_NS

0 comments on commit 842cbd0

Please sign in to comment.