Skip to content

Commit

Permalink
Enable Keras functional model building from intermediate tensor.
Browse files Browse the repository at this point in the history
This new feature is only enabled in tf v2 eager mode, and not available in v1 graph mode.

Also updated existing test cases to cover the new feature.

PiperOrigin-RevId: 388251397
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Aug 2, 2021
1 parent d3cb74e commit 3985029
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 17 deletions.
9 changes: 9 additions & 0 deletions keras/engine/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from keras import backend
from keras.engine import base_layer
from keras.engine import base_layer_utils
from keras.engine import functional_utils
from keras.engine import input_layer as input_layer_module
from keras.engine import input_spec
from keras.engine import node as node_module
Expand Down Expand Up @@ -106,6 +107,14 @@ def __init__(self, inputs, outputs, name=None, trainable=True,
return
generic_utils.validate_kwargs(kwargs, {})
super(Functional, self).__init__(name=name, trainable=trainable)
# Check if the inputs contain any intermediate `KerasTensor` (not created
# by tf.keras.Input()). In this case we need to clone the `Node` and
# `KerasTensor` objects to mimic rebuilding a new model from new inputs.
# This feature is only enabled in TF2 not in v1 graph mode.
if tf.compat.v1.executing_eagerly_outside_functions():
if not all([functional_utils.is_input_keras_tensor(t)
for t in tf.nest.flatten(inputs)]):
inputs, outputs = functional_utils.clone_graph_nodes(inputs, outputs)
self._init_graph_network(inputs, outputs)

@tf.__internal__.tracking.no_automatic_dependency_tracking
Expand Down
9 changes: 0 additions & 9 deletions keras/engine/functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,15 +634,6 @@ def test_invalid_graphs(self):

model = training_lib.Model(inputs=[a, b], outputs=[c, d], name='model')

# input is not an Input tensor
j = layers.Input(shape=(32,), name='input_j')
j = layers.Dense(32)(j)
k = layers.Input(shape=(32,), name='input_k')
m, n = model([j, k])

with self.assertRaises(Exception):
training_lib.Model([j, k], [m, n])

# disconnected graph
j = layers.Input(shape=(32,), name='input_j')
k = layers.Input(shape=(32,), name='input_k')
Expand Down
12 changes: 9 additions & 3 deletions keras/engine/functional_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,12 @@ def clone_graph_nodes(inputs, outputs):
cloned_inputs = tf.nest.pack_sequence_as(inputs, cloned_inputs)

for kt_output in tf.nest.flatten(outputs):
cpy = keras_tensor.keras_tensor_from_type_spec(kt_output.type_spec)
cpy = _clone_keras_tensor(kt_output)
# We reuse the _keras_history here, which contains the old information. It
# is used in the Node constructor to check if the tensor "is_keras_tensor()"
# The history will be override by the Node constructor anyway for the
# corresponding layer output anyway.
cpy._keras_history = kt_output._keras_history # pylint: disable=protected-access
cloned_outputs.append(cpy)
kt_id_mapping[id(kt_output)] = cpy
cloned_outputs = tf.nest.pack_sequence_as(outputs, cloned_outputs)
Expand Down Expand Up @@ -216,6 +221,7 @@ def clone_keras_tensors(args, keras_tensor_mapping):
else:
# Create copy of keras_tensor if we haven't done it before
cpy = _clone_keras_tensor(obj)
cpy._keras_history = obj._keras_history # pylint: disable=protected-access
keras_tensor_mapping[id(obj)] = cpy
result.append(cpy)
else:
Expand All @@ -224,7 +230,7 @@ def clone_keras_tensors(args, keras_tensor_mapping):


def _clone_keras_tensor(kt):
"""Create an idential keras_tensor based on the input.
"""Create an identical keras_tensor based on the input.
We use keras_tensor_to_placeholder and keras_tensor_from_tensor to make sure
inferred shape are not lost during the copy.
Expand All @@ -233,7 +239,7 @@ def _clone_keras_tensor(kt):
kt: the input KerasTensor.
Returns:
An indential copy of the input KerasTensor.
An identical copy of the input KerasTensor.
"""
# Create a scratch graph since we don't intend to use the placeholders.
with backend._scratch_graph() as scratch_graph: # pylint: disable=protected-access
Expand Down
84 changes: 79 additions & 5 deletions keras/engine/functional_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#,============================================================================
"""Tests for functional_utils."""

import collections
import os

from keras import keras_parameterized
from keras import layers
from keras import models
Expand Down Expand Up @@ -117,18 +120,89 @@ def test_build_model_from_intermediate_tensor(self):
layer2 = layers.Dense(16)
x = layer1(inputs)
y = layer2(x)
cloned_inputs, cloned_outputs = functional_utils.clone_graph_nodes(x, y)
# Make sure the inputs and outputs are cloned.
self.assertIsNot(x, cloned_inputs)
self.assertIsNot(y, cloned_outputs)
model = models.Model(x, y)
# Make sure a new node is attached to layer2, which mimic y = layer2(x)
self.assertLen(layer2.inbound_nodes, 2)

model = models.Model(cloned_inputs, cloned_outputs)
self.assertIsInstance(model, models.Model)
# The model only contains 1 dense layer and 1 input layer.
self.assertLen(model.layers, 2)
self.assertIs(model.layers[1], layer2)

model.compile('rmsprop', 'mse')
model.fit(np.random.randn(batch_size, 32), np.random.randn(batch_size, 16))
# Test for model saving
output_path = os.path.join(self.get_temp_dir(), 'tf_keras_saved_model')
model.save(output_path, save_format='tf')
loaded_model = models.load_model(output_path)
self.assertEqual(model.summary(), loaded_model.summary())

# Also make sure the orignal inputs and y can still be used to build model
new_model = models.Model(inputs, y)
# Make sure no new node is attached to layer2
self.assertLen(layer2.inbound_nodes, 2)

self.assertLen(new_model.layers, 3)
self.assertIs(new_model.layers[1], layer1)
self.assertIs(new_model.layers[2], layer2)

def test_build_model_from_intermediate_tensor_with_complicated_model(self):
# The topology is like below:
# input1 -> dense1 -> a
# + -> c - + --> d - + --> output
# input2 -> dense1 -> b -------^ ^
# input3 -> dense2 -> e -----------------|
batch_size = 8
input1 = input_layer_lib.Input((2,))
input2 = input_layer_lib.Input((2,))
input3 = input_layer_lib.Input((8,))

dense1 = layers.Dense(8, name='dense1')
dense2 = layers.Dense(8, name='dense2')

# dense1 are shared between input1 and input2
a = dense1(input1)
b = dense1(input2)

c = layers.Add()([a, b])
# d has a residual connection from b.
d = layers.Add()([b, c])
e = dense2(input3)
output = layers.Add()([d, e])

# We skip the input2 here and use b instead.
model = models.Model([input1, b, input3], output)
# Make sure we have 8 layers, 3 for inputs, 2 for dense and 3 for Add.
# Note that dense1 is still in use by input1.
self.assertLen(model.layers, 8)
# Since the layers are not ordered, let's check class of the layers to make
# sure it match the expectation.
class_count = collections.Counter([l.__class__ for l in model.layers])
self.assertEqual(class_count[input_layer_lib.InputLayer], 3)
self.assertEqual(class_count[layers.Dense], 2)
self.assertEqual(class_count[layers.Add], 3)

model.compile('rmsprop', 'mse')
model.fit([np.random.randn(batch_size, 2),
np.random.randn(batch_size, 8), # The shape of b is (batch, 8)
np.random.randn(batch_size, 8)],
np.random.randn(batch_size, 8))
output_path = os.path.join(self.get_temp_dir(), 'tf_keras_saved_model')
model.save(output_path, save_format='tf')
loaded_model = models.load_model(output_path)
self.assertEqual(model.summary(), loaded_model.summary())

model2 = models.Model([a, b], d)
# 2 input layers and 2 Add layer.
self.assertLen(model2.layers, 4)
class_count = collections.Counter([l.__class__ for l in model2.layers])
self.assertEqual(class_count[input_layer_lib.InputLayer], 2)
self.assertEqual(class_count[layers.Add], 2)

model2.compile('rmsprop', 'mse')
model2.fit([np.random.randn(batch_size, 8),
np.random.randn(batch_size, 8)],
np.random.randn(batch_size, 8))


if __name__ == '__main__':
Expand Down

0 comments on commit 3985029

Please sign in to comment.