Skip to content

Commit

Permalink
Experiment new layout map API for keras models.
Browse files Browse the repository at this point in the history
Please see the layout_map docstring for how the new API works.

The new approach give user just a scope, which user should use to create their DTensor model. There is some assumptions for the code happens in the context:

1. All the weights will be convert to a lazyinitVariable. They will be replaced to DVariable at different time by different type of model.

2. Any model created within the scope will have the layout_map attached to them, and the map will be used to create DVariable for the model. If there is no model created in scope, then the LazyInitVariable will never be converted.

3. For subclass model, since the weights are created when the first time __call__ is invoked, we inject the __call__ to first init the variable with lazyinitVariable, and then map to DVariable. In this case, the layout_map_scope actually does nothing when user create the subclass model, since the weights are not created yet. The scope only allow the model fetch the layout_map, which can be inject to model.__init__ as well. But for API simplification purpose, we consolidate into just one API.

4. For functional/sequential model, since their weights are created eagerly, the DVariable creation happens at the init_graph_network. The scope approach is mostly used for this case, since the variable creation happens before functional.__init__. It will be too late if we inject the layout_map at __init__.

5. The DVariable creation has some special logic for disabling lazy_variable_scope, which was causing issue for functional model. The variable initializer usually uses tf.random.Generator under the hood. It will create the stateVar when init, and will be convert to a LazyInitVariable if the init happens in the scope. We would like to disable the scope for that case, since the init should happen with the tf.function on a dtensor device scope. The stateVar will be created as DVariable.

PiperOrigin-RevId: 432248522
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Mar 3, 2022
1 parent ee8117b commit 2e664be
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 86 deletions.
9 changes: 7 additions & 2 deletions keras/dtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@

# Conditional import the dtensor API, since it is currently broken in OSS.
if _DTENSOR_API_ENABLED:
# pylint: disable=g-direct-tensorflow-import, g-import-not-at-top
from tensorflow.dtensor import python as dtensor_api
try:
# pylint: disable=g-direct-tensorflow-import, g-import-not-at-top
from tensorflow.dtensor import python as dtensor_api
except ImportError:
# TODO(b/222341036): Remove this conditional import after dtensor have a
# trimmed target that can be used by Keras.
dtensor_api = None
else:
# Leave it with a placeholder, so that the import line from other python file
# will not break.
Expand Down
4 changes: 1 addition & 3 deletions keras/dtensor/integration_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_model_with_layout_map(layout_map):
a CNN Keras model used for MNIST
"""

def model_fn():
with layout_map_lib.layout_map_scope(layout_map):
# Define a CNN model to recognize MNIST digits.
model = models.Sequential()
model.add(
Expand Down Expand Up @@ -79,8 +79,6 @@ def model_fn():
))
return model

return layout_map_lib.init_model_with_layout_map(layout_map, model_fn)


def get_all_replicated_layout_map(mesh):
layout_map = layout_map_lib.LayoutMap(mesh=mesh)
Expand Down
97 changes: 46 additions & 51 deletions keras/dtensor/layout_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""Library for map layout and corresponding tf.Variable."""

import collections
import contextlib
import re
import threading

from keras.dtensor import dtensor_api as dtensor
from keras.dtensor import lazy_variable
Expand All @@ -32,6 +34,13 @@
'_non_trainable_weights']


_LAYOUT_MAP = threading.local()


def get_current_layout_map():
return getattr(_LAYOUT_MAP, 'layout_map', None)


class LayoutMap(collections.MutableMapping):

def __init__(self, mesh=None):
Expand Down Expand Up @@ -95,8 +104,9 @@ def get_default_mesh(self):
return self._default_mesh


def init_model_with_layout_map(layout_map, model_init_fn):
"""Apply the layout to all the tf.Variables in the model_init_fn.
@contextlib.contextmanager
def layout_map_scope(layout_map):
"""Apply the layout to all the tf.Variables created under the scope.
Create a scope that all the tf.Variable created under this scope
will be lazily inited, and initialized later on with proper layout when the
Expand Down Expand Up @@ -124,59 +134,51 @@ def init_model_with_layout_map(layout_map, model_init_fn):
## Subclassed model
class SubclassModel(tf.keras.Model):
def __init__(self, name=None):
super().__init__(name=name)
self.d1 = tf.keras.layers.Dense(1000)
self.d2 = tf.keras.layers.Dense(1000)
def __init__(self, name=None):
super().__init__(name=name)
self.d1 = tf.keras.layers.Dense(1000)
self.d2 = tf.keras.layers.Dense(1000)
def call(self, inputs):
x = self.d1(inputs)
return self.d2(x)
def call(self, inputs):
x = self.d1(inputs)
return self.d2(x)
def model_init_fn():
with layout_map_scope(layout_map):
model = SubclassModel()
inputs = tf.keras.Input((10,), batch_size=10)
model(inputs)
return model
# Triggering the creation of weights within or outside of the scope works
inputs = tf.zeros((10, 10))
results = model(inputs)
model_with_layout = layout_map_lib.init_model_with_layout_map(
layout_map, model_init_fn)
model_with_layout.d1.kernel.layout == layout_1
model_with_layout.d1.bias.layout == layout_2
model_with_layout.d2.kernel.layout == layout_3
model_with_layout.d2.bias.layout == layout_4
model.d1.kernel.layout == layout_1
model.d1.bias.layout == layout_2
model.d2.kernel.layout == layout_3
model.d2.bias.layout == layout_4
## Functional model
def model_init_fn():
with layout_map_scope(layout_map):
inputs = tf.keras.Input((10,), batch_size=10)
x = tf.keras.layers.Dense(20, name='d1')(inputs)
output = tf.keras.layers.Dense(30, name='d2')(x)
model = tf.keras.Model(inputs, output)
return model
model_with_layout = layout_map_lib.init_model_with_layout_map(
layout_map, model_init_fn)
d1 = model_with_layout.layers[1]
d2 = model_with_layout.layers[2]
d1 = model.layers[1]
d2 = model.layers[2]
d1.kernel.layout == layout_1
d1.bias.layout == layout_2
d1.kernel.layout == layout_3
d1.bias.layout == layout_4
## Sequential model
def model_init_fn():
with layout_map_scope(layout_map):
model = tf.keras.Sequential([
tf.keras.layers.Dense(20, name='d1', input_shape=(10,)),
tf.keras.layers.Dense(30, name='d2')
])
return model
model_with_layout = layout_map_lib.init_model_with_layout_map(
layout_map, model_init_fn)
d1 = model_with_layout.layers[0]
d2 = model_with_layout.layers[1]
d1 = model.layers[0]
d2 = model.layers[1]
d1.kernel.layout == layout_1
d1.bias.layout == layout_2
Expand All @@ -188,28 +190,20 @@ def model_init_fn():
layout_map: a LayoutMap which contains the variable_object_path (string) ->
Layout. When a layout is not found for the variable, a default all
replicated layout will be created for the variable.
model_init_fn: A function that will trigger the weights creation and return
a Keras model instance. Eg, for a Keras functional model, the
`tf.keras.Model(inputs, output)` will trigger the weights creation. For
subclass model which doesn't have the `build()` method implemented, user
will need to do `model(tf.keras.Input(shape=shape))` to init the weights.
Note that this method need to return the model instance, and the instance
will be used as the root object to track all the tf.Variables that
attached to it.
Returns:
The same model instance returned from model_init_fn, with all tf.Variable
have the proper layout populated.
Yields:
A context that will lazily initialize all `tf.Variable` objects
within the model, with their attributed layouts.
"""
with lazy_variable.lazy_init_scope():
model = model_init_fn()
previous_layout_map = get_current_layout_map()
global _LAYOUT_MAP
_LAYOUT_MAP.layout_map = layout_map

if model._is_graph_network: # pylint: disable=protected-access
# Functional/Sequential model
return _map_functional_model_variable(model, layout_map)
else:
# Subclass Model
return _map_subclass_model_variable(model, layout_map)
with lazy_variable.lazy_init_scope():
try:
yield
finally:
_LAYOUT_MAP.layout_map = previous_layout_map


def _map_subclass_model_variable(model, layout_map):
Expand Down Expand Up @@ -340,7 +334,8 @@ def _create_dvariable(layout_map, object_path, variable):
rank=variable_rank)
init_val = variable._initial_value # pylint: disable=protected-access
if callable(init_val):
init_val = utils.call_with_layout(init_val, layout)
with lazy_variable.disable_init_variable_creator():
init_val = utils.call_with_layout(init_val, layout)
else:
# The init value is probably already created as a tensor, we will just copy
# it to mesh and give it a proper layout.
Expand Down
47 changes: 19 additions & 28 deletions keras/dtensor/layout_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,16 @@ def test_init_subclass_model_variable_with_layout(self):
layout_map['d2.kernel'] = self.layout_2d
layout_map['d2.bias'] = self.layout_1d

def model_init_fn():
with layout_map_lib.layout_map_scope(layout_map):
model = SubclassModel(name='model')
inputs = tf.keras.Input((10,), batch_size=10)
model(inputs)
return model

model_with_layout = layout_map_lib.init_model_with_layout_map(
layout_map, model_init_fn)
d1 = model_with_layout.d1
d2 = model_with_layout.d2
# Init the model with eager tensor, make sure the model weights have correct
# layout, as well as produce correct result.
inputs = tf.zeros((10, 10), layout=self.layout_2d)
result = model(inputs)
self.assertAllClose(result, tf.zeros((10, 1000)))
d1 = model.d1
d2 = model.d2
self.assertEqual(d1.kernel.layout, self.layout_2d)
self.assertEqual(d1.bias.layout, self.layout_1d)
self.assertEqual(d2.kernel.layout, self.layout_2d)
Expand All @@ -195,8 +195,7 @@ def model_init_fn():
self.assertIs(d2.kernel, d2._trainable_weights[0])
self.assertIs(d2.bias, d2._trainable_weights[1])

result = model_with_layout(tf.zeros((10, 10), layout=self.layout_2d),
training=True)
result = model(tf.zeros((10, 10), layout=self.layout_2d), training=True)
self.assertAllClose(result, tf.zeros((10, 1000), layout=self.layout_2d))

def test_init_functional_model_variable_with_layout(self):
Expand All @@ -210,21 +209,18 @@ def test_init_functional_model_variable_with_layout(self):
layout_map['d2.kernel'] = self.layout_2d
layout_map['d2.bias'] = self.layout_1d

def model_init_fn():
with layout_map_lib.layout_map_scope(layout_map):
inputs = tf.keras.Input((10,), batch_size=10)
x = layers.Dense(20, name='d1')(inputs)
x = layers.Dropout(0.1)(x)
output = layers.Dense(30, name='d2')(x)

model = tf.keras.Model(inputs, output)
return model

model_with_layout = layout_map_lib.init_model_with_layout_map(
layout_map, model_init_fn)
# It includes input layer as well.
self.assertLen(model_with_layout.layers, 4)
d1 = model_with_layout.layers[1]
d2 = model_with_layout.layers[3]
self.assertLen(model.layers, 4)
d1 = model.layers[1]
d2 = model.layers[3]

self.assertEqual(d1.kernel.layout, self.layout_2d)
self.assertEqual(d1.bias.layout, self.layout_1d)
Expand All @@ -238,8 +234,7 @@ def model_init_fn():
self.assertIs(d2.kernel, d2._trainable_weights[0])
self.assertIs(d2.bias, d2._trainable_weights[1])

result = model_with_layout(tf.zeros((10, 10), layout=self.layout_2d),
training=True)
result = model(tf.zeros((10, 10), layout=self.layout_2d), training=True)
self.assertAllClose(result, tf.zeros((10, 30), layout=self.layout_2d))

def test_init_sequential_model_variable_with_layout(self):
Expand All @@ -253,19 +248,16 @@ def test_init_sequential_model_variable_with_layout(self):
layout_map['d2.kernel'] = self.layout_2d
layout_map['d2.bias'] = self.layout_1d

def model_init_fn():
with layout_map_lib.layout_map_scope(layout_map):
model = tf.keras.Sequential([
layers.Dense(20, name='d1', input_shape=(10,)),
layers.Dropout(0.1),
layers.Dense(30, name='d2')
])
return model

model_with_layout = layout_map_lib.init_model_with_layout_map(
layout_map, model_init_fn)
self.assertLen(model_with_layout.layers, 3)
d1 = model_with_layout.layers[0]
d2 = model_with_layout.layers[2]
self.assertLen(model.layers, 3)
d1 = model.layers[0]
d2 = model.layers[2]

self.assertEqual(d1.kernel.layout, self.layout_2d)
self.assertEqual(d1.bias.layout, self.layout_1d)
Expand All @@ -279,8 +271,7 @@ def model_init_fn():
self.assertIs(d2.kernel, d2._trainable_weights[0])
self.assertIs(d2.bias, d2._trainable_weights[1])

result = model_with_layout(tf.zeros((10, 10), layout=self.layout_2d),
training=True)
result = model(tf.zeros((10, 10), layout=self.layout_2d), training=True)
self.assertAllClose(result, tf.zeros((10, 30), layout=self.layout_2d))


Expand Down
22 changes: 20 additions & 2 deletions keras/dtensor/lazy_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""Lazily initialized variables, useful for creating a symbolic Keras model."""

