[Feat] enable hierarchical communication for mc2 ops on A2 (#3015)

Currently, when in A2, setting the environment variables
`HCCL_INTRA_PCIE_ENABLE=1` and `HCCL_INTRA_ROCE_ENABLE=0` can reduce
cross-machine communication traffic and significantly improve
communication performance.

For more details, please refer to
[document](https://www.hiascend.com/document/detail/zh/Pytorch/710/apiref/torchnpuCustomsapi/context/torch_npu-npu_moe_distribute_dispatch_v2.md)

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu
2025-10-13 16:13:17 +08:00
committed by GitHub
parent 0563106477
commit 31682961af
6 changed files with 112 additions and 17 deletions

View File

@@ -3,8 +3,9 @@ from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
torchair_fused_experts_with_all2all
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
from vllm_ascend.utils import AscendSocVersion
class TestAscendW8A8FusedMoEMethod(TestBase):
@@ -73,3 +74,57 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
self.assertIsNotNone(result)
self.assertEqual(result.dtype, torch.bfloat16)
self.assertEqual(result.shape, (128, 128))
@patch.dict('os.environ', {
'HCCL_INTRA_ROCE_ENABLE': '0',
'HCCL_INTRA_PCIE_ENABLE': '1'
})
@patch(
"vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_ascend_soc_version"
)
@patch(
'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_mc2_group'
)
@patch('torch_npu.npu_moe_distribute_combine_v2')
@patch('torch_npu.npu_moe_distribute_dispatch_v2')
@patch(
'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.torchair_apply_mlp_decode'
)
def test_torchair_fused_experts_with_mc2_a2_optimization(
self, mock_mlp_decode, mock_dispatch, mock_combine, mock_get_group,
mock_ascend_soc_version):
"""Test expert_scales is passed in A2 SOC version with mc2 optimization"""
# Setup mocks
mock_ascend_soc_version.return_value = AscendSocVersion.A2
mock_group = MagicMock()
mock_group.rank_in_group = 0
mock_group.world_size = 4
mock_get_group.return_value = mock_group
mock_combine.return_value = self.placeholder
mock_dispatch.return_value = (torch.randn(32, 1024), torch.randn(1),
torch.randint(0, 32, (32, )),
torch.randint(1, 5, (8, )),
torch.randint(1, 5, (4, )), None,
torch.randn(32))
mock_mlp_decode.return_value = self.placeholder
result = torchair_fused_experts_with_mc2(
hidden_states=self.placeholder,
w1=self.placeholder,
w2=self.placeholder,
w1_scale=self.placeholder,
w2_scale=self.placeholder,
topk_weights=self.placeholder,
topk_ids=self.placeholder,
top_k=2,
mc2_mask=self.placeholder)
# Check that expert_scales was passed to dispatch
call_args = mock_dispatch.call_args[1]
self.assertIn('expert_scales', call_args)
self.assertIsInstance(result, torch.Tensor)
self.assertEqual(result.shape, self.placeholder.shape)