diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index f97798f3..8bffb12f 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -157,10 +157,18 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): local_chunk_starts = torch.zeros( (len(local_context_lens_allranks),), dtype=torch.int32, device=self.device ) - kv_inverse_idx_for_chunk = torch.argsort( - common_long_seq_metadata.pcp_allgather_restore_idx[pcp_size * num_decode_tokens :].to(torch.float32) - ) - cp_kv_recover_idx_for_chunk = torch.argsort(kv_inverse_idx_for_chunk) + # Note(qcs): we only do restore and recover for pcp, and set these vars to None + # when only using dcp. + if self.pcp_size > 1: + kv_inverse_idx_for_chunk = torch.argsort( + common_long_seq_metadata.pcp_allgather_restore_idx[pcp_size * num_decode_tokens :].to( + torch.float32 + ) + ) + cp_kv_recover_idx_for_chunk = torch.argsort(kv_inverse_idx_for_chunk) + else: + kv_inverse_idx_for_chunk = None + cp_kv_recover_idx_for_chunk = None batch_chunk_seq_mask = local_context_lens_allranks[:, self.pcp_rank, self.dcp_rank] == 0 batch_chunk_seq_mask = torch.repeat_interleave(