[Model][3/N] Refactor sfa into mla and remove deepseek_v3_2.py (#3769)

This is the follow-up PR to PR #3189, which continues to refactor sfa
into mla and finally remove deepseek_v3_2.py. This is the last PR of
deepseek modeling refactoring. After this, all deepseek-related model
codes are removed from vllm_ascend.

FurtherMore, after this PR deepseek v3.2 can run chunk-prefill with
correct accuracy.

- vLLM version: v0.11.0rc3
- vLLM main:
83f478bb19

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-10-30 17:06:38 +08:00
committed by GitHub
parent eff3e5fc6f
commit f6149f3894
10 changed files with 751 additions and 1935 deletions

View File

@@ -52,6 +52,8 @@ if prefill_context_parallel_enable():
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
class AscendMLABackend(AttentionBackend):
@@ -808,16 +810,17 @@ class AscendMLAImpl(MLAAttentionImpl):
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
# Currently mlapo only supports W8A8 quantization in MLA scenario
# TODO(whx): modify this limitation when mlapo supports floating point
if self.fused_qkv_a_proj is None or not isinstance(
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
None), AscendW8A8LinearMethod):
self.enable_mlapo = False
logger.warning_once(
"Currently mlapo only supports W8A8 quantization in MLA scenario."
"Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers.")
if self.enable_mlapo:
# Currently mlapo only supports W8A8 quantization in MLA scenario
# TODO(whx): modify this limitation when mlapo supports floating point
if self.fused_qkv_a_proj is None or not isinstance(
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
None), AscendW8A8LinearMethod):
self.enable_mlapo = False
logger.warning_once(
"Currently mlapo only supports W8A8 quantization in MLA scenario."
"Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers.")
if self.enable_mlapo:
self._process_weights_for_fused_mlapo(act_dtype)
@@ -1282,12 +1285,13 @@ class AscendMLAImpl(MLAAttentionImpl):
def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
attn_metadata, need_gather_q_kv):
# MLA Preprocess:
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
# 3. If need_gather_q_kv, perform all_gather.
# 4. Preprocess decode tokens, write kv cache and get:
# 1. Perform fused_qkv_a_proj and q_a_layernorm to obtain q_c and kv_no_split
# or
# Perform kv_a_proj_with_mqa to obtain kv_no_split
# 2. If need_gather_q_kv, perform all_gather.
# 3. Preprocess decode tokens, write kv cache and get:
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
# 5. Preprocess prefill tokens, write kv cache and get:
# 4. Preprocess prefill tokens, write kv cache and get:
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0

File diff suppressed because it is too large Load Diff

View File

