Skip to content

Commit

Permalink
NodeVector -> OutputVector replacement (openvinotoolkit#1272)
Browse files Browse the repository at this point in the history
  • Loading branch information
mitruska authored Jul 29, 2020
1 parent dec7df1 commit f345116
Show file tree
Hide file tree
Showing 294 changed files with 883 additions and 901 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ void ngraph::pass::ConvertBatchToSpace::convert_batch_to_space() {
return false;
}
auto last_node = batch_to_space->decompose_op()[0];
last_node->set_friendly_name(batch_to_space->get_friendly_name());
ngraph::replace_node(batch_to_space, last_node);
last_node.get_node()->set_friendly_name(batch_to_space->get_friendly_name());
ngraph::replace_node(batch_to_space, last_node.get_node_shared_ptr());
return true;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ void ngraph::pass::ConvertSpaceToBatch::convert_space_to_batch() {
return false;
}
auto last_node = space_to_batch->decompose_op()[0];
last_node->set_friendly_name(space_to_batch->get_friendly_name());
ngraph::replace_node(space_to_batch, last_node);
last_node.get_node()->set_friendly_name(space_to_batch->get_friendly_name());
ngraph::replace_node(space_to_batch, last_node.get_node_shared_ptr());
return true;
};

Expand Down
2 changes: 1 addition & 1 deletion ngraph/src/ngraph/builder/matmul_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Output<Node> builder::MatmulFactory::get_right()
return m_inputs.at(1);
}

NodeVector builder::MatmulFactory::make_matmul_op()
OutputVector builder::MatmulFactory::make_matmul_op()
{
auto left = get_left();
auto right = get_right();
Expand Down
4 changes: 2 additions & 2 deletions ngraph/src/ngraph/builder/matmul_factory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ namespace ngraph

/// \brief Create a sub-graph representing an ONNX MatMul operation.
///
/// \return NodeVector containing the sub-graph output node.
virtual NodeVector make_matmul_op();
/// \return OutputVector containing the sub-graph output node.
virtual OutputVector make_matmul_op();

protected:
/// \return Output representing the left operand.
Expand Down
15 changes: 8 additions & 7 deletions ngraph/src/ngraph/builder/quantization_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,9 @@ namespace ngraph
}
}

void
check_concat(const NodeVector& args, const NodeVector& mins, const NodeVector& maxs)
void check_concat(const OutputVector& args,
const OutputVector& mins,
const OutputVector& maxs)
{
auto size = args.size();
if (size != mins.size() || size != maxs.size())
Expand All @@ -184,17 +185,17 @@ namespace ngraph
{
auto min = mins[i];
auto max = maxs[i];
auto type = min->get_element_type();
if (type != max->get_element_type())
auto type = min.get_element_type();
if (type != max.get_element_type())
{
throw ngraph_error("check_concat: min and max must have same type");
}

if (min->get_shape() != Shape{1} || max->get_shape() != Shape{1})
if (min.get_shape() != Shape{1} || max.get_shape() != Shape{1})
{
throw ngraph_error("check_concat: min/max shape not Shape{1}: " +
vector_to_string(min->get_shape()) +
vector_to_string(max->get_shape()));
vector_to_string(min.get_shape()) +
vector_to_string(max.get_shape()));
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions ngraph/src/ngraph/builder/quantization_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ namespace ngraph
const ngraph::element::Type& output_type,
const bool requantize = true);

void check_concat(const NodeVector& args,
const NodeVector& mins,
const NodeVector& maxs);
void check_concat(const OutputVector& args,
const OutputVector& mins,
const OutputVector& maxs);
}
}
}
16 changes: 8 additions & 8 deletions ngraph/src/ngraph/builder/quantized_concat_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,28 @@ namespace ngraph
{
namespace builder
{
shared_ptr<Node> QuantizedConcatBuilder(const NodeVector& args,
shared_ptr<Node> QuantizedConcatBuilder(const OutputVector& args,
size_t concatenation_axis,
const NodeVector& mins,
const NodeVector& maxs)
const OutputVector& mins,
const OutputVector& maxs)
{
quantization_utils::check_concat(args, mins, maxs);
auto quant_type = args[0]->get_element_type();
auto quant_type = args[0].get_element_type();

// output scale
auto min = make_shared<op::Min>(make_shared<op::Concat>(mins, 0), ngraph::AxisSet{0});
auto max = make_shared<op::Max>(make_shared<op::Concat>(maxs, 0), ngraph::AxisSet{0});
auto out_scale = quantization_utils::get_scale(min, max, quant_type);

NodeVector rescaled_args(args.size());
OutputVector rescaled_args(args.size());
for (size_t i = 0; i < args.size(); ++i)
{
auto q_type = args[i]->get_element_type();
auto q_type = args[i].get_element_type();
auto in_scale = make_shared<ngraph::op::Reshape>(
quantization_utils::get_scale(mins[i], maxs[i], q_type),
AxisVector{0},
Shape{});
auto zero = make_constant(q_type, in_scale->get_shape(), 0);
auto zero = make_constant(q_type, in_scale->get_output_shape(0), 0);

rescaled_args[i] =
make_shared<op::Dequantize>(args[i], in_scale, zero, element::f32, AxisSet{});
Expand All @@ -58,7 +58,7 @@ namespace ngraph
AxisSet{},
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
}
OutputVector base = as_output_vector(args);
OutputVector base = args;
for (auto node : mins)
{
base.push_back(node);
Expand Down
6 changes: 3 additions & 3 deletions ngraph/src/ngraph/builder/quantized_concat_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ namespace ngraph
namespace builder
{
NGRAPH_API
std::shared_ptr<Node> QuantizedConcatBuilder(const NodeVector& args,
std::shared_ptr<Node> QuantizedConcatBuilder(const OutputVector& args,
size_t concatenation_axis,
const NodeVector& mins,
const NodeVector& maxs);
const OutputVector& mins,
const OutputVector& maxs);
}
}
44 changes: 10 additions & 34 deletions ngraph/src/ngraph/builder/split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,37 +47,13 @@ namespace
std::make_shared<op::Slice>(output, lower_bounds, upper_bounds)
->add_provenance_group_members_above({output}));
}

/// \brief Return the outputs of the node as vector.
///
/// \param[in] node Node with multiple outputs.
///
/// \return Vector of outputs of input node.
NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node)
{
const auto outputs_number = node->get_output_size();
ngraph::NodeVector outputs(outputs_number);
for (int i = 0; i < outputs_number; ++i)
{
if (node->output(i).get_node_shared_ptr()->get_output_size() == 1)
{
outputs[i] = node->get_output_as_single_output_node(i);
}
else
{
outputs[i] = std::make_shared<op::GetOutputElement>(node, i);
}
}
return outputs;
}
}

