Skip to content

Commit

Permalink
Update keras metrics to support DTensor mesh.
Browse files Browse the repository at this point in the history
So far all the weights will have fully replicated layout (since they are usually small counter). We inject the kwargs `mesh` to all the metrics, using the similar approach we did for layers.

PiperOrigin-RevId: 433874639
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Mar 11, 2022
1 parent b6e8b99 commit e71391c
Show file tree
Hide file tree
Showing 6 changed files with 1 addition and 205 deletions.
16 changes: 0 additions & 16 deletions keras/dtensor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,6 @@ py_library(
],
)

py_test(
name = "metrics_test",
srcs = ["metrics_test.py"],
shard_count = 4,
tags = ["no_oss"],
deps = [
":dtensor",
"//:expect_absl_installed",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
"//keras/metrics",
"//keras/utils:tf_utils",
"//learning/brain/experimental/dtensor/tests:test_util",
],
)

py_test(
name = "mnist_model_test",
srcs = ["mnist_model_test.py"],
Expand Down
92 changes: 0 additions & 92 deletions keras/dtensor/metrics_test.py

This file was deleted.

41 changes: 0 additions & 41 deletions keras/dtensor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,47 +103,6 @@ def _wrap_function(layer_instance, *args, **kwargs):
target=init_method, decorator_func=_wrap_function)


def inject_mesh(init_method):
"""Inject DTensor mesh information to an object.
This is useful for keras object like `Metric` and `Optimizer` which need
DTensor mesh to create the weights, but doesn't want to change the current
public API interface.
This is for temporary usage and eventually the mesh/layout information will be
public arguments in the `__init__` method
Sample usage:
```python
class Accuracy(tf.keras.metrics.Metric):
@inject_mesh
def __init__(self, name='accuracy', dtype=None):
super().__init__(**kwargs)
acc = Accuracy(mesh=mesh)
assert acc._mesh == mesh
```
Args:
init_method: the `__init__` method of the Keras class to annotate.
Returns:
the annotated __init__ method.
"""
def _wrap_function(instance, *args, **kwargs):
mesh = kwargs.pop("mesh", None)
# Note that the injection of _mesh need to happen before the invocation of
# __init__, since the class might need the mesh to create weights in the
# __init__.
if mesh is not None:
instance._mesh = mesh # pylint: disable=protected-access
init_method(instance, *args, **kwargs)

return tf.__internal__.decorator.make_decorator(
target=init_method, decorator_func=_wrap_function)


def call_with_layout(fn, layout, *args, **kwargs):
"""Invoke the function with inputs and relayout the result.
Expand Down
2 changes: 0 additions & 2 deletions keras/metrics/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ py_library(
"//keras:backend",
"//keras:losses",
"//keras/distribute",
"//keras/dtensor",
"//keras/dtensor:utils",
"//keras/engine:base_layer",
"//keras/engine:base_layer_utils",
"//keras/utils:generic_utils",
Expand Down
16 changes: 1 addition & 15 deletions keras/metrics/base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import warnings

from keras import backend
from keras.dtensor import dtensor_api as dtensor
from keras.dtensor import utils as dtensor_utils
from keras.engine import base_layer
from keras.engine import base_layer_utils
from keras.engine import keras_tensor
Expand Down Expand Up @@ -339,13 +337,6 @@ def add_weight(
# TODO(b/120571621): Make `ON_READ` work with Keras metrics on TPU.
if backend.is_tpu_strategy(strategy):
synchronization = tf.VariableSynchronization.ON_WRITE
if getattr(self, '_mesh', None) is not None:
# When self._mesh is set, it means this metric is used for DTensor.
additional_kwargs = {
'layout': dtensor.Layout.replicated(self._mesh,
tf.TensorShape(shape).rank)}
else:
additional_kwargs = {}

with tf.init_scope():
return super(Metric, self).add_weight(
Expand All @@ -356,8 +347,7 @@ def add_weight(
initializer=initializer,
collections=[],
synchronization=synchronization,
aggregation=aggregation,
**additional_kwargs)
aggregation=aggregation)

### End: For use by subclasses ###

Expand Down Expand Up @@ -530,7 +520,6 @@ class Sum(Reduce):
```
"""

@dtensor_utils.inject_mesh
def __init__(self, name='sum', dtype=None):
super(Sum, self).__init__(reduction=metrics_utils.Reduction.SUM,
name=name, dtype=dtype)
Expand Down Expand Up @@ -573,7 +562,6 @@ class Mean(Reduce):
```
"""

@dtensor_utils.inject_mesh
def __init__(self, name='mean', dtype=None):
super(Mean, self).__init__(
reduction=metrics_utils.Reduction.WEIGHTED_MEAN, name=name, dtype=dtype)
Expand Down Expand Up @@ -607,7 +595,6 @@ def accuracy(y_true, y_pred):
**kwargs: Keyword arguments to pass on to `fn`.
"""

@dtensor_utils.inject_mesh
def __init__(self, fn, name=None, dtype=None, **kwargs):
super(MeanMetricWrapper, self).__init__(name=name, dtype=dtype)
self._fn = fn
Expand Down Expand Up @@ -708,7 +695,6 @@ class MeanTensor(Metric):
array([[2., 3., 4., 5.]])
"""

@dtensor_utils.inject_mesh
def __init__(self, name='mean_tensor', dtype=None, shape=None):
super(MeanTensor, self).__init__(name=name, dtype=dtype)
self._shape = None
Expand Down
Loading

0 comments on commit e71391c

Please sign in to comment.