[feature] chunkprefill support pcp&dcp (#3801)
### What this PR does / why we need it?
ChunkPrefill now can support Long Sequence Feature Pcp&Dcp
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI tests passed with self-test
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: Apocalypse990923-qshi <qiushixu@usc.edu>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <3834144971@qq.com>
This commit is contained in:
@@ -29,7 +29,7 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Manager
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional,
|
||||
Union, cast)
|
||||
Tuple, Union, cast)
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@@ -471,13 +471,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
self.pcp_allgather_restore_idx = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.pcp_allgather_restore_idx = torch.zeros(
|
||||
self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.cp_kv_recover_idx_for_chunk: List[List[int]] = [
|
||||
[] for _ in range(self.pcp_size)
|
||||
]
|
||||
|
||||
self.num_pcp_pads = torch.zeros(self.max_num_reqs, dtype=torch.int32)
|
||||
self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.pcp_padded_slot_mapping = torch.zeros(
|
||||
self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.num_actual_tokens_pcp_padded = 0
|
||||
if self.speculative_config and self.pcp_size > 1:
|
||||
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
|
||||
@@ -739,7 +745,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
backward_kwargs = {}
|
||||
backward_kwargs["mm_features"] = new_req_data.mm_features
|
||||
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
# Create request state - PCP/DCP tracking will be computed below
|
||||
req_state = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
prompt_embeds=new_req_data.prompt_embeds,
|
||||
@@ -750,9 +757,42 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
lora_request=new_req_data.lora_request,
|
||||
local_chunked_kv_lens=None,
|
||||
**backward_kwargs,
|
||||
)
|
||||
|
||||
# Compute PCP/DCP tracking fields for chunked prefill
|
||||
self.input_batch.local_chunked_kv_lens = [None] * self.max_num_reqs
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
num_computed_tokens = new_req_data.num_computed_tokens
|
||||
if num_computed_tokens > 0:
|
||||
# Initialize with starting rank 0
|
||||
temp_start_rank_dict = {req_id: (0, 0)}
|
||||
|
||||
# Compute token distribution for initial tokens
|
||||
current_distribution = self.get_split_computed_tokens(
|
||||
np.array([num_computed_tokens]),
|
||||
request_ids=[req_id],
|
||||
request_start_rank_dict=temp_start_rank_dict,
|
||||
cp_kv_cache_interleave_size=self.parallel_config.
|
||||
cp_kv_cache_interleave_size,
|
||||
)[0]
|
||||
|
||||
# Update next_pcp_dcp_start_rank
|
||||
req_state.next_pcp_dcp_start_rank = temp_start_rank_dict[
|
||||
req_id][0]
|
||||
req_state.token_blank_in_last_blk = temp_start_rank_dict[
|
||||
req_id][1]
|
||||
|
||||
req_state.local_chunked_kv_lens = [
|
||||
copy.deepcopy(current_distribution)
|
||||
]
|
||||
else:
|
||||
# No computed tokens yet
|
||||
req_state.local_chunked_kv_lens = []
|
||||
|
||||
self.requests[req_id] = req_state
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
self._init_mrope_positions(self.requests[req_id])
|
||||
@@ -769,8 +809,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
||||
|
||||
# Update the cached states.
|
||||
prev_num_computed_tokens = req_state.num_computed_tokens
|
||||
req_state.num_computed_tokens = num_computed_tokens
|
||||
|
||||
# Compute PCP/DCP tracking fields for chunked prefill
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
# If this is the first chunk, initialize tracking fields
|
||||
if req_state.local_chunked_kv_lens is None:
|
||||
req_state.local_chunked_kv_lens = []
|
||||
|
||||
# Compute tokens added in this chunk (not cumulative)
|
||||
chunk_tokens = num_computed_tokens - prev_num_computed_tokens
|
||||
|
||||
if chunk_tokens > 0:
|
||||
# Create a temporary dict with this request's starting rank
|
||||
temp_start_rank_dict = {
|
||||
req_id: (req_state.next_pcp_dcp_start_rank,
|
||||
req_state.token_blank_in_last_blk)
|
||||
}
|
||||
|
||||
# Compute distribution for this chunk only
|
||||
chunk_distribution = self.get_split_computed_tokens(
|
||||
np.array([chunk_tokens]),
|
||||
request_ids=[req_id],
|
||||
request_start_rank_dict=temp_start_rank_dict,
|
||||
cp_kv_cache_interleave_size=self.parallel_config.
|
||||
cp_kv_cache_interleave_size,
|
||||
)[0]
|
||||
|
||||
# Update next_pcp_dcp_start_rank for this request
|
||||
req_state.next_pcp_dcp_start_rank = temp_start_rank_dict[
|
||||
req_id][0]
|
||||
req_state.token_blank_in_last_blk = temp_start_rank_dict[
|
||||
req_id][1]
|
||||
|
||||
# Append this chunk's distribution to accumulation list
|
||||
req_state.local_chunked_kv_lens.append(
|
||||
copy.deepcopy(chunk_distribution))
|
||||
|
||||
if not is_last_rank:
|
||||
# When using PP, the scheduler sends the sampled tokens back,
|
||||
# because there's no direct communication between the first-
|
||||
@@ -815,6 +891,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.block_table.append_row(
|
||||
new_block_ids, req_index)
|
||||
|
||||
# Update PCP/DCP tracking fields in input_batch
|
||||
self.input_batch.local_chunked_kv_lens[
|
||||
req_index] = req_state.local_chunked_kv_lens
|
||||
|
||||
# For the last rank, we don't need to update the token_ids_cpu
|
||||
# because the sampled tokens are already cached.
|
||||
if not is_last_rank:
|
||||
@@ -979,6 +1059,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
return None
|
||||
if self.attn_mask_builder is None:
|
||||
raise ValueError("Attn mask builder is None")
|
||||
if self.dcp_size > 1:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
# Pooling situation.
|
||||
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
|
||||
return self.attn_mask_builder.get_pooling_mask(self.device)
|
||||
@@ -1378,6 +1460,49 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
|
||||
def generate_kv_idx(self, tokens, scheduler_output):
|
||||
if not self.pcp_size > 1:
|
||||
return
|
||||
self.cp_kv_recover_idx_for_chunk = [[] for _ in range(self.pcp_size)]
|
||||
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
is_prefill = self.input_batch.num_computed_tokens_cpu[
|
||||
i] < self.input_batch.num_prompt_tokens[i]
|
||||
if is_prefill:
|
||||
num_cp_padded_scheduled_tokens = cdiv(
|
||||
num_scheduled_tokens,
|
||||
2 * self.pcp_size) * (2 * self.pcp_size)
|
||||
full_indices = list(
|
||||
range(self.max_num_tokens * self.pcp_size * self.dcp_size +
|
||||
self.pcp_size * self.dcp_size * self.max_num_reqs))
|
||||
chunk_size = num_cp_padded_scheduled_tokens // (2 *
|
||||
self.pcp_size)
|
||||
num_added_recover_tokens = len(
|
||||
self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_size
|
||||
for rank in range(self.pcp_size):
|
||||
self.cp_kv_recover_idx_for_chunk[rank].extend(
|
||||
full_indices[rank * chunk_size +
|
||||
num_added_recover_tokens:(rank + 1) *
|
||||
chunk_size + num_added_recover_tokens])
|
||||
self.cp_kv_recover_idx_for_chunk[rank].extend(
|
||||
full_indices[num_cp_padded_scheduled_tokens -
|
||||
(rank + 1) * chunk_size +
|
||||
num_added_recover_tokens:
|
||||
num_cp_padded_scheduled_tokens -
|
||||
rank * chunk_size +
|
||||
num_added_recover_tokens])
|
||||
|
||||
cp_kv_recover_idx_for_chunk = torch.from_numpy(
|
||||
np.concatenate(
|
||||
self.cp_kv_recover_idx_for_chunk)).to(device=self.device)
|
||||
cp_kv_recover_idx_for_chunk.copy_(torch.tensor(
|
||||
np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()),
|
||||
non_blocking=True)
|
||||
self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to(
|
||||
torch.float32).argsort().to(torch.int32)
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@@ -1406,7 +1531,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
arange,
|
||||
)
|
||||
|
||||
self.generate_kv_idx(tokens, scheduler_output)
|
||||
self.input_batch.block_table.compute_slot_mapping(
|
||||
req_indices, positions_np)
|
||||
self.input_batch.block_table.commit_slot_mapping(
|
||||
@@ -1610,15 +1735,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
]
|
||||
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
if self.pcp_size == 1:
|
||||
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
|
||||
else:
|
||||
if self.pcp_size > 1:
|
||||
# while pcp > 1, we need the original num_scheduled_tokens before split
|
||||
# to calculate discard_requests_mask
|
||||
tokens_original = [
|
||||
scheduler_output.num_scheduled_tokens[i] for i in req_ids
|
||||
]
|
||||
original_seq_lens_np = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
np.array(list(scheduler_output.num_scheduled_tokens.values())))
|
||||
np.array(tokens_original, dtype=np.int32))
|
||||
discard_requests_mask = original_seq_lens_np < num_tokens_np
|
||||
else:
|
||||
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
|
||||
|
||||
discard_request_indices = np.nonzero(discard_requests_mask)[0]
|
||||
self.num_discarded_requests = len(discard_request_indices)
|
||||
self.discard_request_indices.np[:self.num_discarded_requests] = (
|
||||
@@ -1762,8 +1891,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output.num_scheduled_tokens)
|
||||
|
||||
# prepare pcp meta data
|
||||
# For chunked prefill, use num_scheduled_tokens instead of cumulative seq_lens
|
||||
# to correctly calculate chunk_len in _generate_pcp_metadata
|
||||
if self.vllm_config.scheduler_config.chunked_prefill_enabled and self.pcp_size > 1:
|
||||
# In chunked prefill, seq_lens_for_chunk should be the current chunk size
|
||||
seq_lens_for_chunk = torch.from_numpy(
|
||||
num_scheduled_tokens[:num_reqs])
|
||||
else:
|
||||
# Normal mode: use cumulative sequence lengths
|
||||
seq_lens_for_chunk = seq_lens_cpu
|
||||
long_seq_metadata = self._generate_pcp_metadata(
|
||||
total_num_scheduled_tokens, seq_lens_cpu)
|
||||
total_num_scheduled_tokens, seq_lens_for_chunk, seq_lens_cpu)
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
@@ -2690,7 +2828,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
long_seq_metadata = self._generate_pcp_metadata(
|
||||
num_tokens, self.seq_lens_cpu)
|
||||
num_tokens, self.seq_lens_cpu, self.seq_lens_cpu)
|
||||
if long_seq_metadata is not None:
|
||||
pcp_world_size = get_pcp_group(
|
||||
).world_size if prefill_context_parallel_enable() else 1
|
||||
@@ -4266,23 +4404,149 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
[-1, pcp_world_size, dcp_world_size])
|
||||
return dcp_local_seq_lens
|
||||
|
||||
def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens):
|
||||
def get_split_computed_tokens(
|
||||
self,
|
||||
num_computed_tokens: np.ndarray,
|
||||
request_ids: Optional[List[str]] = None,
|
||||
request_start_rank_dict: Dict[str, tuple[
|
||||
int, int]] = {}, # tuple: start_rank, tokens_blank_in_this_block
|
||||
cp_kv_cache_interleave_size: int = 1
|
||||
) -> list[Optional[list[Optional[list[int]]]]]:
|
||||
"""Splits computed token counts across dcp and sp dimensions for distributed allocation.
|
||||
|
||||
Args:
|
||||
num_computed_tokens: Number of tokens for each request (current chunk, not cumulative)
|
||||
request_ids: Request IDs to track state
|
||||
request_start_rank_dict: Dict mapping req_id to the starting rank for this chunk.
|
||||
Will be updated with next starting rank after distribution.
|
||||
|
||||
Returns:
|
||||
List of [pcp_size][dcp_size] distribution for each request
|
||||
"""
|
||||
self.pcp_world_size = get_pcp_group(
|
||||
).world_size if prefill_context_parallel_enable() else 1
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
num_requests = len(num_computed_tokens)
|
||||
assert request_start_rank_dict is not None and request_ids is not None and len(
|
||||
request_ids) == num_requests
|
||||
local_chunked_kv_lens = [[[0] * self.dcp_world_size
|
||||
for _ in range(self.pcp_world_size)]
|
||||
for _ in range(num_requests)]
|
||||
total_ranks = self.pcp_world_size * self.dcp_world_size
|
||||
|
||||
for req_idx, (req_id, total_tokens) in enumerate(
|
||||
zip(request_ids, num_computed_tokens)):
|
||||
if total_tokens <= 0:
|
||||
continue
|
||||
|
||||
# Get starting rank for this chunk
|
||||
start_rank = 0
|
||||
tokens_blank = 0
|
||||
if request_start_rank_dict is not None:
|
||||
start_rank, tokens_blank = request_start_rank_dict.get(
|
||||
req_id, (0, 0))
|
||||
|
||||
if tokens_blank > 0: # need to continue writing in the last block of previous chunk
|
||||
consumed_tokens = min(tokens_blank, total_tokens)
|
||||
total_tokens -= consumed_tokens
|
||||
tokens_blank -= consumed_tokens
|
||||
pcp_idx = start_rank // self.dcp_world_size
|
||||
dcp_idx = start_rank % self.dcp_world_size
|
||||
local_chunked_kv_lens[req_idx][pcp_idx][
|
||||
dcp_idx] += consumed_tokens
|
||||
if tokens_blank == 0:
|
||||
start_rank = (start_rank + 1) % total_ranks
|
||||
if total_tokens == 0:
|
||||
request_start_rank_dict[req_id] = (start_rank,
|
||||
tokens_blank)
|
||||
continue
|
||||
|
||||
virtual_size = total_ranks * cp_kv_cache_interleave_size
|
||||
base = int(total_tokens) // virtual_size
|
||||
|
||||
# Distribute base tokens to all ranks
|
||||
for rank_idx in range(total_ranks):
|
||||
pcp_idx = rank_idx // self.dcp_world_size
|
||||
dcp_idx = rank_idx % self.dcp_world_size
|
||||
local_chunked_kv_lens[req_idx][pcp_idx][
|
||||
dcp_idx] += base * cp_kv_cache_interleave_size
|
||||
|
||||
remainder = int(total_tokens) % virtual_size
|
||||
if remainder == 0:
|
||||
request_start_rank_dict[req_id] = (start_rank, tokens_blank)
|
||||
continue
|
||||
remain_blocks = cdiv(remainder, cp_kv_cache_interleave_size)
|
||||
assert remain_blocks > 0
|
||||
|
||||
# Distribute remainder tokens starting from start_rank
|
||||
for i in range(remain_blocks):
|
||||
rank = (start_rank + i) % total_ranks
|
||||
pcp_idx = rank // self.dcp_world_size
|
||||
dcp_idx = rank % self.dcp_world_size
|
||||
if i < remain_blocks - 1 or remainder % cp_kv_cache_interleave_size == 0: # not last block or divisible
|
||||
local_chunked_kv_lens[req_idx][pcp_idx][
|
||||
dcp_idx] += 1 * cp_kv_cache_interleave_size
|
||||
tokens_blank = 0
|
||||
else: # if last block and undivisible
|
||||
local_chunked_kv_lens[req_idx][pcp_idx][
|
||||
dcp_idx] += remainder % cp_kv_cache_interleave_size
|
||||
tokens_blank = cp_kv_cache_interleave_size - (
|
||||
remainder % cp_kv_cache_interleave_size)
|
||||
start_rank = (start_rank + remain_blocks - 1) % total_ranks
|
||||
if tokens_blank == 0:
|
||||
start_rank = (start_rank + 1) % total_ranks
|
||||
|
||||
# Update next starting rank for this request
|
||||
request_start_rank_dict[req_id] = (start_rank, tokens_blank)
|
||||
|
||||
return cast(List[Optional[List[Optional[List[int]]]]],
|
||||
local_chunked_kv_lens)
|
||||
|
||||
def _get_chunked_req_mask_and_max_chunk(
|
||||
self,
|
||||
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[
|
||||
Optional[list[int]]]]]]]] = None
|
||||
) -> Tuple[List[bool], int]:
|
||||
"""
|
||||
given 4-d list [req][chunk][pcp][dcp], return:
|
||||
1. if each req has any chunk (list[bool])
|
||||
2. max chunk num along all reqs (int)
|
||||
"""
|
||||
assert local_chunked_kv_lens is not None
|
||||
if len(local_chunked_kv_lens) == 0:
|
||||
return ([], 0)
|
||||
mask_for_non_zero_chunk = [
|
||||
len(req) > 0 for req in local_chunked_kv_lens if req is not None
|
||||
]
|
||||
max_chunk_num = max(
|
||||
(len(req) for req in local_chunked_kv_lens if req is not None),
|
||||
default=0)
|
||||
return mask_for_non_zero_chunk, max_chunk_num
|
||||
|
||||
def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens,
|
||||
seq_lens_origin):
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||
>= self.input_batch.num_prompt_tokens[:num_reqs])
|
||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
|
||||
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
||||
local_chunked_kv_lens = self.input_batch.local_chunked_kv_lens[
|
||||
num_decodes:num_reqs]
|
||||
mask_for_non_zero_chunk, max_chunk_num = self._get_chunked_req_mask_and_max_chunk(
|
||||
local_chunked_kv_lens)
|
||||
long_seq_metadata = None
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
num_computed_tokens_of_pcp_dcp=self._get_pcp_local_seq_lens(
|
||||
seq_lens,
|
||||
seq_lens_origin,
|
||||
self.pcp_size,
|
||||
self.dcp_size,
|
||||
self.parallel_config.cp_kv_cache_interleave_size,
|
||||
).numpy(),
|
||||
)
|
||||
local_chunked_kv_lens=local_chunked_kv_lens,
|
||||
mask_for_non_zero_chunk=mask_for_non_zero_chunk,
|
||||
max_chunk_num=max_chunk_num)
|
||||
if self.pcp_size > 1:
|
||||
q_head_idx, q_tail_idx = [], []
|
||||
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
|
||||
@@ -4393,6 +4657,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
}
|
||||
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx[:
|
||||
num_actual_tokens_pcp_padded]
|
||||
long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk
|
||||
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
|
||||
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
|
||||
long_seq_metadata.q_full_idx = self.q_full_idx
|
||||
|
||||
Reference in New Issue
Block a user