14 Commits

Author SHA1 Message Date
Chranos
fa194c215b add gemma3 2026-02-10 14:52:56 +08:00
Chranos
5fbe8b20a7 add gemma3 2026-02-10 14:26:03 +08:00
Chranos
2dad4e71c5 add gemma3 2026-02-10 14:15:33 +08:00
Chranos
cb1846cd4f add gemma3 2026-02-10 14:10:04 +08:00
Chranos
81fc273396 add gemma3 2026-02-10 14:06:26 +08:00
Chranos
3ef89630ab add gemma3 2026-02-10 13:00:25 +08:00
Chranos
40dee08f7b fix: handle missing tie_word_embeddings attr in MPTConfig
Use getattr with default True for MPTConfig.tie_word_embeddings,
as some MPT model configs lack this attribute.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-02-09 17:47:18 +08:00
Chranos
1d70f93cfc debugging 2026-02-09 15:24:55 +08:00
Chranos
8ecba6115e fix: add logger import to llama.py for unknown weight skip warning
The previous commit added a warning log for skipping unknown weights
(e.g. embed_tokens.biases) but missed importing the logger.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-02-09 13:13:56 +08:00
Chranos
65ad893ee7 debugging 2026-02-09 13:00:35 +08:00
Chranos
d08217307d update README 2026-02-09 11:46:04 +08:00
Chranos
8ac4215755 update README 2026-02-09 11:44:52 +08:00
Chranos
a095dede48 fixed kvcache bug 2026-02-06 17:10:36 +08:00
Chranos
374826c841 fixing kvcache bug 2026-02-06 16:25:54 +08:00
9 changed files with 595 additions and 14 deletions

View File

@@ -163,5 +163,15 @@ curl http://localhost:80/v1/chat/completions \
| 模型名称 | mlu370-X8首字延迟(秒) | mlu370-X8输入处理速度(字每秒) | mlu370-X8输出速度(字每秒) | mlu370-X8输出质量 | Nvidia A100字延迟(秒) | Nvidia A100输入处理速度(字每秒) | Nvidia A100输出速度(字每秒) | Nvidia A100输出质量 |
| ------------------- | ------------------- | -------------------| ------------------- | ------------------- | ------------------- | ------------------- | ------------------- | ------------------- |
| Qwen/Qwen-1_8B |0.203 | 13493.2 | 119.2 | 10.0 | 0.052 | 25591.5 | 165.0 | 15.0|
| Qwen/Qwen1.5-0.5B |0.132 | 12366.6 | 106.9 | 15.0 | 0.066 | 24935.4 | 151.4 | 10.0|
| Qwen/Qwen-1_8B |0.203 | 13493.2 | 119.2 | 10.0 | 0.052 | 25591.5 | 165.0 | 15.0|
| Qwen/Qwen1.5-0.5B |0.132 | 12366.6 | 106.9 | 15.0 | 0.066 | 24935.4 | 151.4 | 10.0|
## 版本更新记录
| 版本 | 日期 | 更新内容 |
|------|------|----------|
| v0.0.2 | 2026-02-04 | **Qwen3 模型支持**:实现 QK Normalization 架构适配,修复 rope_scaling 和 tokenizer 兼容性问题,解决张量连续性导致的 view 操作失败 |
| v0.0.3 | 2026-02-06 | **Transformers 通用后端**:支持通过 `auto_map` 加载任意自定义 HuggingFace 模型,新增 registry 回退逻辑、Linear 返回值处理、RMSNorm 维度恢复等 |
| v0.0.3.1 | 2026-02-06 | **CNNL Tensor 溢出修复**:解决极小模型在大显存设备上部署时 KV cache 元素数超过 int32 限制的问题,在 mlu_worker 和 cache_engine 中添加双重防护 |
| v0.0.4 | 2026-02-10 | **Gemma3 模型支持**:新增 Gemma3ForCausalLM 模型实现(含 QK Normalization、per-layer rope 配置、layer_types 滑动窗口),修复 `patch_rope_scaling_dict` 在 rope_scaling 缺少 `rope_type` 键时崩溃的问题,更新模型注册表及 config.py 中 interleaved attention 和 dtype 自动处理逻辑 |
| v0.0.4.1 | 2026-02-10 | **Gemma3 rope 兼容性修复**:修复新版 transformers `Gemma3TextConfig` 缺少 `rope_theta` 属性的问题,从 `rope_parameters` 字典兼容提取 rope 配置(支持 Transformers v4/v5修复 `rope_scaling` 嵌套字典导致 `get_rope` 缓存 unhashable 的问题;适配 MLU `forward_mlu` 接口,将 q/k 合并为单张量调用 rotary_emb 后再拆分 |

View File

@@ -226,7 +226,7 @@ class ModelConfig:
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
has_interleaved_attention = (sliding_window is not None) and (
isinstance(sliding_window, list) or
(self.hf_text_config.model_type in ["gemma2"]))
(self.hf_text_config.model_type in ["gemma2", "gemma3"]))
if (not self.disable_sliding_window and has_interleaved_attention):
sliding_window_len_min = get_min_sliding_window(
@@ -1854,9 +1854,9 @@ def _get_and_verify_dtype(
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
if config.model_type == "gemma2":
if config.model_type in ("gemma2", "gemma3"):
logger.info(
"For Gemma 2, we downcast float32 to bfloat16 instead "
"For Gemma 2/3, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16.")
torch_dtype = torch.bfloat16

View File

@@ -0,0 +1,507 @@
# Copyright 2024 The vLLM team.
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma3 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
class Gemma3MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_activation: str,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_activation != "gelu_pytorch_tanh":
raise ValueError(
"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
"function. Please set `hidden_activation` to "
"`gelu_pytorch_tanh`.")
self.act_fn = GeluAndMul(approximate="tanh")
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Gemma3Attention(nn.Module):
def __init__(self,
layer_idx: int,
config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_position_embeddings: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
attn_logits_soft_cap: Optional[float] = None) -> None:
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = config.query_pre_attn_scalar**-0.5
# Extract rope_theta from config, compatible with both old-style
# (config.rope_theta) and new-style (config.rope_parameters dict).
rope_params = getattr(config, "rope_parameters", None)
if hasattr(config, "rope_theta"):
self.rope_theta = config.rope_theta
elif isinstance(rope_params, dict):
# Transformers v5: nested per layer_type
if "full_attention" in rope_params:
self.rope_theta = rope_params["full_attention"].get(
"rope_theta", 10000.0)
else:
# Transformers v4: flat dict
self.rope_theta = rope_params.get("rope_theta", 10000.0)
else:
self.rope_theta = 10000.0
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.attention_bias,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
)
# Gemma3 specific: QK normalization
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
# Determine layer type and rope config
layer_types = getattr(config, "layer_types", None)
if layer_types is not None:
layer_type = layer_types[layer_idx]
self.is_sliding = (layer_type == "sliding_attention")
else:
self.is_sliding = (layer_idx % 2 == 1
and config.sliding_window is not None)
# Extract rope config, compatible with both old-style (rope_theta,
# rope_scaling) and new-style (rope_parameters dict) transformers.
rope_params = getattr(config, "rope_parameters", None)
# Set up rope based on layer type
if self.is_sliding:
# Local/sliding attention uses rope_local_base_freq
if hasattr(config, "rope_local_base_freq"):
local_base = config.rope_local_base_freq
elif (isinstance(rope_params, dict)
and "sliding_attention" in rope_params):
local_base = rope_params["sliding_attention"].get(
"rope_theta", self.rope_theta)
else:
local_base = self.rope_theta
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=local_base,
is_neox_style=True,
)
else:
# Global attention: extract rope_base and rope_scaling.
# Prioritize rope_parameters dict (newer transformers) to
# avoid passing nested dicts that are unhashable.
rope_scaling = None
rope_base = self.rope_theta
if isinstance(rope_params, dict):
# Transformers v5: per layer_type sub-dicts
if "full_attention" in rope_params:
rp = rope_params["full_attention"]
else:
# Transformers v4: flat dict
rp = rope_params
rope_base = rp.get("rope_theta", self.rope_theta)
rtype = rp.get("rope_type", None)
if rtype and rtype != "default":
rope_scaling = {
k: v for k, v in rp.items()
if k not in ("rope_theta",)
}
else:
# Fallback: old-style config.rope_scaling (flat dict)
rope_scaling = getattr(config, "rope_scaling", None)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_base,
is_neox_style=True,
rope_scaling=rope_scaling,
)
# NOTE: Like Gemma2, vLLM currently ignores sliding window
# and uses global attention for all layers.
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
# Gemma3 specific: apply QK normalization
q = q.unflatten(-1, (self.num_heads, self.head_dim))
q = self.q_norm(q)
q = q.flatten(-2, -1)
k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
k = self.k_norm(k)
k = k.flatten(-2, -1)
# MLU rotary_emb expects a single concatenated tensor, not
# separate q and k (forward_mlu signature differs from forward_native).
qk = torch.cat([q, k], dim=-1)
self.rotary_emb(positions,
qk.view(-1, self.num_heads + self.num_kv_heads,
self.head_dim))
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class Gemma3DecoderLayer(nn.Module):
def __init__(
self,
layer_idx: int,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Gemma3Attention(
layer_idx=layer_idx,
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
# Gemma3 does not use attn logit softcapping
attn_logits_soft_cap=getattr(config,
"attn_logit_softcapping", None),
)
self.hidden_size = config.hidden_size
self.mlp = Gemma3MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_activation=config.hidden_activation,
quant_config=quant_config,
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, residual = self.pre_feedforward_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
return hidden_states, residual
class Gemma3Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma3DecoderLayer(
int(prefix.split(".")[-1]),
config, cache_config, quant_config),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer))
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
hidden_states *= self.normalizer
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
logger.warning(
"Some weights are not initialized from checkpoints: %s",
unloaded_params)
class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
del lora_config # Unused.
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Gemma3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
# Gemma3 may or may not have final_logit_softcapping
soft_cap = getattr(config, "final_logit_softcapping", None)
self.logits_processor = LogitsProcessor(
config.vocab_size, soft_cap=soft_cap)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.embed_tokens,
hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
loader.load_weights(weights)

