[Refactor] Adjustments to moe_comm_method selection process (#3001)

### What this PR does / why we need it?
Fix issues mentioned in
https://github.com/vllm-project/vllm-ascend/pull/2791 and some minor
refactoring.
1. Use Enum instead of string.
2. Avoid setting a new property to forward_context in
AscendFusedMoE.forward().
3. Enabling TokenDispatcherWithMoge.
4. Remove redundant code.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing:
1. Enable/Disable EP
2. Aclgraph & eager


- vLLM version: v0.10.2
- vLLM main:
9607d5eb44

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-09-22 19:12:58 +08:00
committed by GitHub
parent bb1f0d5a62
commit 37a0715eda
14 changed files with 170 additions and 351 deletions

View File

@@ -24,6 +24,7 @@ class TestMoECommMethod(TestBase):
self.moe_config.dp_group = MagicMock()
self.moe_config.num_global_redundant_experts = 0
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
@@ -31,7 +32,11 @@ class TestMoECommMethod(TestBase):
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
def test_all_gather_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize,
mock_get_forward_context):
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "all_gather"
@@ -64,13 +69,18 @@ class TestMoECommMethod(TestBase):
comm_impl.finalize(h_out, reduce_results=True)
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2"
)
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context):
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "mc2"
@@ -104,6 +114,7 @@ class TestMoECommMethod(TestBase):
comm_impl.finalize(h_out, reduce_results=True)
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All"
@@ -111,7 +122,11 @@ class TestMoECommMethod(TestBase):
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
def test_alltoall_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize,
mock_get_forward_context):
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "alltoall"
@@ -140,6 +155,7 @@ class TestMoECommMethod(TestBase):
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, False, None)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
@@ -148,7 +164,11 @@ class TestMoECommMethod(TestBase):
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp")
def test_fused_experts_method(self, mock_unified_apply_mlp,
mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context):
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "all_gather"