[2/N][Feat] Add MC2 communication method for MoE layers (#2469)

### What this PR does / why we need it?
This method replaces the previous all-gather approach for small numbers
of tokens.

The key changes include:
- A new `AscendFusedMoE` layer that handles token splitting, local
computation, and final aggregation via all-gather.
- Logic in the model runner to dynamically select between the new MC2
method and the existing all-gather method based on the number of input
tokens.
- Sharding the MoE communication mask across tensor-parallel ranks.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Test case fixed.


- vLLM version: v0.10.1.1
- vLLM main:
b00e69f8ca

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-08-26 19:05:23 +08:00
committed by GitHub
parent 5d8ec28009
commit a6bb502e70
11 changed files with 506 additions and 410 deletions

View File

@@ -1,5 +1,5 @@
import unittest
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import MagicMock, patch
import torch
import torch.distributed as dist
@@ -87,69 +87,3 @@ class TestNPUCommunicator(unittest.TestCase):
output = comm.all_to_all(input_, scatter_dim=0, gather_dim=0)
assert output.tolist() == [[10, 20], [50, 60]]
@patch("vllm.config.get_current_vllm_config", return_value=None)
@patch("torch.npu.current_device", return_value=MagicMock())
@patch("torch.npu.set_device", return_value=MagicMock())
@patch("torch.distributed.get_process_group_ranks",
return_value={
0: 0,
1: 1
})
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_backend", return_value="hccl")
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.get_world_size", return_value=2)
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
@patch("torch.npu.device")
def test_dispatch(self, *_):
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
comm.all2all_manager = Mock()
hidden_states = torch.randn(2, 4, 8)
router_logits = torch.randn(2, 4, 2)
mock_dispatch_result = (torch.randn(2, 4, 8), torch.randn(2, 4, 2))
comm.all2all_manager.dispatch.return_value = mock_dispatch_result
result_hidden, result_logits = comm.dispatch(hidden_states,
router_logits)
assert torch.allclose(result_hidden, mock_dispatch_result[0])
assert torch.allclose(result_logits, mock_dispatch_result[1])
comm.all2all_manager.dispatch.assert_called_once_with(
hidden_states, router_logits)
@patch("vllm.config.get_current_vllm_config", return_value=None)
@patch("torch.npu.current_device", return_value=MagicMock())
@patch("torch.npu.set_device", return_value=MagicMock())
@patch("torch.distributed.get_process_group_ranks",
return_value={
0: 0,
1: 1
})
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_backend", return_value="hccl")
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.get_world_size", return_value=2)
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
@patch("torch.npu.device")
def test_combine(self, *_):
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
comm.all2all_manager = Mock()
hidden_states = torch.randn(2, 4, 8)
mock_combine_result = torch.randn(2, 4, 8)
comm.all2all_manager.combine.return_value = mock_combine_result
result = comm.combine(hidden_states)
assert torch.allclose(result, mock_combine_result)
comm.all2all_manager.combine.assert_called_once_with(hidden_states)