[feature] chunkprefill support pcp&dcp (#3801)

### What this PR does / why we need it?
ChunkPrefill now can support Long Sequence Feature Pcp&Dcp

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI tests passed with self-test


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: Apocalypse990923-qshi <qiushixu@usc.edu>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <3834144971@qq.com>
This commit is contained in:
Apocalypse
2025-11-11 09:18:02 +08:00
committed by GitHub
parent 7ffbe73d54
commit 71866d5311
8 changed files with 1276 additions and 170 deletions

View File

@@ -77,14 +77,6 @@ class BlockTable:
self.block_table_np = self.block_table_cpu.numpy()
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int32,
device=self.device)
try:
self.pcp_world_size = get_pcp_group(
).world_size if prefill_context_parallel_enable() else 1
@@ -98,6 +90,20 @@ class BlockTable:
self.dcp_rank = 0
self.pcp_world_size = 1
self.pcp_rank = 0
self.slot_mapping_cpu = torch.zeros(
self.max_num_batched_tokens +
2 * self.pcp_world_size * self.max_num_reqs,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.slot_mapping = torch.zeros(
self.max_num_batched_tokens +
2 * self.pcp_world_size * self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.kernel_sizes = kernel_sizes
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
@@ -148,7 +154,7 @@ class BlockTable:
if self.dcp_world_size * self.pcp_world_size > 1:
# Note(hc): The DCP implement store kvcache with an interleave
# style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
# always stored on the GPU whose dcp_rank equals i % pcp_world_size:
# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
@@ -268,12 +274,12 @@ class MultiGroupBlockTable:
# must be multiplied by dcp_world_size.
try:
dcp_world_size = get_dcp_group().world_size
cp_world_size = get_pcp_group(
pcp_world_size = get_pcp_group(
).world_size if prefill_context_parallel_enable() else 1
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
cp_world_size = 1
pcp_world_size = 1
if kernel_sizes is None:
kernel_sizes = [[0]] * len(block_sizes)
@@ -291,7 +297,7 @@ class MultiGroupBlockTable:
block_size, max_num_reqs,
max(
cdiv(max_model_len,
block_size * dcp_world_size * cp_world_size),
block_size * dcp_world_size * pcp_world_size),
1 + num_speculative_tokens), max_num_batched_tokens,
pin_memory, device, kernel_size_list,
cp_kv_cache_interleave_size)

View File

@@ -29,7 +29,7 @@ from copy import deepcopy
from dataclasses import dataclass
from multiprocessing import Manager
from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional,
Union, cast)
Tuple, Union, cast)
import numpy as np
import numpy.typing as npt
@@ -471,13 +471,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
device="cpu",
pin_memory=True)
self.seq_lens_np = self.seq_lens_cpu.numpy()
self.pcp_allgather_restore_idx = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.pcp_allgather_restore_idx = torch.zeros(
self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.cp_kv_recover_idx_for_chunk: List[List[int]] = [
[] for _ in range(self.pcp_size)
]
self.num_pcp_pads = torch.zeros(self.max_num_reqs, dtype=torch.int32)
self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.pcp_padded_slot_mapping = torch.zeros(
self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.num_actual_tokens_pcp_padded = 0
if self.speculative_config and self.pcp_size > 1:
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
@@ -739,7 +745,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
backward_kwargs = {}
backward_kwargs["mm_features"] = new_req_data.mm_features
self.requests[req_id] = CachedRequestState(
# Create request state - PCP/DCP tracking will be computed below
req_state = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt_embeds=new_req_data.prompt_embeds,
@@ -750,9 +757,42 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
lora_request=new_req_data.lora_request,
local_chunked_kv_lens=None,
**backward_kwargs,
)
# Compute PCP/DCP tracking fields for chunked prefill
self.input_batch.local_chunked_kv_lens = [None] * self.max_num_reqs
if self.pcp_size * self.dcp_size > 1:
num_computed_tokens = new_req_data.num_computed_tokens
if num_computed_tokens > 0:
# Initialize with starting rank 0
temp_start_rank_dict = {req_id: (0, 0)}
# Compute token distribution for initial tokens
current_distribution = self.get_split_computed_tokens(
np.array([num_computed_tokens]),
request_ids=[req_id],
request_start_rank_dict=temp_start_rank_dict,
cp_kv_cache_interleave_size=self.parallel_config.
cp_kv_cache_interleave_size,
)[0]
# Update next_pcp_dcp_start_rank
req_state.next_pcp_dcp_start_rank = temp_start_rank_dict[
req_id][0]
req_state.token_blank_in_last_blk = temp_start_rank_dict[
req_id][1]
req_state.local_chunked_kv_lens = [
copy.deepcopy(current_distribution)
]
else:
# No computed tokens yet
req_state.local_chunked_kv_lens = []
self.requests[req_id] = req_state
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._init_mrope_positions(self.requests[req_id])
@@ -769,8 +809,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states.
prev_num_computed_tokens = req_state.num_computed_tokens
req_state.num_computed_tokens = num_computed_tokens
# Compute PCP/DCP tracking fields for chunked prefill
if self.pcp_size * self.dcp_size > 1:
# If this is the first chunk, initialize tracking fields
if req_state.local_chunked_kv_lens is None:
req_state.local_chunked_kv_lens = []
# Compute tokens added in this chunk (not cumulative)
chunk_tokens = num_computed_tokens - prev_num_computed_tokens
if chunk_tokens > 0:
# Create a temporary dict with this request's starting rank
temp_start_rank_dict = {
req_id: (req_state.next_pcp_dcp_start_rank,
req_state.token_blank_in_last_blk)
}
# Compute distribution for this chunk only
chunk_distribution = self.get_split_computed_tokens(
np.array([chunk_tokens]),
request_ids=[req_id],
request_start_rank_dict=temp_start_rank_dict,
cp_kv_cache_interleave_size=self.parallel_config.
cp_kv_cache_interleave_size,
)[0]
# Update next_pcp_dcp_start_rank for this request
req_state.next_pcp_dcp_start_rank = temp_start_rank_dict[
req_id][0]
req_state.token_blank_in_last_blk = temp_start_rank_dict[
req_id][1]
# Append this chunk's distribution to accumulation list
req_state.local_chunked_kv_lens.append(
copy.deepcopy(chunk_distribution))
if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
@@ -815,6 +891,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.block_table.append_row(
new_block_ids, req_index)
# Update PCP/DCP tracking fields in input_batch
self.input_batch.local_chunked_kv_lens[
req_index] = req_state.local_chunked_kv_lens
# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
if not is_last_rank:
@@ -979,6 +1059,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return None
if self.attn_mask_builder is None:
raise ValueError("Attn mask builder is None")
if self.dcp_size > 1:
return self.attn_mask_builder.get_splitfuse_attn_mask()
# Pooling situation.
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
return self.attn_mask_builder.get_pooling_mask(self.device)
@@ -1378,6 +1460,49 @@ class NPUModelRunner(LoRAModelRunnerMixin):
scheduler_output,
decode_threshold=self.reorder_batch_threshold)
def generate_kv_idx(self, tokens, scheduler_output):
if not self.pcp_size > 1:
return
self.cp_kv_recover_idx_for_chunk = [[] for _ in range(self.pcp_size)]
for i, req_id in enumerate(self.input_batch.req_ids):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
is_prefill = self.input_batch.num_computed_tokens_cpu[
i] < self.input_batch.num_prompt_tokens[i]
if is_prefill:
num_cp_padded_scheduled_tokens = cdiv(
num_scheduled_tokens,
2 * self.pcp_size) * (2 * self.pcp_size)
full_indices = list(
range(self.max_num_tokens * self.pcp_size * self.dcp_size +
self.pcp_size * self.dcp_size * self.max_num_reqs))
chunk_size = num_cp_padded_scheduled_tokens // (2 *
self.pcp_size)
num_added_recover_tokens = len(
self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_size
for rank in range(self.pcp_size):
self.cp_kv_recover_idx_for_chunk[rank].extend(
full_indices[rank * chunk_size +
num_added_recover_tokens:(rank + 1) *
chunk_size + num_added_recover_tokens])
self.cp_kv_recover_idx_for_chunk[rank].extend(
full_indices[num_cp_padded_scheduled_tokens -
(rank + 1) * chunk_size +
num_added_recover_tokens:
num_cp_padded_scheduled_tokens -
rank * chunk_size +
num_added_recover_tokens])
cp_kv_recover_idx_for_chunk = torch.from_numpy(
np.concatenate(
self.cp_kv_recover_idx_for_chunk)).to(device=self.device)
cp_kv_recover_idx_for_chunk.copy_(torch.tensor(
np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()),
non_blocking=True)
self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to(
torch.float32).argsort().to(torch.int32)
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
@@ -1406,7 +1531,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
)
self.generate_kv_idx(tokens, scheduler_output)
self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
@@ -1610,15 +1735,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
]
num_tokens_np = np.array(num_tokens, dtype=np.int32)
num_reqs = self.input_batch.num_reqs
if self.pcp_size == 1:
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
else:
if self.pcp_size > 1:
# while pcp > 1, we need the original num_scheduled_tokens before split
# to calculate discard_requests_mask
tokens_original = [
scheduler_output.num_scheduled_tokens[i] for i in req_ids
]
original_seq_lens_np = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
np.array(list(scheduler_output.num_scheduled_tokens.values())))
np.array(tokens_original, dtype=np.int32))
discard_requests_mask = original_seq_lens_np < num_tokens_np
else:
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
discard_request_indices = np.nonzero(discard_requests_mask)[0]
self.num_discarded_requests = len(discard_request_indices)
self.discard_request_indices.np[:self.num_discarded_requests] = (
@@ -1762,8 +1891,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
scheduler_output.num_scheduled_tokens)
# prepare pcp meta data
# For chunked prefill, use num_scheduled_tokens instead of cumulative seq_lens
# to correctly calculate chunk_len in _generate_pcp_metadata
if self.vllm_config.scheduler_config.chunked_prefill_enabled and self.pcp_size > 1:
# In chunked prefill, seq_lens_for_chunk should be the current chunk size
seq_lens_for_chunk = torch.from_numpy(
num_scheduled_tokens[:num_reqs])
else:
# Normal mode: use cumulative sequence lengths
seq_lens_for_chunk = seq_lens_cpu
long_seq_metadata = self._generate_pcp_metadata(
total_num_scheduled_tokens, seq_lens_cpu)
total_num_scheduled_tokens, seq_lens_for_chunk, seq_lens_cpu)
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
@@ -2690,7 +2828,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int32,
device=self.device)
long_seq_metadata = self._generate_pcp_metadata(
num_tokens, self.seq_lens_cpu)
num_tokens, self.seq_lens_cpu, self.seq_lens_cpu)
if long_seq_metadata is not None:
pcp_world_size = get_pcp_group(
).world_size if prefill_context_parallel_enable() else 1
@@ -4266,23 +4404,149 @@ class NPUModelRunner(LoRAModelRunnerMixin):
[-1, pcp_world_size, dcp_world_size])
return dcp_local_seq_lens
def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens):
def get_split_computed_tokens(
self,
num_computed_tokens: np.ndarray,
request_ids: Optional[List[str]] = None,
request_start_rank_dict: Dict[str, tuple[
int, int]] = {}, # tuple: start_rank, tokens_blank_in_this_block
cp_kv_cache_interleave_size: int = 1
) -> list[Optional[list[Optional[list[int]]]]]:
"""Splits computed token counts across dcp and sp dimensions for distributed allocation.
Args:
num_computed_tokens: Number of tokens for each request (current chunk, not cumulative)
request_ids: Request IDs to track state
request_start_rank_dict: Dict mapping req_id to the starting rank for this chunk.
Will be updated with next starting rank after distribution.
Returns:
List of [pcp_size][dcp_size] distribution for each request
"""
self.pcp_world_size = get_pcp_group(
).world_size if prefill_context_parallel_enable() else 1
self.dcp_world_size = get_dcp_group().world_size
num_requests = len(num_computed_tokens)
assert request_start_rank_dict is not None and request_ids is not None and len(
request_ids) == num_requests
local_chunked_kv_lens = [[[0] * self.dcp_world_size
for _ in range(self.pcp_world_size)]
for _ in range(num_requests)]
total_ranks = self.pcp_world_size * self.dcp_world_size
for req_idx, (req_id, total_tokens) in enumerate(
zip(request_ids, num_computed_tokens)):
if total_tokens <= 0:
continue
# Get starting rank for this chunk
start_rank = 0
tokens_blank = 0
if request_start_rank_dict is not None:
start_rank, tokens_blank = request_start_rank_dict.get(
req_id, (0, 0))
if tokens_blank > 0: # need to continue writing in the last block of previous chunk
consumed_tokens = min(tokens_blank, total_tokens)
total_tokens -= consumed_tokens
tokens_blank -= consumed_tokens
pcp_idx = start_rank // self.dcp_world_size
dcp_idx = start_rank % self.dcp_world_size
local_chunked_kv_lens[req_idx][pcp_idx][
dcp_idx] += consumed_tokens
if tokens_blank == 0:
start_rank = (start_rank + 1) % total_ranks
if total_tokens == 0:
request_start_rank_dict[req_id] = (start_rank,
tokens_blank)
continue
virtual_size = total_ranks * cp_kv_cache_interleave_size
base = int(total_tokens) // virtual_size
# Distribute base tokens to all ranks
for rank_idx in range(total_ranks):
pcp_idx = rank_idx // self.dcp_world_size
dcp_idx = rank_idx % self.dcp_world_size
local_chunked_kv_lens[req_idx][pcp_idx][
dcp_idx] += base * cp_kv_cache_interleave_size
remainder = int(total_tokens) % virtual_size
if remainder == 0:
request_start_rank_dict[req_id] = (start_rank, tokens_blank)
continue
remain_blocks = cdiv(remainder, cp_kv_cache_interleave_size)
assert remain_blocks > 0
# Distribute remainder tokens starting from start_rank
for i in range(remain_blocks):
rank = (start_rank + i) % total_ranks
pcp_idx = rank // self.dcp_world_size
dcp_idx = rank % self.dcp_world_size
if i < remain_blocks - 1 or remainder % cp_kv_cache_interleave_size == 0: # not last block or divisible
local_chunked_kv_lens[req_idx][pcp_idx][
dcp_idx] += 1 * cp_kv_cache_interleave_size
tokens_blank = 0
else: # if last block and undivisible
local_chunked_kv_lens[req_idx][pcp_idx][
dcp_idx] += remainder % cp_kv_cache_interleave_size
tokens_blank = cp_kv_cache_interleave_size - (
remainder % cp_kv_cache_interleave_size)
start_rank = (start_rank + remain_blocks - 1) % total_ranks
if tokens_blank == 0:
start_rank = (start_rank + 1) % total_ranks
# Update next starting rank for this request
request_start_rank_dict[req_id] = (start_rank, tokens_blank)
return cast(List[Optional[List[Optional[List[int]]]]],
local_chunked_kv_lens)
def _get_chunked_req_mask_and_max_chunk(
self,
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[
Optional[list[int]]]]]]]] = None
) -> Tuple[List[bool], int]:
"""
given 4-d list [req][chunk][pcp][dcp], return:
1. if each req has any chunk (list[bool])
2. max chunk num along all reqs (int)
"""
assert local_chunked_kv_lens is not None
if len(local_chunked_kv_lens) == 0:
return ([], 0)
mask_for_non_zero_chunk = [
len(req) > 0 for req in local_chunked_kv_lens if req is not None
]
max_chunk_num = max(
(len(req) for req in local_chunked_kv_lens if req is not None),
default=0)
return mask_for_non_zero_chunk, max_chunk_num
def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens,
seq_lens_origin):
num_reqs = self.input_batch.num_reqs
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
>= self.input_batch.num_prompt_tokens[:num_reqs])
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
local_chunked_kv_lens = self.input_batch.local_chunked_kv_lens[
num_decodes:num_reqs]
mask_for_non_zero_chunk, max_chunk_num = self._get_chunked_req_mask_and_max_chunk(
local_chunked_kv_lens)
long_seq_metadata = None
if self.pcp_size * self.dcp_size > 1:
long_seq_metadata = AscendPrefillContextParallelMetadata(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_computed_tokens_of_pcp_dcp=self._get_pcp_local_seq_lens(
seq_lens,
seq_lens_origin,
self.pcp_size,
self.dcp_size,
self.parallel_config.cp_kv_cache_interleave_size,
).numpy(),
)
local_chunked_kv_lens=local_chunked_kv_lens,
mask_for_non_zero_chunk=mask_for_non_zero_chunk,
max_chunk_num=max_chunk_num)
if self.pcp_size > 1:
q_head_idx, q_tail_idx = [], []
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
@@ -4393,6 +4657,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
}
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx[:
num_actual_tokens_pcp_padded]
long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
long_seq_metadata.q_full_idx = self.q_full_idx

View File

@@ -73,6 +73,12 @@ class CachedRequestState:
lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None
# pcp/dcp param
local_chunked_kv_lens: Optional[list[Optional[list[Optional[
list[int]]]]]] = None # Records computed tokens for each chunk
next_pcp_dcp_start_rank: int = 0 # Tracks next starting rank for round-robin distribution
token_blank_in_last_blk: int = 0 # if the last block is not full, how many future tokens can be stored
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)
@@ -313,6 +319,10 @@ class InputBatch:
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
self.prev_req_id_to_index: Optional[dict[str, int]] = None
# pcp/dcp parameters
self.local_chunked_kv_lens: list[Optional[list[Optional[list[Optional[
list[int]]]]]]] = [None] * max_num_reqs
@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
@@ -385,6 +395,9 @@ class InputBatch:
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
# Add PCP/DCP tracking fields
self.local_chunked_kv_lens[req_index] = request.local_chunked_kv_lens
if sampling_params := request.sampling_params:
if (self.is_spec_decode
and is_spec_decode_unsupported(sampling_params)):
@@ -680,6 +693,8 @@ class InputBatch:
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.local_chunked_kv_lens[
empty_index] = self.local_chunked_kv_lens[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]