71 Commits
v0.0.3 ... main

Author SHA1 Message Date
Chranos
91cd25a8d1 update readme 2026-02-11 17:59:22 +08:00
Chranos
bc9ae6a58a update readme 2026-02-11 17:57:20 +08:00
Chranos
cfc0614191 update readme 2026-02-11 17:47:15 +08:00
Chranos
01560f8227 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
19c3cfb624 add llama4 2026-02-11 17:47:15 +08:00
Chranos
8a85f8580f add llama4 2026-02-11 17:47:15 +08:00
Chranos
5457f79dbb add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
1f77771852 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
e0bd67be53 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
d860f71e4d add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
8657cbec87 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
597187b7e5 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
72507b7703 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
a69129d5b5 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
f6d6f69abc add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
9b05d7285e add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
cba7ad6c59 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
db876765ed add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
78814aaa68 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
45e1fa8bb3 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
1a3e04b0e4 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
5ed7baa68e add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
a21eae79a1 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
5da783780d add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
5c980830a0 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
00083a1c76 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
4ed73b2ef6 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
386b7ec8c7 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
01c0b3d345 add deepseekv3 and llama4 2026-02-11 17:47:15 +08:00
Chranos
4a72c4c91a add deepseekv3 2026-02-11 17:47:15 +08:00
Chranos
9acbda437e add deepseekv3 2026-02-11 17:47:15 +08:00
Chranos
dfb4cff2fc add deepseekv3 2026-02-11 17:47:15 +08:00
Chranos
6c222e8f14 add deepseekv3 2026-02-11 17:47:15 +08:00
Chranos
3ec228b6fa add deepseekv3 2026-02-11 17:47:15 +08:00
Chranos
463fbf8cd1 add qwen3_moe 2026-02-11 17:47:15 +08:00
Chranos
6f6997bafb add qwen3_moe 2026-02-11 17:47:14 +08:00
Chranos
6479429662 add qwen3_moe 2026-02-11 17:47:14 +08:00
Chranos
2a9f483af8 add qwen3_moe 2026-02-11 17:47:14 +08:00
Chranos
cf92e95688 add qwen3_moe 2026-02-11 17:47:14 +08:00
Chranos
d7f5ef1db9 add qwen3_moe 2026-02-11 17:47:14 +08:00
Chranos
de8fc97532 debugging 2026-02-11 17:47:14 +08:00
Chranos
893eeb2208 debugging 2026-02-11 17:47:14 +08:00
Chranos
8f2ae4f67e add gemma3 2026-02-11 17:47:14 +08:00
Chranos
89dc931222 add gemma3 2026-02-11 17:47:14 +08:00
Chranos
a7028ae481 add gemma3 2026-02-11 17:47:14 +08:00
Chranos
2e24d45668 add gemma3 2026-02-11 17:47:14 +08:00
Chranos
5b9e02990a add gemma3 2026-02-11 17:47:14 +08:00
Chranos
ff94650fd1 add gemma3 2026-02-11 17:47:14 +08:00
Chranos
464beead22 fix: handle missing tie_word_embeddings attr in MPTConfig
Use getattr with default True for MPTConfig.tie_word_embeddings,
as some MPT model configs lack this attribute.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-02-11 17:47:14 +08:00
Chranos
ad087d5cf3 debugging 2026-02-11 17:47:14 +08:00
Chranos
6b708a43d8 fix: add logger import to llama.py for unknown weight skip warning
The previous commit added a warning log for skipping unknown weights
(e.g. embed_tokens.biases) but missed importing the logger.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-02-11 17:47:14 +08:00
Chranos
8efce7c44c debugging 2026-02-11 17:47:14 +08:00
Chranos
66d146dfad update README 2026-02-11 17:47:14 +08:00
Chranos
c35d463486 update README 2026-02-11 17:47:14 +08:00
Chranos
7420866d4c fixed kvcache bug 2026-02-11 17:47:14 +08:00
Chranos
3fed2190ad fixing kvcache bug 2026-02-06 16:39:42 +08:00
Chranos
c1b6f39a11 fix: pass lm_head to LogitsProcessor instead of calling forward()
In vLLM v0.6.2, ParallelLMHead.forward() raises RuntimeError since
its weights should be used through LogitsProcessor.linear_method.apply().
Pass lm_head as first arg to LogitsProcessor which handles the
hidden_states -> logits projection internally.
2026-02-06 15:05:49 +08:00
Chranos
3e301ce158 testing dynamic register 2026-02-06 15:05:49 +08:00
Chranos
87f96e1001 testing dynamic register 2026-02-06 15:05:49 +08:00
Chranos
e1a2afd244 testing dynamic register 2026-02-06 15:05:49 +08:00
Chranos
63a1a05999 testing dynamic register 2026-02-06 15:05:49 +08:00
Chranos
6d814b0cd4 testing dynamic register 2026-02-06 15:05:48 +08:00
Chranos
dc239a740c testing dynamic register 2026-02-06 15:05:48 +08:00
Chranos
a476b6458b testing dynamic register 2026-02-06 15:05:48 +08:00
Chranos
80e9a636af testing dynamic register 2026-02-06 15:05:48 +08:00
Chranos
16353d5d2a testing dynamic register 2026-02-06 15:05:48 +08:00
Chranos
70bee4e3ec testing dynamic register 2026-02-06 15:05:48 +08:00
Chranos
83c958a7c5 testing dynamic register 2026-02-06 15:05:48 +08:00
Chranos
9b84dd52be testing dynamic register 2026-02-06 15:05:48 +08:00
Chranos
2cb9f6ce1d testing dynamic register 2026-02-06 15:05:48 +08:00
31e7cd3bf9 删除 .DS_Store 2026-02-05 16:21:10 +08:00
37 changed files with 4551 additions and 176 deletions

BIN
.DS_Store vendored

Binary file not shown.

View File

@@ -1,12 +1,28 @@
# enginex-mlu370-vllm
# 寒武纪 mlu370 文本生成
该模型测试框架在寒武纪mlu370 X8/X4加速卡上基于vllm 推理引擎,适配了 Qwen1.5-1.8B-Chat 模型。
寒武纪 MLU370X8/X4加速卡上基于 vLLM 推理引擎的文本生成框架。
## 版本更新记录
* Qwen1.5-1.8B-Chat 是通义千问系列中一款约18亿参数、轻量级的中英文对话大模型专为高效推理和多场景聊天交互设计。
* Llama-2-7b-chat-hfMeta 发布的 LLaMA 2 系列中 70 亿参数的对话优化版开源大模型,适合多轮聊天与通用任务。
* ChatGLM3-6B智谱 AI 推出的第 3 代 ChatGLM 系列中 60 亿参数的中英双语对话大模型,支持推理、代码和多任务能力。
**v0.0.6.2** — 2026-02-11 · Llama4 模型支持,含 sigmoid routing MoE、QK Norm、交替 dense/MoE 层;由于 MLU370capability=3限制MoE 改为 dense 模式解决 graph capture 兼容性
**v0.0.6.1** — 2026-02-11 · DeepSeek V3 MTP 推测解码,新建 MTP draft model 复用 DeepseekV2DecoderLayer自动检测并启用 MTP speculative decoding
**v0.0.6** — 2026-02-11 · DeepSeek V3 模型支持,复用 V2 实现,新增 `noaux_tc` 路由,修复 MLA unpaged 缓存算子
**v0.0.5** — 2026-02-10 · Qwen3MoE 模型支持,修复 FusedMoE `forward_mlu` 签名 bug
**v0.0.4.1** — 2026-02-10 · Gemma3 rope 兼容性修复,适配 MLU rotary_emb 接口
**v0.0.4** — 2026-02-10 · Gemma3 模型支持,含 QK Norm、per-layer rope、滑动窗口
**v0.0.3.1** — 2026-02-06 · CNNL Tensor 溢出修复KV cache 元素数 int32 上限防护
**v0.0.3** — 2026-02-06 · Transformers 通用后端,支持 `auto_map` 加载自定义 HF 模型
**v0.0.2** — 2026-02-04 · Qwen3 模型支持QK Norm 适配,修复 rope/tokenizer 兼容性
---
## Quick Start
1. 首先从modelscope上下载文本生成大模型`Qwen1.5-1.8B-Chat`
@@ -163,5 +179,5 @@ curl http://localhost:80/v1/chat/completions \
| 模型名称 | mlu370-X8首字延迟(秒) | mlu370-X8输入处理速度(字每秒) | mlu370-X8输出速度(字每秒) | mlu370-X8输出质量 | Nvidia A100字延迟(秒) | Nvidia A100输入处理速度(字每秒) | Nvidia A100输出速度(字每秒) | Nvidia A100输出质量 |
| ------------------- | ------------------- | -------------------| ------------------- | ------------------- | ------------------- | ------------------- | ------------------- | ------------------- |
| Qwen/Qwen-1_8B |0.203 | 13493.2 | 119.2 | 10.0 | 0.052 | 25591.5 | 165.0 | 15.0|
| Qwen/Qwen1.5-0.5B |0.132 | 12366.6 | 106.9 | 15.0 | 0.066 | 24935.4 | 151.4 | 10.0|
| Qwen/Qwen-1_8B |0.203 | 13493.2 | 119.2 | 10.0 | 0.052 | 25591.5 | 165.0 | 15.0|
| Qwen/Qwen1.5-0.5B |0.132 | 12366.6 | 106.9 | 15.0 | 0.066 | 24935.4 | 151.4 | 10.0|

View File

