-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[kernel optimize] benchmark write_req_to_token_pool_triton and optimize kernel #2509
Conversation
|
||
|
||
@triton.jit | ||
def write_req_to_token_pool_triton_optimize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optmized kernel is here.
Good work! Do we want to use this one in the sglang code? |
h100 result: Correctness test passed!
write-req-to-token-pool-performance:
batch_size extend_len PyTorch Triton Triton Optimized
0 1.0 32.0 119.135998 23.375999 23.200000
1 1.0 64.0 115.584001 23.296000 23.264000
2 1.0 128.0 115.712002 23.296000 23.296000
3 1.0 256.0 118.192002 23.296000 23.264000
4 1.0 512.0 116.159998 23.328001 23.264000
5 1.0 1024.0 116.319999 23.808001 23.424000
6 1.0 2048.0 117.632002 24.863999 23.472000
7 1.0 4096.0 116.576001 26.400000 23.488000
8 1.0 8192.0 116.287999 29.696001 23.520000
9 2.0 32.0 208.352000 23.328001 23.264000
10 2.0 64.0 202.368006 23.360001 23.296000
11 2.0 128.0 202.368006 23.360001 23.328001
12 2.0 256.0 206.032008 23.552001 23.488000
13 2.0 512.0 202.528000 23.520000 23.456000
14 2.0 1024.0 203.312010 23.968000 23.488000
15 2.0 2048.0 202.207997 24.863999 23.520000
16 2.0 4096.0 202.848002 26.528001 23.584001
17 2.0 8192.0 203.776002 30.112000 23.647999
18 4.0 32.0 370.848000 23.440000 23.360001
19 4.0 64.0 369.695991 23.391999 23.360001
20 4.0 128.0 371.312022 23.391999 23.328001
21 4.0 256.0 374.336004 23.680000 23.647999
22 4.0 512.0 372.512013 23.520000 23.488000
23 4.0 1024.0 374.143988 24.032000 23.552001
24 4.0 2048.0 373.856008 24.896000 23.808001
25 4.0 4096.0 386.960000 26.720000 23.664000
26 4.0 8192.0 371.840000 30.208001 23.680000
27 8.0 32.0 715.615988 23.456000 23.391999
28 8.0 64.0 721.055984 23.552001 23.391999
29 8.0 128.0 713.295996 23.584001 23.488000
30 8.0 256.0 717.760026 23.584001 23.520000
31 8.0 512.0 718.768001 23.776000 23.712000
32 8.0 1024.0 718.912005 24.032000 23.647999
33 8.0 2048.0 715.631962 25.024001 23.680000
34 8.0 4096.0 717.119992 26.752001 23.871999
35 8.0 8192.0 716.256022 30.304000 23.968000
36 16.0 32.0 1385.951996 23.776000 23.744000
37 16.0 64.0 1381.551981 23.744000 23.871999
38 16.0 128.0 1440.863967 23.776000 23.696000
39 16.0 256.0 1383.343935 23.808001 23.808001
40 16.0 512.0 1396.191955 23.808001 23.808001
41 16.0 1024.0 1437.088013 24.351999 23.840001
42 16.0 2048.0 1401.376009 25.264001 23.903999
43 16.0 4096.0 1388.319969 27.039999 24.095999
44 16.0 8192.0 1422.943950 30.751999 24.224000
45 32.0 32.0 2731.424093 24.256000 24.192000
46 32.0 64.0 2727.488041 24.224000 24.224000
47 32.0 128.0 2727.711916 24.480000 24.288001
48 32.0 256.0 2741.024017 24.351999 24.288001
49 32.0 512.0 2735.775948 24.400000 24.320001
50 32.0 1024.0 2735.008001 25.024001 24.448000
51 32.0 2048.0 2738.368034 25.952000 24.544001
52 32.0 4096.0 2728.832006 27.744001 24.704000
53 32.0 8192.0 2729.279995 31.520002 25.280001
54 64.0 32.0 5473.824024 25.312001 25.327999
55 64.0 64.0 5424.543858 25.343999 25.343999
56 64.0 128.0 5662.623882 25.504000 25.376000
57 64.0 256.0 5442.080021 25.440000 25.376000
58 64.0 512.0 5496.895790 25.472000 25.408000
59 64.0 1024.0 5455.935955 26.079999 25.536001
60 64.0 2048.0 5477.119923 27.200000 25.760001
61 64.0 4096.0 5496.160030 29.216001 26.048001
62 64.0 8192.0 5472.304344 33.216000 26.559999
63 128.0 32.0 11052.864075 27.616000 27.519999
64 128.0 64.0 10997.056007 27.712001 27.519999
65 128.0 128.0 11045.984268 27.584000 27.616000
66 128.0 256.0 11019.136429 27.680000 27.488001
67 128.0 512.0 11424.672127 27.872000 27.775999
68 128.0 1024.0 11234.096527 28.224001 27.744001
69 128.0 2048.0 11017.151833 29.408000 27.968001
70 128.0 4096.0 10965.391159 31.679999 28.767999
71 128.0 8192.0 11565.328598 36.352001 29.440001 |
I will think about whether further optimization is possible. For now, we don't need to apply it to SGLang. The current optimized version of the Triton kernel only shows a performance gain of around 10% when the batch size (bs) and context length increase. |
write_req_to_token_pool_triton
write_req_to_token_pool_triton
write_req_to_token_pool_triton
Peformance with
write_req_to_token_pool_triton_optimize
in 4090:From the ncu profile result below, we can see that the
write_req_to_token_pool_triton
kernel only parallelizes along the batch dimension. This leads to underutilization of SMs when the batch size is small. This issue can be mitigated by introducing parallelism along the token dimension. Additionally, since this kernel is memory-bound and only involves read/write operations, we can further optimize it by improving memory coalescing. The benchmark results above show that the optimized kernel achieves some performance gains when increasing batch size or sequence length. In most cases, this kernel doesn't present significant performance bottlenecks, so I believe a CUDA implementation isn't necessary.