Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fast decode plan for flashinfer mla #3987

Merged
merged 6 commits into from
Mar 3, 2025
Merged

Conversation

Fridge003
Copy link
Collaborator

@Fridge003 Fridge003 commented Mar 2, 2025

Motivation

When using flashinfer mla backend and cuda graph together, graph replay will be hanged due to transmission of indptr tensors between cpu and gpu in BatchMLAPagedAttentionWrapper.plan.

This PR fixes this issue by adding a new decode_seq_len_cpu in forward batch and customizing a faster decode plan for graph replaying.

Also, some issues (#3906, #3917) points out current flashinfer mla backend behaves worse than triton in long output cases. Hopefully this PR will fix this problem.

Modifications

  • Add a new decode_seq_len_cpu in forward batch, which puts the information of seq_lens on cpu in advance.
  • Write fast_mla_decode_plan that can avoid transmitting indptr tensors from gpu to cpu during graph replaying.

Accuracy

Launching

python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1 --tp 8 --trust-remote-code --enable-flashinfer-mla

GSM8K

python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319
Accuracy: 0.956
Invalid: 0.000
Latency: 101.581 s
Output throughput: 1336.431 token/s

MMLU

bash benchmark/mmlu/download_data.sh
python3 benchmark/mmlu/bench_sglang.py --nsub 100 --ntrain 5 --parallel 2000
Total latency: 182.686
Average accuracy: 0.871

Benchmark

To better discover the improvement of this PR, the benchmarks are run on long output workloads (so number of graph replaying can be increased) with Deepseek-v2-lite. Machine is Nvidia H200. Each benchmark is run five times and the average throughput is computed. After this PR, throughput of flashinfer mla on these workloads can be improved by 1% to 2%.

To Launch:

python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V2-Lite --tp 8 --trust-remote-code --enable-flashinfer-mla 

Input-4096-Output-2048 (same workload as #3917)

python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4096 --random-output 2048 --num-prompt 60
Throughput (tok/s) Flashinfer (this PR) Flashinfer (before PR) Triton
Prefill 7757.64 7540.28 6563.27
Decode 4038.21 3925.06 3416.69

Input-180-Output-400 (same workload as #3906)

python3 -m sglang.bench_serving  --dataset-name=random --num-prompts=600    --random-range-ratio 0.9 --seed 42  --random-input 180 --random-output 400  --request-rate 40 --max-concurrency 40
Throughput (tok/s) Flashinfer (this PR) Flashinfer (before PR) Triton
Prefill 1585.48 1558.35 1532.36
Decode 3532.19 3480.68 3413.85

Input-100-Output-2000

python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 100 --random-output 2000 --request-rate 2 --num-prompt 120
Throughput (tok/s) Flashinfer (this PR) Flashinfer (before PR) Triton
Prefill 80.63 80.22 80.00
Decode 1751.71 1742.90 1737.88

Profiler Result

After profiling with torch profiler, we can see the time of waiting for memcpyAsync is removed. Since MLA with absorbed is compute bound and GPU is fully utilized, its influence on e2e throughput is not obvious.

Before this PR:
截屏2025-03-02 14 51 45

After this PR:
截屏2025-03-02 14 52 06

Checklist

@zhyncs zhyncs merged commit fa56106 into sgl-project:main Mar 3, 2025
1 of 16 checks passed
spec_info: Optional[SpecInfo],
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use kwargs. It makes the code more unreadable because we do not know what exact arguments are.
Is it possible to specify it more clearly?

spec_info: Optional[SpecInfo],
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be removed.

@@ -1168,8 +1171,10 @@ def merge_batch(self, other: "ScheduleBatch"):

def get_model_worker_batch(self):
if self.forward_mode.is_decode_or_idle():
decode_seq_lens = self.seq_lens.cpu()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will slowdown other things (e.g., speculative decoding where overlap scheduler is turned off). Can we only do this when needed?


# Common inputs
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)
if forward_batch.decode_seq_lens_cpu is not None:
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it is a CPU tensor, it does not need to go through these CUDA graph things.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants