[1/N] Refactor nightly test structure (#5479)

### What this PR does / why we need it?
This patch is a series of refactoring actions, including clarifying the
directory structure of nightly tests, refactoring the config retrieval
logic, and optimizing the workflow, etc. This is the first step:
refactoring the directory structure of nightly to make it more readable
and logical.

- vLLM version: v0.13.0
- vLLM main:
5326c89803

Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
Li Wang
2025-12-30 19:03:02 +08:00
committed by GitHub
parent c85cc045f8
commit e760aae1df
59 changed files with 475 additions and 471 deletions

View File

@@ -0,0 +1,439 @@
import gc
import os
import sys
from pathlib import Path
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch_npu
import torchair
from vllm_ascend.utils import enable_custom_op
torch.manual_seed(42)
torch_npu.npu.config.allow_internal_format = True
enable_custom_op()
LOG_NAME = "dispatch_gmm_combine_decode_test_logs"
BASE_KWARGS = {
"batch_size": 64,
"token_hidden_size": 7168,
"moe_intermediate_size": 2048,
"ep_world_size": 16,
"moe_expert_num": 64,
"shared_expert_rank_num": 0,
"top_k": 8,
"test_bfloat16": True,
"enable_dynamic_bs": False,
"test_graph": False,
"with_mc2_mask": False
}
def redirect_output(log_file_path):
log_path = Path(LOG_NAME) / log_file_path
log_path.parent.mkdir(parents=True, exist_ok=True)
f = open(LOG_NAME + "/" + log_file_path, "w")
os.dup2(f.fileno(), sys.stdout.fileno())
os.dup2(f.fileno(), sys.stderr.fileno())
return f
def permute_weight(w: torch.Tensor, tile_n):
*dims, n = w.shape
order = list(range(len(dims))) + [-2, -3, -1]
return w.reshape(*dims, 2, n // tile_n,
tile_n // 2).permute(order).reshape(*dims,
n).contiguous()
def from_inclusive_prefix_sum(pref):
if isinstance(pref, torch.Tensor):
if pref.numel() == 0:
return pref
return torch.cat([pref[:1], pref[1:] - pref[:-1]])
if not pref:
return []
out = [pref[0]]
for i in range(1, len(pref)):
out.append(pref[i] - pref[i - 1])
return out
def output_to_file(rank_id):
return False
class DecodeMoeOps(torch.nn.Module):
def __init__(self,
gmm1_weight,
gmm1_weight_scale,
gmm2_weight,
gmm2_weight_scale,
ep_hcomm_info,
batch_size,
token_hidden_size,
moe_intermediate_size,
ep_world_size,
moe_expert_num,
global_rank_id,
shared_expert_rank_num=0):
super().__init__()
self.ep_hcomm_info = ep_hcomm_info
self.batch_size = batch_size
self.token_hidden_size = token_hidden_size
self.moe_intermediate_size = moe_intermediate_size
self.ep_world_size = ep_world_size
self.moe_expert_num = moe_expert_num
self.global_rank_id = global_rank_id
self.shared_expert_rank_num = shared_expert_rank_num
is_shared_expert = global_rank_id < shared_expert_rank_num
moe_expert_num_per_rank = moe_expert_num // (ep_world_size -
shared_expert_rank_num)
self.local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
self.ep_recv_count_size = self.local_expert_num * ep_world_size
self.gmm1_weight = torch.empty([
self.local_expert_num, self.token_hidden_size,
self.moe_intermediate_size * 2
])
self.gmm1_weight_scale = torch.empty(
[self.local_expert_num, self.moe_intermediate_size * 2])
self.gmm2_weight = torch.empty([
self.local_expert_num, self.moe_intermediate_size,
self.token_hidden_size
])
self.gmm2_weight_scale = torch.empty(
[self.local_expert_num, self.token_hidden_size])
self._process_weights_after_loading(gmm1_weight, gmm1_weight_scale,
gmm2_weight, gmm2_weight_scale)
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
gmm2_weight, gmm2_weight_scale):
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
torch_npu.Format.FRACTAL_NZ)
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
torch_npu.Format.FRACTAL_NZ)
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
requires_grad=False)
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
requires_grad=False)
self.gmm1_weight_scale_fp32 = torch.nn.Parameter(
gmm1_weight_scale.float(), requires_grad=False)
self.gmm2_weight_scale_fp32 = torch.nn.Parameter(
gmm2_weight_scale.float(), requires_grad=False)
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
x_active_mask):
raise NotImplementedError("To be implemented in subclass")
def forward(self, x, expert_ids, smooth_scales, expert_scales,
x_active_mask):
return self._apply_ops(x, expert_ids, smooth_scales, expert_scales,
x_active_mask)
class SmallOps(DecodeMoeOps):
def __init__(self,
gmm1_weight,
gmm1_weight_scale,
gmm2_weight,
gmm2_weight_scale,
ep_hcomm_info,
batch_size,
token_hidden_size,
moe_intermediate_size,
ep_world_size,
moe_expert_num,
global_rank_id,
shared_expert_rank_num=0):
super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight,
gmm2_weight_scale, ep_hcomm_info, batch_size,
token_hidden_size, moe_intermediate_size,
ep_world_size, moe_expert_num, global_rank_id,
shared_expert_rank_num)
self.tp_hcomm_info = ""
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
x_active_mask):
outputs = torch_npu.npu_moe_distribute_dispatch_v2(
x=x,
expert_ids=expert_ids,
expert_scales=expert_scales,
x_active_mask=x_active_mask,
group_ep=self.ep_hcomm_info,
ep_world_size=self.ep_world_size,
ep_rank_id=self.global_rank_id,
moe_expert_num=self.moe_expert_num,
group_tp=self.tp_hcomm_info,
tp_world_size=1,
tp_rank_id=0,
expert_shard_type=0,
shared_expert_num=1,
shared_expert_rank_num=self.shared_expert_rank_num,
quant_mode=2,
global_bs=self.batch_size * self.ep_world_size,
expert_token_nums_type=1, # 0代表前缀和1代表各自数量
)
expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums, ep_send_counts, tp_send_counts, expand_scales = outputs
output_dtype = x.dtype
y1_int32 = torch_npu.npu_grouped_matmul(
x=[expand_x],
weight=[self.gmm1_weight],
split_item=3,
group_list_type=1, # 默认为0代表前缀和形式
group_type=0, # 0代表m轴分组
group_list=expert_token_nums,
output_dtype=torch.int32)[0]
y1, y1_scale = torch_npu.npu_dequant_swiglu_quant(
x=y1_int32,
weight_scale=self.gmm1_weight_scale.to(torch.float32),
activation_scale=dynamic_scales,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=expert_token_nums,
activate_left=True,
quant_mode=1,
)
y2 = torch_npu.npu_grouped_matmul(x=[y1],
weight=[self.gmm2_weight],
scale=[self.gmm2_weight_scale],
per_token_scale=[y1_scale],
split_item=2,
group_list_type=1,
group_type=0,
group_list=expert_token_nums,
output_dtype=output_dtype)[0]
combine_output = torch_npu.npu_moe_distribute_combine_v2(
expand_x=y2,
expert_ids=expert_ids,
assist_info_for_combine=assist_info_for_combine,
ep_send_counts=ep_send_counts,
expert_scales=expert_scales,
x_active_mask=x_active_mask,
group_ep=self.ep_hcomm_info,
ep_world_size=self.ep_world_size,
ep_rank_id=self.global_rank_id,
moe_expert_num=self.moe_expert_num,
tp_send_counts=tp_send_counts,
expand_scales=expand_scales,
group_tp=self.tp_hcomm_info,
tp_world_size=1,
tp_rank_id=0,
expert_shard_type=0,
shared_expert_num=1,
shared_expert_rank_num=self.shared_expert_rank_num,
global_bs=self.batch_size * self.ep_world_size)
return (combine_output, ep_send_counts[:self.ep_recv_count_size])
class FusionOp(DecodeMoeOps):
def __init__(self,
gmm1_weight,
gmm1_weight_scale,
gmm2_weight,
gmm2_weight_scale,
ep_hcomm_info,
batch_size,
token_hidden_size,
moe_intermediate_size,
ep_world_size,
moe_expert_num,
global_rank_id,
shared_expert_rank_num=0):
super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight,
gmm2_weight_scale, ep_hcomm_info, batch_size,
token_hidden_size, moe_intermediate_size,
ep_world_size, moe_expert_num, global_rank_id,
shared_expert_rank_num)
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
x_active_mask):
output = torch.ops._C_ascend.dispatch_gmm_combine_decode(
x=x,
expert_ids=expert_ids,
gmm1_permuted_weight=self.gmm1_weight,
gmm1_permuted_weight_scale=self.gmm1_weight_scale_fp32,
gmm2_weight=self.gmm2_weight,
gmm2_weight_scale=self.gmm2_weight_scale_fp32,
expert_scales=expert_scales,
expert_smooth_scales=smooth_scales,
x_active_mask=x_active_mask,
group_ep=self.ep_hcomm_info,
ep_rank_size=self.ep_world_size,
ep_rank_id=self.global_rank_id,
moe_expert_num=self.moe_expert_num,
shared_expert_num=1,
shared_expert_rank_num=self.shared_expert_rank_num,
quant_mode=0,
global_bs=self.batch_size * self.ep_world_size)
return output
def generate_datas(batch_size,
token_hidden_size,
moe_intermediate_size,
ep_world_size,
moe_expert_num,
global_rank_id,
shared_expert_rank_num=0,
top_k=8,
test_bfloat16=True,
enable_dynamic_bs=False,
with_mc2_mask=False):
is_shared_expert = global_rank_id < shared_expert_rank_num
moe_expert_num_per_rank = moe_expert_num // (ep_world_size -
shared_expert_rank_num)
actual_bs = int(
torch.randint(2 if with_mc2_mask else 1, batch_size, [1]).item(
) if enable_dynamic_bs else batch_size)
local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
gmm1_input_dim = token_hidden_size
gmm1_output_dim = moe_intermediate_size * 2
gmm2_input_dim = moe_intermediate_size
gmm2_output_dim = token_hidden_size
x = torch.rand([actual_bs, token_hidden_size]) * 10 - 5
expert_ids = torch.arange(
global_rank_id * batch_size * top_k,
global_rank_id * batch_size * top_k + actual_bs * top_k).to(
torch.int32).view(actual_bs, top_k)
expert_ids = expert_ids % moe_expert_num
if is_shared_expert:
gmm1_weight = torch.ones([
local_expert_num, gmm1_input_dim, gmm1_output_dim
]).to(torch.int8) * 4
gmm2_weight = torch.ones([
local_expert_num, gmm2_input_dim, gmm2_output_dim
]).to(torch.int8) * 4
gmm1_weight[:, :, ::2] = gmm1_weight[:, :, ::2] * -1
gmm2_weight[:, :, ::2] = gmm2_weight[:, :, ::2] * -1
gmm1_weight_scale = torch.ones([local_expert_num, gmm1_output_dim
]) * 0.0015
gmm2_weight_scale = torch.ones([local_expert_num, gmm2_output_dim
]) * 0.0015
else:
gmm1_weight = torch.randint(
-16, 16,
[local_expert_num, gmm1_input_dim, gmm1_output_dim]).to(torch.int8)
gmm2_weight = torch.randint(
-16, 16,
[local_expert_num, gmm2_input_dim, gmm2_output_dim]).to(torch.int8)
gmm1_weight_scale = torch.rand([local_expert_num, gmm1_output_dim
]) * 0.003 + 0.0015
gmm2_weight_scale = torch.rand([local_expert_num, gmm2_output_dim
]) * 0.003 + 0.0015
expert_scales = torch.rand(actual_bs, top_k)
if test_bfloat16:
x = x.bfloat16()
gmm1_weight_scale = gmm1_weight_scale.bfloat16()
gmm2_weight_scale = gmm2_weight_scale.bfloat16()
else:
x = x.half()
smooth_sales = None
x_active_mask = None
valid_token_num = actual_bs
if with_mc2_mask:
valid_token_num = int(torch.randint(1, actual_bs, [1]).item())
x_active_mask = torch.cat(
(torch.ones(valid_token_num),
torch.zeros(actual_bs - valid_token_num))).bool()
return (x, expert_ids, smooth_sales, expert_scales, x_active_mask), \
(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale), \
actual_bs, valid_token_num
def run_once(local_rank_id,
batch_size,
token_hidden_size,
moe_intermediate_size,
ep_world_size,
moe_expert_num,
shared_expert_rank_num=0,
top_k=8,
test_bfloat16=True,
enable_dynamic_bs=False,
test_graph=False,
with_mc2_mask=False):
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
) if output_to_file(local_rank_id) else None
global_rank_id = local_rank_id # 单机
device_id = local_rank_id % 16
torch_npu.npu.set_device(device_id)
# 初始化分布式环境
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500" # 端口号随意
dist.init_process_group(backend="hccl",
rank=local_rank_id,
world_size=ep_world_size)
ep_ranks_list = list(np.arange(0, ep_world_size))
ep_group = dist.new_group(backend="hccl", ranks=ep_ranks_list)
ep_group_small = dist.new_group(backend="hccl", ranks=ep_ranks_list)
ep_hcomm_info_fused = ep_group._get_backend(
torch.device("npu")).get_hccl_comm_name(local_rank_id)
ep_hcomm_info_small = ep_group_small._get_backend(
torch.device("npu")).get_hccl_comm_name(local_rank_id)
torch_npu.npu.synchronize(device_id)
parameter = (batch_size, token_hidden_size, moe_intermediate_size,
ep_world_size, moe_expert_num, global_rank_id,
shared_expert_rank_num)
input_datas, weight_datas, actual_bs, valid_token_num = generate_datas(
*parameter, top_k, test_bfloat16, enable_dynamic_bs, with_mc2_mask)
input_datas = [
data.npu() if data is not None else None for data in input_datas
]
weight_datas = [
data.npu() if data is not None else None for data in weight_datas
]
small_ops = SmallOps(*weight_datas, ep_hcomm_info_small,
*parameter).npu() # type: ignore
fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused,
*parameter).npu() # type: ignore
if test_graph:
config = torchair.CompilerConfig()
config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=config)
fused_ops = torch.compile(fused_ops, backend=npu_backend)
small_op_token_output, small_op_count_output = small_ops(*input_datas)
fused_op_token_output, fused_op_count_output = fused_ops(*input_datas)
torch_npu.npu.synchronize(device_id)
dist.destroy_process_group()
if log_file is not None:
log_file.close()
small_op_count_output = from_inclusive_prefix_sum(small_op_count_output)
torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(),
fused_op_token_output[0:valid_token_num].cpu(),
atol=2.0,
rtol=0.02)
torch.testing.assert_close(small_op_count_output.cpu(),
fused_op_count_output.cpu())
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@torch.inference_mode()
def test_dispatch_gmm_combine_decode_base():
custom_kwargs = BASE_KWARGS
ep_world_size = custom_kwargs["ep_world_size"]
custom_args = tuple(custom_kwargs.values())
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
def test_dispatch_gmm_combine_decode_with_mc2_mask():
custom_kwargs = BASE_KWARGS
custom_kwargs["with_mc2_mask"] = True
ep_world_size = custom_kwargs["ep_world_size"]
custom_args = tuple(custom_kwargs.values())
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)

