[2/N][Feat] Add MC2 communication method for MoE layers (#2469)
### What this PR does / why we need it?
This method replaces the previous all-gather approach for small numbers
of tokens.
The key changes include:
- A new `AscendFusedMoE` layer that handles token splitting, local
computation, and final aggregation via all-gather.
- Logic in the model runner to dynamically select between the new MC2
method and the existing all-gather method based on the number of input
tokens.
- Sharding the MoE communication mask across tensor-parallel ranks.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Test case fixed.
- vLLM version: v0.10.1.1
- vLLM main:
b00e69f8ca
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -18,29 +18,30 @@ from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
from vllm import forward_context
|
||||
|
||||
from vllm_ascend.distributed import moe_comm_method
|
||||
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
||||
NativeAllGatherCommImpl)
|
||||
from vllm.model_executor.layers.fused_moe.config import ( # isort: skip
|
||||
FusedMoEConfig, FusedMoEParallelConfig)
|
||||
|
||||
from vllm_ascend.distributed.moe_comm_method import ( # isort: skip
|
||||
AllGatherCommImpl, NativeAllGatherCommImpl)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [16, 128])
|
||||
@pytest.mark.parametrize("hidden_size", [64, 128])
|
||||
@pytest.mark.parametrize("global_num_experts", [8, 16])
|
||||
@pytest.mark.parametrize("num_local_experts", [4, 8])
|
||||
@pytest.mark.parametrize("top_k_num", [2, 4])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("num_local_experts", [4, 8])
|
||||
@pytest.mark.parametrize("ep_rank", [0, 1])
|
||||
def test_all_gather_comm_impl(
|
||||
num_tokens,
|
||||
hidden_size,
|
||||
global_num_experts,
|
||||
num_local_experts,
|
||||
top_k_num,
|
||||
dtype,
|
||||
num_local_experts,
|
||||
ep_rank,
|
||||
mocker,
|
||||
):
|
||||
"""
|
||||
Tests the AllGatherCommImpl against the NativeAllGatherCommImpl.
|
||||
@@ -56,23 +57,37 @@ def test_all_gather_comm_impl(
|
||||
"num_local_experts cannot be greater than global_num_experts")
|
||||
|
||||
device = torch.device("npu")
|
||||
hf_config = PretrainedConfig(
|
||||
num_experts_per_tok=top_k_num,
|
||||
|
||||
# mock get_tensor_model_parallel_rank to return ep_rank
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.fused_moe.config.get_tensor_model_parallel_rank",
|
||||
return_value=ep_rank,
|
||||
)
|
||||
|
||||
# make moe config
|
||||
parallel_config = SimpleNamespace(
|
||||
enable_expert_parallel=num_local_experts < global_num_experts)
|
||||
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
||||
tp_size_=max(2, global_num_experts // num_local_experts),
|
||||
dp_size_=1,
|
||||
vllm_parallel_config=parallel_config,
|
||||
)
|
||||
|
||||
moe_config = FusedMoEConfig(
|
||||
num_experts=global_num_experts,
|
||||
experts_per_token=top_k_num,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=num_local_experts,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
in_dtype=dtype,
|
||||
quant_config=None, # No quantization in this test
|
||||
max_num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
# Instantiate implementations
|
||||
native_impl = NativeAllGatherCommImpl(device, dtype, hf_config)
|
||||
native_impl = NativeAllGatherCommImpl(moe_config)
|
||||
|
||||
all_gather_impl = AllGatherCommImpl(device, dtype, hf_config)
|
||||
|
||||
# TODO: Find out if this is the correct way to mock the forward context and ep group
|
||||
# Mock get_forward_context to return an object with moe_comm_method
|
||||
forward_context._forward_context = SimpleNamespace(
|
||||
moe_comm_method=all_gather_impl)
|
||||
# Mock get_ep_group to return a fake group with the specified ep_rank
|
||||
fake_ep_group = SimpleNamespace(rank_in_group=ep_rank)
|
||||
moe_comm_method.get_ep_group = lambda: fake_ep_group
|
||||
all_gather_impl = AllGatherCommImpl(moe_config)
|
||||
|
||||
# --- Input Data ---
|
||||
hidden_states = torch.randn(num_tokens,
|
||||
@@ -103,11 +118,11 @@ def test_all_gather_comm_impl(
|
||||
native_permuted_hidden,
|
||||
native_expert_tokens,
|
||||
_,
|
||||
) = native_impl._pre_process(hidden_states, topk_ids, topk_weights,
|
||||
expert_map, num_experts)
|
||||
) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map,
|
||||
num_experts)
|
||||
# Simulate MLP output
|
||||
native_mlp_output = torch.randn_like(native_permuted_hidden)
|
||||
native_impl._post_process(native_mlp_output, native_hidden_states_out)
|
||||
native_impl.unpermute(native_mlp_output, native_hidden_states_out)
|
||||
|
||||
# --- Run AllGather Implementation ---
|
||||
all_gather_hidden_states_out = hidden_states.clone()
|
||||
@@ -115,15 +130,14 @@ def test_all_gather_comm_impl(
|
||||
all_gather_permuted_hidden,
|
||||
all_gather_expert_tokens,
|
||||
_,
|
||||
) = torch.ops.vllm.moe_comm_pre_process(hidden_states, topk_ids,
|
||||
topk_weights, expert_map,
|
||||
num_experts)
|
||||
) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights,
|
||||
expert_map, num_experts)
|
||||
|
||||
# Use the same simulated MLP output for a fair comparison
|
||||
all_gather_mlp_output = native_mlp_output.clone()
|
||||
|
||||
torch.ops.vllm.moe_comm_post_process(all_gather_mlp_output,
|
||||
all_gather_hidden_states_out)
|
||||
all_gather_impl.unpermute(all_gather_mlp_output,
|
||||
all_gather_hidden_states_out)
|
||||
|
||||
# --- Assertions ---
|
||||
# Define tolerance based on dtype
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -87,69 +87,3 @@ class TestNPUCommunicator(unittest.TestCase):
|
||||
output = comm.all_to_all(input_, scatter_dim=0, gather_dim=0)
|
||||
|
||||
assert output.tolist() == [[10, 20], [50, 60]]
|
||||
|
||||
@patch("vllm.config.get_current_vllm_config", return_value=None)
|
||||
@patch("torch.npu.current_device", return_value=MagicMock())
|
||||
@patch("torch.npu.set_device", return_value=MagicMock())
|
||||
@patch("torch.distributed.get_process_group_ranks",
|
||||
return_value={
|
||||
0: 0,
|
||||
1: 1
|
||||
})
|
||||
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
|
||||
@patch("torch.distributed.is_initialized", return_value=True)
|
||||
@patch("torch.distributed.get_rank", return_value=1)
|
||||
@patch("torch.distributed.is_initialized", return_value=True)
|
||||
@patch("torch.distributed.get_backend", return_value="hccl")
|
||||
@patch("torch.distributed.get_rank", return_value=1)
|
||||
@patch("torch.distributed.get_world_size", return_value=2)
|
||||
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
|
||||
@patch("torch.npu.device")
|
||||
def test_dispatch(self, *_):
|
||||
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
|
||||
comm.all2all_manager = Mock()
|
||||
hidden_states = torch.randn(2, 4, 8)
|
||||
router_logits = torch.randn(2, 4, 2)
|
||||
|
||||
mock_dispatch_result = (torch.randn(2, 4, 8), torch.randn(2, 4, 2))
|
||||
comm.all2all_manager.dispatch.return_value = mock_dispatch_result
|
||||
|
||||
result_hidden, result_logits = comm.dispatch(hidden_states,
|
||||
router_logits)
|
||||
|
||||
assert torch.allclose(result_hidden, mock_dispatch_result[0])
|
||||
assert torch.allclose(result_logits, mock_dispatch_result[1])
|
||||
|
||||
comm.all2all_manager.dispatch.assert_called_once_with(
|
||||
hidden_states, router_logits)
|
||||
|
||||
@patch("vllm.config.get_current_vllm_config", return_value=None)
|
||||
@patch("torch.npu.current_device", return_value=MagicMock())
|
||||
@patch("torch.npu.set_device", return_value=MagicMock())
|
||||
@patch("torch.distributed.get_process_group_ranks",
|
||||
return_value={
|
||||
0: 0,
|
||||
1: 1
|
||||
})
|
||||
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
|
||||
@patch("torch.distributed.is_initialized", return_value=True)
|
||||
@patch("torch.distributed.get_rank", return_value=1)
|
||||
@patch("torch.distributed.is_initialized", return_value=True)
|
||||
@patch("torch.distributed.get_backend", return_value="hccl")
|
||||
@patch("torch.distributed.get_rank", return_value=1)
|
||||
@patch("torch.distributed.get_world_size", return_value=2)
|
||||
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
|
||||
@patch("torch.npu.device")
|
||||
def test_combine(self, *_):
|
||||
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
|
||||
comm.all2all_manager = Mock()
|
||||
hidden_states = torch.randn(2, 4, 8)
|
||||
|
||||
mock_combine_result = torch.randn(2, 4, 8)
|
||||
comm.all2all_manager.combine.return_value = mock_combine_result
|
||||
|
||||
result = comm.combine(hidden_states)
|
||||
|
||||
assert torch.allclose(result, mock_combine_result)
|
||||
|
||||
comm.all2all_manager.combine.assert_called_once_with(hidden_states)
|
||||
|
||||
@@ -289,13 +289,13 @@ class TestUtils(TestBase):
|
||||
# ascend custom op is not registered
|
||||
utils.register_ascend_customop()
|
||||
# should call register_oot three
|
||||
self.assertEqual(mock_customop.register_oot.call_count, 8)
|
||||
self.assertEqual(mock_customop.register_oot.call_count, 9)
|
||||
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
|
||||
|
||||
# ascend custom op is already registered
|
||||
utils.register_ascend_customop()
|
||||
# should not register_oot again, thus only called three in this ut
|
||||
self.assertEqual(mock_customop.register_oot.call_count, 8)
|
||||
self.assertEqual(mock_customop.register_oot.call_count, 9)
|
||||
|
||||
|
||||
class TestProfileExecuteDuration(TestBase):
|
||||
|
||||
Reference in New Issue
Block a user