Skip to content

Commit

Permalink
Pytorch AMP (facebookresearch#102)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch/vissl#102

Pull Request resolved: facebookresearch#666

Add Pytorch AMP support. A follow up for FairScale would be to add ShardedGradScaler, so that we support mixed precision with ShardedDDP and ShardedOptimizer

Reviewed By: mannatsingh, prigoyal

Differential Revision: D25383305

fbshipit-source-id: fe3be9c850d4aa6e32c48144b04b42832eaa67f8
  • Loading branch information
blefaudeux authored and facebook-github-bot committed Dec 10, 2020
1 parent f9e1a2f commit ff37fea
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 32 deletions.
2 changes: 1 addition & 1 deletion classy_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from classy_vision.generic.distributed_util import get_rank, get_world_size
from classy_vision.generic.opts import check_generic_args, parse_train_arguments
from classy_vision.generic.registry_utils import import_all_packages_from_directory
from classy_vision.generic.util import load_checkpoint, load_json
from classy_vision.generic.util import load_json
from classy_vision.hooks import (
CheckpointHook,
LossLrMeterLoggingHook,
Expand Down
1 change: 0 additions & 1 deletion classy_vision/generic/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import argparse
import os

import torch
from classy_vision.generic.util import is_pos_int


Expand Down
2 changes: 1 addition & 1 deletion classy_vision/generic/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def is_leaf(module: nn.Module) -> bool:
Returns True if module is leaf in the graph.
"""
assert isinstance(module, nn.Module), "module should be nn.Module"
return len([c for c in module.children()]) == 0 or hasattr(module, "_mask")
return len(list(module.children())) == 0 or hasattr(module, "_mask")


def is_on_gpu(model: torch.nn.Module) -> bool:
Expand Down
7 changes: 6 additions & 1 deletion classy_vision/meters/accuracy_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,12 @@ def update(self, model_output, target, **kwargs):
# Convert target to 0/1 encoding if isn't
target = maybe_convert_to_one_hot(target, model_output)

_, pred = model_output.topk(max(self._topk), dim=1, largest=True, sorted=True)
# If Pytorch AMP is being used, model outputs are probably fp16
# Since .topk() is not compatible with fp16, we promote the model outputs to full precision
_, pred = model_output.float().topk(
max(self._topk), dim=1, largest=True, sorted=True
)

for i, k in enumerate(self._topk):
self._curr_correct_predictions_k[i] += (
torch.gather(target, dim=1, index=pred[:, :k])
Expand Down
2 changes: 0 additions & 2 deletions classy_vision/optim/classy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional

import torch

from .param_scheduler import (
ClassyParamScheduler,
ConstantParamScheduler,
Expand Down
107 changes: 85 additions & 22 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@
except ImportError:
apex_available = False

try:
from torch.cuda.amp import GradScaler as TorchGradScaler

except ImportError:
pass


class AmpType(enum.Enum):
# Automatic Mixed Precision supported types
APEX = enum.auto()
PYTORCH = enum.auto()


class BroadcastBuffersMode(enum.Enum):
DISABLED = enum.auto()
Expand Down Expand Up @@ -162,6 +174,8 @@ def __init__(self):
BroadcastBuffersMode.BEFORE_EVAL
)
self.amp_args = None
self.amp_type = None
self.amp_grad_scaler = None
self.mixup_transform = None
self.perf_log = []
self.last_batch = None
Expand Down Expand Up @@ -422,8 +436,24 @@ def set_amp_args(self, amp_args: Optional[Dict[str, Any]]):
if amp_args is None:
logging.info("AMP disabled")
else:
if not apex_available:
raise RuntimeError("apex is not installed, cannot enable amp")
# Check that the requested AMP type is known
try:
self.amp_type = AmpType[self.amp_args["amp_type"].upper()]
except KeyError:
logging.info("AMP type not specified, defaulting to Apex")
self.amp_type = AmpType.APEX

# Check for CUDA availability, required for both Apex and Pytorch AMP
if not torch.cuda.is_available():
raise RuntimeError(
"AMP is required but CUDA is not supported, cannot enable AMP"
)

# Check for Apex availability
if self.amp_type == AmpType.APEX and not apex_available:
raise RuntimeError(
"Apex AMP is required but Apex is not installed, cannot enable AMP"
)

logging.info(f"AMP enabled with args {amp_args}")
return self
Expand Down Expand Up @@ -701,19 +731,21 @@ def prepare(self):
)

if self.amp_args is not None:
# Initialize apex.amp. This updates the model and the PyTorch optimizer (
# if training, which is wrapped by the ClassyOptimizer in self.optimizer).
# Please note this must happen before loading the checkpoint, cause
# there's amp state to be restored.

if self.optimizer is None:
self.base_model = apex.amp.initialize(
self.base_model, optimizers=None, **self.amp_args
)
else:
self.base_model, self.optimizer.optimizer = apex.amp.initialize(
self.base_model, self.optimizer.optimizer, **self.amp_args
)
if self.amp_type == AmpType.APEX:
# Initialize apex.amp. This updates the model and the PyTorch optimizer (
# if training, which is wrapped by the ClassyOptimizer in self.optimizer).
# Please note this must happen before loading the checkpoint, cause
# there's amp state to be restored.
if self.optimizer is None:
self.base_model = apex.amp.initialize(
self.base_model, optimizers=None, **self.amp_args
)
else:
self.base_model, self.optimizer.optimizer = apex.amp.initialize(
self.base_model, self.optimizer.optimizer, **self.amp_args
)
elif self.amp_type == AmpType.PYTORCH:
self.amp_grad_scaler = TorchGradScaler()

if self.simulated_global_batchsize is not None:
if self.simulated_global_batchsize % self.get_global_batchsize() != 0:
Expand Down Expand Up @@ -836,7 +868,11 @@ def get_classy_state(self, deep_copy: bool = False):
if isinstance(self.base_loss, ClassyLoss):
classy_state_dict["loss"] = self.base_loss.get_classy_state()
if self.amp_args is not None:
classy_state_dict["amp"] = apex.amp.state_dict()
classy_state_dict["amp"] = (
apex.amp.state_dict()
if self.amp_type == AmpType.APEX
else self.amp_grad_scaler.state_dict()
)
if deep_copy:
classy_state_dict = copy.deepcopy(classy_state_dict)
return classy_state_dict
Expand Down Expand Up @@ -864,7 +900,10 @@ def set_classy_state(self, state):
self.base_loss.set_classy_state(state["loss"])

if "amp" in state:
apex.amp.load_state_dict(state["amp"])
if self.amp_type == AmpType.APEX:
apex.amp.load_state_dict(state["amp"])
else:
self.amp_grad_scaler.load_state_dict(state["amp"])

for hook in self.hooks:
# we still want to be able to run when new hooks are added or old
Expand Down Expand Up @@ -901,7 +940,14 @@ def eval_step(self):
if self.use_gpu:
sample = recursive_copy_to_gpu(sample, non_blocking=True)

with torch.no_grad():
# Optional Pytorch AMP context
torch_amp_context = (
torch.cuda.amp.autocast()
if self.amp_type == AmpType.PYTORCH
else contextlib.suppress()
)

with torch.no_grad(), torch_amp_context:
output = self.model(sample["input"])

local_loss = self.compute_loss(output, sample)
Expand Down Expand Up @@ -949,8 +995,15 @@ def train_step(self):
if self.mixup_transform is not None:
sample = self.mixup_transform(sample)

with torch.enable_grad():
# Forward pass
# Optional Pytorch AMP context
torch_amp_context = (
torch.cuda.amp.autocast()
if self.amp_type == AmpType.PYTORCH
else contextlib.suppress()
)

# Forward pass
with torch.enable_grad(), torch_amp_context:
output = self.model(sample["input"])

local_loss = self.compute_loss(output, sample)
Expand Down Expand Up @@ -1004,21 +1057,31 @@ def run_optimizer(self, loss):
)

with ctx_mgr_model, ctx_mgr_loss:
if self.amp_args is not None:
if self.amp_type == AmpType.APEX:
with apex.amp.scale_loss(loss, self.optimizer.optimizer) as scaled_loss:
scaled_loss.backward()
elif self.amp_type == AmpType.PYTORCH:
self.amp_grad_scaler.scale(loss).backward()
else:
loss.backward()

if do_step:
# Handle gradient accumulation related gradient rescaling
if self.optimizer_period != 1:
self._rescale_gradients(1 / self.optimizer_period)

# Clipping must happen after grad accumulation
if self.clip_grad_norm is not None:
self._clip_gradients(self.clip_grad_norm)

self.optimizer.step(where=self.where)
if self.amp_type == AmpType.PYTORCH:
# If using mixed precision, handle underflow-related scaling
# See https://pytorch.org/docs/stable/amp.html#gradient-scaling
# for context
self.amp_grad_scaler.step(self.optimizer, where=self.where)
self.amp_grad_scaler.update()
else:
self.optimizer.step(where=self.where)

def _rescale_gradients(self, scale):
for param in master_params(self.optimizer):
Expand Down
2 changes: 1 addition & 1 deletion classy_vision/tasks/classy_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from typing import Any, Dict


class ClassyTask(ABC):
Expand Down
2 changes: 0 additions & 2 deletions classy_vision/trainer/classy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch
from classy_vision.generic.distributed_util import barrier
from classy_vision.tasks import ClassyTask

Expand Down
17 changes: 16 additions & 1 deletion test/manual/tasks_classification_task_amp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,39 @@


class TestClassificationTaskAMP(unittest.TestCase):
@unittest.skipUnless(torch.cuda.is_available(), "This test needs a gpu to run")
def test_build_task(self):
config = get_test_task_config()
task = build_task(config)
self.assertTrue(isinstance(task, ClassificationTask))
# check that AMP is disabled by default
self.assertIsNone(task.amp_args)

# test a valid AMP opt level
# test a valid APEX AMP opt level
config = copy.deepcopy(config)
config["amp_args"] = {"opt_level": "O1"}
task = build_task(config)
self.assertTrue(isinstance(task, ClassificationTask))

# test a valid Pytorch AMP
config = copy.deepcopy(config)
config["amp_args"] = {"amp_type": "pytorch"}
task = build_task(config)
self.assertTrue(isinstance(task, ClassificationTask))

@unittest.skipUnless(torch.cuda.is_available(), "This test needs a gpu to run")
def test_training(self):
# Test an Apex AMP training
config = get_fast_test_task_config()
config["amp_args"] = {"opt_level": "O2"}
task = build_task(config)
task.set_use_gpu(True)
trainer = LocalTrainer()
trainer.train(task)

# Test a Pytorch AMP training
config["amp_args"] = {"amp_type": "pytorch"}
task = build_task(config)
task.set_use_gpu(True)
trainer = LocalTrainer()
trainer.train(task)

0 comments on commit ff37fea

Please sign in to comment.