From 76915d68a8f8a45e39a51885b7c64619d2968ac0 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Fri, 8 Aug 2025 13:52:09 +0800 Subject: [PATCH] Fix enable flashinfer mxfp4 moe bf16 check (#8950) --- python/sglang/srt/server_args.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 391d8e714..217abc337 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -476,8 +476,15 @@ class ServerArgs: self.attention_backend == "trtllm_mha" or self.attention_backend == "triton" ) + quantization_config = getattr( + self.get_hf_config(), "quantization_config", None + ) + is_mxfp4_quant_format = ( + quantization_config is not None + and quantization_config.get("quant_method") == "mxfp4" + ) - if is_sm100_supported(): + if is_sm100_supported() and is_mxfp4_quant_format: self.enable_flashinfer_mxfp4_moe = True self.enable_triton_kernel_moe = False else: @@ -485,13 +492,7 @@ class ServerArgs: self.disable_hybrid_swa_memory = True - quantization_config = getattr( - self.get_hf_config(), "quantization_config", None - ) - if ( - quantization_config is not None - and quantization_config.get("quant_method") == "mxfp4" - ): + if is_mxfp4_quant_format: # use bf16 for mxfp4 triton kernels self.dtype = "bfloat16"