[Refactor] Formatting output types related to FuseMoE (#5481)

Currently in the Fused MoE module, functions of classes like
MoECommMethod and MoETokenDispatcher output data in dictionary or tuple
format, which hampers code maintainability, readability, and
extensibility. This PR introduces dataclasses for these key output types
to address these issues.

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
Jade Zheng
2025-12-31 14:24:37 +08:00
committed by GitHub
parent 38570cfeb6
commit 7d5242faca
6 changed files with 155 additions and 212 deletions

View File

@@ -8,6 +8,8 @@ from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl,
MC2CommImpl)
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult,
TokenDispatchResult)
class TestMoECommMethod(TestBase):
@@ -178,12 +180,12 @@ class TestMoECommMethod(TestBase):
# Mock token dispatcher
mock_td_instance = MagicMock()
mock_td_instance.token_dispatch.return_value = {
"hidden_states": torch.randn(6, 8),
"group_list": torch.tensor([2, 2, 2]),
"group_list_type": 1
}
mock_td_instance.token_combine.return_value = torch.randn(4, 8)
mock_td_instance.token_dispatch.return_value = TokenDispatchResult(
hidden_states=torch.randn(6, 8),
group_list=torch.tensor([2, 2, 2]),
group_list_type=1)
mock_td_instance.token_combine.return_value = TokenCombineResult(
routed_out=torch.randn(4, 8))
mock_token_dispatcher.return_value = mock_td_instance
# Mock unified_apply_mlp
@@ -213,7 +215,7 @@ class TestMoECommMethod(TestBase):
activation="silu")
# Verify result shape
self.assertEqual(result.shape, (4, 8))
self.assertEqual(result.routed_out.shape, (4, 8))
# Verify token_dispatch was called
mock_td_instance.token_dispatch.assert_called_once()

View File

@@ -97,8 +97,7 @@ class TestTokenDispatcherWithMC2(TestBase):
topk_weights, topk_ids,
expert_map)
mock_dispatch.assert_called_once()
self.assertEqual(output["group_list_type"],
0) # group_list_type == 0
self.assertEqual(output.group_list_type, 0) # group_list_type == 0
def test_token_dispatch_with_shared_experts_and_quant(self):
self.shared_experts = MagicMock()
@@ -149,43 +148,6 @@ class TestTokenDispatcherWithMC2(TestBase):
context_metadata)
self.assertIn("tp_send_counts", kwargs)
def test_token_combine_with_shared_experts(self):
shared_experts = MagicMock()
shared_experts.down_proj.return_value = (torch.randn(10, 128),
torch.tensor(1.0))
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
assist_info_for_combine = torch.arange(10)
tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"mc2_mask": None,
"assist_info_for_combine": assist_info_for_combine,
"expand_scales": None,
"shared_experts": shared_experts,
"shared_act": torch.randn(10, 128),
"swiglu_out_scale": torch.randn(10, 1),
"tp_recv_counts": tp_recv_counts
}
self.dispatcher.with_quant = True
self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True
hidden_states = torch.randn(10, 128)
with patch("torch_npu.npu_moe_distribute_combine_v2",
return_value=torch.randn(10, 128)):
result = self.dispatcher.token_combine(hidden_states,
context_metadata)
self.assertIsInstance(result, tuple)
class TestTokenDispatcherWithAllGather(TestBase):
@@ -233,7 +195,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
self.mock_npu_moe_init_routing_v2.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
self.assertEqual(results["group_list_type"], 1)
self.assertEqual(results.group_list_type, 1)
def test_token_dispatch_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
@@ -248,7 +210,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
self.mock_npu_moe_init_routing_v2.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
self.assertEqual(results["group_list_type"], 1)
self.assertEqual(results.group_list_type, 1)
def test_token_dispatch_without_quant(self):
kwargs = {
@@ -268,7 +230,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights, topk_ids,
None)
self.assertEqual(results["group_list_type"], 1)
self.assertEqual(results.group_list_type, 1)
def test_token_dispatch_with_quant(self):
kwargs = {
@@ -290,10 +252,10 @@ class TestTokenDispatcherWithAllGather(TestBase):
None,
with_quant=True)
self.assertIsNotNone(results["hidden_states"])
self.assertIsNotNone(results["group_list"])
self.assertIsNotNone(results["dynamic_scale"])
self.assertEqual(results["group_list_type"], 1)
self.assertIsNotNone(results.hidden_states)
self.assertIsNotNone(results.group_list)
self.assertIsNotNone(results.dynamic_scale)
self.assertEqual(results.group_list_type, 1)
def test_token_combine_with_expert_map(self):
hidden_states = torch.randn(6, 128)
@@ -303,7 +265,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
}
self.dispatcher.original_shape = (6, 128)
final_hidden_states = self.dispatcher.token_combine(
hidden_states, context_metadata)
hidden_states, context_metadata).routed_out
self.assertEqual(final_hidden_states.shape, (6, 128))
def test_token_combine_without_expert_map(self):
@@ -314,7 +276,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
}
self.dispatcher.original_shape = (6, 128)
final_hidden_states = self.dispatcher.token_combine(
hidden_states, context_metadata)
hidden_states, context_metadata).routed_out
self.mock_npu_moe_token_unpermute.assert_called_once()
self.assertEqual(final_hidden_states.shape, (6, 128))
@@ -326,7 +288,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, None)
self.assertEqual(results["hidden_states"].shape, (6, 128))
self.assertEqual(results.hidden_states.shape, (6, 128))
class TestTokenDispatcherWithAll2AllV(TestBase):
@@ -437,9 +399,9 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
topk_ids=topk_ids,
expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertEqual(result.group_list_type, 1)
def test_token_combine(self):
hidden_states = torch.randn(16, 16)
@@ -458,7 +420,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
output = self.dispatcher.token_combine(hidden_states, context_metadata)
self.assertIsNotNone(output)
self.assertEqual(output.shape, (8, 16))
self.assertEqual(output.routed_out.shape, (8, 16))
def test_token_dispatch_with_quant(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
@@ -480,10 +442,10 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
expert_map=expert_map,
with_quant=True)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertIsNotNone(result["dynamic_scale"])
self.assertEqual(result["group_list_type"], 1)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertIsNotNone(result.dynamic_scale)
self.assertEqual(result.group_list_type, 1)
def test_token_dispatch_with_quant_no_active_tokens(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
@@ -508,10 +470,10 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
expert_map=expert_map,
with_quant=True)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertIsNotNone(result["dynamic_scale"])
self.assertEqual(result["group_list_type"], 1)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertIsNotNone(result.dynamic_scale)
self.assertEqual(result.group_list_type, 1)
def test_token_dispatch_with_log2phy(self):
hidden_states = torch.randn(8, 16)
@@ -530,6 +492,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
expert_map=expert_map,
log2phy=log2phy)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertEqual(result.group_list_type, 1)