[main] [refactor] refactor common_fused_moe.py (#2706)

### What this PR does / why we need it?
1. Move prepare/finalize operation from moe_comm_method to
/ops/moe/fused_moe_prepare_and_finalize
2. Adapt to token_dispatcher in moe_comm_method
3. Move
moe_comm_method/experts_selector/token_dispatcher/fused_moe_prepare_and_finalize
to /ops/moe
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut

- vLLM version: v0.10.1.1
- vLLM main:
f4962a6d55

Signed-off-by: weichen <calvin_zhu0210@outlook.com>
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-09-08 20:09:50 +08:00
committed by GitHub
parent 1a82b16355
commit a041d4f328
21 changed files with 1052 additions and 932 deletions

View File

View File

@@ -0,0 +1,283 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Callable, Optional
import torch
import torch_npu
def return_row_idx(hidden_states, top_k):
num_tokens = hidden_states.shape[0]
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=hidden_states.device).view(
top_k, -1).permute(1, 0).contiguous())
return row_idx
def select_experts(hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
indices_type: Optional[torch.dtype] = None,
global_num_experts: int = -1):
"""
Fused experts with select experts.
Args:
router_logits: router logits of shape (num_tokens, hidden_size).
hidden_states: Hidden states of shape (num_tokens, hidden_size).
top_k: number of top k experts.
use_grouped_topk: Whether to group experts before selecting top-k.
renormalize: Whether to renormalize the routing weights.
topk_group: Number of expert groups to select from.
num_expert_group: Number of experts in each group.
custom_routing_function: Custom routing function.
scoring_func: Scoring function to use.
e_score_correction_bias: Correction bias to apply to expert scores.
indices_type: dtype of indices
global_num_experts: Global number of experts.
Returns:
topk_weights: router weights of shape (num_tokens, top_k).
topk_ids: selected expert IDs of shape (num_tokens, top_k).
"""
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
global_num_experts=global_num_experts)
if topk_weights is None:
topk_weights, topk_ids = _native_select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
if row_idx is None:
row_idx = return_row_idx(hidden_states, top_k)
return topk_weights, topk_ids, row_idx
def _native_grouped_topk(
topk_weights: torch.Tensor,
num_expert_group: Optional[int],
topk_group: Optional[int],
):
topk_group = 0 if topk_group is None else topk_group
num_expert_group = 0 if num_expert_group is None else num_expert_group
num_token = topk_weights.shape[0]
grouped_weights = topk_weights.view(num_token, num_expert_group,
-1).max(dim=-1).values
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
k=topk_group,
dim=-1,
sorted=False)[1]
topk_group_mask = torch.zeros_like(grouped_weights)
topk_group_mask.scatter_(1, topk_group_indices, 1)
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
return topk_weights
def _renormalize_topk_weights(
topk_weights: torch.Tensor,
renormalize: bool,
):
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights
def _select_expert_use_group_topk(
topk_weights: torch.Tensor, topk_group: Optional[int],
renormalize: bool, top_k: int, num_expert_group: Optional[int],
e_score_correction_bias: Optional[torch.Tensor]):
assert topk_group is not None
assert num_expert_group is not None
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_weights = topk_weights
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
topk_weights = _native_grouped_topk(topk_weights, num_expert_group,
topk_group)
# TODO bfloat16 is not supported in torch.topk with ge graph.
if e_score_correction_bias is not None:
topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_weights.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids
def _select_experts_with_fusion_ops(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
e_score_correction_bias: Optional[torch.Tensor],
topk_group: Optional[int],
num_expert_group: Optional[int],
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
global_num_experts: int = -1):
topk_weights, topk_ids, row_idx = None, None, None
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
is_deepseek_v3_r1 = global_num_experts == 256
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently 8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=
1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
row_idx = return_row_idx(hidden_states, top_k)
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
x=router_logits, finished=None, k=top_k)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids, row_idx
def _native_select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select top-k experts based on router logits.
Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
router_logits: Router logits of shape (num_tokens, num_experts).
top_k: Number of experts to select.
use_grouped_topk: Whether to group experts before selecting top-k.
renormalize: Whether to renormalize the routing weights.
topk_group: Number of expert groups to select from.
num_expert_group: Number of experts in each group.
custom_routing_function: Custom routing function.
scoring_func: Scoring function to use.
e_score_correction_bias: Correction bias to apply to expert scores.
Returns:
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
Raises:
ValueError: If an unsupported scoring function is provided.
"""
if scoring_func == "softmax":
topk_weights = router_logits.softmax(dim=-1)
elif scoring_func == "sigmoid":
topk_weights = router_logits.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
if use_grouped_topk:
return _select_expert_use_group_topk(
topk_weights=topk_weights,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
e_score_correction_bias=e_score_correction_bias)
if custom_routing_function is not None:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
global_num_experts=global_num_experts)
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
return topk_weights, topk_ids
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
topk_weights = topk_weights.to(hidden_states.dtype)
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids

View File

@@ -0,0 +1,240 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from abc import ABC, abstractmethod
import torch
import torch.distributed as dist
import torch.nn as nn
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_dp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
class FusedMoEPrepareAndFinalize(ABC):
def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config
@abstractmethod
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
raise NotImplementedError("Prepare not implemented.")
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
raise NotImplementedError("Combine function not implemented.")
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
# tp_size and tp_rank.
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""The target_pad_length is calculated in forward_context, here we pad the
hidden states and router logits. And if TP size > 1, we also need to split
the tensors accordingly.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
if not self.replace_allreduce:
self.num_tokens, _ = hidden_states.shape
forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask
target_pad_length = forward_context.padded_num_tokens
pad_size = target_pad_length - self.num_tokens
if pad_size > 0 and not self.enable_shared_expert_dp:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
if self.tp_size > 1:
if not self.enable_shared_expert_dp:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
self.split_hidden_states = split_hidden_states
split_mc2_mask = torch.tensor_split(mc2_mask,
self.tp_size,
dim=0)
mc2_mask = split_mc2_mask[self.tp_rank]
return hidden_states, router_logits, mc2_mask
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""If TP size > 1, all-gather the hidden states to get the final output.
Also, unpad the hidden states if needed.
"""
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
# tp_size and tp_rank.
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
if not (self.replace_allreduce or self.enable_shared_expert_dp):
self.num_tokens, _ = hidden_states.shape
pad_size = self.tp_size - self.num_tokens
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
self.split_hidden_states = split_hidden_states
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
return hidden_states, router_logits, None
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""If TP size > 1, all-gather the hidden states to get the final output.
Also, unpad the hidden states if needed.
"""
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""When DP size > 1, pad the hidden states and router logits for communication."""
self.rm_router_logits = rm_router_logits
self.enable_shared_expert_dp = enable_shared_expert_dp
if self.moe_config.dp_size > 1:
forward_context = get_forward_context()
max_tokens_across_dp = forward_context.max_tokens_across_dp
self.num_tokens = hidden_states.shape[0]
pad_size = max_tokens_across_dp - self.num_tokens
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
if not self.rm_router_logits:
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
hidden_states = self.moe_config.dp_group.all_gather(
hidden_states, 0)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)
else:
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
return hidden_states, router_logits, None
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""When DP size > 1, reduce-scatter the hidden states to get the final output.
When TP size > 1, all-reduce the hidden states to get the final output.
"""
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
hidden_states = get_dp_group().reduce_scatter(hidden_states, 0)
hidden_states = hidden_states[:self.num_tokens]
if reduce_results and (self.moe_config.tp_size > 1
or self.moe_config.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
return hidden_states

View File

@@ -0,0 +1,298 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from abc import ABC, abstractmethod
from typing import Any, Optional
import torch
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
FusedMoEPrepareAndFinalizeWithAll2All,
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2)
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather,
TokenDispatcherWithMC2)
class MoECommMethod(ABC):
"""Base class for MoE communication methods."""
def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config
self.mc2_mask = None
self.token_dispatcher = self._get_token_dispatcher()
self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize(
)
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare(
hidden_states, router_logits, enable_shared_expert_dp,
rm_router_logits, replace_allreduce, gate)
self.mc2_mask = mc2_mask
return hidden_states, router_logits
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
hidden_states = self.fused_moe_prepare_finalize.finalize(
hidden_states, reduce_results)
return hidden_states
def fused_experts(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
# For TorchAir graph
is_torchair: bool = False,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
shared_gate_up: Optional[Any] = None,
shared_dequant_scale: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
need_trans: bool = False) -> torch.Tensor:
# Check constraints
assert hidden_states.shape[1] == w1.shape[1], (
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}")
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(
), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
moe_comm_method = get_forward_context().moe_comm_method
assert moe_comm_method is not None, "Missing communication context"
results = self.token_dispatcher.token_dispatch(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
mc2_mask=self.mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=use_int8_w8a8 or use_int4_w4a8)
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = \
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"]
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=expert_tokens,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
with_quant=use_int8_w8a8
or use_int4_w4a8,
need_trans=need_trans)
hidden_states[:] = self.token_dispatcher.token_combine(
hidden_states=mlp_output)
return hidden_states
@abstractmethod
def _get_token_dispatcher(self):
raise NotImplementedError(
"_get_token_dispatcher function not implemented.")
@abstractmethod
def _get_fused_moe_prepare_finalize(self):
raise NotImplementedError(
"_get_fused_moe_prepare_finalize function not implemented.")
class AllGatherCommImpl(MoECommMethod):
"""This implementation is the same as NativeAllGatherCommImpl,
but uses NPU-specific ops for better performance.
This implementation should be compatible with all scenarios, and
thus it is the default implementation for MoE communication methods.
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
and `torch_npu.npu_moe_token_unpermute` for post-processing
to handle the token-to-expert mapping and communication efficiently.
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
for pre-processing and post-processing, respectively.
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
use `torch_npu.npu_moe_token_unpermute` instead.
This is a workaround and should be removed after the issue is fixed.
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithAllGather(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_fused_moe_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
class NativeAllGatherCommImpl(AllGatherCommImpl):
"""This implementation should be compatible with all scenarios.
Note that this implementation purely consists of native PyTorch ops
and does not use any NPU-specific ops. So the performance may not be optimal.
But it is a good fallback for scenarios where NPU-specific ops are not available.
"""
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
apply_a8_quantization: bool,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
num_tokens = hidden_states.shape[0]
# Generate token indices and flatten
token_indices = torch.arange(num_tokens,
device=hidden_states.device,
dtype=torch.int64)
token_indices = (token_indices.unsqueeze(1).expand(
-1, self.moe_config.experts_per_token).reshape(-1))
# Flatten token-to-expert mappings and map to local experts
weights_flat = topk_weights.view(-1)
experts_flat = topk_ids.view(-1)
local_experts_flat = (expert_map[experts_flat]
if expert_map is not None else experts_flat)
# Filter valid token-expert pairs
mask = local_experts_flat != -1
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# So we need to filter out invalid tokens by zeroing their weights.
# This is a workaround and should be removed after the issue is fixed
filtered_weights = torch.where(mask, weights_flat,
torch.zeros_like(weights_flat)).to(
topk_weights.dtype)
filtered_experts = torch.where(
mask,
local_experts_flat,
torch.full_like(local_experts_flat, num_experts),
).to(topk_ids.dtype)
# Sort by local expert IDs
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
self.sorted_token_indices = token_indices[sort_indices]
self.sorted_weights = filtered_weights[sort_indices]
# Compute token counts with minlength of num_experts
# This is equivalent to but faster than:
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
token_counts = torch.zeros(num_experts + 1,
device=hidden_states.device,
dtype=torch.int64)
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
expert_tokens = token_counts[:num_experts]
# Rearrange hidden_states
permuted_hidden_states = hidden_states[self.sorted_token_indices]
group_list_type = 1 # `count` mode
return permuted_hidden_states, expert_tokens, None, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
final_hidden_states = torch.zeros_like(hidden_states)
final_hidden_states.index_add_(0, self.sorted_token_indices,
mlp_output)
hidden_states[:] = final_hidden_states
class MC2CommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
3. `enable_expert_parallel=False` is not supported.
This implementation uses the MC2 communication method, which is optimized for
Communication and Computation parallelism on Ascend devices.
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithMC2()
def _get_fused_moe_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
class AlltoAllCommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_grouped_matmul` is available.
This implementation uses all-to-all communication to exchange tokens
between data parallel ranks before and after the MLP computation. It should
have better performance than AllGatherCommImpl when DP size > 1.
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithAll2AllV(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_fused_moe_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)

View File

@@ -0,0 +1,252 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from typing import Optional
import torch
import torch_npu
from torch.nn.functional import pad
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.utils import dispose_tensor, is_310p
def cumsum_group_list(group_list: torch.Tensor,
group_list_type: int,
active_num: int = 0,
expert_num: int = 0) -> torch.Tensor:
if group_list_type not in [0, 1, 2]:
raise ValueError(
f"group_list_type should be in [0, 1, 2], but received {group_list_type}"
)
if group_list_type == 0:
return group_list
if group_list_type == 1:
return group_list.cumsum(dim=0)
experts = pad(group_list[:, 0], (1, 0))
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
cumsum_group_list = torch.full(size=(expert_num, ),
fill_value=active_num,
dtype=group_list.dtype,
device=group_list.device)
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
if end > start:
cumsum_group_list[start:end] = tokens[i]
return cumsum_group_list
def quant_apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int = 1,
dynamic_scale: torch.Tensor = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
fusion: bool = False) -> torch.Tensor:
if dynamic_scale is None:
unquantized_hidden_states = hidden_states
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
hidden_states)
# Dispose the original unquantized hidden states
# to save npu memory because they're no longer used.
dispose_tensor(unquantized_hidden_states)
else:
pertoken_scale = dynamic_scale
bias1, bias2 = None, None
_output_dtype = w2_scale.dtype
is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2
if w1_scale_bias is None and is_mc2:
if w1_scale.dtype != torch.float32:
w1_scale = w1_scale.to(torch.float32)
if fusion:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
weight=w1,
group_list=cumsum_group_list(group_list, group_list_type),
weight_scale=w1_scale,
x_scale=pertoken_scale)
else:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale,
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=w2_scale.dtype)[0]
else:
if w1_scale_bias is not None:
if group_list_type == 0:
group_list = torch.cat(
[group_list[:1],
torch.diff(group_list, dim=0)])
group_list_type = 1
bias1 = [w1_scale_bias] if not fusion else w1_scale_bias
bias2 = [w2_scale_bias]
# TODO w4a8 scene: dynamic acquisition of dtype in the future
_output_dtype = torch.bfloat16
if fusion:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
weight=w1,
bias=bias1,
group_list=cumsum_group_list(group_list, group_list_type),
weight_scale=w1_scale,
x_scale=pertoken_scale)
else:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
scale=[w1_scale.to(w2_scale.dtype)],
bias=bias1,
per_token_scale=[pertoken_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale],
bias=bias2,
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
return hidden_states
def unquant_apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int = 1,
topk_scales: Optional[torch.Tensor] = None,
need_trans: bool = True) -> torch.Tensor:
if need_trans:
w1 = w1.transpose(1, 2)
w2 = w2.transpose(1, 2)
gate_up_out = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
if is_310p():
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
else:
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
if topk_scales is not None:
gate_up_out *= topk_scales
hidden_states = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
return hidden_states
def unified_apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
topk_scales: Optional[torch.Tensor] = None,
with_quant: bool = False,
fusion: bool = False,
need_trans: bool = True) -> torch.Tensor:
if with_quant:
return quant_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
fusion=fusion)
else:
return unquant_apply_mlp(hidden_states=hidden_states,
w1=w1,
w2=w2,
group_list=group_list,
group_list_type=group_list_type,
topk_scales=topk_scales,
need_trans=need_trans)

View File

@@ -0,0 +1,726 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import torch
import torch_npu
from vllm.distributed.parallel_state import get_ep_group
import vllm_ascend.envs as envs_ascend
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.distributed.tensor_parallel import \
gather_from_sequence_parallel_region
from vllm_ascend.ops.comm_utils import async_all_to_all
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
_Dispatchers: Dict[str, Any] = {}
def _register_token_dispatcher(dispatcher: Any):
_Dispatchers[dispatcher.__class__.__name__] = dispatcher
def get_token_dispatcher(name: str):
return _Dispatchers.get(name)
def setup_token_dispatchers(ep_size: int, **kwargs):
existing_dispatchers = set(_Dispatchers.keys())
if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
elif envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 \
and "TokenDispatcherWithAllGather" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
elif ep_size >= 16:
if "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
if "TokenDispatcherWithMC2" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithMC2(**kwargs))
class MoETokenDispatcher(ABC):
def __init__(self, **kwargs) -> None:
"""
Initialize the MoE Token Dispatcher.
"""
self.top_k = kwargs.get("top_k", 0)
self.num_experts = kwargs.get("num_experts", 0)
@property
def ep_group(self):
"""Get expert model parallel group."""
return get_ep_group().device_group
@property
def ep_rank(self):
return get_ep_group().rank_in_group
@property
def ep_size(self):
return get_ep_group().world_size
@abstractmethod
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
raise NotImplementedError("Combine function not implemented.")
class TokenDispatcherWithMC2(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
self.ep_rank_id = get_mc2_group().rank_in_group
self.ep_world_size = get_mc2_group().world_size
self.enable_dispatch_v2 = hasattr(torch_npu,
"npu_moe_distribute_dispatch_v2")
self.need_extra_args = (
get_ascend_soc_version() == AscendSocVersion.A3)
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
self.a3_need_extra_args = \
get_ascend_soc_version() == AscendSocVersion.A3
self.output = None
self.assist_info_for_combine = None
self.ep_recv_counts = None
self.shared_act = None
self.topk_ids = None
self.topk_weights = None
self.shared_experts = None
self.mc2_mask = None
self.with_quant = False
def get_dispatch_mc2_kwargs(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
global_redundant_expert_num: int = 0,
):
if self.with_quant:
quant_mode = 2
if (expert_map is not None):
moe_expert_num = len(expert_map) + global_redundant_expert_num
else:
moe_expert_num = global_redundant_expert_num
else:
quant_mode = 0
moe_expert_num = len(expert_map)
kwargs_mc2 = {
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
}
stage1_kwargs = {
"scales": None,
"quant_mode": quant_mode,
"group_ep": self.moe_all_to_all_group_name,
"ep_world_size": self.ep_world_size,
"ep_rank_id": self.ep_rank_id,
}
if self.need_extra_args:
stage1_kwargs.update({
"group_tp": self.moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.a3_need_extra_args and self.enable_dispatch_v2:
stage1_kwargs.update({
"x_active_mask": self.mc2_mask,
})
kwargs_mc2.update(stage1_kwargs)
return kwargs_mc2
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.with_quant = with_quant
self.expert_map = expert_map
self.topk_ids = topk_ids
self.topk_weights = topk_weights
self.shared_experts = shared_experts
self.mc2_mask = mc2_mask
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
topk_ids, expert_map,
global_redundant_expert_num)
self.output = torch_npu.npu_moe_distribute_dispatch_v2(
**kwargs_mc2
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, self.assist_info_for_combine, \
expert_token_nums, self.ep_recv_counts = self.output[0:5]
if self.with_quant:
if shared_experts is not None:
shared_act_out = shared_experts.act_fn(
(shared_gate_up, shared_dequant_scale))
self.shared_act, self.swiglu_out_scale = \
shared_act_out[0], shared_act_out[1]
else:
if shared_experts is not None:
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
self.shared_act = shared_experts.act_fn(shared_gate_up)
group_list_type = 1
return {
"group_list_type": group_list_type,
"hidden_states": expand_x,
"group_list": expert_token_nums,
"dynamic_scale": dynamic_scale,
}
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor):
assert self.expert_map is not None
assert self.topk_weights is not None
assert self.topk_ids is not None
assert self.output is not None
moe_expert_num = len(self.expert_map)
# moeCombine
kwargs_mc2 = {
"expand_x": hidden_states,
"expert_ids": self.topk_ids,
"expert_scales": self.topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
}
if self.with_quant:
tp_recv_counts = torch.empty(1,
dtype=torch.int32,
device=hidden_states.device)
else:
tp_recv_counts = self.output[5]
stage3_kwargs = {
"ep_send_counts": self.ep_recv_counts,
"group_ep": self.moe_all_to_all_group_name,
"ep_world_size": self.ep_world_size,
"ep_rank_id": self.ep_rank_id,
}
if self.enable_dispatch_v2:
stage3_kwargs.update({
"assist_info_for_combine":
self.assist_info_for_combine,
})
else:
stage3_kwargs.update({
"expand_idx": self.assist_info_for_combine,
})
if self.need_extra_args:
stage3_kwargs.update({
"tp_send_counts": tp_recv_counts,
"group_tp": self.moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.a3_need_extra_args and self.enable_dispatch_v2:
stage3_kwargs.update({
"x_active_mask": self.mc2_mask,
})
kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states)
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
**kwargs_mc2
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
**kwargs_mc2)
if self.shared_experts is None:
return hidden_states
else:
if self.with_quant:
shared_hidden_states, _ = self.shared_experts.down_proj(
(self.shared_act, self.swiglu_out_scale))
else:
shared_hidden_states, _ = self.shared_experts.down_proj(
self.shared_act)
return hidden_states, shared_hidden_states
class TokenDispatcherWithAllGather(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.apply_router_weight_on_input = False
self.max_num_tokens = kwargs.get("max_num_tokens")
self.num_experts_local = kwargs.get("num_local_experts", 0)
self.sorted_weights = None
self.expanded_row_idx = None
self.sorted_token_indices = None
self.original_shape = None
self.mask = None
self.expert_map = None
self.topk_weights = None
self.topk_ids = None
self.with_quant = False
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.with_quant = with_quant
self.original_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel()
self.expert_map = expert_map
self.topk_weights = topk_weights
self.topk_ids = topk_ids
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * \
topk_weights.to(hidden_states.dtype)
if expert_map is not None:
global_num_experts = len(expert_map)
mask = (expert_map[topk_ids] != -1)
self.topk_weights = topk_weights * mask
first_expert_idx = get_ep_group(
).rank_in_group * self.num_experts_local
last_expert_idx = first_expert_idx + self.num_experts_local
else:
first_expert_idx = 0
last_expert_idx = self.num_experts_local
global_num_experts = self.num_experts_local
sorted_hidden_states, self.expanded_row_idx, expert_tokens, _ = (
torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
active_num=num_tokens * self.top_k,
expert_num=global_num_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=-1,
))
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode
return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": expert_tokens,
}
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
assert self.original_shape is not None
final_hidden_states = torch_npu.npu_moe_token_unpermute(
permuted_tokens=hidden_states,
sorted_indices=self.expanded_row_idx,
probs=self.topk_weights)
if len(self.original_shape) == 3:
final_hidden_states = final_hidden_states.view(self.original_shape)
return final_hidden_states
# mypy: disable-error-code="override"
class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.apply_router_weight_on_input = False
self.local_ep = 1
self.local_num_experts = self.num_experts // self.local_ep
self.local_num_group = self.top_k // self.local_ep
self.bsz = None
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * \
topk_weights.to(hidden_states.dtype)
self.bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
sorted_hidden_states = hidden_states.index_select(
0, self.sorted_topk_ids // self.local_num_group)
experts_id = torch.arange(0,
self.local_num_experts,
dtype=topk_ids.dtype,
device=topk_ids.device)
num_tokens_per_expert = (
flatten_topk_ids.unsqueeze(-1) == experts_id).to(
torch.float32).sum(0)
topk_scales = topk_weights.view(-1).index_select(
0, self.sorted_topk_ids).unsqueeze(-1)
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
group_list_type = 0
return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": group_list,
"topk_scales": topk_scales,
}
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
torch.int32)
unsorted_hidden_states = hidden_states.index_select(
0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
self.bsz, self.top_k // self.local_ep, -1).sum(1)
return final_hidden_states
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
"""
The implementation of the AlltoAll-based token dispatcher, which handles token
dispatching on the sequence level instead of token level. The core of this implementation
lies in each device dispatching on the entire sequence, with the hidden state being partitioned.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.with_quant = False
self.num_local_experts = kwargs.get("num_local_experts", 0)
self.num_global_redundant_experts = kwargs.get(
"num_global_redundant_experts", 0)
self.num_experts = self.num_experts + self.num_global_redundant_experts
self.hidden_shape = None
self.topk_weights = None
self.input_splits = None
self.output_splits = None
self.hidden_shape_before_permute = None
# [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent
# to each local expert by all ranks.
self.num_global_tokens_per_local_expert = None
# cached intermediate tensors.
self.tokens_per_expert = None
self.global_input_tokens_local_experts_indices = None
assert self.num_local_experts > 0, "Expected at least one expert"
if self.num_local_experts > 1:
self.expert_ids_per_ep_rank = torch.tensor(
[i % self.num_local_experts for i in range(self.num_experts)],
dtype=torch.int32,
device=torch.npu.current_device(),
)
local_expert_indices_offset = (self.ep_rank * self.num_local_experts)
self.local_expert_indices = [
local_expert_indices_offset + i
for i in range(self.num_local_experts)
]
assert (len(self.local_expert_indices) == self.num_local_experts
), "Invalid local expert indices"
for i in range(len(self.local_expert_indices) - 1):
assert (self.local_expert_indices[i] ==
self.local_expert_indices[i + 1] -
1), "local_expert_indices must be continuous"
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.with_quant = with_quant
self.hidden_shape = hidden_states.shape
self.topk_weights = topk_weights
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"
assert topk_ids.dim() == 2, "Expected 2D tensor for routing map"
if log2phy is not None:
topk_ids = log2phy[topk_ids]
permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = self._dispatch_preprocess(
hidden_states, topk_ids)
self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping
dynamic_scale_after_all2all = None
if self.with_quant:
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(
permutated_local_input_tokens)
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
dynamic_scale,
self.output_splits,
self.input_splits,
self.ep_group,
)
permute2_ep_all_to_all_handle.wait()
dynamic_scale.untyped_storage().resize_(0)
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
permutated_local_input_tokens,
self.output_splits,
self.input_splits,
self.ep_group,
)
permute1_ep_all_to_all_handle.wait()
permutated_local_input_tokens.untyped_storage().resize_(0)
global_input_tokens, dynamic_scale = self._dispatch_postprocess(
global_input_tokens, dynamic_scale_after_all2all)
return {
"hidden_states": global_input_tokens,
"group_list": tokens_per_expert,
"dynamic_scale": dynamic_scale,
"group_list_type": 1
}
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
hidden_states = self._combine_preprocess(hidden_states)
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
_, permutated_local_input_tokens, handle = async_all_to_all(
hidden_states, self.input_splits, self.output_splits,
self.ep_group)
handle.wait()
hidden_states.untyped_storage().resize_(0)
output = self._combine_postprocess(permutated_local_input_tokens)
self.input_splits = None
self.output_splits = None
self.num_global_tokens_per_local_expert = None
return output
def _dispatch_preprocess(self, hidden_states, topk_ids):
assert self.hidden_shape is not None
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self._preprocess(topk_ids)
self.hidden_shape_before_permute = hidden_states.shape
permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
tokens=hidden_states,
indices=topk_ids,
num_out_tokens=self.num_out_tokens,
)
return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert
def _preprocess(self, topk_ids: torch.Tensor) -> torch.Tensor:
num_local_tokens_per_expert = torch.histc(topk_ids,
bins=self.num_experts,
min=0,
max=self.num_experts)
ep_size = self.ep_size
# Dropless
self.num_out_tokens = topk_ids.numel()
# ===================================================
# Calculate input_splits, output_splits for alltoall-v.
# ===================================================
self.input_splits = (num_local_tokens_per_expert.reshape(
ep_size,
self.num_local_experts).sum(axis=1).to(torch.device("cpu"),
non_blocking=True).numpy())
num_global_tokens_per_expert = gather_from_sequence_parallel_region(
num_local_tokens_per_expert,
group=self.ep_group).reshape(ep_size, self.num_experts)
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
0]:self.local_expert_indices[-1] + 1]
if self.num_global_tokens_per_local_expert is None:
raise ValueError(
"num_global_tokens_per_local_expert must be set before sum.")
self.output_splits = (self.num_global_tokens_per_local_expert.sum(
axis=-1).to(torch.device("cpu"), non_blocking=True).numpy())
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(
axis=0)
# ===================================================
# num_global_tokens_per_expert: [ep_size, num_experts]
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
# num_tokens_per_local_expert: [num_local_experts]
# ===================================================
if self.num_local_experts > 1:
if self.num_global_tokens_per_local_expert is None:
raise ValueError(
"num_global_tokens_per_local_expert must be set before operations."
)
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
self.expert_ids_per_ep_rank,
self.num_global_tokens_per_local_expert.ravel())
return num_tokens_per_local_expert
def _dispatch_postprocess(self, global_input_tokens, dynamic_scale=None):
# Early return if no local experts or no tokens
if self.num_local_experts <= 1:
return global_input_tokens, None
# Handle quantized case
if self.with_quant:
assert self.global_input_tokens_local_experts_indices is not None, \
"global_input_tokens_local_experts_indices must be initialized before calling _dispatch_postprocess"
expert_idx_2d = self.global_input_tokens_local_experts_indices.unsqueeze(
-1)
active_num = self.global_input_tokens_local_experts_indices.numel()
# Handle case with no active tokens
if active_num <= 0:
self.reversed_global_input_permutation_mapping = self.global_input_tokens_local_experts_indices
return global_input_tokens, dynamic_scale
# Process with active tokens
global_input_tokens, self.reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
global_input_tokens,
expert_idx_2d,
scale=dynamic_scale,
active_num=active_num,
expert_capacity=0,
expert_num=self.num_local_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[0, self.num_local_experts],
quant_mode=-1,
row_idx_type=0)
return global_input_tokens, expanded_scale
# Handle non-quantized case
global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
global_input_tokens,
self.global_input_tokens_local_experts_indices)
return global_input_tokens, None
def _combine_preprocess(self, hidden_states):
# Unpermutation 2: expert output to AlltoAll input
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
hidden_states = torch_npu.npu_moe_token_unpermute(
hidden_states, self.reversed_global_input_permutation_mapping)
return hidden_states
def _combine_postprocess(self, permutated_local_input_tokens):
# Unpermutation 1: AlltoAll output to output
output = torch_npu.npu_moe_token_unpermute(
permuted_tokens=permutated_local_input_tokens,
sorted_indices=self.reversed_local_input_permutation_mapping.to(
torch.int32),
probs=self.topk_weights,
restore_shape=self.hidden_shape_before_permute)
# Reshape the output tensor
output = output.view(self.hidden_shape)
return output