[Refactor] 4/N Distinguish the branches based on the applicable scenarios of PA and FIA Ops. (#5081)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
We distinguish the branches based on the applicable scenarios of
pagedAttention and fusedInferAttention, making the code more clear.
At the same time, it is convenient for the subsequent iterations of
sliding_window and sinks and removePA ops after FIA is ready.
Todo:
remove PA ops after FIA is ready
add slidingwindow and ops for gpt_oss
replace FIA with FIA_v2
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -236,9 +236,9 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
def test_forward_prefill(self, mock_get_forward_context,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache):
|
||||
def test_forward_fused_infer_attention(
|
||||
self, mock_get_forward_context,
|
||||
mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache):
|
||||
"""Test forward pass in PrefillCacheHit state"""
|
||||
query = torch.randn(10, 8, 64)
|
||||
key = torch.randn(10, 8, 64)
|
||||
@@ -268,28 +268,31 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||
assert output.shape == (10, 8, 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.using_paged_attention')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
def test_forward_decode_only(self, mock_get_forward_context,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_paged_attention):
|
||||
def test_forward_paged_attention(self, mock_get_forward_context,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_paged_attention,
|
||||
mock_using_paged_attention):
|
||||
"""Test forward pass in DecodeOnly state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
query = torch.randn(4, 8 * 64)
|
||||
key = torch.randn(4, 8 * 64)
|
||||
value = torch.randn(4, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([4])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
metadata.num_decodes = 10
|
||||
metadata.num_actual_tokens = 4
|
||||
metadata.slot_mapping = torch.zeros(4, dtype=torch.long)
|
||||
metadata.num_decodes = 4
|
||||
metadata.num_prefills = 0
|
||||
layer = self.layer_no_quant
|
||||
mock_using_paged_attention.return_value = True
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
@@ -297,7 +300,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
metadata, output)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
assert output.shape == (4, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@@ -339,9 +342,9 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
self, mock_npu_reshape_and_cache, mock_fused_infer_attention_score,
|
||||
mock_paged_attention, mock_get_forward_context):
|
||||
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
query = torch.randn(10, 8, 64)
|
||||
key = torch.randn(10, 8, 64)
|
||||
value = torch.randn(10, 8, 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
@@ -354,6 +357,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
layer = self.layer_no_quant
|
||||
metadata.num_decodes = 10
|
||||
metadata.num_prefills = 0
|
||||
metadata.actual_seq_lengths_q = [10]
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
@@ -363,10 +367,10 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
mock_fused_infer_attention_score.assert_not_called()
|
||||
mock_paged_attention.assert_not_called()
|
||||
mock_fused_infer_attention_score.assert_called_once()
|
||||
|
||||
assert output.shape == (10, 8 * 64)
|
||||
assert output.shape == (10, 8, 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_forward_raise_error(self, mock_paged_attention):
|
||||
|
||||
Reference in New Issue
Block a user