Skip to content

Commit

Permalink
Resolve comments 12736 (openvinotoolkit#12778)
Browse files Browse the repository at this point in the history
* Comments resolving

* Style and getting rid of asserts

* style

* Update src/common/transformations/src/transformations/smart_reshape/lstm_states_broadcast.cpp
  • Loading branch information
jane-intel authored Aug 29, 2022
1 parent 7601400 commit 79f1e72
Show file tree
Hide file tree
Showing 10 changed files with 158 additions and 174 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,27 @@
#pragma once

#include <memory>
#include <ngraph/pass/graph_rewrite.hpp>
#include <vector>

namespace ngraph {
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class NGRAPH_API LSTMStatesBroadcast;
class TRANSFORMATIONS_API LSTMStatesBroadcast;

} // namespace pass
} // namespace ngraph
} // namespace ov

/**
* @ingroup ie_transformation_common_api
* @brief In case LSTMCell has constant initial hidden and cell state with single batch size
* we make them broadcast-able by batch
*/

class ngraph::pass::LSTMStatesBroadcast : public ngraph::pass::FunctionPass {
class ov::pass::LSTMStatesBroadcast : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("LSTMStatesBroadcast", "0");
bool run_on_model(const std::shared_ptr<ngraph::Function>& m) override;
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
};
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
#pragma once

#include <memory>
#include <ngraph/pass/graph_rewrite.hpp>
#include <vector>

namespace ngraph {
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class NGRAPH_API ReshapeSinkingMatMul;
class TRANSFORMATIONS_API ReshapeSinkingMatMul;

} // namespace pass
} // namespace ngraph
} // namespace ov

/**
* @ingroup ie_transformation_common_api
Expand All @@ -24,7 +26,7 @@ class NGRAPH_API ReshapeSinkingMatMul;
* Reshape operators to make batch propagate through freely
*/

class ngraph::pass::ReshapeSinkingMatMul : public ngraph::pass::MatcherPass {
class ov::pass::ReshapeSinkingMatMul : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ReshapeSinkingMatMul", "0");
ReshapeSinkingMatMul();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ bool ngraph::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph
manager.register_pass<ngraph::pass::Validate>();

if (!m_use_shapes) { // Approved Smart Reshape
manager.register_pass<ngraph::pass::LSTMStatesBroadcast>();
manager.register_pass<ngraph::pass::Validate>();
manager.register_pass<ngraph::pass::ReshapeSinkingMatMul>();
manager.register_pass<ngraph::pass::Validate>();
manager.register_pass<ov::pass::LSTMStatesBroadcast>();
manager.register_pass<ov::pass::Validate>();
manager.register_pass<ov::pass::ReshapeSinkingMatMul>();
manager.register_pass<ov::pass::Validate>();
}

manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,23 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/smart_reshape/lstm_states_broadcast.hpp"

#include <memory>
#include <ngraph/pass/manager.hpp>
#include <openvino/op/util/sub_graph_base.hpp>
#include <openvino/opsets/opset9.hpp>
#include <transformations/smart_reshape/lstm_states_broadcast.hpp>
#include <transformations/utils/utils.hpp>

#include "dimension_tracker.hpp"
#include "itt.hpp"
#include "openvino/op/util/sub_graph_base.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/utils/utils.hpp"

using namespace std;
using namespace ov::opset9;

ov::Input<ov::Node> get_outer_input_of_ti_by_parameter(const std::shared_ptr<ov::opset9::Parameter>& parameter,
const std::shared_ptr<ov::opset9::TensorIterator>& ti) {
const auto& body = ti->get_body();
OPENVINO_ASSERT(body != nullptr, "TI returns invalid body graph ", ti);
ov::Input<ov::Node> get_outer_input_of_ti_by_parameter(const shared_ptr<Parameter>& parameter,
const shared_ptr<TensorIterator>& ti) {
int64_t parameter_index = ti->get_body()->get_parameter_index(parameter);
OPENVINO_ASSERT(parameter_index >= 0,
"LSTMStatesBroadcast encountered unregistered parameter ",
parameter,
" related to TI body ",
ti);
for (const auto& input_descriptor : ti->get_input_descriptions())
if (input_descriptor->m_body_parameter_index == parameter_index)
return ti->input(input_descriptor->m_input_index);
Expand All @@ -31,13 +28,11 @@ ov::Input<ov::Node> get_outer_input_of_ti_by_parameter(const std::shared_ptr<ov:
parameter);
}

std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
const std::shared_ptr<ov::opset9::TensorIterator>& ti,
const std::shared_ptr<ov::opset9::LSTMCell>& lstm_cell) {
const auto& body = ti->get_body();
OPENVINO_ASSERT(body != nullptr, "TI returns invalid body graph ", ti);
shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(const shared_ptr<TensorIterator>& ti,
const shared_ptr<LSTMCell>& lstm_cell) {
const auto& body = ti->get_body(); // body is not nullptr -- we checked earlier

std::map<ov::opset9::Parameter*, ov::PartialShape> original_shapes;
map<Parameter*, ov::PartialShape> original_shapes;
size_t label = 1;

// mark all input dimensions with labels and making them dynamic, keeping original shapes
Expand All @@ -46,9 +41,7 @@ std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
original_shapes[parameter.get()] = pshape;
if (pshape.rank().is_dynamic())
continue;
for (ngraph::Dimension& n : pshape) {
OPENVINO_ASSERT(ov::DimensionTracker::get_label(n) == 0,
"LSTMStatesBroadcast encountered TI with previously tracked dimensions");
for (ov::Dimension& n : pshape) {
n = ov::Dimension::dynamic();
ov::DimensionTracker::set_label(n, label++);
}
Expand All @@ -68,7 +61,7 @@ std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
}

// batch label was tracked -- finding parameter that delivered it
std::shared_ptr<ov::opset9::Parameter> batch_delivering_parameter;
shared_ptr<Parameter> batch_delivering_parameter;
size_t index_of_batch_dim = 0;

size_t batch_label = ov::DimensionTracker::get_label(lstm_cell->get_input_partial_shape(0)[0]);
Expand All @@ -80,8 +73,11 @@ std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
if (ov::DimensionTracker::get_label(pshape[i]) == batch_label) {
batch_delivering_parameter = parameter;
index_of_batch_dim = i;
break;
}
}
if (index_of_batch_dim != 0 && batch_delivering_parameter != nullptr)
break;
}
for (auto& item : original_shapes)
item.first->set_partial_shape(item.second);
Expand All @@ -91,89 +87,83 @@ std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
return nullptr;

const auto& batched_source = get_outer_input_of_ti_by_parameter(batch_delivering_parameter, ti);
const auto& batched_shape = std::make_shared<ov::opset9::ShapeOf>(batched_source.get_source_output());
const auto& batch = std::make_shared<ov::opset9::Gather>(
batched_shape,
ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {index_of_batch_dim}),
ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0}));
const auto& batched_shape = make_shared<ShapeOf>(batched_source.get_source_output());
const auto& batch = make_shared<Gather>(batched_shape,
Constant::create(ov::element::i64, ov::Shape{1}, {index_of_batch_dim}),
Constant::create(ov::element::i64, ov::Shape{}, {0}));
return batch;
}

