add gemma3
This commit is contained in:
@@ -177,14 +177,17 @@ class Gemma3Attention(nn.Module):
|
||||
is_neox_style=True,
|
||||
)
|
||||
else:
|
||||
# Global attention uses rope_scaling from config
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
# Global attention: extract rope_base and rope_scaling.
|
||||
# Prioritize rope_parameters dict (newer transformers) to
|
||||
# avoid passing nested dicts that are unhashable.
|
||||
rope_scaling = None
|
||||
rope_base = self.rope_theta
|
||||
if rope_scaling is None and isinstance(rope_params, dict):
|
||||
# Try to extract from rope_parameters (newer transformers)
|
||||
if isinstance(rope_params, dict):
|
||||
# Transformers v5: per layer_type sub-dicts
|
||||
if "full_attention" in rope_params:
|
||||
rp = rope_params["full_attention"]
|
||||
else:
|
||||
# Transformers v4: flat dict
|
||||
rp = rope_params
|
||||
rope_base = rp.get("rope_theta", self.rope_theta)
|
||||
rtype = rp.get("rope_type", None)
|
||||
@@ -193,8 +196,9 @@ class Gemma3Attention(nn.Module):
|
||||
k: v for k, v in rp.items()
|
||||
if k not in ("rope_theta",)
|
||||
}
|
||||
rope_scaling["type"] = rope_scaling.pop("rope_type",
|
||||
rtype)
|
||||
else:
|
||||
# Fallback: old-style config.rope_scaling (flat dict)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
|
||||
Reference in New Issue
Block a user