[feature] chunkprefill support pcp&dcp (#3801)
### What this PR does / why we need it?
ChunkPrefill now can support Long Sequence Feature Pcp&Dcp
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI tests passed with self-test
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: Apocalypse990923-qshi <qiushixu@usc.edu>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <3834144971@qq.com>
This commit is contained in:
@@ -14,10 +14,19 @@ from vllm.forward_context import ForwardContext, get_forward_context
|
||||
class AscendPrefillContextParallelMetadata:
|
||||
pcp_allgather_restore_idx: torch.Tensor = None
|
||||
|
||||
cp_kv_recover_idx_for_chunk: torch.Tensor = None
|
||||
|
||||
num_actual_tokens_pcp_padded: Optional[int] = None
|
||||
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||
|
||||
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[Optional[
|
||||
list[int]]]]]]]] = None
|
||||
|
||||
mask_for_non_zero_chunk: Optional[List[bool]] = None
|
||||
|
||||
max_chunk_num: int = 0
|
||||
|
||||
q_head_idx_tensor: torch.Tensor = None
|
||||
|
||||
q_tail_idx_tensor: torch.Tensor = None
|
||||
@@ -46,7 +55,7 @@ class AscendCommonAttentionMetadata:
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
|
||||
|
||||
For many of the tensors we keep both GPU and CPU versions.
|
||||
"""
|
||||
|
||||
@@ -106,6 +115,47 @@ class AscendCommonAttentionMetadata:
|
||||
AscendPrefillContextParallelMetadata] = None
|
||||
|
||||
|
||||
def extract_req_dcp_by_chunk_pcp(lst,
|
||||
chunk_idx,
|
||||
dcp_size,
|
||||
pcp_rank,
|
||||
fill_value=0):
|
||||
num_reqs = len(lst)
|
||||
results: List[List[int]] = []
|
||||
for i in range(num_reqs):
|
||||
if len(lst[i]) == 0 or chunk_idx >= len(lst[i]):
|
||||
# empty req or this req has no corresponding chunk, fill 0
|
||||
results.append([fill_value] * dcp_size)
|
||||
continue
|
||||
dcp_values = lst[i][chunk_idx][pcp_rank]
|
||||
results.append(dcp_values)
|
||||
return results
|
||||
|
||||
|
||||
def filter_chunked_req_indices(
|
||||
seq_len: torch.Tensor,
|
||||
mask_for_non_zero_chunk: Optional[List[bool]],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
filter the reqs which are doing real chunk_prefill.
|
||||
|
||||
Args:
|
||||
seq_len: contains multi-req length: [req0_len, req1_len, ...]
|
||||
mask_for_non_zero_chunk: [True, False, True, False, ...]
|
||||
Returns:
|
||||
filtered_indices: the real chunked req's indices
|
||||
"""
|
||||
assert mask_for_non_zero_chunk is not None and len(seq_len) == len(
|
||||
mask_for_non_zero_chunk)
|
||||
offsets = torch.cumsum(torch.cat([torch.tensor([0]), seq_len[:-1]]), dim=0)
|
||||
filtered_indices = torch.cat([
|
||||
torch.arange(offsets[i], offsets[i] + seq_len[i])
|
||||
for i in range(len(mask_for_non_zero_chunk))
|
||||
if mask_for_non_zero_chunk[i]
|
||||
])
|
||||
return filtered_indices
|
||||
|
||||
|
||||
def split_decodes_and_prefills(
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
decode_threshold: int = 1,
|
||||
|
||||
Reference in New Issue
Block a user