Skip to content

Commit

Permalink
[ONNX importer] Add support for Usample-8 and Upsample-9 (openvinotoo…
Browse files Browse the repository at this point in the history
  • Loading branch information
mitruska authored Jun 23, 2020
1 parent 5ad1bf6 commit 5a2df9e
Show file tree
Hide file tree
Showing 13 changed files with 705 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ngraph/src/ngraph/frontend/onnx_import/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ add_library(onnx_importer SHARED
op/transpose.hpp
op/unsqueeze.cpp
op/unsqueeze.hpp
op/upsample.cpp
op/upsample.hpp
op/where.hpp
op/xor.hpp
ops_bridge.cpp
Expand Down
177 changes: 177 additions & 0 deletions ngraph/src/ngraph/frontend/onnx_import/op/upsample.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include <memory>

#include "default_opset.hpp"
#include "exceptions.hpp"
#include "upsample.hpp"

namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace
{
bool check_mode_support(const onnx_import::Node& node, const std::string mode)
{
const std::unordered_set<std::string> supported_modes = {"nearest", "linear"};
bool is_mode_supported =
(std::find(supported_modes.begin(), supported_modes.end(), mode) !=
supported_modes.end());

if (!is_mode_supported)
{
std::string supported_modes_str = "";
for (const auto& mode_name : supported_modes)
{
supported_modes_str += (mode_name + ", ");
}
CHECK_VALID_NODE(node,
is_mode_supported,
mode,
" - this type of interpolation mode is not supported."
" Choose one of the following modes: ",
supported_modes_str);
}
return is_mode_supported;
}
}

namespace set_1
{
NodeVector upsample(const onnx_import::Node& node)
{
const auto inputs = node.get_ng_inputs();
const auto data = inputs.at(0);

const auto data_shape = data->get_output_partial_shape(0);

const auto scales = node.get_attribute_value<std::vector<float>>("scales");
const auto mode = node.get_attribute_value<std::string>("mode", "nearest");
check_mode_support(node, mode);

auto attrs = ngraph::op::v0::InterpolateAttrs();
attrs.mode = mode;
attrs.align_corners = false;

for (size_t ax = 0; ax < scales.size(); ++ax)
{
attrs.axes.insert(ax);
}

if (data_shape.is_static())
{
auto data_static_shape = data_shape.to_shape();

std::vector<int64_t> output_shape;
for (size_t i = 0; i < data_static_shape.size(); ++i)
{
output_shape.push_back(
std::floor(data_static_shape.at(i) * scales.at(i)));
}
auto output_shape_const = default_opset::Constant::create(
element::u64, Shape({output_shape.size()}), output_shape);

return {std::make_shared<default_opset::Interpolate>(
data, output_shape_const, attrs)};
}

const auto scales_const = default_opset::Constant::create(
ngraph::element::f32, Shape({scales.size()}), scales);

auto shape_of_data = std::make_shared<default_opset::Convert>(
std::make_shared<default_opset::ShapeOf>(data), ngraph::element::f32);
auto multiply =
std::make_shared<default_opset::Multiply>(shape_of_data, scales_const);
auto output_shape = std::make_shared<default_opset::Convert>(
std::make_shared<default_opset::Floor>(multiply), ngraph::element::i64);
return {
std::make_shared<default_opset::Interpolate>(data, output_shape, attrs)};
}

} // namespace set_1

namespace set_9
{
NodeVector upsample(const onnx_import::Node& node)
{
const auto inputs = node.get_ng_inputs();
const auto data = inputs.at(0);
const auto scales = inputs.at(1);

const auto data_shape = data->get_output_partial_shape(0);
const auto scales_shape = scales->get_output_partial_shape(0);

const auto mode = node.get_attribute_value<std::string>("mode", "nearest");
check_mode_support(node, mode);

CHECK_VALID_NODE(
node,
(scales_shape.is_static() || data_shape.rank().is_static()),
" Data rank or shape of Scales input is required to be static.");

auto attrs = ngraph::op::v0::InterpolateAttrs();
attrs.mode = mode;
attrs.align_corners = false;

size_t axes_size = scales_shape.is_static() ? scales_shape.to_shape().at(0)
: data_shape.rank().get_length();
for (size_t ax = 0; ax < axes_size; ++ax)
{
attrs.axes.insert(ax);
}

if (scales->is_constant() && data_shape.is_static())
{
const auto scales_const =
as_type_ptr<default_opset::Constant>(scales->shared_from_this());

auto scales_vector = scales_const->cast_vector<float>();
auto data_static_shape = data_shape.to_shape();

std::vector<int64_t> output_shape;
for (size_t i = 0; i < data_static_shape.size(); ++i)
{
output_shape.push_back(
std::floor(data_static_shape.at(i) * scales_vector.at(i)));
}
auto output_shape_const = default_opset::Constant::create(
element::u64, Shape({output_shape.size()}), output_shape);

return {std::make_shared<default_opset::Interpolate>(
data, output_shape_const, attrs)};
}

auto shape_of_data = std::make_shared<default_opset::Convert>(
std::make_shared<default_opset::ShapeOf>(data), ngraph::element::f32);
auto multiply =
std::make_shared<default_opset::Multiply>(shape_of_data, scales);
auto output_shape = std::make_shared<default_opset::Convert>(
std::make_shared<default_opset::Floor>(multiply), ngraph::element::i64);
return {
std::make_shared<default_opset::Interpolate>(data, output_shape, attrs)};
}

} // namespace set_9

} // namespace op

} // namespace onnx_import

} // namespace ngraph
44 changes: 44 additions & 0 deletions ngraph/src/ngraph/frontend/onnx_import/op/upsample.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#pragma once

