Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prefill with Prefix Cache] Improve the efficiency of prefilling with prefix cache by allowing a larger batch size #3402

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

MeloYang05
Copy link
Contributor

@MeloYang05 MeloYang05 commented Mar 14, 2024

Hello everyone, I am testing the efficiency of prefilling with cache enabled to explore the potential opportunities to apply the techniques of chunked prefilling or even SplitFuse in vLLM. However, currently the performance of prefilling with prefix cache is not satisfied mainly because of the following two reasons:

  1. Small Batch Size. In scheduler.py, when determining whether a SequenceGroup can be added to the running queue, the can_allocate() method of block_manager doesn't aware the cached blocks of the SequenceGroup. For example, assume there's 10 free blocks, and the SequenceGroup requires 12 blocks, but 8 of the blocks is cached. In this scenario, such SequenceGroup should be allowed to be added to the running queue as it only requires 4 extra blocks indeed. However, current implementation doesn't check the cached blocks and so there's no difference between a prefix prefill and a normal prefill. In theory, prefix prefill should allow a much larger batch size than the normal prefill to increase the entire throughput. Moreover, in the later checking of num_batched_tokens, current implementation also doesn't move out the computed tokens number, further lead to the small batch size.
  2. Repeated incremental_detokenize. I also find that even when I fix the batch size to be 1, prefill prefill with 1 extra token's latency(e.g., prompt length 257, and 256 tokens have been cached) is still far behind decode 1 token. I find that the performance gap mainly results from the different execution time of incremental_detokenize of decoding and prefix prefilling. As for decoding, since the seq.tokens have been computed, it only needs to compute the token for the new token id. However, in prefix prefill, seq.tokens is None, so it has to detokenize all the token ids, which is a huge gap. Additionally, I find that incremental_detokenize actually detokenize all the token ids more than once when prefilling. In method _decode_sequence in llm_engine.py, step self._decode_logprobs(...) will call incremental_detokenize many times, and all these incremental_detokenize will all detokenize all the token ids if seq.tokens is None, which I think is totally unnecessary. Such implementation leads to the performance degradation especially when the prompt length is long.

In this PR, I primarily address the problem of enabling a larger batch size for prefix prefilling. The issue of repeated incremental_detokenize calls will be resolved in a separate PR by @Yard1, who will approach the fix in a more systematic way.

How do I fix the problems

Enable Larger Batch Size for Prefilling with Prefix Cache

To enable a larger batch size, we have to make the following two kinds of modifications:

  1. Be aware of the cached blocks when determining whether there are enough free GPU blocks that can be allocated for the sequence.
  2. Be aware of the computed blocks when checking whether the currently batched tokens exceed the system's limitation.

Aware of the Cached Blocks

Cached blocks are those blocks stored in block manager's cached_blocks item, where the hash value of the block is the key.

In block_manager.py, I add a method to get the number of cached blocks:

    def get_num_cached_blocks(self, seq: Sequence) -> int:
        # NOTE: cached blocks of a sequence means some logical blocks
        # of the sequence map to the physical blocks which already
        # have been stored in block manager's 'cached_blocks' dict.
        # These cached blocks doesn't need allocate again, reducing
        # the memory requirement during sequence allocation.
        if not self.enable_caching:
            return 0
        num_cached_blocks = 0
        for logical_idx in range(len(seq.logical_token_blocks)):
            block_hash = seq.hash_of_block(logical_idx)
            if block_hash in self.cached_blocks:
                num_cached_blocks += 1
        return num_cached_blocks

And in method can_allocate(), I add the following lines to make it exclude the number of cached blocks:

        # Check the number of blocks which have been cached
        num_cached_blocks = self.get_num_cached_blocks(seq)
        num_required_blocks -= num_cached_blocks

Aware of the Computed Blocks

