From dfa9ff7f2abf8e41fe2b681b6dfed697ce58bcfb Mon Sep 17 00:00:00 2001 From: wangxiaoteng888 <56506195+wangxiaoteng888@users.noreply.github.com> Date: Mon, 2 Mar 2026 23:24:03 +0800 Subject: [PATCH] [P/D][v0.16.0]Adapt to RecomputeScheduler in vLLM 0.16.0 (#6898) ### What this PR does / why we need it? Adapt the recompute feature to vLLM 0.16.0, where the D node forwards recompute requests to the P node. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? By ci - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7 --------- Signed-off-by: wangxiaoteng --- vllm_ascend/core/recompute_scheduler.py | 162 +++++++++++------------- 1 file changed, 74 insertions(+), 88 deletions(-) diff --git a/vllm_ascend/core/recompute_scheduler.py b/vllm_ascend/core/recompute_scheduler.py index a3358558..b1aaff7c 100644 --- a/vllm_ascend/core/recompute_scheduler.py +++ b/vllm_ascend/core/recompute_scheduler.py @@ -22,7 +22,6 @@ import time from collections import defaultdict from dataclasses import dataclass, fields -import numpy as np from vllm._bc_linter import bc_linter_include from vllm.config import SchedulerConfig, VllmConfig from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata @@ -40,7 +39,6 @@ from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutp from vllm.v1.metrics.perf import PerfStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus -from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.utils import ConstantList, record_function_or_nullcontext @@ -84,27 +82,6 @@ class RecomputeSchedulerOutput(SchedulerOutput): class RecomputeScheduler(Scheduler): running: list[Request] - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # When is_mtp_kv_consumer is true, we will fill request.spec_token_ids - # with placeholder tokens to enable full graph when decode nodes pull - # the KV cache of one request from prefill nodes. - self.is_mtp_kv_consumer = ( - self.vllm_config.speculative_config - and self.vllm_config.kv_transfer_config - and self.vllm_config.kv_transfer_config.is_kv_consumer - ) - - def add_request(self, request: Request) -> None: - # Fill in placeholder tokens to enable full graph compatibility. Without - # placeholders, graph matching may fail, forcing eager mode execution. - if self.is_mtp_kv_consumer: - request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens - self.waiting.add_request(request) - self.requests[request.request_id] = request - if self.log_stats: - request.record_event(EngineCoreEventType.QUEUED) - def schedule(self) -> RecomputeSchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -185,6 +162,9 @@ class RecomputeScheduler(Scheduler): shift_computed_tokens=1 if self.use_eagle else 0, ) + if self.need_mamba_block_aligned_split: + num_new_tokens = self._mamba_block_aligned_split(request, num_new_tokens) + if num_new_tokens == 0: # The request cannot be scheduled because one of the following # reasons: @@ -195,6 +175,8 @@ class RecomputeScheduler(Scheduler): # its max_total_tokens or max_model_len. # 2. The encoder budget is exhausted. # 3. The encoder cache is exhausted. + # 4. Insufficient budget for a block-aligned chunk in hybrid + # models with mamba cache mode \"align\". # NOTE(woosuk): Here, by doing `continue` instead of `break`, # we do not strictly follow the FCFS scheduling policy and # allow the lower-priority requests to be scheduled. @@ -237,12 +219,12 @@ class RecomputeScheduler(Scheduler): ) self.running.remove(preempted_req) if preempted_req in scheduled_running_reqs: + preempted_req_id = preempted_req.request_id scheduled_running_reqs.remove(preempted_req) - token_budget += num_scheduled_tokens[preempted_req.request_id] - req_to_new_blocks.pop(preempted_req.request_id) - num_scheduled_tokens.pop(preempted_req.request_id) - scheduled_spec_decode_tokens.pop(preempted_req.request_id, None) - preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req.request_id, None) + token_budget += num_scheduled_tokens.pop(preempted_req_id) + req_to_new_blocks.pop(preempted_req_id) + scheduled_spec_decode_tokens.pop(preempted_req_id, None) + preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req_id, None) if preempted_encoder_inputs: # Restore encoder compute budget if the preempted # request had encoder inputs scheduled in this step. @@ -266,8 +248,9 @@ class RecomputeScheduler(Scheduler): # Schedule the request. scheduled_running_reqs.append(request) - req_to_new_blocks[request.request_id] = new_blocks - num_scheduled_tokens[request.request_id] = num_new_tokens + request_id = request.request_id + req_to_new_blocks[request_id] = new_blocks + num_scheduled_tokens[request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -277,16 +260,18 @@ class RecomputeScheduler(Scheduler): num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders ) if num_scheduled_spec_tokens > 0: - # Trim spec_token_ids list to num_scheduled_spec_tokens. - del request.spec_token_ids[num_scheduled_spec_tokens:] - scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids + spec_token_ids = request.spec_token_ids + if len(spec_token_ids) > num_scheduled_spec_tokens: + spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens] + scheduled_spec_decode_tokens[request.request_id] = spec_token_ids + # New spec tokens will be set in `update_draft_token_ids` before the # next step when applicable. request.spec_token_ids = [] # Encoder-related. if encoder_inputs_to_schedule: - scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule + scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -318,6 +303,7 @@ class RecomputeScheduler(Scheduler): break request = self.waiting.peek_request() + request_id = request.request_id # KVTransfer: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: @@ -332,7 +318,7 @@ class RecomputeScheduler(Scheduler): else: logger.debug( "%s is still in WAITING_FOR_REMOTE_KVS state.", - request.request_id, + request_id, ) self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) @@ -349,6 +335,13 @@ class RecomputeScheduler(Scheduler): skipped_waiting_requests.prepend_request(request) continue + # Streaming: skip request if still waiting for next streaming req. + if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ: + assert not request.streaming_queue + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + # Check that adding the request still respects the max_loras # constraint. if ( @@ -366,6 +359,7 @@ class RecomputeScheduler(Scheduler): num_external_computed_tokens = 0 load_kv_async = False + connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0 # Get already-cached tokens. if request.num_computed_tokens == 0: @@ -391,6 +385,9 @@ class RecomputeScheduler(Scheduler): request.num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens + connector_prefix_cache_queries = request.num_tokens - num_new_local_computed_tokens + connector_prefix_cache_hits = num_external_computed_tokens + # Total computed tokens (local + external). num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens else: @@ -413,10 +410,7 @@ class RecomputeScheduler(Scheduler): # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. - if self.is_mtp_kv_consumer: - num_new_tokens = request.num_tokens_with_spec - num_computed_tokens - else: - num_new_tokens = request.num_tokens - num_computed_tokens + num_new_tokens = request.num_tokens - num_computed_tokens threshold = self.scheduler_config.long_prefill_token_threshold if 0 < threshold < num_new_tokens: num_new_tokens = threshold @@ -449,6 +443,16 @@ class RecomputeScheduler(Scheduler): # The request cannot be scheduled. break + if self.need_mamba_block_aligned_split: + num_new_tokens = self._mamba_block_aligned_split( + request, + num_new_tokens, + num_new_local_computed_tokens, + num_external_computed_tokens, + ) + if num_new_tokens == 0: + break + # Handles an edge case when P/D Disaggregation # is used with Spec Decoding where an # extra block gets allocated which @@ -487,9 +491,15 @@ class RecomputeScheduler(Scheduler): if self.connector is not None: self.connector.update_state_after_alloc( request, - self.kv_cache_manager.get_blocks(request.request_id), + self.kv_cache_manager.get_blocks(request_id), num_external_computed_tokens, ) + if self.connector_prefix_cache_stats is not None and connector_prefix_cache_queries != 0: + self.connector_prefix_cache_stats.record( + num_tokens=connector_prefix_cache_queries, + num_hits=connector_prefix_cache_hits, + preempted=request.num_preemptions > 0, + ) # Request was already popped from self.waiting # unless it was re-added above due to new_blocks being None. @@ -501,25 +511,6 @@ class RecomputeScheduler(Scheduler): request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue - self._update_connector_prefix_cache_stats(request) - - # For spec_token_ids, the waiting queue has the same processing - # as the running queue. - if self.is_mtp_kv_consumer and request.spec_token_ids: - num_scheduled_spec_tokens = ( - num_new_tokens - + request.num_computed_tokens - - request.num_tokens - - request.num_output_placeholders - ) - if num_scheduled_spec_tokens > 0: - # Trim spec_token_ids list to num_scheduled_spec_tokens. - del request.spec_token_ids[num_scheduled_spec_tokens:] - scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids - # New spec tokens will be set in `update_draft_token_ids` before the - # next step when applicable. - request.spec_token_ids = [] - self.running.append(request) if self.log_stats: request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) @@ -532,8 +523,8 @@ class RecomputeScheduler(Scheduler): if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id) - num_scheduled_tokens[request.request_id] = num_new_tokens + req_to_new_blocks[request_id] = self.kv_cache_manager.get_blocks(request_id) + num_scheduled_tokens[request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens @@ -542,7 +533,7 @@ class RecomputeScheduler(Scheduler): request.num_cached_tokens = num_computed_tokens # Encoder-related. if encoder_inputs_to_schedule: - scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule + scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -573,8 +564,8 @@ class RecomputeScheduler(Scheduler): num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"): if self.running: - any_request = self.running[0] - num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id) + any_request_id = self.running[0].request_id + num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request_id) # Construct the scheduler output. if self.use_v2_model_runner: @@ -644,7 +635,7 @@ class RecomputeScheduler(Scheduler): def update_from_output( self, - scheduler_output: RecomputeSchedulerOutput, + scheduler_output: SchedulerOutput, model_runner_output: ModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: sampled_token_ids = model_runner_output.sampled_token_ids @@ -700,17 +691,21 @@ class RecomputeScheduler(Scheduler): # skip failed or rescheduled requests from KV load failure continue request = self.requests.get(req_id) - if request is None: + if request is None or request.is_finished(): # The request is already finished. This can happen if the # request is aborted while the model is executing it (e.g., - # in pipeline parallelism). + # in pipeline parallelism or in async scheduling). + # NOTE(Kuntai): When delay_free_blocks=True (for async KV + # cache transfer in KV connector), the aborted request will not + # be set to None (in order to finish async KV transfer). + # In this case, we use is_finished() to check. continue req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] if sampled_token_ids else [] scheduled_spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id) - if scheduled_spec_token_ids: + if scheduled_spec_token_ids and generated_token_ids: num_draft_tokens = len(scheduled_spec_token_ids) num_accepted = len(generated_token_ids) - 1 num_rejected = num_draft_tokens - num_accepted @@ -749,27 +744,17 @@ class RecomputeScheduler(Scheduler): stopped = True routed_experts = None + finish_reason = None if stopped: - if self.vllm_config.model_config.enable_return_routed_experts: - kv_blocks = self.kv_cache_manager.get_blocks(request.request_id) - block_ids = kv_blocks.get_block_ids()[0] - num_tokens = request.num_tokens - 1 + routed_experts = self._get_routed_experts(request) - # compute slot mapping - block_ids_array = np.array(block_ids, dtype=np.int32) - num_blocks = len(block_ids) - block_size = self.block_size + # Capture finish_reason BEFORE _handle_stopped_request, which may + # reset the status to WAITING for streaming requests that continue. + finish_reason = request.get_finished_reason() + finished = self._handle_stopped_request(request) + if finished: + kv_transfer_params = self._free_request(request) - # generate block offsets - block_offsets = np.arange(0, block_size) - - # compute slot mapping: slot = block_id * block_size + offset - slot_mapping = ( - block_offsets.reshape((1, block_size)) + block_ids_array.reshape((num_blocks, 1)) * block_size - ).flatten()[:num_tokens] - - routed_experts = self.routed_experts_reader.get_routed_experts(indices=slot_mapping) - kv_transfer_params = self._free_request(request) if status_before_stop == RequestStatus.RUNNING: stopped_running_reqs.add(request) else: @@ -796,13 +781,13 @@ class RecomputeScheduler(Scheduler): # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids or pooler_output is not None or kv_transfer_params: + if new_token_ids or pooler_output is not None or kv_transfer_params or stopped: # Add EngineCoreOutput for this Request. outputs[request.client_index].append( EngineCoreOutput( request_id=req_id, new_token_ids=new_token_ids, - finish_reason=request.get_finished_reason(), + finish_reason=finish_reason, new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, pooling_output=pooler_output, @@ -811,6 +796,7 @@ class RecomputeScheduler(Scheduler): kv_transfer_params=kv_transfer_params, trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, + num_external_computed_tokens=request.num_external_computed_tokens, routed_experts=routed_experts, num_nans_in_logits=request.num_nans_in_logits, )