[Refactor] Adjustments to moe_comm_method selection process (#3001)

### What this PR does / why we need it?
Fix issues mentioned in
https://github.com/vllm-project/vllm-ascend/pull/2791 and some minor
refactoring.
1. Use Enum instead of string.
2. Avoid setting a new property to forward_context in
AscendFusedMoE.forward().
3. Enabling TokenDispatcherWithMoge.
4. Remove redundant code.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing:
1. Enable/Disable EP
2. Aclgraph & eager


- vLLM version: v0.10.2
- vLLM main:
9607d5eb44

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-09-22 19:12:58 +08:00
committed by GitHub
parent bb1f0d5a62
commit 37a0715eda
14 changed files with 170 additions and 351 deletions

View File

@@ -17,56 +17,7 @@ from unittest.mock import patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE, fused_experts_moge
class TestFusedExpertsMoGE(TestBase):
def test_fused_experts_moge(self):
with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \
patch('torch_npu.npu_swiglu') as mock_swiglu, \
patch('vllm_ascend.utils.is_310p') as mock_is_310p:
mock_is_310p.return_value = False
mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [
torch.randn(x[0].shape[0], weight[0].shape[1])
]
mock_swiglu.side_effect = lambda x: x
hidden_states = torch.randn(4, 128)
w1 = torch.randn(4, 256, 128)
w2 = torch.randn(4, 128, 128)
topk_weights = torch.rand(4, 1)
topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)
top_k = 1
global_num_experts = 4
moe_parallel_config = type(
'MockConfig', (), {
'ep_size': 1,
'tp_size': 1,
'dp_size': 1,
'tp_rank': 0,
'dp_rank': 0,
'ep_rank': 0,
'use_ep': True
})()
output = fused_experts_moge(
hidden_states=hidden_states,
w1=w1,
w2=w2,
moe_parallel_config=moe_parallel_config,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
apply_router_weight_on_input=True,
)
self.assertEqual(output.shape, (4, 128))
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
class TestLoadWeight(TestBase):

View File

@@ -23,6 +23,7 @@ from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.moe.experts_selector import select_experts
@@ -55,6 +56,26 @@ def mock_npu_format_cast(weight_data, format):
return weight_data
@pytest.fixture(autouse=True)
def setup_vllm_config_mock(mocker: MockerFixture):
mock_hf_config = MagicMock()
mock_hf_config.model_type = "llama"
mock_model_config = MagicMock()
mock_model_config.hf_config = mock_hf_config
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.parallel_config = MagicMock(tensor_parallel_size=2)
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
mock_vllm_config.model_config.max_model_len = 2048
mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
return_value=mock_vllm_config)
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
return_value=mock_vllm_config)
@pytest.fixture
def mock_dist_env(mocker: MockerFixture):
mock_moe_comm_method = MagicMock()
@@ -74,7 +95,7 @@ def mock_dist_env(mocker: MockerFixture):
mock_forward_context_obj = MagicMock(
moe_comm_method=mock_moe_comm_method,
moe_comm_method_name="mc2commimpl",
moe_comm_type=MoECommType.MC2,
max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
mc2_mask=torch.zeros(16, dtype=torch.bool),
@@ -104,12 +125,6 @@ def mock_dist_env(mocker: MockerFixture):
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
return_value=MagicMock(
parallel_config=MagicMock(tensor_parallel_size=2),
scheduler_config=MagicMock(max_num_seqs=4),
model_config=MagicMock(max_model_len=2048)
)), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
return_value=mock_forward_context_obj), \
@@ -501,7 +516,7 @@ class TestUnifiedApplyMLP(TestBase):
mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.moe_comm_method_name = "mc2commimpl"
mock_forward_context.moe_comm_type = MoECommType.MC2
mock_get_forward_context.return_value = mock_forward_context
mock_is_310p.return_value = False

View File

@@ -24,6 +24,7 @@ class TestMoECommMethod(TestBase):
self.moe_config.dp_group = MagicMock()
self.moe_config.num_global_redundant_experts = 0
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
@@ -31,7 +32,11 @@ class TestMoECommMethod(TestBase):
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
def test_all_gather_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize,
mock_get_forward_context):
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "all_gather"
@@ -64,13 +69,18 @@ class TestMoECommMethod(TestBase):
comm_impl.finalize(h_out, reduce_results=True)
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2"
)
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context):
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "mc2"
@@ -104,6 +114,7 @@ class TestMoECommMethod(TestBase):
comm_impl.finalize(h_out, reduce_results=True)
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All"
@@ -111,7 +122,11 @@ class TestMoECommMethod(TestBase):
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
def test_alltoall_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize,
mock_get_forward_context):
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "alltoall"
@@ -140,6 +155,7 @@ class TestMoECommMethod(TestBase):
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, False, None)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
@@ -148,7 +164,11 @@ class TestMoECommMethod(TestBase):
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp")
def test_fused_experts_method(self, mock_unified_apply_mlp,
mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context):
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "all_gather"