From 5892455f438958f81d34df43ea4aff26febc3816 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Tue, 20 Jan 2026 15:24:05 +0800 Subject: [PATCH] [Bugfix] fix bug of pcp+mtp+async scheduler (#5994) ### What this PR does / why we need it? Fixed the issue where the PCP and MTP services could not be started due to asynchronous scheduling. After the pcp, mtp, and asynchronous scheduling functions are enabled, the service is suspended because of a shape mismatch after a curl request is sent. This PR resolves this issue. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: weiguihua2 --- .../4-cards/long_sequence/test_mtp.py | 2 +- tests/ut/worker/test_pcp_manager.py | 11 +- vllm_ascend/worker/model_runner_v1.py | 22 +++- vllm_ascend/worker/pcp_utils.py | 115 +++++++++++++++++- 4 files changed, 138 insertions(+), 12 deletions(-) diff --git a/tests/e2e/multicard/4-cards/long_sequence/test_mtp.py b/tests/e2e/multicard/4-cards/long_sequence/test_mtp.py index d8f38ba3..3cf269c5 100644 --- a/tests/e2e/multicard/4-cards/long_sequence/test_mtp.py +++ b/tests/e2e/multicard/4-cards/long_sequence/test_mtp.py @@ -70,12 +70,12 @@ def test_pcp_dcp_mtp3_eager(): max_num_batched_tokens=1024, enable_expert_parallel=True, block_size=128, + async_scheduling=True, speculative_config={ "num_speculative_tokens": 3, "method": "deepseek_mtp", }, enforce_eager=True, - async_scheduling=False, ) as runner: runner.generate_greedy(prompts, 32) diff --git a/tests/ut/worker/test_pcp_manager.py b/tests/ut/worker/test_pcp_manager.py index eaa34a2e..9a6779c1 100644 --- a/tests/ut/worker/test_pcp_manager.py +++ b/tests/ut/worker/test_pcp_manager.py @@ -47,6 +47,7 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, max_num_reqs=1000, device="cpu", vllm_config=vllm_config, + use_async_scheduling=False, pin_memory=False) input_batch = MagicMock() input_batch.num_reqs = num_reqs @@ -65,13 +66,16 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, num_prompt_tokens.append(query_lens[i]) num_tokens.append(query_lens[i]) - input_batch.num_computed_tokens_cpu = torch.tensor(num_computed_tokens) + input_batch.num_computed_tokens_cpu = np.array(num_computed_tokens) input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens) input_batch.num_tokens = torch.tensor(num_tokens) + num_scheduled_tokens = np.array( + query_lens) - input_batch.num_computed_tokens_cpu query_lens = torch.tensor(query_lens) result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens, - input_batch) + input_batch, + num_scheduled_tokens) if not expect_not_none: assert result is None, f"Expected to return None, but got {type(result)}" @@ -128,6 +132,7 @@ def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens, max_num_reqs=1000, device="cpu", vllm_config=vllm_config, + use_async_scheduling=False, pin_memory=False) input_batch = MagicMock() input_batch.num_reqs = num_reqs @@ -193,6 +198,7 @@ def test_get_cp_local_seq_lens( max_num_reqs=1000, device="cpu", vllm_config=vllm_config, + use_async_scheduling=False, pin_memory=False) ret = pcp_manager._get_cp_local_seq_lens(seq_lens, pcp_world_size, dcp_world_size, @@ -276,6 +282,7 @@ def test_generate_pcp_mtp_input( max_num_reqs=max_num_reqs, device="cpu", vllm_config=vllm_config, + use_async_scheduling=False, pin_memory=False) arange_np = np.arange(max_model_len) input_batch = MagicMock() diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 210dea7b..4718551c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -226,6 +226,7 @@ class NPUModelRunner(GPUModelRunner): self.max_num_reqs, self.device, self.vllm_config, + self.use_async_scheduling, self.pin_memory, ) # TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this @@ -540,10 +541,18 @@ class NPUModelRunner(GPUModelRunner): # for pcp, prefill mtp should use origin scheduleroutput , if self.speculative_config and self.pcp_size * self.dcp_size > 1: self.pcp_manager.generate_pcp_mtp_input( - num_reqs, total_num_scheduled_tokens, - scheduler_output.num_scheduled_tokens, with_prefill, - self.input_batch, self.arange_np, req_indices, positions_np, - cu_num_tokens) + num_reqs, + total_num_scheduled_tokens, + scheduler_output.num_scheduled_tokens, + with_prefill, + self.input_batch, + self.arange_np, + req_indices, + positions_np, + cu_num_tokens, + self._draft_token_ids, # type: ignore[has-type] + scheduler_output, + self.num_spec_tokens) if self.pcp_size > 1: if not self.vllm_config.model_config.use_mla: @@ -929,7 +938,7 @@ class NPUModelRunner(GPUModelRunner): if self.pcp_size * self.dcp_size > 1: self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata( total_num_scheduled_tokens, self.query_lens, - self.input_batch) + self.input_batch, num_scheduled_tokens) blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1) if self.pcp_size > 1: slot_mapping_pcp = self.pcp_manager.get_padded_slot_mapping( @@ -1946,7 +1955,8 @@ class NPUModelRunner(GPUModelRunner): slot_mapping = self.input_batch.block_table[ kv_cache_group_id].slot_mapping long_seq_metadata = None if self.pcp_size * self.dcp_size == 1 else self.pcp_manager.generate_pcp_metadata( - num_tokens, self.query_lens, self.input_batch) + num_tokens, self.query_lens, self.input_batch, + num_scheduled_tokens) if long_seq_metadata is not None: pcp_world_size = get_pcp_group().world_size dcp_world_size = get_dcp_group().world_size diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index f0ace8a4..40294807 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -17,7 +17,7 @@ # Adapted from vllm-project/vllm/vllm/worker/worker.py # -from typing import List +from typing import TYPE_CHECKING, List import numpy as np import torch @@ -25,6 +25,9 @@ from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv from vllm.v1.utils import CpuGpuBuffer +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + class PCPManager: """ @@ -44,6 +47,7 @@ class PCPManager: max_num_reqs: int, device: torch.device, vllm_config: VllmConfig, + use_async_scheduling: bool, pin_memory: bool = False, ) -> None: self.pcp_world_size = pcp_world_size @@ -58,6 +62,7 @@ class PCPManager: self.max_num_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens self.max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs self.device = device + self.use_async_scheduling = use_async_scheduling self.pcp_allgather_restore_idx = CpuGpuBuffer( max_buffer_num_tokens, dtype=torch.int64, @@ -354,6 +359,9 @@ class PCPManager: req_indices=None, positions_np=None, cu_num_tokens=None, + draft_token_ids=None, + scheduler_output=None, + num_spec_tokens=None, ): """ While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group, @@ -390,6 +398,12 @@ class PCPManager: torch.from_numpy(token_indices_pcp_full), out=self.input_ids_pcp_full. cpu[:total_num_scheduled_tokens_pcp_full]) + if self.use_async_scheduling: + self._update_input_ids_pcp_full_ids(input_batch, draft_token_ids, + scheduler_output, + total_num_scheduled_tokens, + cu_num_tokens_pcp_full, + num_spec_tokens) self.query_lens_pcp_full.copy_to_gpu() self.query_start_loc_pcp_full.copy_to_gpu() self.input_ids_pcp_full.copy_to_gpu( @@ -428,6 +442,99 @@ class PCPManager: mtp_slot_pad[unpad_mask] = mtp_slot_ori self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True) + def _update_input_ids_pcp_full_ids( + self, + input_batch, + draft_token_ids, + scheduler_output: "SchedulerOutput", + total_num_scheduled_tokens: int, + cu_num_tokens: np.ndarray, + num_spec_tokens: int, + ) -> None: + """Prepare the input IDs for the current batch. + + Carefully handles the `prev_sampled_token_ids` which can be cached + from the previous engine iteration, in which case those tokens on the + GPU need to be copied into the corresponding slots into input_ids.""" + + if (input_batch.prev_sampled_token_ids is None + or input_batch.prev_req_id_to_index is None): + return + + # Async scheduling case, where some decode requests from the previous + # iteration won't have entries in input_ids_cpu and need to be copied + # on the GPU from prev_sampled_token_ids. + prev_req_id_to_index = input_batch.prev_req_id_to_index + sample_flattened_indices: list[int] = [] + spec_flattened_indices: list[int] = [] + prev_common_req_indices: list[int] = [] + prev_draft_token_indices: list[int] = [] + total_num_spec_tokens = 0 + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens + + for req_id, cur_index in input_batch.req_id_to_index.items(): + if (prev_index := prev_req_id_to_index.get(req_id)) is not None: + prev_common_req_indices.append(prev_index) + # We need to compute the flattened input_ids index of the + # last token in each common request. + draft_len = len(scheduled_spec_tokens.get(req_id, ())) + total_num_spec_tokens += draft_len + flattened_index = cu_num_tokens[cur_index].item() - 1 + # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2] + # sample_flattened_indices = [0, 2, 5] + # spec_flattened_indices = [1, 3, 4, 6, 7] + sample_flattened_indices.append(flattened_index - draft_len) + spec_flattened_indices.extend( + range(flattened_index - draft_len + 1, + flattened_index + 1)) + start = prev_index * num_spec_tokens + # prev_draft_token_indices is used to find which draft_tokens_id + # should be copied to input_ids + # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]] + # flatten draft_tokens_id [1,2,3,4,5,6] + # draft_len of each request [1, 2, 1] + # then prev_draft_token_indices is [0, 2, 3, 4] + prev_draft_token_indices.extend(range(start, + start + draft_len)) + num_commmon_tokens = len(sample_flattened_indices) + + if num_commmon_tokens == 0: + # No requests in common with the previous iteration + # So input_ids.cpu will have all the input ids. + return + # Upload the index tensors asynchronously so the scatter can be non-blocking. + sampled_tokens_index_tensor = torch.tensor(sample_flattened_indices, + dtype=torch.int64) + prev_common_req_indices_tensor = torch.tensor(prev_common_req_indices, + dtype=torch.int64) + self.input_ids_pcp_full.cpu.scatter_( + dim=0, + index=sampled_tokens_index_tensor, + src=input_batch.prev_sampled_token_ids[ + prev_common_req_indices_tensor, 0].cpu(), + ) + + # Scatter the draft tokens after the sampled tokens are scattered. + if draft_token_ids is None or not spec_flattened_indices: + return + + assert isinstance(draft_token_ids, torch.Tensor) + draft_tokens_index_tensor = torch.tensor(spec_flattened_indices, + dtype=torch.int64) + prev_draft_token_indices_tensor = torch.tensor( + prev_draft_token_indices, dtype=torch.int64) + + # because input_ids dtype is torch.int32, + # so convert draft_token_ids to torch.int32 here. + draft_token_ids = draft_token_ids.to(dtype=torch.int32) + + self.input_ids_pcp_full.cpu.scatter_( + dim=0, + index=draft_tokens_index_tensor, + src=draft_token_ids.flatten() + [prev_draft_token_indices_tensor].cpu(), + ) + def _get_cp_local_seq_lens( self, seq_lens: torch.Tensor, @@ -498,7 +605,7 @@ class PCPManager: torch.float32).argsort().to(torch.int32) def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, - input_batch): + input_batch, num_scheduled_tokens): from vllm_ascend.attention.utils import \ AscendPrefillContextParallelMetadata num_reqs = input_batch.num_reqs or query_lens.size(0) @@ -510,7 +617,9 @@ class PCPManager: self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded long_seq_metadata = None if self.pcp_world_size * self.dcp_world_size > 1: - decode_context_lens = input_batch.num_tokens[:num_decodes] + decode_context_lens = input_batch.num_computed_tokens_cpu[: + num_decodes] + num_scheduled_tokens[: + num_decodes] prefill_context_lens = input_batch.num_computed_tokens_cpu[ num_decodes:num_reqs] context_lens = np.concatenate(