[Misc][Test] add e2e test for apply_top_k_top_p_custom kernel (#6348)
### What this PR does / why we need it?
Add e2e test case for apply_top_k_top_p_custom kernel and eliminate
chinese comments.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
pytest passed.
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -38,7 +38,7 @@ static const int64_t EXPECTED_DIM_ONE = 1;
|
||||
static const int64_t EXPECTED_DIM_TWO = 2;
|
||||
static constexpr size_t DIM_ONE = 1;
|
||||
|
||||
// 根据API定义,需要列出所能支持的所有dtype
|
||||
// According to the API definition, all supported dtypes must be enumerated.
|
||||
static const std::initializer_list<op::DataType> DTYPE_SUPPORT_LIST = {
|
||||
op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16, op::DataType::DT_BF16};
|
||||
|
||||
@@ -57,7 +57,7 @@ static bool CheckNotNull(const aclTensor* logits, const aclTensor* p, const aclT
|
||||
|
||||
static bool CheckDtypeValid(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
|
||||
{
|
||||
// 检查数据类型是否在支持列表内
|
||||
// Check whether the data type is within the supported list.
|
||||
OP_CHECK_DTYPE_NOT_SUPPORT(logits, DTYPE_SUPPORT_LIST, return false);
|
||||
if (p != nullptr) {
|
||||
OP_CHECK_DTYPE_NOT_SUPPORT(p, DTYPE_SUPPORT_LIST, return false);
|
||||
@@ -67,7 +67,7 @@ static bool CheckDtypeValid(const aclTensor* logits, const aclTensor* p, const a
|
||||
}
|
||||
OP_CHECK_DTYPE_NOT_SUPPORT(out, DTYPE_SUPPORT_LIST, return false);
|
||||
|
||||
// 检查数据类型是否相同
|
||||
// Check whether the data types are identical.
|
||||
if (p != nullptr) {
|
||||
OP_CHECK_DTYPE_NOT_MATCH(p, logits->GetDataType(), return false);
|
||||
}
|
||||
@@ -121,17 +121,17 @@ static bool CheckFormatValid(const aclTensor* logits, const aclTensor* p, const
|
||||
|
||||
static aclnnStatus CheckParams(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
|
||||
{
|
||||
// 错误码等DFX方案细化后刷新,错误日志在check接口内打印
|
||||
// 1. 检查参数是否为空指针
|
||||
// Refresh after the DFX scheme for error codes, etc. is refined; error logs are printed inside the check interfaces.
|
||||
// 1. Check whether any parameters are null pointers.
|
||||
CHECK_RET(CheckNotNull(logits, p, k, out), ACLNN_ERR_PARAM_NULLPTR);
|
||||
|
||||
// 2. 检查输入的数据类型是否在API支持的数据类型范围之内,需要根据api定义校验
|
||||
// 2. Check whether the input data types are within the range supported by the API; validate according to the API definition.
|
||||
CHECK_RET(CheckDtypeValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID);
|
||||
|
||||
// 3. 检查shape是否满足约束
|
||||
// 3. Check whether shapes satisfy the constraints.
|
||||
CHECK_RET(CheckShapeValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID);
|
||||
|
||||
// 4. 检查format是否满足约束
|
||||
// 4. Check whether formats satisfy the constraints.
|
||||
CHECK_RET(CheckFormatValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID);
|
||||
|
||||
return ACLNN_SUCCESS;
|
||||
@@ -144,11 +144,11 @@ aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize(
|
||||
{
|
||||
OP_CHECK_COMM_INPUT(workspaceSize, executor);
|
||||
L2_DFX_PHASE_1(aclnnApplyTopKTopPCustom, DFX_IN(logits, p, k), DFX_OUT(out));
|
||||
// 固定写法,创建OpExecutor
|
||||
// Fixed boilerplate: create OpExecutor.
|
||||
auto uniqueExecutor = CREATE_EXECUTOR();
|
||||
CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
|
||||
|
||||
// 固定写法,参数检查
|
||||
// Fixed boilerplate: parameter validation.
|
||||
auto ret = CheckParams(logits, p, k, out);
|
||||
CHECK_RET(ret == ACLNN_SUCCESS, ret);
|
||||
bool pIsEmpty = false;
|
||||
@@ -160,12 +160,12 @@ aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize(
|
||||
kIsEmpty = k->IsEmpty();
|
||||
}
|
||||
if (logits->IsEmpty() || pIsEmpty || kIsEmpty) {
|
||||
// 根据实际支持情况补充
|
||||
// Supplement according to actual support status.
|
||||
*workspaceSize = 0;
|
||||
uniqueExecutor.ReleaseTo(executor);
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
// 固定写法,将输入selfRef转换成连续的tensor
|
||||
// Fixed boilerplate: convert the input selfRef to a contiguous tensor.
|
||||
auto logitsContiguous = l0op::Contiguous(logits, uniqueExecutor.get());
|
||||
CHECK_RET(logitsContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
const aclTensor* pContiguous = nullptr;
|
||||
@@ -190,24 +190,23 @@ aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize(
|
||||
CHECK_RET(sortedIndices != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
auto res = l0op::ApplyTopKTopPCustom(sortedValue, sortedIndices, pContiguous, kContiguous, uniqueExecutor.get());
|
||||
CHECK_RET(res != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
// 固定写法,将计算结果拷贝到输出out上,out可能是非连续的tensor
|
||||
// Fixed boilerplate: copy the computed result to the output 'out'; 'out' may be a non-contiguous tensor.
|
||||
viewCopyResult = l0op::ViewCopy(res, out, uniqueExecutor.get());
|
||||
}
|
||||
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
// 固定写法,获取计算过程中需要使用的workspace大小
|
||||
// Fixed boilerplate: obtain the workspace size required during computation.
|
||||
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
|
||||
uniqueExecutor.ReleaseTo(executor); // 需要把 uniqueExecutor持有executor转移给executor
|
||||
uniqueExecutor.ReleaseTo(executor); // Transfer ownership of the executor held by uniqueExecutor to executor.
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
|
||||
aclnnStatus aclnnApplyTopKTopPCustom(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)
|
||||
{
|
||||
L2_DFX_PHASE_2(aclnnApplyTopKTopPCustom);
|
||||
// 固定写法,调用框架能力,完成计算
|
||||
// Fixed boilerplate: invoke framework capability to complete the computation.
|
||||
return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
|
||||
def cpu_op_exec(logits, p, k):
|
||||
"""
|
||||
Apply top-k and top-p sampling filtering.
|
||||
"""
|
||||
# Sort logits in ascending order
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False, stable=True)
|
||||
|
||||
# 1. Apply top-k filtering
|
||||
if k is not None:
|
||||
# Ensure k does not exceed vocab_size
|
||||
k = torch.minimum(k, torch.tensor(logits.size(-1), device=k.device))
|
||||
top_k_mask_idx = logits_sort.size(1) - k.to(torch.long)
|
||||
top_k_threshold = logits_sort.gather(1, top_k_mask_idx.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_threshold
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
# 2. Apply top-p (nucleus) filtering
|
||||
if p is not None:
|
||||
probs_sort = logits_sort.to(torch.float32).softmax(dim=-1)
|
||||
probs_sum = probs_sort.cumsum(dim=-1)
|
||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
|
||||
# 3. Restore original order
|
||||
logits = torch.empty_like(logits_sort).scatter_(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
|
||||
|
||||
def cpu_op_exec_top_k(logits, p, k):
|
||||
return cpu_op_exec(logits, None, k)
|
||||
|
||||
|
||||
def cpu_op_exec_top_p(logits, p, k):
|
||||
return cpu_op_exec(logits, p, None)
|
||||
|
||||
|
||||
def ascendc_op_exec(logits, p, k):
|
||||
"""
|
||||
Execute the custom Ascend NPU operator.
|
||||
"""
|
||||
logits_npu = logits.npu()
|
||||
p_npu = p.npu() if p is not None else None
|
||||
k_npu = k.npu() if k is not None else None
|
||||
|
||||
return torch.ops._C_ascend.npu_apply_top_k_top_p(logits_npu, k=k_npu, p=p_npu).cpu()
|
||||
|
||||
|
||||
def assert_output_close(out_cpu, out_npu, rtol=1e-4, atol=1e-4):
|
||||
"""
|
||||
Custom assertion to handle Top-P boundary precision issues.
|
||||
"""
|
||||
# 1. Check mask consistency (inf vs finite)
|
||||
mask_cpu = torch.isinf(out_cpu) & (out_cpu < 0)
|
||||
mask_npu = torch.isinf(out_npu) & (out_npu < 0)
|
||||
|
||||
mismatch_mask = mask_cpu ^ mask_npu
|
||||
mismatch_count = mismatch_mask.sum().item()
|
||||
total_elements = out_cpu.numel()
|
||||
|
||||
# Allow 0.1% mismatch for boundary floating point precision differences
|
||||
mismatch_ratio = mismatch_count / total_elements
|
||||
if mismatch_ratio > 0.001:
|
||||
pytest.fail(f"Mask mismatch ratio too high: {mismatch_ratio:.6f} ({mismatch_count}/{total_elements})")
|
||||
|
||||
# 2. Check value consistency for valid elements
|
||||
valid_mask = (~mask_cpu) & (~mask_npu)
|
||||
if valid_mask.any():
|
||||
torch.testing.assert_close(
|
||||
out_cpu[valid_mask],
|
||||
out_npu[valid_mask],
|
||||
rtol=rtol,
|
||||
atol=atol
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize('vocab_size', [15206, 152064])
|
||||
@pytest.mark.parametrize('batch_size', [4, 8, 16, 32, 64, 96, 128, 256])
|
||||
@pytest.mark.parametrize('p_val', [0.5, 0.9, 0.99])
|
||||
@pytest.mark.parametrize('k_val', [50, 200, 1024, 4096, 8192])
|
||||
def test_npu_apply_top_k_top_p(vocab_size, batch_size, p_val, k_val):
|
||||
shape = [batch_size, vocab_size]
|
||||
dtype = torch.float32
|
||||
|
||||
logits = torch.from_numpy(np.random.uniform(-5, 5, shape)).to(dtype)
|
||||
p = torch.full((batch_size,), p_val, dtype=dtype)
|
||||
k = torch.full((batch_size,), k_val, dtype=torch.int32)
|
||||
|
||||
out_cpu = cpu_op_exec(logits.clone(), p, k)
|
||||
out_npu = ascendc_op_exec(logits, p, k)
|
||||
|
||||
assert_output_close(out_cpu, out_npu)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('vocab_size', [15206, 152064])
|
||||
@pytest.mark.parametrize('batch_size', [4, 8, 16, 32, 64, 96, 128, 256])
|
||||
@pytest.mark.parametrize('k_val', [50, 200, 1024, 4096, 8192])
|
||||
def test_npu_apply_top_k(vocab_size, batch_size, k_val):
|
||||
shape = [batch_size, vocab_size]
|
||||
dtype = torch.float32
|
||||
|
||||
logits = torch.from_numpy(np.random.uniform(-5, 5, shape)).to(dtype)
|
||||
p = None
|
||||
k = torch.full((batch_size,), k_val, dtype=torch.int32)
|
||||
|
||||
out_cpu = cpu_op_exec_top_k(logits.clone(), p, k)
|
||||
out_npu = ascendc_op_exec(logits, p, k)
|
||||
|
||||
assert_output_close(out_cpu, out_npu)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('vocab_size', [15206, 152064])
|
||||
@pytest.mark.parametrize('batch_size', [4, 8, 16, 32, 64, 96, 128, 256])
|
||||
@pytest.mark.parametrize('p_val', [0.5, 0.9, 0.99])
|
||||
def test_npu_apply_top_p(vocab_size, batch_size, p_val):
|
||||
shape = [batch_size, vocab_size]
|
||||
dtype = torch.float32
|
||||
|
||||
logits = torch.from_numpy(np.random.uniform(-5, 5, shape)).to(dtype)
|
||||
p = torch.full((batch_size,), p_val, dtype=dtype)
|
||||
k = None
|
||||
|
||||
out_cpu = cpu_op_exec_top_p(logits.clone(), p, k)
|
||||
out_npu = ascendc_op_exec(logits, p, k)
|
||||
|
||||
assert_output_close(out_cpu, out_npu)
|
||||
Reference in New Issue
Block a user