Skip to content

Commit

Permalink
Optimize the sync batchnorm by batching the communication (NVIDIA#980)
Browse files Browse the repository at this point in the history
In this PR, we mainly tried to optimize the performance of Syncatchnorm and also fixed one potential issue in the welford_parallel kernel implementation.

For performance improvement, we batched the mean/var/count all_gather communication together and sent it once in the forward path
We also batch the all_reduce in backward path
We add the contiguous call on the input of welford_parallel kernel.
If there is any standard perf benchmark, I would be happy to run it.
  • Loading branch information
lly-zero-one authored Oct 19, 2020
1 parent a109f85 commit 8a1ed9e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 21 deletions.
36 changes: 18 additions & 18 deletions apex/parallel/optimized_sync_batchnorm_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,31 @@ def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, tr
if channel_last:
count = int(input.numel()/input.size(-1))
mean, var_biased = syncbn.welford_mean_var_c_last(input)
num_channels = input.size(-1)
else:
count = int(input.numel()/input.size(1))
mean, var_biased = syncbn.welford_mean_var(input)
num_channels = input.size(1)

if torch.distributed.is_initialized():
if not process_group:
process_group = torch.distributed.group.WORLD
device = mean.device
world_size = torch.distributed.get_world_size(process_group)
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=device)
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=device)
count_all = torch.cuda.IntTensor(world_size, device=device)
mean_l = [mean_all.narrow(0, i, 1).view(-1) for i in range(world_size)]
var_l = [var_all.narrow(0, i, 1).view(-1) for i in range(world_size)]
count_l = [count_all.narrow(0, i, 1) for i in range(world_size)]
torch.distributed.all_gather(mean_l, mean.view(-1), process_group)
torch.distributed.all_gather(var_l, var_biased.view(-1), process_group)
torch.distributed.all_gather(
count_l,
torch.cuda.IntTensor([count], device=device),
process_group)
mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count_all, eps)

count_t = torch.empty(1, dtype=mean.dtype, device=mean.device).fill_(count)
combined = torch.cat([mean.view(-1), var_biased.view(-1), count_t], dim=0)
combined_list = [torch.empty_like(combined) for k in range(world_size)]
torch.distributed.all_gather(combined_list, combined, process_group)
combined = torch.stack(combined_list, dim=0)
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
count_all = count_all.view(-1)
mean, var, inv_std = syncbn.welford_parallel(mean_all, invstd_all, count_all.to(torch.int32), eps)
else:
device = mean.device
count_all = torch.cuda.IntTensor([count], device=device)
inv_std = 1.0 / torch.sqrt(var_biased + eps)
var = var_biased * (count) / (count-1)
var = var_biased * (count) / (count-1)

if count == 1 and world_size < 2:
raise ValueError('Expected more than 1 value per channel when training, got input size{}'.format(input.size()))
Expand All @@ -60,7 +58,7 @@ def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, tr
mean = running_mean.data
inv_std = 1.0 / torch.sqrt(running_variance.data + eps)

ctx.save_for_backward(input, weight, mean, inv_std, z, bias, count_all)
ctx.save_for_backward(input, weight, mean, inv_std, z, bias, count_all.to(torch.int32))
ctx.process_group = process_group
ctx.channel_last = channel_last
ctx.world_size = world_size
Expand Down Expand Up @@ -101,10 +99,12 @@ def backward(ctx, grad_output):
if ctx.needs_input_grad[0]:

if torch.distributed.is_initialized():
num_channels = sum_dy.shape[0]
combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
torch.distributed.all_reduce(
sum_dy, ReduceOp.SUM, process_group)
torch.distributed.all_reduce(
sum_dy_xmu, ReduceOp.SUM, process_group)
combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
sum_dy, sum_dy_xmu = torch.split(combined, num_channels)

if channel_last:
grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)
else:
Expand Down
10 changes: 7 additions & 3 deletions csrc/welford.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,10 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
at::Tensor inv_std = at::empty_like(out_var);
at::Tensor out_mean = at::empty_like(out_var);

at::Tensor mean_feature_nodes_ = mean_feature_nodes.contiguous();
at::Tensor var_biased_ = var_biased.contiguous();
at::Tensor numel_ = numel.contiguous();

// TODO(jie): tile this for memory coalescing!
const int block = std::min(h_last_pow2(feature_size), MAX_BLOCK_SIZE);
const int grid = std::max<int>(1, feature_size / block);
Expand All @@ -1165,9 +1169,9 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
using namespace at;
DISPATCH_FLOAT_AND_HALF(mean_feature_nodes.scalar_type(), 0, "welford_parallel_kernel",
welford_kernel_parallel<scalar_t_0><<<grid, block, 0, stream>>>(
mean_feature_nodes.DATA_PTR<scalar_t_0>(),
var_biased.DATA_PTR<scalar_t_0>(),
numel.DATA_PTR<int>(),
mean_feature_nodes_.DATA_PTR<scalar_t_0>(),
var_biased_.DATA_PTR<scalar_t_0>(),
numel_.DATA_PTR<int>(),
out_mean.DATA_PTR<scalar_t_0>(),
out_var.DATA_PTR<scalar_t_0>(),
inv_std.DATA_PTR<scalar_t_0>(),
Expand Down

0 comments on commit 8a1ed9e

Please sign in to comment.