From f9f0138f80a32ecba8a4da619cb51dce2bb3381c Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sat, 2 Aug 2025 20:14:30 +0800 Subject: [PATCH] Revert "[1/2] sgl-kernel: Fuse routed scaling factor into select_experts" (#8706) --- sgl-kernel/csrc/common_extension.cc | 2 +- sgl-kernel/csrc/moe/moe_fused_gate.cu | 27 +++++++------------------ sgl-kernel/include/sgl_kernel_ops.h | 3 +-- sgl-kernel/python/sgl_kernel/moe.py | 11 ++-------- sgl-kernel/tests/test_moe_fused_gate.py | 7 +------ 5 files changed, 12 insertions(+), 38 deletions(-) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 989ae14eb..295939900 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -174,7 +174,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def( "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int " - "num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> " + "num_fused_shared_experts, float routed_scaling_factor) -> " "(Tensor[])"); m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); m.def( diff --git a/sgl-kernel/csrc/moe/moe_fused_gate.cu b/sgl-kernel/csrc/moe/moe_fused_gate.cu index 782a884fb..24bf2d36b 100644 --- a/sgl-kernel/csrc/moe/moe_fused_gate.cu +++ b/sgl-kernel/csrc/moe/moe_fused_gate.cu @@ -59,7 +59,6 @@ __device__ void moe_fused_gate_impl( int64_t topk, int64_t num_fused_shared_experts, double routed_scaling_factor, - bool apply_routed_scaling_factor_on_output, Params params) { int tidx = threadIdx.x; int64_t thread_row = @@ -249,9 +248,6 @@ __device__ void moe_fused_gate_impl( for (int ii = 0; ii < topk; ++ii) { int64_t const idx = topk * thread_row + ii; output_ptr[idx] = output_ptr[idx] / output_sum; - if (apply_routed_scaling_factor_on_output) { - output_ptr[idx] *= routed_scaling_factor; - } } } } @@ -286,8 +282,7 @@ __global__ void moe_fused_gate_kernel( int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts, - double routed_scaling_factor, - bool apply_routed_scaling_factor_on_output) { + double routed_scaling_factor) { KernelParams params; moe_fused_gate_impl( input, @@ -299,7 +294,6 @@ __global__ void moe_fused_gate_kernel( topk, num_fused_shared_experts, routed_scaling_factor, - apply_routed_scaling_factor_on_output, params); } @@ -320,8 +314,7 @@ __global__ void moe_fused_gate_kernel( topk_group, \ topk, \ num_fused_shared_experts, \ - routed_scaling_factor, \ - apply_routed_scaling_factor_on_output); \ + routed_scaling_factor); \ dispatched = true; \ } while (0) @@ -349,8 +342,7 @@ __global__ void moe_fused_gate_kernel_dynamic( int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts, - double routed_scaling_factor, - bool apply_routed_scaling_factor_on_output) { + double routed_scaling_factor) { KernelParamsDynamic params; params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256 params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32 @@ -369,7 +361,6 @@ __global__ void moe_fused_gate_kernel_dynamic( topk, num_fused_shared_experts, routed_scaling_factor, - apply_routed_scaling_factor_on_output, params); } @@ -383,8 +374,7 @@ std::vector moe_fused_gate( int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts, - double routed_scaling_factor, - bool apply_routed_scaling_factor_on_output) { + double routed_scaling_factor) { int64_t num_rows = input.size(0); int32_t num_experts = input.size(1); auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); @@ -483,8 +473,7 @@ std::vector moe_fused_gate( topk_group, topk, num_fused_shared_experts, - routed_scaling_factor, - apply_routed_scaling_factor_on_output); + routed_scaling_factor); } else if (input.scalar_type() == at::kHalf) { moe_fused_gate_kernel_dynamic<<>>( input.data_ptr(), @@ -497,8 +486,7 @@ std::vector moe_fused_gate( topk_group, topk, num_fused_shared_experts, - routed_scaling_factor, - apply_routed_scaling_factor_on_output); + routed_scaling_factor); } else if (input.scalar_type() == at::kFloat) { moe_fused_gate_kernel_dynamic<<>>( input.data_ptr(), @@ -511,8 +499,7 @@ std::vector moe_fused_gate( topk_group, topk, num_fused_shared_experts, - routed_scaling_factor, - apply_routed_scaling_factor_on_output); + routed_scaling_factor); } else { TORCH_CHECK(false, "Unsupported data type for moe_fused_gate"); } diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 88720dfea..fa6de7362 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -243,8 +243,7 @@ std::vector moe_fused_gate( int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts, - double routed_scaling_factor, - bool apply_routed_scaling_factor_on_output); + double routed_scaling_factor); void fp8_blockwise_scaled_grouped_mm( torch::Tensor& output, diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index 9008e7a79..c16a2b6fe 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -44,7 +44,6 @@ def moe_fused_gate( topk, num_fused_shared_experts=0, routed_scaling_factor=0, - apply_routed_scaling_factor_on_output=False, ): # This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion # it split group of expert into num_expert_group, and use top2 expert weight sum in each group @@ -52,13 +51,8 @@ def moe_fused_gate( # the #experts is decided by the input tensor shape and we currently only support power of 2 #experts # and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now. # for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk - # num_fused_shared_experts: if > 0, the last several experts will be - # replaced with shared experts. the shared experts will be divided by the - # routed_scaling_factor - this is intended to cancel out later when routed+shared - # output is scaled so that shared experts are not scaled. - # routed_scaling_factor: if > 0, the experts will be scaled by this factor - # apply_routed_scaling_factor_on_output: if true, output will be - # scaled by the routed_scaling_factor + # num_fused_shared_experts: if > 0, the last several experts will be replaced with shared experts + # routed_scaling_factor: if > 0, the shared experts will be scaled by this factor return torch.ops.sgl_kernel.moe_fused_gate.default( input_tensor, bias, @@ -67,7 +61,6 @@ def moe_fused_gate( topk, num_fused_shared_experts, routed_scaling_factor, - apply_routed_scaling_factor_on_output, ) diff --git a/sgl-kernel/tests/test_moe_fused_gate.py b/sgl-kernel/tests/test_moe_fused_gate.py index 274f387a9..70c4ea209 100644 --- a/sgl-kernel/tests/test_moe_fused_gate.py +++ b/sgl-kernel/tests/test_moe_fused_gate.py @@ -19,10 +19,7 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk ], ) @pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2]) -@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [True, False]) -def test_moe_fused_gate_combined( - seq_length, params, num_fused_shared_experts, apply_routed_scaling_factor_on_output -): +def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts): num_experts, num_expert_group, topk_group, topk = params dtype = torch.float32 @@ -40,7 +37,6 @@ def test_moe_fused_gate_combined( topk=topk, num_fused_shared_experts=num_fused_shared_experts, routed_scaling_factor=2.5, - apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) ref_output, ref_indices = biased_grouped_topk( scores, @@ -52,7 +48,6 @@ def test_moe_fused_gate_combined( topk_group=topk_group, num_fused_shared_experts=num_fused_shared_experts, routed_scaling_factor=2.5, - apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) # When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension