[Refactor] Refactor Ascend attention implementation forward (#3714)

### What this PR does / why we need it?
This PR refactors the Ascend attention implementation to align with
vLLM's core interfaces, simplifying the code and improving
maintainability.

### Key Changes:

* **Align with vLLM's Attention Interface**: The `forward` method
signature in `AscendAttentionBackendImpl` now matches the base
`AttentionImpl` in vLLM, removing the custom `trace_flag`.

* **Enable Opaque Attention Operator**: By adding `opaque_attention_op`
to `AscendPlatform`, we allow vLLM to wrap our attention kernel in its
standard `vllm.unified_attention_with_output` operator. This avoids the
need for a custom call path.

*   **Remove Obsolete Code**:
* The custom op `vllm.unified_ascend_attention_with_output` has been
deleted as it is now redundant.
* The `trace_flag` and its associated logic were removed, reducing code
complexity.
* An outdated quantization branch within the attention implementation
was cleaned up.

* **Improve Readability**: Renamed output variables (`output` vs.
`intermediate_output`) and added comments to clarify the in-place nature
of the attention output.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
No extra tests needed.

- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-10-25 08:58:35 +08:00
committed by GitHub
parent 0b1da24742
commit 3158742a97
9 changed files with 191 additions and 349 deletions

View File

@@ -255,59 +255,6 @@ class TestAscendAttentionBackendImpl(TestBase):
attn_type=self.attention_type.DECODER,
kv_sharing_target_layer_name=None)
@patch('torch.ops.vllm.unified_ascend_attention_with_output')
def test_forward_trace_flag_true(self, mock_unified_attention):
"""Test forward pass when trace_flag is True"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 0, 0, 8, 64)
metadata = self.attn_metadata
layer = self.layer
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=True)
mock_unified_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_paged_attention_splitfuse')
def test_forward_with_quant_method(self, mock_paged_attention):
"""Test forward pass when layer has quant_method"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
k_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
v_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
kv_cache = [k_cache, v_cache]
ret_value = torch.ones(1, 1, 10, 8, 64, dtype=torch.int8)
metadata = MagicMock()
metadata.num_actual_tokens = torch.randn(10, 8 * 64)
metadata.block_tables = torch.randn(10, 8 * 64)
metadata.seq_lens = torch.randn(10, 8 * 64)
metadata.attn_mask = torch.randn(10, 8 * 64)
metadata.query_lens = torch.randn(10, 8 * 64)
layer = self.layer
layer.quant_method = MagicMock()
layer.quant_method.apply.return_value = ret_value
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
layer.quant_method.apply.assert_called_once()
assert output.shape == (10, 8 * 64)
def test_forward_no_attn_metadata(self):
"""Test forward pass when attn_metadata is None"""
query = torch.randn(10, 8 * 64)
@@ -315,14 +262,10 @@ class TestAscendAttentionBackendImpl(TestBase):
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 0, 0, 8, 64)
layer = self.layer_no_quant
output = torch.empty_like(query)
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
None,
trace_flag=False)
output = self.impl.forward(layer, query, key, value, kv_cache, None,
output)
assert output.shape == (10, 8 * 64)
@@ -335,6 +278,8 @@ class TestAscendAttentionBackendImpl(TestBase):
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillNoCache
metadata.attn_mask = torch.randn(1, 1, 10, 10)
@@ -344,15 +289,9 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
# layer.quant_method.apply.return_value = metadata
print(self.layer_no_quant._v_scale_float)
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_reshape_cache.assert_called_once()
mock_flash_attention.assert_called_once()
@@ -367,6 +306,8 @@ class TestAscendAttentionBackendImpl(TestBase):
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillCacheHit
metadata.attn_mask = torch.randn(1, 1, 10, 10)
@@ -379,13 +320,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.num_prefills = 10
layer = self.layer_no_quant
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_flash_attention_qlens.assert_called_once()
assert output.shape == (10, 8 * 64)
@@ -401,6 +337,8 @@ class TestAscendAttentionBackendImpl(TestBase):
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10])
@@ -413,13 +351,8 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_get_forward_context.return_value = MagicMock(capturing=False)
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@@ -509,6 +442,8 @@ class TestAscendAttentionBackendImpl(TestBase):
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10])
@@ -522,13 +457,8 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_get_forward_context.return_value = MagicMock(capturing=True)
mock_get_graph_params.return_value = graph_params
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_paged_attention.assert_called_once()
self.assertEqual(len(graph_params.handles[num_tokens]), 0)
@@ -542,6 +472,8 @@ class TestAscendAttentionBackendImpl(TestBase):
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty(10, 8, 64)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10] * 10)
@@ -553,16 +485,11 @@ class TestAscendAttentionBackendImpl(TestBase):
layer = self.layer_no_quant
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1)
output = self.impl_swa.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
metadata, output)
print(output.shape)
mock_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8 * 64)
assert output.shape == (10, 8, 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu._npu_reshape_and_cache')
@@ -576,6 +503,7 @@ class TestAscendAttentionBackendImpl(TestBase):
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
@@ -583,6 +511,7 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
metadata.num_decodes = 10
metadata.num_prefills = 0
@@ -591,13 +520,8 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_get_forward_context.return_value = MagicMock(capturing=False)
output = self.impl_swa.forward(self.layer_no_quant,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_paged_attention.assert_called_once()
mock_fused_infer_attention_score.assert_not_called()
@@ -618,6 +542,8 @@ class TestAscendAttentionBackendImpl(TestBase):
key = torch.randn(10, 8 * 192)
value = torch.randn(10, 8 * 192)
kv_cache = torch.empty(2, 5, 128, 8, 192)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
@@ -631,13 +557,8 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_version.cann = "8.4.RC1"
mock_vanilla_prefill.return_value = MagicMock()
output = self.impl_192.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl_192.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_vanilla_prefill.assert_called_once()
assert output.shape == (10, 8 * 192)
@@ -653,6 +574,8 @@ class TestAscendAttentionBackendImpl(TestBase):
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
@@ -666,13 +589,8 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_version.cann = "8.4.RC1"
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@@ -690,6 +608,8 @@ class TestAscendAttentionBackendImpl(TestBase):
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
@@ -703,13 +623,9 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_npu_format_cast.return_value = metadata.attn_mask
mock_version.cann = "8.4.RC1"
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@@ -720,6 +636,8 @@ class TestAscendAttentionBackendImpl(TestBase):
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
@@ -732,10 +650,5 @@ class TestAscendAttentionBackendImpl(TestBase):
layer = self.layer_no_quant
with self.assertRaises(NotImplementedError):
self.impl_error.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
self.impl_error.forward(layer, query, key, value, kv_cache,
metadata, output)