Skip to content

Commit

Permalink
Syncing with master
Browse files Browse the repository at this point in the history
  • Loading branch information
definitelynotmcarilli committed Mar 4, 2019
2 parents b90b057 + 603e17a commit 6066ddd
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 3 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# PSA: Unified API for mixed precision tools coming soon!
(as introduced by https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html.

Branch `api_refactor` is tracking my progress. Update as of 2/28: PR-ed in https://github.com/NVIDIA/apex/pull/173. I'd like to clean up the documentation a bit more before final merge.

# Introduction

This repository holds NVIDIA-maintained utilities to streamline
Expand Down
2 changes: 1 addition & 1 deletion apex/optimizers/fp16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
lib = ctypes.cdll.LoadLibrary(None)
lib.THCudaHalfTensor_normall.argtypes=[ctypes.c_void_p, ctypes.c_void_p]
lib.THCudaHalfTensor_normall.restype = ctypes.c_float

def fused_norm(input):
if input.type() == 'torch.cuda.HalfTensor':
# 16384 is half 2 if you stare at it long enough
Expand All @@ -23,6 +22,7 @@ def fused_norm(input):
"LoadLibrary with None. Original exception message was ",
stashed_err)


class FP16_Optimizer(object):
"""
:class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
Expand Down
39 changes: 39 additions & 0 deletions apex/parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,42 @@ def convert_syncbn_model(module, process_group=None, channel_last=False):
# TODO(jie) should I delete model explicitly?
del module
return mod

def create_syncbn_process_group(group_size):
'''
Creates process groups to be used for syncbn of a give ``group_size`` and returns
process group that current GPU participates in.
``group_size`` must divide the total number of GPUs (world_size).
``group_size`` of 0 would be considered as =world_size. In this case ``None`` will be returned.
``group_size`` of 1 would be equivalent to using non-sync bn, but will still carry the overhead.
Args:
group_size (int): number of GPU's to collaborate for sync bn
Example::
>>> # model is an instance of torch.nn.Module
>>> import apex
>>> group = apex.parallel.create_syncbn_process_group(group_size)
'''

if group_size==0:
return None

world_size = torch.distributed.get_world_size()
assert(world_size >= group_size)
assert(world_size % group_size == 0)

group=None
for group_num in (range(world_size//group_size)):
group_ids = range(group_num*group_size, (group_num+1)*group_size)
cur_group = torch.distributed.new_group(ranks=group_ids)
if (torch.distributed.get_rank()//group_size == group_num):
group = cur_group
#can not drop out and return here, every process must go through creation of all subgroups

assert(group is not None)
return group
2 changes: 1 addition & 1 deletion apex/parallel/optimized_sync_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class SyncBatchNorm(_BatchNorm):
>>> inp = torch.randn(10, 14, 14, 100).cuda()
"""

def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last = False):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group
self.channel_last = channel_last
Expand Down
4 changes: 3 additions & 1 deletion apex/parallel/sync_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ class SyncBatchNorm(_BatchNorm):

warned = False

def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False):
if channel_last == True:
raise AttributeError("channel_last is not supported by primitive SyncBatchNorm implementation. Try install apex with `--cuda_ext` if channel_last is desired.")

if not SyncBatchNorm.warned:
print("Warning: using Python fallback for SyncBatchNorm, possibly because apex was installed without --cuda_ext. The exception raised when attempting to import the cuda backend was: ", self.syncbn_import_error)
Expand Down
2 changes: 2 additions & 0 deletions tests/distributed/synced_batchnorm/unit_test.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
python single_gpu_unit_test.py
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp64
#beware, you need a system with at least 4 gpus to test group_size<world_size
python -m torch.distributed.launch --nproc_per_node=4 test_groups.py --group_size=2
185 changes: 185 additions & 0 deletions tests/synced_batchnorm/test_groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import torch
import numpy as np
import apex
import syncbn
import os
import argparse
import torch.optim as optim

def compare(desc, inp1, inp2, error):
a = inp1.clone().detach().cpu().numpy()
b = inp2.clone().detach().cpu().numpy()
close = np.allclose(a,b, error, error)
if not close:
print(desc, close)
z = a - b
index = (np.abs(z) >= error + error * np.abs(b)).nonzero()
print("dif : ", z[index])
print("inp1 : ", a[index])
print("inp2 : ", b[index])
return close

feature_size = 10
space_size = 40
batch_size = 32


from apex.parallel import DistributedDataParallel as DDP
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--fp16", action='store_true', default=False)
parser.add_argument("--fp64", action='store_true', default=False)
parser.add_argument("--group_size", default=0, type=int)
args = parser.parse_args()

try:
args.world_size = int(os.environ['WORLD_SIZE'])
except:
print("This is a multi-gpu test. To run it please use 'python -m torch.distributed.launch --nproc_per_node=<num gpus> test_groups.py <more options>'")
exit(1)

torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')

start = (args.local_rank%args.group_size) * batch_size//args.group_size
finish = (args.local_rank%args.group_size + 1) * batch_size//args.group_size

error = 1e-5
dtype = np.float32
if args.fp16:
error = 1e-3
dtype = np.float16
elif args.fp64:
error = 1e-8
dtype = np.float64


np.random.seed(18 + args.local_rank//args.group_size)

inp = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)
grad = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)
weight = np.random.randn(feature_size).astype(dtype)
bias = np.random.randn(feature_size).astype(dtype)


type_tensor = torch.cuda.FloatTensor
if args.fp16:
type_tensor = torch.cuda.HalfTensor
if args.fp64:
type_tensor = torch.cuda.DoubleTensor

ref_tensor = torch.cuda.DoubleTensor

inp_t = type_tensor(inp)
weight_t = type_tensor(weight)
bias_t = type_tensor(bias)

inp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1))
inp2_r = ref_tensor(inp)
weight_r = ref_tensor(weight).view(-1, 1, 1)
bias_r = ref_tensor(bias).view(-1, 1, 1)

grad_output_t = type_tensor(grad)

m = inp_r.mean(1)
b_v = inp_r.var(1, unbiased=False)
unb_v = inp_r.var(1, unbiased=True)

eps = 1e-5

mean, var_biased = syncbn.welford_mean_var(inp_t)
inv_std = 1.0 / torch.sqrt(var_biased + eps)

bn = torch.nn.BatchNorm2d(feature_size).cuda()
bn.momentum = 1.0
bn.weight.data = weight_t.clone()
bn.bias.data = bias_t.clone()
if args.fp16:
bn.half()
if args.fp64:
bn.double()
bn = DDP(bn)
inp_bn = inp_t.clone().requires_grad_()
grad_bn = grad_output_t.clone().detach()
out_bn = bn(inp_bn)
out_bn.backward(grad_bn)
# compensating the averaging over processes done by DDP
# in order to produce mathematically equivalent result
# https://github.com/NVIDIA/apex/issues/134#issuecomment-458307368
for param in bn.parameters():
param.grad = param.grad / args.group_size
bn_opt = optim.SGD(bn.parameters(), lr=1.0)

sbn = apex.parallel.SyncBatchNorm(feature_size, process_group=apex.parallel.create_syncbn_process_group(args.group_size)).cuda()
sbn.momentum = 1.0
sbn.weight.data = weight_t.clone()
sbn.bias.data = bias_t.clone()
if args.fp16:
sbn.half()
if args.fp64:
sbn.double()
sbn = DDP(sbn)
sbn_opt = optim.SGD(sbn.parameters(), lr=1.0)
inp_sbn = inp_t.clone().requires_grad_()
grad_sbn = grad_output_t.clone().detach()
out_sbn = sbn(inp_sbn[start:finish])
out_sbn.backward(grad_sbn[start:finish])

sbn_result = True
bn_result = True

if args.local_rank == 0:
sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result
sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result

out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t)
out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r

if args.local_rank == 0:
sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result
compare("comparing bn output: ", out_bn, out_r, error)

grad_output_t = type_tensor(grad)

grad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1))
grad_output2_r = ref_tensor(grad)

grad_bias_r = grad_output_r.sum(1)
grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)

mean_dy_r = grad_output_r.mean(1)
mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1)

grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)

mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)

if args.local_rank == 0:
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result
sbn_result = compare("comparing mean_dy_xmu grad: ", mean_dy_xmu, mean_dy_xmu_r, error) and sbn_result
sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result
compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)

if args.local_rank == 0:
sbn_result = compare("comparing running_mean: ", bn.module.running_mean.data, sbn.module.running_mean.data, error) and sbn_result
sbn_result = compare("comparing running_variance: ", bn.module.running_var.data, sbn.module.running_var.data, error) and sbn_result

# execute by both
compare("comparing layers output: ", out_bn[start:finish], out_sbn, error) and sbn_result
compare("comparing layers grad_input: ", inp_bn.grad[start:finish], inp_sbn.grad[start:finish], error) and sbn_result

bn_opt.step()
sbn_opt.step()

if args.local_rank == 0:
compare("comparing bn vs sbn bias: ", bn.module.bias, sbn.module.bias, error)
compare("comparing bn vs sbn weight: ", bn.module.weight, sbn.module.weight, error)


if sbn_result:
print("====SBN group test passed")
else:
print("*SBN group test failed*")

0 comments on commit 6066ddd

Please sign in to comment.