[Misc][Test] add e2e test for apply_top_k_top_p_custom kernel (#6348)

### What this PR does / why we need it?
Add e2e test case for apply_top_k_top_p_custom kernel and eliminate
chinese comments.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
pytest passed.

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-01-28 17:25:57 +08:00
committed by GitHub
parent 857c533e27
commit e25ee65729
2 changed files with 155 additions and 17 deletions

View File

@@ -38,7 +38,7 @@ static const int64_t EXPECTED_DIM_ONE = 1;
static const int64_t EXPECTED_DIM_TWO = 2;
static constexpr size_t DIM_ONE = 1;
// 根据API定义需要列出所能支持的所有dtype
// According to the API definition, all supported dtypes must be enumerated.
static const std::initializer_list<op::DataType> DTYPE_SUPPORT_LIST = {
op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16, op::DataType::DT_BF16};
@@ -57,7 +57,7 @@ static bool CheckNotNull(const aclTensor* logits, const aclTensor* p, const aclT
static bool CheckDtypeValid(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
{
// 检查数据类型是否在支持列表内
// Check whether the data type is within the supported list.
OP_CHECK_DTYPE_NOT_SUPPORT(logits, DTYPE_SUPPORT_LIST, return false);
if (p != nullptr) {
OP_CHECK_DTYPE_NOT_SUPPORT(p, DTYPE_SUPPORT_LIST, return false);
@@ -67,7 +67,7 @@ static bool CheckDtypeValid(const aclTensor* logits, const aclTensor* p, const a
}
OP_CHECK_DTYPE_NOT_SUPPORT(out, DTYPE_SUPPORT_LIST, return false);
// 检查数据类型是否相同
// Check whether the data types are identical.
if (p != nullptr) {
OP_CHECK_DTYPE_NOT_MATCH(p, logits->GetDataType(), return false);
}
@@ -121,17 +121,17 @@ static bool CheckFormatValid(const aclTensor* logits, const aclTensor* p, const
static aclnnStatus CheckParams(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
{
// 错误码等DFX方案细化后刷新错误日志在check接口内打印
// 1. 检查参数是否为空指针
// Refresh after the DFX scheme for error codes, etc. is refined; error logs are printed inside the check interfaces.
// 1. Check whether any parameters are null pointers.
CHECK_RET(CheckNotNull(logits, p, k, out), ACLNN_ERR_PARAM_NULLPTR);
// 2. 检查输入的数据类型是否在API支持的数据类型范围之内需要根据api定义校验
// 2. Check whether the input data types are within the range supported by the API; validate according to the API definition.
CHECK_RET(CheckDtypeValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID);
// 3. 检查shape是否满足约束
// 3. Check whether shapes satisfy the constraints.
CHECK_RET(CheckShapeValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID);
// 4. 检查format是否满足约束
// 4. Check whether formats satisfy the constraints.
CHECK_RET(CheckFormatValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID);
return ACLNN_SUCCESS;
@@ -144,11 +144,11 @@ aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize(
{
OP_CHECK_COMM_INPUT(workspaceSize, executor);
L2_DFX_PHASE_1(aclnnApplyTopKTopPCustom, DFX_IN(logits, p, k), DFX_OUT(out));
// 固定写法,创建OpExecutor
// Fixed boilerplate: create OpExecutor.
auto uniqueExecutor = CREATE_EXECUTOR();
CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
// 固定写法,参数检查
// Fixed boilerplate: parameter validation.
auto ret = CheckParams(logits, p, k, out);
CHECK_RET(ret == ACLNN_SUCCESS, ret);
bool pIsEmpty = false;
@@ -160,12 +160,12 @@ aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize(
kIsEmpty = k->IsEmpty();
}
if (logits->IsEmpty() || pIsEmpty || kIsEmpty) {
// 根据实际支持情况补充
// Supplement according to actual support status.
*workspaceSize = 0;
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
// 固定写法将输入selfRef转换成连续的tensor
// Fixed boilerplate: convert the input selfRef to a contiguous tensor.
auto logitsContiguous = l0op::Contiguous(logits, uniqueExecutor.get());
CHECK_RET(logitsContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
const aclTensor* pContiguous = nullptr;
@@ -190,24 +190,23 @@ aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize(
CHECK_RET(sortedIndices != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto res = l0op::ApplyTopKTopPCustom(sortedValue, sortedIndices, pContiguous, kContiguous, uniqueExecutor.get());
CHECK_RET(res != nullptr, ACLNN_ERR_INNER_NULLPTR);
// 固定写法将计算结果拷贝到输出out上out可能是非连续的tensor
// Fixed boilerplate: copy the computed result to the output 'out'; 'out' may be a non-contiguous tensor.
viewCopyResult = l0op::ViewCopy(res, out, uniqueExecutor.get());
}
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
// 固定写法获取计算过程中需要使用的workspace大小
// Fixed boilerplate: obtain the workspace size required during computation.
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
uniqueExecutor.ReleaseTo(executor); // 需要把 uniqueExecutor持有executor转移给executor
uniqueExecutor.ReleaseTo(executor); // Transfer ownership of the executor held by uniqueExecutor to executor.
return ACLNN_SUCCESS;
}
aclnnStatus aclnnApplyTopKTopPCustom(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)
{
L2_DFX_PHASE_2(aclnnApplyTopKTopPCustom);
// 固定写法,调用框架能力,完成计算
// Fixed boilerplate: invoke framework capability to complete the computation.
return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
}
#ifdef __cplusplus
}
#endif