[Refactor] 4/N Distinguish the branches based on the applicable scenarios of PA and FIA Ops. (#5081)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629

Reason:

We distinguish the branches based on the applicable scenarios of
pagedAttention and fusedInferAttention, making the code more clear.

At the same time, it is convenient for the subsequent iterations of
sliding_window and sinks and removePA ops after FIA is ready.

Todo:

remove PA ops after FIA is ready
add slidingwindow and ops for gpt_oss
replace FIA with FIA_v2
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-12-17 23:14:02 +08:00
committed by GitHub
parent 7671ce1bf1
commit 98e6e57622
3 changed files with 117 additions and 154 deletions

View File

@@ -236,9 +236,9 @@ class TestAscendAttentionBackendImpl(TestBase):
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score')
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
def test_forward_prefill(self, mock_get_forward_context,
mock_npu_fused_infer_attention_score,
mock_npu_reshape_and_cache):
def test_forward_fused_infer_attention(
self, mock_get_forward_context,
mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache):
"""Test forward pass in PrefillCacheHit state"""
query = torch.randn(10, 8, 64)
key = torch.randn(10, 8, 64)
@@ -268,28 +268,31 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_npu_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8, 64)
@patch('vllm_ascend.attention.attention_v1.using_paged_attention')
@patch('torch_npu._npu_paged_attention')
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
def test_forward_decode_only(self, mock_get_forward_context,
mock_npu_reshape_and_cache,
mock_paged_attention):
def test_forward_paged_attention(self, mock_get_forward_context,
mock_npu_reshape_and_cache,
mock_paged_attention,
mock_using_paged_attention):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
query = torch.randn(4, 8 * 64)
key = torch.randn(4, 8 * 64)
value = torch.randn(4, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([4])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 10
metadata.num_actual_tokens = 4
metadata.slot_mapping = torch.zeros(4, dtype=torch.long)
metadata.num_decodes = 4
metadata.num_prefills = 0
layer = self.layer_no_quant
mock_using_paged_attention.return_value = True
mock_get_forward_context.return_value = MagicMock(capturing=False)
@@ -297,7 +300,7 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata, output)
mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
assert output.shape == (4, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu.npu_fused_infer_attention_score')
@@ -339,9 +342,9 @@ class TestAscendAttentionBackendImpl(TestBase):
self, mock_npu_reshape_and_cache, mock_fused_infer_attention_score,
mock_paged_attention, mock_get_forward_context):
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
query = torch.randn(10, 8, 64)
key = torch.randn(10, 8, 64)
value = torch.randn(10, 8, 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
@@ -354,6 +357,7 @@ class TestAscendAttentionBackendImpl(TestBase):
layer = self.layer_no_quant
metadata.num_decodes = 10
metadata.num_prefills = 0
metadata.actual_seq_lengths_q = [10]
mock_get_forward_context.return_value = MagicMock(capturing=False)
@@ -363,10 +367,10 @@ class TestAscendAttentionBackendImpl(TestBase):
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_paged_attention.assert_called_once()
mock_fused_infer_attention_score.assert_not_called()
mock_paged_attention.assert_not_called()
mock_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8 * 64)
assert output.shape == (10, 8, 64)
@patch('torch_npu._npu_reshape_and_cache')
def test_forward_raise_error(self, mock_paged_attention):