[v0.18.0][Misc] Recompute scheduler upgrade to vLLM 0.18.0 (#7720)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
cherry-pick from #7675 .
The current RecomputeScheduler is aligned to Scheduler in vLLM v0.16.0.
Since upstream vLLM has upgraded to v0.18.0, we also need to upgrade
RecomputeScheduler to pick up missing updates.

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
Angazenn
2026-03-27 18:24:53 +08:00
committed by GitHub
parent ab619e1c53
commit 7cca7e6990

View File

@@ -30,8 +30,12 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStat
from vllm.logger import logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.interface import PauseState
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.request_queue import (
SchedulingPolicy,
create_request_queue,
)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.core.sched.utils import remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason
@@ -146,14 +150,14 @@ class RecomputeScheduler(Scheduler):
request.num_prompt_tokens -= 1
if self.is_mtp_kv_consumer:
request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens
self.waiting.add_request(request)
self._enqueue_waiting_request(request)
self.requests[request.request_id] = request
if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED)
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
def _update_waiting_for_remote_kv(self, request: Request) -> None:
"""
KV Connector: check if the request_id is finished_recving.
KV Connector: update request state after async recv is finished.
The finished_recving_kv_req_ids list is populated
on the previous steps()'s update_from_output based
@@ -162,10 +166,13 @@ class RecomputeScheduler(Scheduler):
When the kv transfer is ready, we cache the blocks
and the request state will be moved back to WAITING from
WAITING_FOR_REMOTE_KV.
NOTE: The check for whether request.request_id is in
finished_recving_kv_req_ids is now done by the caller
(_try_promote_blocked_waiting_request in the parent Scheduler),
so this method is only called when the recv is confirmed finished.
"""
assert self.connector is not None
if request.request_id not in self.finished_recving_kv_req_ids:
return False
if request.request_id in self.failed_recving_kv_req_ids:
# Request had KV load failures; num_computed_tokens was already
@@ -181,6 +188,8 @@ class RecomputeScheduler(Scheduler):
self.failed_recving_kv_req_ids.remove(request.request_id)
else:
# Now that the blocks are ready, actually cache them.
# Use Ascend-specific block_ids logic to handle multi-group KV
# cache configurations (e.g. MLA) where len(block_ids) > 1.
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
if len(block_ids) == 1:
num_computed_tokens = len(block_ids[0]) * self.block_size
@@ -188,6 +197,8 @@ class RecomputeScheduler(Scheduler):
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
else:
num_computed_tokens = request.num_tokens
# on a full prompt hit, we need to re-compute the last token
# in order to be able to sample the next token
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1
# This will cache the blocks iff caching is enabled.
@@ -196,9 +207,11 @@ class RecomputeScheduler(Scheduler):
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens
# Return that we are ready.
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = request.num_computed_tokens
self.finished_recving_kv_req_ids.remove(request.request_id)
return True
def schedule(self) -> RecomputeSchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
@@ -230,6 +243,12 @@ class RecomputeScheduler(Scheduler):
# For logging.
scheduled_timestamp = time.monotonic()
self.kv_cache_manager.new_step_starts()
if self._pause_state == PauseState.PAUSED_ALL:
# Do not schedule any requests when paused.
token_budget = 0
# First, schedule the RUNNING requests.
req_index = 0
while req_index < len(self.running) and token_budget > 0:
@@ -393,6 +412,8 @@ class RecomputeScheduler(Scheduler):
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
encoder_compute_budget = new_encoder_compute_budget
if external_load_encoder_input:
for i in external_load_encoder_input:
@@ -410,54 +431,31 @@ class RecomputeScheduler(Scheduler):
)
assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
# Next, schedule the WAITING requests.
if not preempted_reqs and not recomputed_reqs:
while self.waiting and token_budget > 0:
if not preempted_reqs and not recomputed_reqs and self._pause_state == PauseState.UNPAUSED:
step_skipped_waiting = create_request_queue(self.policy)
while (self.waiting or self.skipped_waiting) and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
break
request = self.waiting.peek_request()
request_queue = self._select_waiting_queue_for_scheduling()
assert request_queue is not None
request = request_queue.peek_request()
request_id = request.request_id
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
if request.num_preemptions:
# We must be loading for a resumed preemption
# rather than a new request.
request.status = RequestStatus.PREEMPTED
else:
request.status = RequestStatus.WAITING
else:
# try to promote blocked statuses while traversing skipped queue.
if self._is_blocked_waiting_status(request.status) and not self._try_promote_blocked_waiting_request(
request
):
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request_id,
)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Streaming: skip request if still waiting for next streaming req.
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
assert not request.streaming_queue
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
request_queue.pop_request()
step_skipped_waiting.prepend_request(request)
continue
# Check that adding the request still respects the max_loras
@@ -471,8 +469,8 @@ class RecomputeScheduler(Scheduler):
)
):
# Scheduling would exceed max_loras, skip.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
request_queue.pop_request()
step_skipped_waiting.prepend_request(request)
continue
num_external_computed_tokens = 0
@@ -496,8 +494,8 @@ class RecomputeScheduler(Scheduler):
# The request cannot be scheduled because
# the KVConnector couldn't determine
# the number of matched tokens.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
request_queue.pop_request()
step_skipped_waiting.prepend_request(request)
continue
request.num_external_computed_tokens = ext_tokens
@@ -508,6 +506,7 @@ class RecomputeScheduler(Scheduler):
# Total computed tokens (local + external).
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
assert num_computed_tokens <= request.num_tokens
else:
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
@@ -581,9 +580,10 @@ class RecomputeScheduler(Scheduler):
# of local and remote blocks.
effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
num_encoder_tokens = (
self._num_encoder_max_input_tokens if self.is_encoder_decoder and request.has_encoder_inputs else 0
)
# Determine if we need to allocate cross-attention blocks.
num_encoder_tokens = 0
if self.is_encoder_decoder and request.has_encoder_inputs and encoder_inputs_to_schedule:
num_encoder_tokens = sum(request.get_num_encoder_embeds(i) for i in encoder_inputs_to_schedule)
new_blocks = self.kv_cache_manager.allocate_slots(
request,
@@ -622,14 +622,26 @@ class RecomputeScheduler(Scheduler):
preempted=request.num_preemptions > 0,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
request = request_queue.pop_request()
if load_kv_async:
# If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
step_skipped_waiting.prepend_request(request)
# Set num_computed_tokens even though KVs are not yet loaded.
# request.num_computed_tokens will not be used anywhere until
# the request finished the KV transfer.
#
# If a transfer error is reported by the connector,
# request.num_computed_tokens will be re-set accordingly in
# _update_requests_with_invalid_blocks.
#
# When the transfer is finished, either successfully or not,
# request.num_computed_tokens will correctly reflect the number
# of computed tokens.
# _update_waiting_for_remote_kv will then cache
# only the successfully loaded tokens.
request.num_computed_tokens = num_computed_tokens
continue
# For spec_token_ids, the waiting queue has the same processing
@@ -677,6 +689,8 @@ class RecomputeScheduler(Scheduler):
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
encoder_compute_budget = new_encoder_compute_budget
# Allocate for external load encoder cache
if external_load_encoder_input:
@@ -684,9 +698,10 @@ class RecomputeScheduler(Scheduler):
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
# re-queue requests skipped in this pass ahead of older skipped items.
if step_skipped_waiting:
self.skipped_waiting.prepend_requests(step_skipped_waiting)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
@@ -738,6 +753,10 @@ class RecomputeScheduler(Scheduler):
self.prev_step_scheduled_req_ids.clear()
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
new_block_ids_to_zero = (
(self.kv_cache_manager.take_new_block_ids() or None) if self.needs_kv_cache_zeroing else None
)
scheduler_output = RecomputeSchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
@@ -753,6 +772,7 @@ class RecomputeScheduler(Scheduler):
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
new_block_ids_to_zero=new_block_ids_to_zero,
recomputed_reqs=recomputed_reqs,
)