Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][Core][WIP] Tree attention and parallel decoding #4325

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix code style
  • Loading branch information
kavioyu committed Apr 24, 2024
commit af98a2718c945659572788f4134e8d5663ede596
84 changes: 36 additions & 48 deletions tests/kernels/test_tree_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
MAX_SEQ_LEN = 2048
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 1024 # Arbitrary values for testing
NUM_BLOCKS = 1024 # Arbitrary values for testing
# only test on half and bfloat16
DTYPES = [torch.half, torch.bfloat16]
NUM_GEN_SEQS = [7] # Arbitrary values for testing
Expand All @@ -39,9 +39,13 @@
]
TREEWIDTH = [1, 7, 31]

def create_tree_attention_mask(context_len, prompt_len, tree_width, num_kv_head, dtype):
prompt_mask = torch.zeros((num_kv_head, tree_width, prompt_len), dtype=dtype)
none_mask_value = torch.arange(context_len-prompt_len).repeat(tree_width, 1) - torch.arange(tree_width)[:, None]

def create_tree_attention_mask(context_len, prompt_len, tree_width,
num_kv_head, dtype):
prompt_mask = torch.zeros((num_kv_head, tree_width, prompt_len),
dtype=dtype)
none_mask_value = torch.arange(context_len - prompt_len).repeat(
tree_width, 1) - torch.arange(tree_width)[:, None]
none_mask_value = none_mask_value % tree_width
none_mask_value = none_mask_value == 0

