[Feat] 310p supports PrefillCacheHit State (#6756)

### What this PR does / why we need it?
This PR extends the Ascend 310P attention backend to support the
`PrefillCacheHit` state. Previously, only `PrefillNoCache`,
`DecodeOnly`, and `ChunkedPrefill` were supported.
This PR handles this state by routing it to the existing
`forward_chunked_prefill_310` implementation, which is suitable for this
scenario.
The changes also include refactoring the main `forward_impl` dispatch
method for better clarity and updating unit tests to cover the new state
and ensure correctness.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Accuracy test when chunked prefill is disabled.
- vLLM version: v0.15.0
- vLLM main:
9562912cea

---------

Signed-off-by: pu-zhe <zpuaa@outlook.com>
This commit is contained in:
pu-zhe
2026-02-24 16:48:05 +08:00
committed by GitHub
parent 62ea664aa7
commit a8e951e6f5
3 changed files with 169 additions and 24 deletions

View File

@@ -78,7 +78,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
def test_forward_prefill_310(
self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache
):
"""Test forward pass in PrefillCacheHit state"""
"""Test forward pass in PrefillNoCache state"""
query = torch.randn(10, 8, 64)
key = torch.randn(10, 8, 64)
value = torch.randn(10, 8, 64)
@@ -98,7 +98,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_npu_flash_attention.return_value = torch.ones(10, 8, 64)
output = self.impl.forward_prefill_310(query, key, value, metadata, output)
output = self.impl.forward_impl(query, key, value, None, metadata, output)
mock_npu_npu_flash_attention.assert_called_once()
@@ -107,10 +107,15 @@ class TestAscendAttentionBackendImpl310(TestBase):
@patch("torch_npu._npu_paged_attention_splitfuse")
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
def test_forward_chunked_prefill_310(
self, mock_get_forward_context, mock_npu_paged_attention_splitfuse, mock_npu_reshape_and_cache, mock_format_cast
self,
mock_get_forward_context,
mock_npu_paged_attention_splitfuse,
mock_npu_reshape_and_cache,
mock_format_cast,
):
"""Test forward pass in PrefillCacheHit state"""
"""Test forward pass in ChunkedPrefill state"""
query = torch.randn(5, 8, 64)
key, value = None, None
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.ChunkedPrefill
@@ -128,7 +133,42 @@ class TestAscendAttentionBackendImpl310(TestBase):
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64)
output = self.impl.forward_chunked_prefill_310(query, metadata, output)
output = self.impl.forward_impl(query, key, value, None, metadata, output)
mock_npu_paged_attention_splitfuse.assert_called_once()
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
@patch("torch_npu._npu_reshape_and_cache")
@patch("torch_npu._npu_paged_attention_splitfuse")
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
def test_forward_prefill_cache_hit_310(
self,
mock_get_forward_context,
mock_npu_paged_attention_splitfuse,
mock_npu_reshape_and_cache,
mock_format_cast,
):
"""Test forward pass in PrefillCacheHit state"""
query = torch.randn(5, 8, 64)
key, value = None, None
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillCacheHit
metadata.attn_mask = torch.randn(1, 128, 16, 16)
metadata.query_lens = torch.tensor([5])
metadata.seq_lens = torch.tensor([1, 4])
metadata.query_start_loc = torch.tensor([0, 1, 5])
metadata.actual_seq_lengths_q = [5]
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.num_decode_tokens = 0
metadata.num_decodes = 0
metadata.num_prefills = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64)
output = self.impl.forward_impl(query, key, value, None, metadata, output)
mock_npu_paged_attention_splitfuse.assert_called_once()
@@ -141,6 +181,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(4, 8 * 64)
key, value = None, None
output = torch.empty_like(query)
metadata = self.attn_metadata
@@ -155,6 +196,15 @@ class TestAscendAttentionBackendImpl310(TestBase):
mock_get_forward_context.return_value = MagicMock(capturing=False)
output = self.impl.forward_paged_attention(query, metadata, output)
output = self.impl.forward_impl(query, key, value, None, metadata, output)
mock_paged_attention.assert_called_once()
def test_forward_mtp_310(self):
query = torch.randn(4, 8 * 64)
key, value = None, None
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.SpecDecoding
with self.assertRaises(NotImplementedError):
output = self.impl.forward_impl(query, key, value, None, metadata, output)