[bugfix] fix the complex and potentially problematic generate_kv_idx. (#5957)

### What this PR does / why we need it?
In long-sequence scenarios, the chunked-prefill component may encounter
dimension misalignment issues, which previously occurred during
precision testing on the code_generate_lite dataset. This PR removes
redundant computations and instead derives the value using existing
results and straightforward calculations.
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
Qiu
2026-01-21 14:21:02 +08:00
committed by GitHub
parent 12a668b1d9
commit 58ff465821
5 changed files with 4 additions and 57 deletions

View File

@@ -87,7 +87,6 @@ def test_models_chunked_prefill_mixed_length_prompts_including_1_token(
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1" "VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"
}) })
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.skip(reason="skip for bad adaptability with main2main")
def test_models_chunked_prefill_with_empty_kvcache(model: str): def test_models_chunked_prefill_with_empty_kvcache(model: str):
TEST_ROPE_PARAMETERS = { TEST_ROPE_PARAMETERS = {
"rope_theta": 1000000, "rope_theta": 1000000,

View File

@@ -162,14 +162,12 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
local_total_toks = local_chunked_kv_lens_rank.sum() local_total_toks = local_chunked_kv_lens_rank.sum()
chunked_req_mask = self._get_chunked_req_mask(local_context_lens_allranks) chunked_req_mask = self._get_chunked_req_mask(local_context_lens_allranks)
local_chunk_starts = torch.zeros( local_chunk_starts = torch.zeros(
(len(local_context_lens_allranks)), dtype=torch.int32, device=self.device (len(local_context_lens_allranks),), dtype=torch.int32, device=self.device
) )
cp_kv_recover_idx_for_chunk = common_long_seq_metadata.cp_kv_recover_idx_for_chunk kv_inverse_idx_for_chunk = torch.argsort(
kv_inverse_idx_for_chunk = ( common_long_seq_metadata.pcp_allgather_restore_idx[pcp_size * num_decode_tokens :].to(torch.float32)
torch.argsort(cp_kv_recover_idx_for_chunk.to(torch.float32))
if cp_kv_recover_idx_for_chunk is not None
else None
) )
cp_kv_recover_idx_for_chunk = torch.argsort(kv_inverse_idx_for_chunk)
batch_chunk_seq_mask = local_context_lens_allranks[:, self.pcp_rank, self.dcp_rank] == 0 batch_chunk_seq_mask = local_context_lens_allranks[:, self.pcp_rank, self.dcp_rank] == 0
batch_chunk_seq_mask = torch.repeat_interleave( batch_chunk_seq_mask = torch.repeat_interleave(

View File

@@ -70,8 +70,6 @@ class AscendPrefillContextParallelMetadata:
pcp_allgather_restore_idx: torch.Tensor = None pcp_allgather_restore_idx: torch.Tensor = None
cp_kv_recover_idx_for_chunk: torch.Tensor = None
num_actual_tokens_pcp_padded: int = 0 num_actual_tokens_pcp_padded: int = 0
num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None

View File

@@ -555,9 +555,6 @@ class NPUModelRunner(GPUModelRunner):
self.num_spec_tokens) self.num_spec_tokens)
if self.pcp_size > 1: if self.pcp_size > 1:
if not self.vllm_config.model_config.use_mla:
self.pcp_manager.generate_kv_idx(scheduler_output,
self.input_batch)
num_scheduled_tokens[: num_scheduled_tokens[:
num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp( num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp(
num_scheduled_tokens[:num_reqs], num_scheduled_tokens[:num_reqs],

View File

@@ -86,9 +86,6 @@ class PCPManager:
) )
self.num_actual_tokens_pcp_padded = 0 self.num_actual_tokens_pcp_padded = 0
self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy()
self.cp_kv_recover_idx_for_chunk: List[List[int]] = [
[] for _ in range(self.pcp_world_size)
]
self.full_indices = list( self.full_indices = list(
range(self.max_num_tokens * self.pcp_world_size * range(self.max_num_tokens * self.pcp_world_size *
self.dcp_world_size + self.pcp_world_size * self.dcp_world_size + self.pcp_world_size *
@@ -563,47 +560,6 @@ class PCPManager:
[-1, pcp_world_size, dcp_world_size]) [-1, pcp_world_size, dcp_world_size])
return dcp_local_seq_lens return dcp_local_seq_lens
def generate_kv_idx(self, scheduler_output, input_batch):
if not self.pcp_world_size > 1:
return
self.cp_kv_recover_idx_for_chunk = [[]
for _ in range(self.pcp_world_size)
]
for i, req_id in enumerate(input_batch.req_ids):
num_scheduled_token = scheduler_output.num_scheduled_tokens[req_id]
is_prefill = num_scheduled_token > self.decode_threshold
if is_prefill:
num_cp_padded_scheduled_tokens = cdiv(
num_scheduled_token,
2 * self.pcp_world_size) * (2 * self.pcp_world_size)
chunk_size = num_cp_padded_scheduled_tokens // (
2 * self.pcp_world_size)
num_added_recover_tokens = len(
self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_world_size
for rank in range(self.pcp_world_size):
self.cp_kv_recover_idx_for_chunk[rank].extend(
self.full_indices[rank * chunk_size +
num_added_recover_tokens:(rank + 1) *
chunk_size +
num_added_recover_tokens])
self.cp_kv_recover_idx_for_chunk[rank].extend(
self.full_indices[num_cp_padded_scheduled_tokens -
(rank + 1) * chunk_size +
num_added_recover_tokens:
num_cp_padded_scheduled_tokens -
rank * chunk_size +
num_added_recover_tokens])
cp_kv_recover_idx_for_chunk = torch.from_numpy(
np.concatenate(
self.cp_kv_recover_idx_for_chunk)).to(device=self.device)
cp_kv_recover_idx_for_chunk.copy_(torch.tensor(
np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()),
non_blocking=True)
self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to(
torch.float32).argsort().to(torch.int32)
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens,
input_batch, num_scheduled_tokens): input_batch, num_scheduled_tokens):
from vllm_ascend.attention.utils import \ from vllm_ascend.attention.utils import \
@@ -774,7 +730,6 @@ class PCPManager:
} }
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[: long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[:
num_actual_tokens_pcp_padded] num_actual_tokens_pcp_padded]
long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
long_seq_metadata.q_full_idx = self.q_full_idx long_seq_metadata.q_full_idx = self.q_full_idx