[Refactor] Formatting output types related to FuseMoE (#5481)
Currently in the Fused MoE module, functions of classes like
MoECommMethod and MoETokenDispatcher output data in dictionary or tuple
format, which hampers code maintainability, readability, and
extensibility. This PR introduces dataclasses for these key output types
to address these issues.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -21,6 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
@@ -35,6 +36,21 @@ from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
||||
is_hierarchical_communication_enabled)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenDispatchResult:
|
||||
hidden_states: torch.Tensor
|
||||
group_list: torch.Tensor
|
||||
group_list_type: int
|
||||
dynamic_scale: torch.Tensor | None = field(default=None)
|
||||
topk_scales: torch.Tensor | None = field(default=None)
|
||||
context_metadata: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenCombineResult:
|
||||
routed_out: torch.Tensor
|
||||
|
||||
|
||||
class MoETokenDispatcher(ABC):
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
@@ -74,14 +90,14 @@ class MoETokenDispatcher(ABC):
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: Optional[torch.Tensor] = None,
|
||||
):
|
||||
) -> TokenDispatchResult:
|
||||
raise NotImplementedError("Dispatch function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
context_metadata: dict,
|
||||
bias: torch.Tensor = None):
|
||||
bias: torch.Tensor | None = None) -> TokenCombineResult:
|
||||
raise NotImplementedError("Combine function not implemented.")
|
||||
|
||||
|
||||
@@ -207,24 +223,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \
|
||||
ep_recv_counts, tp_recv_counts, expand_scales = output[0:7]
|
||||
|
||||
# Handle shared experts (store intermediate results in local vars, not self)
|
||||
shared_act = None
|
||||
swiglu_out_scale = None
|
||||
if with_quant:
|
||||
if shared_experts is not None:
|
||||
share_up_out, _ = shared_experts.gate_up_proj(
|
||||
(quantized_x_for_share, dynamic_scale_for_share))
|
||||
shared_gate_up, shared_dequant_scale = share_up_out[
|
||||
0], share_up_out[1]
|
||||
shared_act_out = shared_experts.act_fn(
|
||||
(shared_gate_up, shared_dequant_scale))
|
||||
shared_act, swiglu_out_scale = shared_act_out[
|
||||
0], shared_act_out[1]
|
||||
else:
|
||||
if shared_experts is not None:
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
||||
shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
|
||||
context_metadata = {
|
||||
"topk_ids": topk_ids,
|
||||
"topk_weights": topk_weights,
|
||||
@@ -233,20 +231,16 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"tp_recv_counts": tp_recv_counts,
|
||||
"assist_info_for_combine": assist_info_for_combine,
|
||||
"shared_experts": shared_experts,
|
||||
"shared_act": shared_act,
|
||||
"swiglu_out_scale": swiglu_out_scale,
|
||||
"expand_scales": expand_scales
|
||||
}
|
||||
|
||||
group_list_type = 0
|
||||
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": expand_x,
|
||||
"group_list": expert_token_nums,
|
||||
"dynamic_scale": dynamic_scale,
|
||||
"context_metadata": context_metadata
|
||||
}
|
||||
return TokenDispatchResult(hidden_states=expand_x,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list=expert_token_nums,
|
||||
group_list_type=group_list_type,
|
||||
context_metadata=context_metadata)
|
||||
|
||||
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
|
||||
context_metadata: dict):
|
||||
@@ -300,12 +294,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
def token_combine(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
context_metadata: dict,
|
||||
bias: torch.Tensor = None,
|
||||
):
|
||||
def token_combine(self, hidden_states, context_metadata, bias=None):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||
|
||||
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states,
|
||||
@@ -313,20 +302,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \
|
||||
if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||
|
||||
# Handle shared experts from metadata
|
||||
shared_experts = context_metadata["shared_experts"]
|
||||
if shared_experts is None:
|
||||
return combined_output
|
||||
|
||||
shared_act = context_metadata["shared_act"]
|
||||
if self.with_quant:
|
||||
swiglu_out_scale = context_metadata["swiglu_out_scale"]
|
||||
shared_hidden_states, _ = shared_experts.down_proj(
|
||||
(shared_act, swiglu_out_scale))
|
||||
else:
|
||||
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
|
||||
|
||||
return combined_output, shared_hidden_states
|
||||
return TokenCombineResult(routed_out=combined_output)
|
||||
|
||||
|
||||
class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
@@ -401,18 +377,16 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
"topk_weights": topk_weights,
|
||||
"expanded_row_idx": expanded_row_idx
|
||||
}
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": sorted_hidden_states,
|
||||
"group_list": expert_tokens,
|
||||
"dynamic_scale": pertoken_scale if self.with_quant else None,
|
||||
"context_metadata": context_metadata
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
context_metadata: dict,
|
||||
bias: torch.Tensor = None):
|
||||
return TokenDispatchResult(
|
||||
hidden_states=sorted_hidden_states,
|
||||
dynamic_scale=pertoken_scale if self.with_quant else None,
|
||||
group_list=expert_tokens,
|
||||
group_list_type=group_list_type,
|
||||
context_metadata=context_metadata,
|
||||
)
|
||||
|
||||
def token_combine(self, hidden_states, context_metadata, bias=None):
|
||||
assert self.original_shape is not None
|
||||
final_hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=hidden_states,
|
||||
@@ -422,7 +396,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
final_hidden_states = final_hidden_states.view(self.original_shape)
|
||||
|
||||
# these values are no longer used, so they need to be set to None for memory release.
|
||||
return final_hidden_states
|
||||
return TokenCombineResult(routed_out=final_hidden_states)
|
||||
|
||||
|
||||
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
@@ -530,20 +504,15 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
reversed_global_input_permutation_mapping
|
||||
}
|
||||
|
||||
return {
|
||||
"hidden_states": global_input_tokens,
|
||||
"group_list": tokens_per_expert,
|
||||
"group_list_type": 1,
|
||||
"dynamic_scale": dynamic_scale_final,
|
||||
"context_metadata": context_metadata,
|
||||
}
|
||||
return TokenDispatchResult(
|
||||
hidden_states=global_input_tokens,
|
||||
dynamic_scale=dynamic_scale_final,
|
||||
group_list=tokens_per_expert,
|
||||
group_list_type=1,
|
||||
context_metadata=context_metadata,
|
||||
)
|
||||
|
||||
def token_combine(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
context_metadata: dict,
|
||||
bias: torch.Tensor = None,
|
||||
):
|
||||
def token_combine(self, hidden_states, context_metadata, bias=None):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||
|
||||
# 1. Preprocess using metadata
|
||||
@@ -564,7 +533,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
output = self._combine_postprocess(permutated_local_input_tokens,
|
||||
context_metadata)
|
||||
|
||||
return output
|
||||
return TokenCombineResult(routed_out=output)
|
||||
|
||||
def _dispatch_preprocess(self, hidden_states, topk_ids):
|
||||
assert self.hidden_shape is not None
|
||||
|
||||
Reference in New Issue
Block a user