diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 78bd6f08d..d1b560219 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -132,6 +132,7 @@ class TopK(CustomOp): scoring_func: str = "softmax", correction_bias: Optional[torch.Tensor] = None, routed_scaling_factor: Optional[float] = None, + apply_routed_scaling_factor_on_output: Optional[bool] = False, ): # NOTE: scoring_func is not used for now, but we keep it for future use # see https://github.com/sgl-project/sglang/pull/4505 for more details @@ -147,6 +148,9 @@ class TopK(CustomOp): self.custom_routing_function = custom_routing_function self.correction_bias = correction_bias self.routed_scaling_factor = routed_scaling_factor + self.apply_routed_scaling_factor_on_output = ( + apply_routed_scaling_factor_on_output + ) self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"] @@ -207,6 +211,7 @@ class TopK(CustomOp): routed_scaling_factor=self.routed_scaling_factor, num_token_non_padded=num_token_non_padded, expert_location_dispatch_info=expert_location_dispatch_info, + apply_routed_scaling_factor_on_output=self.apply_routed_scaling_factor_on_output, ) def forward_cpu( @@ -375,6 +380,7 @@ def grouped_topk_gpu( routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + apply_routed_scaling_factor_on_output: Optional[bool] = False, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" @@ -422,6 +428,8 @@ def grouped_topk_gpu( else topk_weights[:, :-1].sum(dim=-1, keepdim=True) ) topk_weights = topk_weights / topk_weights_sum + if apply_routed_scaling_factor_on_output: + topk_weights *= routed_scaling_factor topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) @@ -468,6 +476,7 @@ def biased_grouped_topk_impl( routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + apply_routed_scaling_factor_on_output: Optional[bool] = False, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" @@ -519,6 +528,8 @@ def biased_grouped_topk_impl( else topk_weights[:, :-1].sum(dim=-1, keepdim=True) ) topk_weights = topk_weights / topk_weights_sum + if apply_routed_scaling_factor_on_output: + topk_weights *= routed_scaling_factor topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) @@ -561,7 +572,10 @@ def biased_grouped_topk_gpu( routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + apply_routed_scaling_factor_on_output: Optional[bool] = False, ): + # TODO(trevor-m): Remove once sgl-kernel is updated + assert not apply_routed_scaling_factor_on_output assert ( routed_scaling_factor is not None ), "routed_scaling_factor is required for biased_grouped_topk" @@ -580,6 +594,8 @@ def biased_grouped_topk_gpu( topk, num_fused_shared_experts, routed_scaling_factor, + # TODO(trevor-m): Uncomment once sgl-kernel is updated + # apply_routed_scaling_factor_on_output, ) # TODO merge into kernel if (expert_location_dispatch_info is not None) or ( @@ -590,6 +606,7 @@ def biased_grouped_topk_gpu( ) return topk_weights, topk_ids elif _use_aiter: + assert not apply_routed_scaling_factor_on_output, "Not implemented" token = gating_output.shape[0] device = gating_output.device assert ( @@ -621,6 +638,7 @@ def biased_grouped_topk_gpu( routed_scaling_factor=routed_scaling_factor, num_token_non_padded=num_token_non_padded, expert_location_dispatch_info=expert_location_dispatch_info, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) @@ -680,6 +698,7 @@ def select_experts( routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + apply_routed_scaling_factor_on_output: Optional[bool] = False, ) -> TopKOutput: router_logits, correction_bias = ( expert_location_dispatch.transform_select_experts_inputs( @@ -705,6 +724,7 @@ def select_experts( routed_scaling_factor=routed_scaling_factor, num_token_non_padded=num_token_non_padded, expert_location_dispatch_info=expert_location_dispatch_info, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) else: topk_weights, topk_ids = biased_grouped_topk( @@ -719,12 +739,14 @@ def select_experts( routed_scaling_factor=routed_scaling_factor, num_token_non_padded=num_token_non_padded, expert_location_dispatch_info=expert_location_dispatch_info, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) elif torch_native and custom_routing_function is None: assert ( num_token_non_padded is None ), "num_token_non_padded is not yet supported in fused_topk_native" assert expert_location_dispatch_info is None + assert not apply_routed_scaling_factor_on_output, "Not implemented" topk_weights, topk_ids = fused_topk_native( hidden_states=hidden_states, gating_output=router_logits, @@ -732,6 +754,7 @@ def select_experts( renormalize=renormalize, ) elif custom_routing_function is None: + assert not apply_routed_scaling_factor_on_output, "Not implemented" # Qwen3MOE uses fused_topk topk_weights, topk_ids = fused_topk( hidden_states=hidden_states, @@ -746,6 +769,7 @@ def select_experts( num_token_non_padded is None ), "num_token_non_padded is not yet supported in custom_routing_function" assert expert_location_dispatch_info is None + assert not apply_routed_scaling_factor_on_output, "Not implemented" topk_weights, topk_ids = custom_routing_function( hidden_states=hidden_states, gating_output=router_logits, diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 295939900..989ae14eb 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -174,7 +174,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def( "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int " - "num_fused_shared_experts, float routed_scaling_factor) -> " + "num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> " "(Tensor[])"); m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); m.def( diff --git a/sgl-kernel/csrc/moe/moe_fused_gate.cu b/sgl-kernel/csrc/moe/moe_fused_gate.cu index 24bf2d36b..782a884fb 100644 --- a/sgl-kernel/csrc/moe/moe_fused_gate.cu +++ b/sgl-kernel/csrc/moe/moe_fused_gate.cu @@ -59,6 +59,7 @@ __device__ void moe_fused_gate_impl( int64_t topk, int64_t num_fused_shared_experts, double routed_scaling_factor, + bool apply_routed_scaling_factor_on_output, Params params) { int tidx = threadIdx.x; int64_t thread_row = @@ -248,6 +249,9 @@ __device__ void moe_fused_gate_impl( for (int ii = 0; ii < topk; ++ii) { int64_t const idx = topk * thread_row + ii; output_ptr[idx] = output_ptr[idx] / output_sum; + if (apply_routed_scaling_factor_on_output) { + output_ptr[idx] *= routed_scaling_factor; + } } } } @@ -282,7 +286,8 @@ __global__ void moe_fused_gate_kernel( int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts, - double routed_scaling_factor) { + double routed_scaling_factor, + bool apply_routed_scaling_factor_on_output) { KernelParams params; moe_fused_gate_impl( input, @@ -294,6 +299,7 @@ __global__ void moe_fused_gate_kernel( topk, num_fused_shared_experts, routed_scaling_factor, + apply_routed_scaling_factor_on_output, params); } @@ -314,7 +320,8 @@ __global__ void moe_fused_gate_kernel( topk_group, \ topk, \ num_fused_shared_experts, \ - routed_scaling_factor); \ + routed_scaling_factor, \ + apply_routed_scaling_factor_on_output); \ dispatched = true; \ } while (0) @@ -342,7 +349,8 @@ __global__ void moe_fused_gate_kernel_dynamic( int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts, - double routed_scaling_factor) { + double routed_scaling_factor, + bool apply_routed_scaling_factor_on_output) { KernelParamsDynamic params; params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256 params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32 @@ -361,6 +369,7 @@ __global__ void moe_fused_gate_kernel_dynamic( topk, num_fused_shared_experts, routed_scaling_factor, + apply_routed_scaling_factor_on_output, params); } @@ -374,7 +383,8 @@ std::vector moe_fused_gate( int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts, - double routed_scaling_factor) { + double routed_scaling_factor, + bool apply_routed_scaling_factor_on_output) { int64_t num_rows = input.size(0); int32_t num_experts = input.size(1); auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); @@ -473,7 +483,8 @@ std::vector moe_fused_gate( topk_group, topk, num_fused_shared_experts, - routed_scaling_factor); + routed_scaling_factor, + apply_routed_scaling_factor_on_output); } else if (input.scalar_type() == at::kHalf) { moe_fused_gate_kernel_dynamic<<>>( input.data_ptr(), @@ -486,7 +497,8 @@ std::vector moe_fused_gate( topk_group, topk, num_fused_shared_experts, - routed_scaling_factor); + routed_scaling_factor, + apply_routed_scaling_factor_on_output); } else if (input.scalar_type() == at::kFloat) { moe_fused_gate_kernel_dynamic<<>>( input.data_ptr(), @@ -499,7 +511,8 @@ std::vector moe_fused_gate( topk_group, topk, num_fused_shared_experts, - routed_scaling_factor); + routed_scaling_factor, + apply_routed_scaling_factor_on_output); } else { TORCH_CHECK(false, "Unsupported data type for moe_fused_gate"); } diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index fa6de7362..88720dfea 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -243,7 +243,8 @@ std::vector moe_fused_gate( int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts, - double routed_scaling_factor); + double routed_scaling_factor, + bool apply_routed_scaling_factor_on_output); void fp8_blockwise_scaled_grouped_mm( torch::Tensor& output, diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index c16a2b6fe..9008e7a79 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -44,6 +44,7 @@ def moe_fused_gate( topk, num_fused_shared_experts=0, routed_scaling_factor=0, + apply_routed_scaling_factor_on_output=False, ): # This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion # it split group of expert into num_expert_group, and use top2 expert weight sum in each group @@ -51,8 +52,13 @@ def moe_fused_gate( # the #experts is decided by the input tensor shape and we currently only support power of 2 #experts # and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now. # for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk - # num_fused_shared_experts: if > 0, the last several experts will be replaced with shared experts - # routed_scaling_factor: if > 0, the shared experts will be scaled by this factor + # num_fused_shared_experts: if > 0, the last several experts will be + # replaced with shared experts. the shared experts will be divided by the + # routed_scaling_factor - this is intended to cancel out later when routed+shared + # output is scaled so that shared experts are not scaled. + # routed_scaling_factor: if > 0, the experts will be scaled by this factor + # apply_routed_scaling_factor_on_output: if true, output will be + # scaled by the routed_scaling_factor return torch.ops.sgl_kernel.moe_fused_gate.default( input_tensor, bias, @@ -61,6 +67,7 @@ def moe_fused_gate( topk, num_fused_shared_experts, routed_scaling_factor, + apply_routed_scaling_factor_on_output, ) diff --git a/sgl-kernel/tests/test_moe_fused_gate.py b/sgl-kernel/tests/test_moe_fused_gate.py index 70c4ea209..274f387a9 100644 --- a/sgl-kernel/tests/test_moe_fused_gate.py +++ b/sgl-kernel/tests/test_moe_fused_gate.py @@ -19,7 +19,10 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk ], ) @pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2]) -def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts): +@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [True, False]) +def test_moe_fused_gate_combined( + seq_length, params, num_fused_shared_experts, apply_routed_scaling_factor_on_output +): num_experts, num_expert_group, topk_group, topk = params dtype = torch.float32 @@ -37,6 +40,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts): topk=topk, num_fused_shared_experts=num_fused_shared_experts, routed_scaling_factor=2.5, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) ref_output, ref_indices = biased_grouped_topk( scores, @@ -48,6 +52,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts): topk_group=topk_group, num_fused_shared_experts=num_fused_shared_experts, routed_scaling_factor=2.5, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) # When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension