[main] [refactor] refactor fused_moe.py to enable token_dispatchers (#2570)
### What this PR does / why we need it?
Enable token_dispatcher to replace fused_experts_with_xxx in eager mode
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.1.1
- vLLM main:
704432af3c
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: sherie <963372609@qq.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
Co-authored-by: shiyuan680 <72335504+shiyuan680@users.noreply.github.com>
This commit is contained in:
@@ -24,10 +24,12 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
|
||||
from vllm_ascend.ops.fused_moe import fused_experts
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
TokenDispatcherWithAllGather
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1, 4]
|
||||
@@ -35,6 +37,38 @@ TOP_KS = [2, 6]
|
||||
DEVICE = ["npu"]
|
||||
|
||||
|
||||
def apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
) -> torch.Tensor:
|
||||
w1 = w1.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
@@ -60,7 +94,7 @@ def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("device", DEVICE)
|
||||
def test_fused_experts(
|
||||
def test_token_dispatcher_with_all_gather(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
@@ -75,19 +109,23 @@ def test_fused_experts(
|
||||
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
|
||||
|
||||
score = torch.randn((m, e), device=device, dtype=dtype)
|
||||
expert_map = None
|
||||
local_e = e
|
||||
w1_local = w1
|
||||
w2_local = w2
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
e_ids = torch.randint(0,
|
||||
e, (local_e, ),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
e_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
|
||||
e_map[e_ids] = torch.arange(local_e, device=device, dtype=torch.int32)
|
||||
w1 = w1[e_ids]
|
||||
w2 = w2[e_ids]
|
||||
else:
|
||||
e_map = None
|
||||
e_ids = torch.arange(local_e * 0,
|
||||
local_e * (0 + 1),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
expert_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
|
||||
expert_map[e_ids] = torch.arange(local_e,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
w1_local = w1[e_ids]
|
||||
w2_local = w2[e_ids]
|
||||
|
||||
score = torch.softmax(score, dim=-1, dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
@@ -99,11 +137,42 @@ def test_fused_experts(
|
||||
dtype=torch.int32,
|
||||
).view(topk, -1).permute(1, 0).contiguous())
|
||||
|
||||
output = fused_experts(a, w1, w2, topk_weights, topk_ids, row_idx, topk,
|
||||
e_map)
|
||||
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map)
|
||||
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
|
||||
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
|
||||
dispatcher_kwargs = {
|
||||
"num_experts": e,
|
||||
"top_k": topk,
|
||||
"num_local_experts": local_e,
|
||||
}
|
||||
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
|
||||
|
||||
apply_router_weight_on_input = False
|
||||
dispatch_output = dispatcher.token_dispatch(
|
||||
hidden_states=a,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
sorted_hidden_states = dispatch_output["hidden_states"]
|
||||
group_list = dispatch_output["group_list"]
|
||||
group_list_type = dispatch_output.get("group_list_type", 1)
|
||||
|
||||
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
|
||||
w1=w1_local,
|
||||
w2=w2_local,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type)
|
||||
|
||||
combined_output = dispatcher.token_combine(hidden_states=expert_output,
|
||||
bias=None)
|
||||
|
||||
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
|
||||
expert_map)
|
||||
|
||||
torch.testing.assert_close(combined_output,
|
||||
torch_output,
|
||||
atol=4e-2,
|
||||
rtol=1)
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user