NodeVector builder::split(const Output<ngraph::Node>& value,
const std::vector<size_t>& length_parts,
size_t axis)
OutputVector
builder::split(const Output<Node>& value, const std::vector<size_t>& length_parts, size_t axis)
{
size_t start_index{0};
NodeVector outputs;
OutputVector outputs;
for (const auto& length_part : length_parts)
{
size_t end_index{start_index + length_part};
Expand All @@ -87,7 +63,7 @@ NodeVector builder::split(const Output<ngraph::Node>& value,
return outputs;
}

NodeVector builder::split(const Output<Node>& value, size_t split_parts, int axis)
OutputVector builder::split(const Output<Node>& value, size_t split_parts, int axis)
{
size_t axis_to_split{static_cast<size_t>(axis)};
if (axis < 0)
Expand All @@ -100,23 +76,23 @@ NodeVector builder::split(const Output<Node>& value, size_t split_parts, int axi
return split(value, length_parts, axis_to_split);
}

NodeVector builder::opset1::split(const Output<Node>& value,
const std::vector<size_t>& split_lengths,
int64_t axis)
OutputVector builder::opset1::split(const Output<Node>& value,
const std::vector<size_t>& split_lengths,
int64_t axis)
{
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
const auto split_lengths_node =
ngraph::opset1::Constant::create(element::u64, Shape{split_lengths.size()}, split_lengths);
const auto variadic_split =
std::make_shared<ngraph::opset1::VariadicSplit>(value, axis_node, split_lengths_node);

return get_outputs(variadic_split);
return variadic_split->outputs();
}

NodeVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis)
OutputVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis)
{
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits);

return get_outputs(split);
return split->outputs();
}
20 changes: 10 additions & 10 deletions ngraph/src/ngraph/builder/split.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ namespace ngraph
///
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector split(const Output<Node>& value,
const std::vector<size_t>& length_parts,
size_t axis = 0);
OutputVector split(const Output<Node>& value,
const std::vector<size_t>& length_parts,
size_t axis = 0);

/// \brief Split node on specified axis into multiple parts.
///
Expand All @@ -47,9 +47,9 @@ namespace ngraph
/// indexing). This means that the axis to split on will be counted from
/// the back of the tensor (negative values are subtracted from its rank).
///
/// \return The vector containing multiple nodes we split input node into.
/// \return The vector containing multiple outputs we split input node into.
///
NodeVector split(const Output<Node>& value, size_t split_parts, int axis = 0);
OutputVector split(const Output<Node>& value, size_t split_parts, int axis = 0);

namespace opset1
{
Expand All @@ -63,13 +63,13 @@ namespace ngraph
/// indexing). This means that the axis to split on will be counted from
/// the back of the tensor (negative values are subtracted from its rank).
///
/// \return The vector containing multiple nodes we split input node into.
/// \return The vector containing multiple outputs we split input node into.
/// The vector is output of Split:v1 op
///
NGRAPH_API
NodeVector split(const Output<Node>& value,
const std::vector<size_t>& split_lengths,
int64_t axis = 0);
OutputVector split(const Output<Node>& value,
const std::vector<size_t>& split_lengths,
int64_t axis = 0);

/// \brief Split value on specified axis into multiple parts.
///
Expand All @@ -88,7 +88,7 @@ namespace ngraph
/// The vector is output of VariadicSplit:v1 op
///
NGRAPH_API
NodeVector split(const Output<Node>& value, size_t num_splits, int64_t axis = 0);
OutputVector split(const Output<Node>& value, size_t num_splits, int64_t axis = 0);
}
} // namespace builder
} // namespace ngraph
22 changes: 11 additions & 11 deletions ngraph/src/ngraph/frontend/onnx_import/core/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ namespace ngraph
m_nodes.emplace_back(node_proto, *this);
const Node& node{m_nodes.back()};

