From a40eee2ba1714c102750f4f723a3d2c1b2c7f238 Mon Sep 17 00:00:00 2001 From: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com> Date: Fri, 27 Mar 2026 09:20:16 +0800 Subject: [PATCH] [feat] support dispatch_v2/combine_v2 hierarchy communication (#7698) ### What this PR does / why we need it? This PR adds support for hierarchical communication for `dispatch_v2` and `combine_v2` MoE operations. This is achieved by introducing a new configuration `enable_mc2_hierarchy_comm`. When enabled, the communication algorithm is set to "hierarchy", which support mc2 op comm between two super pod. The changes include: - Adding `enable_mc2_hierarchy_comm` to `AscendConfig`. - Modifying `TokenDispatcherWithMC2` to pass `comm_alg: "hierarchy"` to the underlying `torch_npu` ops when the new config is enabled. - Adding validation to ensure that this feature is only used with compatible PTA/CANN versions and is not used with the conflicting `fused_mc2` op. - Updating `is_hierarchical_communication_enabled` to respect the new configuration flag. ### Does this PR introduce _any_ user-facing change? Yes, this PR introduces a new user-facing configuration option `enable_mc2_hierarchy_comm` in `additional_config` to enable hierarchical communication for MoE. ### How was this patch tested? - vLLM version: v0.18.0 Signed-off-by: zzzzwwjj <1183291235@qq.com> --- .../user_guide/configuration/additional_config.md | 5 +++-- vllm_ascend/ascend_config.py | 3 +++ vllm_ascend/ops/fused_moe/token_dispatcher.py | 14 ++++++++++++++ vllm_ascend/platform.py | 7 +++++++ vllm_ascend/utils.py | 4 +++- 5 files changed, 30 insertions(+), 3 deletions(-) diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index f0a7a8dd..0375fdd4 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -43,8 +43,9 @@ The following table lists additional configuration options available in vLLM Asc | `enable_npugraph_ex` | bool | `False` | Whether to enable npugraph_ex graph mode. | | `pa_shape_list` | list | `[]` | The custom shape list of page attention ops. | | `enable_kv_nz` | bool | `False` | Whether to enable KV cache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). | -| `layer_sharding` | dict | `{}` | Configuration options for Layer Sharding Linear | -| `enable_sparse_c8` | bool | `False` | Whether to enable KV cache C8 in DSA models (e.g., DeepSeekV3.2 and GLM5). Not supported on A5 devices now | +| `layer_sharding` | dict | `{}` | Configuration options for Layer Sharding Linear | +| `enable_sparse_c8` | bool | `False` | Whether to enable KV cache C8 in DSA models (e.g., DeepSeekV3.2 and GLM5). Not supported on A5 devices now | +| `enable_mc2_hierarchy_comm` | bool | `False` | Enable dispatch/combine op inter-node communication by ROCE. | The details of each configuration option are as follows: diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 53423b4d..a18e6549 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -165,6 +165,9 @@ class AscendConfig: and vllm_config.compilation_config.pass_config.enable_sp ) + # Enable dispatch/combine op inter-node communication by ROCE + self.enable_mc2_hierarchy_comm = additional_config.get("enable_mc2_hierarchy_comm", False) + @staticmethod def _get_compile_ranges(compilation_config): return compilation_config.compile_ranges_endpoints or [] diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index 152f6e89..6112ec9e 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -28,6 +28,7 @@ import torch_npu from vllm.config import get_current_vllm_config from vllm.distributed.parallel_state import get_ep_group +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.device.device_op import DeviceOperator from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region @@ -96,6 +97,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]): # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly # improve communication performance. + # When enable hierarchical communication, param `expert_scales` need to be passed in. self.need_expert_scale = is_hierarchical_communication_enabled() # Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute @@ -115,6 +117,14 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]): num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size self.global_bs = num_tokens_per_tp_rank * self.ep_world_size + # NOTE: When enable_mc2_hierarchy_comm is true, we need pass in `comm_alg` to mc2 op. + self.need_comm_alg = get_ascend_config().enable_mc2_hierarchy_comm + + if not self.enable_dispatch_v2 and self.need_comm_alg: + raise RuntimeError( + "PTA and CANN version is too old to support mc2 hierarchy comm, please upgrade your version." + ) + def get_dispatch_mc2_kwargs( self, token_dispatch_input: MoETokenDispatchInput, @@ -176,6 +186,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]): "expert_scales": topk_weights.to(torch.float32), } ) + if self.need_comm_alg: + stage1_kwargs.update({"comm_alg": "hierarchy"}) kwargs_mc2.update(stage1_kwargs) return kwargs_mc2 @@ -265,6 +277,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]): "tp_rank_id": 0, } ) + if self.need_comm_alg: + stage3_kwargs.update({"comm_alg": "hierarchy"}) kwargs_mc2.update(stage3_kwargs) return kwargs_mc2 diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 4efa1972..f09c524e 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -30,6 +30,7 @@ from vllm.platforms import Platform, PlatformEnum # todo: please remove it when solve cuda hard code in vllm os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1" +import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import init_ascend_config # isort: off @@ -511,6 +512,12 @@ class NPUPlatform(Platform): ): speculative_config.enforce_eager = False + if ascend_config.enable_mc2_hierarchy_comm and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2: + raise ValueError( + "fused mc2 op cannot be used with hierarchy communication." + "Please disable VLLM_ASCEND_ENABLE_FUSED_MC2 by setting it to 0." + ) + @classmethod def import_kernels(cls) -> None: # Directly importing vllm_ascend_C prevents ASCEND_RT_VISIBLE_DEVICES diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 37658f91..ce23893e 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -991,7 +991,9 @@ def calculate_dp_buffer_size() -> int: # and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and # significantly improve communication performance of MC2 ops dispatch/combine. def is_hierarchical_communication_enabled(): - return os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1" + return ( + os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1" + ) or get_ascend_config().enable_mc2_hierarchy_comm def has_layer_idx(model_instance: torch.nn.Module) -> bool: