Skip to content

Commit

Permalink
Forward compatibility fixes for distributed backend, thanks to @ssnl
Browse files Browse the repository at this point in the history
  • Loading branch information
definitelynotmcarilli committed Sep 18, 2018
1 parent 0ec8add commit ed47ebf
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions apex/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,17 @@ class DistributedDataParallel(Module):

def __init__(self, module, message_size=10000000, shared_param=False):
super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False

# Backward/forward compatibility around
# https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36
if(hasattr(dist, "get_backend")):
self._backend = dist.get_backend()
self.backend_enum_holder = dist.DistBackend
else:
self._backend = dist._backend
self.backend_enum_holder = dist.dist_backend

self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False
self.shared_param = shared_param
self.message_size = message_size

Expand All @@ -141,7 +151,7 @@ def __init__(self, module, message_size=10000000, shared_param=False):
self.module = module
self.param_list = list(self.module.parameters())

if dist._backend == dist.dist_backend.NCCL:
if self._backend == self.backend_enum_holder.NCCL:
for param in self.param_list:
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."

Expand All @@ -156,7 +166,7 @@ def __setstate__(self, state):

def __getstate__(self):
attrs = copy.copy(self.__dict__)
if dist._backend != dist.dist_backend.NCCL:
if self._backend != self.backend_enum_holder.NCCL:
del attrs['self.reduction_stream']
return attrs

Expand Down

0 comments on commit ed47ebf

Please sign in to comment.