diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index 5af1b32be..69b0563e9 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -18,6 +18,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( get_default_config, get_moe_configs, ) +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.utils import is_hip _is_hip = is_hip() @@ -115,10 +116,15 @@ def benchmark_config( w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) - input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) + topk_output = select_experts(x, input_gating, topk, renormalize=True) def prepare(i: int): - input_gating.copy_(gating_output[i]) + input_gating = gating_output[i] + new_topk_output = select_experts(x, input_gating, topk, renormalize=True) + topk_output.topk_weights.copy_(new_topk_output.topk_weights) + topk_output.topk_ids.copy_(new_topk_output.topk_ids) + topk_output.router_logits.copy_(new_topk_output.router_logits) def run(): from sglang.srt.layers.moe.fused_moe_triton import override_config @@ -128,9 +134,7 @@ def benchmark_config( x, w1, w2, - input_gating, - topk, - renormalize=True, + topk_output, inplace=True, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8,