From 4844fac91d0750c45fc8b0eb87196ab1be7aa21c Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sun, 14 Sep 2025 19:16:25 -0700 Subject: [PATCH] Refactor TopK to ensure readability and extensibility (#9338) --- python/sglang/srt/layers/moe/ep_moe/layer.py | 8 ++-- .../srt/layers/moe/fused_moe_triton/layer.py | 10 ----- python/sglang/srt/layers/moe/topk.py | 39 ++++++++++++++----- python/sglang/srt/managers/schedule_batch.py | 1 - python/sglang/srt/models/bailing_moe.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 19 ++++----- python/sglang/srt/models/ernie4.py | 2 +- python/sglang/srt/models/glm4_moe.py | 2 +- python/sglang/srt/models/gpt_oss.py | 2 +- python/sglang/srt/models/longcat_flash.py | 4 +- python/sglang/srt/models/qwen2_moe.py | 2 +- python/sglang/srt/models/qwen3_moe.py | 2 +- python/sglang/srt/models/qwen3_next.py | 4 +- python/sglang/srt/models/step3_vl.py | 2 +- 14 files changed, 52 insertions(+), 47 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 8184a9305..0bd49600e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -888,7 +888,7 @@ class DeepEPMoE(EPMoE): raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}") -def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None): +def get_moe_impl_class(quant_config: Optional[QuantizationConfig]): if get_moe_a2a_backend().is_deepep(): return DeepEPMoE @@ -901,8 +901,7 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None): return FusedMoE try: # Check the quantization argument directly - quantization = global_server_args_dict.get("quantization") - if quantization == "modelopt_fp4": + if quant_config is not None and quant_config.get_name() == "modelopt_fp4": from sglang.srt.layers.moe.fused_moe_triton.layer import ( FlashInferFP4MoE, ) @@ -911,7 +910,8 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None): except: pass - if should_use_flashinfer_trtllm_moe(): + if should_use_flashinfer_trtllm_moe() and quant_config is not None: + # FIXME: FlashInferFusedMoE only supports fp8 quant now return FlashInferFusedMoE if get_moe_runner_backend().is_flashinfer_cutlass(): return FusedMoE diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index abe604fc6..81355c4f9 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -74,16 +74,6 @@ if should_use_flashinfer_trtllm_moe(): logger = logging.getLogger(__name__) -def _is_fp4_quantization_enabled(): - """Check if ModelOpt FP4 quantization is enabled.""" - try: - # Use the same simple check that works for class selection - quantization = global_server_args_dict.get("quantization") - return quantization == "modelopt_fp4" - except: - return False - - def _get_tile_tokens_dim(num_tokens, top_k, num_experts): # Guess tokens per expert assuming perfect expert distribution first. num_tokens_per_expert = (num_tokens * top_k) // num_experts diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index b8f73473c..5b22f5a1f 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -19,6 +19,7 @@ import math from dataclasses import dataclass from enum import Enum, auto from typing import ( + TYPE_CHECKING, Callable, NamedTuple, Optional, @@ -51,6 +52,9 @@ from sglang.srt.utils import ( is_npu, ) +if TYPE_CHECKING: + from sglang.srt.layers.quantization import QuantizationConfig + try: from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing except ImportError: @@ -94,6 +98,7 @@ class TopKConfig: torch_native: bool = False routed_scaling_factor: Optional[float] = None apply_routed_scaling_factor_on_output: bool = False + output_format: Optional[TopKOutputFormat] = None # -------------------------------- TopKOutput --------------------------------------- @@ -196,9 +201,10 @@ class TopK(CustomOp): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", correction_bias: Optional[torch.Tensor] = None, + quant_config: Optional[QuantizationConfig] = None, routed_scaling_factor: Optional[float] = None, apply_routed_scaling_factor_on_output: Optional[bool] = False, - force_topk: bool = False, + output_format: Optional[TopKOutputFormat] = None, ): # NOTE: scoring_func is not used for now, but we keep it for future use # see https://github.com/sgl-project/sglang/pull/4505 for more details @@ -207,6 +213,14 @@ class TopK(CustomOp): if use_grouped_topk: assert num_expert_group is not None and topk_group is not None + if ( + quant_config is not None + and quant_config.get_name() == "modelopt_fp4" + and should_use_flashinfer_trtllm_moe() + ): + # https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643 + correction_bias = correction_bias.to(torch.bfloat16) + self.topk_config = TopKConfig( top_k=top_k, use_grouped_topk=use_grouped_topk, @@ -218,11 +232,9 @@ class TopK(CustomOp): correction_bias=correction_bias, routed_scaling_factor=routed_scaling_factor, apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, + output_format=output_format, ) - self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() - self.force_topk = force_topk - def forward_native( self, hidden_states: torch.Tensor, @@ -248,7 +260,19 @@ class TopK(CustomOp): num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ) -> TopKOutput: - if self.use_triton_kernels: + if self.topk_config.output_format is not None: + output_format = self.topk_config.output_format + elif get_moe_runner_backend().is_triton_kernel(): + output_format = TopKOutputFormat.TRITON_KERNEL + elif ( + should_use_flashinfer_trtllm_moe() + or get_moe_runner_backend().is_flashinfer_mxfp4() + ): + output_format = TopKOutputFormat.BYPASSED + else: + output_format = TopKOutputFormat.STANDARD + + if output_format == TopKOutputFormat.TRITON_KERNEL: # renormalize=True is equivalent to sm_first=False routing_data, gather_idx, scatter_idx = routing( router_logits, @@ -256,10 +280,7 @@ class TopK(CustomOp): sm_first=not self.topk_config.renormalize, ) return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx) - elif not self.force_topk and ( - should_use_flashinfer_trtllm_moe() - or get_moe_runner_backend().is_flashinfer_mxfp4() - ): + elif output_format == TopKOutputFormat.BYPASSED: return BypassedTopKOutput( hidden_states=hidden_states, router_logits=router_logits, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 66e6ae44e..b226b8331 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -105,7 +105,6 @@ GLOBAL_SERVER_ARGS_KEYS = [ "weight_loader_disable_mmap", "enable_multimodal", "enable_symm_mem", - "quantization", "enable_custom_logit_processor", "disaggregation_mode", ] diff --git a/python/sglang/srt/models/bailing_moe.py b/python/sglang/srt/models/bailing_moe.py index 5eb3e9373..0797f4f6f 100644 --- a/python/sglang/srt/models/bailing_moe.py +++ b/python/sglang/srt/models/bailing_moe.py @@ -246,7 +246,7 @@ class BailingMoESparseMoeBlock(nn.Module): routed_scaling_factor=self.routed_scaling_factor, ) - self.experts = get_moe_impl_class()( + self.experts = get_moe_impl_class(quant_config)( num_experts=self.num_experts, top_k=self.top_k, layer_id=self.layer_id, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index dfc9d8d7c..c1573d8a2 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -65,14 +65,10 @@ from sglang.srt.layers.moe import ( get_deepep_mode, get_moe_a2a_backend, should_use_flashinfer_cutlass_moe_fp4_allgather, - should_use_flashinfer_trtllm_moe, ) from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class -from sglang.srt.layers.moe.fused_moe_triton.layer import ( - FusedMoE, - _is_fp4_quantization_enabled, -) -from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_kernel import ( @@ -375,10 +371,6 @@ class DeepseekV2MoE(nn.Module): prefix=add_prefix("experts", prefix), ) - correction_bias = self.gate.e_score_correction_bias - # https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643 - if _is_fp4_quantization_enabled() and should_use_flashinfer_trtllm_moe(): - correction_bias = correction_bias.to(torch.bfloat16) self.topk = TopK( top_k=config.num_experts_per_tok + self.num_fused_shared_experts, renormalize=config.norm_topk_prob, @@ -386,10 +378,13 @@ class DeepseekV2MoE(nn.Module): num_expert_group=config.n_group, num_fused_shared_experts=self.num_fused_shared_experts, topk_group=config.topk_group, - correction_bias=correction_bias, + correction_bias=self.gate.e_score_correction_bias, + quant_config=quant_config, routed_scaling_factor=self.routed_scaling_factor, apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(), - force_topk=quant_config is None, + # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized + # and requires the output format to be standard. We use quant_config to determine the output format. + output_format=TopKOutputFormat.STANDARD if quant_config is None else None, ) self.shared_experts_is_int8 = False diff --git a/python/sglang/srt/models/ernie4.py b/python/sglang/srt/models/ernie4.py index 78a7b4b94..ab1b6576b 100644 --- a/python/sglang/srt/models/ernie4.py +++ b/python/sglang/srt/models/ernie4.py @@ -92,7 +92,7 @@ class Ernie4Moe(nn.Module): correction_bias=self.gate.e_score_correction_bias, ) - self.experts = get_moe_impl_class()( + self.experts = get_moe_impl_class(quant_config)( num_experts=config.moe_num_experts, top_k=config.moe_k, hidden_size=config.hidden_size, diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 5ae5b0af6..867ffe91b 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -429,7 +429,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): routed_scaling_factor=self.routed_scaling_factor, ) - self.experts = get_moe_impl_class()( + self.experts = get_moe_impl_class(quant_config)( num_experts=config.n_routed_experts + self.num_fused_shared_experts + global_server_args_dict["ep_num_redundant_experts"], diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 64efff14b..7231a5d75 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -121,7 +121,7 @@ class GptOssSparseMoeBlock(nn.Module): ) self.top_k = config.num_experts_per_tok - experts_type = get_moe_impl_class() + experts_type = get_moe_impl_class(quant_config) extra_kwargs = {} if experts_type.__name__ == "FusedMoE": quant_config_name = ( diff --git a/python/sglang/srt/models/longcat_flash.py b/python/sglang/srt/models/longcat_flash.py index 9531cb83e..3fdd8f643 100644 --- a/python/sglang/srt/models/longcat_flash.py +++ b/python/sglang/srt/models/longcat_flash.py @@ -260,7 +260,7 @@ class LongcatFlashMoE(nn.Module): ) self.topk.forward = self.topk.forward_native - self.experts = get_moe_impl_class()( + self.experts = get_moe_impl_class(quant_config)( num_experts=self.num_experts, top_k=self.top_k, layer_id=self.layer_id, @@ -853,7 +853,7 @@ class LongcatFlashForCausalLM(nn.Module): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 9291146d9..0375ac478 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -143,7 +143,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): renormalize=config.norm_topk_prob, ) - self.experts = get_moe_impl_class()( + self.experts = get_moe_impl_class(quant_config)( layer_id=self.layer_id, top_k=config.num_experts_per_tok, num_experts=config.num_experts, diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index c1c4c3638..9d92a3b13 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -98,7 +98,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): use_grouped_topk=False, ) - self.experts = get_moe_impl_class()( + self.experts = get_moe_impl_class(quant_config)( num_experts=config.num_experts + global_server_args_dict["ep_num_redundant_experts"], top_k=config.num_experts_per_tok, diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 52d8c6faf..006ce4f91 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -30,7 +30,7 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -935,7 +935,7 @@ class Qwen3NextForCausalLM(nn.Module): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/python/sglang/srt/models/step3_vl.py b/python/sglang/srt/models/step3_vl.py index a93bf69e7..626406da1 100644 --- a/python/sglang/srt/models/step3_vl.py +++ b/python/sglang/srt/models/step3_vl.py @@ -133,7 +133,7 @@ class Step3TextMoEMLP(nn.Module): use_grouped_topk=False, ) - self.experts = get_moe_impl_class()( + self.experts = get_moe_impl_class(quant_config)( num_experts=config.moe_num_experts, top_k=config.moe_top_k, hidden_size=config.hidden_size,