[NVIDIA] Fix trtllm fp4 moe backend when used in MTP (#9384)
This commit is contained in:
@@ -783,13 +783,17 @@ class DeepEPMoE(EPMoE):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def get_moe_impl_class():
|
def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
||||||
if get_moe_a2a_backend().is_deepep():
|
if get_moe_a2a_backend().is_deepep():
|
||||||
return DeepEPMoE
|
return DeepEPMoE
|
||||||
|
|
||||||
# NEW: Direct FP4 detection (bypasses EP requirements)
|
# NEW: Direct FP4 detection (bypasses EP requirements)
|
||||||
# Check for FP4 quantization with TRTLLM flag, regardless of EP
|
# Check for FP4 quantization with TRTLLM flag, regardless of EP
|
||||||
if get_moe_runner_backend().is_flashinfer_trtllm():
|
if get_moe_runner_backend().is_flashinfer_trtllm():
|
||||||
|
# FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod.
|
||||||
|
# If UnquantizedFusedMoEMethod is detected, fall back to FusedMoE instead.
|
||||||
|
if quant_config is None:
|
||||||
|
return FusedMoE
|
||||||
try:
|
try:
|
||||||
# Check the quantization argument directly
|
# Check the quantization argument directly
|
||||||
quantization = global_server_args_dict.get("quantization")
|
quantization = global_server_args_dict.get("quantization")
|
||||||
|
|||||||
@@ -1008,6 +1008,8 @@ class FlashInferFP4MoE(FusedMoE):
|
|||||||
hidden_states: Input tensor
|
hidden_states: Input tensor
|
||||||
topk_output: TopKOutput object with Bypassed format
|
topk_output: TopKOutput object with Bypassed format
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
|
||||||
|
|
||||||
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
||||||
|
|
||||||
router_logits = topk_output.router_logits
|
router_logits = topk_output.router_logits
|
||||||
|
|||||||
@@ -198,6 +198,7 @@ class TopK(CustomOp):
|
|||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||||
|
force_topk: bool = False,
|
||||||
):
|
):
|
||||||
# NOTE: scoring_func is not used for now, but we keep it for future use
|
# NOTE: scoring_func is not used for now, but we keep it for future use
|
||||||
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
||||||
@@ -220,6 +221,7 @@ class TopK(CustomOp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
||||||
|
self.force_topk = force_topk
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
@@ -254,7 +256,7 @@ class TopK(CustomOp):
|
|||||||
sm_first=not self.topk_config.renormalize,
|
sm_first=not self.topk_config.renormalize,
|
||||||
)
|
)
|
||||||
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
||||||
elif (
|
elif not self.force_topk and (
|
||||||
should_use_flashinfer_trtllm_moe()
|
should_use_flashinfer_trtllm_moe()
|
||||||
or get_moe_runner_backend().is_flashinfer_mxfp4()
|
or get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -319,7 +319,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
||||||
)
|
)
|
||||||
|
|
||||||
self.experts = get_moe_impl_class()(
|
self.experts = get_moe_impl_class(quant_config)(
|
||||||
num_experts=config.n_routed_experts
|
num_experts=config.n_routed_experts
|
||||||
+ self.num_fused_shared_experts
|
+ self.num_fused_shared_experts
|
||||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||||
@@ -343,6 +343,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
correction_bias=self.gate.e_score_correction_bias,
|
correction_bias=self.gate.e_score_correction_bias,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
||||||
|
force_topk=quant_config is None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.shared_experts_is_int8 = False
|
self.shared_experts_is_int8 = False
|
||||||
|
|||||||
Reference in New Issue
Block a user