Files
xc-llm-ascend/vllm_ascend/patch/platform/patch_balance_schedule.py

660 lines
31 KiB
Python
Raw Normal View History

[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# mypy: ignore-errors
import signal
import time
import torch
import torch.distributed as dist
import vllm
from vllm.config import ParallelConfig
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
from vllm.utils.system_utils import decorate_logs, set_process_title
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs
from vllm.v1.engine.core import DPEngineCoreProc, EngineCoreProc
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import record_function_or_nullcontext
logger = init_logger(__name__)
class BalanceScheduler(Scheduler):
def __init__(
self,
vllm_config,
kv_cache_config: KVCacheConfig,
structured_output_manager: StructuredOutputManager,
block_size: int,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
) -> None:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
super().__init__(
vllm_config,
kv_cache_config,
structured_output_manager,
block_size,
mm_registry,
include_finished_set,
log_stats,
)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# Balance scheduling.
self.balance_queue = [
torch.tensor([0], dtype=torch.int, device="cpu")
for _ in range(self.vllm_config.parallel_config.data_parallel_size)
]
def balance_gather(self, dp_group):
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
running_tensor = torch.tensor([len(self.running)], dtype=torch.int, device="cpu")
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
dist.all_gather(self.balance_queue, running_tensor, group=dp_group)
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and
# num_tokens_with_spec. num_tokens_with_spec =
# len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# num_tokens_with_spec. This is general enough to cover
# chunked prefills, prefix caching, speculative decoding,
# and the "jump decoding" optimization in the future.
scheduled_new_reqs: list[Request] = []
scheduled_resumed_reqs: list[Request] = []
scheduled_running_reqs: list[Request] = []
preempted_reqs: list[Request] = []
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_compute_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# For logging.
scheduled_timestamp = time.monotonic()
# First, schedule the RUNNING requests.
req_index = 0
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
if (
request.num_output_placeholders > 0
# This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
# Since output placeholders are also included in the computed tokens
# count, we subtract (num_output_placeholders - 1) to remove any draft
# tokens, so that we can be sure no further steps are needed even if
# they are all rejected.
and request.num_computed_tokens + 2 - request.num_output_placeholders
>= request.num_prompt_tokens + request.max_tokens
):
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# Async scheduling: Avoid scheduling an extra step when we are sure that
# the previous step has reached request.max_tokens. We don't schedule
# partial draft tokens since this prevents uniform decode optimizations.
req_index += 1
continue
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
num_new_tokens = (
request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens
)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = min(num_new_tokens, token_budget)
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
num_new_tokens = min(num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
external_load_encoder_input: list[int] = []
new_encoder_compute_budget = encoder_compute_budget
if request.has_encoder_inputs:
(
encoder_inputs_to_schedule,
num_new_tokens,
new_encoder_compute_budget,
external_load_encoder_input,
) = self._try_schedule_encoder_inputs(
request,
request.num_computed_tokens,
num_new_tokens,
encoder_compute_budget,
shift_computed_tokens=1 if self.use_eagle else 0,
)
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reasons:
# 1. No new tokens to schedule. This may happen when
# (1) PP>1 and we have already scheduled all prompt tokens
# but they are not finished yet.
# (2) Async scheduling and the request has reached to either
# its max_total_tokens or max_model_len.
# 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted.
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled.
req_index += 1
continue
# Schedule newly needed KV blocks for the request.
with record_function_or_nullcontext("schedule: allocate_slots"):
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens,
)
if new_blocks is not None:
# The request can be scheduled.
break
# The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
token_budget += num_scheduled_tokens[preempted_req.request_id]
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
req_to_new_blocks.pop(preempted_req.request_id)
num_scheduled_tokens.pop(preempted_req.request_id)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
scheduled_spec_decode_tokens.pop(preempted_req.request_id, None)
preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req.request_id, None)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
if preempted_encoder_inputs:
# Restore encoder compute budget if the preempted
# request had encoder inputs scheduled in this step.
num_embeds_to_restore = sum(
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
preempted_req.get_num_encoder_embeds(i) for i in preempted_encoder_inputs
)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
encoder_compute_budget += num_embeds_to_restore
req_index -= 1
else:
preempted_req = self.running.pop()
self._preempt_request(preempted_req, scheduled_timestamp)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt. Cannot schedule this request.
break
if new_blocks is None:
# Cannot schedule this request.
break
# Schedule the request.
scheduled_running_reqs.append(request)
req_to_new_blocks[request.request_id] = new_blocks
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
# Speculative decode related.
if request.spec_token_ids:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
num_scheduled_spec_tokens = (
num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders
)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# New spec tokens will be set in `update_draft_token_ids` before the
# next step when applicable.
request.spec_token_ids = []
# Encoder-related.
if encoder_inputs_to_schedule:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_compute_budget = new_encoder_compute_budget
if external_load_encoder_input:
for i in external_load_encoder_input:
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
# Record the LoRAs in scheduled_running_reqs
scheduled_loras: set[int] = set()
if self.lora_config:
scheduled_loras = set(
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
req.lora_request.lora_int_id
for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0
)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
break
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
balance_flag = max(t.item() for t in self.balance_queue) == self.max_num_running_reqs
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
if balance_flag:
break
request = self.waiting.peek_request()
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
request.status = RequestStatus.WAITING
else:
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id,
)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Check that adding the request still respects the max_loras
# constraint.
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
if (
self.lora_config
and request.lora_request
and (
len(scheduled_loras) == self.lora_config.max_loras
and request.lora_request.lora_int_id not in scheduled_loras
)
):
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# Scheduling would exceed max_loras, skip.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
num_external_computed_tokens = 0
load_kv_async = False
# Get already-cached tokens.
if request.num_computed_tokens == 0:
# Get locally-cached tokens.
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
new_computed_blocks, num_new_local_computed_tokens = self.kv_cache_manager.get_computed_blocks(
request
)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# Get externally-cached tokens if using a KVConnector.
if self.connector is not None:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
ext_tokens, load_kv_async = self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens
)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
if ext_tokens is None:
# The request cannot be scheduled because
# the KVConnector couldn't determine
# the number of matched tokens.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
request.num_external_computed_tokens = ext_tokens
num_external_computed_tokens = ext_tokens
# Total computed tokens (local + external).
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
else:
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
new_computed_blocks = self.kv_cache_manager.empty_kv_cache_blocks
num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens
encoder_inputs_to_schedule = None
external_load_encoder_input = []
new_encoder_compute_budget = encoder_compute_budget
if load_kv_async:
# KVTransfer: loading remote KV, do not allocate for new work.
assert num_external_computed_tokens > 0
num_new_tokens = 0
else:
# Number of tokens to be scheduled.
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
threshold = self.scheduler_config.long_prefill_token_threshold
if 0 < threshold < num_new_tokens:
num_new_tokens = threshold
# chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
if not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget:
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# If chunked_prefill is disabled,
# we can stop the scheduling here.
break
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Schedule encoder inputs.
if request.has_encoder_inputs:
(
encoder_inputs_to_schedule,
num_new_tokens,
new_encoder_compute_budget,
external_load_encoder_input,
) = self._try_schedule_encoder_inputs(
request,
num_computed_tokens,
num_new_tokens,
encoder_compute_budget,
shift_computed_tokens=1 if self.use_eagle else 0,
)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
# Handles an edge case when P/D Disaggregation
# is used with Spec Decoding where an
# extra block gets allocated which
# creates a mismatch between the number
# of local and remote blocks.
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# Determine if we need to allocate cross-attention blocks.
if self.is_encoder_decoder and request.has_encoder_inputs:
# TODO(russellb): For Whisper, we know that the input is
# always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed.
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
else:
num_encoder_tokens = 0
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
num_new_local_computed_tokens,
new_computed_blocks,
num_lookahead_tokens=effective_lookahead_tokens,
delay_cache_blocks=load_kv_async,
num_encoder_tokens=num_encoder_tokens,
)
if new_blocks is None:
# The request cannot be scheduled.
break
# KVTransfer: the connector uses this info to determine
# if a load is needed. Note that
# This information is used to determine if a load is
# needed for this request.
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
if load_kv_async:
# If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
self._update_connector_prefix_cache_stats(request)
self.running.append(request)
if self.log_stats:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request)
else:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
raise RuntimeError(f"Invalid request status: {request.status}")
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_compute_budget = new_encoder_compute_budget
# Allocate for external load encoder cache
if external_load_encoder_input:
for i in external_load_encoder_input:
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
self.ec_connector.update_state_after_alloc(request, i)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
# Since some requests in the RUNNING queue may not be scheduled in
# this step, the total number of scheduled requests can be smaller than
# len(self.running).
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) <= len(self.running)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
if self.running:
any_request = self.running[0]
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# Construct the scheduler output.
if self.use_v2_model_runner:
scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs
scheduled_resumed_reqs = []
new_reqs_data = [
NewRequestData.from_request(
req,
req_to_new_blocks[req.request_id].get_block_ids(),
req._all_token_ids,
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
)
for req in scheduled_new_reqs
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
]
else:
new_reqs_data = [
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
NewRequestData.from_request(req, req_to_new_blocks[req.request_id].get_block_ids())
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
for req in scheduled_new_reqs
]
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
with record_function_or_nullcontext("schedule: make_cached_request_data"):
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
# Record the request ids that were scheduled in this step.
self.prev_step_scheduled_req_ids.clear()
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
preempted_req_ids={req.request_id for req in preempted_reqs},
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
)
# NOTE(Kuntai): this function is designed for multiple purposes:
# 1. Plan the KV cache store
# 2. Wrap up all the KV cache load / save ops into an opaque object
# 3. Clear the internal states of the connector
if self.connector is not None:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
meta: KVConnectorMetadata = self.connector.build_connector_meta(scheduler_output)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
scheduler_output.kv_connector_metadata = meta
# Build the connector meta for ECConnector
if self.ec_connector is not None:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(scheduler_output)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
scheduler_output.ec_connector_metadata = ec_meta
with record_function_or_nullcontext("schedule: update_after_schedule"):
self._update_after_schedule(scheduler_output)
return scheduler_output
class BalanceDPEngineCoreProc(DPEngineCoreProc):
def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case."""
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core.
executed = self._process_engine_step()
self._maybe_publish_request_counts()
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if not executed:
if not local_unfinished_reqs and not self.engines_running:
# All engines are idle.
continue
# We are in a running state and so must execute a dummy pass
# if the model didn't execute any ready requests.
self.execute_dummy_batch()
# 3) All-reduce operation to determine global unfinished reqs.
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
self.engines_running = self._has_global_unfinished_reqs(local_unfinished_reqs)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
self.scheduler.balance_gather(self.dp_group)
if not self.engines_running:
if self.dp_rank == 0 or not self.has_coordinator:
# Notify client that we are pausing the loop.
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
logger.debug("Wave %d finished, pausing engine loop.", self.current_wave)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# In the coordinator case, dp rank 0 sends updates to the
# coordinator. Otherwise (offline spmd case), each rank
# sends the update to its colocated front-end process.
client_index = -1 if self.has_coordinator else 0
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
self.output_queue.put_nowait(
(
client_index,
EngineCoreOutputs(wave_complete=self.current_wave),
)
)
[Main] [Patch] support balance scheduling patch (#5212) ### Motivation. **Limitations of the current vLLM v1 scheduling strategy** vLLM v1 scheduling currently enables chunkedprefill by default, which processes prefill and decode requests simultaneously in a single scheduling session. This can impact the overall system throughput and performance in some scenarios. Balance scheduling addresses this issue by synchronizing the number of running queues across all schedulers to delay the scheduling of new requests, thereby improving the overall system's steady-state decoding time. This achieves: ✅Adding `balance_gather` to the scheduler synchronizes the number of requests in the running queues between DPs. ✅Balance scheduling improves the decode steady-state time, thereby increasing the overall output throughput of the inference system. ### Proposed Change. **1.Feature Overview** In the vLLM scheduler, running requests (i.e., requests that are already undergoing pre-filled computation) have the highest priority, followed by waiting requests (i.e., requests that have not yet been computed). As shown in the diagram above, when the entire inference system exits from a steady state, the scheduler will schedule a batch of new requests for prefill operations and then synchronize them among the dynamic programming (DP) models. This can cause some DP models that are entirely decoded to synchronize with the number of prefilled tokens. Frequent prefill scheduling by certain DP models can lead to a deterioration in the overall system output throughput. Balance scheduling synchronizes the number of running queue requests across different DPs, and only schedules new requests for prefilling when at least every scheduler has fewer than max_nun_requst. **2.Implementation Design** **3.Experiment Results** - Fixed-length input scenario: In the performance test scenario with 3.5K fixed-length input and 1.5K fixed-length output, the throughput performance was improved by approximately **18%** after adding balance scheduling. | Method | Model | Input Len | Request Count | Output Len | BatchSize | Average TTFT | Average TPOT | e2e duration | Input Token Throughput | Output Token Throughput | Request Throughput | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | Baseline | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 6600 | 86.85 | 591.9s | 3030.5 | 1297.3 | 0.86 | | Balance scheduling | DeepSeekV3.1 | 3500 | 512 | 1500 | 128 | 7012 | 70.63 | 501.7s | 3575.7 | 1530.7 | 1.02 | **4.Demo PR** [#29721 ](https://github.com/vllm-project/vllm/pull/29721) --------- Signed-off-by: GDzhu01 <809721801@qq.com>
2025-12-23 09:04:38 +08:00
# Increment wave count and reset step counter.
self.current_wave += 1
self.step_counter = 0
def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
"""Launch EngineCore busy loop in background process."""
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
shutdown_requested = False
# Ensure we can serialize transformer config after spawning
maybe_register_config_serialize_by_value()
def signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the engine_core
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
engine_core: EngineCoreProc | None = None
try:
parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
set_process_title("EngineCore", f"DP{dp_rank}")
decorate_logs()
# Set data parallel rank for this engine process.
parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = local_dp_rank
engine_core = BalanceDPEngineCoreProc(*args, **kwargs)
else:
set_process_title("EngineCore")
decorate_logs()
engine_core = EngineCoreProc(*args, **kwargs)
engine_core.run_busy_loop()
except SystemExit:
logger.debug("EngineCore exiting.")
raise
except Exception as e:
if engine_core is None:
logger.exception("EngineCore failed to start.")
else:
logger.exception("EngineCore encountered a fatal error.")
engine_core._send_engine_dead()
raise e
finally:
if engine_core is not None:
engine_core.shutdown()
EngineCoreProc.run_engine_core = run_engine_core
vllm.v1.core.sched.scheduler.Scheduler = BalanceScheduler