Skip to content

Commit

Permalink
stashing work
Browse files Browse the repository at this point in the history
  • Loading branch information
definitelynotmcarilli committed Feb 8, 2019
1 parent b2f63c4 commit 1f693b9
Show file tree
Hide file tree
Showing 6 changed files with 340 additions and 20 deletions.
3 changes: 2 additions & 1 deletion apex/amp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .amp import init, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function
register_half_function, register_float_function, register_promote_function,\
register
11 changes: 11 additions & 0 deletions apex/amp/amp.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from . import compat, rnn_compat, utils, wrap
from .handle import AmpHandle, NoOpHandle
from .lists import functional_overrides, torch_overrides, tensor_overrides
from ..fp16_utils import FP16_Optimizer
from .frontend import *

import functools
import itertools

import torch


_DECORATOR_HANDLE = None
_USER_CAST_REGISTRY = set()
_USER_PROMOTE_REGISTRY = set()


def _decorator_helper(orig_fn, cast_fn, wrap_fn):
def wrapper(*args, **kwargs):
handle = _DECORATOR_HANDLE
Expand All @@ -21,38 +25,45 @@ def wrapper(*args, **kwargs):
return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs)
return wrapper


# Decorator form
def half_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_half, wrap_fn)


def float_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)


def promote_function(fn):
wrap_fn = functools.partial(wrap.make_promote_wrapper)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)


# Registry form
def register_half_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_half))


def register_float_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_float))


def register_promote_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_PROMOTE_REGISTRY.add((module, name))


# Top-level function to insert _all_ the hooks.
def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
global _DECORATOR_HANDLE
Expand Down
242 changes: 242 additions & 0 deletions apex/amp/frontend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
import torch
from .initialize import initialize


class Properties(object):
"""
The purpose of this class is twofold: to establish a set of default properties,
and to route setting of these attributes through __setattr__ so that (in theory)
they can be checked for consistency with other existing args.
"""
def __init__(self):
self.options = {
"opt_level" : None,
"cast_model_type" : None,
"cast_torch_functions" : False,
"cast_batchnorm" : None,
"master_weights" : False,
"loss_scale" : 1.0,
"flatten_model_params" : False,
"flatten_master_params" : False,
"enable_ddp_interop" : False}

"""
This function will allow updating several options at a time without routing through
__setattr__ checks, to avoid "you can't get there from here" scenarios.
"""
def update_options_dict(new_options):
for k, v in new_options:
if k in self.options:
self.options[k] = v
else:
raise ValueError("Tried to set unexpected option {}".format(k))
"""
The members of options are not direct attributes of self, so __getattr__ is ok.
This borrows from the logic in torch.nn.Module.
"""
def __getattr__(self, name):
if "options" in self.__dict__:
options = self.__dict__["options"]
if name in options:
return options[name]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, name))

def __setattr__(self, name, value):
if "options" in self.__dict__:
if name in self.options:
print("setting {}".format(name))
self.options[name] = value
else:
super(Properties, self).__setattr__(name, value)


""" O0-O3 are convenience wrappers to establish defaults for typically used mixed precision options. """

class O3:
brief = "O3: Pure FP16 training."
more = "Calls .half() on your model, converting the entire model to FP16.\n"\
"A casting operation is also inserted to cast incoming Tensors to FP16,\n"\
"so you don't need to change your data pipeline.\n"\
"This mode is useful for establishing a performance ceiling.\n"\
"It's also possible training may 'just work' in this mode.\n"\
"If not, try other optimization levels."

def __call__(self, properties):
properties.opt_level = "O3",
properties.cast_model_type = torch.float16
properties.cast_torch_functions = False
properties.cast_batchnorm = False
properties.master_weights = False
properties.loss_scale = 1.0
properties.flatten_model_params = False
properties.flatten_master_params = False
properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary


class O2:
brief = "O2: FP16 training with FP32 batchnorm and FP32 master weights.\n"
more = "Calls .half() on your model, converting the entire model (except for batchnorms)\n"\
"to FP16. Batchnorms are retained in FP32 for additional stability.\n"\
"The forward pass is patched to cast incoming Tensors to FP16, so you don't need to change\n"\
"your data pipeline.\n"\
"O2 creates FP32 master weights outside the model and patches any optimizers to update\n"\
"these master weights, then copy the master weights into the FP16 model weights.\n"\
"Master weights can also improve convergence and stability."

def __call__(self, properties):
properties.opt_level = "O2",
properties.cast_model_type = torch.float16
properties.cast_torch_functions = False
properties.cast_batchnorm = torch.float32
properties.master_weights = True
properties.loss_scale = 128.0
properties.flatten_model_params = False
properties.flatten_master_params = False
properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary


class O1:
brief = "O1: Insert automatic casts around Pytorch functions and Tensor methods.\n"
more = "The type of your model's weights is not altered. However, internally,\n"\
"Pytorch functions are patched to cast any Tensor Core-friendly ops to FP16 for speed,\n"\
"while operations that might benefit from the additional stability of FP32 are patched\n"\
"to cast their inputs to fp32.\n"\
"O1 is the safest way to try mixed precision training, and is recommended when\n"\
"trying mixed precision training for the first time."

