Fix enable flashinfer mxfp4 moe bf16 check (#8950)
This commit is contained in:
@@ -476,8 +476,15 @@ class ServerArgs:
|
||||
self.attention_backend == "trtllm_mha"
|
||||
or self.attention_backend == "triton"
|
||||
)
|
||||
quantization_config = getattr(
|
||||
self.get_hf_config(), "quantization_config", None
|
||||
)
|
||||
is_mxfp4_quant_format = (
|
||||
quantization_config is not None
|
||||
and quantization_config.get("quant_method") == "mxfp4"
|
||||
)
|
||||
|
||||
if is_sm100_supported():
|
||||
if is_sm100_supported() and is_mxfp4_quant_format:
|
||||
self.enable_flashinfer_mxfp4_moe = True
|
||||
self.enable_triton_kernel_moe = False
|
||||
else:
|
||||
@@ -485,13 +492,7 @@ class ServerArgs:
|
||||
|
||||
self.disable_hybrid_swa_memory = True
|
||||
|
||||
quantization_config = getattr(
|
||||
self.get_hf_config(), "quantization_config", None
|
||||
)
|
||||
if (
|
||||
quantization_config is not None
|
||||
and quantization_config.get("quant_method") == "mxfp4"
|
||||
):
|
||||
if is_mxfp4_quant_format:
|
||||
# use bf16 for mxfp4 triton kernels
|
||||
self.dtype = "bfloat16"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user