[feature] chunkprefill support pcp&dcp (#3801)

### What this PR does / why we need it?
ChunkPrefill now can support Long Sequence Feature Pcp&Dcp

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI tests passed with self-test


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: Apocalypse990923-qshi <qiushixu@usc.edu>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <3834144971@qq.com>
This commit is contained in:
Apocalypse
2025-11-11 09:18:02 +08:00
committed by GitHub
parent 7ffbe73d54
commit 71866d5311
8 changed files with 1276 additions and 170 deletions

View File

@@ -14,10 +14,19 @@ from vllm.forward_context import ForwardContext, get_forward_context
class AscendPrefillContextParallelMetadata:
pcp_allgather_restore_idx: torch.Tensor = None
cp_kv_recover_idx_for_chunk: torch.Tensor = None
num_actual_tokens_pcp_padded: Optional[int] = None
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[Optional[
list[int]]]]]]]] = None
mask_for_non_zero_chunk: Optional[List[bool]] = None
max_chunk_num: int = 0
q_head_idx_tensor: torch.Tensor = None
q_tail_idx_tensor: torch.Tensor = None
@@ -46,7 +55,7 @@ class AscendCommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
For many of the tensors we keep both GPU and CPU versions.
"""
@@ -106,6 +115,47 @@ class AscendCommonAttentionMetadata:
AscendPrefillContextParallelMetadata] = None
def extract_req_dcp_by_chunk_pcp(lst,
chunk_idx,
dcp_size,
pcp_rank,
fill_value=0):
num_reqs = len(lst)
results: List[List[int]] = []
for i in range(num_reqs):
if len(lst[i]) == 0 or chunk_idx >= len(lst[i]):
# empty req or this req has no corresponding chunk, fill 0
results.append([fill_value] * dcp_size)
continue
dcp_values = lst[i][chunk_idx][pcp_rank]
results.append(dcp_values)
return results
def filter_chunked_req_indices(
seq_len: torch.Tensor,
mask_for_non_zero_chunk: Optional[List[bool]],
) -> torch.Tensor:
"""
filter the reqs which are doing real chunk_prefill.
Args:
seq_len: contains multi-req length: [req0_len, req1_len, ...]
mask_for_non_zero_chunk: [True, False, True, False, ...]
Returns:
filtered_indices: the real chunked req's indices
"""
assert mask_for_non_zero_chunk is not None and len(seq_len) == len(
mask_for_non_zero_chunk)
offsets = torch.cumsum(torch.cat([torch.tensor([0]), seq_len[:-1]]), dim=0)
filtered_indices = torch.cat([
torch.arange(offsets[i], offsets[i] + seq_len[i])
for i in range(len(mask_for_non_zero_chunk))
if mask_for_non_zero_chunk[i]
])
return filtered_indices
def split_decodes_and_prefills(
common_attn_metadata: AscendCommonAttentionMetadata,
decode_threshold: int = 1,