Files
xc-llm-ascend/tests/e2e/singlecard/ops/test_fused_moe.py
wangxiyuan fef18b60bc Refactor e2e CI (#2276)
Refactor E2E CI to make it clear and faster
1. remove some uesless e2e test
2. remove some uesless function
3. Make sure all test runs with VLLMRunner to avoid oom error
4. Make sure all ops test end with torch.empty_cache to avoid oom error
5. run the test one by one to avoid resource limit error


- vLLM version: v0.10.1.1
- vLLM main:
a344a5aa0a

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-09-02 09:02:22 +08:00

285 lines
9.6 KiB
Python

# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py
"""Tests for the MOE layers.
Run `pytest tests/ops/test_fused_moe.py`.
"""
import gc
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch_npu
from vllm.model_executor.layers.activation import SiluAndMul
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
TokenDispatcherWithAllGather
NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4]
TOP_KS = [2, 6]
DEVICE = ["npu"]
def apply_mlp(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int = 1,
) -> torch.Tensor:
w1 = w1.transpose(1, 2)
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
hidden_states = torch_npu.npu_swiglu(hidden_states)
w2 = w2.transpose(1, 2)
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
return hidden_states
def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
topk_weights = topk_weights.view(-1)
topk_ids = topk_ids.view(-1)
if expert_map is not None:
topk_ids = expert_map[topk_ids]
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("device", DEVICE)
def test_token_dispatcher_with_all_gather(
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
device: str,
):
a = torch.randn((m, k), device=device, dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
score = torch.randn((m, e), device=device, dtype=dtype)
expert_map = None
local_e = e
w1_local = w1
w2_local = w2
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.arange(local_e * 0,
local_e * (0 + 1),
device=device,
dtype=torch.int32)
expert_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
expert_map[e_ids] = torch.arange(local_e,
device=device,
dtype=torch.int32)
w1_local = w1[e_ids]
w2_local = w2[e_ids]
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)
row_idx = (torch.arange(
0,
m * topk,
device=device,
dtype=torch.int32,
).view(topk, -1).permute(1, 0).contiguous())
dispatcher_kwargs = {
"num_experts": e,
"top_k": topk,
"num_local_experts": local_e,
}
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
apply_router_weight_on_input = False
dispatch_output = dispatcher.token_dispatch(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
sorted_hidden_states = dispatch_output["hidden_states"]
group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1)
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
w1=w1_local,
w2=w2_local,
group_list=group_list,
group_list_type=group_list_type)
combined_output = dispatcher.token_combine(hidden_states=expert_output,
bias=None)
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
expert_map)
torch.testing.assert_close(combined_output,
torch_output,
atol=4e-2,
rtol=1)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@pytest.mark.parametrize("m", [1, 33, 64])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("use_grouped_topk", [True, False])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("with_e_correction", [True, False])
@pytest.mark.parametrize("custom_routing", [True, False])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts(
m: int,
n: int,
e: int,
topk: int,
scoring_func: str,
use_grouped_topk: bool,
renormalize: bool,
with_e_correction: bool,
custom_routing: bool,
dtype: torch.dtype,
device: str,
):
topk_group = 4 if use_grouped_topk else None
num_expert_group = e // 4 if use_grouped_topk else None
hidden_states = torch.randn(m, n, device=device, dtype=dtype)
router_logits = torch.randn(m, e, device=device, dtype=dtype)
e_score_correction_bias = (torch.randn(e, device=device, dtype=dtype)
if with_e_correction else None)
custom_routing_function = None
if custom_routing:
custom_routing_function = MagicMock()
mock_weights = torch.randn(m, topk, device=device, dtype=dtype)
mock_ids = torch.randint(0,
e, (m, topk),
device=device,
dtype=torch.int32)
custom_routing_function.return_value = (mock_weights, mock_ids)
with patch("vllm_ascend.ops.layers.experts_selector._native_grouped_topk"
) as mock_native_grouped_topk:
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
if use_grouped_topk:
mock_native_grouped_topk.assert_called_once()
else:
mock_native_grouped_topk.assert_not_called()
assert topk_weights.shape == (m, topk)
assert topk_ids.shape == (m, topk)
assert topk_ids.dtype == torch.int32
assert row_idx.shape == (m, topk)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts_invalid_scoring_func(device: str):
with pytest.raises(ValueError,
match="Unsupported scoring function: invalid"):
select_experts(hidden_states=torch.randn(1, 128, device=device),
router_logits=torch.randn(1, 8, device=device),
top_k=2,
use_grouped_topk=False,
renormalize=False,
scoring_func="invalid")
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts_missing_group_params(device: str):
with pytest.raises(AssertionError):
select_experts(hidden_states=torch.randn(1, 128, device=device),
router_logits=torch.randn(1, 64, device=device),
top_k=2,
use_grouped_topk=True,
renormalize=False,
scoring_func="softmax")
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()