[Feat] 310p supports PrefillCacheHit State (#6756)

### What this PR does / why we need it?
This PR extends the Ascend 310P attention backend to support the
`PrefillCacheHit` state. Previously, only `PrefillNoCache`,
`DecodeOnly`, and `ChunkedPrefill` were supported.
This PR handles this state by routing it to the existing
`forward_chunked_prefill_310` implementation, which is suitable for this
scenario.
The changes also include refactoring the main `forward_impl` dispatch
method for better clarity and updating unit tests to cover the new state
and ensure correctness.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Accuracy test when chunked prefill is disabled.
- vLLM version: v0.15.0
- vLLM main:
9562912cea

---------

Signed-off-by: pu-zhe <zpuaa@outlook.com>
This commit is contained in:
pu-zhe
2026-02-24 16:48:05 +08:00
committed by GitHub
parent 62ea664aa7
commit a8e951e6f5
3 changed files with 169 additions and 24 deletions

View File

@@ -78,7 +78,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
def test_forward_prefill_310(
self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache
):
"""Test forward pass in PrefillCacheHit state"""
"""Test forward pass in PrefillNoCache state"""
query = torch.randn(10, 8, 64)
key = torch.randn(10, 8, 64)
value = torch.randn(10, 8, 64)
@@ -98,7 +98,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_npu_flash_attention.return_value = torch.ones(10, 8, 64)
output = self.impl.forward_prefill_310(query, key, value, metadata, output)
output = self.impl.forward_impl(query, key, value, None, metadata, output)
mock_npu_npu_flash_attention.assert_called_once()
@@ -107,10 +107,15 @@ class TestAscendAttentionBackendImpl310(TestBase):
@patch("torch_npu._npu_paged_attention_splitfuse")
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
def test_forward_chunked_prefill_310(
self, mock_get_forward_context, mock_npu_paged_attention_splitfuse, mock_npu_reshape_and_cache, mock_format_cast
self,
mock_get_forward_context,
mock_npu_paged_attention_splitfuse,
mock_npu_reshape_and_cache,
mock_format_cast,
):
"""Test forward pass in PrefillCacheHit state"""
"""Test forward pass in ChunkedPrefill state"""
query = torch.randn(5, 8, 64)
key, value = None, None
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.ChunkedPrefill
@@ -128,7 +133,42 @@ class TestAscendAttentionBackendImpl310(TestBase):
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64)
output = self.impl.forward_chunked_prefill_310(query, metadata, output)
output = self.impl.forward_impl(query, key, value, None, metadata, output)
mock_npu_paged_attention_splitfuse.assert_called_once()
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
@patch("torch_npu._npu_reshape_and_cache")
@patch("torch_npu._npu_paged_attention_splitfuse")
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
def test_forward_prefill_cache_hit_310(
self,
mock_get_forward_context,
mock_npu_paged_attention_splitfuse,
mock_npu_reshape_and_cache,
mock_format_cast,
):
"""Test forward pass in PrefillCacheHit state"""
query = torch.randn(5, 8, 64)
key, value = None, None
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillCacheHit
metadata.attn_mask = torch.randn(1, 128, 16, 16)
metadata.query_lens = torch.tensor([5])
metadata.seq_lens = torch.tensor([1, 4])
metadata.query_start_loc = torch.tensor([0, 1, 5])
metadata.actual_seq_lengths_q = [5]
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.num_decode_tokens = 0
metadata.num_decodes = 0
metadata.num_prefills = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64)
output = self.impl.forward_impl(query, key, value, None, metadata, output)
mock_npu_paged_attention_splitfuse.assert_called_once()
@@ -141,6 +181,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(4, 8 * 64)
key, value = None, None
output = torch.empty_like(query)
metadata = self.attn_metadata
@@ -155,6 +196,15 @@ class TestAscendAttentionBackendImpl310(TestBase):
mock_get_forward_context.return_value = MagicMock(capturing=False)
output = self.impl.forward_paged_attention(query, metadata, output)
output = self.impl.forward_impl(query, key, value, None, metadata, output)
mock_paged_attention.assert_called_once()
def test_forward_mtp_310(self):
query = torch.randn(4, 8 * 64)
key, value = None, None
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.SpecDecoding
with self.assertRaises(NotImplementedError):
output = self.impl.forward_impl(query, key, value, None, metadata, output)

View File

