forked from NVIDIA/apex
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
351 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import torch | ||
from apex.multi_tensor_apply import multi_tensor_applier | ||
|
||
|
||
class FusedAdagrad(torch.optim.Optimizer): | ||
"""Implements Adagrad algorithm. | ||
Currently GPU-only. Requires Apex to be installed via | ||
``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``. | ||
This version of fused Adagrad implements 2 fusions. | ||
* Fusion of the Adagrad update's elementwise operations | ||
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. | ||
:class:`apex.optimizers.FusedAdagrad`'s usage is identical to any ordinary Pytorch optimizer:: | ||
opt = apex.optimizers.FusedAdagrad(model.parameters(), lr = ....) | ||
... | ||
opt.step() | ||
:class:`apex.optimizers.FusedAdagrad` may be used with or without Amp. If you wish to use :class:`FusedAdagrad` with Amp, | ||
you may choose any ``opt_level``:: | ||
opt = apex.optimizers.FusedAdagrad(model.parameters(), lr = ....) | ||
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") | ||
... | ||
opt.step() | ||
In general, ``opt_level="O1"`` is recommended. | ||
It has been proposed in `Adaptive Subgradient Methods for Online Learning | ||
and Stochastic Optimization`_. | ||
Arguments: | ||
params (iterable): iterable of parameters to optimize or dicts defining | ||
parameter groups | ||
lr (float, optional): learning rate (default: 1e-2) | ||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) | ||
eps (float, optional): term added to the denominator to improve | ||
numerical stability (default: 1e-10) | ||
adagrad_w_mode (boolean, optional): Apply L2 regularization or weight decay | ||
True for decoupled weight decay (also known as AdamW) (default: False) | ||
.. _Adaptive Subgradient Methods for Online Learning and Stochastic | ||
Optimization: http://jmlr.org/papers/v12/duchi11a.html | ||
""" | ||
def __init__(self, params, lr=1e-2, eps=1e-10, | ||
weight_decay=0., set_grad_none=True, adagrad_w_mode=False): | ||
|
||
defaults = dict(lr=lr, eps=eps, weight_decay=weight_decay) | ||
super(FusedAdagrad, self).__init__(params, defaults) | ||
self.adagrad_w_mode = 1 if adagrad_w_mode else 0 | ||
self.set_grad_none = set_grad_none | ||
|
||
if multi_tensor_applier.available: | ||
import amp_C | ||
# Skip buffer | ||
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) | ||
self.multi_tensor_adagrad = amp_C.multi_tensor_adagrad | ||
else: | ||
raise RuntimeError('apex.optimizers.FusedAdagrad requires cuda extensions') | ||
|
||
def zero_grad(self): | ||
if self.set_grad_none: | ||
for group in self.param_groups: | ||
for p in group['params']: | ||
p.grad = None | ||
else: | ||
super(FusedAdagrad, self).zero_grad() | ||
|
||
def step(self, closure=None): | ||
"""Performs a single optimization step. | ||
Arguments: | ||
closure (callable, optional): A closure that reevaluates the model | ||
and returns the loss. | ||
""" | ||
loss = None | ||
if closure is not None: | ||
loss = closure() | ||
|
||
for group in self.param_groups: | ||
# create lists for multi-tensor apply | ||
g_16, p_16, h_16 = [], [], [] | ||
g_32, p_32, h_32 = [], [], [] | ||
|
||
for p in group['params']: | ||
if p.grad is None: | ||
continue | ||
if p.grad.data.is_sparse: | ||
raise RuntimeError('FusedAdagrad does not support sparse gradients') | ||
|
||
state = self.state[p] | ||
# State initialization | ||
if len(state) == 0: | ||
# Exponential moving average of gradient values | ||
state['sum'] = torch.zeros_like(p.data) | ||
if p.dtype == torch.float16: | ||
g_16.append(p.grad.data) | ||
p_16.append(p.data) | ||
h_16.append(state['sum']) | ||
elif p.dtype == torch.float32: | ||
g_32.append(p.grad.data) | ||
p_32.append(p.data) | ||
h_32.append(state['sum']) | ||
else: | ||
raise RuntimeError('FusedAdagrad only support fp16 and fp32.') | ||
|
||
if(len(g_16) > 0): | ||
multi_tensor_applier(self.multi_tensor_adagrad, | ||
self._dummy_overflow_buf, | ||
[g_16, p_16, h_16], | ||
group['lr'], | ||
group['eps'], | ||
self.adagrad_w_mode, | ||
group['weight_decay']) | ||
if(len(g_32) > 0): | ||
multi_tensor_applier(self.multi_tensor_adagrad, | ||
self._dummy_overflow_buf, | ||
[g_32, p_32, h_32], | ||
group['lr'], | ||
group['eps'], | ||
self.adagrad_w_mode, | ||
group['weight_decay']) | ||
|
||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/AccumulateType.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <ATen/cuda/Exceptions.h> | ||
// Another possibility: | ||
// #include <torch/all.h> | ||
|
||
#include <assert.h> | ||
|
||
#include "multi_tensor_apply.cuh" | ||
#include "type_shim.h" | ||
|
||
#define BLOCK_SIZE 1024 | ||
#define ILP 4 | ||
|
||
typedef enum { | ||
ADAGRAD_MODE_0 = 0, // L2 regularization mode. | ||
ADAGRAD_MODE_1 = 1, // AdamW-style weight decay. | ||
|
||
} adagradMode_t; | ||
|
||
using MATH_T = float; | ||
|
||
template <typename T> struct AdagradFunctor { | ||
__device__ __forceinline__ void | ||
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> &tl, | ||
const float epsilon, const float lr, adagradMode_t mode, | ||
const float weight_decay) { | ||
int tensor_loc = tl.block_to_tensor[blockIdx.x]; | ||
int chunk_idx = tl.block_to_chunk[blockIdx.x]; | ||
int n = tl.sizes[tensor_loc]; | ||
|
||
T *g = (T *)tl.addresses[0][tensor_loc]; | ||
g += chunk_idx * chunk_size; | ||
|
||
T *p = (T *)tl.addresses[1][tensor_loc]; | ||
p += chunk_idx * chunk_size; | ||
|
||
T *h = (T *)tl.addresses[2][tensor_loc]; | ||
h += chunk_idx * chunk_size; | ||
|
||
n -= chunk_idx * chunk_size; | ||
|
||
// see note in multi_tensor_scale_kernel.cu | ||
for (int i_start = 0; i_start < n && i_start < chunk_size; | ||
i_start += blockDim.x * ILP) { | ||
MATH_T r_g[ILP]; | ||
MATH_T r_p[ILP]; | ||
MATH_T r_h[ILP]; | ||
#pragma unroll | ||
for (int ii = 0; ii < ILP; ii++) { | ||
int i = i_start + threadIdx.x + ii * blockDim.x; | ||
if (i < n && i < chunk_size) { | ||
r_g[ii] = g[i]; | ||
r_p[ii] = p[i]; | ||
r_h[ii] = h[i]; | ||
} else { | ||
r_g[ii] = MATH_T(0); | ||
r_p[ii] = MATH_T(0); | ||
r_h[ii] = MATH_T(0); | ||
} | ||
} | ||
#pragma unroll | ||
for (int ii = 0; ii < ILP; ii++) { | ||
if (mode == ADAGRAD_MODE_0) { // L2 | ||
r_g[ii] = r_g[ii] + weight_decay * r_p[ii]; | ||
r_h[ii] = r_h[ii] + r_g[ii] * r_g[ii]; | ||
r_p[ii] = r_p[ii] - lr * (r_g[ii] / (sqrtf(r_h[ii]) + epsilon)); | ||
} else { // AdamW-style | ||
r_h[ii] = r_h[ii] + r_g[ii] * r_g[ii]; | ||
r_p[ii] = r_p[ii] - lr * (r_g[ii] / (sqrtf(r_h[ii]) + epsilon) + weight_decay * r_p[ii]); | ||
} | ||
} | ||
#pragma unroll | ||
for (int ii = 0; ii < ILP; ii++) { | ||
int i = i_start + threadIdx.x + ii * blockDim.x; | ||
if (i < n && i < chunk_size) { | ||
p[i] = r_p[ii]; | ||
h[i] = r_h[ii]; | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
void multi_tensor_adagrad_cuda( | ||
int chunk_size, at::Tensor noop_flag, | ||
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr, | ||
const float epsilon, const int mode, const float weight_decay) { | ||
using namespace at; | ||
|
||
// Assume single type across p,g,h now | ||
DISPATCH_DOUBLE_FLOAT_AND_HALF( | ||
tensor_lists[0][0].scalar_type(), 0, "adagrad", | ||
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, | ||
AdagradFunctor<scalar_t_0>(), epsilon, lr, | ||
(adagradMode_t)mode, weight_decay);) | ||
|
||
AT_CUDA_CHECK(cudaGetLastError()); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import unittest | ||
|
||
import apex | ||
import torch | ||
|
||
|
||
class TestFusedAdagrad(unittest.TestCase): | ||
def setUp(self, max_abs_diff=1e-6, max_rel_diff=1, iters=7): | ||
self.max_abs_diff = max_abs_diff | ||
self.max_rel_diff = max_rel_diff | ||
self.iters = iters | ||
torch.cuda.manual_seed(9876) | ||
|
||
def tearDown(self): | ||
pass | ||
|
||
def gen_param_optim(self, tensors, adagrad_option): | ||
ref_param = [] | ||
tst_param = [] | ||
for tensor in tensors: | ||
ref_param.append(torch.nn.Parameter(tensor.clone())) | ||
tst_param.append(torch.nn.Parameter(tensor.clone())) | ||
|
||
ref_optim = torch.optim.Adagrad(ref_param, **adagrad_option) | ||
tst_optim = apex.optimizers.FusedAdagrad(tst_param, **adagrad_option) | ||
|
||
return (ref_param, tst_param, ref_optim, tst_optim) | ||
|
||
def gen_grad(self, ref_param, tst_param): | ||
for p_ref, p_tst in zip(ref_param, tst_param): | ||
p_ref.grad = torch.rand_like(p_ref) | ||
p_tst.grad = p_ref.grad | ||
|
||
def gen_mixed_grad(self, ref_param, tst_param, scale=1.0): | ||
half_grads = [] | ||
for p_ref, _ in zip(ref_param, tst_param): | ||
half_grads.append(torch.rand_like(p_ref).half()) | ||
p_ref.grad = half_grads[-1].float() / scale | ||
return half_grads | ||
|
||
def get_max_diff(self, ref_param, tst_param): | ||
max_abs_diff = max_rel_diff = 0 | ||
for p_ref, p_tst in zip(ref_param, tst_param): | ||
max_abs_diff_p = (p_ref - p_tst).abs().max().item() | ||
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item() | ||
|
||
if max_abs_diff_p > max_abs_diff: | ||
max_abs_diff = max_abs_diff_p | ||
if max_rel_diff_p > max_rel_diff: | ||
max_rel_diff = max_rel_diff_p | ||
|
||
return max_abs_diff, max_rel_diff | ||
|
||
def gen_single_type_test(self, param_type=torch.float): | ||
nelem = 278011 | ||
adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 1.0e-5} | ||
|
||
tensor = torch.rand(nelem, dtype=param_type, device="cuda") | ||
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim( | ||
[tensor], adagrad_option | ||
) | ||
|
||
for _ in range(self.iters): | ||
self.gen_grad(ref_param, tst_param) | ||
ref_optim.step() | ||
tst_optim.step() | ||
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) | ||
|
||
self.assertLessEqual(max_abs_diff, self.max_abs_diff) | ||
self.assertLessEqual(max_rel_diff, self.max_rel_diff) | ||
|
||
def test_float(self): | ||
self.gen_single_type_test(param_type=torch.float) | ||
|
||
@unittest.skip("PyTorch optimizer is not numerically correct for fp16") | ||
def test_half(self): | ||
self.gen_single_type_test(param_type=torch.float16) | ||
|
||
def test_multi_params(self): | ||
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] | ||
adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0} | ||
|
||
tensors = [] | ||
for size in sizes: | ||
tensors.append(torch.rand(size, dtype=torch.float, device="cuda")) | ||
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim( | ||
tensors, adagrad_option | ||
) | ||
|
||
for _ in range(self.iters): | ||
self.gen_grad(ref_param, tst_param) | ||
ref_optim.step() | ||
tst_optim.step() | ||
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) | ||
self.assertLessEqual(max_abs_diff, self.max_abs_diff) | ||
self.assertLessEqual(max_rel_diff, self.max_rel_diff) | ||
|
||
def test_adagrad_option(self): | ||
nelem = 1 | ||
adagrad_option = {"lr": 0.01, "eps": 3e-06, "weight_decay": 0} | ||
|
||
tensor = torch.rand(nelem, dtype=torch.float, device="cuda") | ||
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim( | ||
[tensor], adagrad_option | ||
) | ||
|
||
for _ in range(self.iters): | ||
self.gen_grad(ref_param, tst_param) | ||
ref_optim.step() | ||
tst_optim.step() | ||
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) | ||
|
||
self.assertLessEqual(max_abs_diff, self.max_abs_diff) | ||
self.assertLessEqual(max_rel_diff, self.max_rel_diff) |