Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[kernel optimize] benchmark write_req_to_token_pool_triton and optimize kernel #2509

Merged
merged 2 commits into from
Dec 22, 2024

Conversation

BBuf
Copy link
Collaborator

@BBuf BBuf commented Dec 18, 2024

  • add a correctless test for write_req_to_token_pool_triton
  • optimize write_req_to_token_pool_triton
  • add benchmark test for write_req_to_token_pool_triton

Peformance with write_req_to_token_pool_triton_optimize in 4090:

Correctness test passed!
write-req-to-token-pool-performance:
    batch_size  extend_len      PyTorch     Triton  Triton Optimized
    batch_size  extend_len      PyTorch     Triton  Triton Optimized
0          1.0        32.0    93.184002  47.104001         46.080001
1          1.0        64.0    93.184002  46.080001         46.080001
2          1.0       128.0    93.184002  46.080001         46.080001
3          1.0       256.0    94.208002  46.080001         46.080001
4          1.0       512.0    92.160001  46.080001         46.080001
5          1.0      1024.0    93.184002  47.104001         46.080001
6          1.0      2048.0    93.184002  47.104001         46.080001
7          1.0      4096.0    93.184002  **48.128001**         46.080001
8          1.0      8192.0    93.184002  **50.175998**         46.080001
9          2.0        32.0   141.312003  46.080001         46.080001
10         2.0        64.0   143.360004  46.080001         46.080001
11         2.0       128.0   142.335996  46.080001         46.080001
12         2.0       256.0   143.360004  47.104001         46.080001
13         2.0       512.0   141.312003  46.080001         46.080001
14         2.0      1024.0   141.312003  **47.104001**         46.080001
15         2.0      2048.0   144.383997  **47.104001**         46.080001
16         2.0      4096.0   144.383997  **49.152002**         46.080001
17         2.0      8192.0   142.335996  51.199999         47.104001
18         4.0        32.0   240.640000  46.080001         46.080001
19         4.0        64.0   241.664007  46.080001         46.080001
20         4.0       128.0   243.711993  46.080001         46.080001
21         4.0       256.0   242.688000  47.104001         46.080001
22         4.0       512.0   247.807994  46.592001         46.080001
23         4.0      1024.0   247.807994  47.104001         47.104001
24         4.0      2048.0   243.711993  47.104001         47.104001
25         4.0      4096.0   246.784002  **49.152002**         47.104001
26         4.0      8192.0   250.880003  **51.199999**         47.104001
27         8.0        32.0   439.296007  46.080001         46.080001
28         8.0        64.0   436.224014  46.080001         46.080001
29         8.0       128.0   435.200006  47.104001         46.080001
30         8.0       256.0   437.247992  46.080001         46.080001
31         8.0       512.0   445.439994  47.104001         46.080001
32         8.0      1024.0   434.175998  47.104001         47.104001
33         8.0      2048.0   434.175998  **48.128001**         47.104001
34         8.0      4096.0   431.104004  **49.152002**         47.104001
35         8.0      8192.0   430.079997  **51.199999**         47.104001
36        16.0        32.0   839.680016  47.104001         47.104001
37        16.0        64.0   844.799995  47.104001         47.104001
38        16.0       128.0   821.247995  47.104001         47.104001
39        16.0       256.0   818.175972  47.104001         47.104001
40        16.0       512.0   830.464005  47.104001         47.104001
41        16.0      1024.0   821.247995  47.104001         47.104001
42        16.0      2048.0   827.391982  48.128001         47.104001
43        16.0      4096.0   829.439998  **49.152002**         48.128001
44        16.0      8192.0   845.824003  **52.223999**         48.128001
45        32.0        32.0  1599.488020  47.104001         47.104001
46        32.0        64.0  1629.696012  47.104001         47.104001
47        32.0       128.0  1596.415997  47.104001         47.104001
48        32.0       256.0  1615.872025  47.104001         47.104001
49        32.0       512.0  1597.440004  47.104001         47.104001
50        32.0      1024.0  1592.319965  48.128001         48.128001
51        32.0      2048.0  1616.896033  **49.152002**         48.128001
52        32.0      4096.0  1600.000024  **50.175998**         49.152002
53        32.0      8192.0  1595.391989  **54.272000**         50.175998
54        64.0        32.0  3230.207920  48.128001         48.128001
55        64.0        64.0  3264.512062  48.128001         48.128001
56        64.0       128.0  3148.799896  48.128001         48.128001
57        64.0       256.0  3176.448107  48.128001         48.128001
58        64.0       512.0  3241.983891  48.128001         48.128001
59        64.0      1024.0  3195.904016  49.152002         49.152002
60        64.0      2048.0  3179.519892  **50.175998**         49.152002
61        64.0      4096.0  3219.455957  **53.247999**         50.175998
62        64.0      8192.0  3189.759970  **58.368001**         53.247999
63       128.0        32.0  6406.144142  50.175998         50.175998
64       128.0        64.0  6292.479992  50.175998         50.175998
65       128.0       128.0  6459.392071  50.175998         50.175998
66       128.0       256.0  6463.487625  50.175998         50.175998
67       128.0       512.0  6325.247765  50.175998         50.175998
68       128.0      1024.0  6372.352123  51.199999         51.199999
69       128.0      2048.0  6453.760147  **53.247999**         52.223999
70       128.0      4096.0  6390.783787  **57.344001**         55.296000
71       128.0      8192.0  6376.448154  **65.536000**         60.416002

