[Main] [Refactor] Enable MoECommMethod in Eager Mode (#2791)
### What this PR does / why we need it?
1. Replace prepare/finalize operation in fused_moe.py by
moe_comm_method.prepare()/finalize()
2. Replace unified_fused_experts by moe_comm_method.fused_experts() in
fused_moe.py/w8a8_dynamic.py/w4a8_dynamic.py
3. Add calling _select_moe_comm_method in spec-decode proposers.
4. Currently, w4a8_dynamic does not support gatherep, use all2allv
instead.
5. Remove redundant code.
### Does this PR introduce _any_ user-facing change?
AllgatherEP switch is disabled in aclgraph/eager mode, just follow the
rules in modelrunner_v1._select_moe_comm_method()
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.2
- vLLM main:
7f6f2c1182
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -23,8 +23,7 @@ from tests.ut.base import TestBase
|
||||
|
||||
from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip
|
||||
AscendSocVersion, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers,
|
||||
_register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers)
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithMC2(TestBase):
|
||||
@@ -521,99 +520,3 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
|
||||
class TestDispatcherRegistry(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
_Dispatchers.clear()
|
||||
|
||||
def tearDown(self):
|
||||
_Dispatchers.clear()
|
||||
|
||||
def test_register_and_get_token_dispatcher(self):
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_dispatcher.__class__.__name__ = "MockDispatcher"
|
||||
|
||||
_register_token_dispatcher(mock_dispatcher)
|
||||
|
||||
self.assertIn("MockDispatcher", _Dispatchers)
|
||||
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
|
||||
|
||||
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
|
||||
self.assertIs(retrieved_dispatcher, mock_dispatcher)
|
||||
|
||||
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
|
||||
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAllGather')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
|
||||
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
|
||||
self, mock_register, mock_allgather_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 8}
|
||||
mock_instance = MagicMock()
|
||||
mock_allgather_class.return_value = mock_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=1, **kwargs)
|
||||
|
||||
mock_allgather_class.assert_called_once_with(**kwargs)
|
||||
mock_register.assert_called_once_with(mock_instance)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
|
||||
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
|
||||
self, mock_register, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
|
||||
mock_instance = MagicMock()
|
||||
mock_all2allv_class.return_value = mock_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=2, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_called_once_with(**kwargs)
|
||||
mock_register.assert_called_once_with(mock_instance)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
|
||||
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
|
||||
self, mock_register, mock_mc2_class, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
|
||||
mock_all2allv_instance = MagicMock()
|
||||
mock_mc2_instance = MagicMock()
|
||||
mock_all2allv_class.return_value = mock_all2allv_instance
|
||||
mock_mc2_class.return_value = mock_mc2_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
|
||||
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=16, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_called_once_with(**kwargs)
|
||||
mock_mc2_class.assert_called_once_with(**kwargs)
|
||||
self.assertEqual(mock_register.call_count, 2)
|
||||
mock_register.assert_any_call(mock_all2allv_instance)
|
||||
mock_register.assert_any_call(mock_mc2_instance)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
|
||||
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
|
||||
self, mock_register, mock_mc2_class, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
|
||||
mock_existing_all2allv = MagicMock()
|
||||
mock_existing_mc2 = MagicMock()
|
||||
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
|
||||
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
|
||||
|
||||
setup_token_dispatchers(ep_size=16, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_not_called()
|
||||
mock_mc2_class.assert_not_called()
|
||||
mock_register.assert_not_called()
|
||||
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
|
||||
mock_existing_all2allv)
|
||||
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
|
||||
mock_existing_mc2)
|
||||
|
||||
Reference in New Issue
Block a user