Skip to content

Commit

Permalink
Prevent duplicated checkpoints (NVIDIA#9015)
Browse files Browse the repository at this point in the history
* Prevent duplicated checkpoints

Signed-off-by: Mikołaj Błaż <[email protected]>

* Add versioning to save_to

Signed-off-by: Mikołaj Błaż <[email protected]>

* Add versioning logic to all .nemo files

Signed-off-by: Mikołaj Błaż <[email protected]>

* Add versioning test

Signed-off-by: Mikołaj Błaż <[email protected]>

* Add dist-ckpt test

Signed-off-by: Mikołaj Błaż <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Mikołaj Błaż <[email protected]>

* Rename existing ckpts instead of using different name

Signed-off-by: Mikołaj Błaż <[email protected]>

* Add comment

Signed-off-by: Mikołaj Błaż <[email protected]>

* Run dist-ckpt test on GPU

Signed-off-by: Mikołaj Błaż <[email protected]>

* Prevent saving last for non-equal val intervals

Signed-off-by: Mikołaj Błaż <[email protected]>

* Move checkpoint on rank 0

Signed-off-by: Mikołaj Błaż <[email protected]>

---------

Signed-off-by: Mikołaj Błaż <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Dmytro Pykhtar <[email protected]>
  • Loading branch information
3 people authored May 8, 2024
1 parent 5d83213 commit 305392d
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 28 deletions.
33 changes: 10 additions & 23 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,6 @@ def save_checkpoint(
# dist_checkpointing expects a directory so we will name the directory
# using the path with the file extension removed
checkpoint_dir = ckpt_to_dir(filepath)

fs = get_filesystem(checkpoint_dir)
if fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir):
logging.info(f'Distributed checkpoint at path {checkpoint_dir} already exists, skipping saving')
return

# remove device state_dict
checkpoint['state_dict'] = OrderedDict([])

Expand Down Expand Up @@ -861,26 +855,19 @@ def save_to(self, model, save_path: str):
if dist_ckpt:
# model weights is a directory
dist_ckpt_dir = ckpt_to_dir(os.path.join(dir_name, self.model_weights_ckpt))
fs = get_filesystem(dist_ckpt_dir)

if fs.isdir(dist_ckpt_dir) and dist_checkpointing.check_is_distributed_checkpoint(dist_ckpt_dir):
logging.info(f'Distributed checkpoint at path {dist_ckpt_dir} already exists, skipping saving')
else:
if is_global_rank_zero():
fs.makedirs(dist_ckpt_dir, exist_ok=True)

sharded_state_dict = model.sharded_state_dict()
# dist checkpoint needs torch.distributed to save the checkpoint
if not parallel_state.is_initialized():
sharded_state_dict = model.sharded_state_dict()
# dist checkpoint needs torch.distributed to save the checkpoint
if not parallel_state.is_initialized():

def dummy():
return
def dummy():
return

if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
checkpoint_io = DistributedCheckpointIO(model.cfg.get('dist_ckpt_format', 'zarr'))
checkpoint_io.save_checkpoint(sharded_state_dict, dist_ckpt_dir)
if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
checkpoint_io = DistributedCheckpointIO(model.cfg.get('dist_ckpt_format', 'zarr'))
checkpoint_io.save_checkpoint(sharded_state_dict, dist_ckpt_dir)

else:

Expand Down
46 changes: 41 additions & 5 deletions nemo/utils/callbacks/nemo_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,8 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
logging.warning(f'always_save_nemo will slow down training for model_parallel > 1.')
# since we are creating tarfile artifacts we need to update .nemo path
app_state.model_restore_path = os.path.abspath(
os.path.expanduser(os.path.join(self.dirpath, self.prefix + self.postfix))
)
self._backup_existing_nemo_ckpt(trainer)
app_state.model_restore_path = self._format_nemo_checkpoint_name()
if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
maybe_injected_best_model_path = inject_model_parallel_rank(self.best_model_path)
else:
Expand Down Expand Up @@ -236,7 +235,10 @@ def on_train_end(self, trainer, pl_module):
should_save_last_checkpoint = True
if should_save_last_checkpoint:
monitor_candidates = self._monitor_candidates(trainer)
super()._save_last_checkpoint(trainer, monitor_candidates)
if self.last_model_path == self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST):
logging.debug(f'Last checkpoint {self.last_model_path} already saved')
else:
super()._save_last_checkpoint(trainer, monitor_candidates)
# Call parent on_train_end() to save the -last checkpoint
super().on_train_end(trainer, pl_module)

Expand All @@ -256,7 +258,36 @@ def on_train_end(self, trainer, pl_module):
trainer._checkpoint_connector.restore(self.best_model_path)

if self.save_nemo_on_train_end:
pl_module.save_to(save_path=os.path.join(self.dirpath, self.prefix + self.postfix))
self._backup_existing_nemo_ckpt(trainer)
pl_module.save_to(save_path=self._format_nemo_checkpoint_name())

