[Patch] Fix balance scheduling (#7611)

### What this PR does / why we need it?
This PR introduces a "balance scheduling" feature, enabled by the
`VLLM_ASCEND_BALANCE_SCHEDULING` environment variable. This feature
adjusts the scheduling logic to better balance the load across
data-parallel workers, preventing a single worker from blocking
scheduling for others. This can improve overall throughput.

Additionally, this PR includes a number of other updates and fixes to
the scheduler, syncing it with a more recent version of the upstream
vLLM scheduler. These changes include:
- Handling for paused scheduler state.
- Support for Mamba block-aligned splits.
- Handling for streaming requests.
- Refinements in preemption logic and resource management (KV cache,
encoder cache).
- General code refactoring for clarity and correctness.

Fixes #

### Does this PR introduce _any_ user-facing change?
Yes, this PR introduces a new feature controlled by the
`VLLM_ASCEND_BALANCE_SCHEDULING` environment variable. When enabled, the
scheduling behavior changes, which could affect performance and request
throughput.

### How was this patch tested?
CI passed. Further testing should be done to validate the performance
and correctness of the new scheduling logic under various workloads,
with and without the feature flag enabled.

Signed-off-by: GDzhu01 <809721801@qq.com>
This commit is contained in:
Zhu Yi Lin
2026-03-25 08:57:06 +08:00
committed by GitHub
parent 3f4087a8f0
commit fc3ec100bc
2 changed files with 94 additions and 41 deletions

View File

@@ -19,6 +19,7 @@ import os
import vllm_ascend.patch.platform.patch_distributed # noqa import vllm_ascend.patch.platform.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa
import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa
from vllm_ascend import envs
from vllm_ascend.utils import is_310p from vllm_ascend.utils import is_310p
if not is_310p(): if not is_310p():
@@ -31,3 +32,6 @@ import vllm_ascend.patch.platform.patch_torch_accelerator # noqa
if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv("EXPERT_MAP_RECORD", "false") == "true": if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv("EXPERT_MAP_RECORD", "false") == "true":
import vllm_ascend.patch.platform.patch_multiproc_executor # noqa import vllm_ascend.patch.platform.patch_multiproc_executor # noqa
if envs.VLLM_ASCEND_BALANCE_SCHEDULING:
import vllm_ascend.patch.platform.patch_balance_schedule # noqa

View File

@@ -13,6 +13,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.utils.system_utils import decorate_logs, set_process_title from vllm.utils.system_utils import decorate_logs, set_process_title
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.interface import PauseState
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
@@ -65,6 +66,7 @@ class BalanceScheduler(Scheduler):
# num_tokens_with_spec. This is general enough to cover # num_tokens_with_spec. This is general enough to cover
# chunked prefills, prefix caching, speculative decoding, # chunked prefills, prefix caching, speculative decoding,
# and the "jump decoding" optimization in the future. # and the "jump decoding" optimization in the future.
scheduled_new_reqs: list[Request] = [] scheduled_new_reqs: list[Request] = []
scheduled_resumed_reqs: list[Request] = [] scheduled_resumed_reqs: list[Request] = []
scheduled_running_reqs: list[Request] = [] scheduled_running_reqs: list[Request] = []
@@ -73,6 +75,10 @@ class BalanceScheduler(Scheduler):
req_to_new_blocks: dict[str, KVCacheBlocks] = {} req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {} num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens token_budget = self.max_num_scheduled_tokens
if self._pause_state == PauseState.PAUSED_ALL:
# Do not schedule any requests when paused.
token_budget = 0
# Encoder-related. # Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {} scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_compute_budget = self.max_num_encoder_input_tokens encoder_compute_budget = self.max_num_encoder_input_tokens
@@ -82,6 +88,8 @@ class BalanceScheduler(Scheduler):
# For logging. # For logging.
scheduled_timestamp = time.monotonic() scheduled_timestamp = time.monotonic()
self.kv_cache_manager.new_step_starts()
# First, schedule the RUNNING requests. # First, schedule the RUNNING requests.
req_index = 0 req_index = 0
while req_index < len(self.running) and token_budget > 0: while req_index < len(self.running) and token_budget > 0:
@@ -132,6 +140,9 @@ class BalanceScheduler(Scheduler):
shift_computed_tokens=1 if self.use_eagle else 0, shift_computed_tokens=1 if self.use_eagle else 0,
) )
if self.need_mamba_block_aligned_split:
num_new_tokens = self._mamba_block_aligned_split(request, num_new_tokens)
if num_new_tokens == 0: if num_new_tokens == 0:
# The request cannot be scheduled because one of the following # The request cannot be scheduled because one of the following
# reasons: # reasons:
@@ -142,6 +153,8 @@ class BalanceScheduler(Scheduler):
# its max_total_tokens or max_model_len. # its max_total_tokens or max_model_len.
# 2. The encoder budget is exhausted. # 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted. # 3. The encoder cache is exhausted.
# 4. Insufficient budget for a block-aligned chunk in hybrid
# models with mamba cache mode \"align\".
# NOTE(woosuk): Here, by doing `continue` instead of `break`, # NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and # we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled. # allow the lower-priority requests to be scheduled.
@@ -170,12 +183,12 @@ class BalanceScheduler(Scheduler):
) )
self.running.remove(preempted_req) self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs: if preempted_req in scheduled_running_reqs:
preempted_req_id = preempted_req.request_id
scheduled_running_reqs.remove(preempted_req) scheduled_running_reqs.remove(preempted_req)
token_budget += num_scheduled_tokens[preempted_req.request_id] token_budget += num_scheduled_tokens.pop(preempted_req_id)
req_to_new_blocks.pop(preempted_req.request_id) req_to_new_blocks.pop(preempted_req_id)
num_scheduled_tokens.pop(preempted_req.request_id) scheduled_spec_decode_tokens.pop(preempted_req_id, None)
scheduled_spec_decode_tokens.pop(preempted_req.request_id, None) preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req_id, None)
preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req.request_id, None)
if preempted_encoder_inputs: if preempted_encoder_inputs:
# Restore encoder compute budget if the preempted # Restore encoder compute budget if the preempted
# request had encoder inputs scheduled in this step. # request had encoder inputs scheduled in this step.
@@ -199,8 +212,9 @@ class BalanceScheduler(Scheduler):
# Schedule the request. # Schedule the request.
scheduled_running_reqs.append(request) scheduled_running_reqs.append(request)
req_to_new_blocks[request.request_id] = new_blocks request_id = request.request_id
num_scheduled_tokens[request.request_id] = num_new_tokens req_to_new_blocks[request_id] = new_blocks
num_scheduled_tokens[request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
req_index += 1 req_index += 1
@@ -210,16 +224,18 @@ class BalanceScheduler(Scheduler):
num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders
) )
if num_scheduled_spec_tokens > 0: if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens. spec_token_ids = request.spec_token_ids
del request.spec_token_ids[num_scheduled_spec_tokens:] if len(spec_token_ids) > num_scheduled_spec_tokens:
scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens]
scheduled_spec_decode_tokens[request.request_id] = spec_token_ids
# New spec tokens will be set in `update_draft_token_ids` before the # New spec tokens will be set in `update_draft_token_ids` before the
# next step when applicable. # next step when applicable.
request.spec_token_ids = [] request.spec_token_ids = []
# Encoder-related. # Encoder-related.
if encoder_inputs_to_schedule: if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
# Allocate the encoder cache. # Allocate the encoder cache.
for i in encoder_inputs_to_schedule: for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
@@ -240,31 +256,37 @@ class BalanceScheduler(Scheduler):
) )
assert len(scheduled_loras) <= self.lora_config.max_loras assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
# Next, schedule the WAITING requests. # Next, schedule the WAITING requests.
if not preempted_reqs: if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
while self.waiting and token_budget > 0: while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs: if len(self.running) == self.max_num_running_reqs:
break break
balance_flag = max(t.item() for t in self.balance_queue) == self.max_num_running_reqs balance_flag = max(t.item() for t in self.balance_queue) >= self.max_num_running_reqs - 1
if balance_flag: if balance_flag:
break break
request = self.waiting.peek_request() request = self.waiting.peek_request()
request_id = request.request_id
# KVTransfer: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request) is_ready = self._update_waiting_for_remote_kv(request)
if is_ready: if is_ready:
request.status = RequestStatus.WAITING if request.num_preemptions:
# We must be loading for a resumed preemption
# rather than a new request.
request.status = RequestStatus.PREEMPTED
else:
request.status = RequestStatus.WAITING
else: else:
logger.debug( logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.", "%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id, request_id,
) )
self.waiting.pop_request() self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
@@ -281,6 +303,13 @@ class BalanceScheduler(Scheduler):
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
continue continue
# Streaming: skip request if still waiting for next streaming req.
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
assert not request.streaming_queue
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Check that adding the request still respects the max_loras # Check that adding the request still respects the max_loras
# constraint. # constraint.
if ( if (
@@ -298,6 +327,7 @@ class BalanceScheduler(Scheduler):
num_external_computed_tokens = 0 num_external_computed_tokens = 0
load_kv_async = False load_kv_async = False
connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0
# Get already-cached tokens. # Get already-cached tokens.
if request.num_computed_tokens == 0: if request.num_computed_tokens == 0:
@@ -323,6 +353,9 @@ class BalanceScheduler(Scheduler):
request.num_external_computed_tokens = ext_tokens request.num_external_computed_tokens = ext_tokens
num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens
connector_prefix_cache_queries = request.num_tokens - num_new_local_computed_tokens
connector_prefix_cache_hits = num_external_computed_tokens
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
else: else:
@@ -378,6 +411,16 @@ class BalanceScheduler(Scheduler):
# The request cannot be scheduled. # The request cannot be scheduled.
break break
if self.need_mamba_block_aligned_split:
num_new_tokens = self._mamba_block_aligned_split(
request,
num_new_tokens,
num_new_local_computed_tokens,
num_external_computed_tokens,
)
if num_new_tokens == 0:
break
# Handles an edge case when P/D Disaggregation # Handles an edge case when P/D Disaggregation
# is used with Spec Decoding where an # is used with Spec Decoding where an
# extra block gets allocated which # extra block gets allocated which
@@ -386,27 +429,28 @@ class BalanceScheduler(Scheduler):
effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
# Determine if we need to allocate cross-attention blocks. # Determine if we need to allocate cross-attention blocks.
if self.is_encoder_decoder and request.has_encoder_inputs: num_encoder_tokens = 0
# TODO(russellb): For Whisper, we know that the input is if self.is_encoder_decoder and request.has_encoder_inputs and encoder_inputs_to_schedule:
# always padded to the maximum length. If we support other num_encoder_tokens = sum(request.get_num_encoder_embeds(i) for i in encoder_inputs_to_schedule)
# encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed.
num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens
else:
num_encoder_tokens = 0
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request,
num_new_tokens + num_external_computed_tokens, num_new_tokens,
num_new_local_computed_tokens, num_new_computed_tokens=num_new_local_computed_tokens,
new_computed_blocks, new_computed_blocks=new_computed_blocks,
num_lookahead_tokens=effective_lookahead_tokens, num_lookahead_tokens=effective_lookahead_tokens,
num_external_computed_tokens=num_external_computed_tokens,
delay_cache_blocks=load_kv_async, delay_cache_blocks=load_kv_async,
num_encoder_tokens=num_encoder_tokens, num_encoder_tokens=num_encoder_tokens,
) )
if new_blocks is None: if new_blocks is None:
# The request cannot be scheduled. # The request cannot be scheduled.
# NOTE: we need to untouch the request from the encode cache
# manager
if request.has_encoder_inputs:
self.encoder_cache_manager.free(request)
break break
# KVTransfer: the connector uses this info to determine # KVTransfer: the connector uses this info to determine
@@ -416,9 +460,15 @@ class BalanceScheduler(Scheduler):
if self.connector is not None: if self.connector is not None:
self.connector.update_state_after_alloc( self.connector.update_state_after_alloc(
request, request,
new_computed_blocks + new_blocks, self.kv_cache_manager.get_blocks(request_id),
num_external_computed_tokens, num_external_computed_tokens,
) )
if self.connector_prefix_cache_stats is not None and connector_prefix_cache_queries != 0:
self.connector_prefix_cache_stats.record(
num_tokens=connector_prefix_cache_queries,
num_hits=connector_prefix_cache_hits,
preempted=request.num_preemptions > 0,
)
# Request was already popped from self.waiting # Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None. # unless it was re-added above due to new_blocks being None.
@@ -430,8 +480,6 @@ class BalanceScheduler(Scheduler):
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue continue
self._update_connector_prefix_cache_stats(request)
self.running.append(request) self.running.append(request)
if self.log_stats: if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
@@ -444,8 +492,8 @@ class BalanceScheduler(Scheduler):
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id) scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id) req_to_new_blocks[request_id] = self.kv_cache_manager.get_blocks(request_id)
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
@@ -454,7 +502,7 @@ class BalanceScheduler(Scheduler):
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
# Encoder-related. # Encoder-related.
if encoder_inputs_to_schedule: if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
# Allocate the encoder cache. # Allocate the encoder cache.
for i in encoder_inputs_to_schedule: for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
@@ -465,9 +513,10 @@ class BalanceScheduler(Scheduler):
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None: if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i) self.ec_connector.update_state_after_alloc(request, i)
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests: # Put back any skipped requests at the head of the waiting queue
self.waiting.prepend_requests(skipped_waiting_requests) if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
# Check if the scheduling constraints are satisfied. # Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
@@ -485,8 +534,8 @@ class BalanceScheduler(Scheduler):
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"): with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
if self.running: if self.running:
any_request = self.running[0] any_request_id = self.running[0].request_id
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id) num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request_id)
# Construct the scheduler output. # Construct the scheduler output.
if self.use_v2_model_runner: if self.use_v2_model_runner: