[Patch] Fix balance scheduling (#7611)
### What this PR does / why we need it? This PR introduces a "balance scheduling" feature, enabled by the `VLLM_ASCEND_BALANCE_SCHEDULING` environment variable. This feature adjusts the scheduling logic to better balance the load across data-parallel workers, preventing a single worker from blocking scheduling for others. This can improve overall throughput. Additionally, this PR includes a number of other updates and fixes to the scheduler, syncing it with a more recent version of the upstream vLLM scheduler. These changes include: - Handling for paused scheduler state. - Support for Mamba block-aligned splits. - Handling for streaming requests. - Refinements in preemption logic and resource management (KV cache, encoder cache). - General code refactoring for clarity and correctness. Fixes # ### Does this PR introduce _any_ user-facing change? Yes, this PR introduces a new feature controlled by the `VLLM_ASCEND_BALANCE_SCHEDULING` environment variable. When enabled, the scheduling behavior changes, which could affect performance and request throughput. ### How was this patch tested? CI passed. Further testing should be done to validate the performance and correctness of the new scheduling logic under various workloads, with and without the feature flag enabled. Signed-off-by: GDzhu01 <809721801@qq.com>
This commit is contained in:
@@ -19,6 +19,7 @@ import os
|
||||
import vllm_ascend.patch.platform.patch_distributed # noqa
|
||||
import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa
|
||||
import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.utils import is_310p
|
||||
|
||||
if not is_310p():
|
||||
@@ -31,3 +32,6 @@ import vllm_ascend.patch.platform.patch_torch_accelerator # noqa
|
||||
|
||||
if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv("EXPERT_MAP_RECORD", "false") == "true":
|
||||
import vllm_ascend.patch.platform.patch_multiproc_executor # noqa
|
||||
|
||||
if envs.VLLM_ASCEND_BALANCE_SCHEDULING:
|
||||
import vllm_ascend.patch.platform.patch_balance_schedule # noqa
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||
from vllm.utils.system_utils import decorate_logs, set_process_title
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.interface import PauseState
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
@@ -65,6 +66,7 @@ class BalanceScheduler(Scheduler):
|
||||
# num_tokens_with_spec. This is general enough to cover
|
||||
# chunked prefills, prefix caching, speculative decoding,
|
||||
# and the "jump decoding" optimization in the future.
|
||||
|
||||
scheduled_new_reqs: list[Request] = []
|
||||
scheduled_resumed_reqs: list[Request] = []
|
||||
scheduled_running_reqs: list[Request] = []
|
||||
@@ -73,6 +75,10 @@ class BalanceScheduler(Scheduler):
|
||||
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
||||
num_scheduled_tokens: dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
if self._pause_state == PauseState.PAUSED_ALL:
|
||||
# Do not schedule any requests when paused.
|
||||
token_budget = 0
|
||||
|
||||
# Encoder-related.
|
||||
scheduled_encoder_inputs: dict[str, list[int]] = {}
|
||||
encoder_compute_budget = self.max_num_encoder_input_tokens
|
||||
@@ -82,6 +88,8 @@ class BalanceScheduler(Scheduler):
|
||||
# For logging.
|
||||
scheduled_timestamp = time.monotonic()
|
||||
|
||||
self.kv_cache_manager.new_step_starts()
|
||||
|
||||
# First, schedule the RUNNING requests.
|
||||
req_index = 0
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
@@ -132,6 +140,9 @@ class BalanceScheduler(Scheduler):
|
||||
shift_computed_tokens=1 if self.use_eagle else 0,
|
||||
)
|
||||
|
||||
if self.need_mamba_block_aligned_split:
|
||||
num_new_tokens = self._mamba_block_aligned_split(request, num_new_tokens)
|
||||
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled because one of the following
|
||||
# reasons:
|
||||
@@ -142,6 +153,8 @@ class BalanceScheduler(Scheduler):
|
||||
# its max_total_tokens or max_model_len.
|
||||
# 2. The encoder budget is exhausted.
|
||||
# 3. The encoder cache is exhausted.
|
||||
# 4. Insufficient budget for a block-aligned chunk in hybrid
|
||||
# models with mamba cache mode \"align\".
|
||||
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
|
||||
# we do not strictly follow the FCFS scheduling policy and
|
||||
# allow the lower-priority requests to be scheduled.
|
||||
@@ -170,12 +183,12 @@ class BalanceScheduler(Scheduler):
|
||||
)
|
||||
self.running.remove(preempted_req)
|
||||
if preempted_req in scheduled_running_reqs:
|
||||
preempted_req_id = preempted_req.request_id
|
||||
scheduled_running_reqs.remove(preempted_req)
|
||||
token_budget += num_scheduled_tokens[preempted_req.request_id]
|
||||
req_to_new_blocks.pop(preempted_req.request_id)
|
||||
num_scheduled_tokens.pop(preempted_req.request_id)
|
||||
scheduled_spec_decode_tokens.pop(preempted_req.request_id, None)
|
||||
preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req.request_id, None)
|
||||
token_budget += num_scheduled_tokens.pop(preempted_req_id)
|
||||
req_to_new_blocks.pop(preempted_req_id)
|
||||
scheduled_spec_decode_tokens.pop(preempted_req_id, None)
|
||||
preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req_id, None)
|
||||
if preempted_encoder_inputs:
|
||||
# Restore encoder compute budget if the preempted
|
||||
# request had encoder inputs scheduled in this step.
|
||||
@@ -199,8 +212,9 @@ class BalanceScheduler(Scheduler):
|
||||
|
||||
# Schedule the request.
|
||||
scheduled_running_reqs.append(request)
|
||||
req_to_new_blocks[request.request_id] = new_blocks
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
request_id = request.request_id
|
||||
req_to_new_blocks[request_id] = new_blocks
|
||||
num_scheduled_tokens[request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
req_index += 1
|
||||
|
||||
@@ -210,16 +224,18 @@ class BalanceScheduler(Scheduler):
|
||||
num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders
|
||||
)
|
||||
if num_scheduled_spec_tokens > 0:
|
||||
# Trim spec_token_ids list to num_scheduled_spec_tokens.
|
||||
del request.spec_token_ids[num_scheduled_spec_tokens:]
|
||||
scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids
|
||||
spec_token_ids = request.spec_token_ids
|
||||
if len(spec_token_ids) > num_scheduled_spec_tokens:
|
||||
spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens]
|
||||
scheduled_spec_decode_tokens[request.request_id] = spec_token_ids
|
||||
|
||||
# New spec tokens will be set in `update_draft_token_ids` before the
|
||||
# next step when applicable.
|
||||
request.spec_token_ids = []
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
|
||||
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
@@ -240,31 +256,37 @@ class BalanceScheduler(Scheduler):
|
||||
)
|
||||
assert len(scheduled_loras) <= self.lora_config.max_loras
|
||||
|
||||
# Use a temporary RequestQueue to collect requests that need to be
|
||||
# skipped and put back at the head of the waiting queue later
|
||||
skipped_waiting_requests = create_request_queue(self.policy)
|
||||
|
||||
# Next, schedule the WAITING requests.
|
||||
if not preempted_reqs:
|
||||
if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
|
||||
# Use a temporary RequestQueue to collect requests that need to be
|
||||
# skipped and put back at the head of the waiting queue later
|
||||
skipped_waiting_requests = create_request_queue(self.policy)
|
||||
|
||||
while self.waiting and token_budget > 0:
|
||||
if len(self.running) == self.max_num_running_reqs:
|
||||
break
|
||||
|
||||
balance_flag = max(t.item() for t in self.balance_queue) == self.max_num_running_reqs
|
||||
balance_flag = max(t.item() for t in self.balance_queue) >= self.max_num_running_reqs - 1
|
||||
if balance_flag:
|
||||
break
|
||||
|
||||
request = self.waiting.peek_request()
|
||||
request_id = request.request_id
|
||||
|
||||
# KVTransfer: skip request if still waiting for remote kvs.
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
is_ready = self._update_waiting_for_remote_kv(request)
|
||||
if is_ready:
|
||||
request.status = RequestStatus.WAITING
|
||||
if request.num_preemptions:
|
||||
# We must be loading for a resumed preemption
|
||||
# rather than a new request.
|
||||
request.status = RequestStatus.PREEMPTED
|
||||
else:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
logger.debug(
|
||||
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
||||
request.request_id,
|
||||
request_id,
|
||||
)
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
@@ -281,6 +303,13 @@ class BalanceScheduler(Scheduler):
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Streaming: skip request if still waiting for next streaming req.
|
||||
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
|
||||
assert not request.streaming_queue
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if (
|
||||
@@ -298,6 +327,7 @@ class BalanceScheduler(Scheduler):
|
||||
|
||||
num_external_computed_tokens = 0
|
||||
load_kv_async = False
|
||||
connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0
|
||||
|
||||
# Get already-cached tokens.
|
||||
if request.num_computed_tokens == 0:
|
||||
@@ -323,6 +353,9 @@ class BalanceScheduler(Scheduler):
|
||||
request.num_external_computed_tokens = ext_tokens
|
||||
num_external_computed_tokens = ext_tokens
|
||||
|
||||
connector_prefix_cache_queries = request.num_tokens - num_new_local_computed_tokens
|
||||
connector_prefix_cache_hits = num_external_computed_tokens
|
||||
|
||||
# Total computed tokens (local + external).
|
||||
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
|
||||
else:
|
||||
@@ -378,6 +411,16 @@ class BalanceScheduler(Scheduler):
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
if self.need_mamba_block_aligned_split:
|
||||
num_new_tokens = self._mamba_block_aligned_split(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_new_local_computed_tokens,
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
if num_new_tokens == 0:
|
||||
break
|
||||
|
||||
# Handles an edge case when P/D Disaggregation
|
||||
# is used with Spec Decoding where an
|
||||
# extra block gets allocated which
|
||||
@@ -386,27 +429,28 @@ class BalanceScheduler(Scheduler):
|
||||
effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
|
||||
|
||||
# Determine if we need to allocate cross-attention blocks.
|
||||
if self.is_encoder_decoder and request.has_encoder_inputs:
|
||||
# TODO(russellb): For Whisper, we know that the input is
|
||||
# always padded to the maximum length. If we support other
|
||||
# encoder-decoder models, this will need to be updated if we
|
||||
# want to only allocate what is needed.
|
||||
num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens
|
||||
else:
|
||||
num_encoder_tokens = 0
|
||||
num_encoder_tokens = 0
|
||||
if self.is_encoder_decoder and request.has_encoder_inputs and encoder_inputs_to_schedule:
|
||||
num_encoder_tokens = sum(request.get_num_encoder_embeds(i) for i in encoder_inputs_to_schedule)
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens + num_external_computed_tokens,
|
||||
num_new_local_computed_tokens,
|
||||
new_computed_blocks,
|
||||
num_new_tokens,
|
||||
num_new_computed_tokens=num_new_local_computed_tokens,
|
||||
new_computed_blocks=new_computed_blocks,
|
||||
num_lookahead_tokens=effective_lookahead_tokens,
|
||||
num_external_computed_tokens=num_external_computed_tokens,
|
||||
delay_cache_blocks=load_kv_async,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
)
|
||||
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
|
||||
# NOTE: we need to untouch the request from the encode cache
|
||||
# manager
|
||||
if request.has_encoder_inputs:
|
||||
self.encoder_cache_manager.free(request)
|
||||
break
|
||||
|
||||
# KVTransfer: the connector uses this info to determine
|
||||
@@ -416,9 +460,15 @@ class BalanceScheduler(Scheduler):
|
||||
if self.connector is not None:
|
||||
self.connector.update_state_after_alloc(
|
||||
request,
|
||||
new_computed_blocks + new_blocks,
|
||||
self.kv_cache_manager.get_blocks(request_id),
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
if self.connector_prefix_cache_stats is not None and connector_prefix_cache_queries != 0:
|
||||
self.connector_prefix_cache_stats.record(
|
||||
num_tokens=connector_prefix_cache_queries,
|
||||
num_hits=connector_prefix_cache_hits,
|
||||
preempted=request.num_preemptions > 0,
|
||||
)
|
||||
|
||||
# Request was already popped from self.waiting
|
||||
# unless it was re-added above due to new_blocks being None.
|
||||
@@ -430,8 +480,6 @@ class BalanceScheduler(Scheduler):
|
||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
continue
|
||||
|
||||
self._update_connector_prefix_cache_stats(request)
|
||||
|
||||
self.running.append(request)
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
|
||||
@@ -444,8 +492,8 @@ class BalanceScheduler(Scheduler):
|
||||
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id)
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
req_to_new_blocks[request_id] = self.kv_cache_manager.get_blocks(request_id)
|
||||
num_scheduled_tokens[request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
@@ -454,7 +502,7 @@ class BalanceScheduler(Scheduler):
|
||||
request.num_cached_tokens = num_computed_tokens
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
|
||||
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
@@ -465,9 +513,10 @@ class BalanceScheduler(Scheduler):
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
if self.ec_connector is not None:
|
||||
self.ec_connector.update_state_after_alloc(request, i)
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
||||
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
||||
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
@@ -485,8 +534,8 @@ class BalanceScheduler(Scheduler):
|
||||
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
|
||||
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
|
||||
if self.running:
|
||||
any_request = self.running[0]
|
||||
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id)
|
||||
any_request_id = self.running[0].request_id
|
||||
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request_id)
|
||||
|
||||
# Construct the scheduler output.
|
||||
if self.use_v2_model_runner:
|
||||
|
||||
Reference in New Issue
Block a user