Skip to content

Commit

Permalink
Remove redundancy in device_mesh specification in ModelParallel
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jun 17, 2024
1 parent 45e1175 commit 0ee33b5
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 28 deletions.
58 changes: 35 additions & 23 deletions keras/src/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,9 +520,11 @@ class ModelParallel(Distribution):
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)
distribution = ModelParallel(device_mesh=device_mesh,
layout_map=layout_map,
batch_dim_name='batch')
distribution = ModelParallel(
layout_map=layout_map,
batch_dim_name='batch',
)
# Set the global distribution, or via `with distribution.scope():`
set_distribution(distribution)
Expand All @@ -533,38 +535,49 @@ class ModelParallel(Distribution):
You can quickly update the device mesh shape to change the sharding factor
of the variables. E.g.
```
```python
# With only the shape change for the device mesh, the variables will be
# sharded across 8 devices instead of 4, which further reduces the memory
# footprint of variables on each of the device.
device_mesh = DeviceMesh(shape=(1, 8), axis_names=('batch', 'model'),
devices=devices)
device_mesh = DeviceMesh(
shape=(1, 8),
axis_names=('batch', 'model'),
devices=devices,
)
```
To figure out a proper layout mapping rule for all the model variables, you
can first list out all the model variable paths, which will be used as the
key to map the variables to `TensorLayout`.
e.g.
```
```python
model = create_model()
for v in model.variables:
print(v.path)
```
Args:
device_mesh: `DeviceMesh` instance for physical device and its
logical mapping.
layout_map: `LayoutMap` instance which map the variable path to the
corresponding `TensorLayout`. The axis names of the
`TensorLayout`s should match to the axis names in the
device_mesh, or exception will be raised.
batch_dim_name: optional string, the axis name in the `device_mesh`
corresponding tensor layout.
batch_dim_name: Optional string, the axis name in the device mesh
(of the `layout_map` object)
that will be used to distribute data. If unspecified, the
first axis from the `device_mesh` will be used.
first axis from the device mesh will be used.
"""

def __init__(self, device_mesh, layout_map, batch_dim_name=None):
def __init__(self, *, layout_map=None, batch_dim_name=None, **kwargs):
kwargs.pop("device_mesh", None)
if layout_map is None:
raise ValueError("You must specify a layout_map argument.")
if not isinstance(layout_map, LayoutMap):
raise ValueError(
"Argument `layout_map` must be a `LayoutMap` instance. "
f"Received: layout_map={layout_map}"
)
device_mesh = layout_map.device_mesh
super().__init__(device_mesh)
self._layout_map = layout_map
self._batch_dim_name = batch_dim_name or self.device_mesh.axis_names[0]
Expand Down Expand Up @@ -693,11 +706,11 @@ class LayoutMap(collections.abc.MutableMapping):
as value, and will be converted to `TensorLayout`.
```python
layout_map = LayoutMap(device_mesh=None)
layout_map['dense.*kernel'] = (None, 'model') # layout_2d
layout_map['dense.*bias'] = ('model',) # layout_1d
layout_map['conv2d.*kernel'] = TensorLayout((None, None, None, 'model'))
layout_map['conv2d.*bias'] = TensorLayout(('model',)) # layout_1d
layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)
layout_1 = layout_map['dense_1.kernel'] # layout_1 == layout_2d
layout_2 = layout_map['dense_1.bias'] # layout_2 == layout_1d
Expand All @@ -710,11 +723,10 @@ class LayoutMap(collections.abc.MutableMapping):
```
Args:
device_mesh: An optional `DeviceMesh` that can be used to populate the
`TensorLayout.device_mesh` if `TensorLayout.device_mesh` is not set.
device_mesh: `keras.distribution.DeviceMesh` instance.
"""

def __init__(self, device_mesh=None):
def __init__(self, device_mesh):
self._layout_map = collections.OrderedDict()
self._device_mesh = device_mesh

Expand Down
11 changes: 6 additions & 5 deletions keras/src/distribution/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def test_distribute_weights(self):
layout_map[".*bias"] = distribution_lib.TensorLayout(["model"])

distribution = distribution_lib.ModelParallel(
self.device_mesh, layout_map, batch_dim_name="data"
layout_map=layout_map, batch_dim_name="data"
)
kernel = backend.Variable(initializer=np.arange(8, 4), name="kernel")
bias = backend.Variable(initializer=np.arange(4), name="bias")
Expand All @@ -294,7 +294,7 @@ def test_distribute_weights(self):
def test_distribute_data(self):
layout_map = distribution_lib.LayoutMap(self.device_mesh)
distribution = distribution_lib.ModelParallel(
self.device_mesh, layout_map, batch_dim_name="data"
layout_map=layout_map, batch_dim_name="data"
)

data = np.arange(16).reshape((4, 2, 2))
Expand All @@ -309,7 +309,7 @@ def test_get_tensor_layout(self):
layout_map["/model/layer/tensor"] = ("data", None)

distribution = distribution_lib.ModelParallel(
self.device_mesh, layout_map, batch_dim_name="data"
layout_map=layout_map, batch_dim_name="data"
)
layout = distribution.get_tensor_layout("/model/layer/tensor")
self.assertIs(layout.device_mesh, self.device_mesh)
Expand All @@ -321,8 +321,9 @@ def test_get_tensor_layout(self):
def test_distribute_dataset(self):
# We can only verify the single worker/process case in OSS for now.
dataset = tf.data.Dataset.range(8)
distribution = distribution = distribution_lib.ModelParallel(
self.device_mesh, {}, batch_dim_name="data"
layout_map = distribution_lib.LayoutMap(self.device_mesh)
distribution = distribution_lib.ModelParallel(
layout_map=layout_map, batch_dim_name="data"
)
distributed_dataset = distribution.distribute_dataset(dataset)
self.assertIs(dataset, distributed_dataset)
Expand Down

0 comments on commit 0ee33b5

Please sign in to comment.