View File

@@ -0,0 +1,141 @@
import random
import unittest
import torch
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
torch.set_printoptions(threshold=float("inf"))
class TestMatrixMultiplication(unittest.TestCase):
def compute_golden(self, a, b, res1, m, n):
"""Compute reference result (golden)"""
torch.bmm(a.transpose(0, 1),
b,
out=res1.view(-1, m, n).transpose(0, 1))
def assert_tensors_almost_equal(self, actual, expected, dtype):
"""Check if two tensors are approximately equal (considering floating point errors)"""
self.assertEqual(actual.shape, expected.shape, "Shape mismatch")
# Check for NaN
self.assertFalse(
torch.isnan(actual).any(), "Actual result contains NaN")
self.assertFalse(
torch.isnan(expected).any(), "Expected result contains NaN")
# Check for Inf
self.assertFalse(
torch.isinf(actual).any(), "Actual result contains Inf")
self.assertFalse(
torch.isinf(expected).any(), "Expected result contains Inf")
# Set different tolerances based on data type
if dtype == torch.float16:
rtol, atol = 1e-5, 1e-5
else: # bfloat16
rtol, atol = 1.5e-5, 1.5e-5
# Compare values
diff = torch.abs(actual - expected)
max_diff = diff.max().item()
max_expected = torch.abs(expected).max().item()
# Check relative and absolute errors
if max_expected > 0:
relative_diff = max_diff / max_expected
self.assertLessEqual(
relative_diff,
rtol,
f"Relative error too large: {relative_diff} > {rtol}. Max difference: {max_diff}",
)
self.assertLessEqual(max_diff, atol,
f"Absolute error too large: {max_diff} > {atol}")
def test_boundary_conditions(self):
"""Test boundary conditions"""
test_cases = [
# (b, m, k, n)
(1, 1, 1, 1), # Minimum size
(1, 10, 1, 1), # b=1
(10, 1, 1, 10), # m=1
(5, 5, 1, 5), # k=1
(2, 2, 2, 1), # n=1
(100, 1, 1, 100), # Flat case
(1, 100, 100, 1), # Flat case
(2, 3, 4, 5), # Random small size
(10, 20, 30, 40), # Medium size
(36, 128, 512, 128), # target case
(8, 160, 512, 128),
]
dtypes = [torch.float16, torch.bfloat16]
for dtype in dtypes:
for b, m, k, n in test_cases:
with self.subTest(dtype=dtype, shape=f"({b}, {m}, {k}, {n})"):
a = torch.randn(b, m, k, dtype=dtype, device="npu")
b_tensor = torch.randn(m, k, n, dtype=dtype, device="npu")
res1 = torch.empty((b, m * n), dtype=dtype, device="npu")
res2 = torch.empty((b, m, n), dtype=dtype, device="npu")
self.compute_golden(a, b_tensor, res1, m, n)
torch.ops._C_ascend.batch_matmul_transpose(
a, b_tensor, res2)
self.assert_tensors_almost_equal(res1.view(-1, m, n), res2,
dtype)
def test_random_shapes(self):
"""Test randomly generated shapes"""
num_tests = 1
dtypes = [torch.float16, torch.bfloat16]
for dtype in dtypes:
for _ in range(num_tests):
# Generate reasonable random sizes
b = random.randint(1, 500)
m = random.randint(1, 500)
k = random.randint(1, 500)
n = random.randint(1, 500)
with self.subTest(dtype=dtype,
shape=f"Random ({b}, {m}, {k}, {n})"):
a = torch.randn(b, m, k, dtype=dtype, device="npu")
b_tensor = torch.randn(m, k, n, dtype=dtype, device="npu")
res1 = torch.empty((b, m * n), dtype=dtype, device="npu")
res2 = torch.empty((b, m, n), dtype=dtype, device="npu")
self.compute_golden(a, b_tensor, res1, m, n)
torch.ops._C_ascend.batch_matmul_transpose(
a, b_tensor, res2)
self.assert_tensors_almost_equal(res1.view(-1, m, n), res2,
dtype)
def test_zero_values(self):
"""Test zero input values"""
dtypes = [torch.float16, torch.bfloat16]
b, m, k, n = 5, 4, 3, 2
for dtype in dtypes:
with self.subTest(dtype=dtype):
a = torch.zeros(b, m, k, dtype=dtype, device="npu")
b_tensor = torch.zeros(m, k, n, dtype=dtype, device="npu")
res1 = torch.empty((b, m * n), dtype=dtype, device="npu")
res2 = torch.empty((b, m, n), dtype=dtype, device="npu")
self.compute_golden(a, b_tensor, res1, m, n)
torch.ops._C_ascend.batch_matmul_transpose(a, b_tensor, res2)
self.assert_tensors_almost_equal(res1.view(-1, m, n), res2,
dtype)
self.assertTrue(torch.all(res2 == 0))
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@@ -0,0 +1,46 @@
import gc
import torch
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3
def bgmv_expand_cpu_impl(x: torch.Tensor, w: torch.Tensor,
indices: torch.Tensor, y: torch.tensor,
slice_offset: int, slice_size: int) -> torch.Tensor:
W = w[indices, :, :].transpose(-1, -2).to(torch.float32)
z = torch.bmm(x.unsqueeze(1).to(torch.float32), W).squeeze()
y[:, slice_offset:slice_offset + slice_size] += z
return y
@torch.inference_mode()
def test_bgmv_expand():
B = 1
x = torch.randn([B, 16], dtype=torch.float)
w = torch.randn([64, 128, 16], dtype=torch.float16)
indices = torch.zeros([B], dtype=torch.int64)
y = torch.randn([B, 128 * 3], dtype=torch.float16)
x_npu = x.npu()
w_npu = w.npu()
indices_npu = indices.npu()
y_npu = y.npu()
y_out = bgmv_expand_cpu_impl(x, w, indices, y, 0, 128)
y_out_npu = torch.ops._C_ascend.bgmv_expand(x_npu, w_npu, indices_npu,
y_npu, 0, 128)
# Compare the results.
torch.testing.assert_close(y_out_npu.cpu(),
y_out,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,45 @@
import gc
import torch
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3
def bgmv_shrink_cpu_impl(x: torch.Tensor, w: torch.Tensor,
indices: torch.Tensor, y: torch.tensor,
scaling: float) -> torch.Tensor:
W = w[indices, :, :].transpose(-1, -2).to(torch.float32)
z = torch.bmm(x.unsqueeze(1).to(torch.float32), W).squeeze()
y[:, :] += z * scaling
return y
@torch.inference_mode()
def test_bgmv_shrink():
B = 1
x = torch.randn([B, 128], dtype=torch.float16)
w = torch.randn([64, 16, 128], dtype=torch.float16)
indices = torch.zeros([B], dtype=torch.int64)
y = torch.zeros([B, 16])
x_npu = x.npu()
w_npu = w.npu()
indices_npu = indices.npu()
y_npu = y.npu()
y = bgmv_shrink_cpu_impl(x, w, indices, y, 0.5)
torch.ops._C_ascend.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5)
# Compare the results.
torch.testing.assert_close(y_npu.cpu(),
y,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,168 @@
import random
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch_npu
from torch.distributed.distributed_c10d import _get_default_group
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
class TestDisptachFFNCombine:
def __init__(self, rank, world_size, port):
self.rank = rank
self.world_size = world_size
self.master_ip = "127.0.0.1"
self.port = port
def get_hcomm(self, comm_group):
hcomm_info = None
if torch.__version__ > "2.0.1":
hcomm_info = comm_group._get_backend(
torch.device("npu")).get_hccl_comm_name(self.rank)
else:
hcomm_info = comm_group.get_hccl_comm_name(self.rank)
return hcomm_info
def setup_ep_tp(
self,
rank,
tp_size,
ep_size,
backend_type,
ep_ranks_list=None,
tp_ranks_list=None,
):
for i in range(tp_size):
if ep_ranks_list:
ep_ranks = ep_ranks_list[i]
else:
ep_ranks = [x + ep_size * i for x in range(ep_size)]
ep_group = dist.new_group(backend=backend_type, ranks=ep_ranks)
if rank in ep_ranks:
ep_group_tmp = ep_group
for i in range(ep_size):
if tp_ranks_list:
tp_ranks = tp_ranks_list[i]
else:
tp_ranks = [x * ep_size + i for x in range(tp_size)]
tp_group = dist.new_group(backend=backend_type, ranks=tp_ranks)
if rank in tp_ranks:
tp_group_tmp = tp_group
return ep_group_tmp, tp_group_tmp
def generate_hcom(self):
torch_npu.npu.set_device(self.rank)
dist.init_process_group(
backend="hccl",
rank=self.rank,
world_size=self.world_size,
init_method=f"tcp://127.0.0.1:{self.port}",
)
ep_size = 0
tp_size = self.world_size
hcomm_info_dist = {
"default_pg_info": None,
"ep_hcomm_info": None,
"group_ep": None,
"tp_hcomm_info": None,
"group_tp": None,
}
if ep_size and tp_size:
group_ep, group_tp = self.setup_ep_tp(self.rank, tp_size, ep_size,
"hccl", None, None)
hcomm_info_dist["ep_hcomm_info"] = self.get_hcomm(group_ep)
hcomm_info_dist["tp_hcomm_info"] = self.get_hcomm(group_tp)
hcomm_info_dist["group_ep"] = group_ep
hcomm_info_dist["group_tp"] = group_tp
else:
if dist.is_available():
default_pg = _get_default_group()
hcomm_info_dist["default_pg_info"] = self.get_hcomm(default_pg)
hcomm_info = hcomm_info_dist["default_pg_info"]
self.hcomm_info = hcomm_info
def run_npu_out(self) -> bool:
torch_npu.npu.set_device(self.rank)
m = 2 # token-num 32
k = 4 # hidden_size 7168
n = 4 # mid-hidden-size 4096
topk = 2
e = 2 # expert-num-per-rank 16
k2 = n // 2
n2 = k
torch_npu.npu.config.allow_internal_format = True
x = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
weight1 = self.generate_random_tensor((e, k, n),
dtype=torch.int8).npu()
weight1 = torch_npu.npu_format_cast(weight1, 29)
weight2 = self.generate_random_tensor((e, k2, n2),
dtype=torch.int8).npu()
weight2 = torch_npu.npu_format_cast(weight2, 29)
expert_idx = torch.randint(0,
self.world_size * e, (m, topk),
dtype=torch.int32).npu()
scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu()
scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu()
probs = torch.randn(size=(m, topk), dtype=torch.float32).npu()
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
torch.ops._C_ascend.dispatch_ffn_combine(
x=x,
weight1=weight1,
weight2=weight2,
expert_idx=expert_idx,
scale1=scale1,
scale2=scale2,
probs=probs,
group=self.hcomm_info,
max_output_size=512,
out=out,
)
return True
def generate_random_tensor(self, size, dtype):
if dtype in [torch.float16, torch.bfloat16, torch.float32]:
return torch.randn(size=size, dtype=dtype)
elif dtype is torch.int8:
return torch.randint(-16, 16, size=size, dtype=dtype)
elif dtype is torch.int32:
return torch.randint(-1024, 1024, size=size, dtype=dtype)
else:
raise ValueError(f"Invalid dtype: {dtype}")
def worker(rank: int, world_size: int, port: int, q: mp.SimpleQueue):
op = TestDisptachFFNCombine(rank, world_size, port)
op.generate_hcom()
out = op.run_npu_out()
q.put(out)
@torch.inference_mode()
def test_dispatch_ffn_combine_kernel():
world_size = 2
mp.set_start_method("fork", force=True)
q = mp.SimpleQueue()
p_list = []
port = 29501 + random.randint(0, 10000)
for rank in range(world_size):
p = mp.Process(target=worker, args=(rank, world_size, port, q))
p.start()
p_list.append(p)
results = [q.get() for _ in range(world_size)]
for p in p_list:
p.join()
assert all(results)

View File

@@ -0,0 +1,338 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py
"""Tests for the MOE layers.
Run `pytest tests/ops/test_fused_moe.py`.
"""
import gc
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch_npu
from vllm.model_executor.layers.activation import SiluAndMul
from vllm_ascend.ops.fused_moe.experts_selector import (
check_npu_moe_gating_top_k, select_experts)
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.token_dispatcher import \
TokenDispatcherWithAllGather
NUM_EXPERTS = [8, 64]
EP_SIZE = [1]
TOP_KS = [2, 6]
DEVICE = ["npu"]
def apply_mlp(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int = 1,
) -> torch.Tensor:
w1 = w1.transpose(1, 2)
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
hidden_states = torch_npu.npu_swiglu(hidden_states)
w2 = w2.transpose(1, 2)
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
return hidden_states
def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
topk_weights = topk_weights.view(-1)
topk_ids = topk_ids.view(-1)
if expert_map is not None:
topk_ids = expert_map[topk_ids]
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("device", DEVICE)
def test_token_dispatcher_with_all_gather(
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
device: str,
):
a = torch.randn((m, k), device=device, dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
score = torch.randn((m, e), device=device, dtype=dtype)
expert_map = None
local_e = e
w1_local = w1
w2_local = w2
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)
dispatcher_kwargs = {
"num_experts": e,
"top_k": topk,
"num_local_experts": local_e,
}
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
apply_router_weight_on_input = False
dispatch_output = dispatcher.token_dispatch(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
sorted_hidden_states = dispatch_output["hidden_states"]
group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1)
context_metadata = dispatch_output["context_metadata"]
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
w1=w1_local,
w2=w2_local,
group_list=group_list,
group_list_type=group_list_type)
combined_output = dispatcher.token_combine(
hidden_states=expert_output,
context_metadata=context_metadata,
bias=None)
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
expert_map)
torch.testing.assert_close(combined_output,
torch_output,
atol=4e-2,
rtol=1)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@pytest.mark.parametrize("m", [1, 33, 64])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("device", DEVICE)
def test_token_dispatcher_with_all_gather_quant(
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
device: str,
):
context_mock = MagicMock()
context_mock.fused_moe_state = 0
with patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context",
return_value=context_mock):
a = torch.randn((m, k), device=device, dtype=dtype) / 10
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)
w1_scale = torch.empty((e, 2 * n), device=device, dtype=dtype)
w2 = torch.randn((e, n, k), device=device, dtype=torch.int8)
w2_scale = torch.empty((e, k), device=device, dtype=dtype)
score = torch.randn((m, e), device=device, dtype=dtype)
expert_map = None
local_e = e
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)
dispatcher_kwargs = {
"num_experts": e,
"top_k": topk,
"num_local_experts": local_e,
}
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
apply_router_weight_on_input = False
dispatch_output = dispatcher.token_dispatch(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=True)
sorted_hidden_states = dispatch_output["hidden_states"]
group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1)
dynamic_scale = dispatch_output["dynamic_scale"]
context_metadata = dispatch_output["context_metadata"]
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
group_list_type=group_list_type,
dynamic_scale=dynamic_scale,
with_quant=True)
combined_output = dispatcher.token_combine(
hidden_states=expert_output,
context_metadata=context_metadata,
bias=None)
assert combined_output.shape == (m, k)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@pytest.mark.parametrize("m", [1, 33, 64])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("use_grouped_topk", [True, False])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("with_e_correction", [True, False])
@pytest.mark.parametrize("custom_routing", [True, False])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts(
m: int,
n: int,
e: int,
topk: int,
scoring_func: str,
use_grouped_topk: bool,
renormalize: bool,
with_e_correction: bool,
custom_routing: bool,
dtype: torch.dtype,
device: str,
):
topk_group = 4 if use_grouped_topk else None
num_expert_group = e // 4 if use_grouped_topk else None
hidden_states = torch.randn(m, n, device=device, dtype=dtype)
router_logits = torch.randn(m, e, device=device, dtype=dtype)
e_score_correction_bias = (torch.randn(e, device=device, dtype=dtype)
if with_e_correction else None)
custom_routing_function = None
if custom_routing:
custom_routing_function = MagicMock()
mock_weights = torch.randn(m, topk, device=device, dtype=dtype)
mock_ids = torch.randint(0,
e, (m, topk),
device=device,
dtype=torch.int32)
custom_routing_function.return_value = (mock_weights, mock_ids)
with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk"
) as mock_native_grouped_topk, \
patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
return_value=MagicMock()):
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
call_moe_gatingtopk = check_npu_moe_gating_top_k(
hidden_states, topk, topk_group, num_expert_group, scoring_func,
custom_routing_function)
if not call_moe_gatingtopk and use_grouped_topk:
mock_native_grouped_topk.assert_called_once()
else:
mock_native_grouped_topk.assert_not_called()
assert topk_weights.shape == (m, topk)
assert topk_ids.shape == (m, topk)
assert topk_ids.dtype == torch.int32
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts_invalid_scoring_func(device: str):
with patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
return_value=MagicMock()), \
pytest.raises(ValueError,
match="Unsupported scoring function: invalid"):
select_experts(hidden_states=torch.randn(1, 128, device=device),
router_logits=torch.randn(1, 8, device=device),
top_k=2,
use_grouped_topk=False,
renormalize=False,
scoring_func="invalid")
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,37 @@
import pytest
import torch
import torch_npu
@pytest.mark.parametrize(
'B',
[1, 16, 64, 128, 32768],
)
@pytest.mark.parametrize(
'D',
[8, 16, 32, 64, 128],
)
@pytest.mark.parametrize(
'top_k',
[1, 2, 4, 8],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
(torch.float16, 1e-3, 1e-3),
(torch.bfloat16, 1e-3, 1e-3),
],
)
def test_quant_fpx_linear(B: int, D: int, top_k: int, dtype, atol, rtol):
x = torch.rand((B, D), dtype=dtype).to("npu")
# finished = torch.randint(1, size=(B,), dtype=torch.bool).to("npu")
finished = None
y, expert_idx, row_idx = torch_npu.npu_moe_gating_top_k_softmax(x,
finished,
k=top_k)
topk_weights = x.softmax(dim=-1)
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
topk_ids = topk_ids.to(torch.int32)
torch.allclose(y, topk_weights, atol=atol, rtol=rtol)
torch.allclose(expert_idx, topk_ids, atol=atol, rtol=rtol)

