[Refactor] Refactor Ascend attention implementation forward (#3714)
### What this PR does / why we need it?
This PR refactors the Ascend attention implementation to align with
vLLM's core interfaces, simplifying the code and improving
maintainability.
### Key Changes:
* **Align with vLLM's Attention Interface**: The `forward` method
signature in `AscendAttentionBackendImpl` now matches the base
`AttentionImpl` in vLLM, removing the custom `trace_flag`.
* **Enable Opaque Attention Operator**: By adding `opaque_attention_op`
to `AscendPlatform`, we allow vLLM to wrap our attention kernel in its
standard `vllm.unified_attention_with_output` operator. This avoids the
need for a custom call path.
* **Remove Obsolete Code**:
* The custom op `vllm.unified_ascend_attention_with_output` has been
deleted as it is now redundant.
* The `trace_flag` and its associated logic were removed, reducing code
complexity.
* An outdated quantization branch within the attention implementation
was cleaned up.
* **Improve Readability**: Renamed output variables (`output` vs.
`intermediate_output`) and added comments to clarify the in-place nature
of the attention output.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
No extra tests needed.
- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -255,59 +255,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
attn_type=self.attention_type.DECODER,
|
||||
kv_sharing_target_layer_name=None)
|
||||
|
||||
@patch('torch.ops.vllm.unified_ascend_attention_with_output')
|
||||
def test_forward_trace_flag_true(self, mock_unified_attention):
|
||||
"""Test forward pass when trace_flag is True"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 0, 0, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
layer = self.layer
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=True)
|
||||
|
||||
mock_unified_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_paged_attention_splitfuse')
|
||||
def test_forward_with_quant_method(self, mock_paged_attention):
|
||||
"""Test forward pass when layer has quant_method"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
k_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
|
||||
v_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
|
||||
kv_cache = [k_cache, v_cache]
|
||||
ret_value = torch.ones(1, 1, 10, 8, 64, dtype=torch.int8)
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.num_actual_tokens = torch.randn(10, 8 * 64)
|
||||
metadata.block_tables = torch.randn(10, 8 * 64)
|
||||
metadata.seq_lens = torch.randn(10, 8 * 64)
|
||||
metadata.attn_mask = torch.randn(10, 8 * 64)
|
||||
metadata.query_lens = torch.randn(10, 8 * 64)
|
||||
layer = self.layer
|
||||
layer.quant_method = MagicMock()
|
||||
layer.quant_method.apply.return_value = ret_value
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
layer.quant_method.apply.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
def test_forward_no_attn_metadata(self):
|
||||
"""Test forward pass when attn_metadata is None"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
@@ -315,14 +262,10 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 0, 0, 8, 64)
|
||||
layer = self.layer_no_quant
|
||||
output = torch.empty_like(query)
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
None,
|
||||
trace_flag=False)
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache, None,
|
||||
output)
|
||||
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@@ -335,6 +278,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
@@ -344,15 +289,9 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
layer = self.layer_no_quant
|
||||
# layer.quant_method.apply.return_value = metadata
|
||||
print(self.layer_no_quant._v_scale_float)
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_reshape_cache.assert_called_once()
|
||||
mock_flash_attention.assert_called_once()
|
||||
@@ -367,6 +306,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
@@ -379,13 +320,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
metadata.num_prefills = 10
|
||||
layer = self.layer_no_quant
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_flash_attention_qlens.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
@@ -401,6 +337,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
@@ -413,13 +351,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
@@ -509,6 +442,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
@@ -522,13 +457,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=True)
|
||||
mock_get_graph_params.return_value = graph_params
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
self.assertEqual(len(graph_params.handles[num_tokens]), 0)
|
||||
@@ -542,6 +472,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty(10, 8, 64)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.seq_lens = torch.tensor([10] * 10)
|
||||
@@ -553,16 +485,11 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
layer = self.layer_no_quant
|
||||
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
|
||||
64), 1)
|
||||
output = self.impl_swa.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
print(output.shape)
|
||||
mock_fused_infer_attention_score.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
assert output.shape == (10, 8, 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@@ -576,6 +503,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
@@ -583,6 +511,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
metadata.num_decodes = 10
|
||||
metadata.num_prefills = 0
|
||||
|
||||
@@ -591,13 +520,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
output = self.impl_swa.forward(self.layer_no_quant,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
mock_fused_infer_attention_score.assert_not_called()
|
||||
@@ -618,6 +542,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
key = torch.randn(10, 8 * 192)
|
||||
value = torch.randn(10, 8 * 192)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 192)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
@@ -631,13 +557,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_version.cann = "8.4.RC1"
|
||||
mock_vanilla_prefill.return_value = MagicMock()
|
||||
|
||||
output = self.impl_192.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
output = self.impl_192.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_vanilla_prefill.assert_called_once()
|
||||
assert output.shape == (10, 8 * 192)
|
||||
@@ -653,6 +574,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
@@ -666,13 +589,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
mock_version.cann = "8.4.RC1"
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
@@ -690,6 +608,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
@@ -703,13 +623,9 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
mock_npu_format_cast.return_value = metadata.attn_mask
|
||||
mock_version.cann = "8.4.RC1"
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
@@ -720,6 +636,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
@@ -732,10 +650,5 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
layer = self.layer_no_quant
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.impl_error.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
self.impl_error.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
Reference in New Issue
Block a user