[Bugfix] fix bug of pcp+mtp+async scheduler (#5994)

### What this PR does / why we need it?
Fixed the issue where the PCP and MTP services could not be started due
to asynchronous scheduling.

After the pcp, mtp, and asynchronous scheduling functions are enabled,
the service is suspended because of a shape mismatch after a curl
request is sent. This PR resolves this issue.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2026-01-20 15:24:05 +08:00
committed by GitHub
parent ea57e3e7a4
commit 5892455f43
4 changed files with 138 additions and 12 deletions

View File

@@ -47,6 +47,7 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
max_num_reqs=1000,
device="cpu",
vllm_config=vllm_config,
use_async_scheduling=False,
pin_memory=False)
input_batch = MagicMock()
input_batch.num_reqs = num_reqs
@@ -65,13 +66,16 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
num_prompt_tokens.append(query_lens[i])
num_tokens.append(query_lens[i])
input_batch.num_computed_tokens_cpu = torch.tensor(num_computed_tokens)
input_batch.num_computed_tokens_cpu = np.array(num_computed_tokens)
input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens)
input_batch.num_tokens = torch.tensor(num_tokens)
num_scheduled_tokens = np.array(
query_lens) - input_batch.num_computed_tokens_cpu
query_lens = torch.tensor(query_lens)
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens,
input_batch)
input_batch,
num_scheduled_tokens)
if not expect_not_none:
assert result is None, f"Expected to return None, but got {type(result)}"
@@ -128,6 +132,7 @@ def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
max_num_reqs=1000,
device="cpu",
vllm_config=vllm_config,
use_async_scheduling=False,
pin_memory=False)
input_batch = MagicMock()
input_batch.num_reqs = num_reqs
@@ -193,6 +198,7 @@ def test_get_cp_local_seq_lens(
max_num_reqs=1000,
device="cpu",
vllm_config=vllm_config,
use_async_scheduling=False,
pin_memory=False)
ret = pcp_manager._get_cp_local_seq_lens(seq_lens, pcp_world_size,
dcp_world_size,
@@ -276,6 +282,7 @@ def test_generate_pcp_mtp_input(
max_num_reqs=max_num_reqs,
device="cpu",
vllm_config=vllm_config,
use_async_scheduling=False,
pin_memory=False)
arange_np = np.arange(max_model_len)
input_batch = MagicMock()