[v0.16.0][P/D][Bugfix] Support ALL D-Nodes in fullgraph when running MTP in PD (#6948)
### What this PR does / why we need it? Fix the bug for v0.16.0 recompute_scheduler, the same way as https://github.com/vllm-project/vllm-ascend/pull/5472. Signed-off-by: chenmenglong <chenmenglong1@huawei.com>
This commit is contained in:
@@ -19,7 +19,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
from vllm.config import SchedulerConfig, VllmConfig
|
||||
@@ -37,7 +37,8 @@ from vllm.v1.core.sched.utils import remove_all
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason
|
||||
from vllm.v1.metrics.perf import PerfStats
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.request import Request, RequestStatus, StreamingUpdate
|
||||
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.v1.utils import ConstantList, record_function_or_nullcontext
|
||||
|
||||
@@ -80,6 +81,43 @@ class RecomputeSchedulerOutput(SchedulerOutput):
|
||||
class RecomputeScheduler(Scheduler):
|
||||
running: list[Request]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# When is_mtp_kv_consumer is true, we will fill request.spec_token_ids
|
||||
# with placeholder tokens to enable full graph when decode nodes pull
|
||||
# the KV cache of one request from prefill nodes.
|
||||
self.is_mtp_kv_consumer = (
|
||||
self.vllm_config.speculative_config
|
||||
and self.vllm_config.kv_transfer_config
|
||||
and self.vllm_config.kv_transfer_config.is_kv_consumer
|
||||
)
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
existing = self.requests.get(request.request_id)
|
||||
if existing is not None:
|
||||
update = StreamingUpdate.from_request(request)
|
||||
if existing.status != RequestStatus.WAITING_FOR_STREAMING_REQ:
|
||||
assert existing.streaming_queue is not None, "duplicate request id"
|
||||
# Queue next input chunk (or finished sentinel).
|
||||
existing.streaming_queue.append(update)
|
||||
elif update is not None:
|
||||
# Commence next input chunk.
|
||||
self._update_request_as_session(existing, update)
|
||||
else:
|
||||
# Streaming-input session finished.
|
||||
self.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED)
|
||||
else:
|
||||
if request.resumable:
|
||||
request.streaming_queue = deque()
|
||||
# Fill in placeholder tokens to enable full graph compatibility. Without
|
||||
# placeholders, graph matching may fail, forcing eager mode execution.
|
||||
if self.is_mtp_kv_consumer:
|
||||
request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens
|
||||
self.waiting.add_request(request)
|
||||
self.requests[request.request_id] = request
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.QUEUED)
|
||||
|
||||
def schedule(self) -> RecomputeSchedulerOutput:
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||
@@ -408,7 +446,10 @@ class RecomputeScheduler(Scheduler):
|
||||
# We use `request.num_tokens` instead of
|
||||
# `request.num_prompt_tokens` to consider the resumed
|
||||
# requests, which have output tokens.
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
if self.is_mtp_kv_consumer:
|
||||
num_new_tokens = request.num_tokens_with_spec - num_computed_tokens
|
||||
else:
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
threshold = self.scheduler_config.long_prefill_token_threshold
|
||||
if 0 < threshold < num_new_tokens:
|
||||
num_new_tokens = threshold
|
||||
@@ -509,6 +550,25 @@ class RecomputeScheduler(Scheduler):
|
||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
continue
|
||||
|
||||
# For spec_token_ids, the waiting queue has the same processing
|
||||
# as the running queue.
|
||||
if self.is_mtp_kv_consumer and request.spec_token_ids:
|
||||
num_scheduled_spec_tokens = (
|
||||
num_new_tokens
|
||||
+ request.num_computed_tokens
|
||||
- request.num_tokens
|
||||
- request.num_output_placeholders
|
||||
)
|
||||
if num_scheduled_spec_tokens > 0:
|
||||
spec_token_ids = request.spec_token_ids
|
||||
if len(spec_token_ids) > num_scheduled_spec_tokens:
|
||||
spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens]
|
||||
scheduled_spec_decode_tokens[request.request_id] = spec_token_ids
|
||||
|
||||
# New spec tokens will be set in `update_draft_token_ids` before the
|
||||
# next step when applicable.
|
||||
request.spec_token_ids = []
|
||||
|
||||
self.running.append(request)
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
|
||||
|
||||
Reference in New Issue
Block a user