From 9976e685b762e97cc2122eb16c4c8a6af07b46c6 Mon Sep 17 00:00:00 2001
From: Levi <54832289+Levi-JQ@users.noreply.github.com>
Date: Mon, 23 Mar 2026 17:05:02 +0800
Subject: [PATCH] [Bugfix][eager][oom] fix rank0 load imbalance by no padding
when multi dp (#7297)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
### What this PR does / why we need it?
Fix multi dp padding logic for eager mode, bacause its will cause rank0
load imbalance in kimi-k2.5-w4a8 with the all the padding tokens router
to rank0. And the fix can also apply to other model in multi dp.
- before
hbm usage:
preformance:
```shell
Concurrency NumPrompts QPS TTFT_Avg TTFT_P50 TPOT_Avg TPOT_P50 TPOT_P90
============ ============ ============ ============ ============ ============ ============ ============
1 15 0.0179 1667.7803 1673.3437 35.2973 35.2775 35.3784
32 480 0.4725 2764.8027 1905.2137 40.8030 40.6978 41.0179
64 960 0.7820 4123.7096 3485.6153 48.0461 48.1598 48.2971
100 1500 1.0852 6216.7988 5714.0082 52.9323 53.0613 54.6304
108 1620 1.1040 6277.4892 5798.7425 56.3862 56.9224 57.2901
116 1740 1.1680 6563.3293 6039.5659 56.9894 57.4027 57.5786
128 1920 1.2555 7822.5551 7604.1662 57.7660 58.1768 58.2717
192 2880 1.4314 9212.1953 9131.3461 58.9905 59.1683 59.2791
256 3840 1.4480 9028.0812 8913.7937 59.0092 59.2385 59.3516
```
- after
hbm usage:
preformance:
```shell
Concurrency NumPrompts QPS TTFT_Avg TTFT_P50 TPOT_Avg TPOT_P50 TPOT_P90
============ ============ ============ ============ ============ ============ ============ ============
1 15 0.0181 601.4171 600.9774 35.6270 35.6254 35.6480
32 480 0.4455 720.8782 724.2889 45.4250 45.4755 45.6318
64 960 0.8445 729.6209 728.2149 47.0464 47.0896 47.1985
100 1500 1.2601 723.4834 724.6673 48.3108 48.3844 48.5355
108 1620 1.3409 727.1509 720.6772 48.8962 48.9409 49.0489
116 1740 1.4080 679.9799 677.6119 49.1253 49.1983 49.3087
128 1920 1.4155 680.6284 674.9436 49.2193 49.2450 49.3763
192 2880 1.4422 684.6577 676.7833 49.2059 49.2264 49.3229
256 3840 1.4558 685.2462 678.1709 49.2191 49.2351 49.3419
```
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.17.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d
---------
Signed-off-by: Levi-JQ
Co-authored-by: Levi-JQ
Co-authored-by: fny-coder <985619145@qq.com>
---
vllm_ascend/ascend_forward_context.py | 4 +++-
vllm_ascend/worker/model_runner_v1.py | 18 +++++++++++++-----
2 files changed, 16 insertions(+), 6 deletions(-)
diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py
index d013be34..582e4957 100644
--- a/vllm_ascend/ascend_forward_context.py
+++ b/vllm_ascend/ascend_forward_context.py
@@ -69,7 +69,9 @@ def set_ascend_forward_context(
from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method
- moe_comm_type = select_moe_comm_method(num_tokens, vllm_config, is_draft_model)
+ max_num_tokens = int(num_tokens_across_dp.max().item()) if num_tokens_across_dp is not None else num_tokens
+ moe_comm_type = select_moe_comm_method(max_num_tokens, vllm_config, is_draft_model)
+
forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py
index 4d97c1a1..f89404ef 100644
--- a/vllm_ascend/worker/model_runner_v1.py
+++ b/vllm_ascend/worker/model_runner_v1.py
@@ -1235,6 +1235,7 @@ class NPUModelRunner(GPUModelRunner):
num_scheduled_tokens_np=num_scheduled_tokens_np,
max_num_scheduled_tokens=max_num_scheduled_tokens,
use_cascade_attn=cascade_attn_prefix_lens is not None,
+ force_eager=self.model_config.enforce_eager,
num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs),
)
@@ -1853,6 +1854,7 @@ class NPUModelRunner(GPUModelRunner):
self,
num_tokens_padded: int | None = None,
cudagraph_mode: int = 0,
+ allow_dp_padding: bool = False,
) -> tuple[bool, torch.Tensor | None, int]:
"""
Coordinates amongst all DP ranks to determine if and how the full batch
@@ -1896,11 +1898,16 @@ class NPUModelRunner(GPUModelRunner):
num_tokens_across_dp = tensor[0, :]
max_num_tokens = int(num_tokens_across_dp.max().item())
- num_tokens_after_padding = torch.tensor(
- [max_num_tokens] * len(num_tokens_across_dp),
- device="cpu",
- dtype=torch.int32,
- )
+
+ if allow_dp_padding:
+ num_tokens_after_padding = torch.tensor(
+ [max_num_tokens] * len(num_tokens_across_dp),
+ device="cpu",
+ dtype=torch.int32,
+ )
+ else:
+ num_tokens_after_padding = num_tokens_across_dp.cpu()
+
# Synchronize cudagraph_mode across ranks (take min)
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
return False, num_tokens_after_padding, synced_cudagraph_mode
@@ -1969,6 +1976,7 @@ class NPUModelRunner(GPUModelRunner):
_, num_tokens_across_dp, synced_cudagraph_mode = self._sync_batch_across_dp(
num_tokens_padded=num_tokens_padded,
cudagraph_mode=cudagraph_mode.value,
+ allow_dp_padding=cudagraph_mode != CUDAGraphMode.NONE,
)
# Extract DP padding if there is any