[feat] support dispatch_v2/combine_v2 hierarchy communication (#7698)
### What this PR does / why we need it? This PR adds support for hierarchical communication for `dispatch_v2` and `combine_v2` MoE operations. This is achieved by introducing a new configuration `enable_mc2_hierarchy_comm`. When enabled, the communication algorithm is set to "hierarchy", which support mc2 op comm between two super pod. The changes include: - Adding `enable_mc2_hierarchy_comm` to `AscendConfig`. - Modifying `TokenDispatcherWithMC2` to pass `comm_alg: "hierarchy"` to the underlying `torch_npu` ops when the new config is enabled. - Adding validation to ensure that this feature is only used with compatible PTA/CANN versions and is not used with the conflicting `fused_mc2` op. - Updating `is_hierarchical_communication_enabled` to respect the new configuration flag. ### Does this PR introduce _any_ user-facing change? Yes, this PR introduces a new user-facing configuration option `enable_mc2_hierarchy_comm` in `additional_config` to enable hierarchical communication for MoE. ### How was this patch tested? - vLLM version: v0.18.0 Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -28,6 +28,7 @@ import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.device.device_op import DeviceOperator
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region
|
||||
@@ -96,6 +97,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
# When enable hierarchical communication, param `expert_scales` need to be passed in.
|
||||
self.need_expert_scale = is_hierarchical_communication_enabled()
|
||||
|
||||
# Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute
|
||||
@@ -115,6 +117,14 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
||||
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
||||
self.global_bs = num_tokens_per_tp_rank * self.ep_world_size
|
||||
|
||||
# NOTE: When enable_mc2_hierarchy_comm is true, we need pass in `comm_alg` to mc2 op.
|
||||
self.need_comm_alg = get_ascend_config().enable_mc2_hierarchy_comm
|
||||
|
||||
if not self.enable_dispatch_v2 and self.need_comm_alg:
|
||||
raise RuntimeError(
|
||||
"PTA and CANN version is too old to support mc2 hierarchy comm, please upgrade your version."
|
||||
)
|
||||
|
||||
def get_dispatch_mc2_kwargs(
|
||||
self,
|
||||
token_dispatch_input: MoETokenDispatchInput,
|
||||
@@ -176,6 +186,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
||||
"expert_scales": topk_weights.to(torch.float32),
|
||||
}
|
||||
)
|
||||
if self.need_comm_alg:
|
||||
stage1_kwargs.update({"comm_alg": "hierarchy"})
|
||||
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
return kwargs_mc2
|
||||
@@ -265,6 +277,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
||||
"tp_rank_id": 0,
|
||||
}
|
||||
)
|
||||
if self.need_comm_alg:
|
||||
stage3_kwargs.update({"comm_alg": "hierarchy"})
|
||||
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
Reference in New Issue
Block a user