Skip to content

Commit

Permalink
Implement gradient clippping (facebookresearch#643)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#643

Add support for gradient clipping in ClassificationTask

Reviewed By: mannatsingh

Differential Revision: D24736675

fbshipit-source-id: 9ed5c7a26f1708a81cf0d61f052629e1ff093983
  • Loading branch information
vreis authored and facebook-github-bot committed Nov 6, 2020
1 parent e3ac96c commit 7aaf5b0
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 2 deletions.
32 changes: 32 additions & 0 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class ClassificationTask(ClassyTask):
:var data_iterator: Iterator which can be used to obtain batches
:var losses: Loss curve
:var perf_log: list of training speed measurements, to be logged
:var clip_grad_norm: maximum gradient norm (default None)
"""

def __init__(self):
Expand Down Expand Up @@ -165,6 +166,7 @@ def __init__(self):
self.dataloader_mp_context = "spawn"
self.bn_weight_decay = False
self._train_only = True
self.clip_grad_norm = None

def set_use_gpu(self, use_gpu: bool):
self.use_gpu = use_gpu
Expand All @@ -175,6 +177,19 @@ def set_use_gpu(self, use_gpu: bool):

return self

def set_clip_grad_norm(self, clip_grad_norm: Optional[float]):
"""Sets maximum gradient norm.
None means gradient clipping is disabled. Defaults to None."""
self.clip_grad_norm = clip_grad_norm
if clip_grad_norm is None:
logging.info("Disabled gradient norm clipping.")
else:
logging.info(
f"Enabled gradient norm clipping with threshold: {clip_grad_norm}"
)
return self

def set_checkpoint(self, checkpoint_path: str):
"""Sets checkpoint on task.
Expand Down Expand Up @@ -489,6 +504,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
.set_distributed_options(**distributed_options)
.set_hooks(hooks)
.set_bn_weight_decay(config.get("bn_weight_decay", False))
.set_clip_grad_norm(config.get("clip_grad_norm"))
)

if not test_only:
Expand Down Expand Up @@ -934,10 +950,26 @@ def run_optimizer(self, loss):
else:
self.optimizer.backward(loss)

if self.clip_grad_norm is not None:
self._clip_gradients(self.clip_grad_norm)

self.check_inf_nan(loss)

self.optimizer.step(where=self.where)

def _clip_gradients(self, max_norm):
def all_params(optimizer):
for group in optimizer.param_groups:
for p in group["params"]:
yield p

if self.amp_args is not None:
params_iter = apex.amp.master_params(self.optimizer)
else:
params_iter = all_params(self.optimizer)

nn.utils.clip_grad_norm_(params_iter, max_norm)

def update_meters(self, model_output, sample):
target = sample["target"].detach().cpu()
model_output = model_output.detach().cpu()
Expand Down
77 changes: 75 additions & 2 deletions test/tasks_classification_task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import copy
import itertools
import shutil
import tempfile
import unittest
Expand All @@ -17,13 +18,14 @@
)

import torch
import torch.nn as nn
from classy_vision.dataset import build_dataset
from classy_vision.generic.distributed_util import is_distributed_training_run
from classy_vision.generic.util import get_checkpoint_dict
from classy_vision.hooks import CheckpointHook, LossLrMeterLoggingHook
from classy_vision.losses import ClassyLoss, build_loss, register_loss
from classy_vision.models import build_model
from classy_vision.optim import build_optimizer
from classy_vision.models import ClassyModel, build_model
from classy_vision.optim import SGD, build_optimizer
from classy_vision.tasks import ClassificationTask, build_task
from classy_vision.trainer import LocalTrainer

Expand Down Expand Up @@ -284,3 +286,74 @@ def test_get_classy_state_on_loss(self):
task = build_task(config)
task.prepare()
self.assertIn("alpha", task.get_classy_state()["loss"])

def test_gradient_clipping(self):
# Generate a simple model that has a very high gradient w.r.t. to this
# loss
class SimpleModel(ClassyModel):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.tensor(5.0), requires_grad=True)

def forward(self, x):
return x + self.param

@classmethod
def from_config(cls):
return cls()

class SimpleLoss(nn.Module):
def forward(self, out, target):
return out.pow(2).mean()

apex_available = True
try:
import apex # noqa F401
except ImportError:
apex_available = False

def train_with_clipped_gradients(amp_args=None):
task = build_task(get_fast_test_task_config())
task.set_num_epochs(1)
task.set_model(SimpleModel())
task.set_loss(SimpleLoss())
task.set_meters([])
task.set_use_gpu(torch.cuda.is_available())
task.set_clip_grad_norm(0.5)
task.set_amp_args(amp_args)

task.set_optimizer(SGD(lr=1))

trainer = LocalTrainer()
trainer.train(task)

return task.model.param.grad.norm()

grad_norm = train_with_clipped_gradients(None)
self.assertAlmostEqual(grad_norm, 0.5, delta=1e-2)

if apex_available and torch.cuda.is_available():
grad_norm = train_with_clipped_gradients({"opt_level": "O2"})
self.assertAlmostEqual(grad_norm, 0.5, delta=1e-2)

def test_clip_stateful_loss(self):
config = get_fast_test_task_config()
config["loss"] = {"name": "test_stateful_loss", "in_plane": 256}
config["grad_norm_clip"] = grad_norm_clip = 1
task = build_task(config)
task.set_use_gpu(False)
task.prepare()

# set fake gradients with norm > grad_norm_clip
for param in itertools.chain(
task.base_model.parameters(), task.base_loss.parameters()
):
param.grad = 1.1 + torch.rand(param.shape)
self.assertGreater(param.grad.norm(), grad_norm_clip)

task._clip_gradients(grad_norm_clip)

for param in itertools.chain(
task.base_model.parameters(), task.base_loss.parameters()
):
self.assertLessEqual(param.grad.norm(), grad_norm_clip)

0 comments on commit 7aaf5b0

Please sign in to comment.