[feat] support dispatch_v2/combine_v2 hierarchy communication (#7698)

### What this PR does / why we need it?

This PR adds support for hierarchical communication for `dispatch_v2`
and `combine_v2` MoE operations. This is achieved by introducing a new
configuration `enable_mc2_hierarchy_comm`. When enabled, the
communication algorithm is set to "hierarchy", which support mc2 op comm
between two super pod.

The changes include:
- Adding `enable_mc2_hierarchy_comm` to `AscendConfig`.
- Modifying `TokenDispatcherWithMC2` to pass `comm_alg: "hierarchy"` to
the underlying `torch_npu` ops when the new config is enabled.
- Adding validation to ensure that this feature is only used with
compatible PTA/CANN versions and is not used with the conflicting
`fused_mc2` op.
- Updating `is_hierarchical_communication_enabled` to respect the new
configuration flag.

### Does this PR introduce _any_ user-facing change?

Yes, this PR introduces a new user-facing configuration option
`enable_mc2_hierarchy_comm` in `additional_config` to enable
hierarchical communication for MoE.

### How was this patch tested?

- vLLM version: v0.18.0

Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2026-03-27 09:20:16 +08:00
committed by GitHub
parent 0bab629f90
commit a40eee2ba1
5 changed files with 30 additions and 3 deletions

View File

@@ -43,8 +43,9 @@ The following table lists additional configuration options available in vLLM Asc
| `enable_npugraph_ex` | bool | `False` | Whether to enable npugraph_ex graph mode. |
| `pa_shape_list` | list | `[]` | The custom shape list of page attention ops. |
| `enable_kv_nz` | bool | `False` | Whether to enable KV cache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
| `layer_sharding` | dict | `{}` | Configuration options for Layer Sharding Linear |
| `enable_sparse_c8` | bool | `False` | Whether to enable KV cache C8 in DSA models (e.g., DeepSeekV3.2 and GLM5). Not supported on A5 devices now |
| `layer_sharding` | dict | `{}` | Configuration options for Layer Sharding Linear |
| `enable_sparse_c8` | bool | `False` | Whether to enable KV cache C8 in DSA models (e.g., DeepSeekV3.2 and GLM5). Not supported on A5 devices now |
| `enable_mc2_hierarchy_comm` | bool | `False` | Enable dispatch/combine op inter-node communication by ROCE. |
The details of each configuration option are as follows:

View File

@@ -165,6 +165,9 @@ class AscendConfig:
and vllm_config.compilation_config.pass_config.enable_sp
)
# Enable dispatch/combine op inter-node communication by ROCE
self.enable_mc2_hierarchy_comm = additional_config.get("enable_mc2_hierarchy_comm", False)
@staticmethod
def _get_compile_ranges(compilation_config):
return compilation_config.compile_ranges_endpoints or []

View File

@@ -28,6 +28,7 @@ import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.device.device_op import DeviceOperator
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region
@@ -96,6 +97,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
# improve communication performance.
# When enable hierarchical communication, param `expert_scales` need to be passed in.
self.need_expert_scale = is_hierarchical_communication_enabled()
# Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute
@@ -115,6 +117,14 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
self.global_bs = num_tokens_per_tp_rank * self.ep_world_size
# NOTE: When enable_mc2_hierarchy_comm is true, we need pass in `comm_alg` to mc2 op.
self.need_comm_alg = get_ascend_config().enable_mc2_hierarchy_comm
if not self.enable_dispatch_v2 and self.need_comm_alg:
raise RuntimeError(
"PTA and CANN version is too old to support mc2 hierarchy comm, please upgrade your version."
)
def get_dispatch_mc2_kwargs(
self,
token_dispatch_input: MoETokenDispatchInput,
@@ -176,6 +186,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
"expert_scales": topk_weights.to(torch.float32),
}
)
if self.need_comm_alg:
stage1_kwargs.update({"comm_alg": "hierarchy"})
kwargs_mc2.update(stage1_kwargs)
return kwargs_mc2
@@ -265,6 +277,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
"tp_rank_id": 0,
}
)
if self.need_comm_alg:
stage3_kwargs.update({"comm_alg": "hierarchy"})
kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2

View File

@@ -30,6 +30,7 @@ from vllm.platforms import Platform, PlatformEnum
# todo: please remove it when solve cuda hard code in vllm
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import init_ascend_config
# isort: off
@@ -511,6 +512,12 @@ class NPUPlatform(Platform):
):
speculative_config.enforce_eager = False
if ascend_config.enable_mc2_hierarchy_comm and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2:
raise ValueError(
"fused mc2 op cannot be used with hierarchy communication."
"Please disable VLLM_ASCEND_ENABLE_FUSED_MC2 by setting it to 0."
)
@classmethod
def import_kernels(cls) -> None:
# Directly importing vllm_ascend_C prevents ASCEND_RT_VISIBLE_DEVICES

View File

@@ -991,7 +991,9 @@ def calculate_dp_buffer_size() -> int:
# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and
# significantly improve communication performance of MC2 ops dispatch/combine.
def is_hierarchical_communication_enabled():
return os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1"
return (
os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1"
) or get_ascend_config().enable_mc2_hierarchy_comm
def has_layer_idx(model_instance: torch.nn.Module) -> bool: