Skip to content

Commit

Permalink
Add DTensor layout injection for normalization layers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 435236178
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Mar 17, 2022
1 parent 8198b25 commit f98f87d
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 0 deletions.
5 changes: 5 additions & 0 deletions keras/dtensor/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def setUp(self):
# ('conv3dtranspose', layers.Conv3DTranspose,
# {'filters': 4, 'kernel_size': (3, 3, 3)},
# {'kernel': 5, 'bias': 1}, [10, 28, 28, 28, 3]),
('batch_norm', layers.BatchNormalization, {'fused': False},
{'beta': 1, 'gamma': 1, 'moving_mean': 1, 'moving_variance': 1},
[10, 28, 28, 3]),
('layer_norm', layers.LayerNormalization, {'dtype': tf.float64},
{'beta': 1, 'gamma': 1}, [10, 28, 28, 3])
)
def test_layer(self, layer_cls, init_args, variable_settings, input_shape,
input_dtype=np.float32):
Expand Down
2 changes: 2 additions & 0 deletions keras/layers/normalization/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ py_library(
"//keras:backend",
"//keras:constraints",
"//keras:regularizers",
"//keras/dtensor:utils",
"//keras/engine:base_layer",
"//keras/engine:input_spec",
"//keras/initializers",
Expand All @@ -64,6 +65,7 @@ py_library(
"//:expect_tensorflow_installed",
"//keras:constraints",
"//keras:regularizers",
"//keras/dtensor:utils",
"//keras/engine:base_layer",
"//keras/initializers",
],
Expand Down
2 changes: 2 additions & 0 deletions keras/layers/normalization/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras import constraints
from keras import initializers
from keras import regularizers
from keras.dtensor import utils
from keras.engine.base_layer import Layer
from keras.engine.input_spec import InputSpec
from keras.utils import control_flow_util
Expand Down Expand Up @@ -1215,6 +1216,7 @@ class BatchNormalization(BatchNormalizationBase):
"""
_USE_V2_BEHAVIOR = True

@utils.allow_initializer_layout
def __init__(self,
axis=-1,
momentum=0.99,
Expand Down
2 changes: 2 additions & 0 deletions keras/layers/normalization/layer_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from keras import constraints
from keras import initializers
from keras import regularizers
from keras.dtensor import utils
from keras.engine.base_layer import Layer
from keras.utils import tf_utils

Expand Down Expand Up @@ -149,6 +150,7 @@ class LayerNormalization(Layer):
- [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450).
"""

@utils.allow_initializer_layout
def __init__(self,
axis=-1,
epsilon=1e-3,
Expand Down

0 comments on commit f98f87d

Please sign in to comment.