[Model][3/N] Refactor sfa into mla and remove deepseek_v3_2.py (#3769)

This is the follow-up PR to PR #3189, which continues to refactor sfa
into mla and finally remove deepseek_v3_2.py. This is the last PR of
deepseek modeling refactoring. After this, all deepseek-related model
codes are removed from vllm_ascend.

FurtherMore, after this PR deepseek v3.2 can run chunk-prefill with
correct accuracy.

- vLLM version: v0.11.0rc3
- vLLM main:
83f478bb19

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-10-30 17:06:38 +08:00
committed by GitHub
parent eff3e5fc6f
commit f6149f3894
10 changed files with 751 additions and 1935 deletions

View File

@@ -52,6 +52,8 @@ if prefill_context_parallel_enable():
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
class AscendMLABackend(AttentionBackend): class AscendMLABackend(AttentionBackend):
@@ -808,16 +810,17 @@ class AscendMLAImpl(MLAAttentionImpl):
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
# Currently mlapo only supports W8A8 quantization in MLA scenario if self.enable_mlapo:
# TODO(whx): modify this limitation when mlapo supports floating point # Currently mlapo only supports W8A8 quantization in MLA scenario
if self.fused_qkv_a_proj is None or not isinstance( # TODO(whx): modify this limitation when mlapo supports floating point
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method', if self.fused_qkv_a_proj is None or not isinstance(
None), AscendW8A8LinearMethod): getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
self.enable_mlapo = False None), AscendW8A8LinearMethod):
logger.warning_once( self.enable_mlapo = False
"Currently mlapo only supports W8A8 quantization in MLA scenario." logger.warning_once(
"Some layers in your model are not quantized with W8A8," "Currently mlapo only supports W8A8 quantization in MLA scenario."
"thus mlapo is disabled for these layers.") "Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers.")
if self.enable_mlapo: if self.enable_mlapo:
self._process_weights_for_fused_mlapo(act_dtype) self._process_weights_for_fused_mlapo(act_dtype)
@@ -1282,12 +1285,13 @@ class AscendMLAImpl(MLAAttentionImpl):
def _mla_preprocess(self, layer_name, hidden_states, kv_cache, def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
attn_metadata, need_gather_q_kv): attn_metadata, need_gather_q_kv):
# MLA Preprocess: # MLA Preprocess:
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c # 1. Perform fused_qkv_a_proj and q_a_layernorm to obtain q_c and kv_no_split
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split # or
# 3. If need_gather_q_kv, perform all_gather. # Perform kv_a_proj_with_mqa to obtain kv_no_split
# 4. Preprocess decode tokens, write kv cache and get: # 2. If need_gather_q_kv, perform all_gather.
# 3. Preprocess decode tokens, write kv cache and get:
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope # decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
# 5. Preprocess prefill tokens, write kv cache and get: # 4. Preprocess prefill tokens, write kv cache and get:
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value # prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
has_decode = attn_metadata.num_decodes > 0 has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0

File diff suppressed because it is too large Load Diff

View File

