Skip to content

Commit

Permalink
[Feature]: Support empty tensor in MMSyncBN (open-mmlab#1205)
Browse files Browse the repository at this point in the history
* [Feature]: support empty tensor in MMSyncBN

* refine code

* resolve comments

* clean unnecessary comments

* fix inaccurate statistics when empty tensor

* resolve comments and add docstrings

* update unit tests

* rephrase, ready for merge
  • Loading branch information
ZwwWayne authored Sep 23, 2021
1 parent b6eb382 commit 4e101e0
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 35 deletions.
153 changes: 118 additions & 35 deletions mmcv/ops/sync_bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class SyncBatchNormFunction(Function):

@staticmethod
def symbolic(g, input, running_mean, running_var, weight, bias, momentum,
eps, group, group_size):
eps, group, group_size, stats_mode):
return g.op(
'mmcv::MMCVSyncBatchNorm',
input,
Expand All @@ -31,41 +31,82 @@ def symbolic(g, input, running_mean, running_var, weight, bias, momentum,
momentum_f=momentum,
eps_f=eps,
group_i=group,
group_size_i=group_size)
group_size_i=group_size,
stats_mode=stats_mode)

@staticmethod
def forward(self, input, running_mean, running_var, weight, bias, momentum,
eps, group, group_size):
eps, group, group_size, stats_mode):
self.momentum = momentum
self.eps = eps
self.group = group
self.group_size = group_size
self.stats_mode = stats_mode

