[Core] Optimizing cross-attention QKVParallelLinear
computation
#12325
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
TL;DR: Basically another take at #7448 based on the work on the Whisper model, with sugar on top to provide a drop-in replacement module.
Addressing TODOs https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/bart.py#L352 and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/mllama.py#L750.
Current cross-attention QKV projection is sub-optimal as we're wasting cycles on bigger-than-necessary matrices, especially important in the compute-bound stage. That is because
QKVParallellLinear
layers are being used to only compute theq
andkv
projection, separately in two sequential calls.I propose adopting the solution we make use of here https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/whisper.py#L173, where q\kv are being split into a
ColumnParallelLinear
andQKVParallelLinear
layer, respectively, instantiating and sharding only the matrices we actually make use of. Support of tensor parallelism should be unscathed.I also provide a drop-in replacement util layer
QKVCrossParallellLinear
to use in substitution ofQKVParallellLinear
layers such that loading code remains the same, especially the usualstacked_params_mapping
.==>Let me know what you think about the util Module interface/API, otherwise I can just substitute in its optimized code inline.
Early benchmarking results (single L4 24gb, running
facebook/bart-large-cnn
):PRE-PR
b197a5cc
POST-PR
TODO:
QKVCrossParallellLinear
both in code and docs in "how to add model"