Skip to content

Commit

Permalink
fix the backward for deepspeed (huggingface#9705)
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 authored Jan 20, 2021
1 parent 538245b commit cd5565b
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,8 +1282,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif self.deepspeed:
# calling on DS engine (model_wrapped == DDP(Deepspeed(PretrainedModule)))
self.model_wrapped.module.backward(loss)
self.deepspeed.backward(loss)
else:
loss.backward()

Expand Down

0 comments on commit cd5565b

Please sign in to comment.