[Refactor] Adjustments to moe_comm_method selection process (#3001)
### What this PR does / why we need it?
Fix issues mentioned in
https://github.com/vllm-project/vllm-ascend/pull/2791 and some minor
refactoring.
1. Use Enum instead of string.
2. Avoid setting a new property to forward_context in
AscendFusedMoE.forward().
3. Enabling TokenDispatcherWithMoge.
4. Remove redundant code.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing:
1. Enable/Disable EP
2. Aclgraph & eager
- vLLM version: v0.10.2
- vLLM main:
9607d5eb44
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -24,6 +24,7 @@ class TestMoECommMethod(TestBase):
|
||||
self.moe_config.dp_group = MagicMock()
|
||||
self.moe_config.num_global_redundant_experts = 0
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
|
||||
@@ -31,7 +32,11 @@ class TestMoECommMethod(TestBase):
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
|
||||
def test_all_gather_comm_impl(self, mock_token_dispatcher,
|
||||
mock_prepare_finalize,
|
||||
mock_get_forward_context):
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
# Mock vLLM config
|
||||
mock_get_current_vllm_config.return_value = MagicMock()
|
||||
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.moe_comm_method = "all_gather"
|
||||
@@ -64,13 +69,18 @@ class TestMoECommMethod(TestBase):
|
||||
comm_impl.finalize(h_out, reduce_results=True)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
|
||||
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
|
||||
mock_get_forward_context):
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
# Mock vLLM config
|
||||
mock_get_current_vllm_config.return_value = MagicMock()
|
||||
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.moe_comm_method = "mc2"
|
||||
@@ -104,6 +114,7 @@ class TestMoECommMethod(TestBase):
|
||||
comm_impl.finalize(h_out, reduce_results=True)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All"
|
||||
@@ -111,7 +122,11 @@ class TestMoECommMethod(TestBase):
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
|
||||
def test_alltoall_comm_impl(self, mock_token_dispatcher,
|
||||
mock_prepare_finalize,
|
||||
mock_get_forward_context):
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
# Mock vLLM config
|
||||
mock_get_current_vllm_config.return_value = MagicMock()
|
||||
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.moe_comm_method = "alltoall"
|
||||
@@ -140,6 +155,7 @@ class TestMoECommMethod(TestBase):
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, False, None)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
|
||||
@@ -148,7 +164,11 @@ class TestMoECommMethod(TestBase):
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp")
|
||||
def test_fused_experts_method(self, mock_unified_apply_mlp,
|
||||
mock_token_dispatcher, mock_prepare_finalize,
|
||||
mock_get_forward_context):
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
# Mock vLLM config
|
||||
mock_get_current_vllm_config.return_value = MagicMock()
|
||||
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.moe_comm_method = "all_gather"
|
||||
|
||||
Reference in New Issue
Block a user