bool broadcast_state_by_batch(ov::Input<ov::Node> input, const std::shared_ptr<ov::Node>& batch_delivering_node) {
auto constant_state =
std::dynamic_pointer_cast<ov::opset9::Constant>(input.get_source_output().get_node_shared_ptr());
bool broadcast_state_by_batch(ov::Input<ov::Node> input, const shared_ptr<ov::Node>& batch_delivering_node) {
auto constant_state = dynamic_pointer_cast<Constant>(input.get_source_output().get_node_shared_ptr());
if (constant_state == nullptr)
return false;
const auto& constant_shape = constant_state->get_shape();
OPENVINO_ASSERT(constant_shape.size() == 2, "State has unexpected shape ", constant_shape);
if (constant_shape[0] != 1)
// we only expect to broadcast LSTM states prepared for batch 1 -- no tiling of batch > 1 will be done
return false;

const auto& constant_copy = constant_state->copy_with_new_inputs({});
const auto& broadcast_by_batch = std::make_shared<ov::opset9::Broadcast>(
const auto& broadcast_by_batch = make_shared<Broadcast>(
constant_copy,
std::make_shared<ov::opset9::Concat>(
ngraph::NodeVector{batch_delivering_node,
ngraph::op::util::make_try_fold<ov::opset9::Gather>(
ngraph::op::util::make_try_fold<ov::opset9::ShapeOf>(constant_copy),
ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {1}),
ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0}))},
0));
make_shared<Concat>(ngraph::NodeVector{batch_delivering_node,
ngraph::op::util::make_try_fold<Gather>(
ngraph::op::util::make_try_fold<ShapeOf>(constant_copy),
Constant::create(ov::element::i64, ov::Shape{1}, {1}),
Constant::create(ov::element::i64, ov::Shape{}, {0}))},
0));
input.replace_source_output(broadcast_by_batch->output(0));
return true;
}

