[Refactor] Formatting output types related to FuseMoE (#5481)

Currently in the Fused MoE module, functions of classes like
MoECommMethod and MoETokenDispatcher output data in dictionary or tuple
format, which hampers code maintainability, readability, and
extensibility. This PR introduces dataclasses for these key output types
to address these issues.

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
Jade Zheng
2025-12-31 14:24:37 +08:00
committed by GitHub
parent 38570cfeb6
commit 7d5242faca
6 changed files with 155 additions and 212 deletions

View File

@@ -37,6 +37,7 @@ from vllm_ascend.flash_common3_context import (get_flash_common3_context,
set_flash_common3_context)
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
FusedExpertsResult,
setup_moe_comm_method)
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
from vllm_ascend.quantization.w4a8_dynamic import \
@@ -325,7 +326,7 @@ class AscendFusedMoE(FusedMoE):
pertoken_scale = None
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
@@ -350,25 +351,25 @@ class AscendFusedMoE(FusedMoE):
global_redundant_expert_num=self.global_redundant_expert_num,
mc2_mask=mc2_mask)
if isinstance(final_hidden_states, tuple):
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
if self.dynamic_eplb:
if self.dynamic_eplb:
expert_tokens = fused_experts_results.expert_tokens
group_list_type = fused_experts_results.group_list_type
assert expert_tokens is not None and group_list_type is not None, \
"expert_tokens and group_list_type should not be None when dynamic_eplb is enabled."
moe_load_stream = moe_load_async_stream()
cur_stream = torch.npu.current_stream()
moe_load_stream.wait_stream(cur_stream)
with npu_stream_switch(moe_load_stream):
self.moe_load += expert_tokens if group_list_type == 1 else \
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
cur_stream.wait_stream(moe_load_stream)
moe_load_stream = moe_load_async_stream()
cur_stream = torch.npu.current_stream()
moe_load_stream.wait_stream(cur_stream)
with npu_stream_switch(moe_load_stream):
self.moe_load += expert_tokens if group_list_type == 1 else \
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
cur_stream.wait_stream(moe_load_stream)
final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=final_hidden_states,
routed_out = forward_context.moe_comm_method.finalize(
hidden_states=fused_experts_results.routed_out,
reduce_results=self.reduce_results,
context_metadata=context_metadata)
return final_hidden_states
return routed_out
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
@@ -439,7 +440,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
else:
set_flash_common3_context(shared_experts=self._shared_experts)
fused_output = AscendFusedMoE.forward_impl(
routed_out = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
@@ -462,4 +463,4 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
assert fc3_context is not None
shared_out = fc3_context.shared_out
return shared_out, fused_output
return shared_out, routed_out