add gemma3
This commit is contained in:
@@ -173,3 +173,4 @@ curl http://localhost:80/v1/chat/completions \
|
|||||||
| v0.0.2 | 2026-02-04 | **Qwen3 模型支持**:实现 QK Normalization 架构适配,修复 rope_scaling 和 tokenizer 兼容性问题,解决张量连续性导致的 view 操作失败 |
|
| v0.0.2 | 2026-02-04 | **Qwen3 模型支持**:实现 QK Normalization 架构适配,修复 rope_scaling 和 tokenizer 兼容性问题,解决张量连续性导致的 view 操作失败 |
|
||||||
| v0.0.3 | 2026-02-06 | **Transformers 通用后端**:支持通过 `auto_map` 加载任意自定义 HuggingFace 模型,新增 registry 回退逻辑、Linear 返回值处理、RMSNorm 维度恢复等 |
|
| v0.0.3 | 2026-02-06 | **Transformers 通用后端**:支持通过 `auto_map` 加载任意自定义 HuggingFace 模型,新增 registry 回退逻辑、Linear 返回值处理、RMSNorm 维度恢复等 |
|
||||||
| v0.0.3.1 | 2026-02-06 | **CNNL Tensor 溢出修复**:解决极小模型在大显存设备上部署时 KV cache 元素数超过 int32 限制的问题,在 mlu_worker 和 cache_engine 中添加双重防护 |
|
| v0.0.3.1 | 2026-02-06 | **CNNL Tensor 溢出修复**:解决极小模型在大显存设备上部署时 KV cache 元素数超过 int32 限制的问题,在 mlu_worker 和 cache_engine 中添加双重防护 |
|
||||||
|
| v0.0.4 | 2026-02-10 | **Gemma3 模型支持**:新增 Gemma3ForCausalLM 模型实现(含 QK Normalization、per-layer rope 配置、layer_types 滑动窗口),修复 `patch_rope_scaling_dict` 在 rope_scaling 缺少 `rope_type` 键时崩溃的问题,更新模型注册表及 config.py 中 interleaved attention 和 dtype 自动处理逻辑 |
|
||||||
|
|||||||
@@ -140,11 +140,21 @@ class Gemma3Attention(nn.Module):
|
|||||||
self.is_sliding = (layer_idx % 2 == 1
|
self.is_sliding = (layer_idx % 2 == 1
|
||||||
and config.sliding_window is not None)
|
and config.sliding_window is not None)
|
||||||
|
|
||||||
|
# Extract rope config, compatible with both old-style (rope_theta,
|
||||||
|
# rope_scaling) and new-style (rope_parameters dict) transformers.
|
||||||
|
rope_params = getattr(config, "rope_parameters", None)
|
||||||
|
|
||||||
# Set up rope based on layer type
|
# Set up rope based on layer type
|
||||||
if self.is_sliding:
|
if self.is_sliding:
|
||||||
# Local/sliding attention uses rope_local_base_freq
|
# Local/sliding attention uses rope_local_base_freq
|
||||||
local_base = getattr(config, "rope_local_base_freq",
|
if hasattr(config, "rope_local_base_freq"):
|
||||||
self.rope_theta)
|
local_base = config.rope_local_base_freq
|
||||||
|
elif (isinstance(rope_params, dict)
|
||||||
|
and "sliding_attention" in rope_params):
|
||||||
|
local_base = rope_params["sliding_attention"].get(
|
||||||
|
"rope_theta", self.rope_theta)
|
||||||
|
else:
|
||||||
|
local_base = self.rope_theta
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
@@ -155,11 +165,27 @@ class Gemma3Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# Global attention uses rope_scaling from config
|
# Global attention uses rope_scaling from config
|
||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
rope_base = self.rope_theta
|
||||||
|
if rope_scaling is None and isinstance(rope_params, dict):
|
||||||
|
# Try to extract from rope_parameters (newer transformers)
|
||||||
|
if "full_attention" in rope_params:
|
||||||
|
rp = rope_params["full_attention"]
|
||||||
|
else:
|
||||||
|
rp = rope_params
|
||||||
|
rope_base = rp.get("rope_theta", self.rope_theta)
|
||||||
|
rtype = rp.get("rope_type", None)
|
||||||
|
if rtype and rtype != "default":
|
||||||
|
rope_scaling = {
|
||||||
|
k: v for k, v in rp.items()
|
||||||
|
if k not in ("rope_theta",)
|
||||||
|
}
|
||||||
|
rope_scaling["type"] = rope_scaling.pop("rope_type",
|
||||||
|
rtype)
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
max_position=max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
base=self.rope_theta,
|
base=rope_base,
|
||||||
is_neox_style=True,
|
is_neox_style=True,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
)
|
)
|
||||||
@@ -210,6 +236,16 @@ class Gemma3DecoderLayer(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
# Extract rope_theta: try direct attribute first, then
|
||||||
|
# rope_parameters dict (newer transformers), fallback to 10000.0
|
||||||
|
rope_params = getattr(config, "rope_parameters", None)
|
||||||
|
if hasattr(config, "rope_theta"):
|
||||||
|
rope_theta = config.rope_theta
|
||||||
|
elif isinstance(rope_params, dict):
|
||||||
|
rope_theta = rope_params.get("rope_theta", 10000.0)
|
||||||
|
else:
|
||||||
|
rope_theta = 10000.0
|
||||||
|
|
||||||
self.self_attn = Gemma3Attention(
|
self.self_attn = Gemma3Attention(
|
||||||
layer_idx=layer_idx,
|
layer_idx=layer_idx,
|
||||||
config=config,
|
config=config,
|
||||||
@@ -218,7 +254,7 @@ class Gemma3DecoderLayer(nn.Module):
|
|||||||
num_kv_heads=config.num_key_value_heads,
|
num_kv_heads=config.num_key_value_heads,
|
||||||
head_dim=config.head_dim,
|
head_dim=config.head_dim,
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
rope_theta=config.rope_theta,
|
rope_theta=rope_theta,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
# Gemma3 does not use attn logit softcapping
|
# Gemma3 does not use attn logit softcapping
|
||||||
|
|||||||
Reference in New Issue
Block a user