add gemma3
This commit is contained in:
@@ -88,7 +88,6 @@ class Gemma3Attention(nn.Module):
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_dim: int,
|
head_dim: int,
|
||||||
max_position_embeddings: int,
|
max_position_embeddings: int,
|
||||||
rope_theta: float,
|
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
attn_logits_soft_cap: Optional[float] = None) -> None:
|
attn_logits_soft_cap: Optional[float] = None) -> None:
|
||||||
@@ -110,7 +109,22 @@ class Gemma3Attention(nn.Module):
|
|||||||
self.q_size = self.num_heads * self.head_dim
|
self.q_size = self.num_heads * self.head_dim
|
||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
self.scaling = config.query_pre_attn_scalar**-0.5
|
self.scaling = config.query_pre_attn_scalar**-0.5
|
||||||
self.rope_theta = rope_theta
|
|
||||||
|
# Extract rope_theta from config, compatible with both old-style
|
||||||
|
# (config.rope_theta) and new-style (config.rope_parameters dict).
|
||||||
|
rope_params = getattr(config, "rope_parameters", None)
|
||||||
|
if hasattr(config, "rope_theta"):
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
elif isinstance(rope_params, dict):
|
||||||
|
# Transformers v5: nested per layer_type
|
||||||
|
if "full_attention" in rope_params:
|
||||||
|
self.rope_theta = rope_params["full_attention"].get(
|
||||||
|
"rope_theta", 10000.0)
|
||||||
|
else:
|
||||||
|
# Transformers v4: flat dict
|
||||||
|
self.rope_theta = rope_params.get("rope_theta", 10000.0)
|
||||||
|
else:
|
||||||
|
self.rope_theta = 10000.0
|
||||||
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@@ -236,16 +250,6 @@ class Gemma3DecoderLayer(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
# Extract rope_theta: try direct attribute first, then
|
|
||||||
# rope_parameters dict (newer transformers), fallback to 10000.0
|
|
||||||
rope_params = getattr(config, "rope_parameters", None)
|
|
||||||
if hasattr(config, "rope_theta"):
|
|
||||||
rope_theta = config.rope_theta
|
|
||||||
elif isinstance(rope_params, dict):
|
|
||||||
rope_theta = rope_params.get("rope_theta", 10000.0)
|
|
||||||
else:
|
|
||||||
rope_theta = 10000.0
|
|
||||||
|
|
||||||
self.self_attn = Gemma3Attention(
|
self.self_attn = Gemma3Attention(
|
||||||
layer_idx=layer_idx,
|
layer_idx=layer_idx,
|
||||||
config=config,
|
config=config,
|
||||||
@@ -254,7 +258,6 @@ class Gemma3DecoderLayer(nn.Module):
|
|||||||
num_kv_heads=config.num_key_value_heads,
|
num_kv_heads=config.num_key_value_heads,
|
||||||
head_dim=config.head_dim,
|
head_dim=config.head_dim,
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
rope_theta=rope_theta,
|
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
# Gemma3 does not use attn logit softcapping
|
# Gemma3 does not use attn logit softcapping
|
||||||
|
|||||||
Reference in New Issue
Block a user