[v0.18.0][Misc] Recompute scheduler upgrade to vLLM 0.18.0 (#7720)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
cherry-pick from #7675 .
The current RecomputeScheduler is aligned to Scheduler in vLLM v0.16.0.
Since upstream vLLM has upgraded to v0.18.0, we also need to upgrade
RecomputeScheduler to pick up missing updates.

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
Angazenn
2026-03-27 18:24:53 +08:00
committed by GitHub
parent ab619e1c53
commit 7cca7e6990

View File

@@ -30,8 +30,12 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStat
from vllm.logger import logger from vllm.logger import logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.interface import PauseState
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.request_queue import (
SchedulingPolicy,
create_request_queue,
)
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.core.sched.utils import remove_all from vllm.v1.core.sched.utils import remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason
@@ -146,14 +150,14 @@ class RecomputeScheduler(Scheduler):
request.num_prompt_tokens -= 1 request.num_prompt_tokens -= 1
if self.is_mtp_kv_consumer: if self.is_mtp_kv_consumer:
request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens
self.waiting.add_request(request) self._enqueue_waiting_request(request)
self.requests[request.request_id] = request self.requests[request.request_id] = request
if self.log_stats: if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED) request.record_event(EngineCoreEventType.QUEUED)
def _update_waiting_for_remote_kv(self, request: Request) -> bool: def _update_waiting_for_remote_kv(self, request: Request) -> None:
""" """
KV Connector: check if the request_id is finished_recving. KV Connector: update request state after async recv is finished.
The finished_recving_kv_req_ids list is populated The finished_recving_kv_req_ids list is populated
on the previous steps()'s update_from_output based on the previous steps()'s update_from_output based
@@ -162,10 +166,13 @@ class RecomputeScheduler(Scheduler):
When the kv transfer is ready, we cache the blocks When the kv transfer is ready, we cache the blocks
and the request state will be moved back to WAITING from and the request state will be moved back to WAITING from
WAITING_FOR_REMOTE_KV. WAITING_FOR_REMOTE_KV.
NOTE: The check for whether request.request_id is in
finished_recving_kv_req_ids is now done by the caller
(_try_promote_blocked_waiting_request in the parent Scheduler),
so this method is only called when the recv is confirmed finished.
""" """
assert self.connector is not None assert self.connector is not None
if request.request_id not in self.finished_recving_kv_req_ids:
return False
if request.request_id in self.failed_recving_kv_req_ids: if request.request_id in self.failed_recving_kv_req_ids:
# Request had KV load failures; num_computed_tokens was already # Request had KV load failures; num_computed_tokens was already
@@ -181,6 +188,8 @@ class RecomputeScheduler(Scheduler):
self.failed_recving_kv_req_ids.remove(request.request_id) self.failed_recving_kv_req_ids.remove(request.request_id)
else: else:
# Now that the blocks are ready, actually cache them. # Now that the blocks are ready, actually cache them.
# Use Ascend-specific block_ids logic to handle multi-group KV
# cache configurations (e.g. MLA) where len(block_ids) > 1.
block_ids = self.kv_cache_manager.get_block_ids(request.request_id) block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
if len(block_ids) == 1: if len(block_ids) == 1:
num_computed_tokens = len(block_ids[0]) * self.block_size num_computed_tokens = len(block_ids[0]) * self.block_size
@@ -188,6 +197,8 @@ class RecomputeScheduler(Scheduler):
num_computed_tokens = min(num_computed_tokens, request.num_tokens) num_computed_tokens = min(num_computed_tokens, request.num_tokens)
else: else:
num_computed_tokens = request.num_tokens num_computed_tokens = request.num_tokens
# on a full prompt hit, we need to re-compute the last token
# in order to be able to sample the next token
if num_computed_tokens == request.num_tokens: if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1 num_computed_tokens -= 1
# This will cache the blocks iff caching is enabled. # This will cache the blocks iff caching is enabled.
@@ -196,9 +207,11 @@ class RecomputeScheduler(Scheduler):
# Update the request state for scheduling. # Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
# Return that we are ready. # Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = request.num_computed_tokens
self.finished_recving_kv_req_ids.remove(request.request_id) self.finished_recving_kv_req_ids.remove(request.request_id)
return True
def schedule(self) -> RecomputeSchedulerOutput: def schedule(self) -> RecomputeSchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm: # NOTE(woosuk) on the scheduling algorithm:
@@ -230,6 +243,12 @@ class RecomputeScheduler(Scheduler):
# For logging. # For logging.
scheduled_timestamp = time.monotonic() scheduled_timestamp = time.monotonic()
self.kv_cache_manager.new_step_starts()
if self._pause_state == PauseState.PAUSED_ALL:
# Do not schedule any requests when paused.
token_budget = 0
# First, schedule the RUNNING requests. # First, schedule the RUNNING requests.
req_index = 0 req_index = 0
while req_index < len(self.running) and token_budget > 0: while req_index < len(self.running) and token_budget > 0:
@@ -393,6 +412,8 @@ class RecomputeScheduler(Scheduler):
# Allocate the encoder cache. # Allocate the encoder cache.
for i in encoder_inputs_to_schedule: for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
encoder_compute_budget = new_encoder_compute_budget encoder_compute_budget = new_encoder_compute_budget
if external_load_encoder_input: if external_load_encoder_input:
for i in external_load_encoder_input: for i in external_load_encoder_input:
@@ -410,54 +431,31 @@ class RecomputeScheduler(Scheduler):
) )
assert len(scheduled_loras) <= self.lora_config.max_loras assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
# Next, schedule the WAITING requests. # Next, schedule the WAITING requests.
if not preempted_reqs and not recomputed_reqs: if not preempted_reqs and not recomputed_reqs and self._pause_state == PauseState.UNPAUSED:
while self.waiting and token_budget > 0: step_skipped_waiting = create_request_queue(self.policy)
while (self.waiting or self.skipped_waiting) and token_budget > 0:
if len(self.running) == self.max_num_running_reqs: if len(self.running) == self.max_num_running_reqs:
break break
request = self.waiting.peek_request() request_queue = self._select_waiting_queue_for_scheduling()
assert request_queue is not None
request = request_queue.peek_request()
request_id = request.request_id request_id = request.request_id
# KVTransfer: skip request if still waiting for remote kvs. # try to promote blocked statuses while traversing skipped queue.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if self._is_blocked_waiting_status(request.status) and not self._try_promote_blocked_waiting_request(
is_ready = self._update_waiting_for_remote_kv(request) request
if is_ready: ):
if request.num_preemptions: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
# We must be loading for a resumed preemption
# rather than a new request.
request.status = RequestStatus.PREEMPTED
else:
request.status = RequestStatus.WAITING
else:
logger.debug( logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.", "%s is still in WAITING_FOR_REMOTE_KVS state.",
request_id, request_id,
) )
self.waiting.pop_request() request_queue.pop_request()
skipped_waiting_requests.prepend_request(request) step_skipped_waiting.prepend_request(request)
continue
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Streaming: skip request if still waiting for next streaming req.
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
assert not request.streaming_queue
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue continue
# Check that adding the request still respects the max_loras # Check that adding the request still respects the max_loras
@@ -471,8 +469,8 @@ class RecomputeScheduler(Scheduler):
) )
): ):
# Scheduling would exceed max_loras, skip. # Scheduling would exceed max_loras, skip.
self.waiting.pop_request() request_queue.pop_request()
skipped_waiting_requests.prepend_request(request) step_skipped_waiting.prepend_request(request)
continue continue
num_external_computed_tokens = 0 num_external_computed_tokens = 0
@@ -496,8 +494,8 @@ class RecomputeScheduler(Scheduler):
# The request cannot be scheduled because # The request cannot be scheduled because
# the KVConnector couldn't determine # the KVConnector couldn't determine
# the number of matched tokens. # the number of matched tokens.
self.waiting.pop_request() request_queue.pop_request()
skipped_waiting_requests.prepend_request(request) step_skipped_waiting.prepend_request(request)
continue continue
request.num_external_computed_tokens = ext_tokens request.num_external_computed_tokens = ext_tokens
@@ -508,6 +506,7 @@ class RecomputeScheduler(Scheduler):
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
assert num_computed_tokens <= request.num_tokens
else: else:
# KVTransfer: WAITING reqs have num_computed_tokens > 0 # KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed. # after async KV recvs are completed.
@@ -581,9 +580,10 @@ class RecomputeScheduler(Scheduler):
# of local and remote blocks. # of local and remote blocks.
effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
num_encoder_tokens = ( # Determine if we need to allocate cross-attention blocks.
self._num_encoder_max_input_tokens if self.is_encoder_decoder and request.has_encoder_inputs else 0 num_encoder_tokens = 0
) if self.is_encoder_decoder and request.has_encoder_inputs and encoder_inputs_to_schedule:
num_encoder_tokens = sum(request.get_num_encoder_embeds(i) for i in encoder_inputs_to_schedule)
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request,
@@ -622,14 +622,26 @@ class RecomputeScheduler(Scheduler):
preempted=request.num_preemptions > 0, preempted=request.num_preemptions > 0,
) )
# Request was already popped from self.waiting request = request_queue.pop_request()
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
if load_kv_async: if load_kv_async:
# If loading async, allocate memory and put request # If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state. # into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
step_skipped_waiting.prepend_request(request)
# Set num_computed_tokens even though KVs are not yet loaded.
# request.num_computed_tokens will not be used anywhere until
# the request finished the KV transfer.
#
# If a transfer error is reported by the connector,
# request.num_computed_tokens will be re-set accordingly in
# _update_requests_with_invalid_blocks.
#
# When the transfer is finished, either successfully or not,
# request.num_computed_tokens will correctly reflect the number
# of computed tokens.
# _update_waiting_for_remote_kv will then cache
# only the successfully loaded tokens.
request.num_computed_tokens = num_computed_tokens
continue continue
# For spec_token_ids, the waiting queue has the same processing # For spec_token_ids, the waiting queue has the same processing
@@ -677,6 +689,8 @@ class RecomputeScheduler(Scheduler):
# Allocate the encoder cache. # Allocate the encoder cache.
for i in encoder_inputs_to_schedule: for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
encoder_compute_budget = new_encoder_compute_budget encoder_compute_budget = new_encoder_compute_budget
# Allocate for external load encoder cache # Allocate for external load encoder cache
if external_load_encoder_input: if external_load_encoder_input:
@@ -684,9 +698,10 @@ class RecomputeScheduler(Scheduler):
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None: if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i) self.ec_connector.update_state_after_alloc(request, i)
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests: # re-queue requests skipped in this pass ahead of older skipped items.
self.waiting.prepend_requests(skipped_waiting_requests) if step_skipped_waiting:
self.skipped_waiting.prepend_requests(step_skipped_waiting)
# Check if the scheduling constraints are satisfied. # Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
@@ -738,6 +753,10 @@ class RecomputeScheduler(Scheduler):
self.prev_step_scheduled_req_ids.clear() self.prev_step_scheduled_req_ids.clear()
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
new_block_ids_to_zero = (
(self.kv_cache_manager.take_new_block_ids() or None) if self.needs_kv_cache_zeroing else None
)
scheduler_output = RecomputeSchedulerOutput( scheduler_output = RecomputeSchedulerOutput(
scheduled_new_reqs=new_reqs_data, scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data, scheduled_cached_reqs=cached_reqs_data,
@@ -753,6 +772,7 @@ class RecomputeScheduler(Scheduler):
# the previous and the current steps. # the previous and the current steps.
finished_req_ids=self.finished_req_ids, finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
new_block_ids_to_zero=new_block_ids_to_zero,
recomputed_reqs=recomputed_reqs, recomputed_reqs=recomputed_reqs,
) )