[NVIDIA] Fix trtllm fp4 moe backend when used in MTP (#9384)
This commit is contained in:
@@ -783,13 +783,17 @@ class DeepEPMoE(EPMoE):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def get_moe_impl_class():
|
||||
def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
||||
if get_moe_a2a_backend().is_deepep():
|
||||
return DeepEPMoE
|
||||
|
||||
# NEW: Direct FP4 detection (bypasses EP requirements)
|
||||
# Check for FP4 quantization with TRTLLM flag, regardless of EP
|
||||
if get_moe_runner_backend().is_flashinfer_trtllm():
|
||||
# FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod.
|
||||
# If UnquantizedFusedMoEMethod is detected, fall back to FusedMoE instead.
|
||||
if quant_config is None:
|
||||
return FusedMoE
|
||||
try:
|
||||
# Check the quantization argument directly
|
||||
quantization = global_server_args_dict.get("quantization")
|
||||
|
||||
@@ -1008,6 +1008,8 @@ class FlashInferFP4MoE(FusedMoE):
|
||||
hidden_states: Input tensor
|
||||
topk_output: TopKOutput object with Bypassed format
|
||||
"""
|
||||
assert isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
|
||||
|
||||
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
||||
|
||||
router_logits = topk_output.router_logits
|
||||
|
||||
@@ -198,6 +198,7 @@ class TopK(CustomOp):
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||
force_topk: bool = False,
|
||||
):
|
||||
# NOTE: scoring_func is not used for now, but we keep it for future use
|
||||
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
||||
@@ -220,6 +221,7 @@ class TopK(CustomOp):
|
||||
)
|
||||
|
||||
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
||||
self.force_topk = force_topk
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
@@ -254,7 +256,7 @@ class TopK(CustomOp):
|
||||
sm_first=not self.topk_config.renormalize,
|
||||
)
|
||||
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
||||
elif (
|
||||
elif not self.force_topk and (
|
||||
should_use_flashinfer_trtllm_moe()
|
||||
or get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||
):
|
||||
|
||||
@@ -319,7 +319,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
||||
)
|
||||
|
||||
self.experts = get_moe_impl_class()(
|
||||
self.experts = get_moe_impl_class(quant_config)(
|
||||
num_experts=config.n_routed_experts
|
||||
+ self.num_fused_shared_experts
|
||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||
@@ -343,6 +343,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
||||
force_topk=quant_config is None,
|
||||
)
|
||||
|
||||
self.shared_experts_is_int8 = False
|
||||
|
||||
Reference in New Issue
Block a user