Skip to content

Commit

Permalink
Error out when list input is provided with Normalization layer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 379753763
  • Loading branch information
yashk2810 authored and tensorflower-gardener committed Jun 16, 2021
1 parent 214d110 commit 9df106c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
16 changes: 12 additions & 4 deletions keras/layers/preprocessing/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
# ==============================================================================
"""Normalization preprocessing layer."""

import tensorflow.compat.v2 as tf
# pylint: disable=g-classes-have-attributes

import numpy as np
from keras import backend
from keras.engine import base_preprocessing_layer
import numpy as np
import tensorflow.compat.v2 as tf

from tensorflow.python.util.tf_export import keras_export

# pylint: disable=g-classes-have-attributes



@keras_export('keras.layers.experimental.preprocessing.Normalization')
class Normalization(base_preprocessing_layer.PreprocessingLayer):
Expand Down Expand Up @@ -110,6 +112,12 @@ def __init__(self, axis=-1, mean=None, variance=None, **kwargs):
def build(self, input_shape):
super().build(input_shape)

if (isinstance(input_shape, (list, tuple)) and
all(isinstance(shape, tf.TensorShape) for shape in input_shape)):
raise ValueError('Normalization only accepts a single input. If you are '
'passing a python list or tuple as a single input, '
'please convert to a numpy array or `tf.Tensor`.')

input_shape = tf.TensorShape(input_shape).as_list()
if len(input_shape) == 1:
input_shape = input_shape + [1]
Expand Down
12 changes: 12 additions & 0 deletions keras/layers/preprocessing/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@ def test_bad_axis_fail_build(self, axis):
with self.assertRaisesRegex(ValueError, r"in the range"):
layer.build([None, 2, 3])

def test_list_input(self):
with self.assertRaisesRegex(
ValueError, ("Normalization only accepts a single input. If you are "
"passing a python list or tuple as a single input, "
"please convert to a numpy array or `tf.Tensor`.")):
normalization.Normalization()([1, 2, 3])

def test_scalar_input(self):
with self.assertRaisesRegex(ValueError,
"axis.*values must be in the range"):
normalization.Normalization()(1)


@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
class NormalizationAdaptTest(keras_parameterized.TestCase,
Expand Down

0 comments on commit 9df106c

Please sign in to comment.