Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add kernel config tuning way to get better performance. #681

Merged
merged 48 commits into from
Dec 25, 2024
Merged
Changes from 1 commit
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
b419eaf
add better moe kernel
Dec 18, 2024
94756cf
add better moe kernel
Dec 18, 2024
2d90a34
add grouped gemm for bf16 type
Dec 19, 2024
eac5df9
fix
Dec 19, 2024
35a2684
first ok grouped matmul moe kernel
Dec 19, 2024
6d020aa
fix
Dec 19, 2024
a9ce657
fix
Dec 19, 2024
edb8360
fix
Dec 20, 2024
637d4cf
fix tunning
Dec 20, 2024
c9b4be9
fix tunning
Dec 20, 2024
afc84bf
fix
Dec 20, 2024
4e6bf52
fix
Dec 22, 2024
9e66be4
fix chunck size setting.
Dec 22, 2024
4352845
add grid count auto calcu.
Dec 22, 2024
e28a183
fix
Dec 22, 2024
64c8171
fix
Dec 22, 2024
5cff966
fix
Dec 22, 2024
7d6176a
add pingpong fp8 gemm support.
Dec 23, 2024
4b57cf0
add moe fp8 support
Dec 23, 2024
4906443
fix
Dec 23, 2024
e67ee15
fix gqa_attetion_tuning code.
Dec 23, 2024
fb46535
fix
Dec 23, 2024
7a1dd04
fix
Dec 23, 2024
104ae2f
fix
Dec 23, 2024
df60e37
fix
Dec 23, 2024
1f3cccb
fix
Dec 23, 2024
1a28974
fix
Dec 23, 2024
6d6726e
fix
Dec 23, 2024
2f80d54
fix
Dec 23, 2024
e47e801
fix sb bug.
Dec 24, 2024
141108b
fix
Dec 24, 2024
827e343
add kernel tunning setting config.
Dec 24, 2024
fca5d23
all config store to one dir, and add store interface.
Dec 25, 2024
f177e7d
add tuning save demo.
Dec 25, 2024
693ed2a
fix fp8 tuning code.
Dec 25, 2024
2f9eda3
fix
Dec 25, 2024
0930bb6
fix
Dec 25, 2024
e8cec38
fix
Dec 25, 2024
883e52c
fix
Dec 25, 2024
bedc6df
fix
Dec 25, 2024
5a89cee
fix
Dec 25, 2024
2ea9a79
fix
Dec 25, 2024
4b66933
add multi process kernel tunning code.
Dec 25, 2024
fadca58
fix
Dec 25, 2024
b9aa955
fix
Dec 25, 2024
20b4b35
fix
Dec 25, 2024
cb2d78e
fix
Dec 25, 2024
88867b7
fix format.
Dec 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix
  • Loading branch information
wangzaijun committed Dec 25, 2024
commit 1f3cccb7bace328c955da147f5e280434b9f455c
368 changes: 368 additions & 0 deletions test/kernel/fuse_moe_tuning_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,368 @@
import torch
import time
import torch.multiprocessing as mp
from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl, moe_align, moe_align1, grouped_matmul
from typing import List
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


def set_seed():
import torch
import random
import numpy as np

seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
return


def quantize_moe(self, weight):
try:
HAS_VLLM = True
from lightllm.common.vllm_kernel import _custom_ops as ops
except:
HAS_VLLM = False

assert HAS_VLLM

num_experts = weight.shape[0]
qweights = []
weight_scales = []
qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda()
for i in range(num_experts):
qweight, weight_scale = ops.scaled_fp8_quant(
weight[i].contiguous().cuda(), scale=None, use_per_token_if_dynamic=False
)
qweights[i] = qweight
weight_scales.append(weight_scale)
weight_scale = torch.cat(weight_scales, dim=0).reshape(-1)
return qweights, weight_scale


@torch.no_grad()
def test_kernel(
expert_num: int,
m: int,
n: int,
k: int,
topk: int,
dtype: torch.dtype,
test_count: int,
use_fp8_w8a8: bool,
is_up: bool,
**config,
):
set_seed()
input_tuples = []