assert isinstance(
input, (torch.HalfTensor, torch.FloatTensor,
torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \
f'only support Half or Float Tensor, but {input.type()}'
output = torch.empty_like(input)
input3d = input.view(input.size(0), input.size(1), -1)
output = torch.zeros_like(input)
input3d = input.flatten(start_dim=2)
output3d = output.view_as(input3d)
num_channels = input3d.size(1)

mean = torch.empty(
input3d.size(1), dtype=torch.float, device=input3d.device)
var = torch.empty(
input3d.size(1), dtype=torch.float, device=input3d.device)
norm = torch.empty_like(
# ensure mean/var/norm/std are initialized as zeros
# ``torch.empty()`` does not guarantee that
mean = torch.zeros(
num_channels, dtype=torch.float, device=input3d.device)
var = torch.zeros(
num_channels, dtype=torch.float, device=input3d.device)
norm = torch.zeros_like(
input3d, dtype=torch.float, device=input3d.device)
std = torch.empty(
input3d.size(1), dtype=torch.float, device=input3d.device)
std = torch.zeros(
num_channels, dtype=torch.float, device=input3d.device)

ext_module.sync_bn_forward_mean(input3d, mean)
batch_size = input3d.size(0)
if batch_size > 0:
ext_module.sync_bn_forward_mean(input3d, mean)
batch_flag = torch.ones([1], device=mean.device, dtype=mean.dtype)
else:
# skip updating mean and leave it as zeros when the input is empty
batch_flag = torch.zeros([1], device=mean.device, dtype=mean.dtype)

# synchronize mean and the batch flag
vec = torch.cat([mean, batch_flag])
if self.stats_mode == 'N':
vec *= batch_size
if self.group_size > 1:
dist.all_reduce(mean, group=self.group)
mean /= self.group_size
ext_module.sync_bn_forward_var(input3d, mean, var)
dist.all_reduce(vec, group=self.group)
total_batch = vec[-1].detach()
mean = vec[:num_channels]

if self.stats_mode == 'default':
mean = mean / self.group_size
elif self.stats_mode == 'N':
mean = mean / total_batch.clamp(min=1)
else:
raise NotImplementedError

# leave var as zeros when the input is empty
if batch_size > 0:
ext_module.sync_bn_forward_var(input3d, mean, var)

if self.stats_mode == 'N':
var *= batch_size
if self.group_size > 1:
dist.all_reduce(var, group=self.group)

if self.stats_mode == 'default':
var /= self.group_size
elif self.stats_mode == 'N':
var /= total_batch.clamp(min=1)
else:
raise NotImplementedError

# if the total batch size over all the ranks is zero,
# we should not update the statistics in the current batch
update_flag = total_batch.clamp(max=1)
momentum = update_flag * self.momentum
ext_module.sync_bn_forward_output(
input3d,
mean,
Expand All @@ -78,7 +119,7 @@ def forward(self, input, running_mean, running_var, weight, bias, momentum,
std,
output3d,
eps=self.eps,
momentum=self.momentum,
momentum=momentum,
group_size=self.group_size)
self.save_for_backward(norm, std, weight)
return output
Expand All @@ -87,36 +128,76 @@ def forward(self, input, running_mean, running_var, weight, bias, momentum,
@once_differentiable
def backward(self, grad_output):
norm, std, weight = self.saved_tensors
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(weight)
grad_input = torch.empty_like(grad_output)
grad_output3d = grad_output.view(
grad_output.size(0), grad_output.size(1), -1)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(weight)
grad_input = torch.zeros_like(grad_output)
grad_output3d = grad_output.flatten(start_dim=2)
grad_input3d = grad_input.view_as(grad_output3d)
ext_module.sync_bn_backward_param(grad_output3d, norm, grad_weight,
grad_bias)

batch_size = grad_input3d.size(0)
if batch_size > 0:
ext_module.sync_bn_backward_param(grad_output3d, norm, grad_weight,
grad_bias)

# all reduce
if self.group_size > 1:
dist.all_reduce(grad_weight, group=self.group)
dist.all_reduce(grad_bias, group=self.group)
grad_weight /= self.group_size
grad_bias /= self.group_size
ext_module.sync_bn_backward_data(grad_output3d, weight, grad_weight,
grad_bias, norm, std, grad_input3d)

if batch_size > 0:
ext_module.sync_bn_backward_data(grad_output3d, weight,
grad_weight, grad_bias, norm, std,
grad_input3d)

return grad_input, None, None, grad_weight, grad_bias, \
None, None, None, None
None, None, None, None, None


@NORM_LAYERS.register_module(name='MMSyncBN')
class SyncBatchNorm(Module):
"""Synchronized Batch Normalization.
Args:
num_features (int): number of features/chennels in input tensor
eps (float, optional): a value added to the denominator for numerical
stability. Defaults to 1e-5.
momentum (float, optional): the value used for the running_mean and
running_var computation. Defaults to 0.1.
affine (bool, optional): whether to use learnable affine parameters.
Defaults to True.
track_running_stats (bool, optional): whether to track the running
mean and variance during training. When set to False, this
module does not track such statistics, and initializes statistics
buffers ``running_mean`` and ``running_var`` as ``None``. When
these buffers are ``None``, this module always uses batch
statistics in both training and eval modes. Defaults to True.
group (int, optional): synchronization of stats happen within
each process group individually. By default it is synchronization
across the whole world. Defaults to None.
stats_mode (str, optional): The statistical mode. Available options
includes ``'default'`` and ``'N'``. Defaults to 'default'.
When ``stats_mode=='default'``, it computes the overall statistics
using those from each worker with equal weight, i.e., the
statistics are synchronized and simply divied by ``group``. This
mode will produce inaccurate statistics when empty tensors occur.
When ``stats_mode=='N'``, it compute the overall statistics using
the total number of batches in each worker ignoring the number of
group, i.e., the statistics are synchronized and then divied by
the total batch ``N``. This mode is beneficial when empty tensors
occur during training, as it average the total mean by the real
number of batch.
"""

def __init__(self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
group=None):
group=None,
stats_mode='default'):
super(SyncBatchNorm, self).__init__()
self.num_features = num_features
self.eps = eps
Expand All @@ -126,6 +207,9 @@ def __init__(self,
group = dist.group.WORLD if group is None else group
self.group = group
self.group_size = dist.get_world_size(group)
assert stats_mode in ['default', 'N'], \
f'"stats_mode" only accepts "default" and "N", got "{stats_mode}"'
self.stats_mode = stats_mode
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
Expand Down Expand Up @@ -174,12 +258,10 @@ def forward(self, input):
exponential_average_factor = self.momentum

if self.training or not self.track_running_stats:
return SyncBatchNormFunction.apply(input, self.running_mean,
self.running_var, self.weight,
self.bias,
exponential_average_factor,
self.eps, self.group,
self.group_size)
return SyncBatchNormFunction.apply(
input, self.running_mean, self.running_var, self.weight,
self.bias, exponential_average_factor, self.eps, self.group,
self.group_size, self.stats_mode)
else:
return F.batch_norm(input, self.running_mean, self.running_var,
self.weight, self.bias, False,
Expand All @@ -192,5 +274,6 @@ def __repr__(self):
s += f'momentum={self.momentum}, '
s += f'affine={self.affine}, '
s += f'track_running_stats={self.track_running_stats}, '
s += f'group_size={self.group_size})'
s += f'group_size={self.group_size},'
s += f'stats_mode={self.stats_mode})'
return s
1 change: 1 addition & 0 deletions tests/data/config/code.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from mmcv import Config # isort:skip

cfg = Config.fromfile('./tests/data/config/a.py')
item5 = cfg.item1[0] + cfg.item2.a
134 changes: 134 additions & 0 deletions tests/test_ops/test_syncbn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import platform

import numpy as np
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
Expand Down Expand Up @@ -141,6 +142,121 @@ def _test_syncbn_train(self, size=1, half=False):
assert np.allclose(x_grad.data.cpu().numpy(),
sx_grad.data.cpu().numpy(), 1e-2)

def _test_syncbn_empty_train(self, size=1, half=False):

if 'SLURM_NTASKS' not in os.environ or int(
os.environ['SLURM_NTASKS']) != 4:
print('must run with slurm has 4 processes!\n'
'srun -p test --gres=gpu:4 -n4')
return
else:
print('Running syncbn test')
from mmcv.ops import SyncBatchNorm

assert size in (1, 2, 4)
if not dist.is_initialized():
self.dist_init()
rank = dist.get_rank()

torch.manual_seed(9)
torch.cuda.manual_seed(9)

self.x = torch.rand(0, 3, 2, 3).cuda()
self.y_bp = torch.rand(0, 3, 2, 3).cuda()

if half:
self.x = self.x.half()
self.y_bp = self.y_bp.half()
dist.broadcast(self.x, src=0)
dist.broadcast(self.y_bp, src=0)

torch.cuda.synchronize()
if size == 1:
groups = [None, None, None, None]
groups[0] = dist.new_group([0])
groups[1] = dist.new_group([1])
groups[2] = dist.new_group([2])
groups[3] = dist.new_group([3])
group = groups[rank]
elif size == 2:
groups = [None, None, None, None]
groups[0] = groups[1] = dist.new_group([0, 1])
groups[2] = groups[3] = dist.new_group([2, 3])
group = groups[rank]
elif size == 4:
group = dist.group.WORLD

syncbn = SyncBatchNorm(3, group=group, stats_mode='N').cuda()
syncbn.weight.data[0] = 0.2
syncbn.weight.data[1] = 0.5
syncbn.weight.data[2] = 0.7
syncbn.train()

bn = nn.BatchNorm2d(3).cuda()
bn.weight.data[0] = 0.2
bn.weight.data[1] = 0.5
bn.weight.data[2] = 0.7
bn.train()

sx = self.x[rank * 4:rank * 4 + 4]
sx.requires_grad_()
sy = syncbn(sx)
sy.backward(self.y_bp[rank * 4:rank * 4 + 4])
smean = syncbn.running_mean
svar = syncbn.running_var
sx_grad = sx.grad
sw_grad = syncbn.weight.grad
sb_grad = syncbn.bias.grad

if size == 1:
x = self.x[rank * 4:rank * 4 + 4]
y_bp = self.y_bp[rank * 4:rank * 4 + 4]
elif size == 2:
x = self.x[rank // 2 * 8:rank // 2 * 8 + 8]
y_bp = self.y_bp[rank // 2 * 8:rank // 2 * 8 + 8]
elif size == 4:
x = self.x
y_bp = self.y_bp
x.requires_grad_()
y = bn(x)
y.backward(y_bp)

if size == 2:
y = y[rank % 2 * 4:rank % 2 * 4 + 4]
elif size == 4:
y = y[rank * 4:rank * 4 + 4]

mean = bn.running_mean
var = bn.running_var
if size == 1:
x_grad = x.grad
w_grad = bn.weight.grad
b_grad = bn.bias.grad
elif size == 2:
x_grad = x.grad[rank % 2 * 4:rank % 2 * 4 + 4]
w_grad = bn.weight.grad / 2
b_grad = bn.bias.grad / 2
elif size == 4:
x_grad = x.grad[rank * 4:rank * 4 + 4]
w_grad = bn.weight.grad / 4
b_grad = bn.bias.grad / 4

assert np.allclose(mean.data.cpu().numpy(),
smean.data.cpu().numpy(), 1e-3)
assert np.allclose(var.data.cpu().numpy(),
svar.data.cpu().numpy(), 1e-3)
assert np.allclose(y.data.cpu().numpy(), sy.data.cpu().numpy(), 1e-3)
assert np.allclose(w_grad.data.cpu().numpy(),
sw_grad.data.cpu().numpy(), 1e-3)
assert np.allclose(b_grad.data.cpu().numpy(),
sb_grad.data.cpu().numpy(), 1e-3)
assert np.allclose(x_grad.data.cpu().numpy(),
sx_grad.data.cpu().numpy(), 1e-2)

# 'stats_mode' only allows 'default' and 'N'
with pytest.raises(AssertionError):
SyncBatchNorm(3, group=group, stats_mode='X')

def test_syncbn_1(self):
self._test_syncbn_train(size=1)

Expand All @@ -158,3 +274,21 @@ def test_syncbn_2_half(self):

def test_syncbn_4_half(self):
self._test_syncbn_train(size=4, half=True)

def test_syncbn_empty_1(self):
self._test_syncbn_empty_train(size=1)

def test_syncbn_empty_2(self):
self._test_syncbn_empty_train(size=2)

def test_syncbn_empty_4(self):
self._test_syncbn_empty_train(size=4)

def test_syncbn_empty_1_half(self):
self._test_syncbn_empty_train(size=1, half=True)

def test_syncbn_empty_2_half(self):
self._test_syncbn_empty_train(size=2, half=True)

def test_syncbn_empty_4_half(self):
self._test_syncbn_empty_train(size=4, half=True)

0 comments on commit 4e101e0

Please sign in to comment.