@@ -226,7 +226,7 @@ class ModelConfig:
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
has_interleaved_attention = (sliding_window is not None) and (
isinstance(sliding_window, list) or
(self.hf_text_config.model_type in ["gemma2"]))
(self.hf_text_config.model_type in ["gemma2", "gemma3"]))
if (not self.disable_sliding_window and has_interleaved_attention):
sliding_window_len_min = get_min_sliding_window(
@@ -353,8 +353,20 @@ class ModelConfig:
task_support: Dict[_Task, bool] = {
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate": ModelRegistry.is_text_generation_model(architectures),
"embedding": ModelRegistry.is_embedding_model(architectures),
"generate": ModelRegistry.is_text_generation_model(
architectures,
model_path=self.model,
revision=self.revision,
trust_remote_code=self.trust_remote_code,
hf_config=hf_config,
),
"embedding": ModelRegistry.is_embedding_model(
architectures,
model_path=self.model,
revision=self.revision,
trust_remote_code=self.trust_remote_code,
hf_config=hf_config,
),
}
supported_tasks_lst: List[_Task] = [
task for task, is_supported in task_support.items() if is_supported
@@ -1391,6 +1403,18 @@ class SpeculativeConfig:
draft_hf_config = draft_model_config.hf_config
# Detect DeepSeek V3 MTP: same model path with
# num_nextn_predict_layers > 0
num_nextn = getattr(draft_hf_config,
"num_nextn_predict_layers", 0)
if (num_nextn and num_nextn > 0
and getattr(draft_hf_config, "model_type", "")
in ("deepseek_v3",)):
draft_hf_config.model_type = "deepseek_mtp"
draft_hf_config.architectures = ["DeepSeekMTPModel"]
if num_speculative_tokens is None:
num_speculative_tokens = num_nextn
if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
@@ -1409,7 +1433,7 @@ class SpeculativeConfig:
f"{num_speculative_tokens=} was provided.")
if enable_chunked_prefill and draft_hf_config.model_type in (
"medusa", "mlp_speculator", "eagle"):
"medusa", "mlp_speculator", "eagle", "deepseek_mtp"):
raise ValueError(
"Chunked prefill and hidden-state based draft models are "
"not compatible.")
@@ -1842,9 +1866,9 @@ def _get_and_verify_dtype(
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
if config.model_type == "gemma2":
if config.model_type in ("gemma2", "gemma3"):
logger.info(
"For Gemma 2, we downcast float32 to bfloat16 instead "
"For Gemma 2/3, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16.")
torch_dtype = torch.bfloat16

View File

@@ -153,23 +153,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_mlu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm._mlu_ops import fused_moe
assert use_grouped_topk is False and num_expert_group is None and topk_group is None, \
f"Following params: use_grouped_topk, num_expert_group, topk_group are not support yet."
assert use_grouped_topk is False and num_expert_group is None \
and topk_group is None, \
"Following params: use_grouped_topk, num_expert_group, " \
"topk_group are not supported yet."
return fused_moe(x,
router_logits,
w1, w2,
layer.w13_weight, layer.w2_weight,
None, None, # bias1, bias2
None, # residual
None, # input_smooth

View File

@@ -143,11 +143,14 @@ class RMSNorm(CustomOp):
from vllm import _mlu_ops as mlu_ops
x = x.view(-1, self.weight.data.shape[0])
weight = self.weight.data
if weight.dtype != x.dtype:
weight = weight.to(x.dtype)
if residual is not None:
residual = residual.view(-1, self.weight.data.shape[0])
return mlu_ops.fused_rms_norm(x, residual, self.weight.data, None, None, self.variance_epsilon, True)
return mlu_ops.fused_rms_norm(x, residual, weight, None, None, self.variance_epsilon, True)
else:
return mlu_ops.fused_rms_norm(x, residual, self.weight.data, None, None, self.variance_epsilon, False)
return mlu_ops.fused_rms_norm(x, residual, weight, None, None, self.variance_epsilon, False)
def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"

View File

@@ -146,6 +146,7 @@ class LinearBase(torch.nn.Module):
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
return_bias: If False, return only output tensor instead of (output, bias) tuple.
"""
def __init__(
@@ -156,6 +157,7 @@ class LinearBase(torch.nn.Module):
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
return_bias: bool = True,
):
super().__init__()
@@ -163,6 +165,7 @@ class LinearBase(torch.nn.Module):
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
self.return_bias = return_bias
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
@@ -198,13 +201,15 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
return_bias: bool = True):
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix)
prefix=prefix,
return_bias=return_bias)
# All the linear layer supports quant method.
assert self.quant_method is not None
@@ -238,6 +243,9 @@ class ReplicatedLinear(LinearBase):
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
@@ -281,9 +289,10 @@ class ColumnParallelLinear(LinearBase):
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None,
prefix: str = ""):
prefix: str = "",
return_bias: bool = True):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)
quant_config, prefix, return_bias=return_bias)
self.gather_output = gather_output
@@ -375,6 +384,9 @@ class ColumnParallelLinear(LinearBase):
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
@@ -418,7 +430,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
return_bias: bool = True):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
@@ -429,7 +442,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)
prefix=prefix,
return_bias=return_bias)
def weight_loader(self,
param: Parameter,
@@ -653,7 +667,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
return_bias: bool = True):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
@@ -686,7 +701,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)
prefix=prefix,
return_bias=return_bias)
def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
@@ -980,9 +996,10 @@ class RowParallelLinear(LinearBase):
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
return_bias: bool = True):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)
quant_config, prefix, return_bias=return_bias)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
@@ -1086,8 +1103,9 @@ class RowParallelLinear(LinearBase):
else:
output = output_parallel
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:

View File

@@ -38,6 +38,9 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# MLU F.linear requires matching dtypes
if x.dtype != layer.weight.dtype:
x = x.to(layer.weight.dtype)
return F.linear(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module,

View File

@@ -89,15 +89,63 @@ def device_loading_context(module: torch.nn.Module,
logger = init_logger(__name__)
def _get_device_memory_info_loader():
"""Get device memory info for debug logging. Returns dict or None."""
try:
import torch.mlu
allocated = torch.mlu.memory_allocated() / (1024 ** 3)
reserved = torch.mlu.memory_reserved() / (1024 ** 3)
free, total = torch.mlu.mem_get_info()
return {"allocated": allocated, "reserved": reserved,
"free": free / (1024 ** 3), "total": total / (1024 ** 3)}
except Exception:
pass
try:
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / (1024 ** 3)
reserved = torch.cuda.memory_reserved() / (1024 ** 3)
free, total = torch.cuda.mem_get_info()
return {"allocated": allocated, "reserved": reserved,
"free": free / (1024 ** 3), "total": total / (1024 ** 3)}
except Exception:
pass
return None
def _log_mem(tag: str):
info = _get_device_memory_info_loader()
if info:
logger.info(
"[DEBUG-MEM] %s: allocated=%.2f GiB, reserved=%.2f GiB, "
"free=%.2f GiB, total=%.2f GiB",
tag, info["allocated"], info["reserved"],
info["free"], info["total"])
def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
"""Initialize a model with the given configurations."""
model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config)
logger.info("[DEBUG-MEM] Model class: %s, dtype: %s",
model_class.__name__, model_config.dtype)
_log_mem("Before _initialize_model")
signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
return model_class(vllm_config=vllm_config, prefix=prefix)
model = model_class(vllm_config=vllm_config, prefix=prefix)
_log_mem("After _initialize_model (empty weights created)")
# Print model parameter summary
total_params = 0
total_bytes = 0
for name, param in model.named_parameters():
total_params += param.numel()
total_bytes += param.numel() * param.element_size()
logger.info(
"[DEBUG-MEM] Model params: %d, "
"estimated size: %.2f GiB",
total_params, total_bytes / (1024 ** 3))
return model
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
@@ -327,11 +375,14 @@ class DefaultModelLoader(BaseModelLoader):
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
_log_mem("load_model start, target_device=%s" % target_device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
_log_mem("Before load_weights")
model.load_weights(self._get_all_weights(model_config, model))
_log_mem("After load_weights")
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)

View File

@@ -20,7 +20,7 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
architectures = getattr(model_config.hf_config, "architectures", None) or []
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = [

View File

@@ -0,0 +1,291 @@
"""Inference-only DeepSeek V3 Multi-Token Prediction (MTP) model."""
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .deepseek_v2 import DeepseekV2DecoderLayer
class SharedHead(nn.Module):
"""Shared head for MTP: norm + lm_head."""
def __init__(self, config, prefix: str = ""):
super().__init__()
self.norm = RMSNorm(config.hidden_size,
eps=getattr(config, "rms_norm_eps", 1e-6))
self.head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(hidden_states)
class DeepSeekMultiTokenPredictorLayer(nn.Module):
"""Single MTP layer: enorm + hnorm + eh_proj + shared_head + mtp_block."""
def __init__(
self,
config,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.enorm = RMSNorm(config.hidden_size,
eps=getattr(config, "rms_norm_eps", 1e-6))
self.hnorm = RMSNorm(config.hidden_size,
eps=getattr(config, "rms_norm_eps", 1e-6))
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.shared_head = SharedHead(config,
prefix=f"{prefix}.shared_head")
# Reuse DeepseekV2DecoderLayer (MLU hijack auto-applies)
self.mtp_block = DeepseekV2DecoderLayer(
config,
prefix=f"model.layers.{layer_idx}",
cache_config=cache_config,
quant_config=quant_config,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert inputs_embeds is not None
# Mask inputs at position 0
inputs_embeds = torch.where(
positions.unsqueeze(-1) == 0, 0, inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(
positions, hidden_states, kv_caches[0], attn_metadata,
residual=None)
hidden_states = residual + hidden_states
return hidden_states
def get_spec_layer_idx_from_weight_name(config, weight_name: str):
"""Check if weight belongs to a speculative (MTP) layer.
Returns the layer index if so, None otherwise."""
num_nextn = getattr(config, "num_nextn_predict_layers", 0)
if num_nextn and num_nextn > 0:
layer_idx = config.num_hidden_layers
for i in range(num_nextn):
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
return layer_idx + i
return None
def _rewrite_spec_layer_name(config, spec_layer: int, name: str) -> str:
"""Rewrite weight name for MTP layer.
Add .mtp_block for transformer block weights,
rename shared weights to top level."""
spec_layer_weight_names = [
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head",
]
shared_weight_names = ["embed_tokens"]
spec_layer_weight = False
shared_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
if weight_name in shared_weight_names:
shared_weight = True
break
if not spec_layer_weight:
# Transformer block weights -> add .mtp_block prefix
name = name.replace(
f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.")
elif shared_weight:
# Shared weights -> top level
name = name.replace(f"model.layers.{spec_layer}.", "model.")
return name
class DeepSeekMTP(nn.Module):
"""DeepSeek V3 Multi-Token Prediction draft model.
Uses hidden states from the target model to predict the next token
via a single additional decoder layer."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.mtp_start_layer_idx = config.num_hidden_layers
num_mtp = getattr(config, "num_nextn_predict_layers", 1)
self.layers = nn.ModuleDict()
for i in range(num_mtp):
layer_idx = self.mtp_start_layer_idx + i
self.layers[str(layer_idx)] = DeepSeekMultiTokenPredictorLayer(
config=config,
layer_idx=layer_idx,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"model.layers.{layer_idx}",
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
# Use the first MTP layer (DeepSeek V3 only has 1)
layer = self.layers[str(self.mtp_start_layer_idx)]
hidden_states = layer(
input_ids, positions, previous_hidden_states,
kv_caches, attn_metadata, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
layer = self.layers[str(self.mtp_start_layer_idx)]
normed = layer.shared_head(hidden_states)
logits = self.logits_processor(
layer.shared_head.head, normed, sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# MLU SparseMoeMlp needs pack_params() before loading
try:
from vllm_mlu.model_executor.layers.sparse_moe_mlp import (
SparseMoeMlp)
for name, m in self.named_modules():
if isinstance(m, SparseMoeMlp):
m.pack_params()
except ImportError:
pass
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# Only load MTP layer weights
spec_layer = get_spec_layer_idx_from_weight_name(
self.config, name)
if spec_layer is None:
continue
# Rewrite weight name for MTP structure
name = _rewrite_spec_layer_name(
self.config, spec_layer, name)
# Only load shared weights (embed_tokens) from first
# MTP layer, per DeepSeek V3 Technical Report
if (spec_layer != self.mtp_start_layer_idx
and ".layers" not in name):
continue
# Strip "model." prefix since DeepSeekMTP holds
# embed_tokens and layers directly (no .model wrapper)
if name.startswith("model."):
name = name[len("model."):]
self._load_single_weight(
name, loaded_weight, stacked_params_mapping,
params_dict)
def _load_single_weight(
self,
name: str,
loaded_weight: torch.Tensor,
stacked_params_mapping: List[Tuple[str, str, int]],
params_dict: dict,
):
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip expert weights not in params_dict
if (("mlp.experts." in name
or "mlp.shared_experts." in name
or "mlp.shared_expert_gate." in name
or "e_score_correction_bias" in name)
and name not in params_dict):
return
if name.endswith(".bias") and name not in params_dict:
return
if name not in params_dict:
return
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
return
# Non-stacked weights
if name.endswith(".bias") and name not in params_dict:
return
if (("mlp.experts." in name
or "mlp.shared_experts." in name
or "mlp.shared_expert_gate." in name
or "e_score_correction_bias" in name)
and name not in params_dict):
return
if name not in params_dict:
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -611,3 +611,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass

View File

@@ -0,0 +1,507 @@
# Copyright 2024 The vLLM team.
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma3 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
class Gemma3MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_activation: str,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_activation != "gelu_pytorch_tanh":
raise ValueError(
"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
"function. Please set `hidden_activation` to "
"`gelu_pytorch_tanh`.")
self.act_fn = GeluAndMul(approximate="tanh")
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Gemma3Attention(nn.Module):
def __init__(self,
layer_idx: int,
config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_position_embeddings: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
attn_logits_soft_cap: Optional[float] = None) -> None:
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = config.query_pre_attn_scalar**-0.5
# Extract rope_theta from config, compatible with both old-style
# (config.rope_theta) and new-style (config.rope_parameters dict).
rope_params = getattr(config, "rope_parameters", None)
if hasattr(config, "rope_theta"):
self.rope_theta = config.rope_theta
elif isinstance(rope_params, dict):
# Transformers v5: nested per layer_type
if "full_attention" in rope_params:
self.rope_theta = rope_params["full_attention"].get(
"rope_theta", 10000.0)
else:
# Transformers v4: flat dict
self.rope_theta = rope_params.get("rope_theta", 10000.0)
else:
self.rope_theta = 10000.0
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.attention_bias,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
)
# Gemma3 specific: QK normalization
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
# Determine layer type and rope config
layer_types = getattr(config, "layer_types", None)
if layer_types is not None:
layer_type = layer_types[layer_idx]
self.is_sliding = (layer_type == "sliding_attention")
else:
self.is_sliding = (layer_idx % 2 == 1
and config.sliding_window is not None)
# Extract rope config, compatible with both old-style (rope_theta,
# rope_scaling) and new-style (rope_parameters dict) transformers.
rope_params = getattr(config, "rope_parameters", None)
# Set up rope based on layer type
if self.is_sliding:
# Local/sliding attention uses rope_local_base_freq
if hasattr(config, "rope_local_base_freq"):
local_base = config.rope_local_base_freq
elif (isinstance(rope_params, dict)
and "sliding_attention" in rope_params):
local_base = rope_params["sliding_attention"].get(
"rope_theta", self.rope_theta)
else:
local_base = self.rope_theta
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=local_base,
is_neox_style=True,
)
else:
# Global attention: extract rope_base and rope_scaling.
# Prioritize rope_parameters dict (newer transformers) to
# avoid passing nested dicts that are unhashable.
rope_scaling = None
rope_base = self.rope_theta
if isinstance(rope_params, dict):
# Transformers v5: per layer_type sub-dicts
if "full_attention" in rope_params:
rp = rope_params["full_attention"]
else:
# Transformers v4: flat dict
rp = rope_params
rope_base = rp.get("rope_theta", self.rope_theta)
rtype = rp.get("rope_type", None)
if rtype and rtype != "default":
rope_scaling = {
k: v for k, v in rp.items()
if k not in ("rope_theta",)
}
else:
# Fallback: old-style config.rope_scaling (flat dict)
rope_scaling = getattr(config, "rope_scaling", None)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_base,
is_neox_style=True,
rope_scaling=rope_scaling,
)
# NOTE: Like Gemma2, vLLM currently ignores sliding window
# and uses global attention for all layers.
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
# Gemma3 specific: apply QK normalization
q = q.unflatten(-1, (self.num_heads, self.head_dim))
q = self.q_norm(q)
q = q.flatten(-2, -1)
k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
k = self.k_norm(k)
k = k.flatten(-2, -1)
# MLU rotary_emb expects a single concatenated tensor, not
# separate q and k (forward_mlu signature differs from forward_native).
qk = torch.cat([q, k], dim=-1)
self.rotary_emb(positions,
qk.view(-1, self.num_heads + self.num_kv_heads,
self.head_dim))
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class Gemma3DecoderLayer(nn.Module):
def __init__(
self,
layer_idx: int,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Gemma3Attention(
layer_idx=layer_idx,
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
# Gemma3 does not use attn logit softcapping
attn_logits_soft_cap=getattr(config,
"attn_logit_softcapping", None),
)
self.hidden_size = config.hidden_size
self.mlp = Gemma3MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_activation=config.hidden_activation,
quant_config=quant_config,
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, residual = self.pre_feedforward_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
return hidden_states, residual
class Gemma3Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma3DecoderLayer(
int(prefix.split(".")[-1]),
config, cache_config, quant_config),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer))
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
hidden_states *= self.normalizer
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
logger.warning(
"Some weights are not initialized from checkpoints: %s",
unloaded_params)
class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
del lora_config # Unused.
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Gemma3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
# Gemma3 may or may not have final_logit_softcapping
soft_cap = getattr(config, "final_logit_softcapping", None)
self.logits_processor = LogitsProcessor(
config.vocab_size, soft_cap=soft_cap)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.embed_tokens,
hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
loader.load_weights(weights)

View File

@@ -26,6 +26,10 @@ import torch
from torch import nn
from transformers import LlamaConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
@@ -404,6 +408,12 @@ class LlamaModel(nn.Module):
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
logger.warning(
"Skipping weight %s not present in the model",
name)
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)

View File

