diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index 1e4e5b49..28d87efb 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -19,6 +19,7 @@ import os import vllm_ascend.patch.platform.patch_distributed # noqa import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa +from vllm_ascend import envs from vllm_ascend.utils import is_310p if not is_310p(): @@ -31,3 +32,6 @@ import vllm_ascend.patch.platform.patch_torch_accelerator # noqa if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv("EXPERT_MAP_RECORD", "false") == "true": import vllm_ascend.patch.platform.patch_multiproc_executor # noqa + +if envs.VLLM_ASCEND_BALANCE_SCHEDULING: + import vllm_ascend.patch.platform.patch_balance_schedule # noqa diff --git a/vllm_ascend/patch/platform/patch_balance_schedule.py b/vllm_ascend/patch/platform/patch_balance_schedule.py index 86b717a8..a8a70300 100644 --- a/vllm_ascend/patch/platform/patch_balance_schedule.py +++ b/vllm_ascend/patch/platform/patch_balance_schedule.py @@ -13,6 +13,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.utils.system_utils import decorate_logs, set_process_title from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.interface import PauseState from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.scheduler import Scheduler @@ -65,6 +66,7 @@ class BalanceScheduler(Scheduler): # num_tokens_with_spec. This is general enough to cover # chunked prefills, prefix caching, speculative decoding, # and the "jump decoding" optimization in the future. + scheduled_new_reqs: list[Request] = [] scheduled_resumed_reqs: list[Request] = [] scheduled_running_reqs: list[Request] = [] @@ -73,6 +75,10 @@ class BalanceScheduler(Scheduler): req_to_new_blocks: dict[str, KVCacheBlocks] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens + if self._pause_state == PauseState.PAUSED_ALL: + # Do not schedule any requests when paused. + token_budget = 0 + # Encoder-related. scheduled_encoder_inputs: dict[str, list[int]] = {} encoder_compute_budget = self.max_num_encoder_input_tokens @@ -82,6 +88,8 @@ class BalanceScheduler(Scheduler): # For logging. scheduled_timestamp = time.monotonic() + self.kv_cache_manager.new_step_starts() + # First, schedule the RUNNING requests. req_index = 0 while req_index < len(self.running) and token_budget > 0: @@ -132,6 +140,9 @@ class BalanceScheduler(Scheduler): shift_computed_tokens=1 if self.use_eagle else 0, ) + if self.need_mamba_block_aligned_split: + num_new_tokens = self._mamba_block_aligned_split(request, num_new_tokens) + if num_new_tokens == 0: # The request cannot be scheduled because one of the following # reasons: @@ -142,6 +153,8 @@ class BalanceScheduler(Scheduler): # its max_total_tokens or max_model_len. # 2. The encoder budget is exhausted. # 3. The encoder cache is exhausted. + # 4. Insufficient budget for a block-aligned chunk in hybrid + # models with mamba cache mode \"align\". # NOTE(woosuk): Here, by doing `continue` instead of `break`, # we do not strictly follow the FCFS scheduling policy and # allow the lower-priority requests to be scheduled. @@ -170,12 +183,12 @@ class BalanceScheduler(Scheduler): ) self.running.remove(preempted_req) if preempted_req in scheduled_running_reqs: + preempted_req_id = preempted_req.request_id scheduled_running_reqs.remove(preempted_req) - token_budget += num_scheduled_tokens[preempted_req.request_id] - req_to_new_blocks.pop(preempted_req.request_id) - num_scheduled_tokens.pop(preempted_req.request_id) - scheduled_spec_decode_tokens.pop(preempted_req.request_id, None) - preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req.request_id, None) + token_budget += num_scheduled_tokens.pop(preempted_req_id) + req_to_new_blocks.pop(preempted_req_id) + scheduled_spec_decode_tokens.pop(preempted_req_id, None) + preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req_id, None) if preempted_encoder_inputs: # Restore encoder compute budget if the preempted # request had encoder inputs scheduled in this step. @@ -199,8 +212,9 @@ class BalanceScheduler(Scheduler): # Schedule the request. scheduled_running_reqs.append(request) - req_to_new_blocks[request.request_id] = new_blocks - num_scheduled_tokens[request.request_id] = num_new_tokens + request_id = request.request_id + req_to_new_blocks[request_id] = new_blocks + num_scheduled_tokens[request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -210,16 +224,18 @@ class BalanceScheduler(Scheduler): num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders ) if num_scheduled_spec_tokens > 0: - # Trim spec_token_ids list to num_scheduled_spec_tokens. - del request.spec_token_ids[num_scheduled_spec_tokens:] - scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids + spec_token_ids = request.spec_token_ids + if len(spec_token_ids) > num_scheduled_spec_tokens: + spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens] + scheduled_spec_decode_tokens[request.request_id] = spec_token_ids + # New spec tokens will be set in `update_draft_token_ids` before the # next step when applicable. request.spec_token_ids = [] # Encoder-related. if encoder_inputs_to_schedule: - scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule + scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -240,31 +256,37 @@ class BalanceScheduler(Scheduler): ) assert len(scheduled_loras) <= self.lora_config.max_loras - # Use a temporary RequestQueue to collect requests that need to be - # skipped and put back at the head of the waiting queue later - skipped_waiting_requests = create_request_queue(self.policy) - # Next, schedule the WAITING requests. - if not preempted_reqs: + if not preempted_reqs and self._pause_state == PauseState.UNPAUSED: + # Use a temporary RequestQueue to collect requests that need to be + # skipped and put back at the head of the waiting queue later + skipped_waiting_requests = create_request_queue(self.policy) + while self.waiting and token_budget > 0: if len(self.running) == self.max_num_running_reqs: break - balance_flag = max(t.item() for t in self.balance_queue) == self.max_num_running_reqs + balance_flag = max(t.item() for t in self.balance_queue) >= self.max_num_running_reqs - 1 if balance_flag: break request = self.waiting.peek_request() + request_id = request.request_id # KVTransfer: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: is_ready = self._update_waiting_for_remote_kv(request) if is_ready: - request.status = RequestStatus.WAITING + if request.num_preemptions: + # We must be loading for a resumed preemption + # rather than a new request. + request.status = RequestStatus.PREEMPTED + else: + request.status = RequestStatus.WAITING else: logger.debug( "%s is still in WAITING_FOR_REMOTE_KVS state.", - request.request_id, + request_id, ) self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) @@ -281,6 +303,13 @@ class BalanceScheduler(Scheduler): skipped_waiting_requests.prepend_request(request) continue + # Streaming: skip request if still waiting for next streaming req. + if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ: + assert not request.streaming_queue + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + # Check that adding the request still respects the max_loras # constraint. if ( @@ -298,6 +327,7 @@ class BalanceScheduler(Scheduler): num_external_computed_tokens = 0 load_kv_async = False + connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0 # Get already-cached tokens. if request.num_computed_tokens == 0: @@ -323,6 +353,9 @@ class BalanceScheduler(Scheduler): request.num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens + connector_prefix_cache_queries = request.num_tokens - num_new_local_computed_tokens + connector_prefix_cache_hits = num_external_computed_tokens + # Total computed tokens (local + external). num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens else: @@ -378,6 +411,16 @@ class BalanceScheduler(Scheduler): # The request cannot be scheduled. break + if self.need_mamba_block_aligned_split: + num_new_tokens = self._mamba_block_aligned_split( + request, + num_new_tokens, + num_new_local_computed_tokens, + num_external_computed_tokens, + ) + if num_new_tokens == 0: + break + # Handles an edge case when P/D Disaggregation # is used with Spec Decoding where an # extra block gets allocated which @@ -386,27 +429,28 @@ class BalanceScheduler(Scheduler): effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens # Determine if we need to allocate cross-attention blocks. - if self.is_encoder_decoder and request.has_encoder_inputs: - # TODO(russellb): For Whisper, we know that the input is - # always padded to the maximum length. If we support other - # encoder-decoder models, this will need to be updated if we - # want to only allocate what is needed. - num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens - else: - num_encoder_tokens = 0 + num_encoder_tokens = 0 + if self.is_encoder_decoder and request.has_encoder_inputs and encoder_inputs_to_schedule: + num_encoder_tokens = sum(request.get_num_encoder_embeds(i) for i in encoder_inputs_to_schedule) new_blocks = self.kv_cache_manager.allocate_slots( request, - num_new_tokens + num_external_computed_tokens, - num_new_local_computed_tokens, - new_computed_blocks, + num_new_tokens, + num_new_computed_tokens=num_new_local_computed_tokens, + new_computed_blocks=new_computed_blocks, num_lookahead_tokens=effective_lookahead_tokens, + num_external_computed_tokens=num_external_computed_tokens, delay_cache_blocks=load_kv_async, num_encoder_tokens=num_encoder_tokens, ) if new_blocks is None: # The request cannot be scheduled. + + # NOTE: we need to untouch the request from the encode cache + # manager + if request.has_encoder_inputs: + self.encoder_cache_manager.free(request) break # KVTransfer: the connector uses this info to determine @@ -416,9 +460,15 @@ class BalanceScheduler(Scheduler): if self.connector is not None: self.connector.update_state_after_alloc( request, - new_computed_blocks + new_blocks, + self.kv_cache_manager.get_blocks(request_id), num_external_computed_tokens, ) + if self.connector_prefix_cache_stats is not None and connector_prefix_cache_queries != 0: + self.connector_prefix_cache_stats.record( + num_tokens=connector_prefix_cache_queries, + num_hits=connector_prefix_cache_hits, + preempted=request.num_preemptions > 0, + ) # Request was already popped from self.waiting # unless it was re-added above due to new_blocks being None. @@ -430,8 +480,6 @@ class BalanceScheduler(Scheduler): request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue - self._update_connector_prefix_cache_stats(request) - self.running.append(request) if self.log_stats: request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) @@ -444,8 +492,8 @@ class BalanceScheduler(Scheduler): if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id) - num_scheduled_tokens[request.request_id] = num_new_tokens + req_to_new_blocks[request_id] = self.kv_cache_manager.get_blocks(request_id) + num_scheduled_tokens[request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens @@ -454,7 +502,7 @@ class BalanceScheduler(Scheduler): request.num_cached_tokens = num_computed_tokens # Encoder-related. if encoder_inputs_to_schedule: - scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule + scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -465,9 +513,10 @@ class BalanceScheduler(Scheduler): self.encoder_cache_manager.allocate(request, i) if self.ec_connector is not None: self.ec_connector.update_state_after_alloc(request, i) - # Put back any skipped requests at the head of the waiting queue - if skipped_waiting_requests: - self.waiting.prepend_requests(skipped_waiting_requests) + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.prepend_requests(skipped_waiting_requests) # Check if the scheduling constraints are satisfied. total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) @@ -485,8 +534,8 @@ class BalanceScheduler(Scheduler): num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"): if self.running: - any_request = self.running[0] - num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id) + any_request_id = self.running[0].request_id + num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request_id) # Construct the scheduler output. if self.use_v2_model_runner: