[main] [refactor] refactor fused_moe.py to enable token_dispatchers (#2570)
### What this PR does / why we need it?
Enable token_dispatcher to replace fused_experts_with_xxx in eager mode
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.1.1
- vLLM main:
704432af3c
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: sherie <963372609@qq.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
Co-authored-by: shiyuan680 <72335504+shiyuan680@users.noreply.github.com>
This commit is contained in:
@@ -25,8 +25,8 @@ from tests.ut.base import PytestBase, TestBase
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig,
|
||||
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2)
|
||||
from vllm_ascend.utils import adapt_patch # noqa E402
|
||||
TokenDispatcherWithMC2, _Dispatchers, _register_token_dispatcher,
|
||||
get_token_dispatcher, setup_token_dispatchers)
|
||||
|
||||
|
||||
class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase):
|
||||
@@ -90,7 +90,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
self.forward_context = MagicMock()
|
||||
self.forward_context.mc2_mask = torch.tensor([1, 0, 1])
|
||||
self.forward_context_patch = patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_forward_context",
|
||||
"vllm.forward_context.get_forward_context",
|
||||
return_value=self.forward_context)
|
||||
self.forward_context_patch.start()
|
||||
|
||||
@@ -100,28 +100,18 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
return_value=AscendSocVersion.A3)
|
||||
self.ascend_soc_version_patch.start()
|
||||
|
||||
# Mock get_ascend_config()
|
||||
self.ascend_config = MagicMock()
|
||||
self.ascend_config.torchair_graph_config.enabled = False
|
||||
self.ascend_config_patch = patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_config",
|
||||
return_value=self.ascend_config)
|
||||
self.ascend_config_patch.start()
|
||||
|
||||
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
|
||||
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
|
||||
def tearDown(self):
|
||||
self.mc2_group_patch.stop()
|
||||
self.forward_context_patch.stop()
|
||||
self.ascend_soc_version_patch.stop()
|
||||
self.ascend_config_patch.stop()
|
||||
|
||||
def test_init(self):
|
||||
# self.assertEqual(self.dispatcher.moe_all_to_all_group_name, "hccl_123")
|
||||
self.assertEqual(self.dispatcher.ep_rank_id, 0)
|
||||
self.assertEqual(self.dispatcher.ep_world_size, 8)
|
||||
self.assertFalse(self.dispatcher.torchair_graph_enabled)
|
||||
self.assertFalse(self.dispatcher.with_quant)
|
||||
self.assertTrue(self.dispatcher.enable_dispatch_v2)
|
||||
self.assertTrue(self.dispatcher.need_extra_args)
|
||||
@@ -149,9 +139,10 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch:
|
||||
output = self.dispatcher.token_dispatch(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
expert_map)
|
||||
self.row_idx, expert_map)
|
||||
mock_dispatch.assert_called_once()
|
||||
self.assertEqual(output[0], 1) # group_list_type == 1
|
||||
self.assertEqual(output["group_list_type"],
|
||||
1) # group_list_type == 1
|
||||
|
||||
def test_token_dispatch_with_shared_experts_and_quant(self):
|
||||
self.shared_experts = MagicMock()
|
||||
@@ -166,20 +157,13 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||
return_value=(torch.randn(10, 128), ) * 5):
|
||||
with patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
|
||||
autospec=True):
|
||||
with patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
|
||||
autospec=True) as mock_wait:
|
||||
self.dispatcher.token_dispatch(
|
||||
self.hidden_states,
|
||||
self.topk_weights,
|
||||
torch.randint(0, 8, (10, 1)),
|
||||
torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
shared_experts=self.shared_experts)
|
||||
mock_wait.assert_any_call(self.hidden_states,
|
||||
self.topk_weights)
|
||||
self.dispatcher.token_dispatch(self.hidden_states,
|
||||
self.topk_weights,
|
||||
torch.randint(0, 8, (10, 1)),
|
||||
self.row_idx,
|
||||
torch.tensor(
|
||||
[0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
shared_experts=self.shared_experts)
|
||||
|
||||
def test_get_combine_mc_kwargs_with_quant(self):
|
||||
self.dispatcher.with_quant = True
|
||||
@@ -213,13 +197,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
|
||||
with patch("torch_npu.npu_moe_distribute_combine_v2",
|
||||
return_value=torch.randn(10, 128)):
|
||||
with patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
|
||||
autospec=True):
|
||||
with patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
|
||||
autospec=True):
|
||||
self.dispatcher.token_combine(self.hidden_states)
|
||||
self.dispatcher.token_combine(self.hidden_states)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAllGather(TestBase):
|
||||
@@ -257,6 +235,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
|
||||
)
|
||||
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
|
||||
def tearDown(self):
|
||||
self.patcher_moe_init_routing.stop()
|
||||
@@ -268,14 +247,14 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_dispatch(
|
||||
hidden_states, topk_weights, topk_ids, None)
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, self.row_idx, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_moe_init_routing.assert_called_once()
|
||||
args, kwargs = self.mock_moe_init_routing.call_args
|
||||
|
||||
self.assertEqual(group_list_type, 0)
|
||||
self.assertEqual(results["group_list_type"], 0)
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
kwargs = {
|
||||
@@ -292,11 +271,11 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher_quant.token_dispatch(
|
||||
hidden_states, topk_weights, topk_ids, None)
|
||||
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
self.row_idx, None)
|
||||
|
||||
# Verify quant mode returns group_list_type=1
|
||||
self.assertEqual(group_list_type, 0)
|
||||
self.assertEqual(results["group_list_type"], 0)
|
||||
|
||||
def test_token_combine_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
@@ -337,19 +316,9 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
|
||||
topk_ids = torch.tensor([[0], [1], [2]])
|
||||
|
||||
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_dispatch(
|
||||
hidden_states, topk_weights, topk_ids, None)
|
||||
self.assertEqual(sorted_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_dispatch_invalid_topk_when_router_weight(self):
|
||||
self.dispatcher.apply_router_weight_on_input = True
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
self.dispatcher.token_dispatch(
|
||||
hidden_states, topk_weights,
|
||||
torch.tensor([[0, 1], [1, 2], [2, 3]]), None)
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, None)
|
||||
self.assertEqual(results["hidden_states"].shape, (6, 128))
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
@@ -443,6 +412,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
num_experts=4,
|
||||
num_local_experts=2,
|
||||
with_quant=False)
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
|
||||
def test_token_dispatch(self):
|
||||
hidden_states = torch.randn(8, 16)
|
||||
@@ -457,6 +427,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
@@ -504,6 +475,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
@@ -532,6 +504,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
@@ -553,9 +526,126 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
|
||||
class TestDispatcherRegistry(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
_Dispatchers.clear()
|
||||
|
||||
def tearDown(self):
|
||||
_Dispatchers.clear()
|
||||
|
||||
def test_register_and_get_token_dispatcher(self):
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_dispatcher.__class__.__name__ = "MockDispatcher"
|
||||
|
||||
_register_token_dispatcher(mock_dispatcher)
|
||||
|
||||
self.assertIn("MockDispatcher", _Dispatchers)
|
||||
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
|
||||
|
||||
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
|
||||
self.assertIs(retrieved_dispatcher, mock_dispatcher)
|
||||
|
||||
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAllGather'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
|
||||
self, mock_register, mock_allgather_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 8}
|
||||
mock_instance = MagicMock()
|
||||
mock_allgather_class.return_value = mock_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=1, **kwargs)
|
||||
|
||||
mock_allgather_class.assert_called_once_with(**kwargs)
|
||||
mock_register.assert_called_once_with(mock_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
|
||||
self, mock_register, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
|
||||
mock_instance = MagicMock()
|
||||
mock_all2allv_class.return_value = mock_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=2, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_called_once_with(**kwargs)
|
||||
mock_register.assert_called_once_with(mock_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
|
||||
self, mock_register, mock_mc2_class, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
|
||||
mock_all2allv_instance = MagicMock()
|
||||
mock_mc2_instance = MagicMock()
|
||||
mock_all2allv_class.return_value = mock_all2allv_instance
|
||||
mock_mc2_class.return_value = mock_mc2_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
|
||||
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=16, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_called_once_with(**kwargs)
|
||||
mock_mc2_class.assert_called_once_with(**kwargs)
|
||||
self.assertEqual(mock_register.call_count, 2)
|
||||
mock_register.assert_any_call(mock_all2allv_instance)
|
||||
mock_register.assert_any_call(mock_mc2_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
|
||||
self, mock_register, mock_mc2_class, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
|
||||
mock_existing_all2allv = MagicMock()
|
||||
mock_existing_mc2 = MagicMock()
|
||||
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
|
||||
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
|
||||
|
||||
setup_token_dispatchers(ep_size=16, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_not_called()
|
||||
mock_mc2_class.assert_not_called()
|
||||
mock_register.assert_not_called()
|
||||
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
|
||||
mock_existing_all2allv)
|
||||
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
|
||||
mock_existing_mc2)
|
||||
|
||||
Reference in New Issue
Block a user