[Fix] Fix DP-related padding logic (#2582)

### What this PR does / why we need it?
The determination of attention state, padding, and other forward
metadata has been moved to an earlier stage within the input preparation
process. This change enables us to utilize a single all-reduce
operation, maximizing synchronization efficiency as early as possible.

The logic for synchronizing metadata—such as the number of tokens,
prefill status, and DBO status—across data parallel (DP) ranks has now
been unified and simplified.

For performance improvements, the all-reduce operation has been switched
from the `gloo` backend to the `npu` backend, which results in an
reduction of several milliseconds per step (**approximately 10%
performance gain for TPOT!**).

Additionally, the multi-DP server hang issue has been resolved, ensuring
no more hangs occur when `num_requests < dp_size`. Alas, a relief.

Finally, the miscalculated memory usage issue has been addressed by
removing the unnecessary `DummyCommImpl`, allowing the system to use the
real communication method when determining available memory.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Maybe we should add an test case for multi-DP online server?
@MengqingCao


- vLLM version: v0.10.1.1
- vLLM main:
c5d004aaaf

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-08-28 19:39:58 +08:00
committed by GitHub
parent 175f6bc445
commit dfc7eb39ad
5 changed files with 110 additions and 160 deletions

View File

@@ -94,43 +94,6 @@ class MoECommMethod(ABC):
pass
class DummyCommImpl(MoECommMethod):
def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Dummy prepare method that does nothing."""
return hidden_states, router_logits
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""Dummy finalize method that does nothing."""
return hidden_states
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""Dummy implementation, make sure the output shapes are correct."""
top_k_num = topk_ids.shape[1]
permuted_hidden_states = hidden_states.repeat_interleave(top_k_num,
dim=0)
expert_tokens = torch.zeros((num_experts, ),
dtype=torch.int64,
device=hidden_states.device)
group_list_type = 0
return permuted_hidden_states, expert_tokens, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
"""Dummy implementation that does nothing."""
pass
class AllGatherCommImpl(MoECommMethod):
"""This implementation is the same as NativeAllGatherCommImpl,
but uses NPU-specific ops for better performance.

View File

@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
DummyCommImpl,
MC2CommImpl,
MoECommMethod)
from vllm_ascend.distributed.parallel_state import get_mc2_group
@@ -230,7 +229,7 @@ class AscendFusedMoE(FusedMoE):
self.moe_config.ep_group = get_ep_group()
self.moe_config.mc2_group = get_mc2_group()
for method in {AllGatherCommImpl, DummyCommImpl, MC2CommImpl}:
for method in {AllGatherCommImpl, MC2CommImpl}:
setattr(
self, method.__name__.lower(),
method(moe_config=self.moe_config)) # type: ignore[abstract]
@@ -241,8 +240,11 @@ class AscendFusedMoE(FusedMoE):
forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
if not self.moe_config.use_ep and moe_comm_method_name != "dummycommimpl":
# TODO: Can we refactor this logic to model_runner?
if not self.moe_config.use_ep:
moe_comm_method_name = "allgathercommimpl"
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
hidden_states, router_logits = forward_context.moe_comm_method.prepare(

View File

@@ -70,7 +70,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
register_torchair_model()
torchair_quant_method_register()
def _get_forward_metadata_across_dp_and_pad(
def _sync_metadata_across_dp(
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
"""Override from NPUModelRunner to pad num_tokens"""
@@ -81,8 +81,17 @@ class NPUTorchairModelRunner(NPUModelRunner):
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
return num_tokens, None, with_prefill, enable_dbo
num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(
num_tokens, with_prefill, enable_dbo)
num_tokens_across_dp = torch.zeros(self.dp_size + 2,
dtype=torch.int32,
device="npu")
num_tokens_across_dp[self.dp_rank] = num_tokens
num_tokens_across_dp[-2] = int(with_prefill)
num_tokens_across_dp[-1] = int(not enable_dbo)
dist.all_reduce(num_tokens_across_dp,
group=get_dp_group().device_group)
with_prefill = bool(num_tokens_across_dp[-2])
enable_dbo = not bool(num_tokens_across_dp[-1])
num_tokens_across_dp = num_tokens_across_dp[:-2]
if not with_prefill:
max_num_token = num_tokens_across_dp.max().item()

View File

@@ -43,8 +43,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
get_tp_group,
is_global_first_rank)
from vllm.forward_context import (BatchDescriptor, DPMetadata,
get_forward_context)
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
@@ -373,10 +372,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
device=self.device,
)
self.moe_comm_method = "mc2"
self.fallback_moe_comm_method = "allgather"
self.dummy_moe_comm_method = "dummy"
def _use_aclgraph(self) -> bool:
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
@@ -594,32 +589,43 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
def _get_forward_metadata_across_dp(
self, num_tokens: int, with_prefill: bool,
enable_dbo: bool) -> tuple[torch.Tensor, bool, bool]:
# Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo)
num_tokens_across_dp = torch.zeros(self.dp_size + 2,
dtype=torch.int32,
device="cpu")
num_tokens_across_dp[self.dp_rank] = num_tokens
num_tokens_across_dp[-2] = int(with_prefill)
num_tokens_across_dp[-1] = int(not enable_dbo)
dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group)
with_prefill = bool(num_tokens_across_dp[-2])
enable_dbo = not bool(num_tokens_across_dp[-1])
num_tokens_across_dp = num_tokens_across_dp[:-2]
return num_tokens_across_dp, with_prefill, enable_dbo
def _get_forward_metadata_across_dp_and_pad(
def _sync_metadata_across_dp(
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
if self.dp_size == 1:
if self.dp_size == 1 or self.vllm_config.model_config.enforce_eager:
return num_tokens, None, with_prefill, enable_dbo
num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(
num_tokens, with_prefill, enable_dbo)
return num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
# Sync num_tokens, with_prefill, enable_dbo across dp ranks
num_tokens_tensor = torch.tensor([
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
],
dtype=torch.int32,
device="npu")
flags_tensor = torch.tensor(
[int(with_prefill), int(not enable_dbo)],
dtype=torch.int32,
device="npu")
packed_tensor = torch.cat([num_tokens_tensor, flags_tensor])
dist.all_reduce(packed_tensor, group=get_dp_group().device_group)
# Unpack the results
num_tokens_across_dp = packed_tensor[:-2]
synced_flags = packed_tensor[-2:]
max_tokens_across_dp = torch.max(num_tokens_across_dp).item()
global_with_prefill = bool(synced_flags[0])
global_enable_dbo = not bool(synced_flags[1])
# Create a tensor for num_tokens_after_padding
num_tokens_after_padding = torch.tensor([max_tokens_across_dp] *
self.dp_size,
device="npu",
dtype=torch.int32)
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
attn_state: AscendAttentionState,
@@ -1025,32 +1031,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
mm_embeds.append(mm_embeds_item)
return mm_embeds
def get_dp_padding(self,
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
"""This implementation is derived from vLLM's `GPUModelRunner.get_dp_padding`.
Please note that vLLM may refactor or modify this function over time,
at present, we are using the version introduced in PR #18935.
"""
dp_size = self.vllm_config.parallel_config.data_parallel_size
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
# For DP: Don't pad when setting enforce_eager.
# This lets us set enforce_eager on the prefiller in a P/D setup and
# still use ACL graphs (enabled by this padding) on the decoder.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
# Early exit.
return 0, None
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
dp_size,
device="cpu",
dtype=torch.int32)
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
@@ -1060,24 +1040,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
torch.Tensor, int, torch.Tensor, SpecDecodeMetadata,
Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor]]:
# Check input valid
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
if (self.use_aclgraph and total_num_scheduled_tokens
<= self.aclgraph_batch_sizes[-1]):
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
total_num_scheduled_tokens)
else:
# Eager mode.
num_input_tokens = total_num_scheduled_tokens
# Padding for DP
num_pad, num_tokens_across_dp_native = self.get_dp_padding(
num_input_tokens)
num_input_tokens += num_pad
self.attn_metadata_builder.reorder_batch(self.input_batch,
scheduler_output)
@@ -1098,6 +1064,41 @@ class NPUModelRunner(LoRAModelRunnerMixin):
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
if (self.use_aclgraph and total_num_scheduled_tokens
<= self.aclgraph_batch_sizes[-1]):
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
total_num_scheduled_tokens)
else:
# Eager mode.
num_input_tokens = total_num_scheduled_tokens
# Get the attention state.
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)
self.attn_state = attn_state # type: ignore
# Determine if it's a splitfuse batch
with_prefill = attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
self.query_lens = torch.from_numpy(num_scheduled_tokens)
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
attn_state,
total_num_scheduled_tokens)
# Get info across DP ranks.
# NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP,
# Otherwise, it's just max_tokens_across_dp_cpu
(maybe_padded_num_tokens, num_tokens_across_dp, with_prefill,
enable_dbo) = self._sync_metadata_across_dp(num_input_tokens,
with_prefill, enable_dbo)
if self.use_aclgraph:
# When using TorchAir with DP, we have other plans for padding
num_input_tokens = maybe_padded_num_tokens
# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
@@ -1166,20 +1167,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens[num_reqs:].fill_(0)
self.query_start_loc[num_reqs + 1:].fill_(-1)
with_prefill = attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
attn_state,
total_num_scheduled_tokens)
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
enable_dbo) = self._get_forward_metadata_across_dp_and_pad(
total_num_scheduled_tokens, with_prefill, enable_dbo)
self.with_prefill = with_prefill
self.num_tokens_across_dp = num_tokens_across_dp
self._update_graph_pad_size(with_prefill, padded_num_tokens_across_dp)
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
@@ -1247,7 +1237,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
positions = self.positions[:num_input_tokens]
input_ids, positions = self._update_input_ids_and_positions(
input_ids, positions, num_input_tokens, with_prefill,
padded_num_tokens_across_dp)
maybe_padded_num_tokens)
if get_pp_group().is_first_rank:
intermediate_tensors = None
@@ -1262,14 +1252,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for k, v in self.intermediate_tensors.items()
})
# NOTE: Currently this padding logic is really messy,
# MC2 may not be available in eager mode
# TODO: Unify the padding logic between TorchAir and ACL Graph ASAP
if self.use_aclgraph:
num_tokens_across_dp = num_tokens_across_dp_native
else:
num_input_tokens = padded_num_tokens_across_dp
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
@@ -1297,12 +1279,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return (attn_metadata, positions, num_scheduled_tokens,
num_input_tokens, num_tokens_across_dp,
padded_num_tokens_across_dp, logits_indices,
spec_decode_metadata, input_ids, inputs_embeds,
intermediate_tensors)
maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
input_ids, inputs_embeds, intermediate_tensors)
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
padded_num_tokens_across_dp,
maybe_padded_num_tokens,
input_ids, positions,
intermediate_tensors,
inputs_embeds):
@@ -1345,7 +1326,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _update_input_ids_and_positions(self, input_ids, positions,
num_input_tokens, with_prefill,
padded_num_tokens_across_dp):
maybe_padded_num_tokens):
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
return input_ids, positions
@@ -1632,6 +1613,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_connector_output=kv_connector_output,
)
def _select_moe_comm_method(self, num_tokens: int) -> str:
return ("mc2"
if num_tokens <= self.mc2_tokens_capacity else "allgather")
@torch.inference_mode()
def execute_model(
self,
@@ -1649,15 +1634,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)
(attn_metadata, positions, num_scheduled_tokens_np,
num_input_tokens, num_tokens_across_dp,
padded_num_tokens_across_dp, logits_indices, spec_decode_metadata,
input_ids, inputs_embeds,
num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens,
logits_indices, spec_decode_metadata, input_ids, inputs_embeds,
intermediate_tensors) = (self._prepare_inputs(
scheduler_output, intermediate_tensors))
moe_comm_method = (self.moe_comm_method
if num_input_tokens <= self.mc2_tokens_capacity else
self.fallback_moe_comm_method)
moe_comm_method = self._select_moe_comm_method(num_input_tokens)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=False)
aclgraph_runtime_mode, batch_descriptor = \
@@ -1680,9 +1663,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.maybe_setup_kv_connector(scheduler_output)
hidden_states = self._generate_process_reqs_hidden_states(
attn_metadata, self.with_prefill,
padded_num_tokens_across_dp, input_ids, positions,
intermediate_tensors, inputs_embeds)
attn_metadata, self.with_prefill, maybe_padded_num_tokens,
input_ids, positions, intermediate_tensors, inputs_embeds)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = self.get_finished_kv_transfer(
@@ -1988,7 +1970,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_tokens: int,
with_prefill: bool = False,
is_torchair_compile: bool = False,
moe_comm_method: str = "dummy",
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
force_attention: bool = False,
uniform_decode: bool = False,
@@ -2003,13 +1984,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
# Padding for DP
num_pad, num_tokens_across_dp_native = self.get_dp_padding(num_tokens)
# num_tokens += num_pad ## Uncomment this after TorchAir is removed
# Padding for DP (for TorchAir)
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self._get_forward_metadata_across_dp_and_pad(
num_tokens, with_prefill, False)
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
moe_comm_method = self._select_moe_comm_method(num_tokens)
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.seperate_routine(). This means that we are using
@@ -2518,12 +2496,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self._dummy_run(num_tokens,
aclgraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode,
moe_comm_method=self.moe_comm_method)
uniform_decode=uniform_decode)
self._dummy_run(num_tokens,
aclgraph_runtime_mode=aclgraph_runtime_mode,
uniform_decode=uniform_decode,
moe_comm_method=self.moe_comm_method)
uniform_decode=uniform_decode)
def _capture_model(self):
if not self.use_aclgraph:

View File

@@ -194,7 +194,7 @@ class MtpProposer:
# torch mode need to update num_tokens_across_dp
# TODO: adapt enable_dbo later
(num_input_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._get_forward_metadata_across_dp_and_pad(
_) = self.runner._sync_metadata_across_dp(
num_tokens, self.runner.with_prefill, False)
attn_metadata.slot_mapping = target_slot_mapping
else:
@@ -281,8 +281,8 @@ class MtpProposer:
if not self.torchair_graph_enabled:
# TODO: adapt enable_dbo later
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._get_forward_metadata_across_dp_and_pad(
num_tokens, with_prefill, False)
_) = self.runner._sync_metadata_across_dp(num_tokens,
with_prefill, False)
is_running_torchair = self.torchair_graph_enabled and \
not with_prefill