From 18da2c96ec092e41f0b8b8dbac4af7b5218ec8f2 Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Thu, 21 Aug 2025 00:54:01 -0700 Subject: [PATCH] [NVIDIA] Fix trtllm fp4 moe backend when used in MTP (#9384) --- python/sglang/srt/layers/moe/ep_moe/layer.py | 6 +++++- python/sglang/srt/layers/moe/fused_moe_triton/layer.py | 2 ++ python/sglang/srt/layers/moe/topk.py | 4 +++- python/sglang/srt/models/deepseek_v2.py | 3 ++- 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 97e16a90e..01fdf686a 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -783,13 +783,17 @@ class DeepEPMoE(EPMoE): return hidden_states -def get_moe_impl_class(): +def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None): if get_moe_a2a_backend().is_deepep(): return DeepEPMoE # NEW: Direct FP4 detection (bypasses EP requirements) # Check for FP4 quantization with TRTLLM flag, regardless of EP if get_moe_runner_backend().is_flashinfer_trtllm(): + # FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod. + # If UnquantizedFusedMoEMethod is detected, fall back to FusedMoE instead. + if quant_config is None: + return FusedMoE try: # Check the quantization argument directly quantization = global_server_args_dict.get("quantization") diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 504aeb2fe..2a00ddd00 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1008,6 +1008,8 @@ class FlashInferFP4MoE(FusedMoE): hidden_states: Input tensor topk_output: TopKOutput object with Bypassed format """ + assert isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) + assert TopKOutputChecker.format_is_bypassed(topk_output) router_logits = topk_output.router_logits diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index e3c7018bb..48296752d 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -198,6 +198,7 @@ class TopK(CustomOp): correction_bias: Optional[torch.Tensor] = None, routed_scaling_factor: Optional[float] = None, apply_routed_scaling_factor_on_output: Optional[bool] = False, + force_topk: bool = False, ): # NOTE: scoring_func is not used for now, but we keep it for future use # see https://github.com/sgl-project/sglang/pull/4505 for more details @@ -220,6 +221,7 @@ class TopK(CustomOp): ) self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() + self.force_topk = force_topk def forward_native( self, @@ -254,7 +256,7 @@ class TopK(CustomOp): sm_first=not self.topk_config.renormalize, ) return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx) - elif ( + elif not self.force_topk and ( should_use_flashinfer_trtllm_moe() or get_moe_runner_backend().is_flashinfer_mxfp4() ): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index eabd56594..434cec4b1 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -319,7 +319,7 @@ class DeepseekV2MoE(nn.Module): config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn ) - self.experts = get_moe_impl_class()( + self.experts = get_moe_impl_class(quant_config)( num_experts=config.n_routed_experts + self.num_fused_shared_experts + global_server_args_dict["ep_num_redundant_experts"], @@ -343,6 +343,7 @@ class DeepseekV2MoE(nn.Module): correction_bias=self.gate.e_score_correction_bias, routed_scaling_factor=self.routed_scaling_factor, apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(), + force_topk=quant_config is None, ) self.shared_experts_is_int8 = False