[feature] chunkprefill support pcp&dcp (#3801)

### What this PR does / why we need it?
ChunkPrefill now can support Long Sequence Feature Pcp&Dcp

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI tests passed with self-test


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: Apocalypse990923-qshi <qiushixu@usc.edu>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <3834144971@qq.com>
This commit is contained in:
Apocalypse
2025-11-11 09:18:02 +08:00
committed by GitHub
parent 7ffbe73d54
commit 71866d5311
8 changed files with 1276 additions and 170 deletions

View File

@@ -73,6 +73,12 @@ class CachedRequestState:
lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None
# pcp/dcp param
local_chunked_kv_lens: Optional[list[Optional[list[Optional[
list[int]]]]]] = None # Records computed tokens for each chunk
next_pcp_dcp_start_rank: int = 0 # Tracks next starting rank for round-robin distribution
token_blank_in_last_blk: int = 0 # if the last block is not full, how many future tokens can be stored
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)
@@ -313,6 +319,10 @@ class InputBatch:
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
self.prev_req_id_to_index: Optional[dict[str, int]] = None
# pcp/dcp parameters
self.local_chunked_kv_lens: list[Optional[list[Optional[list[Optional[
list[int]]]]]]] = [None] * max_num_reqs
@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
@@ -385,6 +395,9 @@ class InputBatch:
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
# Add PCP/DCP tracking fields
self.local_chunked_kv_lens[req_index] = request.local_chunked_kv_lens
if sampling_params := request.sampling_params:
if (self.is_spec_decode
and is_spec_decode_unsupported(sampling_params)):
@@ -680,6 +693,8 @@ class InputBatch:
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.local_chunked_kv_lens[
empty_index] = self.local_chunked_kv_lens[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]