From a8e951e6f557514bcda0d7db6e7710fb90a2474d Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Tue, 24 Feb 2026 16:48:05 +0800 Subject: [PATCH] [Feat] 310p supports PrefillCacheHit State (#6756) ### What this PR does / why we need it? This PR extends the Ascend 310P attention backend to support the `PrefillCacheHit` state. Previously, only `PrefillNoCache`, `DecodeOnly`, and `ChunkedPrefill` were supported. This PR handles this state by routing it to the existing `forward_chunked_prefill_310` implementation, which is suitable for this scenario. The changes also include refactoring the main `forward_impl` dispatch method for better clarity and updating unit tests to cover the new state and ensure correctness. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Accuracy test when chunked prefill is disabled. - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/9562912cead1f11e8540fb91306c5cbda66f0007 --------- Signed-off-by: pu-zhe --- .../_310p/attention/test_attention_v1_310.py | 62 ++++++++++-- vllm_ascend/_310p/attention/attention_v1.py | 35 ++++--- vllm_ascend/_310p/model_runner_310p.py | 96 +++++++++++++++++++ 3 files changed, 169 insertions(+), 24 deletions(-) diff --git a/tests/ut/_310p/attention/test_attention_v1_310.py b/tests/ut/_310p/attention/test_attention_v1_310.py index 0794ec42..3370baee 100644 --- a/tests/ut/_310p/attention/test_attention_v1_310.py +++ b/tests/ut/_310p/attention/test_attention_v1_310.py @@ -78,7 +78,7 @@ class TestAscendAttentionBackendImpl310(TestBase): def test_forward_prefill_310( self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache ): - """Test forward pass in PrefillCacheHit state""" + """Test forward pass in PrefillNoCache state""" query = torch.randn(10, 8, 64) key = torch.randn(10, 8, 64) value = torch.randn(10, 8, 64) @@ -98,7 +98,7 @@ class TestAscendAttentionBackendImpl310(TestBase): mock_get_forward_context.return_value = MagicMock(capturing=False) mock_npu_npu_flash_attention.return_value = torch.ones(10, 8, 64) - output = self.impl.forward_prefill_310(query, key, value, metadata, output) + output = self.impl.forward_impl(query, key, value, None, metadata, output) mock_npu_npu_flash_attention.assert_called_once() @@ -107,10 +107,15 @@ class TestAscendAttentionBackendImpl310(TestBase): @patch("torch_npu._npu_paged_attention_splitfuse") @patch("vllm_ascend.attention.attention_v1.get_forward_context") def test_forward_chunked_prefill_310( - self, mock_get_forward_context, mock_npu_paged_attention_splitfuse, mock_npu_reshape_and_cache, mock_format_cast + self, + mock_get_forward_context, + mock_npu_paged_attention_splitfuse, + mock_npu_reshape_and_cache, + mock_format_cast, ): - """Test forward pass in PrefillCacheHit state""" + """Test forward pass in ChunkedPrefill state""" query = torch.randn(5, 8, 64) + key, value = None, None output = torch.empty_like(query) metadata = self.attn_metadata metadata.attn_state = AscendAttentionState.ChunkedPrefill @@ -128,7 +133,42 @@ class TestAscendAttentionBackendImpl310(TestBase): mock_get_forward_context.return_value = MagicMock(capturing=False) mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64) - output = self.impl.forward_chunked_prefill_310(query, metadata, output) + output = self.impl.forward_impl(query, key, value, None, metadata, output) + + mock_npu_paged_attention_splitfuse.assert_called_once() + + @patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16)) + @patch("torch_npu._npu_reshape_and_cache") + @patch("torch_npu._npu_paged_attention_splitfuse") + @patch("vllm_ascend.attention.attention_v1.get_forward_context") + def test_forward_prefill_cache_hit_310( + self, + mock_get_forward_context, + mock_npu_paged_attention_splitfuse, + mock_npu_reshape_and_cache, + mock_format_cast, + ): + """Test forward pass in PrefillCacheHit state""" + query = torch.randn(5, 8, 64) + key, value = None, None + output = torch.empty_like(query) + metadata = self.attn_metadata + metadata.attn_state = AscendAttentionState.PrefillCacheHit + metadata.attn_mask = torch.randn(1, 128, 16, 16) + metadata.query_lens = torch.tensor([5]) + metadata.seq_lens = torch.tensor([1, 4]) + metadata.query_start_loc = torch.tensor([0, 1, 5]) + metadata.actual_seq_lengths_q = [5] + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.num_decode_tokens = 0 + metadata.num_decodes = 0 + metadata.num_prefills = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + + mock_get_forward_context.return_value = MagicMock(capturing=False) + mock_npu_paged_attention_splitfuse.return_value = torch.ones(5, 8, 64) + output = self.impl.forward_impl(query, key, value, None, metadata, output) mock_npu_paged_attention_splitfuse.assert_called_once() @@ -141,6 +181,7 @@ class TestAscendAttentionBackendImpl310(TestBase): ): """Test forward pass in DecodeOnly state""" query = torch.randn(4, 8 * 64) + key, value = None, None output = torch.empty_like(query) metadata = self.attn_metadata @@ -155,6 +196,15 @@ class TestAscendAttentionBackendImpl310(TestBase): mock_get_forward_context.return_value = MagicMock(capturing=False) - output = self.impl.forward_paged_attention(query, metadata, output) + output = self.impl.forward_impl(query, key, value, None, metadata, output) mock_paged_attention.assert_called_once() + + def test_forward_mtp_310(self): + query = torch.randn(4, 8 * 64) + key, value = None, None + output = torch.empty_like(query) + metadata = self.attn_metadata + metadata.attn_state = AscendAttentionState.SpecDecoding + with self.assertRaises(NotImplementedError): + output = self.impl.forward_impl(query, key, value, None, metadata, output) diff --git a/vllm_ascend/_310p/attention/attention_v1.py b/vllm_ascend/_310p/attention/attention_v1.py index ce3b8f02..c080d842 100644 --- a/vllm_ascend/_310p/attention/attention_v1.py +++ b/vllm_ascend/_310p/attention/attention_v1.py @@ -198,6 +198,8 @@ class AscendAttentionBackendImpl310(AscendAttentionBackendImpl): out=output, ) + return output + def forward_impl(self, query, key, value, kv_cache, attn_metadata, output): """ Main dispatch method for attention operations. @@ -218,22 +220,19 @@ class AscendAttentionBackendImpl310(AscendAttentionBackendImpl): NotImplementedError: If the attention state is not supported on 310P. """ state = attn_metadata.attn_state - - if state == AscendAttentionState.DecodeOnly: - return self.forward_paged_attention(query, attn_metadata, output) - + # Condition for PrefillNoCache: No previous tokens have been processed yet if state == AscendAttentionState.PrefillNoCache: - out = self.forward_prefill_310(query, key, value, attn_metadata, output) - return out - - if state == AscendAttentionState.ChunkedPrefill: - self.forward_chunked_prefill_310(query, attn_metadata, output) - return output - - raise NotImplementedError( - f"{self.__class__.__name__}.forward_impl: 310P only supports " - f"{AscendAttentionState.DecodeOnly.name}, " - f"{AscendAttentionState.PrefillNoCache.name}, " - f"{AscendAttentionState.ChunkedPrefill.name}, " - f"got {state!r}." - ) + output = self.forward_prefill_310(query, key, value, attn_metadata, output) + # Condition for DecodeOnly: Pure decoding phase where each request generates one token + elif state == AscendAttentionState.DecodeOnly: + output = self.forward_paged_attention(query, attn_metadata, output) + # Condition for ChunkedPrefill: + # 1. During speculative decoding scenarios (except mtp) + # 2. Processing large prefill requests in chunks + # Condition for PrefillCacheHit: Indicates prefill with some cached tokens already processed + elif state in [AscendAttentionState.ChunkedPrefill, AscendAttentionState.PrefillCacheHit]: + output = self.forward_chunked_prefill_310(query, attn_metadata, output) + # Condition for SpecDecoding: Specified for mtp, which is not supported yet. + else: + raise NotImplementedError(f"AscendAttentionState: {state} is not supported for 310P currently.") + return output diff --git a/vllm_ascend/_310p/model_runner_310p.py b/vllm_ascend/_310p/model_runner_310p.py index e3df5c9a..9cdbdae5 100644 --- a/vllm_ascend/_310p/model_runner_310p.py +++ b/vllm_ascend/_310p/model_runner_310p.py @@ -17,9 +17,11 @@ from __future__ import annotations +import numpy as np import torch import torch_npu from vllm.logger import logger +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MambaSpec from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ @@ -185,3 +187,97 @@ class NPUModelRunner310(NPUModelRunner): raise ValueError("Unknown KV cache spec type.") return kv_caches + + # Override this function because of tensor.copy_(other) accuracy issue. + # TODO: This override will be removed after tensor.copy_(other) accuracy issue is resolved. + def _prepare_input_ids( + self, + scheduler_output: SchedulerOutput, + total_num_scheduled_tokens: int, + cu_num_tokens: np.ndarray, + ) -> None: + """Prepare the input IDs for the current batch. + + Carefully handles the `prev_sampled_token_ids` which can be cached + from the previous engine iteration, in which case those tokens on the + GPU need to be copied into the corresponding slots into input_ids.""" + + if self.input_batch.prev_sampled_token_ids is None: + # Normal scheduling case + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + if self.enable_prompt_embeds: + self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) + self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) + return + + # Async scheduling case, where some decode requests from the previous + # iteration won't have entries in input_ids_cpu and need to be copied + # on the NPU from prev_sampled_token_ids. + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + assert prev_req_id_to_index is not None + sample_flattened_indices: list[int] = [] + spec_flattened_indices: list[int] = [] + prev_common_req_indices: list[int] = [] + prev_draft_token_indices: list[int] = [] + indices_match = True + max_flattened_index = -1 + total_num_spec_tokens = 0 + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens + + for req_id, cur_index in self.input_batch.req_id_to_index.items(): + if (prev_index := prev_req_id_to_index.get(req_id)) is not None: + prev_common_req_indices.append(prev_index) + draft_len = len(scheduled_spec_tokens.get(req_id, ())) + total_num_spec_tokens += draft_len + flattened_index = cu_num_tokens[cur_index].item() - 1 + sample_flattened_indices.append(flattened_index - draft_len) + spec_flattened_indices.extend(range(flattened_index - draft_len + 1, flattened_index + 1)) + start = prev_index * self.num_spec_tokens + prev_draft_token_indices.extend(range(start, start + draft_len)) + indices_match &= prev_index == flattened_index + max_flattened_index = max(max_flattened_index, flattened_index) + num_commmon_tokens = len(sample_flattened_indices) + total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens + if num_commmon_tokens < total_without_spec: + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + if self.enable_prompt_embeds: + self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) + self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) + if num_commmon_tokens == 0: + return + if indices_match and max_flattened_index == (num_commmon_tokens - 1): + # NOTE: Override the copy_ function here + indices = torch.arange(num_commmon_tokens, device=self.input_ids.gpu.device) + source = self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0] + self.input_ids.gpu.index_copy_(0, indices, source) + if self.enable_prompt_embeds: + self.is_token_ids.gpu[:num_commmon_tokens] = True + return + # Upload the index tensors asynchronously so the scatter can be non-blocking. + sampled_tokens_index_tensor = torch.tensor( + sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) + prev_common_req_indices_tensor = torch.tensor( + prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) + self.input_ids.gpu.scatter_( + dim=0, + index=sampled_tokens_index_tensor, + src=self.input_batch.prev_sampled_token_ids[prev_common_req_indices_tensor, 0], + ) + # Scatter the draft tokens after the sampled tokens are scattered. + if self._draft_token_ids is None or not spec_flattened_indices: + return + assert isinstance(self._draft_token_ids, torch.Tensor) + draft_tokens_index_tensor = torch.tensor( + spec_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) + prev_draft_token_indices_tensor = torch.tensor( + prev_draft_token_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) + draft_token_ids = self._draft_token_ids.to(dtype=torch.int32) + self.input_ids.gpu.scatter_( + dim=0, + index=draft_tokens_index_tensor, + src=draft_token_ids.flatten()[prev_draft_token_indices_tensor], + )