[Kimi K2] dsv3_router_gemm supports NUM_EXPERTS == 384 (#8013)

This commit is contained in:
Peter Pan
2025-08-01 22:01:24 +08:00
committed by GitHub
parent 46e9d1c7c1
commit 6bdd27861b
5 changed files with 188 additions and 30 deletions

View File

@@ -13,9 +13,14 @@ from sgl_kernel import dsv3_router_gemm
x_vals=[i + 1 for i in range(16)],
x_log=False,
line_arg="impl",
line_vals=["torch", "sgl-kernel"],
line_names=["torch", "dsv3_router_gemm"],
styles=[("blue", "-"), ("orange", "-")],
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
line_names=[
"torch-256",
"dsv3_router_gemm-256",
"torch-384",
"dsv3_router_gemm-384",
],
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
ylabel="TFLOPs",
plot_name="input-bf16-output-bf16 dsv3 router gemm throughput",
args={},
@@ -23,19 +28,26 @@ from sgl_kernel import dsv3_router_gemm
)
def benchmark_bf16_output(num_tokens, impl):
# M: num_tokens, K: hidden_dim, N: num_experts
M, K, N = num_tokens, 7168, 256
M, K = num_tokens, 7168
if impl == "torch-256" or impl == "sgl-kernel-256":
N = 256
elif impl == "torch-384" or impl == "sgl-kernel-384":
N = 384
else:
raise ValueError(f"Unknown impl: {impl}")
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()
quantiles = [0.5, 0.2, 0.8]
if impl == "torch":
if impl == "torch-256" or impl == "torch-384":
def runner():
F.linear(mat_a, mat_b)
elif impl == "sgl-kernel":
elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384":
def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
@@ -55,9 +67,14 @@ def benchmark_bf16_output(num_tokens, impl):
x_vals=[i + 1 for i in range(16)],
x_log=False,
line_arg="impl",
line_vals=["torch", "sgl-kernel"],
line_names=["torch", "dsv3_router_gemm"],
styles=[("blue", "-"), ("orange", "-")],
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
line_names=[
"torch-256",
"dsv3_router_gemm-256",
"torch-384",
"dsv3_router_gemm-384",
],
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
ylabel="TFLOPs",
plot_name="input-bf16-output-fp32 dsv3 router gemm throughput",
args={},
@@ -65,19 +82,26 @@ def benchmark_bf16_output(num_tokens, impl):
)
def benchmark_float_output(num_tokens, impl):
# M: num_tokens, K: hidden_dim, N: num_experts
M, K, N = num_tokens, 7168, 256
M, K = num_tokens, 7168
if impl == "torch-256" or impl == "sgl-kernel-256":
N = 256
elif impl == "torch-384" or impl == "sgl-kernel-384":
N = 384
else:
raise ValueError(f"Unknown impl: {impl}")
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()
quantiles = [0.5, 0.2, 0.8]
if impl == "torch":
if impl == "torch-256" or impl == "torch-384":
def runner():
F.linear(mat_a, mat_b).to(torch.float32)
elif impl == "sgl-kernel":
elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384":
def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)