@@ -0,0 +1,580 @@
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Llama4 model compatible with HuggingFace weights."""
import re
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .llama import LlamaMLP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
def _extract_layer_index(prefix: str) -> int:
"""Extract layer index from prefix string like 'model.layers.0.self_attn'."""
match = re.search(r'layers\.(\d+)', prefix)
if match is None:
raise ValueError(f"Cannot extract layer index from prefix: {prefix}")
return int(match.group(1))
class Llama4MoE(nn.Module):
"""Llama4 Mixture of Experts with shared expert."""
@staticmethod
def custom_routing_function(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores, router_indices = torch.topk(
gating_output, topk, dim=-1)
router_scores = torch.sigmoid(router_scores.float())
return (router_scores, router_indices.to(torch.int32))
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = getattr(config, "num_experts_per_tok", 1)
self.num_local_experts = getattr(config, "num_local_experts", 8)
self.hidden_size = getattr(config, "hidden_size", 4096)
intermediate_size_moe = getattr(config, "intermediate_size", 8192)
self.router = ReplicatedLinear(
self.hidden_size,
self.num_local_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.router",
)
self.experts = FusedMoE(
num_experts=self.num_local_experts,
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=intermediate_size_moe,
reduce_results=False,
renormalize=False,
quant_config=quant_config,
custom_routing_function=Llama4MoE.custom_routing_function,
prefix=f"{prefix}.experts",
)
self.shared_expert = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=intermediate_size_moe,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.shared_expert",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.router(hidden_states)
# routed experts
routed_out = self.experts(hidden_states, router_logits)
# shared expert
shared_out = self.shared_expert(hidden_states)
# combine and all-reduce
experts_out = routed_out + shared_out
if self.tp_size > 1:
experts_out = tensor_model_parallel_all_reduce(experts_out)
return experts_out.view(orig_shape)
class Llama4Attention(nn.Module):
def __init__(
self,
config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = _extract_layer_index(prefix)
self.hidden_size = hidden_size
self.no_rope_layers = getattr(config, "no_rope_layers", None)
self.nope = (self.no_rope_layers is not None
and self.no_rope_layers[self.layer_idx] == 0)
self.use_qk_norm = getattr(config, "use_qk_norm", False) and not self.nope
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
# Temperature tuning for NoPE layers
self.attn_temperature_tuning = (
self.nope and getattr(config, "attn_temperature_tuning", False))
self.floor_scale = getattr(config, "floor_scale", 8192.0)
self.attn_scale = getattr(config, "attn_scale", 0.1)
# QK norm
rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5)
if self.use_qk_norm:
self.qk_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
# v0.6.2 RMSNorm doesn't support has_weight=False,
# so we set weight to ones and make it non-trainable
self.qk_norm.weight.data.fill_(1.0)
self.qk_norm.weight.requires_grad = False
else:
self.qk_norm = None
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
# RoPE (None for NoPE layers)
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if not self.nope:
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=True,
)
else:
self.rotary_emb = None
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
)
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale)
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
return attn_scale.unsqueeze(-1)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k)
if self.qk_norm is not None:
q = q.reshape(-1, self.head_dim)
q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype)
k = k.reshape(-1, self.head_dim)
k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype)
if self.attn_temperature_tuning and self.nope:
attn_scale = self._get_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class Llama4DecoderLayer(nn.Module):
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = _extract_layer_index(prefix)
self.hidden_size = getattr(config, "hidden_size", 4096)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = Llama4Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=getattr(config, "num_attention_heads", 32),
num_kv_heads=getattr(config, "num_key_value_heads",
getattr(config, "num_attention_heads", 32)),
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
# Interleaved MoE/dense layers
interleave_moe_layer_step = getattr(config,
"interleave_moe_layer_step", 0)
is_moe_layer = (interleave_moe_layer_step > 0
and (self.layer_idx + 1)
% interleave_moe_layer_step == 0)
if is_moe_layer:
self.feed_forward = Llama4MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
else:
intermediate_size_mlp = getattr(config, "intermediate_size_mlp",
getattr(config,
"intermediate_size", 8192))
self.feed_forward = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=intermediate_size_mlp,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.feed_forward",
)
rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5)
self.input_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = RMSNorm(self.hidden_size,
eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
class Llama4Model(nn.Module):
"""Llama4 model - independent implementation to avoid pad_token_id issue."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
# Defensive access - Llama4Config may not have pad_token_id
self.padding_idx = getattr(config, "pad_token_id", None)
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (
getattr(config, "tie_word_embeddings", False)
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Llama4DecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Llama4ForCausalLM(nn.Module, SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
# Llama4ForConditionalGeneration uses top-level Llama4Config
# which has text_config sub-config. Extract it for text model.
text_config = getattr(config, "text_config", None)
if text_config is not None:
orig_archs = getattr(config, "architectures", None)
vllm_config.model_config.hf_config = text_config
if orig_archs and not getattr(text_config, "architectures", None):
text_config.architectures = orig_archs
config = text_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = Llama4Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE if not lora_config
else lora_config.lora_vocab_padding_size),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if getattr(config, "tie_word_embeddings", False):
self.lm_head = self.lm_head.tie_weights(
self.model.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def permute_qk_weight_for_rotary(
self,
name: str,
loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]:
"""Permute Q/K weights for rotary embedding compatibility."""
def permute(w: torch.Tensor, n_heads: int):
attn_in = getattr(self.config, "head_dim", 128) * n_heads
attn_out = getattr(self.config, "hidden_size", 4096)
return (w.contiguous()
.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
.transpose(1, 2).reshape(attn_in, attn_out))
modules = name.split(".")
is_weight = modules[-1] == "weight"
if is_weight:
if "k_proj" in modules:
loaded_weight = permute(
loaded_weight,
getattr(self.config, "num_key_value_heads", 8))
elif "q_proj" in modules:
loaded_weight = permute(
loaded_weight,
getattr(self.config, "num_attention_heads", 32))
return name, loaded_weight
def load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]],
):
loader = AutoWeightsLoader(
self,
skip_prefixes=(
["lm_head."]
if getattr(self.config, "tie_word_embeddings", False)
else None),
)
def _process_weights(weights):
for name, loaded_weight in weights:
# Strip language_model. prefix for Llama4ForConditionalGeneration
if name.startswith("language_model."):
name = name[len("language_model."):]
# Skip vision encoder weights
elif name.startswith("multi_modal_projector.") or \
name.startswith("vision_encoder.") or \
name.startswith("vision_model."):
continue
name, loaded_weight = self.permute_qk_weight_for_rotary(
name, loaded_weight)
yield name, loaded_weight
loader.load_weights(_process_weights(weights))

View File