@@ -29,10 +29,6 @@ def register_model():
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding"
)
ModelRegistry.register_model(
"DeepseekV32ForCausalLM",
"vllm_ascend.models.deepseek_v3_2:CustomDeepseekV3ForCausalLM")
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
ModelRegistry.register_model(

View File

@@ -1,658 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# # Adapted from
# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py
# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
# """Inference-only DeepseekV2/DeepseekV3 model."""
from typing import Any, Dict, Iterable, Optional, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (divide, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group, split_tensor_along_last_dim,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.deepseek_v2 import \
yarn_get_mscale # noqa: E501
from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE,
get_spec_layer_idx_from_weight_name)
from vllm.model_executor.models.utils import (
PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.sfa import AscendSFAModules, Indexer
from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE
from vllm_ascend.ops.linear import AscendLinearBase
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
else:
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
@support_torch_compile
class AscendDeepseekV2Model(DeepseekV2Model, nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Rewrite this init func mainly for removing cuda-hard code
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
assert hasattr(config, "index_topk")
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device=current_platform.device_type)
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens")
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = nn.Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
def forward(
self,
input_,
is_prefill=True,
is_force_scatter=False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class CustomDeepseekV2SFAAttention(DeepseekV2MLAAttention):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
self.tp_size = get_tensor_model_parallel_world_size()
assert num_heads % self.tp_size == 0
self.num_local_heads = num_heads // self.tp_size
self.layers = config.num_hidden_layers
self.first_k_dense_replace = config.first_k_dense_replace
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
if self.q_lora_rank is not None:
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj",
return_bias=False,
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
return_bias=False,
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj",
return_bias=False,
)
self.o_proj = CustomDeepseekV2RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
return_bias=False,
)
if self.q_lora_rank is not None:
self.fused_qkv_a_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_qkv_a_proj",
disable_tp=True)
self.kv_a_proj_with_mqa = None
else:
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.dim: int = config.hidden_size # 7168
# TODO(zzzzwwjj): wait transformers add these params
self.n_heads: int = 64 # 64
self.head_dim: int = 128 # 128
self.index_topk: int = 2048 # 2048
self.indexer = Indexer(
config,
quant_config=quant_config,
dim=self.dim,
n_heads=self.n_heads,
head_dim=self.head_dim,
index_topk=self.index_topk,
prefix=f"{prefix}.indexer",
)
sfa_modules = AscendSFAModules(
q_a_layernorm=self.q_a_layernorm
if self.q_lora_rank is not None else None,
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
fused_qkv_a_proj=self.fused_qkv_a_proj
if self.q_lora_rank is not None else None,
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
rotary_emb=self.rotary_emb,
indexer=self.indexer,
is_sparse=hasattr(config, "index_topk"),
topk_indices_buffer=None)
if vllm_version_is("0.11.0"):
self.sfa_attn = MultiHeadLatentAttention(
hidden_size=self.hidden_size,
num_heads=self.num_local_heads,
scale=self.scaling,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
mla_modules=sfa_modules,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
)
else:
self.sfa_attn = MultiHeadLatentAttentionWrapper(
hidden_size=self.hidden_size,
num_heads=self.num_local_heads,
scale=self.scaling,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
mla_modules=sfa_modules,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
)
self.prefix = prefix
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
return self.sfa_attn(positions, hidden_states, kv_cache, attn_metadata)
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
def __init__(self,
vllm_config: VllmConfig,
prefix: str,
topk_indices_buffer=None) -> None:
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx
self.layers = config.num_hidden_layers
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tp_group().rank_in_group
# TODO: enable mla in vllm-ascend
if model_config.use_mla:
attn_cls = CustomDeepseekV2SFAAttention
else:
attn_cls = DeepseekV2Attention
self.self_attn = attn_cls(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=config.q_lora_rank
if hasattr(config, "q_lora_rank") else None,
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekV2MoE(
config=config,
parallel_config=parallel_config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
if self.mlp.gate.e_score_correction_bias is not None:
self.mlp.gate.e_score_correction_bias.data = (
self.mlp.gate.e_score_correction_bias.data.to(
dtype=torch.get_default_dtype()))
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
self.first_k_dense_replace = config.first_k_dense_replace
self.tp_group = get_tp_group().device_group
class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
# `packed_modules_mapping` needs to be modified before
# initializing DeepseekV2Model, as it is passed inplace to
# quantization config init and may be used to select the
# quant_method for relevant layers during initialization.
self.fuse_qkv_a_proj = hasattr(
config, "q_lora_rank") and config.q_lora_rank is not None
if self.fuse_qkv_a_proj:
self.packed_modules_mapping["fused_qkv_a_proj"] = [
"q_a_proj",
"kv_a_proj_with_mqa",
]
self.model = AscendDeepseekV2Model(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.expert_weights: list[Any] = []
# Set MoE hyperparameters
self.num_moe_layers = (config.num_hidden_layers -
config.first_k_dense_replace)
self.num_expert_groups = config.n_group
self.moe_layers: list[FusedMoE] = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, DeepseekV2DecoderLayer)
if isinstance(layer.mlp, DeepseekV2MoE):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_moe is None:
raise RuntimeError("No DeepseekV2MoE layer found in model.layers.")
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
# NOTE: This `load_weights` is mainly copied from
# https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5
# to fix CI, and it is different from the implementation in main
# TODO: support eplb style load_weights
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
""""""
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = AscendFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "module" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
return_success=False)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
pass
DeepseekV2DecoderLayer.__init__ = CustomDeepseekV2DecoderLayer.__init__

View File

@@ -52,6 +52,35 @@ else:
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
class IndexerWrapper(nn.Module):
'''
A wrapper of Indexer for Deepseek v3.2.
This wrapper is currently used to solve the fp8 hard code issue of vllm's deepseek_v2.py.
It wraps the original Indexer, inherits its module weights
(including wq_b, wk, weights_proj, k_norm)
while deletes the unused topk_indices_buffer and k_cache to save memory.
TODO: Will be removed once original Indexer supports different quantization methods.
'''
def __init__(self, vllm_indexer: nn.Module) -> None:
super().__init__()
self.n_head: int = vllm_indexer.n_head # 64
self.head_dim: int = vllm_indexer.head_dim # 128
self.topk_tokens: int = vllm_indexer.topk_tokens # 2048
self.q_lora_rank: int = vllm_indexer.q_lora_rank # 1536
self.wq_b = vllm_indexer.wq_b
self.wk = vllm_indexer.wk
self.weights_proj = vllm_indexer.weights_proj
self.k_norm = vllm_indexer.k_norm
self.softmax_scale = vllm_indexer.softmax_scale
vllm_indexer.topk_indices_buffer = None # delete topk_indices_buffer
vllm_indexer.k_cache = None # delete k_cache
def forward(self):
return
# TODO(whx): adapt v0.11.0 and DSA
class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
@@ -86,6 +115,10 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
self.first_k_dense_replace = hf_config.first_k_dense_replace
self.tp_size = get_tensor_model_parallel_world_size()
self.layers = hf_config.num_hidden_layers
if mla_modules.indexer is not None:
ascend_indexer = IndexerWrapper(mla_modules.indexer)
else:
ascend_indexer = None
if vllm_version_is("0.11.0"):
self.mla_attn = Attention(
@@ -97,6 +130,8 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
indexer=ascend_indexer,
use_sparse=mla_modules.is_sparse,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
@@ -128,7 +163,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_sparse=mla_modules.is_sparse,
indexer=mla_modules.indexer,
indexer=ascend_indexer,
# extra args
rotary_emb=mla_modules.rotary_emb,
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,

View File

@@ -1,275 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.mla import MLAModules
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.attention import Attention
from vllm.model_executor.layers.mla import \
MultiHeadLatentAttention as MultiHeadLatentAttentionWrapper
from vllm.utils import direct_register_custom_op
else:
from vllm.attention.layer import MLAAttention
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
from vllm.utils.torch_utils import direct_register_custom_op
@dataclass
class AscendSFAModules:
q_a_layernorm: Optional[torch.nn.Module]
q_proj: Optional[torch.nn.Module]
kv_a_proj_with_mqa: torch.nn.Module
kv_a_layernorm: torch.nn.Module
kv_b_proj: torch.nn.Module
o_proj: torch.nn.Module
rotary_emb: torch.nn.Module
indexer: torch.nn.Module
is_sparse: bool
fused_qkv_a_proj: Optional[torch.nn.Module]
q_b_proj: Optional[torch.nn.Module]
topk_indices_buffer: Optional[torch.Tensor]
class AscendSparseFlashAttention(MultiHeadLatentAttentionWrapper):
def __init__(
self,
hidden_size: int,
num_heads: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
mla_modules: MLAModules,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.q_lora_rank = q_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_rope_head_dim + qk_nope_head_dim
self.v_head_dim = v_head_dim
self.prefix = prefix
self.scaling = scale
self.indexer = mla_modules.indexer
self.is_sparse = mla_modules.is_sparse
hf_config = get_current_vllm_config().model_config.hf_config
self.enable_shared_expert_dp = get_ascend_config(
).enable_shared_expert_dp
self.debug_layer_idx = int(self.prefix.split(".")[-2])
self.first_k_dense_replace = hf_config.first_k_dense_replace
self.tp_size = get_tensor_model_parallel_world_size()
self.layers = hf_config.num_hidden_layers
if vllm_version_is("0.11.0"):
self.sfa_attn = Attention(
num_heads=num_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scale,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=True,
indexer=self.indexer,
# SFA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
qk_head_dim=self.qk_head_dim,
rotary_emb=mla_modules.rotary_emb,
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,
q_b_proj=mla_modules.q_b_proj,
q_a_layernorm=mla_modules.q_a_layernorm,
q_proj=mla_modules.q_proj,
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
kv_a_layernorm=mla_modules.kv_a_layernorm,
kv_b_proj=mla_modules.kv_b_proj,
o_proj=mla_modules.o_proj,
)
else:
self.sfa_attn = MLAAttention(
num_heads=num_heads,
scale=scale,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
kv_b_proj=mla_modules.kv_b_proj,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_sparse=mla_modules.is_sparse,
indexer=mla_modules.indexer,
# extra args
rotary_emb=mla_modules.rotary_emb,
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,
q_b_proj=mla_modules.q_b_proj,
q_a_layernorm=mla_modules.q_a_layernorm,
q_proj=mla_modules.q_proj,
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
kv_a_layernorm=mla_modules.kv_a_layernorm,
o_proj=mla_modules.o_proj,
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
need_gather_q_kv = False
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
# Simulate all gather to calculate output shape
num_tokens = num_tokens * self.tp_size
need_gather_q_kv = True
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
output_shape = hidden_states.shape
else:
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
torch.ops.vllm.sfa_forward(hidden_states, need_gather_q_kv, output,
self.prefix)
output = output.view(-1, output_shape[-1])
return output
def sfa_forward(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
if forward_context.attn_metadata:
attn_metadata = forward_context.attn_metadata[self.sfa_attn.layer_name]
else:
attn_metadata = forward_context.attn_metadata
kv_cache = self.sfa_attn.kv_cache[forward_context.virtual_engine]
self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
need_gather_q_kv, output)
return
class Indexer(nn.Module):
def __init__(self,
config,
dim: int = 7168,
n_heads: int = 64,
head_dim: int = 128,
index_topk: int = 2048,
q_lora_rank: int = 1536,
rope_head_dim: int = 64,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = ""):
super().__init__()
self.dim: int = dim # 7168
self.n_heads: int = n_heads # 64
self.head_dim: int = head_dim # 128
self.rope_head_dim: int = rope_head_dim # 64
self.index_topk: int = index_topk # 2048
self.q_lora_rank: int = q_lora_rank # 1536
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
return_bias=False,
)
self.wk = ReplicatedLinear(
self.dim,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
return_bias=False,
)
self.weights_proj = ReplicatedLinear(
self.dim,
self.n_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.weights_proj",
return_bias=False,
)
self.k_norm = nn.LayerNorm(self.head_dim)
self.softmax_scale = self.head_dim**-0.5
def forward(self):
return
def sfa_forward_fake(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="sfa_forward",
op_func=sfa_forward,
mutates_args=["output"],
fake_impl=sfa_forward_fake,
dispatch_key="PrivateUse1",
)

View File

@@ -33,3 +33,4 @@ from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
import vllm_ascend.patch.worker.patch_deepseek_mtp # noqa
import vllm_ascend.patch.worker.patch_deepseek_v3_2 # noqa

View File

@@ -0,0 +1,108 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from itertools import islice
from typing import Optional, Union
import torch
import vllm.model_executor.models.deepseek_v2
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.vocab_parallel_embedding import \
VocabParallelEmbedding
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
from vllm.model_executor.models.utils import (
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers)
from vllm.sequence import IntermediateTensors
@support_torch_compile
class DeepseekV2Model(nn.Module):
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
self.is_v32 = hasattr(config, "index_topk")
topk_indices_buffer = None
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens")
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
vllm.model_executor.models.deepseek_v2.DeepseekV2Model = DeepseekV2Model

View File

@@ -69,7 +69,6 @@ from vllm.sequence import IntermediateTensors
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.sfa import Indexer
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
@@ -83,6 +82,57 @@ else:
from vllm.attention.layer import MLAAttention
class Indexer(nn.Module):
def __init__(self,
config,
dim: int = 7168,
n_heads: int = 64,
head_dim: int = 128,
index_topk: int = 2048,
q_lora_rank: int = 1536,
rope_head_dim: int = 64,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = ""):
super().__init__()
self.dim: int = dim # 7168
self.n_heads: int = n_heads # 64
self.head_dim: int = head_dim # 128
self.rope_head_dim: int = rope_head_dim # 64
self.index_topk: int = index_topk # 2048
self.q_lora_rank: int = q_lora_rank # 1536
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
return_bias=False,
)
self.wk = ReplicatedLinear(
self.dim,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
return_bias=False,
)
self.weights_proj = ReplicatedLinear(
self.dim,
self.n_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.weights_proj",
return_bias=False,
)
self.k_norm = nn.LayerNorm(self.head_dim)
self.softmax_scale = self.head_dim**-0.5
def forward(self):
return
class TorchairDeepseekV2SiluAndMul(SiluAndMul):
def __init__(self,

View File

@@ -577,7 +577,6 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
from vllm.model_executor.custom_op import CustomOp
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
from vllm_ascend.models.layers.sfa import AscendSparseFlashAttention
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE,
AscendSharedFusedMoE)
@@ -625,10 +624,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
mla_to_register = "MultiHeadLatentAttention" if vllm_version_is(
"0.11.0") else "MultiHeadLatentAttentionWrapper"
if vllm_config and vllm_config.model_config and vllm_config.model_config.use_mla:
AscendMLAAttentionWarrper = AscendSparseFlashAttention if hasattr(
vllm_config.model_config.hf_config,
"index_topk") else AscendMultiHeadLatentAttention
REGISTERED_ASCEND_OPS[mla_to_register] = AscendMLAAttentionWarrper
REGISTERED_ASCEND_OPS[mla_to_register] = AscendMultiHeadLatentAttention
for name, op_cls in REGISTERED_ASCEND_OPS.items():
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)