From 2d49f9079a27b45870dc5cf8220ab84bc294b0bf Mon Sep 17 00:00:00 2001 From: MengLong Chen <71744434+dragondream-chen@users.noreply.github.com> Date: Thu, 26 Feb 2026 19:09:05 +0800 Subject: [PATCH] [BugFix] Support ALL D-Nodes in fullgraph when running MTP in PD (#5472) ### What this PR does / why we need it? **BUG** When using prefill-decode disaggregation + MTP + full graph +asynchronous scheduling, the KV cache pulled by decode nodes from prefill decodes does not include spec tokens. As a result, the total_num_scheduled_tokens obtained by decode nodes from the scheduler lacks spec tokens. When determining whether to enqueue the full graph on decode nodes, the condition for uniform_decode ` scheduler_output.total_num_scheduled_tokens == self.input_batch.num_reqs * max_query_len` is not met, leading to the current instance not being enqueued into the full graph. The above situation leads to both full graph and eagle mode instances coexisting in the decode instances. Due to the synchronization wait of MoeDispatch, the decode instances in full graph are significantly slowed down by the instance in eagle mode. **Solution** The scenario is PD separation + MTP + Full Graph + asynchronous scheduling. On the decode nodes, the spec tokens of the request with KV cache from P need be padded. Then, the padded spec tokens will be rejected by sampling. This operation ensures that the uniform_decode condition is satisfied when determining whether decode nodes are included in the full graph, thereby guaranteeing that all decode instances are present in the full graph and avoiding synchronous waiting for MoeDispatch. - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5326c89803566a131c928f7fdd2100b75c981a42 Signed-off-by: chenmenglong --- vllm_ascend/core/recompute_scheduler.py | 44 ++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/core/recompute_scheduler.py b/vllm_ascend/core/recompute_scheduler.py index 90ab413a..a3358558 100644 --- a/vllm_ascend/core/recompute_scheduler.py +++ b/vllm_ascend/core/recompute_scheduler.py @@ -40,6 +40,7 @@ from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutp from vllm.v1.metrics.perf import PerfStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus +from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.utils import ConstantList, record_function_or_nullcontext @@ -83,6 +84,27 @@ class RecomputeSchedulerOutput(SchedulerOutput): class RecomputeScheduler(Scheduler): running: list[Request] + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # When is_mtp_kv_consumer is true, we will fill request.spec_token_ids + # with placeholder tokens to enable full graph when decode nodes pull + # the KV cache of one request from prefill nodes. + self.is_mtp_kv_consumer = ( + self.vllm_config.speculative_config + and self.vllm_config.kv_transfer_config + and self.vllm_config.kv_transfer_config.is_kv_consumer + ) + + def add_request(self, request: Request) -> None: + # Fill in placeholder tokens to enable full graph compatibility. Without + # placeholders, graph matching may fail, forcing eager mode execution. + if self.is_mtp_kv_consumer: + request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens + self.waiting.add_request(request) + self.requests[request.request_id] = request + if self.log_stats: + request.record_event(EngineCoreEventType.QUEUED) + def schedule(self) -> RecomputeSchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -391,7 +413,10 @@ class RecomputeScheduler(Scheduler): # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. - num_new_tokens = request.num_tokens - num_computed_tokens + if self.is_mtp_kv_consumer: + num_new_tokens = request.num_tokens_with_spec - num_computed_tokens + else: + num_new_tokens = request.num_tokens - num_computed_tokens threshold = self.scheduler_config.long_prefill_token_threshold if 0 < threshold < num_new_tokens: num_new_tokens = threshold @@ -478,6 +503,23 @@ class RecomputeScheduler(Scheduler): self._update_connector_prefix_cache_stats(request) + # For spec_token_ids, the waiting queue has the same processing + # as the running queue. + if self.is_mtp_kv_consumer and request.spec_token_ids: + num_scheduled_spec_tokens = ( + num_new_tokens + + request.num_computed_tokens + - request.num_tokens + - request.num_output_placeholders + ) + if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] + scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids + # New spec tokens will be set in `update_draft_token_ids` before the + # next step when applicable. + request.spec_token_ids = [] + self.running.append(request) if self.log_stats: request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)