From 58c468f4045e89981c9f02b6e46a2a49e0fc4b11 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 25 Jul 2025 16:40:23 -0700 Subject: [PATCH] Fix FP4 MoE accuracy from missing routed_scaling_factor (#8333) --- .../sglang/srt/layers/quantization/modelopt_quant.py | 12 ++++++++---- python/sglang/srt/server_args.py | 4 ---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 73de5b0d1..9087f79b0 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -952,7 +952,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): tp_rank: Optional[int] = None, tp_size: Optional[int] = None, ) -> torch.Tensor: - assert activation == "silu", "Only SiLU activation is supported." if self.enable_flashinfer_moe: @@ -982,13 +981,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): tp_size=tp_size, tp_rank=tp_rank, tune_max_num_tokens=next_power_of_2(x.shape[0]), - ) - return output[0] + )[0] + if routed_scaling_factor is not None: + output *= routed_scaling_factor + return output from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 topk_weights, topk_ids, _ = topk_output - return cutlass_moe_fp4( + output = cutlass_moe_fp4( a=x, a1_gscale=layer.w13_input_scale_quant, w1_fp4=layer.w13_weight, @@ -1003,3 +1004,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): params=layer.cutlass_moe_params, apply_router_weight_on_input=apply_router_weight_on_input, ).to(x.dtype) + if routed_scaling_factor is not None: + output *= routed_scaling_factor + return output diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 107c63646..6fec17bc0 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -433,10 +433,6 @@ class ServerArgs: self.quantization == "modelopt_fp4" ), "modelopt_fp4 quantization is required for Flashinfer MOE" os.environ["TRTLLM_ENABLE_PDL"] = "1" - self.disable_shared_experts_fusion = True - logger.warning( - f"Flashinfer MoE is enabled. Shared expert fusion is disabled." - ) # DeepEP MoE if self.enable_deepep_moe: