Skip to content

Commit

Permalink
Gradient accumulation and train_epochs (#402)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr authored Sep 14, 2022
1 parent d27bbef commit 5dcbb56
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
could delay submitting some steps in parallel.
- Fixed a bug where creating a `StepInfo` object from params might result in unnecessary imports.
- Fixed a bug where canceling the Beaker executor might not work properly.
- Fixed a bug where the trainer trains too much when `train_epochs` is set and you're using gradient accumulation.

## [v0.13.0](https://github.com/allenai/tango/releases/tag/v0.13.0) - 2022-09-07

Expand Down
5 changes: 4 additions & 1 deletion tango/integrations/torch/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import math
import os
import shutil
from itertools import islice
Expand Down Expand Up @@ -412,7 +413,9 @@ def _train(
steps_per_epoch = len(train_dataloader)
except TypeError:
raise ConfigurationError("You must set 'train_steps' for streaming/iterable datasets")
config.train_steps = steps_per_epoch * (config.train_epochs or 1)
config.train_steps = math.ceil(
steps_per_epoch * (config.train_epochs or 1) / config.grad_accum
)

assert config.train_steps is not None # for mypy

Expand Down
40 changes: 39 additions & 1 deletion tests/integrations/torch/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def test_basic_train(self, with_validation: bool):
result_dir / "train" / "work" / "checkpoint_state_best" / "worker0_trainer.pt"
).is_file()

def test_basic_train_with_epochs(self):
@pytest.mark.parametrize("grad_acc", [1, 2])
def test_basic_train_with_epochs(self, grad_acc: int):
result_dir = self.run(
self.FIXTURES_ROOT / "integrations" / "torch" / "train.jsonnet",
include_package=[
Expand All @@ -56,11 +57,20 @@ def test_basic_train_with_epochs(self):
"steps.train.train_steps": None,
"steps.train.train_epochs": 2,
"steps.train.validate_every": None,
"steps.train.grad_accum": grad_acc,
}
),
)
assert (result_dir / "train" / "data.pt").is_file()

# Make sure we trained for the right number of steps.
expected_steps = 16 // grad_acc
latest = result_dir / "train" / "work" / "checkpoint_state_latest"
assert latest.is_symlink()
last_step = result_dir / "train" / "work" / f"checkpoint_state_step{expected_steps}"
assert last_step.is_dir()
assert latest.samefile(last_step)

def test_basic_train_with_streaming_data(self):
result_dir = self.run(
self.FIXTURES_ROOT / "integrations" / "torch" / "train.jsonnet",
Expand Down Expand Up @@ -93,3 +103,31 @@ def test_train_distributed(self):
assert (
result_dir / "train" / "work" / "checkpoint_state_best" / "worker1_model.pt"
).is_file()

@pytest.mark.parametrize("grad_acc", [1, 2])
def test_train_distributed_with_epochs(self, grad_acc: int):
result_dir = self.run(
self.FIXTURES_ROOT / "integrations" / "torch" / "train_dist.jsonnet",
include_package=[
"test_fixtures.integrations.common",
"test_fixtures.integrations.torch",
],
overrides=json.dumps(
{
"steps.train.train_steps": None,
"steps.train.train_epochs": 2,
"steps.train.validate_every": None,
"steps.train.grad_accum": grad_acc,
}
),
)

assert (result_dir / "train" / "data.pt").is_file()

# Make sure we trained for the right number of steps.
expected_steps = 8 // grad_acc
latest = result_dir / "train" / "work" / "checkpoint_state_latest"
assert latest.is_symlink()
last_step = result_dir / "train" / "work" / f"checkpoint_state_step{expected_steps}"
assert last_step.is_dir()
assert latest.samefile(last_step)

0 comments on commit 5dcbb56

Please sign in to comment.