View File

@@ -26,6 +26,10 @@ import torch
from torch import nn
from transformers import LlamaConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
@@ -404,6 +408,12 @@ class LlamaModel(nn.Module):
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
logger.warning(
"Skipping weight %s not present in the model",
name)
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)

View File

@@ -272,7 +272,7 @@ class MPTForCausalLM(nn.Module, SupportsPP):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
assert config.tie_word_embeddings
assert getattr(config, "tie_word_embeddings", True)
self.quant_config = quant_config
self.transformer = MPTModel(vllm_config=vllm_config,

View File

@@ -28,6 +28,9 @@ from .interfaces_base import is_embedding_model, is_text_generation_model
logger = init_logger(__name__)
# Cache for architectures that have already been logged
_logged_transformers_architectures: set = set()
# yapf: disable
_TEXT_GENERATION_MODELS = {
# [Decoder-only]
@@ -49,6 +52,7 @@ _TEXT_GENERATION_MODELS = {
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
@@ -403,11 +407,14 @@ class _ModelRegistry:
model_module = getattr(transformers, architecture, None)
if model_module is not None:
# Model exists in transformers, can use TransformersForCausalLM wrapper
logger.info(
"Architecture %s found in transformers library, "
"using TransformersForCausalLM wrapper",
architecture
)
# Only log once per architecture to avoid spam
if architecture not in _logged_transformers_architectures:
_logged_transformers_architectures.add(architecture)
logger.info(
"Architecture %s found in transformers library, "
"using TransformersForCausalLM wrapper",
architecture
)
return "TransformersForCausalLM"
# Get auto_map from hf_config

View File

@@ -112,7 +112,9 @@ def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None:
logger.info("Replacing legacy 'type' key with 'rope_type'")
if "rope_type" not in rope_scaling:
raise ValueError("rope_scaling should have a 'rope_type' key")
rope_scaling["rope_type"] = "default"
logger.warning("rope_scaling missing 'rope_type' key, "
"defaulting to 'default'")
if rope_scaling["rope_type"] == "su":
rope_scaling["rope_type"] = "longrope"

View File

@@ -24,8 +24,29 @@ def vllm__worker__cache_engine__CacheEngine___allocate_kv_cache(
=============================
Modify by vllm_mlu
=============================
@brief: add kv_cache_scale for int8 support
'''
@brief: add kv_cache_scale for int8 support;
cap num_blocks to avoid exceeding CNNL int32 element limit
'''
# CNNL operators have a max supported tensor element count of INT32_MAX.
# num_blocks should already be capped by determine_num_available_blocks,
# this is a defensive check to catch any edge cases.
CNNL_MAX_TENSOR_ELEMENTS = 2**31 - 1
total_elements = 1
for dim in kv_cache_shape:
total_elements *= dim
if total_elements > CNNL_MAX_TENSOR_ELEMENTS:
elements_per_block = total_elements // num_blocks
max_num_blocks = CNNL_MAX_TENSOR_ELEMENTS // elements_per_block
logger.warning(
"KV cache tensor elements (%d) exceed CNNL max (%d). "
"Reducing num_blocks from %d to %d. This indicates "
"determine_num_available_blocks did not cap correctly.",
total_elements, CNNL_MAX_TENSOR_ELEMENTS,
num_blocks, max_num_blocks)
num_blocks = max_num_blocks
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
kv_cache_scales_shape = self.attn_backend.get_kv_cache_scale_shape(
num_blocks, self.block_size, self.num_kv_heads)
pin_memory = is_pin_memory_available() if device == "cpu" else False

View File

@@ -95,6 +95,30 @@ class MLUWorker_V2(MLUWorker):
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
# Cap num_gpu_blocks to avoid exceeding CNNL's int32 tensor element
# limit. CNNL operators do not support tensors with more than
# 2^31 - 1 elements. The KV cache shape is typically
# (2, num_blocks, num_kv_heads, block_size, head_size), and when
# num_blocks is very large (e.g. for tiny models with huge free
# memory), the total element count can overflow.
CNNL_MAX_TENSOR_ELEMENTS = 2**31 - 1
block_size = self.cache_config.block_size
num_kv_heads = self.model_config.get_num_kv_heads(
self.parallel_config)
head_size = self.model_config.get_head_size()
# kv_cache_shape = (2, num_blocks, num_kv_heads, block_size, head_size)
elements_per_block = 2 * num_kv_heads * block_size * head_size
if elements_per_block > 0:
max_blocks_by_cnnl = CNNL_MAX_TENSOR_ELEMENTS // elements_per_block
if num_gpu_blocks > max_blocks_by_cnnl:
logger.warning(
"Reducing num_gpu_blocks from %d to %d to stay within "
"CNNL max tensor element limit (%d). "
"elements_per_block=%d",
num_gpu_blocks, max_blocks_by_cnnl,
CNNL_MAX_TENSOR_ELEMENTS, elements_per_block)
num_gpu_blocks = max_blocks_by_cnnl
logger.info(
"Memory profiling results: total_gpu_memory=%.2fGiB"
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"