[main] [refactor] refactor common_fused_moe.py (#2706)
### What this PR does / why we need it?
1. Move prepare/finalize operation from moe_comm_method to
/ops/moe/fused_moe_prepare_and_finalize
2. Adapt to token_dispatcher in moe_comm_method
3. Move
moe_comm_method/experts_selector/token_dispatcher/fused_moe_prepare_and_finalize
to /ops/moe
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.1.1
- vLLM main:
f4962a6d55
Signed-off-by: weichen <calvin_zhu0210@outlook.com>
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -20,7 +20,8 @@ from unittest.mock import MagicMock, PropertyMock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
|
||||
from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip
|
||||
AscendSocVersion, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers,
|
||||
_register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers)
|
||||
@@ -34,7 +35,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
self.mc2_group.rank_in_group = 0
|
||||
self.mc2_group.world_size = 8
|
||||
self.mc2_group_patch = patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group",
|
||||
"vllm_ascend.ops.moe.token_dispatcher.get_mc2_group",
|
||||
return_value=self.mc2_group)
|
||||
self.mc2_group_patch.start()
|
||||
|
||||
@@ -52,7 +53,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
|
||||
# Mock get_ascend_soc_version()
|
||||
self.ascend_soc_version_patch = patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version",
|
||||
"vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version",
|
||||
return_value=AscendSocVersion.A3)
|
||||
self.ascend_soc_version_patch.start()
|
||||
|
||||
@@ -329,7 +330,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
|
||||
# Mock gather_from_sequence_parallel_region
|
||||
patcher7 = patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
'vllm_ascend.ops.moe.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
)
|
||||
self.mock_gather_from_sequence_parallel_region = patcher7.start()
|
||||
self.addCleanup(patcher7.stop)
|
||||
@@ -518,12 +519,8 @@ class TestDispatcherRegistry(TestBase):
|
||||
|
||||
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAllGather'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAllGather')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
|
||||
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
|
||||
self, mock_register, mock_allgather_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 8}
|
||||
@@ -537,12 +534,8 @@ class TestDispatcherRegistry(TestBase):
|
||||
mock_allgather_class.assert_called_once_with(**kwargs)
|
||||
mock_register.assert_called_once_with(mock_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
|
||||
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
|
||||
self, mock_register, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
|
||||
@@ -556,15 +549,9 @@ class TestDispatcherRegistry(TestBase):
|
||||
mock_all2allv_class.assert_called_once_with(**kwargs)
|
||||
mock_register.assert_called_once_with(mock_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
|
||||
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
|
||||
self, mock_register, mock_mc2_class, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
|
||||
@@ -584,15 +571,9 @@ class TestDispatcherRegistry(TestBase):
|
||||
mock_register.assert_any_call(mock_all2allv_instance)
|
||||
mock_register.assert_any_call(mock_mc2_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
|
||||
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
|
||||
self, mock_register, mock_mc2_class, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
|
||||
|
||||
Reference in New Issue
Block a user