From a6cc86df9d3e7e6fd6b7704d221af92a6dbe8d93 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Tue, 30 Sep 2025 10:33:12 -0700 Subject: [PATCH] Fix DSR1 accuracy for flashinfer_trtllm MoE with FP8 quantization (#11081) --- python/sglang/srt/layers/moe/fused_moe_triton/layer.py | 6 +++--- python/sglang/srt/server_args.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 241f8b142..acddcc652 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -575,9 +575,9 @@ class FusedMoE(torch.nn.Module): ) # Flashinfer assumes w31 format for w13_weight. Same for the scales. - if ( - should_use_flashinfer_trtllm_moe() - and self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod" + if should_use_flashinfer_trtllm_moe() and ( + isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) + or isinstance(self.quant_method, Fp8MoEMethod) ): shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 18269528a..d647e10ec 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -916,7 +916,7 @@ class ServerArgs: if self.moe_runner_backend == "flashinfer_trtllm": assert ( self.quantization == "modelopt_fp4" or self.quantization == "fp8" - ), "modelopt_fp4 quantization is required for Flashinfer TRTLLM MoE" + ), "modelopt_fp4 or fp8 quantization is required for Flashinfer TRTLLM MoE" self.disable_shared_experts_fusion = True logger.warning( "FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."