Refactor AscendMultiHeadLatentAttention (#2826)
### What this PR does / why we need it?
Register AscendMultiHeadLatentAttention as CustomOP, following vllm changes
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: main
- vLLM main:
b23fb78623
---------
Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
@@ -31,7 +31,7 @@ import torch
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@@ -48,6 +48,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import get_sampler
|
from vllm.model_executor.layers.sampler import get_sampler
|
||||||
@@ -68,6 +69,7 @@ from vllm.model_executor.models.utils import (
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.models.layers.mla import AscendMLAModules
|
||||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||||
@@ -529,29 +531,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||||
self.scaling = self.scaling * mscale * mscale
|
self.scaling = self.scaling * mscale * mscale
|
||||||
|
|
||||||
# In the MLA backend, kv_cache includes both k_c and
|
mla_modules = AscendMLAModules(
|
||||||
# pe (i.e. decoupled position embeddings). In particular,
|
|
||||||
# the concat_and_cache_mla op requires
|
|
||||||
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
|
|
||||||
# i.e.
|
|
||||||
# kv_lora_rank + qk_rope_head_dim == head_size
|
|
||||||
self.mla_attn = Attention(
|
|
||||||
num_heads=self.num_local_heads,
|
|
||||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
|
||||||
scale=self.scaling,
|
|
||||||
num_kv_heads=1,
|
|
||||||
cache_config=cache_config,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.attn",
|
|
||||||
use_mla=True,
|
|
||||||
# MLA Args
|
|
||||||
q_lora_rank=self.q_lora_rank,
|
|
||||||
kv_lora_rank=self.kv_lora_rank,
|
|
||||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
|
||||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
|
||||||
qk_head_dim=self.qk_head_dim,
|
|
||||||
v_head_dim=self.v_head_dim,
|
|
||||||
rotary_emb=self.rotary_emb,
|
|
||||||
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
|
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
|
||||||
q_a_layernorm=self.q_a_layernorm
|
q_a_layernorm=self.q_a_layernorm
|
||||||
if self.q_lora_rank is not None else None,
|
if self.q_lora_rank is not None else None,
|
||||||
@@ -560,6 +540,28 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
kv_a_layernorm=self.kv_a_layernorm,
|
kv_a_layernorm=self.kv_a_layernorm,
|
||||||
kv_b_proj=self.kv_b_proj,
|
kv_b_proj=self.kv_b_proj,
|
||||||
o_proj=self.o_proj,
|
o_proj=self.o_proj,
|
||||||
|
rotary_emb=self.rotary_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mla_attn = MultiHeadLatentAttention(
|
||||||
|
self.hidden_size,
|
||||||
|
self.enable_shared_expert_dp,
|
||||||
|
self.debug_layer_idx,
|
||||||
|
self.first_k_dense_replace,
|
||||||
|
self.tp_size,
|
||||||
|
mla_modules,
|
||||||
|
self.num_local_heads,
|
||||||
|
self.scaling,
|
||||||
|
self.layers,
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.qk_rope_head_dim,
|
||||||
|
self.q_lora_rank,
|
||||||
|
self.qk_nope_head_dim,
|
||||||
|
self.qk_head_dim,
|
||||||
|
self.v_head_dim,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -568,30 +570,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor] = None,
|
kv_cache: Optional[torch.Tensor] = None,
|
||||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||||
forward_context = get_forward_context()
|
return self.mla_attn(positions, hidden_states, kv_cache, attn_metadata)
|
||||||
if kv_cache is None:
|
|
||||||
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
|
|
||||||
num_tokens = hidden_states.shape[0]
|
|
||||||
need_gather_q_kv = False
|
|
||||||
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
|
||||||
# Simulate all gather to calculate output shape
|
|
||||||
num_tokens = num_tokens * self.tp_size
|
|
||||||
need_gather_q_kv = True
|
|
||||||
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
|
|
||||||
output_shape = hidden_states.shape
|
|
||||||
else:
|
|
||||||
rows = num_tokens // self.tp_size
|
|
||||||
if num_tokens % self.tp_size:
|
|
||||||
rows += 1
|
|
||||||
output_shape = (rows, hidden_states.shape[1])
|
|
||||||
output = torch.empty(output_shape,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
device=hidden_states.device)
|
|
||||||
output = self.mla_attn.impl.forward(hidden_states, kv_cache,
|
|
||||||
forward_context.attn_metadata,
|
|
||||||
need_gather_q_kv, output)
|
|
||||||
output = output.view(-1, output_shape[-1])
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||||
|
|||||||
0
vllm_ascend/models/layers/__init__.py
Normal file
0
vllm_ascend/models/layers/__init__.py
Normal file
139
vllm_ascend/models/layers/mla.py
Normal file
139
vllm_ascend/models/layers/mla.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
|
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AscendMLAModules:
|
||||||
|
q_a_proj: Optional[torch.nn.Module]
|
||||||
|
q_a_layernorm: Optional[torch.nn.Module]
|
||||||
|
q_proj: Optional[torch.nn.Module]
|
||||||
|
kv_a_proj_with_mqa: torch.nn.Module
|
||||||
|
kv_a_layernorm: torch.nn.Module
|
||||||
|
kv_b_proj: torch.nn.Module
|
||||||
|
o_proj: torch.nn.Module
|
||||||
|
rotary_emb: torch.nn.Module
|
||||||
|
|
||||||
|
|
||||||
|
class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
enable_shared_expert_dp: bool,
|
||||||
|
debug_layer_idx: int,
|
||||||
|
first_k_dense_replace: int,
|
||||||
|
tp_size: int,
|
||||||
|
mla_modules: AscendMLAModules,
|
||||||
|
num_local_heads: int,
|
||||||
|
scaling: float,
|
||||||
|
layers: int,
|
||||||
|
kv_lora_rank: int,
|
||||||
|
qk_rope_head_dim: int,
|
||||||
|
q_lora_rank: Optional[int],
|
||||||
|
qk_nope_head_dim: int,
|
||||||
|
qk_head_dim: int,
|
||||||
|
v_head_dim: int,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||||
|
self.debug_layer_idx = debug_layer_idx
|
||||||
|
self.first_k_dense_replace = first_k_dense_replace
|
||||||
|
self.tp_size = tp_size
|
||||||
|
self.num_local_heads = num_local_heads
|
||||||
|
self.layers = layers
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.qk_head_dim = qk_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
|
||||||
|
self.mla_attn = Attention(
|
||||||
|
num_heads=self.num_local_heads,
|
||||||
|
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
scale=scaling,
|
||||||
|
num_kv_heads=1,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
use_mla=True,
|
||||||
|
# MLA Args
|
||||||
|
q_lora_rank=self.q_lora_rank,
|
||||||
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||||
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||||
|
qk_head_dim=self.qk_head_dim,
|
||||||
|
v_head_dim=self.v_head_dim,
|
||||||
|
rotary_emb=mla_modules.rotary_emb,
|
||||||
|
q_a_proj=mla_modules.q_a_proj,
|
||||||
|
q_a_layernorm=mla_modules.q_a_layernorm,
|
||||||
|
q_proj=mla_modules.q_proj,
|
||||||
|
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
|
||||||
|
kv_a_layernorm=mla_modules.kv_a_layernorm,
|
||||||
|
kv_b_proj=mla_modules.kv_b_proj,
|
||||||
|
o_proj=mla_modules.o_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: Optional[torch.Tensor] = None,
|
||||||
|
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
if kv_cache is None:
|
||||||
|
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
|
||||||
|
num_tokens = hidden_states.shape[0]
|
||||||
|
need_gather_q_kv = False
|
||||||
|
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||||
|
# Simulate all gather to calculate output shape
|
||||||
|
num_tokens = num_tokens * self.tp_size
|
||||||
|
need_gather_q_kv = True
|
||||||
|
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
|
||||||
|
output_shape = hidden_states.shape
|
||||||
|
else:
|
||||||
|
rows = num_tokens // self.tp_size
|
||||||
|
if num_tokens % self.tp_size:
|
||||||
|
rows += 1
|
||||||
|
output_shape = (rows, hidden_states.shape[1])
|
||||||
|
output = torch.empty(output_shape,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device)
|
||||||
|
output = self.mla_attn.impl.forward(hidden_states, kv_cache,
|
||||||
|
forward_context.attn_metadata,
|
||||||
|
need_gather_q_kv, output)
|
||||||
|
output = output.view(-1, output_shape[-1])
|
||||||
|
return output
|
||||||
@@ -529,6 +529,10 @@ def register_ascend_customop():
|
|||||||
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
|
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
|
||||||
CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE")
|
CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE")
|
||||||
|
|
||||||
|
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
|
||||||
|
CustomOp.register_oot(_decorated_op_cls=AscendMultiHeadLatentAttention,
|
||||||
|
name="MultiHeadLatentAttention")
|
||||||
|
|
||||||
# NOTE: Keep this at last to ensure all custom actions are registered
|
# NOTE: Keep this at last to ensure all custom actions are registered
|
||||||
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
|
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user