Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Checkpointing + FSDP causes Cuda OOM #20312

Open
profPlum opened this issue Oct 1, 2024 · 0 comments
Open

Model Checkpointing + FSDP causes Cuda OOM #20312

profPlum opened this issue Oct 1, 2024 · 0 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.3.x

Comments

@profPlum
Copy link

profPlum commented Oct 1, 2024

Bug description

I'm using FSDP and model checkpointing (default settings for both). My model has 254 million parameters. I'm not sure why but when I run Trainer.fit() it will successfully run the first epoch, then it will hit a CUDA OOM on the first backward pass of the second epoch. And this problem goes away when I disable model checkpointing, making me think it is a bug with FSDP model checkpointing. After all why should model checkpointing cause a CUDA OOM? Furthermore the problem persists with both state_dict_type='sharded' and regular model checkpoints.

Also the fact that the OOM happens AFTER the model checkpointing (during the next backward()) makes me think that the model checkpointing could be causing some kind of memory leak?

What version are you seeing the problem on?

v2.3

Error Message(s) (abbreviated):

[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 239, in backward_fn
[rank4]: call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
[rank4]: output = fn(*args, **kwargs)
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 212, in backward
[rank4]: self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 72, in backward
[rank4]: model.backward(tensor, *args, **kwargs)
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1103, in backward
[rank4]: loss.backward(*args, **kwargs)
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
[rank4]: torch.autograd.backward(
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/torch/autograd/init.py", line 267, in backward
[rank4]: _engine_run_backward(
[rank4]: File "/home/dsdeigh/miniforge3/envs/uqops+proxy/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank4]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank4]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 7.57 GiB. GPU

@profPlum profPlum added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Oct 1, 2024
@profPlum profPlum changed the title Model Checkpointing Causes Cuda OOM Model Checkpointing + FSDP causes Cuda OOM Oct 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.3.x
Projects
None yet
Development

No branches or pull requests

1 participant