Updates Gemma3n MLP layer to adapt latest transformers version (#7573)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
Xinyuan Tong
2025-06-26 15:07:22 -07:00
committed by GitHub
parent 1b8cf77b01
commit 604efe07e1

View File

@@ -62,7 +62,7 @@ class Gemma3nTextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
pass
class Gemma3nMLP(nn.Module):
class Gemma3nTextMLP(nn.Module):
def __init__(
self,
hidden_size: int,
@@ -514,10 +514,11 @@ class Gemma3nDecoderLayer(nn.Module):
prefix=add_prefix("self_attn", prefix),
)
intermediate_size = config.intermediate_size[layer_id]
activation_sparsity = config.activation_sparsity_pattern[layer_id]
self.mlp = Gemma3nMLP(
self.mlp = Gemma3nTextMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
intermediate_size=intermediate_size,
hidden_activation=config.hidden_activation,
activation_sparsity=activation_sparsity,
quant_config=quant_config,