add gemma3

This commit is contained in:
Chranos
2026-02-10 14:15:33 +08:00
parent 2e24d45668
commit a7028ae481

View File

@@ -177,14 +177,17 @@ class Gemma3Attention(nn.Module):
is_neox_style=True,
)
else:
# Global attention uses rope_scaling from config
rope_scaling = getattr(config, "rope_scaling", None)
# Global attention: extract rope_base and rope_scaling.
# Prioritize rope_parameters dict (newer transformers) to
# avoid passing nested dicts that are unhashable.
rope_scaling = None
rope_base = self.rope_theta
if rope_scaling is None and isinstance(rope_params, dict):
# Try to extract from rope_parameters (newer transformers)
if isinstance(rope_params, dict):
# Transformers v5: per layer_type sub-dicts
if "full_attention" in rope_params:
rp = rope_params["full_attention"]
else:
# Transformers v4: flat dict
rp = rope_params
rope_base = rp.get("rope_theta", self.rope_theta)
rtype = rp.get("rope_type", None)
@@ -193,8 +196,9 @@ class Gemma3Attention(nn.Module):
k: v for k, v in rp.items()
if k not in ("rope_theta",)
}
rope_scaling["type"] = rope_scaling.pop("rope_type",
rtype)
else:
# Fallback: old-style config.rope_scaling (flat dict)
rope_scaling = getattr(config, "rope_scaling", None)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,