From 3158742a9751234434b1f9caeefe4f9b85f530d9 Mon Sep 17 00:00:00 2001 From: Yizhou <136800916+yiz-liu@users.noreply.github.com> Date: Sat, 25 Oct 2025 08:58:35 +0800 Subject: [PATCH] [Refactor] Refactor Ascend attention implementation forward (#3714) ### What this PR does / why we need it? This PR refactors the Ascend attention implementation to align with vLLM's core interfaces, simplifying the code and improving maintainability. ### Key Changes: * **Align with vLLM's Attention Interface**: The `forward` method signature in `AscendAttentionBackendImpl` now matches the base `AttentionImpl` in vLLM, removing the custom `trace_flag`. * **Enable Opaque Attention Operator**: By adding `opaque_attention_op` to `AscendPlatform`, we allow vLLM to wrap our attention kernel in its standard `vllm.unified_attention_with_output` operator. This avoids the need for a custom call path. * **Remove Obsolete Code**: * The custom op `vllm.unified_ascend_attention_with_output` has been deleted as it is now redundant. * The `trace_flag` and its associated logic were removed, reducing code complexity. * An outdated quantization branch within the attention implementation was cleaned up. * **Improve Readability**: Renamed output variables (`output` vs. `intermediate_output`) and added comments to clarify the in-place nature of the attention output. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No extra tests needed. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44 --------- Signed-off-by: Yizhou Liu --- tests/ut/attention/test_attention_v1.py | 179 +++------ tests/ut/torchair/test_torchair_attention.py | 2 +- vllm_ascend/attention/attention_v1.py | 344 +++++++----------- vllm_ascend/platform.py | 9 +- vllm_ascend/torchair/models/qwen2.py | 1 - vllm_ascend/torchair/models/qwen3_moe.py | 1 - .../torchair/models/torchair_pangu_moe.py | 2 +- vllm_ascend/torchair/torchair_attention.py | 1 - vllm_ascend/worker/model_runner_v1.py | 1 - 9 files changed, 191 insertions(+), 349 deletions(-) diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index 20e09782..237fd299 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -255,59 +255,6 @@ class TestAscendAttentionBackendImpl(TestBase): attn_type=self.attention_type.DECODER, kv_sharing_target_layer_name=None) - @patch('torch.ops.vllm.unified_ascend_attention_with_output') - def test_forward_trace_flag_true(self, mock_unified_attention): - """Test forward pass when trace_flag is True""" - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) - kv_cache = torch.empty(2, 0, 0, 8, 64) - metadata = self.attn_metadata - layer = self.layer - - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=True) - - mock_unified_attention.assert_called_once() - assert output.shape == (10, 8 * 64) - - @patch('torch_npu._npu_paged_attention_splitfuse') - def test_forward_with_quant_method(self, mock_paged_attention): - """Test forward pass when layer has quant_method""" - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) - k_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8) - v_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8) - kv_cache = [k_cache, v_cache] - ret_value = torch.ones(1, 1, 10, 8, 64, dtype=torch.int8) - - metadata = MagicMock() - metadata.num_actual_tokens = torch.randn(10, 8 * 64) - metadata.block_tables = torch.randn(10, 8 * 64) - metadata.seq_lens = torch.randn(10, 8 * 64) - metadata.attn_mask = torch.randn(10, 8 * 64) - metadata.query_lens = torch.randn(10, 8 * 64) - layer = self.layer - layer.quant_method = MagicMock() - layer.quant_method.apply.return_value = ret_value - - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) - - layer.quant_method.apply.assert_called_once() - assert output.shape == (10, 8 * 64) - def test_forward_no_attn_metadata(self): """Test forward pass when attn_metadata is None""" query = torch.randn(10, 8 * 64) @@ -315,14 +262,10 @@ class TestAscendAttentionBackendImpl(TestBase): value = torch.randn(10, 8 * 64) kv_cache = torch.empty(2, 0, 0, 8, 64) layer = self.layer_no_quant + output = torch.empty_like(query) - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - None, - trace_flag=False) + output = self.impl.forward(layer, query, key, value, kv_cache, None, + output) assert output.shape == (10, 8 * 64) @@ -335,6 +278,8 @@ class TestAscendAttentionBackendImpl(TestBase): key = torch.randn(10, 8 * 64) value = torch.randn(10, 8 * 64) kv_cache = torch.empty(2, 5, 128, 8, 64) + output = torch.empty_like(query) + metadata = self.attn_metadata metadata.attn_state = AscendAttentionState.PrefillNoCache metadata.attn_mask = torch.randn(1, 1, 10, 10) @@ -344,15 +289,9 @@ class TestAscendAttentionBackendImpl(TestBase): metadata.num_decodes = 0 metadata.num_prefills = 10 layer = self.layer_no_quant - # layer.quant_method.apply.return_value = metadata - print(self.layer_no_quant._v_scale_float) - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + + output = self.impl.forward(layer, query, key, value, kv_cache, + metadata, output) mock_reshape_cache.assert_called_once() mock_flash_attention.assert_called_once() @@ -367,6 +306,8 @@ class TestAscendAttentionBackendImpl(TestBase): key = torch.randn(10, 8 * 64) value = torch.randn(10, 8 * 64) kv_cache = torch.empty(2, 5, 128, 8, 64) + output = torch.empty_like(query) + metadata = self.attn_metadata metadata.attn_state = AscendAttentionState.PrefillCacheHit metadata.attn_mask = torch.randn(1, 1, 10, 10) @@ -379,13 +320,8 @@ class TestAscendAttentionBackendImpl(TestBase): metadata.num_prefills = 10 layer = self.layer_no_quant - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl.forward(layer, query, key, value, kv_cache, + metadata, output) mock_flash_attention_qlens.assert_called_once() assert output.shape == (10, 8 * 64) @@ -401,6 +337,8 @@ class TestAscendAttentionBackendImpl(TestBase): key = torch.randn(10, 8 * 64) value = torch.randn(10, 8 * 64) kv_cache = torch.empty(2, 5, 128, 8, 64) + output = torch.empty_like(query) + metadata = self.attn_metadata metadata.attn_state = AscendAttentionState.DecodeOnly metadata.seq_lens = torch.tensor([10]) @@ -413,13 +351,8 @@ class TestAscendAttentionBackendImpl(TestBase): mock_get_forward_context.return_value = MagicMock(capturing=False) - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl.forward(layer, query, key, value, kv_cache, + metadata, output) mock_paged_attention.assert_called_once() assert output.shape == (10, 8 * 64) @@ -509,6 +442,8 @@ class TestAscendAttentionBackendImpl(TestBase): key = torch.randn(10, 8 * 64) value = torch.randn(10, 8 * 64) kv_cache = torch.empty(2, 5, 128, 8, 64) + output = torch.empty_like(query) + metadata = self.attn_metadata metadata.attn_state = AscendAttentionState.DecodeOnly metadata.seq_lens = torch.tensor([10]) @@ -522,13 +457,8 @@ class TestAscendAttentionBackendImpl(TestBase): mock_get_forward_context.return_value = MagicMock(capturing=True) mock_get_graph_params.return_value = graph_params - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl.forward(layer, query, key, value, kv_cache, + metadata, output) mock_paged_attention.assert_called_once() self.assertEqual(len(graph_params.handles[num_tokens]), 0) @@ -542,6 +472,8 @@ class TestAscendAttentionBackendImpl(TestBase): key = torch.randn(10, 8 * 64) value = torch.randn(10, 8 * 64) kv_cache = torch.empty(2, 5, 128, 8, 64) + output = torch.empty(10, 8, 64) + metadata = self.attn_metadata metadata.attn_state = AscendAttentionState.DecodeOnly metadata.seq_lens = torch.tensor([10] * 10) @@ -553,16 +485,11 @@ class TestAscendAttentionBackendImpl(TestBase): layer = self.layer_no_quant mock_fused_infer_attention_score.return_value = (torch.ones(10, 8, 64), 1) - output = self.impl_swa.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl_swa.forward(layer, query, key, value, kv_cache, + metadata, output) print(output.shape) mock_fused_infer_attention_score.assert_called_once() - assert output.shape == (10, 8 * 64) + assert output.shape == (10, 8, 64) @patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('torch_npu._npu_reshape_and_cache') @@ -576,6 +503,7 @@ class TestAscendAttentionBackendImpl(TestBase): key = torch.randn(10, 8 * 64) value = torch.randn(10, 8 * 64) kv_cache = torch.empty(2, 5, 128, 8, 64) + output = torch.empty_like(query) metadata = self.attn_metadata metadata.attn_state = AscendAttentionState.DecodeOnly @@ -583,6 +511,7 @@ class TestAscendAttentionBackendImpl(TestBase): metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) metadata.num_actual_tokens = 10 metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant metadata.num_decodes = 10 metadata.num_prefills = 0 @@ -591,13 +520,8 @@ class TestAscendAttentionBackendImpl(TestBase): mock_get_forward_context.return_value = MagicMock(capturing=False) - output = self.impl_swa.forward(self.layer_no_quant, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl_swa.forward(layer, query, key, value, kv_cache, + metadata, output) mock_paged_attention.assert_called_once() mock_fused_infer_attention_score.assert_not_called() @@ -618,6 +542,8 @@ class TestAscendAttentionBackendImpl(TestBase): key = torch.randn(10, 8 * 192) value = torch.randn(10, 8 * 192) kv_cache = torch.empty(2, 5, 128, 8, 192) + output = torch.empty_like(query) + metadata = self.attn_metadata metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.query_lens = torch.tensor([10]) @@ -631,13 +557,8 @@ class TestAscendAttentionBackendImpl(TestBase): mock_version.cann = "8.4.RC1" mock_vanilla_prefill.return_value = MagicMock() - output = self.impl_192.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl_192.forward(layer, query, key, value, kv_cache, + metadata, output) mock_vanilla_prefill.assert_called_once() assert output.shape == (10, 8 * 192) @@ -653,6 +574,8 @@ class TestAscendAttentionBackendImpl(TestBase): key = torch.randn(10, 8 * 64) value = torch.randn(10, 8 * 64) kv_cache = torch.empty(2, 5, 128, 8, 64) + output = torch.empty_like(query) + metadata = self.attn_metadata metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.query_lens = torch.tensor([10]) @@ -666,13 +589,8 @@ class TestAscendAttentionBackendImpl(TestBase): mock_version.cann = "8.4.RC1" - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl.forward(layer, query, key, value, kv_cache, + metadata, output) mock_paged_attention.assert_called_once() assert output.shape == (10, 8 * 64) @@ -690,6 +608,8 @@ class TestAscendAttentionBackendImpl(TestBase): key = torch.randn(10, 8 * 64) value = torch.randn(10, 8 * 64) kv_cache = torch.empty(2, 5, 128, 8, 64) + output = torch.empty_like(query) + metadata = self.attn_metadata metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.query_lens = torch.tensor([10]) @@ -703,13 +623,9 @@ class TestAscendAttentionBackendImpl(TestBase): mock_npu_format_cast.return_value = metadata.attn_mask mock_version.cann = "8.4.RC1" - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + + output = self.impl.forward(layer, query, key, value, kv_cache, + metadata, output) mock_paged_attention.assert_called_once() assert output.shape == (10, 8 * 64) @@ -720,6 +636,8 @@ class TestAscendAttentionBackendImpl(TestBase): key = torch.randn(10, 8 * 64) value = torch.randn(10, 8 * 64) kv_cache = torch.empty(2, 5, 128, 8, 64) + output = torch.empty_like(query) + metadata = self.attn_metadata metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.query_lens = torch.tensor([10]) @@ -732,10 +650,5 @@ class TestAscendAttentionBackendImpl(TestBase): layer = self.layer_no_quant with self.assertRaises(NotImplementedError): - self.impl_error.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + self.impl_error.forward(layer, query, key, value, kv_cache, + metadata, output) diff --git a/tests/ut/torchair/test_torchair_attention.py b/tests/ut/torchair/test_torchair_attention.py index dd262dc8..0ee79d26 100644 --- a/tests/ut/torchair/test_torchair_attention.py +++ b/tests/ut/torchair/test_torchair_attention.py @@ -91,5 +91,5 @@ class TestAscendAttentionTorchairBackendImpl(TestBase): torch.ones(1)) result = self.impl.forward(layer, query, key, value, kv_cache, - metadata, output, False) + metadata, output) self.assertEqual(result.shape[0], num_tokens) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 93cac288..71c5fdab 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -32,31 +32,28 @@ from vllm.distributed import (get_dcp_group, get_decode_context_model_parallel_rank, get_decode_context_model_parallel_world_size) from vllm.forward_context import ForwardContext, get_forward_context -from vllm.utils import cdiv, direct_register_custom_op +from vllm.utils import cdiv from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec -# isort: off from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - maybe_save_kv_layer_to_connector, - split_decodes_and_prefills, - wait_for_kv_layer_from_connector) + split_decodes_and_prefills) from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d, nd_to_nz_spec, - prefill_context_parallel_enable, version_check) - -from ..utils import weak_ref_tensors + prefill_context_parallel_enable, version_check, + weak_ref_tensors) +# isort: off if prefill_context_parallel_enable(): from vllm.distributed import (get_pcp_group, get_prefill_context_model_parallel_rank, get_prefill_context_model_parallel_world_size ) -# isort:on +# isort: on class AscendAttentionBackend(AttentionBackend): @@ -484,7 +481,7 @@ class AscendAttentionBackendImpl(AttentionImpl): num_kv_heads=self.num_kv_heads, out=output) assert output is not None - return output[:num_tokens, :, :] + return output[:num_tokens] def _forward_prefill_cache_hit( self, @@ -937,169 +934,14 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_lse_allgather) return attn_out - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Tuple[torch.Tensor], - attn_metadata: AscendMetadata, - output: Optional[torch.Tensor] = None, - trace_flag: bool = True, - ) -> torch.Tensor: - """Forward pass with Ascend attention. - Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] - kv_cache: shape = [key_cache, value_cache] - key_cache = [num_blocks, block_size, - num_kv_heads, head_size] - value_cache = [num_blocks, block_size, - num_kv_heads, head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [batch_size * seq_len, num_heads, head_size] - """ - num_tokens = query.shape[0] - use_kv_cache_int8 = len( - kv_cache) > 0 and kv_cache[0].dtype == torch.int8 - if output is None: - output = torch.empty(num_tokens, - self.num_heads, - self.head_size, - dtype=query.dtype, - device=query.device) - ori_output = output - if trace_flag: - torch.ops.vllm.unified_ascend_attention_with_output( - query=query, - key=key, - value=value, - output=output, - layer_name=layer.layer_name) - - elif hasattr(layer, 'quant_method') and use_kv_cache_int8: - output = layer.quant_method.apply(layer, query, key, value, - kv_cache, attn_metadata, - self.attn_type, self.scale, - output) - - else: - if attn_metadata is None: - return output.view(num_tokens, self.hidden_size).fill_(0) - num_decode_tokens = attn_metadata.num_decode_tokens - has_decode = attn_metadata.num_decodes > 0 - has_prefill = attn_metadata.num_prefills > 0 - - assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 - attn_type = self.attn_type - if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY: - raise NotImplementedError("Encoder/decoder cross-attention " - "are not implemented for " - "PallasAttentionBackendImpl") - # View q k v to BSH. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - # TODO: Remove this contiguous in the future. - value = value.contiguous() - - if len(kv_cache) > 1: - if self.key_cache is None: - self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] - - if has_decode: - slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * self.pcp_size: self.pcp_size] \ - if self.pcp_size * self.dcp_size > 1 else attn_metadata.slot_mapping[:num_decode_tokens] - torch_npu._npu_reshape_and_cache( - key=key[:num_decode_tokens], - value=value[:num_decode_tokens], - key_cache=self.key_cache, - value_cache=self.value_cache, - slot_indices=slot_mapping) - - if has_prefill: - if self.pcp_size > 1: - kv = torch.cat([key, value], dim=-1) - all_kv = get_pcp_group().all_gather(kv, dim=0) - pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None - all_kv = torch.index_select(all_kv, 0, - pcp_allgather_restore_idx) - key, value = all_kv.split( - [self.head_size, self.head_size], dim=-1) - - torch_npu._npu_reshape_and_cache( - key=key[self.pcp_size * - num_decode_tokens:attn_metadata. - num_actual_tokens_pcp_padded], - value=value[self.pcp_size * - num_decode_tokens:attn_metadata. - num_actual_tokens_pcp_padded], - key_cache=self.key_cache, - value_cache=self.value_cache, - slot_indices=attn_metadata. - slot_mapping[self.pcp_size * - num_decode_tokens:attn_metadata. - num_actual_tokens_pcp_padded]) - - if self.pcp_size * self.dcp_size > 1: - output = self._forward_pcp_dcp(query, key, value, - attn_metadata, output) - - elif attn_type == AttentionType.ENCODER_ONLY: - cum_seq_len = attn_metadata.query_start_loc[1:].tolist() - attn_out = torch_npu.npu_fusion_attention( - query, - key, - value, - head_num=self.num_heads, - input_layout="TND", - scale=self.scale, - sparse_mode=4, - atten_mask=attn_metadata.attn_mask, - next_tockens=attn_metadata.max_query_len, - actual_seq_qlen=cum_seq_len, - actual_seq_kvlen=cum_seq_len, - ) - output = attn_out[0] - # V0-Style scheduler situation. - elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - output = self._forward_prefill_no_cache( - query, key, value, attn_metadata, output, num_tokens) - elif attn_metadata.attn_state == \ - AscendAttentionState.PrefillCacheHit: - output = self._forward_prefill_cache_hit( - query, attn_metadata, output) - elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - output = self._forward_decode_only(query, attn_metadata, - output) - # Normal V1 situation. - else: - if torch.version.cann.startswith("8.3"): - # npu_fused_infer_attention_score does not support cases - # where query.shape[0] != attn_metadata.query_start_loc[-1]. - # Thus we need unpad it here. - num_tokens = attn_metadata.query_start_loc[-1] - query = query[:num_tokens] - output = self._forward_v1_style(query, attn_metadata, output) - - # to make in-place change to the output tensor - if hasattr(layer, 'quant_method') and use_kv_cache_int8: - output = output.view(num_tokens, self.num_heads, self.head_size) - ori_output[:num_tokens, :, :] = output[:num_tokens, :, :] - return output.view(num_tokens, self.hidden_size) - def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata, - output: Optional[torch.Tensor]) -> torch.Tensor: + output: torch.Tensor) -> torch.Tensor: assert attn_metadata is not None has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - if output is None: - raise ValueError("Output buffer is required") + if has_decode: decode_query = query[:num_decode_tokens] output_decode = self._forward_decode_pcp_dcp( @@ -1131,47 +973,137 @@ class AscendAttentionBackendImpl(AttentionImpl): output[num_decode_tokens:] = output_prefill return output + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor], + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with Ascend attention. + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." -def unified_ascend_attention_with_output( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - layer_name: str, -) -> None: - wait_for_kv_layer_from_connector(layer_name) - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[layer_name] - self = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] - self.impl.forward(self, - query, - key, - value, - kv_cache, - attn_metadata, - output, - trace_flag=False) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) - return + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for AscendAttentionBackendImpl") + num_tokens = query.shape[0] + if attn_metadata is None: + return output -def unified_attention_with_output_fake( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - layer_name: str, -) -> None: - return + # NOTE: Currently, we have various attention paths for different + # scenarios, and not all of them are in-place operations. Therefore, + # we need to create a separate tensor to hold the attention result. + # In the future, we may consolidate them into fewer paths, which will + # hopefully allow us to use in-place operation by default. + intermediate_output: torch.Tensor + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + attn_type = self.attn_type + if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") -direct_register_custom_op( - op_name="unified_ascend_attention_with_output", - op_func=unified_ascend_attention_with_output, - mutates_args=["output"], - fake_impl=unified_attention_with_output_fake, - dispatch_key="PrivateUse1", -) + num_decode_tokens = attn_metadata.num_decode_tokens + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + + if len(kv_cache) > 1: + if self.key_cache is None: + self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] + + if has_decode: + slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * self.pcp_size: self.pcp_size] \ + if self.pcp_size * self.dcp_size > 1 else attn_metadata.slot_mapping[:num_decode_tokens] + torch_npu._npu_reshape_and_cache( + key=key[:num_decode_tokens], + value=value[:num_decode_tokens], + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_indices=slot_mapping) + + if has_prefill: + if self.pcp_size > 1: + kv = torch.cat([key, value], dim=-1) + all_kv = get_pcp_group().all_gather(kv, dim=0) + pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None + all_kv = torch.index_select(all_kv, 0, + pcp_allgather_restore_idx) + key, value = all_kv.split([self.head_size, self.head_size], + dim=-1) + + torch_npu._npu_reshape_and_cache( + key=key[self.pcp_size * num_decode_tokens:attn_metadata. + num_actual_tokens_pcp_padded], + value=value[self.pcp_size * + num_decode_tokens:attn_metadata. + num_actual_tokens_pcp_padded], + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_indices=attn_metadata. + slot_mapping[self.pcp_size * + num_decode_tokens:attn_metadata. + num_actual_tokens_pcp_padded]) + + if self.pcp_size * self.dcp_size > 1: + intermediate_output = self._forward_pcp_dcp( + query, key, value, attn_metadata, output) + elif attn_type == AttentionType.ENCODER_ONLY: + # TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly. + cum_seq_len = attn_metadata.query_start_loc[1:].tolist() + intermediate_output = torch_npu.npu_fusion_attention( + query, + key, + value, + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + sparse_mode=4, + atten_mask=attn_metadata.attn_mask, + pre_tockens=attn_metadata.max_query_len, + next_tockens=attn_metadata.max_query_len, + actual_seq_qlen=cum_seq_len, + actual_seq_kvlen=cum_seq_len, + )[0] + # V0-Style scheduler situation. + elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + intermediate_output = self._forward_prefill_no_cache( + query, key, value, attn_metadata, output, num_tokens) + elif attn_metadata.attn_state == \ + AscendAttentionState.PrefillCacheHit: + intermediate_output = self._forward_prefill_cache_hit( + query, attn_metadata, output) + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + intermediate_output = self._forward_decode_only( + query, attn_metadata, output) + # Normal V1 situation. + else: + if torch.version.cann.startswith("8.3"): + # npu_fused_infer_attention_score does not support cases + # where query.shape[0] != attn_metadata.query_start_loc[-1]. + # Thus we need unpad it here. + num_tokens = attn_metadata.query_start_loc[-1] + query = query[:num_tokens] + intermediate_output = self._forward_v1_style( + query, attn_metadata, output) + + output[:num_tokens] = intermediate_output[:num_tokens] + + return output diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 677352e2..a074e9c8 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -293,10 +293,7 @@ class NPUPlatform(Platform): "When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE" compilation_config.set_splitting_ops_for_v1() compilation_config.use_inductor = False - compilation_config.splitting_ops.extend([ - "vllm::unified_ascend_attention_with_output", - "vllm::mla_forward" - ]) + compilation_config.splitting_ops.extend(["vllm::mla_forward"]) update_aclgraph_sizes(vllm_config) elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: logger.info( @@ -453,6 +450,10 @@ class NPUPlatform(Platform): def is_pin_memory_available(cls): return True + @classmethod + def opaque_attention_op(cls) -> bool: + return True + @classmethod def get_static_graph_wrapper_cls(cls) -> str: """ diff --git a/vllm_ascend/torchair/models/qwen2.py b/vllm_ascend/torchair/models/qwen2.py index 6e4990d7..a5a198e0 100644 --- a/vllm_ascend/torchair/models/qwen2.py +++ b/vllm_ascend/torchair/models/qwen2.py @@ -125,7 +125,6 @@ class CustomQwen2Attention(Qwen2Attention): v, kv_cache=kv_cache, attn_metadata=attn_metadata, - trace_flag=False, **forward_kwargs) output, _ = self.o_proj(attn_output) return output diff --git a/vllm_ascend/torchair/models/qwen3_moe.py b/vllm_ascend/torchair/models/qwen3_moe.py index 47508c40..0c90412c 100644 --- a/vllm_ascend/torchair/models/qwen3_moe.py +++ b/vllm_ascend/torchair/models/qwen3_moe.py @@ -257,7 +257,6 @@ class CustomQwen3MoeAttention(Qwen3MoeAttention): v, kv_cache=kv_cache, attn_metadata=attn_metadata, - trace_flag=False, **forward_kwargs) output, _ = self.o_proj(attn_output) return output diff --git a/vllm_ascend/torchair/models/torchair_pangu_moe.py b/vllm_ascend/torchair/models/torchair_pangu_moe.py index 195ffded..7a0c9c06 100644 --- a/vllm_ascend/torchair/models/torchair_pangu_moe.py +++ b/vllm_ascend/torchair/models/torchair_pangu_moe.py @@ -625,7 +625,7 @@ class PanguProMoEAttention(nn.Module): q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) if self.torchair_graph_enabled: - forward_kwargs = {'trace_flag': False} + forward_kwargs = {} output_shape = q.shape attn_output = torch.empty(output_shape, dtype=q.dtype, diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py index 3d3177a0..dd931ed3 100644 --- a/vllm_ascend/torchair/torchair_attention.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -314,7 +314,6 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl): kv_cache: torch.Tensor, attn_metadata: AscendTorchairMetadata, output: Optional[torch.Tensor] = None, - trace_flag: bool = False, ) -> torch.Tensor: """Forward pass with Ascend attention. Args: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6224eacb..1f2ba64c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -3740,7 +3740,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): splitting_ops_contain_attention = ( self.compilation_config.splitting_ops is not None and all(op in self.compilation_config.splitting_ops for op in [ - "vllm.unified_ascend_attention_with_output", "vllm.mla_forward", ]))