[bugfix](pcp,gqa) set kv_inverse_idx_for_chunk and cp_kv_recover_idx_for_chunk to None when dcp only (#6317)
### What this PR does / why we need it?
We only do restore and recover for pcp, so we should set
`kv_inverse_idx_for_chunk` and `cp_kv_recover_idx_for_chunk` to `None`
when only using dcp.
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -157,10 +157,18 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
local_chunk_starts = torch.zeros(
|
||||
(len(local_context_lens_allranks),), dtype=torch.int32, device=self.device
|
||||
)
|
||||
kv_inverse_idx_for_chunk = torch.argsort(
|
||||
common_long_seq_metadata.pcp_allgather_restore_idx[pcp_size * num_decode_tokens :].to(torch.float32)
|
||||
)
|
||||
cp_kv_recover_idx_for_chunk = torch.argsort(kv_inverse_idx_for_chunk)
|
||||
# Note(qcs): we only do restore and recover for pcp, and set these vars to None
|
||||
# when only using dcp.
|
||||
if self.pcp_size > 1:
|
||||
kv_inverse_idx_for_chunk = torch.argsort(
|
||||
common_long_seq_metadata.pcp_allgather_restore_idx[pcp_size * num_decode_tokens :].to(
|
||||
torch.float32
|
||||
)
|
||||
)
|
||||
cp_kv_recover_idx_for_chunk = torch.argsort(kv_inverse_idx_for_chunk)
|
||||
else:
|
||||
kv_inverse_idx_for_chunk = None
|
||||
cp_kv_recover_idx_for_chunk = None
|
||||
|
||||
batch_chunk_seq_mask = local_context_lens_allranks[:, self.pcp_rank, self.dcp_rank] == 0
|
||||
batch_chunk_seq_mask = torch.repeat_interleave(
|
||||
|
||||
Reference in New Issue
Block a user