diff --git a/vllm_ascend/core/recompute_scheduler.py b/vllm_ascend/core/recompute_scheduler.py index d5375e7e..b6fe93f8 100644 --- a/vllm_ascend/core/recompute_scheduler.py +++ b/vllm_ascend/core/recompute_scheduler.py @@ -30,8 +30,12 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStat from vllm.logger import logger from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.async_scheduler import AsyncScheduler +from vllm.v1.core.sched.interface import PauseState from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput -from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue +from vllm.v1.core.sched.request_queue import ( + SchedulingPolicy, + create_request_queue, +) from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.utils import remove_all from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason @@ -146,14 +150,14 @@ class RecomputeScheduler(Scheduler): request.num_prompt_tokens -= 1 if self.is_mtp_kv_consumer: request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens - self.waiting.add_request(request) + self._enqueue_waiting_request(request) self.requests[request.request_id] = request if self.log_stats: request.record_event(EngineCoreEventType.QUEUED) - def _update_waiting_for_remote_kv(self, request: Request) -> bool: + def _update_waiting_for_remote_kv(self, request: Request) -> None: """ - KV Connector: check if the request_id is finished_recving. + KV Connector: update request state after async recv is finished. The finished_recving_kv_req_ids list is populated on the previous steps()'s update_from_output based @@ -162,10 +166,13 @@ class RecomputeScheduler(Scheduler): When the kv transfer is ready, we cache the blocks and the request state will be moved back to WAITING from WAITING_FOR_REMOTE_KV. + + NOTE: The check for whether request.request_id is in + finished_recving_kv_req_ids is now done by the caller + (_try_promote_blocked_waiting_request in the parent Scheduler), + so this method is only called when the recv is confirmed finished. """ assert self.connector is not None - if request.request_id not in self.finished_recving_kv_req_ids: - return False if request.request_id in self.failed_recving_kv_req_ids: # Request had KV load failures; num_computed_tokens was already @@ -181,6 +188,8 @@ class RecomputeScheduler(Scheduler): self.failed_recving_kv_req_ids.remove(request.request_id) else: # Now that the blocks are ready, actually cache them. + # Use Ascend-specific block_ids logic to handle multi-group KV + # cache configurations (e.g. MLA) where len(block_ids) > 1. block_ids = self.kv_cache_manager.get_block_ids(request.request_id) if len(block_ids) == 1: num_computed_tokens = len(block_ids[0]) * self.block_size @@ -188,6 +197,8 @@ class RecomputeScheduler(Scheduler): num_computed_tokens = min(num_computed_tokens, request.num_tokens) else: num_computed_tokens = request.num_tokens + # on a full prompt hit, we need to re-compute the last token + # in order to be able to sample the next token if num_computed_tokens == request.num_tokens: num_computed_tokens -= 1 # This will cache the blocks iff caching is enabled. @@ -196,9 +207,11 @@ class RecomputeScheduler(Scheduler): # Update the request state for scheduling. request.num_computed_tokens = num_computed_tokens - # Return that we are ready. + # Count the number of prefix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = request.num_computed_tokens + self.finished_recving_kv_req_ids.remove(request.request_id) - return True def schedule(self) -> RecomputeSchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: @@ -230,6 +243,12 @@ class RecomputeScheduler(Scheduler): # For logging. scheduled_timestamp = time.monotonic() + self.kv_cache_manager.new_step_starts() + + if self._pause_state == PauseState.PAUSED_ALL: + # Do not schedule any requests when paused. + token_budget = 0 + # First, schedule the RUNNING requests. req_index = 0 while req_index < len(self.running) and token_budget > 0: @@ -393,6 +412,8 @@ class RecomputeScheduler(Scheduler): # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) + if self.ec_connector is not None: + self.ec_connector.update_state_after_alloc(request, i) encoder_compute_budget = new_encoder_compute_budget if external_load_encoder_input: for i in external_load_encoder_input: @@ -410,54 +431,31 @@ class RecomputeScheduler(Scheduler): ) assert len(scheduled_loras) <= self.lora_config.max_loras - # Use a temporary RequestQueue to collect requests that need to be - # skipped and put back at the head of the waiting queue later - skipped_waiting_requests = create_request_queue(self.policy) - # Next, schedule the WAITING requests. - if not preempted_reqs and not recomputed_reqs: - while self.waiting and token_budget > 0: + if not preempted_reqs and not recomputed_reqs and self._pause_state == PauseState.UNPAUSED: + step_skipped_waiting = create_request_queue(self.policy) + + while (self.waiting or self.skipped_waiting) and token_budget > 0: if len(self.running) == self.max_num_running_reqs: break - request = self.waiting.peek_request() + request_queue = self._select_waiting_queue_for_scheduling() + assert request_queue is not None + + request = request_queue.peek_request() request_id = request.request_id - # KVTransfer: skip request if still waiting for remote kvs. - if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: - is_ready = self._update_waiting_for_remote_kv(request) - if is_ready: - if request.num_preemptions: - # We must be loading for a resumed preemption - # rather than a new request. - request.status = RequestStatus.PREEMPTED - else: - request.status = RequestStatus.WAITING - else: + # try to promote blocked statuses while traversing skipped queue. + if self._is_blocked_waiting_status(request.status) and not self._try_promote_blocked_waiting_request( + request + ): + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: logger.debug( "%s is still in WAITING_FOR_REMOTE_KVS state.", request_id, ) - self.waiting.pop_request() - skipped_waiting_requests.prepend_request(request) - continue - - # Skip request if the structured output request is still waiting - # for FSM compilation. - if request.status == RequestStatus.WAITING_FOR_FSM: - structured_output_req = request.structured_output_request - if structured_output_req and structured_output_req.grammar: - request.status = RequestStatus.WAITING - else: - self.waiting.pop_request() - skipped_waiting_requests.prepend_request(request) - continue - - # Streaming: skip request if still waiting for next streaming req. - if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ: - assert not request.streaming_queue - self.waiting.pop_request() - skipped_waiting_requests.prepend_request(request) + request_queue.pop_request() + step_skipped_waiting.prepend_request(request) continue # Check that adding the request still respects the max_loras @@ -471,8 +469,8 @@ class RecomputeScheduler(Scheduler): ) ): # Scheduling would exceed max_loras, skip. - self.waiting.pop_request() - skipped_waiting_requests.prepend_request(request) + request_queue.pop_request() + step_skipped_waiting.prepend_request(request) continue num_external_computed_tokens = 0 @@ -496,8 +494,8 @@ class RecomputeScheduler(Scheduler): # The request cannot be scheduled because # the KVConnector couldn't determine # the number of matched tokens. - self.waiting.pop_request() - skipped_waiting_requests.prepend_request(request) + request_queue.pop_request() + step_skipped_waiting.prepend_request(request) continue request.num_external_computed_tokens = ext_tokens @@ -508,6 +506,7 @@ class RecomputeScheduler(Scheduler): # Total computed tokens (local + external). num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens + assert num_computed_tokens <= request.num_tokens else: # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. @@ -581,9 +580,10 @@ class RecomputeScheduler(Scheduler): # of local and remote blocks. effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens - num_encoder_tokens = ( - self._num_encoder_max_input_tokens if self.is_encoder_decoder and request.has_encoder_inputs else 0 - ) + # Determine if we need to allocate cross-attention blocks. + num_encoder_tokens = 0 + if self.is_encoder_decoder and request.has_encoder_inputs and encoder_inputs_to_schedule: + num_encoder_tokens = sum(request.get_num_encoder_embeds(i) for i in encoder_inputs_to_schedule) new_blocks = self.kv_cache_manager.allocate_slots( request, @@ -622,14 +622,26 @@ class RecomputeScheduler(Scheduler): preempted=request.num_preemptions > 0, ) - # Request was already popped from self.waiting - # unless it was re-added above due to new_blocks being None. - request = self.waiting.pop_request() + request = request_queue.pop_request() if load_kv_async: # If loading async, allocate memory and put request # into the WAITING_FOR_REMOTE_KV state. - skipped_waiting_requests.prepend_request(request) request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + step_skipped_waiting.prepend_request(request) + # Set num_computed_tokens even though KVs are not yet loaded. + # request.num_computed_tokens will not be used anywhere until + # the request finished the KV transfer. + # + # If a transfer error is reported by the connector, + # request.num_computed_tokens will be re-set accordingly in + # _update_requests_with_invalid_blocks. + # + # When the transfer is finished, either successfully or not, + # request.num_computed_tokens will correctly reflect the number + # of computed tokens. + # _update_waiting_for_remote_kv will then cache + # only the successfully loaded tokens. + request.num_computed_tokens = num_computed_tokens continue # For spec_token_ids, the waiting queue has the same processing @@ -677,6 +689,8 @@ class RecomputeScheduler(Scheduler): # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) + if self.ec_connector is not None: + self.ec_connector.update_state_after_alloc(request, i) encoder_compute_budget = new_encoder_compute_budget # Allocate for external load encoder cache if external_load_encoder_input: @@ -684,9 +698,10 @@ class RecomputeScheduler(Scheduler): self.encoder_cache_manager.allocate(request, i) if self.ec_connector is not None: self.ec_connector.update_state_after_alloc(request, i) - # Put back any skipped requests at the head of the waiting queue - if skipped_waiting_requests: - self.waiting.prepend_requests(skipped_waiting_requests) + + # re-queue requests skipped in this pass ahead of older skipped items. + if step_skipped_waiting: + self.skipped_waiting.prepend_requests(step_skipped_waiting) # Check if the scheduling constraints are satisfied. total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) @@ -738,6 +753,10 @@ class RecomputeScheduler(Scheduler): self.prev_step_scheduled_req_ids.clear() self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) + new_block_ids_to_zero = ( + (self.kv_cache_manager.take_new_block_ids() or None) if self.needs_kv_cache_zeroing else None + ) + scheduler_output = RecomputeSchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -753,6 +772,7 @@ class RecomputeScheduler(Scheduler): # the previous and the current steps. finished_req_ids=self.finished_req_ids, free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), + new_block_ids_to_zero=new_block_ids_to_zero, recomputed_reqs=recomputed_reqs, )