The code only depends on torch
and triton
, it has been tested with torch==2.1.1
, triton==2.0.0
and triton-nightly==2.1.0.dev20230728172942
on A100 and 3090 GPU platforms.
We recommend you to install triton
and triton-nightly
using the following commands:
pip install triton==2.0.0
pip install triton-nightly==2.1.0.dev20230728172942 --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/
To install LASP from source, run:
cd LASP
# -e signified dev mode since e stands for editable
pip install -e .
The code is organized as follows:
tests/
contains the launch script and test file for run testing.lasp/
contains the implementation of lasp and its improved variants, includinglasp_native
,lasp_cache
,lasp_fuse
,lasp_fuse_parallel
and the non-sequence parallel version:lightning_attention
.lasp/utils/
contains the communication manager for lasp, i.e.,seq_parallel_manager
, and other utils functions.
The provided code supports the hybrid of Data Parallel (DP) (batch-level) and Sequence Parallel (SP) (sequence-level) on linear attention. As an example, assume we have 1 node with 8 GPUs and the ranks are {0, 1, 2, 3, 4, 5, 6, 7}. For data parallel size = 2 and sequence parallel size = 4, the DP and SP communication groups will be:
4 data_parallel groups (with global rank indices):
(0, 4), (1, 5), (2, 6), (3, 7)
2 sequence paralell groups (with global rank indices):
(0, 1, 2, 3), (4, 5, 6, 7)
In summary, the group maping (with their own rank indices) is as follows:
Global ranks: 0, 1, 2, 3, 4, 5, 6, 7
Data parallel ranks: 0, 0, 0, 0, 1, 1, 1, 1
Sequence parallel ranks: 0, 1, 2, 3, 0, 1, 2, 3
Run the following commands:
cd tests
bash script.sh
You will test LASP with randomly generated Q, K, V
and dO
under the following distributed cases sequentially:
Parallel Type | Case1 | Case2 | Case3 | Case4 |
---|---|---|---|---|
Data Parallel Size | 1 | 2 | 4 | 8 |
Sequence Parallel Size | 8 | 4 | 2 | 1 |
Other configuratures used in the test is batch size per device b=2
, sequence length n=2048
, number of heads h=12
and head dim d=128
.
By running the test, you will get the mean difference values of Oi
, dQi
, dKi
and dVi
obtained by LASP, comparing with the reference values of Lightning Attention (see our lightning attention work at: https://github.com/OpenNLPLab/lightning-attention).
Model | Parameters | Method | Loss | Method | Loss |
---|---|---|---|---|---|
TNL | 0.4B | DDP | 3.719 | LASP + DDP | 3.715 |
TNL | 0.4B | Legacy DDP | 3.709 | LASP + Legacy DDP | 3.705 |
TNL | 0.4B | FSDP | 3.717 | LASP + FSDP | 3.714 |
TNL | 0.4B | ZeRO-1 | 3.653 | LASP + ZeRO-1 | 3.653 |
TNL | 0.4B | ZeRO-2 | 3.655 | LASP + ZeRO-2 | 3.649 |
TNL | 0.4B | ZeRO-3 | 3.656 | LASP + ZeRO-3 | 3.649 |
LinearTransformer | 0.4B | DDP | 5.419 | LASP + DDP | 5.408 |
LinearTransformer | 0.4B | Legacy DDP | 5.425 | LASP + Legacy DDP | 5.413 |
LinearTransformer | 0.4B | FSDP | 5.428 | LASP + FSDP | 5.441 |
LinearTransformer | 0.4B | ZeRO-1 | 5.114 | LASP + ZeRO-1 | 5.118 |
LinearTransformer | 0.4B | ZeRO-2 | 5.105 | LASP + ZeRO-2 | 5.120 |
LinearTransformer | 0.4B | ZeRO-3 | 5.110 | LASP + ZeRO-3 | 5.123 |
*Convergence Performance of LASP. All experiments use 8 A100 80G GPUs, 16K sequence length, and batch size of 1. The results cover various DDP backends in conjunction with LASP. We explore the performance of two linear attention models: TransNormerLLM (TNL) and Linear Transformer, both with 0.4B parameters, across 50K updates.
*Scalability Evaluation of LASP on Throughput (tokens/sec) and Memory Usage. Left: Integration of LASP with FSDP backend; Right: Integration of LASP with DDP backend. The TNL-1B model is used, with a batch size of 1 across up to 128 A100 80GB GPUs. The sign "x" with a dotted line represents occurring an Out of Memory (OOM).
Following is the difference results obtained by running tests/script.sh
on 4 A100 GPUs, i.e., run:
for dp_size in 1 2 4
do
START_TIME=`date +%Y%m%d-%H:%M:%S`
LOG_FILE=${logger_dir}/${START_TIME}-dp-size-${dp_size}.log
torchrun --nproc_per_node 4 \
test.py --dp-size ${dp_size} \
2>&1 | tee -a $LOG_FILE
done
Outputs:
Test lasp_naive on world size 4 with data_parallel_size 1 and sequence_parallel_size 4:
### Forward ###
out diff: mean value: 0.0
### Backward ###
dq diff: mean value: 0.016357421875
dk diff: mean value: 0.047119140625
dv diff: mean value: 0.06689453125
Test lasp_cache on world size 4 with data_parallel_size 1 and sequence_parallel_size 4:
### Forward ###
out diff: mean value: 0.0
### Backward ###
dq diff: mean value: 0.016357421875
dk diff: mean value: 0.047119140625
dv diff: mean value: 0.06689453125
Test lasp_fuse on world size 4 with data_parallel_size 1 and sequence_parallel_size 4:
### Forward ###
out diff: mean value: 0.051025390625
### Backward ###
dq diff: mean value: 0.0186767578125
dk diff: mean value: 0.021240234375
dv diff: mean value: 0.06396484375
Test lasp_fuse_parallel on world size 4 with data_parallel_size 1 and sequence_parallel_size 4:
### Forward ###
out diff: mean value: 0.0179443359375
### Backward ###
dq diff: mean value: 0.0205078125
dk diff: mean value: 0.025146484375
dv diff: mean value: 0.03564453125
Test lasp_naive on world size 4 with data_parallel_size 2 and sequence_parallel_size 2:
### Forward ###
out diff: mean value: 0.0
### Backward ###
dq diff: mean value: 0.0169677734375
dk diff: mean value: 0.04638671875
dv diff: mean value: 0.0654296875
Test lasp_cache on world size 4 with data_parallel_size 2 and sequence_parallel_size 2:
### Forward ###
out diff: mean value: 0.0
### Backward ###
dq diff: mean value: 0.0169677734375
dk diff: mean value: 0.04638671875
dv diff: mean value: 0.0654296875
Test lasp_fuse on world size 4 with data_parallel_size 2 and sequence_parallel_size 2:
### Forward ###
out diff: mean value: 0.05224609375
### Backward ###
dq diff: mean value: 0.0198974609375
dk diff: mean value: 0.021240234375
dv diff: mean value: 0.06396484375
Test lasp_fuse_parallel on world size 4 with data_parallel_size 2 and sequence_parallel_size 2:
### Forward ###
out diff: mean value: 0.0240478515625
### Backward ###
dq diff: mean value: 0.022705078125
dk diff: mean value: 0.0250244140625
dv diff: mean value: 0.035400390625
Test lasp_naive on world size 4 with data_parallel_size 4 and sequence_parallel_size 1:
### Forward ###
out diff: mean value: 0.0
### Backward ###
dq diff: mean value: 0.0172119140625
dk diff: mean value: 0.0172119140625
dv diff: mean value: 0.0244140625
Test lasp_cache on world size 4 with data_parallel_size 4 and sequence_parallel_size 1:
### Forward ###
out diff: mean value: 0.0
### Backward ###
dq diff: mean value: 0.0172119140625
dk diff: mean value: 0.0172119140625
dv diff: mean value: 0.0244140625
Test lasp_fuse on world size 4 with data_parallel_size 4 and sequence_parallel_size 1:
### Forward ###
out diff: mean value: 0.053466796875
### Backward ###
dq diff: mean value: 0.0205078125
dk diff: mean value: 0.0205078125
dv diff: mean value: 0.06298828125
Test lasp_fuse_parallel on world size 4 with data_parallel_size 4 and sequence_parallel_size 1:
### Forward ###
out diff: mean value: 0.027099609375
### Backward ###
dq diff: mean value: 0.02392578125
dk diff: mean value: 0.02392578125
dv diff: mean value: 0.033935546875