Reapply "[MoE] [Refactor] Remove manual memory cleanup (#3365)" (#3483) (#3512)

### What this PR does / why we need it?
1. Replace manual memory cleanup with passing parameter.
2. FusedMoEPrepareAndFinalizeWithMC2 inherits All2All avoid duplicated
code.
3. Fix MC2 bug introduced in
https://github.com/vllm-project/vllm-ascend/pull/3365
4. Unify aclgraph & eager in W8A8_dynamic.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-10-22 11:41:30 +08:00
committed by GitHub
parent 6ef62cb427
commit 2f1b9a7a64
13 changed files with 608 additions and 522 deletions

View File

@@ -75,6 +75,7 @@ class MoETokenDispatcher(ABC):
@abstractmethod
def token_combine(self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None):
raise NotImplementedError("Combine function not implemented.")
@@ -102,16 +103,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
# improve communication performance.
self.need_expert_scale = is_hierarchical_communication_enabled()
self.output = None
self.assist_info_for_combine = None
self.ep_recv_counts = None
self.shared_act = None
self.topk_ids = None
self.topk_weights = None
self.shared_experts = None
self.mc2_mask = None
self.with_quant = False
self.expand_scales = None
def get_dispatch_mc2_kwargs(
self,
@@ -119,6 +111,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
mc2_mask: torch.Tensor,
global_redundant_expert_num: int = 0,
):
if self.with_quant:
@@ -152,7 +145,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
})
if self.a3_need_extra_args and self.enable_dispatch_v2:
stage1_kwargs.update({
"x_active_mask": self.mc2_mask,
"x_active_mask": mc2_mask,
})
if self.need_expert_scale:
stage1_kwargs.update({
@@ -163,99 +156,121 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
kwargs_mc2.update(stage1_kwargs)
return kwargs_mc2
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
):
self.with_quant = with_quant
self.expert_map = expert_map
self.topk_ids = topk_ids
self.topk_weights = topk_weights
self.shared_experts = shared_experts
self.mc2_mask = mc2_mask
# Apply log2phy if needed
if log2phy is not None:
topk_ids = log2phy[topk_ids]
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
topk_ids, expert_map,
mc2_mask,
global_redundant_expert_num)
self.output = torch_npu.npu_moe_distribute_dispatch_v2(
output = torch_npu.npu_moe_distribute_dispatch_v2(
**kwargs_mc2
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, self.assist_info_for_combine, expert_token_nums, \
self.ep_recv_counts, _, self.expand_scales = self.output[0:7]
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \
ep_recv_counts, tp_recv_counts, expand_scales = output[0:7]
if self.with_quant:
# Handle shared experts (store intermediate results in local vars, not self)
shared_act = None
swiglu_out_scale = None
if with_quant:
if shared_experts is not None:
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
shared_act_out = shared_experts.act_fn(
(shared_gate_up, shared_dequant_scale))
self.shared_act, self.swiglu_out_scale = \
shared_act_out[0], shared_act_out[1]
shared_act, swiglu_out_scale = shared_act_out[
0], shared_act_out[1]
else:
if shared_experts is not None:
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
self.shared_act = shared_experts.act_fn(shared_gate_up)
group_list_type = 0
shared_act = shared_experts.act_fn(shared_gate_up)
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"mc2_mask": mc2_mask,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"tp_recv_counts": tp_recv_counts,
"assist_info_for_combine": assist_info_for_combine,
"shared_experts": shared_experts,
"shared_act": shared_act,
"swiglu_out_scale": swiglu_out_scale,
"expand_scales": expand_scales
}
return {
"group_list_type": group_list_type,
"group_list_type": 0,
"hidden_states": expand_x,
"group_list": expert_token_nums,
"dynamic_scale": dynamic_scale,
"context_metadata": context_metadata
}
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor):
assert self.expert_map is not None
assert self.topk_weights is not None
assert self.topk_ids is not None
assert self.output is not None
moe_expert_num = len(self.expert_map)
# moeCombine
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
context_metadata: dict):
expert_map = context_metadata["expert_map"]
topk_ids = context_metadata["topk_ids"]
topk_weights = context_metadata["topk_weights"]
ep_recv_counts = context_metadata["ep_recv_counts"]
tp_recv_counts = context_metadata["tp_recv_counts"]
assist_info_for_combine = context_metadata["assist_info_for_combine"]
mc2_mask = context_metadata["mc2_mask"]
expand_scales = context_metadata["expand_scales"]
assert expert_map is not None
moe_expert_num = len(expert_map)
kwargs_mc2 = {
"expand_x": hidden_states,
"expert_ids": self.topk_ids,
"expert_scales": self.topk_weights.to(torch.float32),
"expert_ids": topk_ids,
"expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
}
if self.with_quant:
tp_recv_counts = torch.empty(1,
dtype=torch.int32,
device=hidden_states.device)
else:
tp_recv_counts = self.output[5]
stage3_kwargs = {
"ep_send_counts": self.ep_recv_counts,
"ep_send_counts": ep_recv_counts,
"group_ep": self.moe_all_to_all_group_name,
"ep_world_size": self.ep_world_size,
"ep_rank_id": self.ep_rank_id,
"expand_scales": self.expand_scales,
"expand_scales": expand_scales,
}
if self.enable_dispatch_v2:
stage3_kwargs.update({
"assist_info_for_combine":
self.assist_info_for_combine,
})
stage3_kwargs["assist_info_for_combine"] = assist_info_for_combine
else:
stage3_kwargs.update({
"expand_idx": self.assist_info_for_combine,
})
stage3_kwargs["expand_idx"] = assist_info_for_combine
if self.need_extra_args:
stage3_kwargs.update({
"tp_send_counts": tp_recv_counts,
@@ -263,45 +278,40 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.a3_need_extra_args and self.enable_dispatch_v2:
stage3_kwargs.update({
"x_active_mask": self.mc2_mask,
})
stage3_kwargs["x_active_mask"] = mc2_mask
kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states)
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
**kwargs_mc2
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
**kwargs_mc2)
def token_combine(
self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None,
):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
# these values are no longer used, so they need to be set to None for memory release.
self.output = None
self.assist_info_for_combine = None
self.ep_recv_counts = None
self.topk_ids = None
self.topk_weights = None
self.mc2_mask = None
self.expert_map = None
self.expand_scales = None
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states,
context_metadata)
combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \
if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
if self.shared_experts is None:
return hidden_states
# Handle shared experts from metadata
shared_experts = context_metadata["shared_experts"]
if shared_experts is None:
return combined_output
shared_act = context_metadata["shared_act"]
if self.with_quant:
swiglu_out_scale = context_metadata["swiglu_out_scale"]
shared_hidden_states, _ = shared_experts.down_proj(
(shared_act, swiglu_out_scale))
else:
if self.with_quant:
shared_hidden_states, _ = self.shared_experts.down_proj(
(self.shared_act, self.swiglu_out_scale))
else:
shared_hidden_states, _ = self.shared_experts.down_proj(
self.shared_act)
self.shared_act = None
self.shared_experts = None
self.swiglu_out_scale = None
return hidden_states, shared_hidden_states
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
return combined_output, shared_hidden_states
class TokenDispatcherWithAllGather(MoETokenDispatcher):
@@ -311,14 +321,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
self.apply_router_weight_on_input = False
self.max_num_tokens = kwargs.get("max_num_tokens")
self.num_experts_local = kwargs.get("num_local_experts", 0)
self.sorted_weights = None
self.expanded_row_idx = None
self.sorted_token_indices = None
self.original_shape = None
self.mask = None
self.expert_map = None
self.topk_weights = None
self.topk_ids = None
self.with_quant = False
def token_dispatch(self,
@@ -338,9 +341,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
self.original_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel()
self.expert_map = expert_map
self.topk_weights = topk_weights
self.topk_ids = topk_ids
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2
@@ -354,7 +354,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
if expert_map is not None:
global_num_experts = len(expert_map)
mask = (expert_map[topk_ids] != -1)
self.topk_weights = topk_weights * mask
topk_weights = topk_weights * mask
first_expert_idx = get_ep_group(
).rank_in_group * self.num_experts_local
last_expert_idx = first_expert_idx + self.num_experts_local
@@ -363,7 +363,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
last_expert_idx = self.num_experts_local
global_num_experts = self.num_experts_local
sorted_hidden_states, self.expanded_row_idx, expert_tokens, pertoken_scale = (
sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = (
torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
@@ -376,29 +376,31 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
))
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode
context_metadata = {
"topk_weights": topk_weights,
"expanded_row_idx": expanded_row_idx
}
return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": expert_tokens,
"dynamic_scale": pertoken_scale if self.with_quant else None,
"context_metadata": context_metadata
}
def token_combine(self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None):
assert self.original_shape is not None
final_hidden_states = torch_npu.npu_moe_token_unpermute(
permuted_tokens=hidden_states,
sorted_indices=torch.abs(self.expanded_row_idx),
probs=self.topk_weights)
sorted_indices=torch.abs(context_metadata["expanded_row_idx"]),
probs=context_metadata["topk_weights"])
if len(self.original_shape) == 3:
final_hidden_states = final_hidden_states.view(self.original_shape)
# these values are no longer used, so they need to be set to None for memory release.
self.expert_map = None
self.topk_weights = None
self.topk_ids = None
self.expanded_row_idx = None
return final_hidden_states
@@ -447,11 +449,12 @@ class TokenDispatcherWithMoge(MoETokenDispatcher):
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": group_list,
"topk_scales": topk_scales,
"topk_scales": topk_scales
}
def token_combine(self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None):
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
torch.int32)
@@ -475,19 +478,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
self.num_local_experts = kwargs.get("num_local_experts", 0)
self.hidden_shape = None
self.topk_weights = None
self.input_splits = None
self.output_splits = None
self.hidden_shape_before_permute = None
# [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent
# to each local expert by all ranks.
self.num_global_tokens_per_local_expert = None
# cached intermediate tensors.
self.tokens_per_expert = None
self.global_input_tokens_local_experts_indices = None
assert self.num_local_experts > 0, "Expected at least one expert"
if self.num_local_experts > 1:
self.expert_ids_per_ep_rank = torch.tensor(
@@ -509,96 +501,116 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
self.local_expert_indices[i + 1] -
1), "local_expert_indices must be continuous"
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
):
self.with_quant = with_quant
self.hidden_shape = hidden_states.shape
self.topk_weights = topk_weights
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"
assert topk_ids.dim() == 2, "Expected 2D tensor for routing map"
if log2phy is not None:
topk_ids = log2phy[topk_ids]
permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = self._dispatch_preprocess(
hidden_states, topk_ids)
self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping
(
permutated_local_input_tokens,
reversed_local_input_permutation_mapping,
tokens_per_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
) = self._dispatch_preprocess(hidden_states, topk_ids)
dynamic_scale_after_all2all = None
if self.with_quant:
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(
permutated_local_input_tokens)
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
dynamic_scale,
self.output_splits,
self.input_splits,
self.ep_group,
)
dynamic_scale, output_splits, input_splits, self.ep_group)
permute2_ep_all_to_all_handle.wait()
dynamic_scale.untyped_storage().resize_(0)
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
permutated_local_input_tokens,
self.output_splits,
self.input_splits,
self.ep_group,
)
permutated_local_input_tokens, output_splits, input_splits,
self.ep_group)
permute1_ep_all_to_all_handle.wait()
permutated_local_input_tokens.untyped_storage().resize_(0)
global_input_tokens, dynamic_scale = self._dispatch_postprocess(
global_input_tokens, dynamic_scale_after_all2all)
# Postprocess
global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = self._dispatch_postprocess(
global_input_tokens, dynamic_scale_after_all2all,
global_input_tokens_local_experts_indices)
context_metadata = {
"input_splits":
input_splits,
"output_splits":
output_splits,
"topk_weights":
topk_weights,
"reversed_local_input_permutation_mapping":
reversed_local_input_permutation_mapping,
"reversed_global_input_permutation_mapping":
reversed_global_input_permutation_mapping
}
return {
"hidden_states": global_input_tokens,
"group_list": tokens_per_expert,
"dynamic_scale": dynamic_scale,
"group_list_type": 1
"group_list_type": 1,
"dynamic_scale": dynamic_scale_final,
"context_metadata": context_metadata,
}
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
def token_combine(
self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None,
):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
hidden_states = self._combine_preprocess(hidden_states)
# 1. Preprocess using metadata
hidden_states = self._combine_preprocess(hidden_states,
context_metadata)
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
# 2. AllToAll
_, permutated_local_input_tokens, handle = async_all_to_all(
hidden_states, self.input_splits, self.output_splits,
self.ep_group)
hidden_states,
context_metadata["input_splits"],
context_metadata["output_splits"],
self.ep_group,
)
handle.wait()
hidden_states.untyped_storage().resize_(0)
output = self._combine_postprocess(permutated_local_input_tokens)
# these values are no longer used, so they need to be set to None for memory release.
self.input_splits = None
self.output_splits = None
self.num_global_tokens_per_local_expert = None
self.topk_weights = None
self.reversed_local_input_permutation_mapping = None
self.reversed_global_input_permutation_mapping = None
self.global_input_tokens_local_experts_indices = None
# 3. Postprocess using metadata
output = self._combine_postprocess(permutated_local_input_tokens,
context_metadata)
return output
def _dispatch_preprocess(self, hidden_states, topk_ids):
assert self.hidden_shape is not None
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self._preprocess(topk_ids)
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
(
tokens_per_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
) = self._preprocess(topk_ids)
self.hidden_shape_before_permute = hidden_states.shape
@@ -607,82 +619,88 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
indices=topk_ids,
num_out_tokens=self.num_out_tokens,
)
return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert
def _preprocess(self, topk_ids: torch.Tensor) -> torch.Tensor:
return (
permutated_local_input_tokens,
reversed_local_input_permutation_mapping,
tokens_per_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
)
def _preprocess(self, topk_ids: torch.Tensor):
num_local_tokens_per_expert = torch.histc(topk_ids,
bins=self.num_experts,
min=0,
max=self.num_experts)
ep_size = self.ep_size
# Dropless
self.num_out_tokens = topk_ids.numel()
# ===================================================
# Calculate input_splits, output_splits for alltoall-v.
# ===================================================
self.input_splits = (num_local_tokens_per_expert.reshape(
input_splits = (num_local_tokens_per_expert.reshape(
ep_size,
self.num_local_experts).sum(axis=1).to(torch.device("cpu"),
non_blocking=True).numpy())
num_global_tokens_per_expert = gather_from_sequence_parallel_region(
num_local_tokens_per_expert,
group=self.ep_group).reshape(ep_size, self.num_experts)
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
0]:self.local_expert_indices[-1] + 1]
if self.num_global_tokens_per_local_expert is None:
if num_global_tokens_per_local_expert is None:
raise ValueError(
"num_global_tokens_per_local_expert must be set before sum.")
self.output_splits = (self.num_global_tokens_per_local_expert.sum(
axis=-1).to(torch.device("cpu"), non_blocking=True).numpy())
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(
axis=0)
# ===================================================
# num_global_tokens_per_expert: [ep_size, num_experts]
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
# num_tokens_per_local_expert: [num_local_experts]
# ===================================================
output_splits = (num_global_tokens_per_local_expert.sum(axis=-1).to(
torch.device("cpu"), non_blocking=True).numpy())
num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(
axis=0)
global_input_tokens_local_experts_indices = None
if self.num_local_experts > 1:
if self.num_global_tokens_per_local_expert is None:
if num_global_tokens_per_local_expert is None:
raise ValueError(
"num_global_tokens_per_local_expert must be set before operations."
)
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
global_input_tokens_local_experts_indices = torch.repeat_interleave(
self.expert_ids_per_ep_rank,
self.num_global_tokens_per_local_expert.ravel())
num_global_tokens_per_local_expert.ravel())
else:
# TODO: This full synchronization can be a performance bottleneck.
# A more granular sync (e.g., blocking D2H copies) should be investigated.
torch.npu.synchronize()
return num_tokens_per_local_expert
return (
num_tokens_per_local_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
)
def _dispatch_postprocess(self, global_input_tokens, dynamic_scale=None):
def _dispatch_postprocess(self, global_input_tokens,
dynamic_scale_after_all2all,
global_input_tokens_local_experts_indices):
# Early return if no local experts or no tokens
if self.num_local_experts <= 1:
return global_input_tokens, None
return global_input_tokens, dynamic_scale_after_all2all, None
# Handle quantized case
if self.with_quant:
assert self.global_input_tokens_local_experts_indices is not None, \
"global_input_tokens_local_experts_indices must be initialized before calling _dispatch_postprocess"
expert_idx_2d = self.global_input_tokens_local_experts_indices.unsqueeze(
assert global_input_tokens_local_experts_indices is not None, \
"global_input_tokens_local_experts_indices must be provided"
expert_idx_2d = global_input_tokens_local_experts_indices.unsqueeze(
-1)
active_num = self.global_input_tokens_local_experts_indices.numel()
active_num = global_input_tokens_local_experts_indices.numel()
# Handle case with no active tokens
if active_num <= 0:
self.reversed_global_input_permutation_mapping = self.global_input_tokens_local_experts_indices
return global_input_tokens, dynamic_scale
reversed_global_input_permutation_mapping = global_input_tokens_local_experts_indices
return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
# Process with active tokens
global_input_tokens, self.reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
global_input_tokens, reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
global_input_tokens,
expert_idx_2d,
scale=dynamic_scale,
scale=dynamic_scale_after_all2all,
active_num=active_num,
expert_capacity=0,
expert_num=self.num_local_experts,
@@ -690,32 +708,34 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
expert_tokens_num_flag=True,
active_expert_range=[0, self.num_local_experts],
quant_mode=-1,
row_idx_type=0)
return global_input_tokens, expanded_scale
row_idx_type=0,
)
return global_input_tokens, expanded_scale, reversed_global_input_permutation_mapping
# Handle non-quantized case
global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
global_input_tokens,
self.global_input_tokens_local_experts_indices)
return global_input_tokens, None
# Non-quantized case
global_input_tokens, reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
global_input_tokens, global_input_tokens_local_experts_indices)
return global_input_tokens, None, reversed_global_input_permutation_mapping
def _combine_preprocess(self, hidden_states):
def _combine_preprocess(self, hidden_states: torch.Tensor,
context_metadata: dict) -> torch.Tensor:
# Unpermutation 2: expert output to AlltoAll input
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
rev_global = context_metadata[
"reversed_global_input_permutation_mapping"]
hidden_states = torch_npu.npu_moe_token_unpermute(
hidden_states, self.reversed_global_input_permutation_mapping)
hidden_states, rev_global)
return hidden_states
def _combine_postprocess(self, permutated_local_input_tokens):
def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor,
context_metadata: dict) -> torch.Tensor:
# Unpermutation 1: AlltoAll output to output
output = torch_npu.npu_moe_token_unpermute(
permuted_tokens=permutated_local_input_tokens,
sorted_indices=self.reversed_local_input_permutation_mapping.to(
torch.int32),
probs=self.topk_weights,
restore_shape=self.hidden_shape_before_permute)
# Reshape the output tensor
sorted_indices=context_metadata[
"reversed_local_input_permutation_mapping"].to(torch.int32),
probs=context_metadata["topk_weights"],
restore_shape=self.hidden_shape_before_permute,
)
output = output.view(self.hidden_shape)
return output