The computed blocks of a sequence should satisfy the following conditions:

  1. stored in either cached blocks or evictor
  2. block.computed is true
  3. these blocks should be continuous as prefix cache is continuous. E.g., even if (0, 1, 2, 4) blocks are in cached_blocks and marked as computed, but 4 is not continuous with (0, 1, 2). Thus, the computed blocks are only (0, 1, 2)
  4. The computed blocks should start at the very beginning of a sequence's logical blocks.

Moreover, to align with #3239, the last blocks of a sequence will always not be regarded as the computed block.

Therefore, I write the following method in block_manager.py to get the number of computed blocks before it is allocated.

    def get_num_computed_blocks(self, seq: Sequence) -> int:
        # NOTE: computed blocks of a sequence means some logical blaocks
        # of the sequence map to the physical blocks which:
        # 1. stored in either 'cached blocks' or 'evictor'
        # 2. 'block.computed' is true
        # 3. these blocks should be continuous as prefix cache is continuous,
        #    e.g. even if (0, 1, 2, 4) blocks are in `cached_blocks` and
        #    marked as computed, but 4 is not continuous with (0, 1, 2).
        #    Thus, the computed blocks are only (0, 1, 2).
        # 4. The computed blocks should start at the very beginning of a
        #    sequence's logical blocks.
        # Only the blocks satisfy the above three conditions are computed
        # blocks, which can be treated as prefix cache during prefilling
        if not self.enable_caching:
            return 0
        num_computed_blocks = 0
        # Align with https://github.com/vllm-project/vllm/pull/3239
        # The last logical block always needs to be computed
        for logical_idx in range(len(seq.logical_token_blocks) - 1):
            block_hash = seq.hash_of_block(logical_idx)
            if block_hash in self.cached_blocks:
                if self.cached_blocks[block_hash].computed:
                    num_computed_blocks += 1
                # Not computed, violate the "continuous" requirement, break
                else:
                    break
            elif block_hash in self.evictor:
                # First remove and then add is reasonable.
                # Because if we check whether a block is computed,
                # it means it's hot and should be evicted later
                block = self.evictor.remove(block_hash)
                self.evictor.add(block)
                if block.computed:
                    num_computed_blocks += 1
                # Not computed, violate the "continuous" requirement, break
                else:
                    break
            # Not in 'cached_blocks‘ or 'evictor', cannot be computed
            # Not computed, violate the "continuous" requirement, break
            else:
                break
        return num_computed_blocks

I believe the logic of my method aligns with that of the get_all_computed_blocks method, which will be invoked following the allocation of the sequence.

    def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
        if seq.seq_id not in self.block_tables:
            return []
        block_table = self.block_tables[seq.seq_id]
        # NOTE We exclude the last block to avoid the case where the entire
        # prompt is cached. This would cause erroneous behavior in model
        # runner.
        return [
            b.block_number
            for b in takewhile(lambda b: b.computed, block_table[:-1])
        ]

Additionally, in the _scheduler() method in scheduler.py, I have modified it to exclude computed tokens when checking the number of batched tokens.

    def get_num_computed_tokens(self, seq: Sequence) -> int:
        # The last block has been exclued in method `get_num_computed_blocks`
        # Thus will not overestimate the computed tokens
        return self.block_size * self.get_num_computed_blocks(seq)
                # If the number of batched tokens exceeds the limit, stop.
                num_computed_tokens = self.block_manager \
                    .get_num_computed_tokens(waiting_seq)
                # Exclude the computed tokens
                new_seq_lens = seq_lens + [
                    num_prompt_tokens - num_computed_tokens
                ]
                num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
                if (num_batched_tokens >
                        self.scheduler_config.max_num_batched_tokens):
                    break

Other Optimizations

Moreover, since now it is more frequently to compute the cache value of logical blocks(seq.hash_of_block(logical_idx)), to improve the efficiency, I store the hash value of the logical block when it is firstly computed(Notice that only the full logical block can store hash value).

class LogicalTokenBlock:
    """A block that stores a contiguous chunk of tokens from left to right.

    Logical blocks are used to represent the states of the corresponding
    physical blocks in the KV cache.
    """

    def __init__(
        self,
        block_number: int,
        block_size: int,
    ) -> None:
        self.block_number = block_number
        self.block_size = block_size

        self.token_ids = [_BLANK_TOKEN_ID] * block_size
        self.num_tokens = 0
        self.block_hash: Optional[int] = None # Store hash value
    def hash_of_block(self, logical_idx: int) -> int:
        # Compute the number of tokens in the sequence
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
        logical_block = self.logical_token_blocks[logical_idx]
        hash_value = logical_block.block_hash
        if hash_value is None:
            num_tokens = self.num_hashed_tokens_of_block(logical_idx)
            hash_value = hash((tuple(self.data.get_token_ids()[0:num_tokens]),
                               self.lora_int_id))
            # Only when the logical block is full, store the hash value
            if logical_block.is_full():
                logical_block.block_hash = hash_value
        return hash_value

Also there's a small modification in method compute_full_blocks_in_seq in block_manager.py:

        max_full_block = seq.get_len() // self.block_size
        # If the last block is not full, then we need to reduce
        # the max full block number by 1
        if not self._is_last_block_full(seq):
            max_full_block -= 1

In the previous implementation, we had max_full_block = seq.get_len() // self.block_size - 1. However, I believe there is no need to subtract 1 from max_full_block if the last block is full. This is because the last block will be excluded by block_table[:-1] in the expression for b in takewhile(lambda b: b.computed, block_table[:-1]) within the get_all_computed_blocks method. Given that the last block is full, it can also be marked as computed.

Performance Improvement

I write the following scripts file to test the performance gain after my modifications.

from vllm import LLM, SamplingParams
from typing import List
import time


def build_batched_token_ids(token_ids: List[int],
                            batch_size: int) -> List[List[int]]:
    assert batch_size >= 1
    batched_token_ids: List[List[int]] = []
    for i in range(batch_size):
        batched_token_ids.append(token_ids[i:] + token_ids[:i])
    return batched_token_ids


prefix_len = 1025
requests_num = 48
block_size = 16

llm = LLM(model="/path/to/Llama-2-7B-FP16",
          enable_prefix_caching=True,
          block_size=block_size)

prompt1_token_ids = list(range(prefix_len))
batched_prompt1_token_ids = build_batched_token_ids(prompt1_token_ids,
                                                    requests_num)

sampling_params = SamplingParams(temperature=0.0,
                                 max_tokens=1,
                                 ignore_eos=True)

# Prepare the prefix cache
llm.generate(sampling_params=sampling_params,
             prompt_token_ids=batched_prompt1_token_ids,
             use_tqdm=False)

# Prefill with the prefix cache
print("Start!")
start_time = time.time()
llm.generate(sampling_params=sampling_params,
             prompt_token_ids=batched_prompt1_token_ids,
             use_tqdm=False)
end_time = time.time()
duaration = end_time - start_time
print(f"Eplapsed: {duaration}s")

I use the Llama-2-7B-FP16 model, and my gpu is A800. Before the modifications, the maximum batch size for prefix prefilling is just 3, and the entire prefix prefilling process takes about 0.93s. After the changes, the batch size reaches 48, and the prefix prefilling time has been shorten to about 0.21s (also with the repeated detokenization simply fixed ), which achieves more than 4x speed up. I hope these changes can help vLLM prepared for techniques which applying prefix prefilling.

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add relevant tests to core/test_scheduler.py and core/test_block_manager.py?

@@ -98,6 +98,41 @@ def get_num_free_blocks(self) -> int:
return (self.num_blocks - self.current_num_blocks +
self.evictor.num_blocks)

def get_num_cached_blocks(self, seq: Sequence) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add doctsring to indicate the difference between cached vs computed blocks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thank you for your review! Since now it is very late in my area, I plan to add unit tests and more doctstrings tommorow. I have reviewed my logic, and think it shoud be right but kind of tricky. I will add more comments to make it more easy to understand. Thank you!

num_computed_tokens = self.block_manager \
.get_num_computed_tokens(waiting_seq)
new_seq_lens = seq_lens + [
num_prompt_tokens - num_computed_tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be max(num_prompt_tokens - num_computed_tokens, 0)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g., block size == 16, 4 blocks are cached (and 60 tokens are prefix), and if the token size is 63, it can have negative value?

return self.gpu_allocator.get_num_computed_blocks(seq)

def get_num_computed_tokens(self, seq: Sequence) -> int:
return self.block_size * self.get_num_computed_blocks(seq)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Does this overestimate the number of computed tokens? (as it is not guaranteed all slots in the block has token allocated)? Can you write a docstring that explains the behavior here? (Either it can be overestimated OR if we have a certain assumption it should be documented)

if self.cached_blocks[block_hash].computed:
num_computed_blocks += 1
else:
break
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment why it breaks? (I assume it has assumption blocks are computed in order?)

vllm/sequence.py Outdated
@@ -289,7 +296,7 @@ class SequenceGroupState:
"""Mutable state tied to a specific sequence group"""

# torch.Generator used in seeded sampling
generator: Optional = None
generator: Optional[torch.Generator] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd intentionally avoided the type here so that torch doesn't need to be imported, since this layer deals with higher level sequence abstractions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I haved removed this modification. I did it before because Pylance in VSCode always remind me this problem, which is kind of annoying. BTW, it' fine according to your explanation, thank you!

@njhill
Copy link
Member

njhill commented Mar 14, 2024

@MeloYang05 this looks great, I think it would be good to split into two PRs though since the improvements are independent?

@MeloYang05
Copy link
Contributor Author

@MeloYang05 this looks great, I think it would be good to split into two PRs though since the improvements are independent?

OK thanks, I will split the PR into two independent ones.

@Yard1
Copy link
Collaborator

Yard1 commented Mar 17, 2024

@MeloYang05 Thanks for pointing out the incremental detokenization issue. Do you mind if I submit a PR fixing it in a more comprehensive way? You will be added as a co-author.

@MeloYang05
Copy link
Contributor Author

@MeloYang05 Thanks for pointing out the incremental detokenization issue. Do you mind if I submit a PR fixing it in a more comprehensive way? You will be added as a co-author.

OK it's fine, I'm very glad to see it can be fixed in a more comprehensive way, thank you! I will only handle the batch size of prefix prefill in this PR.

@zhuohan123 zhuohan123 self-assigned this Mar 18, 2024
@MeloYang05 MeloYang05 changed the title [Prefill with Prefix Cache] Improve the efficiency of prefilling with prefix cache by allowing a larger batch size and avoid redundant incremental detokenize operations [Prefill with Prefix Cache] Improve the efficiency of prefilling with prefix cache by allowing a larger batch size Mar 19, 2024
@MeloYang05 MeloYang05 force-pushed the chunk_batch branch 2 times, most recently from 107fd9a to 02a8a2e Compare March 19, 2024 02:53
return self.gpu_allocator.get_num_computed_blocks(seq)

def get_num_computed_tokens(self, seq: Sequence) -> int:
# The last block has been excluded in method `get_num_computed_blocks`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit;

def get_num_computed_tokens(self, seq: Sequence) -> int:
"""Return the number of tokens that are already computed.

NOTE: This excludes tokens from the last blocks.
"""

@MeloYang05
Copy link
Contributor Author

Hi, I have made some edits to ensure my modifications are compatible with the latest updates regarding the block manager abstraction.

Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale label Oct 29, 2024
Copy link

mergify bot commented Oct 29, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @MeloYang05 please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@simon-mo simon-mo requested a review from comaniac as a code owner November 26, 2024 05:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants