NOTE: This chapter's code builds off of chapter 4's FSDP code.
Here we are going to utilize an 8 node cluster (64 H100 GPUs) to train Llama 3.1 405B. This does not utilize LORA! We are actually fully training the weights of a 405b model in plain pytorch.
The next few sections go through various changes we have to make to our FSDP code from chapter 4 to make training a 405b model work.
Quick Jump:
- Use flash attention
- Download model weights
- Loading pretrained weights
- Sharding Llama 405B
- Gradient (aka activation) checkpointing
- CPU Offload & fused optimizer kernels
- NOT de-allocating gradients
- Launch command
- Monitoring
- Run statistics
- Other notes on settings that didn't affect throughput
Flash attention is a fused implementation of scaled dot product attention that heavily minimizes memory usage. The whole goal behind it is to query memory as little as possible, and minimize temporary memory used.
Check out the repo and the paper for more information.
This ends up saving us 10s of gb in the forward/backward pass.
Install:
pip install packaging
pip install ninja
pip install flash-attn --no-build-isolation
Use it when we initialize our model:
model = AutoModelForCausalLM.from_pretrained(
...
attn_implementation="flash_attention_2",
)
The actual model weights are huge - it contains 191 separate files which are each about 4GB - totally about 764 GB.
There are two options for storing these weights here (and they make a difference!):
- A shared network drive that all the nodes can access
- Locally on the main rank 0 node
Node local storage is much faster when initializing. For some numbers, while running this script on 8 8xH100 80GB nodes, the shared network drive took 50 minutes to initialize, while the node local storage only took 3 minutes.
There's a download script in this repo for utility, run this on node 0:
cd distributed-training-guide/05-training-llama-405b
python download.py
And run this on the other nodes (to download config & tokenizer):
cd distributed-training-guide/05-training-llama-405b
python download.py --skip-model
NOTE: you will likely have to log into your huggingface account using huggingface-cli login
.
When we actual load the weights, it will take some time AND takes a lot of memory to load. Again the full size is about 764 GB, so we need to make sure we have enough RAM to store the weights.
There's three parts to this:
- Loading the weights into RAM only on
rank==0
- Using the meta device on
rank>0
- Using
from_config
instead offrom_pretrained
onrank>0
so we don't need to download the weights on all the nodes.- Note that if you have the weights on a shared network drive, you can just use
from_pretrained
instead.
- Note that if you have the weights on a shared network drive, you can just use
- Enabling sync_module_states in FSDP constructor
You might think of using the device_map
feature of transformers
- e.g. device_map="auto"
tries to smartly fill up memory. However if you try this approach you'll end up with out of memory errors when FSDP tries to start sending memory to the GPU.
Here's our code snippet for doing this:
if rank == 0:
with torch.device("cpu"):
model = AutoModelForCausalLM.from_pretrained(...)
else:
with torch.device("meta"):
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype)
Then later, sync_module_states in FSDP constructor will make sure the weights are broadcasted from rank 0 to the other ranks.
Determining what layers you should shard is complex. If you are using transformers
, they include a private attribute on classes called _no_split_modules that will contain classes that you should not shard anything under them. E.g. for Llama this attribute just contains LlamaDecoderLayer
. So that is what we will wrap! During testing I also found that sharding the nn.Embedding
layer at the beginning of the network improved throughput and reduced memory usage.
We can use the transformer_auto_wrap_policy() to target the specific classes for those layers, and pass that as our auto_wrap_policy in the FSDP constructor:
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer, nn.Embedding},
)
FSDP(..., auto_wrap_policy=wrap_policy)
Please consult our explanation on the FSDP constructor for more info.
As a reminder - this will cause FSDP to gather all the parameters for each DecoderLayer (which includes Attention, Linear, and various norm modules), and shard them across the world. At the start of forward/backward pass FSDP will issue an all-gather so all the nodes have the full weights in memory, and at the end of the DecoderLayer forward/backward, it will free up the full weights again.
So where you apply FSDP determines where the all-gather happens!
Another piece of reducing memory usage is gradient checkpointing (first introduced in Training Deep Nets with Sublinear Memory Cost). Normally when you do the forward pass, you have to keep the input & output in memory until you run the backward pass. This takes up a lot of memory to keep these intermediate tensors around. With gradient checkpointing, we actually re-run the forward pass during backwards to regenerate the output. So we are doing more compute but saving a lot of memory.
The method we are using is kind of a hidden method in pytorch, but this is actually exactly what accelerate uses under the hood so rest assured that it is a "standard" way of doing it:
This piece of code has to go after the FSDP constructor!!! I'm not exactly sure of the reason, but it doesn't work before the FSDP initialization.
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
)
model = FSDP(...)
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=checkpoint_wrapper, auto_wrap_policy=wrap_policy
)
Since the model is so large, we pretty much have to enable CPU offloading with FSDP. When using CPUOffload feature of FSDP, the optimizer entirely runs on the CPU. This is because there is significant cost to transfer data to and from the GPU when doing optimizer.step()
. At the time of this being written there are open issues on how to overlap the optimizer.step()
with the next forward()
call.
By default the optimizers will use non-fused kernel when running on the CPU which will generate a lot of intermediate tensors. By explicitly using the fused kernel we get a lot of speedup, which is especially important since we are running that step on the CPU:
torch.optim.AdamW(model.parameters(), lr=args.lr, fused=True)
If you want to peek through the pytorch code:
- _single_tensor_adamw() is the default implementation used
- _fused_adamw() is the fused implementation
You may have seen this set_to_none
argument in optimizer.zero_grad(). According to the docs:
This will in general have lower memory footprint, and can modestly improve performance.
Basically set_to_none=True
will deallocate the gradients after they are used. In most GPU cases where we want to save a bit of memory, it is a good thing to de-allocate. However in our case we are using CPU offload, which means all of our gradients are already on the CPU! Since we aren't taking up GPU memory, that means we just have to pay for allocating & de-allocating a lot if we do set to none. So if you set set_to_none=False
you should actually see a slight speed up for our case!
optimizer.zero_grad(set_to_none=args.cpu_offload == "off")
That's pretty much all the changes you need from our base FSDP code. Now let's launch!
We provide a customized launch.sh script here based on the bash command for spawning torchrun on all available nodes:
cd distributed-training-guide/05-training-llama-405b
bash launch.sh # NOTE: this is non blocking
Also note that this launch.sh specifies HF_HOME
as an environment variable in the tmux session, so if you've not used the default value of /home/ubuntu/.cache/huggingface
, please update the script!
You can change the hostnames in the hosts file in this directory.
We are using torchrun in our launch.sh script, so we will get an output directory per node with a bunch of sub directories with our log files in them. It's a bit of a pain to manually monitor these, so here's a bash command for tailing all of them at once:
cd distributed-training-guide/05-training-llama-405b
find ../logs/ -name \*stderr.log | xargs tail -f
Additionally, we have a top like utility script for monitoring the entire cluster at the top level of this directory:
cd distributed-training-guide/05-training-llama-405b
python ../top-cluster.py hosts
If you notice any of the nprocs go down or the power usage go down then you know that an error has occurred!
To kill all the processes on all the nodes you can just kill the tmux sessions:
xargs -a hosts -I{} ssh {} tmux kill-session -t torchrun-llama-405b
Training with --seq-length 4096
and --batch-size 1
on 64 H100 gpus (8 separate nodes) has the following stats:
- ~30s per iteration (data/forward/backward/update). Breakdown is
- data: ~2ms
- forward: ~7s
- backward: ~19s
- update: ~4s
- Peak Memory Allocated: 52.9GB
- Peak Memory Reserved: 77.9GB
Noting that reserved memory has to do with pytorch allocation caching.
- Allowing tf32 had no impact on throughput (
torch.backends.cudnn.allow_tf32
andtorch.backends.cuda.matmul.allow_tf32
) - Enabling benchmarking had no impact on throughput (
torch.backends.cudnn.benchmark = True
) - Using CuDNN sdpa was slower (
attn_implementation="sdpa"
andtorch.backends.cuda.enable_cudnn_sdp(True)
) - torch.compile had no impact (
use_orig_params=True
andtorch.compile
after FSDP constructor) - Very minimal testing of NCCL environment variables either made things worse or had no impact (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html)
PYTORCH_NO_CUDA_MEMORY_CACHING=1
made enough memory available that--batch-size 2
or higher sequence lengths were possible, but it was much much slower.- It's possible that some well placed calls to
torch.cuda.empty_cache()
could achieve this without the throughput loss.
- It's possible that some well placed calls to
- Only
FULL_SHARD
works. Others fail silently.