Expand All @@ -52,6 +56,7 @@ def create_tree_attention_mask(context_len, prompt_len, tree_width, num_kv_head,
generate_mask = generate_mask.unsqueeze(0).repeat(num_kv_head, 1, 1)
return torch.concat([prompt_mask, generate_mask], dim=2)


def ref_masked_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand All @@ -68,19 +73,15 @@ def ref_masked_attention(
return out


def ref_query_cached_kv_attention(
output: torch.Tensor,
query: torch.Tensor,
num_queries_per_kv: int,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
scale: float,
alibi_slopes: Optional[torch.Tensor],
prompt_lens: torch.Tensor,
tree_width: int
) -> None:
def ref_query_cached_kv_attention(output: torch.Tensor, query: torch.Tensor,
num_queries_per_kv: int,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor, scale: float,
alibi_slopes: Optional[torch.Tensor],
prompt_lens: torch.Tensor,
tree_width: int) -> None:
num_query_heads = query.shape[1]
num_kv_heads = value_cache.shape[1]
head_size = value_cache.shape[2]
Expand Down Expand Up @@ -116,7 +117,11 @@ def ref_query_cached_kv_attention(
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

mask = create_tree_attention_mask(context_len, prompt_len, tree_width, num_query_heads, dtype=torch.float)
mask = create_tree_attention_mask(context_len,
prompt_len,
tree_width,
num_query_heads,
dtype=torch.float)
alibi_bias = None
if alibi_slopes is not None:
# Create the ALiBi bias used in the paged attention kernel.
Expand Down Expand Up @@ -161,7 +166,10 @@ def test_paged_attention(
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs*tree_width, num_query_heads, head_size, dtype=dtype)
query = torch.empty(num_seqs * tree_width,
num_query_heads,
head_size,
dtype=dtype)
query.uniform_(-scale, scale)

assert num_query_heads % num_kv_heads == 0
Expand Down Expand Up @@ -206,49 +214,29 @@ def test_paged_attention(
# #value_cache = torch.ones_like(value_cache)
# key_cache = torch.ones_like(key_cache)
# query = torch.ones_like(query)

output = torch.empty_like(query)
torch.cuda.synchronize()
start_time = time.time()
tree_attention_fwd(
query,
output,
key_cache,
value_cache,
block_tables,
context_lens,
prompt_lens,
tree_width,
alibi_slopes
)
tree_attention_fwd(query, output, key_cache, value_cache, block_tables,
context_lens, prompt_lens, tree_width, alibi_slopes)
torch.cuda.synchronize()
#print("tree attention duration:", time.time()-start_time)


ref_output = torch.empty_like(query)
ref_query_cached_kv_attention(
ref_output,
query,
num_queries_per_kv,
key_cache,
value_cache,
block_tables,
context_lens,
scale,
alibi_slopes,
prompt_lens,
tree_width
)
ref_query_cached_kv_attention(ref_output, query, num_queries_per_kv,
key_cache, value_cache, block_tables,
context_lens, scale, alibi_slopes,
prompt_lens, tree_width)

# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol = 1e-4
rtol = 2e-2

def diff(a, b):
print(((a-b).abs()/(b+1e-8)).mean())
print(((a - b).abs() / (b + 1e-8)).mean())

diff(output, ref_output)
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)

5 changes: 3 additions & 2 deletions vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ def forward_decode(
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if tree_width > 1:
tree_attention_fwd(query, output, key_cache, value_cache,
block_tables, context_lens, prompt_lens, tree_width, alibi_slopes)

block_tables, context_lens, prompt_lens,
tree_width, alibi_slopes)

elif use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
Expand Down
49 changes: 23 additions & 26 deletions vllm/attention/ops/tree_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _fwd_kernel(

cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_in_all_start_index = cur_batch * tree_width
cur_batch_prompt_len = tl.load(prompt_lens+cur_batch)
cur_batch_prompt_len = tl.load(prompt_lens + cur_batch)

block_start_loc = BLOCK_M * start_m

Expand All @@ -62,10 +62,7 @@ def _fwd_kernel(
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)

q = tl.load(
Q + off_q,
mask=offs_m[:, None] < tree_width,
other=0.0)
q = tl.load(Q + off_q, mask=offs_m[:, None] < tree_width, other=0.0)

# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
Expand Down Expand Up @@ -96,21 +93,23 @@ def _fwd_kernel(

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)

cur_step = start_n + offs_n[None, :] # [1, BlockN]
is_prompt = cur_step<cur_batch_prompt_len # [1, BlockN]
tree_mask = (cur_step - cur_batch_prompt_len - offs_m[:, None]) % tree_width == 0 # [1, BlockN] - [BlockM, 1] = [BlockM, BlockN]

cur_step = start_n + offs_n[None, :] # [1, BlockN]
is_prompt = cur_step < cur_batch_prompt_len # [1, BlockN]
tree_mask = (
cur_step - cur_batch_prompt_len - offs_m[:, None]
) % tree_width == 0 # [1, BlockN] - [BlockM, 1] = [BlockM, BlockN]
tree_mask = is_prompt or tree_mask
mask = tree_mask and (cur_step < cur_batch_ctx_len)

qk = tl.where(mask, qk, -3.4028234663852886e+38)

qk *= sm_scale

# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])

l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
Expand All @@ -125,8 +124,8 @@ def _fwd_kernel(
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
cur_step = start_n + offs_n[:, None] # (BlockN, 1)

cur_step = start_n + offs_n[:, None] # (BlockN, 1)

v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
Expand All @@ -142,11 +141,9 @@ def _fwd_kernel(
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)

out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < tree_width)
tl.store(out_ptrs, acc, mask=offs_m[:, None] < tree_width)
return

@triton.jit
Expand Down Expand Up @@ -371,14 +368,14 @@ def _fwd_kernel_alibi(

@torch.inference_mode()
def tree_attention_fwd(q,
o,
k_cache,
v_cache,
block_table,
context_len,
prompt_len,
tree_width,
alibi_slopes=None):
o,
k_cache,
v_cache,
block_table,
context_len,
prompt_len,
tree_width,
alibi_slopes=None):

cap = torch.cuda.get_device_capability()
BLOCK_N = 128 if cap[0] >= 8 else 64
Expand Down
3 changes: 1 addition & 2 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,8 +927,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None,
root_seq_id=seq_group.root_seq_id
)
root_seq_id=seq_group.root_seq_id)
seq_group_metadata_list.append(seq_group_metadata)

# Now that the batch has been created, we can assume all blocks in the
Expand Down
7 changes: 3 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,12 +472,11 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []


# In tree parallel decoding, all sequences within a sequence group
# are always inferred simultaneously, which results in the generation
# In tree parallel decoding, all sequences within a sequence group
# are always inferred simultaneously, which results in the generation
# of some extra tokens that need to be appended.
root_seq = seq_group.find(seq_group.root_seq_id)
for _ in range(len(seq_group.seqs_dict)-len(parent_seqs)):
for _ in range(len(seq_group.seqs_dict) - len(parent_seqs)):
root_seq._append_tokens_to_blocks([0])

# Process the child samples for each parent sequence
Expand Down
14 changes: 8 additions & 6 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,15 +287,17 @@ def append_token_id(
) -> None:
assert token_id in logprobs
seq_group = self.seq_group

def calc(seq_group):
num_token = 0
for block in seq_group.find(seq_group.root_seq_id).logical_token_blocks:
num_token = 0
for block in seq_group.find(
seq_group.root_seq_id).logical_token_blocks:
num_token += block.num_tokens
return num_token

# allocate block for root seq if sampling type is RANDOM_SEED and generate multiple sequence per prompt
if self.seq_id != seq_group.root_seq_id and seq_group.sampling_params.sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
if self.seq_id != seq_group.root_seq_id and seq_group.sampling_params.sampling_type in (
SamplingType.RANDOM, SamplingType.RANDOM_SEED):
buffer_seq = seq_group.find(self.seq_group.root_seq_id)
else:
buffer_seq = self
Expand Down Expand Up @@ -562,7 +564,7 @@ def is_finished(self) -> bool:
def is_prefill(self) -> bool:
# Every sequences should be in the same stage.
return self.get_seqs()[0].is_prefill()

def get_root(self) -> Sequence:
return self.find(self.root_seq_id)

Expand Down
Loading