From a91e90d9a3604118554b87b2078a513432fa361a Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 20 Aug 2025 15:10:16 -0700 Subject: [PATCH] [2/2] Fuse routed scaling factor into select_experts (#8690) --- .../srt/layers/moe/fused_moe_triton/layer.py | 7 ++++++ python/sglang/srt/layers/moe/topk.py | 20 ++++++++++++++++ python/sglang/srt/layers/quantization/fp8.py | 17 +++++++------- .../srt/layers/quantization/modelopt_quant.py | 6 ++--- python/sglang/srt/models/deepseek_v2.py | 23 ++++++++++--------- sgl-kernel/tests/test_moe_fused_gate.py | 7 +++++- 6 files changed, 55 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 98f89ab7f..504aeb2fe 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -923,6 +924,12 @@ class FusedMoE(torch.nn.Module): for shard_id in ["w1", "w2", "w3"] ] + def should_fuse_routed_scaling_factor_in_topk(self): + return isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) or ( + isinstance(self.quant_method, Fp8MoEMethod) + and self.quant_method.use_cutlass_fused_experts_fp8 + ) + class FlashInferFusedMoE(FusedMoE): def __init__(self, *args, **kwargs): diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 479103e15..bf8981c13 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -197,6 +197,7 @@ class TopK(CustomOp): scoring_func: str = "softmax", correction_bias: Optional[torch.Tensor] = None, routed_scaling_factor: Optional[float] = None, + apply_routed_scaling_factor_on_output: Optional[bool] = False, ): # NOTE: scoring_func is not used for now, but we keep it for future use # see https://github.com/sgl-project/sglang/pull/4505 for more details @@ -215,6 +216,7 @@ class TopK(CustomOp): custom_routing_function=custom_routing_function, correction_bias=correction_bias, routed_scaling_factor=routed_scaling_factor, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() @@ -433,6 +435,7 @@ def grouped_topk_gpu( routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + apply_routed_scaling_factor_on_output: Optional[bool] = False, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" @@ -480,6 +483,8 @@ def grouped_topk_gpu( else topk_weights[:, :-1].sum(dim=-1, keepdim=True) ) topk_weights = topk_weights / topk_weights_sum + if apply_routed_scaling_factor_on_output: + topk_weights *= routed_scaling_factor topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) @@ -528,6 +533,7 @@ def biased_grouped_topk_impl( routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + apply_routed_scaling_factor_on_output: Optional[bool] = False, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" @@ -579,6 +585,8 @@ def biased_grouped_topk_impl( else topk_weights[:, :-1].sum(dim=-1, keepdim=True) ) topk_weights = topk_weights / topk_weights_sum + if apply_routed_scaling_factor_on_output: + topk_weights *= routed_scaling_factor topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) @@ -621,6 +629,7 @@ def biased_grouped_topk_gpu( routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + apply_routed_scaling_factor_on_output: Optional[bool] = False, ): assert ( routed_scaling_factor is not None @@ -640,6 +649,7 @@ def biased_grouped_topk_gpu( topk, num_fused_shared_experts, routed_scaling_factor, + apply_routed_scaling_factor_on_output, ) # TODO merge into kernel if (expert_location_dispatch_info is not None) or ( @@ -650,6 +660,7 @@ def biased_grouped_topk_gpu( ) return topk_weights, topk_ids elif _use_aiter: + assert not apply_routed_scaling_factor_on_output, "Not implemented" token = gating_output.shape[0] device = gating_output.device assert ( @@ -681,6 +692,7 @@ def biased_grouped_topk_gpu( routed_scaling_factor=routed_scaling_factor, num_token_non_padded=num_token_non_padded, expert_location_dispatch_info=expert_location_dispatch_info, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) @@ -743,6 +755,9 @@ def select_experts( correction_bias = topk_config.correction_bias torch_native = topk_config.torch_native routed_scaling_factor = topk_config.routed_scaling_factor + apply_routed_scaling_factor_on_output = ( + topk_config.apply_routed_scaling_factor_on_output + ) router_logits, correction_bias = ( expert_location_dispatch.transform_select_experts_inputs( @@ -768,6 +783,7 @@ def select_experts( routed_scaling_factor=routed_scaling_factor, num_token_non_padded=num_token_non_padded, expert_location_dispatch_info=expert_location_dispatch_info, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) else: topk_weights, topk_ids = biased_grouped_topk( @@ -782,12 +798,14 @@ def select_experts( routed_scaling_factor=routed_scaling_factor, num_token_non_padded=num_token_non_padded, expert_location_dispatch_info=expert_location_dispatch_info, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) elif torch_native and custom_routing_function is None: assert ( num_token_non_padded is None ), "num_token_non_padded is not yet supported in fused_topk_native" assert expert_location_dispatch_info is None + assert not apply_routed_scaling_factor_on_output, "Not implemented" topk_weights, topk_ids = fused_topk_native( hidden_states=hidden_states, gating_output=router_logits, @@ -795,6 +813,7 @@ def select_experts( renormalize=renormalize, ) elif custom_routing_function is None: + assert not apply_routed_scaling_factor_on_output, "Not implemented" # Qwen3MOE uses fused_topk topk_weights, topk_ids = fused_topk( hidden_states=hidden_states, @@ -809,6 +828,7 @@ def select_experts( num_token_non_padded is None ), "num_token_non_padded is not yet supported in custom_routing_function" assert expert_location_dispatch_info is None + assert not apply_routed_scaling_factor_on_output, "Not implemented" topk_weights, topk_ids = custom_routing_function( hidden_states=hidden_states, gating_output=router_logits, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 5c40bd1f0..0192da7ef 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -514,6 +514,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None self.cutlass_fp8_supported = cutlass_fp8_supported() + self.use_cutlass_fused_experts_fp8 = ( + get_bool_env_var("SGLANG_CUTLASS_MOE") + and self.cutlass_fp8_supported + and self.block_quant + and (is_sm100_supported() or is_sm90_supported()) + ) def create_weights( self, @@ -1021,12 +1027,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): if ret is not None: return ret - if ( - get_bool_env_var("SGLANG_CUTLASS_MOE") - and self.cutlass_fp8_supported - and self.block_quant - and (is_sm100_supported() or is_sm90_supported()) - ): + if self.use_cutlass_fused_experts_fp8: from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 topk_weights, topk_ids, _ = topk_output @@ -1053,9 +1054,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.problem_sizes2, use_fp8_blockscale=True, ) - # TODO: Fuse into select_experts - if moe_runner_config.routed_scaling_factor is not None: - output *= moe_runner_config.routed_scaling_factor + # Scale by routed_scaling_factor is fused into select_experts. return output # Expert fusion with FP8 quantization return fused_experts( diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index db0bf3ab7..6d3b76950 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1305,8 +1305,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): tp_rank=layer.moe_tp_rank, tune_max_num_tokens=next_power_of_2(x.shape[0]), )[0] - if moe_runner_config.routed_scaling_factor is not None: - output *= moe_runner_config.routed_scaling_factor + # Scale by routed_scaling_factor is fused into select_experts. if should_use_flashinfer_cutlass_moe_fp4_allgather(): output, global_output = get_local_dp_buffer(), output get_tp_group().reduce_scatterv( @@ -1332,6 +1331,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): params=layer.cutlass_moe_params, apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input, ).to(x.dtype) - if moe_runner_config.routed_scaling_factor is not None: - output *= moe_runner_config.routed_scaling_factor + # Scale by routed_scaling_factor is fused into select_experts. return output diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 3bec16bfc..eabd56594 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -319,17 +319,6 @@ class DeepseekV2MoE(nn.Module): config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn ) - self.topk = TopK( - top_k=config.num_experts_per_tok + self.num_fused_shared_experts, - renormalize=config.norm_topk_prob, - use_grouped_topk=True, - num_expert_group=config.n_group, - num_fused_shared_experts=self.num_fused_shared_experts, - topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, - routed_scaling_factor=self.routed_scaling_factor, - ) - self.experts = get_moe_impl_class()( num_experts=config.n_routed_experts + self.num_fused_shared_experts @@ -344,6 +333,18 @@ class DeepseekV2MoE(nn.Module): prefix=add_prefix("experts", prefix), ) + self.topk = TopK( + top_k=config.num_experts_per_tok + self.num_fused_shared_experts, + renormalize=config.norm_topk_prob, + use_grouped_topk=True, + num_expert_group=config.n_group, + num_fused_shared_experts=self.num_fused_shared_experts, + topk_group=config.topk_group, + correction_bias=self.gate.e_score_correction_bias, + routed_scaling_factor=self.routed_scaling_factor, + apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(), + ) + self.shared_experts_is_int8 = False self.shared_experts_is_fp8 = False self.shared_experts_weight_block_size = None diff --git a/sgl-kernel/tests/test_moe_fused_gate.py b/sgl-kernel/tests/test_moe_fused_gate.py index 70c4ea209..983895752 100644 --- a/sgl-kernel/tests/test_moe_fused_gate.py +++ b/sgl-kernel/tests/test_moe_fused_gate.py @@ -19,7 +19,10 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk ], ) @pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2]) -def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts): +@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [False, True]) +def test_moe_fused_gate_combined( + seq_length, params, num_fused_shared_experts, apply_routed_scaling_factor_on_output +): num_experts, num_expert_group, topk_group, topk = params dtype = torch.float32 @@ -37,6 +40,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts): topk=topk, num_fused_shared_experts=num_fused_shared_experts, routed_scaling_factor=2.5, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) ref_output, ref_indices = biased_grouped_topk( scores, @@ -48,6 +52,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts): topk_group=topk_group, num_fused_shared_experts=num_fused_shared_experts, routed_scaling_factor=2.5, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) # When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension