Skip to content

Commit e6e5d6d

Browse files
authored
Merge pull request tensorflow#37919 from reedwm/none_grad_fix
2.2-rc2 cherry-pick request: Fix crash in Model.fit() if a gradient is None
2 parents f4b139e + ce7990b commit e6e5d6d

File tree

3 files changed

+72
-6
lines changed

3 files changed

+72
-6
lines changed

tensorflow/python/keras/distribute/distribute_strategy_test.py

+27
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,33 @@ def test_predict_multi_output_model_with_partial_batch(
804804
atol=1e-4,
805805
rtol=1e-4)
806806

807+
@combinations.generate(all_strategy_combinations_plus_run_distributed())
808+
def test_gradients_are_none(self, distribution):
809+
810+
if not context.executing_eagerly():
811+
self.skipTest('None gradients are not supported in graph mode')
812+
813+
class DenseWithExtraWeight(keras.layers.Dense):
814+
815+
def build(self, input_shape):
816+
# Gradients w.r.t. extra_weights are None
817+
self.extra_weight_1 = self.add_weight('extra_weight_1', shape=(),
818+
initializer='ones')
819+
super(DenseWithExtraWeight, self).build(input_shape)
820+
self.extra_weight_2 = self.add_weight('extra_weight_2', shape=(),
821+
initializer='ones')
822+
823+
with distribution.scope():
824+
model = keras.Sequential([DenseWithExtraWeight(4, input_shape=(4,))])
825+
model.compile('adam', 'mse')
826+
827+
inputs = np.random.normal(size=(64, 4))
828+
targets = np.random.normal(size=(64, 4))
829+
old_kernel = model.get_weights()[1]
830+
model.fit(inputs, targets)
831+
new_kernel = model.get_weights()[1]
832+
self.assertNotAllEqual(old_kernel, new_kernel)
833+
807834

808835
class TestDistributionStrategyWithDatasets(test.TestCase,
809836
parameterized.TestCase):

tensorflow/python/keras/engine/training_test.py

+24
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,30 @@ def apply_gradients(self, grads_and_vars, name=None): # pylint: disable=useless
13641364
model.fit(x, y)
13651365
self.assertEqual(model.optimizer.aggregate_gradients_called, True)
13661366

1367+
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
1368+
def test_gradients_are_none(self):
1369+
1370+
class DenseWithExtraWeight(keras.layers.Dense):
1371+
1372+
def build(self, input_shape):
1373+
# Gradients w.r.t. extra_weights are None
1374+
self.extra_weight_1 = self.add_weight('extra_weight_1', shape=(),
1375+
initializer='ones')
1376+
super(DenseWithExtraWeight, self).build(input_shape)
1377+
self.extra_weight_2 = self.add_weight('extra_weight_2', shape=(),
1378+
initializer='ones')
1379+
1380+
model = keras.models.Sequential([DenseWithExtraWeight(4, input_shape=(4,))])
1381+
# Test clipping can handle None gradients
1382+
opt = keras.optimizer_v2.adam.Adam(clipnorm=1.0, clipvalue=1.0)
1383+
model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly())
1384+
inputs = np.random.normal(size=(64, 4))
1385+
targets = np.random.normal(size=(64, 4))
1386+
old_kernel = model.get_weights()[1]
1387+
model.fit(inputs, targets)
1388+
new_kernel = model.get_weights()[1]
1389+
self.assertNotAllEqual(old_kernel, new_kernel)
1390+
13671391

13681392
class TestExceptionsAndWarnings(keras_parameterized.TestCase):
13691393

tensorflow/python/keras/optimizer_v2/optimizer_v2.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -344,15 +344,16 @@ def _clip_gradients(self, grads):
344344
raise ValueError("Gradient clipping in the optimizer "
345345
"(by setting clipnorm or clipvalue) is currently "
346346
"unsupported when using a distribution strategy.")
347-
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
347+
grads = [None if g is None else clip_ops.clip_by_norm(g, self.clipnorm)
348+
for g in grads]
348349
if self.clipvalue is not None:
349350
if distribute_ctx.has_strategy():
350351
raise ValueError("Gradient clipping in the optimizer "
351352
"(by setting clipnorm or clipvalue) is currently "
352353
"unsupported when using a distribution strategy.")
354+
v = self.clipvalue
353355
grads = [
354-
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
355-
for g in grads
356+
None if g is None else clip_ops.clip_by_value(g, -v, v) for g in grads
356357
]
357358
return grads
358359

@@ -521,6 +522,7 @@ def _aggregate_gradients(self, grads_and_vars):
521522
A list of all-reduced gradients.
522523
"""
523524
grads_and_vars = list(grads_and_vars)
525+
filtered_grads_and_vars = _filter_grads(grads_and_vars)
524526
def all_reduce_fn(distribution, grads_and_vars):
525527
return distribution.extended.batch_reduce_to(
526528
ds_reduce_util.ReduceOp.SUM, grads_and_vars)
@@ -529,9 +531,22 @@ def all_reduce_fn(distribution, grads_and_vars):
529531
# replica context.
530532
# TODO(b/150507409): Do not switch to a cross-replica context once the bug
531533
# is fixed.
532-
if grads_and_vars:
533-
return distribute_ctx.get_replica_context().merge_call(
534-
all_reduce_fn, args=(grads_and_vars,))
534+
if filtered_grads_and_vars:
535+
reduced = distribute_ctx.get_replica_context().merge_call(
536+
all_reduce_fn, args=(filtered_grads_and_vars,))
537+
else:
538+
reduced = []
539+
# Copy 'reduced' but add None gradients back in
540+
reduced_with_nones = []
541+
reduced_pos = 0
542+
for g, _ in grads_and_vars:
543+
if g is None:
544+
reduced_with_nones.append(None)
545+
else:
546+
reduced_with_nones.append(reduced[reduced_pos])
547+
reduced_pos += 1
548+
assert reduced_pos == len(reduced), "Failed to add all gradients"
549+
return reduced_with_nones
535550

536551
def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
537552
"""`apply_gradients` using a `DistributionStrategy`."""

0 commit comments

Comments
 (0)