[v0.16.0][P/D][Bugfix] Support ALL D-Nodes in fullgraph when running MTP in PD (#6948)
### What this PR does / why we need it? Fix the bug for v0.16.0 recompute_scheduler, the same way as https://github.com/vllm-project/vllm-ascend/pull/5472. Signed-off-by: chenmenglong <chenmenglong1@huawei.com>
This commit is contained in:
@@ -19,7 +19,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict, deque
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
|
|
||||||
from vllm.config import SchedulerConfig, VllmConfig
|
from vllm.config import SchedulerConfig, VllmConfig
|
||||||
@@ -37,7 +37,8 @@ from vllm.v1.core.sched.utils import remove_all
|
|||||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason
|
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason
|
||||||
from vllm.v1.metrics.perf import PerfStats
|
from vllm.v1.metrics.perf import PerfStats
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus, StreamingUpdate
|
||||||
|
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID
|
||||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||||
from vllm.v1.utils import ConstantList, record_function_or_nullcontext
|
from vllm.v1.utils import ConstantList, record_function_or_nullcontext
|
||||||
|
|
||||||
@@ -80,6 +81,43 @@ class RecomputeSchedulerOutput(SchedulerOutput):
|
|||||||
class RecomputeScheduler(Scheduler):
|
class RecomputeScheduler(Scheduler):
|
||||||
running: list[Request]
|
running: list[Request]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# When is_mtp_kv_consumer is true, we will fill request.spec_token_ids
|
||||||
|
# with placeholder tokens to enable full graph when decode nodes pull
|
||||||
|
# the KV cache of one request from prefill nodes.
|
||||||
|
self.is_mtp_kv_consumer = (
|
||||||
|
self.vllm_config.speculative_config
|
||||||
|
and self.vllm_config.kv_transfer_config
|
||||||
|
and self.vllm_config.kv_transfer_config.is_kv_consumer
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_request(self, request: Request) -> None:
|
||||||
|
existing = self.requests.get(request.request_id)
|
||||||
|
if existing is not None:
|
||||||
|
update = StreamingUpdate.from_request(request)
|
||||||
|
if existing.status != RequestStatus.WAITING_FOR_STREAMING_REQ:
|
||||||
|
assert existing.streaming_queue is not None, "duplicate request id"
|
||||||
|
# Queue next input chunk (or finished sentinel).
|
||||||
|
existing.streaming_queue.append(update)
|
||||||
|
elif update is not None:
|
||||||
|
# Commence next input chunk.
|
||||||
|
self._update_request_as_session(existing, update)
|
||||||
|
else:
|
||||||
|
# Streaming-input session finished.
|
||||||
|
self.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED)
|
||||||
|
else:
|
||||||
|
if request.resumable:
|
||||||
|
request.streaming_queue = deque()
|
||||||
|
# Fill in placeholder tokens to enable full graph compatibility. Without
|
||||||
|
# placeholders, graph matching may fail, forcing eager mode execution.
|
||||||
|
if self.is_mtp_kv_consumer:
|
||||||
|
request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens
|
||||||
|
self.waiting.add_request(request)
|
||||||
|
self.requests[request.request_id] = request
|
||||||
|
if self.log_stats:
|
||||||
|
request.record_event(EngineCoreEventType.QUEUED)
|
||||||
|
|
||||||
def schedule(self) -> RecomputeSchedulerOutput:
|
def schedule(self) -> RecomputeSchedulerOutput:
|
||||||
# NOTE(woosuk) on the scheduling algorithm:
|
# NOTE(woosuk) on the scheduling algorithm:
|
||||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||||
@@ -408,7 +446,10 @@ class RecomputeScheduler(Scheduler):
|
|||||||
# We use `request.num_tokens` instead of
|
# We use `request.num_tokens` instead of
|
||||||
# `request.num_prompt_tokens` to consider the resumed
|
# `request.num_prompt_tokens` to consider the resumed
|
||||||
# requests, which have output tokens.
|
# requests, which have output tokens.
|
||||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
if self.is_mtp_kv_consumer:
|
||||||
|
num_new_tokens = request.num_tokens_with_spec - num_computed_tokens
|
||||||
|
else:
|
||||||
|
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||||
threshold = self.scheduler_config.long_prefill_token_threshold
|
threshold = self.scheduler_config.long_prefill_token_threshold
|
||||||
if 0 < threshold < num_new_tokens:
|
if 0 < threshold < num_new_tokens:
|
||||||
num_new_tokens = threshold
|
num_new_tokens = threshold
|
||||||
@@ -509,6 +550,25 @@ class RecomputeScheduler(Scheduler):
|
|||||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# For spec_token_ids, the waiting queue has the same processing
|
||||||
|
# as the running queue.
|
||||||
|
if self.is_mtp_kv_consumer and request.spec_token_ids:
|
||||||
|
num_scheduled_spec_tokens = (
|
||||||
|
num_new_tokens
|
||||||
|
+ request.num_computed_tokens
|
||||||
|
- request.num_tokens
|
||||||
|
- request.num_output_placeholders
|
||||||
|
)
|
||||||
|
if num_scheduled_spec_tokens > 0:
|
||||||
|
spec_token_ids = request.spec_token_ids
|
||||||
|
if len(spec_token_ids) > num_scheduled_spec_tokens:
|
||||||
|
spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens]
|
||||||
|
scheduled_spec_decode_tokens[request.request_id] = spec_token_ids
|
||||||
|
|
||||||
|
# New spec tokens will be set in `update_draft_token_ids` before the
|
||||||
|
# next step when applicable.
|
||||||
|
request.spec_token_ids = []
|
||||||
|
|
||||||
self.running.append(request)
|
self.running.append(request)
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
|
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
|
||||||
|
|||||||
Reference in New Issue
Block a user