[Refactor] Formatting output types related to FuseMoE (#5481)

Currently in the Fused MoE module, functions of classes like
MoECommMethod and MoETokenDispatcher output data in dictionary or tuple
format, which hampers code maintainability, readability, and
extensibility. This PR introduces dataclasses for these key output types
to address these issues.

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
Jade Zheng
2025-12-31 14:24:37 +08:00
committed by GitHub
parent 38570cfeb6
commit 7d5242faca
6 changed files with 155 additions and 212 deletions

View File

@@ -136,10 +136,10 @@ def test_token_dispatcher_with_all_gather(
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
sorted_hidden_states = dispatch_output["hidden_states"]
group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1)
context_metadata = dispatch_output["context_metadata"]
sorted_hidden_states = dispatch_output.hidden_states
group_list = dispatch_output.group_list
group_list_type = dispatch_output.group_list_type
context_metadata = dispatch_output.context_metadata
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
w1=w1_local,
@@ -155,7 +155,7 @@ def test_token_dispatcher_with_all_gather(
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
expert_map)
torch.testing.assert_close(combined_output,
torch.testing.assert_close(combined_output.routed_out,
torch_output,
atol=4e-2,
rtol=1)
@@ -216,11 +216,11 @@ def test_token_dispatcher_with_all_gather_quant(
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=True)
sorted_hidden_states = dispatch_output["hidden_states"]
group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1)
dynamic_scale = dispatch_output["dynamic_scale"]
context_metadata = dispatch_output["context_metadata"]
sorted_hidden_states = dispatch_output.hidden_states
group_list = dispatch_output.group_list
group_list_type = dispatch_output.group_list_type
dynamic_scale = dispatch_output.dynamic_scale
context_metadata = dispatch_output.context_metadata
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
w1=w1,
@@ -235,7 +235,7 @@ def test_token_dispatcher_with_all_gather_quant(
hidden_states=expert_output,
context_metadata=context_metadata,
bias=None)
assert combined_output.shape == (m, k)
assert combined_output.routed_out.shape == (m, k)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()