[bugfix](CP) Fix and unify the PD request discrimination logic. (#5939)

### What this PR does / why we need it?
Since the PR (https://github.com/vllm-project/vllm/pull/32118) has
modified the criteria for judging Prefill and Decode requests in vLLM,
PCPManager needs to synchronize with this standard. As PCPManager
involves multiple calculations of PD request counts, this PR attempts to
consolidate the related logic and update the PD request count once per
batch.

### How was this patch tested?
```bash
pytest tests/e2e/multicard/4-cards/long_sequence/test_mtp.py
```

- vLLM version: v0.13.0
- vLLM main:
11b6af5280

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
Qiu
2026-01-31 10:26:02 +08:00
committed by GitHub
parent 4230bc8646
commit 638cae824d
4 changed files with 88 additions and 111 deletions

View File

@@ -20,17 +20,18 @@
import os
import pytest
from tests.e2e.conftest import VllmRunner
from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free
os.environ["HCCL_BUFFSIZE"] = "512"
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
@wait_until_npu_memory_free()
def test_pcp_dcp_mtp1_eager():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
with VllmRunner(
model,
max_model_len=1024,
@@ -50,15 +51,8 @@ def test_pcp_dcp_mtp1_eager():
runner.generate_greedy(prompts, 32)
@pytest.mark.skip(
reason="vLLM PR-32118 break this",
)
@wait_until_npu_memory_free()
def test_pcp_dcp_mtp3_eager():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
with VllmRunner(
model,
max_model_len=1024,
@@ -78,15 +72,8 @@ def test_pcp_dcp_mtp3_eager():
runner.generate_greedy(prompts, 32)
@pytest.mark.skip(
reason="vLLM PR-32118 break this",
)
@wait_until_npu_memory_free()
def test_pcp_dcp_mtp3_piecewise_graph():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
with VllmRunner(
model,
max_model_len=1024,
@@ -109,15 +96,8 @@ def test_pcp_dcp_mtp3_piecewise_graph():
runner.generate_greedy(prompts, 32)
@pytest.mark.skip(
reason="vLLM PR-32118 break this",
)
@wait_until_npu_memory_free()
def test_pcp_dcp_mtp3_full_graph():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
with VllmRunner(
model,
max_model_len=1024,
@@ -140,12 +120,8 @@ def test_pcp_dcp_mtp3_full_graph():
runner.generate_greedy(prompts, 32)
@wait_until_npu_memory_free()
def test_dcp_mtp3_full_graph():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
with VllmRunner(
model,
max_model_len=1024,

View File

@@ -141,8 +141,10 @@ def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
dtype=np.int32)
input_batch.num_prompt_tokens = np.array(num_prompt_tokens, dtype=np.int32)
arange_np = np.arange(10000)
num_scheduled_tokens = np.array(tokens)
pcp_manager.init_batch_info(num_scheduled_tokens, num_reqs)
pcp_tokens_result, positions_result = pcp_manager.update_tokens_for_pcp(
np.array(tokens), arange_np, num_reqs, 1)
num_scheduled_tokens, arange_np)
assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \
f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}"
@@ -305,8 +307,8 @@ def test_generate_pcp_mtp_input(
for i, token_ids_tensor in enumerate(token_ids_tensor_list):
token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor
pcp_manager.generate_pcp_mtp_input(num_reqs, total_num_scheduled_tokens,
num_scheduled_tokens, False,
pcp_manager.init_batch_info(np.array(list(num_scheduled_tokens.values())), num_reqs)
pcp_manager.generate_pcp_mtp_input(total_num_scheduled_tokens, num_scheduled_tokens, False,
input_batch, arange_np)
assert torch.equal(
pcp_manager.input_ids_pcp_full.cpu[:total_num_scheduled_tokens],

View File

@@ -600,10 +600,16 @@ class NPUModelRunner(GPUModelRunner):
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)
if self.pcp_size * self.dcp_size > 1:
self.pcp_manager.init_batch_info(
num_scheduled_tokens,
self.input_batch.num_reqs,
)
# for pcp, prefill mtp should use origin scheduleroutput ,
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
self.pcp_manager.generate_pcp_mtp_input(
num_reqs,
total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens,
with_prefill,
@@ -621,8 +627,6 @@ class NPUModelRunner(GPUModelRunner):
num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp(
num_scheduled_tokens[:num_reqs],
self.arange_np,
self.input_batch.num_reqs,
self.reorder_batch_threshold,
)
# Re-update after PCP split sequences.
total_num_scheduled_tokens = sum(num_scheduled_tokens)
@@ -772,8 +776,7 @@ class NPUModelRunner(GPUModelRunner):
num_draft_tokens = None
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
if self.pcp_size * self.dcp_size > 1:
logits_indices = self.pcp_manager.get_logits_indices(
cu_num_tokens, num_reqs)
logits_indices = self.pcp_manager.get_logits_indices(cu_num_tokens)
logits_indices = logits_indices.pin_memory().to(
self.device, non_blocking=True)
else:
@@ -1020,9 +1023,8 @@ class NPUModelRunner(GPUModelRunner):
num_reqs = self.input_batch.num_reqs
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
query_start_loc_pcp_full_cpu[:num_reqs]
num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item()
num_decode_reqs = num_reqs - num_prefill_reqs
num_prefill_reqs = self.pcp_manager.num_prefill_reqs
num_decode_reqs = self.pcp_manager.num_decode_reqs
else:
long_seq_metadata = None # type: ignore
num_prefill_reqs = 0
@@ -1976,7 +1978,7 @@ class NPUModelRunner(GPUModelRunner):
)
return blk_table_tensor, slot_mapping
long_seq_metdadata = _get_pcp_metadata(num_tokens)
self.long_seq_metadata = _get_pcp_metadata(num_tokens)
block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0)
actual_last_loc = self.query_start_loc.np[num_reqs_padded]
@@ -2008,7 +2010,7 @@ class NPUModelRunner(GPUModelRunner):
positions=self.positions.gpu,
attn_state=self.attn_state,
decode_token_per_req=self.decode_token_per_req,
prefill_context_parallel_metadata=long_seq_metdadata,
prefill_context_parallel_metadata=self.long_seq_metadata,
)
if logits_indices is not None and self.cache_config.kv_sharing_fast_prefill:
@@ -2198,6 +2200,11 @@ class NPUModelRunner(GPUModelRunner):
force_has_lora=activate_lora,
)
)
if self.pcp_size * self.dcp_size > 1:
self.pcp_manager.init_batch_info(
num_scheduled_tokens,
num_reqs,
)
if cudagraph_runtime_mode is None:
cudagraph_runtime_mode = _cudagraph_mode
else:

View File

@@ -36,6 +36,10 @@ class PCPManager:
This manager encapsulates all PCP-related buffers and logic so that the
ModelRunner can access them via `self.pcp_manager`.
"""
num_reqs: int = 0
num_decode_reqs: int = 0
num_prefill_reqs: int = 0
num_decode_tokens: int = 0
def __init__(
self,
@@ -133,12 +137,25 @@ class PCPManager:
return cu_num_tokens, arange
def init_batch_info(
self,
num_scheduled_tokens: np.ndarray,
num_reqs: int,
) -> None:
self.num_reqs = num_reqs
is_prefill = (num_scheduled_tokens[:num_reqs] > self.decode_threshold)
if not any(is_prefill):
first_prefill = num_reqs
else:
first_prefill = is_prefill.argmax()
self.num_decode_reqs = first_prefill
self.num_prefill_reqs = num_reqs - self.num_decode_reqs
self.num_decode_tokens = num_scheduled_tokens[:self.num_decode_reqs].sum()
def update_tokens_for_pcp(
self,
num_scheduled_tokens: np.ndarray,
arange_np: np.ndarray,
num_reqs: int,
reorder_batch_threshold: int | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""
Update token counts and positions for Prefill Context Parallelism (PCP).
@@ -167,8 +184,6 @@ class PCPManager:
the number of new tokens scheduled per request.
arange_np: 1D numpy array of length max_buffer_num_tokens used for
efficient batched arange operations.
num_reqs: Total number of requests in the batch.
reorder_batch_threshold: Threshold for decode vs prefill requests.
Returns:
Tuple (pcp_tokens, pcp_positions):
@@ -187,16 +202,10 @@ class PCPManager:
>>> self.pcp_unpad_mask_cpu
[True, False, True, True, True, True, True, False, False,
False, True, True, True, True, True, True, True, True]
>>> self.pcp_allgather_resotre_idx
>>> self.pcp_allgather_restore_idx
[0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8]
"""
assert reorder_batch_threshold is not None, (
"PCP depends on reorder batch to split decode and prefill requests."
)
num_decode_reqs = sum(num_scheduled_tokens <= reorder_batch_threshold)
num_decode_tokens = sum(num_scheduled_tokens[:num_decode_reqs])
# DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size).
# We first pad each request's token count up to that multiple.
num_padded_scheduled_tokens = np.ceil(
@@ -205,11 +214,11 @@ class PCPManager:
# PCP does not split decode requests. For decode requests, we instead
# duplicate the scheduled tokens across the pcp_world_size ranks.
num_padded_scheduled_tokens[:num_decode_reqs] = (
num_scheduled_tokens[:num_decode_reqs] * self.pcp_world_size)
num_padded_scheduled_tokens[:self.num_decode_reqs] = (
num_scheduled_tokens[:self.num_decode_reqs] * self.pcp_world_size)
# Record how many pads were added per request (padded - original).
self.num_pcp_pads_cpu[:num_reqs] = (num_padded_scheduled_tokens -
self.num_pcp_pads_cpu[:self.num_reqs] = (num_padded_scheduled_tokens -
num_scheduled_tokens)
# cu_padded_tokens: cumulative sum of padded token counts,
@@ -221,7 +230,7 @@ class PCPManager:
self.pcp_unpad_mask_cpu[:pcp_padded_arange.shape[0]] = (
pcp_padded_arange < np.repeat(num_scheduled_tokens,
num_padded_scheduled_tokens))
unpad_mask_decode = self.pcp_unpad_mask_cpu[:num_decode_tokens *
unpad_mask_decode = self.pcp_unpad_mask_cpu[:self.num_decode_tokens *
self.pcp_world_size]
unpad_mask_decode = unpad_mask_decode.reshape(
[-1, self.pcp_world_size])
@@ -233,7 +242,7 @@ class PCPManager:
# For prefill requests, we further split the pcp_tokens into two chunks
# (head and tail). For decode requests, the chunk equals pcp_tokens.
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs]
pcp_chunk_sizes[:self.num_decode_reqs] = pcp_tokens[:self.num_decode_reqs]
# Build arange-style helpers for pcp tokens and chunk sizes:
# - pcp_arange gives indices repeated for each token in pcp_tokens
@@ -271,16 +280,16 @@ class PCPManager:
# Fill tail positions. Note decode requests do not have tail chunks,
# so the tail filling is only for prefill positions.
positions[~pcp_head_chunk_mask] = (
pcp_chunk_arange[num_decode_tokens:] +
np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:])
pcp_chunk_arange[self.num_decode_tokens:] +
np.repeat(tail_start_loc, pcp_chunk_sizes)[self.num_decode_tokens:])
return positions
positions = get_current_rank_positions(0, self.pcp_world_rank)
# Decode tokens are duplicated only after AG. But their positions are
# same without prefill context parallel.
if num_decode_reqs > 0:
positions[:num_decode_tokens] = self._get_cumsum_and_arange(
num_scheduled_tokens[:num_decode_reqs], arange_np)[1]
if self.num_decode_reqs > 0:
positions[:self.num_decode_tokens] = self._get_cumsum_and_arange(
num_scheduled_tokens[:self.num_decode_reqs], arange_np)[1]
# Build the restore index used after allgather.
padded_pos_start_loc = np.roll(cu_padded_tokens, 1)
@@ -294,27 +303,16 @@ class PCPManager:
all_positions.argsort())
self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0])
self.pcp_tokens[:num_reqs] = pcp_tokens[:num_reqs]
self.total_num_sampled_tokens_pcp = pcp_tokens[:num_reqs].sum()
self.pcp_tokens[:self.num_reqs] = pcp_tokens[:self.num_reqs]
self.total_num_sampled_tokens_pcp = pcp_tokens[:self.num_reqs].sum()
return (
pcp_tokens[:num_reqs],
pcp_tokens[:self.num_reqs],
positions,
)
def get_logits_indices(self, cu_num_tokens: np.ndarray, num_reqs: int):
def get_logits_indices(self, cu_num_tokens: np.ndarray):
return (torch.from_numpy(cu_num_tokens) * self.pcp_world_size -
self.num_pcp_pads_cpu_tensor[:num_reqs] - 1)
def get_discard_request_mask(
self,
num_computed_tokens_cpu: np.ndarray,
num_scheduled_tokens: np.ndarray,
num_reqs: int,
num_tokens_np: np.ndarray,
):
return (num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens * self.pcp_world_size -
self.num_pcp_pads_cpu[:num_reqs]) < num_tokens_np
self.num_pcp_pads_cpu_tensor[:self.num_reqs] - 1)
def get_padded_slot_mapping(self, num_tokens: int, num_tokens_padded: int,
slot_mapping: torch.Tensor):
@@ -350,7 +348,6 @@ class PCPManager:
def generate_pcp_mtp_input(
self,
num_reqs: int,
total_num_scheduled_tokens: int,
num_scheduled_tokens: dict[str, int],
with_prefill: bool = True,
@@ -369,18 +366,18 @@ class PCPManager:
so we record original input_ids here.
"""
total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens
num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32)
num_scheduled_tokens_pcp_full = np.empty(self.num_reqs, dtype=np.int32)
for i, req_id in enumerate(input_batch.req_ids):
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy(
self.query_lens_pcp_full.cpu[:self.num_reqs] = torch.from_numpy(
num_scheduled_tokens_pcp_full)
req_indices_pcp_full = np.repeat(arange_np[:num_reqs],
req_indices_pcp_full = np.repeat(arange_np[:self.num_reqs],
num_scheduled_tokens_pcp_full)
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
self.query_start_loc_pcp_full.np[0] = 0
self.query_start_loc_pcp_full.np[1:num_reqs +
self.query_start_loc_pcp_full.np[1:self.num_reqs +
1] = cu_num_tokens_pcp_full
self.query_start_loc_pcp_full.np[num_reqs + 1:].fill(-1)
self.query_start_loc_pcp_full.np[self.num_reqs + 1:].fill(-1)
cumsums_offsets_pcp_full = np.repeat(
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full,
num_scheduled_tokens_pcp_full)
@@ -413,13 +410,13 @@ class PCPManager:
if self.decode_threshold > 2 and not with_prefill:
num_tokens_ori = sum(list(num_scheduled_tokens.values()))
num_tokens_mtp = \
num_tokens_ori + num_reqs * (self.decode_threshold - 2)
num_tokens_ori + self.num_reqs * (self.decode_threshold - 2)
num_tokens_mtp_pad = num_tokens_mtp * self.pcp_world_size
req_indices_split = np.array_split(req_indices,
cu_num_tokens)[:num_reqs]
cu_num_tokens)[:self.num_reqs]
positions_split = np.array_split(positions_np,
cu_num_tokens)[:num_reqs]
for req_idx in range(num_reqs):
cu_num_tokens)[:self.num_reqs]
for req_idx in range(self.num_reqs):
ori_req_indice = req_indices_split[req_idx]
ori_position = positions_split[req_idx]
req_indices_split[req_idx] = np.append(
@@ -567,25 +564,20 @@ class PCPManager:
input_batch, num_scheduled_tokens):
from vllm_ascend.attention.utils import \
AscendPrefillContextParallelMetadata
num_reqs = input_batch.num_reqs or query_lens.size(0)
query_lens_new = self.query_lens_pcp_full.cpu[:num_reqs] \
if self.pcp_world_size > 1 and self.speculative_config else query_lens
num_decodes = (query_lens_new <= self.decode_threshold).sum().item()
num_prefills = num_reqs - num_decodes
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
long_seq_metadata = None
if self.pcp_world_size * self.dcp_world_size > 1:
decode_context_lens = input_batch.num_computed_tokens_cpu[:
num_decodes] + num_scheduled_tokens[:
num_decodes]
self.num_decode_reqs] + num_scheduled_tokens[:
self.num_decode_reqs]
prefill_context_lens = input_batch.num_computed_tokens_cpu[
num_decodes:num_reqs]
self.num_decode_reqs:self.num_reqs]
context_lens = np.concatenate(
[decode_context_lens, prefill_context_lens])
num_computed_tokens_of_pcp_dcp = torch.zeros(
[
num_reqs * self.decode_threshold, self.pcp_world_size,
self.num_reqs * self.decode_threshold, self.pcp_world_size,
self.dcp_world_size
],
dtype=torch.int32,
@@ -605,24 +597,24 @@ class PCPManager:
)
if self.decode_threshold > 1:
num_computed_tokens_of_pcp_dcp_list = []
if num_decodes:
if self.num_decode_reqs:
num_decodes_flatten = \
query_lens[:num_decodes].sum().item()
if query_lens[:num_decodes].min().item(
query_lens[:self.num_decode_reqs].sum().item()
if query_lens[:self.num_decode_reqs].min().item(
) == self.decode_threshold:
decode_flatten_idx = list(range(num_decodes_flatten))
else:
decode_flatten_idx = []
for req_id in range(num_decodes):
for req_id in range(self.num_decode_reqs):
offset = (req_id + 1) * self.decode_threshold
decode_flatten_idx += \
list(range(offset - query_lens[req_id], offset))
num_computed_tokens_of_pcp_dcp_list.append(
num_computed_tokens_of_pcp_dcp[decode_flatten_idx])
if num_prefills:
if self.num_prefill_reqs:
num_computed_tokens_of_pcp_dcp_list.append(
num_computed_tokens_of_pcp_dcp[
(num_decodes + 1) * self.decode_threshold -
(self.num_decode_reqs + 1) * self.decode_threshold -
1::self.decode_threshold])
num_computed_tokens_of_pcp_dcp = torch.cat(
num_computed_tokens_of_pcp_dcp_list, dim=0)
@@ -643,7 +635,7 @@ class PCPManager:
q_head_chunk_id = self.pcp_world_rank
q_tail_chunk_id = self.pcp_world_size * 2 - 1 - self.pcp_world_rank
for i, seq_len in enumerate(query_lens):
if i < num_decodes:
if i < self.num_decode_reqs:
continue
chunk_len = seq_len // 2
chunk_seqlens.append(chunk_len)