From dfc7eb39ada3f86f5c15425ba759ecfaa8f5c9a8 Mon Sep 17 00:00:00 2001 From: yiz-liu <136800916+yiz-liu@users.noreply.github.com> Date: Thu, 28 Aug 2025 19:39:58 +0800 Subject: [PATCH] [Fix] Fix DP-related padding logic (#2582) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? The determination of attention state, padding, and other forward metadata has been moved to an earlier stage within the input preparation process. This change enables us to utilize a single all-reduce operation, maximizing synchronization efficiency as early as possible. The logic for synchronizing metadata—such as the number of tokens, prefill status, and DBO status—across data parallel (DP) ranks has now been unified and simplified. For performance improvements, the all-reduce operation has been switched from the `gloo` backend to the `npu` backend, which results in an reduction of several milliseconds per step (**approximately 10% performance gain for TPOT!**). Additionally, the multi-DP server hang issue has been resolved, ensuring no more hangs occur when `num_requests < dp_size`. Alas, a relief. Finally, the miscalculated memory usage issue has been addressed by removing the unnecessary `DummyCommImpl`, allowing the system to use the real communication method when determining available memory. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Maybe we should add an test case for multi-DP online server? @MengqingCao - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/c5d004aaaf3b2106d33974c673bec0568c18f762 --------- Signed-off-by: Yizhou Liu --- vllm_ascend/distributed/moe_comm_method.py | 37 ---- vllm_ascend/ops/common_fused_moe.py | 8 +- vllm_ascend/torchair/torchair_model_runner.py | 15 +- vllm_ascend/worker/model_runner_v1.py | 204 ++++++++---------- vllm_ascend/worker/mtp_proposer_v1.py | 6 +- 5 files changed, 110 insertions(+), 160 deletions(-) diff --git a/vllm_ascend/distributed/moe_comm_method.py b/vllm_ascend/distributed/moe_comm_method.py index 02f6d52..ea32495 100644 --- a/vllm_ascend/distributed/moe_comm_method.py +++ b/vllm_ascend/distributed/moe_comm_method.py @@ -94,43 +94,6 @@ class MoECommMethod(ABC): pass -class DummyCommImpl(MoECommMethod): - - def prepare( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Dummy prepare method that does nothing.""" - return hidden_states, router_logits - - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: - """Dummy finalize method that does nothing.""" - return hidden_states - - def permute( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, - num_experts: int, - ) -> tuple[torch.Tensor, torch.Tensor, int]: - """Dummy implementation, make sure the output shapes are correct.""" - top_k_num = topk_ids.shape[1] - permuted_hidden_states = hidden_states.repeat_interleave(top_k_num, - dim=0) - expert_tokens = torch.zeros((num_experts, ), - dtype=torch.int64, - device=hidden_states.device) - group_list_type = 0 - return permuted_hidden_states, expert_tokens, group_list_type - - def unpermute(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - """Dummy implementation that does nothing.""" - pass - - class AllGatherCommImpl(MoECommMethod): """This implementation is the same as NativeAllGatherCommImpl, but uses NPU-specific ops for better performance. diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index ffc1dea..72ee91b 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe.layer import ( from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, - DummyCommImpl, MC2CommImpl, MoECommMethod) from vllm_ascend.distributed.parallel_state import get_mc2_group @@ -230,7 +229,7 @@ class AscendFusedMoE(FusedMoE): self.moe_config.ep_group = get_ep_group() self.moe_config.mc2_group = get_mc2_group() - for method in {AllGatherCommImpl, DummyCommImpl, MC2CommImpl}: + for method in {AllGatherCommImpl, MC2CommImpl}: setattr( self, method.__name__.lower(), method(moe_config=self.moe_config)) # type: ignore[abstract] @@ -241,8 +240,11 @@ class AscendFusedMoE(FusedMoE): forward_context = get_forward_context() moe_comm_method_name = forward_context.moe_comm_method_name - if not self.moe_config.use_ep and moe_comm_method_name != "dummycommimpl": + + # TODO: Can we refactor this logic to model_runner? + if not self.moe_config.use_ep: moe_comm_method_name = "allgathercommimpl" + forward_context.moe_comm_method = getattr(self, moe_comm_method_name) hidden_states, router_logits = forward_context.moe_comm_method.prepare( diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index fb4f583..24fd33a 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -70,7 +70,7 @@ class NPUTorchairModelRunner(NPUModelRunner): register_torchair_model() torchair_quant_method_register() - def _get_forward_metadata_across_dp_and_pad( + def _sync_metadata_across_dp( self, num_tokens: int, with_prefill: bool, enable_dbo: bool ) -> tuple[int, Optional[torch.Tensor], bool, bool]: """Override from NPUModelRunner to pad num_tokens""" @@ -81,8 +81,17 @@ class NPUTorchairModelRunner(NPUModelRunner): return maybe_padded_num_tokens, None, with_prefill, enable_dbo return num_tokens, None, with_prefill, enable_dbo - num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp( - num_tokens, with_prefill, enable_dbo) + num_tokens_across_dp = torch.zeros(self.dp_size + 2, + dtype=torch.int32, + device="npu") + num_tokens_across_dp[self.dp_rank] = num_tokens + num_tokens_across_dp[-2] = int(with_prefill) + num_tokens_across_dp[-1] = int(not enable_dbo) + dist.all_reduce(num_tokens_across_dp, + group=get_dp_group().device_group) + with_prefill = bool(num_tokens_across_dp[-2]) + enable_dbo = not bool(num_tokens_across_dp[-1]) + num_tokens_across_dp = num_tokens_across_dp[:-2] if not with_prefill: max_num_token = num_tokens_across_dp.max().item() diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 468f59e..6930a1c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -43,8 +43,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import (get_dp_group, get_pp_group, get_tp_group, is_global_first_rank) -from vllm.forward_context import (BatchDescriptor, DPMetadata, - get_forward_context) +from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -373,10 +372,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): device=self.device, ) - self.moe_comm_method = "mc2" - self.fallback_moe_comm_method = "allgather" - self.dummy_moe_comm_method = "dummy" - def _use_aclgraph(self) -> bool: return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager @@ -594,32 +589,43 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() - def _get_forward_metadata_across_dp( - self, num_tokens: int, with_prefill: bool, - enable_dbo: bool) -> tuple[torch.Tensor, bool, bool]: - - # Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo) - num_tokens_across_dp = torch.zeros(self.dp_size + 2, - dtype=torch.int32, - device="cpu") - num_tokens_across_dp[self.dp_rank] = num_tokens - num_tokens_across_dp[-2] = int(with_prefill) - num_tokens_across_dp[-1] = int(not enable_dbo) - dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group) - with_prefill = bool(num_tokens_across_dp[-2]) - enable_dbo = not bool(num_tokens_across_dp[-1]) - num_tokens_across_dp = num_tokens_across_dp[:-2] - return num_tokens_across_dp, with_prefill, enable_dbo - - def _get_forward_metadata_across_dp_and_pad( + def _sync_metadata_across_dp( self, num_tokens: int, with_prefill: bool, enable_dbo: bool ) -> tuple[int, Optional[torch.Tensor], bool, bool]: - if self.dp_size == 1: + if self.dp_size == 1 or self.vllm_config.model_config.enforce_eager: return num_tokens, None, with_prefill, enable_dbo - num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp( - num_tokens, with_prefill, enable_dbo) - return num_tokens, num_tokens_across_dp, with_prefill, enable_dbo + # Sync num_tokens, with_prefill, enable_dbo across dp ranks + num_tokens_tensor = torch.tensor([ + num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size) + ], + dtype=torch.int32, + device="npu") + + flags_tensor = torch.tensor( + [int(with_prefill), int(not enable_dbo)], + dtype=torch.int32, + device="npu") + + packed_tensor = torch.cat([num_tokens_tensor, flags_tensor]) + + dist.all_reduce(packed_tensor, group=get_dp_group().device_group) + + # Unpack the results + num_tokens_across_dp = packed_tensor[:-2] + synced_flags = packed_tensor[-2:] + + max_tokens_across_dp = torch.max(num_tokens_across_dp).item() + global_with_prefill = bool(synced_flags[0]) + global_enable_dbo = not bool(synced_flags[1]) + + # Create a tensor for num_tokens_after_padding + num_tokens_after_padding = torch.tensor([max_tokens_across_dp] * + self.dp_size, + device="npu", + dtype=torch.int32) + + return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo def _check_dbo_is_valid(self, query_lens: torch.Tensor, attn_state: AscendAttentionState, @@ -1025,32 +1031,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): mm_embeds.append(mm_embeds_item) return mm_embeds - def get_dp_padding(self, - num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: - """This implementation is derived from vLLM's `GPUModelRunner.get_dp_padding`. - Please note that vLLM may refactor or modify this function over time, - at present, we are using the version introduced in PR #18935. - """ - dp_size = self.vllm_config.parallel_config.data_parallel_size - dp_rank = self.vllm_config.parallel_config.data_parallel_rank - - # For DP: Don't pad when setting enforce_eager. - # This lets us set enforce_eager on the prefiller in a P/D setup and - # still use ACL graphs (enabled by this padding) on the decoder. - - if dp_size == 1 or self.vllm_config.model_config.enforce_eager: - # Early exit. - return 0, None - - num_tokens_across_dp = DPMetadata.num_tokens_across_dp( - num_tokens, dp_size, dp_rank) - max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() - num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * - dp_size, - device="cpu", - dtype=torch.int32) - return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding - def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -1060,24 +1040,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): torch.Tensor, int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: - # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - if (self.use_aclgraph and total_num_scheduled_tokens - <= self.aclgraph_batch_sizes[-1]): - # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph( - total_num_scheduled_tokens) - else: - # Eager mode. - num_input_tokens = total_num_scheduled_tokens - - # Padding for DP - num_pad, num_tokens_across_dp_native = self.get_dp_padding( - num_input_tokens) - num_input_tokens += num_pad self.attn_metadata_builder.reorder_batch(self.input_batch, scheduler_output) @@ -1098,6 +1064,41 @@ class NPUModelRunner(LoRAModelRunnerMixin): max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) + if (self.use_aclgraph and total_num_scheduled_tokens + <= self.aclgraph_batch_sizes[-1]): + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph( + total_num_scheduled_tokens) + else: + # Eager mode. + num_input_tokens = total_num_scheduled_tokens + + # Get the attention state. + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, + num_valid_tokens) + self.attn_state = attn_state # type: ignore + + # Determine if it's a splitfuse batch + with_prefill = attn_state not in [ + AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding + ] + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), + attn_state, + total_num_scheduled_tokens) + + # Get info across DP ranks. + # NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP, + # Otherwise, it's just max_tokens_across_dp_cpu + (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, + enable_dbo) = self._sync_metadata_across_dp(num_input_tokens, + with_prefill, enable_dbo) + + if self.use_aclgraph: + # When using TorchAir with DP, we have other plans for padding + num_input_tokens = maybe_padded_num_tokens + # Hot-Swap lora model if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) @@ -1166,20 +1167,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.seq_lens[num_reqs:].fill_(0) self.query_start_loc[num_reqs + 1:].fill_(-1) - with_prefill = attn_state not in [ - AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding - ] - - enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), - attn_state, - total_num_scheduled_tokens) - - (padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill, - enable_dbo) = self._get_forward_metadata_across_dp_and_pad( - total_num_scheduled_tokens, with_prefill, enable_dbo) self.with_prefill = with_prefill self.num_tokens_across_dp = num_tokens_across_dp - self._update_graph_pad_size(with_prefill, padded_num_tokens_across_dp) + self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens) common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], @@ -1247,7 +1237,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): positions = self.positions[:num_input_tokens] input_ids, positions = self._update_input_ids_and_positions( input_ids, positions, num_input_tokens, with_prefill, - padded_num_tokens_across_dp) + maybe_padded_num_tokens) if get_pp_group().is_first_rank: intermediate_tensors = None @@ -1262,14 +1252,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): for k, v in self.intermediate_tensors.items() }) - # NOTE: Currently this padding logic is really messy, - # MC2 may not be available in eager mode - # TODO: Unify the padding logic between TorchAir and ACL Graph ASAP - if self.use_aclgraph: - num_tokens_across_dp = num_tokens_across_dp_native - else: - num_input_tokens = padded_num_tokens_across_dp - use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: @@ -1297,12 +1279,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): return (attn_metadata, positions, num_scheduled_tokens, num_input_tokens, num_tokens_across_dp, - padded_num_tokens_across_dp, logits_indices, - spec_decode_metadata, input_ids, inputs_embeds, - intermediate_tensors) + maybe_padded_num_tokens, logits_indices, spec_decode_metadata, + input_ids, inputs_embeds, intermediate_tensors) def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, - padded_num_tokens_across_dp, + maybe_padded_num_tokens, input_ids, positions, intermediate_tensors, inputs_embeds): @@ -1345,7 +1326,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): def _update_input_ids_and_positions(self, input_ids, positions, num_input_tokens, with_prefill, - padded_num_tokens_across_dp): + maybe_padded_num_tokens): if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] return input_ids, positions @@ -1632,6 +1613,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): kv_connector_output=kv_connector_output, ) + def _select_moe_comm_method(self, num_tokens: int) -> str: + return ("mc2" + if num_tokens <= self.mc2_tokens_capacity else "allgather") + @torch.inference_mode() def execute_model( self, @@ -1649,15 +1634,13 @@ class NPUModelRunner(LoRAModelRunnerMixin): return EMPTY_MODEL_RUNNER_OUTPUT return self.kv_connector_no_forward(scheduler_output) (attn_metadata, positions, num_scheduled_tokens_np, - num_input_tokens, num_tokens_across_dp, - padded_num_tokens_across_dp, logits_indices, spec_decode_metadata, - input_ids, inputs_embeds, + num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens, + logits_indices, spec_decode_metadata, input_ids, inputs_embeds, intermediate_tensors) = (self._prepare_inputs( scheduler_output, intermediate_tensors)) - moe_comm_method = (self.moe_comm_method - if num_input_tokens <= self.mc2_tokens_capacity else - self.fallback_moe_comm_method) + moe_comm_method = self._select_moe_comm_method(num_input_tokens) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=False) aclgraph_runtime_mode, batch_descriptor = \ @@ -1680,9 +1663,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.maybe_setup_kv_connector(scheduler_output) hidden_states = self._generate_process_reqs_hidden_states( - attn_metadata, self.with_prefill, - padded_num_tokens_across_dp, input_ids, positions, - intermediate_tensors, inputs_embeds) + attn_metadata, self.with_prefill, maybe_padded_num_tokens, + input_ids, positions, intermediate_tensors, inputs_embeds) self.maybe_wait_for_kv_save() finished_sending, finished_recving = self.get_finished_kv_transfer( @@ -1988,7 +1970,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_tokens: int, with_prefill: bool = False, is_torchair_compile: bool = False, - moe_comm_method: str = "dummy", aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, force_attention: bool = False, uniform_decode: bool = False, @@ -2003,13 +1984,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): ) # Padding for DP - num_pad, num_tokens_across_dp_native = self.get_dp_padding(num_tokens) - # num_tokens += num_pad ## Uncomment this after TorchAir is removed - - # Padding for DP (for TorchAir) (num_tokens, num_tokens_across_dp, with_prefill, - _) = self._get_forward_metadata_across_dp_and_pad( - num_tokens, with_prefill, False) + _) = self._sync_metadata_across_dp(num_tokens, with_prefill, False) + + moe_comm_method = self._select_moe_comm_method(num_tokens) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.seperate_routine(). This means that we are using @@ -2518,12 +2496,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): self._dummy_run(num_tokens, aclgraph_runtime_mode=CUDAGraphMode.NONE, force_attention=force_attention, - uniform_decode=uniform_decode, - moe_comm_method=self.moe_comm_method) + uniform_decode=uniform_decode) self._dummy_run(num_tokens, aclgraph_runtime_mode=aclgraph_runtime_mode, - uniform_decode=uniform_decode, - moe_comm_method=self.moe_comm_method) + uniform_decode=uniform_decode) def _capture_model(self): if not self.use_aclgraph: diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 1ec1436..120b17a 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -194,7 +194,7 @@ class MtpProposer: # torch mode need to update num_tokens_across_dp # TODO: adapt enable_dbo later (num_input_tokens, num_tokens_across_dp, with_prefill, - _) = self.runner._get_forward_metadata_across_dp_and_pad( + _) = self.runner._sync_metadata_across_dp( num_tokens, self.runner.with_prefill, False) attn_metadata.slot_mapping = target_slot_mapping else: @@ -281,8 +281,8 @@ class MtpProposer: if not self.torchair_graph_enabled: # TODO: adapt enable_dbo later (num_tokens, num_tokens_across_dp, with_prefill, - _) = self.runner._get_forward_metadata_across_dp_and_pad( - num_tokens, with_prefill, False) + _) = self.runner._sync_metadata_across_dp(num_tokens, + with_prefill, False) is_running_torchair = self.torchair_graph_enabled and \ not with_prefill