From 7c9aa498d6379016a123ba5737c4f9e76211dd39 Mon Sep 17 00:00:00 2001 From: linfeng-yuan <1102311262@qq.com> Date: Thu, 9 Apr 2026 16:51:17 +0800 Subject: [PATCH] [releases/v0.18.0][BugFix] Restore global_bs=0 and mc2_mask for uniform-token dispatching and support inter-node roce hierarchical MC2 communication (#8040) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? Cherry-picked from #8039 Restore the setting of MC2 `global_bs` and `mc2_mask` handling when `all_reduce` across DP group cannot be skipped. Ascend MC2 ops require `global_bs=0` + `mc2_mask` while enabling inter-node roce hierarchical communication. PR #4983 always passed non-zero `global_bs` without `mc2_mask`, which is incompatible with hierarchy comm raised in PR #7583 **Changes:** - Add `should_skip_allreduce_across_dp_group()` to `utils.py` with hierarchy constraint - Set `global_bs=0` when allreduce is not skipped; pass `mc2_mask` accordingly - Add `mc2_mask` field to `MoEMC2CombineMetadata` for dispatch→combine propagation ### Does this PR introduce _any_ user-facing change? No. But this PR fixes cross-super-node communication function on A3 with `enable_mc2_hierarchy_comm=True` in `additional_config` and `export HCCL_INTRA_ROCE_ENABLE=1`. ### How was this patch tested? E2E serving succeeded and CI pssed. - vLLM version: v0.18.0 - vLLM main: https://github.com/vllm-project/vllm/commit/14acf429ac08b6d538ca6feb3e06b6d13895804d --------- Signed-off-by: linfeng-yuan <1102311262@qq.com> --- .../ops/fused_moe/moe_stage_contracts.py | 1 + vllm_ascend/ops/fused_moe/token_dispatcher.py | 20 ++++++- vllm_ascend/utils.py | 56 +++++++++++++++++++ vllm_ascend/worker/model_runner_v1.py | 47 +--------------- 4 files changed, 78 insertions(+), 46 deletions(-) diff --git a/vllm_ascend/ops/fused_moe/moe_stage_contracts.py b/vllm_ascend/ops/fused_moe/moe_stage_contracts.py index 1e137498..b1392c73 100644 --- a/vllm_ascend/ops/fused_moe/moe_stage_contracts.py +++ b/vllm_ascend/ops/fused_moe/moe_stage_contracts.py @@ -92,6 +92,7 @@ class MoEMC2CombineMetadata: assist_info_for_combine: torch.Tensor expand_scales: torch.Tensor | None dispatch_with_quant: bool + mc2_mask: torch.Tensor | None = None @dataclass(frozen=True, slots=True) diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index 6112ec9e..4c0ef995 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -40,7 +40,12 @@ from vllm_ascend.ops.fused_moe.moe_runtime_args import ( MoETokenDispatchOutput, TMoECombineMetadata, ) -from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled +from vllm_ascend.utils import ( + AscendDeviceType, + get_ascend_device_type, + is_hierarchical_communication_enabled, + should_skip_allreduce_across_dp_group, +) class MoETokenDispatcher(ABC, Generic[TMoECombineMetadata]): @@ -115,7 +120,13 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]): else: max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512) num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size - self.global_bs = num_tokens_per_tp_rank * self.ep_world_size + _max_global_bs = num_tokens_per_tp_rank * self.ep_world_size + + # When allreduce across DP is not skipped, tokens are uniform across ranks: + # use global_bs=0 (uniform mode) and pass mc2_mask. + # When allreduce is skipped, tokens may differ per rank: + # use the real global_bs and do NOT pass mc2_mask. + self.global_bs = _max_global_bs if should_skip_allreduce_across_dp_group(vllm_config) else 0 # NOTE: When enable_mc2_hierarchy_comm is true, we need pass in `comm_alg` to mc2 op. self.need_comm_alg = get_ascend_config().enable_mc2_hierarchy_comm @@ -156,6 +167,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]): "global_bs": self.global_bs, "expert_token_nums_type": 0, } + if self.global_bs == 0: + kwargs_mc2["x_active_mask"] = token_dispatch_input.routing.mc2_mask stage1_kwargs = { "scales": None, @@ -228,6 +241,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]): assist_info_for_combine=assist_info_for_combine, expand_scales=expand_scales, dispatch_with_quant=token_dispatch_input.quant.dispatch_with_quant, + mc2_mask=token_dispatch_input.routing.mc2_mask if self.global_bs == 0 else None, ), ) @@ -251,6 +265,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]): "moe_expert_num": self.moe_expert_num, "global_bs": self.global_bs, } + if self.global_bs == 0: + kwargs_mc2["x_active_mask"] = combine_metadata.mc2_mask if combine_metadata.dispatch_with_quant: tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 76b3802a..766bd526 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -997,6 +997,62 @@ def is_hierarchical_communication_enabled(): ) or get_ascend_config().enable_mc2_hierarchy_comm +def should_skip_allreduce_across_dp_group(vllm_config, is_draft_model: bool = False) -> bool: + """Decide whether to skip the all-reduce across the DP group. + + Skipping is applicable for all dense models and for moe models only on ranks + that act as KV consumers. We skip the DP all-reduce when either: + - Both the prefill and decode communication methods are MC2 (or FUSED_MC2), or + - Decode requires MC2 and ascend_config.recompute_scheduler_enable is True. + + Skipping means each rank may have a different number of tokens, so MC2 needs + a non-zero global_bs and must NOT receive mc2_mask. + + Returns False when hierarchy comm is enabled because hierarchy requires + global_bs=0 (uniform tokens), which is incompatible with skipping allreduce. + """ + if is_hierarchical_communication_enabled(): + return False + + # For dense models, since we don't actually need dp communication, we simply skip it. + # This usually happens when main model is moe while eagle draft model is dense. + is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config) + if not is_context_moe_model: + return True + + # Only applicable to MoE models on KV consumer ranks. + is_kv_consumer = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer + if not is_kv_consumer: + return False + + from vllm_ascend.ascend_forward_context import select_moe_comm_method + from vllm_ascend.ops.fused_moe.moe_comm_method import MoECommType + + def needs_mc2(n: int) -> bool: + return select_moe_comm_method(n, vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2} + + compilation_config = vllm_config.compilation_config + scheduler_config = vllm_config.scheduler_config + speculative_config = vllm_config.speculative_config + uniform_decode_query_len = 1 if not speculative_config else 1 + speculative_config.num_speculative_tokens + decode_max_num_seqs = getattr(scheduler_config, "decode_max_num_seqs", 0) + max_num_reqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs) + + # Determine whether decode must use MC2. Use max cudagraph capture size + # if available, otherwise use the maximal uniform decode token count. + if compilation_config.cudagraph_capture_sizes: + potential_max_tokens = compilation_config.max_cudagraph_capture_size + else: + potential_max_tokens = min(max_num_reqs * uniform_decode_query_len, 512) + + decode_must_use_mc2 = needs_mc2(potential_max_tokens) + # For prefill, use the scheduler's max_num_batched_tokens for a single batch. + prefill_must_use_mc2 = needs_mc2(scheduler_config.max_num_batched_tokens) + # Skip all-reduce if decode requires MC2 and either prefill also + # requires MC2 or recompute-based scheduler is enabled. + return decode_must_use_mc2 and (prefill_must_use_mc2 or get_ascend_config().recompute_scheduler_enable) + + def has_layer_idx(model_instance: torch.nn.Module) -> bool: if model_instance is None: return False diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7b546ec1..0c600477 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -123,10 +123,9 @@ from vllm_ascend.utils import ( enable_sp, enable_sp_by_pass, global_stream, - is_drafter_moe_model, - is_moe_model, lmhead_tp_enable, set_weight_prefetch_method, + should_skip_allreduce_across_dp_group, ) from vllm_ascend.worker.npu_input_batch import NPUInputBatch from vllm_ascend.worker.pcp_utils import PCPManager @@ -469,46 +468,6 @@ class NPUModelRunner(GPUModelRunner): and not self.model_config.enforce_eager ) - def _skip_all_reduce_across_dp_group(self, is_draft_model=False) -> bool: - """ - Decide whether to skip the all-reduce across the data-parallel (DP) group. - - Skipping is applicable for all dense models and for moe models only on ranks - that act as KV consumers. We skip the DP all-reduce when either: - - Both the prefill and decode communication methods are MC2 (or FUSED_MC2), or - - Decode requires MC2 and ascend_config.recompute_scheduler_enable is True. - """ - # For dense models, since we don't actually need dp communication, we simply skip it. - # This usually happens when main model is moe while eagle draft model is dense. - is_context_moe_model = ( - is_drafter_moe_model(self.vllm_config) if is_draft_model else is_moe_model(self.vllm_config) - ) - if not is_context_moe_model: - return True - - # Only applicable to MoE models on KV consumer ranks. - if not self.is_kv_consumer: - return False - - def needs_mc2(num_tokens: int) -> bool: - return select_moe_comm_method(num_tokens, self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2} - - # Determine whether decode must use MC2. Use max cudagraph capture size - # if available, otherwise use the maximal uniform decode token count. - if self.compilation_config.cudagraph_capture_sizes: - potential_max_tokens = self.compilation_config.max_cudagraph_capture_size - else: - potential_max_tokens = self.max_num_reqs * self.uniform_decode_query_len - decode_must_use_mc2 = needs_mc2(potential_max_tokens) - - # For prefill, use the scheduler's max_num_batched_tokens for a single - # batch. - prefill_must_use_mc2 = needs_mc2(self.vllm_config.scheduler_config.max_num_batched_tokens) - - # Skip all-reduce if decode requires MC2 and either prefill also - # requires MC2 or recompute-based scheduler is enabled. - return decode_must_use_mc2 and (prefill_must_use_mc2 or self.ascend_config.recompute_scheduler_enable) - def _sync_metadata_across_dp( self, num_tokens: int, with_prefill: bool = False, is_draft_model: bool = False ) -> tuple[int, torch.Tensor | None, bool]: @@ -521,7 +480,7 @@ class NPUModelRunner(GPUModelRunner): if self.dp_size == 1: return num_tokens, None, with_prefill - if self._skip_all_reduce_across_dp_group(is_draft_model): + if should_skip_allreduce_across_dp_group(self.vllm_config, is_draft_model): num_tokens_after_padding = torch.tensor([num_tokens] * self.dp_size, device="cpu", dtype=torch.int32) return num_tokens, num_tokens_after_padding, with_prefill @@ -1891,7 +1850,7 @@ class NPUModelRunner(GPUModelRunner): if self.dp_size == 1: return False, None, cudagraph_mode - if self._skip_all_reduce_across_dp_group(): + if should_skip_allreduce_across_dp_group(self.vllm_config): num_tokens_after_padding = torch.tensor([num_tokens_padded] * self.dp_size, device="cpu", dtype=torch.int32) return False, num_tokens_after_padding, cudagraph_mode