### What this PR does / why we need it?
Bump torch version to 2.7.1, and cleanup infer schema patch
https://github.com/vllm-project/vllm-ascend/commit/857f489
(https://github.com/vllm-project/vllm-ascend/pull/837), this patch
depends on also: https://github.com/vllm-project/vllm-ascend/pull/1974
### Does this PR introduce any user-facing change?
No
#### How was this patch tested?
CI passed
torch-npu 2.7.1rc1 install guide:
https://gitee.com/ascend/pytorch/tree/v2.7.1/
install depending:
```
pip3 install pyyaml
pip3 install setuptools
```
install torch-npu:
Closes: https://github.com/vllm-project/vllm-ascend/issues/1866
Closes: https://github.com/vllm-project/vllm-ascend/issues/1390
- vLLM version: v0.10.0
- vLLM main:
9af654cc38
---------
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
194 lines
7.0 KiB
Python
194 lines
7.0 KiB
Python
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm/tests/kernels/test_moe.py
|
|
"""Tests for the MOE layers.
|
|
|
|
Run `pytest tests/ops/test_fused_moe.py`.
|
|
"""
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
import torch
|
|
from vllm.model_executor.layers.activation import SiluAndMul
|
|
|
|
from vllm_ascend.ops.fused_moe import fused_experts, select_experts
|
|
|
|
NUM_EXPERTS = [8, 64]
|
|
EP_SIZE = [1, 4]
|
|
TOP_KS = [2, 6]
|
|
DEVICE = ["npu"]
|
|
|
|
|
|
def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
|
|
B, D = a.shape
|
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
|
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
|
topk_weights = topk_weights.view(-1)
|
|
topk_ids = topk_ids.view(-1)
|
|
if expert_map is not None:
|
|
topk_ids = expert_map[topk_ids]
|
|
for i in range(w1.shape[0]):
|
|
mask = topk_ids == i
|
|
if mask.sum():
|
|
out[mask] = SiluAndMul()(
|
|
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
|
return (out.view(B, -1, w2.shape[1]) *
|
|
topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
|
|
|
|
|
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
|
|
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
|
@pytest.mark.parametrize("k", [128, 511, 1024])
|
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
|
@pytest.mark.parametrize("topk", TOP_KS)
|
|
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
|
@pytest.mark.parametrize("device", DEVICE)
|
|
def test_fused_experts(
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
e: int,
|
|
topk: int,
|
|
ep_size: int,
|
|
dtype: torch.dtype,
|
|
device: str,
|
|
):
|
|
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
|
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
|
|
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
|
|
|
|
score = torch.randn((m, e), device=device, dtype=dtype)
|
|
|
|
if ep_size > 1:
|
|
local_e = e // ep_size
|
|
e_ids = torch.randint(0,
|
|
e, (local_e, ),
|
|
device=device,
|
|
dtype=torch.int32)
|
|
e_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
|
|
e_map[e_ids] = torch.arange(local_e, device=device, dtype=torch.int32)
|
|
w1 = w1[e_ids]
|
|
w2 = w2[e_ids]
|
|
else:
|
|
e_map = None
|
|
|
|
score = torch.softmax(score, dim=-1, dtype=dtype)
|
|
topk_weights, topk_ids = torch.topk(score, topk)
|
|
topk_ids = topk_ids.to(torch.int32)
|
|
|
|
output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map)
|
|
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map)
|
|
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
|
|
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
|
|
torch.npu.empty_cache()
|
|
|
|
|
|
@pytest.mark.parametrize("m", [1, 33, 64])
|
|
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
|
@pytest.mark.parametrize("topk", TOP_KS)
|
|
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
|
|
@pytest.mark.parametrize("use_grouped_topk", [True, False])
|
|
@pytest.mark.parametrize("renormalize", [True, False])
|
|
@pytest.mark.parametrize("with_e_correction", [True, False])
|
|
@pytest.mark.parametrize("custom_routing", [True, False])
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
|
@pytest.mark.parametrize("device", DEVICE)
|
|
def test_select_experts(
|
|
m: int,
|
|
n: int,
|
|
e: int,
|
|
topk: int,
|
|
scoring_func: str,
|
|
use_grouped_topk: bool,
|
|
renormalize: bool,
|
|
with_e_correction: bool,
|
|
custom_routing: bool,
|
|
dtype: torch.dtype,
|
|
device: str,
|
|
):
|
|
topk_group = 4 if use_grouped_topk else None
|
|
num_expert_group = e // 4 if use_grouped_topk else None
|
|
|
|
hidden_states = torch.randn(m, n, device=device, dtype=dtype)
|
|
router_logits = torch.randn(m, e, device=device, dtype=dtype)
|
|
|
|
e_score_correction_bias = (torch.randn(e, device=device, dtype=dtype)
|
|
if with_e_correction else None)
|
|
|
|
custom_routing_function = None
|
|
if custom_routing:
|
|
custom_routing_function = MagicMock()
|
|
mock_weights = torch.randn(m, topk, device=device, dtype=dtype)
|
|
mock_ids = torch.randint(0,
|
|
e, (m, topk),
|
|
device=device,
|
|
dtype=torch.int32)
|
|
custom_routing_function.return_value = (mock_weights, mock_ids)
|
|
|
|
with patch("vllm_ascend.ops.fused_moe.native_grouped_topk"
|
|
) as mock_native_grouped_topk:
|
|
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
|
x)
|
|
|
|
topk_weights, topk_ids = select_experts(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
top_k=topk,
|
|
use_grouped_topk=use_grouped_topk,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
)
|
|
|
|
if use_grouped_topk:
|
|
mock_native_grouped_topk.assert_called_once()
|
|
else:
|
|
mock_native_grouped_topk.assert_not_called()
|
|
|
|
assert topk_weights.shape == (m, topk)
|
|
assert topk_ids.shape == (m, topk)
|
|
assert topk_ids.dtype == torch.int32
|
|
|
|
|
|
@pytest.mark.parametrize("device", DEVICE)
|
|
def test_select_experts_invalid_scoring_func(device: str):
|
|
with pytest.raises(ValueError,
|
|
match="Unsupported scoring function: invalid"):
|
|
select_experts(hidden_states=torch.randn(1, 128, device=device),
|
|
router_logits=torch.randn(1, 8, device=device),
|
|
top_k=2,
|
|
use_grouped_topk=False,
|
|
renormalize=False,
|
|
scoring_func="invalid")
|
|
|
|
|
|
@pytest.mark.parametrize("device", DEVICE)
|
|
def test_select_experts_missing_group_params(device: str):
|
|
with pytest.raises(AssertionError):
|
|
select_experts(hidden_states=torch.randn(1, 128, device=device),
|
|
router_logits=torch.randn(1, 64, device=device),
|
|
top_k=2,
|
|
use_grouped_topk=True,
|
|
renormalize=False,
|
|
scoring_func="softmax")
|