def __call__(self, properties):
properties.opt_level = "O1",
properties.cast_model_type = False
properties.cast_torch_functions = True
properties.cast_batchnorm = False
properties.master_weights = False
properties.loss_scale = "dynamic"
properties.flatten_model_params = False
properties.flatten_master_params = False
properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary


class O0:
brief = "O0: Pure FP32 training.\n"
more = "Your models are checked to make sure parameters are FP32, but otherwise the\n"\
"types of weights and internal Pytorch operations are not altered. This mode disables any\n"\
"FP16 arithmetic, although other optimizations like parameter flattening and DDP interop\n"\
"may still be requested.\n"

def __call__(self, properties):
properties.opt_level = "O0",
properties.cast_model_type = torch.float32
properties.cast_torch_functions = False
properties.cast_batchnorm = False
properties.master_weights = False
properties.loss_scale = 1.0
properties.flatten_model_params = False
properties.flatten_master_params = False
properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary


opt_levels = {"O3": O3(),
"O2": O2(),
"O1": O1(),
"O0": O0()}

def check_params_fp32(model):
for name, param in model.named_parameters():
if param.type() != "torch.cuda.FloatTensor":
print("Warning: Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.register, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.",
name, param.type())

for name, param in model.named_buffers():
if param.type() != "torch.cuda.FloatTensor":
print("Warning: Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.register, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.",
name, param.type())


# allow user to directly pass Properties struct as well?
def register(enabled=False,
optimizers=None,
models=None,
opt_level=None,
cast_model_type=None,
cast_torch_functions=None,
cast_batchnorm=None,
master_weights=None,
loss_scale=None,
flatten_model_params=None,
flatten_master_params=None,
enable_ddp_interop=None):

if not enabled:
return

if opt_level not in opt_levels:
raise RuntimeError("Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'.")
else:
amp.opt_properties = opt_levels[opt_level](Properties())
print("Selected optimization level {}", opt_levels[opt_level].brief)
print("Defaults for this optimization level are:")
for k, v in amp.opt_properties.options:
print("{:20} : {}", k, v)

for model in models:
check_params_fp32(model)

print("Processing user overrides (additional kwargs that are not None)...")
for k, v in kwargs:
if v is not None:
setattr(amp.opt_properties, k, v)

print("After processing overrides, optimization options are:")
for k, v in amp.opt_properties.options:
print("{:20} : {}", k, v)

initialize(optimizers, models)


def check_option_consistency(enabled=False,
opt_level=None,
cast_model_type=None,
cast_torch_functions=None,
cast_batchnorm=None,
master_weights=None,
loss_scale=None,
flatten_model_params=None,
flatten_master_params=None,
enable_ddp_interop=None):
"""
Utility function that enables users to quickly check if the option combination they intend
to use is permitted. ``check_option_consistency`` does not require models or optimizers
to be constructed, and can be called at any point in the script. ``check_option_consistency``
is totally self-contained; it does not set any amp global state or affect anything outside
of itself.
"""

if not enabled:
return

if opt_level not in opt_levels:
raise RuntimeError("Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'.")
else:
opt_properties = opt_levels[opt_level](Properties())
print("Selected optimization level {}", opt_levels[opt_level].brief)
print("Defaults for this optimization level are:")
for k, v in opt_properties.options:
print("{:20} : {}", k, v)

print("Processing user overrides (additional kwargs that are not None)...")
for k, v in kwargs:
if v is not None:
setattr(opt_properties, k, v)

print("After processing overrides, optimization options are:")
for k, v in opt_properties.options:
print("{:20} : {}", k, v)
57 changes: 57 additions & 0 deletions apex/amp/initialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
from torch._six import container_abcs, string_classes
import functools


def to_type(dtype, t):
if not t.is_cuda:
print("Warning: input tensor was not cuda. Call .cuda() on your data before passing it.")
if t.requires_grad:
print("Warning: input data requires grad. Since input data is not a model parameter,\n"
"its gradients will not be properly allreduced by DDP.")
if t.is_floating_point():
return t.half()
return t


# Modified from torch.optim.optimizer.py. This is a bit more general than casted_args in utils.py.
def applier(value, fn):
if isinstance(value, torch.Tensor):
return fn(value)
elif isinstance(value, string_classes):
return value
elif isinstance(value, container_abcs.Mapping):
return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(applier(v, fn) for v in value)
else:
return value


def initialize(optimizers, models, properties):

# Stash master weights before casting the model.
# if properties.master_weights:

if properties.cast_model_type is not None:
if properties.cast_batchnorm is not None:
for model in models:
model.to(properties.cast_model_type)
else:
for model in models:
model.to(properties.cast_model_type)

caster = functools.partial(to_type, properties.cast_model_type)

# Patch the forward method to cast incoming data to the correct type.
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
return old_fwd(*applier(args, caster),
**applier(kwargs, caster))
return new_fwd

model.forward = patch_forward(model.forward)

# State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict())
1 change: 1 addition & 0 deletions apex/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def wrapper(param):
grad_acc = param_tmp.grad_fn.next_functions[0][0]

def allreduce_hook(*unused):
print("hook fired")
if self.delay_allreduce or self.needs_refresh:
# TODO: How do we want to handle multiple backward passes between
# each forward, e.g., backward passes with retain_graph=True?
Expand Down
Loading

0 comments on commit 1f693b9

Please sign in to comment.