[Refactor]refactor 310p attention impl and add ut (#6579)
### What this PR does / why we need it?
This pull request significantly refactors the attention mechanism for
the Ascend 310P hardware, enhancing its architecture by separating mask
generation concerns from the core attention implementation. It
introduces a dedicated mask builder class capable of handling various
mask types, including causal, splitfuse, and sliding window attention
masks, all optimized for the NPU's fractal data format. This change not
only cleans up the codebase but also lays the groundwork for more robust
and feature-rich attention operations on Ascend devices, backed by new,
extensive unit tests.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
E2E test with qwen3 and qwen3-moe
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd
---------
Signed-off-by: pu-zhe <zpuaa@outlook.com>
This commit is contained in:
61
tests/ut/_310p/attention/test_attention_mask_310.py
Normal file
61
tests/ut/_310p/attention/test_attention_mask_310.py
Normal file
@@ -0,0 +1,61 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder310
|
||||
|
||||
|
||||
class TestAttentionMaskBuilder310(TestBase):
|
||||
def setUp(self):
|
||||
self.attention_mask_builder = AttentionMaskBuilder310(torch.device("cpu"))
|
||||
|
||||
def test_get_attention_mask_310_for_pooling_model(self):
|
||||
model_config = MagicMock()
|
||||
model_config.runner_type = "pooling"
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.attention_mask_builder.get_attention_mask(model_config)
|
||||
|
||||
@patch("torch_npu.npu_format_cast")
|
||||
def test_get_attention_mask_310(self, mock_format_cast):
|
||||
mock_format_cast.side_effect = lambda x, y: x
|
||||
model_config = MagicMock()
|
||||
attn_mask = self.attention_mask_builder.get_attention_mask(model_config)
|
||||
self.assertEqual(attn_mask.shape, (1, 128, 2048, 16))
|
||||
self.assertEqual(attn_mask[0][-1][0][-1], torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
|
||||
@patch("torch_npu.npu_format_cast")
|
||||
def test_get_swa_mask_310(self, mock_format_cast):
|
||||
mock_format_cast.side_effect = lambda x, y: x
|
||||
swa_mask = self.attention_mask_builder.get_swa_mask(torch.float16, None)
|
||||
self.assertIsNone(swa_mask)
|
||||
|
||||
sliding_window = 128
|
||||
swa_mask = self.attention_mask_builder.get_swa_mask(torch.float16, sliding_window)
|
||||
self.assertEqual(swa_mask.shape, (1, 128, 2048, 16))
|
||||
self.assertEqual(swa_mask[0][-1][0][-1], torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
self.assertEqual(swa_mask[0][0][-1][0], torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
|
||||
@patch("torch_npu.npu_format_cast")
|
||||
def test_get_splitfuse_attn_mask_310(self, mock_format_cast):
|
||||
mock_format_cast.side_effect = lambda x, y: x
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.query_start_loc = torch.tensor([0, 1, 5])
|
||||
attn_metadata.seq_lens = torch.tensor([7, 4])
|
||||
attn_mask = self.attention_mask_builder.get_splitfuse_mask(attn_metadata, torch.device("cpu"))
|
||||
self.assertEqual(attn_mask.shape, (1, 128, 16, 16))
|
||||
160
tests/ut/_310p/attention/test_attention_v1_310.py
Normal file
160
tests/ut/_310p/attention/test_attention_v1_310.py
Normal file
@@ -0,0 +1,160 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend._310p.attention.attention_v1 import (
|
||||
AscendAttentionBackend310,
|
||||
AscendAttentionBackendImpl310,
|
||||
AscendAttentionMetadataBuilder310,
|
||||
AscendAttentionState,
|
||||
)
|
||||
|
||||
|
||||
class TestAscendAttentionBackend310(TestBase):
|
||||
def setUp(self):
|
||||
self.mock_config = MagicMock()
|
||||
self.utils_patcher = patch("vllm_ascend.attention.utils.get_current_vllm_config", return_value=self.mock_config)
|
||||
self.utils_patcher.start()
|
||||
|
||||
def test_get_impl_cls(self):
|
||||
self.assertEqual(AscendAttentionBackend310.get_impl_cls(), AscendAttentionBackendImpl310)
|
||||
|
||||
def test_get_builder_cls(self):
|
||||
self.assertEqual(AscendAttentionBackend310.get_builder_cls(), AscendAttentionMetadataBuilder310)
|
||||
|
||||
def test_get_kv_cache_shape_not(self):
|
||||
result = AscendAttentionBackend310.get_kv_cache_shape(10, 20, 30, 40)
|
||||
self.assertEqual(result, (2, 10, 75, 20, 16))
|
||||
|
||||
|
||||
class TestAscendAttentionBackendImpl310(TestBase):
|
||||
def setUp(self):
|
||||
self.attention_type = MagicMock()
|
||||
self.attention_type.DECODER = "decoder"
|
||||
self.attention_type.ENCODER = "encoder"
|
||||
self.attn_metadata = MagicMock()
|
||||
self.attn_metadata.return_value = "1"
|
||||
self.mock_vllm_config = MagicMock()
|
||||
self.layer_no_quant = MagicMock(spec=["layer_name", "_k_scale_float", "_v_scale_float"])
|
||||
self.layer_no_quant.layer_name = "test_layer"
|
||||
self.layer_no_quant._k_scale_float = 1.0
|
||||
self.layer_no_quant._v_scale_float = 1.0
|
||||
self.config_patcher = patch(
|
||||
"vllm_ascend.attention.attention_v1.get_current_vllm_config", return_value=self.mock_vllm_config
|
||||
)
|
||||
self.config_patcher.start()
|
||||
self.impl = AscendAttentionBackendImpl310(
|
||||
num_heads=8,
|
||||
head_size=128,
|
||||
scale=1.0,
|
||||
num_kv_heads=8,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="float16",
|
||||
logits_soft_cap=None,
|
||||
attn_type=self.attention_type.DECODER,
|
||||
kv_sharing_target_layer_name=None,
|
||||
)
|
||||
|
||||
@patch("torch_npu._npu_reshape_and_cache")
|
||||
@patch("torch_npu._npu_flash_attention")
|
||||
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
||||
def test_forward_prefill_310(
|
||||
self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache
|
||||
):
|
||||
"""Test forward pass in PrefillCacheHit state"""
|
||||
query = torch.randn(10, 8, 64)
|
||||
key = torch.randn(10, 8, 64)
|
||||
value = torch.randn(10, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.actual_seq_lengths_q = [10]
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.num_decode_tokens = 0
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_npu_flash_attention.return_value = torch.ones(10, 8, 64)
|
||||
output = self.impl.forward_prefill_310(query, key, value, metadata, output)
|
||||
|
||||
mock_npu_npu_flash_attention.assert_called_once()
|
||||
|
||||
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
|
||||
@patch("torch_npu._npu_reshape_and_cache")
|
||||
@patch("torch_npu._npu_paged_attention_splitfuse")
|
||||
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
||||
def test_forward_chunked_prefill_310(
|
||||
self, mock_get_forward_context, mock_npu_paged_attention_splitfuse, mock_npu_reshape_and_cache, mock_format_cast
|
||||
):
|
||||
"""Test forward pass in PrefillCacheHit state"""
|
||||
query = torch.randn(5, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
||||
metadata.attn_mask = torch.randn(1, 128, 16, 16)
|
||||
metadata.query_lens = torch.tensor([5])
|
||||
metadata.seq_lens = torch.tensor([1, 4])
|
||||
metadata.query_start_loc = torch.tensor([0, 1, 5])
|
||||
metadata.actual_seq_lengths_q = [5]
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.num_decode_tokens = 0
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64)
|
||||
output = self.impl.forward_chunked_prefill_310(query, metadata, output)
|
||||
|
||||
mock_npu_paged_attention_splitfuse.assert_called_once()
|
||||
|
||||
@patch("vllm_ascend.attention.attention_v1.using_paged_attention")
|
||||
@patch("torch_npu._npu_paged_attention")
|
||||
@patch("torch_npu._npu_reshape_and_cache")
|
||||
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
||||
def test_forward_paged_attention_310(
|
||||
self, mock_get_forward_context, mock_npu_reshape_and_cache, mock_paged_attention, mock_using_paged_attention
|
||||
):
|
||||
"""Test forward pass in DecodeOnly state"""
|
||||
query = torch.randn(4, 8 * 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.seq_lens = torch.tensor([4])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 4
|
||||
metadata.slot_mapping = torch.zeros(4, dtype=torch.long)
|
||||
metadata.num_decodes = 4
|
||||
metadata.num_prefills = 0
|
||||
mock_using_paged_attention.return_value = True
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
output = self.impl.forward_paged_attention(query, metadata, output)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -15,84 +15,140 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
import vllm_ascend.attention.attention_mask as _base_mask
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, nd_to_nz_spec
|
||||
|
||||
_BASE_BUILDER: Callable[[torch.device], Any] = _base_mask.AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendMetadata
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, nd_to_nz_2d, nd_to_nz_spec
|
||||
|
||||
|
||||
def _gen_causal_additive_mask_fp16(max_seq_len: int, device: torch.device) -> torch.Tensor:
|
||||
tril = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool, device=device).tril_()
|
||||
upper = ~tril
|
||||
m = torch.zeros((max_seq_len, max_seq_len), dtype=torch.float16, device=device)
|
||||
m.masked_fill_(upper, float("-inf"))
|
||||
return m
|
||||
|
||||
|
||||
def build_splitfuse_attn_mask_310p(attn_metadata, device, *, full_mask_cache=None, full_mask_cache_len=0):
|
||||
qsl = attn_metadata.query_start_loc.detach().to("cpu", dtype=torch.int32)
|
||||
qlens = qsl[1:] - qsl[:-1]
|
||||
|
||||
context_lens = attn_metadata.seq_lens.to(dtype=torch.int32)
|
||||
L = int(context_lens.max().item())
|
||||
|
||||
q_list = qlens.tolist()
|
||||
c_list = context_lens.detach().to("cpu", dtype=torch.int64).tolist()
|
||||
pos_list = [p for ql, cl in zip(q_list, c_list) for p in range(cl - ql, cl)]
|
||||
position = torch.tensor(pos_list, dtype=torch.long, device=device)
|
||||
|
||||
if full_mask_cache is None or full_mask_cache.device != device or full_mask_cache_len < L:
|
||||
tril = torch.ones((L, L), dtype=torch.bool, device=device).tril_()
|
||||
full = torch.zeros((L, L), dtype=torch.float16, device=device)
|
||||
full.masked_fill_(~tril, float("-inf"))
|
||||
full_mask_cache, full_mask_cache_len = full, L
|
||||
else:
|
||||
full = full_mask_cache[:L, :L].contiguous()
|
||||
|
||||
rows = full.index_select(0, position).contiguous()
|
||||
mask = torch_npu.npu_format_cast(nd_to_nz_spec(rows).contiguous(), ACL_FORMAT_FRACTAL_NZ)
|
||||
return mask, full_mask_cache, full_mask_cache_len
|
||||
|
||||
|
||||
class _AttentionMaskBuilder310P:
|
||||
"""
|
||||
310P adapter:
|
||||
- overrides fp16 causal additive mask generation (use -inf fp16)
|
||||
- delegates all other behaviors to base AttentionMaskBuilder
|
||||
- pooling runner_type is NOT supported on 310P (explicit)
|
||||
"""
|
||||
class AttentionMaskBuilder310:
|
||||
chunked_prefill_attn_mask = None
|
||||
max_seqlen = 2048
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
self._base = _BASE_BUILDER(device)
|
||||
"""
|
||||
Initializes the AttentionMaskBuilder for the 310P device.
|
||||
|
||||
self._fp16_mask_cache: torch.Tensor | None = None
|
||||
self._fp16_mask_cached_len: int = 0
|
||||
Args:
|
||||
device (torch.device): The device on which tensors will be allocated.
|
||||
"""
|
||||
self.attn_mask_cache = None
|
||||
self.device = device
|
||||
self.swa_mask = None
|
||||
self._seq_len_cached = 0
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._base, name)
|
||||
@staticmethod
|
||||
def gen_causal_additive_mask(max_seq_len: int, device: torch.device):
|
||||
"""
|
||||
Generates a standard causal lower-triangular attention mask.
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._base.device
|
||||
The upper triangular part is filled with negative infinity (float("-inf"))
|
||||
to mask out future tokens, while the lower triangular part is kept as 0.
|
||||
|
||||
def _get_fp16_mask(self, max_seq_len: int) -> torch.Tensor:
|
||||
if self._fp16_mask_cache is None or max_seq_len > self._fp16_mask_cached_len:
|
||||
self._fp16_mask_cache = _gen_causal_additive_mask_fp16(max_seq_len, self.device)
|
||||
self._fp16_mask_cached_len = max_seq_len
|
||||
assert self._fp16_mask_cache is not None
|
||||
return self._fp16_mask_cache[:max_seq_len, :max_seq_len].contiguous()
|
||||
Args:
|
||||
max_seq_len (int): The maximum sequence length for the mask.
|
||||
device (torch.device): The target device for the tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A float16 tensor representing the causal mask.
|
||||
"""
|
||||
tril = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool, device=device).tril_()
|
||||
upper = ~tril
|
||||
mask = torch.zeros((max_seq_len, max_seq_len), dtype=torch.float16, device=device)
|
||||
mask.masked_fill_(upper, float("-inf"))
|
||||
return mask
|
||||
|
||||
@classmethod
|
||||
def get_splitfuse_mask(cls, attn_metadata: AscendMetadata, device: torch.device):
|
||||
"""
|
||||
Generates and formats the attention mask for SplitFuse (chunked prefill) decoding.
|
||||
|
||||
It calculates the specific indices required based on query start locations
|
||||
and context lengths, selects the relevant parts from the global chunked
|
||||
mask, and converts the result to the NPU-specific fractal format.
|
||||
|
||||
Args:
|
||||
attn_metadata (AscendMetadata): Metadata containing query start locations and sequence lengths.
|
||||
device (torch.device): The device to perform operations on.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The splitfuse attention mask cast to ACL_FORMAT_FRACTAL_NZ.
|
||||
"""
|
||||
if cls.chunked_prefill_attn_mask is None:
|
||||
cls.chunked_prefill_attn_mask = cls.gen_causal_additive_mask(cls.max_seqlen, device)
|
||||
qsl = attn_metadata.query_start_loc.to("cpu", dtype=torch.int32)
|
||||
qlens = qsl[1:] - qsl[:-1]
|
||||
q_list = qlens.tolist()
|
||||
context_lens = attn_metadata.seq_lens.to("cpu", dtype=torch.int32)
|
||||
c_list = context_lens.tolist()
|
||||
pos_list = [p for ql, cl in zip(q_list, c_list) for p in range(cl - ql, cl)]
|
||||
position = torch.tensor(pos_list, dtype=torch.int32, device=device)
|
||||
splitfuse_mask = cls.chunked_prefill_attn_mask.index_select(0, position)
|
||||
splitfuse_mask_nz = torch_npu.npu_format_cast(nd_to_nz_spec(splitfuse_mask).contiguous(), ACL_FORMAT_FRACTAL_NZ)
|
||||
return splitfuse_mask_nz
|
||||
|
||||
def get_swa_mask(self, dtype: torch.dtype, sliding_window):
|
||||
"""
|
||||
Generates or retrieves a cached Sliding Window Attention (SWA) mask.
|
||||
|
||||
This mask allows attention only within a specific local window (diagonal band),
|
||||
masking out tokens that are too far in the past or in the future.
|
||||
|
||||
Args:
|
||||
dtype (torch.dtype): Data type of the mask.
|
||||
sliding_window (int): The size of the sliding window.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The SWA mask cast to ACL_FORMAT_FRACTAL_NZ.
|
||||
"""
|
||||
assert dtype == torch.float16
|
||||
if sliding_window is not None and self.swa_mask is None:
|
||||
mask = torch.ones(self.max_seqlen, self.max_seqlen, dtype=torch.bool)
|
||||
triu_mask = torch.triu(mask, diagonal=1).to(self.device)
|
||||
tril_mask = torch.tril(mask, -sliding_window).to(self.device)
|
||||
mask = triu_mask + tril_mask
|
||||
swa_mask = torch.zeros((self.max_seqlen, self.max_seqlen), dtype=dtype, device=self.device)
|
||||
swa_mask.masked_fill_(mask, float("-inf"))
|
||||
self.swa_mask = torch_npu.npu_format_cast(nd_to_nz_2d(swa_mask), ACL_FORMAT_FRACTAL_NZ)
|
||||
return self.swa_mask
|
||||
|
||||
def get_attention_mask(self, model_config) -> torch.Tensor:
|
||||
"""
|
||||
Retrieves the appropriate attention mask based on the model configuration.
|
||||
|
||||
It explicitly checks for 'pooling' runner types which are not supported
|
||||
on 310P hardware.
|
||||
|
||||
Args:
|
||||
model_config: Configuration object containing runner details.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The causal attention mask.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the runner_type is 'pooling'.
|
||||
"""
|
||||
if getattr(model_config, "runner_type", None) == "pooling":
|
||||
# TODO: pooling model will be supported soon.
|
||||
raise NotImplementedError("310P does not support runner_type='pooling'")
|
||||
return self._get_fp16_mask(2048)
|
||||
return self._get_causal_mask(self.max_seqlen)
|
||||
|
||||
def _get_causal_mask(self, max_seq_len: int) -> torch.Tensor:
|
||||
"""
|
||||
Internal method to get or update the cached causal attention mask.
|
||||
|
||||
def AttentionMaskBuilder(device: torch.device) -> _AttentionMaskBuilder310P:
|
||||
return _AttentionMaskBuilder310P(device)
|
||||
If the cache is empty or the requested length exceeds the cached length,
|
||||
a new mask is generated and converted to the NPU fractal format.
|
||||
|
||||
Args:
|
||||
max_seq_len (int): The required sequence length.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The cached causal mask in ACL_FORMAT_FRACTAL_NZ.
|
||||
"""
|
||||
if self.attn_mask_cache is None or max_seq_len > self._seq_len_cached:
|
||||
attn_mask = self.gen_causal_additive_mask(max_seq_len, self.device)
|
||||
self.attn_mask_cache = torch_npu.npu_format_cast(nd_to_nz_2d(attn_mask), ACL_FORMAT_FRACTAL_NZ)
|
||||
self._seq_len_cached = max_seq_len
|
||||
return self.attn_mask_cache
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -17,43 +17,95 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.v1.attention.backends.registry import ( # type: ignore
|
||||
AttentionBackendEnum,
|
||||
register_backend,
|
||||
)
|
||||
|
||||
from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder, build_splitfuse_attn_mask_310p
|
||||
from vllm_ascend._310p.attention.metadata_builder import AscendAttentionMetadataBuilder310P
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionBackend as _BaseBackend
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionBackendImpl as _BaseImpl
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder, AscendAttentionState, AscendMetadata
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, nd_to_nz_2d
|
||||
from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder310
|
||||
from vllm_ascend._310p.attention.metadata_builder import AscendAttentionMetadataBuilder310
|
||||
from vllm_ascend.attention.attention_v1 import (
|
||||
AscendAttentionBackend,
|
||||
AscendAttentionBackendImpl,
|
||||
AscendAttentionMetadataBuilder,
|
||||
AscendAttentionState,
|
||||
AscendMetadata,
|
||||
)
|
||||
|
||||
|
||||
class AscendAttentionBackend310(_BaseBackend):
|
||||
@register_backend(AttentionBackendEnum.CUSTOM, "ASCEND")
|
||||
class AscendAttentionBackend310(AscendAttentionBackend):
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Initializes the 310P backend and sets up the device-specific mask builder.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||
self.attn_mask_builder = AttentionMaskBuilder310(self.device)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int):
|
||||
"""
|
||||
Determines the shape of the Key-Value (KV) cache tensor.
|
||||
|
||||
The 310P hardware requires specific memory alignment for optimal performance.
|
||||
This method defines a 5D tensor shape where the head size dimension is
|
||||
split to ensure alignment to multiples of 16.
|
||||
|
||||
Args:
|
||||
num_blocks (int): Number of memory blocks.
|
||||
block_size (int): Size of each block.
|
||||
num_kv_heads (int): Number of KV heads.
|
||||
head_size (int): Dimension size of each head.
|
||||
|
||||
Returns:
|
||||
tuple: The specific 5D shape required by the hardware
|
||||
(2, num_blocks, hidden_dim_aligned, block_size, 16).
|
||||
"""
|
||||
# Align to a multiple of 16, as required by the 310P device.
|
||||
return (2, num_blocks, (num_kv_heads * head_size) // 16, block_size, 16)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls():
|
||||
"""
|
||||
Returns the implementation class for the attention operations.
|
||||
"""
|
||||
return AscendAttentionBackendImpl310
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
||||
return AscendAttentionMetadataBuilder310P
|
||||
"""
|
||||
Returns the metadata builder class specifically for 310P.
|
||||
"""
|
||||
return AscendAttentionMetadataBuilder310
|
||||
|
||||
|
||||
class AscendAttentionBackendImpl310(_BaseImpl):
|
||||
class AscendAttentionBackendImpl310(AscendAttentionBackendImpl):
|
||||
"""
|
||||
Implementation of attention operations (Prefill, Decode, Chunked Prefill)
|
||||
optimized for the Ascend 310P architecture.
|
||||
"""
|
||||
|
||||
def forward_paged_attention(
|
||||
self,
|
||||
query: Any,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Any | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Executes Paged Attention (typically for the decode phase).
|
||||
|
||||
Ensures that the sequence length metadata is on the correct device
|
||||
before invoking the base implementation.
|
||||
|
||||
Args:
|
||||
query (Any): The query tensor.
|
||||
attn_metadata (AscendMetadata): Metadata associated with the attention request.
|
||||
output (Any | None): Optional output tensor.
|
||||
|
||||
Returns:
|
||||
Any: The result of the attention operation.
|
||||
"""
|
||||
if attn_metadata.seq_lens.device != query.device:
|
||||
attn_metadata.seq_lens = attn_metadata.seq_lens.to(
|
||||
device=query.device,
|
||||
@@ -61,34 +113,34 @@ class AscendAttentionBackendImpl310(_BaseImpl):
|
||||
)
|
||||
return super().forward_paged_attention(query, attn_metadata, output)
|
||||
|
||||
def _forward_prefill_310p_fallback(self, query, key, value, attn_metadata, output):
|
||||
def forward_prefill_310(self, query, key, value, attn_metadata, output):
|
||||
"""
|
||||
Executes Flash Attention for the prefill phase on 310P.
|
||||
|
||||
This method handles memory alignment padding. If the query shape implies
|
||||
padding (aligned_tokens > real_tokens), it adjusts the sequence length
|
||||
of the last request to account for the delta, ensuring the NPU operator
|
||||
processes the data correctly.
|
||||
|
||||
Args:
|
||||
query, key, value: Input tensors.
|
||||
attn_metadata (AscendMetadata): Attention metadata containing masks and seq_lens.
|
||||
output: Output tensor.
|
||||
|
||||
Returns:
|
||||
The output tensor after flash attention.
|
||||
"""
|
||||
real_tokens = int(attn_metadata.seq_lens.sum().item())
|
||||
|
||||
seq_len = attn_metadata.seq_lens
|
||||
if seq_len.dtype != torch.int32:
|
||||
seq_len = seq_len.to(torch.int32)
|
||||
|
||||
aligned_tokens = int(query.shape[0])
|
||||
delta = aligned_tokens - real_tokens
|
||||
|
||||
# Adjust sequence length if padding (alignment) was applied to the inputs
|
||||
if delta:
|
||||
seq_len = seq_len.clone()
|
||||
seq_len[-1] += delta
|
||||
|
||||
mask = attn_metadata.attn_mask
|
||||
if mask is not None and mask.dim() == 2:
|
||||
max_len = int(seq_len.max().item())
|
||||
aligned_len = ((max_len + 15) // 16) * 16
|
||||
|
||||
mask2d = mask[:aligned_len, :aligned_len].contiguous()
|
||||
mask2d = mask2d.to(torch.float16)
|
||||
mask_nz = nd_to_nz_2d(mask2d).contiguous()
|
||||
|
||||
bsz = int(seq_len.numel())
|
||||
if bsz > 1:
|
||||
mask_nz = mask_nz.repeat(bsz, 1, 1, 1).contiguous()
|
||||
|
||||
mask = torch_npu.npu_format_cast(mask_nz, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
torch_npu._npu_flash_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
@@ -100,43 +152,35 @@ class AscendAttentionBackendImpl310(_BaseImpl):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output,
|
||||
)
|
||||
return output
|
||||
|
||||
return output[:aligned_tokens, :, :]
|
||||
def forward_chunked_prefill_310(self, query, attn_metadata, output):
|
||||
"""
|
||||
Executes SplitFuse (Chunked Prefill) attention on 310P.
|
||||
|
||||
def _forward_chunked_prefill_310p(self, query, attn_metadata, output):
|
||||
assert attn_metadata is not None
|
||||
|
||||
if query.dtype == torch.float32:
|
||||
query = query.to(torch.float16)
|
||||
This handles scenarios where the prefill is split into chunks. It prepares
|
||||
the necessary metadata (query lengths, block tables) and generates the
|
||||
specific splitfuse mask before calling the NPU operator.
|
||||
|
||||
Args:
|
||||
query: The query tensor.
|
||||
attn_metadata (AscendMetadata): Metadata containing start locations and block tables.
|
||||
output: The output tensor.
|
||||
"""
|
||||
num_actual_tokens = int(attn_metadata.num_actual_tokens)
|
||||
query = query[:num_actual_tokens]
|
||||
output = output[:num_actual_tokens]
|
||||
|
||||
qsl_cpu = attn_metadata.query_start_loc.detach().to("cpu", dtype=torch.int32)
|
||||
qlens = (qsl_cpu[1:] - qsl_cpu[:-1]).to(torch.int32)
|
||||
# Calculate query lengths from start locations
|
||||
qsl_cpu = attn_metadata.query_start_loc.cpu()
|
||||
qlens = qsl_cpu[1:] - qsl_cpu[:-1]
|
||||
|
||||
context_lens = attn_metadata.seq_lens
|
||||
if context_lens.dtype != torch.int32:
|
||||
context_lens = context_lens.to(torch.int32)
|
||||
block_table = attn_metadata.block_tables
|
||||
|
||||
block_table = attn_metadata.block_tables.detach()
|
||||
if block_table.dtype != torch.int32:
|
||||
block_table = block_table.to(torch.int32)
|
||||
# Generate the specific mask for splitfuse
|
||||
mask = AttentionMaskBuilder310.get_splitfuse_mask(attn_metadata, query.device)
|
||||
|
||||
if not hasattr(self, "_sf_full_mask_cache"):
|
||||
self._sf_full_mask_cache = None
|
||||
self._sf_full_mask_cache_len = 0
|
||||
|
||||
mask, self._sf_full_mask_cache, self._sf_full_mask_cache_len = build_splitfuse_attn_mask_310p(
|
||||
attn_metadata,
|
||||
query.device,
|
||||
full_mask_cache=self._sf_full_mask_cache,
|
||||
full_mask_cache_len=int(self._sf_full_mask_cache_len),
|
||||
)
|
||||
|
||||
if qlens.device.type != "cpu":
|
||||
qlens = qlens.to("cpu")
|
||||
if context_lens.device != query.device:
|
||||
context_lens = context_lens.to(query.device, non_blocking=True)
|
||||
|
||||
@@ -155,21 +199,35 @@ class AscendAttentionBackendImpl310(_BaseImpl):
|
||||
)
|
||||
|
||||
def forward_impl(self, query, key, value, kv_cache, attn_metadata, output):
|
||||
"""
|
||||
Main dispatch method for attention operations.
|
||||
|
||||
Routes the execution to Decode, Prefill, or Chunked Prefill methods
|
||||
based on the current attention state found in metadata.
|
||||
|
||||
Args:
|
||||
query, key, value: Input tensors (Key/Value usually empty for decode/chunked).
|
||||
kv_cache: The KV cache structure.
|
||||
attn_metadata: Metadata determining the state (Prefill vs Decode).
|
||||
output: Tensor to write results to.
|
||||
|
||||
Returns:
|
||||
The output tensor.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the attention state is not supported on 310P.
|
||||
"""
|
||||
state = attn_metadata.attn_state
|
||||
|
||||
if state == AscendAttentionState.DecodeOnly:
|
||||
return self.forward_paged_attention(query, attn_metadata, output)
|
||||
|
||||
if state == AscendAttentionState.PrefillNoCache:
|
||||
num_tokens = query.shape[0]
|
||||
q = query[:num_tokens]
|
||||
k = key[:num_tokens]
|
||||
v = value[:num_tokens]
|
||||
out = self._forward_prefill_310p_fallback(q, k, v, attn_metadata, output)
|
||||
out = self.forward_prefill_310(query, key, value, attn_metadata, output)
|
||||
return out
|
||||
|
||||
if state == AscendAttentionState.ChunkedPrefill:
|
||||
self._forward_chunked_prefill_310p(query, attn_metadata, output)
|
||||
self.forward_chunked_prefill_310(query, attn_metadata, output)
|
||||
return output
|
||||
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -15,19 +15,26 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder as _BaseBuilder
|
||||
from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder310
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder310P(_BaseBuilder):
|
||||
class AscendAttentionMetadataBuilder310(AscendAttentionMetadataBuilder):
|
||||
"""
|
||||
Metadata builder specialized for the Huawei Ascend 310P NPU.
|
||||
|
||||
This class extends the base Ascend attention metadata builder to use
|
||||
the 310P-specific attention mask builder, ensuring that masks are
|
||||
generated in the correct format (FRACTAL_NZ) and logic required by
|
||||
the 310P hardware.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
@@ -35,6 +42,16 @@ class AscendAttentionMetadataBuilder310P(_BaseBuilder):
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
"""
|
||||
Initializes the metadata builder and the 310P-specific mask builder.
|
||||
|
||||
Args:
|
||||
kv_cache_spec (AttentionSpec): Specification for the KV cache (block size, etc.).
|
||||
layer_names (list[str]): List of layer names in the model.
|
||||
vllm_config (VllmConfig): Global vLLM configuration object.
|
||||
device (torch.device): The device (NPU) to run operations on.
|
||||
"""
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.attn_mask_builder: Any = AttentionMaskBuilder(self.device)
|
||||
# Override the mask builder with the 310P-specific version
|
||||
self.attn_mask_builder: Any = AttentionMaskBuilder310(self.device)
|
||||
|
||||
Reference in New Issue
Block a user