fix: disable dsv3_router_gemm in dsv3_nextn (#7793)

This commit is contained in:
JieXin Liang
2025-07-06 10:01:01 +08:00
committed by GitHub
parent 625018d259
commit 54411f6afa

View File

@@ -210,8 +210,10 @@ class MoEGate(nn.Module):
self,
config,
prefix: str = "",
is_nextn: bool = False,
):
super().__init__()
self.is_nextn = is_nextn
self.weight = nn.Parameter(
torch.empty((config.n_routed_experts, config.hidden_size))
)
@@ -233,8 +235,10 @@ class MoEGate(nn.Module):
True, # is_vnni
)
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
if (
_is_cuda
and not self.is_nextn
and hidden_states.shape[0] < 4
and hidden_states.shape[1] == 7168
and self.weight.shape[0] == 256
@@ -258,6 +262,7 @@ class DeepseekV2MoE(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
is_nextn: bool = False,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
@@ -284,7 +289,9 @@ class DeepseekV2MoE(nn.Module):
"Only silu is supported for now."
)
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
self.gate = MoEGate(
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
)
self.experts = get_moe_impl_class()(
num_experts=config.n_routed_experts
@@ -1776,6 +1783,7 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=add_prefix("mlp", prefix),
layer_id=self.layer_id,
alt_stream=alt_stream,
is_nextn=is_nextn,
)
else:
if enable_moe_dense_fully_dp():