diff --git a/vllm_ascend/core/recompute_scheduler.py b/vllm_ascend/core/recompute_scheduler.py index c4f39f02..b83ced7e 100644 --- a/vllm_ascend/core/recompute_scheduler.py +++ b/vllm_ascend/core/recompute_scheduler.py @@ -19,7 +19,7 @@ from __future__ import annotations import time -from collections import defaultdict +from collections import defaultdict, deque from dataclasses import dataclass, fields from vllm.config import SchedulerConfig, VllmConfig @@ -37,7 +37,8 @@ from vllm.v1.core.sched.utils import remove_all from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason from vllm.v1.metrics.perf import PerfStats from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.request import Request, RequestStatus +from vllm.v1.request import Request, RequestStatus, StreamingUpdate +from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.utils import ConstantList, record_function_or_nullcontext @@ -80,6 +81,43 @@ class RecomputeSchedulerOutput(SchedulerOutput): class RecomputeScheduler(Scheduler): running: list[Request] + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # When is_mtp_kv_consumer is true, we will fill request.spec_token_ids + # with placeholder tokens to enable full graph when decode nodes pull + # the KV cache of one request from prefill nodes. + self.is_mtp_kv_consumer = ( + self.vllm_config.speculative_config + and self.vllm_config.kv_transfer_config + and self.vllm_config.kv_transfer_config.is_kv_consumer + ) + + def add_request(self, request: Request) -> None: + existing = self.requests.get(request.request_id) + if existing is not None: + update = StreamingUpdate.from_request(request) + if existing.status != RequestStatus.WAITING_FOR_STREAMING_REQ: + assert existing.streaming_queue is not None, "duplicate request id" + # Queue next input chunk (or finished sentinel). + existing.streaming_queue.append(update) + elif update is not None: + # Commence next input chunk. + self._update_request_as_session(existing, update) + else: + # Streaming-input session finished. + self.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED) + else: + if request.resumable: + request.streaming_queue = deque() + # Fill in placeholder tokens to enable full graph compatibility. Without + # placeholders, graph matching may fail, forcing eager mode execution. + if self.is_mtp_kv_consumer: + request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens + self.waiting.add_request(request) + self.requests[request.request_id] = request + if self.log_stats: + request.record_event(EngineCoreEventType.QUEUED) + def schedule(self) -> RecomputeSchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -408,7 +446,10 @@ class RecomputeScheduler(Scheduler): # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. - num_new_tokens = request.num_tokens - num_computed_tokens + if self.is_mtp_kv_consumer: + num_new_tokens = request.num_tokens_with_spec - num_computed_tokens + else: + num_new_tokens = request.num_tokens - num_computed_tokens threshold = self.scheduler_config.long_prefill_token_threshold if 0 < threshold < num_new_tokens: num_new_tokens = threshold @@ -509,6 +550,25 @@ class RecomputeScheduler(Scheduler): request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue + # For spec_token_ids, the waiting queue has the same processing + # as the running queue. + if self.is_mtp_kv_consumer and request.spec_token_ids: + num_scheduled_spec_tokens = ( + num_new_tokens + + request.num_computed_tokens + - request.num_tokens + - request.num_output_placeholders + ) + if num_scheduled_spec_tokens > 0: + spec_token_ids = request.spec_token_ids + if len(spec_token_ids) > num_scheduled_spec_tokens: + spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens] + scheduled_spec_decode_tokens[request.request_id] = spec_token_ids + + # New spec tokens will be set in `update_draft_token_ids` before the + # next step when applicable. + request.spec_token_ids = [] + self.running.append(request) if self.log_stats: request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)