[Feat] 310p supports PrefillCacheHit State (#6756)
### What this PR does / why we need it?
This PR extends the Ascend 310P attention backend to support the
`PrefillCacheHit` state. Previously, only `PrefillNoCache`,
`DecodeOnly`, and `ChunkedPrefill` were supported.
This PR handles this state by routing it to the existing
`forward_chunked_prefill_310` implementation, which is suitable for this
scenario.
The changes also include refactoring the main `forward_impl` dispatch
method for better clarity and updating unit tests to cover the new state
and ensure correctness.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Accuracy test when chunked prefill is disabled.
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: pu-zhe <zpuaa@outlook.com>
This commit is contained in:
@@ -78,7 +78,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
||||
def test_forward_prefill_310(
|
||||
self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache
|
||||
):
|
||||
"""Test forward pass in PrefillCacheHit state"""
|
||||
"""Test forward pass in PrefillNoCache state"""
|
||||
query = torch.randn(10, 8, 64)
|
||||
key = torch.randn(10, 8, 64)
|
||||
value = torch.randn(10, 8, 64)
|
||||
@@ -98,7 +98,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_npu_flash_attention.return_value = torch.ones(10, 8, 64)
|
||||
output = self.impl.forward_prefill_310(query, key, value, metadata, output)
|
||||
output = self.impl.forward_impl(query, key, value, None, metadata, output)
|
||||
|
||||
mock_npu_npu_flash_attention.assert_called_once()
|
||||
|
||||
@@ -107,10 +107,15 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
||||
@patch("torch_npu._npu_paged_attention_splitfuse")
|
||||
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
||||
def test_forward_chunked_prefill_310(
|
||||
self, mock_get_forward_context, mock_npu_paged_attention_splitfuse, mock_npu_reshape_and_cache, mock_format_cast
|
||||
self,
|
||||
mock_get_forward_context,
|
||||
mock_npu_paged_attention_splitfuse,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_format_cast,
|
||||
):
|
||||
"""Test forward pass in PrefillCacheHit state"""
|
||||
"""Test forward pass in ChunkedPrefill state"""
|
||||
query = torch.randn(5, 8, 64)
|
||||
key, value = None, None
|
||||
output = torch.empty_like(query)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
||||
@@ -128,7 +133,42 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64)
|
||||
output = self.impl.forward_chunked_prefill_310(query, metadata, output)
|
||||
output = self.impl.forward_impl(query, key, value, None, metadata, output)
|
||||
|
||||
mock_npu_paged_attention_splitfuse.assert_called_once()
|
||||
|
||||
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
|
||||
@patch("torch_npu._npu_reshape_and_cache")
|
||||
@patch("torch_npu._npu_paged_attention_splitfuse")
|
||||
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
||||
def test_forward_prefill_cache_hit_310(
|
||||
self,
|
||||
mock_get_forward_context,
|
||||
mock_npu_paged_attention_splitfuse,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_format_cast,
|
||||
):
|
||||
"""Test forward pass in PrefillCacheHit state"""
|
||||
query = torch.randn(5, 8, 64)
|
||||
key, value = None, None
|
||||
output = torch.empty_like(query)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
||||
metadata.attn_mask = torch.randn(1, 128, 16, 16)
|
||||
metadata.query_lens = torch.tensor([5])
|
||||
metadata.seq_lens = torch.tensor([1, 4])
|
||||
metadata.query_start_loc = torch.tensor([0, 1, 5])
|
||||
metadata.actual_seq_lengths_q = [5]
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.num_decode_tokens = 0
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64)
|
||||
output = self.impl.forward_impl(query, key, value, None, metadata, output)
|
||||
|
||||
mock_npu_paged_attention_splitfuse.assert_called_once()
|
||||
|
||||
@@ -141,6 +181,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
||||
):
|
||||
"""Test forward pass in DecodeOnly state"""
|
||||
query = torch.randn(4, 8 * 64)
|
||||
key, value = None, None
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
@@ -155,6 +196,15 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
output = self.impl.forward_paged_attention(query, metadata, output)
|
||||
output = self.impl.forward_impl(query, key, value, None, metadata, output)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
|
||||
def test_forward_mtp_310(self):
|
||||
query = torch.randn(4, 8 * 64)
|
||||
key, value = None, None
|
||||
output = torch.empty_like(query)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.SpecDecoding
|
||||
with self.assertRaises(NotImplementedError):
|
||||
output = self.impl.forward_impl(query, key, value, None, metadata, output)
|
||||
|
||||
Reference in New Issue
Block a user