Skip to content

Commit

Permalink
Fix TransposeSinking transformation for Gather op (openvinotoolkit#17540
Browse files Browse the repository at this point in the history
)

* fix TransposeSinking for Gather op

* add test

* fix copyright

* Resolve review comments
  • Loading branch information
itikhono authored May 24, 2023
1 parent 0d3b636 commit fa428a1
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporationconvert_reduce_to_pooling
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,27 +169,52 @@ TSGatherBackward::TSGatherBackward() {
bool optimization = out_pshape.is_static() && main_node->input_value(1).get_partial_shape().is_static();
bool success = false;
std::vector<size_t> axes_val;
std::shared_ptr<ov::op::v0::Squeeze> squeeze;
// In some cases shape of 2nd input to Gather op (indices) has `1` dims which can
// prevent TransposeSinking in backward direction.
// We can get around this case by wrapping Transpose op with Squeeze+Unsqueeze pair.
/*
* Data_input:shape(257, 8) Indices_input: shape(1, 2)
│ │
└────────────┐ ┌─────────────┘
▼ ▼
Gather(axis = 0)
Gather output: shape(1,2,8)
Transpose
Transpose output: shape(1,8,2)
*/
if (optimization) {
auto squeeze = std::make_shared<ov::op::v0::Squeeze>(main_node->input_value(1));
squeeze = std::make_shared<ov::op::v0::Squeeze>(main_node->input_value(1));
copy_runtime_info(main_node, squeeze);
main_node->input(1).replace_source_output(squeeze);
main_node->validate_and_infer_types();
auto new_out_pshape = main_node->get_output_partial_shape(0);
auto shape = out_pshape.get_shape();
auto new_shape = new_out_pshape.get_shape();
success = !(new_out_pshape.is_dynamic() || shape == new_shape);
if (success) {
size_t j = 0;
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] != new_shape[j] && shape[i] == 1) {
axes_val.push_back(i);
continue;
} else if (shape[i] != new_shape[j]) {
if (new_out_pshape.is_static()) {
const auto shape = out_pshape.get_shape();
const auto new_shape = new_out_pshape.get_shape();
success = shape != new_shape;
if (success) {
size_t j = 0;
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] != new_shape[j] && shape[i] == 1) {
axes_val.push_back(i);
continue;
} else if (shape[i] != new_shape[j]) {
success = false;
break;
}
j++;
}
if (j != new_shape.size()) {
success = false;
}
j++;
}
if (j != new_shape.size()) {
success = false;
}
}
if (!success) {
Expand All @@ -216,6 +241,9 @@ TSGatherBackward::TSGatherBackward() {
size_t prev_idx = i;
for (size_t k = 0; i < order_val.size() && k < indices_rank_val; ++i, ++k) {
if (order_val[i] != order_val[prev_idx]) {
if (success && squeeze) {
main_node->input(1).replace_source_output(squeeze->input_value(0));
}
return false;
}
prev_idx = i;
Expand All @@ -230,6 +258,11 @@ TSGatherBackward::TSGatherBackward() {
for (const auto& input : target_inputs) {
input.replace_source_output(unsqueeze);
}
unsqueeze->output(0).add_names(main_node->output(0).get_names());
main_node->output(0).set_names({});
unsqueeze->set_friendly_name(main_node->get_friendly_name());
main_node->set_friendly_name("");
copy_runtime_info(main_node, {unsqueeze, unsqueeze_axes});
}
const auto reversed_transpose_order = ReverseTransposeOrder(order_val);
const auto& transpose_const = ov::op::v0::Constant::create(transpose_order->get_element_type(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,51 @@ vector<GatherBackwardArguments> tests_arguments_bw{

INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackward_0, TSTestFixture, test_backward_gather(tests_arguments_bw[0]));

// In some cases shape of 2nd input to Gather op (indices) has `1` dims which can
// prevent TransposeSinking in backward direction.
// We can get around this case by wrapping Transpose op with Squeeze+Unsqueeze pair.
auto test_backward_gather_optimization = [](const GatherBackwardArguments& test_arguments) {
TestCase test_case;

// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSGatherBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = test_arguments.inputs_to_main;

// Test model description:
test_case.model.main_op = {CREATE_GATHER_FACTORY(Gather)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = create_model;

// Reference model description:
auto update_gather_inputs = [&](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] = make_shared<Squeeze>(out_vec[1]);
new_out_vec[2] = test_arguments.new_input_to_Gather_1;
return new_out_vec;
};

auto unsqueeze_for = [&](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
auto axis = constant<int>(i32, {1}, {0});
return {make_shared<Unsqueeze>(out_vec[0], axis)};
};

test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, update_gather_inputs}, {{0}, {1, 2}}};
test_case.model_ref.main_op = {CREATE_GATHER_FACTORY(Gather)};
test_case.model_ref.preprocess_outputs_of_main = {{unsqueeze_for}, {{0}}};
test_case.model_ref.model_template = create_model;

return wrapper(test_case);
};

vector<GatherBackwardArguments> tests_arguments_bw_optimization{
{{{parameter(f32, {257, 8}), constant<int>(i32, {1, 2}, {0}), constant<int>(i32, {1}, {0})}},
constant<int>(i32, {1}, {1})}};

INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_0,
TSTestFixture,
test_backward_gather_optimization(tests_arguments_bw_optimization[0]));
} // namespace gather
} // namespace testing
} // namespace transpose_sinking

0 comments on commit fa428a1

Please sign in to comment.