Fix enable flashinfer mxfp4 moe bf16 check (#8950)
This commit is contained in:
@@ -476,8 +476,15 @@ class ServerArgs:
|
|||||||
self.attention_backend == "trtllm_mha"
|
self.attention_backend == "trtllm_mha"
|
||||||
or self.attention_backend == "triton"
|
or self.attention_backend == "triton"
|
||||||
)
|
)
|
||||||
|
quantization_config = getattr(
|
||||||
|
self.get_hf_config(), "quantization_config", None
|
||||||
|
)
|
||||||
|
is_mxfp4_quant_format = (
|
||||||
|
quantization_config is not None
|
||||||
|
and quantization_config.get("quant_method") == "mxfp4"
|
||||||
|
)
|
||||||
|
|
||||||
if is_sm100_supported():
|
if is_sm100_supported() and is_mxfp4_quant_format:
|
||||||
self.enable_flashinfer_mxfp4_moe = True
|
self.enable_flashinfer_mxfp4_moe = True
|
||||||
self.enable_triton_kernel_moe = False
|
self.enable_triton_kernel_moe = False
|
||||||
else:
|
else:
|
||||||
@@ -485,13 +492,7 @@ class ServerArgs:
|
|||||||
|
|
||||||
self.disable_hybrid_swa_memory = True
|
self.disable_hybrid_swa_memory = True
|
||||||
|
|
||||||
quantization_config = getattr(
|
if is_mxfp4_quant_format:
|
||||||
self.get_hf_config(), "quantization_config", None
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
quantization_config is not None
|
|
||||||
and quantization_config.get("quant_method") == "mxfp4"
|
|
||||||
):
|
|
||||||
# use bf16 for mxfp4 triton kernels
|
# use bf16 for mxfp4 triton kernels
|
||||||
self.dtype = "bfloat16"
|
self.dtype = "bfloat16"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user