diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index f5f01a60..0513307a 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -32,7 +32,8 @@ from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe.comm_utils import ( async_all_to_all, gather_from_sequence_parallel_region) -from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type +from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, + is_hierarchical_communication_enabled) @dataclass @@ -116,6 +117,10 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): self.need_extra_args = ( get_ascend_device_type() == AscendDeviceType.A3) + # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and + # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly + # improve communication performance. + self.need_expert_scale = is_hierarchical_communication_enabled() self.with_quant = False # Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute @@ -153,7 +158,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): else: quant_mode = 0 moe_expert_num = len(expert_map) - kwargs_mc2 = { "x": hidden_states, "expert_ids": topk_ids, @@ -162,12 +166,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "moe_expert_num": moe_expert_num, "global_bs": self.global_bs, "expert_token_nums_type": 0, - "expert_scales": topk_weights.to(torch.float32), } - if get_ascend_device_type() == AscendDeviceType.A2: - kwargs_mc2["comm_alg"] = "hierarchy" - stage1_kwargs = { "scales": None, "quant_mode": quant_mode, @@ -181,6 +181,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "tp_world_size": 1, "tp_rank_id": 0, }) + if self.need_expert_scale: + stage1_kwargs.update({ + "expert_scales": + topk_weights.to(torch.float32), + }) kwargs_mc2.update(stage1_kwargs) return kwargs_mc2 @@ -258,12 +263,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, "global_bs": self.global_bs, - "expand_scales": expand_scales, } - if get_ascend_device_type() == AscendDeviceType.A2: - kwargs_mc2["comm_alg"] = "hierarchy" - if self.with_quant: tp_recv_counts = torch.empty(1, dtype=torch.int32, @@ -274,6 +275,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "group_ep": self.moe_all_to_all_group_name, "ep_world_size": self.ep_world_size, "ep_rank_id": self.ep_rank_id, + "expand_scales": expand_scales, } if self.enable_dispatch_v2: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index cecb88cd..bbe63625 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -983,6 +983,14 @@ def calculate_dp_buffer_size() -> int: return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE) +# Currently, when in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 +# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and +# significantly improve communication performance of MC2 ops dispatch/combine. +def is_hierarchical_communication_enabled(): + return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" + and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1") + + def has_layer_idx(model_instance: torch.nn.Module) -> bool: if model_instance is None: return False