From 09f71c14a666a74f167a8e5c2e5209e7ead65370 Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Sat, 27 Dec 2025 17:06:58 +0800 Subject: [PATCH] Revert "[feat] enable hierarchical mc2 ops on A2 by default (#5300)" (#5434) We'll release 0.13.0 soon. The main branch is freeze. Let's revert the newest change and redo it once 0.13.0 is released. - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/81786c87748b0177111dfdc07af5351d8389baa1 Signed-off-by: realliujiaxu --- vllm_ascend/ops/fused_moe/token_dispatcher.py | 22 ++++++++++--------- vllm_ascend/utils.py | 8 +++++++ 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index da185118..aeb751d0 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -31,7 +31,8 @@ from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe.comm_utils import ( async_all_to_all, gather_from_sequence_parallel_region) -from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type +from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, + is_hierarchical_communication_enabled) class MoETokenDispatcher(ABC): @@ -100,6 +101,10 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): self.need_extra_args = ( get_ascend_device_type() == AscendDeviceType.A3) + # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and + # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly + # improve communication performance. + self.need_expert_scale = is_hierarchical_communication_enabled() self.with_quant = False # Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute @@ -137,7 +142,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): else: quant_mode = 0 moe_expert_num = len(expert_map) - kwargs_mc2 = { "x": hidden_states, "expert_ids": topk_ids, @@ -146,12 +150,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "moe_expert_num": moe_expert_num, "global_bs": self.global_bs, "expert_token_nums_type": 0, - "expert_scales": topk_weights.to(torch.float32), } - if get_ascend_device_type() == AscendDeviceType.A2: - kwargs_mc2["comm_alg"] = "hierarchy" - stage1_kwargs = { "scales": None, "quant_mode": quant_mode, @@ -165,6 +165,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "tp_world_size": 1, "tp_rank_id": 0, }) + if self.need_expert_scale: + stage1_kwargs.update({ + "expert_scales": + topk_weights.to(torch.float32), + }) kwargs_mc2.update(stage1_kwargs) return kwargs_mc2 @@ -264,12 +269,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, "global_bs": self.global_bs, - "expand_scales": expand_scales, } - if get_ascend_device_type() == AscendDeviceType.A2: - kwargs_mc2["comm_alg"] = "hierarchy" - if self.with_quant: tp_recv_counts = torch.empty(1, dtype=torch.int32, @@ -280,6 +281,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "group_ep": self.moe_all_to_all_group_name, "ep_world_size": self.ep_world_size, "ep_rank_id": self.ep_rank_id, + "expand_scales": expand_scales, } if self.enable_dispatch_v2: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 51b87cfe..97f8e2b6 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -958,6 +958,14 @@ def calculate_dp_buffer_size() -> int: return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE) +# Currently, when in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 +# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and +# significantly improve communication performance of MC2 ops dispatch/combine. +def is_hierarchical_communication_enabled(): + return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" + and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1") + + def has_layer_idx(model_instance: torch.nn.Module) -> bool: if model_instance is None: return False