[Feat] 310p supports PrefillCacheHit State (#6756)
### What this PR does / why we need it?
This PR extends the Ascend 310P attention backend to support the
`PrefillCacheHit` state. Previously, only `PrefillNoCache`,
`DecodeOnly`, and `ChunkedPrefill` were supported.
This PR handles this state by routing it to the existing
`forward_chunked_prefill_310` implementation, which is suitable for this
scenario.
The changes also include refactoring the main `forward_impl` dispatch
method for better clarity and updating unit tests to cover the new state
and ensure correctness.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Accuracy test when chunked prefill is disabled.
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: pu-zhe <zpuaa@outlook.com>
This commit is contained in:
@@ -78,7 +78,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
||||
def test_forward_prefill_310(
|
||||
self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache
|
||||
):
|
||||
"""Test forward pass in PrefillCacheHit state"""
|
||||
"""Test forward pass in PrefillNoCache state"""
|
||||
query = torch.randn(10, 8, 64)
|
||||
key = torch.randn(10, 8, 64)
|
||||
value = torch.randn(10, 8, 64)
|
||||
@@ -98,7 +98,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_npu_flash_attention.return_value = torch.ones(10, 8, 64)
|
||||
output = self.impl.forward_prefill_310(query, key, value, metadata, output)
|
||||
output = self.impl.forward_impl(query, key, value, None, metadata, output)
|
||||
|
||||
mock_npu_npu_flash_attention.assert_called_once()
|
||||
|
||||
@@ -107,10 +107,15 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
||||
@patch("torch_npu._npu_paged_attention_splitfuse")
|
||||
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
||||
def test_forward_chunked_prefill_310(
|
||||
self, mock_get_forward_context, mock_npu_paged_attention_splitfuse, mock_npu_reshape_and_cache, mock_format_cast
|
||||
self,
|
||||
mock_get_forward_context,
|
||||
mock_npu_paged_attention_splitfuse,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_format_cast,
|
||||
):
|
||||
"""Test forward pass in PrefillCacheHit state"""
|
||||
"""Test forward pass in ChunkedPrefill state"""
|
||||
query = torch.randn(5, 8, 64)
|
||||
key, value = None, None
|
||||
output = torch.empty_like(query)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
||||
@@ -128,7 +133,42 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64)
|
||||
output = self.impl.forward_chunked_prefill_310(query, metadata, output)
|
||||
output = self.impl.forward_impl(query, key, value, None, metadata, output)
|
||||
|
||||
mock_npu_paged_attention_splitfuse.assert_called_once()
|
||||
|
||||
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
|
||||
@patch("torch_npu._npu_reshape_and_cache")
|
||||
@patch("torch_npu._npu_paged_attention_splitfuse")
|
||||
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
||||
def test_forward_prefill_cache_hit_310(
|
||||
self,
|
||||
mock_get_forward_context,
|
||||
mock_npu_paged_attention_splitfuse,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_format_cast,
|
||||
):
|
||||
"""Test forward pass in PrefillCacheHit state"""
|
||||
query = torch.randn(5, 8, 64)
|
||||
key, value = None, None
|
||||
output = torch.empty_like(query)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
||||
metadata.attn_mask = torch.randn(1, 128, 16, 16)
|
||||
metadata.query_lens = torch.tensor([5])
|
||||
metadata.seq_lens = torch.tensor([1, 4])
|
||||
metadata.query_start_loc = torch.tensor([0, 1, 5])
|
||||
metadata.actual_seq_lengths_q = [5]
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.num_decode_tokens = 0
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64)
|
||||
output = self.impl.forward_impl(query, key, value, None, metadata, output)
|
||||
|
||||
mock_npu_paged_attention_splitfuse.assert_called_once()
|
||||
|
||||
@@ -141,6 +181,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
||||
):
|
||||
"""Test forward pass in DecodeOnly state"""
|
||||
query = torch.randn(4, 8 * 64)
|
||||
key, value = None, None
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
@@ -155,6 +196,15 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
output = self.impl.forward_paged_attention(query, metadata, output)
|
||||
output = self.impl.forward_impl(query, key, value, None, metadata, output)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
|
||||
def test_forward_mtp_310(self):
|
||||
query = torch.randn(4, 8 * 64)
|
||||
key, value = None, None
|
||||
output = torch.empty_like(query)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.SpecDecoding
|
||||
with self.assertRaises(NotImplementedError):
|
||||
output = self.impl.forward_impl(query, key, value, None, metadata, output)
|
||||
|
||||
@@ -198,6 +198,8 @@ class AscendAttentionBackendImpl310(AscendAttentionBackendImpl):
|
||||
out=output,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def forward_impl(self, query, key, value, kv_cache, attn_metadata, output):
|
||||
"""
|
||||
Main dispatch method for attention operations.
|
||||
@@ -218,22 +220,19 @@ class AscendAttentionBackendImpl310(AscendAttentionBackendImpl):
|
||||
NotImplementedError: If the attention state is not supported on 310P.
|
||||
"""
|
||||
state = attn_metadata.attn_state
|
||||
|
||||
if state == AscendAttentionState.DecodeOnly:
|
||||
return self.forward_paged_attention(query, attn_metadata, output)
|
||||
|
||||
# Condition for PrefillNoCache: No previous tokens have been processed yet
|
||||
if state == AscendAttentionState.PrefillNoCache:
|
||||
out = self.forward_prefill_310(query, key, value, attn_metadata, output)
|
||||
return out
|
||||
|
||||
if state == AscendAttentionState.ChunkedPrefill:
|
||||
self.forward_chunked_prefill_310(query, attn_metadata, output)
|
||||
return output
|
||||
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__}.forward_impl: 310P only supports "
|
||||
f"{AscendAttentionState.DecodeOnly.name}, "
|
||||
f"{AscendAttentionState.PrefillNoCache.name}, "
|
||||
f"{AscendAttentionState.ChunkedPrefill.name}, "
|
||||
f"got {state!r}."
|
||||
)
|
||||
output = self.forward_prefill_310(query, key, value, attn_metadata, output)
|
||||
# Condition for DecodeOnly: Pure decoding phase where each request generates one token
|
||||
elif state == AscendAttentionState.DecodeOnly:
|
||||
output = self.forward_paged_attention(query, attn_metadata, output)
|
||||
# Condition for ChunkedPrefill:
|
||||
# 1. During speculative decoding scenarios (except mtp)
|
||||
# 2. Processing large prefill requests in chunks
|
||||
# Condition for PrefillCacheHit: Indicates prefill with some cached tokens already processed
|
||||
elif state in [AscendAttentionState.ChunkedPrefill, AscendAttentionState.PrefillCacheHit]:
|
||||
output = self.forward_chunked_prefill_310(query, attn_metadata, output)
|
||||
# Condition for SpecDecoding: Specified for mtp, which is not supported yet.
|
||||
else:
|
||||
raise NotImplementedError(f"AscendAttentionState: {state} is not supported for 310P currently.")
|
||||
return output
|
||||
|
||||
@@ -17,9 +17,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MambaSpec
|
||||
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||
@@ -185,3 +187,97 @@ class NPUModelRunner310(NPUModelRunner):
|
||||
raise ValueError("Unknown KV cache spec type.")
|
||||
|
||||
return kv_caches
|
||||
|
||||
# Override this function because of tensor.copy_(other) accuracy issue.
|
||||
# TODO: This override will be removed after tensor.copy_(other) accuracy issue is resolved.
|
||||
def _prepare_input_ids(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
total_num_scheduled_tokens: int,
|
||||
cu_num_tokens: np.ndarray,
|
||||
) -> None:
|
||||
"""Prepare the input IDs for the current batch.
|
||||
|
||||
Carefully handles the `prev_sampled_token_ids` which can be cached
|
||||
from the previous engine iteration, in which case those tokens on the
|
||||
GPU need to be copied into the corresponding slots into input_ids."""
|
||||
|
||||
if self.input_batch.prev_sampled_token_ids is None:
|
||||
# Normal scheduling case
|
||||
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||
if self.enable_prompt_embeds:
|
||||
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
|
||||
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||
return
|
||||
|
||||
# Async scheduling case, where some decode requests from the previous
|
||||
# iteration won't have entries in input_ids_cpu and need to be copied
|
||||
# on the NPU from prev_sampled_token_ids.
|
||||
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
|
||||
assert prev_req_id_to_index is not None
|
||||
sample_flattened_indices: list[int] = []
|
||||
spec_flattened_indices: list[int] = []
|
||||
prev_common_req_indices: list[int] = []
|
||||
prev_draft_token_indices: list[int] = []
|
||||
indices_match = True
|
||||
max_flattened_index = -1
|
||||
total_num_spec_tokens = 0
|
||||
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
|
||||
for req_id, cur_index in self.input_batch.req_id_to_index.items():
|
||||
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
|
||||
prev_common_req_indices.append(prev_index)
|
||||
draft_len = len(scheduled_spec_tokens.get(req_id, ()))
|
||||
total_num_spec_tokens += draft_len
|
||||
flattened_index = cu_num_tokens[cur_index].item() - 1
|
||||
sample_flattened_indices.append(flattened_index - draft_len)
|
||||
spec_flattened_indices.extend(range(flattened_index - draft_len + 1, flattened_index + 1))
|
||||
start = prev_index * self.num_spec_tokens
|
||||
prev_draft_token_indices.extend(range(start, start + draft_len))
|
||||
indices_match &= prev_index == flattened_index
|
||||
max_flattened_index = max(max_flattened_index, flattened_index)
|
||||
num_commmon_tokens = len(sample_flattened_indices)
|
||||
total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens
|
||||
if num_commmon_tokens < total_without_spec:
|
||||
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||
if self.enable_prompt_embeds:
|
||||
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
|
||||
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||
if num_commmon_tokens == 0:
|
||||
return
|
||||
if indices_match and max_flattened_index == (num_commmon_tokens - 1):
|
||||
# NOTE: Override the copy_ function here
|
||||
indices = torch.arange(num_commmon_tokens, device=self.input_ids.gpu.device)
|
||||
source = self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0]
|
||||
self.input_ids.gpu.index_copy_(0, indices, source)
|
||||
if self.enable_prompt_embeds:
|
||||
self.is_token_ids.gpu[:num_commmon_tokens] = True
|
||||
return
|
||||
# Upload the index tensors asynchronously so the scatter can be non-blocking.
|
||||
sampled_tokens_index_tensor = torch.tensor(
|
||||
sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||
).to(self.device, non_blocking=True)
|
||||
prev_common_req_indices_tensor = torch.tensor(
|
||||
prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||
).to(self.device, non_blocking=True)
|
||||
self.input_ids.gpu.scatter_(
|
||||
dim=0,
|
||||
index=sampled_tokens_index_tensor,
|
||||
src=self.input_batch.prev_sampled_token_ids[prev_common_req_indices_tensor, 0],
|
||||
)
|
||||
# Scatter the draft tokens after the sampled tokens are scattered.
|
||||
if self._draft_token_ids is None or not spec_flattened_indices:
|
||||
return
|
||||
assert isinstance(self._draft_token_ids, torch.Tensor)
|
||||
draft_tokens_index_tensor = torch.tensor(
|
||||
spec_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||
).to(self.device, non_blocking=True)
|
||||
prev_draft_token_indices_tensor = torch.tensor(
|
||||
prev_draft_token_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||
).to(self.device, non_blocking=True)
|
||||
draft_token_ids = self._draft_token_ids.to(dtype=torch.int32)
|
||||
self.input_ids.gpu.scatter_(
|
||||
dim=0,
|
||||
index=draft_tokens_index_tensor,
|
||||
src=draft_token_ids.flatten()[prev_draft_token_indices_tensor],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user