Skip to content

Commit

Permalink
Small clarification of torch.cuda.amp multi-model example (#41203)
Browse files Browse the repository at this point in the history
Summary:
some people have been confused by `retain_graph` in the snippet, they thought it was an additional requirement imposed by amp.

Pull Request resolved: pytorch/pytorch#41203

Differential Revision: D22463700

Pulled By: ngimel

fbshipit-source-id: e6fc8871be2bf0ecc1794b1c6f5ea99af922bf7e
  • Loading branch information
definitelynotmcarilli authored and facebook-github-bot committed Jul 10, 2020
1 parent 4a09501 commit d927aee
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/notes/amp_examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ after all optimizers used this iteration have been stepped::
loss0 = loss_fn(2 * output0 + 3 * output1, target)
loss1 = loss_fn(3 * output0 - 5 * output1, target)

# (retain_graph here is unrelated to amp, it's present because in this
# example, both backward() calls share some sections of graph.)
scaler.scale(loss0).backward(retain_graph=True)
scaler.scale(loss1).backward()

Expand Down

0 comments on commit d927aee

Please sign in to comment.