Skip to content
/ LASP Public
forked from OpenNLPLab/LASP

Linear Attention Sequence Parallelism (LASP)

Notifications You must be signed in to change notification settings

weigao266/LASP

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Linear Attention Sequence Parallelism (LASP)

Installation

The code only depends on torch and triton, it has been tested with torch==2.1.1, triton==2.0.0 and triton-nightly==2.1.0.dev20230728172942 on A100 and 3090 GPU platforms.

We recommend you to install triton and triton-nightly using the following commands:

pip install triton==2.0.0
pip install triton-nightly==2.1.0.dev20230728172942 --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/

To install LASP from source, run:

cd LASP
# -e signified dev mode since e stands for editable
pip install -e .

Code Structure

The code is organized as follows:

  • tests/ contains the launch script and test file for run testing.
  • lasp/ contains the implementation of lasp and its improved variants, including lasp_native, lasp_cache, lasp_fuse, lasp_fuse_parallel and the non-sequence parallel version: lightning_attention.
  • lasp/utils/ contains the communication manager for lasp, i.e., seq_parallel_manager, and other utils functions.

Code Usage

The provided code supports the hybrid of Data Parallel (DP) (batch-level) and Sequence Parallel (SP) (sequence-level) on linear attention. As an example, assume we have 1 node with 8 GPUs and the ranks are {0, 1, 2, 3, 4, 5, 6, 7}. For data parallel size = 2 and sequence parallel size = 4, the DP and SP communication groups will be:

4 data_parallel groups (with global rank indices):

(0, 4), (1, 5), (2, 6), (3, 7)

2 sequence paralell groups (with global rank indices):

(0, 1, 2, 3), (4, 5, 6, 7)

In summary, the group maping (with their own rank indices) is as follows:

Global ranks:             0, 1, 2, 3, 4, 5, 6, 7
Data parallel ranks:      0, 0, 0, 0, 1, 1, 1, 1
Sequence parallel ranks:  0, 1, 2, 3, 0, 1, 2, 3

Run the following commands:

cd tests
bash script.sh

You will test LASP with randomly generated Q, K, V and dO under the following distributed cases sequentially:

Parallel Type Case1 Case2 Case3 Case4
Data Parallel Size 1 2 4 8
Sequence Parallel Size 8 4 2 1

Other configuratures used in the test is batch size per device b=2, sequence length n=2048, number of heads h=12 and head dim d=128.

