Skip to content

Commit

Permalink
Merge branch 'master' into api_refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
definitelynotmcarilli committed Feb 6, 2019
2 parents a9a3fe5 + 340e71a commit 2cbca1a
Show file tree
Hide file tree
Showing 29 changed files with 733 additions and 181 deletions.
28 changes: 9 additions & 19 deletions apex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,13 @@
from . import fp16_utils
from . import parallel
from . import amp
try:
from . import optimizers
except ImportError:
# An attempt to fix https://github.com/NVIDIA/apex/issues/97. I'm not sure why 97 is even
# happening because Python modules should only be imported once, even if import is called
# multiple times.
try:
_ = warned_optimizers
except NameError:
print("Warning: apex was installed without --cuda_ext. FusedAdam will be unavailable.")
warned_optimizers = True
try:
from . import normalization
except ImportError:
try:
_ = warned_normalization
except NameError:
print("Warning: apex was installed without --cuda_ext. FusedLayerNorm will be unavailable.")
warned_normalization = True

# For optimizers and normalization there is no Python fallback.
# Absence of cuda backend is a hard error.
# I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda
# to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext
# so they expect those backends to be available, but for some reason they actually aren't
# available (for example because they built improperly in a way that isn't revealed until
# load time) the error message is timely and visible.
from . import optimizers
from . import normalization
11 changes: 2 additions & 9 deletions apex/amp/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,16 @@ def scale_loss(self, loss, optimizer):
'use `optimizer.scale_loss(loss)`.')

# TODO: this code block is duplicated here and `opt.py`. Unify.
loss_backward = loss.backward
def warning_wrapper():
warnings.warn("You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost "
"certainly an error.", stacklevel=2)
loss_backward()
loss.backward = warning_wrapper
loss_scale = self._default_scaler.loss_scale()
yield loss * loss_scale
loss.backward = loss_backward

should_skip = self._default_scaler.unscale_and_update(
optimizer.param_groups, loss_scale)
if should_skip:
optimizer_step = optimizer.step
def skip_step():
logging.info('Gradient overflow, skipping update')
logger = logging.getLogger('apex.amp')
logger.warning('Gradient overflow, skipping update')
optimizer.step = optimizer_step
optimizer.step = skip_step

Expand Down
12 changes: 2 additions & 10 deletions apex/amp/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@ def scale_loss(self, loss):
yield loss
return

loss_backward = loss.backward
def warning_wrapper():
warnings.warn("You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost "
"certainly an error.", stacklevel=2)
loss_backward()
loss.backward = warning_wrapper

# When there are multiple losses per-optimizer, we need
# to save out current grad accumulation, since we won't be
# able to unscale this particulare loss once the grads are
Expand All @@ -44,7 +36,6 @@ def warning_wrapper():

loss_scale = self._cur_loss_scaler().loss_scale()
yield loss * loss_scale
loss.backward = loss_backward

self._skip_next[self._loss_idx] = self._cur_loss_scaler().unscale_and_update(
self._optimizer.param_groups, loss_scale)
Expand Down Expand Up @@ -76,7 +67,8 @@ def step(self, closure=None):
'The `closure` argument is unsupported by the amp ' +
'optimizer wrapper.')
if any(self._skip_next):
logging.info('Gradient overflow, skipping update')
logger = logging.getLogger('apex.amp')
logger.info('Gradient overflow, skipping update')
self._skip_next = [False] * self._num_loss
else:
return self._optimizer.step(closure=closure)
Expand Down
50 changes: 41 additions & 9 deletions apex/amp/scaler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
import logging

# from apex_C import scale_check_overflow

# Python stopgap, until we get a future-proof kernel into upstream
def scale_check_overflow(d_grads, scale):
def scale_check_overflow_python(d_grads, scale):
# Exception handling for 18.04 compatibility
try:
cpu_sum = float(d_grads.float().sum())
Expand All @@ -18,28 +18,60 @@ def scale_check_overflow(d_grads, scale):
return False

class LossScaler(object):
warned_no_fused_kernel = False
warned_fp16_grad = False
has_fused_kernel = False

def __init__(self):
self._loss_scale = 2.**16
self._max_loss_scale = 2.**24
self._scale_seq_len = 2000
self._unskipped = 0
self._has_overflow = False
# self._overflow_buf = torch.cuda.ByteTensor(1024,)
try:
import amp_C
LossScaler.has_fused_kernel = True
LossScaler.scale_check_overflow_cuda = amp_C.scale_check_overflow
self._overflow_buf = torch.cuda.IntTensor([0])
except ImportError as err:
if not LossScaler.warned_no_fused_kernel:
print("Warning: Amp fused downscale kernel is unavailable, possibly because apex "
"was installed without --cuda_ext. Using Python fallback. ImportError was: ",
err)
LossScaler.has_fused_kernel = False
LossScaler.warned_no_fused_kernel = True

def loss_scale(self):
return self._loss_scale

def unscale_and_update(self, param_groups, scale):
# self._overflow_buf.zero_()
if LossScaler.has_fused_kernel:
self._overflow_buf.zero_()
self._has_overflow = False
for p in iter_params(param_groups):
if p.grad is not None:
self._has_overflow = scale_check_overflow(p.grad.data,
1. / scale)
if self._has_overflow:
break
if LossScaler.has_fused_kernel and p.grad.data.type() == "torch.cuda.FloatTensor":
LossScaler.scale_check_overflow_cuda(p.grad.data,
1./scale,
self._overflow_buf,
p.grad.data)
else:
if (p.grad.data.type() != "torch.cuda.FloatTensor"
and not LossScaler.warned_fp16_grad):
logger = logging.getLogger("apex.amp")
logger.warning("Incoming grads are not fp32 (not master grads). "
"Downscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.")
LossScaler.warned_fp16_grad = True
self._has_overflow = scale_check_overflow_python(p.grad.data,
1./scale)
if self._has_overflow:
break

