38 Commits

Author SHA1 Message Date
Chranos
dd221f3084 add llama4 2026-02-11 17:25:38 +08:00
Chranos
7b4f7d74c3 add llama4 2026-02-11 16:08:37 +08:00
Chranos
16d41a8fc1 add deepseekv3 and llama4 2026-02-11 16:03:06 +08:00
Chranos
633aa4db30 add deepseekv3 and llama4 2026-02-11 15:58:34 +08:00
Chranos
6eae065dd6 add deepseekv3 and llama4 2026-02-11 15:48:35 +08:00
Chranos
e752946445 add deepseekv3 and llama4 2026-02-11 15:44:44 +08:00
Chranos
7626238695 add deepseekv3 and llama4 2026-02-11 15:40:19 +08:00
Chranos
f3a4d10195 add deepseekv3 and llama4 2026-02-11 15:39:35 +08:00
Chranos
ed6a2aff91 add deepseekv3 and llama4 2026-02-11 15:37:19 +08:00
Chranos
6faa595799 add deepseekv3 and llama4 2026-02-11 15:32:07 +08:00
Chranos
50e02f2011 add deepseekv3 and llama4 2026-02-11 15:27:19 +08:00
Chranos
c584139543 add deepseekv3 and llama4 2026-02-11 15:24:13 +08:00
Chranos
2ad23aa8da add deepseekv3 and llama4 2026-02-11 15:17:07 +08:00
Chranos
86fd3b5a92 add deepseekv3 and llama4 2026-02-11 15:13:14 +08:00
Chranos
eaeb5169e0 add deepseekv3 and llama4 2026-02-11 15:09:59 +08:00
Chranos
44ffd2094a add deepseekv3 and llama4 2026-02-11 15:07:52 +08:00
Chranos
5132af6176 add deepseekv3 and llama4 2026-02-11 15:05:55 +08:00
Chranos
5c4c2222ba add deepseekv3 and llama4 2026-02-11 15:03:30 +08:00
Chranos
026380fddb add deepseekv3 and llama4 2026-02-11 14:56:40 +08:00
Chranos
d9d1f3a724 add deepseekv3 and llama4 2026-02-11 14:39:48 +08:00
Chranos
d93c740e4d add deepseekv3 and llama4 2026-02-11 14:37:00 +08:00
Chranos
153bc4ec7b add deepseekv3 and llama4 2026-02-11 14:32:37 +08:00
Chranos
96ed925486 add deepseekv3 and llama4 2026-02-11 14:30:01 +08:00
Chranos
8ac7afcbd3 add deepseekv3 and llama4 2026-02-11 14:26:59 +08:00
Chranos
128aed196c add deepseekv3 and llama4 2026-02-11 14:19:17 +08:00
Chranos
659ef273c8 add deepseekv3 2026-02-11 13:18:03 +08:00
Chranos
98003e6f8b add deepseekv3 2026-02-11 13:12:46 +08:00
Chranos
094541296e add deepseekv3 2026-02-11 12:28:36 +08:00
Chranos
5a05c22162 add deepseekv3 2026-02-11 11:40:57 +08:00
Chranos
60f3a23d5f add deepseekv3 2026-02-11 11:35:12 +08:00
Chranos
9c1d7cc9ff add qwen3_moe 2026-02-10 18:55:35 +08:00
Chranos
934ed88691 add qwen3_moe 2026-02-10 18:30:48 +08:00
Chranos
fa0219fbf8 add qwen3_moe 2026-02-10 18:22:13 +08:00
Chranos
efbb06147a add qwen3_moe 2026-02-10 18:18:32 +08:00
Chranos
a26729bf7f add qwen3_moe 2026-02-10 18:09:58 +08:00
Chranos
8a613d15bd add qwen3_moe 2026-02-10 18:02:40 +08:00
Chranos
a6f39375e5 debugging 2026-02-10 16:10:28 +08:00
Chranos
afc34d988e debugging 2026-02-10 15:47:48 +08:00
24 changed files with 2323 additions and 106 deletions

View File

@@ -175,3 +175,7 @@ curl http://localhost:80/v1/chat/completions \
| v0.0.3.1 | 2026-02-06 | **CNNL Tensor 溢出修复**:解决极小模型在大显存设备上部署时 KV cache 元素数超过 int32 限制的问题,在 mlu_worker 和 cache_engine 中添加双重防护 |
| v0.0.4 | 2026-02-10 | **Gemma3 模型支持**:新增 Gemma3ForCausalLM 模型实现(含 QK Normalization、per-layer rope 配置、layer_types 滑动窗口),修复 `patch_rope_scaling_dict` 在 rope_scaling 缺少 `rope_type` 键时崩溃的问题,更新模型注册表及 config.py 中 interleaved attention 和 dtype 自动处理逻辑 |
| v0.0.4.1 | 2026-02-10 | **Gemma3 rope 兼容性修复**:修复新版 transformers `Gemma3TextConfig` 缺少 `rope_theta` 属性的问题,从 `rope_parameters` 字典兼容提取 rope 配置(支持 Transformers v4/v5修复 `rope_scaling` 嵌套字典导致 `get_rope` 缓存 unhashable 的问题;适配 MLU `forward_mlu` 接口,将 q/k 合并为单张量调用 rotary_emb 后再拆分 |
| v0.0.5 | 2026-02-10 | **Qwen3MoE 模型支持**:新增 Qwen3MoeForCausalLM 模型实现(含 QK Normalization、ReplicatedLinear shared_expert_gate修复 FusedMoE `forward_mlu` 签名缺少 `layer` 参数的已有 bug影响所有 MLU 上的 MoE 模型),更新模型注册表 |
| v0.0.6 | 2026-02-11 | **DeepSeek V3 模型支持**:注册 DeepseekV3ForCausalLM复用 V2 实现),扩展 MLU MLA config 判断支持 `deepseek_v3`,实现 `noaux_tc` 路由方式(`e_score_correction_bias`),跳过 MTP 层权重加载,修复 MLA unpaged 缓存路径使用错误的 paged cache 算子prefill + decode 均替换为 `reshape_linear_cache` |
| v0.0.6 | 2026-02-11 | **DeepSeek V3 MTP 推测解码**:新建 `deepseek_mtp.py` 实现 MTP draft model复用 DeepseekV2DecoderLayerEAGLE 模板适配SpeculativeConfig 自动检测 `num_nextn_predict_layers` 并改写 draft configtarget worker 为 MTP 返回 hidden statesMLU config 三处 model_type 判断扩展支持 `deepseek_mtp` 以匹配 MLA cache 格式 |
| v0.0.6 | 2026-02-11 | **Llama4 模型支持**:新建 Llama4ForCausalLM 模型实现(复合 config 处理、sigmoid routing MoE、QK Normalization、交替 dense/MoE 层),新建 MLU hijack 适配SparseMoeMlp MoE 替换、embedding dtype 修复),处理 `Llama4Config` 嵌套 `text_config` 的 architectures 提取问题。**⚠️ MoE dense 模式(影响所有 MoE 模型)**:原始 `forward_experts_nofused` 包含 `torch.unique``torch.tensor` 创建、数据依赖分支等 graph capture 不兼容操作,导致 MLU370 上所有走 `SparseMoeMlp` 的 MoE 模型必须加 `--enforce-eager` 才能运行。现已改为 dense 模式(每个 expert 处理全部 token解决了 graph capture 兼容性,所有 MoE 模型无需 `--enforce-eager` 即可运行,但计算量增大 num_experts/topk 倍Mixtral 4x、Llama4 16x、Qwen2-MoE 15x。DeepSeek V2/V3 不受影响(有独立 MLU MoE hijack。后续应拆分 `is_use_fused_moe` 标志让 MLU370 走 `forward_group_experts` 路径优化 |

View File

@@ -1403,6 +1403,18 @@ class SpeculativeConfig:
draft_hf_config = draft_model_config.hf_config
# Detect DeepSeek V3 MTP: same model path with
# num_nextn_predict_layers > 0
num_nextn = getattr(draft_hf_config,
"num_nextn_predict_layers", 0)
if (num_nextn and num_nextn > 0
and getattr(draft_hf_config, "model_type", "")
in ("deepseek_v3",)):
draft_hf_config.model_type = "deepseek_mtp"
draft_hf_config.architectures = ["DeepSeekMTPModel"]
if num_speculative_tokens is None:
num_speculative_tokens = num_nextn
if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
@@ -1421,7 +1433,7 @@ class SpeculativeConfig:
f"{num_speculative_tokens=} was provided.")
if enable_chunked_prefill and draft_hf_config.model_type in (
"medusa", "mlp_speculator", "eagle"):
"medusa", "mlp_speculator", "eagle", "deepseek_mtp"):
raise ValueError(
"Chunked prefill and hidden-state based draft models are "
"not compatible.")

