v0.10.1rc1
This commit is contained in:
0
tests/e2e/singlecard/ops/__init__.py
Normal file
0
tests/e2e/singlecard/ops/__init__.py
Normal file
46
tests/e2e/singlecard/ops/test_bgmv_expand.py
Normal file
46
tests/e2e/singlecard/ops/test_bgmv_expand.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import gc
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
DEFAULT_ATOL = 1e-3
|
||||
DEFAULT_RTOL = 1e-3
|
||||
|
||||
|
||||
def bgmv_expand_cpu_impl(x: torch.Tensor, w: torch.Tensor,
|
||||
indices: torch.Tensor, y: torch.tensor,
|
||||
slice_offset: int, slice_size: int) -> torch.Tensor:
|
||||
W = w[indices, :, :].transpose(-1, -2).to(torch.float32)
|
||||
z = torch.bmm(x.unsqueeze(1).to(torch.float32), W).squeeze()
|
||||
y[:, slice_offset:slice_offset + slice_size] += z
|
||||
return y
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_bgmv_expand():
|
||||
B = 1
|
||||
x = torch.randn([B, 16], dtype=torch.float)
|
||||
w = torch.randn([64, 128, 16], dtype=torch.float16)
|
||||
indices = torch.zeros([B], dtype=torch.int64)
|
||||
y = torch.randn([B, 128 * 3], dtype=torch.float16)
|
||||
|
||||
x_npu = x.npu()
|
||||
w_npu = w.npu()
|
||||
indices_npu = indices.npu()
|
||||
y_npu = y.npu()
|
||||
|
||||
y_out = bgmv_expand_cpu_impl(x, w, indices, y, 0, 128)
|
||||
y_out_npu = torch.ops._C.bgmv_expand(x_npu, w_npu, indices_npu, y_npu, 0,
|
||||
128)
|
||||
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(y_out_npu.cpu(),
|
||||
y_out,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
45
tests/e2e/singlecard/ops/test_bgmv_shrink.py
Normal file
45
tests/e2e/singlecard/ops/test_bgmv_shrink.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import gc
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
DEFAULT_ATOL = 1e-3
|
||||
DEFAULT_RTOL = 1e-3
|
||||
|
||||
|
||||
def bgmv_shrink_cpu_impl(x: torch.Tensor, w: torch.Tensor,
|
||||
indices: torch.Tensor, y: torch.tensor,
|
||||
scaling: float) -> torch.Tensor:
|
||||
W = w[indices, :, :].transpose(-1, -2).to(torch.float32)
|
||||
z = torch.bmm(x.unsqueeze(1).to(torch.float32), W).squeeze()
|
||||
y[:, :] += z * scaling
|
||||
return y
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_bgmv_shrink():
|
||||
B = 1
|
||||
x = torch.randn([B, 128], dtype=torch.float16)
|
||||
w = torch.randn([64, 16, 128], dtype=torch.float16)
|
||||
indices = torch.zeros([B], dtype=torch.int64)
|
||||
y = torch.zeros([B, 16])
|
||||
|
||||
x_npu = x.npu()
|
||||
w_npu = w.npu()
|
||||
indices_npu = indices.npu()
|
||||
y_npu = y.npu()
|
||||
|
||||
y = bgmv_shrink_cpu_impl(x, w, indices, y, 0.5)
|
||||
torch.ops._C.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5)
|
||||
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(y_npu.cpu(),
|
||||
y,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
284
tests/e2e/singlecard/ops/test_fused_moe.py
Normal file
284
tests/e2e/singlecard/ops/test_fused_moe.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
"""Tests for the MOE layers.
|
||||
|
||||
Run `pytest tests/ops/test_fused_moe.py`.
|
||||
"""
|
||||
|
||||
import gc
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
TokenDispatcherWithAllGather
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1, 4]
|
||||
TOP_KS = [2, 6]
|
||||
DEVICE = ["npu"]
|
||||
|
||||
|
||||
def apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
) -> torch.Tensor:
|
||||
w1 = w1.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
topk_weights = topk_weights.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(
|
||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("device", DEVICE)
|
||||
def test_token_dispatcher_with_all_gather(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
|
||||
|
||||
score = torch.randn((m, e), device=device, dtype=dtype)
|
||||
expert_map = None
|
||||
local_e = e
|
||||
w1_local = w1
|
||||
w2_local = w2
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
e_ids = torch.arange(local_e * 0,
|
||||
local_e * (0 + 1),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
expert_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
|
||||
expert_map[e_ids] = torch.arange(local_e,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
w1_local = w1[e_ids]
|
||||
w2_local = w2[e_ids]
|
||||
|
||||
score = torch.softmax(score, dim=-1, dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
row_idx = (torch.arange(
|
||||
0,
|
||||
m * topk,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
).view(topk, -1).permute(1, 0).contiguous())
|
||||
|
||||
dispatcher_kwargs = {
|
||||
"num_experts": e,
|
||||
"top_k": topk,
|
||||
"num_local_experts": local_e,
|
||||
}
|
||||
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
|
||||
|
||||
apply_router_weight_on_input = False
|
||||
dispatch_output = dispatcher.token_dispatch(
|
||||
hidden_states=a,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
sorted_hidden_states = dispatch_output["hidden_states"]
|
||||
group_list = dispatch_output["group_list"]
|
||||
group_list_type = dispatch_output.get("group_list_type", 1)
|
||||
|
||||
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
|
||||
w1=w1_local,
|
||||
w2=w2_local,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type)
|
||||
|
||||
combined_output = dispatcher.token_combine(hidden_states=expert_output,
|
||||
bias=None)
|
||||
|
||||
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
|
||||
expert_map)
|
||||
|
||||
torch.testing.assert_close(combined_output,
|
||||
torch_output,
|
||||
atol=4e-2,
|
||||
rtol=1)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
|
||||
@pytest.mark.parametrize("use_grouped_topk", [True, False])
|
||||
@pytest.mark.parametrize("renormalize", [True, False])
|
||||
@pytest.mark.parametrize("with_e_correction", [True, False])
|
||||
@pytest.mark.parametrize("custom_routing", [True, False])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("device", DEVICE)
|
||||
def test_select_experts(
|
||||
m: int,
|
||||
n: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
scoring_func: str,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
with_e_correction: bool,
|
||||
custom_routing: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
topk_group = 4 if use_grouped_topk else None
|
||||
num_expert_group = e // 4 if use_grouped_topk else None
|
||||
|
||||
hidden_states = torch.randn(m, n, device=device, dtype=dtype)
|
||||
router_logits = torch.randn(m, e, device=device, dtype=dtype)
|
||||
|
||||
e_score_correction_bias = (torch.randn(e, device=device, dtype=dtype)
|
||||
if with_e_correction else None)
|
||||
|
||||
custom_routing_function = None
|
||||
if custom_routing:
|
||||
custom_routing_function = MagicMock()
|
||||
mock_weights = torch.randn(m, topk, device=device, dtype=dtype)
|
||||
mock_ids = torch.randint(0,
|
||||
e, (m, topk),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
custom_routing_function.return_value = (mock_weights, mock_ids)
|
||||
|
||||
with patch("vllm_ascend.ops.layers.experts_selector._native_grouped_topk"
|
||||
) as mock_native_grouped_topk:
|
||||
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
||||
x)
|
||||
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=topk,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if use_grouped_topk:
|
||||
mock_native_grouped_topk.assert_called_once()
|
||||
else:
|
||||
mock_native_grouped_topk.assert_not_called()
|
||||
|
||||
assert topk_weights.shape == (m, topk)
|
||||
assert topk_ids.shape == (m, topk)
|
||||
assert topk_ids.dtype == torch.int32
|
||||
assert row_idx.shape == (m, topk)
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICE)
|
||||
def test_select_experts_invalid_scoring_func(device: str):
|
||||
with pytest.raises(ValueError,
|
||||
match="Unsupported scoring function: invalid"):
|
||||
select_experts(hidden_states=torch.randn(1, 128, device=device),
|
||||
router_logits=torch.randn(1, 8, device=device),
|
||||
top_k=2,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="invalid")
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICE)
|
||||
def test_select_experts_missing_group_params(device: str):
|
||||
with pytest.raises(AssertionError):
|
||||
select_experts(hidden_states=torch.randn(1, 128, device=device),
|
||||
router_logits=torch.randn(1, 64, device=device),
|
||||
top_k=2,
|
||||
use_grouped_topk=True,
|
||||
renormalize=False,
|
||||
scoring_func="softmax")
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
37
tests/e2e/singlecard/ops/test_gating_top_k_softmax.py
Normal file
37
tests/e2e/singlecard/ops/test_gating_top_k_softmax.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'B',
|
||||
[1, 16, 64, 128, 32768],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
'D',
|
||||
[8, 16, 32, 64, 128],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
'top_k',
|
||||
[1, 2, 4, 8],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"dtype, atol, rtol",
|
||||
[
|
||||
(torch.float16, 1e-3, 1e-3),
|
||||
(torch.bfloat16, 1e-3, 1e-3),
|
||||
],
|
||||
)
|
||||
def test_quant_fpx_linear(B: int, D: int, top_k: int, dtype, atol, rtol):
|
||||
x = torch.rand((B, D), dtype=dtype).to("npu")
|
||||
# finished = torch.randint(1, size=(B,), dtype=torch.bool).to("npu")
|
||||
finished = None
|
||||
y, expert_idx, row_idx = torch_npu.npu_moe_gating_top_k_softmax(x,
|
||||
finished,
|
||||
k=top_k)
|
||||
|
||||
topk_weights = x.softmax(dim=-1)
|
||||
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
torch.allclose(y, topk_weights, atol=atol, rtol=rtol)
|
||||
torch.allclose(expert_idx, topk_ids, atol=atol, rtol=rtol)
|
||||
175
tests/e2e/singlecard/ops/test_moe_comm.py
Normal file
175
tests/e2e/singlecard/ops/test_moe_comm.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
import gc
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import ( # isort: skip
|
||||
FusedMoEConfig, FusedMoEParallelConfig)
|
||||
|
||||
from vllm_ascend.distributed.moe_comm_method import ( # isort: skip
|
||||
AllGatherCommImpl, NativeAllGatherCommImpl)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [16, 128])
|
||||
@pytest.mark.parametrize("hidden_size", [64, 128])
|
||||
@pytest.mark.parametrize("global_num_experts", [8, 16])
|
||||
@pytest.mark.parametrize("num_local_experts", [4, 8])
|
||||
@pytest.mark.parametrize("top_k_num", [2, 4])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("ep_rank", [0, 1])
|
||||
@pytest.mark.parametrize("apply_a8_quantization", [False])
|
||||
def test_all_gather_comm_impl(
|
||||
num_tokens,
|
||||
hidden_size,
|
||||
global_num_experts,
|
||||
num_local_experts,
|
||||
top_k_num,
|
||||
dtype,
|
||||
ep_rank,
|
||||
apply_a8_quantization,
|
||||
mocker,
|
||||
):
|
||||
"""
|
||||
Tests the AllGatherCommImpl against the NativeAllGatherCommImpl.
|
||||
|
||||
This test compares the outputs of the NPU-optimized AllGatherCommImpl
|
||||
with a native PyTorch implementation (NativeAllGatherCommImpl) to ensure
|
||||
correctness across various configurations.
|
||||
"""
|
||||
if top_k_num > global_num_experts:
|
||||
pytest.skip("top_k_num cannot be greater than global_num_experts")
|
||||
if num_local_experts > global_num_experts:
|
||||
pytest.skip(
|
||||
"num_local_experts cannot be greater than global_num_experts")
|
||||
|
||||
device = torch.device("npu")
|
||||
|
||||
# mock get_tensor_model_parallel_rank to return ep_rank
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.fused_moe.config.get_tensor_model_parallel_rank",
|
||||
return_value=ep_rank,
|
||||
)
|
||||
|
||||
# make moe config
|
||||
parallel_config = SimpleNamespace(
|
||||
enable_expert_parallel=num_local_experts < global_num_experts)
|
||||
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
||||
tp_size_=max(2, global_num_experts // num_local_experts),
|
||||
dp_size_=1,
|
||||
vllm_parallel_config=parallel_config,
|
||||
)
|
||||
|
||||
moe_config = FusedMoEConfig(
|
||||
num_experts=global_num_experts,
|
||||
experts_per_token=top_k_num,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=num_local_experts,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
in_dtype=dtype,
|
||||
quant_config=None, # No quantization in this test
|
||||
max_num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
# Instantiate implementations
|
||||
native_impl = NativeAllGatherCommImpl(moe_config)
|
||||
|
||||
all_gather_impl = AllGatherCommImpl(moe_config)
|
||||
|
||||
# --- Input Data ---
|
||||
hidden_states = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=dtype)
|
||||
topk_ids = torch.randint(0,
|
||||
global_num_experts, (num_tokens, top_k_num),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
topk_weights = torch.rand(num_tokens, top_k_num, device=device).to(dtype)
|
||||
topk_weights = torch.nn.functional.softmax(topk_weights, dim=1)
|
||||
|
||||
num_experts = global_num_experts
|
||||
|
||||
expert_map = None
|
||||
if num_local_experts < global_num_experts:
|
||||
# Create a map where some experts are local and some are not
|
||||
expert_map = torch.full((global_num_experts, ), -1, device=device)
|
||||
expert_map[ep_rank * num_local_experts:(ep_rank + 1) *
|
||||
num_local_experts] = torch.arange(num_local_experts,
|
||||
device=device)
|
||||
num_experts = num_local_experts
|
||||
|
||||
# --- Run Native Implementation (Golden Reference) ---
|
||||
native_hidden_states_out = hidden_states.clone()
|
||||
(
|
||||
native_permuted_hidden,
|
||||
native_expert_tokens,
|
||||
_,
|
||||
_,
|
||||
) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map,
|
||||
num_experts, apply_a8_quantization)
|
||||
# Simulate MLP output
|
||||
native_mlp_output = torch.randn_like(native_permuted_hidden)
|
||||
native_impl.unpermute(native_mlp_output, native_hidden_states_out)
|
||||
|
||||
# --- Run AllGather Implementation ---
|
||||
all_gather_hidden_states_out = hidden_states.clone()
|
||||
(
|
||||
all_gather_permuted_hidden,
|
||||
all_gather_expert_tokens,
|
||||
_,
|
||||
_,
|
||||
) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights,
|
||||
expert_map, num_experts, apply_a8_quantization)
|
||||
|
||||
# Use the same simulated MLP output for a fair comparison
|
||||
all_gather_mlp_output = native_mlp_output.clone()
|
||||
|
||||
all_gather_impl.unpermute(all_gather_mlp_output,
|
||||
all_gather_hidden_states_out)
|
||||
|
||||
# --- Assertions ---
|
||||
# Define tolerance based on dtype
|
||||
atol = 1e-3 if dtype == torch.float16 else 1e-2
|
||||
rtol = 1e-3 if dtype == torch.float16 else 1e-2
|
||||
|
||||
# 1. Compare expert_tokens from pre_process
|
||||
assert torch.allclose(native_expert_tokens.to(
|
||||
all_gather_expert_tokens.device),
|
||||
all_gather_expert_tokens,
|
||||
atol=atol,
|
||||
rtol=rtol), "Expert tokens do not match."
|
||||
|
||||
# 2. Compare permuted_hidden_states from pre_process
|
||||
num_valid_tokens = native_expert_tokens.sum()
|
||||
assert torch.allclose(native_permuted_hidden[:num_valid_tokens].to(
|
||||
all_gather_permuted_hidden.device),
|
||||
all_gather_permuted_hidden[:num_valid_tokens],
|
||||
atol=atol,
|
||||
rtol=rtol), "Permuted hidden states do not match."
|
||||
|
||||
# 3. Compare final hidden_states from post_process
|
||||
assert torch.allclose(native_hidden_states_out.to(
|
||||
all_gather_hidden_states_out.device),
|
||||
all_gather_hidden_states_out,
|
||||
atol=atol,
|
||||
rtol=rtol), "Final hidden states do not match."
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
351
tests/e2e/singlecard/ops/test_rotary_embedding.py
Normal file
351
tests/e2e/singlecard/ops/test_rotary_embedding.py
Normal file
@@ -0,0 +1,351 @@
|
||||
# Copyright 2023 The vLLM team.
|
||||
|
||||
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/tests/kernels/test_rotary_embedding.py
|
||||
|
||||
import gc
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
# Only Neox style true scenario is supported for now
|
||||
IS_NEOX_STYLE = [True]
|
||||
DTYPES = [torch.half]
|
||||
HEAD_SIZES = [64, 64, 96, 128, 256]
|
||||
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
||||
NUM_HEADS = [17] # Arbitrary values for testing
|
||||
BATCH_SIZES = [5] # Arbitrary values for testing
|
||||
SEQ_LENS = [11, 4096] # Arbitrary values for testing
|
||||
NUM_TOKENS = [10, 21]
|
||||
SEEDS = [0]
|
||||
DEVICES = [f"npu:{0}"]
|
||||
# Set tolerance to 1 for quant ops
|
||||
DEFAULT_ATOL = 1e-3
|
||||
DEFAULT_RTOL = 1e-3
|
||||
|
||||
|
||||
def _apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/vllm/model_executor/layers/rotary_embedding.py
|
||||
class RotaryEmbedding(nn.Module):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||
# use CPU to compute the cache and then move it to GPU. However, we
|
||||
# create the cache on GPU for faster initialization. This may cause
|
||||
# a slight numerical difference between the HF implementation and ours.
|
||||
inv_freq = 1.0 / (base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A PyTorch-native implementation of forward()."""
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
# test with leading dimension and merge seqlen and batch_size as num_tokens
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding_quant_with_leading_dim(
|
||||
is_neox_style: bool,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, dtype)
|
||||
rope = rope.to(dtype=dtype)
|
||||
num_tokens = batch_size * seq_len
|
||||
positions = torch.randint(0, max_position, (batch_size * seq_len, ))
|
||||
qkv_tensor = torch.randn(num_tokens,
|
||||
num_heads * head_size * 3,
|
||||
dtype=dtype)
|
||||
query, key, _ = qkv_tensor.split(
|
||||
[num_heads * head_size, num_heads * head_size, num_heads * head_size],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||
query, key = torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
rope.head_size,
|
||||
rope.cos_sin_cache,
|
||||
rope.is_neox_style,
|
||||
)
|
||||
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(query.view(ref_query.size()),
|
||||
ref_query,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
torch.testing.assert_close(key.view(ref_key.size()),
|
||||
ref_key,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
class ModelwithRotaryEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.qkv_proj = nn.Linear(hidden_size, num_heads * head_size * 3)
|
||||
self.rope = RotaryEmbedding(
|
||||
head_size=head_size,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
base=base,
|
||||
is_neox_style=is_neox_style,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.o_proj = nn.Linear(num_heads * head_size, hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
query, key = torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
q,
|
||||
k,
|
||||
self.rope.head_size,
|
||||
self.rope.cos_sin_cache,
|
||||
self.rope.is_neox_style,
|
||||
)
|
||||
query = query.view(q.shape)
|
||||
key = key.view(k.shape)
|
||||
o = self.o_proj(query)
|
||||
return o
|
||||
|
||||
|
||||
# The first graph seems will have some accuracy issue when directly run pytest on the ops folder,
|
||||
# add a warmup graph replay for workaround
|
||||
ACL_GRPAH_FIRST_RUN = True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("num_tokens", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_capture_rotary_embedding_in_aclgraph(
|
||||
is_neox_style: bool,
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
max_position_embeddings: int = 8192,
|
||||
base: int = 10000,
|
||||
):
|
||||
"""Test if the rotary embedding can be captured in aclgraph."""
|
||||
torch.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
model = ModelwithRotaryEmbedding(
|
||||
hidden_size=num_heads * head_size,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
base=base,
|
||||
is_neox_style=is_neox_style,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def custom_op_checking_backend(gm: torch.fx.GraphModule, example_input):
|
||||
# Validate if the rotary_embedding custom kernel is indeed inside the graph by
|
||||
# string match
|
||||
graph = str(gm.graph)
|
||||
assert "_C.rotary_embedding" in graph
|
||||
return gm
|
||||
|
||||
static_positions = torch.randint(0, max_position_embeddings,
|
||||
(num_tokens, ))
|
||||
static_hidden_states = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device="npu")
|
||||
compiled_model = torch.compile(model, backend=custom_op_checking_backend)
|
||||
stream = torch.npu.Stream()
|
||||
stream.wait_stream(torch.npu.current_stream())
|
||||
with torch.npu.stream(stream):
|
||||
# warmup the fx graph before capture
|
||||
for i in range(3):
|
||||
static_output = compiled_model(static_positions,
|
||||
static_hidden_states,
|
||||
offsets=None)
|
||||
stream.wait_stream(torch.npu.current_stream())
|
||||
|
||||
aclgraph = torch.npu.NPUGraph()
|
||||
|
||||
with torch.npu.graph(aclgraph):
|
||||
# Capture the model in aclgraph.
|
||||
static_output = compiled_model(static_positions, static_hidden_states)
|
||||
# Capture the model in aclgraph.
|
||||
random_filled_positions = torch.randint(0,
|
||||
max_position_embeddings,
|
||||
(num_tokens, ),
|
||||
device="npu")
|
||||
random_filled_hidden_states = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device="npu")
|
||||
static_positions.copy_(random_filled_positions)
|
||||
static_hidden_states.copy_(random_filled_hidden_states)
|
||||
|
||||
aclgraph.replay()
|
||||
global ACL_GRPAH_FIRST_RUN
|
||||
if ACL_GRPAH_FIRST_RUN:
|
||||
ACL_GRPAH_FIRST_RUN = False
|
||||
return
|
||||
output_reference = model(static_positions, static_hidden_states)
|
||||
torch.testing.assert_close(static_output,
|
||||
output_reference,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
98
tests/e2e/singlecard/ops/test_vocabparallelembedding.py
Normal file
98
tests/e2e/singlecard/ops/test_vocabparallelembedding.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import gc
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
import vllm_ascend.platform # noqa: F401
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
# Test parameters
|
||||
DTYPES = [torch.int32]
|
||||
#SHAPES = [(100,), (5, 20), (3, 4, 5)] # Various tensor shapes
|
||||
#SHAPES = [(3, 4, 8), (3, 4, 5)] # Various tensor shapes
|
||||
SHAPES = [(3, 4, 3)]
|
||||
DEVICES = [f"npu:{0}"]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def get_masked_input_and_mask_ref(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Reference implementation for verification"""
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
||||
input_ < org_vocab_end_index)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
masked_input = vocab_mask * (input_ - valid_offset)
|
||||
return masked_input, ~vocab_mask
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_get_masked_input_and_mask(
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
seed: int,
|
||||
) -> None:
|
||||
# Set random seed
|
||||
torch.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Generate random input tensor
|
||||
input_tensor = torch.randint(0, 1000, shape, dtype=dtype)
|
||||
|
||||
# Test parameters
|
||||
test_case = {
|
||||
"org_start": 100,
|
||||
"org_end": 200,
|
||||
"padding": 0,
|
||||
"added_start": 300,
|
||||
"added_end": 400,
|
||||
}
|
||||
|
||||
# Get reference result
|
||||
ref_masked_input, ref_mask = get_masked_input_and_mask_ref(
|
||||
input_tensor, test_case["org_start"], test_case["org_end"],
|
||||
test_case["padding"], test_case["added_start"], test_case["added_end"])
|
||||
|
||||
# Get custom op result
|
||||
print("input_tensor:", input_tensor)
|
||||
custom_masked_input, custom_mask = torch.ops._C.get_masked_input_and_mask(
|
||||
input_tensor, test_case["org_start"], test_case["org_end"],
|
||||
test_case["padding"], test_case["added_start"], test_case["added_end"])
|
||||
|
||||
ref_masked_input = ref_masked_input.to(dtype)
|
||||
print("custom_masked_input:", custom_masked_input)
|
||||
print("ref_masked_input:", ref_masked_input)
|
||||
print("custom_mask:", custom_mask)
|
||||
print("ref_mask:", ref_mask)
|
||||
# Compare results
|
||||
torch.testing.assert_close(
|
||||
custom_masked_input,
|
||||
ref_masked_input,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Masked input mismatch for case: {test_case}")
|
||||
torch.testing.assert_close(custom_mask,
|
||||
ref_mask,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Mask mismatch for case: {test_case}")
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
Reference in New Issue
Block a user