def _backup_existing_nemo_ckpt(self, trainer) -> str:
""" Search for an available name with version infix and rename existing checkpoint.
NOTE: this behavior is slightly different from regular checkpoints.
PTL creates new regular checkpoint with the first available name.
Here, for backward compatibility, we create .nemo checkpoint as before
and create a backup under the first available name.
"""
base_path = self._format_nemo_checkpoint_name()
available_path = base_path
if self._enable_version_counter:
version_cnt = self.STARTING_VERSION
while self.file_exists(available_path, trainer, check_dist_ckpt=False):
available_path = self._format_nemo_checkpoint_name(version_cnt)
version_cnt += 1
if available_path != base_path:
if trainer.is_global_zero:
logging.info(f'{base_path} already exists, moving existing checkpoint to {available_path}')
shutil.move(base_path, available_path)
trainer.strategy.barrier()
return available_path

def _format_nemo_checkpoint_name(self, ver: Optional[int] = None) -> str:
version_infix = '' if ver is None else f'{self.CHECKPOINT_JOIN_CHAR}v{ver}'
return os.path.abspath(
os.path.expanduser(os.path.join(self.dirpath, self.prefix + version_infix + self.postfix))
)

def _del_model_without_trainer(self, filepath: str) -> None:

Expand Down Expand Up @@ -367,6 +398,11 @@ def remove_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barri
except:
return

def file_exists(self, filepath: str, trainer: "pytorch_lightning.Trainer", check_dist_ckpt: bool = True) -> bool:
"""Checks if a file or a file without a suffix (distributed checkpoint) exists."""
exists = self._fs.exists(filepath) or (check_dist_ckpt and self._fs.exists(ckpt_to_dir(filepath)))
return trainer.strategy.broadcast(exists)

def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None:
# barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed.
# if anything goes wrong during checkpointing, we should be able to detect that data is incomplete.
Expand Down
63 changes: 63 additions & 0 deletions tests/core/test_exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import os
import re
Expand All @@ -25,6 +26,7 @@
from pytorch_lightning import Callback
from pytorch_lightning.loops import _TrainingEpochLoop

from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from nemo.constants import NEMO_ENV_VARNAME_VERSION
from nemo.core.classes import ModelPT
from nemo.utils.callbacks import NeMoModelCheckpoint
Expand Down Expand Up @@ -130,6 +132,11 @@ def on_validation_epoch_end(self):
self.log("val_loss", torch.stack([self.loss]).mean())


class ExampleMCoreModel(ExampleModel):
def sharded_state_dict(self):
return {'a': 3}


class DoNothingModel(ExampleModel):
def configure_optimizers(self):
return DoNothingOptimizer(self.parameters())
Expand Down Expand Up @@ -502,6 +509,62 @@ def test_nemo_checkpoint_restore_model(self, tmp_path):
test_trainer.fit(model)
assert math.fabs(float(model(torch.tensor([1.0, 1.0], device=model.device))) - 0.03) < 1e-5

@pytest.mark.run_only_on('GPU')
@pytest.mark.parametrize('test_dist_ckpt', [False, True])
def test_checkpoints_are_not_overwritten(self, tmp_path, test_dist_ckpt):
""" Simulates already existing checkpoints in the ckpt directory and tests ckpt versioning """
strategy = NLPDDPStrategy() if test_dist_ckpt else 'auto'
test_trainer = pl.Trainer(
accelerator='cpu', enable_checkpointing=False, logger=False, max_epochs=4, strategy=strategy
)
exp_manager(
test_trainer,
{
"checkpoint_callback_params": {"save_nemo_on_train_end": True},
"explicit_log_dir": str(tmp_path / "test"),
},
)
model = ExampleMCoreModel() if test_dist_ckpt else ExampleModel()

ckpt_dir = Path(tmp_path / "test" / "checkpoints")
assert not ckpt_dir.exists()

# Fake existing 1st and last checkpoint
suffix = '' if test_dist_ckpt else '.ckpt'
ckpt_dir.mkdir(parents=True)
ckpt_1 = ckpt_dir / f'default--val_loss=0.0000-epoch=1{suffix}'
ckpt_2 = ckpt_dir / f'default--val_loss=0.0300-epoch=2{suffix}'

if test_dist_ckpt:
ckpt_1.mkdir()
with open(ckpt_1 / 'metadata.json', 'w') as f:
json.dump({'sharded_backend': 'xxx'}, f)
else:
ckpt_1.touch()
# don't create 2nd checkpoint
ckpt_nemo = ckpt_dir / 'default.nemo'
ckpt_nemo.touch()

# Train
test_trainer.fit(model)

# Check base checkpoint (without versioning)
all_checkpoints = [p.name for p in Path(str(tmp_path / "test" / "checkpoints")).glob("*")]
assert ckpt_1.exists(), all_checkpoints # existed before
assert ckpt_2.exists(), all_checkpoints
assert ckpt_nemo.exists(), all_checkpoints # existed before

# Versioned checkpoints
def _get_versioned_name(ckpt_name: Path, nemo: bool = False):
if test_dist_ckpt and not nemo:
# no suffix at all
return ckpt_name.with_name(ckpt_name.name + '-v1')
return ckpt_name.with_stem(ckpt_name.stem + '-v1')

assert _get_versioned_name(ckpt_1).exists(), all_checkpoints
assert not _get_versioned_name(ckpt_2).exists(), all_checkpoints # ckpt2 didn't exist before
assert _get_versioned_name(ckpt_nemo, nemo=True).exists(), all_checkpoints

@pytest.mark.unit
def test_last_checkpoint_saved(self, tmp_path):
max_steps = 64
Expand Down

0 comments on commit 305392d

Please sign in to comment.