[main] [refactor] refactor common_fused_moe.py (#2706)

### What this PR does / why we need it?
1. Move prepare/finalize operation from moe_comm_method to
/ops/moe/fused_moe_prepare_and_finalize
2. Adapt to token_dispatcher in moe_comm_method
3. Move
moe_comm_method/experts_selector/token_dispatcher/fused_moe_prepare_and_finalize
to /ops/moe
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut

- vLLM version: v0.10.1.1
- vLLM main:
f4962a6d55

Signed-off-by: weichen <calvin_zhu0210@outlook.com>
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-09-08 20:09:50 +08:00
committed by GitHub
parent 1a82b16355
commit a041d4f328
21 changed files with 1052 additions and 932 deletions

View File

@@ -20,7 +20,8 @@ from unittest.mock import MagicMock, PropertyMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip
AscendSocVersion, TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers,
_register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers)
@@ -34,7 +35,7 @@ class TestTokenDispatcherWithMC2(TestBase):
self.mc2_group.rank_in_group = 0
self.mc2_group.world_size = 8
self.mc2_group_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group",
"vllm_ascend.ops.moe.token_dispatcher.get_mc2_group",
return_value=self.mc2_group)
self.mc2_group_patch.start()
@@ -52,7 +53,7 @@ class TestTokenDispatcherWithMC2(TestBase):
# Mock get_ascend_soc_version()
self.ascend_soc_version_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version",
"vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version",
return_value=AscendSocVersion.A3)
self.ascend_soc_version_patch.start()
@@ -329,7 +330,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
# Mock gather_from_sequence_parallel_region
patcher7 = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.gather_from_sequence_parallel_region'
'vllm_ascend.ops.moe.token_dispatcher.gather_from_sequence_parallel_region'
)
self.mock_gather_from_sequence_parallel_region = patcher7.start()
self.addCleanup(patcher7.stop)
@@ -518,12 +519,8 @@ class TestDispatcherRegistry(TestBase):
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAllGather'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAllGather')
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
self, mock_register, mock_allgather_class):
kwargs = {"top_k": 2, "num_experts": 8}
@@ -537,12 +534,8 @@ class TestDispatcherRegistry(TestBase):
mock_allgather_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
self, mock_register, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
@@ -556,15 +549,9 @@ class TestDispatcherRegistry(TestBase):
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2')
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
@@ -584,15 +571,9 @@ class TestDispatcherRegistry(TestBase):
mock_register.assert_any_call(mock_all2allv_instance)
mock_register.assert_any_call(mock_mc2_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2')
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}