diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index 129b5410..3a94e9e8 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -25,12 +25,6 @@ class TestAscendAttentionBackend(TestBase): self.assertEqual(AscendAttentionBackend.get_builder_cls(), AscendAttentionMetadataBuilder) - @patch('vllm_ascend.attention.attention_v1.get_ascend_device_type', - return_value=AscendDeviceType._310P) - def test_get_kv_cache_shape_310p(self, mock_soc_version): - result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40) - self.assertEqual(result, (2, 10, 30 * 40 // 16, 20, 16)) - @patch('vllm_ascend.utils.get_ascend_device_type', return_value=AscendDeviceType._910_93) def test_get_kv_cache_shape_not_310p(self, mock_soc_version): @@ -95,76 +89,6 @@ class TestAscendAttentionMetadataBuilder(TestBase): self.assertFalse(result) - @patch('vllm_ascend.attention.attention_v1.AscendMetadata') - @patch('torch_npu.npu_format_cast') - @patch('vllm_ascend.utils.nd_to_nz_2d') - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType._310P) - def test_build_prefill_no_cache(self, mock_soc_version, mock_nd_to_nz_2d, - mock_npu_format_cast, - mock_ascend_metadata): - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=torch.tensor([0, 3, 7]), - query_start_loc_cpu=torch.tensor([0, 3, 7]), - seq_lens_cpu=torch.tensor([5, 6]), - num_reqs=2, - num_actual_tokens=10, - max_query_len=5, - decode_token_per_req=torch.tensor([1, 1]), - block_table_tensor=torch.zeros((10, 10)), - slot_mapping=torch.tensor(range(20)), - actual_seq_lengths_q=torch.tensor([0, 1]), - positions=torch.tensor([10, 10]), - attn_mask=torch.ones((10, 10)), - spec_attn_mask=None, - attn_state=AscendAttentionState.PrefillNoCache, - num_computed_tokens_cpu=None, - seq_lens=None) - - mock_nz_tensor = MagicMock() - mock_model = MagicMock() - mock_nd_to_nz_2d.return_value = mock_nz_tensor - mock_npu_format_cast.return_value = mock_nz_tensor - - self.builder.build(1, common_attn_metadata, mock_model) - - @patch('vllm_ascend.attention.attention_v1.AscendMetadata') - @patch('torch_npu.npu_format_cast') - @patch('vllm_ascend.utils.nd_to_nz_spec') - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType._310P) - @patch('vllm_ascend.attention.attention_v1.AscendAttentionState') - def test_build_chunked_prefill(self, mock_ascend_attention_state, - mock_soc_version, mock_nd_to_nz_spec, - mock_npu_format_cast, mock_ascend_metadata): - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=torch.tensor([0, 2, 5, 9]), - query_start_loc_cpu=torch.tensor([0, 2, 5, 9]), - seq_lens_cpu=torch.tensor([4, 5, 6]), - num_reqs=3, - num_actual_tokens=15, - max_query_len=6, - decode_token_per_req=torch.tensor([1, 1, 1]), - block_table_tensor=torch.zeros((10, 10)), - slot_mapping=torch.tensor(range(20)), - actual_seq_lengths_q=torch.tensor([0, 1, 2]), - positions=torch.tensor([10, 10]), - attn_mask=torch.ones((15, 15)), - spec_attn_mask=None, - attn_state=AscendAttentionState.ChunkedPrefill, - num_computed_tokens_cpu=None, - seq_lens=None) - - mock_ascend_attention_state = MagicMock() - mock_ascend_attention_state.PrefillNoCache = 0 - - mock_nz_tensor = MagicMock() - mock_model = MagicMock() - mock_nd_to_nz_spec.return_value = mock_nz_tensor - mock_npu_format_cast.return_value = mock_nz_tensor - - self.builder.build(1, common_attn_metadata, mock_model) - @patch('vllm_ascend.attention.attention_v1.AscendMetadata') @patch('vllm_ascend.utils.get_ascend_device_type', return_value=AscendDeviceType._910_93) @@ -286,73 +210,40 @@ class TestAscendAttentionBackendImpl(TestBase): assert output.shape == (10, 8 * 64) - @patch('vllm_ascend.attention.attention_v1.get_forward_context') - @patch('torch_npu._npu_reshape_and_cache') - @patch('torch_npu._npu_flash_attention') - def test_forward_prefill_no_cache(self, mock_flash_attention, - mock_reshape_cache, - mock_get_forward_context): - """Test forward pass in PrefillNoCache state""" - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) - kv_cache = torch.empty(2, 5, 128, 8, 64) - output = torch.empty_like(query) - - mock_get_forward_context.return_value = MagicMock(capturing=False) - - metadata = self.attn_metadata - metadata.attn_state = AscendAttentionState.PrefillNoCache - metadata.attn_mask = torch.randn(1, 1, 10, 10) - metadata.seq_lens = torch.tensor([10]) - metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) - metadata.num_decodes = 0 - metadata.num_prefills = 10 - layer = self.layer_no_quant - - output = self.impl.forward(layer, query, key, value, kv_cache, - metadata, output) - - mock_reshape_cache.assert_called_once() - mock_flash_attention.assert_called_once() - assert output.shape == (10, 8 * 64) - @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu.npu_fused_infer_attention_score') @patch('vllm_ascend.attention.attention_v1.get_forward_context') - def test_forward_prefill_cache_hit(self, mock_get_forward_context, - mock_npu_fused_infer_attention_score, - mock_npu_reshape_and_cache): + def test_forward_prefill(self, mock_get_forward_context, + mock_npu_fused_infer_attention_score, + mock_npu_reshape_and_cache): """Test forward pass in PrefillCacheHit state""" - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) + query = torch.randn(10, 8, 64) + key = torch.randn(10, 8, 64) + value = torch.randn(10, 8, 64) kv_cache = torch.empty(2, 5, 128, 8, 64) output = torch.empty_like(query) - metadata = self.attn_metadata metadata.attn_state = AscendAttentionState.PrefillCacheHit metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.query_lens = torch.tensor([10]) metadata.seq_lens = torch.tensor([10]) + metadata.actual_seq_lengths_q = [10] metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + metadata.num_decode_tokens = 0 metadata.num_decodes = 0 metadata.num_prefills = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) layer = self.layer_no_quant mock_get_forward_context.return_value = MagicMock(capturing=False) - mock_npu_fused_infer_attention_score.return_value = (output, - torch.ones( - 10, 8, 64)) - + mock_npu_fused_infer_attention_score.return_value = (torch.ones( + 10, 8, 64), torch.ones(10, 8, 64)) output = self.impl.forward(layer, query, key, value, kv_cache, metadata, output) mock_npu_fused_infer_attention_score.assert_called_once() - assert output.shape == (10, 8 * 64) + assert output.shape == (10, 8, 64) @patch('torch_npu._npu_paged_attention') @patch('torch_npu._npu_reshape_and_cache') @@ -454,119 +345,6 @@ class TestAscendAttentionBackendImpl(TestBase): assert output.shape == (10, 8 * 64) - @patch('vllm_ascend.attention.attention_v1.get_forward_context') - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType._910_93) - @patch('torch_npu._npu_reshape_and_cache') - @patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill') - def test_forward_head_size_192(self, mock_vanilla_prefill, - mock_npu_reshape_and_cache, - mock_soc_version, mock_get_forward_context): - """Test forward pass when head_size is 192""" - - self.impl.head_size = 192 - query = torch.randn(10, 8 * 192) - key = torch.randn(10, 8 * 192) - value = torch.randn(10, 8 * 192) - kv_cache = torch.empty(2, 5, 128, 8, 192) - output = torch.empty_like(query) - - mock_get_forward_context.return_value = MagicMock(capturing=False) - - metadata = self.attn_metadata - metadata.attn_mask = torch.randn(1, 1, 10, 10) - metadata.query_lens = torch.tensor([10]) - metadata.seq_lens = torch.tensor([10]) - metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) - metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) - metadata.num_decodes = 10 - metadata.num_prefills = 0 - layer = self.layer_no_quant - mock_vanilla_prefill.return_value = MagicMock() - - output = self.impl_192.forward(layer, query, key, value, kv_cache, - metadata, output) - - mock_vanilla_prefill.assert_called_once() - assert output.shape == (10, 8 * 192) - - @patch('vllm_ascend.attention.attention_v1.get_forward_context') - @patch('torch_npu.npu_fused_infer_attention_score') - @patch('torch_npu._npu_reshape_and_cache') - def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache, - mock_npu_fused_infer_attention_score, - mock_get_forward_context): - """Test forward pass in normal V1 situation""" - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) - kv_cache = torch.empty(2, 5, 128, 8, 64) - output = torch.empty_like(query) - - metadata = self.attn_metadata - metadata.attn_mask = torch.randn(1, 1, 10, 10) - metadata.query_lens = torch.tensor([10]) - metadata.seq_lens = torch.tensor([10]) - metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) - metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) - metadata.num_decodes = 0 - metadata.num_prefills = 10 - layer = self.layer_no_quant - mock_get_forward_context.return_value = MagicMock(capturing=False) - mock_npu_fused_infer_attention_score.return_value = (output, - torch.ones( - 10, 8, 64)) - - output = self.impl.forward(layer, query, key, value, kv_cache, - metadata, output) - - mock_npu_fused_infer_attention_score.assert_called_once() - assert output.shape == (10, 8 * 64) - - @patch('torch_npu.npu_format_cast') - @patch('torch_npu._npu_reshape_and_cache') - @patch('torch_npu.npu_fused_infer_attention_score') - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType._310P) - @patch('vllm_ascend.attention.attention_v1.get_forward_context') - def test_forward_310p_device(self, mock_get_forward_context, - mock_soc_version, - mock_npu_fused_infer_attention_score, - mock_npu_reshape_and_cache, - mock_npu_format_cast): - """Test forward pass on 310P device""" - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) - kv_cache = torch.empty(2, 5, 128, 8, 64) - output = torch.empty_like(query) - - metadata = self.attn_metadata - metadata.attn_mask = torch.randn(1, 1, 10, 10) - metadata.query_lens = torch.tensor([10]) - metadata.seq_lens = torch.tensor([10]) - metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) - metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) - metadata.num_decodes = 0 - metadata.num_prefills = 10 - layer = self.layer_no_quant - - mock_npu_format_cast.return_value = metadata.attn_mask - - mock_get_forward_context.return_value = MagicMock(capturing=False) - mock_npu_fused_infer_attention_score.return_value = (output, - torch.ones( - 10, 8, 64)) - - output = self.impl.forward(layer, query, key, value, kv_cache, - metadata, output) - - mock_npu_fused_infer_attention_score.assert_called_once() - assert output.shape == (10, 8 * 64) - @patch('torch_npu._npu_reshape_and_cache') def test_forward_raise_error(self, mock_paged_attention): query = torch.randn(10, 8 * 64) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 1d9139c5..0cb2b75c 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -41,11 +41,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, split_decodes_and_prefills) from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) -from vllm_ascend.ops.attention import vanilla_chunked_prefill -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, - aligned_16, get_ascend_device_type, nd_to_nz_2d, - nd_to_nz_spec, prefill_context_parallel_enable, - weak_ref_tensors) +from vllm_ascend.utils import prefill_context_parallel_enable, weak_ref_tensors # isort: off if prefill_context_parallel_enable(): @@ -83,9 +79,6 @@ class AscendAttentionBackend(AttentionBackend): num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - if get_ascend_device_type() == AscendDeviceType._310P: - return (2, num_blocks, num_kv_heads * head_size // 16, block_size, - 16) return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod @@ -351,16 +344,6 @@ class AscendAttentionMetadataBuilder: query_start_loc = query_start_loc_cpu.to(self.device, non_blocking=True) - if get_ascend_device_type() == AscendDeviceType._310P: - if attn_state == AscendAttentionState.PrefillNoCache: - mask_nz = nd_to_nz_2d(attn_mask) - attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), - ACL_FORMAT_FRACTAL_NZ) - elif attn_state == AscendAttentionState.ChunkedPrefill: - mask_nz = nd_to_nz_spec(attn_mask) - attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), - ACL_FORMAT_FRACTAL_NZ) - common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata prefill_metadata = None decode_metadata = None @@ -585,9 +568,9 @@ class AscendAttentionBackendImpl(AttentionImpl): output: torch.Tensor, num_tokens=0): if self.pcp_size * self.dcp_size > 1: - intermediate_output = self._forward_pcp_dcp( - query, key, value, kv_cache, attn_metadata, output) - return intermediate_output, query.shape[0] + attn_output = self._forward_pcp_dcp(query, key, value, kv_cache, + attn_metadata, output) + return attn_output, query.shape[0] elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: block_size = 128 block_table = None @@ -688,93 +671,58 @@ class AscendAttentionBackendImpl(AttentionImpl): graph_params.handles[num_tokens].append(handle) return output, num_tokens - def _forward_prefill_no_cache( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_metadata: AscendMetadata, - output: Optional[torch.Tensor] = None, - num_tokens=0, - ) -> torch.Tensor: - assert attn_metadata is not None - assert attn_metadata.attn_mask is not None - - mask = attn_metadata.attn_mask - - if get_ascend_device_type() == AscendDeviceType._310P: - # align q k v output tensors - query = aligned_16(query) - key = aligned_16(key) - value = aligned_16(value) - output = aligned_16(output) - # do reformat in case of broadcasted tensors - mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1) - mask = torch_npu.npu_format_cast(mask.contiguous(), - ACL_FORMAT_FRACTAL_NZ) - - torch_npu._npu_flash_attention(query=query, - key=key, - value=value, - mask=mask, - seq_len=attn_metadata.seq_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - out=output) - assert output is not None - return output[:num_tokens] - - def _forward_prefill_cache_hit( - self, - query: torch.Tensor, - attn_metadata: AscendMetadata, - output: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert attn_metadata is not None - assert attn_metadata.attn_mask is not None - - compress_mask = attn_metadata.attn_mask - batch_size = attn_metadata.query_lens.shape[0] - block_table = attn_metadata.block_tables[:batch_size, :] - num_block, block_size, _, _ = self.key_cache.shape # type: ignore - - if block_size == 128: - # TODO:The npu_fused_infer_attention_score op is planned to - # be utilized in a wider range in upcoming versions. + def _forward_prefill(self, query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, attn_metadata: AscendMetadata, + output: torch.Tensor): + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + block_size = 128 + block_table = None + actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q + elif attn_metadata.attn_state == \ + AscendAttentionState.PrefillCacheHit: + batch_size = attn_metadata.query_lens.shape[0] + block_table = attn_metadata.block_tables[:batch_size, :] + num_block, block_size, _, _ = self.key_cache.shape # type: ignore key = self.key_cache.view( # type: ignore num_block, block_size, -1) value = self.value_cache.view( # type: ignore num_block, block_size, -1) - - output, _ = torch_npu.npu_fused_infer_attention_score( - query=query, - key=key, - value=value, - atten_mask=compress_mask, - block_table=block_table, - input_layout="TND", - block_size=block_size, - actual_seq_lengths=attn_metadata.actual_seq_lengths_q, - actual_seq_lengths_kv=attn_metadata.seq_lens_list, - num_key_value_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale=self.scale, - sparse_mode=3, - ) + actual_seq_lengths_kv = attn_metadata.seq_lens_list + # chunked_prefill. else: - torch_npu._npu_flash_attention_qlens( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - block_table=block_table, - mask=compress_mask, - seq_len=attn_metadata.query_lens, - context_lens=attn_metadata.seq_lens, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - out=output) + num_block, block_size, _, _ = self.key_cache.shape # type: ignore + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + block_table = attn_metadata.block_tables + actual_seq_lengths_kv = attn_metadata.seq_lens_list + + num_tokens = attn_metadata.actual_seq_lengths_q[-1] + query = query[:num_tokens] + # Prepare tensors for attention output + # TODO: Refactor this to step-level instead of layer-level + + # Get workspace from cache or calculate it if not present. + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.attn_mask, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=attn_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + ) + + attn_output = attn_output.view(num_tokens, self.num_heads, + self.head_size) + output[:num_tokens] = attn_output[:num_tokens] return output def _forward_decode_only( @@ -783,10 +731,6 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_metadata: AscendMetadata, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if get_ascend_device_type() == AscendDeviceType._310P: - # seq_lens_tensor needs to be transferred to the device for 310P. - attn_metadata.seq_lens = \ - attn_metadata.seq_lens.to(device=query.device) if self.sliding_window is not None and attn_metadata.seq_lens.shape[ 0] == query.size(0): batch_size = attn_metadata.seq_lens.shape[0] @@ -827,69 +771,6 @@ class AscendAttentionBackendImpl(AttentionImpl): out=output) return output - def _forward_v1_style( - self, - query: torch.Tensor, - attn_metadata: AscendMetadata, - output: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - # Use chunked prefill for head size 192 scenario, like deepseek - # paged_attention_splitfuse maybe crash at such scenario. - # TODO: vanilla path will be removed after the kernel support - # head_size 192 scenario. - if self.head_size == 192: - cu_seqlen_q = [0] + attn_metadata.query_lens.tolist() - cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist() - cu_seqlen_q = torch.tensor(cu_seqlen_q, device=query.device) - cu_seqlen_k = torch.tensor(cu_seqlen_k, device=query.device) - cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0) - cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0) - max_seqlen_q = torch.max(attn_metadata.query_lens) - max_seqlen_k = torch.max(attn_metadata.seq_lens) - vanilla_chunked_prefill(output, query, self.key_cache, - self.value_cache, - attn_metadata.block_tables, cu_seqlen_q, - cu_seqlen_k, max_seqlen_q, max_seqlen_k, - self.scale, None, True) - return output - - # Use paged attention. - assert attn_metadata is not None - assert attn_metadata.attn_mask is not None - - if get_ascend_device_type() == AscendDeviceType._310P: - # Do reformat in case of broadcasted tensors. - attn_metadata.attn_mask = \ - torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(), - ACL_FORMAT_FRACTAL_NZ) - attn_metadata.seq_lens = \ - attn_metadata.seq_lens.to(device=query.device) - - # TODO:The npu_fused_infer_attention_score op is planned to - # be utilized in a wider range in upcoming versions. - num_block, block_size, _, _ = self.key_cache.shape # type: ignore - key = self.key_cache.view( # type: ignore - num_block, block_size, -1) - value = self.value_cache.view( # type: ignore - num_block, block_size, -1) - - output, _ = torch_npu.npu_fused_infer_attention_score( - query=query, - key=key, - value=value, - atten_mask=attn_metadata.attn_mask, - block_table=attn_metadata.block_tables, - input_layout="TND", - block_size=block_size, - actual_seq_lengths=attn_metadata.actual_seq_lengths_q, - actual_seq_lengths_kv=attn_metadata.seq_lens_list, - num_key_value_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale=self.scale, - sparse_mode=3, - ) - return output - def _attention_with_nomask_and_mask(self, q: torch.Tensor, q_seqlens: List[int], k_nomask: torch.Tensor, @@ -1464,6 +1345,31 @@ class AscendAttentionBackendImpl(AttentionImpl): ) return key, value + def _forward_encode( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + output: torch.Tensor, + ) -> torch.Tensor: + cum_seq_len = attn_metadata.query_start_loc[1:].tolist() + output = torch_npu.npu_fusion_attention( + query, + key, + value, + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + sparse_mode=4, + atten_mask=attn_metadata.attn_mask, + pre_tockens=attn_metadata.max_query_len, + next_tockens=attn_metadata.max_query_len, + actual_seq_qlen=cum_seq_len, + actual_seq_kvlen=cum_seq_len, + )[0] + return output + def forward( self, layer: AttentionLayer, @@ -1494,24 +1400,16 @@ class AscendAttentionBackendImpl(AttentionImpl): "fused output quantization is not yet supported" " for AscendAttentionBackendImpl") - num_tokens = query.shape[0] - if attn_metadata is None: - return output - - # NOTE: Currently, we have various attention paths for different - # scenarios, and not all of them are in-place operations. Therefore, - # we need to create a separate tensor to hold the attention result. - # In the future, we may consolidate them into fewer paths, which will - # hopefully allow us to use in-place operation by default. - intermediate_output: torch.Tensor - assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 - attn_type = self.attn_type - if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY: + if self.attn_type != AttentionType.DECODER and self.attn_type != AttentionType.ENCODER_ONLY: raise NotImplementedError("Encoder/decoder cross-attention " "are not implemented for " "PallasAttentionBackendImpl") + num_tokens = query.shape[0] + if attn_metadata is None: + return output.fill_(0) + num_decode_tokens = attn_metadata.num_decode_tokens has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 @@ -1558,48 +1456,25 @@ class AscendAttentionBackendImpl(AttentionImpl): forward_context: ForwardContext = get_forward_context() if not forward_context.capturing: if self.pcp_size * self.dcp_size > 1: - intermediate_output = self._forward_pcp_dcp( - query, key, value, kv_cache, attn_metadata, output) - elif attn_type == AttentionType.ENCODER_ONLY: - # TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly. - cum_seq_len = attn_metadata.query_start_loc[1:].tolist() - intermediate_output = torch_npu.npu_fusion_attention( - query, - key, - value, - head_num=self.num_heads, - input_layout="TND", - scale=self.scale, - sparse_mode=4, - atten_mask=attn_metadata.attn_mask, - pre_tockens=attn_metadata.max_query_len, - next_tockens=attn_metadata.max_query_len, - actual_seq_qlen=cum_seq_len, - actual_seq_kvlen=cum_seq_len, - )[0] - # V0-Style scheduler situation. - elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - intermediate_output = self._forward_prefill_no_cache( - query, key, value, attn_metadata, output, num_tokens) - elif attn_metadata.attn_state == \ - AscendAttentionState.PrefillCacheHit: - intermediate_output = self._forward_prefill_cache_hit( - query, attn_metadata, output) - elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - intermediate_output = self._forward_decode_only( - query, attn_metadata, output) - # Normal V1 situation. + attn_output = self._forward_pcp_dcp(query, key, value, + kv_cache, attn_metadata, + output) + output[:num_tokens] = attn_output[:num_tokens] + return output + if self.attn_type == AttentionType.ENCODER_ONLY: + attn_output = self._forward_encode(query, key, value, + attn_metadata, output) + output[:num_tokens] = attn_output[:num_tokens] + return output + if attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + output = self._forward_decode_only(query, attn_metadata, + output) else: - # npu_fused_infer_attention_score does not support cases - # where query.shape[0] != attn_metadata.query_start_loc[-1]. - # Thus we need unpad it here. - num_tokens = attn_metadata.query_start_loc[-1] - query = query[:num_tokens] - intermediate_output = self._forward_v1_style( - query, attn_metadata, output) + output = self._forward_prefill(query, key, value, + attn_metadata, output) else: - intermediate_output, num_tokens = self.full_graph_attention( + attn_output, num_tokens = self.full_graph_attention( query, key, value, kv_cache, attn_metadata, output) - output[:num_tokens] = intermediate_output[:num_tokens] + output[:num_tokens] = attn_output[:num_tokens] return output diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f5c3bb35..37fb4381 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -979,25 +979,21 @@ class NPUModelRunner(LoRAModelRunnerMixin): # dcp situation. if self.dcp_size > 1: return self.attn_mask_builder.get_splitfuse_attn_mask() + if self.vllm_config.model_config.use_mla: + return None # Pooling situation. if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS": return self.attn_mask_builder.get_pooling_mask(self.device) - # Chunk Prefill situation. - elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse: + # fia prefill situation. + if attn_state in [ + AscendAttentionState.PrefillNoCache, + AscendAttentionState.PrefillCacheHit, + AscendAttentionState.ChunkedPrefill + ]: return self.attn_mask_builder.get_splitfuse_attn_mask() - # Prefill without cache situation. - elif attn_state == AscendAttentionState.PrefillNoCache: - max_seq_len = max(seq_lens.max().item(), 0) - return self.attn_mask_builder.get_attn_mask( - max_seq_len, self.dtype, self.device) - # Prefill with cache hit. - elif attn_state == AscendAttentionState.PrefillCacheHit: - return self.attn_mask_builder.get_splitfuse_attn_mask().to( - torch.bool) # Decode-only situation. - else: - return None + return None def _make_fia_attention_mask(self) -> torch.Tensor: # pcp situation.