@@ -272,7 +272,7 @@ class MPTForCausalLM(nn.Module, SupportsPP):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
assert config.tie_word_embeddings
assert getattr(config, "tie_word_embeddings", True)
self.quant_config = quant_config
self.transformer = MPTModel(vllm_config=vllm_config,

View File

@@ -0,0 +1,556 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class Qwen3MoeMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
self.experts = FusedMoE(num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config)
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
bias=False,
quant_config=None)
shared_expert_intermediate_size = getattr(
config, "shared_expert_intermediate_size", 0)
if shared_expert_intermediate_size > 0:
self.shared_expert = Qwen3MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=shared_expert_intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
)
else:
self.shared_expert = None
# Qwen3Moe uses ReplicatedLinear for shared_expert_gate
# (unlike Qwen2Moe which uses torch.nn.Linear)
self.shared_expert_gate = ReplicatedLinear(config.hidden_size,
1,
bias=False,
quant_config=None)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
if self.shared_expert_gate is not None:
shared_output = F.sigmoid(
self.shared_expert_gate(hidden_states)[0]
) * shared_output
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(orig_shape)
class Qwen3MoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
# Qwen3 specific: QK normalization
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
# Qwen3 specific: Apply QK normalization before rotary embedding
# Use .contiguous() to ensure memory layout is compatible with
# MLU's RMSNorm which uses .view() internally.
q_shape = q.shape
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim).contiguous()
q_by_head = self.q_norm(q_by_head)
q = q_by_head.reshape(q_shape)
k_shape = k.shape
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim).contiguous()
k_by_head = self.k_norm(k_by_head)
k = k_by_head.reshape(k_shape)
# MLU rotary_emb expects a single concatenated 3D tensor, not
# separate q and k (forward_mlu signature differs from forward_native).
qk = torch.cat([q, k], dim=-1)
self.rotary_emb(positions,
qk.view(-1, self.num_heads + self.num_kv_heads,
self.head_dim))
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class Qwen3MoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = Qwen3MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
cache_config=cache_config,
quant_config=quant_config,
)
# Note: Qwen3MoE may not have `mlp_only_layers` in the config.
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
quant_config=quant_config)
else:
self.mlp = Qwen3MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class Qwen3MoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Qwen3MoeDecoderLayer(config=config,
layer_idx=int(
prefix.split(".")[-1]),
cache_config=cache_config,
quant_config=quant_config),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = Qwen3MoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -28,6 +28,9 @@ from .interfaces_base import is_embedding_model, is_text_generation_model
logger = init_logger(__name__)
# Cache for architectures that have already been logged
_logged_transformers_architectures: set = set()
# yapf: disable
_TEXT_GENERATION_MODELS = {
# [Decoder-only]
@@ -45,10 +48,12 @@ _TEXT_GENERATION_MODELS = {
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
@@ -60,6 +65,8 @@ _TEXT_GENERATION_MODELS = {
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
"Llama4ForConditionalGeneration": ("llama4", "Llama4ForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
@@ -87,6 +94,7 @@ _TEXT_GENERATION_MODELS = {
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
@@ -158,11 +166,14 @@ _SPECULATIVE_DECODING_MODELS = {
"EAGLEModel": ("eagle", "EAGLE"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
}
# Transformers backend models - for custom models with auto_map
# Transformers backend models - wrapper classes for custom HuggingFace models
# These provide the vLLM interface for models loaded via auto_map
_TRANSFORMERS_BACKEND_MODELS = {
"TransformersForCausalLM": ("transformers_backend", "TransformersForCausalLM"),
# Text generation models
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
}
# yapf: enable
@@ -171,6 +182,7 @@ _VLLM_MODELS = {
**_EMBEDDING_MODELS,
**_MULTIMODAL_MODELS,
**_SPECULATIVE_DECODING_MODELS,
**_TRANSFORMERS_BACKEND_MODELS,
}
# Models not supported by ROCm.
@@ -383,54 +395,86 @@ class _ModelRegistry:
revision: Optional[str],
trust_remote_code: bool,
hf_config: Optional[object] = None,
) -> Optional[Type[nn.Module]]:
) -> Optional[str]:
"""
Try to resolve a model architecture using the Transformers backend.
This allows loading custom models that define their own implementation
via the `auto_map` field in config.json.
Returns the loaded model class if successful, None otherwise.
Returns the vLLM wrapper architecture name (e.g. "TransformersForCausalLM")
if the model can be loaded via auto_map, None otherwise.
"""
# Check if architecture is in transformers
# If architecture is already a transformers backend model, return it
if architecture in _TRANSFORMERS_BACKEND_MODELS:
return architecture
# Check if architecture exists in transformers library
model_module = getattr(transformers, architecture, None)
if model_module is not None:
# Model exists in transformers, can use TransformersForCausalLM wrapper
# Only log once per architecture to avoid spam
if architecture not in _logged_transformers_architectures:
_logged_transformers_architectures.add(architecture)
logger.info(
"Architecture %s found in transformers library, "
"using TransformersForCausalLM wrapper",
architecture
)
return "TransformersForCausalLM"
# Get auto_map from hf_config
auto_map: Dict[str, str] = {}
if hf_config is not None:
auto_map = getattr(hf_config, "auto_map", None) or {}
if model_module is None and auto_map:
# Try to load from auto_map
# First, ensure config class is loaded
for prefix in ("AutoConfig", "AutoModel"):
for name, module in auto_map.items():
if name.startswith(prefix):
try_get_class_from_dynamic_module(
module,
model_path,
trust_remote_code=trust_remote_code,
revision=revision,
warn_on_fail=False,
)
# Now try to load the model class
for name, module in auto_map.items():
if name.startswith("AutoModel"):
model_module = try_get_class_from_dynamic_module(
module,
model_path,
trust_remote_code=trust_remote_code,
revision=revision,
warn_on_fail=True,
)
if model_module is not None:
logger.info(
"Loaded custom model class %s from auto_map",
model_module.__name__
)
return model_module
if not auto_map:
return None
return model_module
# Try to load from auto_map to verify it works
# First, ensure config class is loaded
for name, module in auto_map.items():
if name.startswith("AutoConfig"):
try_get_class_from_dynamic_module(
module,
model_path,
trust_remote_code=trust_remote_code,
revision=revision,
warn_on_fail=False,
)
# Check if auto_map has a model class we can use
# Priority: AutoModelForCausalLM > AutoModelForSeq2SeqLM > AutoModel
auto_model_keys = sorted(
[k for k in auto_map.keys() if k.startswith("AutoModel")],
key=lambda x: (0 if "ForCausalLM" in x else (1 if "ForSeq2Seq" in x else 2))
)
for name in auto_model_keys:
module = auto_map[name]
model_cls = try_get_class_from_dynamic_module(
module,
model_path,
trust_remote_code=trust_remote_code,
revision=revision,
warn_on_fail=True,
)
if model_cls is not None:
# Only log once per model class to avoid spam
log_key = f"{model_cls.__name__}_{name}"
if not hasattr(self, '_logged_custom_models'):
self._logged_custom_models = set()
if log_key not in self._logged_custom_models:
logger.info(
"Found custom model class %s from auto_map[%s], "
"using TransformersForCausalLM wrapper",
model_cls.__name__,
name
)
self._logged_custom_models.add(log_key)
# Return the wrapper architecture, not the actual class
return "TransformersForCausalLM"
return None
def _normalize_archs(
self,
@@ -440,6 +484,7 @@ class _ModelRegistry:
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
return []
return architectures
@@ -461,12 +506,14 @@ class _ModelRegistry:
# Fallback: try to resolve using transformers backend (auto_map)
if model_path and trust_remote_code and hf_config:
for arch in architectures:
model_cls = self._try_resolve_transformers(
wrapper_arch = self._try_resolve_transformers(
arch, model_path, revision, trust_remote_code, hf_config
)
if model_cls is not None:
# Create ModelInfo from the dynamically loaded class
return _ModelInfo.from_model_cls(model_cls)
if wrapper_arch is not None:
# Use the wrapper architecture's ModelInfo
model_info = self._try_inspect_model_cls(wrapper_arch)
if model_info is not None:
return model_info
return self._raise_for_unsupported(architectures)
@@ -488,11 +535,14 @@ class _ModelRegistry:
# Fallback: try to resolve using transformers backend (auto_map)
if model_path and trust_remote_code and hf_config:
for arch in architectures:
model_cls = self._try_resolve_transformers(
wrapper_arch = self._try_resolve_transformers(
arch, model_path, revision, trust_remote_code, hf_config
)
if model_cls is not None:
return (model_cls, arch)
if wrapper_arch is not None:
model_cls = self._try_load_model_cls(wrapper_arch)
if model_cls is not None:
# Return wrapper class but keep original architecture name
return (model_cls, arch)
return self._raise_for_unsupported(architectures)

View File

@@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Wrapper around `transformers` models for vLLM v0.6.2.
This module provides the Transformers modeling backend that wraps
any HuggingFace model with the vLLM interface, enabling support for custom
models that define their implementation via `auto_map` in config.json.
Architecture (following latest vLLM patterns):
- Base: Core functionality (meta init, PP/TP support, module replacement, attention, weight loading)
- CausalMixin: Causal LM specific (lm_head, compute_logits, sample)
- EmbeddingMixin: Embedding/pooling specific (pooler, pooling)
- SequenceClassificationMixin: Classification specific (classifier, pooling)
Composed model classes:
- TransformersForCausalLM = CausalMixin + Base
- TransformersForEmbedding = EmbeddingMixin + Base
- TransformersForSequenceClassification = SequenceClassificationMixin + Base
Key optimizations:
- Meta device initialization for memory efficiency
- Pipeline Parallel support (PPMissingLayer)
- Tensor Parallel support (tp_plan based module replacement)
- Module replacement (Linear, RMSNorm, Embedding) with vLLM optimized versions
- vLLM Attention instances for proper KV cache allocation
- AutoWeightsLoader for efficient weight loading with name mapping
"""
from vllm.model_executor.models.transformers.base import (
Base,
set_attention_context,
clear_attention_context,
get_attention_context,
vllm_flash_attention_forward,
)
from vllm.model_executor.models.transformers.causal import CausalMixin
from vllm.model_executor.models.transformers.pooling import (
EmbeddingMixin,
SequenceClassificationMixin,
)
from vllm.model_executor.models.transformers.legacy import LegacyMixin
from vllm.model_executor.models.transformers.utils import (
init_on_device_without_buffers,
replace_linear_class,
replace_rms_norm_class,
log_replacement,
maybe_prefix,
)
# ============================================================================
# Composed Model Classes (Mixin + Base pattern)
# ============================================================================
class TransformersForCausalLM(CausalMixin, Base):
"""
Transformers backend wrapper for causal language models.
Combines CausalMixin (lm_head, compute_logits, sample) with
Base (meta init, PP/TP support, module replacement, attention, weight loading).
Supports any HuggingFace model with auto_map in config.json.
"""
pass
class TransformersForEmbedding(EmbeddingMixin, Base):
"""
Transformers backend wrapper for embedding/sentence similarity models.
Combines EmbeddingMixin (pooler, pooling) with
Base (meta init, PP/TP support, module replacement, attention, weight loading).
Supports embedding models like BERT, sentence-transformers, etc.
"""
pass
class TransformersForSequenceClassification(SequenceClassificationMixin, Base):
"""
Transformers backend wrapper for sequence classification models.
Combines SequenceClassificationMixin (classifier, pooling) with
Base (meta init, PP/TP support, module replacement, attention, weight loading).
Supports cross-encoders and classification models.
"""
pass
class TransformersForLegacy(LegacyMixin, EmbeddingMixin, Base):
"""
Transformers backend wrapper for legacy/encoder models.
Combines LegacyMixin (BERT/RoBERTa weight mapping, position handling) with
EmbeddingMixin (pooler) and Base (core functionality).
Supports BERT, RoBERTa, and similar encoder models.
"""
pass
__all__ = [
# Main wrapper classes
"TransformersForCausalLM",
"TransformersForEmbedding",
"TransformersForSequenceClassification",
"TransformersForLegacy",
# Base class for extension
"Base",
# Mixin classes for custom combinations
"CausalMixin",
"EmbeddingMixin",
"SequenceClassificationMixin",
"LegacyMixin",
# Attention context management
"set_attention_context",
"clear_attention_context",
"get_attention_context",
"vllm_flash_attention_forward",
# Utility functions
"init_on_device_without_buffers",
"replace_linear_class",
"replace_rms_norm_class",
"log_replacement",
"maybe_prefix",
]

View File

@@ -0,0 +1,704 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Transformers modeling backend base class for v0.6.2.
This module provides the Base class following latest vLLM architecture:
- Meta device initialization for memory efficiency
- Pipeline parallel support (PPMissingLayer)
- Tensor parallel support (tp_plan based module replacement)
- Module replacement (Linear, RMSNorm) with vLLM optimized versions
- VocabParallelEmbedding for input embeddings
- Attention instances for KV cache allocation
- Weight loading with AutoWeightsLoader and WeightsMapper
"""
import re
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tp_group
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
make_empty_intermediate_tensors_factory,
)
from vllm.attention.layer import Attention
from vllm.sequence import IntermediateTensors
from .utils import (
init_on_device_without_buffers,
replace_linear_class,
replace_rms_norm_class,
log_replacement,
maybe_prefix,
)
if TYPE_CHECKING:
from transformers import PreTrainedModel
from vllm.attention import AttentionMetadata
logger = init_logger(__name__)
# ============================================================================
# Attention Context Management (for vLLM attention integration)
# ============================================================================
_current_attn_metadata = None
_current_kv_caches = None
def set_attention_context(attn_metadata, kv_caches):
"""Set the current attention context for vLLM attention functions."""
global _current_attn_metadata, _current_kv_caches
_current_attn_metadata = attn_metadata
_current_kv_caches = kv_caches
def clear_attention_context():
"""Clear the current attention context after forward pass."""
global _current_attn_metadata, _current_kv_caches
_current_attn_metadata = None
_current_kv_caches = None
def get_attention_context():
"""Get the current attention context."""
return _current_attn_metadata, _current_kv_caches
# ============================================================================
# vLLM Attention Function for Transformers Integration
# ============================================================================
def vllm_flash_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
scaling: float = None,
attention_instances: Dict[int, Attention] = None,
**kwargs,
):
"""
vLLM's optimized attention function for transformers integration.
In v0.6.2, Attention.forward signature is:
(query, key, value, kv_cache, attn_metadata)
"""
layer_idx = getattr(module, 'layer_idx', 0)
if attention_instances is None or layer_idx not in attention_instances:
return _standard_attention(query, key, value, attention_mask, scaling)
self_attn = attention_instances[layer_idx]
attn_metadata, kv_caches = get_attention_context()
if attn_metadata is None or kv_caches is None:
return _standard_attention(query, key, value, attention_mask, scaling)
if scaling is not None:
self_attn.impl.scale = float(scaling)
# Reshape: [batch, heads, seq, head_dim] -> [seq, heads * head_dim]
hidden = query.shape[-2]
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
kv_cache = kv_caches[layer_idx] if layer_idx < len(kv_caches) else None
output = self_attn.forward(query, key, value, kv_cache, attn_metadata)
return output, None
def _standard_attention(query, key, value, attention_mask, scaling):
"""Standard scaled dot-product attention fallback."""
attn_weights = torch.matmul(query, key.transpose(-2, -1))
if scaling is not None:
attn_weights = attn_weights * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value)
return attn_output, None
# Register vLLM attention to transformers
_vllm_attention_registered = False
try:
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
_vllm_attention_registered = True
logger.info("Registered vLLM attention function to transformers")
except (ImportError, AttributeError) as e:
logger.warning("Could not register vLLM attention: %s", e)
# ============================================================================
# Base Class with Pipeline Parallel and Tensor Parallel Support
# ============================================================================
class Base(nn.Module):
"""
Base class for Transformers backend models with full parallel support.
Features:
- Pipeline Parallel: PPMissingLayer for distributed layers
- Tensor Parallel: tp_plan based module replacement
- Meta device initialization
- Module replacement (Linear → vLLM Linear, RMSNorm → vLLM RMSNorm)
- VocabParallelEmbedding for input embeddings
- Attention instances for KV cache allocation
"""
# For vLLM's weight loader
embedding_modules = ["embed_tokens"]
# Weight name mapping following latest vLLM pattern
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Add `model.` prefix for base model checkpoints,
# handling the case where it is already present
"": "model.",
"model.model.": "model.",
# Heads will be adjacent to `model` (pooling included because of adapters)
"model.lm_head.": "lm_head.",
"model.score.": "classifier.",
"model.classifier.": "classifier.",
}
)
# Note: __init_subclass__ with WeightsMapper merging is not supported in v0.6.2
# because WeightsMapper doesn't implement __or__/__ior__ operators.
# Each Mixin should define its own hf_to_vllm_mapper if needed.
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
logger.info("Using Transformers modeling backend.")
# Store configuration
self.config = vllm_config.model_config.hf_config
self.text_config = getattr(self.config, "text_config", self.config)
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.device_config = vllm_config.device_config
self.parallel_config = vllm_config.parallel_config
self.quant_config = vllm_config.quant_config
self.prefix = prefix
# Parallel groups
self.pp_group = get_pp_group()
self.tp_group = get_tp_group()
# Model dimensions
self.hidden_size = getattr(self.text_config, "hidden_size", 4096)
self.vocab_size = getattr(self.text_config, "vocab_size", 32000)
# Weight loading configuration
self.skip_prefixes: List[str] = []
self.ignore_unexpected_prefixes: List[str] = []
# Configure attention backend
self._configure_attention_backend()
# Create model on meta device
self._init_model_on_meta()
# Apply pipeline parallel
self._apply_pipeline_parallel()
# Replace modules (with tensor parallel support)
self._replace_modules()
# Fix attention head_dim in case config was incorrect
self._fix_attention_head_dim()
# Add debug hook to first attention module to capture tensor shapes
self._add_attention_debug_hook()
# Replace input embeddings
self._replace_input_embeddings()
# Create attention instances
self.attention_instances = self._create_attention_instances()
# Initialize parameters on target device
self._init_parameters()
# Pipeline parallel intermediate tensors
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], self.hidden_size
)
def _configure_attention_backend(self) -> None:
"""Configure vLLM attention backend."""
# Note: attention implementation is set in _init_model_on_meta
# This method is kept for potential platform-specific configuration
pass
def _init_model_on_meta(self) -> None:
"""Create model structure on meta device."""
from transformers import AutoModel
logger.info("Creating model structure on meta device...")
# Set attention implementation to vLLM's
self.text_config._attn_implementation = "vllm"
# Ensure head_dim is correctly set in BOTH config and text_config
# Transformers models use config.head_dim to compute attention dimensions
# Some models may have incorrect head_dim, so we compute and set it
if hasattr(self.text_config, "num_attention_heads") and hasattr(self.text_config, "hidden_size"):
correct_head_dim = self.text_config.hidden_size // self.text_config.num_attention_heads
# Check and fix head_dim in text_config
if hasattr(self.text_config, "head_dim"):
if self.text_config.head_dim != correct_head_dim:
logger.warning(
"Correcting head_dim in text_config: %d -> %d",
self.text_config.head_dim, correct_head_dim
)
self.text_config.head_dim = correct_head_dim
else:
self.text_config.head_dim = correct_head_dim
# Also set in self.config (which is passed to AutoModel.from_config)
if hasattr(self.config, "head_dim"):
if self.config.head_dim != correct_head_dim:
logger.warning(
"Correcting head_dim in config: %d -> %d",
self.config.head_dim, correct_head_dim
)
self.config.head_dim = correct_head_dim
else:
self.config.head_dim = correct_head_dim
# Some models also need _attn_implementation in config
self.config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"):
self.model: "PreTrainedModel" = AutoModel.from_config(
self.config,
torch_dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
def _apply_pipeline_parallel(self) -> None:
"""
Apply pipeline parallelization plan.
For models that don't explicitly support pp_plan, we do a best-effort
approach by splitting layers based on num_hidden_layers.
"""
if self.pp_group.world_size <= 1:
return
logger.info("Applying pipeline parallel (world_size=%d, rank=%d)",
self.pp_group.world_size, self.pp_group.rank_in_group)
num_layers = getattr(self.text_config, "num_hidden_layers",
getattr(self.text_config, "num_layers", 32))
start_layer, end_layer = get_pp_indices(
num_layers,
self.pp_group.rank_in_group,
self.pp_group.world_size,
)
# Find and process layer modules
layers_module = self._find_layers_module()
if layers_module is not None:
layers = list(layers_module.children())
for i, layer in enumerate(layers):
if not (start_layer <= i < end_layer):
# Replace layers not on this rank with PPMissingLayer
setattr(layers_module, str(i), PPMissingLayer())
# Handle embeddings (only on first rank)
if not self.pp_group.is_first_rank:
input_embeddings = self.model.get_input_embeddings()
if input_embeddings is not None:
# Keep a reference but mark as missing for forward
self._has_embeddings = False
else:
self._has_embeddings = True
# Handle final norm and lm_head (only on last rank)
if not self.pp_group.is_last_rank:
# Mark lm_head as missing
if hasattr(self.model, 'lm_head'):
self.model.lm_head = PPMissingLayer()
logger.info("Pipeline parallel applied: layers %d-%d on this rank",
start_layer, end_layer)
def _find_layers_module(self) -> Optional[nn.Module]:
"""Find the ModuleList containing transformer layers."""
# Common layer container names
layer_names = ['layers', 'h', 'blocks', 'layer', 'encoder.layer', 'decoder.layers']
def _search_layers(module: nn.Module, prefix: str = "") -> Optional[nn.Module]:
for name, child in module.named_children():
if name in ['layers', 'h', 'blocks', 'layer'] and isinstance(child, nn.ModuleList):
return child
# Recursively search in model backbone
if name in ['model', 'transformer', 'encoder', 'decoder']:
result = _search_layers(child, f"{prefix}.{name}" if prefix else name)
if result is not None:
return result
return None
return _search_layers(self.model)
def _get_tp_plan(self) -> Dict[str, str]:
"""
Get tensor parallel plan for module replacement.
This maps module name patterns to parallelization styles:
- "colwise": Column parallel (split output dim)
- "rowwise": Row parallel (split input dim)
- "replicate": Replicated (no split)
Returns a dict mapping regex patterns to styles.
"""
# Check if model has explicit tp_plan
if hasattr(self.model, 'tp_plan') and self.model.tp_plan:
return {maybe_prefix("model", k): v for k, v in self.model.tp_plan.items()}
# Default tp_plan for common LLM architectures
# Based on typical transformer structure
return {
r".*\.q_proj$": "colwise",
r".*\.k_proj$": "colwise",
r".*\.v_proj$": "colwise",
r".*\.o_proj$": "rowwise",
r".*\.gate_proj$": "colwise",
r".*\.up_proj$": "colwise",
r".*\.down_proj$": "rowwise",
r".*\.query$": "colwise",
r".*\.key$": "colwise",
r".*\.value$": "colwise",
r".*\.dense$": "rowwise",
r".*\.fc1$": "colwise",
r".*\.fc2$": "rowwise",
}
def _replace_modules(self) -> None:
"""
Replace modules with vLLM optimized versions.
Uses tp_plan for tensor parallel style selection.
Note: lm_head is NOT replaced here - it's created at wrapper level by CausalMixin.
"""
logger.info("Replacing modules with vLLM optimized versions...")
replaced_count = 0
# Get tensor parallel plan
tp_plan = self._get_tp_plan() if self.tp_group.world_size > 1 else {}
# Modules to skip replacement (handled at wrapper level)
skip_modules = {"lm_head", "score", "classifier"}
def _recursive_replace(module: nn.Module, prefix: str = ""):
nonlocal replaced_count
for name, child in list(module.named_children()):
# Skip PPMissingLayer
if isinstance(child, PPMissingLayer):
continue
# Skip modules that are handled at wrapper level
if name in skip_modules:
logger.debug("Skipping %s (handled at wrapper level)", name)
continue
qual_name = maybe_prefix(prefix, name)
new_module = None
if isinstance(child, nn.Linear):
# Determine parallelization style from tp_plan
style = "replicate"
for pattern, plan_style in tp_plan.items():
if re.match(pattern, qual_name):
style = plan_style
break
new_module = replace_linear_class(
child,
style=style,
quant_config=self.quant_config,
prefix=qual_name,
)
replaced_count += 1
elif child.__class__.__name__.endswith("RMSNorm") and \
not isinstance(child, RMSNorm):
new_module = replace_rms_norm_class(child, self.hidden_size)
replaced_count += 1
if new_module is not None:
setattr(module, name, new_module)
log_replacement(qual_name, child, new_module)
else:
_recursive_replace(child, qual_name)
_recursive_replace(self.model, "model")
logger.info("Replaced %d modules", replaced_count)
def _add_attention_debug_hook(self) -> None:
"""No-op. Debug hooks removed after root cause identified."""
pass
def _fix_attention_head_dim(self) -> None:
"""
Fix head_dim in attention modules and rotary embeddings after model creation.
Some models may have incorrect head_dim in config, which causes
Transformers attention modules and RoPE to use wrong dimensions.
This method corrects head_dim in all attention modules and recreates
rotary embeddings if needed.
"""
correct_head_dim = self.hidden_size // getattr(
self.text_config, "num_attention_heads", 32
)
fixed_count = 0
for name, module in self.model.named_modules():
module_name = module.__class__.__name__
# Fix head_dim in Attention modules
if "Attention" in module_name:
if hasattr(module, "head_dim"):
if module.head_dim != correct_head_dim:
logger.warning(
"Fixing head_dim in %s: %d -> %d",
name, module.head_dim, correct_head_dim
)
module.head_dim = correct_head_dim
fixed_count += 1
# Fix rotary embeddings - recreate inv_freq buffer if needed
if "RotaryEmbedding" in module_name:
if hasattr(module, "inv_freq"):
current_dim = module.inv_freq.shape[0] * 2
if current_dim != correct_head_dim:
logger.warning(
"Recreating rotary embedding %s: dim %d -> %d",
name, current_dim, correct_head_dim
)
base = getattr(module.config, 'rope_theta', 10000.0)
if hasattr(module.config, 'rope_parameters'):
base = module.config.rope_parameters.get('rope_theta', base)
device = module.inv_freq.device
inv_freq = 1.0 / (
base ** (
torch.arange(0, correct_head_dim, 2, dtype=torch.int64)
.to(device=device, dtype=torch.float) / correct_head_dim
)
)
module.register_buffer("inv_freq", inv_freq, persistent=False)
if hasattr(module, "original_inv_freq"):
module.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
if fixed_count > 0:
logger.info("Fixed head_dim in %d attention modules", fixed_count)
def _replace_input_embeddings(self) -> None:
"""Replace input embeddings with VocabParallelEmbedding."""
input_embeddings = self.model.get_input_embeddings()
if input_embeddings is None or isinstance(input_embeddings, PPMissingLayer):
return
if hasattr(input_embeddings, "embedding_dim"):
embedding_dim = input_embeddings.embedding_dim
elif hasattr(input_embeddings, "weight"):
embedding_dim = input_embeddings.weight.shape[1]
else:
embedding_dim = self.hidden_size
self.embed_scale = getattr(input_embeddings, "embed_scale", None)
logger.info("Replacing input embeddings (vocab=%d, dim=%d)",
self.vocab_size, embedding_dim)
new_embeddings = VocabParallelEmbedding(
self.vocab_size,
embedding_dim,
org_num_embeddings=self.vocab_size,
quant_config=self.quant_config,
)
self.model.set_input_embeddings(new_embeddings)
def _create_attention_instances(self) -> Dict[int, Attention]:
"""Create Attention instances for KV cache allocation."""
num_layers = getattr(self.text_config, "num_hidden_layers",
getattr(self.text_config, "num_layers", 32))
num_heads = getattr(self.text_config, "num_attention_heads", 32)
head_size = self.hidden_size // num_heads
num_kv_heads = getattr(self.text_config, "num_key_value_heads", num_heads)
# Get PP layer range
pp_rank = self.pp_group.rank_in_group
pp_size = self.pp_group.world_size
start_layer, end_layer = get_pp_indices(num_layers, pp_rank, pp_size)
logger.info("Creating attention instances for layers %d-%d "
"(heads=%d, head_size=%d, kv_heads=%d)",
start_layer, end_layer, num_heads, head_size, num_kv_heads)
attention_instances: Dict[int, Attention] = {}
for layer_idx in range(start_layer, end_layer):
per_layer_sliding_window = None
if hasattr(self.config, "layer_types"):
layer_types = self.config.layer_types
if layer_idx < len(layer_types) and layer_types[layer_idx] == "sliding_attention":
per_layer_sliding_window = getattr(self.config, "sliding_window", None)
attention = Attention(
num_heads=num_heads,
head_size=head_size,
scale=1.0 / (head_size ** 0.5),
num_kv_heads=num_kv_heads,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=f"model.layers.{layer_idx}.self_attn",
)
attention_instances[layer_idx] = attention
return attention_instances
def _init_parameters(self) -> None:
"""Initialize parameters from meta device to target device."""
device = self.device_config.device
if device is None:
device = torch.device("cpu")
dtype = self.model_config.dtype
def _init_params(module: nn.Module):
if isinstance(module, PPMissingLayer):
return
for name, param in list(module.named_parameters(recurse=False)):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(param.data, dtype=dtype, device=device),
requires_grad=False,
)
setattr(module, name, new_param)
for child in module.children():
_init_params(child)
_init_params(self.model)
logger.info("Parameters initialized on %s", device)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Get embeddings for input IDs."""
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if self.embed_scale is not None:
inputs_embeds = inputs_embeds * self.embed_scale
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""Forward pass with pipeline parallel support."""
# Handle intermediate tensors for PP
if not self.pp_group.is_first_rank:
assert intermediate_tensors is not None
input_ids = None
inputs_embeds = intermediate_tensors["hidden_states"]
set_attention_context(attn_metadata, kv_caches)
try:
# Prepare inputs
if inputs_embeds is not None:
if inputs_embeds.dim() == 2:
inputs_embeds = inputs_embeds.unsqueeze(0)
model_inputs = {"inputs_embeds": inputs_embeds}
else:
if input_ids is not None and input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
model_inputs = {"input_ids": input_ids}
if positions is not None:
if positions.dim() == 1:
positions = positions.unsqueeze(0)
model_inputs["position_ids"] = positions
# Apply embed_scale if needed
if (
self.embed_scale is not None
and input_ids is not None
and inputs_embeds is None
):
inputs_embeds = self.embed_input_ids(model_inputs["input_ids"])
model_inputs = {"inputs_embeds": inputs_embeds}
if positions is not None:
model_inputs["position_ids"] = positions
# Forward through model
# Note: return_dict=False returns tuple, first element is last hidden state
with torch.no_grad():
outputs = self.model(
**model_inputs,
use_cache=False,
return_dict=False,
attention_instances=self.attention_instances,
)
# Get hidden states from model output
# For models using return_dict=False, outputs is a tuple
# outputs[0] is usually the last hidden state
if isinstance(outputs, tuple):
hidden_states = outputs[0]
else:
hidden_states = outputs
# Remove batch dimension
if hidden_states.dim() == 3 and hidden_states.size(0) == 1:
hidden_states = hidden_states.squeeze(0)
# Return intermediate tensors for PP
if not self.pp_group.is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states
finally:
clear_attention_context()
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> Set[str]:
"""Load weights using AutoWeightsLoader with name mapping."""
loader = AutoWeightsLoader(
self,
skip_prefixes=self.skip_prefixes,
ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
)
loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
logger.info("Loaded %d weight tensors", len(loaded))
return set(loaded)

