[Bugfix] fix bug of pcp+mtp+async scheduler (#5994)

### What this PR does / why we need it?
Fixed the issue where the PCP and MTP services could not be started due
to asynchronous scheduling.

After the pcp, mtp, and asynchronous scheduling functions are enabled,
the service is suspended because of a shape mismatch after a curl
request is sent. This PR resolves this issue.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2026-01-20 15:24:05 +08:00
committed by GitHub
parent ea57e3e7a4
commit 5892455f43
4 changed files with 138 additions and 12 deletions

View File

@@ -17,7 +17,7 @@
# Adapted from vllm-project/vllm/vllm/worker/worker.py
#
from typing import List
from typing import TYPE_CHECKING, List
import numpy as np
import torch
@@ -25,6 +25,9 @@ from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
class PCPManager:
"""
@@ -44,6 +47,7 @@ class PCPManager:
max_num_reqs: int,
device: torch.device,
vllm_config: VllmConfig,
use_async_scheduling: bool,
pin_memory: bool = False,
) -> None:
self.pcp_world_size = pcp_world_size
@@ -58,6 +62,7 @@ class PCPManager:
self.max_num_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs
self.device = device
self.use_async_scheduling = use_async_scheduling
self.pcp_allgather_restore_idx = CpuGpuBuffer(
max_buffer_num_tokens,
dtype=torch.int64,
@@ -354,6 +359,9 @@ class PCPManager:
req_indices=None,
positions_np=None,
cu_num_tokens=None,
draft_token_ids=None,
scheduler_output=None,
num_spec_tokens=None,
):
"""
While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group,
@@ -390,6 +398,12 @@ class PCPManager:
torch.from_numpy(token_indices_pcp_full),
out=self.input_ids_pcp_full.
cpu[:total_num_scheduled_tokens_pcp_full])
if self.use_async_scheduling:
self._update_input_ids_pcp_full_ids(input_batch, draft_token_ids,
scheduler_output,
total_num_scheduled_tokens,
cu_num_tokens_pcp_full,
num_spec_tokens)
self.query_lens_pcp_full.copy_to_gpu()
self.query_start_loc_pcp_full.copy_to_gpu()
self.input_ids_pcp_full.copy_to_gpu(
@@ -428,6 +442,99 @@ class PCPManager:
mtp_slot_pad[unpad_mask] = mtp_slot_ori
self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True)
def _update_input_ids_pcp_full_ids(
self,
input_batch,
draft_token_ids,
scheduler_output: "SchedulerOutput",
total_num_scheduled_tokens: int,
cu_num_tokens: np.ndarray,
num_spec_tokens: int,
) -> None:
"""Prepare the input IDs for the current batch.
Carefully handles the `prev_sampled_token_ids` which can be cached
from the previous engine iteration, in which case those tokens on the
GPU need to be copied into the corresponding slots into input_ids."""
if (input_batch.prev_sampled_token_ids is None
or input_batch.prev_req_id_to_index is None):
return
# Async scheduling case, where some decode requests from the previous
# iteration won't have entries in input_ids_cpu and need to be copied
# on the GPU from prev_sampled_token_ids.
prev_req_id_to_index = input_batch.prev_req_id_to_index
sample_flattened_indices: list[int] = []
spec_flattened_indices: list[int] = []
prev_common_req_indices: list[int] = []
prev_draft_token_indices: list[int] = []
total_num_spec_tokens = 0
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id, cur_index in input_batch.req_id_to_index.items():
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
prev_common_req_indices.append(prev_index)
# We need to compute the flattened input_ids index of the
# last token in each common request.
draft_len = len(scheduled_spec_tokens.get(req_id, ()))
total_num_spec_tokens += draft_len
flattened_index = cu_num_tokens[cur_index].item() - 1
# example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2]
# sample_flattened_indices = [0, 2, 5]
# spec_flattened_indices = [1, 3, 4, 6, 7]
sample_flattened_indices.append(flattened_index - draft_len)
spec_flattened_indices.extend(
range(flattened_index - draft_len + 1,
flattened_index + 1))
start = prev_index * num_spec_tokens
# prev_draft_token_indices is used to find which draft_tokens_id
# should be copied to input_ids
# example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
# flatten draft_tokens_id [1,2,3,4,5,6]
# draft_len of each request [1, 2, 1]
# then prev_draft_token_indices is [0, 2, 3, 4]
prev_draft_token_indices.extend(range(start,
start + draft_len))
num_commmon_tokens = len(sample_flattened_indices)
if num_commmon_tokens == 0:
# No requests in common with the previous iteration
# So input_ids.cpu will have all the input ids.
return
# Upload the index tensors asynchronously so the scatter can be non-blocking.
sampled_tokens_index_tensor = torch.tensor(sample_flattened_indices,
dtype=torch.int64)
prev_common_req_indices_tensor = torch.tensor(prev_common_req_indices,
dtype=torch.int64)
self.input_ids_pcp_full.cpu.scatter_(
dim=0,
index=sampled_tokens_index_tensor,
src=input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor, 0].cpu(),
)
# Scatter the draft tokens after the sampled tokens are scattered.
if draft_token_ids is None or not spec_flattened_indices:
return
assert isinstance(draft_token_ids, torch.Tensor)
draft_tokens_index_tensor = torch.tensor(spec_flattened_indices,
dtype=torch.int64)
prev_draft_token_indices_tensor = torch.tensor(
prev_draft_token_indices, dtype=torch.int64)
# because input_ids dtype is torch.int32,
# so convert draft_token_ids to torch.int32 here.
draft_token_ids = draft_token_ids.to(dtype=torch.int32)
self.input_ids_pcp_full.cpu.scatter_(
dim=0,
index=draft_tokens_index_tensor,
src=draft_token_ids.flatten()
[prev_draft_token_indices_tensor].cpu(),
)
def _get_cp_local_seq_lens(
self,
seq_lens: torch.Tensor,
@@ -498,7 +605,7 @@ class PCPManager:
torch.float32).argsort().to(torch.int32)
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens,
input_batch):
input_batch, num_scheduled_tokens):
from vllm_ascend.attention.utils import \
AscendPrefillContextParallelMetadata
num_reqs = input_batch.num_reqs or query_lens.size(0)
@@ -510,7 +617,9 @@ class PCPManager:
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
long_seq_metadata = None
if self.pcp_world_size * self.dcp_world_size > 1:
decode_context_lens = input_batch.num_tokens[:num_decodes]
decode_context_lens = input_batch.num_computed_tokens_cpu[:
num_decodes] + num_scheduled_tokens[:
num_decodes]
prefill_context_lens = input_batch.num_computed_tokens_cpu[
num_decodes:num_reqs]
context_lens = np.concatenate(