From f642524fd992ea5116c68830fef2b9afb2981b31 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 1 Aug 2025 18:14:24 -0700 Subject: [PATCH] [1/2] sgl-kernel: Fuse routed scaling factor into select_experts (#8364) --- sgl-kernel/csrc/common_extension.cc | 2 +- sgl-kernel/csrc/moe/moe_fused_gate.cu | 27 ++++++++++++++++++------- sgl-kernel/include/sgl_kernel_ops.h | 3 ++- sgl-kernel/python/sgl_kernel/moe.py | 11 ++++++++-- sgl-kernel/tests/test_moe_fused_gate.py | 7 ++++++- 5 files changed, 38 insertions(+), 12 deletions(-) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 295939900..989ae14eb 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -174,7 +174,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def( "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int " - "num_fused_shared_experts, float routed_scaling_factor) -> " + "num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> " "(Tensor[])"); m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); m.def( diff --git a/sgl-kernel/csrc/moe/moe_fused_gate.cu b/sgl-kernel/csrc/moe/moe_fused_gate.cu index 24bf2d36b..782a884fb 100644 --- a/sgl-kernel/csrc/moe/moe_fused_gate.cu +++ b/sgl-kernel/csrc/moe/moe_fused_gate.cu @@ -59,6 +59,7 @@ __device__ void moe_fused_gate_impl( int64_t topk, int64_t num_fused_shared_experts, double routed_scaling_factor, + bool apply_routed_scaling_factor_on_output, Params params) { int tidx = threadIdx.x; int64_t thread_row = @@ -248,6 +249,9 @@ __device__ void moe_fused_gate_impl( for (int ii = 0; ii < topk; ++ii) { int64_t const idx = topk * thread_row + ii; output_ptr[idx] = output_ptr[idx] / output_sum; + if (apply_routed_scaling_factor_on_output) { + output_ptr[idx] *= routed_scaling_factor; + } } } } @@ -282,7 +286,8 @@ __global__ void moe_fused_gate_kernel( int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts, - double routed_scaling_factor) { + double routed_scaling_factor, + bool apply_routed_scaling_factor_on_output) { KernelParams params; moe_fused_gate_impl( input, @@ -294,6 +299,7 @@ __global__ void moe_fused_gate_kernel( topk, num_fused_shared_experts, routed_scaling_factor, + apply_routed_scaling_factor_on_output, params); } @@ -314,7 +320,8 @@ __global__ void moe_fused_gate_kernel( topk_group, \ topk, \ num_fused_shared_experts, \ - routed_scaling_factor); \ + routed_scaling_factor, \ + apply_routed_scaling_factor_on_output); \ dispatched = true; \ } while (0) @@ -342,7 +349,8 @@ __global__ void moe_fused_gate_kernel_dynamic( int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts, - double routed_scaling_factor) { + double routed_scaling_factor, + bool apply_routed_scaling_factor_on_output) { KernelParamsDynamic params; params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256 params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32 @@ -361,6 +369,7 @@ __global__ void moe_fused_gate_kernel_dynamic( topk, num_fused_shared_experts, routed_scaling_factor, + apply_routed_scaling_factor_on_output, params); } @@ -374,7 +383,8 @@ std::vector moe_fused_gate( int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts, - double routed_scaling_factor) { + double routed_scaling_factor, + bool apply_routed_scaling_factor_on_output) { int64_t num_rows = input.size(0); int32_t num_experts = input.size(1); auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); @@ -473,7 +483,8 @@ std::vector moe_fused_gate( topk_group, topk, num_fused_shared_experts, - routed_scaling_factor); + routed_scaling_factor, + apply_routed_scaling_factor_on_output); } else if (input.scalar_type() == at::kHalf) { moe_fused_gate_kernel_dynamic<<>>( input.data_ptr(), @@ -486,7 +497,8 @@ std::vector moe_fused_gate( topk_group, topk, num_fused_shared_experts, - routed_scaling_factor); + routed_scaling_factor, + apply_routed_scaling_factor_on_output); } else if (input.scalar_type() == at::kFloat) { moe_fused_gate_kernel_dynamic<<>>( input.data_ptr(), @@ -499,7 +511,8 @@ std::vector moe_fused_gate( topk_group, topk, num_fused_shared_experts, - routed_scaling_factor); + routed_scaling_factor, + apply_routed_scaling_factor_on_output); } else { TORCH_CHECK(false, "Unsupported data type for moe_fused_gate"); } diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index fa6de7362..88720dfea 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -243,7 +243,8 @@ std::vector moe_fused_gate( int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts, - double routed_scaling_factor); + double routed_scaling_factor, + bool apply_routed_scaling_factor_on_output); void fp8_blockwise_scaled_grouped_mm( torch::Tensor& output, diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index c16a2b6fe..9008e7a79 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -44,6 +44,7 @@ def moe_fused_gate( topk, num_fused_shared_experts=0, routed_scaling_factor=0, + apply_routed_scaling_factor_on_output=False, ): # This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion # it split group of expert into num_expert_group, and use top2 expert weight sum in each group @@ -51,8 +52,13 @@ def moe_fused_gate( # the #experts is decided by the input tensor shape and we currently only support power of 2 #experts # and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now. # for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk - # num_fused_shared_experts: if > 0, the last several experts will be replaced with shared experts - # routed_scaling_factor: if > 0, the shared experts will be scaled by this factor + # num_fused_shared_experts: if > 0, the last several experts will be + # replaced with shared experts. the shared experts will be divided by the + # routed_scaling_factor - this is intended to cancel out later when routed+shared + # output is scaled so that shared experts are not scaled. + # routed_scaling_factor: if > 0, the experts will be scaled by this factor + # apply_routed_scaling_factor_on_output: if true, output will be + # scaled by the routed_scaling_factor return torch.ops.sgl_kernel.moe_fused_gate.default( input_tensor, bias, @@ -61,6 +67,7 @@ def moe_fused_gate( topk, num_fused_shared_experts, routed_scaling_factor, + apply_routed_scaling_factor_on_output, ) diff --git a/sgl-kernel/tests/test_moe_fused_gate.py b/sgl-kernel/tests/test_moe_fused_gate.py index 70c4ea209..274f387a9 100644 --- a/sgl-kernel/tests/test_moe_fused_gate.py +++ b/sgl-kernel/tests/test_moe_fused_gate.py @@ -19,7 +19,10 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk ], ) @pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2]) -def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts): +@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [True, False]) +def test_moe_fused_gate_combined( + seq_length, params, num_fused_shared_experts, apply_routed_scaling_factor_on_output +): num_experts, num_expert_group, topk_group, topk = params dtype = torch.float32 @@ -37,6 +40,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts): topk=topk, num_fused_shared_experts=num_fused_shared_experts, routed_scaling_factor=2.5, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) ref_output, ref_indices = biased_grouped_topk( scores, @@ -48,6 +52,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts): topk_group=topk_group, num_fused_shared_experts=num_fused_shared_experts, routed_scaling_factor=2.5, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) # When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension