add gemma3

This commit is contained in:
Chranos
2026-02-10 14:06:26 +08:00
parent ff94650fd1
commit 5b9e02990a
2 changed files with 41 additions and 4 deletions

View File

@@ -140,11 +140,21 @@ class Gemma3Attention(nn.Module):
self.is_sliding = (layer_idx % 2 == 1
and config.sliding_window is not None)
# Extract rope config, compatible with both old-style (rope_theta,
# rope_scaling) and new-style (rope_parameters dict) transformers.
rope_params = getattr(config, "rope_parameters", None)
# Set up rope based on layer type
if self.is_sliding:
# Local/sliding attention uses rope_local_base_freq
local_base = getattr(config, "rope_local_base_freq",
self.rope_theta)
if hasattr(config, "rope_local_base_freq"):
local_base = config.rope_local_base_freq
elif (isinstance(rope_params, dict)
and "sliding_attention" in rope_params):
local_base = rope_params["sliding_attention"].get(
"rope_theta", self.rope_theta)
else:
local_base = self.rope_theta
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
@@ -155,11 +165,27 @@ class Gemma3Attention(nn.Module):
else:
# Global attention uses rope_scaling from config
rope_scaling = getattr(config, "rope_scaling", None)
rope_base = self.rope_theta
if rope_scaling is None and isinstance(rope_params, dict):
# Try to extract from rope_parameters (newer transformers)
if "full_attention" in rope_params:
rp = rope_params["full_attention"]
else:
rp = rope_params
rope_base = rp.get("rope_theta", self.rope_theta)
rtype = rp.get("rope_type", None)
if rtype and rtype != "default":
rope_scaling = {
k: v for k, v in rp.items()
if k not in ("rope_theta",)
}
rope_scaling["type"] = rope_scaling.pop("rope_type",
rtype)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=self.rope_theta,
base=rope_base,
is_neox_style=True,
rope_scaling=rope_scaling,
)
@@ -210,6 +236,16 @@ class Gemma3DecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Extract rope_theta: try direct attribute first, then
# rope_parameters dict (newer transformers), fallback to 10000.0
rope_params = getattr(config, "rope_parameters", None)
if hasattr(config, "rope_theta"):
rope_theta = config.rope_theta
elif isinstance(rope_params, dict):
rope_theta = rope_params.get("rope_theta", 10000.0)
else:
rope_theta = 10000.0
self.self_attn = Gemma3Attention(
layer_idx=layer_idx,
config=config,
@@ -218,7 +254,7 @@ class Gemma3DecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
# Gemma3 does not use attn logit softcapping