[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -21,7 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
@@ -31,25 +31,18 @@ from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.device.device_op import DeviceOperator
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||
MoEAllGatherCombineMetadata,
|
||||
MoEAllToAllCombineMetadata,
|
||||
MoEMC2CombineMetadata,
|
||||
MoETokenDispatchInput,
|
||||
MoETokenDispatchOutput,
|
||||
TMoECombineMetadata,
|
||||
)
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenDispatchResult:
|
||||
hidden_states: torch.Tensor
|
||||
group_list: torch.Tensor
|
||||
group_list_type: int
|
||||
dynamic_scale: torch.Tensor | None = field(default=None)
|
||||
topk_scales: torch.Tensor | None = field(default=None)
|
||||
context_metadata: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenCombineResult:
|
||||
routed_out: torch.Tensor
|
||||
|
||||
|
||||
class MoETokenDispatcher(ABC):
|
||||
class MoETokenDispatcher(ABC, Generic[TMoECombineMetadata]):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""
|
||||
Initialize the MoE Token Dispatcher.
|
||||
@@ -73,27 +66,21 @@ class MoETokenDispatcher(ABC):
|
||||
@abstractmethod
|
||||
def token_dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
) -> TokenDispatchResult:
|
||||
token_dispatch_input: MoETokenDispatchInput,
|
||||
) -> MoETokenDispatchOutput[TMoECombineMetadata]:
|
||||
raise NotImplementedError("Dispatch function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def token_combine(
|
||||
self, hidden_states: torch.Tensor, context_metadata: dict, bias: torch.Tensor | None = None
|
||||
) -> TokenCombineResult:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
combine_metadata: TMoECombineMetadata,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError("Combine function not implemented.")
|
||||
|
||||
|
||||
class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
device_group = get_mc2_group().device_group
|
||||
@@ -110,7 +97,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
self.need_expert_scale = is_hierarchical_communication_enabled()
|
||||
self.with_quant = False
|
||||
|
||||
# Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute
|
||||
# dispatch & combine operators with different input num_tokens per rank.
|
||||
@@ -131,25 +117,23 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
|
||||
def get_dispatch_mc2_kwargs(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
mc2_mask: torch.Tensor,
|
||||
global_redundant_expert_num: int = 0,
|
||||
**kwargs,
|
||||
token_dispatch_input: MoETokenDispatchInput,
|
||||
):
|
||||
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
|
||||
comm_quant_mode = kwargs.get("comm_quant_mode")
|
||||
hidden_states = token_dispatch_input.hidden_states
|
||||
topk_weights = token_dispatch_input.topk_weights
|
||||
topk_ids = token_dispatch_input.topk_ids
|
||||
expert_map = token_dispatch_input.routing.expert_map
|
||||
global_redundant_expert_num = token_dispatch_input.routing.global_redundant_expert_num
|
||||
comm_quant_mode = token_dispatch_input.quant.comm_quant_mode
|
||||
|
||||
assert expert_map is not None, "expert_map is required for MC2 token dispatch."
|
||||
# NOTE: quant_mode differs by quant feature:
|
||||
# - Legacy int communication quantization uses quant_mode=2.
|
||||
# - A5 MXFP8 communication uses quant_mode=4.
|
||||
# TODO(linfeng): The quantization-related parameters need to be consolidated into a single
|
||||
# dataclass, and the FP8 MoE code path should be integrated into it going forward.
|
||||
if comm_quant_mode is not None:
|
||||
quant_mode = comm_quant_mode
|
||||
elif self.with_quant:
|
||||
quant_mode = 4 if self.a5_need_extra_args and use_mxfp_quant else 2
|
||||
elif token_dispatch_input.quant.dispatch_with_quant:
|
||||
quant_mode = 4 if self.a5_need_extra_args and token_dispatch_input.quant.is_mxfp else 2
|
||||
else:
|
||||
quant_mode = 0
|
||||
self.moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||
@@ -178,10 +162,13 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"tp_rank_id": 0,
|
||||
}
|
||||
)
|
||||
if self.a5_need_extra_args and use_mxfp_quant:
|
||||
y_dtype = kwargs.get("y_dtype")
|
||||
if self.with_quant:
|
||||
y_dtype = torch.float8_e4m3fn if y_dtype is None else y_dtype
|
||||
if self.a5_need_extra_args and token_dispatch_input.quant.is_mxfp:
|
||||
y_dtype = torch.float8_e4m3fn
|
||||
if (
|
||||
token_dispatch_input.quant.mxfp is not None
|
||||
and token_dispatch_input.quant.mxfp.act_quant_type is not None
|
||||
):
|
||||
y_dtype = token_dispatch_input.quant.mxfp.act_quant_type
|
||||
stage1_kwargs.update({"tp_world_size": 1, "tp_rank_id": 0, "y_dtype": y_dtype})
|
||||
if self.need_expert_scale or self.a5_need_extra_args:
|
||||
stage1_kwargs.update(
|
||||
@@ -195,22 +182,9 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
|
||||
def token_dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
token_dispatch_input: MoETokenDispatchInput,
|
||||
):
|
||||
self.with_quant = with_quant
|
||||
kwargs_mc2 = self.get_dispatch_mc2_kwargs(
|
||||
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num, **kwargs
|
||||
)
|
||||
kwargs_mc2 = self.get_dispatch_mc2_kwargs(token_dispatch_input)
|
||||
output = (
|
||||
torch_npu.npu_moe_distribute_dispatch_v2(**kwargs_mc2)
|
||||
if self.enable_dispatch_v2
|
||||
@@ -227,33 +201,32 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
expand_scales,
|
||||
) = output[0:7]
|
||||
|
||||
context_metadata = {
|
||||
"topk_ids": topk_ids,
|
||||
"topk_weights": topk_weights,
|
||||
"expert_map": expert_map,
|
||||
"ep_recv_counts": ep_recv_counts,
|
||||
"tp_recv_counts": tp_recv_counts,
|
||||
"assist_info_for_combine": assist_info_for_combine,
|
||||
"expand_scales": expand_scales,
|
||||
}
|
||||
|
||||
group_list_type = 0
|
||||
return TokenDispatchResult(
|
||||
return MoETokenDispatchOutput(
|
||||
hidden_states=expand_x,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list=expert_token_nums,
|
||||
group_list_type=group_list_type,
|
||||
context_metadata=context_metadata,
|
||||
combine_metadata=MoEMC2CombineMetadata(
|
||||
topk_ids=token_dispatch_input.topk_ids,
|
||||
topk_weights=token_dispatch_input.topk_weights,
|
||||
expert_map=token_dispatch_input.routing.expert_map,
|
||||
ep_recv_counts=ep_recv_counts,
|
||||
tp_recv_counts=tp_recv_counts,
|
||||
assist_info_for_combine=assist_info_for_combine,
|
||||
expand_scales=expand_scales,
|
||||
dispatch_with_quant=token_dispatch_input.quant.dispatch_with_quant,
|
||||
),
|
||||
)
|
||||
|
||||
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, context_metadata: dict):
|
||||
expert_map = context_metadata["expert_map"]
|
||||
topk_ids = context_metadata["topk_ids"]
|
||||
topk_weights = context_metadata["topk_weights"]
|
||||
ep_recv_counts = context_metadata["ep_recv_counts"]
|
||||
tp_recv_counts = context_metadata["tp_recv_counts"]
|
||||
assist_info_for_combine = context_metadata["assist_info_for_combine"]
|
||||
expand_scales = context_metadata["expand_scales"]
|
||||
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, combine_metadata: MoEMC2CombineMetadata):
|
||||
expert_map = combine_metadata.expert_map
|
||||
topk_ids = combine_metadata.topk_ids
|
||||
topk_weights = combine_metadata.topk_weights
|
||||
ep_recv_counts = combine_metadata.ep_recv_counts
|
||||
tp_recv_counts = combine_metadata.tp_recv_counts
|
||||
assist_info_for_combine = combine_metadata.assist_info_for_combine
|
||||
expand_scales = combine_metadata.expand_scales
|
||||
|
||||
assert expert_map is not None
|
||||
|
||||
@@ -267,7 +240,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"global_bs": self.global_bs,
|
||||
}
|
||||
|
||||
if self.with_quant:
|
||||
if combine_metadata.dispatch_with_quant:
|
||||
tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
stage3_kwargs = {
|
||||
@@ -296,52 +269,44 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
def token_combine(self, hidden_states, context_metadata, bias=None):
|
||||
def token_combine(self, hidden_states, combine_metadata, bias=None):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||
|
||||
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, context_metadata)
|
||||
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, combine_metadata)
|
||||
combined_output = (
|
||||
torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2)
|
||||
if self.enable_dispatch_v2
|
||||
else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||
)
|
||||
|
||||
return TokenCombineResult(
|
||||
routed_out=combined_output,
|
||||
)
|
||||
return combined_output
|
||||
|
||||
|
||||
class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
class TokenDispatcherWithAllGather(MoETokenDispatcher[MoEAllGatherCombineMetadata]):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.apply_router_weight_on_input = False
|
||||
self.max_num_tokens = kwargs.get("max_num_tokens")
|
||||
num_experts_local = kwargs.get("num_local_experts", 0)
|
||||
self.num_experts_local = (
|
||||
num_experts_local.item() if torch.is_tensor(num_experts_local) else int(num_experts_local)
|
||||
)
|
||||
self.original_shape = None
|
||||
self.with_quant = False
|
||||
|
||||
def token_dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
token_dispatch_input: MoETokenDispatchInput,
|
||||
):
|
||||
self.with_quant = with_quant
|
||||
self.original_shape = hidden_states.shape
|
||||
with_quant = token_dispatch_input.quant.is_int_quant
|
||||
hidden_states = token_dispatch_input.hidden_states
|
||||
topk_weights = token_dispatch_input.topk_weights
|
||||
topk_ids = token_dispatch_input.topk_ids
|
||||
expert_map = token_dispatch_input.routing.expert_map
|
||||
pertoken_scale = token_dispatch_input.routing.pertoken_scale
|
||||
global_redundant_expert_num = token_dispatch_input.routing.global_redundant_expert_num
|
||||
restore_shape = hidden_states.shape
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
if self.apply_router_weight_on_input:
|
||||
apply_router_weight_on_input = token_dispatch_input.routing.apply_router_weight_on_input
|
||||
if apply_router_weight_on_input:
|
||||
assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
@@ -365,35 +330,37 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||
quant_mode=1 if self.with_quant and pertoken_scale is None else -1,
|
||||
quant_mode=1 if with_quant and pertoken_scale is None else -1,
|
||||
)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 1 # `count` mode
|
||||
context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx}
|
||||
|
||||
return TokenDispatchResult(
|
||||
return MoETokenDispatchOutput(
|
||||
hidden_states=sorted_hidden_states,
|
||||
dynamic_scale=pertoken_scale if self.with_quant else None,
|
||||
dynamic_scale=pertoken_scale if with_quant else None,
|
||||
group_list=expert_tokens,
|
||||
group_list_type=group_list_type,
|
||||
context_metadata=context_metadata,
|
||||
combine_metadata=MoEAllGatherCombineMetadata(
|
||||
topk_weights=topk_weights,
|
||||
expanded_row_idx=expanded_row_idx,
|
||||
restore_shape=restore_shape,
|
||||
),
|
||||
)
|
||||
|
||||
def token_combine(self, hidden_states, context_metadata, bias=None):
|
||||
assert self.original_shape is not None
|
||||
def token_combine(self, hidden_states, combine_metadata, bias=None):
|
||||
final_hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=hidden_states,
|
||||
sorted_indices=torch.abs(context_metadata["expanded_row_idx"]),
|
||||
probs=context_metadata["topk_weights"],
|
||||
sorted_indices=torch.abs(combine_metadata.expanded_row_idx),
|
||||
probs=combine_metadata.topk_weights,
|
||||
)
|
||||
if len(self.original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(self.original_shape)
|
||||
if len(combine_metadata.restore_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(combine_metadata.restore_shape)
|
||||
|
||||
# these values are no longer used, so they need to be set to None for memory release.
|
||||
return TokenCombineResult(routed_out=final_hidden_states)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
class TokenDispatcherWithAll2AllV(MoETokenDispatcher[MoEAllToAllCombineMetadata]):
|
||||
"""
|
||||
The implementation of the AlltoAll-based token dispatcher, which handles token
|
||||
dispatching on the sequence level instead of token level. The core of this implementation
|
||||
@@ -402,12 +369,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.with_quant = False
|
||||
self.num_local_experts = kwargs.get("num_local_experts", 0)
|
||||
|
||||
self.hidden_shape = None
|
||||
self.hidden_shape_before_permute = None
|
||||
|
||||
assert self.num_local_experts > 0, "Expected at least one expert"
|
||||
if self.num_local_experts > 1:
|
||||
self.expert_ids_per_ep_rank = torch.tensor(
|
||||
@@ -432,19 +395,12 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
|
||||
def token_dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
token_dispatch_input: MoETokenDispatchInput,
|
||||
):
|
||||
self.with_quant = with_quant
|
||||
self.hidden_shape = hidden_states.shape
|
||||
with_quant = token_dispatch_input.quant.is_int_quant
|
||||
hidden_states = token_dispatch_input.hidden_states
|
||||
topk_weights = token_dispatch_input.topk_weights
|
||||
topk_ids = token_dispatch_input.topk_ids
|
||||
|
||||
(
|
||||
permutated_local_input_tokens,
|
||||
@@ -452,12 +408,13 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
tokens_per_expert,
|
||||
input_splits,
|
||||
output_splits,
|
||||
num_global_tokens_per_local_expert,
|
||||
global_input_tokens_local_experts_indices,
|
||||
hidden_shape,
|
||||
hidden_shape_before_permute,
|
||||
) = self._dispatch_preprocess(hidden_states, topk_ids)
|
||||
|
||||
dynamic_scale_after_all2all = None
|
||||
if self.with_quant:
|
||||
if with_quant:
|
||||
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(permutated_local_input_tokens)
|
||||
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
|
||||
dynamic_scale, output_splits, input_splits, self.ep_group
|
||||
@@ -474,64 +431,66 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
# Postprocess
|
||||
global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = (
|
||||
self._dispatch_postprocess(
|
||||
global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices
|
||||
global_input_tokens,
|
||||
dynamic_scale_after_all2all,
|
||||
global_input_tokens_local_experts_indices,
|
||||
with_quant,
|
||||
)
|
||||
)
|
||||
|
||||
context_metadata = {
|
||||
"input_splits": input_splits,
|
||||
"output_splits": output_splits,
|
||||
"topk_weights": topk_weights,
|
||||
"reversed_local_input_permutation_mapping": reversed_local_input_permutation_mapping,
|
||||
"reversed_global_input_permutation_mapping": reversed_global_input_permutation_mapping,
|
||||
}
|
||||
|
||||
return TokenDispatchResult(
|
||||
return MoETokenDispatchOutput(
|
||||
hidden_states=global_input_tokens,
|
||||
dynamic_scale=dynamic_scale_final,
|
||||
group_list=tokens_per_expert,
|
||||
group_list_type=1,
|
||||
context_metadata=context_metadata,
|
||||
combine_metadata=MoEAllToAllCombineMetadata(
|
||||
input_splits=input_splits,
|
||||
output_splits=output_splits,
|
||||
topk_weights=topk_weights,
|
||||
reversed_local_input_permutation_mapping=reversed_local_input_permutation_mapping,
|
||||
reversed_global_input_permutation_mapping=reversed_global_input_permutation_mapping,
|
||||
hidden_shape=hidden_shape,
|
||||
hidden_shape_before_permute=hidden_shape_before_permute,
|
||||
),
|
||||
)
|
||||
|
||||
def token_combine(self, hidden_states, context_metadata, bias=None):
|
||||
def token_combine(self, hidden_states, combine_metadata, bias=None):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||
|
||||
# 1. Preprocess using metadata
|
||||
hidden_states = self._combine_preprocess(hidden_states, context_metadata)
|
||||
hidden_states = self._combine_preprocess(hidden_states, combine_metadata)
|
||||
|
||||
# 2. AllToAll
|
||||
_, permutated_local_input_tokens, handle = async_all_to_all(
|
||||
hidden_states,
|
||||
context_metadata["input_splits"],
|
||||
context_metadata["output_splits"],
|
||||
combine_metadata.input_splits,
|
||||
combine_metadata.output_splits,
|
||||
self.ep_group,
|
||||
)
|
||||
handle.wait()
|
||||
hidden_states.untyped_storage().resize_(0)
|
||||
|
||||
# 3. Postprocess using metadata
|
||||
output = self._combine_postprocess(permutated_local_input_tokens, context_metadata)
|
||||
output = self._combine_postprocess(permutated_local_input_tokens, combine_metadata)
|
||||
|
||||
return TokenCombineResult(routed_out=output)
|
||||
return output
|
||||
|
||||
def _dispatch_preprocess(self, hidden_states, topk_ids):
|
||||
assert self.hidden_shape is not None
|
||||
hidden_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
||||
(
|
||||
tokens_per_expert,
|
||||
input_splits,
|
||||
output_splits,
|
||||
num_global_tokens_per_local_expert,
|
||||
global_input_tokens_local_experts_indices,
|
||||
num_out_tokens,
|
||||
) = self._preprocess(topk_ids)
|
||||
|
||||
self.hidden_shape_before_permute = hidden_states.shape
|
||||
hidden_shape_before_permute = hidden_states.shape
|
||||
|
||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
tokens=hidden_states,
|
||||
indices=topk_ids,
|
||||
num_out_tokens=self.num_out_tokens,
|
||||
num_out_tokens=num_out_tokens,
|
||||
)
|
||||
|
||||
return (
|
||||
@@ -540,15 +499,16 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
tokens_per_expert,
|
||||
input_splits,
|
||||
output_splits,
|
||||
num_global_tokens_per_local_expert,
|
||||
global_input_tokens_local_experts_indices,
|
||||
hidden_shape,
|
||||
hidden_shape_before_permute,
|
||||
)
|
||||
|
||||
def _preprocess(self, topk_ids: torch.Tensor):
|
||||
num_local_tokens_per_expert = torch.histc(topk_ids, bins=self.num_experts, min=0, max=self.num_experts)
|
||||
|
||||
ep_size = self.ep_size
|
||||
self.num_out_tokens = topk_ids.numel()
|
||||
num_out_tokens = topk_ids.numel()
|
||||
|
||||
input_splits = (
|
||||
num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts)
|
||||
@@ -585,19 +545,19 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
num_tokens_per_local_expert,
|
||||
input_splits,
|
||||
output_splits,
|
||||
num_global_tokens_per_local_expert,
|
||||
global_input_tokens_local_experts_indices,
|
||||
num_out_tokens,
|
||||
)
|
||||
|
||||
def _dispatch_postprocess(
|
||||
self, global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices
|
||||
self, global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices, with_quant
|
||||
):
|
||||
# Early return if no local experts or no tokens
|
||||
if self.num_local_experts <= 1:
|
||||
return global_input_tokens, dynamic_scale_after_all2all, None
|
||||
|
||||
# Handle quantized case
|
||||
if self.with_quant:
|
||||
if with_quant:
|
||||
assert global_input_tokens_local_experts_indices is not None, (
|
||||
"global_input_tokens_local_experts_indices must be provided"
|
||||
)
|
||||
@@ -612,20 +572,26 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
)
|
||||
return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
|
||||
|
||||
def _combine_preprocess(self, hidden_states: torch.Tensor, context_metadata: dict) -> torch.Tensor:
|
||||
def _combine_preprocess(
|
||||
self, hidden_states: torch.Tensor, combine_metadata: MoEAllToAllCombineMetadata
|
||||
) -> torch.Tensor:
|
||||
# Unpermutation 2: expert output to AlltoAll input
|
||||
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
|
||||
rev_global = context_metadata["reversed_global_input_permutation_mapping"]
|
||||
rev_global = combine_metadata.reversed_global_input_permutation_mapping
|
||||
if hidden_states.shape[0] > 0 and self.num_local_experts > 1 and rev_global is not None:
|
||||
hidden_states = torch_npu.npu_moe_token_unpermute(hidden_states, rev_global)
|
||||
return hidden_states
|
||||
|
||||
def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor, context_metadata: dict) -> torch.Tensor:
|
||||
def _combine_postprocess(
|
||||
self,
|
||||
permutated_local_input_tokens: torch.Tensor,
|
||||
combine_metadata: MoEAllToAllCombineMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Unpermutation 1: AlltoAll output to output
|
||||
output = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=permutated_local_input_tokens,
|
||||
sorted_indices=context_metadata["reversed_local_input_permutation_mapping"].to(torch.int32),
|
||||
probs=context_metadata["topk_weights"],
|
||||
restore_shape=self.hidden_shape_before_permute,
|
||||
sorted_indices=combine_metadata.reversed_local_input_permutation_mapping.to(torch.int32),
|
||||
probs=combine_metadata.topk_weights,
|
||||
restore_shape=combine_metadata.hidden_shape_before_permute,
|
||||
)
|
||||
output = output.view(self.hidden_shape)
|
||||
output = output.view(combine_metadata.hidden_shape)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user