diff --git a/benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py index dd8504fd9..7621628c1 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py @@ -17,6 +17,8 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( triton_kernel_moe_forward, ) +from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig +from sglang.srt.layers.moe.topk import TopK, TopKConfig, select_experts def get_model_config(model_name: str, tp_size: int): @@ -80,13 +82,26 @@ def fused_moe_triton_api( input_gating, topk, ): + topk_op = TopK( + top_k=topk, + renormalize=False, + use_grouped_topk=False, + ) + topk_op.use_triton_kernels = True + triton_topk_output = topk_op.forward_cuda( + hidden_states=x, + router_logits=input_gating, + ) + + moe_runner_config = MoeRunnerConfig( + inplace=False, + ) return triton_kernel_moe_forward( x, w1, w2, - input_gating, - topk, - renormalize=False, + triton_topk_output, + moe_runner_config, ) @@ -103,14 +118,16 @@ def fused_moe_sglang_api( a2_scale=None, block_shape=None, ): + topk_output = select_experts( + hidden_states=x, + router_logits=input_gating, + topk_config=TopKConfig(top_k=topk, renormalize=False), + ) return fused_moe_sglang( x, w1, w2, - input_gating, - topk, - renormalize=False, - inplace=True, + topk_output, use_fp8_w8a8=use_fp8_w8a8, w1_scale=w1_scale, w2_scale=w2_scale, diff --git a/test/srt/test_triton_fused_moe.py b/test/srt/test_triton_fused_moe.py index 8d014f6c7..88d33b5f7 100644 --- a/test/srt/test_triton_fused_moe.py +++ b/test/srt/test_triton_fused_moe.py @@ -8,6 +8,8 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( triton_kernel_moe_forward, ) +from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig +from sglang.srt.layers.moe.topk import TopK from sglang.test.test_utils import CustomTestCase @@ -92,8 +94,22 @@ class TestFusedMOE(CustomTestCase): w2_tri = w2_tri.transpose(-2, -1).contiguous() score = self.create_random_cuda_tensor((m, e), dtype) + topk_op = TopK( + top_k=topk, + renormalize=False, + use_grouped_topk=False, + ) + topk_op.use_triton_kernels = True + triton_topk_output = topk_op.forward_cuda( + hidden_states=a, + router_logits=score, + ) + + moe_runner_config = MoeRunnerConfig( + inplace=False, + ) triton_output = triton_kernel_moe_forward( - a, w1_tri, w2_tri, score, topk, renormalize=False + a, w1_tri, w2_tri, triton_topk_output, moe_runner_config ) torch_output = self.torch_naive_moe(a, w1, w2, score, topk) torch.testing.assert_close(triton_output, torch_output, rtol=rtol, atol=atol)