# If the fused kernel is available, we only need one D2H memcopy and sync.
if LossScaler.has_fused_kernel and not self._has_overflow:
self._has_overflow = self._overflow_buf.item()

# if self._overflow_buf.any():
if self._has_overflow:
should_skip = True
self._loss_scale /= 2.
Expand Down
12 changes: 10 additions & 2 deletions apex/normalization/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import numbers
from torch.nn.parameter import Parameter
from torch.nn import init

import fused_layer_norm_cuda
import importlib

class FusedLayerNormAffineFunction(torch.autograd.Function):
def __init__(self, normalized_shape, eps=1e-6):
global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")

self.normalized_shape = normalized_shape
self.eps = eps

Expand All @@ -31,6 +33,8 @@ def backward(self, grad_output):

class FusedLayerNormFunction(torch.autograd.Function):
def __init__(self, normalized_shape, eps=1e-6):
global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
self.normalized_shape = normalized_shape
self.eps = eps

Expand Down Expand Up @@ -117,6 +121,10 @@ class FusedLayerNorm(torch.nn.Module):
"""
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super(FusedLayerNorm, self).__init__()

global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")

if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
Expand Down
66 changes: 66 additions & 0 deletions apex/optimizers/fp16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,69 @@ def _set_param_groups(self, value):
self.optimizer.param_groups = value

param_groups = property(_get_param_groups, _set_param_groups)

def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['cur_scale'] = self.cur_scale
state_dict['cur_iter'] = self.cur_iter
if state_dict['dynamic_loss_scale']:
state_dict['last_overflow_iter'] = self.last_overflow_iter
state_dict['scale_factor'] = self.scale_factor
state_dict['scale_window'] = self.scale_window
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict['fp32_groups_flat'] = self.fp32_groups_flat
return state_dict

def load_state_dict(self, state_dict):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.cur_scale = state_dict['cur_scale']
self.cur_iter = state_dict['cur_iter']
if state_dict['dynamic_loss_scale']:
self.last_overflow_iter = state_dict['last_overflow_iter']
self.scale_factor = state_dict['scale_factor']
self.scale_window = state_dict['scale_window']
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 2.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']):
current.data.copy_(saved.data)

5 changes: 4 additions & 1 deletion apex/optimizers/fused_adam.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import types
import torch
import fused_adam_cuda
import importlib

class FusedAdam(torch.optim.Optimizer):

Expand Down Expand Up @@ -36,6 +36,9 @@ def __init__(self, params,
lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")

if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
Expand Down
11 changes: 5 additions & 6 deletions apex/parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@
ReduceOp = torch.distributed.deprecated.reduce_op

from .distributed import DistributedDataParallel, Reducer
# This is tricky because I'd like SyncBatchNorm to be exposed the same way
# for both the cuda-enabled and python-fallback versions, and I don't want
# to suppress the error information.
try:
import syncbn
from .optimized_sync_batchnorm import SyncBatchNorm
except ImportError:
try:
_ = warned_syncbn
except NameError:
print("Warning: apex was installed without --cuda_ext. Fused syncbn kernels will be unavailable. Python fallbacks will be used instead.")
warned_syncbn = True
except ImportError as err:
from .sync_batchnorm import SyncBatchNorm
SyncBatchNorm.syncbn_import_error = err

def convert_syncbn_model(module, process_group=None, channel_last=False):
'''
Expand Down
37 changes: 25 additions & 12 deletions apex/parallel/distributed.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
import torch
# from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
try:
from apex_C import flatten
from apex_C import unflatten
except ImportError:
try:
_ = warned_flatten
except NameError:
print("Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten.")
warned_flatten = True
from torch._utils import _flatten_dense_tensors as flatten
from torch._utils import _unflatten_dense_tensors as unflatten
import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
from collections import OrderedDict
from itertools import chain
import copy
import importlib

imported_flatten_impl = False

def import_flatten_impl():
global flatten_impl, unflatten_impl, imported_flatten_impl
try:
import apex_C
flatten_impl = apex_C.flatten
unflatten_impl = apex_C.unflatten
except ImportError:
print("Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten.")
flatten_impl = torch._utils._flatten_dense_tensors
unflatten_impl = torch._utils._unflatten_dense_tensors
imported_flatten_impl = True

def flatten(bucket):
if not imported_flatten_impl:
import_flatten_impl()
return flatten_impl(bucket)

def unflatten(coalesced, bucket):
if not imported_flatten_impl:
import_flatten_impl()
return unflatten_impl(coalesced, bucket)

# apply_dist_call requires that tensors in 'bucket' are all the same type.
def apply_flat_dist_call(bucket, call, extra_args=None):
Expand Down
1 change: 1 addition & 0 deletions apex/parallel/optimized_sync_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn import functional as F

import syncbn
from .optimized_sync_batchnorm_kernel import SyncBatchnormFunction


Expand Down
7 changes: 7 additions & 0 deletions apex/parallel/sync_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,14 @@ class SyncBatchNorm(_BatchNorm):
>>> out = sbn(inp)
"""

warned = False

def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None):

if not SyncBatchNorm.warned:
print("Warning: using Python fallback for SyncBatchNorm, possibly because apex was installed without --cuda_ext. The exception raised when attempting to import the cuda backend was: ", self.syncbn_import_error)
SyncBatchNorm.warned = True

super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group

Expand Down
Loading

0 comments on commit 2cbca1a

Please sign in to comment.