#include "core/node.hpp"
#include "ngraph/node.hpp"

namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector upsample(const Node& node);

} // namespace set_1

namespace set_9
{
NodeVector upsample(const Node& node);

} // namespace set_9

} // namespace op

} // namespace onnx_import

} // namespace ngraph
3 changes: 3 additions & 0 deletions ngraph/src/ngraph/frontend/onnx_import/ops_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
#include "op/topk.hpp"
#include "op/transpose.hpp"
#include "op/unsqueeze.hpp"
#include "op/upsample.hpp"
#include "op/where.hpp"
#include "op/xor.hpp"
#include "ops_bridge.hpp"
Expand Down Expand Up @@ -385,6 +386,8 @@ namespace ngraph
REGISTER_OPERATOR("TopK", 11, topk);
REGISTER_OPERATOR("Transpose", 1, transpose);
REGISTER_OPERATOR("Unsqueeze", 1, unsqueeze);
REGISTER_OPERATOR("Upsample", 1, upsample);
REGISTER_OPERATOR("Upsample", 9, upsample);
REGISTER_OPERATOR("Where", 1, where);
REGISTER_OPERATOR("Xor", 1, logical_xor);
}
Expand Down
59 changes: 59 additions & 0 deletions ngraph/test/models/onnx/upsample8_linear.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
ir_version: 7
producer_name: "onnx-importer-test"
graph {
node {
input: "X"
output: "Y"
op_type: "Upsample"
attribute {
name: "mode"
s: "linear"
type: STRING
}
attribute {
name: "scales"
floats: 1.0
floats: 1.0
floats: 2.0
floats: 2.0
type: FLOATS
}
}
name: "test-model"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
domain: ""
version: 8
}
59 changes: 59 additions & 0 deletions ngraph/test/models/onnx/upsample8_nearest.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
ir_version: 7
producer_name: "onnx-importer-test"
graph {
node {
input: "X"
output: "Y"
op_type: "Upsample"
attribute {
name: "mode"
s: "nearest"
type: STRING
}
attribute {
name: "scales"
floats: 1.0
floats: 1.0
floats: 2.0
floats: 3.0
type: FLOATS
}
}
name: "test-model"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
domain: ""
version: 8
}
Loading

0 comments on commit 5a2df9e

Please sign in to comment.