Skip to content

Commit

Permalink
ScatterUpdate ng op shell revision (openvinotoolkit#7375)
Browse files Browse the repository at this point in the history
* add visitors, type_prop tests, update ngrap op class

* update NGRPH_RTTI for scatter_update

* add proper formatting for error message

* update opset
  • Loading branch information
bszmelcz authored Sep 24, 2021
1 parent f202c45 commit f038fcf
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 150 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>
#include <ngraph/opsets/opset8.hpp>

#include "shared_test_classes/single_layer/scatter_update.hpp"
#include "common_test_utils/test_constants.hpp"

using namespace LayerTestsDefinitions;
using namespace ngraph::opset8;

namespace {
TEST_P(ScatterUpdateLayerTest, Serialize) {
Serialize();
}

const std::vector<InferenceEngine::Precision> inputPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::I32,
};

const std::vector<InferenceEngine::Precision> idxPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64,
};

// map<inputShape, map<indicesShape, axis>>
std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<int>>> axesShapeInShape {
{{10, 16, 12, 15}, {{{2, 4}, {0, 1, 2, 3}}, {{8}, {-1, -2, -3, -4}}}},
{{10, 9, 10, 9, 10}, {{{8}, {-3, -1, 0, 2, 4}}, {{4, 2}, {-2, 2}}}},
};
//indices should not be random value
const std::vector<std::vector<int64_t>> idxValue = {
{0, 2, 4, 6, 1, 3, 5, 7}
};

const auto ScatterUpdateCase = ::testing::Combine(
::testing::ValuesIn(ScatterUpdateLayerTest::combineShapes(axesShapeInShape)),
::testing::ValuesIn(idxValue),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);

INSTANTIATE_TEST_SUITE_P(smoke_ScatterUpdate, ScatterUpdateLayerTest, ScatterUpdateCase, ScatterUpdateLayerTest::getTestCaseName);

} // namespace

17 changes: 12 additions & 5 deletions ngraph/core/src/op/util/scatter_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@ void ov::op::util::ScatterBase::validate_and_infer_types() {
this,
data_shape.rank().is_dynamic() || indices_shape.rank().is_dynamic() || updates_shape.rank().is_dynamic() ||
updates_shape.rank().get_length() == indices_shape.rank().get_length() + data_shape.rank().get_length() - 1,
"Updates rank is expected to be indices rank + data rank - 1.");
"Updates rank is expected to be rank(indices) + rank(data) - 1.",
" Got: rank(data) = ",
data_shape.rank().get_length(),
", rank(indices) = ",
indices_shape.rank().get_length(),
", rank(updates) = ",
updates_shape.rank().get_length());

if (data_shape.is_dynamic()) {
set_input_is_relevant_to_shape(0);
Expand All @@ -73,20 +79,21 @@ void ov::op::util::ScatterBase::validate_and_infer_types() {
if (const auto& axis_const_input = get_constant_from_source(input_value(AXIS))) {
bool compatible = true;
int64_t axis = axis_const_input->cast_vector<int64_t>().at(0);
axis = ngraph::normalize_axis(this, axis, data_shape.rank().get_length());
int64_t data_rank = data_shape.rank().get_length();
axis = ngraph::normalize_axis(this, axis, data_rank);

if (indices_shape.rank().is_static() && updates_shape.rank().is_static()) {
for (int64_t i = 0; i < indices_shape.rank().get_length(); ++i) {
int64_t indices_rank = indices_shape.rank().get_length();
for (int64_t i = 0; i < indices_rank; ++i) {
compatible = compatible && updates_shape[axis + i].compatible(indices_shape[i]);
}

int64_t indices_rank = indices_shape.rank().get_length();
// Check [d_0, d_1, ... d_(axis - 1)] updates dimensions
for (int64_t i = 0; i < axis; ++i) {
compatible = compatible && updates_shape[i].compatible(data_shape[i]);
}
// Check [d_(axis + k + 1), ..., d_n] updates dimensions
for (int64_t i = axis + 1; i < data_shape.rank().get_length(); ++i) {
for (int64_t i = axis + 1; i < data_rank; ++i) {
compatible = compatible && updates_shape[indices_rank - 1 + i].compatible(data_shape[i]);
}
}
Expand Down
1 change: 1 addition & 0 deletions ngraph/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ set(SRC
visitors/op/roi_pooling.cpp
visitors/op/round.cpp
visitors/op/scatter_elements_update.cpp
visitors/op/scatter_update.cpp
visitors/op/select.cpp
visitors/op/space_to_depth.cpp
visitors/op/selu.cpp
Expand Down
Loading

0 comments on commit f038fcf

Please sign in to comment.