From abda2542d5cd465bbbfa5971139090df2dc02646 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sat, 19 Jul 2025 17:33:50 -0700 Subject: [PATCH] Fix tuning_fused_moe_triton.py (#8175) --- .../fused_moe_triton/tuning_fused_moe_triton.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index 5af1b32be..69b0563e9 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -18,6 +18,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( get_default_config, get_moe_configs, ) +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.utils import is_hip _is_hip = is_hip() @@ -115,10 +116,15 @@ def benchmark_config( w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) - input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) + topk_output = select_experts(x, input_gating, topk, renormalize=True) def prepare(i: int): - input_gating.copy_(gating_output[i]) + input_gating = gating_output[i] + new_topk_output = select_experts(x, input_gating, topk, renormalize=True) + topk_output.topk_weights.copy_(new_topk_output.topk_weights) + topk_output.topk_ids.copy_(new_topk_output.topk_ids) + topk_output.router_logits.copy_(new_topk_output.router_logits) def run(): from sglang.srt.layers.moe.fused_moe_triton import override_config @@ -128,9 +134,7 @@ def benchmark_config( x, w1, w2, - input_gating, - topk, - renormalize=True, + topk_output, inplace=True, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8,