View File

@@ -0,0 +1,142 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Transformers modeling backend mixin for causal language models.
This module provides CausalMixin that adds causal language model specific
functionality (lm_head, compute_logits, sample) to the Base class.
Following latest vLLM architecture:
- TransformersForCausalLM = CausalMixin + Base
- lm_head is created at the wrapper level (not inside self.model)
"""
from typing import TYPE_CHECKING, Optional
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
class CausalMixin:
"""
Mixin class that adds causal language model functionality.
This mixin provides:
- ParallelLMHead for language model head (created at wrapper level)
- LogitsProcessor for logits computation
- Sampler for token sampling
- compute_logits method for VllmModelForTextGeneration protocol
- sample method for VllmModelForTextGeneration protocol
Following latest vLLM architecture:
- lm_head is a direct attribute of TransformersForCausalLM (not inside self.model)
- hf_to_vllm_mapper maps "model.lm_head." -> "lm_head." to handle this
- For tied embeddings, lm_head weight loading is skipped and weights are tied
Should be used with Base class:
class TransformersForCausalLM(CausalMixin, Base): ...
"""
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
# Call next class in MRO (should be Base)
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Handle tied word embeddings - skip loading lm_head weights
tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False)
if tie_word_embeddings:
self.skip_prefixes.append("lm_head.")
logger.info("Model has tied word embeddings, will tie lm_head weights")
# Create lm_head at wrapper level (following latest vLLM architecture)
# This is outside self.model, so weights map "model.lm_head." -> "lm_head."
if self.pp_group.is_last_rank:
self.lm_head = ParallelLMHead(
self.vocab_size,
self.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
# Tie weights if needed
if tie_word_embeddings:
input_embeddings = self.model.get_input_embeddings()
if input_embeddings is not None:
self.lm_head = self.lm_head.tie_weights(input_embeddings)
logger.info("Tied lm_head weights with input embeddings")
# Setup logits processor
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.vocab_size,
logits_as_input=False,
scale=logit_scale,
)
logger.info("CausalMixin initialized (vocab_size=%d, hidden_size=%d, logit_scale=%s)",
self.vocab_size, self.hidden_size, logit_scale)
else:
# For non-last PP ranks, use PPMissingLayer
self.lm_head = PPMissingLayer()
self.logits_processor = None
logger.info("CausalMixin initialized (PP non-last rank, using PPMissingLayer)")
# Setup sampler
self.sampler = Sampler()
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
"""
Compute logits from hidden states.
This method conforms to the VllmModelForTextGeneration protocol.
Args:
hidden_states: Hidden states from the model [seq_len, hidden_size]
sampling_metadata: Sampling metadata
Returns:
Logits tensor or None
"""
if self.logits_processor is None:
# Non-last PP rank
return None
# In v0.6.2, LogitsProcessor handles the lm_head projection internally
# via lm_head.linear_method.apply(). Pass lm_head as the first arg.
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
"""
Sample tokens from logits.
This method conforms to the VllmModelForTextGeneration protocol.
Args:
logits: Logits tensor
sampling_metadata: Sampling metadata
Returns:
SamplerOutput with sampled tokens
"""
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

View File

@@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Transformers modeling backend mixin for legacy models.
This module provides LegacyMixin for BERT-like encoder models that have
different weight naming conventions and special position handling.
Following latest vLLM architecture patterns adapted for v0.6.2.
"""
from typing import TYPE_CHECKING, List, Optional
import torch
from vllm.logger import init_logger
from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
class LegacyMixin:
"""
Mixin class for legacy/encoder models like BERT, RoBERTa.
This mixin provides:
- Weight name mapping for legacy suffix conventions (.gamma/.beta)
- Prefix mapping for BERT-like model structures
- RoBERTa-specific position handling
- Skip prefixes for unsupported output layers
Should be used with Base class:
class TransformersForLegacy(LegacyMixin, Base): ...
"""
# Weight name mapping for legacy models
hf_to_vllm_mapper = WeightsMapper(
# These are applied in order, so the order matters!
orig_to_new_prefix={
# Handle BERT-like models
"roberta": "model",
"bert": "model",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms
".gamma": ".weight",
".beta": ".bias",
},
)
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
# Call next class in MRO (should be Base)
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Skip unsupported/unwanted output embeddings layers
self.skip_prefixes.extend([
"model.lm_head.",
"model.predictions.",
"model.qa_outputs.",
"model.embeddings_project.",
"model.discriminator_predictions.",
])
# v0.6.2 doesn't have skip_substrs, so we handle it differently
# Store patterns to skip during weight loading
self._legacy_skip_patterns: List[str] = [
"position_ids", # Some encoder models have position_ids buffer
"score.bias", # Final classifier bias not used by vLLM
]
# RoBERTa-like models have extra padding in positions
model_type = getattr(self.text_config, "model_type", "").lower()
self.is_roberta = "roberta" in model_type
self.padding_idx = getattr(self.text_config, "pad_token_id", 1)
if self.is_roberta:
logger.info("LegacyMixin detected RoBERTa model, enabling position padding")
logger.info("LegacyMixin initialized for legacy/encoder model")
def _should_skip_weight(self, name: str) -> bool:
"""Check if a weight should be skipped during loading."""
for pattern in self._legacy_skip_patterns:
if pattern in name:
return True
return False
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Forward pass with RoBERTa position handling.
RoBERTa models require positions to be offset by padding_idx + 1.
"""
if self.is_roberta and positions is not None:
# RoBERTa-specific positions padding
positions = positions + self.padding_idx + 1
return super().forward(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs,
)

View File

@@ -0,0 +1,170 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Transformers modeling backend mixins for pooling/embedding models.
This module provides mixins for embedding and sequence classification models:
- EmbeddingMixin: For embedding/sentence similarity models
- SequenceClassificationMixin: For sequence classification/cross-encoding
Following latest vLLM architecture patterns adapted for v0.6.2.
"""
from typing import TYPE_CHECKING, List, Optional
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import PoolerOutput
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
class EmbeddingMixin:
"""
Mixin class that adds embedding/pooling functionality.
This mixin provides:
- Pooler layer for extracting embeddings
- pooling method for VllmModelForPooling protocol
Should be used with Base class:
class TransformersForEmbedding(EmbeddingMixin, Base): ...
"""
# Default pooling configuration
default_pooling_type: PoolingType = PoolingType.CLS
default_normalize: bool = True
default_softmax: bool = False
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
# Call next class in MRO (should be Base)
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Get pooler config from model config
pooler_config = vllm_config.model_config.pooler_config
# Setup pooler
self.pooler = Pooler.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=self.default_pooling_type,
normalize=self.default_normalize,
softmax=self.default_softmax,
)
if self.pooler is None:
# Create default pooler if config doesn't specify
self.pooler = Pooler(
pooling_type=self.default_pooling_type,
normalize=self.default_normalize,
softmax=self.default_softmax,
)
logger.info("EmbeddingMixin initialized (pooling_type=%s, normalize=%s)",
self.pooler.pooling_type.name, self.pooler.normalize)
def pooling(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
"""
Apply pooling to hidden states.
Args:
hidden_states: Hidden states from the model [seq_len, hidden_size]
pooling_metadata: Pooling metadata
Returns:
PoolerOutput with pooled embeddings
"""
return self.pooler(hidden_states, pooling_metadata)
class SequenceClassificationMixin(EmbeddingMixin):
"""
Mixin class that adds sequence classification functionality.
This mixin provides:
- Classifier layer for sequence classification
- pooling method with classification logits
Should be used with Base class:
class TransformersForSequenceClassification(SequenceClassificationMixin, Base): ...
"""
default_pooling_type: PoolingType = PoolingType.CLS
default_normalize: bool = False
default_softmax: bool = True
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
# Call EmbeddingMixin.__init__ -> Base.__init__
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Find and setup classifier layer
self.classifier = self._find_classifier()
if self.classifier is not None:
# Initialize classifier parameters on device
self._init_classifier_params()
logger.info("SequenceClassificationMixin initialized with classifier")
else:
logger.warning("Could not find classifier layer")
def _find_classifier(self) -> Optional[nn.Module]:
"""Find the classifier layer in the model."""
# Common classifier layer names
classifier_names = ['classifier', 'score', 'fc', 'head']
for name in classifier_names:
if hasattr(self.model, name):
return getattr(self.model, name)
return None
def _init_classifier_params(self) -> None:
"""Initialize classifier parameters on target device."""
device = self.device_config.device
if device is None:
device = torch.device("cpu")
dtype = self.model_config.dtype
for name, param in list(self.classifier.named_parameters()):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(param.data, dtype=dtype, device=device),
requires_grad=False,
)
setattr(self.classifier, name.split('.')[-1], new_param)
def pooling(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
"""
Apply pooling and classification to hidden states.
Args:
hidden_states: Hidden states from the model [seq_len, hidden_size]
pooling_metadata: Pooling metadata
Returns:
PoolerOutput with classification logits
"""
# First apply base pooling
pooled = self.pooler(hidden_states, pooling_metadata)
# Apply classifier if available
if self.classifier is not None and pooled is not None:
# Apply classifier to each pooled output
for i, output in enumerate(pooled.outputs):
if hasattr(output, 'data'):
output.data = self.classifier(output.data)
return pooled

View File

@@ -0,0 +1,247 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Transformers modeling backend utilities for v0.6.2.
This module provides utility functions for the Transformers backend,
including context managers for meta device initialization and
module replacement functions.
"""
from contextlib import contextmanager
from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
logger = init_logger(__name__)
@contextmanager
def init_on_device_without_buffers(device: Union[str, torch.device]):
"""
A context manager under which models are initialized with all
parameters on the specified device. However buffers are not
initialized on specified device.
This is useful for creating model structure without allocating
GPU memory, which is essential for memory efficiency.
Args:
device: Device to initialize all parameters on (e.g., "meta").
Example:
with init_on_device_without_buffers("meta"):
model = AutoModel.from_config(config)
# Now model is on meta device, no GPU memory allocated
"""
if isinstance(device, str):
device = torch.device(device)
old_register_parameter = nn.Module.register_parameter
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(
module._parameters[name].to(device), **kwargs
)
try:
nn.Module.register_parameter = register_empty_parameter
yield
finally:
nn.Module.register_parameter = old_register_parameter
# Linear replacement styles
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
def replace_linear_class(
linear: nn.Linear,
style: Style = "replicate",
quant_config: Optional["QuantizationConfig"] = None,
prefix: str = "",
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
This replacement provides:
- Memory efficiency through proper tensor allocation
- Support for quantization
- Tensor parallel support (when using ColumnParallel/RowParallel)
Args:
linear: `nn.Linear` to be replaced.
style: Tensor parallel style of the new linear:
- "colwise": Column parallel (split output dim)
- "colwise_rep": Column parallel with gather output
- "rowwise": Row parallel (split input dim)
- "rowwise_rep": Row parallel without parallel input
- "replicate": Replicated (no parallelism)
quant_config: Quantization config for the new linear.
prefix: The name of the layer for weight loading.
Returns:
The new vLLM linear layer.
"""
if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
vllm_linear_cls, vllm_linear_kwargs = {
"colwise": (ColumnParallelLinear, {}),
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
"rowwise": (RowParallelLinear, {}),
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
"replicate": (ReplicatedLinear, {}),
}.get(style, (ReplicatedLinear, {}))
return vllm_linear_cls(
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
quant_config=quant_config,
prefix=prefix,
return_bias=False, # Return tensor only, not (tensor, bias) tuple
**vllm_linear_kwargs,
)
class TransformersRMSNorm(RMSNorm):
"""
vLLM RMSNorm subclass that preserves tensor dimensions.
vLLM's RMSNorm (especially the MLU backend) flattens input to 2D
(e.g., [batch, seq, hidden] -> [batch*seq, hidden]), but transformers
expects the batch dimension to be preserved. This subclass wraps
the parent forward methods to save and restore the original tensor shape.
Since this inherits from RMSNorm directly, weight loading via
named_parameters() works correctly (weight path stays the same).
"""
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_native(x, residual)
return self._restore_shape(result, orig_shape)
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_cuda(x, residual)
return self._restore_shape(result, orig_shape)
def forward_mlu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_mlu(x, residual)
return self._restore_shape(result, orig_shape)
def forward_xpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_xpu(x, residual)
return self._restore_shape(result, orig_shape)
def forward_hpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_hpu(x, residual)
return self._restore_shape(result, orig_shape)
@staticmethod
def _restore_shape(result, orig_shape: Tuple):
"""Restore original tensor shape if it was changed."""
if isinstance(result, tuple):
restored = []
for t in result:
if t is not None and t.shape != orig_shape:
t = t.view(orig_shape)
restored.append(t)
return tuple(restored)
else:
if result.shape != orig_shape:
result = result.view(orig_shape)
return result
def replace_rms_norm_class(
rms_norm: nn.Module,
hidden_size: int,
) -> nn.Module:
"""
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm,
wrapped to preserve tensor dimensions.
vLLM's RMSNorm provides:
- Fused CUDA kernels for better performance
- Support for fused add + norm operations
The wrapper ensures that the original tensor shape (including batch
dimension) is preserved, which is required by transformers' model
forward methods.
Args:
rms_norm: The RMSNorm module to replace.
hidden_size: The hidden size of the model.
Returns:
The new vLLM RMSNorm layer wrapped for shape preservation.
"""
# Try to get epsilon from various attribute names
eps = getattr(rms_norm, "eps", None)
if eps is None:
eps = getattr(rms_norm, "variance_epsilon", None)
if eps is None:
eps = 1e-6
# Check if weight exists and get its size
weight = getattr(rms_norm, "weight", None)
if weight is not None:
hidden_size = weight.size(0)
return TransformersRMSNorm(hidden_size=hidden_size, eps=eps)
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
"""Log module replacement for debugging."""
logger.debug("Replaced %s: %s -> %s", name, type(old_module).__name__, type(new_module).__name__)
def maybe_prefix(prefix: str, name: str) -> str:
"""Combine prefix and name with a dot separator."""
if prefix:
return f"{prefix}.{name}"
return name

View File

@@ -492,6 +492,29 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
return module
def _get_device_memory_info() -> Tuple[Optional[float], Optional[float], Optional[float]]:
"""Get device memory info in GiB. Returns (allocated, reserved, total) or Nones."""
try:
import torch.mlu
allocated = torch.mlu.memory_allocated() / (1024 ** 3)
reserved = torch.mlu.memory_reserved() / (1024 ** 3)
free, total = torch.mlu.mem_get_info()
total = total / (1024 ** 3)
return allocated, reserved, total
except Exception:
pass
try:
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / (1024 ** 3)
reserved = torch.cuda.memory_reserved() / (1024 ** 3)
free, total = torch.cuda.mem_get_info()
total = total / (1024 ** 3)
return allocated, reserved, total
except Exception:
pass
return None, None, None
def make_layers(
num_hidden_layers: int,
layer_fn: LayerFn,
@@ -505,11 +528,31 @@ def make_layers(
start_layer, end_layer = get_pp_indices(num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)
alloc_before, _, total = _get_device_memory_info()
if alloc_before is not None:
logger.info(
"[DEBUG-MEM] make_layers start: allocated=%.2f GiB, "
"total=%.2f GiB, layers to create: %d-%d / %d",
alloc_before, total, start_layer, end_layer, num_hidden_layers)
created_layers = []
for idx in range(start_layer, end_layer):
layer = maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
alloc_after, reserved, _ = _get_device_memory_info()
if alloc_after is not None:
delta = alloc_after - alloc_before
logger.info(
"[DEBUG-MEM] Layer %s.%d created: "
"allocated=%.2f GiB (+%.4f GiB), reserved=%.2f GiB",
prefix, idx, alloc_after, delta, reserved)
alloc_before = alloc_after
created_layers.append(layer)
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] + [
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
for idx in range(start_layer, end_layer)
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
[PPMissingLayer() for _ in range(start_layer)]
+ created_layers
+ [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules

View File

@@ -159,9 +159,11 @@ class MLUSpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_worker_kwargs[
"model_runner_cls"] = MLUTP1DraftModelRunner
else:
if draft_model_config.hf_config.model_type == "eagle":
if draft_model_config.hf_config.model_type in (
"eagle", "deepseek_mtp"):
raise NotImplementedError(
"EAGLE does not support TP > 1 yet")
f"{draft_model_config.hf_config.model_type} "
"does not support TP > 1 yet")
allow_zero_draft_token_step = False
proposer_worker = MLUMultiStepWorker(**draft_worker_kwargs)

View File

@@ -13,7 +13,7 @@ from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import (
get_image_processor_config)
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES)
from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
from vllm.envs import VLLM_USE_MODELSCOPE
@@ -89,9 +89,10 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
# hf_hub. This will fail in offline mode.
try:
return file_exists(model, config_name, revision=revision, token=token)
except huggingface_hub.errors.OfflineModeIsEnabled:
# Don't raise in offline mode, all we know is that we don't have this
# file cached.
except (huggingface_hub.errors.OfflineModeIsEnabled,
huggingface_hub.errors.HFValidationError):
# Don't raise in offline mode or when model path fails HF validation
# (e.g., local paths that don't match HF repo id format)
return False
@@ -112,7 +113,9 @@ def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None:
logger.info("Replacing legacy 'type' key with 'rope_type'")
if "rope_type" not in rope_scaling:
raise ValueError("rope_scaling should have a 'rope_type' key")
rope_scaling["rope_type"] = "default"
logger.warning("rope_scaling missing 'rope_type' key, "
"defaulting to 'default'")
if rope_scaling["rope_type"] == "su":
rope_scaling["rope_type"] = "longrope"
@@ -167,12 +170,6 @@ def get_config(
token=token):
config_format = ConfigFormat.MISTRAL
else:
# If we're in offline mode and found no valid config format, then
# raise an offline mode error to indicate to the user that they
# don't have files cached and may need to go online.
# This is conveniently triggered by calling file_exists().
file_exists(model, HF_CONFIG_NAME, revision=revision, token=token)
raise ValueError(f"No supported config format found in {model}")
if config_format == ConfigFormat.HF:
@@ -232,6 +229,17 @@ def get_config(
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})
# Architecture mapping for models without explicit architectures field
if not getattr(config, "architectures", None):
if config.model_type not in MODEL_MAPPING_NAMES:
logger.warning(
"Model config does not have a top-level 'architectures' "
"field: expecting `hf_overrides={'architectures': "
"['...']}` to be passed in engine args.")
else:
model_type = MODEL_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})
patch_rope_scaling(config)
if trust_remote_code:

