diff --git a/vllm_ascend/core/recompute_scheduler.py b/vllm_ascend/core/recompute_scheduler.py index 90ab413a..a3358558 100644 --- a/vllm_ascend/core/recompute_scheduler.py +++ b/vllm_ascend/core/recompute_scheduler.py @@ -40,6 +40,7 @@ from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutp from vllm.v1.metrics.perf import PerfStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus +from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.utils import ConstantList, record_function_or_nullcontext @@ -83,6 +84,27 @@ class RecomputeSchedulerOutput(SchedulerOutput): class RecomputeScheduler(Scheduler): running: list[Request] + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # When is_mtp_kv_consumer is true, we will fill request.spec_token_ids + # with placeholder tokens to enable full graph when decode nodes pull + # the KV cache of one request from prefill nodes. + self.is_mtp_kv_consumer = ( + self.vllm_config.speculative_config + and self.vllm_config.kv_transfer_config + and self.vllm_config.kv_transfer_config.is_kv_consumer + ) + + def add_request(self, request: Request) -> None: + # Fill in placeholder tokens to enable full graph compatibility. Without + # placeholders, graph matching may fail, forcing eager mode execution. + if self.is_mtp_kv_consumer: + request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens + self.waiting.add_request(request) + self.requests[request.request_id] = request + if self.log_stats: + request.record_event(EngineCoreEventType.QUEUED) + def schedule(self) -> RecomputeSchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -391,7 +413,10 @@ class RecomputeScheduler(Scheduler): # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. - num_new_tokens = request.num_tokens - num_computed_tokens + if self.is_mtp_kv_consumer: + num_new_tokens = request.num_tokens_with_spec - num_computed_tokens + else: + num_new_tokens = request.num_tokens - num_computed_tokens threshold = self.scheduler_config.long_prefill_token_threshold if 0 < threshold < num_new_tokens: num_new_tokens = threshold @@ -478,6 +503,23 @@ class RecomputeScheduler(Scheduler): self._update_connector_prefix_cache_stats(request) + # For spec_token_ids, the waiting queue has the same processing + # as the running queue. + if self.is_mtp_kv_consumer and request.spec_token_ids: + num_scheduled_spec_tokens = ( + num_new_tokens + + request.num_computed_tokens + - request.num_tokens + - request.num_output_placeholders + ) + if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] + scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids + # New spec tokens will be set in `update_draft_token_ids` before the + # next step when applicable. + request.spec_token_ids = [] + self.running.append(request) if self.log_stats: request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)