Fix DSR1 accuracy for flashinfer_trtllm MoE with FP8 quantization (#11081)
This commit is contained in:
@@ -575,9 +575,9 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
||||||
if (
|
if should_use_flashinfer_trtllm_moe() and (
|
||||||
should_use_flashinfer_trtllm_moe()
|
isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
|
||||||
and self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
|
or isinstance(self.quant_method, Fp8MoEMethod)
|
||||||
):
|
):
|
||||||
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
||||||
|
|
||||||
|
|||||||
@@ -916,7 +916,7 @@ class ServerArgs:
|
|||||||
if self.moe_runner_backend == "flashinfer_trtllm":
|
if self.moe_runner_backend == "flashinfer_trtllm":
|
||||||
assert (
|
assert (
|
||||||
self.quantization == "modelopt_fp4" or self.quantization == "fp8"
|
self.quantization == "modelopt_fp4" or self.quantization == "fp8"
|
||||||
), "modelopt_fp4 quantization is required for Flashinfer TRTLLM MoE"
|
), "modelopt_fp4 or fp8 quantization is required for Flashinfer TRTLLM MoE"
|
||||||
self.disable_shared_experts_fusion = True
|
self.disable_shared_experts_fusion = True
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
|
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
|
||||||
|
|||||||
Reference in New Issue
Block a user