Fix trtllm_moe wrong correction bias (#10440)
This commit is contained in:
@@ -213,14 +213,6 @@ class TopK(CustomOp):
|
||||
if use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
|
||||
if (
|
||||
quant_config is not None
|
||||
and quant_config.get_name() == "modelopt_fp4"
|
||||
and should_use_flashinfer_trtllm_moe()
|
||||
):
|
||||
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
|
||||
correction_bias = correction_bias.to(torch.bfloat16)
|
||||
|
||||
self.topk_config = TopKConfig(
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
|
||||
@@ -65,6 +65,7 @@ from sglang.srt.layers.moe import (
|
||||
get_deepep_mode,
|
||||
get_moe_a2a_backend,
|
||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
@@ -269,6 +270,7 @@ class MoEGate(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config,
|
||||
prefix: str = "",
|
||||
is_nextn: bool = False,
|
||||
):
|
||||
@@ -278,8 +280,15 @@ class MoEGate(nn.Module):
|
||||
torch.empty((config.n_routed_experts, config.hidden_size))
|
||||
)
|
||||
if config.topk_method == "noaux_tc":
|
||||
correction_bias_dtype = (
|
||||
torch.bfloat16
|
||||
if quant_config is not None
|
||||
and quant_config.get_name() == "modelopt_fp4"
|
||||
and should_use_flashinfer_trtllm_moe()
|
||||
else torch.float32
|
||||
)
|
||||
self.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty((config.n_routed_experts), dtype=torch.float32)
|
||||
torch.empty((config.n_routed_experts), dtype=correction_bias_dtype)
|
||||
)
|
||||
else:
|
||||
self.e_score_correction_bias = None
|
||||
@@ -354,7 +363,10 @@ class DeepseekV2MoE(nn.Module):
|
||||
)
|
||||
|
||||
self.gate = MoEGate(
|
||||
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate", prefix),
|
||||
is_nextn=is_nextn,
|
||||
)
|
||||
|
||||
self.experts = get_moe_impl_class(quant_config)(
|
||||
|
||||
Reference in New Issue
Block a user