[feat] support dispatch_v2/combine_v2 hierarchy communication (#7698)
### What this PR does / why we need it? This PR adds support for hierarchical communication for `dispatch_v2` and `combine_v2` MoE operations. This is achieved by introducing a new configuration `enable_mc2_hierarchy_comm`. When enabled, the communication algorithm is set to "hierarchy", which support mc2 op comm between two super pod. The changes include: - Adding `enable_mc2_hierarchy_comm` to `AscendConfig`. - Modifying `TokenDispatcherWithMC2` to pass `comm_alg: "hierarchy"` to the underlying `torch_npu` ops when the new config is enabled. - Adding validation to ensure that this feature is only used with compatible PTA/CANN versions and is not used with the conflicting `fused_mc2` op. - Updating `is_hierarchical_communication_enabled` to respect the new configuration flag. ### Does this PR introduce _any_ user-facing change? Yes, this PR introduces a new user-facing configuration option `enable_mc2_hierarchy_comm` in `additional_config` to enable hierarchical communication for MoE. ### How was this patch tested? - vLLM version: v0.18.0 Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -43,8 +43,9 @@ The following table lists additional configuration options available in vLLM Asc
|
|||||||
| `enable_npugraph_ex` | bool | `False` | Whether to enable npugraph_ex graph mode. |
|
| `enable_npugraph_ex` | bool | `False` | Whether to enable npugraph_ex graph mode. |
|
||||||
| `pa_shape_list` | list | `[]` | The custom shape list of page attention ops. |
|
| `pa_shape_list` | list | `[]` | The custom shape list of page attention ops. |
|
||||||
| `enable_kv_nz` | bool | `False` | Whether to enable KV cache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
|
| `enable_kv_nz` | bool | `False` | Whether to enable KV cache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
|
||||||
| `layer_sharding` | dict | `{}` | Configuration options for Layer Sharding Linear |
|
| `layer_sharding` | dict | `{}` | Configuration options for Layer Sharding Linear |
|
||||||
| `enable_sparse_c8` | bool | `False` | Whether to enable KV cache C8 in DSA models (e.g., DeepSeekV3.2 and GLM5). Not supported on A5 devices now |
|
| `enable_sparse_c8` | bool | `False` | Whether to enable KV cache C8 in DSA models (e.g., DeepSeekV3.2 and GLM5). Not supported on A5 devices now |
|
||||||
|
| `enable_mc2_hierarchy_comm` | bool | `False` | Enable dispatch/combine op inter-node communication by ROCE. |
|
||||||
|
|
||||||
The details of each configuration option are as follows:
|
The details of each configuration option are as follows:
|
||||||
|
|
||||||
|
|||||||
@@ -165,6 +165,9 @@ class AscendConfig:
|
|||||||
and vllm_config.compilation_config.pass_config.enable_sp
|
and vllm_config.compilation_config.pass_config.enable_sp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Enable dispatch/combine op inter-node communication by ROCE
|
||||||
|
self.enable_mc2_hierarchy_comm = additional_config.get("enable_mc2_hierarchy_comm", False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_compile_ranges(compilation_config):
|
def _get_compile_ranges(compilation_config):
|
||||||
return compilation_config.compile_ranges_endpoints or []
|
return compilation_config.compile_ranges_endpoints or []
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import torch_npu
|
|||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.distributed.parallel_state import get_ep_group
|
from vllm.distributed.parallel_state import get_ep_group
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.device.device_op import DeviceOperator
|
from vllm_ascend.device.device_op import DeviceOperator
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region
|
from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region
|
||||||
@@ -96,6 +97,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
|||||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||||
# improve communication performance.
|
# improve communication performance.
|
||||||
|
# When enable hierarchical communication, param `expert_scales` need to be passed in.
|
||||||
self.need_expert_scale = is_hierarchical_communication_enabled()
|
self.need_expert_scale = is_hierarchical_communication_enabled()
|
||||||
|
|
||||||
# Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute
|
# Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute
|
||||||
@@ -115,6 +117,14 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
|||||||
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
||||||
self.global_bs = num_tokens_per_tp_rank * self.ep_world_size
|
self.global_bs = num_tokens_per_tp_rank * self.ep_world_size
|
||||||
|
|
||||||
|
# NOTE: When enable_mc2_hierarchy_comm is true, we need pass in `comm_alg` to mc2 op.
|
||||||
|
self.need_comm_alg = get_ascend_config().enable_mc2_hierarchy_comm
|
||||||
|
|
||||||
|
if not self.enable_dispatch_v2 and self.need_comm_alg:
|
||||||
|
raise RuntimeError(
|
||||||
|
"PTA and CANN version is too old to support mc2 hierarchy comm, please upgrade your version."
|
||||||
|
)
|
||||||
|
|
||||||
def get_dispatch_mc2_kwargs(
|
def get_dispatch_mc2_kwargs(
|
||||||
self,
|
self,
|
||||||
token_dispatch_input: MoETokenDispatchInput,
|
token_dispatch_input: MoETokenDispatchInput,
|
||||||
@@ -176,6 +186,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
|||||||
"expert_scales": topk_weights.to(torch.float32),
|
"expert_scales": topk_weights.to(torch.float32),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
if self.need_comm_alg:
|
||||||
|
stage1_kwargs.update({"comm_alg": "hierarchy"})
|
||||||
|
|
||||||
kwargs_mc2.update(stage1_kwargs)
|
kwargs_mc2.update(stage1_kwargs)
|
||||||
return kwargs_mc2
|
return kwargs_mc2
|
||||||
@@ -265,6 +277,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
|||||||
"tp_rank_id": 0,
|
"tp_rank_id": 0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
if self.need_comm_alg:
|
||||||
|
stage3_kwargs.update({"comm_alg": "hierarchy"})
|
||||||
|
|
||||||
kwargs_mc2.update(stage3_kwargs)
|
kwargs_mc2.update(stage3_kwargs)
|
||||||
return kwargs_mc2
|
return kwargs_mc2
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from vllm.platforms import Platform, PlatformEnum
|
|||||||
# todo: please remove it when solve cuda hard code in vllm
|
# todo: please remove it when solve cuda hard code in vllm
|
||||||
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
|
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
|
||||||
|
|
||||||
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_config import init_ascend_config
|
from vllm_ascend.ascend_config import init_ascend_config
|
||||||
|
|
||||||
# isort: off
|
# isort: off
|
||||||
@@ -511,6 +512,12 @@ class NPUPlatform(Platform):
|
|||||||
):
|
):
|
||||||
speculative_config.enforce_eager = False
|
speculative_config.enforce_eager = False
|
||||||
|
|
||||||
|
if ascend_config.enable_mc2_hierarchy_comm and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2:
|
||||||
|
raise ValueError(
|
||||||
|
"fused mc2 op cannot be used with hierarchy communication."
|
||||||
|
"Please disable VLLM_ASCEND_ENABLE_FUSED_MC2 by setting it to 0."
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def import_kernels(cls) -> None:
|
def import_kernels(cls) -> None:
|
||||||
# Directly importing vllm_ascend_C prevents ASCEND_RT_VISIBLE_DEVICES
|
# Directly importing vllm_ascend_C prevents ASCEND_RT_VISIBLE_DEVICES
|
||||||
|
|||||||
@@ -991,7 +991,9 @@ def calculate_dp_buffer_size() -> int:
|
|||||||
# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and
|
# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and
|
||||||
# significantly improve communication performance of MC2 ops dispatch/combine.
|
# significantly improve communication performance of MC2 ops dispatch/combine.
|
||||||
def is_hierarchical_communication_enabled():
|
def is_hierarchical_communication_enabled():
|
||||||
return os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1"
|
return (
|
||||||
|
os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1"
|
||||||
|
) or get_ascend_config().enable_mc2_hierarchy_comm
|
||||||
|
|
||||||
|
|
||||||
def has_layer_idx(model_instance: torch.nn.Module) -> bool:
|
def has_layer_idx(model_instance: torch.nn.Module) -> bool:
|
||||||
|
|||||||
Reference in New Issue
Block a user