Skip to content

Commit

Permalink
[MISC] Add prefix cache hit rate to metrics (vllm-project#7606)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Aug 19, 2024
1 parent df845b2 commit 3ac50b4
Show file tree
Hide file tree
Showing 16 changed files with 200 additions and 16 deletions.
26 changes: 26 additions & 0 deletions tests/core/block/test_prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,32 @@ def test_eviction_order(num_blocks: int, block_size: int, seed: int):

assert new_block[0].block_id == last_block_id

# Test case for cache mertics
@staticmethod
def test_metric():
block_size = 16
allocator = PrefixCachingBlockAllocator(num_blocks=4,
block_size=block_size)
# Test when no query (0/0)
assert allocator.get_prefix_cache_hit_rate() == 0.0

token_ids = list(range(block_size))
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids)
# Test 0/1 hit rate
assert allocator.get_prefix_cache_hit_rate() == 0.0

allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids)
# Test 1/2 hit rate
assert allocator.get_prefix_cache_hit_rate() == 0.5

# Test more than one block
for _ in range(2, 1005):
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids)
assert allocator.get_prefix_cache_hit_rate() > 0.99

@staticmethod
def create_immutable_chain(
block_size: int,
Expand Down
7 changes: 7 additions & 0 deletions tests/prefix_caching/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def test_block_allocator(
assert (first_block == second_block)
assert (second_block.ref_count == 2)

# Check metric: 1 hit of 2 queries
assert block_allocator.get_prefix_cache_hit_rate() == 0.5

# Free the first_block and confirm that the ref_count is correctly
# decremented on the second block
block_allocator.free(first_block)
Expand All @@ -48,6 +51,10 @@ def test_block_allocator(
assert (first_block == second_block)
assert (first_block.block_hash == block_hash)

# Allocate one more time to get 3/4 hit rate for easy checking
block_allocator.allocate(block_hash, 0)
assert block_allocator.get_prefix_cache_hit_rate() == 0.75


@pytest.mark.parametrize("num_blocks", [16])
def test_eviction(num_blocks: int, ):
Expand Down
53 changes: 53 additions & 0 deletions vllm/core/block/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import deque
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple

from vllm.core.block.interfaces import Block, BlockAllocator
Expand Down Expand Up @@ -282,6 +283,58 @@ def ids(self) -> List[int]:
return self._block_ids


@dataclass
class CacheMetricData:
"""A utility dataclass to maintain cache metric.
To avoid overflow, we maintain the hit rate in block granularity, so that
we can maintain a single hit rate for n_completed_block x block_size,
and calculate the real time hit rate by the following:
BS = The number of queries per block.
nB = The number of completed blocks.
HR = hit rate of (nB x BS) queries.
Q = current number of queries (< BS).
H = current number of hits (< BS).
hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS)
"""
num_completed_blocks: int = 0
completed_block_cache_hit_rate: float = 0.0
num_incompleted_block_queries: int = 0
num_incompleted_block_hit: int = 0
block_size: int = 1000

def query(self, hit: bool):
self.num_incompleted_block_queries += 1
self.num_incompleted_block_hit += 1 if hit else 0

# When a block is completed, update the cache hit rate
# and reset the incomplete numbers.
if self.num_incompleted_block_queries == self.block_size:
hit_rate = (self.num_incompleted_block_hit /
self.num_incompleted_block_queries)
self.completed_block_cache_hit_rate = (
self.completed_block_cache_hit_rate * self.num_completed_blocks
+ hit_rate) / (self.num_completed_blocks + 1)
self.num_incompleted_block_queries = 0
self.num_incompleted_block_hit = 0
self.num_completed_blocks += 1

def get_hit_rate(self):
incomplete_ratio = self.num_incompleted_block_queries / self.block_size
total_blocks = self.num_completed_blocks + incomplete_ratio
if total_blocks == 0:
return 0.0

completed_block_hit, incompleted_block_hit = 0.0, 0.0
if self.num_completed_blocks > 0:
completed_block_hit = (self.completed_block_cache_hit_rate *
self.num_completed_blocks)
if self.num_incompleted_block_queries > 0:
incompleted_hit_rate = (self.num_incompleted_block_hit /
self.num_incompleted_block_queries)
incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio)
return (completed_block_hit + incompleted_block_hit) / total_blocks


def get_all_blocks_recursively(last_block: Block) -> List[Block]:
"""Retrieves all the blocks in a sequence starting from the last block.
Expand Down
5 changes: 5 additions & 0 deletions vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,11 @@ def get_common_computed_block_ids(
def all_block_ids(self) -> FrozenSet[int]:
return frozenset(self._block_ids_to_allocator.keys())

def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
assert device in self._allocators
return self._allocators[device].get_prefix_cache_hit_rate()

def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
"""Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every
Expand Down
10 changes: 10 additions & 0 deletions vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ def get_num_blocks_touched(self,
num_lookahead_slots: int = 0) -> int:
pass

@abstractmethod
def get_prefix_cache_hit_rate(self) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass

class NoFreeBlocksError(ValueError):
pass

Expand Down Expand Up @@ -278,3 +283,8 @@ def allocate_or_get_null_block(self) -> Block:
There is at most one null block per allocator.
"""
pass

@abstractmethod
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
3 changes: 3 additions & 0 deletions vllm/core/block/naive_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ def swap_in(self, blocks: List[Block]) -> None:

block.block_id = block_id # Assign block_id

def get_prefix_cache_hit_rate(self) -> float:
return -1


class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix
Expand Down
10 changes: 8 additions & 2 deletions vllm/core/block/prefix_caching_block.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Token blocks."""

from os.path import commonprefix
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple

from vllm.core.block.common import (CopyOnWriteTracker,
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
Expand Down Expand Up @@ -107,6 +106,8 @@ def __init__(
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly())

self.metric_data = CacheMetricData()

# Implements Block.Factory.
def _create_block(
self,
Expand Down Expand Up @@ -155,9 +156,11 @@ def allocate_immutable_block(self,

cached_block_id = self._cached_blocks.get(block.content_hash, None)
if cached_block_id is not None:
self.metric_data.query(hit=True)
block.block_id = cached_block_id
self._incr_refcount_cached_block(block)
return block
self.metric_data.query(hit=False)
self._block_pool.free_block(block)

# No cached block => Allocate a new block
Expand Down Expand Up @@ -404,6 +407,9 @@ def get_physical_block_id(self, absolute_id: int) -> int:
def all_block_ids(self) -> FrozenSet[int]:
return self._hashless_allocator.all_block_ids

def get_prefix_cache_hit_rate(self) -> float:
return self.metric_data.get_hit_rate()

def is_block_cached(self, block: Block) -> bool:
assert block.content_hash is not None
if block.content_hash in self._cached_blocks:
Expand Down
31 changes: 27 additions & 4 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Set, Tuple

from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.block.common import CacheMetricData
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
Expand Down Expand Up @@ -60,6 +61,11 @@ def contains_block(self, block_hash: int) -> bool:
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
pass

@abstractmethod
def get_prefix_cache_hit_rate(self) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass


class CachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
Expand All @@ -85,6 +91,8 @@ def __init__(self,

self.default_hash_ctr = count()

self.cache_metric_data = CacheMetricData()

def allocate_block(self, block_hash: int,
num_hashed_tokens: int) -> PhysicalTokenBlock:
if self.current_num_blocks == self.num_blocks:
Expand All @@ -105,15 +113,17 @@ def allocate(self,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
if block_hash is None:
block_hash = next(self.default_hash_ctr)

if block_hash in self.evictor:
assert block_hash not in self.cached_blocks
block = self.evictor.remove(block_hash)
assert block.ref_count == 0
self.cached_blocks[block_hash] = block
block.ref_count += 1
assert block.block_hash == block_hash
return block
if block_hash not in self.cached_blocks:

if block_hash in self.cached_blocks:
self.cache_metric_data.query(hit=True)
else:
self.cache_metric_data.query(hit=False)
self.cached_blocks[block_hash] = self.allocate_block(
block_hash, num_hashed_tokens)
block = self.cached_blocks[block_hash]
Expand Down Expand Up @@ -150,6 +160,9 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
del self.cached_blocks[old_hash]
self.cached_blocks[block_hash] = block

def get_prefix_cache_hit_rate(self) -> float:
return self.cache_metric_data.get_hit_rate()


class UncachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
Expand Down Expand Up @@ -209,6 +222,9 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")

def get_prefix_cache_hit_rate(self) -> float:
return -1


class BlockSpaceManagerV1(BlockSpaceManager):
"""Manages the mapping between logical and physical token blocks."""
Expand Down Expand Up @@ -705,3 +721,10 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup):
if self.enable_caching:
for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq)

def get_prefix_cache_hit_rate(self, device: Device) -> float:
if device == Device.GPU:
return self.gpu_allocator.get_prefix_cache_hit_rate()
if device == Device.CPU:
return self.cpu_allocator.get_prefix_cache_hit_rate()
raise ValueError(f"Invalid device: {device}")
3 changes: 3 additions & 0 deletions vllm/core/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ def get_num_free_gpu_blocks(self) -> int:
def get_num_free_cpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.CPU)

def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_allocator.get_prefix_cache_hit_rate(device)

def _can_swap(self,
seq_group: SequenceGroup,
device: Device,
Expand Down
4 changes: 4 additions & 0 deletions vllm/core/embedding_model_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup
from vllm.utils import Device


class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
Expand Down Expand Up @@ -81,3 +82,6 @@ def get_common_computed_block_ids(self,

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass

def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1
15 changes: 8 additions & 7 deletions vllm/core/evictor_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,21 @@ def evict(self) -> Tuple[int, int]:
if len(self.free_table) == 0:
raise ValueError("No usable cache memory left")

evicted_block = next(iter(self.free_table.values()))
evicted_block_id = next(iter(self.free_table.keys()))
evicted_block, evicted_block_id = None, None
# The blocks with the lowest timestamps should be placed consecutively
# at the start of OrderedDict. Loop through all these blocks to
# find the one with maximum number of hashed tokens.
for _id, block in self.free_table.items():
if evicted_block is None:
evicted_block, evicted_block_id = block, _id
continue
if evicted_block.last_accessed < block.last_accessed:
break
if (evicted_block.last_accessed == block.last_accessed and
evicted_block.num_hashed_tokens < block.num_hashed_tokens):
evicted_block = block
evicted_block_id = _id
if evicted_block.num_hashed_tokens < block.num_hashed_tokens:
evicted_block, evicted_block_id = block, _id

assert evicted_block is not None
assert evicted_block_id is not None
self.free_table.pop(evicted_block_id)

return evicted_block_id, evicted_block.content_hash
Expand All @@ -110,7 +112,6 @@ def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,

def update(self, block_id: int, last_accessed: float):
self.free_table[block_id].last_accessed = last_accessed
self.free_table.move_to_end(block_id)

def remove(self, block_id: int):
if block_id not in self.free_table:
Expand Down
6 changes: 6 additions & 0 deletions vllm/core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Tuple

from vllm.sequence import Sequence, SequenceGroup
from vllm.utils import Device


class AllocStatus(enum.Enum):
Expand Down Expand Up @@ -116,3 +117,8 @@ def get_common_computed_block_ids(
@abstractmethod
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass

@abstractmethod
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
5 changes: 4 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus)
from vllm.utils import PyObjectCache
from vllm.utils import Device, PyObjectCache

logger = init_logger(__name__)

Expand Down Expand Up @@ -447,6 +447,9 @@ def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0

def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)

def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)

Expand Down
Loading

0 comments on commit 3ac50b4

Please sign in to comment.