[bugfix] fix norm type error in qwen3_next model (#10322)
Co-authored-by: caoyizhong.cyz <caoyizhong.cyz@alibaba-inc.com> Co-authored-by: Yi Zhang <1109276519@qq.com>
This commit is contained in:
@@ -518,24 +518,10 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if getattr(
|
||||
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False)
|
||||
):
|
||||
logger.warning_once(
|
||||
"Using Gemma RMSNorm for input normalization and post attn normalization."
|
||||
)
|
||||
self.input_layernorm = GemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
else:
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.layer_communicator = LayerCommunicator(
|
||||
layer_scatter_modes=self.layer_scatter_modes,
|
||||
input_layernorm=self.input_layernorm,
|
||||
@@ -685,23 +671,10 @@ class Qwen3HybridAttentionDecoderLayer(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if getattr(
|
||||
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False)
|
||||
):
|
||||
logger.warning_once(
|
||||
"Using Gemma RMSNorm for input normalization and post attn normalization."
|
||||
)
|
||||
self.input_layernorm = GemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
else:
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
@@ -844,13 +817,7 @@ class Qwen3NextModel(nn.Module):
|
||||
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
|
||||
)
|
||||
|
||||
if getattr(
|
||||
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False)
|
||||
):
|
||||
logger.warning_once("Using Gemma RMSNorm for final normalization.")
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.infer_count = 0
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -54,15 +54,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
|
||||
# (1) do not use_dedicated_mtp_embeddings provided in ckpt since not provided and directly use the target model embeddings
|
||||
# (2) hardcode bias=False since not provided
|
||||
self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
|
||||
if getattr(
|
||||
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False)
|
||||
):
|
||||
logger.warning_once(
|
||||
"Using Gemma RMSNorm for input normalization and post attn normalization."
|
||||
)
|
||||
RMSNorm_cls = GemmaRMSNorm
|
||||
else:
|
||||
RMSNorm_cls = RMSNorm
|
||||
RMSNorm_cls = GemmaRMSNorm
|
||||
self.pre_fc_norm_embedding = RMSNorm_cls(
|
||||
config.hidden_size, config.rms_norm_eps
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user