[bugfix] fix norm type error in qwen3_next model (#10322)

Co-authored-by: caoyizhong.cyz <caoyizhong.cyz@alibaba-inc.com>
Co-authored-by: Yi Zhang <1109276519@qq.com>
This commit is contained in:
cao1zhg
2025-09-12 00:05:59 +08:00
committed by GitHub
parent 64f296f8e6
commit 4a0e0be2a2
2 changed files with 10 additions and 51 deletions

View File

@@ -518,24 +518,10 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
hidden_act=config.hidden_act,
quant_config=quant_config,
)
if getattr(
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False)
):
logger.warning_once(
"Using Gemma RMSNorm for input normalization and post attn normalization."
)
self.input_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
else:
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
@@ -685,23 +671,10 @@ class Qwen3HybridAttentionDecoderLayer(nn.Module):
hidden_act=config.hidden_act,
quant_config=quant_config,
)
if getattr(
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False)
):
logger.warning_once(
"Using Gemma RMSNorm for input normalization and post attn normalization."
)
self.input_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
else:
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
@@ -844,13 +817,7 @@ class Qwen3NextModel(nn.Module):
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
)
if getattr(
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False)
):
logger.warning_once("Using Gemma RMSNorm for final normalization.")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.infer_count = 0
def forward(

View File

@@ -54,15 +54,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
# (1) do not use_dedicated_mtp_embeddings provided in ckpt since not provided and directly use the target model embeddings
# (2) hardcode bias=False since not provided
self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
if getattr(
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False)
):
logger.warning_once(
"Using Gemma RMSNorm for input normalization and post attn normalization."
)
RMSNorm_cls = GemmaRMSNorm
else:
RMSNorm_cls = RMSNorm
RMSNorm_cls = GemmaRMSNorm
self.pre_fc_norm_embedding = RMSNorm_cls(
config.hidden_size, config.rms_norm_eps
)