[Refactor]refactor 310p attention impl and add ut (#6579)

### What this PR does / why we need it?
This pull request significantly refactors the attention mechanism for
the Ascend 310P hardware, enhancing its architecture by separating mask
generation concerns from the core attention implementation. It
introduces a dedicated mask builder class capable of handling various
mask types, including causal, splitfuse, and sliding window attention
masks, all optimized for the NPU's fractal data format. This change not
only cleans up the codebase but also lays the groundwork for more robust
and feature-rich attention operations on Ascend devices, backed by new,
extensive unit tests.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
E2E test with qwen3 and qwen3-moe
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: pu-zhe <zpuaa@outlook.com>
This commit is contained in:
pu-zhe
2026-02-07 09:26:26 +08:00
committed by GitHub
parent 23524f2ca4
commit 4f33e25046
5 changed files with 487 additions and 135 deletions

View File

@@ -1,5 +1,5 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,43 +17,95 @@
from typing import Any
import torch
import torch_npu
from vllm.v1.attention.backends.registry import ( # type: ignore
AttentionBackendEnum,
register_backend,
)
from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder, build_splitfuse_attn_mask_310p
from vllm_ascend._310p.attention.metadata_builder import AscendAttentionMetadataBuilder310P
from vllm_ascend.attention.attention_v1 import AscendAttentionBackend as _BaseBackend
from vllm_ascend.attention.attention_v1 import AscendAttentionBackendImpl as _BaseImpl
from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder, AscendAttentionState, AscendMetadata
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, nd_to_nz_2d
from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder310
from vllm_ascend._310p.attention.metadata_builder import AscendAttentionMetadataBuilder310
from vllm_ascend.attention.attention_v1 import (
AscendAttentionBackend,
AscendAttentionBackendImpl,
AscendAttentionMetadataBuilder,
AscendAttentionState,
AscendMetadata,
)
class AscendAttentionBackend310(_BaseBackend):
@register_backend(AttentionBackendEnum.CUSTOM, "ASCEND")
class AscendAttentionBackend310(AscendAttentionBackend):
def __init__(self, *args, **kwargs):
"""
Initializes the 310P backend and sets up the device-specific mask builder.
"""
super().__init__(*args, **kwargs)
self.attn_mask_builder = AttentionMaskBuilder(self.device)
self.attn_mask_builder = AttentionMaskBuilder310(self.device)
@staticmethod
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int):
"""
Determines the shape of the Key-Value (KV) cache tensor.
The 310P hardware requires specific memory alignment for optimal performance.
This method defines a 5D tensor shape where the head size dimension is
split to ensure alignment to multiples of 16.
Args:
num_blocks (int): Number of memory blocks.
block_size (int): Size of each block.
num_kv_heads (int): Number of KV heads.
head_size (int): Dimension size of each head.
Returns:
tuple: The specific 5D shape required by the hardware
(2, num_blocks, hidden_dim_aligned, block_size, 16).
"""
# Align to a multiple of 16, as required by the 310P device.
return (2, num_blocks, (num_kv_heads * head_size) // 16, block_size, 16)
@staticmethod
def get_impl_cls():
"""
Returns the implementation class for the attention operations.
"""
return AscendAttentionBackendImpl310
@staticmethod
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
return AscendAttentionMetadataBuilder310P
"""
Returns the metadata builder class specifically for 310P.
"""
return AscendAttentionMetadataBuilder310
class AscendAttentionBackendImpl310(_BaseImpl):
class AscendAttentionBackendImpl310(AscendAttentionBackendImpl):
"""
Implementation of attention operations (Prefill, Decode, Chunked Prefill)
optimized for the Ascend 310P architecture.
"""
def forward_paged_attention(
self,
query: Any,
attn_metadata: AscendMetadata,
output: Any | None = None,
) -> Any:
"""
Executes Paged Attention (typically for the decode phase).
Ensures that the sequence length metadata is on the correct device
before invoking the base implementation.
Args:
query (Any): The query tensor.
attn_metadata (AscendMetadata): Metadata associated with the attention request.
output (Any | None): Optional output tensor.
Returns:
Any: The result of the attention operation.
"""
if attn_metadata.seq_lens.device != query.device:
attn_metadata.seq_lens = attn_metadata.seq_lens.to(
device=query.device,
@@ -61,34 +113,34 @@ class AscendAttentionBackendImpl310(_BaseImpl):
)
return super().forward_paged_attention(query, attn_metadata, output)
def _forward_prefill_310p_fallback(self, query, key, value, attn_metadata, output):
def forward_prefill_310(self, query, key, value, attn_metadata, output):
"""
Executes Flash Attention for the prefill phase on 310P.
This method handles memory alignment padding. If the query shape implies
padding (aligned_tokens > real_tokens), it adjusts the sequence length
of the last request to account for the delta, ensuring the NPU operator
processes the data correctly.
Args:
query, key, value: Input tensors.
attn_metadata (AscendMetadata): Attention metadata containing masks and seq_lens.
output: Output tensor.
Returns:
The output tensor after flash attention.
"""
real_tokens = int(attn_metadata.seq_lens.sum().item())
seq_len = attn_metadata.seq_lens
if seq_len.dtype != torch.int32:
seq_len = seq_len.to(torch.int32)
aligned_tokens = int(query.shape[0])
delta = aligned_tokens - real_tokens
# Adjust sequence length if padding (alignment) was applied to the inputs
if delta:
seq_len = seq_len.clone()
seq_len[-1] += delta
mask = attn_metadata.attn_mask
if mask is not None and mask.dim() == 2:
max_len = int(seq_len.max().item())
aligned_len = ((max_len + 15) // 16) * 16
mask2d = mask[:aligned_len, :aligned_len].contiguous()
mask2d = mask2d.to(torch.float16)
mask_nz = nd_to_nz_2d(mask2d).contiguous()
bsz = int(seq_len.numel())
if bsz > 1:
mask_nz = mask_nz.repeat(bsz, 1, 1, 1).contiguous()
mask = torch_npu.npu_format_cast(mask_nz, ACL_FORMAT_FRACTAL_NZ)
torch_npu._npu_flash_attention(
query=query,
key=key,
@@ -100,43 +152,35 @@ class AscendAttentionBackendImpl310(_BaseImpl):
num_kv_heads=self.num_kv_heads,
out=output,
)
return output
return output[:aligned_tokens, :, :]
def forward_chunked_prefill_310(self, query, attn_metadata, output):
"""
Executes SplitFuse (Chunked Prefill) attention on 310P.
def _forward_chunked_prefill_310p(self, query, attn_metadata, output):
assert attn_metadata is not None
if query.dtype == torch.float32:
query = query.to(torch.float16)
This handles scenarios where the prefill is split into chunks. It prepares
the necessary metadata (query lengths, block tables) and generates the
specific splitfuse mask before calling the NPU operator.
Args:
query: The query tensor.
attn_metadata (AscendMetadata): Metadata containing start locations and block tables.
output: The output tensor.
"""
num_actual_tokens = int(attn_metadata.num_actual_tokens)
query = query[:num_actual_tokens]
output = output[:num_actual_tokens]
qsl_cpu = attn_metadata.query_start_loc.detach().to("cpu", dtype=torch.int32)
qlens = (qsl_cpu[1:] - qsl_cpu[:-1]).to(torch.int32)
# Calculate query lengths from start locations
qsl_cpu = attn_metadata.query_start_loc.cpu()
qlens = qsl_cpu[1:] - qsl_cpu[:-1]
context_lens = attn_metadata.seq_lens
if context_lens.dtype != torch.int32:
context_lens = context_lens.to(torch.int32)
block_table = attn_metadata.block_tables
block_table = attn_metadata.block_tables.detach()
if block_table.dtype != torch.int32:
block_table = block_table.to(torch.int32)
# Generate the specific mask for splitfuse
mask = AttentionMaskBuilder310.get_splitfuse_mask(attn_metadata, query.device)
if not hasattr(self, "_sf_full_mask_cache"):
self._sf_full_mask_cache = None
self._sf_full_mask_cache_len = 0
mask, self._sf_full_mask_cache, self._sf_full_mask_cache_len = build_splitfuse_attn_mask_310p(
attn_metadata,
query.device,
full_mask_cache=self._sf_full_mask_cache,
full_mask_cache_len=int(self._sf_full_mask_cache_len),
)
if qlens.device.type != "cpu":
qlens = qlens.to("cpu")
if context_lens.device != query.device:
context_lens = context_lens.to(query.device, non_blocking=True)
@@ -155,21 +199,35 @@ class AscendAttentionBackendImpl310(_BaseImpl):
)
def forward_impl(self, query, key, value, kv_cache, attn_metadata, output):
"""
Main dispatch method for attention operations.
Routes the execution to Decode, Prefill, or Chunked Prefill methods
based on the current attention state found in metadata.
Args:
query, key, value: Input tensors (Key/Value usually empty for decode/chunked).
kv_cache: The KV cache structure.
attn_metadata: Metadata determining the state (Prefill vs Decode).
output: Tensor to write results to.
Returns:
The output tensor.
Raises:
NotImplementedError: If the attention state is not supported on 310P.
"""
state = attn_metadata.attn_state
if state == AscendAttentionState.DecodeOnly:
return self.forward_paged_attention(query, attn_metadata, output)
if state == AscendAttentionState.PrefillNoCache:
num_tokens = query.shape[0]
q = query[:num_tokens]
k = key[:num_tokens]
v = value[:num_tokens]
out = self._forward_prefill_310p_fallback(q, k, v, attn_metadata, output)
out = self.forward_prefill_310(query, key, value, attn_metadata, output)
return out
if state == AscendAttentionState.ChunkedPrefill:
self._forward_chunked_prefill_310p(query, attn_metadata, output)
self.forward_chunked_prefill_310(query, attn_metadata, output)
return output
raise NotImplementedError(