@@ -198,6 +198,8 @@ class AscendAttentionBackendImpl310(AscendAttentionBackendImpl):
out=output,
)
return output
def forward_impl(self, query, key, value, kv_cache, attn_metadata, output):
"""
Main dispatch method for attention operations.
@@ -218,22 +220,19 @@ class AscendAttentionBackendImpl310(AscendAttentionBackendImpl):
NotImplementedError: If the attention state is not supported on 310P.
"""
state = attn_metadata.attn_state
if state == AscendAttentionState.DecodeOnly:
return self.forward_paged_attention(query, attn_metadata, output)
# Condition for PrefillNoCache: No previous tokens have been processed yet
if state == AscendAttentionState.PrefillNoCache:
out = self.forward_prefill_310(query, key, value, attn_metadata, output)
return out
if state == AscendAttentionState.ChunkedPrefill:
self.forward_chunked_prefill_310(query, attn_metadata, output)
return output
raise NotImplementedError(
f"{self.__class__.__name__}.forward_impl: 310P only supports "
f"{AscendAttentionState.DecodeOnly.name}, "
f"{AscendAttentionState.PrefillNoCache.name}, "
f"{AscendAttentionState.ChunkedPrefill.name}, "
f"got {state!r}."
)
output = self.forward_prefill_310(query, key, value, attn_metadata, output)
# Condition for DecodeOnly: Pure decoding phase where each request generates one token
elif state == AscendAttentionState.DecodeOnly:
output = self.forward_paged_attention(query, attn_metadata, output)
# Condition for ChunkedPrefill:
# 1. During speculative decoding scenarios (except mtp)
# 2. Processing large prefill requests in chunks
# Condition for PrefillCacheHit: Indicates prefill with some cached tokens already processed
elif state in [AscendAttentionState.ChunkedPrefill, AscendAttentionState.PrefillCacheHit]:
output = self.forward_chunked_prefill_310(query, attn_metadata, output)
# Condition for SpecDecoding: Specified for mtp, which is not supported yet.
else:
raise NotImplementedError(f"AscendAttentionState: {state} is not supported for 310P currently.")
return output

View File

@@ -17,9 +17,11 @@
from __future__ import annotations
import numpy as np
import torch
import torch_npu
from vllm.logger import logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MambaSpec
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
@@ -185,3 +187,97 @@ class NPUModelRunner310(NPUModelRunner):
raise ValueError("Unknown KV cache spec type.")
return kv_caches
# Override this function because of tensor.copy_(other) accuracy issue.
# TODO: This override will be removed after tensor.copy_(other) accuracy issue is resolved.
def _prepare_input_ids(
self,
scheduler_output: SchedulerOutput,
total_num_scheduled_tokens: int,
cu_num_tokens: np.ndarray,
) -> None:
"""Prepare the input IDs for the current batch.
Carefully handles the `prev_sampled_token_ids` which can be cached
from the previous engine iteration, in which case those tokens on the
GPU need to be copied into the corresponding slots into input_ids."""
if self.input_batch.prev_sampled_token_ids is None:
# Normal scheduling case
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
if self.enable_prompt_embeds:
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
return
# Async scheduling case, where some decode requests from the previous
# iteration won't have entries in input_ids_cpu and need to be copied
# on the NPU from prev_sampled_token_ids.
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
assert prev_req_id_to_index is not None
sample_flattened_indices: list[int] = []
spec_flattened_indices: list[int] = []
prev_common_req_indices: list[int] = []
prev_draft_token_indices: list[int] = []
indices_match = True
max_flattened_index = -1
total_num_spec_tokens = 0
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id, cur_index in self.input_batch.req_id_to_index.items():
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
prev_common_req_indices.append(prev_index)
draft_len = len(scheduled_spec_tokens.get(req_id, ()))
total_num_spec_tokens += draft_len
flattened_index = cu_num_tokens[cur_index].item() - 1
sample_flattened_indices.append(flattened_index - draft_len)
spec_flattened_indices.extend(range(flattened_index - draft_len + 1, flattened_index + 1))
start = prev_index * self.num_spec_tokens
prev_draft_token_indices.extend(range(start, start + draft_len))
indices_match &= prev_index == flattened_index
max_flattened_index = max(max_flattened_index, flattened_index)
num_commmon_tokens = len(sample_flattened_indices)
total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens
if num_commmon_tokens < total_without_spec:
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
if self.enable_prompt_embeds:
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
if num_commmon_tokens == 0:
return
if indices_match and max_flattened_index == (num_commmon_tokens - 1):
# NOTE: Override the copy_ function here
indices = torch.arange(num_commmon_tokens, device=self.input_ids.gpu.device)
source = self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0]
self.input_ids.gpu.index_copy_(0, indices, source)
if self.enable_prompt_embeds:
self.is_token_ids.gpu[:num_commmon_tokens] = True
return
# Upload the index tensors asynchronously so the scatter can be non-blocking.
sampled_tokens_index_tensor = torch.tensor(
sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
prev_common_req_indices_tensor = torch.tensor(
prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
self.input_ids.gpu.scatter_(
dim=0,
index=sampled_tokens_index_tensor,
src=self.input_batch.prev_sampled_token_ids[prev_common_req_indices_tensor, 0],
)
# Scatter the draft tokens after the sampled tokens are scattered.
if self._draft_token_ids is None or not spec_flattened_indices:
return
assert isinstance(self._draft_token_ids, torch.Tensor)
draft_tokens_index_tensor = torch.tensor(
spec_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
prev_draft_token_indices_tensor = torch.tensor(
prev_draft_token_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
draft_token_ids = self._draft_token_ids.to(dtype=torch.int32)
self.input_ids.gpu.scatter_(
dim=0,
index=draft_tokens_index_tensor,
src=draft_token_ids.flatten()[prev_draft_token_indices_tensor],
)