@@ -29,10 +29,6 @@ def register_model():
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding" "vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding"
) )
ModelRegistry.register_model(
"DeepseekV32ForCausalLM",
"vllm_ascend.models.deepseek_v3_2:CustomDeepseekV3ForCausalLM")
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization # There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM. # to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
ModelRegistry.register_model( ModelRegistry.register_model(

View File

@@ -1,658 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# # Adapted from
# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py
# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
# """Inference-only DeepseekV2/DeepseekV3 model."""
from typing import Any, Dict, Iterable, Optional, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (divide, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group, split_tensor_along_last_dim,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.deepseek_v2 import \
yarn_get_mscale # noqa: E501
from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE,
get_spec_layer_idx_from_weight_name)
from vllm.model_executor.models.utils import (
PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.sfa import AscendSFAModules, Indexer
from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE
from vllm_ascend.ops.linear import AscendLinearBase
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
else:
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
@support_torch_compile
class AscendDeepseekV2Model(DeepseekV2Model, nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Rewrite this init func mainly for removing cuda-hard code
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
assert hasattr(config, "index_topk")
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device=current_platform.device_type)
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens")
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = nn.Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
def forward(
self,
input_,
is_prefill=True,
is_force_scatter=False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class CustomDeepseekV2SFAAttention(DeepseekV2MLAAttention):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
self.tp_size = get_tensor_model_parallel_world_size()
assert num_heads % self.tp_size == 0
self.num_local_heads = num_heads // self.tp_size
self.layers = config.num_hidden_layers
self.first_k_dense_replace = config.first_k_dense_replace
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
if self.q_lora_rank is not None:
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj",
return_bias=False,
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
return_bias=False,
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj",
return_bias=False,
)
self.o_proj = CustomDeepseekV2RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
return_bias=False,
)
if self.q_lora_rank is not None:
self.fused_qkv_a_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_qkv_a_proj",
disable_tp=True)
self.kv_a_proj_with_mqa = None
else:
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.dim: int = config.hidden_size # 7168
# TODO(zzzzwwjj): wait transformers add these params
self.n_heads: int = 64 # 64
self.head_dim: int = 128 # 128
self.index_topk: int = 2048 # 2048
self.indexer = Indexer(
config,
quant_config=quant_config,
dim=self.dim,
n_heads=self.n_heads,
head_dim=self.head_dim,
index_topk=self.index_topk,
prefix=f"{prefix}.indexer",
)
sfa_modules = AscendSFAModules(
q_a_layernorm=self.q_a_layernorm
if self.q_lora_rank is not None else None,
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
fused_qkv_a_proj=self.fused_qkv_a_proj
if self.q_lora_rank is not None else None,
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
rotary_emb=self.rotary_emb,
indexer=self.indexer,
is_sparse=hasattr(config, "index_topk"),
topk_indices_buffer=None)
if vllm_version_is("0.11.0"):
self.sfa_attn = MultiHeadLatentAttention(
hidden_size=self.hidden_size,
num_heads=self.num_local_heads,
scale=self.scaling,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
mla_modules=sfa_modules,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
)
else:
self.sfa_attn = MultiHeadLatentAttentionWrapper(
hidden_size=self.hidden_size,
num_heads=self.num_local_heads,
scale=self.scaling,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
mla_modules=sfa_modules,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
)
self.prefix = prefix
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
return self.sfa_attn(positions, hidden_states, kv_cache, attn_metadata)
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
def __init__(self,
vllm_config: VllmConfig,
prefix: str,
topk_indices_buffer=None) -> None:
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx
self.layers = config.num_hidden_layers
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tp_group().rank_in_group
# TODO: enable mla in vllm-ascend
if model_config.use_mla:
attn_cls = CustomDeepseekV2SFAAttention
else:
attn_cls = DeepseekV2Attention
self.self_attn = attn_cls(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=config.q_lora_rank
if hasattr(config, "q_lora_rank") else None,
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekV2MoE(
config=config,
parallel_config=parallel_config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
if self.mlp.gate.e_score_correction_bias is not None:
self.mlp.gate.e_score_correction_bias.data = (
self.mlp.gate.e_score_correction_bias.data.to(
dtype=torch.get_default_dtype()))
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
self.first_k_dense_replace = config.first_k_dense_replace
self.tp_group = get_tp_group().device_group
class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
# `packed_modules_mapping` needs to be modified before
# initializing DeepseekV2Model, as it is passed inplace to
# quantization config init and may be used to select the
# quant_method for relevant layers during initialization.
self.fuse_qkv_a_proj = hasattr(
config, "q_lora_rank") and config.q_lora_rank is not None
if self.fuse_qkv_a_proj:
self.packed_modules_mapping["fused_qkv_a_proj"] = [
"q_a_proj",
"kv_a_proj_with_mqa",
]
self.model = AscendDeepseekV2Model(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.expert_weights: list[Any] = []
# Set MoE hyperparameters
self.num_moe_layers = (config.num_hidden_layers -
config.first_k_dense_replace)
self.num_expert_groups = config.n_group
self.moe_layers: list[FusedMoE] = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, DeepseekV2DecoderLayer)
if isinstance(layer.mlp, DeepseekV2MoE):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_moe is None:
raise RuntimeError("No DeepseekV2MoE layer found in model.layers.")
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
# NOTE: This `load_weights` is mainly copied from
# https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5
# to fix CI, and it is different from the implementation in main
# TODO: support eplb style load_weights
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
""""""
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = AscendFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "module" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
return_success=False)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
pass
DeepseekV2DecoderLayer.__init__ = CustomDeepseekV2DecoderLayer.__init__

View File

@@ -52,6 +52,35 @@ else:
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
class IndexerWrapper(nn.Module):
'''
A wrapper of Indexer for Deepseek v3.2.
This wrapper is currently used to solve the fp8 hard code issue of vllm's deepseek_v2.py.
It wraps the original Indexer, inherits its module weights
(including wq_b, wk, weights_proj, k_norm)
while deletes the unused topk_indices_buffer and k_cache to save memory.
TODO: Will be removed once original Indexer supports different quantization methods.
'''
def __init__(self, vllm_indexer: nn.Module) -> None:
super().__init__()
self.n_head: int = vllm_indexer.n_head # 64
self.head_dim: int = vllm_indexer.head_dim # 128
self.topk_tokens: int = vllm_indexer.topk_tokens # 2048
self.q_lora_rank: int = vllm_indexer.q_lora_rank # 1536
self.wq_b = vllm_indexer.wq_b
self.wk = vllm_indexer.wk
self.weights_proj = vllm_indexer.weights_proj
self.k_norm = vllm_indexer.k_norm
self.softmax_scale = vllm_indexer.softmax_scale
vllm_indexer.topk_indices_buffer = None # delete topk_indices_buffer
vllm_indexer.k_cache = None # delete k_cache
def forward(self):
return
# TODO(whx): adapt v0.11.0 and DSA # TODO(whx): adapt v0.11.0 and DSA
class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
@@ -86,6 +115,10 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
self.first_k_dense_replace = hf_config.first_k_dense_replace self.first_k_dense_replace = hf_config.first_k_dense_replace
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.layers = hf_config.num_hidden_layers self.layers = hf_config.num_hidden_layers
if mla_modules.indexer is not None:
ascend_indexer = IndexerWrapper(mla_modules.indexer)
else:
ascend_indexer = None
if vllm_version_is("0.11.0"): if vllm_version_is("0.11.0"):
self.mla_attn = Attention( self.mla_attn = Attention(
@@ -97,6 +130,8 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_mla=True, use_mla=True,
indexer=ascend_indexer,
use_sparse=mla_modules.is_sparse,
# MLA Args # MLA Args
q_lora_rank=self.q_lora_rank, q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank, kv_lora_rank=self.kv_lora_rank,
@@ -128,7 +163,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_sparse=mla_modules.is_sparse, use_sparse=mla_modules.is_sparse,
indexer=mla_modules.indexer, indexer=ascend_indexer,
# extra args # extra args
rotary_emb=mla_modules.rotary_emb, rotary_emb=mla_modules.rotary_emb,
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj, fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,

View File

@@ -1,275 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.mla import MLAModules
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.attention import Attention
from vllm.model_executor.layers.mla import \
MultiHeadLatentAttention as MultiHeadLatentAttentionWrapper
from vllm.utils import direct_register_custom_op
else:
from vllm.attention.layer import MLAAttention
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
from vllm.utils.torch_utils import direct_register_custom_op
@dataclass
class AscendSFAModules:
q_a_layernorm: Optional[torch.nn.Module]
q_proj: Optional[torch.nn.Module]
kv_a_proj_with_mqa: torch.nn.Module
kv_a_layernorm: torch.nn.Module
kv_b_proj: torch.nn.Module
o_proj: torch.nn.Module
rotary_emb: torch.nn.Module
indexer: torch.nn.Module
is_sparse: bool
fused_qkv_a_proj: Optional[torch.nn.Module]
q_b_proj: Optional[torch.nn.Module]
topk_indices_buffer: Optional[torch.Tensor]
class AscendSparseFlashAttention(MultiHeadLatentAttentionWrapper):
def __init__(
self,
hidden_size: int,
num_heads: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
mla_modules: MLAModules,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.q_lora_rank = q_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_rope_head_dim + qk_nope_head_dim
self.v_head_dim = v_head_dim
self.prefix = prefix
self.scaling = scale
self.indexer = mla_modules.indexer
self.is_sparse = mla_modules.is_sparse
hf_config = get_current_vllm_config().model_config.hf_config
self.enable_shared_expert_dp = get_ascend_config(
).enable_shared_expert_dp
self.debug_layer_idx = int(self.prefix.split(".")[-2])
self.first_k_dense_replace = hf_config.first_k_dense_replace
self.tp_size = get_tensor_model_parallel_world_size()
self.layers = hf_config.num_hidden_layers
if vllm_version_is("0.11.0"):
self.sfa_attn = Attention(
num_heads=num_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scale,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=True,
indexer=self.indexer,
# SFA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
qk_head_dim=self.qk_head_dim,
rotary_emb=mla_modules.rotary_emb,
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,
q_b_proj=mla_modules.q_b_proj,
q_a_layernorm=mla_modules.q_a_layernorm,
q_proj=mla_modules.q_proj,
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
kv_a_layernorm=mla_modules.kv_a_layernorm,
kv_b_proj=mla_modules.kv_b_proj,
o_proj=mla_modules.o_proj,
)
else:
self.sfa_attn = MLAAttention(
num_heads=num_heads,
scale=scale,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
kv_b_proj=mla_modules.kv_b_proj,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_sparse=mla_modules.is_sparse,
indexer=mla_modules.indexer,
# extra args
rotary_emb=mla_modules.rotary_emb,
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,
q_b_proj=mla_modules.q_b_proj,
q_a_layernorm=mla_modules.q_a_layernorm,
q_proj=mla_modules.q_proj,
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
kv_a_layernorm=mla_modules.kv_a_layernorm,
o_proj=mla_modules.o_proj,
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
need_gather_q_kv = False
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
# Simulate all gather to calculate output shape
num_tokens = num_tokens * self.tp_size
need_gather_q_kv = True
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
output_shape = hidden_states.shape
else:
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
torch.ops.vllm.sfa_forward(hidden_states, need_gather_q_kv, output,
self.prefix)
output = output.view(-1, output_shape[-1])
return output
def sfa_forward(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
if forward_context.attn_metadata:
attn_metadata = forward_context.attn_metadata[self.sfa_attn.layer_name]
else:
attn_metadata = forward_context.attn_metadata
kv_cache = self.sfa_attn.kv_cache[forward_context.virtual_engine]
self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
need_gather_q_kv, output)
return
class Indexer(nn.Module):
def __init__(self,
config,
dim: int = 7168,
n_heads: int = 64,
head_dim: int = 128,
index_topk: int = 2048,
q_lora_rank: int = 1536,
rope_head_dim: int = 64,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = ""):
super().__init__()
self.dim: int = dim # 7168
self.n_heads: int = n_heads # 64
self.head_dim: int = head_dim # 128
self.rope_head_dim: int = rope_head_dim # 64
self.index_topk: int = index_topk # 2048
self.q_lora_rank: int = q_lora_rank # 1536
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
return_bias=False,
)
self.wk = ReplicatedLinear(
self.dim,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
return_bias=False,
)
self.weights_proj = ReplicatedLinear(
self.dim,
self.n_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.weights_proj",
return_bias=False,
)
self.k_norm = nn.LayerNorm(self.head_dim)
self.softmax_scale = self.head_dim**-0.5
def forward(self):
return
def sfa_forward_fake(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="sfa_forward",
op_func=sfa_forward,
mutates_args=["output"],
fake_impl=sfa_forward_fake,
dispatch_key="PrivateUse1",
)

View File

@@ -33,3 +33,4 @@ from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"): if vllm_version_is("0.11.0"):
import vllm_ascend.patch.worker.patch_deepseek_mtp # noqa import vllm_ascend.patch.worker.patch_deepseek_mtp # noqa
import vllm_ascend.patch.worker.patch_deepseek_v3_2 # noqa

View File

@@ -0,0 +1,108 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from itertools import islice
from typing import Optional, Union
import torch
import vllm.model_executor.models.deepseek_v2
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.vocab_parallel_embedding import \
VocabParallelEmbedding
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
from vllm.model_executor.models.utils import (
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers)
from vllm.sequence import IntermediateTensors
@support_torch_compile
class DeepseekV2Model(nn.Module):
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
self.is_v32 = hasattr(config, "index_topk")
topk_indices_buffer = None
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens")
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
vllm.model_executor.models.deepseek_v2.DeepseekV2Model = DeepseekV2Model

View File

@@ -69,7 +69,6 @@ from vllm.sequence import IntermediateTensors
from vllm_ascend import envs from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.sfa import Indexer
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
@@ -83,6 +82,57 @@ else:
from vllm.attention.layer import MLAAttention from vllm.attention.layer import MLAAttention
class Indexer(nn.Module):
def __init__(self,
config,
dim: int = 7168,
n_heads: int = 64,
head_dim: int = 128,
index_topk: int = 2048,
q_lora_rank: int = 1536,
rope_head_dim: int = 64,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = ""):
super().__init__()
self.dim: int = dim # 7168
self.n_heads: int = n_heads # 64
self.head_dim: int = head_dim # 128
self.rope_head_dim: int = rope_head_dim # 64
self.index_topk: int = index_topk # 2048
self.q_lora_rank: int = q_lora_rank # 1536
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
return_bias=False,
)
self.wk = ReplicatedLinear(
self.dim,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
return_bias=False,
)
self.weights_proj = ReplicatedLinear(
self.dim,
self.n_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.weights_proj",
return_bias=False,
)
self.k_norm = nn.LayerNorm(self.head_dim)
self.softmax_scale = self.head_dim**-0.5
def forward(self):
return
class TorchairDeepseekV2SiluAndMul(SiluAndMul): class TorchairDeepseekV2SiluAndMul(SiluAndMul):
def __init__(self, def __init__(self,

View File

@@ -577,7 +577,6 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
from vllm_ascend.models.layers.sfa import AscendSparseFlashAttention
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE, from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE,
AscendSharedFusedMoE) AscendSharedFusedMoE)
@@ -625,10 +624,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
mla_to_register = "MultiHeadLatentAttention" if vllm_version_is( mla_to_register = "MultiHeadLatentAttention" if vllm_version_is(
"0.11.0") else "MultiHeadLatentAttentionWrapper" "0.11.0") else "MultiHeadLatentAttentionWrapper"
if vllm_config and vllm_config.model_config and vllm_config.model_config.use_mla: if vllm_config and vllm_config.model_config and vllm_config.model_config.use_mla:
AscendMLAAttentionWarrper = AscendSparseFlashAttention if hasattr( REGISTERED_ASCEND_OPS[mla_to_register] = AscendMultiHeadLatentAttention
vllm_config.model_config.hf_config,
"index_topk") else AscendMultiHeadLatentAttention
REGISTERED_ASCEND_OPS[mla_to_register] = AscendMLAAttentionWarrper
for name, op_cls in REGISTERED_ASCEND_OPS.items(): for name, op_cls in REGISTERED_ASCEND_OPS.items():
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name) CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)