By running the test, you will get the mean difference values of Oi, dQi, dKi and dVi obtained by LASP, comparing with the reference values of Lightning Attention (see our lightning attention work at: https://github.com/OpenNLPLab/lightning-attention).

Benchmark

Convergence Results

Model Parameters Method Loss Method Loss
TNL 0.4B DDP 3.719 LASP + DDP 3.715
TNL 0.4B Legacy DDP 3.709 LASP + Legacy DDP 3.705
TNL 0.4B FSDP 3.717 LASP + FSDP 3.714
TNL 0.4B ZeRO-1 3.653 LASP + ZeRO-1 3.653
TNL 0.4B ZeRO-2 3.655 LASP + ZeRO-2 3.649
TNL 0.4B ZeRO-3 3.656 LASP + ZeRO-3 3.649
LinearTransformer 0.4B DDP 5.419 LASP + DDP 5.408
LinearTransformer 0.4B Legacy DDP 5.425 LASP + Legacy DDP 5.413
LinearTransformer 0.4B FSDP 5.428 LASP + FSDP 5.441
LinearTransformer 0.4B ZeRO-1 5.114 LASP + ZeRO-1 5.118
LinearTransformer 0.4B ZeRO-2 5.105 LASP + ZeRO-2 5.120
LinearTransformer 0.4B ZeRO-3 5.110 LASP + ZeRO-3 5.123

*Convergence Performance of LASP. All experiments use 8 A100 80G GPUs, 16K sequence length, and batch size of 1. The results cover various DDP backends in conjunction with LASP. We explore the performance of two linear attention models: TransNormerLLM (TNL) and Linear Transformer, both with 0.4B parameters, across 50K updates.

Scalability Results

*Scalability Evaluation of LASP on Throughput (tokens/sec) and Memory Usage. Left: Integration of LASP with FSDP backend; Right: Integration of LASP with DDP backend. The TNL-1B model is used, with a batch size of 1 across up to 128 A100 80GB GPUs. The sign "x" with a dotted line represents occurring an Out of Memory (OOM).

Difference Results on 4 A100 GPUs

Following is the difference results obtained by running tests/script.sh on 4 A100 GPUs, i.e., run:

for dp_size in 1 2 4
do
    START_TIME=`date +%Y%m%d-%H:%M:%S`

    LOG_FILE=${logger_dir}/${START_TIME}-dp-size-${dp_size}.log
    torchrun --nproc_per_node 4 \
    test.py --dp-size ${dp_size} \
    2>&1 | tee -a $LOG_FILE
done

Outputs:

Test lasp_naive on world size 4 with data_parallel_size 1 and sequence_parallel_size 4:
### Forward ###
out diff: mean value: 0.0
### Backward ###
dq diff: mean value: 0.016357421875
dk diff: mean value: 0.047119140625
dv diff: mean value: 0.06689453125


Test lasp_cache on world size 4 with data_parallel_size 1 and sequence_parallel_size 4:
### Forward ###
out diff: mean value: 0.0
### Backward ###
dq diff: mean value: 0.016357421875
dk diff: mean value: 0.047119140625
dv diff: mean value: 0.06689453125


Test lasp_fuse on world size 4 with data_parallel_size 1 and sequence_parallel_size 4:
### Forward ###
out diff: mean value: 0.051025390625
### Backward ###
dq diff: mean value: 0.0186767578125
dk diff: mean value: 0.021240234375
dv diff: mean value: 0.06396484375


Test lasp_fuse_parallel on world size 4 with data_parallel_size 1 and sequence_parallel_size 4:
### Forward ###
out diff: mean value: 0.0179443359375
### Backward ###
dq diff: mean value: 0.0205078125
dk diff: mean value: 0.025146484375
dv diff: mean value: 0.03564453125


Test lasp_naive on world size 4 with data_parallel_size 2 and sequence_parallel_size 2:
### Forward ###
out diff: mean value: 0.0
### Backward ###
dq diff: mean value: 0.0169677734375
dk diff: mean value: 0.04638671875
dv diff: mean value: 0.0654296875


Test lasp_cache on world size 4 with data_parallel_size 2 and sequence_parallel_size 2:
### Forward ###
out diff: mean value: 0.0
### Backward ###
dq diff: mean value: 0.0169677734375
dk diff: mean value: 0.04638671875
dv diff: mean value: 0.0654296875


Test lasp_fuse on world size 4 with data_parallel_size 2 and sequence_parallel_size 2:
### Forward ###
out diff: mean value: 0.05224609375
### Backward ###
dq diff: mean value: 0.0198974609375
dk diff: mean value: 0.021240234375
dv diff: mean value: 0.06396484375


Test lasp_fuse_parallel on world size 4 with data_parallel_size 2 and sequence_parallel_size 2:
### Forward ###
out diff: mean value: 0.0240478515625
### Backward ###
dq diff: mean value: 0.022705078125
dk diff: mean value: 0.0250244140625
dv diff: mean value: 0.035400390625


Test lasp_naive on world size 4 with data_parallel_size 4 and sequence_parallel_size 1:
### Forward ###
out diff: mean value: 0.0
### Backward ###
dq diff: mean value: 0.0172119140625
dk diff: mean value: 0.0172119140625
dv diff: mean value: 0.0244140625


Test lasp_cache on world size 4 with data_parallel_size 4 and sequence_parallel_size 1:
### Forward ###
out diff: mean value: 0.0
### Backward ###
dq diff: mean value: 0.0172119140625
dk diff: mean value: 0.0172119140625
dv diff: mean value: 0.0244140625


Test lasp_fuse on world size 4 with data_parallel_size 4 and sequence_parallel_size 1:
### Forward ###
out diff: mean value: 0.053466796875
### Backward ###
dq diff: mean value: 0.0205078125
dk diff: mean value: 0.0205078125
dv diff: mean value: 0.06298828125


Test lasp_fuse_parallel on world size 4 with data_parallel_size 4 and sequence_parallel_size 1:
### Forward ###
out diff: mean value: 0.027099609375
### Backward ###
dq diff: mean value: 0.02392578125
dk diff: mean value: 0.02392578125
dv diff: mean value: 0.033935546875

About

Linear Attention Sequence Parallelism (LASP)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.7%
  • Shell 0.3%