[v0.16.0][P/D][Bugfix] Support ALL D-Nodes in fullgraph when running MTP in PD (#6948)

### What this PR does / why we need it?
Fix the bug for v0.16.0 recompute_scheduler, the same way as
https://github.com/vllm-project/vllm-ascend/pull/5472.

Signed-off-by: chenmenglong <chenmenglong1@huawei.com>
This commit is contained in:
MengLong Chen
2026-03-06 10:01:33 +08:00
committed by GitHub
parent ccd00798f3
commit a838a89630

View File

@@ -19,7 +19,7 @@
from __future__ import annotations
import time
from collections import defaultdict
from collections import defaultdict, deque
from dataclasses import dataclass, fields
from vllm.config import SchedulerConfig, VllmConfig
@@ -37,7 +37,8 @@ from vllm.v1.core.sched.utils import remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason
from vllm.v1.metrics.perf import PerfStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.request import Request, RequestStatus, StreamingUpdate
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.utils import ConstantList, record_function_or_nullcontext
@@ -80,6 +81,43 @@ class RecomputeSchedulerOutput(SchedulerOutput):
class RecomputeScheduler(Scheduler):
running: list[Request]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# When is_mtp_kv_consumer is true, we will fill request.spec_token_ids
# with placeholder tokens to enable full graph when decode nodes pull
# the KV cache of one request from prefill nodes.
self.is_mtp_kv_consumer = (
self.vllm_config.speculative_config
and self.vllm_config.kv_transfer_config
and self.vllm_config.kv_transfer_config.is_kv_consumer
)
def add_request(self, request: Request) -> None:
existing = self.requests.get(request.request_id)
if existing is not None:
update = StreamingUpdate.from_request(request)
if existing.status != RequestStatus.WAITING_FOR_STREAMING_REQ:
assert existing.streaming_queue is not None, "duplicate request id"
# Queue next input chunk (or finished sentinel).
existing.streaming_queue.append(update)
elif update is not None:
# Commence next input chunk.
self._update_request_as_session(existing, update)
else:
# Streaming-input session finished.
self.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED)
else:
if request.resumable:
request.streaming_queue = deque()
# Fill in placeholder tokens to enable full graph compatibility. Without
# placeholders, graph matching may fail, forcing eager mode execution.
if self.is_mtp_kv_consumer:
request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens
self.waiting.add_request(request)
self.requests[request.request_id] = request
if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED)
def schedule(self) -> RecomputeSchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
@@ -408,7 +446,10 @@ class RecomputeScheduler(Scheduler):
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if self.is_mtp_kv_consumer:
num_new_tokens = request.num_tokens_with_spec - num_computed_tokens
else:
num_new_tokens = request.num_tokens - num_computed_tokens
threshold = self.scheduler_config.long_prefill_token_threshold
if 0 < threshold < num_new_tokens:
num_new_tokens = threshold
@@ -509,6 +550,25 @@ class RecomputeScheduler(Scheduler):
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
# For spec_token_ids, the waiting queue has the same processing
# as the running queue.
if self.is_mtp_kv_consumer and request.spec_token_ids:
num_scheduled_spec_tokens = (
num_new_tokens
+ request.num_computed_tokens
- request.num_tokens
- request.num_output_placeholders
)
if num_scheduled_spec_tokens > 0:
spec_token_ids = request.spec_token_ids
if len(spec_token_ids) > num_scheduled_spec_tokens:
spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens]
scheduled_spec_decode_tokens[request.request_id] = spec_token_ids
# New spec tokens will be set in `update_draft_token_ids` before the
# next step when applicable.
request.spec_token_ids = []
self.running.append(request)
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)