[Refactor] Formatting output types related to FuseMoE (#5481)
Currently in the Fused MoE module, functions of classes like
MoECommMethod and MoETokenDispatcher output data in dictionary or tuple
format, which hampers code maintainability, readability, and
extensibility. This PR introduces dataclasses for these key output types
to address these issues.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -8,6 +8,8 @@ from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl,
|
||||
MC2CommImpl)
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult,
|
||||
TokenDispatchResult)
|
||||
|
||||
|
||||
class TestMoECommMethod(TestBase):
|
||||
@@ -178,12 +180,12 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Mock token dispatcher
|
||||
mock_td_instance = MagicMock()
|
||||
mock_td_instance.token_dispatch.return_value = {
|
||||
"hidden_states": torch.randn(6, 8),
|
||||
"group_list": torch.tensor([2, 2, 2]),
|
||||
"group_list_type": 1
|
||||
}
|
||||
mock_td_instance.token_combine.return_value = torch.randn(4, 8)
|
||||
mock_td_instance.token_dispatch.return_value = TokenDispatchResult(
|
||||
hidden_states=torch.randn(6, 8),
|
||||
group_list=torch.tensor([2, 2, 2]),
|
||||
group_list_type=1)
|
||||
mock_td_instance.token_combine.return_value = TokenCombineResult(
|
||||
routed_out=torch.randn(4, 8))
|
||||
mock_token_dispatcher.return_value = mock_td_instance
|
||||
|
||||
# Mock unified_apply_mlp
|
||||
@@ -213,7 +215,7 @@ class TestMoECommMethod(TestBase):
|
||||
activation="silu")
|
||||
|
||||
# Verify result shape
|
||||
self.assertEqual(result.shape, (4, 8))
|
||||
self.assertEqual(result.routed_out.shape, (4, 8))
|
||||
|
||||
# Verify token_dispatch was called
|
||||
mock_td_instance.token_dispatch.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user