[Refactor] Formatting output types related to FuseMoE (#5481)
Currently in the Fused MoE module, functions of classes like
MoECommMethod and MoETokenDispatcher output data in dictionary or tuple
format, which hampers code maintainability, readability, and
extensibility. This PR introduces dataclasses for these key output types
to address these issues.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -37,6 +37,7 @@ from vllm_ascend.flash_common3_context import (get_flash_common3_context,
|
||||
set_flash_common3_context)
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
||||
FusedExpertsResult,
|
||||
setup_moe_comm_method)
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
|
||||
from vllm_ascend.quantization.w4a8_dynamic import \
|
||||
@@ -325,7 +326,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
pertoken_scale = None
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
@@ -350,25 +351,25 @@ class AscendFusedMoE(FusedMoE):
|
||||
global_redundant_expert_num=self.global_redundant_expert_num,
|
||||
mc2_mask=mc2_mask)
|
||||
|
||||
if isinstance(final_hidden_states, tuple):
|
||||
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
|
||||
if self.dynamic_eplb:
|
||||
if self.dynamic_eplb:
|
||||
expert_tokens = fused_experts_results.expert_tokens
|
||||
group_list_type = fused_experts_results.group_list_type
|
||||
assert expert_tokens is not None and group_list_type is not None, \
|
||||
"expert_tokens and group_list_type should not be None when dynamic_eplb is enabled."
|
||||
moe_load_stream = moe_load_async_stream()
|
||||
cur_stream = torch.npu.current_stream()
|
||||
moe_load_stream.wait_stream(cur_stream)
|
||||
with npu_stream_switch(moe_load_stream):
|
||||
self.moe_load += expert_tokens if group_list_type == 1 else \
|
||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
cur_stream.wait_stream(moe_load_stream)
|
||||
|
||||
moe_load_stream = moe_load_async_stream()
|
||||
cur_stream = torch.npu.current_stream()
|
||||
|
||||
moe_load_stream.wait_stream(cur_stream)
|
||||
with npu_stream_switch(moe_load_stream):
|
||||
self.moe_load += expert_tokens if group_list_type == 1 else \
|
||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
cur_stream.wait_stream(moe_load_stream)
|
||||
|
||||
final_hidden_states = forward_context.moe_comm_method.finalize(
|
||||
hidden_states=final_hidden_states,
|
||||
routed_out = forward_context.moe_comm_method.finalize(
|
||||
hidden_states=fused_experts_results.routed_out,
|
||||
reduce_results=self.reduce_results,
|
||||
context_metadata=context_metadata)
|
||||
|
||||
return final_hidden_states
|
||||
return routed_out
|
||||
|
||||
|
||||
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
@@ -439,7 +440,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
else:
|
||||
set_flash_common3_context(shared_experts=self._shared_experts)
|
||||
|
||||
fused_output = AscendFusedMoE.forward_impl(
|
||||
routed_out = AscendFusedMoE.forward_impl(
|
||||
self,
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
@@ -462,4 +463,4 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
assert fc3_context is not None
|
||||
shared_out = fc3_context.shared_out
|
||||
|
||||
return shared_out, fused_output
|
||||
return shared_out, routed_out
|
||||
|
||||
Reference in New Issue
Block a user