bool relax_batch_for_initial_states_of_lstm_in_ti(const std::shared_ptr<ov::opset9::TensorIterator>& ti,
const std::shared_ptr<ov::opset9::LSTMCell>& lstm_cell) {
bool relax_batch_for_initial_states_of_lstm_in_ti(const shared_ptr<TensorIterator>& ti,
const shared_ptr<LSTMCell>& lstm_cell) {
bool rewritten = false;
auto batch_delivering_node = deduce_outer_source_of_batch_for_inner_lstm_cell(ti, lstm_cell);
if (batch_delivering_node == nullptr)
return rewritten;
if (auto init_hidden_state =
std::dynamic_pointer_cast<ov::opset9::Parameter>(lstm_cell->get_input_node_shared_ptr(1))) {
if (auto init_hidden_state = dynamic_pointer_cast<Parameter>(lstm_cell->get_input_node_shared_ptr(1))) {
auto outer_init_hidden_state_input = get_outer_input_of_ti_by_parameter(init_hidden_state, ti);
rewritten |= broadcast_state_by_batch(outer_init_hidden_state_input, batch_delivering_node);
}
if (auto init_cell_state =
std::dynamic_pointer_cast<ov::opset9::Parameter>(lstm_cell->get_input_node_shared_ptr(2))) {
if (auto init_cell_state = dynamic_pointer_cast<Parameter>(lstm_cell->get_input_node_shared_ptr(2))) {
auto outer_init_cell_state_input = get_outer_input_of_ti_by_parameter(init_cell_state, ti);
rewritten |= broadcast_state_by_batch(outer_init_cell_state_input, batch_delivering_node);
}
return rewritten;
}

bool relax_batch_for_initial_states_of_lstm(const std::shared_ptr<ov::opset9::LSTMCell>& lstm_cell) {
bool relax_batch_for_initial_states_of_lstm(const shared_ptr<LSTMCell>& lstm_cell) {
bool rewritten = false;
const auto& batched_shape = std::make_shared<ov::opset9::ShapeOf>(lstm_cell->get_input_source_output(0));
const auto& batch_delivering_node =
std::make_shared<ov::opset9::Gather>(batched_shape,
ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {0}),
ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0}));
const auto& batched_shape = make_shared<ShapeOf>(lstm_cell->get_input_source_output(0));
const auto& batch_delivering_node = make_shared<Gather>(batched_shape,
Constant::create(ov::element::i64, ov::Shape{1}, {0}),
Constant::create(ov::element::i64, ov::Shape{}, {0}));
rewritten |= broadcast_state_by_batch(lstm_cell->input(1), batch_delivering_node);
rewritten |= broadcast_state_by_batch(lstm_cell->input(2), batch_delivering_node);
return rewritten;
}

bool ngraph::pass::LSTMStatesBroadcast::run_on_model(const std::shared_ptr<ngraph::Function>& f) {
bool ov::pass::LSTMStatesBroadcast::run_on_model(const shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(LSTMStatesBroadcast);
bool rewritten = false;
for (auto& node : f->get_ordered_ops()) {
// Recursively apply transformation for sub-graph based operations
if (const auto& sub_graph_node = std::dynamic_pointer_cast<ov::op::util::SubGraphOp>(node))
if (const auto& sub_graph_node = dynamic_pointer_cast<ov::op::util::SubGraphOp>(node))
if (const auto& sub_graph = sub_graph_node->get_function())
rewritten |= run_on_model(sub_graph);

// Case without TI (LSTMCell and Constant are in the same ov::Model)
if (const auto& lstm_cell = std::dynamic_pointer_cast<ov::opset9::LSTMCell>(node))
if (const auto& lstm_cell = dynamic_pointer_cast<LSTMCell>(node))
rewritten |= relax_batch_for_initial_states_of_lstm(lstm_cell);

// Case with TI (LSTMCell and Constant are in different ov::Model objects)
if (auto ti = std::dynamic_pointer_cast<ov::opset9::TensorIterator>(node)) {
if (auto ti = dynamic_pointer_cast<TensorIterator>(node)) {
auto body = ti->get_body();
OPENVINO_ASSERT(body, "TensorIterator must have body network");
if (body == nullptr)
continue;
for (const auto& body_node : body->get_ordered_ops())
if (const auto& lstm_cell = std::dynamic_pointer_cast<ov::opset9::LSTMCell>(body_node))
if (const auto& lstm_cell = dynamic_pointer_cast<LSTMCell>(body_node))
rewritten |= relax_batch_for_initial_states_of_lstm_in_ti(ti, lstm_cell);
}
}
Expand Down
Loading

0 comments on commit 79f1e72

Please sign in to comment.