add gemma3

This commit is contained in:
Chranos
2026-02-10 14:15:33 +08:00
parent 2e24d45668
commit a7028ae481

View File

@@ -177,14 +177,17 @@ class Gemma3Attention(nn.Module):
is_neox_style=True, is_neox_style=True,
) )
else: else:
# Global attention uses rope_scaling from config # Global attention: extract rope_base and rope_scaling.
rope_scaling = getattr(config, "rope_scaling", None) # Prioritize rope_parameters dict (newer transformers) to
# avoid passing nested dicts that are unhashable.
rope_scaling = None
rope_base = self.rope_theta rope_base = self.rope_theta
if rope_scaling is None and isinstance(rope_params, dict): if isinstance(rope_params, dict):
# Try to extract from rope_parameters (newer transformers) # Transformers v5: per layer_type sub-dicts
if "full_attention" in rope_params: if "full_attention" in rope_params:
rp = rope_params["full_attention"] rp = rope_params["full_attention"]
else: else:
# Transformers v4: flat dict
rp = rope_params rp = rope_params
rope_base = rp.get("rope_theta", self.rope_theta) rope_base = rp.get("rope_theta", self.rope_theta)
rtype = rp.get("rope_type", None) rtype = rp.get("rope_type", None)
@@ -193,8 +196,9 @@ class Gemma3Attention(nn.Module):
k: v for k, v in rp.items() k: v for k, v in rp.items()
if k not in ("rope_theta",) if k not in ("rope_theta",)
} }
rope_scaling["type"] = rope_scaling.pop("rope_type", else:
rtype) # Fallback: old-style config.rope_scaling (flat dict)
rope_scaling = getattr(config, "rope_scaling", None)
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,