[v0.16.0][P/D][Bugfix] Support ALL D-Nodes in fullgraph when running MTP in PD (#6948)

### What this PR does / why we need it?
Fix the bug for v0.16.0 recompute_scheduler, the same way as
https://github.com/vllm-project/vllm-ascend/pull/5472.

Signed-off-by: chenmenglong <chenmenglong1@huawei.com>
This commit is contained in:
MengLong Chen
2026-03-06 10:01:33 +08:00
committed by GitHub
parent ccd00798f3
commit a838a89630

View File

@@ -19,7 +19,7 @@
from __future__ import annotations from __future__ import annotations
import time import time
from collections import defaultdict from collections import defaultdict, deque
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from vllm.config import SchedulerConfig, VllmConfig from vllm.config import SchedulerConfig, VllmConfig
@@ -37,7 +37,8 @@ from vllm.v1.core.sched.utils import remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason
from vllm.v1.metrics.perf import PerfStats from vllm.v1.metrics.perf import PerfStats
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus, StreamingUpdate
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.utils import ConstantList, record_function_or_nullcontext from vllm.v1.utils import ConstantList, record_function_or_nullcontext
@@ -80,6 +81,43 @@ class RecomputeSchedulerOutput(SchedulerOutput):
class RecomputeScheduler(Scheduler): class RecomputeScheduler(Scheduler):
running: list[Request] running: list[Request]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# When is_mtp_kv_consumer is true, we will fill request.spec_token_ids
# with placeholder tokens to enable full graph when decode nodes pull
# the KV cache of one request from prefill nodes.
self.is_mtp_kv_consumer = (
self.vllm_config.speculative_config
and self.vllm_config.kv_transfer_config
and self.vllm_config.kv_transfer_config.is_kv_consumer
)
def add_request(self, request: Request) -> None:
existing = self.requests.get(request.request_id)
if existing is not None:
update = StreamingUpdate.from_request(request)
if existing.status != RequestStatus.WAITING_FOR_STREAMING_REQ:
assert existing.streaming_queue is not None, "duplicate request id"
# Queue next input chunk (or finished sentinel).
existing.streaming_queue.append(update)
elif update is not None:
# Commence next input chunk.
self._update_request_as_session(existing, update)
else:
# Streaming-input session finished.
self.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED)
else:
if request.resumable:
request.streaming_queue = deque()
# Fill in placeholder tokens to enable full graph compatibility. Without
# placeholders, graph matching may fail, forcing eager mode execution.
if self.is_mtp_kv_consumer:
request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens
self.waiting.add_request(request)
self.requests[request.request_id] = request
if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED)
def schedule(self) -> RecomputeSchedulerOutput: def schedule(self) -> RecomputeSchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm: # NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler. # There's no "decoding phase" nor "prefill phase" in the scheduler.
@@ -408,7 +446,10 @@ class RecomputeScheduler(Scheduler):
# We use `request.num_tokens` instead of # We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed # `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens. # requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens if self.is_mtp_kv_consumer:
num_new_tokens = request.num_tokens_with_spec - num_computed_tokens
else:
num_new_tokens = request.num_tokens - num_computed_tokens
threshold = self.scheduler_config.long_prefill_token_threshold threshold = self.scheduler_config.long_prefill_token_threshold
if 0 < threshold < num_new_tokens: if 0 < threshold < num_new_tokens:
num_new_tokens = threshold num_new_tokens = threshold
@@ -509,6 +550,25 @@ class RecomputeScheduler(Scheduler):
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue continue
# For spec_token_ids, the waiting queue has the same processing
# as the running queue.
if self.is_mtp_kv_consumer and request.spec_token_ids:
num_scheduled_spec_tokens = (
num_new_tokens
+ request.num_computed_tokens
- request.num_tokens
- request.num_output_placeholders
)
if num_scheduled_spec_tokens > 0:
spec_token_ids = request.spec_token_ids
if len(spec_token_ids) > num_scheduled_spec_tokens:
spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens]
scheduled_spec_decode_tokens[request.request_id] = spec_token_ids
# New spec tokens will be set in `update_draft_token_ids` before the
# next step when applicable.
request.spec_token_ids = []
self.running.append(request) self.running.append(request)
if self.log_stats: if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)