Skip to content

Commit

Permalink
Cuda rng_state_all is used when saving in distributed mode so same sh…
Browse files Browse the repository at this point in the history
…ould also be used when loading (huggingface#23045)

cuda rng state should be all for distributed bc all were saved
  • Loading branch information
ShivamShrirao authored Apr 28, 2023
1 parent 521a8ff commit 4d0ea3d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2327,10 +2327,10 @@ def _load_rng_state(self, checkpoint):
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
if torch.cuda.is_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
else:
try:
torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
except Exception as e:
logger.info(
f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
Expand Down

0 comments on commit 4d0ea3d

Please sign in to comment.