Updates Gemma3n MLP layer to adapt latest transformers version (#7573)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -62,7 +62,7 @@ class Gemma3nTextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class Gemma3nMLP(nn.Module):
|
||||
class Gemma3nTextMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -514,10 +514,11 @@ class Gemma3nDecoderLayer(nn.Module):
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
|
||||
intermediate_size = config.intermediate_size[layer_id]
|
||||
activation_sparsity = config.activation_sparsity_pattern[layer_id]
|
||||
self.mlp = Gemma3nMLP(
|
||||
self.mlp = Gemma3nTextMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_activation=config.hidden_activation,
|
||||
activation_sparsity=activation_sparsity,
|
||||
quant_config=quant_config,
|
||||
|
||||
Reference in New Issue
Block a user