[P/D][v0.16.0]Adapt to RecomputeScheduler in vLLM 0.16.0 (#6898)

### What this PR does / why we need it?
Adapt the recompute feature to vLLM 0.16.0, where the D node forwards
recompute requests to the P node.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
By ci
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
wangxiaoteng888
2026-03-02 23:24:03 +08:00
committed by GitHub
parent 5899438a86
commit dfa9ff7f2a

View File

@@ -22,7 +22,6 @@ import time
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
import numpy as np
from vllm._bc_linter import bc_linter_include from vllm._bc_linter import bc_linter_include
from vllm.config import SchedulerConfig, VllmConfig from vllm.config import SchedulerConfig, VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
@@ -40,7 +39,6 @@ from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutp
from vllm.v1.metrics.perf import PerfStats from vllm.v1.metrics.perf import PerfStats
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.utils import ConstantList, record_function_or_nullcontext from vllm.v1.utils import ConstantList, record_function_or_nullcontext
@@ -84,27 +82,6 @@ class RecomputeSchedulerOutput(SchedulerOutput):
class RecomputeScheduler(Scheduler): class RecomputeScheduler(Scheduler):
running: list[Request] running: list[Request]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# When is_mtp_kv_consumer is true, we will fill request.spec_token_ids
# with placeholder tokens to enable full graph when decode nodes pull
# the KV cache of one request from prefill nodes.
self.is_mtp_kv_consumer = (
self.vllm_config.speculative_config
and self.vllm_config.kv_transfer_config
and self.vllm_config.kv_transfer_config.is_kv_consumer
)
def add_request(self, request: Request) -> None:
# Fill in placeholder tokens to enable full graph compatibility. Without
# placeholders, graph matching may fail, forcing eager mode execution.
if self.is_mtp_kv_consumer:
request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens
self.waiting.add_request(request)
self.requests[request.request_id] = request
if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED)
def schedule(self) -> RecomputeSchedulerOutput: def schedule(self) -> RecomputeSchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm: # NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler. # There's no "decoding phase" nor "prefill phase" in the scheduler.
@@ -185,6 +162,9 @@ class RecomputeScheduler(Scheduler):
shift_computed_tokens=1 if self.use_eagle else 0, shift_computed_tokens=1 if self.use_eagle else 0,
) )
if self.need_mamba_block_aligned_split:
num_new_tokens = self._mamba_block_aligned_split(request, num_new_tokens)
if num_new_tokens == 0: if num_new_tokens == 0:
# The request cannot be scheduled because one of the following # The request cannot be scheduled because one of the following
# reasons: # reasons:
@@ -195,6 +175,8 @@ class RecomputeScheduler(Scheduler):
# its max_total_tokens or max_model_len. # its max_total_tokens or max_model_len.
# 2. The encoder budget is exhausted. # 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted. # 3. The encoder cache is exhausted.
# 4. Insufficient budget for a block-aligned chunk in hybrid
# models with mamba cache mode \"align\".
# NOTE(woosuk): Here, by doing `continue` instead of `break`, # NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and # we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled. # allow the lower-priority requests to be scheduled.
@@ -237,12 +219,12 @@ class RecomputeScheduler(Scheduler):
) )
self.running.remove(preempted_req) self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs: if preempted_req in scheduled_running_reqs:
preempted_req_id = preempted_req.request_id
scheduled_running_reqs.remove(preempted_req) scheduled_running_reqs.remove(preempted_req)
token_budget += num_scheduled_tokens[preempted_req.request_id] token_budget += num_scheduled_tokens.pop(preempted_req_id)
req_to_new_blocks.pop(preempted_req.request_id) req_to_new_blocks.pop(preempted_req_id)
num_scheduled_tokens.pop(preempted_req.request_id) scheduled_spec_decode_tokens.pop(preempted_req_id, None)
scheduled_spec_decode_tokens.pop(preempted_req.request_id, None) preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req_id, None)
preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req.request_id, None)
if preempted_encoder_inputs: if preempted_encoder_inputs:
# Restore encoder compute budget if the preempted # Restore encoder compute budget if the preempted
# request had encoder inputs scheduled in this step. # request had encoder inputs scheduled in this step.
@@ -266,8 +248,9 @@ class RecomputeScheduler(Scheduler):
# Schedule the request. # Schedule the request.
scheduled_running_reqs.append(request) scheduled_running_reqs.append(request)
req_to_new_blocks[request.request_id] = new_blocks request_id = request.request_id
num_scheduled_tokens[request.request_id] = num_new_tokens req_to_new_blocks[request_id] = new_blocks
num_scheduled_tokens[request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
req_index += 1 req_index += 1
@@ -277,16 +260,18 @@ class RecomputeScheduler(Scheduler):
num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders
) )
if num_scheduled_spec_tokens > 0: if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens. spec_token_ids = request.spec_token_ids
del request.spec_token_ids[num_scheduled_spec_tokens:] if len(spec_token_ids) > num_scheduled_spec_tokens:
scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens]
scheduled_spec_decode_tokens[request.request_id] = spec_token_ids
# New spec tokens will be set in `update_draft_token_ids` before the # New spec tokens will be set in `update_draft_token_ids` before the
# next step when applicable. # next step when applicable.
request.spec_token_ids = [] request.spec_token_ids = []
# Encoder-related. # Encoder-related.
if encoder_inputs_to_schedule: if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
# Allocate the encoder cache. # Allocate the encoder cache.
for i in encoder_inputs_to_schedule: for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
@@ -318,6 +303,7 @@ class RecomputeScheduler(Scheduler):
break break
request = self.waiting.peek_request() request = self.waiting.peek_request()
request_id = request.request_id
# KVTransfer: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
@@ -332,7 +318,7 @@ class RecomputeScheduler(Scheduler):
else: else:
logger.debug( logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.", "%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id, request_id,
) )
self.waiting.pop_request() self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
@@ -349,6 +335,13 @@ class RecomputeScheduler(Scheduler):
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
continue continue
# Streaming: skip request if still waiting for next streaming req.
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
assert not request.streaming_queue
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Check that adding the request still respects the max_loras # Check that adding the request still respects the max_loras
# constraint. # constraint.
if ( if (
@@ -366,6 +359,7 @@ class RecomputeScheduler(Scheduler):
num_external_computed_tokens = 0 num_external_computed_tokens = 0
load_kv_async = False load_kv_async = False
connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0
# Get already-cached tokens. # Get already-cached tokens.
if request.num_computed_tokens == 0: if request.num_computed_tokens == 0:
@@ -391,6 +385,9 @@ class RecomputeScheduler(Scheduler):
request.num_external_computed_tokens = ext_tokens request.num_external_computed_tokens = ext_tokens
num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens
connector_prefix_cache_queries = request.num_tokens - num_new_local_computed_tokens
connector_prefix_cache_hits = num_external_computed_tokens
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
else: else:
@@ -413,10 +410,7 @@ class RecomputeScheduler(Scheduler):
# We use `request.num_tokens` instead of # We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed # `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens. # requests, which have output tokens.
if self.is_mtp_kv_consumer: num_new_tokens = request.num_tokens - num_computed_tokens
num_new_tokens = request.num_tokens_with_spec - num_computed_tokens
else:
num_new_tokens = request.num_tokens - num_computed_tokens
threshold = self.scheduler_config.long_prefill_token_threshold threshold = self.scheduler_config.long_prefill_token_threshold
if 0 < threshold < num_new_tokens: if 0 < threshold < num_new_tokens:
num_new_tokens = threshold num_new_tokens = threshold
@@ -449,6 +443,16 @@ class RecomputeScheduler(Scheduler):
# The request cannot be scheduled. # The request cannot be scheduled.
break break
if self.need_mamba_block_aligned_split:
num_new_tokens = self._mamba_block_aligned_split(
request,
num_new_tokens,
num_new_local_computed_tokens,
num_external_computed_tokens,
)
if num_new_tokens == 0:
break
# Handles an edge case when P/D Disaggregation # Handles an edge case when P/D Disaggregation
# is used with Spec Decoding where an # is used with Spec Decoding where an
# extra block gets allocated which # extra block gets allocated which
@@ -487,9 +491,15 @@ class RecomputeScheduler(Scheduler):
if self.connector is not None: if self.connector is not None:
self.connector.update_state_after_alloc( self.connector.update_state_after_alloc(
request, request,
self.kv_cache_manager.get_blocks(request.request_id), self.kv_cache_manager.get_blocks(request_id),
num_external_computed_tokens, num_external_computed_tokens,
) )
if self.connector_prefix_cache_stats is not None and connector_prefix_cache_queries != 0:
self.connector_prefix_cache_stats.record(
num_tokens=connector_prefix_cache_queries,
num_hits=connector_prefix_cache_hits,
preempted=request.num_preemptions > 0,
)
# Request was already popped from self.waiting # Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None. # unless it was re-added above due to new_blocks being None.
@@ -501,25 +511,6 @@ class RecomputeScheduler(Scheduler):
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue continue
self._update_connector_prefix_cache_stats(request)
# For spec_token_ids, the waiting queue has the same processing
# as the running queue.
if self.is_mtp_kv_consumer and request.spec_token_ids:
num_scheduled_spec_tokens = (
num_new_tokens
+ request.num_computed_tokens
- request.num_tokens
- request.num_output_placeholders
)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids
# New spec tokens will be set in `update_draft_token_ids` before the
# next step when applicable.
request.spec_token_ids = []
self.running.append(request) self.running.append(request)
if self.log_stats: if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
@@ -532,8 +523,8 @@ class RecomputeScheduler(Scheduler):
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id) scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id) req_to_new_blocks[request_id] = self.kv_cache_manager.get_blocks(request_id)
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
@@ -542,7 +533,7 @@ class RecomputeScheduler(Scheduler):
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
# Encoder-related. # Encoder-related.
if encoder_inputs_to_schedule: if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
# Allocate the encoder cache. # Allocate the encoder cache.
for i in encoder_inputs_to_schedule: for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
@@ -573,8 +564,8 @@ class RecomputeScheduler(Scheduler):
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"): with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
if self.running: if self.running:
any_request = self.running[0] any_request_id = self.running[0].request_id
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id) num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request_id)
# Construct the scheduler output. # Construct the scheduler output.
if self.use_v2_model_runner: if self.use_v2_model_runner:
@@ -644,7 +635,7 @@ class RecomputeScheduler(Scheduler):
def update_from_output( def update_from_output(
self, self,
scheduler_output: RecomputeSchedulerOutput, scheduler_output: SchedulerOutput,
model_runner_output: ModelRunnerOutput, model_runner_output: ModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]: ) -> dict[int, EngineCoreOutputs]:
sampled_token_ids = model_runner_output.sampled_token_ids sampled_token_ids = model_runner_output.sampled_token_ids
@@ -700,17 +691,21 @@ class RecomputeScheduler(Scheduler):
# skip failed or rescheduled requests from KV load failure # skip failed or rescheduled requests from KV load failure
continue continue
request = self.requests.get(req_id) request = self.requests.get(req_id)
if request is None: if request is None or request.is_finished():
# The request is already finished. This can happen if the # The request is already finished. This can happen if the
# request is aborted while the model is executing it (e.g., # request is aborted while the model is executing it (e.g.,
# in pipeline parallelism). # in pipeline parallelism or in async scheduling).
# NOTE(Kuntai): When delay_free_blocks=True (for async KV
# cache transfer in KV connector), the aborted request will not
# be set to None (in order to finish async KV transfer).
# In this case, we use is_finished() to check.
continue continue
req_index = model_runner_output.req_id_to_index[req_id] req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index] if sampled_token_ids else [] generated_token_ids = sampled_token_ids[req_index] if sampled_token_ids else []
scheduled_spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id) scheduled_spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
if scheduled_spec_token_ids: if scheduled_spec_token_ids and generated_token_ids:
num_draft_tokens = len(scheduled_spec_token_ids) num_draft_tokens = len(scheduled_spec_token_ids)
num_accepted = len(generated_token_ids) - 1 num_accepted = len(generated_token_ids) - 1
num_rejected = num_draft_tokens - num_accepted num_rejected = num_draft_tokens - num_accepted
@@ -749,27 +744,17 @@ class RecomputeScheduler(Scheduler):
stopped = True stopped = True
routed_experts = None routed_experts = None
finish_reason = None
if stopped: if stopped:
if self.vllm_config.model_config.enable_return_routed_experts: routed_experts = self._get_routed_experts(request)
kv_blocks = self.kv_cache_manager.get_blocks(request.request_id)
block_ids = kv_blocks.get_block_ids()[0]
num_tokens = request.num_tokens - 1
# compute slot mapping # Capture finish_reason BEFORE _handle_stopped_request, which may
block_ids_array = np.array(block_ids, dtype=np.int32) # reset the status to WAITING for streaming requests that continue.
num_blocks = len(block_ids) finish_reason = request.get_finished_reason()
block_size = self.block_size finished = self._handle_stopped_request(request)
if finished:
kv_transfer_params = self._free_request(request)
# generate block offsets
block_offsets = np.arange(0, block_size)
# compute slot mapping: slot = block_id * block_size + offset
slot_mapping = (
block_offsets.reshape((1, block_size)) + block_ids_array.reshape((num_blocks, 1)) * block_size
).flatten()[:num_tokens]
routed_experts = self.routed_experts_reader.get_routed_experts(indices=slot_mapping)
kv_transfer_params = self._free_request(request)
if status_before_stop == RequestStatus.RUNNING: if status_before_stop == RequestStatus.RUNNING:
stopped_running_reqs.add(request) stopped_running_reqs.add(request)
else: else:
@@ -796,13 +781,13 @@ class RecomputeScheduler(Scheduler):
# Get prompt logprobs for this request. # Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None or kv_transfer_params: if new_token_ids or pooler_output is not None or kv_transfer_params or stopped:
# Add EngineCoreOutput for this Request. # Add EngineCoreOutput for this Request.
outputs[request.client_index].append( outputs[request.client_index].append(
EngineCoreOutput( EngineCoreOutput(
request_id=req_id, request_id=req_id,
new_token_ids=new_token_ids, new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(), finish_reason=finish_reason,
new_logprobs=new_logprobs, new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors, new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output, pooling_output=pooler_output,
@@ -811,6 +796,7 @@ class RecomputeScheduler(Scheduler):
kv_transfer_params=kv_transfer_params, kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers, trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens, num_cached_tokens=request.num_cached_tokens,
num_external_computed_tokens=request.num_external_computed_tokens,
routed_experts=routed_experts, routed_experts=routed_experts,
num_nans_in_logits=request.num_nans_in_logits, num_nans_in_logits=request.num_nans_in_logits,
) )