From aa4d2a91ed6450759895ee7fd614eaf336c4722e Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Wed, 10 Sep 2025 11:26:11 +0800 Subject: [PATCH] Refactor AscendMultiHeadLatentAttention (#2826) ### What this PR does / why we need it? Register AscendMultiHeadLatentAttention as CustomOP, following vllm changes ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? CI passed with new added/existing test. - vLLM version: main - vLLM main: https://github.com/vllm-project/vllm/commit/b23fb78623f0019854af7a820d398cba9a316677 --------- Signed-off-by: Icey <1790571317@qq.com> --- vllm_ascend/models/deepseek_v2.py | 75 +++++--------- vllm_ascend/models/layers/__init__.py | 0 vllm_ascend/models/layers/mla.py | 139 ++++++++++++++++++++++++++ vllm_ascend/utils.py | 4 + 4 files changed, 170 insertions(+), 48 deletions(-) create mode 100644 vllm_ascend/models/layers/__init__.py create mode 100644 vllm_ascend/models/layers/mla.py diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 1811255..59273ed 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -31,7 +31,7 @@ import torch import torch_npu from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -48,6 +48,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear, UnquantizedLinearMethod) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mla import MultiHeadLatentAttention from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import get_sampler @@ -68,6 +69,7 @@ from vllm.model_executor.models.utils import ( from vllm.sequence import IntermediateTensors from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.models.layers.mla import AscendMLAModules from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod @@ -529,29 +531,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - # In the MLA backend, kv_cache includes both k_c and - # pe (i.e. decoupled position embeddings). In particular, - # the concat_and_cache_mla op requires - # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) - # i.e. - # kv_lora_rank + qk_rope_head_dim == head_size - self.mla_attn = Attention( - num_heads=self.num_local_heads, - head_size=self.kv_lora_rank + self.qk_rope_head_dim, - scale=self.scaling, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - # MLA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - qk_head_dim=self.qk_head_dim, - v_head_dim=self.v_head_dim, - rotary_emb=self.rotary_emb, + mla_modules = AscendMLAModules( q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None, @@ -560,6 +540,28 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): kv_a_layernorm=self.kv_a_layernorm, kv_b_proj=self.kv_b_proj, o_proj=self.o_proj, + rotary_emb=self.rotary_emb, + ) + + self.mla_attn = MultiHeadLatentAttention( + self.hidden_size, + self.enable_shared_expert_dp, + self.debug_layer_idx, + self.first_k_dense_replace, + self.tp_size, + mla_modules, + self.num_local_heads, + self.scaling, + self.layers, + self.kv_lora_rank, + self.qk_rope_head_dim, + self.q_lora_rank, + self.qk_nope_head_dim, + self.qk_head_dim, + self.v_head_dim, + cache_config, + quant_config, + prefix, ) def forward( @@ -568,30 +570,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): hidden_states: torch.Tensor, kv_cache: Optional[torch.Tensor] = None, attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - forward_context = get_forward_context() - if kv_cache is None: - kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine] - num_tokens = hidden_states.shape[0] - need_gather_q_kv = False - if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: - # Simulate all gather to calculate output shape - num_tokens = num_tokens * self.tp_size - need_gather_q_kv = True - if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: - output_shape = hidden_states.shape - else: - rows = num_tokens // self.tp_size - if num_tokens % self.tp_size: - rows += 1 - output_shape = (rows, hidden_states.shape[1]) - output = torch.empty(output_shape, - dtype=hidden_states.dtype, - device=hidden_states.device) - output = self.mla_attn.impl.forward(hidden_states, kv_cache, - forward_context.attn_metadata, - need_gather_q_kv, output) - output = output.view(-1, output_shape[-1]) - return output + return self.mla_attn(positions, hidden_states, kv_cache, attn_metadata) class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): diff --git a/vllm_ascend/models/layers/__init__.py b/vllm_ascend/models/layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/models/layers/mla.py b/vllm_ascend/models/layers/mla.py new file mode 100644 index 0000000..f28e485 --- /dev/null +++ b/vllm_ascend/models/layers/mla.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.mla import MultiHeadLatentAttention +from vllm.model_executor.layers.quantization import QuantizationConfig + + +@dataclass +class AscendMLAModules: + q_a_proj: Optional[torch.nn.Module] + q_a_layernorm: Optional[torch.nn.Module] + q_proj: Optional[torch.nn.Module] + kv_a_proj_with_mqa: torch.nn.Module + kv_a_layernorm: torch.nn.Module + kv_b_proj: torch.nn.Module + o_proj: torch.nn.Module + rotary_emb: torch.nn.Module + + +class AscendMultiHeadLatentAttention(MultiHeadLatentAttention): + + def __init__( + self, + hidden_size: int, + enable_shared_expert_dp: bool, + debug_layer_idx: int, + first_k_dense_replace: int, + tp_size: int, + mla_modules: AscendMLAModules, + num_local_heads: int, + scaling: float, + layers: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + q_lora_rank: Optional[int], + qk_nope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.enable_shared_expert_dp = enable_shared_expert_dp + self.debug_layer_idx = debug_layer_idx + self.first_k_dense_replace = first_k_dense_replace + self.tp_size = tp_size + self.num_local_heads = num_local_heads + self.layers = layers + self.kv_lora_rank = kv_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.q_lora_rank = q_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + + self.mla_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=mla_modules.rotary_emb, + q_a_proj=mla_modules.q_a_proj, + q_a_layernorm=mla_modules.q_a_layernorm, + q_proj=mla_modules.q_proj, + kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa, + kv_a_layernorm=mla_modules.kv_a_layernorm, + kv_b_proj=mla_modules.kv_b_proj, + o_proj=mla_modules.o_proj, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + forward_context = get_forward_context() + if kv_cache is None: + kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine] + num_tokens = hidden_states.shape[0] + need_gather_q_kv = False + if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + # Simulate all gather to calculate output shape + num_tokens = num_tokens * self.tp_size + need_gather_q_kv = True + if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: + output_shape = hidden_states.shape + else: + rows = num_tokens // self.tp_size + if num_tokens % self.tp_size: + rows += 1 + output_shape = (rows, hidden_states.shape[1]) + output = torch.empty(output_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + output = self.mla_attn.impl.forward(hidden_states, kv_cache, + forward_context.attn_metadata, + need_gather_q_kv, output) + output = output.view(-1, output_shape[-1]) + return output diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 8813b68..a61a5af 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -529,6 +529,10 @@ def register_ascend_customop(): from vllm_ascend.ops.common_fused_moe import AscendFusedMoE CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE") + from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention + CustomOp.register_oot(_decorated_op_cls=AscendMultiHeadLatentAttention, + name="MultiHeadLatentAttention") + # NOTE: Keep this at last to ensure all custom actions are registered _ASCEND_CUSTOMOP_IS_REIGISTERED = True