View File

@@ -59,12 +59,20 @@ class MLUWorker(Worker):
# mlp_speculator
speculative_config = self.speculative_config
model_config = self.model_config
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle"]) \
else {"return_hidden_states": True}
is_mtp = (speculative_config is not None
and model_config.task != "draft"
and getattr(
speculative_config.draft_model_config.hf_config,
"model_type", None) == "deepseek_mtp")
speculative_args = (
{"return_hidden_states": True} if is_mtp else
({} if speculative_config is None
or (speculative_config.draft_model_config.model ==
model_config.model)
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle"])
else {"return_hidden_states": True})
)
ModelRunnerClass: Type[MLUModelRunnerBase] = MLUModelRunner
if model_runner_cls is not None:

View File

@@ -580,34 +580,58 @@ def unified_flash_attention_v2(
value_cache,
updated_slot_mapping.flatten())
else:
# FIXME: After TMO-1496 is completed, remove this code.
if key.stride() != value.stride():
key = key.contiguous()
value = value.contiguous()
if kv_cache_dtype == 'int8':
mlu_ops.quant_to_linear_cache(key,
value,
key_cache,
value_cache,
key_cache_scale,
value_cache_scale,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, # packed
None, # context_seq_offset
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
# unpaged (linear cache) path
if use_mla:
# MLA: 镜像 paged 路径的处理方式
# key_cache: (num_blocks, 1, block_size, 576)
value_to_cache = None
if attn_metadata.prefill_metadata:
# MLA prefill cache 已在 forward_prefill 中写入,跳过
pass
else:
if kv_cache_dtype == 'int8':
mlu_ops.quant_to_linear_cache(
key, value_to_cache,
key_cache, value_cache,
key_cache_scale, value_cache_scale,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
else:
mlu_ops.reshape_linear_cache(
key, value_to_cache,
key_cache, value_cache,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
else:
mlu_ops.reshape_linear_cache(key,
value,
key_cache,
value_cache,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, # packed
None, # context_seq_offset
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
# FIXME: After TMO-1496 is completed, remove this code.
if key.stride() != value.stride():
key = key.contiguous()
value = value.contiguous()
if kv_cache_dtype == 'int8':
mlu_ops.quant_to_linear_cache(
key, value,
key_cache, value_cache,
key_cache_scale, value_cache_scale,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
else:
mlu_ops.reshape_linear_cache(
key, value,
key_cache, value_cache,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
if use_mla and attn_metadata.prefill_metadata:
output = torch.empty(query.shape[0], query.shape[1], v_head_size, dtype=query.dtype, device="mlu")
else:

View File

@@ -37,7 +37,7 @@ def vllm__config__CacheConfig___verify_cache_dtype(self) -> None:
def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type == 'deepseek_v2':
if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
# feature flag MLA
return 1
total_num_kv_heads = self.get_total_num_kv_heads()
@@ -51,7 +51,7 @@ def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "Parallel
def vllm__config__ModelConfig__get_head_size(self) -> int:
# TODO remove hard code
if hasattr(self.hf_text_config, "model_type"
) and self.hf_text_config.model_type == 'deepseek_v2':
) and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
'''
=============================
Modify by vllm_mlu
@@ -109,7 +109,7 @@ def vllm__config__LoRAConfig__verify_with_model_config(self, model_config: Model
def vllm__config__ModelConfig__is_deepseek_v2(self) -> bool:
result = hasattr(
self.hf_text_config,
"model_type") and self.hf_text_config.model_type == 'deepseek_v2'
"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp')
return result
MluHijackObject.apply_hijack(ModelConfig,

View File

@@ -26,6 +26,12 @@ def vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply(
beta = 1.0
residual = residual.view(-1, residual.shape[-1])
res_shape = x.shape[0:-1] + (layer.weight.shape[0], )
# MLU matmul requires all tensors to have matching dtypes
target_dtype = layer.weight.dtype
if x.dtype != target_dtype:
x = x.to(target_dtype)
if residual is not None and residual.dtype != target_dtype:
residual = residual.to(target_dtype)
return mlu_ops.matmul(x.view(-1, x.shape[-1]), layer.weight, bias, residual, 'none', 1.0, beta).view(res_shape)

View File

@@ -73,6 +73,24 @@ class SparseMoeMlp(nn.Module):
self.expert_group = expert_group
self.topk_group = topk_group
if get_device_major_capability() == 3:
# WARNING: MLU370 (capability=3) 不支持 fused_moe 算子,强制关闭。
#
# 背景:原始 forward_experts_nofused 包含 torch.unique、torch.tensor([0], ...)、
# 数据依赖分支等 graph capture 不兼容操作,导致 MLU370 上所有走 SparseMoeMlp
# 的 MoE 模型必须加 --enforce-eager 才能运行。当前已将 forward_experts_nofused
# 改为 dense 模式(每个 expert 处理全部 token用路由权重 mask解决了
# graph capture 兼容性问题,所有 MoE 模型无需 --enforce-eager 即可运行。
#
# 性能代价dense 模式计算量为 O(num_experts * num_tokens),相比稀疏路由的
# O(topk * num_tokens) 增大了 num_experts/topk 倍。prefill 阶段对 expert
# 数量多的模型会明显变慢decode 阶段token 少)影响可忽略。
# 已知受影响模型Mixtral (8)、Qwen2-MoE (60)、HunYuan (16)、Llama4 (16) 等。
# DeepSeek V2/V3 不受影响(有独立的 MLU MoE hijack 实现)。
#
# TODO: MLU370 已有完整的 MoE 算子链moe_gen_idx、moe_expand_input、
# group_gemm、moe_active、moe_combine_result与 forward_group_experts
# 使用的算子相同。后续应拆分 is_use_fused_moe 标志,让 MLU370 走
# forward_group_experts 路径以避免 dense 模式的性能开销。
self.is_use_fused_moe = False
if params_dtype is None:
@@ -284,34 +302,28 @@ class SparseMoeMlp(nn.Module):
def forward_experts_nofused(self, hidden_states, expert_logits):
hidden_states_shape = hidden_states.shape
# Dense approach: each expert processes ALL tokens, then mask by routing
# weights. This avoids data-dependent control flow (variable-size slicing,
# conditional branches, torch.unique, torch.tensor creation) that is
# incompatible with MLU graph capture.
num_tokens, hidden_size = hidden_states.shape
topk_values, topk_indices = self.topk_softmax(expert_logits)
expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count = self.generate_gather_idx(
topk_indices)
# no expert is routed, then expand_gather_idx, expand_scatter_idx has no item,
# expand_token_count and expand_cusum_token_count has item but the value is all zero
# so this rank should only return final_hidden_states with zero value
if expand_gather_idx.numel() == 0:
final_hidden_states = torch.zeros_like(hidden_states,
dtype=hidden_states.dtype,
device=hidden_states.device)
return final_hidden_states
expand_hidden_states = self.expand_input(hidden_states, expand_gather_idx)
final_hidden_states = torch.zeros(
num_tokens, hidden_size,
dtype=hidden_states.dtype, device=hidden_states.device)
expand_output_list = []
expand_cusum_token_count = cusum_token_count[self.start_expert_id:self.end_expert_id +
1] - cusum_token_count[self.start_expert_id]
for expert_idx, num_tokens_per_expert in enumerate(expand_token_count):
if num_tokens_per_expert > 0:
expert_hidden_states = expand_hidden_states[
expand_cusum_token_count[expert_idx]:expand_cusum_token_count[expert_idx + 1]]
expert_output = self.experts[expert_idx](expert_hidden_states)
expert_output = expert_output[0] if isinstance(expert_output, (tuple, list)) else expert_output
expand_output_list.append(expert_output)
expand_output = torch.cat(expand_output_list, dim=0)
final_hidden_states = self.combine_moe(expand_output, scatter_idx, cusum_token_count, hidden_states_shape,
topk_values)
for expert_idx in range(self.num_experts_per_rank):
global_expert_idx = self.start_expert_id + expert_idx
expert_output = self.experts[expert_idx](hidden_states)
expert_output = expert_output[0] if isinstance(
expert_output, (tuple, list)) else expert_output
# Routing weight per token for this expert
expert_mask = (topk_indices == global_expert_idx).to(topk_values.dtype)
expert_weights = (topk_values * expert_mask).sum(dim=-1, keepdim=True)
final_hidden_states = final_hidden_states + expert_output * expert_weights
return final_hidden_states
@@ -425,9 +437,9 @@ class SparseMoeMlp(nn.Module):
scatter_idx=torch.zeros((indices.numel(),), dtype=seqs.dtype, device=seqs.device).scatter(0, indices, seqs)
# token_count: [self.num_experts_per_rank]
partial_token_index, partial_token_count = sorted_expert_id.unique(sorted=True, return_counts=True)
zero_token_count = torch.zeros(self.num_total_experts, dtype=partial_token_count.dtype, device=device)
token_count = zero_token_count.scatter(dim=0, index=partial_token_index, src=partial_token_count)
# Use scatter_add_ instead of torch.unique for MLU graph capture compatibility
token_count = torch.zeros(self.num_total_experts, dtype=sorted_expert_id.dtype, device=device)
token_count.scatter_add_(0, sorted_expert_id, torch.ones_like(sorted_expert_id))
# cusum_token_count: [self.num_experts_per_rank + 1]
cusum_token_count = torch.cat(
[torch.tensor([0], dtype=token_count.dtype, device=device),

View File

@@ -39,3 +39,9 @@ try:
except ImportError as e:
import logging
logging.warning(f"Failed to import mllama hijack: {e}")
try:
import vllm_mlu.model_executor.models.llama4
except ImportError as e:
import logging
logging.warning(f"Failed to import llama4 hijack: {e}")

View File

@@ -28,6 +28,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
from vllm import _mlu_ops as mlu_ops
from vllm.utils import print_warning_once
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm_mlu.model_executor.models.layer_utils import quant_fusion_with_rmsnorm
@@ -77,6 +78,12 @@ class DeepseekV2MoE(SparseMoeMlp):
bias=False,
quant_config=None,
prefix=f"{prefix}.gate")
if getattr(config, "topk_method", None) == "noaux_tc":
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts, dtype=torch.float32)
)
else:
self.gate.e_score_correction_bias = None
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
@@ -104,6 +111,7 @@ class DeepseekV2MoE(SparseMoeMlp):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
@@ -113,9 +121,25 @@ class DeepseekV2MoE(SparseMoeMlp):
Modify by vllm_mlu
=============================
@brief: replace experts() with forward_experts, which defined by SparseMoeMlp.
For noaux_tc (DeepSeek V3), do manual routing with e_score_correction_bias.
'''
final_hidden_states = self.forward_experts(
hidden_states, router_logits) * self.routed_scaling_factor
if self.gate.e_score_correction_bias is not None:
# noaux_tc routing: softmax → add bias for topk selection → use original scores
scores = router_logits.float().softmax(dim=-1)
scores_for_choice = scores + self.gate.e_score_correction_bias.unsqueeze(0)
topk_weights, topk_indices = torch.topk(
scores_for_choice, k=self.top_k, dim=-1)
# Use original softmax scores (without bias) as weights
topk_weights = scores.gather(1, topk_indices)
if self.renormalize:
topk_weights = topk_weights / topk_weights.sum(
dim=-1, keepdim=True)
final_hidden_states = self.forward_experts_with_precomputed_routing(
hidden_states, topk_weights, topk_indices
) * self.routed_scaling_factor
else:
final_hidden_states = self.forward_experts(
hidden_states, router_logits) * self.routed_scaling_factor
'''
==================
End of MLU Hijack
@@ -129,6 +153,55 @@ class DeepseekV2MoE(SparseMoeMlp):
return final_hidden_states.view(num_tokens, hidden_dim)
def forward_experts_with_precomputed_routing(
self, hidden_states, topk_weights, topk_indices
):
"""使用预计算的路由结果执行 MoE 前向传播"""
self.pack_params()
ori_input_shape = hidden_states.shape
expert_num = self.num_total_experts
expert_size = self.w13.size(0)
max_m = hidden_states.shape[0]
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
reduce_weight = topk_weights.to(torch.float32)
expert_id = topk_indices.to(torch.int32)
# gen_idx
expand_idx, combine_idx, token_count, cusum_token_count = (
mlu_ops.moe_gen_idx(expert_id, expert_num)
)
start_expert_id = self.start_expert_id
# gemm1
expand_hidden_states = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id, expert_size
)
gemm1_out = mlu_ops.group_gemm(
expand_hidden_states, self.w13,
token_count[start_expert_id:start_expert_id + expert_size],
None, None, None, None, max_m
)
# activation
act_out = mlu_ops.moe_active(
gemm1_out, self.hidden_act, self.is_gated, None, self.b13,
cusum_token_count, start_expert_id, expert_size
)
# gemm2
gemm2_out = mlu_ops.group_gemm(
act_out, self.w2,
token_count[start_expert_id:start_expert_id + expert_size],
None, None, None, None, max_m
)
# combine
output = mlu_ops.moe_combine_result(
gemm2_out, reduce_weight, combine_idx,
None, cusum_token_count, start_expert_id,
expert_size, self.b2
)
return output.view(ori_input_shape)
def forward_prefill(
self,
positions: torch.Tensor,
@@ -179,19 +252,27 @@ def forward_prefill(
updated_slot_mapping = attn_metadata.slot_mapping
if self.attn.kv_cache_dtype == 'int8':
key_cache_scale = kv_cache[1][0]
mlu_ops.quant_to_paged_cache(key_value,
mlu_ops.quant_to_linear_cache(key_value,
None,
key_cache,
None,
key_cache_scale,
None,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
else:
mlu_ops.reshape_linear_cache(key_value,
None,
key_cache,
None,
key_cache_scale,
None,
updated_slot_mapping.flatten())
else:
mlu_ops.reshape_paged_cache(key_value,
None,
key_cache,
None,
updated_slot_mapping.flatten())
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
'''
==================
End of MLU Hijack
@@ -491,6 +572,15 @@ def vllm__module_executor__models__deepseek_v2__DeepseekV2DecoderLayer__init__(
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def get_spec_layer_idx_from_weight_name(config, weight_name):
num_nextn = getattr(config, "num_nextn_predict_layers", 0)
if num_nextn and num_nextn > 0:
layer_idx = config.num_hidden_layers
for i in range(num_nextn):
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
return layer_idx + i
return None
def vllm__module_executor__models__deepseek_v2__DeepseekV2ForCausalLM__load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
'''
=============================
@@ -530,6 +620,10 @@ def vllm__module_executor__models__deepseek_v2__DeepseekV2ForCausalLM__load_weig
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# Skip MTP speculative decoding layer weights
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue
'''
=============================
Modify by vllm_mlu
@@ -565,7 +659,9 @@ def vllm__module_executor__models__deepseek_v2__DeepseekV2ForCausalLM__load_weig
@brief: add expert skiped condition and delete useless if name not in params_dict: continue condition
'''
name = name.replace(weight_name, param_name)
if (("mlp.experts." in name or "mlp.shared_experts." in name or "mlp.shared_expert_gate." in name)
if (("mlp.experts." in name or "mlp.shared_experts." in name
or "mlp.shared_expert_gate." in name
or "e_score_correction_bias" in name)
and name not in params_dict):
continue
'''
@@ -595,7 +691,9 @@ def vllm__module_executor__models__deepseek_v2__DeepseekV2ForCausalLM__load_weig
if name.endswith(".bias") and name not in params_dict:
continue
if (("mlp.experts." in name or "mlp.shared_experts." in name or "mlp.shared_expert_gate." in name)
if (("mlp.experts." in name or "mlp.shared_experts." in name
or "mlp.shared_expert_gate." in name
or "e_score_correction_bias" in name)
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):

View File

@@ -194,6 +194,11 @@ def decoder_model_forward_base_pp(
hidden_states = inputs_embeds
else:
hidden_states = get_input_embeddings(input_ids)
# MLU F.embedding may output float32 even with float16 weights;
# cast to model dtype to avoid dtype mismatches downstream.
target_dtype = next(layers[start_layer].parameters()).dtype
if hidden_states.dtype != target_dtype:
hidden_states = hidden_states.to(target_dtype)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]

View File

@@ -0,0 +1,495 @@
import torch
import re
from typing import List, Optional, Tuple, Union, Iterable
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.llama4 import (
Llama4Attention, Llama4DecoderLayer, Llama4ForCausalLM,
Llama4Model, Llama4MoE)
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm.sequence import IntermediateTensors
from vllm_mlu.model_executor.models.layer_utils import (
decoder_layer_forward_base, decoder_model_forward_base_pp,
is_per_tensor_smoothquant, is_per_token_smoothquant,
quant_fusion_with_rmsnorm)
from vllm.logger import init_logger
logger = init_logger(__name__)
# ============================================================
# Llama4MoE MLU replacement: SparseMoeMlp + shared expert
# ============================================================
class Llama4MoEMlu(SparseMoeMlp):
"""MLU replacement for Llama4MoE using SparseMoeMlp + shared expert."""
def __init__(self, config, quant_config=None, prefix=""):
num_local_experts = getattr(config, "num_local_experts", 8)
top_k = getattr(config, "num_experts_per_tok", 1)
hidden_size = getattr(config, "hidden_size", 4096)
intermediate_size = getattr(config, "intermediate_size", 8192)
super().__init__(
num_experts=num_local_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
up_proj_name="gate_up_proj",
is_gated=True,
down_proj_name="down_proj",
has_bias=False,
skip_bias_add=False,
renormalize=False,
hidden_act="silu",
params_dtype=None,
quant_config=quant_config,
is_use_fused_moe=True,
)
# Llama4 uses sigmoid routing, not softmax
# Override topk_softmax to use sigmoid
self._use_sigmoid_routing = True
# Shared expert (independent from routed experts)
self.shared_expert = FeedForward(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act="silu",
up_proj_name="gate_up_proj",
is_gated=True,
down_proj_name="down_proj",
bias=False,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_expert",
)
def topk_softmax(self, expert_logits):
"""Override: Llama4 uses sigmoid routing instead of softmax."""
topk_values, topk_indices = torch.topk(
expert_logits, self.top_k, dim=-1)
topk_values = torch.sigmoid(topk_values.float())
return topk_values, topk_indices
def forward(self, hidden_states, residual=None):
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# Shared expert output
shared_out = self.shared_expert(hidden_states)
# Router logits
router_logits, _ = self.gate(hidden_states)
# Routed experts
routed_out = self.forward_experts(hidden_states, router_logits, None)
# Combine
final_out = routed_out + shared_out
if self.tp_size > 1:
final_out = tensor_model_parallel_all_reduce(final_out)
return final_out.view(orig_shape)
# ============================================================
# Llama4Attention hijack
# ============================================================
vllm__llama4__Llama4Attention__init__org = Llama4Attention.__init__
def vllm__llama4__Llama4Attention____init__(
self,
config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
vllm__llama4__Llama4Attention__init__org(
self, config, hidden_size, num_heads, num_kv_heads,
max_position_embeddings, quant_config, bias, cache_config, prefix)
'''
=============================
Modify by vllm_mlu
=============================
@brief: save rope_scaling for MLU RoPE dispatch
'''
self.rope_scaling = getattr(config, "rope_scaling", None)
'''
==================
End of MLU Hijack
==================
'''
def vllm__llama4__Llama4Attention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: MLU RoPE: merge q/k, apply rotary, split back (教训 #3)
For NoPE layers (self.rotary_emb is None), skip RoPE entirely.
'''
if self.rotary_emb is not None:
if (self.rope_scaling is not None
and self.rope_scaling.get("rope_type") == "longrope"):
q, k = self.rotary_emb(positions, q, k)
else:
qk, _ = qkv.split(
[self.q_size + self.kv_size, self.kv_size], dim=-1)
self.rotary_emb(
positions,
qk.view(-1, self.num_heads + self.num_kv_heads,
self.head_dim))
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
'''
==================
End of MLU Hijack
==================
'''
# QK norm (MLU fused_rms_norm requires matching dtypes, skip .float())
if self.qk_norm is not None:
q = q.contiguous().reshape(-1, self.head_dim)
q = (self.qk_norm(q)
.contiguous().reshape(-1, self.q_size))
k = k.contiguous().reshape(-1, self.head_dim)
k = (self.qk_norm(k)
.contiguous().reshape(-1, self.kv_size))
# Temperature tuning for NoPE layers
if self.attn_temperature_tuning and self.nope:
attn_scale = self._get_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual in o_proj
'''
output, _ = self.o_proj(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output
# ============================================================
# Llama4DecoderLayer hijack
# ============================================================
def vllm__llama4__Llama4DecoderLayer____init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super(Llama4DecoderLayer, self).__init__()
from vllm.model_executor.models.llama4 import (
_extract_layer_index, Llama4Attention)
self.layer_idx = _extract_layer_index(prefix)
self.hidden_size = getattr(config, "hidden_size", 4096)
max_position_embeddings = getattr(
config, "max_position_embeddings", 8192)
self.self_attn = Llama4Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=getattr(config, "num_attention_heads", 32),
num_kv_heads=getattr(config, "num_key_value_heads",
getattr(config, "num_attention_heads", 32)),
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
interleave_moe_layer_step = getattr(
config, "interleave_moe_layer_step", 0)
is_moe_layer = (interleave_moe_layer_step > 0
and (self.layer_idx + 1)
% interleave_moe_layer_step == 0)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Replace MoE with Llama4MoEMlu (SparseMoeMlp + shared expert),
Replace dense MLP with FeedForward.
'''
if is_moe_layer:
self.feed_forward = Llama4MoEMlu(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
else:
intermediate_size_mlp = getattr(
config, "intermediate_size_mlp",
getattr(config, "intermediate_size", 8192))
self.feed_forward = FeedForward(
hidden_size=self.hidden_size,
intermediate_size=intermediate_size_mlp,
hidden_act="silu",
up_proj_name="gate_up_proj",
is_gated=True,
down_proj_name="down_proj",
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
'''
==================
End of MLU Hijack
==================
'''
rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5)
self.input_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
self.hidden_size, eps=rms_norm_eps)
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(
quant_config)
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(
quant_config)
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
self.quant_fusion_attn_layernorm = None
def vllm__llama4__Llama4DecoderLayer__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: use decoder_layer_forward_base with residual-in-matmul
and optional quant fusion.
'''
attn_layernorm = self.input_layernorm
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
if self.quant_fusion_attn_layernorm is None:
if self.is_per_token_sq_perf_cases:
attn_quant_scale = self.self_attn.qkv_proj.smooth
else:
attn_quant_scale = self.self_attn.qkv_proj.scale_to_int
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.input_layernorm, attn_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
attn_layernorm = self.quant_fusion_attn_layernorm
return decoder_layer_forward_base(
positions, hidden_states, kv_cache, attn_metadata,
attn_layernorm,
self.self_attn,
self.post_attention_layernorm,
self.feed_forward,
input_norm_fuse_en=self.is_per_token_sq_perf_cases)
# ============================================================
# Llama4Model hijack
# ============================================================
def vllm__llama4__Llama4Model__forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return decoder_model_forward_base_pp(
input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors,
self.layers, self.start_layer, self.end_layer,
self.get_input_embeddings,
self.norm,
inputs_embeds)
# ============================================================
# Llama4ForCausalLM load_weights hijack
# ============================================================
def vllm__llama4__Llama4ForCausalLM__load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
):
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack params for SparseMoeMlp (MoE layers)
'''
for name, m in self.model.named_modules():
if isinstance(m, SparseMoeMlp):
m.pack_params()
start_expert_id = 0
'''
==================
End of MLU Hijack
==================
'''
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# Strip language_model. prefix for Llama4ForConditionalGeneration
if name.startswith("language_model."):
name = name[len("language_model."):]
# Skip vision encoder weights
elif (name.startswith("multi_modal_projector.")
or name.startswith("vision_encoder.")
or name.startswith("vision_model.")):
continue
# Permute Q/K weights for rotary embedding
name, loaded_weight = self.permute_qk_weight_for_rotary(
name, loaded_weight)
'''
=============================
Modify by vllm_mlu
=============================
@brief: remap expert_id for distributed inference
'''
if (start_expert_id > 0
and "feed_forward.experts." in name):
match = re.search(r'experts\.\d+', name)
if match:
expert_str = match.group(0)
expert_id = int(expert_str.split(".")[1])
named_expert_id = expert_id - start_expert_id
name = name.replace(
f"experts.{expert_id}",
f"experts.{named_expert_id}")
'''
==================
End of MLU Hijack
==================
'''
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
# Skip experts not assigned to this worker
if ("feed_forward.experts." in name
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
# Skip experts not assigned to this worker
if ("feed_forward.experts." in name
and name not in params_dict):
continue
if name not in params_dict:
logger.warning(
"Skipping weight %s not present in the model", name)
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
# ============================================================
# Apply all hijacks
# ============================================================
MluHijackObject.apply_hijack(
Llama4Attention,
Llama4Attention.__init__,
vllm__llama4__Llama4Attention____init__)
MluHijackObject.apply_hijack(
Llama4Attention,
Llama4Attention.forward,
vllm__llama4__Llama4Attention__forward)
MluHijackObject.apply_hijack(
Llama4DecoderLayer,
Llama4DecoderLayer.__init__,
vllm__llama4__Llama4DecoderLayer____init__)
MluHijackObject.apply_hijack(
Llama4DecoderLayer,
Llama4DecoderLayer.forward,
vllm__llama4__Llama4DecoderLayer__forward)
MluHijackObject.apply_hijack(
Llama4Model,
Llama4Model.forward,
vllm__llama4__Llama4Model__forward)
MluHijackObject.apply_hijack(
Llama4ForCausalLM,
Llama4ForCausalLM.load_weights,
vllm__llama4__Llama4ForCausalLM__load_weights)

View File

@@ -24,8 +24,29 @@ def vllm__worker__cache_engine__CacheEngine___allocate_kv_cache(
=============================
Modify by vllm_mlu
=============================
@brief: add kv_cache_scale for int8 support
'''
@brief: add kv_cache_scale for int8 support;
cap num_blocks to avoid exceeding CNNL int32 element limit
'''
# CNNL operators have a max supported tensor element count of INT32_MAX.
# num_blocks should already be capped by determine_num_available_blocks,
# this is a defensive check to catch any edge cases.
CNNL_MAX_TENSOR_ELEMENTS = 2**31 - 1
total_elements = 1
for dim in kv_cache_shape:
total_elements *= dim
if total_elements > CNNL_MAX_TENSOR_ELEMENTS:
elements_per_block = total_elements // num_blocks
max_num_blocks = CNNL_MAX_TENSOR_ELEMENTS // elements_per_block
logger.warning(
"KV cache tensor elements (%d) exceed CNNL max (%d). "
"Reducing num_blocks from %d to %d. This indicates "
"determine_num_available_blocks did not cap correctly.",
total_elements, CNNL_MAX_TENSOR_ELEMENTS,
num_blocks, max_num_blocks)
num_blocks = max_num_blocks
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
kv_cache_scales_shape = self.attn_backend.get_kv_cache_scale_shape(
num_blocks, self.block_size, self.num_kv_heads)
pin_memory = is_pin_memory_available() if device == "cpu" else False

View File

@@ -95,6 +95,30 @@ class MLUWorker_V2(MLUWorker):
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
# Cap num_gpu_blocks to avoid exceeding CNNL's int32 tensor element
# limit. CNNL operators do not support tensors with more than
# 2^31 - 1 elements. The KV cache shape is typically
# (2, num_blocks, num_kv_heads, block_size, head_size), and when
# num_blocks is very large (e.g. for tiny models with huge free
# memory), the total element count can overflow.
CNNL_MAX_TENSOR_ELEMENTS = 2**31 - 1
block_size = self.cache_config.block_size
num_kv_heads = self.model_config.get_num_kv_heads(
self.parallel_config)
head_size = self.model_config.get_head_size()
# kv_cache_shape = (2, num_blocks, num_kv_heads, block_size, head_size)
elements_per_block = 2 * num_kv_heads * block_size * head_size
if elements_per_block > 0:
max_blocks_by_cnnl = CNNL_MAX_TENSOR_ELEMENTS // elements_per_block
if num_gpu_blocks > max_blocks_by_cnnl:
logger.warning(
"Reducing num_gpu_blocks from %d to %d to stay within "
"CNNL max tensor element limit (%d). "
"elements_per_block=%d",
num_gpu_blocks, max_blocks_by_cnnl,
CNNL_MAX_TENSOR_ELEMENTS, elements_per_block)
num_gpu_blocks = max_blocks_by_cnnl
logger.info(
"Memory profiling results: total_gpu_memory=%.2fGiB"
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"