Skip to content

Commit

Permalink
Fix missing Keras variables shim (horovod#3907)
Browse files Browse the repository at this point in the history
Signed-off-by: Nicolas Castet <[email protected]>
  • Loading branch information
nvcastet authored May 4, 2023
1 parent 39c8f7c commit 67ea042
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions horovod/_keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,15 @@ def __init__(self, **kwargs):
scale_local_gradients=scale_local_gradients
)

def variables(self):
if _IS_TF2:
return super(self.__class__, self).variables()
return self.get_weights()

if version.parse(tf.__version__) >= version.parse('2.12.0'):
variables = property(variables)
@property
def variables(self):
return CallableList(super(self.__class__, self).variables())
else:
def variables(self):
if _IS_TF2:
return super(self.__class__, self).variables()
return self.get_weights()

def register_local_var(self, var):
"""Registers a source/variable as worker local. Horovod will not perform any global
Expand Down Expand Up @@ -253,6 +255,13 @@ def apply_gradients(self, *args, **kwargs):
return cls.from_config(optimizer.get_config())


class CallableList(list):
"""Temporary shim to support both `opt.variables()` and `opt.variables`."""

def __call__(self):
return self


def _eval(backend, op_or_result):
if hvd._executing_eagerly():
return op_or_result
Expand Down

0 comments on commit 67ea042

Please sign in to comment.