Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[kernel optimize] benchmark write_req_to_token_pool_triton and optimize kernel #2509

Merged
merged 2 commits into from
Dec 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,345 @@
import itertools
import os
from typing import List

import numpy as np
import pytest
import torch
import triton
import triton.language as tl


@triton.jit
def write_req_to_token_pool_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)

req_pool_index = tl.load(req_pool_indices + pid)
pre_len = tl.load(pre_lens + pid)
seq_len = tl.load(seq_lens + pid)

# TODO: optimize this?
cumsum_start = 0
for i in range(pid):
cumsum_start += tl.load(extend_lens + i)

num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < (seq_len - pre_len)
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
tl.store(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ offset
+ pre_len,
value,
mask=mask,
)


@triton.jit
def write_req_to_token_pool_triton_optimize(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optmized kernel is here.

req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_to_token_ptr_stride: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_batch = tl.program_id(0)
pid_token = tl.program_id(1)

req_pool_index = tl.load(req_pool_indices + pid_batch)
pre_len = tl.load(pre_lens + pid_batch)
seq_len = tl.load(seq_lens + pid_batch)
extend_len = seq_len - pre_len

cumsum_start = 0
for i in range(pid_batch):
cumsum_start += tl.load(extend_lens + i)

token_start = pid_token * BLOCK_SIZE

offset = tl.arange(0, BLOCK_SIZE)
actual_offset = token_start + offset
mask = actual_offset < extend_len

src_ptr = out_cache_loc + cumsum_start + actual_offset
src_ptr = tl.max_contiguous(tl.multiple_of(src_ptr, BLOCK_SIZE), BLOCK_SIZE)
value = tl.load(src_ptr, mask=mask)
dst_ptr = (
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ actual_offset
+ pre_len
)
dst_ptr = tl.max_contiguous(tl.multiple_of(dst_ptr, BLOCK_SIZE), BLOCK_SIZE)

tl.store(dst_ptr, value, mask=mask)


def write_req_to_token_pool_reference(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
pre_lens: torch.Tensor,
seq_lens: torch.Tensor,
extend_lens: torch.Tensor,
out_cache_loc: torch.Tensor,
) -> None:
"""Reference implementation using PyTorch"""
for i in range(len(req_pool_indices)):
req_pool_idx = req_pool_indices[i].item()
pre_len = pre_lens[i].item()
seq_len = seq_lens[i].item()
extend_len = extend_lens[i].item()

cumsum_start = sum(extend_lens[:i].tolist())

# Copy values from out_cache_loc to req_to_token
req_to_token[req_pool_idx, pre_len:seq_len] = out_cache_loc[
cumsum_start : cumsum_start + extend_len
]


def test_write_req_to_token_pool():
max_batch = 4097
max_context_len = 6148
batch_size = 1
extend_len = 14

# Initialize input tensors
req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.tensor([42], dtype=torch.int32, device="cuda")
pre_lens = torch.tensor([8], dtype=torch.int32, device="cuda")
seq_lens = torch.tensor([22], dtype=torch.int32, device="cuda")
extend_lens = torch.tensor([extend_len], dtype=torch.int32, device="cuda")
out_cache_loc = torch.arange(extend_len, dtype=torch.int32, device="cuda")

# Create copies for reference implementation
req_to_token_ref = req_to_token.clone()
req_to_token_opt = req_to_token.clone()

# Run original triton kernel
write_req_to_token_pool_triton[(batch_size,)](
req_to_token,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
)

# Run optimized triton kernel
def grid(batch_size, extend_len):
num_token_blocks = triton.cdiv(extend_len, 512)
return (batch_size, num_token_blocks)

write_req_to_token_pool_triton_optimize[grid(batch_size, extend_len)](
req_to_token_opt,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
BLOCK_SIZE=512,
)

# Run reference implementation
write_req_to_token_pool_reference(
req_to_token_ref,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
)

# Compare results
torch.testing.assert_close(req_to_token, req_to_token_ref)
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)

# Test case 2: batch size > 1
batch_size = 3
extend_lens_list = [14, 20, 30]
total_extend_len = sum(extend_lens_list)

req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.tensor([42, 100, 200], dtype=torch.int32, device="cuda")
pre_lens = torch.tensor([8, 10, 15], dtype=torch.int32, device="cuda")
seq_lens = torch.tensor([22, 30, 45], dtype=torch.int32, device="cuda")
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")

req_to_token_ref = req_to_token.clone()
req_to_token_opt = req_to_token.clone()

# Run original triton kernel
write_req_to_token_pool_triton[(batch_size,)](
req_to_token,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
)

# Run optimized triton kernel
max_extend_len = max(extend_lens_list)
write_req_to_token_pool_triton_optimize[grid(batch_size, max_extend_len)](
req_to_token_opt,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
BLOCK_SIZE=512,
)

# Run reference implementation
write_req_to_token_pool_reference(
req_to_token_ref,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
)

# Compare results
torch.testing.assert_close(req_to_token, req_to_token_ref)
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)


def get_benchmark():
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128]
extend_lens = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
configs = list(itertools.product(batch_sizes, extend_lens))

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "extend_len"],
x_vals=configs,
line_arg="provider",
line_vals=["reference", "triton", "triton_optimize"],
line_names=["PyTorch", "Triton", "Triton Optimized"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name="write-req-to-token-pool-performance",
args={},
)
)
def benchmark(batch_size, extend_len, provider):
max_batch = 256
max_context_len = 16384

extend_lens_list = [extend_len] * batch_size
total_extend_len = sum(extend_lens_list)

req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.arange(batch_size, dtype=torch.int32, device="cuda")
pre_lens = torch.ones(batch_size, dtype=torch.int32, device="cuda") * 8
seq_lens = pre_lens + extend_len
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")

quantiles = [0.5, 0.2, 0.8]

if provider == "reference":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: write_req_to_token_pool_reference(
req_to_token.clone(),
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
),
quantiles=quantiles,
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: write_req_to_token_pool_triton[(batch_size,)](
req_to_token.clone(),
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
),
quantiles=quantiles,
)
else:

def run_optimized():
block_size = 128 if extend_len <= 1024 else 512
grid_config = (batch_size, triton.cdiv(extend_len, block_size))
write_req_to_token_pool_triton_optimize[grid_config](
req_to_token.clone(),
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
BLOCK_SIZE=block_size,
)

ms, min_ms, max_ms = triton.testing.do_bench(
run_optimized, quantiles=quantiles
)

return 1000 * ms, 1000 * max_ms, 1000 * min_ms

return benchmark


def run_benchmark(save_path: str = "./configs/benchmark_ops/write_req_to_token_pool/"):
"""Run benchmark and save results"""

# Ensure save path exists
os.makedirs(save_path, exist_ok=True)

# Run correctness test
test_write_req_to_token_pool()
print("Correctness test passed!")

# Run performance test
benchmark = get_benchmark()
benchmark.run(print_data=True, save_path=save_path)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/write_req_to_token_pool/",
help="Path to save benchmark results",
)
args = parser.parse_args()

run_benchmark(args.save_path)
Loading