@@ -344,15 +344,16 @@ def _clip_gradients(self, grads):
344
344
raise ValueError ("Gradient clipping in the optimizer "
345
345
"(by setting clipnorm or clipvalue) is currently "
346
346
"unsupported when using a distribution strategy." )
347
- grads = [clip_ops .clip_by_norm (g , self .clipnorm ) for g in grads ]
347
+ grads = [None if g is None else clip_ops .clip_by_norm (g , self .clipnorm )
348
+ for g in grads ]
348
349
if self .clipvalue is not None :
349
350
if distribute_ctx .has_strategy ():
350
351
raise ValueError ("Gradient clipping in the optimizer "
351
352
"(by setting clipnorm or clipvalue) is currently "
352
353
"unsupported when using a distribution strategy." )
354
+ v = self .clipvalue
353
355
grads = [
354
- clip_ops .clip_by_value (g , - self .clipvalue , self .clipvalue )
355
- for g in grads
356
+ None if g is None else clip_ops .clip_by_value (g , - v , v ) for g in grads
356
357
]
357
358
return grads
358
359
@@ -521,6 +522,7 @@ def _aggregate_gradients(self, grads_and_vars):
521
522
A list of all-reduced gradients.
522
523
"""
523
524
grads_and_vars = list (grads_and_vars )
525
+ filtered_grads_and_vars = _filter_grads (grads_and_vars )
524
526
def all_reduce_fn (distribution , grads_and_vars ):
525
527
return distribution .extended .batch_reduce_to (
526
528
ds_reduce_util .ReduceOp .SUM , grads_and_vars )
@@ -529,9 +531,22 @@ def all_reduce_fn(distribution, grads_and_vars):
529
531
# replica context.
530
532
# TODO(b/150507409): Do not switch to a cross-replica context once the bug
531
533
# is fixed.
532
- if grads_and_vars :
533
- return distribute_ctx .get_replica_context ().merge_call (
534
- all_reduce_fn , args = (grads_and_vars ,))
534
+ if filtered_grads_and_vars :
535
+ reduced = distribute_ctx .get_replica_context ().merge_call (
536
+ all_reduce_fn , args = (filtered_grads_and_vars ,))
537
+ else :
538
+ reduced = []
539
+ # Copy 'reduced' but add None gradients back in
540
+ reduced_with_nones = []
541
+ reduced_pos = 0
542
+ for g , _ in grads_and_vars :
543
+ if g is None :
544
+ reduced_with_nones .append (None )
545
+ else :
546
+ reduced_with_nones .append (reduced [reduced_pos ])
547
+ reduced_pos += 1
548
+ assert reduced_pos == len (reduced ), "Failed to add all gradients"
549
+ return reduced_with_nones
535
550
536
551
def _distributed_apply (self , distribution , grads_and_vars , name , apply_state ):
537
552
"""`apply_gradients` using a `DistributionStrategy`."""
0 commit comments