[sgl-kernel] Support moe_sum_reduce cuda kernel (#10321)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com> Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import moe_sum_reduce as moe_sum_reduce_cuda
|
||||
from triton.testing import do_bench
|
||||
|
||||
|
||||
@@ -57,7 +58,7 @@ def _moe_sum_reduce_kernel(
|
||||
|
||||
|
||||
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
|
||||
def moe_sum_reduce(
|
||||
def moe_sum_reduce_triton(
|
||||
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
|
||||
):
|
||||
assert input.is_contiguous()
|
||||
@@ -117,9 +118,9 @@ def get_benchmark():
|
||||
x_names=["num_tokens"],
|
||||
x_vals=num_tokens_range,
|
||||
line_arg="version",
|
||||
line_vals=["baseline", "compiled", "triton"],
|
||||
line_names=["Original", "TorchCompile", "TritonKernel"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||
line_vals=["baseline", "compiled", "triton", "cuda"],
|
||||
line_names=["Original", "TorchCompile", "TritonKernel", "CudaKernel"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("yellow", "-")],
|
||||
ylabel="us",
|
||||
plot_name="sum_scaled_performance",
|
||||
args={},
|
||||
@@ -140,8 +141,10 @@ def get_benchmark():
|
||||
compute_sum_scaled_baseline(x, out, scaling_factor)
|
||||
elif version == "compiled":
|
||||
compute_sum_scaled_compiled(x, out, scaling_factor)
|
||||
elif version == "triton":
|
||||
moe_sum_reduce_triton(x, out, scaling_factor)
|
||||
else:
|
||||
moe_sum_reduce(x, out, scaling_factor)
|
||||
moe_sum_reduce_cuda(x, out, scaling_factor)
|
||||
|
||||
# Benchmark
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
@@ -155,9 +158,15 @@ def get_benchmark():
|
||||
lambda: compute_sum_scaled_compiled(x, out, scaling_factor),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif version == "triton":
|
||||
ms, min_ms, max_ms = do_bench(
|
||||
lambda: moe_sum_reduce_triton(x, out, scaling_factor),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
ms, min_ms, max_ms = do_bench(
|
||||
lambda: moe_sum_reduce(x, out, scaling_factor), quantiles=quantiles
|
||||
lambda: moe_sum_reduce_cuda(x, out, scaling_factor),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
@@ -176,11 +185,16 @@ def verify_correctness(num_tokens=1024):
|
||||
compute_sum_scaled_compiled(x, out_compiled, scaling_factor)
|
||||
|
||||
out_triton = torch.empty_like(out_baseline)
|
||||
moe_sum_reduce(x, out_triton, scaling_factor)
|
||||
moe_sum_reduce_triton(x, out_triton, scaling_factor)
|
||||
|
||||
if torch.allclose(
|
||||
out_baseline, out_compiled, atol=1e-2, rtol=1e-2
|
||||
) and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2):
|
||||
out_cuda = torch.empty_like(out_baseline)
|
||||
moe_sum_reduce_cuda(x, out_cuda, scaling_factor)
|
||||
|
||||
if (
|
||||
torch.allclose(out_baseline, out_compiled, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(out_baseline, out_cuda, atol=1e-2, rtol=1e-2)
|
||||
):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
@@ -188,6 +202,7 @@ def verify_correctness(num_tokens=1024):
|
||||
f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}"
|
||||
)
|
||||
print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}")
|
||||
print(f"Baseline vs Cuda: {(out_baseline - out_cuda).abs().max().item()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user