[1/N] Refactor nightly test structure (#5479)
### What this PR does / why we need it?
This patch is a series of refactoring actions, including clarifying the
directory structure of nightly tests, refactoring the config retrieval
logic, and optimizing the workflow, etc. This is the first step:
refactoring the directory structure of nightly to make it more readable
and logical.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
0
tests/e2e/nightly/single_node/ops/__init__.py
Normal file
0
tests/e2e/nightly/single_node/ops/__init__.py
Normal file
@@ -0,0 +1,439 @@
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch_npu
|
||||
import torchair
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
torch.manual_seed(42)
|
||||
torch_npu.npu.config.allow_internal_format = True
|
||||
enable_custom_op()
|
||||
LOG_NAME = "dispatch_gmm_combine_decode_test_logs"
|
||||
BASE_KWARGS = {
|
||||
"batch_size": 64,
|
||||
"token_hidden_size": 7168,
|
||||
"moe_intermediate_size": 2048,
|
||||
"ep_world_size": 16,
|
||||
"moe_expert_num": 64,
|
||||
"shared_expert_rank_num": 0,
|
||||
"top_k": 8,
|
||||
"test_bfloat16": True,
|
||||
"enable_dynamic_bs": False,
|
||||
"test_graph": False,
|
||||
"with_mc2_mask": False
|
||||
}
|
||||
|
||||
|
||||
def redirect_output(log_file_path):
|
||||
log_path = Path(LOG_NAME) / log_file_path
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
f = open(LOG_NAME + "/" + log_file_path, "w")
|
||||
os.dup2(f.fileno(), sys.stdout.fileno())
|
||||
os.dup2(f.fileno(), sys.stderr.fileno())
|
||||
return f
|
||||
|
||||
|
||||
def permute_weight(w: torch.Tensor, tile_n):
|
||||
*dims, n = w.shape
|
||||
order = list(range(len(dims))) + [-2, -3, -1]
|
||||
return w.reshape(*dims, 2, n // tile_n,
|
||||
tile_n // 2).permute(order).reshape(*dims,
|
||||
n).contiguous()
|
||||
|
||||
|
||||
def from_inclusive_prefix_sum(pref):
|
||||
if isinstance(pref, torch.Tensor):
|
||||
if pref.numel() == 0:
|
||||
return pref
|
||||
return torch.cat([pref[:1], pref[1:] - pref[:-1]])
|
||||
|
||||
if not pref:
|
||||
return []
|
||||
out = [pref[0]]
|
||||
for i in range(1, len(pref)):
|
||||
out.append(pref[i] - pref[i - 1])
|
||||
return out
|
||||
|
||||
|
||||
def output_to_file(rank_id):
|
||||
return False
|
||||
|
||||
|
||||
class DecodeMoeOps(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
gmm1_weight,
|
||||
gmm1_weight_scale,
|
||||
gmm2_weight,
|
||||
gmm2_weight_scale,
|
||||
ep_hcomm_info,
|
||||
batch_size,
|
||||
token_hidden_size,
|
||||
moe_intermediate_size,
|
||||
ep_world_size,
|
||||
moe_expert_num,
|
||||
global_rank_id,
|
||||
shared_expert_rank_num=0):
|
||||
super().__init__()
|
||||
self.ep_hcomm_info = ep_hcomm_info
|
||||
self.batch_size = batch_size
|
||||
self.token_hidden_size = token_hidden_size
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.ep_world_size = ep_world_size
|
||||
self.moe_expert_num = moe_expert_num
|
||||
self.global_rank_id = global_rank_id
|
||||
self.shared_expert_rank_num = shared_expert_rank_num
|
||||
is_shared_expert = global_rank_id < shared_expert_rank_num
|
||||
moe_expert_num_per_rank = moe_expert_num // (ep_world_size -
|
||||
shared_expert_rank_num)
|
||||
self.local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
|
||||
self.ep_recv_count_size = self.local_expert_num * ep_world_size
|
||||
self.gmm1_weight = torch.empty([
|
||||
self.local_expert_num, self.token_hidden_size,
|
||||
self.moe_intermediate_size * 2
|
||||
])
|
||||
self.gmm1_weight_scale = torch.empty(
|
||||
[self.local_expert_num, self.moe_intermediate_size * 2])
|
||||
self.gmm2_weight = torch.empty([
|
||||
self.local_expert_num, self.moe_intermediate_size,
|
||||
self.token_hidden_size
|
||||
])
|
||||
self.gmm2_weight_scale = torch.empty(
|
||||
[self.local_expert_num, self.token_hidden_size])
|
||||
self._process_weights_after_loading(gmm1_weight, gmm1_weight_scale,
|
||||
gmm2_weight, gmm2_weight_scale)
|
||||
|
||||
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
|
||||
gmm2_weight, gmm2_weight_scale):
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
|
||||
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
|
||||
requires_grad=False)
|
||||
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
|
||||
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
self.gmm1_weight_scale_fp32 = torch.nn.Parameter(
|
||||
gmm1_weight_scale.float(), requires_grad=False)
|
||||
self.gmm2_weight_scale_fp32 = torch.nn.Parameter(
|
||||
gmm2_weight_scale.float(), requires_grad=False)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask):
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
|
||||
def forward(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask):
|
||||
return self._apply_ops(x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask)
|
||||
|
||||
|
||||
class SmallOps(DecodeMoeOps):
|
||||
|
||||
def __init__(self,
|
||||
gmm1_weight,
|
||||
gmm1_weight_scale,
|
||||
gmm2_weight,
|
||||
gmm2_weight_scale,
|
||||
ep_hcomm_info,
|
||||
batch_size,
|
||||
token_hidden_size,
|
||||
moe_intermediate_size,
|
||||
ep_world_size,
|
||||
moe_expert_num,
|
||||
global_rank_id,
|
||||
shared_expert_rank_num=0):
|
||||
super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight,
|
||||
gmm2_weight_scale, ep_hcomm_info, batch_size,
|
||||
token_hidden_size, moe_intermediate_size,
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num)
|
||||
self.tp_hcomm_info = ""
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask):
|
||||
outputs = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
x=x,
|
||||
expert_ids=expert_ids,
|
||||
expert_scales=expert_scales,
|
||||
x_active_mask=x_active_mask,
|
||||
group_ep=self.ep_hcomm_info,
|
||||
ep_world_size=self.ep_world_size,
|
||||
ep_rank_id=self.global_rank_id,
|
||||
moe_expert_num=self.moe_expert_num,
|
||||
group_tp=self.tp_hcomm_info,
|
||||
tp_world_size=1,
|
||||
tp_rank_id=0,
|
||||
expert_shard_type=0,
|
||||
shared_expert_num=1,
|
||||
shared_expert_rank_num=self.shared_expert_rank_num,
|
||||
quant_mode=2,
|
||||
global_bs=self.batch_size * self.ep_world_size,
|
||||
expert_token_nums_type=1, # 0代表前缀和,1代表各自数量
|
||||
)
|
||||
expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums, ep_send_counts, tp_send_counts, expand_scales = outputs
|
||||
output_dtype = x.dtype
|
||||
|
||||
y1_int32 = torch_npu.npu_grouped_matmul(
|
||||
x=[expand_x],
|
||||
weight=[self.gmm1_weight],
|
||||
split_item=3,
|
||||
group_list_type=1, # 默认为0,代表前缀和形式
|
||||
group_type=0, # 0代表m轴分组
|
||||
group_list=expert_token_nums,
|
||||
output_dtype=torch.int32)[0]
|
||||
y1, y1_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=y1_int32,
|
||||
weight_scale=self.gmm1_weight_scale.to(torch.float32),
|
||||
activation_scale=dynamic_scales,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=expert_token_nums,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
y2 = torch_npu.npu_grouped_matmul(x=[y1],
|
||||
weight=[self.gmm2_weight],
|
||||
scale=[self.gmm2_weight_scale],
|
||||
per_token_scale=[y1_scale],
|
||||
split_item=2,
|
||||
group_list_type=1,
|
||||
group_type=0,
|
||||
group_list=expert_token_nums,
|
||||
output_dtype=output_dtype)[0]
|
||||
combine_output = torch_npu.npu_moe_distribute_combine_v2(
|
||||
expand_x=y2,
|
||||
expert_ids=expert_ids,
|
||||
assist_info_for_combine=assist_info_for_combine,
|
||||
ep_send_counts=ep_send_counts,
|
||||
expert_scales=expert_scales,
|
||||
x_active_mask=x_active_mask,
|
||||
group_ep=self.ep_hcomm_info,
|
||||
ep_world_size=self.ep_world_size,
|
||||
ep_rank_id=self.global_rank_id,
|
||||
moe_expert_num=self.moe_expert_num,
|
||||
tp_send_counts=tp_send_counts,
|
||||
expand_scales=expand_scales,
|
||||
group_tp=self.tp_hcomm_info,
|
||||
tp_world_size=1,
|
||||
tp_rank_id=0,
|
||||
expert_shard_type=0,
|
||||
shared_expert_num=1,
|
||||
shared_expert_rank_num=self.shared_expert_rank_num,
|
||||
global_bs=self.batch_size * self.ep_world_size)
|
||||
return (combine_output, ep_send_counts[:self.ep_recv_count_size])
|
||||
|
||||
|
||||
class FusionOp(DecodeMoeOps):
|
||||
|
||||
def __init__(self,
|
||||
gmm1_weight,
|
||||
gmm1_weight_scale,
|
||||
gmm2_weight,
|
||||
gmm2_weight_scale,
|
||||
ep_hcomm_info,
|
||||
batch_size,
|
||||
token_hidden_size,
|
||||
moe_intermediate_size,
|
||||
ep_world_size,
|
||||
moe_expert_num,
|
||||
global_rank_id,
|
||||
shared_expert_rank_num=0):
|
||||
super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight,
|
||||
gmm2_weight_scale, ep_hcomm_info, batch_size,
|
||||
token_hidden_size, moe_intermediate_size,
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask):
|
||||
output = torch.ops._C_ascend.dispatch_gmm_combine_decode(
|
||||
x=x,
|
||||
expert_ids=expert_ids,
|
||||
gmm1_permuted_weight=self.gmm1_weight,
|
||||
gmm1_permuted_weight_scale=self.gmm1_weight_scale_fp32,
|
||||
gmm2_weight=self.gmm2_weight,
|
||||
gmm2_weight_scale=self.gmm2_weight_scale_fp32,
|
||||
expert_scales=expert_scales,
|
||||
expert_smooth_scales=smooth_scales,
|
||||
x_active_mask=x_active_mask,
|
||||
group_ep=self.ep_hcomm_info,
|
||||
ep_rank_size=self.ep_world_size,
|
||||
ep_rank_id=self.global_rank_id,
|
||||
moe_expert_num=self.moe_expert_num,
|
||||
shared_expert_num=1,
|
||||
shared_expert_rank_num=self.shared_expert_rank_num,
|
||||
quant_mode=0,
|
||||
global_bs=self.batch_size * self.ep_world_size)
|
||||
return output
|
||||
|
||||
|
||||
def generate_datas(batch_size,
|
||||
token_hidden_size,
|
||||
moe_intermediate_size,
|
||||
ep_world_size,
|
||||
moe_expert_num,
|
||||
global_rank_id,
|
||||
shared_expert_rank_num=0,
|
||||
top_k=8,
|
||||
test_bfloat16=True,
|
||||
enable_dynamic_bs=False,
|
||||
with_mc2_mask=False):
|
||||
is_shared_expert = global_rank_id < shared_expert_rank_num
|
||||
moe_expert_num_per_rank = moe_expert_num // (ep_world_size -
|
||||
shared_expert_rank_num)
|
||||
actual_bs = int(
|
||||
torch.randint(2 if with_mc2_mask else 1, batch_size, [1]).item(
|
||||
) if enable_dynamic_bs else batch_size)
|
||||
local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
|
||||
gmm1_input_dim = token_hidden_size
|
||||
gmm1_output_dim = moe_intermediate_size * 2
|
||||
gmm2_input_dim = moe_intermediate_size
|
||||
gmm2_output_dim = token_hidden_size
|
||||
x = torch.rand([actual_bs, token_hidden_size]) * 10 - 5
|
||||
expert_ids = torch.arange(
|
||||
global_rank_id * batch_size * top_k,
|
||||
global_rank_id * batch_size * top_k + actual_bs * top_k).to(
|
||||
torch.int32).view(actual_bs, top_k)
|
||||
expert_ids = expert_ids % moe_expert_num
|
||||
if is_shared_expert:
|
||||
gmm1_weight = torch.ones([
|
||||
local_expert_num, gmm1_input_dim, gmm1_output_dim
|
||||
]).to(torch.int8) * 4
|
||||
gmm2_weight = torch.ones([
|
||||
local_expert_num, gmm2_input_dim, gmm2_output_dim
|
||||
]).to(torch.int8) * 4
|
||||
gmm1_weight[:, :, ::2] = gmm1_weight[:, :, ::2] * -1
|
||||
gmm2_weight[:, :, ::2] = gmm2_weight[:, :, ::2] * -1
|
||||
gmm1_weight_scale = torch.ones([local_expert_num, gmm1_output_dim
|
||||
]) * 0.0015
|
||||
gmm2_weight_scale = torch.ones([local_expert_num, gmm2_output_dim
|
||||
]) * 0.0015
|
||||
else:
|
||||
gmm1_weight = torch.randint(
|
||||
-16, 16,
|
||||
[local_expert_num, gmm1_input_dim, gmm1_output_dim]).to(torch.int8)
|
||||
gmm2_weight = torch.randint(
|
||||
-16, 16,
|
||||
[local_expert_num, gmm2_input_dim, gmm2_output_dim]).to(torch.int8)
|
||||
gmm1_weight_scale = torch.rand([local_expert_num, gmm1_output_dim
|
||||
]) * 0.003 + 0.0015
|
||||
gmm2_weight_scale = torch.rand([local_expert_num, gmm2_output_dim
|
||||
]) * 0.003 + 0.0015
|
||||
expert_scales = torch.rand(actual_bs, top_k)
|
||||
if test_bfloat16:
|
||||
x = x.bfloat16()
|
||||
gmm1_weight_scale = gmm1_weight_scale.bfloat16()
|
||||
gmm2_weight_scale = gmm2_weight_scale.bfloat16()
|
||||
else:
|
||||
x = x.half()
|
||||
smooth_sales = None
|
||||
x_active_mask = None
|
||||
valid_token_num = actual_bs
|
||||
if with_mc2_mask:
|
||||
valid_token_num = int(torch.randint(1, actual_bs, [1]).item())
|
||||
x_active_mask = torch.cat(
|
||||
(torch.ones(valid_token_num),
|
||||
torch.zeros(actual_bs - valid_token_num))).bool()
|
||||
return (x, expert_ids, smooth_sales, expert_scales, x_active_mask), \
|
||||
(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale), \
|
||||
actual_bs, valid_token_num
|
||||
|
||||
|
||||
def run_once(local_rank_id,
|
||||
batch_size,
|
||||
token_hidden_size,
|
||||
moe_intermediate_size,
|
||||
ep_world_size,
|
||||
moe_expert_num,
|
||||
shared_expert_rank_num=0,
|
||||
top_k=8,
|
||||
test_bfloat16=True,
|
||||
enable_dynamic_bs=False,
|
||||
test_graph=False,
|
||||
with_mc2_mask=False):
|
||||
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
|
||||
) if output_to_file(local_rank_id) else None
|
||||
global_rank_id = local_rank_id # 单机
|
||||
device_id = local_rank_id % 16
|
||||
torch_npu.npu.set_device(device_id)
|
||||
|
||||
# 初始化分布式环境
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "29500" # 端口号随意
|
||||
dist.init_process_group(backend="hccl",
|
||||
rank=local_rank_id,
|
||||
world_size=ep_world_size)
|
||||
ep_ranks_list = list(np.arange(0, ep_world_size))
|
||||
ep_group = dist.new_group(backend="hccl", ranks=ep_ranks_list)
|
||||
ep_group_small = dist.new_group(backend="hccl", ranks=ep_ranks_list)
|
||||
|
||||
ep_hcomm_info_fused = ep_group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(local_rank_id)
|
||||
ep_hcomm_info_small = ep_group_small._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(local_rank_id)
|
||||
torch_npu.npu.synchronize(device_id)
|
||||
|
||||
parameter = (batch_size, token_hidden_size, moe_intermediate_size,
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num)
|
||||
input_datas, weight_datas, actual_bs, valid_token_num = generate_datas(
|
||||
*parameter, top_k, test_bfloat16, enable_dynamic_bs, with_mc2_mask)
|
||||
input_datas = [
|
||||
data.npu() if data is not None else None for data in input_datas
|
||||
]
|
||||
weight_datas = [
|
||||
data.npu() if data is not None else None for data in weight_datas
|
||||
]
|
||||
small_ops = SmallOps(*weight_datas, ep_hcomm_info_small,
|
||||
*parameter).npu() # type: ignore
|
||||
fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused,
|
||||
*parameter).npu() # type: ignore
|
||||
if test_graph:
|
||||
config = torchair.CompilerConfig()
|
||||
config.mode = "reduce-overhead"
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
fused_ops = torch.compile(fused_ops, backend=npu_backend)
|
||||
small_op_token_output, small_op_count_output = small_ops(*input_datas)
|
||||
fused_op_token_output, fused_op_count_output = fused_ops(*input_datas)
|
||||
torch_npu.npu.synchronize(device_id)
|
||||
dist.destroy_process_group()
|
||||
if log_file is not None:
|
||||
log_file.close()
|
||||
small_op_count_output = from_inclusive_prefix_sum(small_op_count_output)
|
||||
torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(),
|
||||
fused_op_token_output[0:valid_token_num].cpu(),
|
||||
atol=2.0,
|
||||
rtol=0.02)
|
||||
torch.testing.assert_close(small_op_count_output.cpu(),
|
||||
fused_op_count_output.cpu())
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_dispatch_gmm_combine_decode_base():
|
||||
custom_kwargs = BASE_KWARGS
|
||||
ep_world_size = custom_kwargs["ep_world_size"]
|
||||
custom_args = tuple(custom_kwargs.values())
|
||||
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
|
||||
|
||||
|
||||
def test_dispatch_gmm_combine_decode_with_mc2_mask():
|
||||
custom_kwargs = BASE_KWARGS
|
||||
custom_kwargs["with_mc2_mask"] = True
|
||||
ep_world_size = custom_kwargs["ep_world_size"]
|
||||
custom_args = tuple(custom_kwargs.values())
|
||||
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
|
||||
@@ -0,0 +1,141 @@
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
torch.set_printoptions(threshold=float("inf"))
|
||||
|
||||
|
||||
class TestMatrixMultiplication(unittest.TestCase):
|
||||
|
||||
def compute_golden(self, a, b, res1, m, n):
|
||||
"""Compute reference result (golden)"""
|
||||
torch.bmm(a.transpose(0, 1),
|
||||
b,
|
||||
out=res1.view(-1, m, n).transpose(0, 1))
|
||||
|
||||
def assert_tensors_almost_equal(self, actual, expected, dtype):
|
||||
"""Check if two tensors are approximately equal (considering floating point errors)"""
|
||||
self.assertEqual(actual.shape, expected.shape, "Shape mismatch")
|
||||
|
||||
# Check for NaN
|
||||
self.assertFalse(
|
||||
torch.isnan(actual).any(), "Actual result contains NaN")
|
||||
self.assertFalse(
|
||||
torch.isnan(expected).any(), "Expected result contains NaN")
|
||||
|
||||
# Check for Inf
|
||||
self.assertFalse(
|
||||
torch.isinf(actual).any(), "Actual result contains Inf")
|
||||
self.assertFalse(
|
||||
torch.isinf(expected).any(), "Expected result contains Inf")
|
||||
|
||||
# Set different tolerances based on data type
|
||||
if dtype == torch.float16:
|
||||
rtol, atol = 1e-5, 1e-5
|
||||
else: # bfloat16
|
||||
rtol, atol = 1.5e-5, 1.5e-5
|
||||
|
||||
# Compare values
|
||||
diff = torch.abs(actual - expected)
|
||||
max_diff = diff.max().item()
|
||||
max_expected = torch.abs(expected).max().item()
|
||||
|
||||
# Check relative and absolute errors
|
||||
if max_expected > 0:
|
||||
relative_diff = max_diff / max_expected
|
||||
self.assertLessEqual(
|
||||
relative_diff,
|
||||
rtol,
|
||||
f"Relative error too large: {relative_diff} > {rtol}. Max difference: {max_diff}",
|
||||
)
|
||||
|
||||
self.assertLessEqual(max_diff, atol,
|
||||
f"Absolute error too large: {max_diff} > {atol}")
|
||||
|
||||
def test_boundary_conditions(self):
|
||||
"""Test boundary conditions"""
|
||||
test_cases = [
|
||||
# (b, m, k, n)
|
||||
(1, 1, 1, 1), # Minimum size
|
||||
(1, 10, 1, 1), # b=1
|
||||
(10, 1, 1, 10), # m=1
|
||||
(5, 5, 1, 5), # k=1
|
||||
(2, 2, 2, 1), # n=1
|
||||
(100, 1, 1, 100), # Flat case
|
||||
(1, 100, 100, 1), # Flat case
|
||||
(2, 3, 4, 5), # Random small size
|
||||
(10, 20, 30, 40), # Medium size
|
||||
(36, 128, 512, 128), # target case
|
||||
(8, 160, 512, 128),
|
||||
]
|
||||
|
||||
dtypes = [torch.float16, torch.bfloat16]
|
||||
|
||||
for dtype in dtypes:
|
||||
for b, m, k, n in test_cases:
|
||||
with self.subTest(dtype=dtype, shape=f"({b}, {m}, {k}, {n})"):
|
||||
a = torch.randn(b, m, k, dtype=dtype, device="npu")
|
||||
b_tensor = torch.randn(m, k, n, dtype=dtype, device="npu")
|
||||
res1 = torch.empty((b, m * n), dtype=dtype, device="npu")
|
||||
res2 = torch.empty((b, m, n), dtype=dtype, device="npu")
|
||||
|
||||
self.compute_golden(a, b_tensor, res1, m, n)
|
||||
torch.ops._C_ascend.batch_matmul_transpose(
|
||||
a, b_tensor, res2)
|
||||
|
||||
self.assert_tensors_almost_equal(res1.view(-1, m, n), res2,
|
||||
dtype)
|
||||
|
||||
def test_random_shapes(self):
|
||||
"""Test randomly generated shapes"""
|
||||
num_tests = 1
|
||||
dtypes = [torch.float16, torch.bfloat16]
|
||||
|
||||
for dtype in dtypes:
|
||||
for _ in range(num_tests):
|
||||
# Generate reasonable random sizes
|
||||
b = random.randint(1, 500)
|
||||
m = random.randint(1, 500)
|
||||
k = random.randint(1, 500)
|
||||
n = random.randint(1, 500)
|
||||
|
||||
with self.subTest(dtype=dtype,
|
||||
shape=f"Random ({b}, {m}, {k}, {n})"):
|
||||
a = torch.randn(b, m, k, dtype=dtype, device="npu")
|
||||
b_tensor = torch.randn(m, k, n, dtype=dtype, device="npu")
|
||||
res1 = torch.empty((b, m * n), dtype=dtype, device="npu")
|
||||
res2 = torch.empty((b, m, n), dtype=dtype, device="npu")
|
||||
|
||||
self.compute_golden(a, b_tensor, res1, m, n)
|
||||
torch.ops._C_ascend.batch_matmul_transpose(
|
||||
a, b_tensor, res2)
|
||||
self.assert_tensors_almost_equal(res1.view(-1, m, n), res2,
|
||||
dtype)
|
||||
|
||||
def test_zero_values(self):
|
||||
"""Test zero input values"""
|
||||
dtypes = [torch.float16, torch.bfloat16]
|
||||
b, m, k, n = 5, 4, 3, 2
|
||||
|
||||
for dtype in dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
a = torch.zeros(b, m, k, dtype=dtype, device="npu")
|
||||
b_tensor = torch.zeros(m, k, n, dtype=dtype, device="npu")
|
||||
res1 = torch.empty((b, m * n), dtype=dtype, device="npu")
|
||||
res2 = torch.empty((b, m, n), dtype=dtype, device="npu")
|
||||
|
||||
self.compute_golden(a, b_tensor, res1, m, n)
|
||||
torch.ops._C_ascend.batch_matmul_transpose(a, b_tensor, res2)
|
||||
|
||||
self.assert_tensors_almost_equal(res1.view(-1, m, n), res2,
|
||||
dtype)
|
||||
self.assertTrue(torch.all(res2 == 0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
@@ -0,0 +1,46 @@
|
||||
import gc
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
DEFAULT_ATOL = 1e-3
|
||||
DEFAULT_RTOL = 1e-3
|
||||
|
||||
|
||||
def bgmv_expand_cpu_impl(x: torch.Tensor, w: torch.Tensor,
|
||||
indices: torch.Tensor, y: torch.tensor,
|
||||
slice_offset: int, slice_size: int) -> torch.Tensor:
|
||||
W = w[indices, :, :].transpose(-1, -2).to(torch.float32)
|
||||
z = torch.bmm(x.unsqueeze(1).to(torch.float32), W).squeeze()
|
||||
y[:, slice_offset:slice_offset + slice_size] += z
|
||||
return y
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_bgmv_expand():
|
||||
B = 1
|
||||
x = torch.randn([B, 16], dtype=torch.float)
|
||||
w = torch.randn([64, 128, 16], dtype=torch.float16)
|
||||
indices = torch.zeros([B], dtype=torch.int64)
|
||||
y = torch.randn([B, 128 * 3], dtype=torch.float16)
|
||||
|
||||
x_npu = x.npu()
|
||||
w_npu = w.npu()
|
||||
indices_npu = indices.npu()
|
||||
y_npu = y.npu()
|
||||
|
||||
y_out = bgmv_expand_cpu_impl(x, w, indices, y, 0, 128)
|
||||
y_out_npu = torch.ops._C_ascend.bgmv_expand(x_npu, w_npu, indices_npu,
|
||||
y_npu, 0, 128)
|
||||
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(y_out_npu.cpu(),
|
||||
y_out,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
@@ -0,0 +1,45 @@
|
||||
import gc
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
DEFAULT_ATOL = 1e-3
|
||||
DEFAULT_RTOL = 1e-3
|
||||
|
||||
|
||||
def bgmv_shrink_cpu_impl(x: torch.Tensor, w: torch.Tensor,
|
||||
indices: torch.Tensor, y: torch.tensor,
|
||||
scaling: float) -> torch.Tensor:
|
||||
W = w[indices, :, :].transpose(-1, -2).to(torch.float32)
|
||||
z = torch.bmm(x.unsqueeze(1).to(torch.float32), W).squeeze()
|
||||
y[:, :] += z * scaling
|
||||
return y
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_bgmv_shrink():
|
||||
B = 1
|
||||
x = torch.randn([B, 128], dtype=torch.float16)
|
||||
w = torch.randn([64, 16, 128], dtype=torch.float16)
|
||||
indices = torch.zeros([B], dtype=torch.int64)
|
||||
y = torch.zeros([B, 16])
|
||||
|
||||
x_npu = x.npu()
|
||||
w_npu = w.npu()
|
||||
indices_npu = indices.npu()
|
||||
y_npu = y.npu()
|
||||
|
||||
y = bgmv_shrink_cpu_impl(x, w, indices, y, 0.5)
|
||||
torch.ops._C_ascend.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5)
|
||||
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(y_npu.cpu(),
|
||||
y,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
@@ -0,0 +1,168 @@
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch_npu
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
|
||||
class TestDisptachFFNCombine:
|
||||
|
||||
def __init__(self, rank, world_size, port):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.master_ip = "127.0.0.1"
|
||||
self.port = port
|
||||
|
||||
def get_hcomm(self, comm_group):
|
||||
hcomm_info = None
|
||||
if torch.__version__ > "2.0.1":
|
||||
hcomm_info = comm_group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(self.rank)
|
||||
else:
|
||||
hcomm_info = comm_group.get_hccl_comm_name(self.rank)
|
||||
return hcomm_info
|
||||
|
||||
def setup_ep_tp(
|
||||
self,
|
||||
rank,
|
||||
tp_size,
|
||||
ep_size,
|
||||
backend_type,
|
||||
ep_ranks_list=None,
|
||||
tp_ranks_list=None,
|
||||
):
|
||||
for i in range(tp_size):
|
||||
if ep_ranks_list:
|
||||
ep_ranks = ep_ranks_list[i]
|
||||
else:
|
||||
ep_ranks = [x + ep_size * i for x in range(ep_size)]
|
||||
ep_group = dist.new_group(backend=backend_type, ranks=ep_ranks)
|
||||
if rank in ep_ranks:
|
||||
ep_group_tmp = ep_group
|
||||
for i in range(ep_size):
|
||||
if tp_ranks_list:
|
||||
tp_ranks = tp_ranks_list[i]
|
||||
else:
|
||||
tp_ranks = [x * ep_size + i for x in range(tp_size)]
|
||||
tp_group = dist.new_group(backend=backend_type, ranks=tp_ranks)
|
||||
if rank in tp_ranks:
|
||||
tp_group_tmp = tp_group
|
||||
return ep_group_tmp, tp_group_tmp
|
||||
|
||||
def generate_hcom(self):
|
||||
torch_npu.npu.set_device(self.rank)
|
||||
dist.init_process_group(
|
||||
backend="hccl",
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
init_method=f"tcp://127.0.0.1:{self.port}",
|
||||
)
|
||||
|
||||
ep_size = 0
|
||||
tp_size = self.world_size
|
||||
hcomm_info_dist = {
|
||||
"default_pg_info": None,
|
||||
"ep_hcomm_info": None,
|
||||
"group_ep": None,
|
||||
"tp_hcomm_info": None,
|
||||
"group_tp": None,
|
||||
}
|
||||
if ep_size and tp_size:
|
||||
group_ep, group_tp = self.setup_ep_tp(self.rank, tp_size, ep_size,
|
||||
"hccl", None, None)
|
||||
hcomm_info_dist["ep_hcomm_info"] = self.get_hcomm(group_ep)
|
||||
hcomm_info_dist["tp_hcomm_info"] = self.get_hcomm(group_tp)
|
||||
hcomm_info_dist["group_ep"] = group_ep
|
||||
hcomm_info_dist["group_tp"] = group_tp
|
||||
else:
|
||||
if dist.is_available():
|
||||
default_pg = _get_default_group()
|
||||
hcomm_info_dist["default_pg_info"] = self.get_hcomm(default_pg)
|
||||
hcomm_info = hcomm_info_dist["default_pg_info"]
|
||||
self.hcomm_info = hcomm_info
|
||||
|
||||
def run_npu_out(self) -> bool:
|
||||
torch_npu.npu.set_device(self.rank)
|
||||
m = 2 # token-num 32
|
||||
k = 4 # hidden_size 7168
|
||||
n = 4 # mid-hidden-size 4096
|
||||
topk = 2
|
||||
e = 2 # expert-num-per-rank 16
|
||||
k2 = n // 2
|
||||
n2 = k
|
||||
|
||||
torch_npu.npu.config.allow_internal_format = True
|
||||
x = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
||||
weight1 = self.generate_random_tensor((e, k, n),
|
||||
dtype=torch.int8).npu()
|
||||
weight1 = torch_npu.npu_format_cast(weight1, 29)
|
||||
weight2 = self.generate_random_tensor((e, k2, n2),
|
||||
dtype=torch.int8).npu()
|
||||
weight2 = torch_npu.npu_format_cast(weight2, 29)
|
||||
|
||||
expert_idx = torch.randint(0,
|
||||
self.world_size * e, (m, topk),
|
||||
dtype=torch.int32).npu()
|
||||
scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu()
|
||||
scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu()
|
||||
probs = torch.randn(size=(m, topk), dtype=torch.float32).npu()
|
||||
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
||||
|
||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||
x=x,
|
||||
weight1=weight1,
|
||||
weight2=weight2,
|
||||
expert_idx=expert_idx,
|
||||
scale1=scale1,
|
||||
scale2=scale2,
|
||||
probs=probs,
|
||||
group=self.hcomm_info,
|
||||
max_output_size=512,
|
||||
out=out,
|
||||
)
|
||||
return True
|
||||
|
||||
def generate_random_tensor(self, size, dtype):
|
||||
if dtype in [torch.float16, torch.bfloat16, torch.float32]:
|
||||
return torch.randn(size=size, dtype=dtype)
|
||||
elif dtype is torch.int8:
|
||||
return torch.randint(-16, 16, size=size, dtype=dtype)
|
||||
elif dtype is torch.int32:
|
||||
return torch.randint(-1024, 1024, size=size, dtype=dtype)
|
||||
else:
|
||||
raise ValueError(f"Invalid dtype: {dtype}")
|
||||
|
||||
|
||||
def worker(rank: int, world_size: int, port: int, q: mp.SimpleQueue):
|
||||
op = TestDisptachFFNCombine(rank, world_size, port)
|
||||
op.generate_hcom()
|
||||
out = op.run_npu_out()
|
||||
q.put(out)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_dispatch_ffn_combine_kernel():
|
||||
world_size = 2
|
||||
mp.set_start_method("fork", force=True)
|
||||
|
||||
q = mp.SimpleQueue()
|
||||
p_list = []
|
||||
port = 29501 + random.randint(0, 10000)
|
||||
|
||||
for rank in range(world_size):
|
||||
p = mp.Process(target=worker, args=(rank, world_size, port, q))
|
||||
p.start()
|
||||
p_list.append(p)
|
||||
|
||||
results = [q.get() for _ in range(world_size)]
|
||||
|
||||
for p in p_list:
|
||||
p.join()
|
||||
|
||||
assert all(results)
|
||||
@@ -0,0 +1,338 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
"""Tests for the MOE layers.
|
||||
|
||||
Run `pytest tests/ops/test_fused_moe.py`.
|
||||
"""
|
||||
|
||||
import gc
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import (
|
||||
check_npu_moe_gating_top_k, select_experts)
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import \
|
||||
TokenDispatcherWithAllGather
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1]
|
||||
TOP_KS = [2, 6]
|
||||
DEVICE = ["npu"]
|
||||
|
||||
|
||||
def apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
) -> torch.Tensor:
|
||||
w1 = w1.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
topk_weights = topk_weights.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(
|
||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("device", DEVICE)
|
||||
def test_token_dispatcher_with_all_gather(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
|
||||
|
||||
score = torch.randn((m, e), device=device, dtype=dtype)
|
||||
expert_map = None
|
||||
local_e = e
|
||||
w1_local = w1
|
||||
w2_local = w2
|
||||
|
||||
score = torch.softmax(score, dim=-1, dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
|
||||
dispatcher_kwargs = {
|
||||
"num_experts": e,
|
||||
"top_k": topk,
|
||||
"num_local_experts": local_e,
|
||||
}
|
||||
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
|
||||
|
||||
apply_router_weight_on_input = False
|
||||
dispatch_output = dispatcher.token_dispatch(
|
||||
hidden_states=a,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
sorted_hidden_states = dispatch_output["hidden_states"]
|
||||
group_list = dispatch_output["group_list"]
|
||||
group_list_type = dispatch_output.get("group_list_type", 1)
|
||||
context_metadata = dispatch_output["context_metadata"]
|
||||
|
||||
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
|
||||
w1=w1_local,
|
||||
w2=w2_local,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type)
|
||||
|
||||
combined_output = dispatcher.token_combine(
|
||||
hidden_states=expert_output,
|
||||
context_metadata=context_metadata,
|
||||
bias=None)
|
||||
|
||||
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
|
||||
expert_map)
|
||||
|
||||
torch.testing.assert_close(combined_output,
|
||||
torch_output,
|
||||
atol=4e-2,
|
||||
rtol=1)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("device", DEVICE)
|
||||
def test_token_dispatcher_with_all_gather_quant(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
context_mock = MagicMock()
|
||||
context_mock.fused_moe_state = 0
|
||||
with patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context",
|
||||
return_value=context_mock):
|
||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)
|
||||
w1_scale = torch.empty((e, 2 * n), device=device, dtype=dtype)
|
||||
w2 = torch.randn((e, n, k), device=device, dtype=torch.int8)
|
||||
w2_scale = torch.empty((e, k), device=device, dtype=dtype)
|
||||
|
||||
score = torch.randn((m, e), device=device, dtype=dtype)
|
||||
expert_map = None
|
||||
local_e = e
|
||||
|
||||
score = torch.softmax(score, dim=-1, dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
|
||||
dispatcher_kwargs = {
|
||||
"num_experts": e,
|
||||
"top_k": topk,
|
||||
"num_local_experts": local_e,
|
||||
}
|
||||
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
|
||||
|
||||
apply_router_weight_on_input = False
|
||||
dispatch_output = dispatcher.token_dispatch(
|
||||
hidden_states=a,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=True)
|
||||
|
||||
sorted_hidden_states = dispatch_output["hidden_states"]
|
||||
group_list = dispatch_output["group_list"]
|
||||
group_list_type = dispatch_output.get("group_list_type", 1)
|
||||
dynamic_scale = dispatch_output["dynamic_scale"]
|
||||
context_metadata = dispatch_output["context_metadata"]
|
||||
|
||||
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
dynamic_scale=dynamic_scale,
|
||||
with_quant=True)
|
||||
combined_output = dispatcher.token_combine(
|
||||
hidden_states=expert_output,
|
||||
context_metadata=context_metadata,
|
||||
bias=None)
|
||||
assert combined_output.shape == (m, k)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
|
||||
@pytest.mark.parametrize("use_grouped_topk", [True, False])
|
||||
@pytest.mark.parametrize("renormalize", [True, False])
|
||||
@pytest.mark.parametrize("with_e_correction", [True, False])
|
||||
@pytest.mark.parametrize("custom_routing", [True, False])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("device", DEVICE)
|
||||
def test_select_experts(
|
||||
m: int,
|
||||
n: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
scoring_func: str,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
with_e_correction: bool,
|
||||
custom_routing: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
topk_group = 4 if use_grouped_topk else None
|
||||
num_expert_group = e // 4 if use_grouped_topk else None
|
||||
|
||||
hidden_states = torch.randn(m, n, device=device, dtype=dtype)
|
||||
router_logits = torch.randn(m, e, device=device, dtype=dtype)
|
||||
|
||||
e_score_correction_bias = (torch.randn(e, device=device, dtype=dtype)
|
||||
if with_e_correction else None)
|
||||
|
||||
custom_routing_function = None
|
||||
if custom_routing:
|
||||
custom_routing_function = MagicMock()
|
||||
mock_weights = torch.randn(m, topk, device=device, dtype=dtype)
|
||||
mock_ids = torch.randint(0,
|
||||
e, (m, topk),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
custom_routing_function.return_value = (mock_weights, mock_ids)
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk"
|
||||
) as mock_native_grouped_topk, \
|
||||
patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
|
||||
return_value=MagicMock()):
|
||||
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
||||
x)
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=topk,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
call_moe_gatingtopk = check_npu_moe_gating_top_k(
|
||||
hidden_states, topk, topk_group, num_expert_group, scoring_func,
|
||||
custom_routing_function)
|
||||
if not call_moe_gatingtopk and use_grouped_topk:
|
||||
mock_native_grouped_topk.assert_called_once()
|
||||
else:
|
||||
mock_native_grouped_topk.assert_not_called()
|
||||
|
||||
assert topk_weights.shape == (m, topk)
|
||||
assert topk_ids.shape == (m, topk)
|
||||
assert topk_ids.dtype == torch.int32
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICE)
|
||||
def test_select_experts_invalid_scoring_func(device: str):
|
||||
with patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
|
||||
return_value=MagicMock()), \
|
||||
pytest.raises(ValueError,
|
||||
match="Unsupported scoring function: invalid"):
|
||||
select_experts(hidden_states=torch.randn(1, 128, device=device),
|
||||
router_logits=torch.randn(1, 8, device=device),
|
||||
top_k=2,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="invalid")
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
@@ -0,0 +1,37 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'B',
|
||||
[1, 16, 64, 128, 32768],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
'D',
|
||||
[8, 16, 32, 64, 128],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
'top_k',
|
||||
[1, 2, 4, 8],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"dtype, atol, rtol",
|
||||
[
|
||||
(torch.float16, 1e-3, 1e-3),
|
||||
(torch.bfloat16, 1e-3, 1e-3),
|
||||
],
|
||||
)
|
||||
def test_quant_fpx_linear(B: int, D: int, top_k: int, dtype, atol, rtol):
|
||||
x = torch.rand((B, D), dtype=dtype).to("npu")
|
||||
# finished = torch.randint(1, size=(B,), dtype=torch.bool).to("npu")
|
||||
finished = None
|
||||
y, expert_idx, row_idx = torch_npu.npu_moe_gating_top_k_softmax(x,
|
||||
finished,
|
||||
k=top_k)
|
||||
|
||||
topk_weights = x.softmax(dim=-1)
|
||||
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
torch.allclose(y, topk_weights, atol=atol, rtol=rtol)
|
||||
torch.allclose(expert_idx, topk_ids, atol=atol, rtol=rtol)
|
||||
@@ -0,0 +1,148 @@
|
||||
import gc
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
# enable internal format
|
||||
torch_npu.npu.config.allow_internal_format = True
|
||||
# enable vllm-ascend custom ops
|
||||
enable_custom_op()
|
||||
|
||||
|
||||
def gmm_swiglu_quant(x: torch.Tensor, weight: torch.Tensor,
|
||||
perChannelScale: torch.Tensor,
|
||||
perTokenScale: torch.Tensor, m: int):
|
||||
"""
|
||||
Perform quantized GMM (Grouped Matrix Multiplication) operation with SwiGLU activation function.
|
||||
|
||||
Parameters:
|
||||
x (torch.Tensor): Input tensor with shape (m, k).
|
||||
weight (torch.Tensor): Weight tensor with shape (k, n).
|
||||
perChannelScale (torch.Tensor): Per-channel scaling factor with shape (n,).
|
||||
perTokenScale (torch.Tensor): Per-token scaling factor with shape (m,).
|
||||
m (int): Number of tokens (rows of x).
|
||||
|
||||
Returns:
|
||||
quantOutput (torch.Tensor): Quantized output tensor with shape (m, k // 2).
|
||||
quantScaleOutput (torch.Tensor): Quantization scaling factor with shape (m,).
|
||||
"""
|
||||
# Perform matrix multiplication with int32 precision
|
||||
c_temp1 = torch.matmul(x.to(torch.int32), weight.to(torch.int32))
|
||||
c_temp1 = c_temp1.to(torch.float32) # Convert back to float32 for scaling
|
||||
|
||||
# Apply per-channel and per-token scaling
|
||||
c_temp2 = torch.mul(c_temp1, perChannelScale)
|
||||
c_temp3 = torch.mul(c_temp2, perTokenScale.reshape(m, 1))
|
||||
|
||||
# Split the result into two parts to apply SwiGLU activation function
|
||||
c_temp4, gate = c_temp3.chunk(2, dim=-1)
|
||||
c_temp5 = c_temp4 * torch.sigmoid(c_temp4) # SwiGLU activation
|
||||
c_temp6 = c_temp5 * gate # Element-wise multiplication with gating values
|
||||
|
||||
# Quantize the output
|
||||
max = torch.max(
|
||||
torch.abs(c_temp6),
|
||||
-1).values # Find maximum absolute value to calculate scaling factor
|
||||
quantScaleOutput = 127 / max # Calculate quantization scaling factor
|
||||
quantOutput = torch.round(c_temp6 * quantScaleOutput.reshape(m, 1)).to(
|
||||
torch.int8) # Quantize to int8
|
||||
quantScaleOutput = 1 / quantScaleOutput # Inverse quantization scaling factor for subsequent dequantization
|
||||
|
||||
return quantOutput, quantScaleOutput
|
||||
|
||||
|
||||
def process_groups(x: torch.Tensor, weight: torch.Tensor,
|
||||
perChannelScale: torch.Tensor, perTokenScale: torch.Tensor,
|
||||
groupList: torch.Tensor):
|
||||
"""
|
||||
Process input data by groups and call GMM_Swiglu_quant function for quantized computation.
|
||||
|
||||
Parameters:
|
||||
x (torch.Tensor): Input tensor with shape (M, K).
|
||||
weight (torch.Tensor): List of weight tensors, each with shape (E, K, N).
|
||||
perChannelScale (torch.Tensor): List of per-channel scaling factors, each with shape (E, N).
|
||||
perTokenScale (torch.Tensor): Per-token scaling factor with shape (M,).
|
||||
groupList (list): List defining the number of tokens in each group.
|
||||
|
||||
Returns:
|
||||
quantOutput (torch.Tensor): Quantized output tensor with shape (M, N // 2).
|
||||
quantScaleOutput (torch.Tensor): Quantization scaling factor with shape (M,).
|
||||
"""
|
||||
M, N = x.shape[0], weight.shape[2] # Get the shape of the input tensor
|
||||
quantOutput = torch.zeros(M, N // 2).to(
|
||||
torch.int8) # Initialize quantized output tensor
|
||||
quantScaleOutput = torch.zeros(M).to(
|
||||
torch.float32) # Initialize quantization scaling factor tensor
|
||||
|
||||
start_idx = 0 # Starting index
|
||||
preV = 0 # Number of tokens in the previous group
|
||||
groupList = groupList.tolist()
|
||||
# Iterate through groupList to process data by groups
|
||||
for i, v in enumerate(groupList):
|
||||
currV = v
|
||||
tempV = currV - preV # Calculate number of tokens in the current group
|
||||
preV = currV # Update number of tokens in the previous group
|
||||
if tempV > 0:
|
||||
# Call GMM_Swiglu_quant to process the current group
|
||||
quantOutput[start_idx:start_idx + tempV], quantScaleOutput[start_idx:start_idx + tempV] = \
|
||||
gmm_swiglu_quant(x[start_idx:start_idx + tempV],
|
||||
weight[i],
|
||||
perChannelScale[i],
|
||||
perTokenScale[start_idx:start_idx + tempV],
|
||||
tempV)
|
||||
|
||||
start_idx += tempV # Update starting index to process the next group
|
||||
return quantOutput, quantScaleOutput
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_gmm_swiglu_quant_weight_nz_tensor_list():
|
||||
M, K, E, N = 8192, 7168, 4, 4096
|
||||
|
||||
# x (M, K) - int8
|
||||
x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
|
||||
|
||||
# weight (E, N, K) - int8
|
||||
weight = torch.randint(-128, 127, size=(E, K, N), dtype=torch.int8)
|
||||
|
||||
# weight_scale (E, N) - float32
|
||||
weight_scale = torch.rand(E, N) * 0.9 + 0.1 # uniform(0.1, 1.0)
|
||||
weight_scale = weight_scale.to(torch.float32)
|
||||
|
||||
weight_nz_npu = []
|
||||
weight_scale_npu = []
|
||||
for i in range(E):
|
||||
weight_nz_npu.append(torch_npu.npu_format_cast(weight[i].npu(), 29))
|
||||
weight_scale_npu.append(weight_scale[i].npu())
|
||||
|
||||
# x_scale (M,) - float32
|
||||
x_scale = torch.rand(M) * 0.9 + 0.1 # uniform(0.1, 1.0)
|
||||
x_scale = x_scale.to(torch.float32)
|
||||
|
||||
group_list = torch.tensor([2048, 4096, 6144, 8192], dtype=torch.int64)
|
||||
|
||||
output_cpu, output_scale_cpu = process_groups(x, weight, weight_scale,
|
||||
x_scale, group_list)
|
||||
output_npu, output_scale_npu, _ = \
|
||||
torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(x.npu(),
|
||||
weight_nz_npu,
|
||||
weight_scale_npu,
|
||||
x_scale.npu(),
|
||||
group_list.npu())
|
||||
output_npu_valid = output_npu[:group_list[-1], :]
|
||||
output_scale_npu_valid = output_scale_npu[:group_list[-1]]
|
||||
|
||||
torch.testing.assert_close(output_npu_valid.cpu(),
|
||||
output_cpu,
|
||||
atol=1,
|
||||
rtol=2**-13)
|
||||
torch.testing.assert_close(output_scale_npu_valid.cpu(),
|
||||
output_scale_cpu,
|
||||
atol=1e-9,
|
||||
rtol=1e-6)
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
@@ -0,0 +1,175 @@
|
||||
import gc
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
|
||||
def x_int8_to_x_int4(x: torch.Tensor):
|
||||
m, k = x.shape
|
||||
x_high_4bit = torch.floor(x.to(torch.float16) // 16).to(torch.int8)
|
||||
x_low_4bit = (
|
||||
torch.bitwise_and(x.view(torch.int16), 0x0f0f).view(torch.int8) - 8)
|
||||
x_int4 = torch.empty((2 * m, k), dtype=torch.int8)
|
||||
x_int4[::2, :] = x_high_4bit
|
||||
x_int4[1::2, :] = x_low_4bit
|
||||
return x_int4
|
||||
|
||||
|
||||
def custom_mm(x: torch.Tensor, weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor, m: int):
|
||||
"""
|
||||
Performing Quantized GMM (General Matrix Multiplication) Operation
|
||||
Parameters:
|
||||
x (torch.Tensor): Input tensor with shape (m, k).
|
||||
weight (torch.Tensor): Weight tensor with shape (k, n).
|
||||
weight_scale (torch.Tensor): Scaling factor for each channel.
|
||||
- In perGroup scenario: Shape is (k_group_num, n). Note: When k_group_num == 1, it is a perChannel scenario.
|
||||
- In perChannel scenario: Shape is (n).
|
||||
m (int): Number of tokens (number of rows in x).
|
||||
Returns:
|
||||
mm_out(fp16): Result of MatMul + perGroup or perChannel dequantization.
|
||||
"""
|
||||
# Perform matrix multiplication with int32 precision
|
||||
k, n = weight.shape
|
||||
mm_out = torch.zeros((m, n), dtype=torch.float16)
|
||||
# perGroup scenario
|
||||
if len(weight_scale.shape) == 2 and weight_scale.shape[0] != 1:
|
||||
k_group = weight_scale.shape[0]
|
||||
per_group_ele = k // k_group
|
||||
x_grouped = x.view(-1, k_group, per_group_ele).transpose(0, 1)
|
||||
weight_grouped = weight.view(k_group, per_group_ele, n)
|
||||
|
||||
c_temp = torch.bmm(x_grouped.to(torch.int32),
|
||||
weight_grouped.to(torch.int32)).to(torch.float16)
|
||||
for k_idx in range(k_group):
|
||||
mm_out += (c_temp[k_idx] *
|
||||
weight_scale[k_idx].view(1, -1).to(torch.float16)).to(
|
||||
torch.float16)
|
||||
# perChannel scenario
|
||||
elif len(weight_scale.shape) == 1 or (len(weight_scale.shape) == 2
|
||||
and weight_scale.shape[0] == 1):
|
||||
c_temp = torch.matmul(x.to(torch.int32),
|
||||
weight.to(torch.int32)).to(torch.float32)
|
||||
mm_out = c_temp * weight_scale.view(1, -1).to(torch.float16)
|
||||
return mm_out.to(torch.float32)
|
||||
|
||||
|
||||
def gmm_swiglu_quant_golden_a8_w4(x: torch.Tensor, weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
per_token_scale: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
group_list: torch.Tensor):
|
||||
"""
|
||||
Process the input data by group and call the GMM_Swiglu_quant function for quantization computation.
|
||||
Parameters:
|
||||
x (torch.Tensor): Input tensor with shape (M, K), type INT8.
|
||||
weight (torch.Tensor): List of weight tensors, each with shape (E, K, N), data type INT8 but data range INT4, representing INT4 values.
|
||||
weight_scale (torch.Tensor): Scaling factor for each channel.
|
||||
- In perGroup scenario: shape (E, k_group_num, N).
|
||||
- In perChannel scenario: shape (E, N).
|
||||
per_token_scale (torch.Tensor): Scaling factor for each token, shape (M, ).
|
||||
bias: torch.Tensor,
|
||||
group_list (list): List defining the number of tokens in each group.
|
||||
Returns:
|
||||
quant_output (torch.Tensor): Quantized output tensor with shape (M, N // 2).
|
||||
quant_scale_output (torch.Tensor): Quantization scaling factor, shape (M, ).
|
||||
"""
|
||||
M, N = x.shape[0], weight.shape[2]
|
||||
quant_output = torch.zeros(M, N // 2).to(torch.int8)
|
||||
quant_scale_output = torch.zeros(M).to(torch.float32)
|
||||
# Preprocessing X_INT8 -> X_INT4
|
||||
x_int4 = x_int8_to_x_int4(x)
|
||||
start_idx = 0
|
||||
# Number of tokens in the previous group
|
||||
pre_v = 0
|
||||
group_list = group_list.tolist()
|
||||
# Traverse group_list and process data by group
|
||||
for i, v in enumerate(group_list):
|
||||
curr_v = v
|
||||
# Calculate the number of tokens in the current group " * 2 " because 1 row of Int8--> 2 rows of Int4
|
||||
temp_v = int((curr_v - pre_v) * 2)
|
||||
# Update the number of tokens in the previous group
|
||||
pre_v = curr_v
|
||||
if (temp_v > 0):
|
||||
mm_out = custom_mm(x_int4[int(start_idx):int(start_idx + temp_v)],
|
||||
weight[i], weight_scale[i], temp_v)
|
||||
mm_num_concat = ((mm_out[::2] * 16 + mm_out[1::2]) +
|
||||
bias[i].view(1, -1))
|
||||
per_token_quant = mm_num_concat * per_token_scale[start_idx // 2:(
|
||||
start_idx + temp_v) // 2].view(-1, 1)
|
||||
swiglu, gate = per_token_quant.chunk(2, dim=-1)
|
||||
temp = swiglu * torch.sigmoid(swiglu)
|
||||
temp = temp * gate
|
||||
max_value = torch.max(torch.abs(temp), dim=-1).values
|
||||
quant_scale_output_temp = 127 / max_value
|
||||
quant_output[start_idx // 2:(start_idx + temp_v) //
|
||||
2] = torch.round(temp *
|
||||
quant_scale_output_temp.reshape(
|
||||
temp_v // 2, 1)).to(torch.int8)
|
||||
quant_scale_output[start_idx // 2:(start_idx + temp_v) //
|
||||
2] = 1 / quant_scale_output_temp
|
||||
start_idx += temp_v
|
||||
return quant_output, quant_scale_output
|
||||
|
||||
|
||||
def generate_non_decreasing_sequence(length, upper_limit):
|
||||
# Generate random increasing sequence
|
||||
random_increments = torch.randint(0, 128, (length, ))
|
||||
sequence = torch.cumsum(random_increments, dim=0)
|
||||
|
||||
# Make sure the last value is less than the upper limit
|
||||
if sequence[-1] >= upper_limit:
|
||||
scale_factor = upper_limit / sequence[-1]
|
||||
sequence = (sequence * scale_factor).to(torch.int64)
|
||||
return sequence
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_grouped_matmul_swiglu_quant_kernel():
|
||||
E = 16
|
||||
M = 512
|
||||
K = 7168
|
||||
N = 4096
|
||||
torch.npu.config.allow_internal_format = True
|
||||
x = torch.randint(-5, 5, (M, K), dtype=torch.int8).npu()
|
||||
weight_ori = torch.randint(-5, 5, (E, K, N), dtype=torch.int8)
|
||||
weight_nz = torch_npu.npu_format_cast(weight_ori.npu().to(torch.float32),
|
||||
29)
|
||||
pack_weight = torch_npu.npu_quantize(weight_nz,
|
||||
torch.tensor([1.], device='npu'),
|
||||
None, torch.quint4x2, -1, False)
|
||||
|
||||
weight_scale = torch.randn(E, 1, N)
|
||||
scale_np = weight_scale.cpu().numpy()
|
||||
scale_np.dtype = np.uint32
|
||||
scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu()
|
||||
pertoken_scale = torch.randn(M).to(torch.float32).npu()
|
||||
group_list = generate_non_decreasing_sequence(E, M).npu()
|
||||
bias = torch.zeros((E, N), dtype=torch.float32,
|
||||
device="npu").uniform_(-5, 5)
|
||||
|
||||
output_golden, output_scale_golden = gmm_swiglu_quant_golden_a8_w4(
|
||||
x.cpu(), weight_ori, weight_scale, pertoken_scale.cpu(), bias.cpu(),
|
||||
group_list.cpu())
|
||||
|
||||
output, output_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant(
|
||||
x=x,
|
||||
weight=pack_weight,
|
||||
bias=bias,
|
||||
group_list=group_list,
|
||||
weight_scale=scale_uint64_tensor,
|
||||
x_scale=pertoken_scale)
|
||||
torch.testing.assert_close(output_golden, output.cpu(), atol=1, rtol=0.005)
|
||||
torch.testing.assert_close(output_scale_golden,
|
||||
output_scale.cpu(),
|
||||
atol=1,
|
||||
rtol=0.005)
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
@@ -0,0 +1,135 @@
|
||||
import gc
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch_npu
|
||||
import torchair
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
config.mode = "reduce-overhead"
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
torch_npu.npu.config.allow_internal_format = True
|
||||
enable_custom_op()
|
||||
|
||||
global_rank_id = 0
|
||||
|
||||
|
||||
def golden_op_matmul_allreduce_add_rmsnorm(a, b, residual, gamma, epsilon):
|
||||
c_ret = torch.nn.functional.linear(a, b)
|
||||
dist.all_reduce(c_ret)
|
||||
rmsnorm_ret, _, add_ret = torch_npu.npu_add_rms_norm(
|
||||
c_ret, residual, gamma, epsilon)
|
||||
return rmsnorm_ret, add_ret
|
||||
|
||||
|
||||
def worker(rank, ep_world_size, batch_size, m, k, n):
|
||||
global global_rank_id
|
||||
global_rank_id = rank
|
||||
rank = rank
|
||||
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "29500"
|
||||
dist.init_process_group(backend="hccl",
|
||||
rank=rank,
|
||||
world_size=ep_world_size)
|
||||
|
||||
ep_ranks_list = list(np.arange(0, ep_world_size))
|
||||
|
||||
ep_group = dist.new_group(backend="hccl", ranks=ep_ranks_list)
|
||||
|
||||
torch_npu.npu.set_device(rank)
|
||||
ep_hcomm_info = ep_group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(rank)
|
||||
|
||||
torch_npu.npu.synchronize(rank)
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x1, x2, residual, gamma, ep_hcomm_info, epsilon,
|
||||
is_trans_b, is_allgather_add_out):
|
||||
out1, add_out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(
|
||||
x1, x2, residual, gamma, ep_hcomm_info, ep_world_size,
|
||||
global_rank_id, epsilon, is_trans_b, is_allgather_add_out)
|
||||
return out1, add_out1
|
||||
|
||||
DTYPE = torch.bfloat16
|
||||
USE_ONES = False
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
if USE_ONES:
|
||||
x1 = torch.ones([m, k], dtype=DTYPE).npu(rank)
|
||||
x2 = torch.ones([n, k], dtype=DTYPE).npu(rank)
|
||||
else:
|
||||
x1 = torch.normal(0, 0.1, [m, k], dtype=DTYPE).npu(rank)
|
||||
x2 = torch.normal(0, 0.1, [n, k], dtype=DTYPE).npu(rank)
|
||||
|
||||
if USE_ONES:
|
||||
residual = torch.full([m, n], 2048, dtype=DTYPE).npu(rank)
|
||||
else:
|
||||
residual = torch.full([m, n], 0, dtype=DTYPE).npu(rank)
|
||||
|
||||
gamma = torch.full([n], 1, dtype=DTYPE).npu(rank)
|
||||
|
||||
epsilon = 1e-5
|
||||
is_trans_b = True
|
||||
is_allgather_add_out = True
|
||||
warnup_cnt = 5
|
||||
repeat_cnt = 10
|
||||
|
||||
def run_golden_case(loop_cnt):
|
||||
for _ in range(loop_cnt):
|
||||
golden_out, golden_add_out = golden_op_matmul_allreduce_add_rmsnorm(
|
||||
x1, x2, residual, gamma, epsilon)
|
||||
torch_npu.npu.synchronize(rank)
|
||||
return golden_out, golden_add_out
|
||||
|
||||
run_golden_case(warnup_cnt)
|
||||
|
||||
golden_out, golden_add_out = run_golden_case(repeat_cnt)
|
||||
golden_out = golden_out.detach().cpu()
|
||||
golden_add_out = golden_add_out.detach().cpu()
|
||||
|
||||
mod = Module().npu()
|
||||
opt_model = torch.compile(mod, backend=npu_backend)
|
||||
|
||||
def run_custom_case(loop_cnt):
|
||||
for _ in range(loop_cnt):
|
||||
out, add_out = opt_model(x1, x2, residual, gamma, ep_hcomm_info,
|
||||
epsilon, is_trans_b, is_allgather_add_out)
|
||||
torch_npu.npu.synchronize(rank)
|
||||
return out, add_out
|
||||
|
||||
# warn up
|
||||
run_custom_case(warnup_cnt)
|
||||
|
||||
out, add_out = run_custom_case(repeat_cnt)
|
||||
out = out.detach().cpu()
|
||||
add_out = add_out.detach().cpu()
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
torch.testing.assert_close(golden_out, out, atol=0.1, rtol=0.005)
|
||||
torch.testing.assert_close(golden_add_out, add_out, atol=0.1, rtol=0.005)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_matmul_allreduce_add_rmsnorm_kernel():
|
||||
ep_world_size = 8
|
||||
batch_size = 1
|
||||
m = 10000
|
||||
k = 1024
|
||||
n = 5120
|
||||
args = (ep_world_size, batch_size, m, k, n)
|
||||
mp.spawn(worker, args=args, nprocs=ep_world_size, join=True)
|
||||
@@ -0,0 +1,115 @@
|
||||
import gc
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_mla_preprocess_kernel():
|
||||
token_num = 1
|
||||
head_num = 2
|
||||
N_7168 = 7168
|
||||
block_num = 1
|
||||
block_size = 128
|
||||
dtype = torch.bfloat16
|
||||
|
||||
hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu()
|
||||
quant_scale0 = torch.randn((1, ), dtype=dtype).npu()
|
||||
quant_offset0 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
|
||||
|
||||
wdqkv = torch.randint(0, 7, (1, 224, 2112, 32), dtype=torch.int8).npu()
|
||||
wdqkv = torch_npu.npu_format_cast(wdqkv.contiguous(), 29)
|
||||
|
||||
de_scale0 = torch.rand((2112, ), dtype=torch.float).npu()
|
||||
bias0 = torch.randint(0, 7, (2112, ), dtype=torch.int32).npu()
|
||||
gamma1 = torch.randn((1536), dtype=dtype).npu()
|
||||
beta1 = torch.randn((1536), dtype=dtype).npu()
|
||||
quant_scale1 = torch.randn((1, ), dtype=dtype).npu()
|
||||
quant_offset1 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
|
||||
|
||||
wuq = torch.randint(0, 7, (1, 48, head_num * 192, 32),
|
||||
dtype=torch.int8).npu()
|
||||
wuq = torch_npu.npu_format_cast(wuq.contiguous(), 29)
|
||||
|
||||
de_scale1 = torch.rand((head_num * 192, ), dtype=torch.float).npu()
|
||||
bias1 = torch.randint(0, 7, (head_num * 192, ), dtype=torch.int32).npu()
|
||||
|
||||
gamma2 = torch.randn((512), dtype=dtype).npu()
|
||||
|
||||
cos = torch.randn((token_num, 64), dtype=dtype).npu()
|
||||
sin = torch.randn((token_num, 64), dtype=dtype).npu()
|
||||
|
||||
wuk = torch.randn((head_num, 128, 512), dtype=dtype).npu()
|
||||
wuk = torch_npu.npu_format_cast(wuk, 29)
|
||||
kv_cache = torch.randint(0,
|
||||
7,
|
||||
(block_num, head_num * 512 // 32, block_size, 32),
|
||||
dtype=dtype).npu()
|
||||
kv_cache_rope = torch.randn(
|
||||
(block_num, head_num * 64 // 16, block_size, 16), dtype=dtype).npu()
|
||||
|
||||
slotmapping = torch.randint(0, 7, (token_num, ), dtype=torch.int32).npu()
|
||||
|
||||
ctkv_scale = torch.randn((1, ), dtype=dtype).npu()
|
||||
qnope_scale = torch.randn((head_num), dtype=dtype).npu()
|
||||
|
||||
q_nope_out = torch.empty(
|
||||
(hidden_states.shape[0], wuk.shape[0], kv_cache.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
q_rope_out = torch.empty(
|
||||
(hidden_states.shape[0], wuk.shape[0], kv_cache_rope.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
q_down = torch.empty(
|
||||
(hidden_states.shape[0], 1536),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
q_nope_old = q_nope_out.clone()
|
||||
q_rope_old = q_rope_out.clone()
|
||||
|
||||
torch.ops._C_ascend.mla_preprocess(
|
||||
hidden_states,
|
||||
wdqkv,
|
||||
de_scale0,
|
||||
gamma1,
|
||||
beta1,
|
||||
wuq,
|
||||
de_scale1,
|
||||
gamma2,
|
||||
cos,
|
||||
sin,
|
||||
wuk,
|
||||
kv_cache,
|
||||
kv_cache_rope,
|
||||
slotmapping,
|
||||
quant_scale0=quant_scale0,
|
||||
quant_offset0=quant_offset0,
|
||||
bias0=bias0,
|
||||
quant_scale1=quant_scale1,
|
||||
quant_offset1=quant_offset1,
|
||||
bias1=bias1,
|
||||
ctkv_scale=ctkv_scale,
|
||||
q_nope_scale=qnope_scale,
|
||||
cache_mode="krope_ctkv",
|
||||
quant_mode="per_tensor_quant_asymm",
|
||||
enable_inner_out=False,
|
||||
q_out0=q_nope_out,
|
||||
kv_cache_out0=kv_cache,
|
||||
q_out1=q_rope_out,
|
||||
kv_cache_out1=kv_cache_rope,
|
||||
inner_out=q_down,
|
||||
)
|
||||
assert not torch.equal(q_nope_out, q_nope_old)
|
||||
assert not torch.equal(q_rope_out, q_rope_old)
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
@@ -0,0 +1,99 @@
|
||||
import gc
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_mla_preprocess_kernel():
|
||||
token_num = 1
|
||||
head_num = 2
|
||||
N_7168 = 7168
|
||||
block_num = 1
|
||||
block_size = 128
|
||||
dtype = torch.bfloat16
|
||||
|
||||
hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu()
|
||||
|
||||
wdqkv = torch.randint(0, 7, (1, 448, 2112, 16), dtype=dtype).npu()
|
||||
wdqkv = torch_npu.npu_format_cast(wdqkv.contiguous(), 29)
|
||||
gamma1 = torch.randn((1536), dtype=dtype).npu()
|
||||
|
||||
wuq = torch.randint(0, 7, (1, 96, head_num * 192, 16), dtype=dtype).npu()
|
||||
wuq = torch_npu.npu_format_cast(wuq.contiguous(), 29)
|
||||
gamma2 = torch.randn((512), dtype=dtype).npu()
|
||||
|
||||
cos = torch.randn((token_num, 64), dtype=dtype).npu()
|
||||
sin = torch.randn((token_num, 64), dtype=dtype).npu()
|
||||
|
||||
wuk = torch.randn((head_num, 128, 512), dtype=dtype).npu()
|
||||
# wuk = torch_npu.npu_format_cast(wuk, 29)
|
||||
kv_cache = torch.randint(0,
|
||||
7,
|
||||
(block_num, head_num * 512 // 32, block_size, 32),
|
||||
dtype=dtype).npu()
|
||||
kv_cache_rope = torch.randn(
|
||||
(block_num, head_num * 64 // 16, block_size, 16), dtype=dtype).npu()
|
||||
|
||||
slotmapping = torch.randint(0, 7, (token_num, ), dtype=torch.int32).npu()
|
||||
|
||||
q_nope_out = torch.empty(
|
||||
(hidden_states.shape[0], wuk.shape[0], kv_cache.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
q_rope_out = torch.empty(
|
||||
(hidden_states.shape[0], wuk.shape[0], kv_cache_rope.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
q_down = torch.empty(
|
||||
(hidden_states.shape[0], 1536),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
q_nope_old = q_nope_out.clone()
|
||||
q_rope_old = q_rope_out.clone()
|
||||
|
||||
torch.ops._C_ascend.mla_preprocess(
|
||||
hidden_states,
|
||||
wdqkv,
|
||||
None,
|
||||
gamma1,
|
||||
None,
|
||||
wuq,
|
||||
None,
|
||||
gamma2,
|
||||
cos,
|
||||
sin,
|
||||
wuk,
|
||||
kv_cache,
|
||||
kv_cache_rope,
|
||||
slotmapping,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
cache_mode="krope_ctkv",
|
||||
quant_mode="no_quant",
|
||||
enable_inner_out=False,
|
||||
q_out0=q_nope_out,
|
||||
kv_cache_out0=kv_cache,
|
||||
q_out1=q_rope_out,
|
||||
kv_cache_out1=kv_cache_rope,
|
||||
inner_out=q_down,
|
||||
)
|
||||
assert not torch.equal(q_nope_out, q_nope_old)
|
||||
assert not torch.equal(q_rope_out, q_rope_old)
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
@@ -0,0 +1,116 @@
|
||||
import gc
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_mla_preprocess_kernel():
|
||||
token_num = 1
|
||||
head_num = 2
|
||||
N_7168 = 7168
|
||||
block_num = 1
|
||||
block_size = 128
|
||||
dtype = torch.bfloat16
|
||||
|
||||
hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu()
|
||||
quant_scale0 = torch.randn((1, ), dtype=dtype).npu()
|
||||
quant_offset0 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
|
||||
|
||||
wdqkv = torch.randint(0, 7, (1, 224, 2112, 32), dtype=torch.int8).npu()
|
||||
wdqkv = torch_npu.npu_format_cast(wdqkv.contiguous(), 29)
|
||||
|
||||
de_scale0 = torch.rand((2112, ), dtype=torch.float).npu()
|
||||
bias0 = torch.randint(0, 7, (2112, ), dtype=torch.int32).npu()
|
||||
gamma1 = torch.randn((1536), dtype=dtype).npu()
|
||||
beta1 = torch.randn((1536), dtype=dtype).npu()
|
||||
quant_scale1 = torch.randn((1, ), dtype=dtype).npu()
|
||||
quant_offset1 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
|
||||
|
||||
wuq = torch.randint(0, 7, (1, 48, head_num * 192, 32),
|
||||
dtype=torch.int8).npu()
|
||||
wuq = torch_npu.npu_format_cast(wuq.contiguous(), 29)
|
||||
|
||||
de_scale1 = torch.rand((head_num * 192, ), dtype=torch.float).npu()
|
||||
bias1 = torch.randint(0, 7, (head_num * 192, ), dtype=torch.int32).npu()
|
||||
|
||||
gamma2 = torch.randn((512), dtype=dtype).npu()
|
||||
|
||||
cos = torch.randn((token_num, 64), dtype=dtype).npu()
|
||||
sin = torch.randn((token_num, 64), dtype=dtype).npu()
|
||||
|
||||
wuk = torch.randn((head_num, 128, 512), dtype=dtype).npu()
|
||||
wuk = torch_npu.npu_format_cast(wuk, 29)
|
||||
kv_cache = torch.randint(0,
|
||||
7,
|
||||
(block_num, head_num * 512 // 32, block_size, 32),
|
||||
dtype=dtype).npu()
|
||||
kv_cache_rope = torch.randn(
|
||||
(block_num, head_num * 64 // 16, block_size, 16), dtype=dtype).npu()
|
||||
|
||||
slotmapping = torch.randint(0, 7, (token_num, ), dtype=torch.int32).npu()
|
||||
|
||||
ctkv_scale = torch.randn((1, ), dtype=dtype).npu()
|
||||
qnope_scale = torch.randn((head_num), dtype=dtype).npu()
|
||||
|
||||
q_nope_out = torch.empty(
|
||||
(hidden_states.shape[0], wuk.shape[0], kv_cache.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
q_rope_out = torch.empty(
|
||||
(hidden_states.shape[0], wuk.shape[0], kv_cache_rope.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
q_down = torch.empty(
|
||||
(hidden_states.shape[0], 1536),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
q_nope_old = q_nope_out.clone()
|
||||
q_rope_old = q_rope_out.clone()
|
||||
q_down_old = q_down.clone()
|
||||
|
||||
torch.ops._C_ascend.mla_preprocess(hidden_states,
|
||||
wdqkv,
|
||||
de_scale0,
|
||||
gamma1,
|
||||
beta1,
|
||||
wuq,
|
||||
de_scale1,
|
||||
gamma2,
|
||||
cos,
|
||||
sin,
|
||||
wuk,
|
||||
kv_cache,
|
||||
kv_cache_rope,
|
||||
slotmapping,
|
||||
quant_scale0=quant_scale0,
|
||||
quant_offset0=quant_offset0,
|
||||
bias0=bias0,
|
||||
quant_scale1=quant_scale1,
|
||||
quant_offset1=quant_offset1,
|
||||
bias1=bias1,
|
||||
ctkv_scale=ctkv_scale,
|
||||
q_nope_scale=qnope_scale,
|
||||
cache_mode="krope_ctkv",
|
||||
quant_mode="per_tensor_quant_asymm",
|
||||
enable_inner_out=True,
|
||||
q_out0=q_nope_out,
|
||||
kv_cache_out0=kv_cache,
|
||||
q_out1=q_rope_out,
|
||||
kv_cache_out1=kv_cache_rope,
|
||||
inner_out=q_down)
|
||||
assert not torch.equal(q_nope_out, q_nope_old)
|
||||
assert not torch.equal(q_rope_out, q_rope_old)
|
||||
assert not torch.equal(q_down, q_down_old)
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
@@ -0,0 +1,349 @@
|
||||
import itertools
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
|
||||
def adapter_capacity(sorted_row_idx, sorted_expert_idx, capacity):
|
||||
count = 0
|
||||
last = sorted_expert_idx[0]
|
||||
for i, val in enumerate(sorted_expert_idx):
|
||||
if last != val:
|
||||
count = 1
|
||||
last = val
|
||||
else:
|
||||
count += 1
|
||||
if count > capacity:
|
||||
sorted_expert_idx[i] = -1
|
||||
sorted_row_idx[i] = -1
|
||||
|
||||
|
||||
def moe_init_routing_golden(x, expert_idx, scale, offset, active_num,
|
||||
expert_capacity, expert_num, drop_pad_mode,
|
||||
expert_tokens_num_type, expert_tokens_num_flag,
|
||||
active_expert_range, quant_mode, row_idx_type):
|
||||
if drop_pad_mode == 1:
|
||||
if expert_num <= 0:
|
||||
print("expert num can not be 0")
|
||||
return
|
||||
expert_start = active_expert_range[0] if drop_pad_mode == 0 else 0
|
||||
expert_end = active_expert_range[1] if drop_pad_mode == 0 else expert_num
|
||||
num_rows = x.shape[0]
|
||||
h = x.shape[1]
|
||||
k = expert_idx.shape[-1]
|
||||
expert_idx_in = expert_idx.copy().reshape(-1)
|
||||
actual_expert_total_num: int = np.sum((expert_idx_in >= expert_start)
|
||||
& (expert_idx_in < expert_end))
|
||||
|
||||
expert_idx_in[(expert_idx_in
|
||||
< expert_start)] = np.int32(np.iinfo(np.int32).max)
|
||||
sorted_expert_indices = np.argsort(expert_idx_in, axis=-1, kind="stable")
|
||||
sorted_expert_idx = expert_idx_in[sorted_expert_indices]
|
||||
if row_idx_type == 1:
|
||||
expanded_row_idx = sorted_expert_indices[:actual_expert_total_num]
|
||||
else:
|
||||
expanded_row_idx = np.ones(num_rows * k).astype(np.int32) * -1
|
||||
tmp_indices = np.arange(actual_expert_total_num)
|
||||
expanded_row_idx[
|
||||
sorted_expert_indices[:actual_expert_total_num]] = tmp_indices
|
||||
|
||||
if not expert_tokens_num_flag:
|
||||
expert_tokens_count = torch.tensor([0])
|
||||
else:
|
||||
if drop_pad_mode == 0:
|
||||
if expert_tokens_num_type == 1:
|
||||
expert_tokens_count = np.bincount(
|
||||
sorted_expert_idx[:actual_expert_total_num] - expert_start)
|
||||
expert_tokens_count = np.concatenate([
|
||||
expert_tokens_count,
|
||||
np.zeros((expert_end - expert_start) -
|
||||
len(expert_tokens_count)).astype(np.int64)
|
||||
])
|
||||
elif expert_tokens_num_type == 0:
|
||||
expert_tokens_count = np.bincount(
|
||||
sorted_expert_idx[:actual_expert_total_num] - expert_start)
|
||||
expert_tokens_count = np.concatenate([
|
||||
expert_tokens_count,
|
||||
np.zeros((expert_end - expert_start) -
|
||||
len(expert_tokens_count)).astype(np.int64)
|
||||
])
|
||||
expert_tokens_count = np.cumsum(expert_tokens_count)
|
||||
elif expert_tokens_num_type == 2:
|
||||
expert_id, counts = np.unique(
|
||||
sorted_expert_idx[:actual_expert_total_num],
|
||||
return_counts=True)
|
||||
expert_tokens_count = np.column_stack((expert_id, counts))
|
||||
if expert_tokens_count.shape[0] < expert_num:
|
||||
expert_tokens_count = np.concatenate(
|
||||
(expert_tokens_count, [
|
||||
[0, 0],
|
||||
]), axis=0)
|
||||
else:
|
||||
expert_tokens_count = np.bincount(
|
||||
sorted_expert_idx[:actual_expert_total_num] - expert_start)
|
||||
zeros_array = np.zeros(
|
||||
(expert_end - expert_start) - len(expert_tokens_count),
|
||||
dtype=np.int64)
|
||||
expert_tokens_count = np.concatenate(
|
||||
[expert_tokens_count, zeros_array])
|
||||
expert_tokens_count = expert_tokens_count.astype(np.int64)
|
||||
|
||||
if drop_pad_mode == 0:
|
||||
if active_num == 0:
|
||||
active_num = actual_expert_total_num
|
||||
else:
|
||||
active_num = min(active_num, actual_expert_total_num)
|
||||
expanded_scale = None
|
||||
expanded_x = x[sorted_expert_indices[:active_num] // k, :]
|
||||
if scale is not None and quant_mode == -1:
|
||||
expanded_scale = scale[sorted_expert_indices[:active_num] // k]
|
||||
else:
|
||||
adapter_capacity(sorted_expert_indices, sorted_expert_idx,
|
||||
expert_capacity)
|
||||
|
||||
sort_row_tmp = np.full((expert_num * expert_capacity), -1, dtype=int)
|
||||
offset_tmp = 0
|
||||
lastExpertId = 0
|
||||
for i, val in enumerate(sorted_expert_indices):
|
||||
if val != -1:
|
||||
if lastExpertId != sorted_expert_idx[i]:
|
||||
offset_tmp = 0
|
||||
lastExpertId = sorted_expert_idx[i]
|
||||
sort_row_tmp[sorted_expert_idx[i] * expert_capacity +
|
||||
offset_tmp] = sorted_expert_indices[i]
|
||||
offset_tmp = offset_tmp + 1
|
||||
|
||||
expanded_row_idx = np.full(sorted_expert_indices.shape, -1)
|
||||
for i, val in enumerate(sort_row_tmp):
|
||||
if val != -1:
|
||||
expanded_row_idx[val] = i
|
||||
|
||||
expanded_x_mask = np.full((expert_num * expert_capacity, h),
|
||||
1,
|
||||
dtype=int)
|
||||
expanded_x = np.full((expert_num * expert_capacity, h),
|
||||
0,
|
||||
dtype=x.dtype)
|
||||
for i, val in enumerate(sort_row_tmp):
|
||||
if val != -1:
|
||||
expanded_x[i] = x[val // k]
|
||||
expanded_x_mask[i] = np.full((h, ), 0, dtype=int)
|
||||
|
||||
if quant_mode == -1:
|
||||
expanded_x = expanded_x
|
||||
expanded_row_idx = expanded_row_idx
|
||||
if scale is not None and drop_pad_mode == 1:
|
||||
expanded_scale = np.full((expert_num * expert_capacity, ),
|
||||
0,
|
||||
dtype=scale.dtype)
|
||||
for i, val in enumerate(sort_row_tmp):
|
||||
if val != -1:
|
||||
expanded_scale[i] = scale[val // k]
|
||||
if scale is None:
|
||||
expanded_scale = None
|
||||
|
||||
if quant_mode == 0:
|
||||
expanded_scale = None
|
||||
expanded_x_fp16 = expanded_x.astype(np.float16)
|
||||
if scale is not None:
|
||||
scale_val = scale.astype(np.float16)
|
||||
else:
|
||||
raise ValueError("scale cannot be None when quant_mode is 0")
|
||||
if offset is not None:
|
||||
offset_val = offset.astype(np.float16)
|
||||
else:
|
||||
raise ValueError("offset cannot be None when quant_mode is 0")
|
||||
scale_rst = expanded_x_fp16 * scale_val[0]
|
||||
add_offset = scale_rst + offset_val[0]
|
||||
round_data = np.rint(add_offset)
|
||||
round_data = np.clip(round_data, -128, 127)
|
||||
expanded_x = round_data.astype(np.int8)
|
||||
|
||||
if quant_mode == 1:
|
||||
x_final = expanded_x.astype(np.float32)
|
||||
if scale is None:
|
||||
x_abs = np.abs(x_final)
|
||||
x_max = np.max(x_abs, axis=-1, keepdims=True)
|
||||
expanded_scale = x_max / 127
|
||||
expanded_x = x_final / expanded_scale
|
||||
expanded_x = np.round(expanded_x).astype(np.int8)
|
||||
else:
|
||||
if scale.shape[0] == 1:
|
||||
x_final = x_final * scale
|
||||
else:
|
||||
if drop_pad_mode == 0:
|
||||
x_final = x_final * scale[sorted_expert_idx[:active_num] -
|
||||
expert_start]
|
||||
|
||||
else:
|
||||
for i, val in enumerate(sort_row_tmp):
|
||||
if val != -1:
|
||||
x_final[i] = x_final[i] * scale[i //
|
||||
expert_capacity]
|
||||
x_abs = np.abs(x_final)
|
||||
x_max = np.max(x_abs, axis=-1, keepdims=True)
|
||||
expanded_scale = x_max / 127
|
||||
expanded_x = x_final / expanded_scale
|
||||
expanded_x = np.round(expanded_x).astype(np.int8)
|
||||
if x.dtype == np.int8:
|
||||
expanded_scale = None
|
||||
if drop_pad_mode == 1:
|
||||
expanded_x = np.ma.array(expanded_x, mask=expanded_x_mask).filled(0)
|
||||
expanded_x = expanded_x.reshape(expert_num, expert_capacity, h)
|
||||
|
||||
return expanded_x, expanded_row_idx, expert_tokens_count, expanded_scale
|
||||
|
||||
|
||||
def npu_pta(x, expert_idx, scale, offset, active_num, expert_capacity,
|
||||
expert_num, drop_pad_mode, expert_tokens_num_type,
|
||||
expert_tokens_num_flag, quant_mode, active_expert_range,
|
||||
row_idx_type):
|
||||
expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale = torch.ops._C_ascend.npu_moe_init_routing_custom(
|
||||
x,
|
||||
expert_idx,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
active_num=active_num,
|
||||
expert_capacity=expert_capacity,
|
||||
expert_num=expert_num,
|
||||
drop_pad_mode=drop_pad_mode,
|
||||
expert_tokens_num_type=expert_tokens_num_type,
|
||||
expert_tokens_num_flag=expert_tokens_num_flag,
|
||||
quant_mode=quant_mode,
|
||||
active_expert_range=active_expert_range,
|
||||
row_idx_type=row_idx_type)
|
||||
|
||||
return expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale
|
||||
|
||||
|
||||
def cmp_out_golden(x_golden, x_out, dtype):
|
||||
if dtype == 'int8':
|
||||
cmp = np.isclose(x_out.cpu().numpy()[:len(x_golden)], x_golden, atol=1)
|
||||
else:
|
||||
cmp = np.isclose(x_out.cpu().numpy()[:len(x_golden)],
|
||||
x_golden,
|
||||
rtol=1e-05,
|
||||
atol=1e-05)
|
||||
return np.all(cmp)
|
||||
|
||||
|
||||
def test_moe_npu(x, expert_idx, scale, offset, active_num, expert_capacity,
|
||||
expert_num, drop_pad_mode, expert_tokens_num_type,
|
||||
expert_tokens_num_flag, quant_mode, active_expert_range,
|
||||
row_idx_type):
|
||||
x_npu = x.npu()
|
||||
expert_idx_npu = expert_idx.npu()
|
||||
scale_npu = scale.npu() if scale is not None else None
|
||||
offset_npu = offset.npu() if offset is not None else None
|
||||
|
||||
x_numpy = x.numpy()
|
||||
expert_idx_numpy = expert_idx.numpy()
|
||||
scale_numpy = scale.numpy() if scale is not None else None
|
||||
offset_numpy = offset.numpy() if offset is not None else None
|
||||
|
||||
expanded_x_golden, expanded_row_idx_golden, expert_token_cumsum_or_count_golden, expanded_scale_golden = moe_init_routing_golden(
|
||||
x_numpy, expert_idx_numpy, scale_numpy, offset_numpy, active_num,
|
||||
expert_capacity, expert_num, drop_pad_mode, expert_tokens_num_type,
|
||||
expert_tokens_num_flag, active_expert_range, quant_mode, row_idx_type)
|
||||
|
||||
expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale = npu_pta(
|
||||
x_npu, expert_idx_npu, scale_npu, offset_npu, active_num,
|
||||
expert_capacity, expert_num, drop_pad_mode, expert_tokens_num_type,
|
||||
expert_tokens_num_flag, quant_mode, active_expert_range, row_idx_type)
|
||||
if quant_mode == -1:
|
||||
expanded_x_result = cmp_out_golden(expanded_x_golden, expanded_x,
|
||||
"float32")
|
||||
else:
|
||||
expanded_x_result = cmp_out_golden(expanded_x_golden, expanded_x,
|
||||
"int8")
|
||||
|
||||
expanded_row_idx_result = cmp_out_golden(expanded_row_idx_golden,
|
||||
expanded_row_idx, "int32")
|
||||
|
||||
if expert_tokens_num_flag:
|
||||
expert_tokens_result = cmp_out_golden(
|
||||
expert_token_cumsum_or_count_golden, expert_token_cumsum_or_count,
|
||||
"int64")
|
||||
else:
|
||||
expert_tokens_result = True
|
||||
|
||||
if quant_mode == 1 or (quant_mode == -1 and scale is not None):
|
||||
expand_scale_result = cmp_out_golden(expanded_scale_golden.flatten(),
|
||||
expanded_scale, "float32")
|
||||
else:
|
||||
expand_scale_result = True
|
||||
|
||||
compare_result = expanded_x_result and expanded_row_idx_result and expert_tokens_result and expand_scale_result
|
||||
# print('=======case result=======: ', compare_result)
|
||||
return compare_result
|
||||
|
||||
|
||||
def test_moe_init_routing_custom():
|
||||
failed_test_cnt = 0
|
||||
drop_pad_mode = [0, 1]
|
||||
expert_tokens_num_type = [0, 1, 2]
|
||||
expert_tokens_num_flag = [True, False]
|
||||
quant_mode = [0, 1, -1]
|
||||
row_idx_type = [0, 1]
|
||||
scale_type = [0, 1, 2]
|
||||
product_result = itertools.product(drop_pad_mode, expert_tokens_num_type,
|
||||
expert_tokens_num_flag, quant_mode,
|
||||
row_idx_type, scale_type)
|
||||
|
||||
for idx, (drop_pad_mode_, expert_tokens_num_type_, expert_tokens_num_flag_,
|
||||
quant_mode_, row_idx_type_,
|
||||
scale_type_) in enumerate(product_result, 5):
|
||||
expert_num_ = random.randint(2, 500)
|
||||
expert_start = random.randint(0, expert_num_ - 1)
|
||||
expert_end = random.randint(expert_start + 1, expert_num_)
|
||||
active_expert_range_ = [expert_start, expert_end]
|
||||
|
||||
N = random.randint(1, 100)
|
||||
H = random.randint(12, 100)
|
||||
K = random.randint(1, 12)
|
||||
x_ = torch.randn(N, H, dtype=torch.float16) * 5
|
||||
expert_capacity_ = random.randint(1, N - 1) if N > 1 else 1
|
||||
expert_idx_ = torch.randint(0,
|
||||
expert_num_ - 1, (N, K),
|
||||
dtype=torch.int32)
|
||||
active_num_ = N * K
|
||||
|
||||
if drop_pad_mode_ == 1:
|
||||
active_expert_range_ = [0, expert_num_]
|
||||
expert_tokens_num_type_ = 1
|
||||
row_idx_type_ = 0
|
||||
|
||||
if quant_mode_ == 0:
|
||||
scale_ = torch.randn(1, dtype=torch.float)
|
||||
offset_ = torch.randn(1, dtype=torch.float)
|
||||
elif quant_mode_ == -1:
|
||||
scale_ = None
|
||||
offset_ = None
|
||||
else:
|
||||
if scale_type_ == 0:
|
||||
scale_ = None
|
||||
offset_ = None
|
||||
elif scale_type_ == 1:
|
||||
scale_ = torch.randn(1, H, dtype=torch.float)
|
||||
offset_ = None
|
||||
else:
|
||||
scale_ = torch.randn(active_expert_range_[1] -
|
||||
active_expert_range_[0],
|
||||
H,
|
||||
dtype=torch.float)
|
||||
offset_ = None
|
||||
|
||||
result_pta = test_moe_npu(x_, expert_idx_, scale_, offset_,
|
||||
active_num_, expert_capacity_, expert_num_,
|
||||
drop_pad_mode_, expert_tokens_num_type_,
|
||||
expert_tokens_num_flag_, quant_mode_,
|
||||
active_expert_range_, row_idx_type_)
|
||||
if not result_pta:
|
||||
failed_test_cnt += 1
|
||||
|
||||
assert (failed_test_cnt == 0)
|
||||
@@ -0,0 +1,351 @@
|
||||
# Copyright 2023 The vLLM team.
|
||||
|
||||
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/tests/kernels/test_rotary_embedding.py
|
||||
|
||||
import gc
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
# Only Neox style true scenario is supported for now
|
||||
IS_NEOX_STYLE = [True]
|
||||
DTYPES = [torch.half]
|
||||
HEAD_SIZES = [64, 64, 96, 128, 256]
|
||||
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
||||
NUM_HEADS = [17] # Arbitrary values for testing
|
||||
BATCH_SIZES = [5] # Arbitrary values for testing
|
||||
SEQ_LENS = [11, 4096] # Arbitrary values for testing
|
||||
NUM_TOKENS = [10, 21]
|
||||
SEEDS = [0]
|
||||
DEVICES = [f"npu:{0}"]
|
||||
# Set tolerance to 1 for quant ops
|
||||
DEFAULT_ATOL = 1e-3
|
||||
DEFAULT_RTOL = 1e-3
|
||||
|
||||
|
||||
def _apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/vllm/model_executor/layers/rotary_embedding.py
|
||||
class RotaryEmbedding(nn.Module):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||
# use CPU to compute the cache and then move it to GPU. However, we
|
||||
# create the cache on GPU for faster initialization. This may cause
|
||||
# a slight numerical difference between the HF implementation and ours.
|
||||
inv_freq = 1.0 / (base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A PyTorch-native implementation of forward()."""
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
# test with leading dimension and merge seqlen and batch_size as num_tokens
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding_quant_with_leading_dim(
|
||||
is_neox_style: bool,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, dtype)
|
||||
rope = rope.to(dtype=dtype)
|
||||
num_tokens = batch_size * seq_len
|
||||
positions = torch.randint(0, max_position, (batch_size * seq_len, ))
|
||||
qkv_tensor = torch.randn(num_tokens,
|
||||
num_heads * head_size * 3,
|
||||
dtype=dtype)
|
||||
query, key, _ = qkv_tensor.split(
|
||||
[num_heads * head_size, num_heads * head_size, num_heads * head_size],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
rope.head_size,
|
||||
rope.cos_sin_cache,
|
||||
rope.is_neox_style,
|
||||
)
|
||||
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(query.view(ref_query.size()),
|
||||
ref_query,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
torch.testing.assert_close(key.view(ref_key.size()),
|
||||
ref_key,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
class ModelwithRotaryEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.qkv_proj = nn.Linear(hidden_size, num_heads * head_size * 3)
|
||||
self.rope = RotaryEmbedding(
|
||||
head_size=head_size,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
base=base,
|
||||
is_neox_style=is_neox_style,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.o_proj = nn.Linear(num_heads * head_size, hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||
positions,
|
||||
q,
|
||||
k,
|
||||
self.rope.head_size,
|
||||
self.rope.cos_sin_cache,
|
||||
self.rope.is_neox_style,
|
||||
)
|
||||
query = query.view(q.shape)
|
||||
key = key.view(k.shape)
|
||||
o = self.o_proj(query)
|
||||
return o
|
||||
|
||||
|
||||
# The first graph seems will have some accuracy issue when directly run pytest on the ops folder,
|
||||
# add a warmup graph replay for workaround
|
||||
ACL_GRPAH_FIRST_RUN = True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("num_tokens", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_capture_rotary_embedding_in_aclgraph(
|
||||
is_neox_style: bool,
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
max_position_embeddings: int = 8192,
|
||||
base: int = 10000,
|
||||
):
|
||||
"""Test if the rotary embedding can be captured in aclgraph."""
|
||||
torch.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
model = ModelwithRotaryEmbedding(
|
||||
hidden_size=num_heads * head_size,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
base=base,
|
||||
is_neox_style=is_neox_style,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def custom_op_checking_backend(gm: torch.fx.GraphModule, example_input):
|
||||
# Validate if the rotary_embedding custom kernel is indeed inside the graph by
|
||||
# string match
|
||||
graph = str(gm.graph)
|
||||
assert "_C_ascend.rotary_embedding" in graph
|
||||
return gm
|
||||
|
||||
static_positions = torch.randint(0, max_position_embeddings,
|
||||
(num_tokens, ))
|
||||
static_hidden_states = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device="npu")
|
||||
compiled_model = torch.compile(model, backend=custom_op_checking_backend)
|
||||
stream = torch.npu.Stream()
|
||||
stream.wait_stream(torch.npu.current_stream())
|
||||
with torch.npu.stream(stream):
|
||||
# warmup the fx graph before capture
|
||||
for i in range(3):
|
||||
static_output = compiled_model(static_positions,
|
||||
static_hidden_states,
|
||||
offsets=None)
|
||||
stream.wait_stream(torch.npu.current_stream())
|
||||
|
||||
aclgraph = torch.npu.NPUGraph()
|
||||
|
||||
with torch.npu.graph(aclgraph):
|
||||
# Capture the model in aclgraph.
|
||||
static_output = compiled_model(static_positions, static_hidden_states)
|
||||
# Capture the model in aclgraph.
|
||||
random_filled_positions = torch.randint(0,
|
||||
max_position_embeddings,
|
||||
(num_tokens, ),
|
||||
device="npu")
|
||||
random_filled_hidden_states = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device="npu")
|
||||
static_positions.copy_(random_filled_positions)
|
||||
static_hidden_states.copy_(random_filled_hidden_states)
|
||||
|
||||
aclgraph.replay()
|
||||
global ACL_GRPAH_FIRST_RUN
|
||||
if ACL_GRPAH_FIRST_RUN:
|
||||
ACL_GRPAH_FIRST_RUN = False
|
||||
return
|
||||
output_reference = model(static_positions, static_hidden_states)
|
||||
torch.testing.assert_close(static_output,
|
||||
output_reference,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
@@ -0,0 +1,98 @@
|
||||
import gc
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
import vllm_ascend.platform # noqa: F401
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
# Test parameters
|
||||
DTYPES = [torch.int32]
|
||||
#SHAPES = [(100,), (5, 20), (3, 4, 5)] # Various tensor shapes
|
||||
#SHAPES = [(3, 4, 8), (3, 4, 5)] # Various tensor shapes
|
||||
SHAPES = [(3, 4, 3)]
|
||||
DEVICES = [f"npu:{0}"]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def get_masked_input_and_mask_ref(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Reference implementation for verification"""
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
||||
input_ < org_vocab_end_index)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
masked_input = vocab_mask * (input_ - valid_offset)
|
||||
return masked_input, ~vocab_mask
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_get_masked_input_and_mask(
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
seed: int,
|
||||
) -> None:
|
||||
# Set random seed
|
||||
torch.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Generate random input tensor
|
||||
input_tensor = torch.randint(0, 1000, shape, dtype=dtype)
|
||||
|
||||
# Test parameters
|
||||
test_case = {
|
||||
"org_start": 100,
|
||||
"org_end": 200,
|
||||
"padding": 0,
|
||||
"added_start": 300,
|
||||
"added_end": 400,
|
||||
}
|
||||
|
||||
# Get reference result
|
||||
ref_masked_input, ref_mask = get_masked_input_and_mask_ref(
|
||||
input_tensor, test_case["org_start"], test_case["org_end"],
|
||||
test_case["padding"], test_case["added_start"], test_case["added_end"])
|
||||
|
||||
# Get custom op result
|
||||
print("input_tensor:", input_tensor)
|
||||
custom_masked_input, custom_mask = torch.ops._C_ascend.get_masked_input_and_mask(
|
||||
input_tensor, test_case["org_start"], test_case["org_end"],
|
||||
test_case["padding"], test_case["added_start"], test_case["added_end"])
|
||||
|
||||
ref_masked_input = ref_masked_input.to(dtype)
|
||||
print("custom_masked_input:", custom_masked_input)
|
||||
print("ref_masked_input:", ref_masked_input)
|
||||
print("custom_mask:", custom_mask)
|
||||
print("ref_mask:", ref_mask)
|
||||
# Compare results
|
||||
torch.testing.assert_close(
|
||||
custom_masked_input,
|
||||
ref_masked_input,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Masked input mismatch for case: {test_case}")
|
||||
torch.testing.assert_close(custom_mask,
|
||||
ref_mask,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Mask mismatch for case: {test_case}")
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
@@ -0,0 +1,361 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm_ascend.ops.triton.mamba.causal_conv1d import (PAD_SLOT_ID,
|
||||
causal_conv1d_fn)
|
||||
from vllm_ascend.ops.triton.mamba.causal_conv1d import \
|
||||
causal_conv1d_update_npu as causal_conv1d_update
|
||||
|
||||
|
||||
def validate_cmp(y_cal, y_ref, dtype, device='npu'):
|
||||
y_cal = y_cal.to(device)
|
||||
y_ref = y_ref.to(device)
|
||||
if dtype == torch.float16:
|
||||
torch.testing.assert_close(y_ref,
|
||||
y_cal,
|
||||
rtol=3e-03,
|
||||
atol=1e-02,
|
||||
equal_nan=True)
|
||||
elif dtype == torch.bfloat16:
|
||||
torch.testing.assert_close(y_ref,
|
||||
y_cal,
|
||||
rtol=1e-02,
|
||||
atol=1e-02,
|
||||
equal_nan=True)
|
||||
elif dtype == torch.float32:
|
||||
torch.testing.assert_close(y_ref,
|
||||
y_cal,
|
||||
rtol=1e-03,
|
||||
atol=4e-03,
|
||||
equal_nan=True)
|
||||
elif dtype == torch.int32 or dtype == torch.int64 or dtype == torch.int16 or dtype == torch.int8 or dtype == torch.uint32:
|
||||
assert torch.equal(y_cal, y_ref)
|
||||
elif dtype == torch.bool:
|
||||
assert torch.equal(y_cal, y_ref)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Invalid parameter \"dtype\" is found : {}'.format(dtype))
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
initial_states: Optional[torch.Tensor] = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1)
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
x = x.to(weight.dtype)
|
||||
seqlen = x.shape[-1]
|
||||
dim, width = weight.shape
|
||||
|
||||
if initial_states is None:
|
||||
out = F.conv1d(x,
|
||||
weight.unsqueeze(1),
|
||||
bias,
|
||||
padding=width - 1,
|
||||
groups=dim)
|
||||
else:
|
||||
x = torch.cat([initial_states, x], dim=-1)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
||||
out = out[..., :seqlen]
|
||||
|
||||
if return_final_states:
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||
dtype_in) # (batch, dim, width - 1)
|
||||
if final_states_out is not None:
|
||||
final_states_out.copy_(final_states)
|
||||
else:
|
||||
final_states_out = final_states
|
||||
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
return (out, None) if not return_final_states else (out, final_states_out)
|
||||
|
||||
|
||||
def causal_conv1d_fn_pytorch(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
cache_indices: torch.Tensor,
|
||||
has_initial_state: torch.Tensor,
|
||||
conv_states: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
|
||||
sequences are concatenated from left to right for varlen
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
query_start_loc: (batch + 1) int32
|
||||
The cumulative sequence lengths of the sequences in
|
||||
the batch, used to index into sequence. prepended by 0.
|
||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||
x.shape=(dim,17)
|
||||
cache_indices: (batch) int32
|
||||
indicates the corresponding state index,
|
||||
like so: conv_state = conv_states[cache_indices[batch_id]]
|
||||
has_initial_state: (batch) bool
|
||||
indicates whether should the kernel take the current state as initial
|
||||
state for the calculations
|
||||
conv_states: (...,dim,width - 1) itype
|
||||
updated inplace if provided
|
||||
activation: either None or "silu" or "swish"
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
out_ref = []
|
||||
out_ref_b = []
|
||||
seqlens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
seqlens = seqlens.tolist()
|
||||
splits = torch.split(x, seqlens, dim=-1)
|
||||
width = weight.shape[1]
|
||||
|
||||
for i in range(len(seqlens)):
|
||||
x_s = splits[i]
|
||||
if cache_indices[i] == PAD_SLOT_ID:
|
||||
continue
|
||||
out_ref_b.append(
|
||||
causal_conv1d_ref(
|
||||
x_s,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
return_final_states=True,
|
||||
final_states_out=conv_states[cache_indices[i]][..., :(
|
||||
width - 1)].unsqueeze(0),
|
||||
initial_states=conv_states[cache_indices[i]][..., :(width - 1)]
|
||||
if has_initial_state[i] else None))
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
return out_ref_tensor
|
||||
|
||||
|
||||
@pytest.mark.parametrize('has_initial_state', [False, True])
|
||||
@pytest.mark.parametrize('itype', [torch.bfloat16])
|
||||
@pytest.mark.parametrize('silu_activation', [True])
|
||||
@pytest.mark.parametrize('has_bias', [True])
|
||||
@pytest.mark.parametrize('seq_len', [[128, 1024, 2048, 4096]])
|
||||
@pytest.mark.parametrize('extra_state_len', [0, 2])
|
||||
@pytest.mark.parametrize('width', [2, 4])
|
||||
@pytest.mark.parametrize('dim', [4160])
|
||||
def test_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias,
|
||||
silu_activation, itype, has_initial_state):
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
|
||||
device = "npu"
|
||||
cu_seqlen, num_seq = sum(seq_len), len(seq_len)
|
||||
state_len = width - 1 + extra_state_len
|
||||
|
||||
x = torch.randn(cu_seqlen, dim, device=device, dtype=itype).transpose(0, 1)
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
query_start_loc = torch.cumsum(torch.tensor([0] + seq_len,
|
||||
device=device,
|
||||
dtype=torch.int32),
|
||||
dim=0)
|
||||
cache_indices = torch.arange(num_seq, device=device, dtype=torch.int32)
|
||||
has_initial_state_tensor = torch.tensor([has_initial_state] * num_seq,
|
||||
device=device,
|
||||
dtype=torch.bool)
|
||||
activation = None if not silu_activation else "silu"
|
||||
|
||||
if has_initial_state:
|
||||
conv_states = torch.randn((num_seq, state_len, dim),
|
||||
device=device,
|
||||
dtype=itype).transpose(-1, -2)
|
||||
conv_states_ref = torch.randn(
|
||||
(num_seq, state_len, dim), device=device,
|
||||
dtype=itype).transpose(-1, -2).copy_(conv_states)
|
||||
else:
|
||||
conv_states = torch.zeros((num_seq, state_len, dim),
|
||||
device=device,
|
||||
dtype=itype).transpose(-1, -2)
|
||||
conv_states_ref = torch.zeros((num_seq, state_len, dim),
|
||||
device=device,
|
||||
dtype=itype).transpose(-1, -2)
|
||||
|
||||
if has_bias:
|
||||
bias = torch.randn(dim, device=device, dtype=itype)
|
||||
else:
|
||||
bias = None
|
||||
|
||||
out_ref = causal_conv1d_fn_pytorch(
|
||||
x,
|
||||
weight,
|
||||
bias=bias,
|
||||
activation=activation,
|
||||
conv_states=conv_states_ref,
|
||||
has_initial_state=has_initial_state_tensor,
|
||||
cache_indices=cache_indices,
|
||||
query_start_loc=query_start_loc)
|
||||
out = causal_conv1d_fn(x,
|
||||
weight,
|
||||
bias=bias,
|
||||
activation=activation,
|
||||
conv_states=conv_states,
|
||||
has_initial_state=has_initial_state_tensor,
|
||||
cache_indices=cache_indices,
|
||||
query_start_loc=query_start_loc)
|
||||
|
||||
validate_cmp(out, out_ref, itype)
|
||||
validate_cmp(conv_states, conv_states_ref, itype)
|
||||
|
||||
|
||||
def causal_conv1d_update_ref(x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias=None,
|
||||
activation=None,
|
||||
cache_seqlens=None):
|
||||
"""
|
||||
x: (batch, dim) or (batch, dim, seqlen)
|
||||
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
cache_seqlens: (batch,), dtype int32.
|
||||
If not None, the conv_state is treated as a circular buffer.
|
||||
The conv_state will be updated by copying x to the
|
||||
conv_state starting at the index
|
||||
@cache_seqlens % state_len before performing the convolution.
|
||||
out: (batch, dim) or (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
unsqueeze = x.dim() == 2
|
||||
if unsqueeze:
|
||||
x = x.unsqueeze(-1)
|
||||
batch, dim, seqlen = x.shape
|
||||
width = weight.shape[1]
|
||||
state_len = conv_state.shape[-1]
|
||||
assert conv_state.shape == (batch, dim, state_len)
|
||||
assert weight.shape == (dim, width)
|
||||
if cache_seqlens is None:
|
||||
x_new = torch.cat([conv_state, x], dim=-1).to(
|
||||
weight.dtype) # (batch, dim, state_len + seqlen)
|
||||
conv_state.copy_(x_new[:, :, -state_len:])
|
||||
else:
|
||||
width_idx = torch.arange(
|
||||
-(width - 1), 0, dtype=torch.long,
|
||||
device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
||||
width_idx = (torch.remainder(width_idx, state_len).unsqueeze(1).expand(
|
||||
-1, dim, -1))
|
||||
x_new = torch.cat([conv_state.gather(2, width_idx), x],
|
||||
dim=-1).to(weight.dtype)
|
||||
copy_idx = torch.arange(
|
||||
seqlen, dtype=torch.long,
|
||||
device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
||||
copy_idx = torch.remainder(copy_idx,
|
||||
state_len).unsqueeze(1).expand(-1, dim, -1)
|
||||
conv_state.scatter_(2, copy_idx, x)
|
||||
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0,
|
||||
groups=dim)[:, :, -seqlen:]
|
||||
if unsqueeze:
|
||||
out = out.squeeze(-1)
|
||||
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("seqlen", [1, 3])
|
||||
@pytest.mark.parametrize("width", [3, 4])
|
||||
@pytest.mark.parametrize("dim", [2048 + 16, 4096])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
@pytest.mark.parametrize("batch_size", [3, 64])
|
||||
def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
|
||||
width, seqlen, has_bias,
|
||||
silu_activation, itype):
|
||||
device = "npu"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
|
||||
padding = 5 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
# total_entries = number of cache line
|
||||
total_entries = 10 * batch_size
|
||||
|
||||
# x will be (batch, dim, seqlen) with contiguous along dim-axis
|
||||
x = torch.randn(padded_batch_size, seqlen, dim, device=device,
|
||||
dtype=itype).transpose(1, 2)
|
||||
|
||||
x_ref = x.clone()
|
||||
|
||||
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
dtype=torch.int32, device=device)
|
||||
unused_states_bool = torch.ones(total_entries,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
unused_states_bool[conv_state_indices] = False
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
conv_state_indices,
|
||||
torch.as_tensor(
|
||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# conv_state will be (cache_lines, dim, state_len)
|
||||
# with contiguous along dim-axis
|
||||
conv_state = torch.randn(total_entries,
|
||||
width - 1,
|
||||
dim,
|
||||
device=device,
|
||||
dtype=itype).transpose(1, 2)
|
||||
|
||||
conv_state_for_padding_test = conv_state.clone()
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
|
||||
out = causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_state_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
|
||||
conv_state_ref,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation)
|
||||
|
||||
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
||||
assert torch.equal(conv_state[unused_states_bool],
|
||||
conv_state_for_padding_test[unused_states_bool])
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
@@ -0,0 +1,34 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm_ascend.ops.triton.fla.l2norm import l2norm_fwd
|
||||
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('B', 'T', 'H', 'D', 'dtype'),
|
||||
[
|
||||
pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test))
|
||||
for test in [
|
||||
(1, 63, 1, 60, torch.float),
|
||||
(2, 500, 4, 64, torch.float),
|
||||
(2, 1000, 2, 100, torch.float),
|
||||
(3, 1024, 4, 128, torch.float),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_l2norm(B: int, T: int, H: int, D: int, dtype: torch.dtype):
|
||||
torch.manual_seed(42)
|
||||
init_device_properties_triton()
|
||||
device = "npu"
|
||||
rtol, atol = (3e-4, 1e-3) if dtype == torch.float32 else (3e-3, 5e-3)
|
||||
if dtype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
x = torch.randn(B, T, H, D, dtype=dtype).to(device).requires_grad_(True)
|
||||
x = x * 0.5 + 0.3
|
||||
|
||||
ref = F.normalize(x, dim=-1, p=2)
|
||||
tri = l2norm_fwd(x)
|
||||
|
||||
assert torch.allclose(tri, ref, rtol=rtol, atol=atol)
|
||||
@@ -0,0 +1,141 @@
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
|
||||
IS_NEOX_STYLE = [True, False]
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
HEAD_SIZES = [64, 128]
|
||||
ROTARY_DIMS = [32, 64]
|
||||
NUM_Q_HEADS = [64]
|
||||
NUM_K_HEADS = [1]
|
||||
NUM_TOKENS = [1, 4, 8, 16, 1024]
|
||||
SEEDS = [0]
|
||||
DEVICES = [f"npu:{0}"]
|
||||
DEFAULT_ATOL = 1e-3
|
||||
DEFAULT_RTOL = 1e-3
|
||||
|
||||
|
||||
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return x.flatten(-2)
|
||||
|
||||
|
||||
def _rope_pytorch_native(
|
||||
query, key, cos, sin, rope_dim,
|
||||
is_neox_style) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
assert key is not None
|
||||
orig_dtype = query.dtype
|
||||
query_rot = query[..., :rope_dim].to(torch.float32)
|
||||
key_rot = key[..., :rope_dim].to(torch.float32)
|
||||
head_size = query.shape[-1]
|
||||
if rope_dim < head_size:
|
||||
query_pass = query[..., rope_dim:]
|
||||
key_pass = key[..., rope_dim:]
|
||||
|
||||
if is_neox_style:
|
||||
cos = cos.repeat(1, 2).unsqueeze(-2).to(torch.float32)
|
||||
sin = sin.repeat(1, 2).unsqueeze(-2).to(torch.float32)
|
||||
else:
|
||||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2).to(torch.float32)
|
||||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2).to(torch.float32)
|
||||
|
||||
rotate_fn = rotate_neox if is_neox_style else rotate_gptj
|
||||
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
|
||||
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
||||
|
||||
if rope_dim < head_size:
|
||||
query = torch.cat((query_rot.to(orig_dtype), query_pass), dim=-1)
|
||||
key = torch.cat((key_rot.to(orig_dtype), key_pass), dim=-1)
|
||||
else:
|
||||
query = query_rot.to(orig_dtype)
|
||||
key = key_rot.to(orig_dtype)
|
||||
return query, key
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_q_heads", NUM_Q_HEADS)
|
||||
@pytest.mark.parametrize("num_k_heads", NUM_K_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding_triton_kernel(
|
||||
is_neox_style: bool,
|
||||
num_tokens: int,
|
||||
num_q_heads: int,
|
||||
num_k_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
torch.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
init_device_properties_triton()
|
||||
if rotary_dim == -1:
|
||||
rotary_dim = head_size
|
||||
sin = torch.randn(num_tokens, rotary_dim // 2, dtype=dtype, device=device)
|
||||
cos = torch.randn(num_tokens, rotary_dim // 2, dtype=dtype, device=device)
|
||||
q_trt = torch.randn(num_tokens,
|
||||
num_q_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k_trt = torch.randn(num_tokens,
|
||||
num_k_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
q_gold = torch.randn(num_tokens,
|
||||
num_q_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k_gold = torch.randn(num_tokens,
|
||||
num_k_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
q_trt.copy_(q_gold)
|
||||
k_trt.copy_(k_gold)
|
||||
q_trt, k_trt = rope_forward_triton(q_trt,
|
||||
k_trt,
|
||||
cos,
|
||||
sin,
|
||||
rope_dim=rotary_dim,
|
||||
is_neox_style=is_neox_style)
|
||||
q_gold, k_gold = _rope_pytorch_native(q_gold,
|
||||
k_gold,
|
||||
cos,
|
||||
sin,
|
||||
rope_dim=rotary_dim,
|
||||
is_neox_style=is_neox_style)
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(q_trt.view(q_gold.size()),
|
||||
q_gold,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
torch.testing.assert_close(k_trt.view(k_gold.size()),
|
||||
k_gold,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
Reference in New Issue
Block a user