add gemma3

This commit is contained in:
Chranos
2026-02-10 14:10:04 +08:00
parent 5b9e02990a
commit 2e24d45668

View File

@@ -88,7 +88,6 @@ class Gemma3Attention(nn.Module):
num_kv_heads: int, num_kv_heads: int,
head_dim: int, head_dim: int,
max_position_embeddings: int, max_position_embeddings: int,
rope_theta: float,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
attn_logits_soft_cap: Optional[float] = None) -> None: attn_logits_soft_cap: Optional[float] = None) -> None:
@@ -110,7 +109,22 @@ class Gemma3Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = config.query_pre_attn_scalar**-0.5 self.scaling = config.query_pre_attn_scalar**-0.5
self.rope_theta = rope_theta
# Extract rope_theta from config, compatible with both old-style
# (config.rope_theta) and new-style (config.rope_parameters dict).
rope_params = getattr(config, "rope_parameters", None)
if hasattr(config, "rope_theta"):
self.rope_theta = config.rope_theta
elif isinstance(rope_params, dict):
# Transformers v5: nested per layer_type
if "full_attention" in rope_params:
self.rope_theta = rope_params["full_attention"].get(
"rope_theta", 10000.0)
else:
# Transformers v4: flat dict
self.rope_theta = rope_params.get("rope_theta", 10000.0)
else:
self.rope_theta = 10000.0
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
@@ -236,16 +250,6 @@ class Gemma3DecoderLayer(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Extract rope_theta: try direct attribute first, then
# rope_parameters dict (newer transformers), fallback to 10000.0
rope_params = getattr(config, "rope_parameters", None)
if hasattr(config, "rope_theta"):
rope_theta = config.rope_theta
elif isinstance(rope_params, dict):
rope_theta = rope_params.get("rope_theta", 10000.0)
else:
rope_theta = 10000.0
self.self_attn = Gemma3Attention( self.self_attn = Gemma3Attention(
layer_idx=layer_idx, layer_idx=layer_idx,
config=config, config=config,
@@ -254,7 +258,6 @@ class Gemma3DecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim, head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings, max_position_embeddings=config.max_position_embeddings,
rope_theta=rope_theta,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
# Gemma3 does not use attn logit softcapping # Gemma3 does not use attn logit softcapping