[Refactor] Remove redundant attention operator branches. (#4531)

[Refactor] Remove redundant attention operator branches.

Reason:

We replace other attention ops with fused_infer_attention_score expect
decode_only state.
clean code and remove 310P support.

https://github.com/vllm-project/vllm-ascend/pull/4455


- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-12-02 09:13:26 +08:00
committed by GitHub
parent 981a14f8d5
commit b4bf01ead1
3 changed files with 119 additions and 470 deletions

View File

@@ -25,12 +25,6 @@ class TestAscendAttentionBackend(TestBase):
self.assertEqual(AscendAttentionBackend.get_builder_cls(),
AscendAttentionMetadataBuilder)
@patch('vllm_ascend.attention.attention_v1.get_ascend_device_type',
return_value=AscendDeviceType._310P)
def test_get_kv_cache_shape_310p(self, mock_soc_version):
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
self.assertEqual(result, (2, 10, 30 * 40 // 16, 20, 16))
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType._910_93)
def test_get_kv_cache_shape_not_310p(self, mock_soc_version):
@@ -95,76 +89,6 @@ class TestAscendAttentionMetadataBuilder(TestBase):
self.assertFalse(result)
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@patch('torch_npu.npu_format_cast')
@patch('vllm_ascend.utils.nd_to_nz_2d')
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType._310P)
def test_build_prefill_no_cache(self, mock_soc_version, mock_nd_to_nz_2d,
mock_npu_format_cast,
mock_ascend_metadata):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 3, 7]),
query_start_loc_cpu=torch.tensor([0, 3, 7]),
seq_lens_cpu=torch.tensor([5, 6]),
num_reqs=2,
num_actual_tokens=10,
max_query_len=5,
decode_token_per_req=torch.tensor([1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((10, 10)),
spec_attn_mask=None,
attn_state=AscendAttentionState.PrefillNoCache,
num_computed_tokens_cpu=None,
seq_lens=None)
mock_nz_tensor = MagicMock()
mock_model = MagicMock()
mock_nd_to_nz_2d.return_value = mock_nz_tensor
mock_npu_format_cast.return_value = mock_nz_tensor
self.builder.build(1, common_attn_metadata, mock_model)
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@patch('torch_npu.npu_format_cast')
@patch('vllm_ascend.utils.nd_to_nz_spec')
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType._310P)
@patch('vllm_ascend.attention.attention_v1.AscendAttentionState')
def test_build_chunked_prefill(self, mock_ascend_attention_state,
mock_soc_version, mock_nd_to_nz_spec,
mock_npu_format_cast, mock_ascend_metadata):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 2, 5, 9]),
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
seq_lens_cpu=torch.tensor([4, 5, 6]),
num_reqs=3,
num_actual_tokens=15,
max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill,
num_computed_tokens_cpu=None,
seq_lens=None)
mock_ascend_attention_state = MagicMock()
mock_ascend_attention_state.PrefillNoCache = 0
mock_nz_tensor = MagicMock()
mock_model = MagicMock()
mock_nd_to_nz_spec.return_value = mock_nz_tensor
mock_npu_format_cast.return_value = mock_nz_tensor
self.builder.build(1, common_attn_metadata, mock_model)
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType._910_93)
@@ -286,73 +210,40 @@ class TestAscendAttentionBackendImpl(TestBase):
assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_flash_attention')
def test_forward_prefill_no_cache(self, mock_flash_attention,
mock_reshape_cache,
mock_get_forward_context):
"""Test forward pass in PrefillNoCache state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
mock_get_forward_context.return_value = MagicMock(capturing=False)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillNoCache
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.seq_lens = torch.tensor([10])
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_reshape_cache.assert_called_once()
mock_flash_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score')
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
def test_forward_prefill_cache_hit(self, mock_get_forward_context,
mock_npu_fused_infer_attention_score,
mock_npu_reshape_and_cache):
def test_forward_prefill(self, mock_get_forward_context,
mock_npu_fused_infer_attention_score,
mock_npu_reshape_and_cache):
"""Test forward pass in PrefillCacheHit state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
query = torch.randn(10, 8, 64)
key = torch.randn(10, 8, 64)
value = torch.randn(10, 8, 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillCacheHit
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.actual_seq_lengths_q = [10]
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decode_tokens = 0
metadata.num_decodes = 0
metadata.num_prefills = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_fused_infer_attention_score.return_value = (output,
torch.ones(
10, 8, 64))
mock_npu_fused_infer_attention_score.return_value = (torch.ones(
10, 8, 64), torch.ones(10, 8, 64))
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_npu_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8 * 64)
assert output.shape == (10, 8, 64)
@patch('torch_npu._npu_paged_attention')
@patch('torch_npu._npu_reshape_and_cache')
@@ -454,119 +345,6 @@ class TestAscendAttentionBackendImpl(TestBase):
assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType._910_93)
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
def test_forward_head_size_192(self, mock_vanilla_prefill,
mock_npu_reshape_and_cache,
mock_soc_version, mock_get_forward_context):
"""Test forward pass when head_size is 192"""
self.impl.head_size = 192
query = torch.randn(10, 8 * 192)
key = torch.randn(10, 8 * 192)
value = torch.randn(10, 8 * 192)
kv_cache = torch.empty(2, 5, 128, 8, 192)
output = torch.empty_like(query)
mock_get_forward_context.return_value = MagicMock(capturing=False)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 10
metadata.num_prefills = 0
layer = self.layer_no_quant
mock_vanilla_prefill.return_value = MagicMock()
output = self.impl_192.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_vanilla_prefill.assert_called_once()
assert output.shape == (10, 8 * 192)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu.npu_fused_infer_attention_score')
@patch('torch_npu._npu_reshape_and_cache')
def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache,
mock_npu_fused_infer_attention_score,
mock_get_forward_context):
"""Test forward pass in normal V1 situation"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_fused_infer_attention_score.return_value = (output,
torch.ones(
10, 8, 64))
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_npu_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu.npu_format_cast')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score')
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType._310P)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
def test_forward_310p_device(self, mock_get_forward_context,
mock_soc_version,
mock_npu_fused_infer_attention_score,
mock_npu_reshape_and_cache,
mock_npu_format_cast):
"""Test forward pass on 310P device"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
mock_npu_format_cast.return_value = metadata.attn_mask
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_fused_infer_attention_score.return_value = (output,
torch.ones(
10, 8, 64))
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_npu_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
def test_forward_raise_error(self, mock_paged_attention):
query = torch.randn(10, 8 * 64)