[Refactor] Formatting output types related to FuseMoE (#5481)
Currently in the Fused MoE module, functions of classes like
MoECommMethod and MoETokenDispatcher output data in dictionary or tuple
format, which hampers code maintainability, readability, and
extensibility. This PR introduces dataclasses for these key output types
to address these issues.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -136,10 +136,10 @@ def test_token_dispatcher_with_all_gather(
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
sorted_hidden_states = dispatch_output["hidden_states"]
|
||||
group_list = dispatch_output["group_list"]
|
||||
group_list_type = dispatch_output.get("group_list_type", 1)
|
||||
context_metadata = dispatch_output["context_metadata"]
|
||||
sorted_hidden_states = dispatch_output.hidden_states
|
||||
group_list = dispatch_output.group_list
|
||||
group_list_type = dispatch_output.group_list_type
|
||||
context_metadata = dispatch_output.context_metadata
|
||||
|
||||
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
|
||||
w1=w1_local,
|
||||
@@ -155,7 +155,7 @@ def test_token_dispatcher_with_all_gather(
|
||||
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
|
||||
expert_map)
|
||||
|
||||
torch.testing.assert_close(combined_output,
|
||||
torch.testing.assert_close(combined_output.routed_out,
|
||||
torch_output,
|
||||
atol=4e-2,
|
||||
rtol=1)
|
||||
@@ -216,11 +216,11 @@ def test_token_dispatcher_with_all_gather_quant(
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=True)
|
||||
|
||||
sorted_hidden_states = dispatch_output["hidden_states"]
|
||||
group_list = dispatch_output["group_list"]
|
||||
group_list_type = dispatch_output.get("group_list_type", 1)
|
||||
dynamic_scale = dispatch_output["dynamic_scale"]
|
||||
context_metadata = dispatch_output["context_metadata"]
|
||||
sorted_hidden_states = dispatch_output.hidden_states
|
||||
group_list = dispatch_output.group_list
|
||||
group_list_type = dispatch_output.group_list_type
|
||||
dynamic_scale = dispatch_output.dynamic_scale
|
||||
context_metadata = dispatch_output.context_metadata
|
||||
|
||||
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
|
||||
w1=w1,
|
||||
@@ -235,7 +235,7 @@ def test_token_dispatcher_with_all_gather_quant(
|
||||
hidden_states=expert_output,
|
||||
context_metadata=context_metadata,
|
||||
bias=None)
|
||||
assert combined_output.shape == (m, k)
|
||||
assert combined_output.routed_out.shape == (m, k)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
Reference in New Issue
Block a user