[Feat] 310p supports PrefillCacheHit State (#6756)
### What this PR does / why we need it?
This PR extends the Ascend 310P attention backend to support the
`PrefillCacheHit` state. Previously, only `PrefillNoCache`,
`DecodeOnly`, and `ChunkedPrefill` were supported.
This PR handles this state by routing it to the existing
`forward_chunked_prefill_310` implementation, which is suitable for this
scenario.
The changes also include refactoring the main `forward_impl` dispatch
method for better clarity and updating unit tests to cover the new state
and ensure correctness.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Accuracy test when chunked prefill is disabled.
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: pu-zhe <zpuaa@outlook.com>
This commit is contained in:
@@ -78,7 +78,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
|||||||
def test_forward_prefill_310(
|
def test_forward_prefill_310(
|
||||||
self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache
|
self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache
|
||||||
):
|
):
|
||||||
"""Test forward pass in PrefillCacheHit state"""
|
"""Test forward pass in PrefillNoCache state"""
|
||||||
query = torch.randn(10, 8, 64)
|
query = torch.randn(10, 8, 64)
|
||||||
key = torch.randn(10, 8, 64)
|
key = torch.randn(10, 8, 64)
|
||||||
value = torch.randn(10, 8, 64)
|
value = torch.randn(10, 8, 64)
|
||||||
@@ -98,7 +98,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
|||||||
|
|
||||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
mock_npu_npu_flash_attention.return_value = torch.ones(10, 8, 64)
|
mock_npu_npu_flash_attention.return_value = torch.ones(10, 8, 64)
|
||||||
output = self.impl.forward_prefill_310(query, key, value, metadata, output)
|
output = self.impl.forward_impl(query, key, value, None, metadata, output)
|
||||||
|
|
||||||
mock_npu_npu_flash_attention.assert_called_once()
|
mock_npu_npu_flash_attention.assert_called_once()
|
||||||
|
|
||||||
@@ -107,10 +107,15 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
|||||||
@patch("torch_npu._npu_paged_attention_splitfuse")
|
@patch("torch_npu._npu_paged_attention_splitfuse")
|
||||||
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
||||||
def test_forward_chunked_prefill_310(
|
def test_forward_chunked_prefill_310(
|
||||||
self, mock_get_forward_context, mock_npu_paged_attention_splitfuse, mock_npu_reshape_and_cache, mock_format_cast
|
self,
|
||||||
|
mock_get_forward_context,
|
||||||
|
mock_npu_paged_attention_splitfuse,
|
||||||
|
mock_npu_reshape_and_cache,
|
||||||
|
mock_format_cast,
|
||||||
):
|
):
|
||||||
"""Test forward pass in PrefillCacheHit state"""
|
"""Test forward pass in ChunkedPrefill state"""
|
||||||
query = torch.randn(5, 8, 64)
|
query = torch.randn(5, 8, 64)
|
||||||
|
key, value = None, None
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
@@ -128,7 +133,42 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
|||||||
|
|
||||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64)
|
mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64)
|
||||||
output = self.impl.forward_chunked_prefill_310(query, metadata, output)
|
output = self.impl.forward_impl(query, key, value, None, metadata, output)
|
||||||
|
|
||||||
|
mock_npu_paged_attention_splitfuse.assert_called_once()
|
||||||
|
|
||||||
|
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
|
||||||
|
@patch("torch_npu._npu_reshape_and_cache")
|
||||||
|
@patch("torch_npu._npu_paged_attention_splitfuse")
|
||||||
|
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
||||||
|
def test_forward_prefill_cache_hit_310(
|
||||||
|
self,
|
||||||
|
mock_get_forward_context,
|
||||||
|
mock_npu_paged_attention_splitfuse,
|
||||||
|
mock_npu_reshape_and_cache,
|
||||||
|
mock_format_cast,
|
||||||
|
):
|
||||||
|
"""Test forward pass in PrefillCacheHit state"""
|
||||||
|
query = torch.randn(5, 8, 64)
|
||||||
|
key, value = None, None
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
metadata = self.attn_metadata
|
||||||
|
metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
||||||
|
metadata.attn_mask = torch.randn(1, 128, 16, 16)
|
||||||
|
metadata.query_lens = torch.tensor([5])
|
||||||
|
metadata.seq_lens = torch.tensor([1, 4])
|
||||||
|
metadata.query_start_loc = torch.tensor([0, 1, 5])
|
||||||
|
metadata.actual_seq_lengths_q = [5]
|
||||||
|
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||||
|
metadata.num_actual_tokens = 10
|
||||||
|
metadata.num_decode_tokens = 0
|
||||||
|
metadata.num_decodes = 0
|
||||||
|
metadata.num_prefills = 10
|
||||||
|
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||||
|
|
||||||
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
|
mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64)
|
||||||
|
output = self.impl.forward_impl(query, key, value, None, metadata, output)
|
||||||
|
|
||||||
mock_npu_paged_attention_splitfuse.assert_called_once()
|
mock_npu_paged_attention_splitfuse.assert_called_once()
|
||||||
|
|
||||||
@@ -141,6 +181,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
|||||||
):
|
):
|
||||||
"""Test forward pass in DecodeOnly state"""
|
"""Test forward pass in DecodeOnly state"""
|
||||||
query = torch.randn(4, 8 * 64)
|
query = torch.randn(4, 8 * 64)
|
||||||
|
key, value = None, None
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
@@ -155,6 +196,15 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
|||||||
|
|
||||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
|
|
||||||
output = self.impl.forward_paged_attention(query, metadata, output)
|
output = self.impl.forward_impl(query, key, value, None, metadata, output)
|
||||||
|
|
||||||
mock_paged_attention.assert_called_once()
|
mock_paged_attention.assert_called_once()
|
||||||
|
|
||||||
|
def test_forward_mtp_310(self):
|
||||||
|
query = torch.randn(4, 8 * 64)
|
||||||
|
key, value = None, None
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
metadata = self.attn_metadata
|
||||||
|
metadata.attn_state = AscendAttentionState.SpecDecoding
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
output = self.impl.forward_impl(query, key, value, None, metadata, output)
|
||||||
|
|||||||
@@ -198,6 +198,8 @@ class AscendAttentionBackendImpl310(AscendAttentionBackendImpl):
|
|||||||
out=output,
|
out=output,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
def forward_impl(self, query, key, value, kv_cache, attn_metadata, output):
|
def forward_impl(self, query, key, value, kv_cache, attn_metadata, output):
|
||||||
"""
|
"""
|
||||||
Main dispatch method for attention operations.
|
Main dispatch method for attention operations.
|
||||||
@@ -218,22 +220,19 @@ class AscendAttentionBackendImpl310(AscendAttentionBackendImpl):
|
|||||||
NotImplementedError: If the attention state is not supported on 310P.
|
NotImplementedError: If the attention state is not supported on 310P.
|
||||||
"""
|
"""
|
||||||
state = attn_metadata.attn_state
|
state = attn_metadata.attn_state
|
||||||
|
# Condition for PrefillNoCache: No previous tokens have been processed yet
|
||||||
if state == AscendAttentionState.DecodeOnly:
|
|
||||||
return self.forward_paged_attention(query, attn_metadata, output)
|
|
||||||
|
|
||||||
if state == AscendAttentionState.PrefillNoCache:
|
if state == AscendAttentionState.PrefillNoCache:
|
||||||
out = self.forward_prefill_310(query, key, value, attn_metadata, output)
|
output = self.forward_prefill_310(query, key, value, attn_metadata, output)
|
||||||
return out
|
# Condition for DecodeOnly: Pure decoding phase where each request generates one token
|
||||||
|
elif state == AscendAttentionState.DecodeOnly:
|
||||||
if state == AscendAttentionState.ChunkedPrefill:
|
output = self.forward_paged_attention(query, attn_metadata, output)
|
||||||
self.forward_chunked_prefill_310(query, attn_metadata, output)
|
# Condition for ChunkedPrefill:
|
||||||
return output
|
# 1. During speculative decoding scenarios (except mtp)
|
||||||
|
# 2. Processing large prefill requests in chunks
|
||||||
raise NotImplementedError(
|
# Condition for PrefillCacheHit: Indicates prefill with some cached tokens already processed
|
||||||
f"{self.__class__.__name__}.forward_impl: 310P only supports "
|
elif state in [AscendAttentionState.ChunkedPrefill, AscendAttentionState.PrefillCacheHit]:
|
||||||
f"{AscendAttentionState.DecodeOnly.name}, "
|
output = self.forward_chunked_prefill_310(query, attn_metadata, output)
|
||||||
f"{AscendAttentionState.PrefillNoCache.name}, "
|
# Condition for SpecDecoding: Specified for mtp, which is not supported yet.
|
||||||
f"{AscendAttentionState.ChunkedPrefill.name}, "
|
else:
|
||||||
f"got {state!r}."
|
raise NotImplementedError(f"AscendAttentionState: {state} is not supported for 310P currently.")
|
||||||
)
|
return output
|
||||||
|
|||||||
@@ -17,9 +17,11 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MambaSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MambaSpec
|
||||||
|
|
||||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||||
@@ -185,3 +187,97 @@ class NPUModelRunner310(NPUModelRunner):
|
|||||||
raise ValueError("Unknown KV cache spec type.")
|
raise ValueError("Unknown KV cache spec type.")
|
||||||
|
|
||||||
return kv_caches
|
return kv_caches
|
||||||
|
|
||||||
|
# Override this function because of tensor.copy_(other) accuracy issue.
|
||||||
|
# TODO: This override will be removed after tensor.copy_(other) accuracy issue is resolved.
|
||||||
|
def _prepare_input_ids(
|
||||||
|
self,
|
||||||
|
scheduler_output: SchedulerOutput,
|
||||||
|
total_num_scheduled_tokens: int,
|
||||||
|
cu_num_tokens: np.ndarray,
|
||||||
|
) -> None:
|
||||||
|
"""Prepare the input IDs for the current batch.
|
||||||
|
|
||||||
|
Carefully handles the `prev_sampled_token_ids` which can be cached
|
||||||
|
from the previous engine iteration, in which case those tokens on the
|
||||||
|
GPU need to be copied into the corresponding slots into input_ids."""
|
||||||
|
|
||||||
|
if self.input_batch.prev_sampled_token_ids is None:
|
||||||
|
# Normal scheduling case
|
||||||
|
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||||
|
if self.enable_prompt_embeds:
|
||||||
|
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
|
||||||
|
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Async scheduling case, where some decode requests from the previous
|
||||||
|
# iteration won't have entries in input_ids_cpu and need to be copied
|
||||||
|
# on the NPU from prev_sampled_token_ids.
|
||||||
|
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
|
||||||
|
assert prev_req_id_to_index is not None
|
||||||
|
sample_flattened_indices: list[int] = []
|
||||||
|
spec_flattened_indices: list[int] = []
|
||||||
|
prev_common_req_indices: list[int] = []
|
||||||
|
prev_draft_token_indices: list[int] = []
|
||||||
|
indices_match = True
|
||||||
|
max_flattened_index = -1
|
||||||
|
total_num_spec_tokens = 0
|
||||||
|
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||||
|
|
||||||
|
for req_id, cur_index in self.input_batch.req_id_to_index.items():
|
||||||
|
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
|
||||||
|
prev_common_req_indices.append(prev_index)
|
||||||
|
draft_len = len(scheduled_spec_tokens.get(req_id, ()))
|
||||||
|
total_num_spec_tokens += draft_len
|
||||||
|
flattened_index = cu_num_tokens[cur_index].item() - 1
|
||||||
|
sample_flattened_indices.append(flattened_index - draft_len)
|
||||||
|
spec_flattened_indices.extend(range(flattened_index - draft_len + 1, flattened_index + 1))
|
||||||
|
start = prev_index * self.num_spec_tokens
|
||||||
|
prev_draft_token_indices.extend(range(start, start + draft_len))
|
||||||
|
indices_match &= prev_index == flattened_index
|
||||||
|
max_flattened_index = max(max_flattened_index, flattened_index)
|
||||||
|
num_commmon_tokens = len(sample_flattened_indices)
|
||||||
|
total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens
|
||||||
|
if num_commmon_tokens < total_without_spec:
|
||||||
|
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||||
|
if self.enable_prompt_embeds:
|
||||||
|
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
|
||||||
|
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||||
|
if num_commmon_tokens == 0:
|
||||||
|
return
|
||||||
|
if indices_match and max_flattened_index == (num_commmon_tokens - 1):
|
||||||
|
# NOTE: Override the copy_ function here
|
||||||
|
indices = torch.arange(num_commmon_tokens, device=self.input_ids.gpu.device)
|
||||||
|
source = self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0]
|
||||||
|
self.input_ids.gpu.index_copy_(0, indices, source)
|
||||||
|
if self.enable_prompt_embeds:
|
||||||
|
self.is_token_ids.gpu[:num_commmon_tokens] = True
|
||||||
|
return
|
||||||
|
# Upload the index tensors asynchronously so the scatter can be non-blocking.
|
||||||
|
sampled_tokens_index_tensor = torch.tensor(
|
||||||
|
sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||||
|
).to(self.device, non_blocking=True)
|
||||||
|
prev_common_req_indices_tensor = torch.tensor(
|
||||||
|
prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||||
|
).to(self.device, non_blocking=True)
|
||||||
|
self.input_ids.gpu.scatter_(
|
||||||
|
dim=0,
|
||||||
|
index=sampled_tokens_index_tensor,
|
||||||
|
src=self.input_batch.prev_sampled_token_ids[prev_common_req_indices_tensor, 0],
|
||||||
|
)
|
||||||
|
# Scatter the draft tokens after the sampled tokens are scattered.
|
||||||
|
if self._draft_token_ids is None or not spec_flattened_indices:
|
||||||
|
return
|
||||||
|
assert isinstance(self._draft_token_ids, torch.Tensor)
|
||||||
|
draft_tokens_index_tensor = torch.tensor(
|
||||||
|
spec_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||||
|
).to(self.device, non_blocking=True)
|
||||||
|
prev_draft_token_indices_tensor = torch.tensor(
|
||||||
|
prev_draft_token_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||||
|
).to(self.device, non_blocking=True)
|
||||||
|
draft_token_ids = self._draft_token_ids.to(dtype=torch.int32)
|
||||||
|
self.input_ids.gpu.scatter_(
|
||||||
|
dim=0,
|
||||||
|
index=draft_tokens_index_tensor,
|
||||||
|
src=draft_token_ids.flatten()[prev_draft_token_indices_tensor],
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user