import threading

# pylint: disable=g-direct-tensorflow-import
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import context
Expand All @@ -26,6 +28,9 @@
from tensorflow.python.util import tf_contextlib


_DISABLE_LAZY_VARIABLE_INIT = threading.local()


def _infer_shape_dtype_and_create_handle(initial_value, shape, dtype, name):
"""Infer shape and dtype from initial_value and create a variable handle."""
with ops.name_scope(name, "Variable", skip_on_eager=False) as name:
Expand Down Expand Up @@ -191,12 +196,25 @@ def create_and_initialize(self):


def _lazy_init_variable_creator(next_creator, **kwargs):
del next_creator
return LazyInitVariable(**kwargs)
if getattr(_DISABLE_LAZY_VARIABLE_INIT, "disabled", False):
return next_creator(**kwargs)
else:
return LazyInitVariable(**kwargs)


@tf_contextlib.contextmanager
def lazy_init_scope():
with variable_scope.variable_creator_scope(_lazy_init_variable_creator):
yield


@tf_contextlib.contextmanager
def disable_init_variable_creator():
try:
global _DISABLE_LAZY_VARIABLE_INIT
existing_value = getattr(_DISABLE_LAZY_VARIABLE_INIT, "disabled", False)
_DISABLE_LAZY_VARIABLE_INIT.disabled = True
yield
finally:
_DISABLE_LAZY_VARIABLE_INIT.disabled = existing_value

1 change: 1 addition & 0 deletions keras/engine/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ py_library(
"//keras:regularizers",
"//keras/distribute",
"//keras/distribute:distribute_coordinator_utils",
"//keras/dtensor:layout_map",
"//keras/initializers",
"//keras/metrics",
"//keras/mixed_precision:autocast_variable",
Expand Down
7 changes: 7 additions & 0 deletions keras/engine/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import itertools
import warnings
from keras import backend
from keras.dtensor import layout_map as layout_map_lib
from keras.engine import base_layer
from keras.engine import base_layer_utils
from keras.engine import functional_utils
Expand Down Expand Up @@ -258,6 +259,12 @@ def _init_graph_network(self, inputs, outputs):
self._set_save_spec(self._nested_inputs)
tf_utils.assert_no_legacy_layers(self.layers)

# Note that this method is used by both functional and sequential model,
# so we can't just this method in functional.__init__, which will miss the
# coverage of sequential model.
if self._layout_map:
layout_map_lib._map_functional_model_variable(self, self._layout_map)

@property
def input(self):
"""Retrieves the input tensor(s) of a layer.
Expand Down
Loading

0 comments on commit 2e664be

Please sign in to comment.