NodeVector ng_nodes{node.get_ng_nodes()};
OutputVector ng_nodes{node.get_ng_nodes()};
// Iterate over the number of outputs for given node in graph.
// Some of them may be optional and trimmed. See:
// https://github.com/onnx/onnx/blob/master/docs/IR.md#optional-inputs-and-outputs
Expand All @@ -174,26 +174,26 @@ namespace ngraph
return m_cache->contains(name);
}

std::shared_ptr<ngraph::Node> Graph::get_ng_node_from_cache(const std::string& name) const
Output<ngraph::Node> Graph::get_ng_node_from_cache(const std::string& name) const
{
return m_cache->get_node(name);
}

NodeVector Graph::get_ng_outputs() const
OutputVector Graph::get_ng_outputs() const
{
NodeVector results;
OutputVector results;
for (const auto& output : m_graph_proto->output())
{
results.emplace_back(get_ng_node_from_cache(output.name()));
}
return results;
}

NodeVector Graph::make_ng_nodes(const Node& onnx_node) const
OutputVector Graph::make_ng_nodes(const Node& onnx_node) const
{
const auto ng_node_factory =
m_model->get_operator(onnx_node.op_type(), onnx_node.domain());
NodeVector ng_node_vector;
OutputVector ng_node_vector;
try
{
ng_node_vector = ng_node_factory(onnx_node);
Expand Down Expand Up @@ -223,7 +223,7 @@ namespace ngraph
}

void Graph::set_friendly_names(const Node& onnx_node,
const NodeVector& ng_node_vector) const
const OutputVector& ng_node_vector) const
{
for (int i = 0; i < ng_node_vector.size(); ++i)
{
Expand All @@ -234,7 +234,7 @@ namespace ngraph
break;
}

ng_node_vector[i]->set_friendly_name(onnx_node.output(i));
ng_node_vector[i].get_node()->set_friendly_name(onnx_node.output(i));
}
}

Expand Down Expand Up @@ -267,7 +267,7 @@ namespace ngraph
}

void Graph::add_provenance_tags(const Node& onnx_node,
const NodeVector& ng_node_vector) const
const OutputVector& ng_node_vector) const
{
if (!ngraph::get_provenance_enabled())
{
Expand All @@ -278,9 +278,9 @@ namespace ngraph
const auto ng_inputs = onnx_node.get_ng_inputs();

ngraph::traverse_nodes(
ng_node_vector,
as_node_vector(ng_node_vector),
[&tag](std::shared_ptr<ngraph::Node> ng_node) { ng_node->add_provenance_tag(tag); },
ng_inputs);
as_node_vector(ng_inputs));
}

Subgraph::Subgraph(const ONNX_NAMESPACE::GraphProto& proto,
Expand Down
12 changes: 7 additions & 5 deletions ngraph/src/ngraph/frontend/onnx_import/core/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,30 @@ namespace ngraph
const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
const std::vector<ValueInfo>& get_outputs() const { return m_outputs; }
NodeVector get_ng_outputs() const;
OutputVector get_ng_outputs() const;
const ParameterVector& get_ng_parameters() const { return m_parameters; }
bool is_node_in_cache(const std::string& name) const;
std::shared_ptr<ngraph::Node> get_ng_node_from_cache(const std::string& name) const;
Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const;
const std::string& get_name() const { return m_graph_proto->name(); }
NodeVector make_ng_nodes(const Node& onnx_node) const;
OutputVector make_ng_nodes(const Node& onnx_node) const;
const GraphCache& get_graph_cache() const;

protected:
Graph(const ONNX_NAMESPACE::GraphProto& proto,
Model& model,
std::unique_ptr<GraphCache>&& cache);

void set_friendly_names(const Node& onnx_node, const NodeVector& ng_node_vector) const;
void set_friendly_names(const Node& onnx_node,
const OutputVector& ng_node_vector) const;

void add_provenance_tag_to_initializer(
const Tensor& initializer, std::shared_ptr<default_opset::Constant> node) const;

void add_provenance_tag_to_input(const ValueInfo& input,
std::shared_ptr<ngraph::Node> node) const;

void add_provenance_tags(const Node& onnx_node, const NodeVector& ng_node_vector) const;
void add_provenance_tags(const Node& onnx_node,
const OutputVector& ng_node_vector) const;

private:
const ONNX_NAMESPACE::GraphProto* m_graph_proto;
Expand Down
Loading

0 comments on commit f345116

Please sign in to comment.