-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fast decode plan for flashinfer mla #3987
Conversation
spec_info: Optional[SpecInfo], | ||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not use kwargs
. It makes the code more unreadable because we do not know what exact arguments are.
Is it possible to specify it more clearly?
spec_info: Optional[SpecInfo], | ||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be removed.
@@ -1168,8 +1171,10 @@ def merge_batch(self, other: "ScheduleBatch"): | |||
|
|||
def get_model_worker_batch(self): | |||
if self.forward_mode.is_decode_or_idle(): | |||
decode_seq_lens = self.seq_lens.cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will slowdown other things (e.g., speculative decoding where overlap scheduler is turned off). Can we only do this when needed?
|
||
# Common inputs | ||
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) | ||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) | ||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) | ||
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) | ||
self.positions[:raw_num_token].copy_(forward_batch.positions) | ||
if forward_batch.decode_seq_lens_cpu is not None: | ||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it is a CPU tensor, it does not need to go through these CUDA graph things.
This reverts commit fa56106.
Motivation
When using flashinfer mla backend and cuda graph together, graph replay will be hanged due to transmission of indptr tensors between cpu and gpu in
BatchMLAPagedAttentionWrapper.plan
.This PR fixes this issue by adding a new
decode_seq_len_cpu
in forward batch and customizing a faster decode plan for graph replaying.Also, some issues (#3906, #3917) points out current flashinfer mla backend behaves worse than triton in long output cases. Hopefully this PR will fix this problem.
Modifications
decode_seq_len_cpu
in forward batch, which puts the information of seq_lens on cpu in advance.fast_mla_decode_plan
that can avoid transmitting indptr tensors from gpu to cpu during graph replaying.Accuracy
Launching
GSM8K
MMLU
Benchmark
To better discover the improvement of this PR, the benchmarks are run on long output workloads (so number of graph replaying can be increased) with Deepseek-v2-lite. Machine is Nvidia H200. Each benchmark is run five times and the average throughput is computed. After this PR, throughput of flashinfer mla on these workloads can be improved by 1% to 2%.
To Launch:
Input-4096-Output-2048 (same workload as #3917)
Input-180-Output-400 (same workload as #3906)
Input-100-Output-2000
Profiler Result
After profiling with torch profiler, we can see the time of waiting for memcpyAsync is removed. Since MLA with absorbed is compute bound and GPU is fully utilized, its influence on e2e throughput is not obvious.
Before this PR:

After this PR:

Checklist