Skip to content

Commit

Permalink
Decouple optimizer state and grad dtypes in distributed Adam optimizer (
Browse files Browse the repository at this point in the history
NVIDIA#1575)

* Decouple distopt dtypes for grads and optim state

* Automatically detect grad dtype for Transformer layer wgrad fusion

* Review suggestions from @crcrpar
  • Loading branch information
timmoon10 authored Feb 25, 2023
1 parent 0c8400a commit 03c9d80
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 106 deletions.
84 changes: 43 additions & 41 deletions apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ struct DistAdamFunctor
* to store FP32 main params. Instead, store 16-bit param remainder
* and combine with BF16 param to reconstruct the FP32 main param.
*/
template <typename GRAD_T>
struct DistAdamWithParamRemaindersFunctor
{
__device__ __forceinline__ void operator()(
Expand Down Expand Up @@ -230,7 +231,7 @@ struct DistAdamWithParamRemaindersFunctor
m += chunk_idx*chunk_size;
float* v = (float *)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;
float* g = (float *)tl.addresses[4][tensor_loc];
const GRAD_T* g = (GRAD_T *)tl.addresses[4][tensor_loc];
g += chunk_idx*chunk_size;
int16_t* p_out = (int16_t *)tl.addresses[5][tensor_loc];
p_out += chunk_idx*chunk_size;
Expand All @@ -256,7 +257,7 @@ struct DistAdamWithParamRemaindersFunctor
int16_t local_p_rem[ILP];
float local_m[ILP];
float local_v[ILP];
float local_g[ILP];
GRAD_T local_g[ILP];

// Load
if (aligned) {
Expand Down Expand Up @@ -294,7 +295,7 @@ struct DistAdamWithParamRemaindersFunctor
}

// Local compute
using LocalFunctor = DistAdamFunctor<float, float, void>;
using LocalFunctor = DistAdamFunctor<float, GRAD_T, void>;
LocalFunctor::local_step(
reinterpret_cast<float *>(local_p), local_m, local_v, local_g, grad_scale,
beta1, beta2, beta1_correction, beta2_correction,
Expand Down Expand Up @@ -349,12 +350,9 @@ void multi_tensor_fused_adam_cuda(
// Expect p_in, m, v, g, p_out
size_t tl_sz = tensor_lists.size();
TORCH_CHECK(tl_sz == 5, "expected tensor lists of size 5");

// Assume p_in and g have same type
auto p_in_type = tensor_lists[0][0].scalar_type();
auto g_type = tensor_lists[3][0].scalar_type();
auto p_out_type = tensor_lists[4][0].scalar_type();
TORCH_CHECK(p_in_type == g_type, "expected main params and grads to have same type");
const auto p_in_type = tensor_lists[0][0].scalar_type();
const auto g_type = tensor_lists[3][0].scalar_type();
const auto p_out_type = tensor_lists[4][0].scalar_type();

float beta1_correction = 1.0f, beta2_correction = 1.0f;
if (bias_correction == 1) {
Expand All @@ -363,23 +361,24 @@ void multi_tensor_fused_adam_cuda(
}

DISPATCH_FLOAT_HALF_AND_BFLOAT(p_in_type, 0, "dist_adam_cuda_kernel",
DISPATCH_FLOAT_HALF_AND_BFLOAT(p_out_type, 1, "dist_adam_cuda_kernel",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistAdamFunctor<scalar_t_0, scalar_t_0, scalar_t_1>(),
grad_scale.DATA_PTR<float>(),
beta1,
beta2,
beta1_correction,
beta2_correction,
eps,
lr,
(adamMode_t) mode,
weight_decay);
));
DISPATCH_FLOAT_HALF_AND_BFLOAT(g_type, 1, "dist_adam_cuda_kernel",
DISPATCH_FLOAT_HALF_AND_BFLOAT(p_out_type, 2, "dist_adam_cuda_kernel",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistAdamFunctor<scalar_t_0, scalar_t_1, scalar_t_2>(),
grad_scale.data_ptr<float>(),
beta1,
beta2,
beta1_correction,
beta2_correction,
eps,
lr,
(adamMode_t) mode,
weight_decay);
)));
C10_CUDA_CHECK(cudaGetLastError());
}

Expand All @@ -402,27 +401,30 @@ void multi_tensor_fused_adam_with_param_remainders_cuda(
// Expect p_in, p_rem, m, v, g, p_out
size_t tl_sz = tensor_lists.size();
TORCH_CHECK(tl_sz == 6, "expected tensor lists of size 6");
const auto g_type = tensor_lists[4][0].scalar_type();

float beta1_correction = 1.0f, beta2_correction = 1.0f;
if (bias_correction == 1) {
beta1_correction = 1 - std::pow(beta1, step);
beta2_correction = 1 - std::pow(beta2, step);
}

multi_tensor_apply<6>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistAdamWithParamRemaindersFunctor(),
grad_scale.DATA_PTR<float>(),
beta1,
beta2,
beta1_correction,
beta2_correction,
eps,
lr,
(adamMode_t) mode,
weight_decay);
DISPATCH_FLOAT_HALF_AND_BFLOAT(g_type, 0, "dist_adam_with_param_remainders_cuda_kernel",
multi_tensor_apply<6>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistAdamWithParamRemaindersFunctor<scalar_t_0>(),
grad_scale.data_ptr<float>(),
beta1,
beta2,
beta1_correction,
beta2_correction,
eps,
lr,
(adamMode_t) mode,
weight_decay);
);
C10_CUDA_CHECK(cudaGetLastError());
}
28 changes: 7 additions & 21 deletions apex/contrib/optimizers/distributed_fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,11 +444,6 @@ def __init__(self,
f'grad_sync_dtype={grad_sync_dtype}, '
f'param_sync_dtype={param_sync_dtype}))'
)
if grad_sync_dtype != dtype:
raise RuntimeError(
'DistributedFusedAdam requires dtype to match grad dtype '
f'(dtype={dtype}, grad_sync_dtype={grad_sync_dtype})'
)
self.dtype = dtype
self.grad_sync_dtype = grad_sync_dtype
self.param_sync_dtype = param_sync_dtype
Expand Down Expand Up @@ -488,13 +483,7 @@ def __init__(self,
f'distributed process group size = {self.distributed_size}, '
f'redundant process group size = {self.redundant_size})'
)
try:
self._process_group_ranks = [
get_global_rank(self.process_group, local_rank)
for local_rank in range(self.distributed_size)
]
except:
self._process_group_ranks = list(range(self.distributed_size))
self.process_group_root = get_global_rank(self.process_group, 0)

# Use average reduction for grad sync
self.average_grad_sync = average_grad_sync
Expand All @@ -515,14 +504,12 @@ def __init__(self,
'with store_params=True and store_param_remainders=True'
)
if (self.dtype != torch.float32
or self.grad_sync_dtype != torch.float32
or self.param_sync_dtype != torch.bfloat16):
raise RuntimeError(
'DistributedFusedAdam requires '
'BF16 params and FP32 optimizer state '
'when storing parameter remainders '
f'(dtype={self.dtype}, '
f'grad_sync_dtype={self.grad_sync_dtype}, '
f'param_sync_dtype={self.param_sync_dtype}))'
)
self.store_params = store_params
Expand Down Expand Up @@ -565,7 +552,7 @@ def __init__(self,

# Scale by factor before optimizer step. Used for grad
# clipping and gradient scaler.
self._grad_scale = torch.full([], 1.0, dtype=self.dtype, device=self.device)
self._grad_scale = torch.full([], 1.0, dtype=torch.float32, device=self.device)
# Norm of parameter gradients. Used for gradient clipping and
# gradient scaler.
self._grad_norm = None
Expand Down Expand Up @@ -603,7 +590,7 @@ def _broadcast_params(self):
sync_requests.append(
torch.distributed.broadcast(
param,
src=self._process_group_ranks[0],
src=self.process_group_root,
group=process_group,
async_op=True,
)
Expand Down Expand Up @@ -829,7 +816,7 @@ def _init_grad_buffer(self):
buffer_size = 0
self._grad_buffer = torch.zeros(
[buffer_size],
dtype=self.dtype,
dtype=self.grad_sync_dtype,
device=self.device,
)

Expand Down Expand Up @@ -1089,7 +1076,7 @@ def zero_grad(self, set_to_none=False):
param.grad.zero_()

# Reset other state
self._grad_scale = torch.full([], 1.0, dtype=self.dtype, device=self.device)
self._grad_scale = torch.full([], 1.0, dtype=torch.float32, device=self.device)
self._grad_norm = None
self._dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=self.device)

Expand Down Expand Up @@ -1934,7 +1921,6 @@ def state_dict(self, gather_on_root=True):
ranks on the root rank (default: True)
"""
### TODO Fix
state_dict = super().state_dict()
if not gather_on_root:
return state_dict
Expand Down Expand Up @@ -2036,14 +2022,14 @@ def state_dict(self, gather_on_root=True):
torch.distributed.gather(
gathered_chunks[0],
gathered_chunks,
dst=self._process_group_ranks[0],
dst=self.process_group_root,
group=self.process_group,
**no_copy_kwarg,
)
else:
torch.distributed.gather(
chunk,
dst=self._process_group_ranks[0],
dst=self.process_group_root,
group=self.process_group,
)
stream.wait_stream(main_stream)
Expand Down
14 changes: 14 additions & 0 deletions apex/contrib/test/optimizers/test_dist_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def make_models(
adam_w_mode=True,
model_dtype=torch.float32,
optim_dtype=None,
grad_sync_dtype=None,
param_sync_dtype=None,
device='cuda',
overlap_communication=True,
Expand Down Expand Up @@ -79,6 +80,7 @@ def make_models(
overlap_param_sync=overlap_communication,
bucket_cap_mb=71/(4*1024*1024),
dtype=optim_dtype,
grad_sync_dtype=grad_sync_dtype,
param_sync_dtype=param_sync_dtype,
contiguous_param_buffer=contiguous_buffers,
contiguous_grad_buffer=contiguous_buffers,
Expand Down Expand Up @@ -117,6 +119,7 @@ def test_matches_pytorch(
use_nosync=True,
model_dtype=torch.float32,
optim_dtype=None,
grad_sync_dtype=None,
param_sync_dtype=None,
device='cuda',
contiguous_buffers=False,
Expand All @@ -133,6 +136,7 @@ def test_matches_pytorch(
adam_w_mode=adam_w_mode,
model_dtype=model_dtype,
optim_dtype=optim_dtype,
grad_sync_dtype=grad_sync_dtype,
param_sync_dtype=param_sync_dtype,
device=device,
overlap_communication=overlap_communication,
Expand Down Expand Up @@ -239,6 +243,16 @@ def test_matches_pytorch_fp16_params(self):
store_params=True,
)

def test_matches_pytorch_bf16_grads(self):
self.test_matches_pytorch(
rtol=5e-2,
atol=1e-5,
micro_batch_steps=1,
model_dtype=torch.float32,
optim_dtype=torch.float32,
grad_sync_dtype=torch.bfloat16,
)

def test_matches_pytorch_bf16_param_remainders(self):
self.test_matches_pytorch(
rtol=5e-2,
Expand Down
Loading

0 comments on commit 03c9d80

Please sign in to comment.