View File

@@ -153,23 +153,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_mlu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm._mlu_ops import fused_moe
assert use_grouped_topk is False and num_expert_group is None and topk_group is None, \
f"Following params: use_grouped_topk, num_expert_group, topk_group are not support yet."
assert use_grouped_topk is False and num_expert_group is None \
and topk_group is None, \
"Following params: use_grouped_topk, num_expert_group, " \
"topk_group are not supported yet."
return fused_moe(x,
router_logits,
w1, w2,
layer.w13_weight, layer.w2_weight,
None, None, # bias1, bias2
None, # residual
None, # input_smooth

View File

@@ -143,11 +143,14 @@ class RMSNorm(CustomOp):
from vllm import _mlu_ops as mlu_ops
x = x.view(-1, self.weight.data.shape[0])
weight = self.weight.data
if weight.dtype != x.dtype:
weight = weight.to(x.dtype)
if residual is not None:
residual = residual.view(-1, self.weight.data.shape[0])
return mlu_ops.fused_rms_norm(x, residual, self.weight.data, None, None, self.variance_epsilon, True)
return mlu_ops.fused_rms_norm(x, residual, weight, None, None, self.variance_epsilon, True)
else:
return mlu_ops.fused_rms_norm(x, residual, self.weight.data, None, None, self.variance_epsilon, False)
return mlu_ops.fused_rms_norm(x, residual, weight, None, None, self.variance_epsilon, False)
def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"

View File

@@ -38,6 +38,9 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# MLU F.linear requires matching dtypes
if x.dtype != layer.weight.dtype:
x = x.to(layer.weight.dtype)
return F.linear(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module,

View File

@@ -89,15 +89,63 @@ def device_loading_context(module: torch.nn.Module,
logger = init_logger(__name__)
def _get_device_memory_info_loader():
"""Get device memory info for debug logging. Returns dict or None."""
try:
import torch.mlu
allocated = torch.mlu.memory_allocated() / (1024 ** 3)
reserved = torch.mlu.memory_reserved() / (1024 ** 3)
free, total = torch.mlu.mem_get_info()
return {"allocated": allocated, "reserved": reserved,
"free": free / (1024 ** 3), "total": total / (1024 ** 3)}
except Exception:
pass
try:
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / (1024 ** 3)
reserved = torch.cuda.memory_reserved() / (1024 ** 3)
free, total = torch.cuda.mem_get_info()
return {"allocated": allocated, "reserved": reserved,
"free": free / (1024 ** 3), "total": total / (1024 ** 3)}
except Exception:
pass
return None
def _log_mem(tag: str):
info = _get_device_memory_info_loader()
if info:
logger.info(
"[DEBUG-MEM] %s: allocated=%.2f GiB, reserved=%.2f GiB, "
"free=%.2f GiB, total=%.2f GiB",
tag, info["allocated"], info["reserved"],
info["free"], info["total"])
def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
"""Initialize a model with the given configurations."""
model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config)
logger.info("[DEBUG-MEM] Model class: %s, dtype: %s",
model_class.__name__, model_config.dtype)
_log_mem("Before _initialize_model")
signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
return model_class(vllm_config=vllm_config, prefix=prefix)
model = model_class(vllm_config=vllm_config, prefix=prefix)
_log_mem("After _initialize_model (empty weights created)")
# Print model parameter summary
total_params = 0
total_bytes = 0
for name, param in model.named_parameters():
total_params += param.numel()
total_bytes += param.numel() * param.element_size()
logger.info(
"[DEBUG-MEM] Model params: %d, "
"estimated size: %.2f GiB",
total_params, total_bytes / (1024 ** 3))
return model
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
@@ -327,11 +375,14 @@ class DefaultModelLoader(BaseModelLoader):
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
_log_mem("load_model start, target_device=%s" % target_device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
_log_mem("Before load_weights")
model.load_weights(self._get_all_weights(model_config, model))
_log_mem("After load_weights")
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)

View File

@@ -20,7 +20,7 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
architectures = getattr(model_config.hf_config, "architectures", None) or []
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = [

View File

@@ -0,0 +1,291 @@
"""Inference-only DeepSeek V3 Multi-Token Prediction (MTP) model."""
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .deepseek_v2 import DeepseekV2DecoderLayer
class SharedHead(nn.Module):
"""Shared head for MTP: norm + lm_head."""
def __init__(self, config, prefix: str = ""):
super().__init__()
self.norm = RMSNorm(config.hidden_size,
eps=getattr(config, "rms_norm_eps", 1e-6))
self.head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(hidden_states)
class DeepSeekMultiTokenPredictorLayer(nn.Module):
"""Single MTP layer: enorm + hnorm + eh_proj + shared_head + mtp_block."""
def __init__(
self,
config,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.enorm = RMSNorm(config.hidden_size,
eps=getattr(config, "rms_norm_eps", 1e-6))
self.hnorm = RMSNorm(config.hidden_size,
eps=getattr(config, "rms_norm_eps", 1e-6))
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.shared_head = SharedHead(config,
prefix=f"{prefix}.shared_head")
# Reuse DeepseekV2DecoderLayer (MLU hijack auto-applies)
self.mtp_block = DeepseekV2DecoderLayer(
config,
prefix=f"model.layers.{layer_idx}",
cache_config=cache_config,
quant_config=quant_config,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert inputs_embeds is not None
# Mask inputs at position 0
inputs_embeds = torch.where(
positions.unsqueeze(-1) == 0, 0, inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(
positions, hidden_states, kv_caches[0], attn_metadata,
residual=None)
hidden_states = residual + hidden_states
return hidden_states
def get_spec_layer_idx_from_weight_name(config, weight_name: str):
"""Check if weight belongs to a speculative (MTP) layer.
Returns the layer index if so, None otherwise."""
num_nextn = getattr(config, "num_nextn_predict_layers", 0)
if num_nextn and num_nextn > 0:
layer_idx = config.num_hidden_layers
for i in range(num_nextn):
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
return layer_idx + i
return None
def _rewrite_spec_layer_name(config, spec_layer: int, name: str) -> str:
"""Rewrite weight name for MTP layer.
Add .mtp_block for transformer block weights,
rename shared weights to top level."""
spec_layer_weight_names = [
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head",
]
shared_weight_names = ["embed_tokens"]
spec_layer_weight = False
shared_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
if weight_name in shared_weight_names:
shared_weight = True
break
if not spec_layer_weight:
# Transformer block weights -> add .mtp_block prefix
name = name.replace(
f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.")
elif shared_weight:
# Shared weights -> top level
name = name.replace(f"model.layers.{spec_layer}.", "model.")
return name
class DeepSeekMTP(nn.Module):
"""DeepSeek V3 Multi-Token Prediction draft model.
Uses hidden states from the target model to predict the next token
via a single additional decoder layer."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.mtp_start_layer_idx = config.num_hidden_layers
num_mtp = getattr(config, "num_nextn_predict_layers", 1)
self.layers = nn.ModuleDict()
for i in range(num_mtp):
layer_idx = self.mtp_start_layer_idx + i
self.layers[str(layer_idx)] = DeepSeekMultiTokenPredictorLayer(
config=config,
layer_idx=layer_idx,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"model.layers.{layer_idx}",
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
# Use the first MTP layer (DeepSeek V3 only has 1)
layer = self.layers[str(self.mtp_start_layer_idx)]
hidden_states = layer(
input_ids, positions, previous_hidden_states,
kv_caches, attn_metadata, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
layer = self.layers[str(self.mtp_start_layer_idx)]
normed = layer.shared_head(hidden_states)
logits = self.logits_processor(
layer.shared_head.head, normed, sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# MLU SparseMoeMlp needs pack_params() before loading
try:
from vllm_mlu.model_executor.layers.sparse_moe_mlp import (
SparseMoeMlp)
for name, m in self.named_modules():
if isinstance(m, SparseMoeMlp):
m.pack_params()
except ImportError:
pass
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# Only load MTP layer weights
spec_layer = get_spec_layer_idx_from_weight_name(
self.config, name)
if spec_layer is None:
continue
# Rewrite weight name for MTP structure
name = _rewrite_spec_layer_name(
self.config, spec_layer, name)
# Only load shared weights (embed_tokens) from first
# MTP layer, per DeepSeek V3 Technical Report
if (spec_layer != self.mtp_start_layer_idx
and ".layers" not in name):
continue
# Strip "model." prefix since DeepSeekMTP holds
# embed_tokens and layers directly (no .model wrapper)
if name.startswith("model."):
name = name[len("model."):]
self._load_single_weight(
name, loaded_weight, stacked_params_mapping,
params_dict)
def _load_single_weight(
self,
name: str,
loaded_weight: torch.Tensor,
stacked_params_mapping: List[Tuple[str, str, int]],
params_dict: dict,
):
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip expert weights not in params_dict
if (("mlp.experts." in name
or "mlp.shared_experts." in name
or "mlp.shared_expert_gate." in name
or "e_score_correction_bias" in name)
and name not in params_dict):
return
if name.endswith(".bias") and name not in params_dict:
return
if name not in params_dict:
return
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
return
# Non-stacked weights
if name.endswith(".bias") and name not in params_dict:
return
if (("mlp.experts." in name
or "mlp.shared_experts." in name
or "mlp.shared_expert_gate." in name
or "e_score_correction_bias" in name)
and name not in params_dict):
return
if name not in params_dict:
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -611,3 +611,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass

View File

@@ -0,0 +1,580 @@
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Llama4 model compatible with HuggingFace weights."""
import re
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .llama import LlamaMLP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
def _extract_layer_index(prefix: str) -> int:
"""Extract layer index from prefix string like 'model.layers.0.self_attn'."""
match = re.search(r'layers\.(\d+)', prefix)
if match is None:
raise ValueError(f"Cannot extract layer index from prefix: {prefix}")
return int(match.group(1))
class Llama4MoE(nn.Module):
"""Llama4 Mixture of Experts with shared expert."""
@staticmethod
def custom_routing_function(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores, router_indices = torch.topk(
gating_output, topk, dim=-1)
router_scores = torch.sigmoid(router_scores.float())
return (router_scores, router_indices.to(torch.int32))
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = getattr(config, "num_experts_per_tok", 1)
self.num_local_experts = getattr(config, "num_local_experts", 8)
self.hidden_size = getattr(config, "hidden_size", 4096)
intermediate_size_moe = getattr(config, "intermediate_size", 8192)
self.router = ReplicatedLinear(
self.hidden_size,
self.num_local_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.router",
)
self.experts = FusedMoE(
num_experts=self.num_local_experts,
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=intermediate_size_moe,
reduce_results=False,
renormalize=False,
quant_config=quant_config,
custom_routing_function=Llama4MoE.custom_routing_function,
prefix=f"{prefix}.experts",
)
self.shared_expert = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=intermediate_size_moe,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.shared_expert",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.router(hidden_states)
# routed experts
routed_out = self.experts(hidden_states, router_logits)
# shared expert
shared_out = self.shared_expert(hidden_states)
# combine and all-reduce
experts_out = routed_out + shared_out
if self.tp_size > 1:
experts_out = tensor_model_parallel_all_reduce(experts_out)
return experts_out.view(orig_shape)
class Llama4Attention(nn.Module):
def __init__(
self,
config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = _extract_layer_index(prefix)
self.hidden_size = hidden_size
self.no_rope_layers = getattr(config, "no_rope_layers", None)
self.nope = (self.no_rope_layers is not None
and self.no_rope_layers[self.layer_idx] == 0)
self.use_qk_norm = getattr(config, "use_qk_norm", False) and not self.nope
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
# Temperature tuning for NoPE layers
self.attn_temperature_tuning = (
self.nope and getattr(config, "attn_temperature_tuning", False))
self.floor_scale = getattr(config, "floor_scale", 8192.0)
self.attn_scale = getattr(config, "attn_scale", 0.1)
# QK norm
rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5)
if self.use_qk_norm:
self.qk_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
# v0.6.2 RMSNorm doesn't support has_weight=False,
# so we set weight to ones and make it non-trainable
self.qk_norm.weight.data.fill_(1.0)
self.qk_norm.weight.requires_grad = False
else:
self.qk_norm = None
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
# RoPE (None for NoPE layers)
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if not self.nope:
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=True,
)
else:
self.rotary_emb = None
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
)
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale)
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
return attn_scale.unsqueeze(-1)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k)
if self.qk_norm is not None:
q = q.reshape(-1, self.head_dim)
q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype)
k = k.reshape(-1, self.head_dim)
k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype)
if self.attn_temperature_tuning and self.nope:
attn_scale = self._get_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class Llama4DecoderLayer(nn.Module):
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = _extract_layer_index(prefix)
self.hidden_size = getattr(config, "hidden_size", 4096)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = Llama4Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=getattr(config, "num_attention_heads", 32),
num_kv_heads=getattr(config, "num_key_value_heads",
getattr(config, "num_attention_heads", 32)),
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
# Interleaved MoE/dense layers
interleave_moe_layer_step = getattr(config,
"interleave_moe_layer_step", 0)
is_moe_layer = (interleave_moe_layer_step > 0
and (self.layer_idx + 1)
% interleave_moe_layer_step == 0)
if is_moe_layer:
self.feed_forward = Llama4MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
else:
intermediate_size_mlp = getattr(config, "intermediate_size_mlp",
getattr(config,
"intermediate_size", 8192))
self.feed_forward = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=intermediate_size_mlp,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.feed_forward",
)
rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5)
self.input_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = RMSNorm(self.hidden_size,
eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
class Llama4Model(nn.Module):
"""Llama4 model - independent implementation to avoid pad_token_id issue."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
# Defensive access - Llama4Config may not have pad_token_id
self.padding_idx = getattr(config, "pad_token_id", None)
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (
getattr(config, "tie_word_embeddings", False)
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Llama4DecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Llama4ForCausalLM(nn.Module, SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
# Llama4ForConditionalGeneration uses top-level Llama4Config
# which has text_config sub-config. Extract it for text model.
text_config = getattr(config, "text_config", None)
if text_config is not None:
orig_archs = getattr(config, "architectures", None)
vllm_config.model_config.hf_config = text_config
if orig_archs and not getattr(text_config, "architectures", None):
text_config.architectures = orig_archs
config = text_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = Llama4Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE if not lora_config
else lora_config.lora_vocab_padding_size),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if getattr(config, "tie_word_embeddings", False):
self.lm_head = self.lm_head.tie_weights(
self.model.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def permute_qk_weight_for_rotary(
self,
name: str,
loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]:
"""Permute Q/K weights for rotary embedding compatibility."""
def permute(w: torch.Tensor, n_heads: int):
attn_in = getattr(self.config, "head_dim", 128) * n_heads
attn_out = getattr(self.config, "hidden_size", 4096)
return (w.contiguous()
.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
.transpose(1, 2).reshape(attn_in, attn_out))
modules = name.split(".")
is_weight = modules[-1] == "weight"
if is_weight:
if "k_proj" in modules:
loaded_weight = permute(
loaded_weight,
getattr(self.config, "num_key_value_heads", 8))
elif "q_proj" in modules:
loaded_weight = permute(
loaded_weight,
getattr(self.config, "num_attention_heads", 32))
return name, loaded_weight
def load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]],
):
loader = AutoWeightsLoader(
self,
skip_prefixes=(
["lm_head."]
if getattr(self.config, "tie_word_embeddings", False)
else None),
)
def _process_weights(weights):
for name, loaded_weight in weights:
# Strip language_model. prefix for Llama4ForConditionalGeneration
if name.startswith("language_model."):
name = name[len("language_model."):]
# Skip vision encoder weights
elif name.startswith("multi_modal_projector.") or \
name.startswith("vision_encoder.") or \
name.startswith("vision_model."):
continue
name, loaded_weight = self.permute_qk_weight_for_rotary(
name, loaded_weight)
yield name, loaded_weight
loader.load_weights(_process_weights(weights))

View File

@@ -0,0 +1,556 @@
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class Qwen3MoeMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
self.experts = FusedMoE(num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config)
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
bias=False,
quant_config=None)
shared_expert_intermediate_size = getattr(
config, "shared_expert_intermediate_size", 0)
if shared_expert_intermediate_size > 0:
self.shared_expert = Qwen3MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=shared_expert_intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
)
else:
self.shared_expert = None
# Qwen3Moe uses ReplicatedLinear for shared_expert_gate
# (unlike Qwen2Moe which uses torch.nn.Linear)
self.shared_expert_gate = ReplicatedLinear(config.hidden_size,
1,
bias=False,
quant_config=None)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
if self.shared_expert_gate is not None:
shared_output = F.sigmoid(
self.shared_expert_gate(hidden_states)[0]
) * shared_output
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(orig_shape)
class Qwen3MoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
# Qwen3 specific: QK normalization
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
# Qwen3 specific: Apply QK normalization before rotary embedding
# Use .contiguous() to ensure memory layout is compatible with
# MLU's RMSNorm which uses .view() internally.
q_shape = q.shape
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim).contiguous()
q_by_head = self.q_norm(q_by_head)
q = q_by_head.reshape(q_shape)
k_shape = k.shape
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim).contiguous()
k_by_head = self.k_norm(k_by_head)
k = k_by_head.reshape(k_shape)
# MLU rotary_emb expects a single concatenated 3D tensor, not
# separate q and k (forward_mlu signature differs from forward_native).
qk = torch.cat([q, k], dim=-1)
self.rotary_emb(positions,
qk.view(-1, self.num_heads + self.num_kv_heads,
self.head_dim))
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class Qwen3MoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = Qwen3MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
cache_config=cache_config,
quant_config=quant_config,
)
# Note: Qwen3MoE may not have `mlp_only_layers` in the config.
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
quant_config=quant_config)
else:
self.mlp = Qwen3MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class Qwen3MoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Qwen3MoeDecoderLayer(config=config,
layer_idx=int(
prefix.split(".")[-1]),
cache_config=cache_config,
quant_config=quant_config),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = Qwen3MoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -48,6 +48,7 @@ _TEXT_GENERATION_MODELS = {
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
@@ -64,6 +65,8 @@ _TEXT_GENERATION_MODELS = {
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
"Llama4ForConditionalGeneration": ("llama4", "Llama4ForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
@@ -91,6 +94,7 @@ _TEXT_GENERATION_MODELS = {
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
@@ -162,6 +166,7 @@ _SPECULATIVE_DECODING_MODELS = {
"EAGLEModel": ("eagle", "EAGLE"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
}
# Transformers backend models - wrapper classes for custom HuggingFace models
@@ -479,6 +484,7 @@ class _ModelRegistry:
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
return []
return architectures

View File

@@ -492,6 +492,29 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
return module
def _get_device_memory_info() -> Tuple[Optional[float], Optional[float], Optional[float]]:
"""Get device memory info in GiB. Returns (allocated, reserved, total) or Nones."""
try:
import torch.mlu
allocated = torch.mlu.memory_allocated() / (1024 ** 3)
reserved = torch.mlu.memory_reserved() / (1024 ** 3)
free, total = torch.mlu.mem_get_info()
total = total / (1024 ** 3)
return allocated, reserved, total
except Exception:
pass
try:
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / (1024 ** 3)
reserved = torch.cuda.memory_reserved() / (1024 ** 3)
free, total = torch.cuda.mem_get_info()
total = total / (1024 ** 3)
return allocated, reserved, total
except Exception:
pass
return None, None, None
def make_layers(
num_hidden_layers: int,
layer_fn: LayerFn,
@@ -505,11 +528,31 @@ def make_layers(
start_layer, end_layer = get_pp_indices(num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)
alloc_before, _, total = _get_device_memory_info()
if alloc_before is not None:
logger.info(
"[DEBUG-MEM] make_layers start: allocated=%.2f GiB, "
"total=%.2f GiB, layers to create: %d-%d / %d",
alloc_before, total, start_layer, end_layer, num_hidden_layers)
created_layers = []
for idx in range(start_layer, end_layer):
layer = maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
alloc_after, reserved, _ = _get_device_memory_info()
if alloc_after is not None:
delta = alloc_after - alloc_before
logger.info(
"[DEBUG-MEM] Layer %s.%d created: "
"allocated=%.2f GiB (+%.4f GiB), reserved=%.2f GiB",
prefix, idx, alloc_after, delta, reserved)
alloc_before = alloc_after
created_layers.append(layer)
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] + [
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
for idx in range(start_layer, end_layer)
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
[PPMissingLayer() for _ in range(start_layer)]
+ created_layers
+ [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules

View File

@@ -159,9 +159,11 @@ class MLUSpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_worker_kwargs[
"model_runner_cls"] = MLUTP1DraftModelRunner
else:
if draft_model_config.hf_config.model_type == "eagle":
if draft_model_config.hf_config.model_type in (
"eagle", "deepseek_mtp"):
raise NotImplementedError(
"EAGLE does not support TP > 1 yet")
f"{draft_model_config.hf_config.model_type} "
"does not support TP > 1 yet")
allow_zero_draft_token_step = False
proposer_worker = MLUMultiStepWorker(**draft_worker_kwargs)

View File

@@ -13,7 +13,7 @@ from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import (
get_image_processor_config)
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES)
from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
from vllm.envs import VLLM_USE_MODELSCOPE
@@ -89,9 +89,10 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
# hf_hub. This will fail in offline mode.
try:
return file_exists(model, config_name, revision=revision, token=token)
except huggingface_hub.errors.OfflineModeIsEnabled:
# Don't raise in offline mode, all we know is that we don't have this
# file cached.
except (huggingface_hub.errors.OfflineModeIsEnabled,
huggingface_hub.errors.HFValidationError):
# Don't raise in offline mode or when model path fails HF validation
# (e.g., local paths that don't match HF repo id format)
return False
@@ -169,12 +170,6 @@ def get_config(
token=token):
config_format = ConfigFormat.MISTRAL
else:
# If we're in offline mode and found no valid config format, then
# raise an offline mode error to indicate to the user that they
# don't have files cached and may need to go online.
# This is conveniently triggered by calling file_exists().
file_exists(model, HF_CONFIG_NAME, revision=revision, token=token)
raise ValueError(f"No supported config format found in {model}")
if config_format == ConfigFormat.HF:
@@ -234,6 +229,17 @@ def get_config(
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})
# Architecture mapping for models without explicit architectures field
if not getattr(config, "architectures", None):
if config.model_type not in MODEL_MAPPING_NAMES:
logger.warning(
"Model config does not have a top-level 'architectures' "
"field: expecting `hf_overrides={'architectures': "
"['...']}` to be passed in engine args.")
else:
model_type = MODEL_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})
patch_rope_scaling(config)
if trust_remote_code:

View File

@@ -59,12 +59,20 @@ class MLUWorker(Worker):
# mlp_speculator
speculative_config = self.speculative_config
model_config = self.model_config
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle"]) \
else {"return_hidden_states": True}
is_mtp = (speculative_config is not None
and model_config.task != "draft"
and getattr(
speculative_config.draft_model_config.hf_config,
"model_type", None) == "deepseek_mtp")
speculative_args = (
{"return_hidden_states": True} if is_mtp else
({} if speculative_config is None
or (speculative_config.draft_model_config.model ==
model_config.model)
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle"])
else {"return_hidden_states": True})
)
ModelRunnerClass: Type[MLUModelRunnerBase] = MLUModelRunner
if model_runner_cls is not None:

View File

@@ -580,34 +580,58 @@ def unified_flash_attention_v2(
value_cache,
updated_slot_mapping.flatten())
else:
# FIXME: After TMO-1496 is completed, remove this code.
if key.stride() != value.stride():
key = key.contiguous()
value = value.contiguous()
if kv_cache_dtype == 'int8':
mlu_ops.quant_to_linear_cache(key,
value,
key_cache,
value_cache,
key_cache_scale,
value_cache_scale,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, # packed
None, # context_seq_offset
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
# unpaged (linear cache) path
if use_mla:
# MLA: 镜像 paged 路径的处理方式
# key_cache: (num_blocks, 1, block_size, 576)
value_to_cache = None
if attn_metadata.prefill_metadata:
# MLA prefill cache 已在 forward_prefill 中写入,跳过
pass
else:
if kv_cache_dtype == 'int8':
mlu_ops.quant_to_linear_cache(
key, value_to_cache,
key_cache, value_cache,
key_cache_scale, value_cache_scale,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
else:
mlu_ops.reshape_linear_cache(
key, value_to_cache,
key_cache, value_cache,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
else:
mlu_ops.reshape_linear_cache(key,
value,
key_cache,
value_cache,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, # packed
None, # context_seq_offset
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
# FIXME: After TMO-1496 is completed, remove this code.
if key.stride() != value.stride():
key = key.contiguous()
value = value.contiguous()
if kv_cache_dtype == 'int8':
mlu_ops.quant_to_linear_cache(
key, value,
key_cache, value_cache,
key_cache_scale, value_cache_scale,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
else:
mlu_ops.reshape_linear_cache(
key, value,
key_cache, value_cache,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
if use_mla and attn_metadata.prefill_metadata:
output = torch.empty(query.shape[0], query.shape[1], v_head_size, dtype=query.dtype, device="mlu")
else:

View File

@@ -37,7 +37,7 @@ def vllm__config__CacheConfig___verify_cache_dtype(self) -> None:
def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type == 'deepseek_v2':
if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
# feature flag MLA
return 1
total_num_kv_heads = self.get_total_num_kv_heads()
@@ -51,7 +51,7 @@ def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "Parallel
def vllm__config__ModelConfig__get_head_size(self) -> int:
# TODO remove hard code
if hasattr(self.hf_text_config, "model_type"
) and self.hf_text_config.model_type == 'deepseek_v2':
) and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
'''
=============================
Modify by vllm_mlu
@@ -109,7 +109,7 @@ def vllm__config__LoRAConfig__verify_with_model_config(self, model_config: Model
def vllm__config__ModelConfig__is_deepseek_v2(self) -> bool:
result = hasattr(
self.hf_text_config,
"model_type") and self.hf_text_config.model_type == 'deepseek_v2'
"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp')
return result
MluHijackObject.apply_hijack(ModelConfig,

View File

@@ -26,6 +26,12 @@ def vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply(
beta = 1.0
residual = residual.view(-1, residual.shape[-1])
res_shape = x.shape[0:-1] + (layer.weight.shape[0], )
# MLU matmul requires all tensors to have matching dtypes
target_dtype = layer.weight.dtype
if x.dtype != target_dtype:
x = x.to(target_dtype)
if residual is not None and residual.dtype != target_dtype:
residual = residual.to(target_dtype)
return mlu_ops.matmul(x.view(-1, x.shape[-1]), layer.weight, bias, residual, 'none', 1.0, beta).view(res_shape)

View File

@@ -73,6 +73,24 @@ class SparseMoeMlp(nn.Module):
self.expert_group = expert_group
self.topk_group = topk_group
if get_device_major_capability() == 3:
# WARNING: MLU370 (capability=3) 不支持 fused_moe 算子,强制关闭。
#
# 背景:原始 forward_experts_nofused 包含 torch.unique、torch.tensor([0], ...)、
# 数据依赖分支等 graph capture 不兼容操作,导致 MLU370 上所有走 SparseMoeMlp
# 的 MoE 模型必须加 --enforce-eager 才能运行。当前已将 forward_experts_nofused
# 改为 dense 模式(每个 expert 处理全部 token用路由权重 mask解决了
# graph capture 兼容性问题,所有 MoE 模型无需 --enforce-eager 即可运行。
#
# 性能代价dense 模式计算量为 O(num_experts * num_tokens),相比稀疏路由的
# O(topk * num_tokens) 增大了 num_experts/topk 倍。prefill 阶段对 expert
# 数量多的模型会明显变慢decode 阶段token 少)影响可忽略。
# 已知受影响模型Mixtral (8)、Qwen2-MoE (60)、HunYuan (16)、Llama4 (16) 等。
# DeepSeek V2/V3 不受影响(有独立的 MLU MoE hijack 实现)。
#
# TODO: MLU370 已有完整的 MoE 算子链moe_gen_idx、moe_expand_input、
# group_gemm、moe_active、moe_combine_result与 forward_group_experts
# 使用的算子相同。后续应拆分 is_use_fused_moe 标志,让 MLU370 走
# forward_group_experts 路径以避免 dense 模式的性能开销。
self.is_use_fused_moe = False
if params_dtype is None:
@@ -284,34 +302,28 @@ class SparseMoeMlp(nn.Module):
def forward_experts_nofused(self, hidden_states, expert_logits):
hidden_states_shape = hidden_states.shape
# Dense approach: each expert processes ALL tokens, then mask by routing
# weights. This avoids data-dependent control flow (variable-size slicing,
# conditional branches, torch.unique, torch.tensor creation) that is
# incompatible with MLU graph capture.
num_tokens, hidden_size = hidden_states.shape
topk_values, topk_indices = self.topk_softmax(expert_logits)
expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count = self.generate_gather_idx(
topk_indices)
# no expert is routed, then expand_gather_idx, expand_scatter_idx has no item,
# expand_token_count and expand_cusum_token_count has item but the value is all zero
# so this rank should only return final_hidden_states with zero value
if expand_gather_idx.numel() == 0:
final_hidden_states = torch.zeros_like(hidden_states,
dtype=hidden_states.dtype,
device=hidden_states.device)
return final_hidden_states
expand_hidden_states = self.expand_input(hidden_states, expand_gather_idx)
final_hidden_states = torch.zeros(
num_tokens, hidden_size,
dtype=hidden_states.dtype, device=hidden_states.device)
expand_output_list = []
expand_cusum_token_count = cusum_token_count[self.start_expert_id:self.end_expert_id +
1] - cusum_token_count[self.start_expert_id]
for expert_idx, num_tokens_per_expert in enumerate(expand_token_count):
if num_tokens_per_expert > 0:
expert_hidden_states = expand_hidden_states[
expand_cusum_token_count[expert_idx]:expand_cusum_token_count[expert_idx + 1]]
expert_output = self.experts[expert_idx](expert_hidden_states)
expert_output = expert_output[0] if isinstance(expert_output, (tuple, list)) else expert_output
expand_output_list.append(expert_output)
expand_output = torch.cat(expand_output_list, dim=0)
final_hidden_states = self.combine_moe(expand_output, scatter_idx, cusum_token_count, hidden_states_shape,
topk_values)
for expert_idx in range(self.num_experts_per_rank):
global_expert_idx = self.start_expert_id + expert_idx
expert_output = self.experts[expert_idx](hidden_states)
expert_output = expert_output[0] if isinstance(
expert_output, (tuple, list)) else expert_output
# Routing weight per token for this expert
expert_mask = (topk_indices == global_expert_idx).to(topk_values.dtype)
expert_weights = (topk_values * expert_mask).sum(dim=-1, keepdim=True)
final_hidden_states = final_hidden_states + expert_output * expert_weights
return final_hidden_states
@@ -425,9 +437,9 @@ class SparseMoeMlp(nn.Module):
scatter_idx=torch.zeros((indices.numel(),), dtype=seqs.dtype, device=seqs.device).scatter(0, indices, seqs)
# token_count: [self.num_experts_per_rank]
partial_token_index, partial_token_count = sorted_expert_id.unique(sorted=True, return_counts=True)
zero_token_count = torch.zeros(self.num_total_experts, dtype=partial_token_count.dtype, device=device)
token_count = zero_token_count.scatter(dim=0, index=partial_token_index, src=partial_token_count)
# Use scatter_add_ instead of torch.unique for MLU graph capture compatibility
token_count = torch.zeros(self.num_total_experts, dtype=sorted_expert_id.dtype, device=device)
token_count.scatter_add_(0, sorted_expert_id, torch.ones_like(sorted_expert_id))
# cusum_token_count: [self.num_experts_per_rank + 1]
cusum_token_count = torch.cat(
[torch.tensor([0], dtype=token_count.dtype, device=device),

View File

@@ -39,3 +39,9 @@ try:
except ImportError as e:
import logging
logging.warning(f"Failed to import mllama hijack: {e}")
try:
import vllm_mlu.model_executor.models.llama4
except ImportError as e:
import logging
logging.warning(f"Failed to import llama4 hijack: {e}")

View File

@@ -28,6 +28,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
from vllm import _mlu_ops as mlu_ops
from vllm.utils import print_warning_once
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm_mlu.model_executor.models.layer_utils import quant_fusion_with_rmsnorm
@@ -77,6 +78,12 @@ class DeepseekV2MoE(SparseMoeMlp):
bias=False,
quant_config=None,
prefix=f"{prefix}.gate")
if getattr(config, "topk_method", None) == "noaux_tc":
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts, dtype=torch.float32)
)
else:
self.gate.e_score_correction_bias = None
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
@@ -104,6 +111,7 @@ class DeepseekV2MoE(SparseMoeMlp):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
@@ -113,9 +121,25 @@ class DeepseekV2MoE(SparseMoeMlp):
Modify by vllm_mlu
=============================
@brief: replace experts() with forward_experts, which defined by SparseMoeMlp.
For noaux_tc (DeepSeek V3), do manual routing with e_score_correction_bias.
'''
final_hidden_states = self.forward_experts(
hidden_states, router_logits) * self.routed_scaling_factor
if self.gate.e_score_correction_bias is not None:
# noaux_tc routing: softmax → add bias for topk selection → use original scores
scores = router_logits.float().softmax(dim=-1)
scores_for_choice = scores + self.gate.e_score_correction_bias.unsqueeze(0)
topk_weights, topk_indices = torch.topk(
scores_for_choice, k=self.top_k, dim=-1)
# Use original softmax scores (without bias) as weights
topk_weights = scores.gather(1, topk_indices)
if self.renormalize:
topk_weights = topk_weights / topk_weights.sum(
dim=-1, keepdim=True)
final_hidden_states = self.forward_experts_with_precomputed_routing(
hidden_states, topk_weights, topk_indices
) * self.routed_scaling_factor
else:
final_hidden_states = self.forward_experts(
hidden_states, router_logits) * self.routed_scaling_factor
'''
==================
End of MLU Hijack
@@ -129,6 +153,55 @@ class DeepseekV2MoE(SparseMoeMlp):
return final_hidden_states.view(num_tokens, hidden_dim)
def forward_experts_with_precomputed_routing(
self, hidden_states, topk_weights, topk_indices
):
"""使用预计算的路由结果执行 MoE 前向传播"""
self.pack_params()
ori_input_shape = hidden_states.shape
expert_num = self.num_total_experts
expert_size = self.w13.size(0)
max_m = hidden_states.shape[0]
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
reduce_weight = topk_weights.to(torch.float32)
expert_id = topk_indices.to(torch.int32)
# gen_idx
expand_idx, combine_idx, token_count, cusum_token_count = (
mlu_ops.moe_gen_idx(expert_id, expert_num)
)
start_expert_id = self.start_expert_id
# gemm1
expand_hidden_states = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id, expert_size
)
gemm1_out = mlu_ops.group_gemm(
expand_hidden_states, self.w13,
token_count[start_expert_id:start_expert_id + expert_size],
None, None, None, None, max_m
)
# activation
act_out = mlu_ops.moe_active(
gemm1_out, self.hidden_act, self.is_gated, None, self.b13,
cusum_token_count, start_expert_id, expert_size
)
# gemm2
gemm2_out = mlu_ops.group_gemm(
act_out, self.w2,
token_count[start_expert_id:start_expert_id + expert_size],
None, None, None, None, max_m
)
# combine
output = mlu_ops.moe_combine_result(
gemm2_out, reduce_weight, combine_idx,
None, cusum_token_count, start_expert_id,
expert_size, self.b2
)
return output.view(ori_input_shape)
def forward_prefill(
self,
positions: torch.Tensor,
@@ -179,19 +252,27 @@ def forward_prefill(
updated_slot_mapping = attn_metadata.slot_mapping
if self.attn.kv_cache_dtype == 'int8':
key_cache_scale = kv_cache[1][0]
mlu_ops.quant_to_paged_cache(key_value,
mlu_ops.quant_to_linear_cache(key_value,
None,
key_cache,
None,
key_cache_scale,
None,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
else:
mlu_ops.reshape_linear_cache(key_value,
None,
key_cache,
None,
key_cache_scale,
None,
updated_slot_mapping.flatten())
else:
mlu_ops.reshape_paged_cache(key_value,
None,
key_cache,
None,
updated_slot_mapping.flatten())
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
'''
==================
End of MLU Hijack
@@ -491,6 +572,15 @@ def vllm__module_executor__models__deepseek_v2__DeepseekV2DecoderLayer__init__(
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def get_spec_layer_idx_from_weight_name(config, weight_name):
num_nextn = getattr(config, "num_nextn_predict_layers", 0)
if num_nextn and num_nextn > 0:
layer_idx = config.num_hidden_layers
for i in range(num_nextn):
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
return layer_idx + i
return None
def vllm__module_executor__models__deepseek_v2__DeepseekV2ForCausalLM__load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
'''
=============================
@@ -530,6 +620,10 @@ def vllm__module_executor__models__deepseek_v2__DeepseekV2ForCausalLM__load_weig
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# Skip MTP speculative decoding layer weights
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue
'''
=============================
Modify by vllm_mlu
@@ -565,7 +659,9 @@ def vllm__module_executor__models__deepseek_v2__DeepseekV2ForCausalLM__load_weig
@brief: add expert skiped condition and delete useless if name not in params_dict: continue condition
'''
name = name.replace(weight_name, param_name)
if (("mlp.experts." in name or "mlp.shared_experts." in name or "mlp.shared_expert_gate." in name)
if (("mlp.experts." in name or "mlp.shared_experts." in name
or "mlp.shared_expert_gate." in name
or "e_score_correction_bias" in name)
and name not in params_dict):
continue
'''
@@ -595,7 +691,9 @@ def vllm__module_executor__models__deepseek_v2__DeepseekV2ForCausalLM__load_weig
if name.endswith(".bias") and name not in params_dict:
continue
if (("mlp.experts." in name or "mlp.shared_experts." in name or "mlp.shared_expert_gate." in name)
if (("mlp.experts." in name or "mlp.shared_experts." in name
or "mlp.shared_expert_gate." in name
or "e_score_correction_bias" in name)
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):

View File

@@ -194,6 +194,11 @@ def decoder_model_forward_base_pp(
hidden_states = inputs_embeds
else:
hidden_states = get_input_embeddings(input_ids)
# MLU F.embedding may output float32 even with float16 weights;
# cast to model dtype to avoid dtype mismatches downstream.
target_dtype = next(layers[start_layer].parameters()).dtype
if hidden_states.dtype != target_dtype:
hidden_states = hidden_states.to(target_dtype)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]

View File

@@ -0,0 +1,495 @@
import torch
import re
from typing import List, Optional, Tuple, Union, Iterable
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.llama4 import (
Llama4Attention, Llama4DecoderLayer, Llama4ForCausalLM,
Llama4Model, Llama4MoE)
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm.sequence import IntermediateTensors
from vllm_mlu.model_executor.models.layer_utils import (
decoder_layer_forward_base, decoder_model_forward_base_pp,
is_per_tensor_smoothquant, is_per_token_smoothquant,
quant_fusion_with_rmsnorm)
from vllm.logger import init_logger
logger = init_logger(__name__)
# ============================================================
# Llama4MoE MLU replacement: SparseMoeMlp + shared expert
# ============================================================
class Llama4MoEMlu(SparseMoeMlp):
"""MLU replacement for Llama4MoE using SparseMoeMlp + shared expert."""
def __init__(self, config, quant_config=None, prefix=""):
num_local_experts = getattr(config, "num_local_experts", 8)
top_k = getattr(config, "num_experts_per_tok", 1)
hidden_size = getattr(config, "hidden_size", 4096)
intermediate_size = getattr(config, "intermediate_size", 8192)
super().__init__(
num_experts=num_local_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
up_proj_name="gate_up_proj",
is_gated=True,
down_proj_name="down_proj",
has_bias=False,
skip_bias_add=False,
renormalize=False,
hidden_act="silu",
params_dtype=None,
quant_config=quant_config,
is_use_fused_moe=True,
)
# Llama4 uses sigmoid routing, not softmax
# Override topk_softmax to use sigmoid
self._use_sigmoid_routing = True
# Shared expert (independent from routed experts)
self.shared_expert = FeedForward(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act="silu",
up_proj_name="gate_up_proj",
is_gated=True,
down_proj_name="down_proj",
bias=False,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_expert",
)
def topk_softmax(self, expert_logits):
"""Override: Llama4 uses sigmoid routing instead of softmax."""
topk_values, topk_indices = torch.topk(
expert_logits, self.top_k, dim=-1)
topk_values = torch.sigmoid(topk_values.float())
return topk_values, topk_indices
def forward(self, hidden_states, residual=None):
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# Shared expert output
shared_out = self.shared_expert(hidden_states)
# Router logits
router_logits, _ = self.gate(hidden_states)
# Routed experts
routed_out = self.forward_experts(hidden_states, router_logits, None)
# Combine
final_out = routed_out + shared_out
if self.tp_size > 1:
final_out = tensor_model_parallel_all_reduce(final_out)
return final_out.view(orig_shape)
# ============================================================
# Llama4Attention hijack
# ============================================================
vllm__llama4__Llama4Attention__init__org = Llama4Attention.__init__
def vllm__llama4__Llama4Attention____init__(
self,
config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
vllm__llama4__Llama4Attention__init__org(
self, config, hidden_size, num_heads, num_kv_heads,
max_position_embeddings, quant_config, bias, cache_config, prefix)
'''
=============================
Modify by vllm_mlu
=============================
@brief: save rope_scaling for MLU RoPE dispatch
'''
self.rope_scaling = getattr(config, "rope_scaling", None)
'''
==================
End of MLU Hijack
==================
'''
def vllm__llama4__Llama4Attention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: MLU RoPE: merge q/k, apply rotary, split back (教训 #3)
For NoPE layers (self.rotary_emb is None), skip RoPE entirely.
'''
if self.rotary_emb is not None:
if (self.rope_scaling is not None
and self.rope_scaling.get("rope_type") == "longrope"):
q, k = self.rotary_emb(positions, q, k)
else:
qk, _ = qkv.split(
[self.q_size + self.kv_size, self.kv_size], dim=-1)
self.rotary_emb(
positions,
qk.view(-1, self.num_heads + self.num_kv_heads,
self.head_dim))
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
'''
==================
End of MLU Hijack
==================
'''
# QK norm (MLU fused_rms_norm requires matching dtypes, skip .float())
if self.qk_norm is not None:
q = q.contiguous().reshape(-1, self.head_dim)
q = (self.qk_norm(q)
.contiguous().reshape(-1, self.q_size))
k = k.contiguous().reshape(-1, self.head_dim)
k = (self.qk_norm(k)
.contiguous().reshape(-1, self.kv_size))
# Temperature tuning for NoPE layers
if self.attn_temperature_tuning and self.nope:
attn_scale = self._get_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual in o_proj
'''
output, _ = self.o_proj(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output
# ============================================================
# Llama4DecoderLayer hijack
# ============================================================
def vllm__llama4__Llama4DecoderLayer____init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super(Llama4DecoderLayer, self).__init__()
from vllm.model_executor.models.llama4 import (
_extract_layer_index, Llama4Attention)
self.layer_idx = _extract_layer_index(prefix)
self.hidden_size = getattr(config, "hidden_size", 4096)
max_position_embeddings = getattr(
config, "max_position_embeddings", 8192)
self.self_attn = Llama4Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=getattr(config, "num_attention_heads", 32),
num_kv_heads=getattr(config, "num_key_value_heads",
getattr(config, "num_attention_heads", 32)),
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
interleave_moe_layer_step = getattr(
config, "interleave_moe_layer_step", 0)
is_moe_layer = (interleave_moe_layer_step > 0
and (self.layer_idx + 1)
% interleave_moe_layer_step == 0)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Replace MoE with Llama4MoEMlu (SparseMoeMlp + shared expert),
Replace dense MLP with FeedForward.
'''
if is_moe_layer:
self.feed_forward = Llama4MoEMlu(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
else:
intermediate_size_mlp = getattr(
config, "intermediate_size_mlp",
getattr(config, "intermediate_size", 8192))
self.feed_forward = FeedForward(
hidden_size=self.hidden_size,
intermediate_size=intermediate_size_mlp,
hidden_act="silu",
up_proj_name="gate_up_proj",
is_gated=True,
down_proj_name="down_proj",
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
'''
==================
End of MLU Hijack
==================
'''
rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5)
self.input_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
self.hidden_size, eps=rms_norm_eps)
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(
quant_config)
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(
quant_config)
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
self.quant_fusion_attn_layernorm = None
def vllm__llama4__Llama4DecoderLayer__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: use decoder_layer_forward_base with residual-in-matmul
and optional quant fusion.
'''
attn_layernorm = self.input_layernorm
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
if self.quant_fusion_attn_layernorm is None:
if self.is_per_token_sq_perf_cases:
attn_quant_scale = self.self_attn.qkv_proj.smooth
else:
attn_quant_scale = self.self_attn.qkv_proj.scale_to_int
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.input_layernorm, attn_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
attn_layernorm = self.quant_fusion_attn_layernorm
return decoder_layer_forward_base(
positions, hidden_states, kv_cache, attn_metadata,
attn_layernorm,
self.self_attn,
self.post_attention_layernorm,
self.feed_forward,
input_norm_fuse_en=self.is_per_token_sq_perf_cases)
# ============================================================
# Llama4Model hijack
# ============================================================
def vllm__llama4__Llama4Model__forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return decoder_model_forward_base_pp(
input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors,
self.layers, self.start_layer, self.end_layer,
self.get_input_embeddings,
self.norm,
inputs_embeds)
# ============================================================
# Llama4ForCausalLM load_weights hijack
# ============================================================
def vllm__llama4__Llama4ForCausalLM__load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
):
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack params for SparseMoeMlp (MoE layers)
'''
for name, m in self.model.named_modules():
if isinstance(m, SparseMoeMlp):
m.pack_params()
start_expert_id = 0
'''
==================
End of MLU Hijack
==================
'''
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# Strip language_model. prefix for Llama4ForConditionalGeneration
if name.startswith("language_model."):
name = name[len("language_model."):]
# Skip vision encoder weights
elif (name.startswith("multi_modal_projector.")
or name.startswith("vision_encoder.")
or name.startswith("vision_model.")):
continue
# Permute Q/K weights for rotary embedding
name, loaded_weight = self.permute_qk_weight_for_rotary(
name, loaded_weight)
'''
=============================
Modify by vllm_mlu
=============================
@brief: remap expert_id for distributed inference
'''
if (start_expert_id > 0
and "feed_forward.experts." in name):
match = re.search(r'experts\.\d+', name)
if match:
expert_str = match.group(0)
expert_id = int(expert_str.split(".")[1])
named_expert_id = expert_id - start_expert_id
name = name.replace(
f"experts.{expert_id}",
f"experts.{named_expert_id}")
'''
==================
End of MLU Hijack
==================
'''
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
# Skip experts not assigned to this worker
if ("feed_forward.experts." in name
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
# Skip experts not assigned to this worker
if ("feed_forward.experts." in name
and name not in params_dict):
continue
if name not in params_dict:
logger.warning(
"Skipping weight %s not present in the model", name)
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
# ============================================================
# Apply all hijacks
# ============================================================
MluHijackObject.apply_hijack(
Llama4Attention,
Llama4Attention.__init__,
vllm__llama4__Llama4Attention____init__)
MluHijackObject.apply_hijack(
Llama4Attention,
Llama4Attention.forward,
vllm__llama4__Llama4Attention__forward)
MluHijackObject.apply_hijack(
Llama4DecoderLayer,
Llama4DecoderLayer.__init__,
vllm__llama4__Llama4DecoderLayer____init__)
MluHijackObject.apply_hijack(
Llama4DecoderLayer,
Llama4DecoderLayer.forward,
vllm__llama4__Llama4DecoderLayer__forward)
MluHijackObject.apply_hijack(
Llama4Model,
Llama4Model.forward,
vllm__llama4__Llama4Model__forward)
MluHijackObject.apply_hijack(
Llama4ForCausalLM,
Llama4ForCausalLM.load_weights,
vllm__llama4__Llama4ForCausalLM__load_weights)