[Refactor] Adjustments to moe_comm_method selection process (#3001)

### What this PR does / why we need it?
Fix issues mentioned in
https://github.com/vllm-project/vllm-ascend/pull/2791 and some minor
refactoring.
1. Use Enum instead of string.
2. Avoid setting a new property to forward_context in
AscendFusedMoE.forward().
3. Enabling TokenDispatcherWithMoge.
4. Remove redundant code.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing:
1. Enable/Disable EP
2. Aclgraph & eager


- vLLM version: v0.10.2
- vLLM main:
9607d5eb44

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-09-22 19:12:58 +08:00
committed by GitHub
parent bb1f0d5a62
commit 37a0715eda
14 changed files with 170 additions and 351 deletions

View File

@@ -17,56 +17,7 @@ from unittest.mock import patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE, fused_experts_moge
class TestFusedExpertsMoGE(TestBase):
def test_fused_experts_moge(self):
with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \
patch('torch_npu.npu_swiglu') as mock_swiglu, \
patch('vllm_ascend.utils.is_310p') as mock_is_310p:
mock_is_310p.return_value = False
mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [
torch.randn(x[0].shape[0], weight[0].shape[1])
]
mock_swiglu.side_effect = lambda x: x
hidden_states = torch.randn(4, 128)
w1 = torch.randn(4, 256, 128)
w2 = torch.randn(4, 128, 128)
topk_weights = torch.rand(4, 1)
topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)
top_k = 1
global_num_experts = 4
moe_parallel_config = type(
'MockConfig', (), {
'ep_size': 1,
'tp_size': 1,
'dp_size': 1,
'tp_rank': 0,
'dp_rank': 0,
'ep_rank': 0,
'use_ep': True
})()
output = fused_experts_moge(
hidden_states=hidden_states,
w1=w1,
w2=w2,
moe_parallel_config=moe_parallel_config,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
apply_router_weight_on_input=True,
)
self.assertEqual(output.shape, (4, 128))
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
class TestLoadWeight(TestBase):