From the highlighted data, we can observe that the optimized kernel achieves lower latency when batch size or sequence length increases.

From the ncu profile result below, we can see that the write_req_to_token_pool_triton kernel only parallelizes along the batch dimension. This leads to underutilization of SMs when the batch size is small. This issue can be mitigated by introducing parallelism along the token dimension. Additionally, since this kernel is memory-bound and only involves read/write operations, we can further optimize it by improving memory coalescing. The benchmark results above show that the optimized kernel achieves some performance gains when increasing batch size or sequence length. In most cases, this kernel doesn't present significant performance bottlenecks, so I believe a CUDA implementation isn't necessary.

图片



@triton.jit
def write_req_to_token_pool_triton_optimize(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optmized kernel is here.

@merrymercy
Copy link
Contributor

Good work! Do we want to use this one in the sglang code?

@merrymercy merrymercy merged commit 7d672d2 into sgl-project:main Dec 22, 2024
1 check passed
@BBuf
Copy link
Collaborator Author

BBuf commented Dec 22, 2024

h100 result:

Correctness test passed!
write-req-to-token-pool-performance:
    batch_size  extend_len       PyTorch     Triton  Triton Optimized
0          1.0        32.0    119.135998  23.375999         23.200000
1          1.0        64.0    115.584001  23.296000         23.264000
2          1.0       128.0    115.712002  23.296000         23.296000
3          1.0       256.0    118.192002  23.296000         23.264000
4          1.0       512.0    116.159998  23.328001         23.264000
5          1.0      1024.0    116.319999  23.808001         23.424000
6          1.0      2048.0    117.632002  24.863999         23.472000
7          1.0      4096.0    116.576001  26.400000         23.488000
8          1.0      8192.0    116.287999  29.696001         23.520000
9          2.0        32.0    208.352000  23.328001         23.264000
10         2.0        64.0    202.368006  23.360001         23.296000
11         2.0       128.0    202.368006  23.360001         23.328001
12         2.0       256.0    206.032008  23.552001         23.488000
13         2.0       512.0    202.528000  23.520000         23.456000
14         2.0      1024.0    203.312010  23.968000         23.488000
15         2.0      2048.0    202.207997  24.863999         23.520000
16         2.0      4096.0    202.848002  26.528001         23.584001
17         2.0      8192.0    203.776002  30.112000         23.647999
18         4.0        32.0    370.848000  23.440000         23.360001
19         4.0        64.0    369.695991  23.391999         23.360001
20         4.0       128.0    371.312022  23.391999         23.328001
21         4.0       256.0    374.336004  23.680000         23.647999
22         4.0       512.0    372.512013  23.520000         23.488000
23         4.0      1024.0    374.143988  24.032000         23.552001
24         4.0      2048.0    373.856008  24.896000         23.808001
25         4.0      4096.0    386.960000  26.720000         23.664000
26         4.0      8192.0    371.840000  30.208001         23.680000
27         8.0        32.0    715.615988  23.456000         23.391999
28         8.0        64.0    721.055984  23.552001         23.391999
29         8.0       128.0    713.295996  23.584001         23.488000
30         8.0       256.0    717.760026  23.584001         23.520000
31         8.0       512.0    718.768001  23.776000         23.712000
32         8.0      1024.0    718.912005  24.032000         23.647999
33         8.0      2048.0    715.631962  25.024001         23.680000
34         8.0      4096.0    717.119992  26.752001         23.871999
35         8.0      8192.0    716.256022  30.304000         23.968000
36        16.0        32.0   1385.951996  23.776000         23.744000
37        16.0        64.0   1381.551981  23.744000         23.871999
38        16.0       128.0   1440.863967  23.776000         23.696000
39        16.0       256.0   1383.343935  23.808001         23.808001
40        16.0       512.0   1396.191955  23.808001         23.808001
41        16.0      1024.0   1437.088013  24.351999         23.840001
42        16.0      2048.0   1401.376009  25.264001         23.903999
43        16.0      4096.0   1388.319969  27.039999         24.095999
44        16.0      8192.0   1422.943950  30.751999         24.224000
45        32.0        32.0   2731.424093  24.256000         24.192000
46        32.0        64.0   2727.488041  24.224000         24.224000
47        32.0       128.0   2727.711916  24.480000         24.288001
48        32.0       256.0   2741.024017  24.351999         24.288001
49        32.0       512.0   2735.775948  24.400000         24.320001
50        32.0      1024.0   2735.008001  25.024001         24.448000
51        32.0      2048.0   2738.368034  25.952000         24.544001
52        32.0      4096.0   2728.832006  27.744001         24.704000
53        32.0      8192.0   2729.279995  31.520002         25.280001
54        64.0        32.0   5473.824024  25.312001         25.327999
55        64.0        64.0   5424.543858  25.343999         25.343999
56        64.0       128.0   5662.623882  25.504000         25.376000
57        64.0       256.0   5442.080021  25.440000         25.376000
58        64.0       512.0   5496.895790  25.472000         25.408000
59        64.0      1024.0   5455.935955  26.079999         25.536001
60        64.0      2048.0   5477.119923  27.200000         25.760001
61        64.0      4096.0   5496.160030  29.216001         26.048001
62        64.0      8192.0   5472.304344  33.216000         26.559999
63       128.0        32.0  11052.864075  27.616000         27.519999
64       128.0        64.0  10997.056007  27.712001         27.519999
65       128.0       128.0  11045.984268  27.584000         27.616000
66       128.0       256.0  11019.136429  27.680000         27.488001
67       128.0       512.0  11424.672127  27.872000         27.775999
68       128.0      1024.0  11234.096527  28.224001         27.744001
69       128.0      2048.0  11017.151833  29.408000         27.968001
70       128.0      4096.0  10965.391159  31.679999         28.767999
71       128.0      8192.0  11565.328598  36.352001         29.440001

@BBuf
Copy link
Collaborator Author

BBuf commented Dec 22, 2024

Good work! Do we want to use this one in the sglang code?

I will think about whether further optimization is possible. For now, we don't need to apply it to SGLang. The current optimized version of the Triton kernel only shows a performance gain of around 10% when the batch size (bs) and context length increase.

chosen-ox pushed a commit to chosen-ox/sglang that referenced this pull request Dec 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants