[Feat] enable hierarchical communication for mc2 ops on A2 (#3015)

Currently, when in A2, setting the environment variables
`HCCL_INTRA_PCIE_ENABLE=1` and `HCCL_INTRA_ROCE_ENABLE=0` can reduce
cross-machine communication traffic and significantly improve
communication performance.

For more details, please refer to
[document](https://www.hiascend.com/document/detail/zh/Pytorch/710/apiref/torchnpuCustomsapi/context/torch_npu-npu_moe_distribute_dispatch_v2.md)

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu
2025-10-13 16:13:17 +08:00
committed by GitHub
parent 0563106477
commit 31682961af
6 changed files with 112 additions and 17 deletions

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
@@ -29,7 +28,8 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
dispose_tensor, get_ascend_soc_version)
dispose_tensor, get_ascend_soc_version,
is_hierarchical_communication_enabled)
def torchair_apply_mlp_decode(hidden_states: torch.Tensor,
@@ -237,6 +237,10 @@ def torchair_fused_experts_with_mc2(
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
# improve communication performance.
need_expert_scale = is_hierarchical_communication_enabled()
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
@@ -271,6 +275,10 @@ def torchair_fused_experts_with_mc2(
stage1_kwargs.update({
"x_active_mask": mc2_mask,
})
if need_expert_scale:
stage1_kwargs.update({
"expert_scales": topk_weights.to(torch.float32),
})
kwargs_mc2.update(stage1_kwargs)
output = torch_npu.npu_moe_distribute_dispatch_v2(
@@ -278,8 +286,8 @@ def torchair_fused_experts_with_mc2(
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[
0:5]
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \
ep_recv_counts, _, expand_scales = output[0:7]
if shared_experts is not None:
with npu_stream_switch("moe_secondary", 0):
@@ -327,6 +335,7 @@ def torchair_fused_experts_with_mc2(
"group_ep": moe_all_to_all_group_name,
"ep_world_size": ep_world_size,
"ep_rank_id": ep_rank_id,
"expand_scales": expand_scales,
}
if enable_dispatch_v2:
stage3_kwargs.update({