[Refactor] Formatting output types related to FuseMoE (#5481)
Currently in the Fused MoE module, functions of classes like
MoECommMethod and MoETokenDispatcher output data in dictionary or tuple
format, which hampers code maintainability, readability, and
extensibility. This PR introduces dataclasses for these key output types
to address these issues.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -8,6 +8,8 @@ from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl,
|
||||
MC2CommImpl)
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult,
|
||||
TokenDispatchResult)
|
||||
|
||||
|
||||
class TestMoECommMethod(TestBase):
|
||||
@@ -178,12 +180,12 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Mock token dispatcher
|
||||
mock_td_instance = MagicMock()
|
||||
mock_td_instance.token_dispatch.return_value = {
|
||||
"hidden_states": torch.randn(6, 8),
|
||||
"group_list": torch.tensor([2, 2, 2]),
|
||||
"group_list_type": 1
|
||||
}
|
||||
mock_td_instance.token_combine.return_value = torch.randn(4, 8)
|
||||
mock_td_instance.token_dispatch.return_value = TokenDispatchResult(
|
||||
hidden_states=torch.randn(6, 8),
|
||||
group_list=torch.tensor([2, 2, 2]),
|
||||
group_list_type=1)
|
||||
mock_td_instance.token_combine.return_value = TokenCombineResult(
|
||||
routed_out=torch.randn(4, 8))
|
||||
mock_token_dispatcher.return_value = mock_td_instance
|
||||
|
||||
# Mock unified_apply_mlp
|
||||
@@ -213,7 +215,7 @@ class TestMoECommMethod(TestBase):
|
||||
activation="silu")
|
||||
|
||||
# Verify result shape
|
||||
self.assertEqual(result.shape, (4, 8))
|
||||
self.assertEqual(result.routed_out.shape, (4, 8))
|
||||
|
||||
# Verify token_dispatch was called
|
||||
mock_td_instance.token_dispatch.assert_called_once()
|
||||
|
||||
@@ -97,8 +97,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
topk_weights, topk_ids,
|
||||
expert_map)
|
||||
mock_dispatch.assert_called_once()
|
||||
self.assertEqual(output["group_list_type"],
|
||||
0) # group_list_type == 0
|
||||
self.assertEqual(output.group_list_type, 0) # group_list_type == 0
|
||||
|
||||
def test_token_dispatch_with_shared_experts_and_quant(self):
|
||||
self.shared_experts = MagicMock()
|
||||
@@ -149,43 +148,6 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
context_metadata)
|
||||
self.assertIn("tp_send_counts", kwargs)
|
||||
|
||||
def test_token_combine_with_shared_experts(self):
|
||||
shared_experts = MagicMock()
|
||||
shared_experts.down_proj.return_value = (torch.randn(10, 128),
|
||||
torch.tensor(1.0))
|
||||
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
topk_weights = torch.randn(10, 1)
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
assist_info_for_combine = torch.arange(10)
|
||||
tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
|
||||
context_metadata = {
|
||||
"topk_ids": topk_ids,
|
||||
"topk_weights": topk_weights,
|
||||
"expert_map": expert_map,
|
||||
"ep_recv_counts": ep_recv_counts,
|
||||
"mc2_mask": None,
|
||||
"assist_info_for_combine": assist_info_for_combine,
|
||||
"expand_scales": None,
|
||||
"shared_experts": shared_experts,
|
||||
"shared_act": torch.randn(10, 128),
|
||||
"swiglu_out_scale": torch.randn(10, 1),
|
||||
"tp_recv_counts": tp_recv_counts
|
||||
}
|
||||
|
||||
self.dispatcher.with_quant = True
|
||||
self.dispatcher.need_extra_args = True
|
||||
self.dispatcher.enable_dispatch_v2 = True
|
||||
|
||||
hidden_states = torch.randn(10, 128)
|
||||
with patch("torch_npu.npu_moe_distribute_combine_v2",
|
||||
return_value=torch.randn(10, 128)):
|
||||
result = self.dispatcher.token_combine(hidden_states,
|
||||
context_metadata)
|
||||
self.assertIsInstance(result, tuple)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAllGather(TestBase):
|
||||
|
||||
@@ -233,7 +195,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
|
||||
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
self.assertEqual(results.group_list_type, 1)
|
||||
|
||||
def test_token_dispatch_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
@@ -248,7 +210,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
|
||||
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
self.assertEqual(results.group_list_type, 1)
|
||||
|
||||
def test_token_dispatch_without_quant(self):
|
||||
kwargs = {
|
||||
@@ -268,7 +230,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights, topk_ids,
|
||||
None)
|
||||
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
self.assertEqual(results.group_list_type, 1)
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
kwargs = {
|
||||
@@ -290,10 +252,10 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
None,
|
||||
with_quant=True)
|
||||
|
||||
self.assertIsNotNone(results["hidden_states"])
|
||||
self.assertIsNotNone(results["group_list"])
|
||||
self.assertIsNotNone(results["dynamic_scale"])
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
self.assertIsNotNone(results.hidden_states)
|
||||
self.assertIsNotNone(results.group_list)
|
||||
self.assertIsNotNone(results.dynamic_scale)
|
||||
self.assertEqual(results.group_list_type, 1)
|
||||
|
||||
def test_token_combine_with_expert_map(self):
|
||||
hidden_states = torch.randn(6, 128)
|
||||
@@ -303,7 +265,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
}
|
||||
self.dispatcher.original_shape = (6, 128)
|
||||
final_hidden_states = self.dispatcher.token_combine(
|
||||
hidden_states, context_metadata)
|
||||
hidden_states, context_metadata).routed_out
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_combine_without_expert_map(self):
|
||||
@@ -314,7 +276,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
}
|
||||
self.dispatcher.original_shape = (6, 128)
|
||||
final_hidden_states = self.dispatcher.token_combine(
|
||||
hidden_states, context_metadata)
|
||||
hidden_states, context_metadata).routed_out
|
||||
self.mock_npu_moe_token_unpermute.assert_called_once()
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
@@ -326,7 +288,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, None)
|
||||
self.assertEqual(results["hidden_states"].shape, (6, 128))
|
||||
self.assertEqual(results.hidden_states.shape, (6, 128))
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
@@ -437,9 +399,9 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
self.assertIsNotNone(result.hidden_states)
|
||||
self.assertIsNotNone(result.group_list)
|
||||
self.assertEqual(result.group_list_type, 1)
|
||||
|
||||
def test_token_combine(self):
|
||||
hidden_states = torch.randn(16, 16)
|
||||
@@ -458,7 +420,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
|
||||
output = self.dispatcher.token_combine(hidden_states, context_metadata)
|
||||
self.assertIsNotNone(output)
|
||||
self.assertEqual(output.shape, (8, 16))
|
||||
self.assertEqual(output.routed_out.shape, (8, 16))
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||
@@ -480,10 +442,10 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertIsNotNone(result["dynamic_scale"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
self.assertIsNotNone(result.hidden_states)
|
||||
self.assertIsNotNone(result.group_list)
|
||||
self.assertIsNotNone(result.dynamic_scale)
|
||||
self.assertEqual(result.group_list_type, 1)
|
||||
|
||||
def test_token_dispatch_with_quant_no_active_tokens(self):
|
||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||
@@ -508,10 +470,10 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertIsNotNone(result["dynamic_scale"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
self.assertIsNotNone(result.hidden_states)
|
||||
self.assertIsNotNone(result.group_list)
|
||||
self.assertIsNotNone(result.dynamic_scale)
|
||||
self.assertEqual(result.group_list_type, 1)
|
||||
|
||||
def test_token_dispatch_with_log2phy(self):
|
||||
hidden_states = torch.randn(8, 16)
|
||||
@@ -530,6 +492,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
self.assertIsNotNone(result.hidden_states)
|
||||
self.assertIsNotNone(result.group_list)
|
||||
self.assertEqual(result.group_list_type, 1)
|
||||
|
||||
Reference in New Issue
Block a user