[main] [refactor] refactor fused_moe.py to enable token_dispatchers (#2570)

### What this PR does / why we need it?
Enable token_dispatcher to replace fused_experts_with_xxx in eager mode
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut


- vLLM version: v0.10.1.1
- vLLM main:
704432af3c

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: sherie <963372609@qq.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
Co-authored-by: shiyuan680 <72335504+shiyuan680@users.noreply.github.com>
This commit is contained in:
weichen
2025-08-28 10:13:35 +08:00
committed by GitHub
parent 936c102105
commit 320edde2df
10 changed files with 1066 additions and 1639 deletions

View File

@@ -25,8 +25,8 @@ from tests.ut.base import PytestBase, TestBase
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig,
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
TokenDispatcherWithMC2)
from vllm_ascend.utils import adapt_patch # noqa E402
TokenDispatcherWithMC2, _Dispatchers, _register_token_dispatcher,
get_token_dispatcher, setup_token_dispatchers)
class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase):
@@ -90,7 +90,7 @@ class TestTokenDispatcherWithMC2(TestBase):
self.forward_context = MagicMock()
self.forward_context.mc2_mask = torch.tensor([1, 0, 1])
self.forward_context_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_forward_context",
"vllm.forward_context.get_forward_context",
return_value=self.forward_context)
self.forward_context_patch.start()
@@ -100,28 +100,18 @@ class TestTokenDispatcherWithMC2(TestBase):
return_value=AscendSocVersion.A3)
self.ascend_soc_version_patch.start()
# Mock get_ascend_config()
self.ascend_config = MagicMock()
self.ascend_config.torchair_graph_config.enabled = False
self.ascend_config_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_config",
return_value=self.ascend_config)
self.ascend_config_patch.start()
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
self.row_idx = torch.arange(10, dtype=torch.int32)
def tearDown(self):
self.mc2_group_patch.stop()
self.forward_context_patch.stop()
self.ascend_soc_version_patch.stop()
self.ascend_config_patch.stop()
def test_init(self):
# self.assertEqual(self.dispatcher.moe_all_to_all_group_name, "hccl_123")
self.assertEqual(self.dispatcher.ep_rank_id, 0)
self.assertEqual(self.dispatcher.ep_world_size, 8)
self.assertFalse(self.dispatcher.torchair_graph_enabled)
self.assertFalse(self.dispatcher.with_quant)
self.assertTrue(self.dispatcher.enable_dispatch_v2)
self.assertTrue(self.dispatcher.need_extra_args)
@@ -149,9 +139,10 @@ class TestTokenDispatcherWithMC2(TestBase):
return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch:
output = self.dispatcher.token_dispatch(hidden_states,
topk_weights, topk_ids,
expert_map)
self.row_idx, expert_map)
mock_dispatch.assert_called_once()
self.assertEqual(output[0], 1) # group_list_type == 1
self.assertEqual(output["group_list_type"],
1) # group_list_type == 1
def test_token_dispatch_with_shared_experts_and_quant(self):
self.shared_experts = MagicMock()
@@ -166,20 +157,13 @@ class TestTokenDispatcherWithMC2(TestBase):
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5):
with patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
autospec=True):
with patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
autospec=True) as mock_wait:
self.dispatcher.token_dispatch(
self.hidden_states,
self.topk_weights,
torch.randint(0, 8, (10, 1)),
torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
shared_experts=self.shared_experts)
mock_wait.assert_any_call(self.hidden_states,
self.topk_weights)
self.dispatcher.token_dispatch(self.hidden_states,
self.topk_weights,
torch.randint(0, 8, (10, 1)),
self.row_idx,
torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7]),
shared_experts=self.shared_experts)
def test_get_combine_mc_kwargs_with_quant(self):
self.dispatcher.with_quant = True
@@ -213,13 +197,7 @@ class TestTokenDispatcherWithMC2(TestBase):
with patch("torch_npu.npu_moe_distribute_combine_v2",
return_value=torch.randn(10, 128)):
with patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
autospec=True):
with patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
autospec=True):
self.dispatcher.token_combine(self.hidden_states)
self.dispatcher.token_combine(self.hidden_states)
class TestTokenDispatcherWithAllGather(TestBase):
@@ -257,6 +235,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
)
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
self.row_idx = torch.arange(10, dtype=torch.int32)
def tearDown(self):
self.patcher_moe_init_routing.stop()
@@ -268,14 +247,14 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_dispatch(
hidden_states, topk_weights, topk_ids, None)
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, self.row_idx, None)
# Verify npu_moe_init_routing is called
self.mock_moe_init_routing.assert_called_once()
args, kwargs = self.mock_moe_init_routing.call_args
self.assertEqual(group_list_type, 0)
self.assertEqual(results["group_list_type"], 0)
def test_token_dispatch_with_quant(self):
kwargs = {
@@ -292,11 +271,11 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher_quant.token_dispatch(
hidden_states, topk_weights, topk_ids, None)
results = self.dispatcher_quant.token_dispatch(hidden_states,
topk_weights, topk_ids,
self.row_idx, None)
# Verify quant mode returns group_list_type=1
self.assertEqual(group_list_type, 0)
self.assertEqual(results["group_list_type"], 0)
def test_token_combine_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
@@ -337,19 +316,9 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
topk_ids = torch.tensor([[0], [1], [2]])
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_dispatch(
hidden_states, topk_weights, topk_ids, None)
self.assertEqual(sorted_hidden_states.shape, (6, 128))
def test_token_dispatch_invalid_topk_when_router_weight(self):
self.dispatcher.apply_router_weight_on_input = True
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
with self.assertRaises(AssertionError):
self.dispatcher.token_dispatch(
hidden_states, topk_weights,
torch.tensor([[0, 1], [1, 2], [2, 3]]), None)
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, None)
self.assertEqual(results["hidden_states"].shape, (6, 128))
class TestTokenDispatcherWithAll2AllV(TestBase):
@@ -443,6 +412,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
num_experts=4,
num_local_experts=2,
with_quant=False)
self.row_idx = torch.arange(10, dtype=torch.int32)
def test_token_dispatch(self):
hidden_states = torch.randn(8, 16)
@@ -457,6 +427,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"])
@@ -504,6 +475,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"])
@@ -532,6 +504,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"])
@@ -553,9 +526,126 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
log2phy=log2phy)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)
class TestDispatcherRegistry(TestBase):
def setUp(self):
_Dispatchers.clear()
def tearDown(self):
_Dispatchers.clear()
def test_register_and_get_token_dispatcher(self):
mock_dispatcher = MagicMock()
mock_dispatcher.__class__.__name__ = "MockDispatcher"
_register_token_dispatcher(mock_dispatcher)
self.assertIn("MockDispatcher", _Dispatchers)
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
self.assertIs(retrieved_dispatcher, mock_dispatcher)
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAllGather'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
self, mock_register, mock_allgather_class):
kwargs = {"top_k": 2, "num_experts": 8}
mock_instance = MagicMock()
mock_allgather_class.return_value = mock_instance
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
setup_token_dispatchers(ep_size=1, **kwargs)
mock_allgather_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
self, mock_register, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
mock_instance = MagicMock()
mock_all2allv_class.return_value = mock_instance
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
setup_token_dispatchers(ep_size=2, **kwargs)
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
mock_all2allv_instance = MagicMock()
mock_mc2_instance = MagicMock()
mock_all2allv_class.return_value = mock_all2allv_instance
mock_mc2_class.return_value = mock_mc2_instance
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
setup_token_dispatchers(ep_size=16, **kwargs)
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_mc2_class.assert_called_once_with(**kwargs)
self.assertEqual(mock_register.call_count, 2)
mock_register.assert_any_call(mock_all2allv_instance)
mock_register.assert_any_call(mock_mc2_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
mock_existing_all2allv = MagicMock()
mock_existing_mc2 = MagicMock()
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
setup_token_dispatchers(ep_size=16, **kwargs)
mock_all2allv_class.assert_not_called()
mock_mc2_class.assert_not_called()
mock_register.assert_not_called()
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
mock_existing_all2allv)
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
mock_existing_mc2)