diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index 129994e0..6b088ae7 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -234,7 +234,7 @@ class TestNPUPlatform(TestBase): return_value=AscendDeviceType._910_93) @patch("os.environ", {}) @patch( - "vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config" + "vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config" ) def test_check_and_update_config_basic_config_update( self, mock_init_recompute, mock_soc_version, mock_update_acl, @@ -266,7 +266,7 @@ class TestNPUPlatform(TestBase): return_value=AscendDeviceType._910_93) @patch("vllm_ascend.ascend_config.init_ascend_config") @patch( - "vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config" + "vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config" ) def test_check_and_update_config_no_model_config_warning( self, mock_init_recompute, mock_init_ascend, mock_soc_version): @@ -291,7 +291,7 @@ class TestNPUPlatform(TestBase): return_value=AscendDeviceType._910_93) @patch("vllm_ascend.ascend_config.init_ascend_config") @patch( - "vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config" + "vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config" ) def test_check_and_update_config_enforce_eager_mode( self, mock_init_recompute, mock_init_ascend, mock_soc_version): @@ -328,7 +328,7 @@ class TestNPUPlatform(TestBase): @patch("vllm_ascend.utils.update_default_aclgraph_sizes") @patch("vllm_ascend.ascend_config.init_ascend_config") @patch( - "vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config" + "vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config" ) def test_check_and_update_config_unsupported_compilation_level( self, mock_init_recompute, mock_init_ascend, mock_update_default, @@ -397,7 +397,7 @@ class TestNPUPlatform(TestBase): return_value=AscendDeviceType._910_93) @patch("vllm_ascend.ascend_config.init_ascend_config") @patch( - "vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config" + "vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config" ) def test_check_and_update_config_cache_config_block_size( self, mock_init_recompute, mock_init_ascend, mock_soc_version): @@ -424,7 +424,7 @@ class TestNPUPlatform(TestBase): return_value=AscendDeviceType._910_93) @patch("vllm_ascend.ascend_config.init_ascend_config") @patch( - "vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config" + "vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config" ) def test_check_and_update_config_v1_worker_class_selection( self, mock_init_recompute, mock_init_ascend, mock_soc_version): @@ -462,7 +462,7 @@ class TestNPUPlatform(TestBase): @patch('vllm_ascend.utils.get_ascend_device_type', return_value=AscendDeviceType._310P) @patch( - "vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config" + "vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config" ) def test_check_and_update_config_310p_no_custom_ops( self, mock_init_recompute, mock_soc_version, mock_init_ascend): diff --git a/vllm_ascend/core/recompute_schedule_config.py b/vllm_ascend/core/recompute_schedule_config.py deleted file mode 100644 index be19a1c7..00000000 --- a/vllm_ascend/core/recompute_schedule_config.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# - -from dataclasses import dataclass, fields -from typing import Type, Union - -from vllm.config import SchedulerConfig - -MAX_INT = 2147483647 - - -@dataclass -class RecomputeSchedulerConfig(SchedulerConfig): - scheduler_cls: Union[str, Type[object]] = ( - "vllm_ascend.core.recompute_scheduler.RecomputeScheduler") - - @classmethod - def initialize_from_config(cls, vllm_scheduler_config: SchedulerConfig): - scheduler_config = { - field.name: getattr(vllm_scheduler_config, field.name) - for field in fields(vllm_scheduler_config) if field.init - } - scheduler_config["scheduler_cls"] = ( - "vllm_ascend.core.recompute_scheduler.RecomputeScheduler") - return cls(**scheduler_config) diff --git a/vllm_ascend/core/recompute_scheduler.py b/vllm_ascend/core/recompute_scheduler.py index a99e01cf..48aa67a2 100644 --- a/vllm_ascend/core/recompute_scheduler.py +++ b/vllm_ascend/core/recompute_scheduler.py @@ -13,194 +13,94 @@ # See the License for the specific language governing permissions and # limitations under the License. # This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/vllm/v1/core/sched/scheduler.py # + from __future__ import annotations -import itertools import time from collections import defaultdict -from collections.abc import Iterable -from dataclasses import dataclass -from typing import Any, Optional, Union +from dataclasses import dataclass, fields +from typing import Type, Union -import numpy as np -import numpy.typing as npt -from vllm.config import VllmConfig -from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch -from vllm.distributed.kv_transfer.kv_connector.factory import \ - KVConnectorFactory -from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, - KVConnectorRole) +from vllm._bc_linter import bc_linter_include +from vllm.config import SchedulerConfig, VllmConfig +from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata +from vllm.distributed.kv_events import KVEventBatch from vllm.distributed.kv_transfer.kv_connector.v1.base import \ KVConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.metrics import \ KVConnectorStats -from vllm.logger import logger -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, - compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager -from vllm.v1.core.sched.interface import SchedulerInterface -from vllm.v1.core.sched.output import CachedRequestData, NewRequestData +from vllm.logger import init_logger +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.async_scheduler import AsyncScheduler +from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.request_queue import (SchedulingPolicy, create_request_queue) +from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason) -from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.metrics.stats import SchedulerStats -from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput +from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats -from vllm.v1.structured_output import StructuredOutputManager -from vllm.v1.utils import ConstantList +from vllm.v1.utils import ConstantList, record_function_or_nullcontext + +logger = init_logger(__name__) -class RecomputeScheduler(SchedulerInterface): - """This Scheduler extends vllm's original v1 scheduler of version 0.11 - to fix recomputing bug.""" +@dataclass +class RecomputeSchedulerConfig(SchedulerConfig): + scheduler_cls: Union[str, Type[object]] = ( + "vllm_ascend.core.recompute_scheduler.RecomputeScheduler") - def __init__( - self, - vllm_config: VllmConfig, - kv_cache_config: KVCacheConfig, - structured_output_manager: StructuredOutputManager, - block_size: Optional[int] = None, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - include_finished_set: bool = False, - log_stats: bool = False, - ) -> None: - self.vllm_config = vllm_config - self.scheduler_config = vllm_config.scheduler_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.kv_cache_config = kv_cache_config - self.kv_events_config = vllm_config.kv_events_config - self.parallel_config = vllm_config.parallel_config - self.log_stats = log_stats - self.structured_output_manager = structured_output_manager - self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder - - # include_finished_set controls whether a separate set of finished - # request ids should be included in the EngineCoreOutputs returned - # by update_from_outputs(). This is currently used in the multi-engine - # case to track request lifetimes efficiently. - self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( - defaultdict(set) if include_finished_set else None) - - # Scheduling constraints. - self.max_num_running_reqs = self.scheduler_config.max_num_seqs - self.max_num_scheduled_tokens = \ - self.scheduler_config.max_num_batched_tokens - self.max_model_len = self.vllm_config.model_config.max_model_len - self.enable_kv_cache_events = ( - self.kv_events_config is not None - and self.kv_events_config.enable_kv_cache_events) - - # Create KVConnector for the Scheduler. Note that each Worker - # will have a corresponding KVConnector with Role=WORKER. - # KV Connector pushes/pull of remote KVs for P/D and offloading. - self.connector = None - if self.vllm_config.kv_transfer_config is not None: - assert len(self.kv_cache_config.kv_cache_groups) == 1, ( - "Multiple KV cache groups are not currently supported " - "with KV connectors") - assert not self.is_encoder_decoder, ( - "Encoder-decoder models are not currently supported " - "with KV connectors") - self.connector = KVConnectorFactory.create_connector( - config=self.vllm_config, role=KVConnectorRole.SCHEDULER) - - self.kv_event_publisher = EventPublisherFactory.create( - self.kv_events_config, - self.parallel_config.data_parallel_rank, - ) - - num_gpu_blocks = self.cache_config.num_gpu_blocks - assert num_gpu_blocks is not None and num_gpu_blocks > 0 - - self.block_size = self.cache_config.block_size - - self.dcp_world_size = \ - vllm_config.parallel_config.decode_context_parallel_size - # Note(hc): The scheduler’s block_size must be multiplied - # by dcp_world_size, since block hashes are computed on the - # original full token sequence at a granularity of - # original_block_size × dcp_world_size. - if self.dcp_world_size > 1: - self.block_size *= self.dcp_world_size - - # req_id -> Request - self.requests: dict[str, Request] = {} - # Scheduling policy - if self.scheduler_config.policy == "priority": - self.policy = SchedulingPolicy.PRIORITY - elif self.scheduler_config.policy == "fcfs": - self.policy = SchedulingPolicy.FCFS + @classmethod + def initialize_from_config(cls, vllm_config: VllmConfig): + vllm_scheduler_config = vllm_config.scheduler_config + scheduler_config = { + field.name: getattr(vllm_scheduler_config, field.name) + for field in fields(vllm_scheduler_config) if field.init + } + if vllm_scheduler_config.async_scheduling: + scheduler_config["scheduler_cls"] = ( + "vllm_ascend.core.recompute_scheduler.AsyncRecomputeScheduler") else: - raise ValueError( - f"Unknown scheduling policy: {self.scheduler_config.policy}") - # Priority queues for requests. - self.waiting = create_request_queue(self.policy) - self.running: list[Request] = [] + scheduler_config["scheduler_cls"] = ( + "vllm_ascend.core.recompute_scheduler.RecomputeScheduler") + scheduler_config[ + "max_model_len"] = vllm_config.model_config.max_model_len + scheduler_config[ + "is_encoder_decoder"] = vllm_config.model_config.is_encoder_decoder + return cls(**scheduler_config) - # The request IDs that are finished in between the previous and the - # current steps. This is used to notify the workers about the finished - # requests so that they can free the cached states for those requests. - # This is flushed at the end of each scheduling step. - self.finished_req_ids: set[str] = set() - # KV Connector: requests in process of async KV loading or recving - self.finished_recving_kv_req_ids: set[str] = set() +@dataclass +class RecomputeReqInfo: + request_id: str + output_token_ids: ConstantList + client_index: int = 0 - # Encoder-related. - # Calculate encoder cache size if applicable - # NOTE: For now we use the same budget for both compute and space. - # This can be changed when we make encoder cache for embedding caching - # across requests. - encoder_compute_budget, encoder_cache_size = compute_encoder_budget( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - mm_registry=mm_registry, - ) - # NOTE(woosuk): Here, "encoder" includes the vision encoder (and - # projector if needed) for MM models as well as encoder-decoder - # transformers. - self.max_num_encoder_input_tokens = encoder_compute_budget - # NOTE: For the models without encoder (e.g., text-only models), - # the encoder cache will not be initialized because cache size is 0 - # for these models. - self.encoder_cache_manager = EncoderCacheManager( - cache_size=encoder_cache_size) +@bc_linter_include +@dataclass +class RecomputeSchedulerOutput(SchedulerOutput): + recomputed_reqs: list[RecomputeReqInfo] | None = None - speculative_config = vllm_config.speculative_config - self.use_eagle = False - self.num_spec_tokens = self.num_lookahead_tokens = 0 - if speculative_config: - self.num_spec_tokens = speculative_config.num_speculative_tokens - if speculative_config.use_eagle(): - self.use_eagle = True - self.num_lookahead_tokens = self.num_spec_tokens - # Create the KV cache manager. - self.kv_cache_manager = KVCacheManager( - kv_cache_config=kv_cache_config, - max_model_len=self.max_model_len, - enable_caching=self.cache_config.enable_prefix_caching, - use_eagle=self.use_eagle, - log_stats=self.log_stats, - enable_kv_cache_events=self.enable_kv_cache_events, - dcp_world_size=self.dcp_world_size, - ) - self.use_pp = self.parallel_config.pipeline_parallel_size > 1 +class RecomputeScheduler(Scheduler): + running: list[Request] - def schedule(self) -> RecomputeSchedulerOutput: - """This scheduler extends vLLM's original v1 scheduler - by introducing a decoding instance recomputing scheduling strategy. - Specifically, if a request is preempted in the decoding instance, - it halts the process with the recomputed symbol and recalculates - its KVC in the prefill instance.""" + def schedule(self) -> SchedulerOutput: + # NOTE(woosuk) on the scheduling algorithm: + # There's no "decoding phase" nor "prefill phase" in the scheduler. + # Each request just has the num_computed_tokens and + # num_tokens_with_spec. num_tokens_with_spec = + # len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids). + # At each step, the scheduler tries to assign tokens to the requests + # so that each request's num_computed_tokens can catch up its + # num_tokens_with_spec. This is general enough to cover + # chunked prefills, prefix caching, speculative decoding, + # and the "jump decoding" optimization in the future. scheduled_new_reqs: list[Request] = [] scheduled_resumed_reqs: list[Request] = [] @@ -225,13 +125,26 @@ class RecomputeScheduler(SchedulerInterface): while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] + if (request.num_output_placeholders > 0 + # This is (num_computed_tokens + 1) - (num_output_placeholders - 1). + # Since output placeholders are also included in the computed tokens + # count, we subtract (num_output_placeholders - 1) to remove any draft + # tokens, so that we can be sure no further steps are needed even if + # they are all rejected. + and request.num_computed_tokens + 2 - + request.num_output_placeholders + >= request.num_prompt_tokens + request.max_tokens): + # Async scheduling: Avoid scheduling an extra step when we are sure that + # the previous step has reached request.max_tokens. We don't schedule + # partial draft tokens since this prevents uniform decode optimizations. + req_index += 1 + continue + num_new_tokens = (request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens) - if (0 < self.scheduler_config.long_prefill_token_threshold < - num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: + num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) # Make sure the input position does not exceed the max model len. @@ -242,13 +155,21 @@ class RecomputeScheduler(SchedulerInterface): # Schedule encoder inputs. encoder_inputs_to_schedule = None + external_load_encoder_input: list[int] = [] new_encoder_compute_budget = encoder_compute_budget if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_compute_budget - ) = self._try_schedule_encoder_inputs( - request, request.num_computed_tokens, num_new_tokens, - encoder_compute_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + external_load_encoder_input, + ) = self._try_schedule_encoder_inputs( + request, + request.num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + shift_computed_tokens=1 if self.use_eagle else 0, + ) if num_new_tokens == 0: # The request cannot be scheduled because one of the following @@ -266,12 +187,23 @@ class RecomputeScheduler(SchedulerInterface): req_index += 1 continue - while True: - new_blocks = self.kv_cache_manager.allocate_slots( - request, - num_new_tokens, - num_lookahead_tokens=self.num_lookahead_tokens) - if new_blocks is None: + # Schedule newly needed KV blocks for the request. + with record_function_or_nullcontext("schedule: allocate_slots"): + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens, + ) + + if new_blocks is not None: + # The request can be scheduled. + break + + # The request cannot be scheduled. + # Preempt the lowest-priority request. + # NOTE: We add the preempted_req to recomputed_reqs in kv_consumer to + # drop the request to PD proxy. transfer_config = self.vllm_config.kv_transfer_config if transfer_config is not None and not transfer_config.is_kv_producer: recomputed_req = self.running.pop() @@ -281,11 +213,8 @@ class RecomputeScheduler(SchedulerInterface): recomputed_req.output_token_ids, recomputed_req.client_index)) if recomputed_req == request: - can_schedule = False break else: - # The request cannot be scheduled. - # Preempt the lowest-priority request. if self.policy == SchedulingPolicy.PRIORITY: preempted_req = max( self.running, @@ -294,31 +223,36 @@ class RecomputeScheduler(SchedulerInterface): self.running.remove(preempted_req) if preempted_req in scheduled_running_reqs: scheduled_running_reqs.remove(preempted_req) + token_budget += num_scheduled_tokens[ + preempted_req.request_id] + req_to_new_blocks.pop(preempted_req.request_id) + num_scheduled_tokens.pop( + preempted_req.request_id) + scheduled_spec_decode_tokens.pop( + preempted_req.request_id, None) + preempted_encoder_inputs = scheduled_encoder_inputs.pop( + preempted_req.request_id, None) + if preempted_encoder_inputs: + # Restore encoder compute budget if the preempted + # request had encoder inputs scheduled in this step. + num_tokens_to_restore = sum( + preempted_req.get_num_encoder_tokens(i) + for i in preempted_encoder_inputs) + encoder_compute_budget += num_tokens_to_restore + req_index -= 1 else: preempted_req = self.running.pop() - self.kv_cache_manager.free(preempted_req) - self.encoder_cache_manager.free(preempted_req) - preempted_req.status = RequestStatus.PREEMPTED - preempted_req.num_computed_tokens = 0 - if self.log_stats: - preempted_req.record_event( - EngineCoreEventType.PREEMPTED, - scheduled_timestamp) - - self.waiting.prepend_request(preempted_req) + self._preempt_request(preempted_req, + scheduled_timestamp) preempted_reqs.append(preempted_req) if preempted_req == request: - # No more request to preempt. - can_schedule = False + # No more request to preempt. Cannot schedule this request. break - else: - # The request can be scheduled. - can_schedule = True - break - if not can_schedule: + + if new_blocks is None: + # Cannot schedule this request. break - assert new_blocks is not None # Schedule the request. scheduled_running_reqs.append(request) @@ -331,12 +265,16 @@ class RecomputeScheduler(SchedulerInterface): if request.spec_token_ids: num_scheduled_spec_tokens = (num_new_tokens + request.num_computed_tokens - - request.num_tokens) + request.num_tokens - + request.num_output_placeholders) if num_scheduled_spec_tokens > 0: # Trim spec_token_ids list to num_scheduled_spec_tokens. del request.spec_token_ids[num_scheduled_spec_tokens:] scheduled_spec_decode_tokens[request.request_id] = ( request.spec_token_ids) + # New spec tokens will be set in `update_draft_token_ids` before the + # next step when applicable. + request.spec_token_ids = [] # Encoder-related. if encoder_inputs_to_schedule: @@ -346,6 +284,11 @@ class RecomputeScheduler(SchedulerInterface): for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) encoder_compute_budget = new_encoder_compute_budget + if external_load_encoder_input: + for i in external_load_encoder_input: + self.encoder_cache_manager.allocate(request, i) + if self.ec_connector is not None: + self.ec_connector.update_state_after_alloc(request, i) # Record the LoRAs in scheduled_running_reqs scheduled_loras: set[int] = set() @@ -375,7 +318,8 @@ class RecomputeScheduler(SchedulerInterface): else: logger.debug( "%s is still in WAITING_FOR_REMOTE_KVS state.", - request.request_id) + request.request_id, + ) self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -407,17 +351,16 @@ class RecomputeScheduler(SchedulerInterface): # Get already-cached tokens. if request.num_computed_tokens == 0: # Get locally-cached tokens. - new_computed_blocks, num_new_local_computed_tokens = \ - self.kv_cache_manager.get_computed_blocks( - request) + new_computed_blocks, num_new_local_computed_tokens = ( + self.kv_cache_manager.get_computed_blocks(request)) # Get externally-cached tokens if using a KVConnector. if self.connector is not None: - num_external_computed_tokens, load_kv_async = ( + ext_tokens, load_kv_async = ( self.connector.get_num_new_matched_tokens( request, num_new_local_computed_tokens)) - if num_external_computed_tokens is None: + if ext_tokens is None: # The request cannot be scheduled because # the KVConnector couldn't determine # the number of matched tokens. @@ -425,53 +368,62 @@ class RecomputeScheduler(SchedulerInterface): skipped_waiting_requests.prepend_request(request) continue + request.num_external_computed_tokens = ext_tokens + num_external_computed_tokens = ext_tokens + # Total computed tokens (local + external). num_computed_tokens = (num_new_local_computed_tokens + num_external_computed_tokens) - # KVTransfer: WAITING reqs have num_computed_tokens > 0 - # after async KV recvs are completed. else: - new_computed_blocks = ( - self.kv_cache_manager.create_empty_block_list()) + # KVTransfer: WAITING reqs have num_computed_tokens > 0 + # after async KV recvs are completed. + new_computed_blocks = self.kv_cache_manager.empty_kv_cache_blocks num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens encoder_inputs_to_schedule = None + external_load_encoder_input = [] new_encoder_compute_budget = encoder_compute_budget - # KVTransfer: loading remote KV, do not allocate for new work. if load_kv_async: + # KVTransfer: loading remote KV, do not allocate for new work. assert num_external_computed_tokens > 0 num_new_tokens = 0 - # Number of tokens to be scheduled. else: + # Number of tokens to be scheduled. # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens - if (0 < self.scheduler_config.long_prefill_token_threshold - < num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + threshold = self.scheduler_config.long_prefill_token_threshold + if 0 < threshold < num_new_tokens: + num_new_tokens = threshold # chunked prefill has to be enabled explicitly to allow # pooling requests to be chunked - if not self.scheduler_config.enable_chunked_prefill and \ - num_new_tokens > token_budget: - self.waiting.pop_request() - skipped_waiting_requests.prepend_request(request) - continue + if (not self.scheduler_config.enable_chunked_prefill + and num_new_tokens > token_budget): + # If chunked_prefill is disabled, + # we can stop the scheduling here. + break num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 # Schedule encoder inputs. if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_compute_budget - ) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_compute_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + external_load_encoder_input, + ) = self._try_schedule_encoder_inputs( + request, + num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + shift_computed_tokens=1 if self.use_eagle else 0, + ) if num_new_tokens == 0: # The request cannot be scheduled. break @@ -491,8 +443,8 @@ class RecomputeScheduler(SchedulerInterface): # always padded to the maximum length. If we support other # encoder-decoder models, this will need to be updated if we # want to only allocate what is needed. - num_encoder_tokens = \ - self.scheduler_config.max_num_encoder_input_tokens + num_encoder_tokens = ( + self.scheduler_config.max_num_encoder_input_tokens) else: num_encoder_tokens = 0 @@ -531,6 +483,8 @@ class RecomputeScheduler(SchedulerInterface): request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue + self._update_connector_prefix_cache_stats(request) + req_index += 1 self.running.append(request) if self.log_stats: @@ -563,7 +517,13 @@ class RecomputeScheduler(SchedulerInterface): for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) encoder_compute_budget = new_encoder_compute_budget - + # Allocate for external load encoder cache + if external_load_encoder_input: + for i in external_load_encoder_input: + self.encoder_cache_manager.allocate(request, i) + if self.ec_connector is not None: + self.ec_connector.update_state_after_alloc( + request, i) # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: self.waiting.prepend_requests(skipped_waiting_requests) @@ -571,42 +531,59 @@ class RecomputeScheduler(SchedulerInterface): # Check if the scheduling constraints are satisfied. total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 assert len(self.running) <= self.max_num_running_reqs # Since some requests in the RUNNING queue may not be scheduled in # this step, the total number of scheduled requests can be smaller than # len(self.running). - assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + - len(scheduled_running_reqs) <= len(self.running)) + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( + scheduled_running_reqs) <= len(self.running) # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. num_common_prefix_blocks = [0] * len( self.kv_cache_config.kv_cache_groups) - if self.running: - any_request = self.running[0] - num_common_prefix_blocks = ( - self.kv_cache_manager.get_num_common_prefix_blocks( - any_request.request_id)) + with record_function_or_nullcontext( + "schedule: get_num_common_prefix_blocks"): + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request.request_id)) # Construct the scheduler output. - new_reqs_data = [ - NewRequestData.from_request( - req, req_to_new_blocks[req.request_id].get_block_ids()) - for req in scheduled_new_reqs - ] - cached_reqs_data = self._make_cached_request_data( - scheduled_running_reqs, - scheduled_resumed_reqs, - num_scheduled_tokens, - scheduled_spec_decode_tokens, - req_to_new_blocks, - ) - scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs + - scheduled_resumed_reqs) - structured_output_request_ids, grammar_bitmask = ( - self.get_grammar_bitmask(scheduled_requests, - scheduled_spec_decode_tokens)) + if self.use_v2_model_runner: + scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs + scheduled_resumed_reqs = [] + new_reqs_data = [ + NewRequestData.from_request( + req, + req_to_new_blocks[req.request_id].get_block_ids(), + req._all_token_ids, + ) for req in scheduled_new_reqs + ] + else: + new_reqs_data = [ + NewRequestData.from_request( + req, req_to_new_blocks[req.request_id].get_block_ids()) + for req in scheduled_new_reqs + ] + + with record_function_or_nullcontext( + "schedule: make_cached_request_data"): + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_blocks, + ) + + # Record the request ids that were scheduled in this step. + self.prev_step_scheduled_req_ids.clear() + self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) + scheduler_output = RecomputeSchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -615,6 +592,8 @@ class RecomputeScheduler(SchedulerInterface): scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, + preempted_req_ids={req.request_id + for req in preempted_reqs}, # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between @@ -622,8 +601,6 @@ class RecomputeScheduler(SchedulerInterface): finished_req_ids=self.finished_req_ids, free_encoder_mm_hashes=self.encoder_cache_manager. get_freed_mm_hashes(), - structured_output_request_ids=structured_output_request_ids, - grammar_bitmask=grammar_bitmask, recomputed_reqs=recomputed_reqs, ) @@ -632,261 +609,20 @@ class RecomputeScheduler(SchedulerInterface): # 2. Wrap up all the KV cache load / save ops into an opaque object # 3. Clear the internal states of the connector if self.connector is not None: - meta = self.connector.build_connector_meta(scheduler_output) + meta: KVConnectorMetadata = self.connector.build_connector_meta( + scheduler_output) scheduler_output.kv_connector_metadata = meta - # collect KV cache events from KV cache manager - events = self.kv_cache_manager.take_events() + # Build the connector meta for ECConnector + if self.ec_connector is not None: + ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta( + scheduler_output) + scheduler_output.ec_connector_metadata = ec_meta - # collect KV cache events from connector - if self.connector is not None: - connector_events = self.connector.take_events() - if connector_events: - if events is None: - events = list(connector_events) - else: - events.extend(connector_events) - - # publish collected KV cache events - if events: - batch = KVEventBatch(ts=time.time(), events=events) - self.kv_event_publisher.publish(batch) - - self._update_after_schedule(scheduler_output) + with record_function_or_nullcontext("schedule: update_after_schedule"): + self._update_after_schedule(scheduler_output) return scheduler_output - def _update_after_schedule( - self, - scheduler_output: RecomputeSchedulerOutput, - ) -> None: - # Advance the number of computed tokens for the request AFTER - # the request is scheduled. - # 1. The scheduler_output of the current step has to include the - # original number of scheduled tokens to determine input IDs. - # 2. Advance the number of computed tokens here allowing us to - # schedule the prefill request again immediately in the next - # scheduling step. - # 3. If some tokens (e.g. spec tokens) are rejected later, the number of - # computed tokens will be adjusted in update_from_output. - num_scheduled_tokens = scheduler_output.num_scheduled_tokens - for req_id, num_scheduled_token in num_scheduled_tokens.items(): - request = self.requests[req_id] - request.num_computed_tokens += num_scheduled_token - - # NOTE: _free_encoder_inputs relies on num_computed_tokens, which - # may be updated again in _update_from_output for speculative - # decoding. However, it is safe to call the method here because - # encoder inputs are always part of the prompt, not the output, - # and thus are unaffected by speculative decoding. - if request.has_encoder_inputs: - self._free_encoder_inputs(request) - - # Clear the finished request IDs. - # NOTE: We shouldn't do self.finished_req_ids.clear() here because - # it will also affect the scheduler output. - self.finished_req_ids = set() - - def _make_cached_request_data( - self, - running_reqs: list[Request], - resumed_reqs: list[Request], - num_scheduled_tokens: dict[str, int], - spec_decode_tokens: dict[str, list[int]], - req_to_new_blocks: dict[str, KVCacheBlocks], - ) -> CachedRequestData: - req_ids: list[str] = [] - new_token_ids: list[list[int]] = [] - new_block_ids: list[Optional[tuple[list[int], ...]]] = [] - num_computed_tokens: list[int] = [] - - use_connector = self.connector is not None - for req in itertools.chain(running_reqs, resumed_reqs): - req_id = req.request_id - req_ids.append(req_id) - num_tokens = (num_scheduled_tokens[req_id] - - len(spec_decode_tokens.get(req_id, ()))) - if self.use_pp: - # When using PP, the scheduler sends the sampled tokens back, - # because there's no direct communication between the first- - # stage worker and the last-stage worker. Otherwise, we don't - # need to send the sampled tokens back because the model runner - # will cache them. - token_ids = req.all_token_ids[req.num_computed_tokens:req. - num_computed_tokens + num_tokens] - new_token_ids.append(token_ids) - elif use_connector: - # When using a KVConnector, we add a placeholder to avoid index - # out of bounds errors. TODO: Remove this once the KVConnector - # is updated to handle token IDs properly. - new_token_ids.append([]) - new_block_ids.append( - req_to_new_blocks[req_id].get_block_ids(allow_none=True)) - num_computed_tokens.append(req.num_computed_tokens) - # Because resumed_reqs is usually empty, it is more efficient to do - # in-place appending so that we don't need to allocate a new list. - resumed_from_preemption = [False] * len(running_reqs) - resumed_from_preemption += [True] * len(resumed_reqs) - - return CachedRequestData( - req_ids=req_ids, - resumed_from_preemption=resumed_from_preemption, - new_token_ids=new_token_ids, - new_block_ids=new_block_ids, - num_computed_tokens=num_computed_tokens, - ) - - def _try_schedule_encoder_inputs( - self, - request: Request, - num_computed_tokens: int, - num_new_tokens: int, - encoder_compute_budget: int, - ) -> tuple[list[int], int, int]: - """ - Determine which encoder inputs need to be scheduled in the current step, - and update `num_new_tokens` and encoder token budget accordingly. - - An encoder input will be scheduled if: - - Its output tokens overlap with the range of tokens being computed - in this step, i.e., - [num_computed_tokens, num_computed_tokens + num_new_tokens). - - It is not already computed and stored in the encoder cache. - - There is sufficient encoder token budget to process it. - - The encoder cache has space to store it. - - If an encoder input cannot be scheduled due to cache or budget - limitations, the method adjusts `num_new_tokens` to schedule only the - decoder tokens up to just before the unschedulable encoder input. - - Note that num_computed_tokens includes both locally cached - blocks and externally cached blocks (via KVConnector). - """ - if num_new_tokens == 0 or not request.has_encoder_inputs: - return [], num_new_tokens, encoder_compute_budget - encoder_inputs_to_schedule: list[int] = [] - mm_features = request.mm_features - assert mm_features is not None - assert len(mm_features) > 0 - - # NOTE: since scheduler operates on the request level (possibly with - # multiple encoder inputs per request), we need to create temporary - # trackers for accounting at the encoder input level. - mm_hashes_to_schedule = set() - num_tokens_to_schedule = 0 - for i, mm_feature in enumerate(mm_features): - start_pos = mm_feature.mm_position.offset - num_encoder_tokens = mm_feature.mm_position.length - - # The encoder output is needed if the two ranges overlap: - # [num_computed_tokens, num_computed_tokens + num_new_tokens) and - # [start_pos, start_pos + num_encoder_tokens) - if start_pos >= num_computed_tokens + num_new_tokens: - # The encoder input is not needed in this step. - break - - if self.is_encoder_decoder and num_computed_tokens > 0: - assert start_pos == 0, ( - "Encoder input should be processed at the beginning of " - "the sequence when encoder-decoder models are used.") - # Encoder input has already been computed - # The calculation here is a bit different. We don't turn encoder - # output into tokens that get processed by the decoder and - # reflected in num_computed_tokens. Instead, start_pos reflects - # the position where we need to ensure we calculate encoder - # inputs. This should always be 0 to ensure we calculate encoder - # inputs before running the decoder. Once we've calculated some - # decoder tokens (num_computed_tokens > 0), then we know we - # already calculated encoder inputs and can skip here. - continue - elif start_pos + num_encoder_tokens <= num_computed_tokens: - # The encoder input is already computed and stored - # in the decoder's KV cache. - continue - - if not self.is_encoder_decoder: - # We are not using the encoder cache for encoder-decoder models, - # yet. - if request.mm_features[i].identifier in mm_hashes_to_schedule: - # The same encoder input has already been scheduled in the - # current step. - continue - - if self.encoder_cache_manager.check_and_update_cache( - request, i): - # The encoder input is already computed and cached from a - # previous step. - continue - - # If no encoder input chunking is allowed, we do not want to - # partially schedule a multimodal item. If the scheduled range would - # only cover part of the mm input, roll back to before the mm item. - if (self.scheduler_config.disable_chunked_mm_input - and num_computed_tokens < start_pos - and (num_computed_tokens + num_new_tokens) - < (start_pos + num_encoder_tokens)): - num_new_tokens = start_pos - num_computed_tokens - break - - if not self.encoder_cache_manager.can_allocate( - request, i, encoder_compute_budget, - num_tokens_to_schedule): - # The encoder cache is full or the encoder budget is exhausted. - # NOTE(woosuk): We assume that the encoder input tokens should - # be processed altogether, as the encoder usually uses - # bidirectional attention. - if num_computed_tokens < start_pos: - # We only schedule the decoder tokens just before the - # encoder input. - num_new_tokens = start_pos - num_computed_tokens - else: - # Because of prefix caching, num_computed_tokens is greater - # than start_pos even though its encoder input is not - # available. In this case, we can't schedule any token for - # the request in this step. - num_new_tokens = 0 - break - - num_tokens_to_schedule += num_encoder_tokens - encoder_compute_budget -= num_encoder_tokens - mm_hashes_to_schedule.add(request.mm_features[i].identifier) - encoder_inputs_to_schedule.append(i) - - return ( - encoder_inputs_to_schedule, - num_new_tokens, - encoder_compute_budget, - ) - - def get_grammar_bitmask( - self, - requests: list[Request], - scheduled_spec_decode_tokens: dict[str, list[int]], - ): - # NOTE: structured_output_request_ids maps - # a request's (request that uses structured output) - # request_id to its index in the batch. - # This will help us determine to slice the grammar bitmask - # and only applies valid mask for requests that - # uses structured decoding. - structured_output_request_ids: dict[str, int] = {} - for i, req in enumerate(requests): - if req.use_structured_output: - # PERF: in case of chunked prefill, - # request might not include any new tokens. - # Therefore, we might introduce some additional - # cycle to fill in the bitmask, which could be a big no-op. - structured_output_request_ids[req.request_id] = i - - if not structured_output_request_ids: - bitmask = None - else: - bitmask = self.structured_output_manager.grammar_bitmask( - self.requests, - structured_output_request_ids, - scheduled_spec_decode_tokens, - ) - return structured_output_request_ids, bitmask - def update_from_output( self, scheduler_output: RecomputeSchedulerOutput, @@ -901,18 +637,34 @@ class RecomputeScheduler(SchedulerInterface): kv_connector_output = model_runner_output.kv_connector_output outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) - spec_decoding_stats: Optional[SpecDecodingStats] = None - kv_connector_stats = (kv_connector_output.kv_connector_stats - if kv_connector_output else None) + spec_decoding_stats: SpecDecodingStats | None = None + kv_connector_stats: KVConnectorStats | None = ( + kv_connector_output.kv_connector_stats + if kv_connector_output else None) + if kv_connector_stats and self.connector: + kv_stats = self.connector.get_kv_connector_stats() + if kv_stats: + kv_connector_stats = kv_connector_stats.aggregate(kv_stats) + + failed_kv_load_req_ids = None + if kv_connector_output and kv_connector_output.invalid_block_ids: + # These blocks contain externally computed tokens that failed to + # load. Identify affected requests and adjust their computed token + # count to trigger recomputation of the invalid blocks. + failed_kv_load_req_ids = self._handle_invalid_blocks( + kv_connector_output.invalid_block_ids) + # return recomputed requests as EngineCoreOutput - for req_info in scheduler_output.recomputed_reqs: - outputs[req_info.client_index].append( - EngineCoreOutput( - request_id=req_info.request_id, - finish_reason=FinishReason.STOP, - new_token_ids=[req_info.output_token_ids[-1]], - stop_reason="recomputed", - )) + if scheduler_output.recomputed_reqs is not None: + for req_info in scheduler_output.recomputed_reqs: + outputs[req_info.client_index].append( + EngineCoreOutput( + request_id=req_info.request_id, + finish_reason=FinishReason.STOP, + new_token_ids=[req_info.output_token_ids[-1]], + stop_reason="recomputed", + )) + # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, # the below loop can be a performance bottleneck. We should do our best # to avoid expensive operations inside the loop. @@ -920,6 +672,9 @@ class RecomputeScheduler(SchedulerInterface): stopped_preempted_reqs: set[Request] = set() for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): assert num_tokens_scheduled > 0 + if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids: + # Skip requests that were recovered from KV load failure + continue request = self.requests.get(req_id) if request is None: # The request is already finished. This can happen if the @@ -928,9 +683,8 @@ class RecomputeScheduler(SchedulerInterface): continue req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids: list[int] = ( - sampled_token_ids[req_index].tolist() - if sampled_token_ids else []) + generated_token_ids = (sampled_token_ids[req_index] + if sampled_token_ids else []) scheduled_spec_token_ids = ( scheduler_output.scheduled_spec_decode_tokens.get(req_id)) @@ -943,11 +697,17 @@ class RecomputeScheduler(SchedulerInterface): # tokens and rejections. If some tokens are rejected, # num_computed_tokens is decreased by the number of rejected # tokens. - request.num_computed_tokens -= num_rejected + if request.num_computed_tokens > 0: + request.num_computed_tokens -= num_rejected + # If async scheduling, num_output_placeholders also includes + # the scheduled spec tokens count and so is similarly adjusted. + if request.num_output_placeholders > 0: + request.num_output_placeholders -= num_rejected spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted) + num_accepted_tokens=num_accepted, + ) stopped = False new_logprobs = None @@ -975,18 +735,18 @@ class RecomputeScheduler(SchedulerInterface): stopped_preempted_reqs.add(request) # Extract sample logprobs if needed. - if request.sampling_params is not None \ - and request.sampling_params.logprobs is not None and logprobs: - # NOTE: once we support N tokens per step (spec decode), - # the outer lists can be of length > 1. - new_logprobs = logprobs.slice(req_index, req_index + 1) + if (request.sampling_params is not None + and request.sampling_params.logprobs is not None + and logprobs): + new_logprobs = logprobs.slice_request(req_index, + len(new_token_ids)) if new_token_ids and self.structured_output_manager.should_advance( request): - # NOTE: structured_output_request - # should not be None if use_structured_output, we have - # checked above, so safe to ignore type warning - request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] + struct_output_request = request.structured_output_request + assert struct_output_request is not None + assert struct_output_request.grammar is not None + struct_output_request.grammar.accept_tokens( req_id, new_token_ids) if num_nans_in_logits is not None and req_id in num_nans_in_logits: @@ -994,9 +754,7 @@ class RecomputeScheduler(SchedulerInterface): # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids or pooler_output is not None \ - or kv_transfer_params: - + if new_token_ids or pooler_output is not None or kv_transfer_params: # Add EngineCoreOutput for this Request. outputs[request.client_index].append( EngineCoreOutput( @@ -1011,6 +769,7 @@ class RecomputeScheduler(SchedulerInterface): kv_transfer_params=kv_transfer_params, trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, + num_nans_in_logits=request.num_nans_in_logits, )) else: # Invariant: EngineCore returns no partial prefill outputs. @@ -1024,9 +783,25 @@ class RecomputeScheduler(SchedulerInterface): self.waiting.remove_requests(stopped_preempted_reqs) # KV Connector: update state for finished KV Transfers. - if model_runner_output.kv_connector_output: - self._update_from_kv_xfer_finished( - model_runner_output.kv_connector_output) + if kv_connector_output: + self._update_from_kv_xfer_finished(kv_connector_output) + + # collect KV cache events from KV cache manager + events = self.kv_cache_manager.take_events() + + # collect KV cache events from connector + if self.connector is not None: + connector_events = self.connector.take_events() + if connector_events: + if events is None: + events = list(connector_events) + else: + events.extend(connector_events) + + # publish collected KV cache events + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) # Create EngineCoreOutputs for all clients that have requests with # outputs in this step. @@ -1059,334 +834,8 @@ class RecomputeScheduler(SchedulerInterface): return engine_core_outputs - def _update_request_with_output( - self, - request: Request, - new_token_ids: list[int], - ) -> tuple[list[int], bool]: - # Append generated tokens and check for stop. Note that if - # a request is still being prefilled, we expect the model runner - # to return empty token ids for the request. - stopped = False - for num_new, output_token_id in enumerate(new_token_ids, 1): - request.append_output_token_ids(output_token_id) - # Check for stop and update request state. - # This must be called before we make the EngineCoreOutput. - stopped = check_stop(request, self.max_model_len) - if stopped: - del new_token_ids[num_new:] # Trim new tokens if needed. - break - return new_token_ids, stopped +class AsyncRecomputeScheduler(AsyncScheduler, RecomputeScheduler): - def _free_encoder_inputs(self, request: Request) -> None: - cached_encoder_input_ids = ( - self.encoder_cache_manager.get_cached_input_ids(request)) - # OPTIMIZATION: Avoid list(set) if the set is empty. - if not cached_encoder_input_ids: - return - - # Here, we use list(set) to avoid modifying the set while iterating - # over it. - for input_id in list(cached_encoder_input_ids): - mm_feature = request.mm_features[input_id] - start_pos = mm_feature.mm_position.offset - num_tokens = mm_feature.mm_position.length - if self.is_encoder_decoder and request.num_computed_tokens > 0: - # With Whisper, as soon as we've generated a single token, - # we know we're done with the encoder input. Cross Attention - # KVs have been calculated and cached already. - self.encoder_cache_manager.free_encoder_input( - request, input_id) - elif start_pos + num_tokens <= request.num_computed_tokens: - # The encoder output is already processed and stored - # in the decoder's KV cache. - self.encoder_cache_manager.free_encoder_input( - request, input_id) - - def update_draft_token_ids( - self, - draft_token_ids: DraftTokenIds, - ) -> None: - for req_id, spec_token_ids in zip( - draft_token_ids.req_ids, - draft_token_ids.draft_token_ids, - ): - request = self.requests.get(req_id) - if request is None or request.is_finished(): - # The request may have been finished. Skip. - continue - - # Add newly generated spec token ids to the request. - if not spec_token_ids: - # NOTE(woosuk): request.spec_token_ids should be updated. - request.spec_token_ids.clear() - elif self.structured_output_manager.should_advance(request): - metadata = request.structured_output_request - request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] - spec_token_ids) - else: - request.spec_token_ids = spec_token_ids - - def get_request_counts(self) -> tuple[int, int]: - """Returns (num_running_reqs, num_waiting_reqs).""" - return len(self.running), len(self.waiting) - - def add_request(self, request: Request) -> None: - self.waiting.add_request(request) - self.requests[request.request_id] = request - if self.log_stats: - request.record_event(EngineCoreEventType.QUEUED) - - def finish_requests( - self, - request_ids: Union[str, Iterable[str]], - finished_status: RequestStatus, - ) -> None: - """Handles the finish signal from outside the scheduler. - - For example, the API server can abort a request when the client - disconnects. - """ - assert RequestStatus.is_finished(finished_status) - if isinstance(request_ids, str): - request_ids = (request_ids, ) - else: - request_ids = set(request_ids) - - running_requests_to_remove = set() - waiting_requests_to_remove = [] - valid_requests = [] - - # First pass: collect requests to remove from queues - for req_id in request_ids: - request = self.requests.get(req_id) - if request is None: - # Invalid request ID. - continue - - valid_requests.append(request) - if request.status == RequestStatus.RUNNING: - running_requests_to_remove.add(request) - else: - waiting_requests_to_remove.append(request) - - # Remove all requests from queues at once for better efficiency - if running_requests_to_remove: - self.running = remove_all(self.running, running_requests_to_remove) - if waiting_requests_to_remove: - self.waiting.remove_requests(waiting_requests_to_remove) - - # Second pass: set status and free requests - for request in valid_requests: - request.status = finished_status - self._free_request(request) - - def _free_request(self, request: Request) -> Optional[dict[str, Any]]: - assert request.is_finished() - - delay_free_blocks, kv_xfer_params = self._connector_finished(request) - self.encoder_cache_manager.free(request) - request_id = request.request_id - self.finished_req_ids.add(request_id) - if self.finished_req_ids_dict is not None: - self.finished_req_ids_dict[request.client_index].add(request_id) - - if not delay_free_blocks: - self._free_blocks(request) - - return kv_xfer_params - - def _free_blocks(self, request: Request): - assert request.is_finished() - self.kv_cache_manager.free(request) - del self.requests[request.request_id] - - def get_num_unfinished_requests(self) -> int: - return len(self.waiting) + len(self.running) - - def has_finished_requests(self) -> bool: - return len(self.finished_req_ids) > 0 - - def reset_prefix_cache(self) -> bool: - return self.kv_cache_manager.reset_prefix_cache() - - def make_stats( - self, - spec_decoding_stats: Optional[SpecDecodingStats] = None, - kv_connector_stats: Optional[KVConnectorStats] = None, - ) -> Optional[SchedulerStats]: - if not self.log_stats: - return None - prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() - assert prefix_cache_stats is not None - return SchedulerStats(num_running_reqs=len(self.running), - num_waiting_reqs=len(self.waiting), - kv_cache_usage=self.kv_cache_manager.usage, - prefix_cache_stats=prefix_cache_stats, - spec_decoding_stats=spec_decoding_stats, - num_corrupted_reqs=sum(req.is_output_corrupted - for req in self.running), - kv_connector_stats=kv_connector_stats.data - if kv_connector_stats else None) - - def make_spec_decoding_stats( - self, - spec_decoding_stats: Optional[SpecDecodingStats], - num_draft_tokens: int, - num_accepted_tokens: int, - ) -> Optional[SpecDecodingStats]: - if not self.log_stats: - return None - if spec_decoding_stats is None: - spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) - spec_decoding_stats.observe_draft( - num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens) - return spec_decoding_stats - - def shutdown(self) -> None: - if self.kv_event_publisher: - self.kv_event_publisher.shutdown() - if self.connector is not None: - self.connector.shutdown() - - ######################################################################## - # KV Connector Related Methods - ######################################################################## - - def get_kv_connector(self) -> Optional[KVConnectorBase_V1]: - return self.connector - - def _connector_finished( - self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]: - """ - Invoke the KV connector request_finished() method if applicable. - - Returns optional kv transfer parameters to be included with the - request outputs. - """ - if self.connector is None: - return False, None - - (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) - return self.connector.request_finished(request, block_ids) - - def _update_waiting_for_remote_kv(self, request: Request) -> bool: - """ - KV Connector: check if the request_id is finished_recving. - - The finished_recving_kv_req_ids list is populated - on the previous steps()'s update_from_output based - on the worker side connector. - - When the kv transfer is ready, we cache the blocks - and the request state will be moved back to WAITING from - WAITING_FOR_REMOTE_KV. - """ - assert self.connector is not None - if request.request_id not in self.finished_recving_kv_req_ids: - return False - - # Now that the blocks are ready, actually cache them. - (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) - num_computed_tokens = len(block_ids) * self.block_size - # Handle the case where num request tokens less than one block. - num_computed_tokens = min(num_computed_tokens, request.num_tokens) - if num_computed_tokens == request.num_tokens: - num_computed_tokens -= 1 - # This will cache the blocks iff caching is enabled. - self.kv_cache_manager.cache_blocks(request, num_computed_tokens) - - # Update the request state for scheduling. - request.num_computed_tokens = num_computed_tokens - - # Return that we are ready. - self.finished_recving_kv_req_ids.remove(request.request_id) - return True - - def _update_from_kv_xfer_finished(self, - kv_connector_output: KVConnectorOutput): - """ - KV Connector: update the scheduler state based on the output. - - The Worker side connectors add finished_recving and - finished_sending reqs to the output. - * if finished_sending: free the blocks - # if finished_recving: add to state so we can - schedule the request during the next step. - """ - - if self.connector is not None: - self.connector.update_connector_output(kv_connector_output) - - # KV Connector:: update recv and send status from last step. - for req_id in (kv_connector_output.finished_recving or ()): - logger.debug("Finished recving KV transfer for request %s", req_id) - self.finished_recving_kv_req_ids.add(req_id) - for req_id in (kv_connector_output.finished_sending or ()): - logger.debug("Finished sending KV transfer for request %s", req_id) - if req_id not in self.requests: - logger.warning( - "Got finished sending KV transfer for request %s," - "but the request is already freed.", req_id) - else: - self._free_blocks(self.requests[req_id]) - - -@dataclass -class RecomputeReqInfo: - request_id: str - output_token_ids: ConstantList - client_index: int = 0 - - -@dataclass -class RecomputeSchedulerOutput: - - # list of the requests that are scheduled for the first time. - # We cache the request's data in each worker process, so that we don't - # need to re-send it every scheduling step. - scheduled_new_reqs: list[NewRequestData] - # list of the requests that have been scheduled before. - # Since the request's data is already cached in the worker processes, - # we only send the diff to minimize the communication cost. - scheduled_cached_reqs: CachedRequestData - - # req_id -> num_scheduled_tokens - # Number of tokens scheduled for each request. - num_scheduled_tokens: dict[str, int] - # Total number of tokens scheduled for all requests. - # Equal to sum(num_scheduled_tokens.values()) - total_num_scheduled_tokens: int - # req_id -> spec_token_ids - # If a request does not have any spec decode tokens, it will not be - # included in the dictionary. - scheduled_spec_decode_tokens: dict[str, list[int]] - # req_id -> encoder input indices that need processing. - # E.g., if a request has [0, 1], it could mean the vision encoder needs - # to process that the request's 0-th and 1-th images in the current step. - scheduled_encoder_inputs: dict[str, list[int]] - # Number of common prefix blocks for all requests in each KV cache group. - # This can be used for cascade attention. - num_common_prefix_blocks: list[int] - - # Request IDs that are finished in between the previous and the current - # steps. This is used to notify the workers about the finished requests - # so that they can free the cached states for those requests. - finished_req_ids: set[str] - # list of mm_hash strings associated with the encoder outputs to be - # freed from the encoder cache. - free_encoder_mm_hashes: list[str] - - # Dict of request ids to their index within the batch - # for filling the next token bitmask - structured_output_request_ids: dict[str, int] - # the bitmask for the whole batch - grammar_bitmask: Optional[npt.NDArray[np.int32]] - - # requests that need to recompute kv - recomputed_reqs: list[RecomputeReqInfo] - - # KV Cache Connector metadata. - kv_connector_metadata: Optional[KVConnectorMetadata] = None + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 7688a0cb..45c87477 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -298,10 +298,9 @@ class NPUPlatform(Platform): compilation_config.custom_ops = ["all"] if ascend_config.recompute_scheduler_enable: - from vllm_ascend.core.recompute_schedule_config import \ - RecomputeSchedulerConfig + from vllm_ascend.core.recompute_scheduler import RecomputeSchedulerConfig recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config( - vllm_config.scheduler_config) + vllm_config) vllm_config.scheduler_config = recompute_scheduler_config # Extend original scheduler_config to use SchedulerDynamicBatch.