[bugfix] fix the complex and potentially problematic generate_kv_idx. (#5957)
### What this PR does / why we need it?
In long-sequence scenarios, the chunked-prefill component may encounter
dimension misalignment issues, which previously occurred during
precision testing on the code_generate_lite dataset. This PR removes
redundant computations and instead derives the value using existing
results and straightforward calculations.
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -87,7 +87,6 @@ def test_models_chunked_prefill_mixed_length_prompts_including_1_token(
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"
|
||||
})
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.skip(reason="skip for bad adaptability with main2main")
|
||||
def test_models_chunked_prefill_with_empty_kvcache(model: str):
|
||||
TEST_ROPE_PARAMETERS = {
|
||||
"rope_theta": 1000000,
|
||||
|
||||
@@ -162,14 +162,12 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
local_total_toks = local_chunked_kv_lens_rank.sum()
|
||||
chunked_req_mask = self._get_chunked_req_mask(local_context_lens_allranks)
|
||||
local_chunk_starts = torch.zeros(
|
||||
(len(local_context_lens_allranks)), dtype=torch.int32, device=self.device
|
||||
(len(local_context_lens_allranks),), dtype=torch.int32, device=self.device
|
||||
)
|
||||
cp_kv_recover_idx_for_chunk = common_long_seq_metadata.cp_kv_recover_idx_for_chunk
|
||||
kv_inverse_idx_for_chunk = (
|
||||
torch.argsort(cp_kv_recover_idx_for_chunk.to(torch.float32))
|
||||
if cp_kv_recover_idx_for_chunk is not None
|
||||
else None
|
||||
kv_inverse_idx_for_chunk = torch.argsort(
|
||||
common_long_seq_metadata.pcp_allgather_restore_idx[pcp_size * num_decode_tokens :].to(torch.float32)
|
||||
)
|
||||
cp_kv_recover_idx_for_chunk = torch.argsort(kv_inverse_idx_for_chunk)
|
||||
|
||||
batch_chunk_seq_mask = local_context_lens_allranks[:, self.pcp_rank, self.dcp_rank] == 0
|
||||
batch_chunk_seq_mask = torch.repeat_interleave(
|
||||
|
||||
@@ -70,8 +70,6 @@ class AscendPrefillContextParallelMetadata:
|
||||
|
||||
pcp_allgather_restore_idx: torch.Tensor = None
|
||||
|
||||
cp_kv_recover_idx_for_chunk: torch.Tensor = None
|
||||
|
||||
num_actual_tokens_pcp_padded: int = 0
|
||||
|
||||
num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None
|
||||
|
||||
@@ -555,9 +555,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.num_spec_tokens)
|
||||
|
||||
if self.pcp_size > 1:
|
||||
if not self.vllm_config.model_config.use_mla:
|
||||
self.pcp_manager.generate_kv_idx(scheduler_output,
|
||||
self.input_batch)
|
||||
num_scheduled_tokens[:
|
||||
num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp(
|
||||
num_scheduled_tokens[:num_reqs],
|
||||
|
||||
@@ -86,9 +86,6 @@ class PCPManager:
|
||||
)
|
||||
self.num_actual_tokens_pcp_padded = 0
|
||||
self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy()
|
||||
self.cp_kv_recover_idx_for_chunk: List[List[int]] = [
|
||||
[] for _ in range(self.pcp_world_size)
|
||||
]
|
||||
self.full_indices = list(
|
||||
range(self.max_num_tokens * self.pcp_world_size *
|
||||
self.dcp_world_size + self.pcp_world_size *
|
||||
@@ -563,47 +560,6 @@ class PCPManager:
|
||||
[-1, pcp_world_size, dcp_world_size])
|
||||
return dcp_local_seq_lens
|
||||
|
||||
def generate_kv_idx(self, scheduler_output, input_batch):
|
||||
if not self.pcp_world_size > 1:
|
||||
return
|
||||
self.cp_kv_recover_idx_for_chunk = [[]
|
||||
for _ in range(self.pcp_world_size)
|
||||
]
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_scheduled_token = scheduler_output.num_scheduled_tokens[req_id]
|
||||
is_prefill = num_scheduled_token > self.decode_threshold
|
||||
if is_prefill:
|
||||
num_cp_padded_scheduled_tokens = cdiv(
|
||||
num_scheduled_token,
|
||||
2 * self.pcp_world_size) * (2 * self.pcp_world_size)
|
||||
chunk_size = num_cp_padded_scheduled_tokens // (
|
||||
2 * self.pcp_world_size)
|
||||
num_added_recover_tokens = len(
|
||||
self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_world_size
|
||||
for rank in range(self.pcp_world_size):
|
||||
self.cp_kv_recover_idx_for_chunk[rank].extend(
|
||||
self.full_indices[rank * chunk_size +
|
||||
num_added_recover_tokens:(rank + 1) *
|
||||
chunk_size +
|
||||
num_added_recover_tokens])
|
||||
self.cp_kv_recover_idx_for_chunk[rank].extend(
|
||||
self.full_indices[num_cp_padded_scheduled_tokens -
|
||||
(rank + 1) * chunk_size +
|
||||
num_added_recover_tokens:
|
||||
num_cp_padded_scheduled_tokens -
|
||||
rank * chunk_size +
|
||||
num_added_recover_tokens])
|
||||
|
||||
cp_kv_recover_idx_for_chunk = torch.from_numpy(
|
||||
np.concatenate(
|
||||
self.cp_kv_recover_idx_for_chunk)).to(device=self.device)
|
||||
cp_kv_recover_idx_for_chunk.copy_(torch.tensor(
|
||||
np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()),
|
||||
non_blocking=True)
|
||||
self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to(
|
||||
torch.float32).argsort().to(torch.int32)
|
||||
|
||||
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens,
|
||||
input_batch, num_scheduled_tokens):
|
||||
from vllm_ascend.attention.utils import \
|
||||
@@ -774,7 +730,6 @@ class PCPManager:
|
||||
}
|
||||
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[:
|
||||
num_actual_tokens_pcp_padded]
|
||||
long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk
|
||||
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
|
||||
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
|
||||
long_seq_metadata.q_full_idx = self.q_full_idx
|
||||
|
||||
Reference in New Issue
Block a user