[Refactor][MoE] remove redundant code after refactoring fused_moe (#2612)

### What this PR does / why we need it?
There are a lot of redundant codes related to moe here, and the
structure is not very clear.
We did the following things:

we have placed the relatively independent code related to apply_mlp into
a separate file;
removed the environment variables of alltoall_buffer and alltoall_seq.
Remove the code related to alltoall_buffer and alltoall_seq, and retain
the sole TokenDispatcher inheritance class.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e&ut

- vLLM version: v0.10.1.1
- vLLM main:
4071c76cf3

---------

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-08-30 22:28:50 +08:00
committed by GitHub
parent 20ae71291d
commit 3a5fc5ee01
13 changed files with 417 additions and 1237 deletions

View File

@@ -22,6 +22,8 @@ import torch_npu
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import \
FusedMoEParallelConfig # isort: skip
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod)
@@ -30,7 +32,6 @@ from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl,
MC2CommImpl)
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe import fused_experts_moge
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
setup_token_dispatchers
@@ -139,6 +140,95 @@ def fused_experts(
return hidden_states
def fused_experts_moge(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
moe_parallel_config: FusedMoEParallelConfig,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
global_num_experts: int,
expert_map: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
top_k: Number of experts to select.
expert_map: Expert mapping of shape (num_experts,).
Returns:
hidden_states: Hidden states after routing.
"""
ep_size = moe_parallel_config.ep_size
local_num_experts = global_num_experts // ep_size
local_num_group = top_k // ep_size
if apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
sorted_hidden_states = hidden_states.index_select(
0, sorted_topk_ids // local_num_group)
experts_id = torch.arange(0,
local_num_experts,
dtype=topk_ids.dtype,
device=topk_ids.device)
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
torch.float32).sum(0)
topk_scales = topk_weights.view(-1).index_select(
0, sorted_topk_ids).unsqueeze(-1)
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
gate_up_out = torch_npu.npu_grouped_matmul(
x=[sorted_hidden_states],
weight=[w1],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)[0]
if is_310p():
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
else:
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
gate_up_out *= topk_scales
down_out_list = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)[0]
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
bsz, top_k // ep_size, -1).sum(1)
return final_hidden_states
def unquantized_fused_moe_init_func(self, *args, **kwargs):
original_unquantized_fused_moe_init_func(self, *args, **kwargs)