-
Notifications
You must be signed in to change notification settings - Fork 614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[kernel optimize] benchmark write_req_to_token_pool_triton and optimize kernel #2509
Merged
Merged
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
347 changes: 347 additions & 0 deletions
347
benchmark/kernels/scheduler_batch/benchmark_write_req_to_token_pool_triton.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,347 @@ | ||
import itertools | ||
import os | ||
from typing import List | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
|
||
@triton.jit | ||
def write_req_to_token_pool_triton( | ||
req_to_token_ptr, # [max_batch, max_context_len] | ||
req_pool_indices, | ||
pre_lens, | ||
seq_lens, | ||
extend_lens, | ||
out_cache_loc, | ||
req_to_token_ptr_stride: tl.constexpr, | ||
): | ||
BLOCK_SIZE: tl.constexpr = 512 | ||
pid = tl.program_id(0) | ||
|
||
req_pool_index = tl.load(req_pool_indices + pid) | ||
pre_len = tl.load(pre_lens + pid) | ||
seq_len = tl.load(seq_lens + pid) | ||
|
||
# TODO: optimize this? | ||
cumsum_start = 0 | ||
for i in range(pid): | ||
cumsum_start += tl.load(extend_lens + i) | ||
|
||
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE) | ||
for i in range(num_loop): | ||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE | ||
mask = offset < (seq_len - pre_len) | ||
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask) | ||
tl.store( | ||
req_to_token_ptr | ||
+ req_pool_index * req_to_token_ptr_stride | ||
+ offset | ||
+ pre_len, | ||
value, | ||
mask=mask, | ||
) | ||
|
||
|
||
@triton.jit | ||
def write_req_to_token_pool_triton_optimize( | ||
req_to_token_ptr, # [max_batch, max_context_len] | ||
req_pool_indices, | ||
pre_lens, | ||
seq_lens, | ||
extend_lens, | ||
out_cache_loc, | ||
req_to_token_ptr_stride: tl.constexpr, | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
pid_batch = tl.program_id(0) | ||
pid_token = tl.program_id(1) | ||
|
||
req_pool_index = tl.load(req_pool_indices + pid_batch) | ||
pre_len = tl.load(pre_lens + pid_batch) | ||
seq_len = tl.load(seq_lens + pid_batch) | ||
extend_len = seq_len - pre_len | ||
|
||
cumsum_start = 0 | ||
for i in range(pid_batch): | ||
cumsum_start += tl.load(extend_lens + i) | ||
|
||
token_start = pid_token * BLOCK_SIZE | ||
|
||
offset = tl.arange(0, BLOCK_SIZE) | ||
actual_offset = token_start + offset | ||
mask = actual_offset < extend_len | ||
|
||
src_ptr = out_cache_loc + cumsum_start + actual_offset | ||
src_ptr = tl.max_contiguous(tl.multiple_of(src_ptr, BLOCK_SIZE), BLOCK_SIZE) | ||
value = tl.load(src_ptr, mask=mask) | ||
dst_ptr = ( | ||
req_to_token_ptr | ||
+ req_pool_index * req_to_token_ptr_stride | ||
+ actual_offset | ||
+ pre_len | ||
) | ||
dst_ptr = tl.max_contiguous(tl.multiple_of(dst_ptr, BLOCK_SIZE), BLOCK_SIZE) | ||
|
||
tl.store(dst_ptr, value, mask=mask) | ||
|
||
|
||
def write_req_to_token_pool_reference( | ||
req_to_token: torch.Tensor, | ||
req_pool_indices: torch.Tensor, | ||
pre_lens: torch.Tensor, | ||
seq_lens: torch.Tensor, | ||
extend_lens: torch.Tensor, | ||
out_cache_loc: torch.Tensor, | ||
) -> None: | ||
"""Reference implementation using PyTorch""" | ||
for i in range(len(req_pool_indices)): | ||
req_pool_idx = req_pool_indices[i].item() | ||
pre_len = pre_lens[i].item() | ||
seq_len = seq_lens[i].item() | ||
extend_len = extend_lens[i].item() | ||
|
||
cumsum_start = sum(extend_lens[:i].tolist()) | ||
|
||
# Copy values from out_cache_loc to req_to_token | ||
req_to_token[req_pool_idx, pre_len:seq_len] = out_cache_loc[ | ||
cumsum_start : cumsum_start + extend_len | ||
] | ||
|
||
|
||
def test_write_req_to_token_pool(): | ||
max_batch = 4097 | ||
max_context_len = 6148 | ||
batch_size = 1 | ||
extend_len = 14 | ||
|
||
# Initialize input tensors | ||
req_to_token = torch.zeros( | ||
(max_batch, max_context_len), dtype=torch.int32, device="cuda" | ||
) | ||
req_pool_indices = torch.tensor([42], dtype=torch.int32, device="cuda") | ||
pre_lens = torch.tensor([8], dtype=torch.int32, device="cuda") | ||
seq_lens = torch.tensor([22], dtype=torch.int32, device="cuda") | ||
extend_lens = torch.tensor([extend_len], dtype=torch.int32, device="cuda") | ||
out_cache_loc = torch.arange(extend_len, dtype=torch.int32, device="cuda") | ||
|
||
# Create copies for reference implementation | ||
req_to_token_ref = req_to_token.clone() | ||
req_to_token_opt = req_to_token.clone() | ||
|
||
# Run original triton kernel | ||
write_req_to_token_pool_triton[(batch_size,)]( | ||
req_to_token, | ||
req_pool_indices, | ||
pre_lens, | ||
seq_lens, | ||
extend_lens, | ||
out_cache_loc, | ||
max_context_len, | ||
) | ||
|
||
# Run optimized triton kernel | ||
def grid(batch_size, extend_len): | ||
num_token_blocks = triton.cdiv( | ||
extend_len, 512 | ||
) | ||
return (batch_size, num_token_blocks) | ||
|
||
write_req_to_token_pool_triton_optimize[grid(batch_size, extend_len)]( | ||
req_to_token_opt, | ||
req_pool_indices, | ||
pre_lens, | ||
seq_lens, | ||
extend_lens, | ||
out_cache_loc, | ||
max_context_len, | ||
BLOCK_SIZE=512, | ||
) | ||
|
||
# Run reference implementation | ||
write_req_to_token_pool_reference( | ||
req_to_token_ref, | ||
req_pool_indices, | ||
pre_lens, | ||
seq_lens, | ||
extend_lens, | ||
out_cache_loc, | ||
) | ||
|
||
# Compare results | ||
torch.testing.assert_close(req_to_token, req_to_token_ref) | ||
torch.testing.assert_close(req_to_token_opt, req_to_token_ref) | ||
|
||
# Test case 2: batch size > 1 | ||
batch_size = 3 | ||
extend_lens_list = [14, 20, 30] | ||
total_extend_len = sum(extend_lens_list) | ||
|
||
req_to_token = torch.zeros( | ||
(max_batch, max_context_len), dtype=torch.int32, device="cuda" | ||
) | ||
req_pool_indices = torch.tensor([42, 100, 200], dtype=torch.int32, device="cuda") | ||
pre_lens = torch.tensor([8, 10, 15], dtype=torch.int32, device="cuda") | ||
seq_lens = torch.tensor([22, 30, 45], dtype=torch.int32, device="cuda") | ||
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda") | ||
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda") | ||
|
||
req_to_token_ref = req_to_token.clone() | ||
req_to_token_opt = req_to_token.clone() | ||
|
||
# Run original triton kernel | ||
write_req_to_token_pool_triton[(batch_size,)]( | ||
req_to_token, | ||
req_pool_indices, | ||
pre_lens, | ||
seq_lens, | ||
extend_lens, | ||
out_cache_loc, | ||
max_context_len, | ||
) | ||
|
||
# Run optimized triton kernel | ||
max_extend_len = max(extend_lens_list) | ||
write_req_to_token_pool_triton_optimize[grid(batch_size, max_extend_len)]( | ||
req_to_token_opt, | ||
req_pool_indices, | ||
pre_lens, | ||
seq_lens, | ||
extend_lens, | ||
out_cache_loc, | ||
max_context_len, | ||
BLOCK_SIZE=512, | ||
) | ||
|
||
# Run reference implementation | ||
write_req_to_token_pool_reference( | ||
req_to_token_ref, | ||
req_pool_indices, | ||
pre_lens, | ||
seq_lens, | ||
extend_lens, | ||
out_cache_loc, | ||
) | ||
|
||
# Compare results | ||
torch.testing.assert_close(req_to_token, req_to_token_ref) | ||
torch.testing.assert_close(req_to_token_opt, req_to_token_ref) | ||
|
||
|
||
def get_benchmark(): | ||
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128] | ||
extend_lens = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] | ||
configs = list(itertools.product(batch_sizes, extend_lens)) | ||
|
||
@triton.testing.perf_report( | ||
triton.testing.Benchmark( | ||
x_names=["batch_size", "extend_len"], | ||
x_vals=configs, | ||
line_arg="provider", | ||
line_vals=["reference", "triton", "triton_optimize"], | ||
line_names=["PyTorch", "Triton", "Triton Optimized"], | ||
styles=[("blue", "-"), ("green", "-"), ("red", "-")], | ||
ylabel="us", | ||
plot_name="write-req-to-token-pool-performance", | ||
args={}, | ||
) | ||
) | ||
def benchmark(batch_size, extend_len, provider): | ||
max_batch = 256 | ||
max_context_len = 16384 | ||
|
||
extend_lens_list = [extend_len] * batch_size | ||
total_extend_len = sum(extend_lens_list) | ||
|
||
req_to_token = torch.zeros( | ||
(max_batch, max_context_len), dtype=torch.int32, device="cuda" | ||
) | ||
req_pool_indices = torch.arange(batch_size, dtype=torch.int32, device="cuda") | ||
pre_lens = torch.ones(batch_size, dtype=torch.int32, device="cuda") * 8 | ||
seq_lens = pre_lens + extend_len | ||
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda") | ||
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda") | ||
|
||
quantiles = [0.5, 0.2, 0.8] | ||
|
||
if provider == "reference": | ||
ms, min_ms, max_ms = triton.testing.do_bench( | ||
lambda: write_req_to_token_pool_reference( | ||
req_to_token.clone(), | ||
req_pool_indices, | ||
pre_lens, | ||
seq_lens, | ||
extend_lens, | ||
out_cache_loc, | ||
), | ||
quantiles=quantiles, | ||
) | ||
elif provider == "triton": | ||
ms, min_ms, max_ms = triton.testing.do_bench( | ||
lambda: write_req_to_token_pool_triton[(batch_size,)]( | ||
req_to_token.clone(), | ||
req_pool_indices, | ||
pre_lens, | ||
seq_lens, | ||
extend_lens, | ||
out_cache_loc, | ||
max_context_len, | ||
), | ||
quantiles=quantiles, | ||
) | ||
else: | ||
|
||
def run_optimized(): | ||
block_size = 128 if extend_len <= 1024 else 512 | ||
grid_config = (batch_size, triton.cdiv(extend_len, block_size)) | ||
write_req_to_token_pool_triton_optimize[grid_config]( | ||
req_to_token.clone(), | ||
req_pool_indices, | ||
pre_lens, | ||
seq_lens, | ||
extend_lens, | ||
out_cache_loc, | ||
max_context_len, | ||
BLOCK_SIZE=block_size, | ||
) | ||
|
||
ms, min_ms, max_ms = triton.testing.do_bench( | ||
run_optimized, quantiles=quantiles | ||
) | ||
|
||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms | ||
|
||
return benchmark | ||
|
||
|
||
def run_benchmark(save_path: str = "./configs/benchmark_ops/write_req_to_token_pool/"): | ||
"""Run benchmark and save results""" | ||
|
||
# Ensure save path exists | ||
os.makedirs(save_path, exist_ok=True) | ||
|
||
# Run correctness test | ||
test_write_req_to_token_pool() | ||
print("Correctness test passed!") | ||
|
||
# Run performance test | ||
benchmark = get_benchmark() | ||
benchmark.run(print_data=True, save_path=save_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--save_path", | ||
type=str, | ||
default="./configs/benchmark_ops/write_req_to_token_pool/", | ||
help="Path to save benchmark results", | ||
) | ||
args = parser.parse_args() | ||
|
||
run_benchmark(args.save_path) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optmized kernel is here.