[Refactor] Remove redundant attention operator branches. (#4531)
[Refactor] Remove redundant attention operator branches. Reason: We replace other attention ops with fused_infer_attention_score expect decode_only state. clean code and remove 310P support. https://github.com/vllm-project/vllm-ascend/pull/4455 - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -25,12 +25,6 @@ class TestAscendAttentionBackend(TestBase):
|
||||
self.assertEqual(AscendAttentionBackend.get_builder_cls(),
|
||||
AscendAttentionMetadataBuilder)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._310P)
|
||||
def test_get_kv_cache_shape_310p(self, mock_soc_version):
|
||||
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
|
||||
self.assertEqual(result, (2, 10, 30 * 40 // 16, 20, 16))
|
||||
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._910_93)
|
||||
def test_get_kv_cache_shape_not_310p(self, mock_soc_version):
|
||||
@@ -95,76 +89,6 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
|
||||
self.assertFalse(result)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
@patch('vllm_ascend.utils.nd_to_nz_2d')
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._310P)
|
||||
def test_build_prefill_no_cache(self, mock_soc_version, mock_nd_to_nz_2d,
|
||||
mock_npu_format_cast,
|
||||
mock_ascend_metadata):
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 3, 7]),
|
||||
query_start_loc_cpu=torch.tensor([0, 3, 7]),
|
||||
seq_lens_cpu=torch.tensor([5, 6]),
|
||||
num_reqs=2,
|
||||
num_actual_tokens=10,
|
||||
max_query_len=5,
|
||||
decode_token_per_req=torch.tensor([1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((10, 10)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
|
||||
mock_nz_tensor = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_nd_to_nz_2d.return_value = mock_nz_tensor
|
||||
mock_npu_format_cast.return_value = mock_nz_tensor
|
||||
|
||||
self.builder.build(1, common_attn_metadata, mock_model)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
@patch('vllm_ascend.utils.nd_to_nz_spec')
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._310P)
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendAttentionState')
|
||||
def test_build_chunked_prefill(self, mock_ascend_attention_state,
|
||||
mock_soc_version, mock_nd_to_nz_spec,
|
||||
mock_npu_format_cast, mock_ascend_metadata):
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 2, 5, 9]),
|
||||
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
|
||||
seq_lens_cpu=torch.tensor([4, 5, 6]),
|
||||
num_reqs=3,
|
||||
num_actual_tokens=15,
|
||||
max_query_len=6,
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((15, 15)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.ChunkedPrefill,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
|
||||
mock_ascend_attention_state = MagicMock()
|
||||
mock_ascend_attention_state.PrefillNoCache = 0
|
||||
|
||||
mock_nz_tensor = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_nd_to_nz_spec.return_value = mock_nz_tensor
|
||||
mock_npu_format_cast.return_value = mock_nz_tensor
|
||||
|
||||
self.builder.build(1, common_attn_metadata, mock_model)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._910_93)
|
||||
@@ -286,73 +210,40 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_flash_attention')
|
||||
def test_forward_prefill_no_cache(self, mock_flash_attention,
|
||||
mock_reshape_cache,
|
||||
mock_get_forward_context):
|
||||
"""Test forward pass in PrefillNoCache state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
layer = self.layer_no_quant
|
||||
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_reshape_cache.assert_called_once()
|
||||
mock_flash_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
def test_forward_prefill_cache_hit(self, mock_get_forward_context,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache):
|
||||
def test_forward_prefill(self, mock_get_forward_context,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache):
|
||||
"""Test forward pass in PrefillCacheHit state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
query = torch.randn(10, 8, 64)
|
||||
key = torch.randn(10, 8, 64)
|
||||
value = torch.randn(10, 8, 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.actual_seq_lengths_q = [10]
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
metadata.num_decode_tokens = 0
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_fused_infer_attention_score.return_value = (output,
|
||||
torch.ones(
|
||||
10, 8, 64))
|
||||
|
||||
mock_npu_fused_infer_attention_score.return_value = (torch.ones(
|
||||
10, 8, 64), torch.ones(10, 8, 64))
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
assert output.shape == (10, 8, 64)
|
||||
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@@ -454,119 +345,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._910_93)
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
|
||||
def test_forward_head_size_192(self, mock_vanilla_prefill,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_soc_version, mock_get_forward_context):
|
||||
"""Test forward pass when head_size is 192"""
|
||||
|
||||
self.impl.head_size = 192
|
||||
query = torch.randn(10, 8 * 192)
|
||||
key = torch.randn(10, 8 * 192)
|
||||
value = torch.randn(10, 8 * 192)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 192)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
metadata.num_decodes = 10
|
||||
metadata.num_prefills = 0
|
||||
layer = self.layer_no_quant
|
||||
mock_vanilla_prefill.return_value = MagicMock()
|
||||
|
||||
output = self.impl_192.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_vanilla_prefill.assert_called_once()
|
||||
assert output.shape == (10, 8 * 192)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_get_forward_context):
|
||||
"""Test forward pass in normal V1 situation"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
layer = self.layer_no_quant
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_fused_infer_attention_score.return_value = (output,
|
||||
torch.ones(
|
||||
10, 8, 64))
|
||||
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._310P)
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
def test_forward_310p_device(self, mock_get_forward_context,
|
||||
mock_soc_version,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_npu_format_cast):
|
||||
"""Test forward pass on 310P device"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
layer = self.layer_no_quant
|
||||
|
||||
mock_npu_format_cast.return_value = metadata.attn_mask
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_fused_infer_attention_score.return_value = (output,
|
||||
torch.ones(
|
||||
10, 8, 64))
|
||||
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_forward_raise_error(self, mock_paged_attention):
|
||||
query = torch.randn(10, 8 * 64)
|
||||
|
||||
Reference in New Issue
Block a user