[Refactor] Formatting output types related to FuseMoE (#5481)

Currently in the Fused MoE module, functions of classes like
MoECommMethod and MoETokenDispatcher output data in dictionary or tuple
format, which hampers code maintainability, readability, and
extensibility. This PR introduces dataclasses for these key output types
to address these issues.

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
Jade Zheng
2025-12-31 14:24:37 +08:00
committed by GitHub
parent 38570cfeb6
commit 7d5242faca
6 changed files with 155 additions and 212 deletions

View File

@@ -21,6 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Optional
import torch
@@ -35,6 +36,21 @@ from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
is_hierarchical_communication_enabled)
@dataclass
class TokenDispatchResult:
hidden_states: torch.Tensor
group_list: torch.Tensor
group_list_type: int
dynamic_scale: torch.Tensor | None = field(default=None)
topk_scales: torch.Tensor | None = field(default=None)
context_metadata: dict = field(default_factory=dict)
@dataclass
class TokenCombineResult:
routed_out: torch.Tensor
class MoETokenDispatcher(ABC):
def __init__(self, **kwargs) -> None:
@@ -74,14 +90,14 @@ class MoETokenDispatcher(ABC):
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None,
):
) -> TokenDispatchResult:
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
def token_combine(self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None):
bias: torch.Tensor | None = None) -> TokenCombineResult:
raise NotImplementedError("Combine function not implemented.")
@@ -207,24 +223,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \
ep_recv_counts, tp_recv_counts, expand_scales = output[0:7]
# Handle shared experts (store intermediate results in local vars, not self)
shared_act = None
swiglu_out_scale = None
if with_quant:
if shared_experts is not None:
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
shared_act_out = shared_experts.act_fn(
(shared_gate_up, shared_dequant_scale))
shared_act, swiglu_out_scale = shared_act_out[
0], shared_act_out[1]
else:
if shared_experts is not None:
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
shared_act = shared_experts.act_fn(shared_gate_up)
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
@@ -233,20 +231,16 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"tp_recv_counts": tp_recv_counts,
"assist_info_for_combine": assist_info_for_combine,
"shared_experts": shared_experts,
"shared_act": shared_act,
"swiglu_out_scale": swiglu_out_scale,
"expand_scales": expand_scales
}
group_list_type = 0
return {
"group_list_type": group_list_type,
"hidden_states": expand_x,
"group_list": expert_token_nums,
"dynamic_scale": dynamic_scale,
"context_metadata": context_metadata
}
return TokenDispatchResult(hidden_states=expand_x,
dynamic_scale=dynamic_scale,
group_list=expert_token_nums,
group_list_type=group_list_type,
context_metadata=context_metadata)
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
context_metadata: dict):
@@ -300,12 +294,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2
def token_combine(
self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None,
):
def token_combine(self, hidden_states, context_metadata, bias=None):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states,
@@ -313,20 +302,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \
if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
# Handle shared experts from metadata
shared_experts = context_metadata["shared_experts"]
if shared_experts is None:
return combined_output
shared_act = context_metadata["shared_act"]
if self.with_quant:
swiglu_out_scale = context_metadata["swiglu_out_scale"]
shared_hidden_states, _ = shared_experts.down_proj(
(shared_act, swiglu_out_scale))
else:
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
return combined_output, shared_hidden_states
return TokenCombineResult(routed_out=combined_output)
class TokenDispatcherWithAllGather(MoETokenDispatcher):
@@ -401,18 +377,16 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
"topk_weights": topk_weights,
"expanded_row_idx": expanded_row_idx
}
return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": expert_tokens,
"dynamic_scale": pertoken_scale if self.with_quant else None,
"context_metadata": context_metadata
}
def token_combine(self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None):
return TokenDispatchResult(
hidden_states=sorted_hidden_states,
dynamic_scale=pertoken_scale if self.with_quant else None,
group_list=expert_tokens,
group_list_type=group_list_type,
context_metadata=context_metadata,
)
def token_combine(self, hidden_states, context_metadata, bias=None):
assert self.original_shape is not None
final_hidden_states = torch_npu.npu_moe_token_unpermute(
permuted_tokens=hidden_states,
@@ -422,7 +396,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
final_hidden_states = final_hidden_states.view(self.original_shape)
# these values are no longer used, so they need to be set to None for memory release.
return final_hidden_states
return TokenCombineResult(routed_out=final_hidden_states)
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
@@ -530,20 +504,15 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
reversed_global_input_permutation_mapping
}
return {
"hidden_states": global_input_tokens,
"group_list": tokens_per_expert,
"group_list_type": 1,
"dynamic_scale": dynamic_scale_final,
"context_metadata": context_metadata,
}
return TokenDispatchResult(
hidden_states=global_input_tokens,
dynamic_scale=dynamic_scale_final,
group_list=tokens_per_expert,
group_list_type=1,
context_metadata=context_metadata,
)
def token_combine(
self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None,
):
def token_combine(self, hidden_states, context_metadata, bias=None):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
# 1. Preprocess using metadata
@@ -564,7 +533,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
output = self._combine_postprocess(permutated_local_input_tokens,
context_metadata)
return output
return TokenCombineResult(routed_out=output)
def _dispatch_preprocess(self, hidden_states, topk_ids):
assert self.hidden_shape is not None