[bugfix] fix the complex and potentially problematic generate_kv_idx. (#5957)
### What this PR does / why we need it?
In long-sequence scenarios, the chunked-prefill component may encounter
dimension misalignment issues, which previously occurred during
precision testing on the code_generate_lite dataset. This PR removes
redundant computations and instead derives the value using existing
results and straightforward calculations.
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -162,14 +162,12 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
local_total_toks = local_chunked_kv_lens_rank.sum()
|
||||
chunked_req_mask = self._get_chunked_req_mask(local_context_lens_allranks)
|
||||
local_chunk_starts = torch.zeros(
|
||||
(len(local_context_lens_allranks)), dtype=torch.int32, device=self.device
|
||||
(len(local_context_lens_allranks),), dtype=torch.int32, device=self.device
|
||||
)
|
||||
cp_kv_recover_idx_for_chunk = common_long_seq_metadata.cp_kv_recover_idx_for_chunk
|
||||
kv_inverse_idx_for_chunk = (
|
||||
torch.argsort(cp_kv_recover_idx_for_chunk.to(torch.float32))
|
||||
if cp_kv_recover_idx_for_chunk is not None
|
||||
else None
|
||||
kv_inverse_idx_for_chunk = torch.argsort(
|
||||
common_long_seq_metadata.pcp_allgather_restore_idx[pcp_size * num_decode_tokens :].to(torch.float32)
|
||||
)
|
||||
cp_kv_recover_idx_for_chunk = torch.argsort(kv_inverse_idx_for_chunk)
|
||||
|
||||
batch_chunk_seq_mask = local_context_lens_allranks[:, self.pcp_rank, self.dcp_rank] == 0
|
||||
batch_chunk_seq_mask = torch.repeat_interleave(
|
||||
|
||||
Reference in New Issue
Block a user