diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index fd0d0e942..cdba9975f 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -518,24 +518,10 @@ class Qwen3HybridLinearDecoderLayer(nn.Module): hidden_act=config.hidden_act, quant_config=quant_config, ) - if getattr( - config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False) - ): - logger.warning_once( - "Using Gemma RMSNorm for input normalization and post attn normalization." - ) - self.input_layernorm = GemmaRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = GemmaRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - else: - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - + self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.layer_communicator = LayerCommunicator( layer_scatter_modes=self.layer_scatter_modes, input_layernorm=self.input_layernorm, @@ -685,23 +671,10 @@ class Qwen3HybridAttentionDecoderLayer(nn.Module): hidden_act=config.hidden_act, quant_config=quant_config, ) - if getattr( - config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False) - ): - logger.warning_once( - "Using Gemma RMSNorm for input normalization and post attn normalization." - ) - self.input_layernorm = GemmaRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = GemmaRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - else: - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) @@ -844,13 +817,7 @@ class Qwen3NextModel(nn.Module): config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" ) - if getattr( - config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False) - ): - logger.warning_once("Using Gemma RMSNorm for final normalization.") - self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - else: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.infer_count = 0 def forward( diff --git a/python/sglang/srt/models/qwen3_next_mtp.py b/python/sglang/srt/models/qwen3_next_mtp.py index 4630ea300..a9da0867d 100644 --- a/python/sglang/srt/models/qwen3_next_mtp.py +++ b/python/sglang/srt/models/qwen3_next_mtp.py @@ -54,15 +54,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM): # (1) do not use_dedicated_mtp_embeddings provided in ckpt since not provided and directly use the target model embeddings # (2) hardcode bias=False since not provided self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) - if getattr( - config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False) - ): - logger.warning_once( - "Using Gemma RMSNorm for input normalization and post attn normalization." - ) - RMSNorm_cls = GemmaRMSNorm - else: - RMSNorm_cls = RMSNorm + RMSNorm_cls = GemmaRMSNorm self.pre_fc_norm_embedding = RMSNorm_cls( config.hidden_size, config.rms_norm_eps )