From 58ff465821866478ab7f51cd47081b534689ce7d Mon Sep 17 00:00:00 2001 From: Qiu Date: Wed, 21 Jan 2026 14:21:02 +0800 Subject: [PATCH] [bugfix] fix the complex and potentially problematic generate_kv_idx. (#5957) ### What this PR does / why we need it? In long-sequence scenarios, the chunked-prefill component may encounter dimension misalignment issues, which previously occurred during precision testing on the code_generate_lite dataset. This PR removes redundant computations and instead derives the value using existing results and straightforward calculations. - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 Signed-off-by: QiuChunshuo --- .../long_sequence/test_chunked_prefill.py | 1 - .../context_parallel/attention_cp.py | 10 ++--- vllm_ascend/attention/utils.py | 2 - vllm_ascend/worker/model_runner_v1.py | 3 -- vllm_ascend/worker/pcp_utils.py | 45 ------------------- 5 files changed, 4 insertions(+), 57 deletions(-) diff --git a/tests/e2e/multicard/4-cards/long_sequence/test_chunked_prefill.py b/tests/e2e/multicard/4-cards/long_sequence/test_chunked_prefill.py index 3c18a946..3021c9d5 100644 --- a/tests/e2e/multicard/4-cards/long_sequence/test_chunked_prefill.py +++ b/tests/e2e/multicard/4-cards/long_sequence/test_chunked_prefill.py @@ -87,7 +87,6 @@ def test_models_chunked_prefill_mixed_length_prompts_including_1_token( "VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1" }) @pytest.mark.parametrize("model", MODELS) -@pytest.mark.skip(reason="skip for bad adaptability with main2main") def test_models_chunked_prefill_with_empty_kvcache(model: str): TEST_ROPE_PARAMETERS = { "rope_theta": 1000000, diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index 9cbfebcd..d7cf9963 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -162,14 +162,12 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): local_total_toks = local_chunked_kv_lens_rank.sum() chunked_req_mask = self._get_chunked_req_mask(local_context_lens_allranks) local_chunk_starts = torch.zeros( - (len(local_context_lens_allranks)), dtype=torch.int32, device=self.device + (len(local_context_lens_allranks),), dtype=torch.int32, device=self.device ) - cp_kv_recover_idx_for_chunk = common_long_seq_metadata.cp_kv_recover_idx_for_chunk - kv_inverse_idx_for_chunk = ( - torch.argsort(cp_kv_recover_idx_for_chunk.to(torch.float32)) - if cp_kv_recover_idx_for_chunk is not None - else None + kv_inverse_idx_for_chunk = torch.argsort( + common_long_seq_metadata.pcp_allgather_restore_idx[pcp_size * num_decode_tokens :].to(torch.float32) ) + cp_kv_recover_idx_for_chunk = torch.argsort(kv_inverse_idx_for_chunk) batch_chunk_seq_mask = local_context_lens_allranks[:, self.pcp_rank, self.dcp_rank] == 0 batch_chunk_seq_mask = torch.repeat_interleave( diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 3c6ba213..79d2300c 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -70,8 +70,6 @@ class AscendPrefillContextParallelMetadata: pcp_allgather_restore_idx: torch.Tensor = None - cp_kv_recover_idx_for_chunk: torch.Tensor = None - num_actual_tokens_pcp_padded: int = 0 num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 4718551c..f0076c76 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -555,9 +555,6 @@ class NPUModelRunner(GPUModelRunner): self.num_spec_tokens) if self.pcp_size > 1: - if not self.vllm_config.model_config.use_mla: - self.pcp_manager.generate_kv_idx(scheduler_output, - self.input_batch) num_scheduled_tokens[: num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp( num_scheduled_tokens[:num_reqs], diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 40294807..873c4f1c 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -86,9 +86,6 @@ class PCPManager: ) self.num_actual_tokens_pcp_padded = 0 self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() - self.cp_kv_recover_idx_for_chunk: List[List[int]] = [ - [] for _ in range(self.pcp_world_size) - ] self.full_indices = list( range(self.max_num_tokens * self.pcp_world_size * self.dcp_world_size + self.pcp_world_size * @@ -563,47 +560,6 @@ class PCPManager: [-1, pcp_world_size, dcp_world_size]) return dcp_local_seq_lens - def generate_kv_idx(self, scheduler_output, input_batch): - if not self.pcp_world_size > 1: - return - self.cp_kv_recover_idx_for_chunk = [[] - for _ in range(self.pcp_world_size) - ] - - for i, req_id in enumerate(input_batch.req_ids): - num_scheduled_token = scheduler_output.num_scheduled_tokens[req_id] - is_prefill = num_scheduled_token > self.decode_threshold - if is_prefill: - num_cp_padded_scheduled_tokens = cdiv( - num_scheduled_token, - 2 * self.pcp_world_size) * (2 * self.pcp_world_size) - chunk_size = num_cp_padded_scheduled_tokens // ( - 2 * self.pcp_world_size) - num_added_recover_tokens = len( - self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_world_size - for rank in range(self.pcp_world_size): - self.cp_kv_recover_idx_for_chunk[rank].extend( - self.full_indices[rank * chunk_size + - num_added_recover_tokens:(rank + 1) * - chunk_size + - num_added_recover_tokens]) - self.cp_kv_recover_idx_for_chunk[rank].extend( - self.full_indices[num_cp_padded_scheduled_tokens - - (rank + 1) * chunk_size + - num_added_recover_tokens: - num_cp_padded_scheduled_tokens - - rank * chunk_size + - num_added_recover_tokens]) - - cp_kv_recover_idx_for_chunk = torch.from_numpy( - np.concatenate( - self.cp_kv_recover_idx_for_chunk)).to(device=self.device) - cp_kv_recover_idx_for_chunk.copy_(torch.tensor( - np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()), - non_blocking=True) - self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to( - torch.float32).argsort().to(torch.int32) - def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_batch, num_scheduled_tokens): from vllm_ascend.attention.utils import \ @@ -774,7 +730,6 @@ class PCPManager: } long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[: num_actual_tokens_pcp_padded] - long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor long_seq_metadata.q_full_idx = self.q_full_idx