[BugFix] Support ALL D-Nodes in fullgraph when running MTP in PD (#5472)

### What this PR does / why we need it?
**BUG**
When using prefill-decode disaggregation + MTP + full graph
+asynchronous scheduling, the KV cache pulled by decode nodes from
prefill decodes does not include spec tokens. As a result, the
total_num_scheduled_tokens obtained by decode nodes from the scheduler
lacks spec tokens. When determining whether to enqueue the full graph on
decode nodes, the condition for uniform_decode `
scheduler_output.total_num_scheduled_tokens == self.input_batch.num_reqs
* max_query_len` is not met, leading to the current instance not being
enqueued into the full graph.

The above situation leads to both full graph and eagle mode instances
coexisting in the decode instances. Due to the synchronization wait of
MoeDispatch, the decode instances in full graph are significantly slowed
down by the instance in eagle mode.

**Solution**
The scenario is PD separation + MTP + Full Graph + asynchronous
scheduling.
On the decode nodes, the spec tokens of the request with KV cache from P
need be padded. Then, the padded spec tokens will be rejected by
sampling. This operation ensures that the uniform_decode condition is
satisfied when determining whether decode nodes are included in the full
graph, thereby guaranteeing that all decode instances are present in the
full graph and avoiding synchronous waiting for MoeDispatch.

- vLLM version: v0.15.0
- vLLM main:
5326c89803

Signed-off-by: chenmenglong <chenmenglong1@huawei.com>
This commit is contained in:
MengLong Chen
2026-02-26 19:09:05 +08:00
committed by GitHub
parent 532f7a82f2
commit 2d49f9079a

View File

@@ -40,6 +40,7 @@ from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutp
from vllm.v1.metrics.perf import PerfStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.utils import ConstantList, record_function_or_nullcontext
@@ -83,6 +84,27 @@ class RecomputeSchedulerOutput(SchedulerOutput):
class RecomputeScheduler(Scheduler):
running: list[Request]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# When is_mtp_kv_consumer is true, we will fill request.spec_token_ids
# with placeholder tokens to enable full graph when decode nodes pull
# the KV cache of one request from prefill nodes.
self.is_mtp_kv_consumer = (
self.vllm_config.speculative_config
and self.vllm_config.kv_transfer_config
and self.vllm_config.kv_transfer_config.is_kv_consumer
)
def add_request(self, request: Request) -> None:
# Fill in placeholder tokens to enable full graph compatibility. Without
# placeholders, graph matching may fail, forcing eager mode execution.
if self.is_mtp_kv_consumer:
request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens
self.waiting.add_request(request)
self.requests[request.request_id] = request
if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED)
def schedule(self) -> RecomputeSchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
@@ -391,7 +413,10 @@ class RecomputeScheduler(Scheduler):
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if self.is_mtp_kv_consumer:
num_new_tokens = request.num_tokens_with_spec - num_computed_tokens
else:
num_new_tokens = request.num_tokens - num_computed_tokens
threshold = self.scheduler_config.long_prefill_token_threshold
if 0 < threshold < num_new_tokens:
num_new_tokens = threshold
@@ -478,6 +503,23 @@ class RecomputeScheduler(Scheduler):
self._update_connector_prefix_cache_stats(request)
# For spec_token_ids, the waiting queue has the same processing
# as the running queue.
if self.is_mtp_kv_consumer and request.spec_token_ids:
num_scheduled_spec_tokens = (
num_new_tokens
+ request.num_computed_tokens
- request.num_tokens
- request.num_output_placeholders
)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids
# New spec tokens will be set in `update_draft_token_ids` before the
# next step when applicable.
request.spec_token_ids = []
self.running.append(request)
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)