Skip to content

Commit

Permalink
check the trainable type for tf.keras.layers.Layer.__init__ to be boo…
Browse files Browse the repository at this point in the history
…lean.

PiperOrigin-RevId: 387021364
  • Loading branch information
haifeng-jin authored and tensorflower-gardener committed Jul 27, 2021
1 parent d3688b7 commit b877289
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 75 deletions.
5 changes: 5 additions & 0 deletions keras/engine/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@ def __init__(self,
# Mutable properties
# Indicates whether the layer's weights are updated during training
# and whether the layer's updates are run during training.
if not (isinstance(trainable, bool) or
(isinstance(trainable, (tf.Tensor, tf.Variable)) and
trainable.dtype is tf.bool)):
raise TypeError(
f'Expected trainable argument to be a boolean, but got: {trainable}')
self._trainable = trainable
# A stateful layer is a layer whose updates are run during inference too,
# for instance stateful RNNs.
Expand Down
10 changes: 9 additions & 1 deletion keras/engine/base_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,14 @@ def __init__(self):
self.assertEqual(self.evaluate(layer.weights[0]), 1.)
self.assertEqual(self.evaluate(layer.weights[1]), 2.)

def test_exception_if_trainable_not_boolean(self):
base_layer.Layer(trainable=True)
base_layer.Layer(trainable=tf.constant(True))
base_layer.Layer(trainable=tf.Variable(tf.constant(True)))
with self.assertRaisesRegex(TypeError,
'Expected trainable argument to be a boolean'):
base_layer.Layer(trainable=0)

@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_layer_names(self):
inputs = input_layer.Input(shape=[2])
Expand Down Expand Up @@ -916,7 +924,7 @@ def __call__(self, x):
class MyLayer(base_layer.Layer):

def __init__(self, **kwargs):
super(MyLayer, self).__init__(self, **kwargs)
super(MyLayer, self).__init__(**kwargs)
self.my_modules = {}
self.my_modules['a'] = MyModule()

Expand Down
165 changes: 95 additions & 70 deletions keras/feature_column/dense_features_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ def test_retrieving_input(self):
@combinations.generate(combinations.combine(mode=['eager']))
def test_reuses_variables(self):
sparse_input = tf.SparseTensor(
indices=((0, 0), (1, 0), (2, 0)),
values=(0, 1, 2),
dense_shape=(3, 3))
indices=((0, 0), (1, 0), (2, 0)), values=(0, 1, 2), dense_shape=(3, 3))

# Create feature columns (categorical and embedding).
categorical_column = tf.feature_column.categorical_column_with_identity(
Expand Down Expand Up @@ -122,8 +120,7 @@ def _embedding_column_initializer(shape, dtype, partition_info=None):
initializer=_embedding_column_initializer)

dense_features = df.DenseFeatures(
[embedding_column],
partitioner=tf.compat.v1.fixed_size_partitioner(2))
[embedding_column], partitioner=tf.compat.v1.fixed_size_partitioner(2))
features = {'a': sparse_input}

inputs = dense_features(features)
Expand All @@ -145,9 +142,7 @@ def _embedding_column_initializer(shape, dtype, partition_info=None):
@combinations.generate(combinations.combine(mode=['eager']))
def test_feature_column_dense_features_gradient(self):
sparse_input = tf.SparseTensor(
indices=((0, 0), (1, 0), (2, 0)),
values=(0, 1, 2),
dense_shape=(3, 3))
indices=((0, 0), (1, 0), (2, 0)), values=(0, 1, 2), dense_shape=(3, 3))

# Create feature columns (categorical and embedding).
categorical_column = tf.feature_column.categorical_column_with_identity(
Expand Down Expand Up @@ -205,10 +200,11 @@ def test_should_be_dense_column(self):
def test_does_not_support_dict_columns(self):
with self.assertRaisesRegex(
ValueError, 'Expected feature_columns to be iterable, found dict.'):
df.DenseFeatures(feature_columns={'a': tf.feature_column.numeric_column('a')})(
features={
'a': [[0]]
})
df.DenseFeatures(
feature_columns={'a': tf.feature_column.numeric_column('a')})(
features={
'a': [[0]]
})

def test_bare_column(self):
with tf.Graph().as_default():
Expand All @@ -234,12 +230,13 @@ def test_column_generator(self):
def test_raises_if_duplicate_name(self):
with self.assertRaisesRegex(
ValueError, 'Duplicate feature column name found for columns'):
df.DenseFeatures(
feature_columns=[tf.feature_column.numeric_column('a'),
tf.feature_column.numeric_column('a')])(
features={
'a': [[0]]
})
df.DenseFeatures(feature_columns=[
tf.feature_column.numeric_column('a'),
tf.feature_column.numeric_column('a')
])(
features={
'a': [[0]]
})

def test_one_column(self):
price = tf.feature_column.numeric_column('price')
Expand Down Expand Up @@ -348,7 +345,8 @@ def test_column_order(self):
self.assertAllClose([[1., 3.]], self.evaluate(net2))

def test_fails_for_categorical_column(self):
animal = tf.feature_column.categorical_column_with_identity('animal', num_buckets=4)
animal = tf.feature_column.categorical_column_with_identity(
'animal', num_buckets=4)
with tf.Graph().as_default():
features = {
'animal':
Expand Down Expand Up @@ -431,15 +429,19 @@ def test_multiple_layers_with_same_embedding_column(self):
df.DenseFeatures(all_cols)(features)
df.DenseFeatures(all_cols)(features)
# Make sure that 2 variables get created in this case.
self.assertEqual(2,
len(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)))
self.assertEqual(
2,
len(
tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)))
expected_var_names = [
'dense_features/sparse_feature_embedding/embedding_weights:0',
'dense_features_1/sparse_feature_embedding/embedding_weights:0'
]
self.assertItemsEqual(
expected_var_names,
[v.name for v in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)])
self.assertCountEqual(expected_var_names, [
v.name for v in tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)
])

@test_util.run_deprecated_v1
def test_multiple_layers_with_same_shared_embedding_column(self):
Expand Down Expand Up @@ -469,11 +471,15 @@ def test_multiple_layers_with_same_shared_embedding_column(self):
df.DenseFeatures(all_cols)(features)
df.DenseFeatures(all_cols)(features)
# Make sure that only 1 variable gets created in this case.
self.assertEqual(1,
len(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)))
self.assertItemsEqual(
['aaa_bbb_shared_embedding:0'],
[v.name for v in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)])
self.assertEqual(
1,
len(
tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)))
self.assertCountEqual(['aaa_bbb_shared_embedding:0'], [
v.name for v in tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)
])

@test_util.run_deprecated_v1
def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self):
Expand Down Expand Up @@ -502,8 +508,11 @@ def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self):
}
df.DenseFeatures(all_cols)(features)
# Make sure that only 1 variable gets created in this case.
self.assertEqual(1,
len(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)))
self.assertEqual(
1,
len(
tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)))

with tf.Graph().as_default():
features1 = {
Expand All @@ -521,11 +530,15 @@ def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self):

df.DenseFeatures(all_cols)(features1)
# Make sure that only 1 variable gets created in this case.
self.assertEqual(1,
len(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)))
self.assertItemsEqual(
['aaa_bbb_shared_embedding:0'],
[v.name for v in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)])
self.assertEqual(
1,
len(
tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)))
self.assertCountEqual(['aaa_bbb_shared_embedding:0'], [
v.name for v in tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)
])

@test_util.run_deprecated_v1
def test_with_1d_sparse_tensor(self):
Expand Down Expand Up @@ -672,7 +685,8 @@ class IndicatorColumnTest(tf.test.TestCase):
@test_util.run_deprecated_v1
def test_dense_features(self):
animal = tf.feature_column.indicator_column(
tf.feature_column.categorical_column_with_identity('animal', num_buckets=4))
tf.feature_column.categorical_column_with_identity(
'animal', num_buckets=4))
with tf.Graph().as_default():
features = {
'animal':
Expand Down Expand Up @@ -771,7 +785,8 @@ def _initializer(shape, dtype, partition_info=None):
dense_features = l({'aaa': sparse_input})

# Assert expected embedding variable and lookups.
global_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)
global_vars = tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)
if partition_variables:
self.assertCountEqual(
('vars/dense_features/aaa_embedding/embedding_weights/part_0:0',
Expand All @@ -783,7 +798,8 @@ def _initializer(shape, dtype, partition_info=None):
tuple([v.name for v in global_vars]))
for v in global_vars:
self.assertIsInstance(v, tf.Variable)
trainable_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)
trainable_vars = tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)
if partition_variables:
self.assertCountEqual(
('vars/dense_features/aaa_embedding/embedding_weights/part_0:0',
Expand All @@ -801,8 +817,9 @@ def _initializer(shape, dtype, partition_info=None):
self.assertAllEqual(expected_lookups, self.evaluate(dense_features))

if use_safe_embedding_lookup:
self.assertIn('SparseFillEmptyRows',
[x.type for x in tf.compat.v1.get_default_graph().get_operations()])
self.assertIn(
'SparseFillEmptyRows',
[x.type for x in tf.compat.v1.get_default_graph().get_operations()])
else:
self.assertNotIn(
'SparseFillEmptyRows',
Expand Down Expand Up @@ -862,11 +879,13 @@ def _initializer(shape, dtype, partition_info=None):
})

# Assert expected embedding variable and lookups.
global_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)
global_vars = tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)
self.assertCountEqual(('dense_features/aaa_embedding/embedding_weights:0',),
tuple([v.name for v in global_vars]))
self.assertCountEqual([],
tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES))
tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES))

self.evaluate(tf.compat.v1.global_variables_initializer())
self.evaluate(tf.compat.v1.tables_initializer())
Expand Down Expand Up @@ -970,13 +989,15 @@ def _initializer(shape, dtype, partition_info=None):
features)

# Assert expected embedding variable and lookups.
global_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)
global_vars = tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)
self.assertCountEqual(
['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'],
tuple([v.name for v in global_vars]))
for v in global_vars:
self.assertIsInstance(v, tf.Variable)
trainable_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)
trainable_vars = tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)
if trainable:
self.assertCountEqual(
['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'],
Expand Down Expand Up @@ -1004,39 +1025,42 @@ def test_dense_features_no_trainable(self):
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class DenseFeaturesSerializationTest(tf.test.TestCase, parameterized.TestCase):

@parameterized.named_parameters(
('default', None, None),
('trainable', True, 'trainable'),
('not_trainable', False, 'frozen'))
@parameterized.named_parameters(('trainable', True, 'trainable'),
('not_trainable', False, 'frozen'))
def test_get_config(self, trainable, name):
cols = [tf.feature_column.numeric_column('a'),
tf.feature_column.embedding_column(tf.feature_column.categorical_column_with_identity(
key='b', num_buckets=3), dimension=2)]
orig_layer = df.DenseFeatures(
cols, trainable=trainable, name=name)
cols = [
tf.feature_column.numeric_column('a'),
tf.feature_column.embedding_column(
tf.feature_column.categorical_column_with_identity(
key='b', num_buckets=3),
dimension=2)
]
orig_layer = df.DenseFeatures(cols, trainable=trainable, name=name)
config = orig_layer.get_config()

self.assertEqual(config['name'], orig_layer.name)
self.assertEqual(config['trainable'], trainable)
self.assertLen(config['feature_columns'], 2)
self.assertEqual(
config['feature_columns'][0]['class_name'], 'NumericColumn')
self.assertEqual(config['feature_columns'][0]['class_name'],
'NumericColumn')
self.assertEqual(config['feature_columns'][0]['config']['shape'], (1,))
self.assertEqual(
config['feature_columns'][1]['class_name'], 'EmbeddingColumn')
self.assertEqual(config['feature_columns'][1]['class_name'],
'EmbeddingColumn')

@parameterized.named_parameters(
('default', None, None),
('trainable', True, 'trainable'),
('not_trainable', False, 'frozen'))
@parameterized.named_parameters(('trainable', True, 'trainable'),
('not_trainable', False, 'frozen'))
def test_from_config(self, trainable, name):
cols = [tf.feature_column.numeric_column('a'),
tf.feature_column.embedding_column(tf.feature_column.categorical_column_with_vocabulary_list(
'b', vocabulary_list=['1', '2', '3']), dimension=2),
tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_hash_bucket(
key='c', hash_bucket_size=3))]
orig_layer = df.DenseFeatures(
cols, trainable=trainable, name=name)
cols = [
tf.feature_column.numeric_column('a'),
tf.feature_column.embedding_column(
tf.feature_column.categorical_column_with_vocabulary_list(
'b', vocabulary_list=['1', '2', '3']),
dimension=2),
tf.feature_column.indicator_column(
tf.feature_column.categorical_column_with_hash_bucket(
key='c', hash_bucket_size=3))
]
orig_layer = df.DenseFeatures(cols, trainable=trainable, name=name)
config = orig_layer.get_config()

new_layer = df.DenseFeatures.from_config(config)
Expand Down Expand Up @@ -1106,7 +1130,8 @@ def test_indicator_column(self):

categorical_column_a = tf.feature_column.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
indicator_column_a = tf.feature_column.indicator_column(categorical_column_a)
indicator_column_a = tf.feature_column.indicator_column(
categorical_column_a)

input_layer = df.DenseFeatures([indicator_column_a])
with self.assertRaisesRegex(
Expand Down
6 changes: 2 additions & 4 deletions keras/feature_column/sequence_feature_column_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,7 @@ def test_compute_output_shape(self):
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class SequenceFeaturesSerializationTest(tf.test.TestCase, parameterized.TestCase):

@parameterized.named_parameters(('default', None, None),
('trainable', True, 'trainable'),
@parameterized.named_parameters(('trainable', True, 'trainable'),
('not_trainable', False, 'frozen'))
def test_get_config(self, trainable, name):
cols = [tf.feature_column.sequence_numeric_column('a')]
Expand All @@ -578,8 +577,7 @@ def test_get_config(self, trainable, name):
'SequenceNumericColumn')
self.assertEqual(config['feature_columns'][0]['config']['shape'], (1,))

@parameterized.named_parameters(('default', None, None),
('trainable', True, 'trainable'),
@parameterized.named_parameters(('trainable', True, 'trainable'),
('not_trainable', False, 'frozen'))
def test_from_config(self, trainable, name):
cols = [tf.feature_column.sequence_numeric_column('a')]
Expand Down

0 comments on commit b877289

Please sign in to comment.