a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((expert_num, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((expert_num, k, n), device="cuda", dtype=dtype) / 10
rnd_logics = torch.randn(m, expert_num, device="cuda")
topk_values, topk_ids = torch.topk(rnd_logics, topk, dim=1)
topk_weights = torch.randn((m, topk), device="cuda", dtype=dtype) / 10

expert_to_tokens = torch.empty((expert_num, topk * m), dtype=torch.int32, device="cuda")
expert_to_weights = torch.empty((expert_num, topk * m), dtype=torch.float32, device="cuda")
moe_align(topk_ids=topk_ids, out=expert_to_tokens)
expert_to_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda")
moe_align1(expert_to_tokens, topk_weights, expert_to_weights, expert_to_token_num, topk=topk)
if use_fp8_w8a8:
w1, w1_scale = quantize_moe(w1)
w2, w2_scale = quantize_moe(w2)
else:
w1_scale = torch.empty((0,))
w2_scale = torch.empty((0,))

out1 = torch.zeros((m * topk, 2 * n), dtype=torch.bfloat16, device="cuda")
down_in = torch.zeros((m * topk, n), dtype=torch.bfloat16, device="cuda")
out2 = torch.zeros((m * topk, k), dtype=torch.bfloat16, device="cuda")

for _ in range(test_count):
input_tuples.append(
(
a.clone(),
w1.clone(),
w2.clone(),
w1_scale.clone(),
w2_scale.clone(),
topk_ids.clone(),
topk_weights.clone(),
out1.clone(),
out2.clone(),
down_in.clone(),
)
)

if is_up:
grouped_matmul(
a,
None,
expert_to_token_num,
expert_to_tokens,
expert_to_weights=expert_to_weights,
expert_weights=w1,
expert_to_weights_scale=w1_scale,
topk_num=topk,
out=out1,
expert_token_limit=2 ** 31 - 1,
mul_routed_weight=False,
use_fp8_w8a8=use_fp8_w8a8,
**config,
)
else:
grouped_matmul(
down_in,
None,
expert_to_token_num,
expert_to_tokens,
expert_to_weights=expert_to_weights,
expert_weights=w2,
expert_to_weights_scale=w2_scale,
topk_num=1,
out=out2,
expert_token_limit=2 ** 31 - 1,
mul_routed_weight=True,
use_fp8_w8a8=use_fp8_w8a8,
**config,
)

graph = torch.cuda.CUDAGraph()

with torch.cuda.graph(graph):
for index in range(test_count):
a, w1, w2, w1_scale, w2_scale, topk_ids, topk_weights, out1, out2, down_in = input_tuples[index]
if is_up:
grouped_matmul(
a,
None,
expert_to_token_num,
expert_to_tokens,
expert_to_weights=expert_to_weights,
expert_weights=w1,
expert_to_weights_scale=w1_scale,
topk_num=topk,
out=out1,
expert_token_limit=2 ** 31 - 1,
mul_routed_weight=False,
use_fp8_w8a8=use_fp8_w8a8,
**config,
)
else:
grouped_matmul(
down_in,
None,
expert_to_token_num,
expert_to_tokens,
expert_to_weights=expert_to_weights,
expert_weights=w2,
expert_to_weights_scale=w2_scale,
topk_num=1,
out=out2,
expert_token_limit=2 ** 31 - 1,
mul_routed_weight=True,
use_fp8_w8a8=use_fp8_w8a8,
**config,
)

graph.replay()

torch.cuda.synchronize()
start = time.time()
graph.replay()
torch.cuda.synchronize()

cost_time = (time.time() - start) * 1000

logger.info(str(config))
logger.info(f"bf16 {m} cost time: {cost_time} ms")
return cost_time


def worker(
expert_num: int,
m: int,
n: int,
k: int,
topk: int,
dtype: torch.dtype,
test_count: int,
use_fp8_w8a8: bool,
is_up: bool,
test_configs,
queue,
):
try:
for index in range(len(test_configs)):
cost_time = test_kernel(
expert_num=expert_num,
m=m,
n=n,
k=k,
topk=topk,
dtype=dtype,
test_count=test_count,
use_fp8_w8a8=use_fp8_w8a8,
is_up=is_up,
**test_configs[index],
)
queue.put(cost_time) # Put result in queue

except Exception as ex:
logger.error(str(ex))
logger.exception(str(ex))
import sys

sys.exit(-1)
pass


def get_test_configs():
# all_configs = []
for num_stages in [
1,
2,
3,
]:
for GROUP_SIZE_M in [
1,
2,
4,
]:
for num_warps in [
2,
4,
8,
]:
for BLOCK_SIZE_M in [
16,
32,
64,
]:
for BLOCK_SIZE_N in [16, 32, 64, 128]:
for BLOCK_SIZE_K in [16, 32, 64, 128]:
t_config = {
"BLOCK_SIZE_M": BLOCK_SIZE_M,
"BLOCK_SIZE_N": BLOCK_SIZE_N,
"BLOCK_SIZE_K": BLOCK_SIZE_K,
"GROUP_SIZE_M": GROUP_SIZE_M,
"num_warps": num_warps,
"num_stages": num_stages,
}
yield t_config
# all_configs.append(t_config)

# import random
# random.shuffle(all_configs)
# for t_config in all_configs:
# yield t_config


def tuning_configs(
expert_num: int,
m: int,
n: int,
k: int,
topk: int,
dtype: torch.dtype,
test_count: int,
use_fp8_w8a8: bool,
is_up: bool,
):

best_config, best_cost_time = None, 10000000
queue = mp.Queue()
test_configs = []
for t_config in get_test_configs():
test_configs.append(t_config)
if len(test_configs) < 256:
continue

p = mp.Process(
target=worker,
args=(
expert_num,
m,
n,
k,
topk,
dtype,
test_count,
use_fp8_w8a8,
is_up,
test_configs,
queue,
),
)
p.start()
p.join()
while len(test_configs) != 0:
try:
cost_time = queue.get_nowait()
logger.info(f"get {test_configs[0]} cost_time: {cost_time}")
if cost_time < best_cost_time:
best_config = test_configs[0]
best_cost_time = cost_time
logger.info(f"cur best : {best_config} {best_cost_time}")
del test_configs[0:1]
except:
del test_configs[0:16]
logger.info(f"cur best : {best_config} {best_cost_time}")
break

while len(test_configs) != 0:
p = mp.Process(
target=worker,
args=(
expert_num,
m,
n,
k,
topk,
dtype,
test_count,
use_fp8_w8a8,
is_up,
test_configs,
queue,
),
)
p.start()
p.join()

try:
cost_time = queue.get_nowait()
logger.info(f"get {test_configs[0]} cost_time: {cost_time}")
if cost_time < best_cost_time:
best_config = test_configs[0]
best_cost_time = cost_time
logger.info(f"cur best : {best_config} {best_cost_time}")
del test_configs[0:1]
except:
del test_configs[0:16]
logger.info(f"cur best : {best_config} {best_cost_time}")
break

logger.info(f"{best_config} best cost: {best_cost_time}")


if __name__ == "__main__":
tuning_configs(
expert_num=64,
m=200,
n=1408 // 2,
k=2048,
topk=6,
dtype=torch.bfloat16,
test_count=8,
use_fp8_w8a8=False,
is_up=True,
)
pass