Files
xc-llm-ascend/vllm_ascend/ops/moe/token_dispatcher.py
weichen 2f1b9a7a64 Reapply "[MoE] [Refactor] Remove manual memory cleanup (#3365)" (#3483) (#3512)
### What this PR does / why we need it?
1. Replace manual memory cleanup with passing parameter.
2. FusedMoEPrepareAndFinalizeWithMC2 inherits All2All avoid duplicated
code.
3. Fix MC2 bug introduced in
https://github.com/vllm-project/vllm-ascend/pull/3365
4. Unify aclgraph & eager in W8A8_dynamic.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
2025-10-22 11:41:30 +08:00

742 lines
30 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Optional
import torch
import torch_npu
from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.moe.comm_utils import (
async_all_to_all, gather_from_sequence_parallel_region)
from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version,
is_hierarchical_communication_enabled)
class MoETokenDispatcher(ABC):
def __init__(self, **kwargs) -> None:
"""
Initialize the MoE Token Dispatcher.
"""
self.top_k = kwargs.get("top_k", 0)
self.num_experts = kwargs.get("num_experts", 0)
@property
def ep_group(self):
"""Get expert model parallel group."""
return get_ep_group().device_group
@property
def ep_rank(self):
return get_ep_group().rank_in_group
@property
def ep_size(self):
return get_ep_group().world_size
@abstractmethod
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
def token_combine(self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None):
raise NotImplementedError("Combine function not implemented.")
class TokenDispatcherWithMC2(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
self.ep_rank_id = get_mc2_group().rank_in_group
self.ep_world_size = get_mc2_group().world_size
self.enable_dispatch_v2 = hasattr(torch_npu,
"npu_moe_distribute_dispatch_v2")
self.need_extra_args = (
get_ascend_soc_version() == AscendSocVersion.A3)
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
self.a3_need_extra_args = \
get_ascend_soc_version() == AscendSocVersion.A3
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
# improve communication performance.
self.need_expert_scale = is_hierarchical_communication_enabled()
self.with_quant = False
def get_dispatch_mc2_kwargs(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
mc2_mask: torch.Tensor,
global_redundant_expert_num: int = 0,
):
if self.with_quant:
quant_mode = 2
moe_expert_num = len(expert_map)
else:
quant_mode = 0
moe_expert_num = len(expert_map)
kwargs_mc2 = {
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
"expert_token_nums_type": 0,
}
stage1_kwargs = {
"scales": None,
"quant_mode": quant_mode,
"group_ep": self.moe_all_to_all_group_name,
"ep_world_size": self.ep_world_size,
"ep_rank_id": self.ep_rank_id,
}
if self.need_extra_args:
stage1_kwargs.update({
"group_tp": self.moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.a3_need_extra_args and self.enable_dispatch_v2:
stage1_kwargs.update({
"x_active_mask": mc2_mask,
})
if self.need_expert_scale:
stage1_kwargs.update({
"expert_scales":
topk_weights.to(torch.float32),
})
kwargs_mc2.update(stage1_kwargs)
return kwargs_mc2
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
):
self.with_quant = with_quant
# Apply log2phy if needed
if log2phy is not None:
topk_ids = log2phy[topk_ids]
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
topk_ids, expert_map,
mc2_mask,
global_redundant_expert_num)
output = torch_npu.npu_moe_distribute_dispatch_v2(
**kwargs_mc2
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \
ep_recv_counts, tp_recv_counts, expand_scales = output[0:7]
# Handle shared experts (store intermediate results in local vars, not self)
shared_act = None
swiglu_out_scale = None
if with_quant:
if shared_experts is not None:
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
shared_act_out = shared_experts.act_fn(
(shared_gate_up, shared_dequant_scale))
shared_act, swiglu_out_scale = shared_act_out[
0], shared_act_out[1]
else:
if shared_experts is not None:
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
shared_act = shared_experts.act_fn(shared_gate_up)
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"mc2_mask": mc2_mask,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"tp_recv_counts": tp_recv_counts,
"assist_info_for_combine": assist_info_for_combine,
"shared_experts": shared_experts,
"shared_act": shared_act,
"swiglu_out_scale": swiglu_out_scale,
"expand_scales": expand_scales
}
return {
"group_list_type": 0,
"hidden_states": expand_x,
"group_list": expert_token_nums,
"dynamic_scale": dynamic_scale,
"context_metadata": context_metadata
}
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
context_metadata: dict):
expert_map = context_metadata["expert_map"]
topk_ids = context_metadata["topk_ids"]
topk_weights = context_metadata["topk_weights"]
ep_recv_counts = context_metadata["ep_recv_counts"]
tp_recv_counts = context_metadata["tp_recv_counts"]
assist_info_for_combine = context_metadata["assist_info_for_combine"]
mc2_mask = context_metadata["mc2_mask"]
expand_scales = context_metadata["expand_scales"]
assert expert_map is not None
moe_expert_num = len(expert_map)
kwargs_mc2 = {
"expand_x": hidden_states,
"expert_ids": topk_ids,
"expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
}
if self.with_quant:
tp_recv_counts = torch.empty(1,
dtype=torch.int32,
device=hidden_states.device)
stage3_kwargs = {
"ep_send_counts": ep_recv_counts,
"group_ep": self.moe_all_to_all_group_name,
"ep_world_size": self.ep_world_size,
"ep_rank_id": self.ep_rank_id,
"expand_scales": expand_scales,
}
if self.enable_dispatch_v2:
stage3_kwargs["assist_info_for_combine"] = assist_info_for_combine
else:
stage3_kwargs["expand_idx"] = assist_info_for_combine
if self.need_extra_args:
stage3_kwargs.update({
"tp_send_counts": tp_recv_counts,
"group_tp": self.moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.a3_need_extra_args and self.enable_dispatch_v2:
stage3_kwargs["x_active_mask"] = mc2_mask
kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2
def token_combine(
self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None,
):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states,
context_metadata)
combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \
if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
# Handle shared experts from metadata
shared_experts = context_metadata["shared_experts"]
if shared_experts is None:
return combined_output
shared_act = context_metadata["shared_act"]
if self.with_quant:
swiglu_out_scale = context_metadata["swiglu_out_scale"]
shared_hidden_states, _ = shared_experts.down_proj(
(shared_act, swiglu_out_scale))
else:
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
return combined_output, shared_hidden_states
class TokenDispatcherWithAllGather(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.apply_router_weight_on_input = False
self.max_num_tokens = kwargs.get("max_num_tokens")
self.num_experts_local = kwargs.get("num_local_experts", 0)
self.original_shape = None
self.with_quant = False
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.with_quant = with_quant
self.original_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel()
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * \
topk_weights.to(hidden_states.dtype)
if expert_map is not None:
global_num_experts = len(expert_map)
mask = (expert_map[topk_ids] != -1)
topk_weights = topk_weights * mask
first_expert_idx = get_ep_group(
).rank_in_group * self.num_experts_local
last_expert_idx = first_expert_idx + self.num_experts_local
else:
first_expert_idx = 0
last_expert_idx = self.num_experts_local
global_num_experts = self.num_experts_local
sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = (
torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
active_num=num_tokens * self.top_k,
expert_num=global_num_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=1 if self.with_quant else -1,
))
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode
context_metadata = {
"topk_weights": topk_weights,
"expanded_row_idx": expanded_row_idx
}
return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": expert_tokens,
"dynamic_scale": pertoken_scale if self.with_quant else None,
"context_metadata": context_metadata
}
def token_combine(self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None):
assert self.original_shape is not None
final_hidden_states = torch_npu.npu_moe_token_unpermute(
permuted_tokens=hidden_states,
sorted_indices=torch.abs(context_metadata["expanded_row_idx"]),
probs=context_metadata["topk_weights"])
if len(self.original_shape) == 3:
final_hidden_states = final_hidden_states.view(self.original_shape)
# these values are no longer used, so they need to be set to None for memory release.
return final_hidden_states
# mypy: disable-error-code="override"
class TokenDispatcherWithMoge(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.apply_router_weight_on_input = False
self.local_num_experts = self.num_experts // self.ep_size
self.local_num_group = self.top_k // self.ep_size
self.bsz = None
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
sorted_hidden_states = hidden_states.index_select(
0, self.sorted_topk_ids // self.local_num_group)
experts_id = torch.arange(0,
self.local_num_experts,
dtype=topk_ids.dtype,
device=topk_ids.device)
num_tokens_per_expert = (
flatten_topk_ids.unsqueeze(-1) == experts_id).to(
torch.float32).sum(0)
topk_scales = topk_weights.view(-1).index_select(
0, self.sorted_topk_ids).unsqueeze(-1)
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
group_list_type = 0
return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": group_list,
"topk_scales": topk_scales
}
def token_combine(self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None):
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
torch.int32)
unsorted_hidden_states = hidden_states.index_select(
0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
self.bsz, self.top_k // self.ep_size, -1).sum(1)
return final_hidden_states
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
"""
The implementation of the AlltoAll-based token dispatcher, which handles token
dispatching on the sequence level instead of token level. The core of this implementation
lies in each device dispatching on the entire sequence, with the hidden state being partitioned.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.with_quant = False
self.num_local_experts = kwargs.get("num_local_experts", 0)
self.hidden_shape = None
self.hidden_shape_before_permute = None
assert self.num_local_experts > 0, "Expected at least one expert"
if self.num_local_experts > 1:
self.expert_ids_per_ep_rank = torch.tensor(
[i % self.num_local_experts for i in range(self.num_experts)],
dtype=torch.int32,
device=torch.npu.current_device(),
)
local_expert_indices_offset = (self.ep_rank * self.num_local_experts)
self.local_expert_indices = [
local_expert_indices_offset + i
for i in range(self.num_local_experts)
]
assert (len(self.local_expert_indices) == self.num_local_experts
), "Invalid local expert indices"
for i in range(len(self.local_expert_indices) - 1):
assert (self.local_expert_indices[i] ==
self.local_expert_indices[i + 1] -
1), "local_expert_indices must be continuous"
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
):
self.with_quant = with_quant
self.hidden_shape = hidden_states.shape
if log2phy is not None:
topk_ids = log2phy[topk_ids]
(
permutated_local_input_tokens,
reversed_local_input_permutation_mapping,
tokens_per_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
) = self._dispatch_preprocess(hidden_states, topk_ids)
dynamic_scale_after_all2all = None
if self.with_quant:
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(
permutated_local_input_tokens)
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
dynamic_scale, output_splits, input_splits, self.ep_group)
permute2_ep_all_to_all_handle.wait()
dynamic_scale.untyped_storage().resize_(0)
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
permutated_local_input_tokens, output_splits, input_splits,
self.ep_group)
permute1_ep_all_to_all_handle.wait()
permutated_local_input_tokens.untyped_storage().resize_(0)
# Postprocess
global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = self._dispatch_postprocess(
global_input_tokens, dynamic_scale_after_all2all,
global_input_tokens_local_experts_indices)
context_metadata = {
"input_splits":
input_splits,
"output_splits":
output_splits,
"topk_weights":
topk_weights,
"reversed_local_input_permutation_mapping":
reversed_local_input_permutation_mapping,
"reversed_global_input_permutation_mapping":
reversed_global_input_permutation_mapping
}
return {
"hidden_states": global_input_tokens,
"group_list": tokens_per_expert,
"group_list_type": 1,
"dynamic_scale": dynamic_scale_final,
"context_metadata": context_metadata,
}
def token_combine(
self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None,
):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
# 1. Preprocess using metadata
hidden_states = self._combine_preprocess(hidden_states,
context_metadata)
# 2. AllToAll
_, permutated_local_input_tokens, handle = async_all_to_all(
hidden_states,
context_metadata["input_splits"],
context_metadata["output_splits"],
self.ep_group,
)
handle.wait()
hidden_states.untyped_storage().resize_(0)
# 3. Postprocess using metadata
output = self._combine_postprocess(permutated_local_input_tokens,
context_metadata)
return output
def _dispatch_preprocess(self, hidden_states, topk_ids):
assert self.hidden_shape is not None
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
(
tokens_per_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
) = self._preprocess(topk_ids)
self.hidden_shape_before_permute = hidden_states.shape
permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
tokens=hidden_states,
indices=topk_ids,
num_out_tokens=self.num_out_tokens,
)
return (
permutated_local_input_tokens,
reversed_local_input_permutation_mapping,
tokens_per_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
)
def _preprocess(self, topk_ids: torch.Tensor):
num_local_tokens_per_expert = torch.histc(topk_ids,
bins=self.num_experts,
min=0,
max=self.num_experts)
ep_size = self.ep_size
self.num_out_tokens = topk_ids.numel()
input_splits = (num_local_tokens_per_expert.reshape(
ep_size,
self.num_local_experts).sum(axis=1).to(torch.device("cpu"),
non_blocking=True).numpy())
num_global_tokens_per_expert = gather_from_sequence_parallel_region(
num_local_tokens_per_expert,
group=self.ep_group).reshape(ep_size, self.num_experts)
num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
0]:self.local_expert_indices[-1] + 1]
if num_global_tokens_per_local_expert is None:
raise ValueError(
"num_global_tokens_per_local_expert must be set before sum.")
output_splits = (num_global_tokens_per_local_expert.sum(axis=-1).to(
torch.device("cpu"), non_blocking=True).numpy())
num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(
axis=0)
global_input_tokens_local_experts_indices = None
if self.num_local_experts > 1:
if num_global_tokens_per_local_expert is None:
raise ValueError(
"num_global_tokens_per_local_expert must be set before operations."
)
global_input_tokens_local_experts_indices = torch.repeat_interleave(
self.expert_ids_per_ep_rank,
num_global_tokens_per_local_expert.ravel())
else:
torch.npu.synchronize()
return (
num_tokens_per_local_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
)
def _dispatch_postprocess(self, global_input_tokens,
dynamic_scale_after_all2all,
global_input_tokens_local_experts_indices):
# Early return if no local experts or no tokens
if self.num_local_experts <= 1:
return global_input_tokens, dynamic_scale_after_all2all, None
# Handle quantized case
if self.with_quant:
assert global_input_tokens_local_experts_indices is not None, \
"global_input_tokens_local_experts_indices must be provided"
expert_idx_2d = global_input_tokens_local_experts_indices.unsqueeze(
-1)
active_num = global_input_tokens_local_experts_indices.numel()
if active_num <= 0:
reversed_global_input_permutation_mapping = global_input_tokens_local_experts_indices
return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
global_input_tokens, reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
global_input_tokens,
expert_idx_2d,
scale=dynamic_scale_after_all2all,
active_num=active_num,
expert_capacity=0,
expert_num=self.num_local_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[0, self.num_local_experts],
quant_mode=-1,
row_idx_type=0,
)
return global_input_tokens, expanded_scale, reversed_global_input_permutation_mapping
# Non-quantized case
global_input_tokens, reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
global_input_tokens, global_input_tokens_local_experts_indices)
return global_input_tokens, None, reversed_global_input_permutation_mapping
def _combine_preprocess(self, hidden_states: torch.Tensor,
context_metadata: dict) -> torch.Tensor:
# Unpermutation 2: expert output to AlltoAll input
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
rev_global = context_metadata[
"reversed_global_input_permutation_mapping"]
hidden_states = torch_npu.npu_moe_token_unpermute(
hidden_states, rev_global)
return hidden_states
def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor,
context_metadata: dict) -> torch.Tensor:
# Unpermutation 1: AlltoAll output to output
output = torch_npu.npu_moe_token_unpermute(
permuted_tokens=permutated_local_input_tokens,
sorted_indices=context_metadata[
"reversed_local_input_permutation_mapping"].to(torch.int32),
probs=context_metadata["topk_weights"],
restore_shape=self.hidden_shape_before_permute,
)
output = output.view(self.hidden_shape)
return output