Fix trtllm_moe wrong correction bias (#10440)

This commit is contained in:
fzyzcjy
2025-09-15 16:02:05 +08:00
committed by GitHub
parent 50dc0c1e9c
commit 059c13de5c
2 changed files with 14 additions and 10 deletions

View File

@@ -213,14 +213,6 @@ class TopK(CustomOp):
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
if (
quant_config is not None
and quant_config.get_name() == "modelopt_fp4"
and should_use_flashinfer_trtllm_moe()
):
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
correction_bias = correction_bias.to(torch.bfloat16)
self.topk_config = TopKConfig(
top_k=top_k,
use_grouped_topk=use_grouped_topk,

View File

@@ -65,6 +65,7 @@ from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
@@ -269,6 +270,7 @@ class MoEGate(nn.Module):
def __init__(
self,
config,
quant_config,
prefix: str = "",
is_nextn: bool = False,
):
@@ -278,8 +280,15 @@ class MoEGate(nn.Module):
torch.empty((config.n_routed_experts, config.hidden_size))
)
if config.topk_method == "noaux_tc":
correction_bias_dtype = (
torch.bfloat16
if quant_config is not None
and quant_config.get_name() == "modelopt_fp4"
and should_use_flashinfer_trtllm_moe()
else torch.float32
)
self.e_score_correction_bias = nn.Parameter(
torch.empty((config.n_routed_experts), dtype=torch.float32)
torch.empty((config.n_routed_experts), dtype=correction_bias_dtype)
)
else:
self.e_score_correction_bias = None
@@ -354,7 +363,10 @@ class DeepseekV2MoE(nn.Module):
)
self.gate = MoEGate(
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
config=config,
quant_config=quant_config,
prefix=add_prefix("gate", prefix),
is_nextn=is_nextn,
)
self.experts = get_moe_impl_class(quant_config)(