From ed47ebff3cda136972c7d1c63b807dda0938e2b0 Mon Sep 17 00:00:00 2001 From: Michael Carilli Date: Tue, 18 Sep 2018 13:39:21 -0700 Subject: [PATCH] Forward compatibility fixes for distributed backend, thanks to @Ssnl --- apex/parallel/distributed.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/apex/parallel/distributed.py b/apex/parallel/distributed.py index 621f3fbf4..2848c1f0d 100644 --- a/apex/parallel/distributed.py +++ b/apex/parallel/distributed.py @@ -129,7 +129,17 @@ class DistributedDataParallel(Module): def __init__(self, module, message_size=10000000, shared_param=False): super(DistributedDataParallel, self).__init__() - self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False + + # Backward/forward compatibility around + # https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36 + if(hasattr(dist, "get_backend")): + self._backend = dist.get_backend() + self.backend_enum_holder = dist.DistBackend + else: + self._backend = dist._backend + self.backend_enum_holder = dist.dist_backend + + self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False self.shared_param = shared_param self.message_size = message_size @@ -141,7 +151,7 @@ def __init__(self, module, message_size=10000000, shared_param=False): self.module = module self.param_list = list(self.module.parameters()) - if dist._backend == dist.dist_backend.NCCL: + if self._backend == self.backend_enum_holder.NCCL: for param in self.param_list: assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU." @@ -156,7 +166,7 @@ def __setstate__(self, state): def __getstate__(self): attrs = copy.copy(self.__dict__) - if dist._backend != dist.dist_backend.NCCL: + if self._backend != self.backend_enum_holder.NCCL: del attrs['self.reduction_stream'] return attrs