diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 391d8e714..217abc337 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -476,8 +476,15 @@ class ServerArgs: self.attention_backend == "trtllm_mha" or self.attention_backend == "triton" ) + quantization_config = getattr( + self.get_hf_config(), "quantization_config", None + ) + is_mxfp4_quant_format = ( + quantization_config is not None + and quantization_config.get("quant_method") == "mxfp4" + ) - if is_sm100_supported(): + if is_sm100_supported() and is_mxfp4_quant_format: self.enable_flashinfer_mxfp4_moe = True self.enable_triton_kernel_moe = False else: @@ -485,13 +492,7 @@ class ServerArgs: self.disable_hybrid_swa_memory = True - quantization_config = getattr( - self.get_hf_config(), "quantization_config", None - ) - if ( - quantization_config is not None - and quantization_config.get("quant_method") == "mxfp4" - ): + if is_mxfp4_quant_format: # use bf16 for mxfp4 triton kernels self.dtype = "bfloat16"