Skip to content

Commit

Permalink
Sync the OSS keras test.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 375730632
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed May 25, 2021
1 parent 93ac9ec commit 58c20db
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 49 deletions.
62 changes: 30 additions & 32 deletions keras/saving/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,14 +289,13 @@ def get_variables(file_name):
path = os.path.join(self.get_temp_dir(), 'no_optimizer')
x, y = np.ones((10, 10)), np.ones((10, 1))

with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(1))
model.compile('adam', loss='mse')
model.train_on_batch(x, y)
model.save(path, save_format='tf', include_optimizer=False)

model = keras.models.Sequential()
model.add(keras.layers.Dense(1))
model.compile('adam', loss='mse')
model.train_on_batch(x, y)
model.save(path, save_format='tf', include_optimizer=False)
variables = get_variables(path)

for v in variables:
self.assertNotIn('optimizer', v)

Expand Down Expand Up @@ -1034,31 +1033,30 @@ def test_multi_output_metrics_name_stay_same(self, fit):
if not tf.executing_eagerly() and not fit:
self.skipTest('b/181767784')

with self.cached_session():
input_ = keras.Input((4,))
model = keras.Model(
input_,
[keras.layers.Softmax(name='head_0')(keras.layers.Dense(3)(input_)),
keras.layers.Softmax(name='head_1')(keras.layers.Dense(5)(input_))])
metric = keras.metrics.BinaryAccuracy()
model.compile(optimizer='rmsprop',
loss='mse',
metrics={'head_0': [metric, 'accuracy']})

x = np.random.rand(2, 4)
y = {'head_0': np.random.randint(2, size=(2, 3)),
'head_1': np.random.randint(2, size=(2, 5))}

# Make sure metrix prefixing works the same regardless of whether the user
# has fit the model before saving.
if fit:
model.fit(x, y, verbose=0)

# Save and reload.
save_format = testing_utils.get_save_format()
saved_model_dir = self._save_model_dir()
keras.models.save_model(model, saved_model_dir, save_format=save_format)
loaded = keras.models.load_model(saved_model_dir)
input_ = keras.Input((4,))
model = keras.Model(
input_,
[keras.layers.Softmax(name='head_0')(keras.layers.Dense(3)(input_)),
keras.layers.Softmax(name='head_1')(keras.layers.Dense(5)(input_))])
metric = keras.metrics.BinaryAccuracy()
model.compile(optimizer='rmsprop',
loss='mse',
metrics={'head_0': [metric, 'accuracy']})

x = np.random.rand(2, 4)
y = {'head_0': np.random.randint(2, size=(2, 3)),
'head_1': np.random.randint(2, size=(2, 5))}

# Make sure metrix prefixing works the same regardless of whether the user
# has fit the model before saving.
if fit:
model.fit(x, y, verbose=0)

# Save and reload.
save_format = testing_utils.get_save_format()
saved_model_dir = self._save_model_dir()
keras.models.save_model(model, saved_model_dir, save_format=save_format)
loaded = keras.models.load_model(saved_model_dir)

# Make sure the metrics names from the model before saving match the loaded
# model.
Expand Down
30 changes: 13 additions & 17 deletions keras/saving/saved_model/saved_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,26 +1276,22 @@ def zero_metric(y_true, y_pred):
del y_true, y_pred
return 0

with self.cached_session():
custom_metric = CustomMetric()
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
model.compile(loss='mse', optimizer='SGD',
metrics=[custom_metric, zero_metric])
self.evaluate(tf.compat.v1.global_variables_initializer())
self.evaluate([v.initializer for v in custom_metric.variables])
model.fit(x, y)
saved_model_dir = self._save_model_dir()
tf.saved_model.save(model, saved_model_dir)
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
model.compile(loss='mse', optimizer='SGD',
metrics=[CustomMetric(), zero_metric])
model.fit(x, y)
saved_model_dir = self._save_model_dir()
tf.saved_model.save(model, saved_model_dir)

with self.assertRaisesRegex(ValueError, 'custom_objects'):
keras_load.load(saved_model_dir)
with self.assertRaisesRegex(ValueError, 'custom_objects'):
keras_load.load(saved_model_dir)

with generic_utils.CustomObjectScope(
{'CustomMetric': CustomMetric, 'zero_metric': zero_metric}):
loaded = keras_load.load(saved_model_dir)
with generic_utils.CustomObjectScope(
{'CustomMetric': CustomMetric, 'zero_metric': zero_metric}):
loaded = keras_load.load(saved_model_dir)

self.evaluate([v.initializer for v in loaded.variables])
loaded.fit(x, y)
self.evaluate([v.initializer for v in loaded.variables])
loaded.fit(x, y)

if __name__ == '__main__':
tf.test.main()

0 comments on commit 58c20db

Please sign in to comment.