[feature] chunkprefill support pcp&dcp (#3801)
### What this PR does / why we need it?
ChunkPrefill now can support Long Sequence Feature Pcp&Dcp
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI tests passed with self-test
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: Apocalypse990923-qshi <qiushixu@usc.edu>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <3834144971@qq.com>
This commit is contained in:
@@ -73,6 +73,12 @@ class CachedRequestState:
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
prompt_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
# pcp/dcp param
|
||||
local_chunked_kv_lens: Optional[list[Optional[list[Optional[
|
||||
list[int]]]]]] = None # Records computed tokens for each chunk
|
||||
next_pcp_dcp_start_rank: int = 0 # Tracks next starting rank for round-robin distribution
|
||||
token_blank_in_last_blk: int = 0 # if the last block is not full, how many future tokens can be stored
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
self.prompt_token_ids, self.prompt_embeds)
|
||||
@@ -313,6 +319,10 @@ class InputBatch:
|
||||
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
|
||||
self.prev_req_id_to_index: Optional[dict[str, int]] = None
|
||||
|
||||
# pcp/dcp parameters
|
||||
self.local_chunked_kv_lens: list[Optional[list[Optional[list[Optional[
|
||||
list[int]]]]]]] = [None] * max_num_reqs
|
||||
|
||||
@property
|
||||
def req_ids(self) -> list[str]:
|
||||
# None elements should only be present transiently
|
||||
@@ -385,6 +395,9 @@ class InputBatch:
|
||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||
self.block_table.add_row(request.block_ids, req_index)
|
||||
|
||||
# Add PCP/DCP tracking fields
|
||||
self.local_chunked_kv_lens[req_index] = request.local_chunked_kv_lens
|
||||
|
||||
if sampling_params := request.sampling_params:
|
||||
if (self.is_spec_decode
|
||||
and is_spec_decode_unsupported(sampling_params)):
|
||||
@@ -680,6 +693,8 @@ class InputBatch:
|
||||
last_req_index]
|
||||
self.num_computed_tokens_cpu[
|
||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||
self.local_chunked_kv_lens[
|
||||
empty_index] = self.local_chunked_kv_lens[last_req_index]
|
||||
self.block_table.move_row(last_req_index, empty_index)
|
||||
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
||||
last_req_index]
|
||||
|
||||
Reference in New Issue
Block a user