Fix trtllm_moe wrong correction bias (#10440)
This commit is contained in:
@@ -213,14 +213,6 @@ class TopK(CustomOp):
|
|||||||
if use_grouped_topk:
|
if use_grouped_topk:
|
||||||
assert num_expert_group is not None and topk_group is not None
|
assert num_expert_group is not None and topk_group is not None
|
||||||
|
|
||||||
if (
|
|
||||||
quant_config is not None
|
|
||||||
and quant_config.get_name() == "modelopt_fp4"
|
|
||||||
and should_use_flashinfer_trtllm_moe()
|
|
||||||
):
|
|
||||||
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
|
|
||||||
correction_bias = correction_bias.to(torch.bfloat16)
|
|
||||||
|
|
||||||
self.topk_config = TopKConfig(
|
self.topk_config = TopKConfig(
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
use_grouped_topk=use_grouped_topk,
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ from sglang.srt.layers.moe import (
|
|||||||
get_deepep_mode,
|
get_deepep_mode,
|
||||||
get_moe_a2a_backend,
|
get_moe_a2a_backend,
|
||||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||||
|
should_use_flashinfer_trtllm_moe,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
@@ -269,6 +270,7 @@ class MoEGate(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
|
quant_config,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
is_nextn: bool = False,
|
is_nextn: bool = False,
|
||||||
):
|
):
|
||||||
@@ -278,8 +280,15 @@ class MoEGate(nn.Module):
|
|||||||
torch.empty((config.n_routed_experts, config.hidden_size))
|
torch.empty((config.n_routed_experts, config.hidden_size))
|
||||||
)
|
)
|
||||||
if config.topk_method == "noaux_tc":
|
if config.topk_method == "noaux_tc":
|
||||||
|
correction_bias_dtype = (
|
||||||
|
torch.bfloat16
|
||||||
|
if quant_config is not None
|
||||||
|
and quant_config.get_name() == "modelopt_fp4"
|
||||||
|
and should_use_flashinfer_trtllm_moe()
|
||||||
|
else torch.float32
|
||||||
|
)
|
||||||
self.e_score_correction_bias = nn.Parameter(
|
self.e_score_correction_bias = nn.Parameter(
|
||||||
torch.empty((config.n_routed_experts), dtype=torch.float32)
|
torch.empty((config.n_routed_experts), dtype=correction_bias_dtype)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.e_score_correction_bias = None
|
self.e_score_correction_bias = None
|
||||||
@@ -354,7 +363,10 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.gate = MoEGate(
|
self.gate = MoEGate(
|
||||||
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
config=config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("gate", prefix),
|
||||||
|
is_nextn=is_nextn,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.experts = get_moe_impl_class(quant_config)(
|
self.experts = get_moe_impl_class(quant_config)(
|
||||||
|
|||||||
Reference in New Issue
Block a user