diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index e2c83e41..cc3ae851 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -236,9 +236,9 @@ class TestAscendAttentionBackendImpl(TestBase): @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu.npu_fused_infer_attention_score') @patch('vllm_ascend.attention.attention_v1.get_forward_context') - def test_forward_prefill(self, mock_get_forward_context, - mock_npu_fused_infer_attention_score, - mock_npu_reshape_and_cache): + def test_forward_fused_infer_attention( + self, mock_get_forward_context, + mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache): """Test forward pass in PrefillCacheHit state""" query = torch.randn(10, 8, 64) key = torch.randn(10, 8, 64) @@ -268,28 +268,31 @@ class TestAscendAttentionBackendImpl(TestBase): mock_npu_fused_infer_attention_score.assert_called_once() assert output.shape == (10, 8, 64) + @patch('vllm_ascend.attention.attention_v1.using_paged_attention') @patch('torch_npu._npu_paged_attention') @patch('torch_npu._npu_reshape_and_cache') @patch('vllm_ascend.attention.attention_v1.get_forward_context') - def test_forward_decode_only(self, mock_get_forward_context, - mock_npu_reshape_and_cache, - mock_paged_attention): + def test_forward_paged_attention(self, mock_get_forward_context, + mock_npu_reshape_and_cache, + mock_paged_attention, + mock_using_paged_attention): """Test forward pass in DecodeOnly state""" - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) + query = torch.randn(4, 8 * 64) + key = torch.randn(4, 8 * 64) + value = torch.randn(4, 8 * 64) kv_cache = torch.empty(2, 5, 128, 8, 64) output = torch.empty_like(query) metadata = self.attn_metadata metadata.attn_state = AscendAttentionState.DecodeOnly - metadata.seq_lens = torch.tensor([10]) + metadata.seq_lens = torch.tensor([4]) metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) - metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) - metadata.num_decodes = 10 + metadata.num_actual_tokens = 4 + metadata.slot_mapping = torch.zeros(4, dtype=torch.long) + metadata.num_decodes = 4 metadata.num_prefills = 0 layer = self.layer_no_quant + mock_using_paged_attention.return_value = True mock_get_forward_context.return_value = MagicMock(capturing=False) @@ -297,7 +300,7 @@ class TestAscendAttentionBackendImpl(TestBase): metadata, output) mock_paged_attention.assert_called_once() - assert output.shape == (10, 8 * 64) + assert output.shape == (4, 8 * 64) @patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('torch_npu.npu_fused_infer_attention_score') @@ -339,9 +342,9 @@ class TestAscendAttentionBackendImpl(TestBase): self, mock_npu_reshape_and_cache, mock_fused_infer_attention_score, mock_paged_attention, mock_get_forward_context): """Test forward pass in DecodeOnly state when seq)len_mismatch""" - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) + query = torch.randn(10, 8, 64) + key = torch.randn(10, 8, 64) + value = torch.randn(10, 8, 64) kv_cache = torch.empty(2, 5, 128, 8, 64) output = torch.empty_like(query) @@ -354,6 +357,7 @@ class TestAscendAttentionBackendImpl(TestBase): layer = self.layer_no_quant metadata.num_decodes = 10 metadata.num_prefills = 0 + metadata.actual_seq_lengths_q = [10] mock_get_forward_context.return_value = MagicMock(capturing=False) @@ -363,10 +367,10 @@ class TestAscendAttentionBackendImpl(TestBase): output = self.impl_swa.forward(layer, query, key, value, kv_cache, metadata, output) - mock_paged_attention.assert_called_once() - mock_fused_infer_attention_score.assert_not_called() + mock_paged_attention.assert_not_called() + mock_fused_infer_attention_score.assert_called_once() - assert output.shape == (10, 8 * 64) + assert output.shape == (10, 8, 64) @patch('torch_npu._npu_reshape_and_cache') def test_forward_raise_error(self, mock_paged_attention): diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 875dd432..003da21f 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -386,41 +386,11 @@ class AscendAttentionBackendImpl(AttentionImpl): self.key_cache = None self.value_cache = None - def full_graph_attention(self, query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, - attn_metadata: AscendMetadata, - output: torch.Tensor) -> torch.Tensor: - if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - block_size = 128 - block_table = None - actual_seq_lengths_kv = attn_metadata.query_start_loc_list - elif attn_metadata.attn_state == \ - AscendAttentionState.PrefillCacheHit: - batch_size = attn_metadata.query_lens.shape[0] - block_table = attn_metadata.block_tables[:batch_size, :] - num_block, block_size, _, _ = self.key_cache.shape # type: ignore - key = self.key_cache.view( # type: ignore - num_block, block_size, -1) - value = self.value_cache.view( # type: ignore - num_block, block_size, -1) - actual_seq_lengths_kv = attn_metadata.seq_lens_list - elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - num_block, block_size, _, _ = self.key_cache.shape # type: ignore - key = self.key_cache.view( # type: ignore - num_block, block_size, -1) - value = self.value_cache.view( # type: ignore - num_block, block_size, -1) - block_table = attn_metadata.block_tables - actual_seq_lengths_kv = attn_metadata.seq_lens_list - # Normal V1 situation. - else: - num_block, block_size, _, _ = self.key_cache.shape # type: ignore - key = self.key_cache.view( # type: ignore - num_block, block_size, -1) - value = self.value_cache.view( # type: ignore - num_block, block_size, -1) - block_table = attn_metadata.block_tables - actual_seq_lengths_kv = attn_metadata.seq_lens_list + def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, attn_metadata: AscendMetadata, + output: torch.Tensor) -> torch.Tensor: + key, value, block_size, block_table, actual_seq_lengths_kv \ + = self._get_fia_params(key, value, attn_metadata) num_tokens = attn_metadata.query_start_loc_list[-1] graph_params = get_graph_params() @@ -489,7 +459,7 @@ class AscendAttentionBackendImpl(AttentionImpl): graph_params.handles[num_tokens].append(handle) return output, num_tokens - def full_graph_attention_with_pa( + def full_graph_pa( self, query: torch.Tensor, attn_metadata: AscendMetadata, @@ -550,13 +520,13 @@ class AscendAttentionBackendImpl(AttentionImpl): graph_params.handles[num_tokens].append(handle) return output - def _forward_prefill(self, query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, attn_metadata: AscendMetadata, - output: torch.Tensor): + def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor, + attn_metadata: AscendMetadata): + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: block_size = 128 block_table = None - actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q + actual_seq_lengths_kv = attn_metadata.query_start_loc_list elif attn_metadata.attn_state == \ AscendAttentionState.PrefillCacheHit: batch_size = attn_metadata.query_lens.shape[0] @@ -567,7 +537,15 @@ class AscendAttentionBackendImpl(AttentionImpl): value = self.value_cache.view( # type: ignore num_block, block_size, -1) actual_seq_lengths_kv = attn_metadata.seq_lens_list - # chunked_prefill. + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + num_block, block_size, _, _ = self.key_cache.shape # type: ignore + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + block_table = attn_metadata.block_tables + actual_seq_lengths_kv = attn_metadata.seq_lens_list + # chunked prefill. else: num_block, block_size, _, _ = self.key_cache.shape # type: ignore key = self.key_cache.view( # type: ignore @@ -576,12 +554,57 @@ class AscendAttentionBackendImpl(AttentionImpl): num_block, block_size, -1) block_table = attn_metadata.block_tables actual_seq_lengths_kv = attn_metadata.seq_lens_list + return key, value, block_size, block_table, actual_seq_lengths_kv + def _forward_fia_slidingwindow(self, query: torch.Tensor, + attn_metadata: AscendMetadata, + output: torch.Tensor): + batch_size = attn_metadata.seq_lens.shape[0] + block_size = 128 + query = query.view(batch_size, 1, self.num_heads * self.head_size) + key = self.key_cache + value = self.value_cache + if self.key_cache is not None and self.value_cache is not None: + block_size = self.key_cache.shape[1] + key = self.key_cache.flatten(2, 3).contiguous() + value = self.value_cache.flatten(2, 3).contiguous() + + output, _ = torch_npu.npu_fused_infer_attention_score( + query, + key, + value, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="BSH", + block_size=block_size, + pre_tokens=self.sliding_window, + scale=self.scale, + block_table=attn_metadata.block_tables, + actual_seq_lengths=[1] * len(attn_metadata.seq_lens), + actual_seq_lengths_kv=attn_metadata.seq_lens) + + output = output.view(batch_size, self.num_heads, self.head_size) + return output + + def forward_fused_infer_attention(self, query: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + attn_metadata: AscendMetadata, + output: torch.Tensor): + forward_context: ForwardContext = get_forward_context() + if forward_context.capturing: + attn_output, num_tokens = self.full_graph_fia( + query, key, value, attn_metadata, output) + output[:num_tokens] = attn_output[:num_tokens] + return output + if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly + and self.sliding_window is not None + and attn_metadata.seq_lens.shape[0] == query.size(0)): + return self._forward_fia_slidingwindow(query, attn_metadata, + output) + key, value, block_size, block_table, actual_seq_lengths_kv \ + = self._get_fia_params(key, value, attn_metadata) num_tokens = attn_metadata.actual_seq_lengths_q[-1] query = query[:num_tokens] - # Prepare tensors for attention output - # TODO: Refactor this to step-level instead of layer-level - # Get workspace from cache or calculate it if not present. attn_output, _ = torch_npu.npu_fused_infer_attention_score( query=query, @@ -604,83 +627,24 @@ class AscendAttentionBackendImpl(AttentionImpl): output[:num_tokens] = attn_output[:num_tokens] return output - def _forward_decode_only_ascend91095( - self, - query: torch.Tensor, - attn_metadata: AscendMetadata, - output: torch.Tensor, - ) -> torch.Tensor: - batch_size = attn_metadata.query_lens.shape[0] - num_block, block_size, _, _ = self.key_cache.shape # type: ignore - key = self.key_cache.view( # type: ignore - num_block, block_size, -1) - value = self.value_cache.view( # type: ignore - num_block, block_size, -1) - actual_seq_lengths_kv = attn_metadata.seq_lens_list - - attn_output, _ = torch_npu.npu_fused_infer_attention_score( - query=query, - key=key, - value=value, - block_table=attn_metadata.block_tables, - input_layout="TND", - block_size=block_size, - actual_seq_lengths=attn_metadata.actual_seq_lengths_q, - actual_seq_lengths_kv=actual_seq_lengths_kv, - num_key_value_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale=self.scale, - ) - output[:batch_size] = attn_output[:batch_size] - return output - - def _forward_decode_only( + def forward_paged_attention( self, query: torch.Tensor, attn_metadata: AscendMetadata, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if get_ascend_device_type() == AscendDeviceType.A5: - return self._forward_decode_only_ascend91095( - query, attn_metadata, output) - if self.sliding_window is not None and attn_metadata.seq_lens.shape[ - 0] == query.size(0): - batch_size = attn_metadata.seq_lens.shape[0] - block_size = 128 - query = query.view(batch_size, 1, self.num_heads * self.head_size) - key = self.key_cache - value = self.value_cache - if self.key_cache is not None and self.value_cache is not None: - block_size = self.key_cache.shape[1] - key = self.key_cache.flatten(2, 3).contiguous() - value = self.value_cache.flatten(2, 3).contiguous() - - output, _ = torch_npu.npu_fused_infer_attention_score( - query, - key, - value, - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - input_layout="BSH", - block_size=block_size, - pre_tokens=self.sliding_window, - scale=self.scale, - block_table=attn_metadata.block_tables, - actual_seq_lengths=[1] * len(attn_metadata.seq_lens), - actual_seq_lengths_kv=attn_metadata.seq_lens) - - output = output.view(batch_size, self.num_heads, self.head_size) - else: - torch_npu._npu_paged_attention( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output) + forward_context: ForwardContext = get_forward_context() + if forward_context.capturing: + return self.full_graph_pa(query, attn_metadata, output) + torch_npu._npu_paged_attention(query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) return output def _forward_encoder_attention(self, query: torch.Tensor, @@ -757,23 +721,14 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_metadata: AscendMetadata, output: torch.Tensor, ): - forward_context: ForwardContext = get_forward_context() - if not forward_context.capturing: - if attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - output = self._forward_decode_only(query, attn_metadata, - output) - else: - output = self._forward_prefill(query, key, value, - attn_metadata, output) + num_tokens = query.shape[0] + if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly + and using_paged_attention(num_tokens) + and self.sliding_window is None): + output = self.forward_paged_attention(query, attn_metadata, output) else: - num_tokens = query.shape[0] - if using_paged_attention(num_tokens): - output = self.full_graph_attention_with_pa( - query, attn_metadata, output) - else: - attn_output, num_tokens = self.full_graph_attention( - query, key, value, attn_metadata, output) - output[:num_tokens] = attn_output[:num_tokens] + output = self.forward_fused_infer_attention( + query, key, value, attn_metadata, output) return output diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 454190a5..08a17fbc 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -10,7 +10,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context -from vllm_ascend.utils import get_ascend_config +from vllm_ascend.utils import (AscendDeviceType, get_ascend_config, + get_ascend_device_type) @lru_cache @@ -18,8 +19,11 @@ def using_paged_attention(runtime_shape: int) -> bool: vllm_config = get_current_vllm_config() if vllm_config.speculative_config is not None: return False + if get_ascend_device_type() == AscendDeviceType.A5: + return False from vllm.config.compilation import CUDAGraphMode - if vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY: + cudagraph_mode = vllm_config.compilation_config.cudagraph_mode + if cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY: return False return runtime_shape in get_ascend_config().pa_shape_list