add deepseekv3 and llama4

This commit is contained in:
Chranos
2026-02-11 14:19:17 +08:00
parent 659ef273c8
commit 128aed196c
6 changed files with 1069 additions and 8 deletions

View File

@@ -582,15 +582,24 @@ def unified_flash_attention_v2(
else:
# unpaged (linear cache) path
if use_mla:
# MLA cache 是 2D (total_slots, head_dim)
# 不能用 reshape_paged_cache期望 4D直接索引写入
# MLA: 镜像 paged 路径的处理方式
# key_cache: (num_blocks, 1, block_size, 576)
value_to_cache = None
if attn_metadata.prefill_metadata:
# MLA prefill cache 已在 forward_prefill 中写入,跳过
pass
else:
# key: (num_tokens, 1, head_dim) → squeeze → (num_tokens, head_dim)
# key_cache: (total_slots, head_dim)
key_cache[updated_slot_mapping.flatten()] = key.squeeze(1)
if kv_cache_dtype == 'int8':
mlu_ops.quant_to_paged_cache(
key, value_to_cache,
key_cache, value_cache,
key_cache_scale, value_cache_scale,
updated_slot_mapping.flatten())
else:
mlu_ops.reshape_paged_cache(
key, value_to_cache,
key_cache, value_cache,
updated_slot_mapping.flatten())
else:
# FIXME: After TMO-1496 is completed, remove this code.
if key.stride() != value.stride():

View File

@@ -37,7 +37,7 @@ def vllm__config__CacheConfig___verify_cache_dtype(self) -> None:
def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type == 'deepseek_v2':
if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3'):
# feature flag MLA
return 1
total_num_kv_heads = self.get_total_num_kv_heads()
@@ -51,7 +51,7 @@ def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "Parallel
def vllm__config__ModelConfig__get_head_size(self) -> int:
# TODO remove hard code
if hasattr(self.hf_text_config, "model_type"
) and self.hf_text_config.model_type == 'deepseek_v2':
) and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3'):
'''
=============================
Modify by vllm_mlu
@@ -109,7 +109,7 @@ def vllm__config__LoRAConfig__verify_with_model_config(self, model_config: Model
def vllm__config__ModelConfig__is_deepseek_v2(self) -> bool:
result = hasattr(
self.hf_text_config,
"model_type") and self.hf_text_config.model_type == 'deepseek_v2'
"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3')
return result
MluHijackObject.apply_hijack(ModelConfig,

View File

@@ -39,3 +39,9 @@ try:
except ImportError as e:
import logging
logging.warning(f"Failed to import mllama hijack: {e}")
try:
import vllm_mlu.model_executor.models.llama4
except ImportError as e:
import logging
logging.warning(f"Failed to import llama4 hijack: {e}")

View File

@@ -0,0 +1,485 @@
import torch
import re
from typing import List, Optional, Tuple, Union, Iterable
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama4 import (
Llama4Attention, Llama4DecoderLayer, Llama4ForCausalLM,
Llama4Model, Llama4MoE)
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm.sequence import IntermediateTensors
from vllm_mlu.model_executor.models.layer_utils import (
decoder_layer_forward_base, decoder_model_forward_base_pp,
is_per_tensor_smoothquant, is_per_token_smoothquant,
quant_fusion_with_rmsnorm)
from vllm.logger import init_logger
logger = init_logger(__name__)
# ============================================================
# Llama4MoE MLU replacement: SparseMoeMlp + shared expert
# ============================================================
class Llama4MoEMlu(SparseMoeMlp):
"""MLU replacement for Llama4MoE using SparseMoeMlp + shared expert."""
def __init__(self, config, quant_config=None, prefix=""):
num_local_experts = getattr(config, "num_local_experts", 8)
top_k = getattr(config, "num_experts_per_tok", 1)
hidden_size = getattr(config, "hidden_size", 4096)
intermediate_size = getattr(config, "intermediate_size", 8192)
super().__init__(
num_experts=num_local_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
up_proj_name="gate_up_proj",
is_gated=True,
down_proj_name="down_proj",
has_bias=False,
skip_bias_add=False,
renormalize=False,
hidden_act="silu",
params_dtype=None,
quant_config=quant_config,
is_use_fused_moe=True,
)
# Llama4 uses sigmoid routing, not softmax
# Override topk_softmax to use sigmoid
self._use_sigmoid_routing = True
# Shared expert (independent from routed experts)
self.shared_expert = FeedForward(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act="silu",
up_proj_name="gate_up_proj",
is_gated=True,
down_proj_name="down_proj",
bias=False,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_expert",
)
def topk_softmax(self, expert_logits):
"""Override: Llama4 uses sigmoid routing instead of softmax."""
topk_values, topk_indices = torch.topk(
expert_logits, self.top_k, dim=-1)
topk_values = torch.sigmoid(topk_values.float())
return topk_values, topk_indices
def forward(self, hidden_states, residual=None):
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# Shared expert output
shared_out = self.shared_expert(hidden_states)
# Router logits
router_logits, _ = self.gate(hidden_states)
# Routed experts
routed_out = self.forward_experts(hidden_states, router_logits, None)
# Combine
final_out = routed_out + shared_out
if self.tp_size > 1:
final_out = tensor_model_parallel_all_reduce(final_out)
return final_out.view(orig_shape)
# ============================================================
# Llama4Attention hijack
# ============================================================
vllm__llama4__Llama4Attention__init__org = Llama4Attention.__init__
def vllm__llama4__Llama4Attention____init__(
self,
config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
vllm__llama4__Llama4Attention__init__org(
self, config, hidden_size, num_heads, num_kv_heads,
max_position_embeddings, quant_config, bias, cache_config, prefix)
'''
=============================
Modify by vllm_mlu
=============================
@brief: save rope_scaling for MLU RoPE dispatch
'''
self.rope_scaling = getattr(config, "rope_scaling", None)
'''
==================
End of MLU Hijack
==================
'''
def vllm__llama4__Llama4Attention__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
smooth_quant_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
'''
=============================
Modify by vllm_mlu
=============================
@brief: MLU RoPE: merge q/k, apply rotary, split back (教训 #3)
For NoPE layers (self.rotary_emb is None), skip RoPE entirely.
'''
if self.rotary_emb is not None:
if (self.rope_scaling is not None
and self.rope_scaling.get("rope_type") == "longrope"):
q, k = self.rotary_emb(positions, q, k)
else:
qk, _ = qkv.split(
[self.q_size + self.kv_size, self.kv_size], dim=-1)
self.rotary_emb(
positions,
qk.view(-1, self.num_heads + self.num_kv_heads,
self.head_dim))
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
'''
==================
End of MLU Hijack
==================
'''
# QK norm (教训 #2: use contiguous + reshape)
if self.qk_norm is not None:
q = q.contiguous().reshape(-1, self.head_dim)
q = (self.qk_norm(q.float())
.contiguous().reshape(-1, self.q_size).to(q.dtype))
k = k.contiguous().reshape(-1, self.head_dim)
k = (self.qk_norm(k.float())
.contiguous().reshape(-1, self.kv_size).to(k.dtype))
# Temperature tuning for NoPE layers
if self.attn_temperature_tuning and self.nope:
attn_scale = self._get_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual in o_proj
'''
output, _ = self.o_proj(attn_output, residual)
'''
==================
End of MLU Hijack
==================
'''
return output
# ============================================================
# Llama4DecoderLayer hijack
# ============================================================
def vllm__llama4__Llama4DecoderLayer____init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super(Llama4DecoderLayer, self).__init__()
from vllm.model_executor.models.llama4 import (
_extract_layer_index, Llama4Attention)
self.layer_idx = _extract_layer_index(prefix)
self.hidden_size = getattr(config, "hidden_size", 4096)
max_position_embeddings = getattr(
config, "max_position_embeddings", 8192)
self.self_attn = Llama4Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=getattr(config, "num_attention_heads", 32),
num_kv_heads=getattr(config, "num_key_value_heads",
getattr(config, "num_attention_heads", 32)),
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
interleave_moe_layer_step = getattr(
config, "interleave_moe_layer_step", 0)
is_moe_layer = (interleave_moe_layer_step > 0
and (self.layer_idx + 1)
% interleave_moe_layer_step == 0)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Replace MoE with Llama4MoEMlu (SparseMoeMlp + shared expert),
Replace dense MLP with FeedForward.
'''
if is_moe_layer:
self.feed_forward = Llama4MoEMlu(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
else:
intermediate_size_mlp = getattr(
config, "intermediate_size_mlp",
getattr(config, "intermediate_size", 8192))
self.feed_forward = FeedForward(
hidden_size=self.hidden_size,
intermediate_size=intermediate_size_mlp,
hidden_act="silu",
up_proj_name="gate_up_proj",
is_gated=True,
down_proj_name="down_proj",
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
'''
==================
End of MLU Hijack
==================
'''
rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5)
self.input_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
self.hidden_size, eps=rms_norm_eps)
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(
quant_config)
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(
quant_config)
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
self.quant_fusion_attn_layernorm = None
def vllm__llama4__Llama4DecoderLayer__forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: use decoder_layer_forward_base with residual-in-matmul
and optional quant fusion.
'''
attn_layernorm = self.input_layernorm
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
if self.quant_fusion_attn_layernorm is None:
if self.is_per_token_sq_perf_cases:
attn_quant_scale = self.self_attn.qkv_proj.smooth
else:
attn_quant_scale = self.self_attn.qkv_proj.scale_to_int
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
self.input_layernorm, attn_quant_scale,
dynamic_quant=self.is_per_token_sq_perf_cases)
attn_layernorm = self.quant_fusion_attn_layernorm
return decoder_layer_forward_base(
positions, hidden_states, kv_cache, attn_metadata,
attn_layernorm,
self.self_attn,
self.post_attention_layernorm,
self.feed_forward,
input_norm_fuse_en=self.is_per_token_sq_perf_cases)
# ============================================================
# Llama4Model hijack
# ============================================================
def vllm__llama4__Llama4Model__forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return decoder_model_forward_base_pp(
input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors,
self.layers, self.start_layer, self.end_layer,
self.get_input_embeddings,
self.norm,
inputs_embeds)
# ============================================================
# Llama4ForCausalLM load_weights hijack
# ============================================================
def vllm__llama4__Llama4ForCausalLM__load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
):
'''
=============================
Modify by vllm_mlu
=============================
@brief: pack params for SparseMoeMlp (MoE layers)
'''
for name, m in self.model.named_modules():
if isinstance(m, SparseMoeMlp):
m.pack_params()
start_expert_id = 0
'''
==================
End of MLU Hijack
==================
'''
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# Permute Q/K weights for rotary embedding
name, loaded_weight = self.permute_qk_weight_for_rotary(
name, loaded_weight)
'''
=============================
Modify by vllm_mlu
=============================
@brief: remap expert_id for distributed inference
'''
if (start_expert_id > 0
and "feed_forward.experts." in name):
match = re.search(r'experts\.\d+', name)
if match:
expert_str = match.group(0)
expert_id = int(expert_str.split(".")[1])
named_expert_id = expert_id - start_expert_id
name = name.replace(
f"experts.{expert_id}",
f"experts.{named_expert_id}")
'''
==================
End of MLU Hijack
==================
'''
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
# Skip experts not assigned to this worker
if ("feed_forward.experts." in name
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
# Skip experts not assigned to this worker
if ("feed_forward.experts." in name
and name not in params_dict):
continue
if name not in params_dict:
logger.warning(
"Skipping weight %s not present in the model", name)
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
# ============================================================
# Apply all hijacks
# ============================================================
MluHijackObject.apply_hijack(
Llama4Attention,
Llama4Attention.__init__,
vllm__llama4__Llama4Attention____init__)
MluHijackObject.apply_hijack(
Llama4Attention,
Llama4Attention.forward,
vllm__llama4__Llama4Attention__forward)
MluHijackObject.apply_hijack(
Llama4DecoderLayer,
Llama4DecoderLayer.__init__,
vllm__llama4__Llama4DecoderLayer____init__)
MluHijackObject.apply_hijack(
Llama4DecoderLayer,
Llama4DecoderLayer.forward,
vllm__llama4__Llama4DecoderLayer__forward)
MluHijackObject.apply_hijack(
Llama4Model,
Llama4Model.forward,
vllm__llama4__Llama4Model__forward)
MluHijackObject.apply_hijack(
Llama4ForCausalLM,
Llama4ForCausalLM.load_weights,
vllm__llama4__Llama4ForCausalLM__load_weights)