[sgl-kernel] Support moe_sum_reduce cuda kernel (#10321)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
This commit is contained in:
Yuan Luo
2025-09-19 14:12:09 +08:00
committed by GitHub
parent ac2a723bb3
commit 616a3e20df
7 changed files with 346 additions and 10 deletions

View File

@@ -1,6 +1,7 @@
import torch
import triton
import triton.language as tl
from sgl_kernel import moe_sum_reduce as moe_sum_reduce_cuda
from triton.testing import do_bench
@@ -57,7 +58,7 @@ def _moe_sum_reduce_kernel(
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
def moe_sum_reduce(
def moe_sum_reduce_triton(
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
):
assert input.is_contiguous()
@@ -117,9 +118,9 @@ def get_benchmark():
x_names=["num_tokens"],
x_vals=num_tokens_range,
line_arg="version",
line_vals=["baseline", "compiled", "triton"],
line_names=["Original", "TorchCompile", "TritonKernel"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
line_vals=["baseline", "compiled", "triton", "cuda"],
line_names=["Original", "TorchCompile", "TritonKernel", "CudaKernel"],
styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("yellow", "-")],
ylabel="us",
plot_name="sum_scaled_performance",
args={},
@@ -140,8 +141,10 @@ def get_benchmark():
compute_sum_scaled_baseline(x, out, scaling_factor)
elif version == "compiled":
compute_sum_scaled_compiled(x, out, scaling_factor)
elif version == "triton":
moe_sum_reduce_triton(x, out, scaling_factor)
else:
moe_sum_reduce(x, out, scaling_factor)
moe_sum_reduce_cuda(x, out, scaling_factor)
# Benchmark
quantiles = [0.5, 0.2, 0.8]
@@ -155,9 +158,15 @@ def get_benchmark():
lambda: compute_sum_scaled_compiled(x, out, scaling_factor),
quantiles=quantiles,
)
elif version == "triton":
ms, min_ms, max_ms = do_bench(
lambda: moe_sum_reduce_triton(x, out, scaling_factor),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = do_bench(
lambda: moe_sum_reduce(x, out, scaling_factor), quantiles=quantiles
lambda: moe_sum_reduce_cuda(x, out, scaling_factor),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
@@ -176,11 +185,16 @@ def verify_correctness(num_tokens=1024):
compute_sum_scaled_compiled(x, out_compiled, scaling_factor)
out_triton = torch.empty_like(out_baseline)
moe_sum_reduce(x, out_triton, scaling_factor)
moe_sum_reduce_triton(x, out_triton, scaling_factor)
if torch.allclose(
out_baseline, out_compiled, atol=1e-2, rtol=1e-2
) and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2):
out_cuda = torch.empty_like(out_baseline)
moe_sum_reduce_cuda(x, out_cuda, scaling_factor)
if (
torch.allclose(out_baseline, out_compiled, atol=1e-2, rtol=1e-2)
and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2)
and torch.allclose(out_baseline, out_cuda, atol=1e-2, rtol=1e-2)
):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
@@ -188,6 +202,7 @@ def verify_correctness(num_tokens=1024):
f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}"
)
print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}")
print(f"Baseline vs Cuda: {(out_baseline - out_cuda).abs().max().item()}")
if __name__ == "__main__":