diff --git a/python/sglang/srt/models/gemma3n_causal.py b/python/sglang/srt/models/gemma3n_causal.py index 802cb9fc5..0f710b0f8 100644 --- a/python/sglang/srt/models/gemma3n_causal.py +++ b/python/sglang/srt/models/gemma3n_causal.py @@ -62,7 +62,7 @@ class Gemma3nTextScaledWordEmbedding(Gemma3TextScaledWordEmbedding): pass -class Gemma3nMLP(nn.Module): +class Gemma3nTextMLP(nn.Module): def __init__( self, hidden_size: int, @@ -514,10 +514,11 @@ class Gemma3nDecoderLayer(nn.Module): prefix=add_prefix("self_attn", prefix), ) + intermediate_size = config.intermediate_size[layer_id] activation_sparsity = config.activation_sparsity_pattern[layer_id] - self.mlp = Gemma3nMLP( + self.mlp = Gemma3nTextMLP( hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, + intermediate_size=intermediate_size, hidden_activation=config.hidden_activation, activation_sparsity=activation_sparsity, quant_config=quant_config,