[2/N][Feat] Add MC2 communication method for MoE layers (#2469)

### What this PR does / why we need it?
This method replaces the previous all-gather approach for small numbers
of tokens.

The key changes include:
- A new `AscendFusedMoE` layer that handles token splitting, local
computation, and final aggregation via all-gather.
- Logic in the model runner to dynamically select between the new MC2
method and the existing all-gather method based on the number of input
tokens.
- Sharding the MoE communication mask across tensor-parallel ranks.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Test case fixed.


- vLLM version: v0.10.1.1
- vLLM main:
b00e69f8ca

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-08-26 19:05:23 +08:00
committed by GitHub
parent 5d8ec28009
commit a6bb502e70
11 changed files with 506 additions and 410 deletions

View File

@@ -18,29 +18,30 @@ from types import SimpleNamespace
import pytest
import torch
from transformers import PretrainedConfig
from vllm import forward_context
from vllm_ascend.distributed import moe_comm_method
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
NativeAllGatherCommImpl)
from vllm.model_executor.layers.fused_moe.config import ( # isort: skip
FusedMoEConfig, FusedMoEParallelConfig)
from vllm_ascend.distributed.moe_comm_method import ( # isort: skip
AllGatherCommImpl, NativeAllGatherCommImpl)
@pytest.mark.parametrize("num_tokens", [16, 128])
@pytest.mark.parametrize("hidden_size", [64, 128])
@pytest.mark.parametrize("global_num_experts", [8, 16])
@pytest.mark.parametrize("num_local_experts", [4, 8])
@pytest.mark.parametrize("top_k_num", [2, 4])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("num_local_experts", [4, 8])
@pytest.mark.parametrize("ep_rank", [0, 1])
def test_all_gather_comm_impl(
num_tokens,
hidden_size,
global_num_experts,
num_local_experts,
top_k_num,
dtype,
num_local_experts,
ep_rank,
mocker,
):
"""
Tests the AllGatherCommImpl against the NativeAllGatherCommImpl.
@@ -56,23 +57,37 @@ def test_all_gather_comm_impl(
"num_local_experts cannot be greater than global_num_experts")
device = torch.device("npu")
hf_config = PretrainedConfig(
num_experts_per_tok=top_k_num,
# mock get_tensor_model_parallel_rank to return ep_rank
mocker.patch(
"vllm.model_executor.layers.fused_moe.config.get_tensor_model_parallel_rank",
return_value=ep_rank,
)
# make moe config
parallel_config = SimpleNamespace(
enable_expert_parallel=num_local_experts < global_num_experts)
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
tp_size_=max(2, global_num_experts // num_local_experts),
dp_size_=1,
vllm_parallel_config=parallel_config,
)
moe_config = FusedMoEConfig(
num_experts=global_num_experts,
experts_per_token=top_k_num,
hidden_dim=hidden_size,
num_local_experts=num_local_experts,
moe_parallel_config=moe_parallel_config,
in_dtype=dtype,
quant_config=None, # No quantization in this test
max_num_tokens=num_tokens,
)
# Instantiate implementations
native_impl = NativeAllGatherCommImpl(device, dtype, hf_config)
native_impl = NativeAllGatherCommImpl(moe_config)
all_gather_impl = AllGatherCommImpl(device, dtype, hf_config)
# TODO: Find out if this is the correct way to mock the forward context and ep group
# Mock get_forward_context to return an object with moe_comm_method
forward_context._forward_context = SimpleNamespace(
moe_comm_method=all_gather_impl)
# Mock get_ep_group to return a fake group with the specified ep_rank
fake_ep_group = SimpleNamespace(rank_in_group=ep_rank)
moe_comm_method.get_ep_group = lambda: fake_ep_group
all_gather_impl = AllGatherCommImpl(moe_config)
# --- Input Data ---
hidden_states = torch.randn(num_tokens,
@@ -103,11 +118,11 @@ def test_all_gather_comm_impl(
native_permuted_hidden,
native_expert_tokens,
_,
) = native_impl._pre_process(hidden_states, topk_ids, topk_weights,
expert_map, num_experts)
) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map,
num_experts)
# Simulate MLP output
native_mlp_output = torch.randn_like(native_permuted_hidden)
native_impl._post_process(native_mlp_output, native_hidden_states_out)
native_impl.unpermute(native_mlp_output, native_hidden_states_out)
# --- Run AllGather Implementation ---
all_gather_hidden_states_out = hidden_states.clone()
@@ -115,15 +130,14 @@ def test_all_gather_comm_impl(
all_gather_permuted_hidden,
all_gather_expert_tokens,
_,
) = torch.ops.vllm.moe_comm_pre_process(hidden_states, topk_ids,
topk_weights, expert_map,
num_experts)
) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights,
expert_map, num_experts)
# Use the same simulated MLP output for a fair comparison
all_gather_mlp_output = native_mlp_output.clone()
torch.ops.vllm.moe_comm_post_process(all_gather_mlp_output,
all_gather_hidden_states_out)
all_gather_impl.unpermute(all_gather_mlp_output,
all_gather_hidden_states_out)
# --- Assertions ---
# Define tolerance based on dtype