add gemma3
This commit is contained in:
@@ -177,14 +177,17 @@ class Gemma3Attention(nn.Module):
|
|||||||
is_neox_style=True,
|
is_neox_style=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Global attention uses rope_scaling from config
|
# Global attention: extract rope_base and rope_scaling.
|
||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
# Prioritize rope_parameters dict (newer transformers) to
|
||||||
|
# avoid passing nested dicts that are unhashable.
|
||||||
|
rope_scaling = None
|
||||||
rope_base = self.rope_theta
|
rope_base = self.rope_theta
|
||||||
if rope_scaling is None and isinstance(rope_params, dict):
|
if isinstance(rope_params, dict):
|
||||||
# Try to extract from rope_parameters (newer transformers)
|
# Transformers v5: per layer_type sub-dicts
|
||||||
if "full_attention" in rope_params:
|
if "full_attention" in rope_params:
|
||||||
rp = rope_params["full_attention"]
|
rp = rope_params["full_attention"]
|
||||||
else:
|
else:
|
||||||
|
# Transformers v4: flat dict
|
||||||
rp = rope_params
|
rp = rope_params
|
||||||
rope_base = rp.get("rope_theta", self.rope_theta)
|
rope_base = rp.get("rope_theta", self.rope_theta)
|
||||||
rtype = rp.get("rope_type", None)
|
rtype = rp.get("rope_type", None)
|
||||||
@@ -193,8 +196,9 @@ class Gemma3Attention(nn.Module):
|
|||||||
k: v for k, v in rp.items()
|
k: v for k, v in rp.items()
|
||||||
if k not in ("rope_theta",)
|
if k not in ("rope_theta",)
|
||||||
}
|
}
|
||||||
rope_scaling["type"] = rope_scaling.pop("rope_type",
|
else:
|
||||||
rtype)
|
# Fallback: old-style config.rope_scaling (flat dict)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
|
|||||||
Reference in New Issue
Block a user