diff --git a/tests/ut/_310p/attention/test_attention_mask_310.py b/tests/ut/_310p/attention/test_attention_mask_310.py new file mode 100644 index 00000000..e3b1f284 --- /dev/null +++ b/tests/ut/_310p/attention/test_attention_mask_310.py @@ -0,0 +1,61 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder310 + + +class TestAttentionMaskBuilder310(TestBase): + def setUp(self): + self.attention_mask_builder = AttentionMaskBuilder310(torch.device("cpu")) + + def test_get_attention_mask_310_for_pooling_model(self): + model_config = MagicMock() + model_config.runner_type = "pooling" + with self.assertRaises(NotImplementedError): + self.attention_mask_builder.get_attention_mask(model_config) + + @patch("torch_npu.npu_format_cast") + def test_get_attention_mask_310(self, mock_format_cast): + mock_format_cast.side_effect = lambda x, y: x + model_config = MagicMock() + attn_mask = self.attention_mask_builder.get_attention_mask(model_config) + self.assertEqual(attn_mask.shape, (1, 128, 2048, 16)) + self.assertEqual(attn_mask[0][-1][0][-1], torch.tensor(float("-inf"), dtype=torch.float16)) + + @patch("torch_npu.npu_format_cast") + def test_get_swa_mask_310(self, mock_format_cast): + mock_format_cast.side_effect = lambda x, y: x + swa_mask = self.attention_mask_builder.get_swa_mask(torch.float16, None) + self.assertIsNone(swa_mask) + + sliding_window = 128 + swa_mask = self.attention_mask_builder.get_swa_mask(torch.float16, sliding_window) + self.assertEqual(swa_mask.shape, (1, 128, 2048, 16)) + self.assertEqual(swa_mask[0][-1][0][-1], torch.tensor(float("-inf"), dtype=torch.float16)) + self.assertEqual(swa_mask[0][0][-1][0], torch.tensor(float("-inf"), dtype=torch.float16)) + + @patch("torch_npu.npu_format_cast") + def test_get_splitfuse_attn_mask_310(self, mock_format_cast): + mock_format_cast.side_effect = lambda x, y: x + attn_metadata = MagicMock() + attn_metadata.query_start_loc = torch.tensor([0, 1, 5]) + attn_metadata.seq_lens = torch.tensor([7, 4]) + attn_mask = self.attention_mask_builder.get_splitfuse_mask(attn_metadata, torch.device("cpu")) + self.assertEqual(attn_mask.shape, (1, 128, 16, 16)) diff --git a/tests/ut/_310p/attention/test_attention_v1_310.py b/tests/ut/_310p/attention/test_attention_v1_310.py new file mode 100644 index 00000000..0794ec42 --- /dev/null +++ b/tests/ut/_310p/attention/test_attention_v1_310.py @@ -0,0 +1,160 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend._310p.attention.attention_v1 import ( + AscendAttentionBackend310, + AscendAttentionBackendImpl310, + AscendAttentionMetadataBuilder310, + AscendAttentionState, +) + + +class TestAscendAttentionBackend310(TestBase): + def setUp(self): + self.mock_config = MagicMock() + self.utils_patcher = patch("vllm_ascend.attention.utils.get_current_vllm_config", return_value=self.mock_config) + self.utils_patcher.start() + + def test_get_impl_cls(self): + self.assertEqual(AscendAttentionBackend310.get_impl_cls(), AscendAttentionBackendImpl310) + + def test_get_builder_cls(self): + self.assertEqual(AscendAttentionBackend310.get_builder_cls(), AscendAttentionMetadataBuilder310) + + def test_get_kv_cache_shape_not(self): + result = AscendAttentionBackend310.get_kv_cache_shape(10, 20, 30, 40) + self.assertEqual(result, (2, 10, 75, 20, 16)) + + +class TestAscendAttentionBackendImpl310(TestBase): + def setUp(self): + self.attention_type = MagicMock() + self.attention_type.DECODER = "decoder" + self.attention_type.ENCODER = "encoder" + self.attn_metadata = MagicMock() + self.attn_metadata.return_value = "1" + self.mock_vllm_config = MagicMock() + self.layer_no_quant = MagicMock(spec=["layer_name", "_k_scale_float", "_v_scale_float"]) + self.layer_no_quant.layer_name = "test_layer" + self.layer_no_quant._k_scale_float = 1.0 + self.layer_no_quant._v_scale_float = 1.0 + self.config_patcher = patch( + "vllm_ascend.attention.attention_v1.get_current_vllm_config", return_value=self.mock_vllm_config + ) + self.config_patcher.start() + self.impl = AscendAttentionBackendImpl310( + num_heads=8, + head_size=128, + scale=1.0, + num_kv_heads=8, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="float16", + logits_soft_cap=None, + attn_type=self.attention_type.DECODER, + kv_sharing_target_layer_name=None, + ) + + @patch("torch_npu._npu_reshape_and_cache") + @patch("torch_npu._npu_flash_attention") + @patch("vllm_ascend.attention.attention_v1.get_forward_context") + def test_forward_prefill_310( + self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache + ): + """Test forward pass in PrefillCacheHit state""" + query = torch.randn(10, 8, 64) + key = torch.randn(10, 8, 64) + value = torch.randn(10, 8, 64) + output = torch.empty_like(query) + metadata = self.attn_metadata + metadata.attn_state = AscendAttentionState.PrefillNoCache + metadata.attn_mask = torch.randn(1, 1, 10, 10) + metadata.query_lens = torch.tensor([10]) + metadata.seq_lens = torch.tensor([10]) + metadata.actual_seq_lengths_q = [10] + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.num_decode_tokens = 0 + metadata.num_decodes = 0 + metadata.num_prefills = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + + mock_get_forward_context.return_value = MagicMock(capturing=False) + mock_npu_npu_flash_attention.return_value = torch.ones(10, 8, 64) + output = self.impl.forward_prefill_310(query, key, value, metadata, output) + + mock_npu_npu_flash_attention.assert_called_once() + + @patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16)) + @patch("torch_npu._npu_reshape_and_cache") + @patch("torch_npu._npu_paged_attention_splitfuse") + @patch("vllm_ascend.attention.attention_v1.get_forward_context") + def test_forward_chunked_prefill_310( + self, mock_get_forward_context, mock_npu_paged_attention_splitfuse, mock_npu_reshape_and_cache, mock_format_cast + ): + """Test forward pass in PrefillCacheHit state""" + query = torch.randn(5, 8, 64) + output = torch.empty_like(query) + metadata = self.attn_metadata + metadata.attn_state = AscendAttentionState.ChunkedPrefill + metadata.attn_mask = torch.randn(1, 128, 16, 16) + metadata.query_lens = torch.tensor([5]) + metadata.seq_lens = torch.tensor([1, 4]) + metadata.query_start_loc = torch.tensor([0, 1, 5]) + metadata.actual_seq_lengths_q = [5] + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.num_decode_tokens = 0 + metadata.num_decodes = 0 + metadata.num_prefills = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + + mock_get_forward_context.return_value = MagicMock(capturing=False) + mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64) + output = self.impl.forward_chunked_prefill_310(query, metadata, output) + + mock_npu_paged_attention_splitfuse.assert_called_once() + + @patch("vllm_ascend.attention.attention_v1.using_paged_attention") + @patch("torch_npu._npu_paged_attention") + @patch("torch_npu._npu_reshape_and_cache") + @patch("vllm_ascend.attention.attention_v1.get_forward_context") + def test_forward_paged_attention_310( + self, mock_get_forward_context, mock_npu_reshape_and_cache, mock_paged_attention, mock_using_paged_attention + ): + """Test forward pass in DecodeOnly state""" + query = torch.randn(4, 8 * 64) + output = torch.empty_like(query) + + metadata = self.attn_metadata + metadata.attn_state = AscendAttentionState.DecodeOnly + metadata.seq_lens = torch.tensor([4]) + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 4 + metadata.slot_mapping = torch.zeros(4, dtype=torch.long) + metadata.num_decodes = 4 + metadata.num_prefills = 0 + mock_using_paged_attention.return_value = True + + mock_get_forward_context.return_value = MagicMock(capturing=False) + + output = self.impl.forward_paged_attention(query, metadata, output) + + mock_paged_attention.assert_called_once() diff --git a/vllm_ascend/_310p/attention/attention_mask.py b/vllm_ascend/_310p/attention/attention_mask.py index f0d42f31..7fec30ef 100644 --- a/vllm_ascend/_310p/attention/attention_mask.py +++ b/vllm_ascend/_310p/attention/attention_mask.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,84 +15,140 @@ # This file is a part of the vllm-ascend project. # -from collections.abc import Callable -from typing import Any - import torch import torch_npu -import vllm_ascend.attention.attention_mask as _base_mask -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, nd_to_nz_spec - -_BASE_BUILDER: Callable[[torch.device], Any] = _base_mask.AttentionMaskBuilder +from vllm_ascend.attention.attention_v1 import AscendMetadata +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, nd_to_nz_2d, nd_to_nz_spec -def _gen_causal_additive_mask_fp16(max_seq_len: int, device: torch.device) -> torch.Tensor: - tril = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool, device=device).tril_() - upper = ~tril - m = torch.zeros((max_seq_len, max_seq_len), dtype=torch.float16, device=device) - m.masked_fill_(upper, float("-inf")) - return m - - -def build_splitfuse_attn_mask_310p(attn_metadata, device, *, full_mask_cache=None, full_mask_cache_len=0): - qsl = attn_metadata.query_start_loc.detach().to("cpu", dtype=torch.int32) - qlens = qsl[1:] - qsl[:-1] - - context_lens = attn_metadata.seq_lens.to(dtype=torch.int32) - L = int(context_lens.max().item()) - - q_list = qlens.tolist() - c_list = context_lens.detach().to("cpu", dtype=torch.int64).tolist() - pos_list = [p for ql, cl in zip(q_list, c_list) for p in range(cl - ql, cl)] - position = torch.tensor(pos_list, dtype=torch.long, device=device) - - if full_mask_cache is None or full_mask_cache.device != device or full_mask_cache_len < L: - tril = torch.ones((L, L), dtype=torch.bool, device=device).tril_() - full = torch.zeros((L, L), dtype=torch.float16, device=device) - full.masked_fill_(~tril, float("-inf")) - full_mask_cache, full_mask_cache_len = full, L - else: - full = full_mask_cache[:L, :L].contiguous() - - rows = full.index_select(0, position).contiguous() - mask = torch_npu.npu_format_cast(nd_to_nz_spec(rows).contiguous(), ACL_FORMAT_FRACTAL_NZ) - return mask, full_mask_cache, full_mask_cache_len - - -class _AttentionMaskBuilder310P: - """ - 310P adapter: - - overrides fp16 causal additive mask generation (use -inf fp16) - - delegates all other behaviors to base AttentionMaskBuilder - - pooling runner_type is NOT supported on 310P (explicit) - """ +class AttentionMaskBuilder310: + chunked_prefill_attn_mask = None + max_seqlen = 2048 def __init__(self, device: torch.device): - self._base = _BASE_BUILDER(device) + """ + Initializes the AttentionMaskBuilder for the 310P device. - self._fp16_mask_cache: torch.Tensor | None = None - self._fp16_mask_cached_len: int = 0 + Args: + device (torch.device): The device on which tensors will be allocated. + """ + self.attn_mask_cache = None + self.device = device + self.swa_mask = None + self._seq_len_cached = 0 - def __getattr__(self, name: str) -> Any: - return getattr(self._base, name) + @staticmethod + def gen_causal_additive_mask(max_seq_len: int, device: torch.device): + """ + Generates a standard causal lower-triangular attention mask. - @property - def device(self) -> torch.device: - return self._base.device + The upper triangular part is filled with negative infinity (float("-inf")) + to mask out future tokens, while the lower triangular part is kept as 0. - def _get_fp16_mask(self, max_seq_len: int) -> torch.Tensor: - if self._fp16_mask_cache is None or max_seq_len > self._fp16_mask_cached_len: - self._fp16_mask_cache = _gen_causal_additive_mask_fp16(max_seq_len, self.device) - self._fp16_mask_cached_len = max_seq_len - assert self._fp16_mask_cache is not None - return self._fp16_mask_cache[:max_seq_len, :max_seq_len].contiguous() + Args: + max_seq_len (int): The maximum sequence length for the mask. + device (torch.device): The target device for the tensor. + + Returns: + torch.Tensor: A float16 tensor representing the causal mask. + """ + tril = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool, device=device).tril_() + upper = ~tril + mask = torch.zeros((max_seq_len, max_seq_len), dtype=torch.float16, device=device) + mask.masked_fill_(upper, float("-inf")) + return mask + + @classmethod + def get_splitfuse_mask(cls, attn_metadata: AscendMetadata, device: torch.device): + """ + Generates and formats the attention mask for SplitFuse (chunked prefill) decoding. + + It calculates the specific indices required based on query start locations + and context lengths, selects the relevant parts from the global chunked + mask, and converts the result to the NPU-specific fractal format. + + Args: + attn_metadata (AscendMetadata): Metadata containing query start locations and sequence lengths. + device (torch.device): The device to perform operations on. + + Returns: + torch.Tensor: The splitfuse attention mask cast to ACL_FORMAT_FRACTAL_NZ. + """ + if cls.chunked_prefill_attn_mask is None: + cls.chunked_prefill_attn_mask = cls.gen_causal_additive_mask(cls.max_seqlen, device) + qsl = attn_metadata.query_start_loc.to("cpu", dtype=torch.int32) + qlens = qsl[1:] - qsl[:-1] + q_list = qlens.tolist() + context_lens = attn_metadata.seq_lens.to("cpu", dtype=torch.int32) + c_list = context_lens.tolist() + pos_list = [p for ql, cl in zip(q_list, c_list) for p in range(cl - ql, cl)] + position = torch.tensor(pos_list, dtype=torch.int32, device=device) + splitfuse_mask = cls.chunked_prefill_attn_mask.index_select(0, position) + splitfuse_mask_nz = torch_npu.npu_format_cast(nd_to_nz_spec(splitfuse_mask).contiguous(), ACL_FORMAT_FRACTAL_NZ) + return splitfuse_mask_nz + + def get_swa_mask(self, dtype: torch.dtype, sliding_window): + """ + Generates or retrieves a cached Sliding Window Attention (SWA) mask. + + This mask allows attention only within a specific local window (diagonal band), + masking out tokens that are too far in the past or in the future. + + Args: + dtype (torch.dtype): Data type of the mask. + sliding_window (int): The size of the sliding window. + + Returns: + torch.Tensor: The SWA mask cast to ACL_FORMAT_FRACTAL_NZ. + """ + assert dtype == torch.float16 + if sliding_window is not None and self.swa_mask is None: + mask = torch.ones(self.max_seqlen, self.max_seqlen, dtype=torch.bool) + triu_mask = torch.triu(mask, diagonal=1).to(self.device) + tril_mask = torch.tril(mask, -sliding_window).to(self.device) + mask = triu_mask + tril_mask + swa_mask = torch.zeros((self.max_seqlen, self.max_seqlen), dtype=dtype, device=self.device) + swa_mask.masked_fill_(mask, float("-inf")) + self.swa_mask = torch_npu.npu_format_cast(nd_to_nz_2d(swa_mask), ACL_FORMAT_FRACTAL_NZ) + return self.swa_mask def get_attention_mask(self, model_config) -> torch.Tensor: + """ + Retrieves the appropriate attention mask based on the model configuration. + + It explicitly checks for 'pooling' runner types which are not supported + on 310P hardware. + + Args: + model_config: Configuration object containing runner details. + + Returns: + torch.Tensor: The causal attention mask. + + Raises: + NotImplementedError: If the runner_type is 'pooling'. + """ if getattr(model_config, "runner_type", None) == "pooling": + # TODO: pooling model will be supported soon. raise NotImplementedError("310P does not support runner_type='pooling'") - return self._get_fp16_mask(2048) + return self._get_causal_mask(self.max_seqlen) + def _get_causal_mask(self, max_seq_len: int) -> torch.Tensor: + """ + Internal method to get or update the cached causal attention mask. -def AttentionMaskBuilder(device: torch.device) -> _AttentionMaskBuilder310P: - return _AttentionMaskBuilder310P(device) + If the cache is empty or the requested length exceeds the cached length, + a new mask is generated and converted to the NPU fractal format. + + Args: + max_seq_len (int): The required sequence length. + + Returns: + torch.Tensor: The cached causal mask in ACL_FORMAT_FRACTAL_NZ. + """ + if self.attn_mask_cache is None or max_seq_len > self._seq_len_cached: + attn_mask = self.gen_causal_additive_mask(max_seq_len, self.device) + self.attn_mask_cache = torch_npu.npu_format_cast(nd_to_nz_2d(attn_mask), ACL_FORMAT_FRACTAL_NZ) + self._seq_len_cached = max_seq_len + return self.attn_mask_cache diff --git a/vllm_ascend/_310p/attention/attention_v1.py b/vllm_ascend/_310p/attention/attention_v1.py index 3637685b..ce3b8f02 100644 --- a/vllm_ascend/_310p/attention/attention_v1.py +++ b/vllm_ascend/_310p/attention/attention_v1.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,43 +17,95 @@ from typing import Any -import torch import torch_npu +from vllm.v1.attention.backends.registry import ( # type: ignore + AttentionBackendEnum, + register_backend, +) -from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder, build_splitfuse_attn_mask_310p -from vllm_ascend._310p.attention.metadata_builder import AscendAttentionMetadataBuilder310P -from vllm_ascend.attention.attention_v1 import AscendAttentionBackend as _BaseBackend -from vllm_ascend.attention.attention_v1 import AscendAttentionBackendImpl as _BaseImpl -from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder, AscendAttentionState, AscendMetadata -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, nd_to_nz_2d +from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder310 +from vllm_ascend._310p.attention.metadata_builder import AscendAttentionMetadataBuilder310 +from vllm_ascend.attention.attention_v1 import ( + AscendAttentionBackend, + AscendAttentionBackendImpl, + AscendAttentionMetadataBuilder, + AscendAttentionState, + AscendMetadata, +) -class AscendAttentionBackend310(_BaseBackend): +@register_backend(AttentionBackendEnum.CUSTOM, "ASCEND") +class AscendAttentionBackend310(AscendAttentionBackend): def __init__(self, *args, **kwargs): + """ + Initializes the 310P backend and sets up the device-specific mask builder. + """ super().__init__(*args, **kwargs) - self.attn_mask_builder = AttentionMaskBuilder(self.device) + self.attn_mask_builder = AttentionMaskBuilder310(self.device) @staticmethod def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int): + """ + Determines the shape of the Key-Value (KV) cache tensor. + + The 310P hardware requires specific memory alignment for optimal performance. + This method defines a 5D tensor shape where the head size dimension is + split to ensure alignment to multiples of 16. + + Args: + num_blocks (int): Number of memory blocks. + block_size (int): Size of each block. + num_kv_heads (int): Number of KV heads. + head_size (int): Dimension size of each head. + + Returns: + tuple: The specific 5D shape required by the hardware + (2, num_blocks, hidden_dim_aligned, block_size, 16). + """ # Align to a multiple of 16, as required by the 310P device. return (2, num_blocks, (num_kv_heads * head_size) // 16, block_size, 16) @staticmethod def get_impl_cls(): + """ + Returns the implementation class for the attention operations. + """ return AscendAttentionBackendImpl310 @staticmethod def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: - return AscendAttentionMetadataBuilder310P + """ + Returns the metadata builder class specifically for 310P. + """ + return AscendAttentionMetadataBuilder310 -class AscendAttentionBackendImpl310(_BaseImpl): +class AscendAttentionBackendImpl310(AscendAttentionBackendImpl): + """ + Implementation of attention operations (Prefill, Decode, Chunked Prefill) + optimized for the Ascend 310P architecture. + """ + def forward_paged_attention( self, query: Any, attn_metadata: AscendMetadata, output: Any | None = None, ) -> Any: + """ + Executes Paged Attention (typically for the decode phase). + + Ensures that the sequence length metadata is on the correct device + before invoking the base implementation. + + Args: + query (Any): The query tensor. + attn_metadata (AscendMetadata): Metadata associated with the attention request. + output (Any | None): Optional output tensor. + + Returns: + Any: The result of the attention operation. + """ if attn_metadata.seq_lens.device != query.device: attn_metadata.seq_lens = attn_metadata.seq_lens.to( device=query.device, @@ -61,34 +113,34 @@ class AscendAttentionBackendImpl310(_BaseImpl): ) return super().forward_paged_attention(query, attn_metadata, output) - def _forward_prefill_310p_fallback(self, query, key, value, attn_metadata, output): + def forward_prefill_310(self, query, key, value, attn_metadata, output): + """ + Executes Flash Attention for the prefill phase on 310P. + + This method handles memory alignment padding. If the query shape implies + padding (aligned_tokens > real_tokens), it adjusts the sequence length + of the last request to account for the delta, ensuring the NPU operator + processes the data correctly. + + Args: + query, key, value: Input tensors. + attn_metadata (AscendMetadata): Attention metadata containing masks and seq_lens. + output: Output tensor. + + Returns: + The output tensor after flash attention. + """ real_tokens = int(attn_metadata.seq_lens.sum().item()) - seq_len = attn_metadata.seq_lens - if seq_len.dtype != torch.int32: - seq_len = seq_len.to(torch.int32) - aligned_tokens = int(query.shape[0]) delta = aligned_tokens - real_tokens + + # Adjust sequence length if padding (alignment) was applied to the inputs if delta: seq_len = seq_len.clone() seq_len[-1] += delta mask = attn_metadata.attn_mask - if mask is not None and mask.dim() == 2: - max_len = int(seq_len.max().item()) - aligned_len = ((max_len + 15) // 16) * 16 - - mask2d = mask[:aligned_len, :aligned_len].contiguous() - mask2d = mask2d.to(torch.float16) - mask_nz = nd_to_nz_2d(mask2d).contiguous() - - bsz = int(seq_len.numel()) - if bsz > 1: - mask_nz = mask_nz.repeat(bsz, 1, 1, 1).contiguous() - - mask = torch_npu.npu_format_cast(mask_nz, ACL_FORMAT_FRACTAL_NZ) - torch_npu._npu_flash_attention( query=query, key=key, @@ -100,43 +152,35 @@ class AscendAttentionBackendImpl310(_BaseImpl): num_kv_heads=self.num_kv_heads, out=output, ) + return output - return output[:aligned_tokens, :, :] + def forward_chunked_prefill_310(self, query, attn_metadata, output): + """ + Executes SplitFuse (Chunked Prefill) attention on 310P. - def _forward_chunked_prefill_310p(self, query, attn_metadata, output): - assert attn_metadata is not None - - if query.dtype == torch.float32: - query = query.to(torch.float16) + This handles scenarios where the prefill is split into chunks. It prepares + the necessary metadata (query lengths, block tables) and generates the + specific splitfuse mask before calling the NPU operator. + Args: + query: The query tensor. + attn_metadata (AscendMetadata): Metadata containing start locations and block tables. + output: The output tensor. + """ num_actual_tokens = int(attn_metadata.num_actual_tokens) query = query[:num_actual_tokens] output = output[:num_actual_tokens] - qsl_cpu = attn_metadata.query_start_loc.detach().to("cpu", dtype=torch.int32) - qlens = (qsl_cpu[1:] - qsl_cpu[:-1]).to(torch.int32) + # Calculate query lengths from start locations + qsl_cpu = attn_metadata.query_start_loc.cpu() + qlens = qsl_cpu[1:] - qsl_cpu[:-1] context_lens = attn_metadata.seq_lens - if context_lens.dtype != torch.int32: - context_lens = context_lens.to(torch.int32) + block_table = attn_metadata.block_tables - block_table = attn_metadata.block_tables.detach() - if block_table.dtype != torch.int32: - block_table = block_table.to(torch.int32) + # Generate the specific mask for splitfuse + mask = AttentionMaskBuilder310.get_splitfuse_mask(attn_metadata, query.device) - if not hasattr(self, "_sf_full_mask_cache"): - self._sf_full_mask_cache = None - self._sf_full_mask_cache_len = 0 - - mask, self._sf_full_mask_cache, self._sf_full_mask_cache_len = build_splitfuse_attn_mask_310p( - attn_metadata, - query.device, - full_mask_cache=self._sf_full_mask_cache, - full_mask_cache_len=int(self._sf_full_mask_cache_len), - ) - - if qlens.device.type != "cpu": - qlens = qlens.to("cpu") if context_lens.device != query.device: context_lens = context_lens.to(query.device, non_blocking=True) @@ -155,21 +199,35 @@ class AscendAttentionBackendImpl310(_BaseImpl): ) def forward_impl(self, query, key, value, kv_cache, attn_metadata, output): + """ + Main dispatch method for attention operations. + + Routes the execution to Decode, Prefill, or Chunked Prefill methods + based on the current attention state found in metadata. + + Args: + query, key, value: Input tensors (Key/Value usually empty for decode/chunked). + kv_cache: The KV cache structure. + attn_metadata: Metadata determining the state (Prefill vs Decode). + output: Tensor to write results to. + + Returns: + The output tensor. + + Raises: + NotImplementedError: If the attention state is not supported on 310P. + """ state = attn_metadata.attn_state if state == AscendAttentionState.DecodeOnly: return self.forward_paged_attention(query, attn_metadata, output) if state == AscendAttentionState.PrefillNoCache: - num_tokens = query.shape[0] - q = query[:num_tokens] - k = key[:num_tokens] - v = value[:num_tokens] - out = self._forward_prefill_310p_fallback(q, k, v, attn_metadata, output) + out = self.forward_prefill_310(query, key, value, attn_metadata, output) return out if state == AscendAttentionState.ChunkedPrefill: - self._forward_chunked_prefill_310p(query, attn_metadata, output) + self.forward_chunked_prefill_310(query, attn_metadata, output) return output raise NotImplementedError( diff --git a/vllm_ascend/_310p/attention/metadata_builder.py b/vllm_ascend/_310p/attention/metadata_builder.py index 71c0c650..5e43ac63 100644 --- a/vllm_ascend/_310p/attention/metadata_builder.py +++ b/vllm_ascend/_310p/attention/metadata_builder.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,19 +15,26 @@ # This file is a part of the vllm-ascend project. # -from __future__ import annotations - from typing import Any import torch from vllm.config import VllmConfig from vllm.v1.kv_cache_interface import AttentionSpec -from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder -from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder as _BaseBuilder +from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder310 +from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder -class AscendAttentionMetadataBuilder310P(_BaseBuilder): +class AscendAttentionMetadataBuilder310(AscendAttentionMetadataBuilder): + """ + Metadata builder specialized for the Huawei Ascend 310P NPU. + + This class extends the base Ascend attention metadata builder to use + the 310P-specific attention mask builder, ensuring that masks are + generated in the correct format (FRACTAL_NZ) and logic required by + the 310P hardware. + """ + def __init__( self, kv_cache_spec: AttentionSpec, @@ -35,6 +42,16 @@ class AscendAttentionMetadataBuilder310P(_BaseBuilder): vllm_config: VllmConfig, device: torch.device, ): + """ + Initializes the metadata builder and the 310P-specific mask builder. + + Args: + kv_cache_spec (AttentionSpec): Specification for the KV cache (block size, etc.). + layer_names (list[str]): List of layer names in the model. + vllm_config (VllmConfig): Global vLLM configuration object. + device (torch.device): The device (NPU) to run operations on. + """ super().__init__(kv_cache_spec, layer_names, vllm_config, device) - self.attn_mask_builder: Any = AttentionMaskBuilder(self.device) + # Override the mask builder with the 310P-specific version + self.attn_mask_builder: Any = AttentionMaskBuilder310(self.device)