View File

@@ -0,0 +1,148 @@
import gc
import torch
import torch_npu
from vllm_ascend.utils import enable_custom_op
# enable internal format
torch_npu.npu.config.allow_internal_format = True
# enable vllm-ascend custom ops
enable_custom_op()
def gmm_swiglu_quant(x: torch.Tensor, weight: torch.Tensor,
perChannelScale: torch.Tensor,
perTokenScale: torch.Tensor, m: int):
"""
Perform quantized GMM (Grouped Matrix Multiplication) operation with SwiGLU activation function.
Parameters:
x (torch.Tensor): Input tensor with shape (m, k).
weight (torch.Tensor): Weight tensor with shape (k, n).
perChannelScale (torch.Tensor): Per-channel scaling factor with shape (n,).
perTokenScale (torch.Tensor): Per-token scaling factor with shape (m,).
m (int): Number of tokens (rows of x).
Returns:
quantOutput (torch.Tensor): Quantized output tensor with shape (m, k // 2).
quantScaleOutput (torch.Tensor): Quantization scaling factor with shape (m,).
"""
# Perform matrix multiplication with int32 precision
c_temp1 = torch.matmul(x.to(torch.int32), weight.to(torch.int32))
c_temp1 = c_temp1.to(torch.float32) # Convert back to float32 for scaling
# Apply per-channel and per-token scaling
c_temp2 = torch.mul(c_temp1, perChannelScale)
c_temp3 = torch.mul(c_temp2, perTokenScale.reshape(m, 1))
# Split the result into two parts to apply SwiGLU activation function
c_temp4, gate = c_temp3.chunk(2, dim=-1)
c_temp5 = c_temp4 * torch.sigmoid(c_temp4) # SwiGLU activation
c_temp6 = c_temp5 * gate # Element-wise multiplication with gating values
# Quantize the output
max = torch.max(
torch.abs(c_temp6),
-1).values # Find maximum absolute value to calculate scaling factor
quantScaleOutput = 127 / max # Calculate quantization scaling factor
quantOutput = torch.round(c_temp6 * quantScaleOutput.reshape(m, 1)).to(
torch.int8) # Quantize to int8
quantScaleOutput = 1 / quantScaleOutput # Inverse quantization scaling factor for subsequent dequantization
return quantOutput, quantScaleOutput
def process_groups(x: torch.Tensor, weight: torch.Tensor,
perChannelScale: torch.Tensor, perTokenScale: torch.Tensor,
groupList: torch.Tensor):
"""
Process input data by groups and call GMM_Swiglu_quant function for quantized computation.
Parameters:
x (torch.Tensor): Input tensor with shape (M, K).
weight (torch.Tensor): List of weight tensors, each with shape (E, K, N).
perChannelScale (torch.Tensor): List of per-channel scaling factors, each with shape (E, N).
perTokenScale (torch.Tensor): Per-token scaling factor with shape (M,).
groupList (list): List defining the number of tokens in each group.
Returns:
quantOutput (torch.Tensor): Quantized output tensor with shape (M, N // 2).
quantScaleOutput (torch.Tensor): Quantization scaling factor with shape (M,).
"""
M, N = x.shape[0], weight.shape[2] # Get the shape of the input tensor
quantOutput = torch.zeros(M, N // 2).to(
torch.int8) # Initialize quantized output tensor
quantScaleOutput = torch.zeros(M).to(
torch.float32) # Initialize quantization scaling factor tensor
start_idx = 0 # Starting index
preV = 0 # Number of tokens in the previous group
groupList = groupList.tolist()
# Iterate through groupList to process data by groups
for i, v in enumerate(groupList):
currV = v
tempV = currV - preV # Calculate number of tokens in the current group
preV = currV # Update number of tokens in the previous group
if tempV > 0:
# Call GMM_Swiglu_quant to process the current group
quantOutput[start_idx:start_idx + tempV], quantScaleOutput[start_idx:start_idx + tempV] = \
gmm_swiglu_quant(x[start_idx:start_idx + tempV],
weight[i],
perChannelScale[i],
perTokenScale[start_idx:start_idx + tempV],
tempV)
start_idx += tempV # Update starting index to process the next group
return quantOutput, quantScaleOutput
@torch.inference_mode()
def test_gmm_swiglu_quant_weight_nz_tensor_list():
M, K, E, N = 8192, 7168, 4, 4096
# x (M, K) - int8
x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
# weight (E, N, K) - int8
weight = torch.randint(-128, 127, size=(E, K, N), dtype=torch.int8)
# weight_scale (E, N) - float32
weight_scale = torch.rand(E, N) * 0.9 + 0.1 # uniform(0.1, 1.0)
weight_scale = weight_scale.to(torch.float32)
weight_nz_npu = []
weight_scale_npu = []
for i in range(E):
weight_nz_npu.append(torch_npu.npu_format_cast(weight[i].npu(), 29))
weight_scale_npu.append(weight_scale[i].npu())
# x_scale (M,) - float32
x_scale = torch.rand(M) * 0.9 + 0.1 # uniform(0.1, 1.0)
x_scale = x_scale.to(torch.float32)
group_list = torch.tensor([2048, 4096, 6144, 8192], dtype=torch.int64)
output_cpu, output_scale_cpu = process_groups(x, weight, weight_scale,
x_scale, group_list)
output_npu, output_scale_npu, _ = \
torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(x.npu(),
weight_nz_npu,
weight_scale_npu,
x_scale.npu(),
group_list.npu())
output_npu_valid = output_npu[:group_list[-1], :]
output_scale_npu_valid = output_scale_npu[:group_list[-1]]
torch.testing.assert_close(output_npu_valid.cpu(),
output_cpu,
atol=1,
rtol=2**-13)
torch.testing.assert_close(output_scale_npu_valid.cpu(),
output_scale_cpu,
atol=1e-9,
rtol=1e-6)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,175 @@
import gc
import numpy as np
import torch
import torch_npu
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
def x_int8_to_x_int4(x: torch.Tensor):
m, k = x.shape
x_high_4bit = torch.floor(x.to(torch.float16) // 16).to(torch.int8)
x_low_4bit = (
torch.bitwise_and(x.view(torch.int16), 0x0f0f).view(torch.int8) - 8)
x_int4 = torch.empty((2 * m, k), dtype=torch.int8)
x_int4[::2, :] = x_high_4bit
x_int4[1::2, :] = x_low_4bit
return x_int4
def custom_mm(x: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, m: int):
"""
Performing Quantized GMM (General Matrix Multiplication) Operation
Parameters:
x (torch.Tensor): Input tensor with shape (m, k).
weight (torch.Tensor): Weight tensor with shape (k, n).
weight_scale (torch.Tensor): Scaling factor for each channel.
- In perGroup scenario: Shape is (k_group_num, n). Note: When k_group_num == 1, it is a perChannel scenario.
- In perChannel scenario: Shape is (n).
m (int): Number of tokens (number of rows in x).
Returns:
mm_out(fp16): Result of MatMul + perGroup or perChannel dequantization.
"""
# Perform matrix multiplication with int32 precision
k, n = weight.shape
mm_out = torch.zeros((m, n), dtype=torch.float16)
# perGroup scenario
if len(weight_scale.shape) == 2 and weight_scale.shape[0] != 1:
k_group = weight_scale.shape[0]
per_group_ele = k // k_group
x_grouped = x.view(-1, k_group, per_group_ele).transpose(0, 1)
weight_grouped = weight.view(k_group, per_group_ele, n)
c_temp = torch.bmm(x_grouped.to(torch.int32),
weight_grouped.to(torch.int32)).to(torch.float16)
for k_idx in range(k_group):
mm_out += (c_temp[k_idx] *
weight_scale[k_idx].view(1, -1).to(torch.float16)).to(
torch.float16)
# perChannel scenario
elif len(weight_scale.shape) == 1 or (len(weight_scale.shape) == 2
and weight_scale.shape[0] == 1):
c_temp = torch.matmul(x.to(torch.int32),
weight.to(torch.int32)).to(torch.float32)
mm_out = c_temp * weight_scale.view(1, -1).to(torch.float16)
return mm_out.to(torch.float32)
def gmm_swiglu_quant_golden_a8_w4(x: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor,
per_token_scale: torch.Tensor,
bias: torch.Tensor,
group_list: torch.Tensor):
"""
Process the input data by group and call the GMM_Swiglu_quant function for quantization computation.
Parameters:
x (torch.Tensor): Input tensor with shape (M, K), type INT8.
weight (torch.Tensor): List of weight tensors, each with shape (E, K, N), data type INT8 but data range INT4, representing INT4 values.
weight_scale (torch.Tensor): Scaling factor for each channel.
- In perGroup scenario: shape (E, k_group_num, N).
- In perChannel scenario: shape (E, N).
per_token_scale (torch.Tensor): Scaling factor for each token, shape (M, ).
bias: torch.Tensor,
group_list (list): List defining the number of tokens in each group.
Returns:
quant_output (torch.Tensor): Quantized output tensor with shape (M, N // 2).
quant_scale_output (torch.Tensor): Quantization scaling factor, shape (M, ).
"""
M, N = x.shape[0], weight.shape[2]
quant_output = torch.zeros(M, N // 2).to(torch.int8)
quant_scale_output = torch.zeros(M).to(torch.float32)
# Preprocessing X_INT8 -> X_INT4
x_int4 = x_int8_to_x_int4(x)
start_idx = 0
# Number of tokens in the previous group
pre_v = 0
group_list = group_list.tolist()
# Traverse group_list and process data by group
for i, v in enumerate(group_list):
curr_v = v
# Calculate the number of tokens in the current group " * 2 " because 1 row of Int8--> 2 rows of Int4
temp_v = int((curr_v - pre_v) * 2)
# Update the number of tokens in the previous group
pre_v = curr_v
if (temp_v > 0):
mm_out = custom_mm(x_int4[int(start_idx):int(start_idx + temp_v)],
weight[i], weight_scale[i], temp_v)
mm_num_concat = ((mm_out[::2] * 16 + mm_out[1::2]) +
bias[i].view(1, -1))
per_token_quant = mm_num_concat * per_token_scale[start_idx // 2:(
start_idx + temp_v) // 2].view(-1, 1)
swiglu, gate = per_token_quant.chunk(2, dim=-1)
temp = swiglu * torch.sigmoid(swiglu)
temp = temp * gate
max_value = torch.max(torch.abs(temp), dim=-1).values
quant_scale_output_temp = 127 / max_value
quant_output[start_idx // 2:(start_idx + temp_v) //
2] = torch.round(temp *
quant_scale_output_temp.reshape(
temp_v // 2, 1)).to(torch.int8)
quant_scale_output[start_idx // 2:(start_idx + temp_v) //
2] = 1 / quant_scale_output_temp
start_idx += temp_v
return quant_output, quant_scale_output
def generate_non_decreasing_sequence(length, upper_limit):
# Generate random increasing sequence
random_increments = torch.randint(0, 128, (length, ))
sequence = torch.cumsum(random_increments, dim=0)
# Make sure the last value is less than the upper limit
if sequence[-1] >= upper_limit:
scale_factor = upper_limit / sequence[-1]
sequence = (sequence * scale_factor).to(torch.int64)
return sequence
@torch.inference_mode()
def test_grouped_matmul_swiglu_quant_kernel():
E = 16
M = 512
K = 7168
N = 4096
torch.npu.config.allow_internal_format = True
x = torch.randint(-5, 5, (M, K), dtype=torch.int8).npu()
weight_ori = torch.randint(-5, 5, (E, K, N), dtype=torch.int8)
weight_nz = torch_npu.npu_format_cast(weight_ori.npu().to(torch.float32),
29)
pack_weight = torch_npu.npu_quantize(weight_nz,
torch.tensor([1.], device='npu'),
None, torch.quint4x2, -1, False)
weight_scale = torch.randn(E, 1, N)
scale_np = weight_scale.cpu().numpy()
scale_np.dtype = np.uint32
scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu()
pertoken_scale = torch.randn(M).to(torch.float32).npu()
group_list = generate_non_decreasing_sequence(E, M).npu()
bias = torch.zeros((E, N), dtype=torch.float32,
device="npu").uniform_(-5, 5)
output_golden, output_scale_golden = gmm_swiglu_quant_golden_a8_w4(
x.cpu(), weight_ori, weight_scale, pertoken_scale.cpu(), bias.cpu(),
group_list.cpu())
output, output_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant(
x=x,
weight=pack_weight,
bias=bias,
group_list=group_list,
weight_scale=scale_uint64_tensor,
x_scale=pertoken_scale)
torch.testing.assert_close(output_golden, output.cpu(), atol=1, rtol=0.005)
torch.testing.assert_close(output_scale_golden,
output_scale.cpu(),
atol=1,
rtol=0.005)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,135 @@
import gc
import os
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch_npu
import torchair
from vllm_ascend.utils import enable_custom_op
config = torchair.CompilerConfig()
config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=config)
torch_npu.npu.config.allow_internal_format = True
enable_custom_op()
global_rank_id = 0
def golden_op_matmul_allreduce_add_rmsnorm(a, b, residual, gamma, epsilon):
c_ret = torch.nn.functional.linear(a, b)
dist.all_reduce(c_ret)
rmsnorm_ret, _, add_ret = torch_npu.npu_add_rms_norm(
c_ret, residual, gamma, epsilon)
return rmsnorm_ret, add_ret
def worker(rank, ep_world_size, batch_size, m, k, n):
global global_rank_id
global_rank_id = rank
rank = rank
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend="hccl",
rank=rank,
world_size=ep_world_size)
ep_ranks_list = list(np.arange(0, ep_world_size))
ep_group = dist.new_group(backend="hccl", ranks=ep_ranks_list)
torch_npu.npu.set_device(rank)
ep_hcomm_info = ep_group._get_backend(
torch.device("npu")).get_hccl_comm_name(rank)
torch_npu.npu.synchronize(rank)
class Module(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x1, x2, residual, gamma, ep_hcomm_info, epsilon,
is_trans_b, is_allgather_add_out):
out1, add_out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(
x1, x2, residual, gamma, ep_hcomm_info, ep_world_size,
global_rank_id, epsilon, is_trans_b, is_allgather_add_out)
return out1, add_out1
DTYPE = torch.bfloat16
USE_ONES = False
torch.manual_seed(42)
if USE_ONES:
x1 = torch.ones([m, k], dtype=DTYPE).npu(rank)
x2 = torch.ones([n, k], dtype=DTYPE).npu(rank)
else:
x1 = torch.normal(0, 0.1, [m, k], dtype=DTYPE).npu(rank)
x2 = torch.normal(0, 0.1, [n, k], dtype=DTYPE).npu(rank)
if USE_ONES:
residual = torch.full([m, n], 2048, dtype=DTYPE).npu(rank)
else:
residual = torch.full([m, n], 0, dtype=DTYPE).npu(rank)
gamma = torch.full([n], 1, dtype=DTYPE).npu(rank)
epsilon = 1e-5
is_trans_b = True
is_allgather_add_out = True
warnup_cnt = 5
repeat_cnt = 10
def run_golden_case(loop_cnt):
for _ in range(loop_cnt):
golden_out, golden_add_out = golden_op_matmul_allreduce_add_rmsnorm(
x1, x2, residual, gamma, epsilon)
torch_npu.npu.synchronize(rank)
return golden_out, golden_add_out
run_golden_case(warnup_cnt)
golden_out, golden_add_out = run_golden_case(repeat_cnt)
golden_out = golden_out.detach().cpu()
golden_add_out = golden_add_out.detach().cpu()
mod = Module().npu()
opt_model = torch.compile(mod, backend=npu_backend)
def run_custom_case(loop_cnt):
for _ in range(loop_cnt):
out, add_out = opt_model(x1, x2, residual, gamma, ep_hcomm_info,
epsilon, is_trans_b, is_allgather_add_out)
torch_npu.npu.synchronize(rank)
return out, add_out
# warn up
run_custom_case(warnup_cnt)
out, add_out = run_custom_case(repeat_cnt)
out = out.detach().cpu()
add_out = add_out.detach().cpu()
dist.destroy_process_group()
torch.testing.assert_close(golden_out, out, atol=0.1, rtol=0.005)
torch.testing.assert_close(golden_add_out, add_out, atol=0.1, rtol=0.005)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@torch.inference_mode()
def test_matmul_allreduce_add_rmsnorm_kernel():
ep_world_size = 8
batch_size = 1
m = 10000
k = 1024
n = 5120
args = (ep_world_size, batch_size, m, k, n)
mp.spawn(worker, args=args, nprocs=ep_world_size, join=True)

View File

@@ -0,0 +1,115 @@
import gc
import torch
import torch_npu
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
@torch.inference_mode()
def test_mla_preprocess_kernel():
token_num = 1
head_num = 2
N_7168 = 7168
block_num = 1
block_size = 128
dtype = torch.bfloat16
hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu()
quant_scale0 = torch.randn((1, ), dtype=dtype).npu()
quant_offset0 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
wdqkv = torch.randint(0, 7, (1, 224, 2112, 32), dtype=torch.int8).npu()
wdqkv = torch_npu.npu_format_cast(wdqkv.contiguous(), 29)
de_scale0 = torch.rand((2112, ), dtype=torch.float).npu()
bias0 = torch.randint(0, 7, (2112, ), dtype=torch.int32).npu()
gamma1 = torch.randn((1536), dtype=dtype).npu()
beta1 = torch.randn((1536), dtype=dtype).npu()
quant_scale1 = torch.randn((1, ), dtype=dtype).npu()
quant_offset1 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
wuq = torch.randint(0, 7, (1, 48, head_num * 192, 32),
dtype=torch.int8).npu()
wuq = torch_npu.npu_format_cast(wuq.contiguous(), 29)
de_scale1 = torch.rand((head_num * 192, ), dtype=torch.float).npu()
bias1 = torch.randint(0, 7, (head_num * 192, ), dtype=torch.int32).npu()
gamma2 = torch.randn((512), dtype=dtype).npu()
cos = torch.randn((token_num, 64), dtype=dtype).npu()
sin = torch.randn((token_num, 64), dtype=dtype).npu()
wuk = torch.randn((head_num, 128, 512), dtype=dtype).npu()
wuk = torch_npu.npu_format_cast(wuk, 29)
kv_cache = torch.randint(0,
7,
(block_num, head_num * 512 // 32, block_size, 32),
dtype=dtype).npu()
kv_cache_rope = torch.randn(
(block_num, head_num * 64 // 16, block_size, 16), dtype=dtype).npu()
slotmapping = torch.randint(0, 7, (token_num, ), dtype=torch.int32).npu()
ctkv_scale = torch.randn((1, ), dtype=dtype).npu()
qnope_scale = torch.randn((head_num), dtype=dtype).npu()
q_nope_out = torch.empty(
(hidden_states.shape[0], wuk.shape[0], kv_cache.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_rope_out = torch.empty(
(hidden_states.shape[0], wuk.shape[0], kv_cache_rope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_down = torch.empty(
(hidden_states.shape[0], 1536),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_nope_old = q_nope_out.clone()
q_rope_old = q_rope_out.clone()
torch.ops._C_ascend.mla_preprocess(
hidden_states,
wdqkv,
de_scale0,
gamma1,
beta1,
wuq,
de_scale1,
gamma2,
cos,
sin,
wuk,
kv_cache,
kv_cache_rope,
slotmapping,
quant_scale0=quant_scale0,
quant_offset0=quant_offset0,
bias0=bias0,
quant_scale1=quant_scale1,
quant_offset1=quant_offset1,
bias1=bias1,
ctkv_scale=ctkv_scale,
q_nope_scale=qnope_scale,
cache_mode="krope_ctkv",
quant_mode="per_tensor_quant_asymm",
enable_inner_out=False,
q_out0=q_nope_out,
kv_cache_out0=kv_cache,
q_out1=q_rope_out,
kv_cache_out1=kv_cache_rope,
inner_out=q_down,
)
assert not torch.equal(q_nope_out, q_nope_old)
assert not torch.equal(q_rope_out, q_rope_old)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,99 @@
import gc
import torch
import torch_npu
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
@torch.inference_mode()
def test_mla_preprocess_kernel():
token_num = 1
head_num = 2
N_7168 = 7168
block_num = 1
block_size = 128
dtype = torch.bfloat16
hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu()
wdqkv = torch.randint(0, 7, (1, 448, 2112, 16), dtype=dtype).npu()
wdqkv = torch_npu.npu_format_cast(wdqkv.contiguous(), 29)
gamma1 = torch.randn((1536), dtype=dtype).npu()
wuq = torch.randint(0, 7, (1, 96, head_num * 192, 16), dtype=dtype).npu()
wuq = torch_npu.npu_format_cast(wuq.contiguous(), 29)
gamma2 = torch.randn((512), dtype=dtype).npu()
cos = torch.randn((token_num, 64), dtype=dtype).npu()
sin = torch.randn((token_num, 64), dtype=dtype).npu()
wuk = torch.randn((head_num, 128, 512), dtype=dtype).npu()
# wuk = torch_npu.npu_format_cast(wuk, 29)
kv_cache = torch.randint(0,
7,
(block_num, head_num * 512 // 32, block_size, 32),
dtype=dtype).npu()
kv_cache_rope = torch.randn(
(block_num, head_num * 64 // 16, block_size, 16), dtype=dtype).npu()
slotmapping = torch.randint(0, 7, (token_num, ), dtype=torch.int32).npu()
q_nope_out = torch.empty(
(hidden_states.shape[0], wuk.shape[0], kv_cache.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_rope_out = torch.empty(
(hidden_states.shape[0], wuk.shape[0], kv_cache_rope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_down = torch.empty(
(hidden_states.shape[0], 1536),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_nope_old = q_nope_out.clone()
q_rope_old = q_rope_out.clone()
torch.ops._C_ascend.mla_preprocess(
hidden_states,
wdqkv,
None,
gamma1,
None,
wuq,
None,
gamma2,
cos,
sin,
wuk,
kv_cache,
kv_cache_rope,
slotmapping,
None,
None,
None,
None,
None,
None,
None,
None,
cache_mode="krope_ctkv",
quant_mode="no_quant",
enable_inner_out=False,
q_out0=q_nope_out,
kv_cache_out0=kv_cache,
q_out1=q_rope_out,
kv_cache_out1=kv_cache_rope,
inner_out=q_down,
)
assert not torch.equal(q_nope_out, q_nope_old)
assert not torch.equal(q_rope_out, q_rope_old)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,116 @@
import gc
import torch
import torch_npu
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
@torch.inference_mode()
def test_mla_preprocess_kernel():
token_num = 1
head_num = 2
N_7168 = 7168
block_num = 1
block_size = 128
dtype = torch.bfloat16
hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu()
quant_scale0 = torch.randn((1, ), dtype=dtype).npu()
quant_offset0 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
wdqkv = torch.randint(0, 7, (1, 224, 2112, 32), dtype=torch.int8).npu()
wdqkv = torch_npu.npu_format_cast(wdqkv.contiguous(), 29)
de_scale0 = torch.rand((2112, ), dtype=torch.float).npu()
bias0 = torch.randint(0, 7, (2112, ), dtype=torch.int32).npu()
gamma1 = torch.randn((1536), dtype=dtype).npu()
beta1 = torch.randn((1536), dtype=dtype).npu()
quant_scale1 = torch.randn((1, ), dtype=dtype).npu()
quant_offset1 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
wuq = torch.randint(0, 7, (1, 48, head_num * 192, 32),
dtype=torch.int8).npu()
wuq = torch_npu.npu_format_cast(wuq.contiguous(), 29)
de_scale1 = torch.rand((head_num * 192, ), dtype=torch.float).npu()
bias1 = torch.randint(0, 7, (head_num * 192, ), dtype=torch.int32).npu()
gamma2 = torch.randn((512), dtype=dtype).npu()
cos = torch.randn((token_num, 64), dtype=dtype).npu()
sin = torch.randn((token_num, 64), dtype=dtype).npu()
wuk = torch.randn((head_num, 128, 512), dtype=dtype).npu()
wuk = torch_npu.npu_format_cast(wuk, 29)
kv_cache = torch.randint(0,
7,
(block_num, head_num * 512 // 32, block_size, 32),
dtype=dtype).npu()
kv_cache_rope = torch.randn(
(block_num, head_num * 64 // 16, block_size, 16), dtype=dtype).npu()
slotmapping = torch.randint(0, 7, (token_num, ), dtype=torch.int32).npu()
ctkv_scale = torch.randn((1, ), dtype=dtype).npu()
qnope_scale = torch.randn((head_num), dtype=dtype).npu()
q_nope_out = torch.empty(
(hidden_states.shape[0], wuk.shape[0], kv_cache.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_rope_out = torch.empty(
(hidden_states.shape[0], wuk.shape[0], kv_cache_rope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_down = torch.empty(
(hidden_states.shape[0], 1536),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_nope_old = q_nope_out.clone()
q_rope_old = q_rope_out.clone()
q_down_old = q_down.clone()
torch.ops._C_ascend.mla_preprocess(hidden_states,
wdqkv,
de_scale0,
gamma1,
beta1,
wuq,
de_scale1,
gamma2,
cos,
sin,
wuk,
kv_cache,
kv_cache_rope,
slotmapping,
quant_scale0=quant_scale0,
quant_offset0=quant_offset0,
bias0=bias0,
quant_scale1=quant_scale1,
quant_offset1=quant_offset1,
bias1=bias1,
ctkv_scale=ctkv_scale,
q_nope_scale=qnope_scale,
cache_mode="krope_ctkv",
quant_mode="per_tensor_quant_asymm",
enable_inner_out=True,
q_out0=q_nope_out,
kv_cache_out0=kv_cache,
q_out1=q_rope_out,
kv_cache_out1=kv_cache_rope,
inner_out=q_down)
assert not torch.equal(q_nope_out, q_nope_old)
assert not torch.equal(q_rope_out, q_rope_old)
assert not torch.equal(q_down, q_down_old)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,349 @@
import itertools
import random
import numpy as np
import torch
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
def adapter_capacity(sorted_row_idx, sorted_expert_idx, capacity):
count = 0
last = sorted_expert_idx[0]
for i, val in enumerate(sorted_expert_idx):
if last != val:
count = 1
last = val
else:
count += 1
if count > capacity:
sorted_expert_idx[i] = -1
sorted_row_idx[i] = -1
def moe_init_routing_golden(x, expert_idx, scale, offset, active_num,
expert_capacity, expert_num, drop_pad_mode,
expert_tokens_num_type, expert_tokens_num_flag,
active_expert_range, quant_mode, row_idx_type):
if drop_pad_mode == 1:
if expert_num <= 0:
print("expert num can not be 0")
return
expert_start = active_expert_range[0] if drop_pad_mode == 0 else 0
expert_end = active_expert_range[1] if drop_pad_mode == 0 else expert_num
num_rows = x.shape[0]
h = x.shape[1]
k = expert_idx.shape[-1]
expert_idx_in = expert_idx.copy().reshape(-1)
actual_expert_total_num: int = np.sum((expert_idx_in >= expert_start)
& (expert_idx_in < expert_end))
expert_idx_in[(expert_idx_in
< expert_start)] = np.int32(np.iinfo(np.int32).max)
sorted_expert_indices = np.argsort(expert_idx_in, axis=-1, kind="stable")
sorted_expert_idx = expert_idx_in[sorted_expert_indices]
if row_idx_type == 1:
expanded_row_idx = sorted_expert_indices[:actual_expert_total_num]
else:
expanded_row_idx = np.ones(num_rows * k).astype(np.int32) * -1
tmp_indices = np.arange(actual_expert_total_num)
expanded_row_idx[
sorted_expert_indices[:actual_expert_total_num]] = tmp_indices
if not expert_tokens_num_flag:
expert_tokens_count = torch.tensor([0])
else:
if drop_pad_mode == 0:
if expert_tokens_num_type == 1:
expert_tokens_count = np.bincount(
sorted_expert_idx[:actual_expert_total_num] - expert_start)
expert_tokens_count = np.concatenate([
expert_tokens_count,
np.zeros((expert_end - expert_start) -
len(expert_tokens_count)).astype(np.int64)
])
elif expert_tokens_num_type == 0:
expert_tokens_count = np.bincount(
sorted_expert_idx[:actual_expert_total_num] - expert_start)
expert_tokens_count = np.concatenate([
expert_tokens_count,
np.zeros((expert_end - expert_start) -
len(expert_tokens_count)).astype(np.int64)
])
expert_tokens_count = np.cumsum(expert_tokens_count)
elif expert_tokens_num_type == 2:
expert_id, counts = np.unique(
sorted_expert_idx[:actual_expert_total_num],
return_counts=True)
expert_tokens_count = np.column_stack((expert_id, counts))
if expert_tokens_count.shape[0] < expert_num:
expert_tokens_count = np.concatenate(
(expert_tokens_count, [
[0, 0],
]), axis=0)
else:
expert_tokens_count = np.bincount(
sorted_expert_idx[:actual_expert_total_num] - expert_start)
zeros_array = np.zeros(
(expert_end - expert_start) - len(expert_tokens_count),
dtype=np.int64)
expert_tokens_count = np.concatenate(
[expert_tokens_count, zeros_array])
expert_tokens_count = expert_tokens_count.astype(np.int64)
if drop_pad_mode == 0:
if active_num == 0:
active_num = actual_expert_total_num
else:
active_num = min(active_num, actual_expert_total_num)
expanded_scale = None
expanded_x = x[sorted_expert_indices[:active_num] // k, :]
if scale is not None and quant_mode == -1:
expanded_scale = scale[sorted_expert_indices[:active_num] // k]
else:
adapter_capacity(sorted_expert_indices, sorted_expert_idx,
expert_capacity)
sort_row_tmp = np.full((expert_num * expert_capacity), -1, dtype=int)
offset_tmp = 0
lastExpertId = 0
for i, val in enumerate(sorted_expert_indices):
if val != -1:
if lastExpertId != sorted_expert_idx[i]:
offset_tmp = 0
lastExpertId = sorted_expert_idx[i]
sort_row_tmp[sorted_expert_idx[i] * expert_capacity +
offset_tmp] = sorted_expert_indices[i]
offset_tmp = offset_tmp + 1
expanded_row_idx = np.full(sorted_expert_indices.shape, -1)
for i, val in enumerate(sort_row_tmp):
if val != -1:
expanded_row_idx[val] = i
expanded_x_mask = np.full((expert_num * expert_capacity, h),
1,
dtype=int)
expanded_x = np.full((expert_num * expert_capacity, h),
0,
dtype=x.dtype)
for i, val in enumerate(sort_row_tmp):
if val != -1:
expanded_x[i] = x[val // k]
expanded_x_mask[i] = np.full((h, ), 0, dtype=int)
if quant_mode == -1:
expanded_x = expanded_x
expanded_row_idx = expanded_row_idx
if scale is not None and drop_pad_mode == 1:
expanded_scale = np.full((expert_num * expert_capacity, ),
0,
dtype=scale.dtype)
for i, val in enumerate(sort_row_tmp):
if val != -1:
expanded_scale[i] = scale[val // k]
if scale is None:
expanded_scale = None
if quant_mode == 0:
expanded_scale = None
expanded_x_fp16 = expanded_x.astype(np.float16)
if scale is not None:
scale_val = scale.astype(np.float16)
else:
raise ValueError("scale cannot be None when quant_mode is 0")
if offset is not None:
offset_val = offset.astype(np.float16)
else:
raise ValueError("offset cannot be None when quant_mode is 0")
scale_rst = expanded_x_fp16 * scale_val[0]
add_offset = scale_rst + offset_val[0]
round_data = np.rint(add_offset)
round_data = np.clip(round_data, -128, 127)
expanded_x = round_data.astype(np.int8)
if quant_mode == 1:
x_final = expanded_x.astype(np.float32)
if scale is None:
x_abs = np.abs(x_final)
x_max = np.max(x_abs, axis=-1, keepdims=True)
expanded_scale = x_max / 127
expanded_x = x_final / expanded_scale
expanded_x = np.round(expanded_x).astype(np.int8)
else:
if scale.shape[0] == 1:
x_final = x_final * scale
else:
if drop_pad_mode == 0:
x_final = x_final * scale[sorted_expert_idx[:active_num] -
expert_start]
else:
for i, val in enumerate(sort_row_tmp):
if val != -1:
x_final[i] = x_final[i] * scale[i //
expert_capacity]
x_abs = np.abs(x_final)
x_max = np.max(x_abs, axis=-1, keepdims=True)
expanded_scale = x_max / 127
expanded_x = x_final / expanded_scale
expanded_x = np.round(expanded_x).astype(np.int8)
if x.dtype == np.int8:
expanded_scale = None
if drop_pad_mode == 1:
expanded_x = np.ma.array(expanded_x, mask=expanded_x_mask).filled(0)
expanded_x = expanded_x.reshape(expert_num, expert_capacity, h)
return expanded_x, expanded_row_idx, expert_tokens_count, expanded_scale
def npu_pta(x, expert_idx, scale, offset, active_num, expert_capacity,
expert_num, drop_pad_mode, expert_tokens_num_type,
expert_tokens_num_flag, quant_mode, active_expert_range,
row_idx_type):
expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale = torch.ops._C_ascend.npu_moe_init_routing_custom(
x,
expert_idx,
scale=scale,
offset=offset,
active_num=active_num,
expert_capacity=expert_capacity,
expert_num=expert_num,
drop_pad_mode=drop_pad_mode,
expert_tokens_num_type=expert_tokens_num_type,
expert_tokens_num_flag=expert_tokens_num_flag,
quant_mode=quant_mode,
active_expert_range=active_expert_range,
row_idx_type=row_idx_type)
return expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale
def cmp_out_golden(x_golden, x_out, dtype):
if dtype == 'int8':
cmp = np.isclose(x_out.cpu().numpy()[:len(x_golden)], x_golden, atol=1)
else:
cmp = np.isclose(x_out.cpu().numpy()[:len(x_golden)],
x_golden,
rtol=1e-05,
atol=1e-05)
return np.all(cmp)
def test_moe_npu(x, expert_idx, scale, offset, active_num, expert_capacity,
expert_num, drop_pad_mode, expert_tokens_num_type,
expert_tokens_num_flag, quant_mode, active_expert_range,
row_idx_type):
x_npu = x.npu()
expert_idx_npu = expert_idx.npu()
scale_npu = scale.npu() if scale is not None else None
offset_npu = offset.npu() if offset is not None else None
x_numpy = x.numpy()
expert_idx_numpy = expert_idx.numpy()
scale_numpy = scale.numpy() if scale is not None else None
offset_numpy = offset.numpy() if offset is not None else None
expanded_x_golden, expanded_row_idx_golden, expert_token_cumsum_or_count_golden, expanded_scale_golden = moe_init_routing_golden(
x_numpy, expert_idx_numpy, scale_numpy, offset_numpy, active_num,
expert_capacity, expert_num, drop_pad_mode, expert_tokens_num_type,
expert_tokens_num_flag, active_expert_range, quant_mode, row_idx_type)
expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale = npu_pta(
x_npu, expert_idx_npu, scale_npu, offset_npu, active_num,
expert_capacity, expert_num, drop_pad_mode, expert_tokens_num_type,
expert_tokens_num_flag, quant_mode, active_expert_range, row_idx_type)
if quant_mode == -1:
expanded_x_result = cmp_out_golden(expanded_x_golden, expanded_x,
"float32")
else:
expanded_x_result = cmp_out_golden(expanded_x_golden, expanded_x,
"int8")
expanded_row_idx_result = cmp_out_golden(expanded_row_idx_golden,
expanded_row_idx, "int32")
if expert_tokens_num_flag:
expert_tokens_result = cmp_out_golden(
expert_token_cumsum_or_count_golden, expert_token_cumsum_or_count,
"int64")
else:
expert_tokens_result = True
if quant_mode == 1 or (quant_mode == -1 and scale is not None):
expand_scale_result = cmp_out_golden(expanded_scale_golden.flatten(),
expanded_scale, "float32")
else:
expand_scale_result = True
compare_result = expanded_x_result and expanded_row_idx_result and expert_tokens_result and expand_scale_result
# print('=======case result=======: ', compare_result)
return compare_result
def test_moe_init_routing_custom():
failed_test_cnt = 0
drop_pad_mode = [0, 1]
expert_tokens_num_type = [0, 1, 2]
expert_tokens_num_flag = [True, False]
quant_mode = [0, 1, -1]
row_idx_type = [0, 1]
scale_type = [0, 1, 2]
product_result = itertools.product(drop_pad_mode, expert_tokens_num_type,
expert_tokens_num_flag, quant_mode,
row_idx_type, scale_type)
for idx, (drop_pad_mode_, expert_tokens_num_type_, expert_tokens_num_flag_,
quant_mode_, row_idx_type_,
scale_type_) in enumerate(product_result, 5):
expert_num_ = random.randint(2, 500)
expert_start = random.randint(0, expert_num_ - 1)
expert_end = random.randint(expert_start + 1, expert_num_)
active_expert_range_ = [expert_start, expert_end]
N = random.randint(1, 100)
H = random.randint(12, 100)
K = random.randint(1, 12)
x_ = torch.randn(N, H, dtype=torch.float16) * 5
expert_capacity_ = random.randint(1, N - 1) if N > 1 else 1
expert_idx_ = torch.randint(0,
expert_num_ - 1, (N, K),
dtype=torch.int32)
active_num_ = N * K
if drop_pad_mode_ == 1:
active_expert_range_ = [0, expert_num_]
expert_tokens_num_type_ = 1
row_idx_type_ = 0
if quant_mode_ == 0:
scale_ = torch.randn(1, dtype=torch.float)
offset_ = torch.randn(1, dtype=torch.float)
elif quant_mode_ == -1:
scale_ = None
offset_ = None
else:
if scale_type_ == 0:
scale_ = None
offset_ = None
elif scale_type_ == 1:
scale_ = torch.randn(1, H, dtype=torch.float)
offset_ = None
else:
scale_ = torch.randn(active_expert_range_[1] -
active_expert_range_[0],
H,
dtype=torch.float)
offset_ = None
result_pta = test_moe_npu(x_, expert_idx_, scale_, offset_,
active_num_, expert_capacity_, expert_num_,
drop_pad_mode_, expert_tokens_num_type_,
expert_tokens_num_flag_, quant_mode_,
active_expert_range_, row_idx_type_)
if not result_pta:
failed_test_cnt += 1
assert (failed_test_cnt == 0)

View File

@@ -0,0 +1,351 @@
# Copyright 2023 The vLLM team.
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
# Adapted from
# https://github.com/vllm-project/vllm/blob/main/vllm/tests/kernels/test_rotary_embedding.py
import gc
from typing import Optional, Tuple, Union
import pytest
import torch
import torch.nn as nn
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
# Only Neox style true scenario is supported for now
IS_NEOX_STYLE = [True]
DTYPES = [torch.half]
HEAD_SIZES = [64, 64, 96, 128, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [17] # Arbitrary values for testing
BATCH_SIZES = [5] # Arbitrary values for testing
SEQ_LENS = [11, 4096] # Arbitrary values for testing
NUM_TOKENS = [10, 21]
SEEDS = [0]
DEVICES = [f"npu:{0}"]
# Set tolerance to 1 for quant ops
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3
def _apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
# adapted from https://github.com/vllm-project/vllm/vllm/model_executor/layers/rotary_embedding.py
class RotaryEmbedding(nn.Module):
"""Original rotary positional embedding."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
# test with leading dimension and merge seqlen and batch_size as num_tokens
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_rotary_embedding_quant_with_leading_dim(
is_neox_style: bool,
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: str,
max_position: int = 8192,
base: int = 10000,
) -> None:
if rotary_dim is None:
rotary_dim = head_size
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
rope = rope.to(dtype=dtype)
num_tokens = batch_size * seq_len
positions = torch.randint(0, max_position, (batch_size * seq_len, ))
qkv_tensor = torch.randn(num_tokens,
num_heads * head_size * 3,
dtype=dtype)
query, key, _ = qkv_tensor.split(
[num_heads * head_size, num_heads * head_size, num_heads * head_size],
dim=-1,
)
ref_query, ref_key = rope.forward_native(positions, query, key)
query, key = torch.ops._C_ascend.rotary_embedding(
positions,
query,
key,
rope.head_size,
rope.cos_sin_cache,
rope.is_neox_style,
)
# Compare the results.
torch.testing.assert_close(query.view(ref_query.size()),
ref_query,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(key.view(ref_key.size()),
ref_key,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
class ModelwithRotaryEmbedding(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.qkv_proj = nn.Linear(hidden_size, num_heads * head_size * 3)
self.rope = RotaryEmbedding(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position_embeddings,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
self.o_proj = nn.Linear(num_heads * head_size, hidden_size)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(3, dim=-1)
query, key = torch.ops._C_ascend.rotary_embedding(
positions,
q,
k,
self.rope.head_size,
self.rope.cos_sin_cache,
self.rope.is_neox_style,
)
query = query.view(q.shape)
key = key.view(k.shape)
o = self.o_proj(query)
return o
# The first graph seems will have some accuracy issue when directly run pytest on the ops folder,
# add a warmup graph replay for workaround
ACL_GRPAH_FIRST_RUN = True
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("num_tokens", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_capture_rotary_embedding_in_aclgraph(
is_neox_style: bool,
num_tokens: int,
num_heads: int,
head_size: int,
rotary_dim: int,
dtype: torch.dtype,
seed: int,
device: str,
max_position_embeddings: int = 8192,
base: int = 10000,
):
"""Test if the rotary embedding can be captured in aclgraph."""
torch.manual_seed(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
model = ModelwithRotaryEmbedding(
hidden_size=num_heads * head_size,
num_heads=num_heads,
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position_embeddings,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
def custom_op_checking_backend(gm: torch.fx.GraphModule, example_input):
# Validate if the rotary_embedding custom kernel is indeed inside the graph by
# string match
graph = str(gm.graph)
assert "_C_ascend.rotary_embedding" in graph
return gm
static_positions = torch.randint(0, max_position_embeddings,
(num_tokens, ))
static_hidden_states = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device="npu")
compiled_model = torch.compile(model, backend=custom_op_checking_backend)
stream = torch.npu.Stream()
stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(stream):
# warmup the fx graph before capture
for i in range(3):
static_output = compiled_model(static_positions,
static_hidden_states,
offsets=None)
stream.wait_stream(torch.npu.current_stream())
aclgraph = torch.npu.NPUGraph()
with torch.npu.graph(aclgraph):
# Capture the model in aclgraph.
static_output = compiled_model(static_positions, static_hidden_states)
# Capture the model in aclgraph.
random_filled_positions = torch.randint(0,
max_position_embeddings,
(num_tokens, ),
device="npu")
random_filled_hidden_states = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device="npu")
static_positions.copy_(random_filled_positions)
static_hidden_states.copy_(random_filled_hidden_states)
aclgraph.replay()
global ACL_GRPAH_FIRST_RUN
if ACL_GRPAH_FIRST_RUN:
ACL_GRPAH_FIRST_RUN = False
return
output_reference = model(static_positions, static_hidden_states)
torch.testing.assert_close(static_output,
output_reference,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,98 @@
import gc
from typing import Tuple
import pytest
import torch
import torch_npu # noqa: F401
import vllm_ascend.platform # noqa: F401
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
# Test parameters
DTYPES = [torch.int32]
#SHAPES = [(100,), (5, 20), (3, 4, 5)] # Various tensor shapes
#SHAPES = [(3, 4, 8), (3, 4, 5)] # Various tensor shapes
SHAPES = [(3, 4, 3)]
DEVICES = [f"npu:{0}"]
SEEDS = [0]
def get_masked_input_and_mask_ref(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Reference implementation for verification"""
org_vocab_mask = (input_ >= org_vocab_start_index) & (
input_ < org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index *
org_vocab_mask) + (added_offset * added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
masked_input = vocab_mask * (input_ - valid_offset)
return masked_input, ~vocab_mask
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_get_masked_input_and_mask(
shape: Tuple[int, ...],
dtype: torch.dtype,
device: str,
seed: int,
) -> None:
# Set random seed
torch.manual_seed(seed)
torch.set_default_device(device)
# Generate random input tensor
input_tensor = torch.randint(0, 1000, shape, dtype=dtype)
# Test parameters
test_case = {
"org_start": 100,
"org_end": 200,
"padding": 0,
"added_start": 300,
"added_end": 400,
}
# Get reference result
ref_masked_input, ref_mask = get_masked_input_and_mask_ref(
input_tensor, test_case["org_start"], test_case["org_end"],
test_case["padding"], test_case["added_start"], test_case["added_end"])
# Get custom op result
print("input_tensor:", input_tensor)
custom_masked_input, custom_mask = torch.ops._C_ascend.get_masked_input_and_mask(
input_tensor, test_case["org_start"], test_case["org_end"],
test_case["padding"], test_case["added_start"], test_case["added_end"])
ref_masked_input = ref_masked_input.to(dtype)
print("custom_masked_input:", custom_masked_input)
print("ref_masked_input:", ref_masked_input)
print("custom_mask:", custom_mask)
print("ref_mask:", ref_mask)
# Compare results
torch.testing.assert_close(
custom_masked_input,
ref_masked_input,
rtol=1e-5,
atol=1e-5,
msg=f"Masked input mismatch for case: {test_case}")
torch.testing.assert_close(custom_mask,
ref_mask,
rtol=1e-5,
atol=1e-5,
msg=f"Mask mismatch for case: {test_case}")
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -0,0 +1,361 @@
from typing import Optional
import pytest
import torch
import torch.nn.functional as F
from vllm_ascend.ops.triton.mamba.causal_conv1d import (PAD_SLOT_ID,
causal_conv1d_fn)
from vllm_ascend.ops.triton.mamba.causal_conv1d import \
causal_conv1d_update_npu as causal_conv1d_update
def validate_cmp(y_cal, y_ref, dtype, device='npu'):
y_cal = y_cal.to(device)
y_ref = y_ref.to(device)
if dtype == torch.float16:
torch.testing.assert_close(y_ref,
y_cal,
rtol=3e-03,
atol=1e-02,
equal_nan=True)
elif dtype == torch.bfloat16:
torch.testing.assert_close(y_ref,
y_cal,
rtol=1e-02,
atol=1e-02,
equal_nan=True)
elif dtype == torch.float32:
torch.testing.assert_close(y_ref,
y_cal,
rtol=1e-03,
atol=4e-03,
equal_nan=True)
elif dtype == torch.int32 or dtype == torch.int64 or dtype == torch.int16 or dtype == torch.int8 or dtype == torch.uint32:
assert torch.equal(y_cal, y_ref)
elif dtype == torch.bool:
assert torch.equal(y_cal, y_ref)
else:
raise ValueError(
'Invalid parameter \"dtype\" is found : {}'.format(dtype))
def causal_conv1d_ref(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
x = x.to(weight.dtype)
seqlen = x.shape[-1]
dim, width = weight.shape
if initial_states is None:
out = F.conv1d(x,
weight.unsqueeze(1),
bias,
padding=width - 1,
groups=dim)
else:
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
out = out[..., :seqlen]
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
else:
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_fn_pytorch(
x: torch.Tensor,
weight: torch.Tensor,
query_start_loc: torch.Tensor,
cache_indices: torch.Tensor,
has_initial_state: torch.Tensor,
conv_states: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
):
"""
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
weight: (dim, width)
bias: (dim,)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial
state for the calculations
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
out_ref = []
out_ref_b = []
seqlens = query_start_loc[1:] - query_start_loc[:-1]
seqlens = seqlens.tolist()
splits = torch.split(x, seqlens, dim=-1)
width = weight.shape[1]
for i in range(len(seqlens)):
x_s = splits[i]
if cache_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight,
bias,
activation=activation,
return_final_states=True,
final_states_out=conv_states[cache_indices[i]][..., :(
width - 1)].unsqueeze(0),
initial_states=conv_states[cache_indices[i]][..., :(width - 1)]
if has_initial_state[i] else None))
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
out_ref_tensor = torch.cat(out_ref, dim=0)
return out_ref_tensor
@pytest.mark.parametrize('has_initial_state', [False, True])
@pytest.mark.parametrize('itype', [torch.bfloat16])
@pytest.mark.parametrize('silu_activation', [True])
@pytest.mark.parametrize('has_bias', [True])
@pytest.mark.parametrize('seq_len', [[128, 1024, 2048, 4096]])
@pytest.mark.parametrize('extra_state_len', [0, 2])
@pytest.mark.parametrize('width', [2, 4])
@pytest.mark.parametrize('dim', [4160])
def test_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias,
silu_activation, itype, has_initial_state):
torch.random.manual_seed(0)
device = "npu"
cu_seqlen, num_seq = sum(seq_len), len(seq_len)
state_len = width - 1 + extra_state_len
x = torch.randn(cu_seqlen, dim, device=device, dtype=itype).transpose(0, 1)
weight = torch.randn(dim, width, device=device, dtype=itype)
query_start_loc = torch.cumsum(torch.tensor([0] + seq_len,
device=device,
dtype=torch.int32),
dim=0)
cache_indices = torch.arange(num_seq, device=device, dtype=torch.int32)
has_initial_state_tensor = torch.tensor([has_initial_state] * num_seq,
device=device,
dtype=torch.bool)
activation = None if not silu_activation else "silu"
if has_initial_state:
conv_states = torch.randn((num_seq, state_len, dim),
device=device,
dtype=itype).transpose(-1, -2)
conv_states_ref = torch.randn(
(num_seq, state_len, dim), device=device,
dtype=itype).transpose(-1, -2).copy_(conv_states)
else:
conv_states = torch.zeros((num_seq, state_len, dim),
device=device,
dtype=itype).transpose(-1, -2)
conv_states_ref = torch.zeros((num_seq, state_len, dim),
device=device,
dtype=itype).transpose(-1, -2)
if has_bias:
bias = torch.randn(dim, device=device, dtype=itype)
else:
bias = None
out_ref = causal_conv1d_fn_pytorch(
x,
weight,
bias=bias,
activation=activation,
conv_states=conv_states_ref,
has_initial_state=has_initial_state_tensor,
cache_indices=cache_indices,
query_start_loc=query_start_loc)
out = causal_conv1d_fn(x,
weight,
bias=bias,
activation=activation,
conv_states=conv_states,
has_initial_state=has_initial_state_tensor,
cache_indices=cache_indices,
query_start_loc=query_start_loc)
validate_cmp(out, out_ref, itype)
validate_cmp(conv_states, conv_states_ref, itype)
def causal_conv1d_update_ref(x,
conv_state,
weight,
bias=None,
activation=None,
cache_seqlens=None):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the
conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
width = weight.shape[1]
state_len = conv_state.shape[-1]
assert conv_state.shape == (batch, dim, state_len)
assert weight.shape == (dim, width)
if cache_seqlens is None:
x_new = torch.cat([conv_state, x], dim=-1).to(
weight.dtype) # (batch, dim, state_len + seqlen)
conv_state.copy_(x_new[:, :, -state_len:])
else:
width_idx = torch.arange(
-(width - 1), 0, dtype=torch.long,
device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
width_idx = (torch.remainder(width_idx, state_len).unsqueeze(1).expand(
-1, dim, -1))
x_new = torch.cat([conv_state.gather(2, width_idx), x],
dim=-1).to(weight.dtype)
copy_idx = torch.arange(
seqlen, dtype=torch.long,
device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
copy_idx = torch.remainder(copy_idx,
state_len).unsqueeze(1).expand(-1, dim, -1)
conv_state.scatter_(2, copy_idx, x)
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0,
groups=dim)[:, :, -seqlen:]
if unsqueeze:
out = out.squeeze(-1)
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 3])
@pytest.mark.parametrize("width", [3, 4])
@pytest.mark.parametrize("dim", [2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
@pytest.mark.parametrize("batch_size", [3, 64])
def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
width, seqlen, has_bias,
silu_activation, itype):
device = "npu"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
# total_entries = number of cache line
total_entries = 10 * batch_size
# x will be (batch, dim, seqlen) with contiguous along dim-axis
x = torch.randn(padded_batch_size, seqlen, dim, device=device,
dtype=itype).transpose(1, 2)
x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device)
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
unused_states_bool[conv_state_indices] = False
padded_state_indices = torch.concat(
[
conv_state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=0,
)
# conv_state will be (cache_lines, dim, state_len)
# with contiguous along dim-axis
conv_state = torch.randn(total_entries,
width - 1,
dim,
device=device,
dtype=itype).transpose(1, 2)
conv_state_for_padding_test = conv_state.clone()
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(
x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
)
out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
conv_state_ref,
weight,
bias,
activation=activation)
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.equal(conv_state[unused_states_bool],
conv_state_for_padding_test[unused_states_bool])
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)

View File

@@ -0,0 +1,34 @@
import pytest
import torch
import torch.nn.functional as F
from vllm_ascend.ops.triton.fla.l2norm import l2norm_fwd
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
@pytest.mark.parametrize(
('B', 'T', 'H', 'D', 'dtype'),
[
pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test))
for test in [
(1, 63, 1, 60, torch.float),
(2, 500, 4, 64, torch.float),
(2, 1000, 2, 100, torch.float),
(3, 1024, 4, 128, torch.float),
]
],
)
def test_l2norm(B: int, T: int, H: int, D: int, dtype: torch.dtype):
torch.manual_seed(42)
init_device_properties_triton()
device = "npu"
rtol, atol = (3e-4, 1e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
x = torch.randn(B, T, H, D, dtype=dtype).to(device).requires_grad_(True)
x = x * 0.5 + 0.3
ref = F.normalize(x, dim=-1, p=2)
tri = l2norm_fwd(x)
assert torch.allclose(tri, ref, rtol=rtol, atol=atol)

View File

@@ -0,0 +1,141 @@
import gc
import pytest
import torch
from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
IS_NEOX_STYLE = [True, False]
DTYPES = [torch.bfloat16, torch.float16]
HEAD_SIZES = [64, 128]
ROTARY_DIMS = [32, 64]
NUM_Q_HEADS = [64]
NUM_K_HEADS = [1]
NUM_TOKENS = [1, 4, 8, 16, 1024]
SEEDS = [0]
DEVICES = [f"npu:{0}"]
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
def _rope_pytorch_native(
query, key, cos, sin, rope_dim,
is_neox_style) -> tuple[torch.Tensor, torch.Tensor | None]:
"""PyTorch-native implementation equivalent to forward()."""
assert key is not None
orig_dtype = query.dtype
query_rot = query[..., :rope_dim].to(torch.float32)
key_rot = key[..., :rope_dim].to(torch.float32)
head_size = query.shape[-1]
if rope_dim < head_size:
query_pass = query[..., rope_dim:]
key_pass = key[..., rope_dim:]
if is_neox_style:
cos = cos.repeat(1, 2).unsqueeze(-2).to(torch.float32)
sin = sin.repeat(1, 2).unsqueeze(-2).to(torch.float32)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2).to(torch.float32)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2).to(torch.float32)
rotate_fn = rotate_neox if is_neox_style else rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if rope_dim < head_size:
query = torch.cat((query_rot.to(orig_dtype), query_pass), dim=-1)
key = torch.cat((key_rot.to(orig_dtype), key_pass), dim=-1)
else:
query = query_rot.to(orig_dtype)
key = key_rot.to(orig_dtype)
return query, key
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads", NUM_Q_HEADS)
@pytest.mark.parametrize("num_k_heads", NUM_K_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_rotary_embedding_triton_kernel(
is_neox_style: bool,
num_tokens: int,
num_q_heads: int,
num_k_heads: int,
head_size: int,
rotary_dim: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
torch.manual_seed(seed)
torch.set_default_device(device)
init_device_properties_triton()
if rotary_dim == -1:
rotary_dim = head_size
sin = torch.randn(num_tokens, rotary_dim // 2, dtype=dtype, device=device)
cos = torch.randn(num_tokens, rotary_dim // 2, dtype=dtype, device=device)
q_trt = torch.randn(num_tokens,
num_q_heads,
head_size,
dtype=dtype,
device=device)
k_trt = torch.randn(num_tokens,
num_k_heads,
head_size,
dtype=dtype,
device=device)
q_gold = torch.randn(num_tokens,
num_q_heads,
head_size,
dtype=dtype,
device=device)
k_gold = torch.randn(num_tokens,
num_k_heads,
head_size,
dtype=dtype,
device=device)
q_trt.copy_(q_gold)
k_trt.copy_(k_gold)
q_trt, k_trt = rope_forward_triton(q_trt,
k_trt,
cos,
sin,
rope_dim=rotary_dim,
is_neox_style=is_neox_style)
q_gold, k_gold = _rope_pytorch_native(q_gold,
k_gold,
cos,
sin,
rope_dim=rotary_dim,
is_neox_style=is_neox_style)
# Compare the results.
torch.testing.assert_close(q_trt.view(q_gold.size()),
q_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(k_trt.view(k_gold.size()),
k_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()