[ops] support advanced apply_top_k_top_p without top_k constraint (#6098)

### What this PR does / why we need it?
Implement `apply_top_k_top_p` via ascendC to eliminate the constraint of
k [1,1024]. It enables high performance TopKTopP calculation and avoid
D2H synchronization introduced by k validation.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
E2E serving with `k=4096` and  `p=0.95`
- vLLM version: v0.13.0
- vLLM main:
d68209402d

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
linfeng-yuan
2026-01-26 09:08:42 +08:00
committed by GitHub
parent 4e3919e965
commit 96309e2b79
16 changed files with 2208 additions and 3 deletions

View File

@@ -0,0 +1,41 @@
add_ops_compile_options(
OP_NAME ApplyTopKTopPCustom
OPTIONS --cce-auto-sync=on
-Wno-deprecated-declarations
-Werror
)
target_sources(op_host_aclnnExc PRIVATE
apply_top_k_top_p_custom_def.cpp
)
target_sources(opapi PRIVATE
apply_top_k_top_p_custom.cpp
aclnn_apply_top_k_top_p_custom.cpp
)
if (NOT BUILD_OPEN_PROJECT)
target_sources(aclnn_ops_train PRIVATE
apply_top_k_top_p_custom.cpp
aclnn_apply_top_k_top_p_custom.cpp
)
target_sources(aclnn_ops_infer PRIVATE
apply_top_k_top_p_custom.cpp
aclnn_apply_top_k_top_p_custom.cpp
)
endif ()
target_sources(optiling PRIVATE
apply_top_k_top_p_custom_tiling.cpp
)
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_apply_top_k_top_p_custom.h")
install(FILES ${_GMM_Aclnn_header}
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
)

View File

@@ -0,0 +1,213 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file aclnn_apply_top_k_top_p_custom.cpp
* \brief
*/
#include "aclnn_apply_top_k_top_p_custom.h"
#include "apply_top_k_top_p_custom.h"
#include "sort.h"
#include "aclnn_kernels/contiguous.h"
#include "aclnn_kernels/common/op_error_check.h"
#include "aclnn/aclnn_base.h"
#include "opdev/common_types.h"
#include "opdev/data_type_utils.h"
#include "opdev/make_op_executor.h"
#include "opdev/format_utils.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/tensor_view_utils.h"
#include "opdev/shape_utils.h"
#include "opdev/platform.h"
using namespace op;
#ifdef __cplusplus
extern "C" {
#endif
namespace {
static const int64_t EXPECTED_DIM_ONE = 1;
static const int64_t EXPECTED_DIM_TWO = 2;
static constexpr size_t DIM_ONE = 1;
// 根据API定义需要列出所能支持的所有dtype
static const std::initializer_list<op::DataType> DTYPE_SUPPORT_LIST = {
op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16, op::DataType::DT_BF16};
static const std::initializer_list<op::DataType> INT_DTYPE_SUPPORT_LIST = {
op::DataType::DT_INT32};
static bool CheckNotNull(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
{
OP_CHECK_NULL(logits, return false);
if (p == nullptr && k == nullptr) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The inputs, p and k, should not be nullptr at the same time.");
}
OP_CHECK_NULL(out, return false);
return true;
}
static bool CheckDtypeValid(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
{
// 检查数据类型是否在支持列表内
OP_CHECK_DTYPE_NOT_SUPPORT(logits, DTYPE_SUPPORT_LIST, return false);
if (p != nullptr) {
OP_CHECK_DTYPE_NOT_SUPPORT(p, DTYPE_SUPPORT_LIST, return false);
}
if (k != nullptr) {
OP_CHECK_DTYPE_NOT_SUPPORT(k, INT_DTYPE_SUPPORT_LIST, return false);
}
OP_CHECK_DTYPE_NOT_SUPPORT(out, DTYPE_SUPPORT_LIST, return false);
// 检查数据类型是否相同
if (p != nullptr) {
OP_CHECK_DTYPE_NOT_MATCH(p, logits->GetDataType(), return false);
}
OP_CHECK_DTYPE_NOT_MATCH(out, logits->GetDataType(), return false);
return true;
}
static bool CheckShapeValid(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
{
OP_CHECK_WRONG_DIMENSION(logits, EXPECTED_DIM_TWO, return false);
OP_CHECK_SHAPE_NOT_EQUAL(out, logits, return false);
if (p != nullptr) {
OP_CHECK_WRONG_DIMENSION(p, EXPECTED_DIM_ONE, return false);
}
if (k != nullptr) {
OP_CHECK_WRONG_DIMENSION(k, EXPECTED_DIM_ONE, return false);
}
if (p != nullptr && p->GetViewShape().GetDim(0) != logits->GetViewShape().GetDim(0)) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "expected p.size(0) is equal to logits.size(0), but got %ld.",
p->GetViewShape().GetDim(0));
return false;
}
if (k != nullptr && k->GetViewShape().GetDim(0) != logits->GetViewShape().GetDim(0)) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "expected k.size(0) is equal to logits.size(0), but got %ld.",
k->GetViewShape().GetDim(0));
return false;
}
return true;
}
static bool CheckFormatValid(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
{
if (logits->GetStorageFormat() != Format::FORMAT_ND) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "logits format only support ND");
return false;
}
if (p != nullptr && p->GetStorageFormat() != Format::FORMAT_ND) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "p format only support ND");
return false;
}
if (k != nullptr && k->GetStorageFormat() != Format::FORMAT_ND) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "k format only support ND");
return false;
}
if (out->GetStorageFormat() != Format::FORMAT_ND) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "out format only support ND");
return false;
}
return true;
}
static aclnnStatus CheckParams(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
{
// 错误码等DFX方案细化后刷新错误日志在check接口内打印
// 1. 检查参数是否为空指针
CHECK_RET(CheckNotNull(logits, p, k, out), ACLNN_ERR_PARAM_NULLPTR);
// 2. 检查输入的数据类型是否在API支持的数据类型范围之内需要根据api定义校验
CHECK_RET(CheckDtypeValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID);
// 3. 检查shape是否满足约束
CHECK_RET(CheckShapeValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID);
// 4. 检查format是否满足约束
CHECK_RET(CheckFormatValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID);
return ACLNN_SUCCESS;
}
} // namespace
aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize(
const aclTensor* logits, const aclTensor* p, const aclTensor* k, aclTensor* out, uint64_t* workspaceSize,
aclOpExecutor** executor)
{
OP_CHECK_COMM_INPUT(workspaceSize, executor);
L2_DFX_PHASE_1(aclnnApplyTopKTopPCustom, DFX_IN(logits, p, k), DFX_OUT(out));
// 固定写法创建OpExecutor
auto uniqueExecutor = CREATE_EXECUTOR();
CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
// 固定写法,参数检查
auto ret = CheckParams(logits, p, k, out);
CHECK_RET(ret == ACLNN_SUCCESS, ret);
bool pIsEmpty = false;
bool kIsEmpty = false;
if (p != nullptr) {
pIsEmpty = p->IsEmpty();
}
if (k != nullptr) {
kIsEmpty = k->IsEmpty();
}
if (logits->IsEmpty() || pIsEmpty || kIsEmpty) {
// 根据实际支持情况补充
*workspaceSize = 0;
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
// 固定写法将输入selfRef转换成连续的tensor
auto logitsContiguous = l0op::Contiguous(logits, uniqueExecutor.get());
CHECK_RET(logitsContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
const aclTensor* pContiguous = nullptr;
const aclTensor* kContiguous = nullptr;
if (p != nullptr) {
pContiguous = l0op::Contiguous(p, uniqueExecutor.get());
CHECK_RET(pContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
if (k != nullptr) {
kContiguous = l0op::Contiguous(k, uniqueExecutor.get());
CHECK_RET(kContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
bool isLastDimSizeOne = logits->GetViewShape()[DIM_ONE] == 1;
auto viewCopyResult = logitsContiguous;
if (isLastDimSizeOne) {
viewCopyResult = l0op::ViewCopy(logitsContiguous, out, uniqueExecutor.get());
} else {
auto sortResult = l0op::Sort(logitsContiguous, -1, false, true, op::DataType::DT_INT32, uniqueExecutor.get());
const aclTensor* sortedValue = std::get<0>(sortResult);
CHECK_RET(sortedValue != nullptr, ACLNN_ERR_INNER_NULLPTR);
const aclTensor* sortedIndices = std::get<1>(sortResult);
CHECK_RET(sortedIndices != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto res = l0op::ApplyTopKTopPCustom(sortedValue, sortedIndices, pContiguous, kContiguous, uniqueExecutor.get());
CHECK_RET(res != nullptr, ACLNN_ERR_INNER_NULLPTR);
// 固定写法将计算结果拷贝到输出out上out可能是非连续的tensor
viewCopyResult = l0op::ViewCopy(res, out, uniqueExecutor.get());
}
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
// 固定写法获取计算过程中需要使用的workspace大小
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
uniqueExecutor.ReleaseTo(executor); // 需要把 uniqueExecutor持有executor转移给executor
return ACLNN_SUCCESS;
}
aclnnStatus aclnnApplyTopKTopPCustom(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)
{
L2_DFX_PHASE_2(aclnnApplyTopKTopPCustom);
// 固定写法,调用框架能力,完成计算
return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
}
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,54 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file aclnn_apply_top_k_top_p_custom.h
* \brief
*/
#ifndef OP_API_INC_APPLY_TOP_K_TOP_P_CUSTOM_H_
#define OP_API_INC_APPLY_TOP_K_TOP_P_CUSTOM_H_
#include "aclnn/aclnn_base.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* @brief aclnnApplyTopKTopPCustom的第一段接口根据具体的计算流程计算workspace大小。
* @domain aclnn_ops_infer
* @param [in] logits: npu device侧的aclTensor数据类型支持FLOAT、FLOAT16、BFLOAT16支持非连续的Tensor数据格式支持ND。
* @param [in] p: npu device侧的aclTensor数据类型支持FLOAT、FLOAT16、BFLOAT16支持非连续的Tensor数据格式支持ND。
* @param [in] k: npu device侧的aclTensor数据类型支持INT32支持非连续的Tensor数据格式支持ND。
* @param [in] out: npu device侧的aclTensor数据类型支持FLOAT、FLOAT16、BFLOAT16支持非连续的Tensor数据格式支持ND。
* @param [out] workspaceSize: 返回用户需要在npu device侧申请的workspace大小。
* @param [out] executor: 返回op执行器包含算子计算流程。
* @return aclnnStatus: 返回状态码。
*/
aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize(const aclTensor* logits, const aclTensor* p,
const aclTensor* k, aclTensor* out, uint64_t* workspaceSize,
aclOpExecutor** executor);
/**
* @brief aclnnApplyTopKTopPCustom的第二段接口用于执行计算。
* @param [in] workspace: 在npu device侧申请的workspace内存起址。
* @param [in] workspaceSize: 在npu device侧申请的workspace大小由第一段接口aaclnnApplyTopKTopPCustomGetWorkspaceSize获取。
* @param [in] stream: acl stream流。
* @param [in] executor: op执行器包含了算子计算流程。
* @return aclnnStatus: 返回状态码。
*/
aclnnStatus aclnnApplyTopKTopPCustom(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor,
aclrtStream stream);
#ifdef __cplusplus
}
#endif
#endif // OP_API_INC_APPLY_TOP_K_TOP_P_CUSTOM_H_

View File

@@ -0,0 +1,46 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file apply_top_k_top_p_custom.cpp
* \brief
*/
#include "apply_top_k_top_p_custom.h"
#include "opdev/data_type_utils.h"
#include "opdev/format_utils.h"
#include "opdev/make_op_executor.h"
#include "opdev/op_def.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/shape_utils.h"
using namespace op;
namespace l0op {
OP_TYPE_REGISTER(ApplyTopKTopPCustom);
const aclTensor* ApplyTopKTopPCustom(
const aclTensor* sortedValue, const aclTensor* sortedIndices, const aclTensor* p, const aclTensor* k,
aclOpExecutor* executor)
{
L0_DFX(ApplyTopKTopPCustom, sortedValue, sortedIndices, p, k);
auto output = executor->AllocTensor(sortedValue->GetViewShape(), sortedValue->GetDataType());
if (p == nullptr) {
p = executor->AllocTensor(sortedValue->GetDataType(), Format::FORMAT_ND, Format::FORMAT_ND);
}
if (k == nullptr) {
k = executor->AllocTensor(DataType::DT_INT32, Format::FORMAT_ND, Format::FORMAT_ND);
}
ADD_TO_LAUNCHER_LIST_AICORE(ApplyTopKTopPCustom, OP_INPUT(sortedValue, sortedIndices, p, k), OP_OUTPUT(output));
return output;
}
} // namespace l0op

View File

@@ -0,0 +1,24 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file apply_top_k_top_p_custom.h
* \brief
*/
#ifndef OP_API_INC_LEVEL0_OP_APPLY_TOP_K_TOP_P_CUSTOM_OP_H_
#define OP_API_INC_LEVEL0_OP_APPLY_TOP_K_TOP_P_CUSTOM_OP_H_
#include "opdev/op_executor.h"
namespace l0op {
const aclTensor* ApplyTopKTopPCustom(const aclTensor* sortedValue, const aclTensor* sortedIndices,
const aclTensor* p, const aclTensor* k, aclOpExecutor* executor);
}
#endif // OP_API_INC_LEVEL0_OP_APPLY_TOP_K_TOP_P_CUSTOM_OP_H_

View File

@@ -0,0 +1,100 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file apply_top_k_top_p_custom_def.cpp
* \brief
*/
#include "register/op_def_registry.h"
namespace ops {
class ApplyTopKTopPCustom : public OpDef {
public:
explicit ApplyTopKTopPCustom(const char *name) : OpDef(name)
{
this->Input("sorted_value")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("sorted_indices")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("p")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("k")
.ParamType(OPTIONAL)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
OpAICoreConfig aicore_config;
aicore_config.DynamicCompileStaticFlag(true)
.DynamicFormatFlag(false)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
OpAICoreConfig config_kirin = GetKirinCoreConfig();
this->AICore().AddConfig("kirinx90", config_kirin);
}
private:
OpAICoreConfig GetKirinCoreConfig() const
{
OpAICoreConfig config_kirin;
config_kirin.DynamicCompileStaticFlag(true)
.DynamicFormatFlag(true)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true)
.NeedCheckSupportFlag(false)
.PrecisionReduceFlag(true);
config_kirin.Input("sorted_value")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
config_kirin.Input("sorted_indices")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
config_kirin.Input("p")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
config_kirin.Input("k")
.ParamType(OPTIONAL)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
config_kirin.Output("out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
return config_kirin;
}
};
OP_ADD(ApplyTopKTopPCustom);
} // namespace ops

View File

@@ -0,0 +1,314 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file apply_top_k_top_p_custom_tiling.cpp
* \brief
*/
#include <iostream>
#include <map>
#include "error_log.h"
#include "tiling/platform/platform_ascendc.h"
#include "platform/platform_infos_def.h"
#include "register/op_def_registry.h"
#include "register/op_impl_registry.h"
#include "tiling/tiling_api.h"
#include "apply_top_k_top_p_custom_tiling.h"
namespace {
constexpr uint32_t SYS_RESERVED_UB = uint32_t(16 * 1024);
constexpr uint32_t SELECT_RESERVED_UB = uint32_t(8 * 1024);
constexpr uint32_t DIM_ONE = 1;
constexpr uint32_t DIM_TWO = 2;
constexpr int32_t SORTED_VALUE_INPUT_INDEX = 0;
constexpr int32_t SORTED_INDICES_INPUT_INDEX = 1;
constexpr int32_t P_INPUT_INDEX = 2;
constexpr int32_t K_INPUT_INDEX = 3;
constexpr uint32_t DIM_INDEX0 = 0;
constexpr uint32_t FLOAT_BYTES = 4;
static std::map<ge::DataType, uint32_t> DTYPE_MAP = {{ge::DT_BF16, 2}, {ge::DT_FLOAT16, 1}, {ge::DT_FLOAT, 0}};
static std::map<ge::DataType, uint32_t> DATATYPE_LEN_MAP = {
{ge::DT_FLOAT16, 2}, {ge::DT_BF16, 2}, {ge::DT_FLOAT, 4}};
const static uint32_t SYS_WORKSPACESIZE = uint32_t(16 * 1024 * 1024);
constexpr uint32_t DATA_PER_BLOCK_B32 = 8;
constexpr uint32_t BYTES_B32 = 4;
constexpr uint32_t BLOCK_BYTES = 32;
constexpr uint32_t K_VALUE_MAX = 1024;
constexpr uint32_t ONLY_TOP_P_KEY = 2;
constexpr uint32_t ONLY_TOP_K_KEY = 1;
constexpr uint32_t BATCH_MODE = 1;
} // namespace
namespace optiling {
class ApplyTopKTopPCustomTiling {
public:
explicit ApplyTopKTopPCustomTiling(gert::TilingContext* context) : tilingcontext(context){};
ge::graphStatus Init();
ge::graphStatus RunKernelTiling();
private:
ApplyTopKTopPCustomTilingData tilingData;
gert::TilingContext* tilingcontext = nullptr;
ge::graphStatus CheckShape();
void SetTilingKey();
void GetUsedCore();
void CalDataPerCore();
void FillTilingData();
void PrintTilingData();
template <typename T1>
inline auto CeilAlign(T1 a, T1 b) const -> T1
{
return b == 0 ? a : (a + b - 1) / b * b;
}
template <typename T1>
inline auto FloorAlign(T1 a, T1 b) const -> T1
{
return b == 0 ? a : a / b * b;
}
const char *opName_ = nullptr;
uint32_t coreNum_ = 0;
uint32_t calUbSize_ = 0;
uint32_t batchSize_ = 0;
uint32_t vocabSize_ = 0;
uint32_t tilingKey_ = 0;
uint32_t usedCoreNum_ = 0;
uint32_t batchPerCore_ = 1;
uint32_t tailBatch_ = 0;
uint32_t dataNumInit_ = 0;
uint32_t dataNumInitAligned_ = 0;
uint32_t ubFactorElement_ = 0;
uint32_t ubFactorElementAligned_ = 0;
uint32_t tailUbFactorElement_ = 0;
uint32_t tailUbFactorElementAligned_ = 0;
uint32_t iterateTimes_ = 0;
uint32_t onlyTopK_ = 0;
uint32_t onlyTopP_ = 0;
uint64_t platformUbSize_ = 0;
};
ge::graphStatus ApplyTopKTopPCustomTiling::CheckShape() {
auto sortedValueShapePtr = tilingcontext->GetInputShape(SORTED_VALUE_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(tilingcontext, sortedValueShapePtr);
auto sortedValueShape = sortedValueShapePtr->GetStorageShape();
if (sortedValueShape.GetDimNum() != DIM_TWO) {
OP_LOGE(opName_, "the dimNum of sorted_value should be 2, but got %u.", sortedValueShape.GetDimNum());
return ge::GRAPH_FAILED;
}
auto sortedIndicesShapePtr = tilingcontext->GetInputShape(SORTED_INDICES_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(tilingcontext, sortedIndicesShapePtr);
auto sortedIndicesShape = sortedIndicesShapePtr->GetStorageShape();
if (sortedIndicesShape.GetDimNum() != DIM_TWO) {
OP_LOGE(opName_, "the dimNum of sorted_indices should be 2, but got %u.", sortedIndicesShape.GetDimNum());
return ge::GRAPH_FAILED;
}
batchSize_ = sortedValueShape.GetDim(DIM_INDEX0);
vocabSize_ = sortedValueShape.GetDim(DIM_ONE);
if (sortedIndicesShape.GetDim(DIM_INDEX0) != batchSize_ || sortedIndicesShape.GetDim(DIM_ONE) != vocabSize_) {
OP_LOGE(opName_, "the shape of sorted_indices should be equal to sorted_value.");
return ge::GRAPH_FAILED;
}
auto pShapePtr = tilingcontext->GetOptionalInputShape(P_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(tilingcontext, pShapePtr);
auto pShape = pShapePtr->GetStorageShape();
auto pDimNum = pShape.GetDimNum();
if (pDimNum != DIM_ONE && pDimNum != 0) {
OP_LOGE(opName_, "the dimNum of p should be 1 or 0, but got %u.", pDimNum);
return ge::GRAPH_FAILED;
}
if (pDimNum != 0 && batchSize_ != pShape.GetDim(DIM_INDEX0)) {
OP_LOGE(opName_, "p.shape[0] should be equal to logits.shape[0].");
return ge::GRAPH_FAILED;
}
auto kShapePtr = tilingcontext->GetOptionalInputShape(K_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(tilingcontext, kShapePtr);
auto kShape = kShapePtr->GetStorageShape();
auto kDimNum = kShape.GetDimNum();
if (kDimNum != DIM_ONE && kDimNum != 0) {
OP_LOGE(opName_, "the dimNum of k should be 1 or 0, but got %u.", kShape.GetDimNum());
return ge::GRAPH_FAILED;
}
if (kDimNum != 0 && batchSize_ != kShape.GetDim(DIM_INDEX0)) {
OP_LOGE(opName_, "k.shape[0] should be equal to logits.shape[0].");
return ge::GRAPH_FAILED;
}
if (kDimNum == 0 && pDimNum == 0) {
OP_LOGE(opName_, "the dimNum of q and k should be 0 at the same time.");
return ge::GRAPH_FAILED;
}
onlyTopK_ = (kDimNum != 0 && pDimNum == 0) ? ONLY_TOP_K_KEY : 0;
onlyTopP_ = (pDimNum != 0 && kDimNum == 0) ? ONLY_TOP_P_KEY : 0;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus ApplyTopKTopPCustomTiling::Init() {
opName_ = tilingcontext->GetNodeName();
OP_LOGD(opName_, "TilingForApplyTopKTopPCustom init.");
auto platformInfo = platform_ascendc::PlatformAscendC(tilingcontext->GetPlatformInfo());
coreNum_ = platformInfo.GetCoreNumAiv();
platformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, platformUbSize_);
OP_LOGD(opName_, "platformUbSize: %lu.", platformUbSize_);
uint32_t avaliableUb = static_cast<uint32_t>(platformUbSize_) - SYS_RESERVED_UB - SELECT_RESERVED_UB;
calUbSize_ = FloorAlign(avaliableUb, BLOCK_BYTES);
if (CheckShape() == ge::GRAPH_FAILED) {
OP_LOGE(opName_, "check shape failed.");
return ge::GRAPH_FAILED;
}
uint32_t tempValue = 1;
while (tempValue < vocabSize_) {
tempValue <<= 1;
iterateTimes_++;
} // ceil(log2(vocabSize_))
return ge::GRAPH_SUCCESS;
}
void ApplyTopKTopPCustomTiling::SetTilingKey() {
tilingKey_ += onlyTopK_;
tilingKey_ += onlyTopP_;
tilingcontext->SetTilingKey(tilingKey_);
if (tilingKey_ == ONLY_TOP_P_KEY){
tilingcontext->SetScheduleMode(BATCH_MODE);
}
}
void ApplyTopKTopPCustomTiling::GetUsedCore()
{
if (coreNum_ > 0) {
batchPerCore_ = coreNum_ == uint32_t(0) ? batchSize_ : batchSize_ / coreNum_;
tailBatch_ = batchSize_ % coreNum_;
usedCoreNum_ = coreNum_;
}
}
void ApplyTopKTopPCustomTiling::CalDataPerCore()
{
uint32_t inputDataTypeByte = DATATYPE_LEN_MAP[tilingcontext->GetInputDesc(SORTED_VALUE_INPUT_INDEX)->GetDataType()];
uint32_t dataPerBlock = BLOCK_BYTES / inputDataTypeByte;
dataNumInit_ = vocabSize_ < K_VALUE_MAX ? vocabSize_ : K_VALUE_MAX;
dataNumInitAligned_ = vocabSize_ < K_VALUE_MAX ? vocabSize_ : K_VALUE_MAX;
ubFactorElement_ = vocabSize_ < K_VALUE_MAX ? vocabSize_ : K_VALUE_MAX;
ubFactorElementAligned_ = CeilAlign(ubFactorElement_, dataPerBlock);
tailUbFactorElement_ = vocabSize_ % ubFactorElement_;
tailUbFactorElement_ = tailUbFactorElement_ == uint32_t(0) ? ubFactorElement_ : tailUbFactorElement_;
tailUbFactorElementAligned_ = CeilAlign(tailUbFactorElement_, dataPerBlock);
uint32_t sortedValueBytes = ubFactorElementAligned_ * inputDataTypeByte + K_VALUE_MAX * inputDataTypeByte;
uint32_t sortedIndicesBytes = ubFactorElementAligned_ * BYTES_B32 + K_VALUE_MAX * BYTES_B32;
uint32_t pBytes = dataPerBlock * inputDataTypeByte;
uint32_t kBytes = DATA_PER_BLOCK_B32 * BYTES_B32;
uint32_t outTensorBytes = ubFactorElementAligned_ * inputDataTypeByte;
calUbSize_ = calUbSize_ - sortedValueBytes - sortedIndicesBytes - pBytes - kBytes - outTensorBytes;
if (onlyTopP_ > 0) {
calUbSize_ = static_cast<uint32_t>(platformUbSize_);
}
}
void ApplyTopKTopPCustomTiling::FillTilingData()
{
tilingData.set_batchSize(batchSize_);
tilingData.set_vocabSize(vocabSize_);
tilingData.set_batchPerCore(batchPerCore_);
tilingData.set_tailBatch(tailBatch_);
tilingData.set_blockNum(usedCoreNum_);
tilingData.set_dataNumInit(dataNumInit_);
tilingData.set_dataNumInitAligned(dataNumInitAligned_);
tilingData.set_ubFactorElement(ubFactorElement_);
tilingData.set_ubFactorElementAligned(ubFactorElementAligned_);
tilingData.set_tailUbFactorElement(tailUbFactorElement_);
tilingData.set_tailUbFactorElementAligned(tailUbFactorElementAligned_);
tilingData.set_calUbSize(calUbSize_);
tilingData.set_iterateTimes(iterateTimes_);
}
void ApplyTopKTopPCustomTiling::PrintTilingData()
{
OP_LOGD(opName_, "batchSize: %u.", tilingData.get_batchSize());
OP_LOGD(opName_, "vocabSize: %u.", tilingData.get_vocabSize());
OP_LOGD(opName_, "batchPerCore: %u.", tilingData.get_batchPerCore());
OP_LOGD(opName_, "tailBatch: %u.", tilingData.get_tailBatch());
OP_LOGD(opName_, "usedCoreNum: %u.", tilingData.get_blockNum());
OP_LOGD(opName_, "dataNumInit_: %u.", tilingData.get_dataNumInit());
OP_LOGD(opName_, "dataNumInitAligned_: %u.", tilingData.get_dataNumInitAligned());
OP_LOGD(opName_, "ubFactorElement: %u.", tilingData.get_ubFactorElement());
OP_LOGD(opName_, "ubFactorElementAligned: %u.", tilingData.get_ubFactorElementAligned());
OP_LOGD(opName_, "tailUbFactorElement: %u.", tilingData.get_tailUbFactorElement());
OP_LOGD(opName_, "tailUbFactorElementAligned: %u.", tilingData.get_tailUbFactorElementAligned());
OP_LOGD(opName_, "calUbSize: %u.", tilingData.get_calUbSize());
OP_LOGD(opName_, "iterateTimes: %u.", tilingData.get_iterateTimes());
}
ge::graphStatus ApplyTopKTopPCustomTiling::RunKernelTiling()
{
OP_LOGD(opName_, "TilingForApplyTopKTopPCustom start.");
SetTilingKey();
GetUsedCore();
CalDataPerCore();
FillTilingData();
PrintTilingData();
OP_LOGD(opName_, "tilingKey: %u.", tilingKey_);
uint32_t syncWorkspaceSize = SYS_WORKSPACESIZE;
size_t* currentWorkspace = tilingcontext->GetWorkspaceSizes(1);
currentWorkspace[0] = onlyTopP_ > 0 ? syncWorkspaceSize + batchSize_ * vocabSize_ * FLOAT_BYTES : syncWorkspaceSize;
tilingData.SaveToBuffer(tilingcontext->GetRawTilingData()->GetData(),
tilingcontext->GetRawTilingData()->GetCapacity());
tilingcontext->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
tilingcontext->SetBlockDim(usedCoreNum_);
OP_LOGD(opName_, "TilingForApplyTopKTopPCustom end.");
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingForApplyTopKTopPCustom(gert::TilingContext* context)
{
ApplyTopKTopPCustomTiling tilingObject(context);
auto ret = tilingObject.Init();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(context->GetNodeName(), "tiling Init failed.");
return ge::GRAPH_FAILED;
}
ret = tilingObject.RunKernelTiling();
OP_LOGD(context->GetNodeName(), "TilingForApplyTopKTopPCustom end.");
return ret;
}
static ge::graphStatus TilingPrepareForApplyTopKTopPCustom(gert::TilingParseContext* context)
{
OP_LOGD(context->GetNodeName(), "TilingPrepareForApplyTopKTopPCustom start");
auto compileInfo = context->GetCompiledInfo<TilingForApplyTopKTopPCustomCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
auto platformInfo = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfo);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
compileInfo->totalCoreNum = ascendcPlatform.GetCoreNumAiv();
uint64_t ubSizePlatForm;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
compileInfo->ubSizePlatForm = static_cast<int64_t>(ubSizePlatForm);
OP_CHECK_IF(compileInfo->ubSizePlatForm <= 0,
OP_LOGE(context->GetNodeName(), "Failed to get ub size"),
return ge::GRAPH_FAILED);
OP_LOGD(context->GetNodeName(), "ub_size_platform is %lu", compileInfo->ubSizePlatForm);
uint64_t totalUbSize = 0;
platformInfo->GetLocalMemSize(fe::LocalMemType::UB, totalUbSize);
OP_LOGD(context->GetNodeName(), "total ub size is %lu", totalUbSize);
OP_LOGD(context->GetNodeName(), "TilingPrepareForApplyTopKTopPCustom end");
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(ApplyTopKTopPCustom)
.Tiling(TilingForApplyTopKTopPCustom)
.TilingParse<TilingForApplyTopKTopPCustomCompileInfo>(TilingPrepareForApplyTopKTopPCustom);
} // namespace optiling

View File

@@ -0,0 +1,51 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file apply_top_k_top_p_custom_tiling.h
* \brief
* ATTENTION: MAKE SURE 'BEGIN_TILING_DATA_DEF' STAY IN THE SAME LINE (28) USING BLANK LINES.
*
*
*
*
*
*/
#ifndef __APPLY_TOP_K_TOP_P_CUSTOM_TILINGDATA_H__
#define __APPLY_TOP_K_TOP_P_CUSTOM_TILINGDATA_H__
#include "register/tilingdata_base.h"
namespace optiling {
BEGIN_TILING_DATA_DEF(ApplyTopKTopPCustomTilingData)
TILING_DATA_FIELD_DEF(uint32_t, batchSize);
TILING_DATA_FIELD_DEF(uint32_t, vocabSize);
TILING_DATA_FIELD_DEF(uint32_t, batchPerCore);
TILING_DATA_FIELD_DEF(uint32_t, tailBatch);
TILING_DATA_FIELD_DEF(uint32_t, blockNum);
TILING_DATA_FIELD_DEF(uint32_t, dataNumInit);
TILING_DATA_FIELD_DEF(uint32_t, dataNumInitAligned);
TILING_DATA_FIELD_DEF(uint32_t, ubFactorElement);
TILING_DATA_FIELD_DEF(uint32_t, ubFactorElementAligned);
TILING_DATA_FIELD_DEF(uint32_t, tailUbFactorElement);
TILING_DATA_FIELD_DEF(uint32_t, tailUbFactorElementAligned);
TILING_DATA_FIELD_DEF(uint32_t, calUbSize);
TILING_DATA_FIELD_DEF(uint32_t, iterateTimes);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(ApplyTopKTopPCustom, ApplyTopKTopPCustomTilingData)
struct TilingForApplyTopKTopPCustomCompileInfo {
uint32_t totalCoreNum = 0;
uint64_t ubSizePlatForm = 0;
};
} // namespace optiling
#endif // __APPLY_TOP_K_TOP_P_CUSTOM_TILINGDATA_H__

View File

@@ -0,0 +1,71 @@
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#include <string>
#include "toolchain/slog.h"
#define OP_LOGI(opname, ...)
#define OP_LOGW(opname, ...) \
do { \
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
do { \
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE(opname, ...) \
do { \
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGD(opname, ...)
namespace optiling {
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
do { \
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
} while (0)
#define OP_CHECK_IF(cond, log_func, expr) \
do { \
if (cond) { \
log_func; \
expr; \
} \
} while (0)
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
do { \
if ((ptr) == nullptr) { \
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
return ge::GRAPH_FAILED; \
} \
} while (0)
} // namespace optiling
template <typename T>
T CeilAlign(T a, T b)
{
return (a + b - 1) / b * b;
}
template <typename T>
T CeilDiv(T a, T b)
{
if (b == 0) {
return a;
}
return (a + b - 1) / b;
}
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_

View File

@@ -0,0 +1,26 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file sort.h
* \brief
*/
#ifndef PTA_NPU_OP_API_INC_LEVEL0_OP_SORT_OP_H_
#define PTA_NPU_OP_API_INC_LEVEL0_OP_SORT_OP_H_
#include "opdev/op_executor.h"
#include "opdev/fast_vector.h"
namespace l0op {
const std::tuple<aclTensor*, aclTensor*> Sort(const aclTensor* self, int64_t dim, bool descending, bool stable,
op::DataType indicesType, aclOpExecutor* executor);
}
#endif // PTA_NPU_OP_API_INC_LEVEL0_OP_SORT_OP_H_