diff --git a/apex/amp/_initialize.py b/apex/amp/_initialize.py index 3d31b6641..2193e3128 100644 --- a/apex/amp/_initialize.py +++ b/apex/amp/_initialize.py @@ -40,12 +40,12 @@ def applier(value, fn): return value elif isinstance(value, np.ndarray): return value + elif hasattr(value, "to"): # Allow handling of custom batch classes + return fn(value) elif isinstance(value, container_abcs.Mapping): return {applier(k, fn) : applier(v, fn) for k, v in value.items()} elif isinstance(value, container_abcs.Iterable): return type(value)(applier(v, fn) for v in value) - elif hasattr(value, "to"): # Allow handling of custom batch classes - return fn(value) else: # Do I want this to fire off even if someone chooses to pass something ordinary like # an int or float? May be more annoying than it's worth.