Add deepseek style fused moe group gate selection kernel (#4530)
This commit is contained in:
74
sgl-kernel/benchmark/bench_moe_fused_gate.py
Normal file
74
sgl-kernel/benchmark/bench_moe_fused_gate.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import moe_fused_gate
|
||||
|
||||
from sglang.srt.layers.moe.topk import biased_grouped_topk
|
||||
|
||||
|
||||
def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk):
|
||||
return biased_grouped_topk(
|
||||
scores,
|
||||
scores,
|
||||
bias,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
|
||||
|
||||
def biased_grouped_topk_org_kernel(scores, bias, num_expert_group, topk_group, topk):
|
||||
return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk)
|
||||
|
||||
|
||||
seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]
|
||||
configs = [(sq,) for sq in seq_length_range]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["seq_length"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["original", "kernel"],
|
||||
line_names=["Original", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name="moe-fused-gate-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(seq_length, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8
|
||||
|
||||
scores = torch.randn((seq_length, num_experts), device=device, dtype=dtype)
|
||||
bias = torch.rand(num_experts, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "original":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: biased_grouped_topk_org(
|
||||
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "kernel":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: biased_grouped_topk_org_kernel(
|
||||
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark.run(print_data=True)
|
||||
Reference in New Issue
Block a user