[Refactor] Formatting output types related to FuseMoE (#5481)

Currently in the Fused MoE module, functions of classes like
MoECommMethod and MoETokenDispatcher output data in dictionary or tuple
format, which hampers code maintainability, readability, and
extensibility. This PR introduces dataclasses for these key output types
to address these issues.

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
Jade Zheng
2025-12-31 14:24:37 +08:00
committed by GitHub
parent 38570cfeb6
commit 7d5242faca
6 changed files with 155 additions and 212 deletions

View File

@@ -8,6 +8,8 @@ from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl,
MC2CommImpl)
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult,
TokenDispatchResult)
class TestMoECommMethod(TestBase):
@@ -178,12 +180,12 @@ class TestMoECommMethod(TestBase):
# Mock token dispatcher
mock_td_instance = MagicMock()
mock_td_instance.token_dispatch.return_value = {
"hidden_states": torch.randn(6, 8),
"group_list": torch.tensor([2, 2, 2]),
"group_list_type": 1
}
mock_td_instance.token_combine.return_value = torch.randn(4, 8)
mock_td_instance.token_dispatch.return_value = TokenDispatchResult(
hidden_states=torch.randn(6, 8),
group_list=torch.tensor([2, 2, 2]),
group_list_type=1)
mock_td_instance.token_combine.return_value = TokenCombineResult(
routed_out=torch.randn(4, 8))
mock_token_dispatcher.return_value = mock_td_instance
# Mock unified_apply_mlp
@@ -213,7 +215,7 @@ class TestMoECommMethod(TestBase):
activation="silu")
# Verify result shape
self.assertEqual(result.shape, (4, 8))
self.assertEqual(result.routed_out.shape, (4, 8))
# Verify token_dispatch was called
mock_td_instance.token_dispatch.assert_called_once()