Skip to content

Commit

Permalink
Register activation with its id
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed Apr 11, 2019
1 parent 96eb179 commit 9b82c1f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
18 changes: 12 additions & 6 deletions tools/onnx2daq/OnnxConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ DNN::FuseCode OnnxConverter::ConvertFuseCodeType(FuseCode fuse_code) {
throw std::invalid_argument("Invalid FuseCode");
}

std::pair<nonstd::optional<std::string>, OnnxConverter::FuseCode>
std::pair<nonstd::optional<std::pair<int, ONNX_NAMESPACE::NodeProto>>,
OnnxConverter::FuseCode>
OnnxConverter::FindActivation(const ONNX_NAMESPACE::ModelProto &model_proto,
css &output_name) {
std::pair<nonstd::optional<string>, FuseCode> activation{
{}, FuseCode::FUSED_NONE};
std::pair<nonstd::optional<std::pair<int, ONNX_NAMESPACE::NodeProto>>,
FuseCode>
activation{{}, FuseCode::FUSED_NONE};
int i = 0;
for (const auto &_node : model_proto.graph().node()) {
if (!_node.input().empty() && output_name == _node.input(0) &&
_node.op_type() == "Relu") {
Expand All @@ -50,9 +53,11 @@ OnnxConverter::FindActivation(const ONNX_NAMESPACE::ModelProto &model_proto,
if (activation.second != FuseCode::FUSED_NONE) {
return {{}, FuseCode::FUSED_NONE};
}
activation = std::make_pair(nonstd::make_optional(_node.name()),
const auto node_pair = std::make_pair(i, _node);
activation = std::make_pair(nonstd::make_optional(node_pair),
FuseCode::FUSED_RELU);
}
i++;
}
if (activation.first.has_value()) {
skipped_act_.push_back(activation.first.value().first);
Expand Down Expand Up @@ -791,11 +796,12 @@ void OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto &model_proto,
const auto inputs = GetInputOfOnnxModel();

bool has_reshape = false;
for (const auto &node : model_proto_.graph().node()) {
for (int i = 0; i < model_proto_.graph().node_size(); i++) {
const auto &node = model_proto_.graph().node(i);
NodeAttrHelper helper(node);
const auto &op = node.op_type();
LOG(INFO) << "Node " << node.name();
if (std::find(skipped_act_.begin(), skipped_act_.end(), node.name()) !=
if (std::find(skipped_act_.begin(), skipped_act_.end(), i) !=
skipped_act_.end()) {
LOG(INFO) << "Skip layer " << node.name();
continue;
Expand Down
8 changes: 5 additions & 3 deletions tools/onnx2daq/OnnxConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class OnnxConverter {

ONNX_NAMESPACE::ModelProto model_proto_;
flatbuffers::FlatBufferBuilder builder_;
std::vector<std::string> skipped_act_;
std::vector<int> skipped_act_;
std::vector<std::string> dequantize_after_;

std::vector<std::string> operands_;
Expand All @@ -67,8 +67,10 @@ class OnnxConverter {
std::vector<flatbuffers::Offset<DNN::Tensor>> tensors_;

DNN::FuseCode ConvertFuseCodeType(FuseCode fuse_code);
std::pair<nonstd::optional<std::string>, FuseCode> FindActivation(
const ONNX_NAMESPACE::ModelProto &model_proto, css &output_name);
std::pair<nonstd::optional<std::pair<int, ONNX_NAMESPACE::NodeProto>>,
FuseCode>
FindActivation(const ONNX_NAMESPACE::ModelProto &model_proto,
css &output_name);
void CreateTensorFb(const Tensor &tensor, const DNN::DataType &data_type);
void CreateTensorFb(const std::string &name, const Tensor &tensor);
void CreateTensorFb(const std::string &name, const Tensor &tensor,
Expand Down

0 comments on commit 9b82c1f

Please sign in to comment.