Skip to content

Commit

Permalink
Cleaning up READMEs
Browse files Browse the repository at this point in the history
  • Loading branch information
definitelynotmcarilli committed Mar 4, 2019
1 parent 6066ddd commit df83b67
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 144 deletions.
105 changes: 40 additions & 65 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
# PSA: Unified API for mixed precision tools coming soon!
(as introduced by https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html.

Branch `api_refactor` is tracking my progress. Update as of 2/28: PR-ed in https://github.com/NVIDIA/apex/pull/173. I'd like to clean up the documentation a bit more before final merge.

# Introduction

This repository holds NVIDIA-maintained utilities to streamline
Expand All @@ -15,35 +10,24 @@ users as quickly as possible.

# Contents

## 1. Mixed Precision
## 1. Mixed Precision

### amp: Automatic Mixed Precision

`apex.amp` is a tool designed for ease of use and maximum safety in FP16 training. All potentially unsafe ops are performed in FP32 under the hood, while safe ops are performed using faster, Tensor Core-friendly FP16 math. `amp` also automatically implements dynamic loss scaling.

The intention of `amp` is to be the "on-ramp" to easy FP16 training: achieve all the numerical stability of full FP32 training, with most of the performance benefits of full FP16 training.

[Python Source and API Documentation](https://github.com/NVIDIA/apex/tree/master/apex/amp)

### FP16_Optimizer

`apex.FP16_Optimizer` wraps an existing Python optimizer and automatically implements master parameters and static or dynamic loss scaling under the hood.

The intention of `FP16_Optimizer` is to be the "highway" for FP16 training: achieve most of the numerically stability of full FP32 training, and almost all the performance benefits of full FP16 training.

[API Documentation](https://nvidia.github.io/apex/fp16_utils.html#automatic-management-of-master-params-loss-scaling)

[Python Source](https://github.com/NVIDIA/apex/tree/master/apex/fp16_utils)
`apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script.
Users can easily experiment with different pure and mixed precision training modes by supplying
different flags to `amp.initialize`.

[Simple examples with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple)
[Webinar introducing Amp](https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html)
(The flag `cast_batchnorm` has been renamed to `keep_batchnorm_fp32`).

[Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
[API Documentation](https://nvidia.github.io/apex/amp.html)

[word_language_model with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/word_language_model)
[Comprehensive Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)

The Imagenet and word_language_model directories also contain examples that show manual management of master parameters and static loss scaling.
[DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan)

These manual examples illustrate what sort of operations `amp` and `FP16_Optimizer` are performing automatically.
[Moving to the new Amp API] (for users of the deprecated tools formerly called "Amp" and "FP16_Optimizer")

## 2. Distributed Training

Expand All @@ -57,69 +41,60 @@ optimized for NVIDIA's NCCL communication library.

[Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/distributed)

The [Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
mixed precision examples also demonstrate `apex.parallel.DistributedDataParallel`.
The [Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`.

### Synchronized Batch Normalization

`apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to
support synchronized BN.
It reduces stats across processes during multiprocess distributed data parallel
training.
Synchronous Batch Normalization has been used in cases where only very small
number of mini-batch could be fit on each GPU.
All-reduced stats boost the effective batch size for sync BN layer to be the
total number of mini-batches across all processes.
It has improved the converged accuracy in some of our research models.
It allreduces stats across processes during multiprocess (DistributedDataParallel) training.
Synchronous BN has been used in cases where only a small
local minibatch can fit on each GPU.
Allreduced stats increase the effective batch size for the BN layer to the
global batch size across all processes (which, technically, is the correct
formulation).
Synchronous BN has been observed to improve converged accuracy in some of our research models.

# Requirements

Python 3

CUDA 9 or 10
CUDA 9 or newer

PyTorch 0.4 or newer. We recommend to use the latest stable release, obtainable from
[https://pytorch.org/](https://pytorch.org/). We also test against the latest master branch, obtainable from [https://github.com/pytorch/pytorch](https://github.com/pytorch/pytorch).
If you have any problems building, please file an issue.

The cpp and cuda extensions require pytorch 1.0 or newer.
PyTorch 0.4 or newer. The CUDA and C++ extensions require pytorch 1.0 or newer.

We recommend the latest stable release, obtainable from
[https://pytorch.org/](https://pytorch.org/). We also test against the latest master branch, obtainable from [https://github.com/pytorch/pytorch](https://github.com/pytorch/pytorch).

It's often convenient to use Apex in Docker containers. Compatible options include:
* [NVIDIA Pytorch containers from NGC](https://ngc.nvidia.com/catalog/containers/nvidia%2Fpytorch), which come with Apex preinstalled. To use the latest Amp API, you may need to `pip uninstall apex` then reinstall Apex using the **Quick Start** commands below.
* [official Pytorch -devel Dockerfiles](https://hub.docker.com/r/pytorch/pytorch/tags), e.g. `docker pull pytorch/pytorch:nightly-devel-cuda10.0-cudnn7`, in which you can install Apex using the **Quick Start** commands.

# Quick Start

### Linux
To build the extension run
```
python setup.py install
```
in the root directory of the cloned repository.

To use the extension
For performance and full functionality, we recommend installing Apex with
CUDA and C++ extensions via
```
import apex
$ git clone apex
$ cd apex
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
```

### CUDA/C++ extension
Apex contains optional CUDA/C++ extensions, installable via
Apex also supports a Python-only build (required with Pytorch 0.4) via
```
python setup.py install [--cuda_ext] [--cpp_ext]
$ pip install -v --no-cache-dir .
```
Currently, `--cuda_ext` enables
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
A Python-only build omits:
- Fused kernels required to use `apex.optimizers.FusedAdam`.
- Fused kernels required to use `apex.normalization.FusedLayerNorm`.

`--cpp_ext` enables
- C++-side flattening and unflattening utilities that reduce the CPU overhead of `apex.parallel.DistributedDataParallel`.
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.

### Windows support
Windows support is experimental, and Linux is recommended. However, since Apex could be Python-only, there's a good chance the Python-only features "just works" the same way as Linux. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.

<!--
reparametrization and RNN API under construction
Current version of apex contains:
3. Reparameterization function that allows you to recursively apply reparameterization to an entire module (including children modules).
4. An experimental and in development flexible RNN API.
-->
Windows support is experimental, and Linux is recommended. `python setup.py install --cpp_ext --cuda_ext` may work if you were able to build Pytorch from source
on your system. `python setup.py install` (without CUDA/C++ extensions) is more likely to work. If you installed Pytorch in a Conda environment,
make sure to install Apex in that same environment.
4 changes: 2 additions & 2 deletions apex/amp/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# amp: Automatic Mixed Precision

## This README documents the legacy (pre-Amp 1.0) API.
## This README documents the deprecated (pre-unified) API.

## Documentation for the new 1.0 API can be found [here](https://nvidia.github.io/apex/)
## Documentation for the current unified API can be found [here](https://nvidia.github.io/apex/)

amp is an experimental tool to enable mixed precision training in
PyTorch with extreme simplicity and overall numerical safety. It
Expand Down
10 changes: 9 additions & 1 deletion apex/amp/_amp_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
# But apparently it's ok:
# http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm
class AmpState(object):
pass
def __init__(self):
self.hard_override=False

# Attribute stash. Could also just stash things as global module attributes.
_amp_state = AmpState()

def warn_or_err(msg):
if _amp_state.hard_override:
print("Warning: " + msg)
else:
raise RuntimeError(msg + " If you're sure you know what you're doing, supply " +
"hard_override=True to amp.initialize.")
28 changes: 15 additions & 13 deletions apex/amp/_initialize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch._six import container_abcs, string_classes
import functools
from ._amp_state import _amp_state
from ._amp_state import _amp_state, warn_or_err
from .handle import disable_casts
from .scaler import LossScaler
from apex.fp16_utils import convert_network
Expand All @@ -13,10 +13,12 @@

def to_type(dtype, t):
if not t.is_cuda:
print("Warning: input tensor was not cuda. Call .cuda() on your data before passing it.")
# This should not be a hard error, since it may be legitimate.
print("Warning: An input tensor was not cuda. ")
if t.requires_grad:
print("Warning: input data requires grad. Since input data is not a model parameter,\n"
"its gradients will not be properly allreduced by DDP.")
# This should be a hard-ish error.
warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
"its gradients will not be properly allreduced by DDP.")
if t.is_floating_point():
return t.to(dtype)
return t
Expand Down Expand Up @@ -55,17 +57,17 @@ def check_params_fp32(models):
for model in models:
for name, param in model.named_parameters():
if param.is_floating_point() and param.type() != "torch.cuda.FloatTensor":
print("Warning: Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, param.type()))
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, param.type()))

for name, buf in model.named_buffers():
if buf.is_floating_point() and buf.type() != "torch.cuda.FloatTensor":
print("Warning: Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, buf.type()))
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, buf.type()))


def check_optimizers(optimizers):
Expand All @@ -77,7 +79,7 @@ def check_optimizers(optimizers):
bad_optim_type = "apex.optimizers.FP16_Optimizer"
if bad_optim_type is not None:
raise RuntimeError("An incoming optimizer is an instance of {}. ".format(optim_type) +
"The optimizer(s) passed to amp.initialize() should be bare \n"
"The optimizer(s) passed to amp.initialize() must be bare \n"
"instances of either ordinary Pytorch optimizers, or Apex fused \n"
"optimizers (currently just FusedAdam, but FusedSGD will be added \n"
"soon). You should not manually wrap your optimizer in either \n"
Expand Down
97 changes: 43 additions & 54 deletions apex/amp/frontend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from ._initialize import _initialize
from ._amp_state import _amp_state
from ._amp_state import _amp_state, warn_or_err


class Properties(object):
Expand Down Expand Up @@ -165,21 +165,6 @@ def __call__(self, properties):
"O1": O1(),
"O0": O0()}

def check_params_fp32(model):
for name, param in model.named_parameters():
if param.type() != "torch.cuda.FloatTensor":
print("Warning: Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, param.type()))

for name, param in model.named_buffers():
if param.type() != "torch.cuda.FloatTensor":
print("Warning: Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, param.type()))


# allow user to directly pass Properties struct as well?
def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
Expand All @@ -193,6 +178,8 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
loss_scale=None,)
"""
if not enabled:
if "hard_override" in kwargs:
_amp_state.hard_override = kwargs["hard_override"]
_amp_state.opt_properties = Properties()
return models, optimizers

Expand Down Expand Up @@ -222,41 +209,43 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
return _initialize(models, optimizers, _amp_state.opt_properties)


def check_option_consistency(enabled=True,
opt_level=None,
cast_model_type=None,
patch_torch_functions=None,
keep_batchnorm_fp32=None,
master_weights=None,
loss_scale=None,
enable_ddp_interop=None):
"""
Utility function that enables users to quickly check if the option combination they intend
to use is permitted. ``check_option_consistency`` does not require models or optimizers
to be constructed, and can be called at any point in the script. ``check_option_consistency``
is totally self-contained; it does not set any amp global state or affect anything outside
of itself.
"""

if not enabled:
return

if opt_level not in opt_levels:
raise RuntimeError("Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'.")
else:
opt_properties = opt_levels[opt_level](Properties())
print("Selected optimization level {}", opt_levels[opt_level].brief)
print("Defaults for this optimization level are:")
for k, v in opt_properties.options:
print("{:22} : {}".format(k, v))

print("Processing user overrides (additional kwargs that are not None)...")
for k, v in kwargs:
if k not in amp_state.opt_properties.options:
raise RuntimeError("Unexpected kwarg {}".format(k))
if v is not None:
setattr(opt_properties, k, v)

print("After processing overrides, optimization options are:")
for k, v in opt_properties.options:
print("{:22} : {}".format(k, v))
# TODO: is this necessary/useful?
# def check_option_consistency(enabled=True,
# opt_level=None,
# cast_model_type=None,
# patch_torch_functions=None,
# keep_batchnorm_fp32=None,
# master_weights=None,
# loss_scale=None,
# enable_ddp_interop=None,
# hard_override=False):
# """
# Utility function that enables users to quickly check if the option combination they intend
# to use is permitted. ``check_option_consistency`` does not require models or optimizers
# to be constructed, and can be called at any point in the script. ``check_option_consistency``
# is totally self-contained; it does not set any amp global state or affect anything outside
# of itself.
# """
#
# if not enabled:
# return
#
# if opt_level not in opt_levels:
# raise RuntimeError("Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'.")
# else:
# opt_properties = opt_levels[opt_level](Properties())
# print("Selected optimization level {}", opt_levels[opt_level].brief)
# print("Defaults for this optimization level are:")
# for k, v in opt_properties.options:
# print("{:22} : {}".format(k, v))
#
# print("Processing user overrides (additional kwargs that are not None)...")
# for k, v in kwargs:
# if k not in _amp_state.opt_properties.options:
# raise RuntimeError("Unexpected kwarg {}".format(k))
# if v is not None:
# setattr(opt_properties, k, v)
#
# print("After processing overrides, optimization options are:")
# for k, v in opt_properties.options:
# print("{:22} : {}".format(k, v))
1 change: 1 addition & 0 deletions examples/dcgan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Under construction...
Loading

0 comments on commit df83b67

Please sign in to comment.