[Refactor][MoE] remove redundant code after refactoring fused_moe (#2612)
### What this PR does / why we need it?
There are a lot of redundant codes related to moe here, and the
structure is not very clear.
We did the following things:
we have placed the relatively independent code related to apply_mlp into
a separate file;
removed the environment variables of alltoall_buffer and alltoall_seq.
Remove the code related to alltoall_buffer and alltoall_seq, and retain
the sole TokenDispatcher inheritance class.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e&ut
- vLLM version: v0.10.1.1
- vLLM main:
4071c76cf3
---------
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -22,6 +22,8 @@ import torch_npu
|
||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import \
|
||||
FusedMoEParallelConfig # isort: skip
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod)
|
||||
|
||||
@@ -30,7 +32,6 @@ from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl,
|
||||
MC2CommImpl)
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe import fused_experts_moge
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
setup_token_dispatchers
|
||||
@@ -139,6 +140,95 @@ def fused_experts(
|
||||
return hidden_states
|
||||
|
||||
|
||||
def fused_experts_moge(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
||||
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
top_k: Number of experts to select.
|
||||
expert_map: Expert mapping of shape (num_experts,).
|
||||
|
||||
Returns:
|
||||
hidden_states: Hidden states after routing.
|
||||
"""
|
||||
ep_size = moe_parallel_config.ep_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
local_num_group = top_k // ep_size
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
||||
|
||||
bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, sorted_topk_ids // local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[sorted_hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
||||
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
bsz, top_k // ep_size, -1).sum(1)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
||||
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user