-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Splitwise implementation to vLLM #2809
base: main
Are you sure you want to change the base?
Conversation
Initialize MSCCL++ communication group if sep_prompt_token is set in ParallelConfig. Also add documentation for MSCCL++ installation.
- Add worker_type to differentiate prompt, token, and mixed workers - Set a driver for each prompt machines and token machines - Allow broadcasts to take a group - Setup KV Cache communication using MSCCL++ - Add test for KV Cache communication
- Obtain `blocks_to_nw` when creating batches in scheduler. Coalesce blocks where possible for fast network transfers. - Use a Sequence to Semaphore Mapper to allow for fine-grained waiting for kv-cache transfer per sequence - Separately run prompt and token workers using the `_run_stage_workers` helper - Populate KVCacheCommunicator for all PagedAttention modules, which allows implementation of layer-wise sends from within `attention.py` - Populate destination rank for Sampler, which will be used as root in `gather` operations. - Fix `tensor_model_parallel_gather` - use global rank instead of group local rank.
Splitting the prefilling and decoding on different GPUs is an excellent idea. For example, the prefilling on H100 and decoding on A100 since H100 has 3.43x more compute but only 1.6x more memory bandwidth - meaning that it would be more cost-efficient to use H100 only for prefilling. In the paper, you demonstrate a 2.35x increase in throughput at same cost or 1.4x higher through at 20% cost saving. Are you able to reproduce these numbers in this PR? |
@casper-hansen In the paper, we use the splitwise simulator to simulate a 40 node cluster - both for the homogenous (Splitwise-AA/HH), and the heterogenous (Splitwise-HHcap/HA) solutions. That simulation allows us to run a production trace through these systems at various requests per second (rps) under SLOs. That is what allows us to calculate the maximum throughput under a given cost/power, leading us to the scaled results. The code that we just pushed only allows us to build the prototype of that solution, since it does not include the optimized cluster-level scheduler. The prototype has been developed and tested on 1 prompt machine and 1 token machine. Therefore, the main point of this PR is to show the optimized KV cache transfer time, rather than the at-scale results. Hope this answers your question. |
@GindaChen Hey Junda can you help take a look at the PR and leave some comments? Thanks! |
@aashaka This is a very promising PR to integrate Splitwise into vLLM! I will try to finish up the review today. From what I have understood, this PR only tries to introduce the splitwise-mode, stage parallelism and KV cache transfer into vLLM. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome changes so far! Please see comments for changes or feedbacks for future design changes.
I may still have another round of review on the comm_utils.py
and after running the tests. So far I am stuck at the installation of MSCCL++ library for some weird reason. I can post the error / environment once I figure out a path to testing.
vllm/worker/worker.py
Outdated
# Populate Sampler with dst_rank as driver worker's rank. | ||
self.model_runner.model.sampler.set_dst_rank(self.model_runner.driver_rank) | ||
|
||
def init_mscclpp_comm(self, mscclpp_init_method: Optional[str] = None) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recommend separating the initialization of communication as a separate class or function (than being inside the worker). In my opinion, the vLLM codebase may benefit from having a better abstraction for communication, say:
class CommManager:
...
def init_comm(...): ...
def destroy_comm(...): ...
def get_group(self, group_type): ... # say tensor / stage / pipeline parallel group
...
class Worker:
def __init__(self, ...):
self.comm_manager = CommManager(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm open for having this as a future refactor, but want to see if you think separating this logic now is feasible.
`MSCCL++ <https://github.com/microsoft/mscclpp>`_ is a GPU-driven communication stack for scalable AI applications. | ||
It is used to implement KV cache communication in Splitwise. | ||
|
||
To install MSCCL++, please follow the instructions at `MSCCL++ Quickstart <https://github.com/microsoft/mscclpp/blob/main/docs/quickstart.md>`_ or follow the steps below to install it from source: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a heads that that the system may need to apt install libnuma-dev
(libnuma1
) prior make
(I hit this error at installation).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the heads up. I will add this to the installation instructions both here and in the MSCCL++ repo. Also, feel free to reach out to me if you have any other problems with MSCCL++ setup.
@@ -369,14 +369,22 @@ def __init__( | |||
worker_use_ray: bool, | |||
max_parallel_loading_workers: Optional[int] = None, | |||
disable_custom_all_reduce: bool = False, | |||
sep_prompt_token: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modification on config looks good to me. I guess there may be future extension to pass in the number of prompt / token workers, and I think so far the abstraction looks good.
vllm/utils.py
Outdated
|
||
class SeqToSlotMapper: | ||
""" SeqToSlotMapper maps sequence ids to a limited set of slot ids. | ||
A slot is freed every time a sequence finishes. It is used to manage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be helpful to introduce what is a "slot" here.
all_outputs = self._run_stage_workers( | ||
"execute_model", | ||
prompt_stage=seq_group_metadata_list[0].is_prompt, | ||
driver_kwargs={ | ||
"seq_group_metadata_list": seq_group_metadata_list, | ||
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, | ||
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, | ||
"blocks_to_copy": scheduler_outputs.blocks_to_copy, | ||
"blocks_to_nw": scheduler_outputs.blocks_to_nw, | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just want to confirm - this means prompt workers cannot work concurrently with decode workers because each time we only schedule one set of workers to run.
I understand this PR is a prototype to verify KV cache transfer between prompt / decode workers, so performance of concurrent set of workers running isn't the focus. I do want to point out that making workers run concurrently (and the communication between schedulers) turns out to be one of the major challenge in design that potentially break the vLLM current architecture.
I would be more than happy to hear if you have a great solution! I'm also open to talk in detail about our design, and vLLM team's concern about the architectural change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With that said, I am personally okay leaving this code as is in this PR so we can demonstrate the KV cache transfer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recommend adding a TODO
to say something like "this doesn't achieve the best performance because we only schedule one set of workers at a time" so people don't get confused.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@GindaChen thanks for all the comments. I have updated the PR accordingly. Please let me know if you have any further comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aashaka Thanks for the heads up! I will take a look at the changes soon!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this means prompt workers cannot work concurrently with decode workers
I do not quite understand this part. Since decode works are provisioned separately, why can them run concurrently with prompt workers here?
because each time we only schedule one set of workers to run.
one set of workers means <prompt, worker> pair?
What you are trying to say is currently it doesn't support <List, List> to accelerate the prefill and decoding phase? (which requires model parallelism)
Hello, I have encountered a few problems when repeating the results. First, some difficulty arose when installing the MSCCL++ library. It requires CUDA at least version 11 (I don't know exactly which one), 10.2 will not work for sure. You also need cmake >= 3.25.0. The main difficulty with the installation was to install libibverbs-dev correctly. Version 17.1 of libibverbs-dev will not work because of errors during the building process, so you need at least 28. I was able to build with libibvers-dev>=36.0-1. However, I had to add Second, I don't have an "eth0" interface on my machine. There are others in
Could you give me a hint, maybe some profiling information can be gathered to see what the problem might be. |
@valvarl, the communication setup will require InfiniBand support. Looks like |
Hi, Does it means that InfiniBand support is must to have to enable splitwise feature? |
Unfortunately, I don't have physical access to the server. However, I can see some information about available connections on it.
I am not familiar with RDMA library and I don't understand how to use |
@leiwen83 currently, yes that is the case. While it has been low-priority item for us, we do have a plan to support Ethernet in the future. |
Hi @aashaka, what are the system requirements to run this Splitwise implementation:
|
Hi. I was trying to run this PR on my machine. I tested the code by running 'tests/distributed/test_kvcache_comm.py'. But, i am getting this error below:
Seems like kvcache_comm_manager.put method works without any problem (inside the attention.py). But, kvcache_comm_manager.signal_and_flush gets this error inside worker.py . I couldn't figure out the source of the problem. Does this error message say something to you ? |
|
||
self.world_size = pipeline_parallel_size * tensor_parallel_size | ||
if sep_prompt_token: | ||
# Half of the workers are prompt workers and the other half are token |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shout it be the fixed size? If the workers use exact same gpus. seems prompt may need more works?
@@ -116,8 +116,13 @@ def __init__( | |||
# Profile the memory usage and initialize the cache. | |||
self._init_cache() | |||
|
|||
if self.parallel_config.sep_prompt_token: | |||
# Setup the MSCCL++ communication required for KV cache transfer | |||
self._setup_kvcache_comm() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious why do we need MSCCL for communication, any other options like torch.distributed
with nccl backend?. Any analysis on the communication collective library used for KV cache transfer?
@@ -229,6 +250,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", | |||
|
|||
distributed_init_method = get_distributed_init_method( | |||
driver_ip, get_open_port()) | |||
mscclpp_init_method = f"eth0:{driver_ip}:{get_open_port()}" if self.parallel_config.sep_prompt_token else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: can we leave a TODO here.
- I feel this is limited if we want to use high speed network interface. It would be great to extract and env like
NCCL_SOCKET_IFNAME
- Do not know whether mscclpp is compatible with other high speed interfaces?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides, I don't think everyone's default network interface would have eth0. Agree with @Jeffwan's suggestion. Or at least let user to specify which network interface they would like to use.
all_outputs = self._run_stage_workers( | ||
"execute_model", | ||
prompt_stage=seq_group_metadata_list[0].is_prompt, | ||
driver_kwargs={ | ||
"seq_group_metadata_list": seq_group_metadata_list, | ||
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, | ||
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, | ||
"blocks_to_copy": scheduler_outputs.blocks_to_copy, | ||
"blocks_to_nw": scheduler_outputs.blocks_to_nw, | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this means prompt workers cannot work concurrently with decode workers
I do not quite understand this part. Since decode works are provisioned separately, why can them run concurrently with prompt workers here?
because each time we only schedule one set of workers to run.
one set of workers means <prompt, worker> pair?
What you are trying to say is currently it doesn't support <List, List> to accelerate the prefill and decoding phase? (which requires model parallelism)
What are the features provided by your cluster-level scheduler in this case? something like prompt and decode machine collaboration? |
I am trying to get more details here. Seems
|
len(HEAD_TYPES)) + layer_id * len(HEAD_TYPES) + head_type | ||
torch.cuda.synchronize() | ||
|
||
def send_recv_kvcache_all(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this method only use in testing? I can not find any other references. If so, can you add some comments?
Hi, all, I was wondering when will it be merged into the main branch? |
any update? |
Same question: will the cluster-level scheduler be released in the vllm repo or another repo? @aashaka |
have the same issue |
do you fix this problem, maybe you can help me, please |
Seems a problem of mscclpp. In mscclpp github, there are some similar issues. |
I currently have a question: the paper states that per-layer can accelerate GPU memory release during the prompt phase, but I seem to be unable to find where this is implemented in the code. Could you please clarify whether the GPU memory is immediately released after completing the KV transfer for a certain layer, or is it retained until the end of the prompt phase? |
This pull request has merge conflicts that must be resolved before it can be |
If there isn't any InfiniBand or NVLink on my machine, how can I use this technique to separate prefill and decode? |
Is this still ongoing? Given that we now have #10502. |
This PR follows up on #2472 to implement the prompt and token stage parallelism introduced in Splitwise.
On enabling the
--sep-prompt-token
flag, first half of the workers are assigned to process prompts and the second half perform token sampling. The KV-cache state is communicated over the network in a layer-wise manner as soon as it is ready on the prompt side. We use the MSCCL++ communication library to perform fast asynchronous KV-cache transfers over the IB fabric.This PR makes the following changes:
Installation dependencies:
We use the MSCCL++ collective communication library for KV-cache transfers.
Please follow these instructions at MSCCL++ Quickstart or follow the steps below to install it from source:
Make sure that
$MSCCLPP_HOME
is set to the installation directory or runsudo make install
Tests:
This PR has been tested in the following scenarios.
Validating communication of KV cache:
Command used:
python tests/distributed/test_kvcache_comm.py
Result: Runs without assertion errors.
Without MSCCL++ environment, no stage parallelism:
Command used:
python examples/llm_engine_example_single.py --tensor-parallel-size 8 --model bigscience/bloom
Result: Runs like normal.
With stage parallelism:
Command used:
python examples/llm_engine_example_single.py --tensor-parallel-size 8 --model bigscience/bloom --sep-prompt-token
Result: Same output as before.
llm_engine_example_single.py
is the llm_engine_example.py with n=1 and deterministic SamplingParameters.Known issues: