Skip to content

Commit

Permalink
[TF Hub][TF FE] Fix 5D case for FusedBatchNorm (openvinotoolkit#19904)
Browse files Browse the repository at this point in the history
Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Sep 18, 2023
1 parent df19699 commit d90ceb9
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void compute_fused_batch_norm_inference(const NodeContext& node,
// retrieve attributes
auto epsilon = node.get_attribute<float>("epsilon", 0.0001f);
auto data_format = node.get_attribute<string>("data_format", "NHWC");
bool is_nhwc = (data_format == "NHWC");
bool is_nhwc = (data_format == "NHWC" || data_format == "NDHWC");

// create auxiliary Constant nodes for some attributes: epsilon and exponential_avg_factor
auto eps_const = create_same_type_const_scalar<float>(x, epsilon);
Expand Down
7 changes: 6 additions & 1 deletion tests/layer_tests/tensorflow_tests/test_tf_FusedBatchNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def create_fused_batch_norm_net(self, x_shape, epsilon, exponential_avg_factor,
# Create the graph and model
with tf.compat.v1.Session() as sess:
c_dim = x_shape[-1]
if data_format == "NCHW":
if data_format == "NCHW" or data_format == "NCDHW":
c_dim = x_shape[1]
x = tf.compat.v1.placeholder(tf.float32, x_shape, 'x')
if empty_mean_variance:
Expand Down Expand Up @@ -92,6 +92,11 @@ def create_fused_batch_norm_net(self, x_shape, epsilon, exponential_avg_factor,
fbn_version="v3"),
dict(x_shape=[5, 10, 8, 2], epsilon=0.0002, exponential_avg_factor=0.2, data_format="NHWC",
is_training=True, fbn_version="v3", empty_mean_variance=False),
# 5D cases
dict(x_shape=[5, 4, 3, 2, 3], epsilon=0.0005, exponential_avg_factor=0.0, data_format="NCDHW",
is_training=False, fbn_version="v3"),
dict(x_shape=[3, 4, 3, 3, 2], epsilon=0.0003, exponential_avg_factor=0.0, data_format="NDHWC",
is_training=False, fbn_version="v3"),
]

@pytest.mark.parametrize("params", test_data_basic)
Expand Down
1 change: 1 addition & 0 deletions tests/model_hub_tests/tf_hub_tests/precommit_models
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ imagenet/efficientnet_v2_imagenet1k_b0/feature_vector,https://tfhub.dev/google/i
imagenet/mobilenet_v1_100_224/classification,https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/5?tf-hub-format=compressed,skip,119718 - Accuracy issue
magenta/arbitrary-image-stylization-v1-256,https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2?tf-hub-format=compressed
small_bert/bert_en_uncased_L-4_H-256_A-4,https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-256_A-4/2?tf-hub-format=compressed,skip,119718 - Accuracy issue
movinet/a5/base/kinetics-600/classification,https://tfhub.dev/tensorflow/movinet/a5/base/kinetics-600/classification/3?tf-hub-format=compressed
# secure notebook models
unet/industrial/class_1,https://tfhub.dev/nvidia/unet/industrial/class_1/1?tf-hub-format=compressed
movenet/singlepose/thunder,https://tfhub.dev/google/movenet/singlepose/thunder/4?tf-hub-format=compressed
Expand Down

0 comments on commit d90ceb9

Please sign in to comment.