[bugfix] fix norm type error in qwen3_next model (#10322)
Co-authored-by: caoyizhong.cyz <caoyizhong.cyz@alibaba-inc.com> Co-authored-by: Yi Zhang <1109276519@qq.com>
This commit is contained in:
@@ -518,24 +518,10 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
|
|||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
if getattr(
|
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False)
|
self.post_attention_layernorm = GemmaRMSNorm(
|
||||||
):
|
config.hidden_size, eps=config.rms_norm_eps
|
||||||
logger.warning_once(
|
)
|
||||||
"Using Gemma RMSNorm for input normalization and post attn normalization."
|
|
||||||
)
|
|
||||||
self.input_layernorm = GemmaRMSNorm(
|
|
||||||
config.hidden_size, eps=config.rms_norm_eps
|
|
||||||
)
|
|
||||||
self.post_attention_layernorm = GemmaRMSNorm(
|
|
||||||
config.hidden_size, eps=config.rms_norm_eps
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.post_attention_layernorm = RMSNorm(
|
|
||||||
config.hidden_size, eps=config.rms_norm_eps
|
|
||||||
)
|
|
||||||
|
|
||||||
self.layer_communicator = LayerCommunicator(
|
self.layer_communicator = LayerCommunicator(
|
||||||
layer_scatter_modes=self.layer_scatter_modes,
|
layer_scatter_modes=self.layer_scatter_modes,
|
||||||
input_layernorm=self.input_layernorm,
|
input_layernorm=self.input_layernorm,
|
||||||
@@ -685,23 +671,10 @@ class Qwen3HybridAttentionDecoderLayer(nn.Module):
|
|||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
if getattr(
|
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False)
|
self.post_attention_layernorm = GemmaRMSNorm(
|
||||||
):
|
config.hidden_size, eps=config.rms_norm_eps
|
||||||
logger.warning_once(
|
)
|
||||||
"Using Gemma RMSNorm for input normalization and post attn normalization."
|
|
||||||
)
|
|
||||||
self.input_layernorm = GemmaRMSNorm(
|
|
||||||
config.hidden_size, eps=config.rms_norm_eps
|
|
||||||
)
|
|
||||||
self.post_attention_layernorm = GemmaRMSNorm(
|
|
||||||
config.hidden_size, eps=config.rms_norm_eps
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.post_attention_layernorm = RMSNorm(
|
|
||||||
config.hidden_size, eps=config.rms_norm_eps
|
|
||||||
)
|
|
||||||
|
|
||||||
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||||
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||||
@@ -844,13 +817,7 @@ class Qwen3NextModel(nn.Module):
|
|||||||
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
|
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
|
||||||
)
|
)
|
||||||
|
|
||||||
if getattr(
|
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False)
|
|
||||||
):
|
|
||||||
logger.warning_once("Using Gemma RMSNorm for final normalization.")
|
|
||||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
else:
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.infer_count = 0
|
self.infer_count = 0
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -54,15 +54,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
|
|||||||
# (1) do not use_dedicated_mtp_embeddings provided in ckpt since not provided and directly use the target model embeddings
|
# (1) do not use_dedicated_mtp_embeddings provided in ckpt since not provided and directly use the target model embeddings
|
||||||
# (2) hardcode bias=False since not provided
|
# (2) hardcode bias=False since not provided
|
||||||
self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
|
self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
|
||||||
if getattr(
|
RMSNorm_cls = GemmaRMSNorm
|
||||||
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False)
|
|
||||||
):
|
|
||||||
logger.warning_once(
|
|
||||||
"Using Gemma RMSNorm for input normalization and post attn normalization."
|
|
||||||
)
|
|
||||||
RMSNorm_cls = GemmaRMSNorm
|
|
||||||
else:
|
|
||||||
RMSNorm_cls = RMSNorm
|
|
||||||
self.pre_fc_norm_embedding = RMSNorm_cls(
|
self.pre_fc_norm_embedding = RMSNorm_cls(
|
||||||
config.hidden_size, config.rms_norm_eps
|
config.hidden_size, config.rms_norm_eps
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user