[v0.18.0][Misc] Recompute scheduler upgrade to vLLM 0.18.0 (#7720)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? cherry-pick from #7675 . The current RecomputeScheduler is aligned to Scheduler in vLLM v0.16.0. Since upstream vLLM has upgraded to v0.18.0, we also need to upgrade RecomputeScheduler to pick up missing updates. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -30,8 +30,12 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStat
|
|||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||||
|
from vllm.v1.core.sched.interface import PauseState
|
||||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
from vllm.v1.core.sched.request_queue import (
|
||||||
|
SchedulingPolicy,
|
||||||
|
create_request_queue,
|
||||||
|
)
|
||||||
from vllm.v1.core.sched.scheduler import Scheduler
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
from vllm.v1.core.sched.utils import remove_all
|
from vllm.v1.core.sched.utils import remove_all
|
||||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason
|
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason
|
||||||
@@ -146,14 +150,14 @@ class RecomputeScheduler(Scheduler):
|
|||||||
request.num_prompt_tokens -= 1
|
request.num_prompt_tokens -= 1
|
||||||
if self.is_mtp_kv_consumer:
|
if self.is_mtp_kv_consumer:
|
||||||
request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens
|
request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens
|
||||||
self.waiting.add_request(request)
|
self._enqueue_waiting_request(request)
|
||||||
self.requests[request.request_id] = request
|
self.requests[request.request_id] = request
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
request.record_event(EngineCoreEventType.QUEUED)
|
request.record_event(EngineCoreEventType.QUEUED)
|
||||||
|
|
||||||
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
|
def _update_waiting_for_remote_kv(self, request: Request) -> None:
|
||||||
"""
|
"""
|
||||||
KV Connector: check if the request_id is finished_recving.
|
KV Connector: update request state after async recv is finished.
|
||||||
|
|
||||||
The finished_recving_kv_req_ids list is populated
|
The finished_recving_kv_req_ids list is populated
|
||||||
on the previous steps()'s update_from_output based
|
on the previous steps()'s update_from_output based
|
||||||
@@ -162,10 +166,13 @@ class RecomputeScheduler(Scheduler):
|
|||||||
When the kv transfer is ready, we cache the blocks
|
When the kv transfer is ready, we cache the blocks
|
||||||
and the request state will be moved back to WAITING from
|
and the request state will be moved back to WAITING from
|
||||||
WAITING_FOR_REMOTE_KV.
|
WAITING_FOR_REMOTE_KV.
|
||||||
|
|
||||||
|
NOTE: The check for whether request.request_id is in
|
||||||
|
finished_recving_kv_req_ids is now done by the caller
|
||||||
|
(_try_promote_blocked_waiting_request in the parent Scheduler),
|
||||||
|
so this method is only called when the recv is confirmed finished.
|
||||||
"""
|
"""
|
||||||
assert self.connector is not None
|
assert self.connector is not None
|
||||||
if request.request_id not in self.finished_recving_kv_req_ids:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if request.request_id in self.failed_recving_kv_req_ids:
|
if request.request_id in self.failed_recving_kv_req_ids:
|
||||||
# Request had KV load failures; num_computed_tokens was already
|
# Request had KV load failures; num_computed_tokens was already
|
||||||
@@ -181,6 +188,8 @@ class RecomputeScheduler(Scheduler):
|
|||||||
self.failed_recving_kv_req_ids.remove(request.request_id)
|
self.failed_recving_kv_req_ids.remove(request.request_id)
|
||||||
else:
|
else:
|
||||||
# Now that the blocks are ready, actually cache them.
|
# Now that the blocks are ready, actually cache them.
|
||||||
|
# Use Ascend-specific block_ids logic to handle multi-group KV
|
||||||
|
# cache configurations (e.g. MLA) where len(block_ids) > 1.
|
||||||
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
|
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
|
||||||
if len(block_ids) == 1:
|
if len(block_ids) == 1:
|
||||||
num_computed_tokens = len(block_ids[0]) * self.block_size
|
num_computed_tokens = len(block_ids[0]) * self.block_size
|
||||||
@@ -188,6 +197,8 @@ class RecomputeScheduler(Scheduler):
|
|||||||
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
|
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
|
||||||
else:
|
else:
|
||||||
num_computed_tokens = request.num_tokens
|
num_computed_tokens = request.num_tokens
|
||||||
|
# on a full prompt hit, we need to re-compute the last token
|
||||||
|
# in order to be able to sample the next token
|
||||||
if num_computed_tokens == request.num_tokens:
|
if num_computed_tokens == request.num_tokens:
|
||||||
num_computed_tokens -= 1
|
num_computed_tokens -= 1
|
||||||
# This will cache the blocks iff caching is enabled.
|
# This will cache the blocks iff caching is enabled.
|
||||||
@@ -196,9 +207,11 @@ class RecomputeScheduler(Scheduler):
|
|||||||
# Update the request state for scheduling.
|
# Update the request state for scheduling.
|
||||||
request.num_computed_tokens = num_computed_tokens
|
request.num_computed_tokens = num_computed_tokens
|
||||||
|
|
||||||
# Return that we are ready.
|
# Count the number of prefix cached tokens.
|
||||||
|
if request.num_cached_tokens < 0:
|
||||||
|
request.num_cached_tokens = request.num_computed_tokens
|
||||||
|
|
||||||
self.finished_recving_kv_req_ids.remove(request.request_id)
|
self.finished_recving_kv_req_ids.remove(request.request_id)
|
||||||
return True
|
|
||||||
|
|
||||||
def schedule(self) -> RecomputeSchedulerOutput:
|
def schedule(self) -> RecomputeSchedulerOutput:
|
||||||
# NOTE(woosuk) on the scheduling algorithm:
|
# NOTE(woosuk) on the scheduling algorithm:
|
||||||
@@ -230,6 +243,12 @@ class RecomputeScheduler(Scheduler):
|
|||||||
# For logging.
|
# For logging.
|
||||||
scheduled_timestamp = time.monotonic()
|
scheduled_timestamp = time.monotonic()
|
||||||
|
|
||||||
|
self.kv_cache_manager.new_step_starts()
|
||||||
|
|
||||||
|
if self._pause_state == PauseState.PAUSED_ALL:
|
||||||
|
# Do not schedule any requests when paused.
|
||||||
|
token_budget = 0
|
||||||
|
|
||||||
# First, schedule the RUNNING requests.
|
# First, schedule the RUNNING requests.
|
||||||
req_index = 0
|
req_index = 0
|
||||||
while req_index < len(self.running) and token_budget > 0:
|
while req_index < len(self.running) and token_budget > 0:
|
||||||
@@ -393,6 +412,8 @@ class RecomputeScheduler(Scheduler):
|
|||||||
# Allocate the encoder cache.
|
# Allocate the encoder cache.
|
||||||
for i in encoder_inputs_to_schedule:
|
for i in encoder_inputs_to_schedule:
|
||||||
self.encoder_cache_manager.allocate(request, i)
|
self.encoder_cache_manager.allocate(request, i)
|
||||||
|
if self.ec_connector is not None:
|
||||||
|
self.ec_connector.update_state_after_alloc(request, i)
|
||||||
encoder_compute_budget = new_encoder_compute_budget
|
encoder_compute_budget = new_encoder_compute_budget
|
||||||
if external_load_encoder_input:
|
if external_load_encoder_input:
|
||||||
for i in external_load_encoder_input:
|
for i in external_load_encoder_input:
|
||||||
@@ -410,54 +431,31 @@ class RecomputeScheduler(Scheduler):
|
|||||||
)
|
)
|
||||||
assert len(scheduled_loras) <= self.lora_config.max_loras
|
assert len(scheduled_loras) <= self.lora_config.max_loras
|
||||||
|
|
||||||
# Use a temporary RequestQueue to collect requests that need to be
|
|
||||||
# skipped and put back at the head of the waiting queue later
|
|
||||||
skipped_waiting_requests = create_request_queue(self.policy)
|
|
||||||
|
|
||||||
# Next, schedule the WAITING requests.
|
# Next, schedule the WAITING requests.
|
||||||
if not preempted_reqs and not recomputed_reqs:
|
if not preempted_reqs and not recomputed_reqs and self._pause_state == PauseState.UNPAUSED:
|
||||||
while self.waiting and token_budget > 0:
|
step_skipped_waiting = create_request_queue(self.policy)
|
||||||
|
|
||||||
|
while (self.waiting or self.skipped_waiting) and token_budget > 0:
|
||||||
if len(self.running) == self.max_num_running_reqs:
|
if len(self.running) == self.max_num_running_reqs:
|
||||||
break
|
break
|
||||||
|
|
||||||
request = self.waiting.peek_request()
|
request_queue = self._select_waiting_queue_for_scheduling()
|
||||||
|
assert request_queue is not None
|
||||||
|
|
||||||
|
request = request_queue.peek_request()
|
||||||
request_id = request.request_id
|
request_id = request.request_id
|
||||||
|
|
||||||
# KVTransfer: skip request if still waiting for remote kvs.
|
# try to promote blocked statuses while traversing skipped queue.
|
||||||
|
if self._is_blocked_waiting_status(request.status) and not self._try_promote_blocked_waiting_request(
|
||||||
|
request
|
||||||
|
):
|
||||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||||
is_ready = self._update_waiting_for_remote_kv(request)
|
|
||||||
if is_ready:
|
|
||||||
if request.num_preemptions:
|
|
||||||
# We must be loading for a resumed preemption
|
|
||||||
# rather than a new request.
|
|
||||||
request.status = RequestStatus.PREEMPTED
|
|
||||||
else:
|
|
||||||
request.status = RequestStatus.WAITING
|
|
||||||
else:
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
||||||
request_id,
|
request_id,
|
||||||
)
|
)
|
||||||
self.waiting.pop_request()
|
request_queue.pop_request()
|
||||||
skipped_waiting_requests.prepend_request(request)
|
step_skipped_waiting.prepend_request(request)
|
||||||
continue
|
|
||||||
|
|
||||||
# Skip request if the structured output request is still waiting
|
|
||||||
# for FSM compilation.
|
|
||||||
if request.status == RequestStatus.WAITING_FOR_FSM:
|
|
||||||
structured_output_req = request.structured_output_request
|
|
||||||
if structured_output_req and structured_output_req.grammar:
|
|
||||||
request.status = RequestStatus.WAITING
|
|
||||||
else:
|
|
||||||
self.waiting.pop_request()
|
|
||||||
skipped_waiting_requests.prepend_request(request)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Streaming: skip request if still waiting for next streaming req.
|
|
||||||
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
|
|
||||||
assert not request.streaming_queue
|
|
||||||
self.waiting.pop_request()
|
|
||||||
skipped_waiting_requests.prepend_request(request)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check that adding the request still respects the max_loras
|
# Check that adding the request still respects the max_loras
|
||||||
@@ -471,8 +469,8 @@ class RecomputeScheduler(Scheduler):
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
# Scheduling would exceed max_loras, skip.
|
# Scheduling would exceed max_loras, skip.
|
||||||
self.waiting.pop_request()
|
request_queue.pop_request()
|
||||||
skipped_waiting_requests.prepend_request(request)
|
step_skipped_waiting.prepend_request(request)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
num_external_computed_tokens = 0
|
num_external_computed_tokens = 0
|
||||||
@@ -496,8 +494,8 @@ class RecomputeScheduler(Scheduler):
|
|||||||
# The request cannot be scheduled because
|
# The request cannot be scheduled because
|
||||||
# the KVConnector couldn't determine
|
# the KVConnector couldn't determine
|
||||||
# the number of matched tokens.
|
# the number of matched tokens.
|
||||||
self.waiting.pop_request()
|
request_queue.pop_request()
|
||||||
skipped_waiting_requests.prepend_request(request)
|
step_skipped_waiting.prepend_request(request)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
request.num_external_computed_tokens = ext_tokens
|
request.num_external_computed_tokens = ext_tokens
|
||||||
@@ -508,6 +506,7 @@ class RecomputeScheduler(Scheduler):
|
|||||||
|
|
||||||
# Total computed tokens (local + external).
|
# Total computed tokens (local + external).
|
||||||
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
|
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
|
||||||
|
assert num_computed_tokens <= request.num_tokens
|
||||||
else:
|
else:
|
||||||
# KVTransfer: WAITING reqs have num_computed_tokens > 0
|
# KVTransfer: WAITING reqs have num_computed_tokens > 0
|
||||||
# after async KV recvs are completed.
|
# after async KV recvs are completed.
|
||||||
@@ -581,9 +580,10 @@ class RecomputeScheduler(Scheduler):
|
|||||||
# of local and remote blocks.
|
# of local and remote blocks.
|
||||||
effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
|
effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
|
||||||
|
|
||||||
num_encoder_tokens = (
|
# Determine if we need to allocate cross-attention blocks.
|
||||||
self._num_encoder_max_input_tokens if self.is_encoder_decoder and request.has_encoder_inputs else 0
|
num_encoder_tokens = 0
|
||||||
)
|
if self.is_encoder_decoder and request.has_encoder_inputs and encoder_inputs_to_schedule:
|
||||||
|
num_encoder_tokens = sum(request.get_num_encoder_embeds(i) for i in encoder_inputs_to_schedule)
|
||||||
|
|
||||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||||
request,
|
request,
|
||||||
@@ -622,14 +622,26 @@ class RecomputeScheduler(Scheduler):
|
|||||||
preempted=request.num_preemptions > 0,
|
preempted=request.num_preemptions > 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Request was already popped from self.waiting
|
request = request_queue.pop_request()
|
||||||
# unless it was re-added above due to new_blocks being None.
|
|
||||||
request = self.waiting.pop_request()
|
|
||||||
if load_kv_async:
|
if load_kv_async:
|
||||||
# If loading async, allocate memory and put request
|
# If loading async, allocate memory and put request
|
||||||
# into the WAITING_FOR_REMOTE_KV state.
|
# into the WAITING_FOR_REMOTE_KV state.
|
||||||
skipped_waiting_requests.prepend_request(request)
|
|
||||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
|
step_skipped_waiting.prepend_request(request)
|
||||||
|
# Set num_computed_tokens even though KVs are not yet loaded.
|
||||||
|
# request.num_computed_tokens will not be used anywhere until
|
||||||
|
# the request finished the KV transfer.
|
||||||
|
#
|
||||||
|
# If a transfer error is reported by the connector,
|
||||||
|
# request.num_computed_tokens will be re-set accordingly in
|
||||||
|
# _update_requests_with_invalid_blocks.
|
||||||
|
#
|
||||||
|
# When the transfer is finished, either successfully or not,
|
||||||
|
# request.num_computed_tokens will correctly reflect the number
|
||||||
|
# of computed tokens.
|
||||||
|
# _update_waiting_for_remote_kv will then cache
|
||||||
|
# only the successfully loaded tokens.
|
||||||
|
request.num_computed_tokens = num_computed_tokens
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# For spec_token_ids, the waiting queue has the same processing
|
# For spec_token_ids, the waiting queue has the same processing
|
||||||
@@ -677,6 +689,8 @@ class RecomputeScheduler(Scheduler):
|
|||||||
# Allocate the encoder cache.
|
# Allocate the encoder cache.
|
||||||
for i in encoder_inputs_to_schedule:
|
for i in encoder_inputs_to_schedule:
|
||||||
self.encoder_cache_manager.allocate(request, i)
|
self.encoder_cache_manager.allocate(request, i)
|
||||||
|
if self.ec_connector is not None:
|
||||||
|
self.ec_connector.update_state_after_alloc(request, i)
|
||||||
encoder_compute_budget = new_encoder_compute_budget
|
encoder_compute_budget = new_encoder_compute_budget
|
||||||
# Allocate for external load encoder cache
|
# Allocate for external load encoder cache
|
||||||
if external_load_encoder_input:
|
if external_load_encoder_input:
|
||||||
@@ -684,9 +698,10 @@ class RecomputeScheduler(Scheduler):
|
|||||||
self.encoder_cache_manager.allocate(request, i)
|
self.encoder_cache_manager.allocate(request, i)
|
||||||
if self.ec_connector is not None:
|
if self.ec_connector is not None:
|
||||||
self.ec_connector.update_state_after_alloc(request, i)
|
self.ec_connector.update_state_after_alloc(request, i)
|
||||||
# Put back any skipped requests at the head of the waiting queue
|
|
||||||
if skipped_waiting_requests:
|
# re-queue requests skipped in this pass ahead of older skipped items.
|
||||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
if step_skipped_waiting:
|
||||||
|
self.skipped_waiting.prepend_requests(step_skipped_waiting)
|
||||||
|
|
||||||
# Check if the scheduling constraints are satisfied.
|
# Check if the scheduling constraints are satisfied.
|
||||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||||
@@ -738,6 +753,10 @@ class RecomputeScheduler(Scheduler):
|
|||||||
self.prev_step_scheduled_req_ids.clear()
|
self.prev_step_scheduled_req_ids.clear()
|
||||||
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
|
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
|
||||||
|
|
||||||
|
new_block_ids_to_zero = (
|
||||||
|
(self.kv_cache_manager.take_new_block_ids() or None) if self.needs_kv_cache_zeroing else None
|
||||||
|
)
|
||||||
|
|
||||||
scheduler_output = RecomputeSchedulerOutput(
|
scheduler_output = RecomputeSchedulerOutput(
|
||||||
scheduled_new_reqs=new_reqs_data,
|
scheduled_new_reqs=new_reqs_data,
|
||||||
scheduled_cached_reqs=cached_reqs_data,
|
scheduled_cached_reqs=cached_reqs_data,
|
||||||
@@ -753,6 +772,7 @@ class RecomputeScheduler(Scheduler):
|
|||||||
# the previous and the current steps.
|
# the previous and the current steps.
|
||||||
finished_req_ids=self.finished_req_ids,
|
finished_req_ids=self.finished_req_ids,
|
||||||
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
||||||
|
new_block_ids_to_zero=new_block_ids_to_zero,
|
||||||
recomputed_reqs=recomputed_reqs,
|
recomputed_reqs=recomputed_reqs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user