Refactor AscendMultiHeadLatentAttention (#2826)

### What this PR does / why we need it?
Register AscendMultiHeadLatentAttention as CustomOP, following vllm changes

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with new added/existing test.


- vLLM version: main
- vLLM main:
b23fb78623

---------

Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
Icey
2025-09-10 11:26:11 +08:00
committed by GitHub
parent 168ad600b5
commit aa4d2a91ed
4 changed files with 170 additions and 48 deletions

View File

@@ -31,7 +31,7 @@ import torch
import torch_npu
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -48,6 +48,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import get_sampler
@@ -68,6 +69,7 @@ from vllm.model_executor.models.utils import (
from vllm.sequence import IntermediateTensors
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.mla import AscendMLAModules
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
@@ -529,29 +531,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
# In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular,
# the concat_and_cache_mla op requires
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
self.mla_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=self.scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=self.rotary_emb,
mla_modules = AscendMLAModules(
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
q_a_layernorm=self.q_a_layernorm
if self.q_lora_rank is not None else None,
@@ -560,6 +540,28 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
rotary_emb=self.rotary_emb,
)
self.mla_attn = MultiHeadLatentAttention(
self.hidden_size,
self.enable_shared_expert_dp,
self.debug_layer_idx,
self.first_k_dense_replace,
self.tp_size,
mla_modules,
self.num_local_heads,
self.scaling,
self.layers,
self.kv_lora_rank,
self.qk_rope_head_dim,
self.q_lora_rank,
self.qk_nope_head_dim,
self.qk_head_dim,
self.v_head_dim,
cache_config,
quant_config,
prefix,
)
def forward(
@@ -568,30 +570,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
forward_context = get_forward_context()
if kv_cache is None:
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
num_tokens = hidden_states.shape[0]
need_gather_q_kv = False
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
# Simulate all gather to calculate output shape
num_tokens = num_tokens * self.tp_size
need_gather_q_kv = True
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
output_shape = hidden_states.shape
else:
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
output = torch.empty(output_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
output = self.mla_attn.impl.forward(hidden_states, kv_cache,
forward_context.attn_metadata,
need_gather_q_kv, output)
output = output.view(-1, output_shape[-1])
return output
return self.mla_attn(positions, hidden_states, kv_cache, attn_metadata)
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):