[Patch] Fix balance scheduling (#7611)

### What this PR does / why we need it?
This PR introduces a "balance scheduling" feature, enabled by the
`VLLM_ASCEND_BALANCE_SCHEDULING` environment variable. This feature
adjusts the scheduling logic to better balance the load across
data-parallel workers, preventing a single worker from blocking
scheduling for others. This can improve overall throughput.

Additionally, this PR includes a number of other updates and fixes to
the scheduler, syncing it with a more recent version of the upstream
vLLM scheduler. These changes include:
- Handling for paused scheduler state.
- Support for Mamba block-aligned splits.
- Handling for streaming requests.
- Refinements in preemption logic and resource management (KV cache,
encoder cache).
- General code refactoring for clarity and correctness.

Fixes #

### Does this PR introduce _any_ user-facing change?
Yes, this PR introduces a new feature controlled by the
`VLLM_ASCEND_BALANCE_SCHEDULING` environment variable. When enabled, the
scheduling behavior changes, which could affect performance and request
throughput.

### How was this patch tested?
CI passed. Further testing should be done to validate the performance
and correctness of the new scheduling logic under various workloads,
with and without the feature flag enabled.

Signed-off-by: GDzhu01 <809721801@qq.com>
This commit is contained in:
Zhu Yi Lin
2026-03-25 08:57:06 +08:00
committed by GitHub
parent 3f4087a8f0
commit fc3ec100bc
2 changed files with 94 additions and 41 deletions

View File

@@ -19,6 +19,7 @@ import os
import vllm_ascend.patch.platform.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa
import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa
from vllm_ascend import envs
from vllm_ascend.utils import is_310p
if not is_310p():
@@ -31,3 +32,6 @@ import vllm_ascend.patch.platform.patch_torch_accelerator # noqa
if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv("EXPERT_MAP_RECORD", "false") == "true":
import vllm_ascend.patch.platform.patch_multiproc_executor # noqa
if envs.VLLM_ASCEND_BALANCE_SCHEDULING:
import vllm_ascend.patch.platform.patch_balance_schedule # noqa

View File

@@ -13,6 +13,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.utils.system_utils import decorate_logs, set_process_title
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.interface import PauseState
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.scheduler import Scheduler
@@ -65,6 +66,7 @@ class BalanceScheduler(Scheduler):
# num_tokens_with_spec. This is general enough to cover
# chunked prefills, prefix caching, speculative decoding,
# and the "jump decoding" optimization in the future.
scheduled_new_reqs: list[Request] = []
scheduled_resumed_reqs: list[Request] = []
scheduled_running_reqs: list[Request] = []
@@ -73,6 +75,10 @@ class BalanceScheduler(Scheduler):
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
if self._pause_state == PauseState.PAUSED_ALL:
# Do not schedule any requests when paused.
token_budget = 0
# Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_compute_budget = self.max_num_encoder_input_tokens
@@ -82,6 +88,8 @@ class BalanceScheduler(Scheduler):
# For logging.
scheduled_timestamp = time.monotonic()
self.kv_cache_manager.new_step_starts()
# First, schedule the RUNNING requests.
req_index = 0
while req_index < len(self.running) and token_budget > 0:
@@ -132,6 +140,9 @@ class BalanceScheduler(Scheduler):
shift_computed_tokens=1 if self.use_eagle else 0,
)
if self.need_mamba_block_aligned_split:
num_new_tokens = self._mamba_block_aligned_split(request, num_new_tokens)
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reasons:
@@ -142,6 +153,8 @@ class BalanceScheduler(Scheduler):
# its max_total_tokens or max_model_len.
# 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted.
# 4. Insufficient budget for a block-aligned chunk in hybrid
# models with mamba cache mode \"align\".
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled.
@@ -170,12 +183,12 @@ class BalanceScheduler(Scheduler):
)
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
preempted_req_id = preempted_req.request_id
scheduled_running_reqs.remove(preempted_req)
token_budget += num_scheduled_tokens[preempted_req.request_id]
req_to_new_blocks.pop(preempted_req.request_id)
num_scheduled_tokens.pop(preempted_req.request_id)
scheduled_spec_decode_tokens.pop(preempted_req.request_id, None)
preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req.request_id, None)
token_budget += num_scheduled_tokens.pop(preempted_req_id)
req_to_new_blocks.pop(preempted_req_id)
scheduled_spec_decode_tokens.pop(preempted_req_id, None)
preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req_id, None)
if preempted_encoder_inputs:
# Restore encoder compute budget if the preempted
# request had encoder inputs scheduled in this step.
@@ -199,8 +212,9 @@ class BalanceScheduler(Scheduler):
# Schedule the request.
scheduled_running_reqs.append(request)
req_to_new_blocks[request.request_id] = new_blocks
num_scheduled_tokens[request.request_id] = num_new_tokens
request_id = request.request_id
req_to_new_blocks[request_id] = new_blocks
num_scheduled_tokens[request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
@@ -210,16 +224,18 @@ class BalanceScheduler(Scheduler):
num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders
)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids
spec_token_ids = request.spec_token_ids
if len(spec_token_ids) > num_scheduled_spec_tokens:
spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens]
scheduled_spec_decode_tokens[request.request_id] = spec_token_ids
# New spec tokens will be set in `update_draft_token_ids` before the
# next step when applicable.
request.spec_token_ids = []
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
@@ -240,31 +256,37 @@ class BalanceScheduler(Scheduler):
)
assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
# Next, schedule the WAITING requests.
if not preempted_reqs:
if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
break
balance_flag = max(t.item() for t in self.balance_queue) == self.max_num_running_reqs
balance_flag = max(t.item() for t in self.balance_queue) >= self.max_num_running_reqs - 1
if balance_flag:
break
request = self.waiting.peek_request()
request_id = request.request_id
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
request.status = RequestStatus.WAITING
if request.num_preemptions:
# We must be loading for a resumed preemption
# rather than a new request.
request.status = RequestStatus.PREEMPTED
else:
request.status = RequestStatus.WAITING
else:
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id,
request_id,
)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
@@ -281,6 +303,13 @@ class BalanceScheduler(Scheduler):
skipped_waiting_requests.prepend_request(request)
continue
# Streaming: skip request if still waiting for next streaming req.
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
assert not request.streaming_queue
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Check that adding the request still respects the max_loras
# constraint.
if (
@@ -298,6 +327,7 @@ class BalanceScheduler(Scheduler):
num_external_computed_tokens = 0
load_kv_async = False
connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0
# Get already-cached tokens.
if request.num_computed_tokens == 0:
@@ -323,6 +353,9 @@ class BalanceScheduler(Scheduler):
request.num_external_computed_tokens = ext_tokens
num_external_computed_tokens = ext_tokens
connector_prefix_cache_queries = request.num_tokens - num_new_local_computed_tokens
connector_prefix_cache_hits = num_external_computed_tokens
# Total computed tokens (local + external).
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
else:
@@ -378,6 +411,16 @@ class BalanceScheduler(Scheduler):
# The request cannot be scheduled.
break
if self.need_mamba_block_aligned_split:
num_new_tokens = self._mamba_block_aligned_split(
request,
num_new_tokens,
num_new_local_computed_tokens,
num_external_computed_tokens,
)
if num_new_tokens == 0:
break
# Handles an edge case when P/D Disaggregation
# is used with Spec Decoding where an
# extra block gets allocated which
@@ -386,27 +429,28 @@ class BalanceScheduler(Scheduler):
effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
# Determine if we need to allocate cross-attention blocks.
if self.is_encoder_decoder and request.has_encoder_inputs:
# TODO(russellb): For Whisper, we know that the input is
# always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed.
num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens
else:
num_encoder_tokens = 0
num_encoder_tokens = 0
if self.is_encoder_decoder and request.has_encoder_inputs and encoder_inputs_to_schedule:
num_encoder_tokens = sum(request.get_num_encoder_embeds(i) for i in encoder_inputs_to_schedule)
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
num_new_local_computed_tokens,
new_computed_blocks,
num_new_tokens,
num_new_computed_tokens=num_new_local_computed_tokens,
new_computed_blocks=new_computed_blocks,
num_lookahead_tokens=effective_lookahead_tokens,
num_external_computed_tokens=num_external_computed_tokens,
delay_cache_blocks=load_kv_async,
num_encoder_tokens=num_encoder_tokens,
)
if new_blocks is None:
# The request cannot be scheduled.
# NOTE: we need to untouch the request from the encode cache
# manager
if request.has_encoder_inputs:
self.encoder_cache_manager.free(request)
break
# KVTransfer: the connector uses this info to determine
@@ -416,9 +460,15 @@ class BalanceScheduler(Scheduler):
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
self.kv_cache_manager.get_blocks(request_id),
num_external_computed_tokens,
)
if self.connector_prefix_cache_stats is not None and connector_prefix_cache_queries != 0:
self.connector_prefix_cache_stats.record(
num_tokens=connector_prefix_cache_queries,
num_hits=connector_prefix_cache_hits,
preempted=request.num_preemptions > 0,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
@@ -430,8 +480,6 @@ class BalanceScheduler(Scheduler):
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
self._update_connector_prefix_cache_stats(request)
self.running.append(request)
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
@@ -444,8 +492,8 @@ class BalanceScheduler(Scheduler):
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id)
num_scheduled_tokens[request.request_id] = num_new_tokens
req_to_new_blocks[request_id] = self.kv_cache_manager.get_blocks(request_id)
num_scheduled_tokens[request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
@@ -454,7 +502,7 @@ class BalanceScheduler(Scheduler):
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
@@ -465,9 +513,10 @@ class BalanceScheduler(Scheduler):
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
@@ -485,8 +534,8 @@ class BalanceScheduler(Scheduler):
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id)
any_request_id = self.running[0].request_id
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request_id)
# Construct the scheduler output.
if self.use_v2_model_runner: