[ops] support advanced apply_top_k_top_p without top_k constraint (#6098)

### What this PR does / why we need it?
Implement `apply_top_k_top_p` via ascendC to eliminate the constraint of
k [1,1024]. It enables high performance TopKTopP calculation and avoid
D2H synchronization introduced by k validation.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
E2E serving with `k=4096` and  `p=0.95`
- vLLM version: v0.13.0
- vLLM main:
d68209402d

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
linfeng-yuan
2026-01-26 09:08:42 +08:00
committed by GitHub
parent 4e3919e965
commit 96309e2b79
16 changed files with 2208 additions and 3 deletions

View File

@@ -0,0 +1,41 @@
add_ops_compile_options(
OP_NAME ApplyTopKTopPCustom
OPTIONS --cce-auto-sync=on
-Wno-deprecated-declarations
-Werror
)
target_sources(op_host_aclnnExc PRIVATE
apply_top_k_top_p_custom_def.cpp
)
target_sources(opapi PRIVATE
apply_top_k_top_p_custom.cpp
aclnn_apply_top_k_top_p_custom.cpp
)
if (NOT BUILD_OPEN_PROJECT)
target_sources(aclnn_ops_train PRIVATE
apply_top_k_top_p_custom.cpp
aclnn_apply_top_k_top_p_custom.cpp
)
target_sources(aclnn_ops_infer PRIVATE
apply_top_k_top_p_custom.cpp
aclnn_apply_top_k_top_p_custom.cpp
)
endif ()
target_sources(optiling PRIVATE
apply_top_k_top_p_custom_tiling.cpp
)
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_apply_top_k_top_p_custom.h")
install(FILES ${_GMM_Aclnn_header}
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
)

View File

@@ -0,0 +1,213 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file aclnn_apply_top_k_top_p_custom.cpp
* \brief
*/
#include "aclnn_apply_top_k_top_p_custom.h"
#include "apply_top_k_top_p_custom.h"
#include "sort.h"
#include "aclnn_kernels/contiguous.h"
#include "aclnn_kernels/common/op_error_check.h"
#include "aclnn/aclnn_base.h"
#include "opdev/common_types.h"
#include "opdev/data_type_utils.h"
#include "opdev/make_op_executor.h"
#include "opdev/format_utils.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/tensor_view_utils.h"
#include "opdev/shape_utils.h"
#include "opdev/platform.h"
using namespace op;
#ifdef __cplusplus
extern "C" {
#endif
namespace {
static const int64_t EXPECTED_DIM_ONE = 1;
static const int64_t EXPECTED_DIM_TWO = 2;
static constexpr size_t DIM_ONE = 1;
// 根据API定义需要列出所能支持的所有dtype
static const std::initializer_list<op::DataType> DTYPE_SUPPORT_LIST = {
op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16, op::DataType::DT_BF16};
static const std::initializer_list<op::DataType> INT_DTYPE_SUPPORT_LIST = {
op::DataType::DT_INT32};
static bool CheckNotNull(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
{
OP_CHECK_NULL(logits, return false);
if (p == nullptr && k == nullptr) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The inputs, p and k, should not be nullptr at the same time.");
}
OP_CHECK_NULL(out, return false);
return true;
}
static bool CheckDtypeValid(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
{
// 检查数据类型是否在支持列表内
OP_CHECK_DTYPE_NOT_SUPPORT(logits, DTYPE_SUPPORT_LIST, return false);
if (p != nullptr) {
OP_CHECK_DTYPE_NOT_SUPPORT(p, DTYPE_SUPPORT_LIST, return false);
}
if (k != nullptr) {
OP_CHECK_DTYPE_NOT_SUPPORT(k, INT_DTYPE_SUPPORT_LIST, return false);
}
OP_CHECK_DTYPE_NOT_SUPPORT(out, DTYPE_SUPPORT_LIST, return false);
// 检查数据类型是否相同
if (p != nullptr) {
OP_CHECK_DTYPE_NOT_MATCH(p, logits->GetDataType(), return false);
}
OP_CHECK_DTYPE_NOT_MATCH(out, logits->GetDataType(), return false);
return true;
}
static bool CheckShapeValid(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
{
OP_CHECK_WRONG_DIMENSION(logits, EXPECTED_DIM_TWO, return false);
OP_CHECK_SHAPE_NOT_EQUAL(out, logits, return false);
if (p != nullptr) {
OP_CHECK_WRONG_DIMENSION(p, EXPECTED_DIM_ONE, return false);
}
if (k != nullptr) {
OP_CHECK_WRONG_DIMENSION(k, EXPECTED_DIM_ONE, return false);
}
if (p != nullptr && p->GetViewShape().GetDim(0) != logits->GetViewShape().GetDim(0)) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "expected p.size(0) is equal to logits.size(0), but got %ld.",
p->GetViewShape().GetDim(0));
return false;
}
if (k != nullptr && k->GetViewShape().GetDim(0) != logits->GetViewShape().GetDim(0)) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "expected k.size(0) is equal to logits.size(0), but got %ld.",
k->GetViewShape().GetDim(0));
return false;
}
return true;
}
static bool CheckFormatValid(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
{
if (logits->GetStorageFormat() != Format::FORMAT_ND) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "logits format only support ND");
return false;
}
if (p != nullptr && p->GetStorageFormat() != Format::FORMAT_ND) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "p format only support ND");
return false;
}
if (k != nullptr && k->GetStorageFormat() != Format::FORMAT_ND) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "k format only support ND");
return false;
}
if (out->GetStorageFormat() != Format::FORMAT_ND) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "out format only support ND");
return false;
}
return true;
}
static aclnnStatus CheckParams(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
{
// 错误码等DFX方案细化后刷新错误日志在check接口内打印
// 1. 检查参数是否为空指针
CHECK_RET(CheckNotNull(logits, p, k, out), ACLNN_ERR_PARAM_NULLPTR);
// 2. 检查输入的数据类型是否在API支持的数据类型范围之内需要根据api定义校验
CHECK_RET(CheckDtypeValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID);
// 3. 检查shape是否满足约束
CHECK_RET(CheckShapeValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID);
// 4. 检查format是否满足约束
CHECK_RET(CheckFormatValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID);
return ACLNN_SUCCESS;
}
} // namespace
aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize(
const aclTensor* logits, const aclTensor* p, const aclTensor* k, aclTensor* out, uint64_t* workspaceSize,
aclOpExecutor** executor)
{
OP_CHECK_COMM_INPUT(workspaceSize, executor);
L2_DFX_PHASE_1(aclnnApplyTopKTopPCustom, DFX_IN(logits, p, k), DFX_OUT(out));
// 固定写法创建OpExecutor
auto uniqueExecutor = CREATE_EXECUTOR();
CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
// 固定写法,参数检查
auto ret = CheckParams(logits, p, k, out);
CHECK_RET(ret == ACLNN_SUCCESS, ret);
bool pIsEmpty = false;
bool kIsEmpty = false;
if (p != nullptr) {
pIsEmpty = p->IsEmpty();
}
if (k != nullptr) {
kIsEmpty = k->IsEmpty();
}
if (logits->IsEmpty() || pIsEmpty || kIsEmpty) {
// 根据实际支持情况补充
*workspaceSize = 0;
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
// 固定写法将输入selfRef转换成连续的tensor
auto logitsContiguous = l0op::Contiguous(logits, uniqueExecutor.get());
CHECK_RET(logitsContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
const aclTensor* pContiguous = nullptr;
const aclTensor* kContiguous = nullptr;
if (p != nullptr) {
pContiguous = l0op::Contiguous(p, uniqueExecutor.get());
CHECK_RET(pContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
if (k != nullptr) {
kContiguous = l0op::Contiguous(k, uniqueExecutor.get());
CHECK_RET(kContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
bool isLastDimSizeOne = logits->GetViewShape()[DIM_ONE] == 1;
auto viewCopyResult = logitsContiguous;
if (isLastDimSizeOne) {
viewCopyResult = l0op::ViewCopy(logitsContiguous, out, uniqueExecutor.get());
} else {
auto sortResult = l0op::Sort(logitsContiguous, -1, false, true, op::DataType::DT_INT32, uniqueExecutor.get());
const aclTensor* sortedValue = std::get<0>(sortResult);
CHECK_RET(sortedValue != nullptr, ACLNN_ERR_INNER_NULLPTR);
const aclTensor* sortedIndices = std::get<1>(sortResult);
CHECK_RET(sortedIndices != nullptr, ACLNN_ERR_INNER_NULLPTR);
auto res = l0op::ApplyTopKTopPCustom(sortedValue, sortedIndices, pContiguous, kContiguous, uniqueExecutor.get());
CHECK_RET(res != nullptr, ACLNN_ERR_INNER_NULLPTR);
// 固定写法将计算结果拷贝到输出out上out可能是非连续的tensor
viewCopyResult = l0op::ViewCopy(res, out, uniqueExecutor.get());
}
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
// 固定写法获取计算过程中需要使用的workspace大小
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
uniqueExecutor.ReleaseTo(executor); // 需要把 uniqueExecutor持有executor转移给executor
return ACLNN_SUCCESS;
}
aclnnStatus aclnnApplyTopKTopPCustom(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)
{
L2_DFX_PHASE_2(aclnnApplyTopKTopPCustom);
// 固定写法,调用框架能力,完成计算
return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
}
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,54 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file aclnn_apply_top_k_top_p_custom.h
* \brief
*/
#ifndef OP_API_INC_APPLY_TOP_K_TOP_P_CUSTOM_H_
#define OP_API_INC_APPLY_TOP_K_TOP_P_CUSTOM_H_
#include "aclnn/aclnn_base.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* @brief aclnnApplyTopKTopPCustom的第一段接口根据具体的计算流程计算workspace大小。
* @domain aclnn_ops_infer
* @param [in] logits: npu device侧的aclTensor数据类型支持FLOAT、FLOAT16、BFLOAT16支持非连续的Tensor数据格式支持ND。
* @param [in] p: npu device侧的aclTensor数据类型支持FLOAT、FLOAT16、BFLOAT16支持非连续的Tensor数据格式支持ND。
* @param [in] k: npu device侧的aclTensor数据类型支持INT32支持非连续的Tensor数据格式支持ND。
* @param [in] out: npu device侧的aclTensor数据类型支持FLOAT、FLOAT16、BFLOAT16支持非连续的Tensor数据格式支持ND。
* @param [out] workspaceSize: 返回用户需要在npu device侧申请的workspace大小。
* @param [out] executor: 返回op执行器包含算子计算流程。
* @return aclnnStatus: 返回状态码。
*/
aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize(const aclTensor* logits, const aclTensor* p,
const aclTensor* k, aclTensor* out, uint64_t* workspaceSize,
aclOpExecutor** executor);
/**
* @brief aclnnApplyTopKTopPCustom的第二段接口用于执行计算。
* @param [in] workspace: 在npu device侧申请的workspace内存起址。
* @param [in] workspaceSize: 在npu device侧申请的workspace大小由第一段接口aaclnnApplyTopKTopPCustomGetWorkspaceSize获取。
* @param [in] stream: acl stream流。
* @param [in] executor: op执行器包含了算子计算流程。
* @return aclnnStatus: 返回状态码。
*/
aclnnStatus aclnnApplyTopKTopPCustom(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor,
aclrtStream stream);
#ifdef __cplusplus
}
#endif
#endif // OP_API_INC_APPLY_TOP_K_TOP_P_CUSTOM_H_

View File

@@ -0,0 +1,46 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file apply_top_k_top_p_custom.cpp
* \brief
*/
#include "apply_top_k_top_p_custom.h"
#include "opdev/data_type_utils.h"
#include "opdev/format_utils.h"
#include "opdev/make_op_executor.h"
#include "opdev/op_def.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/shape_utils.h"
using namespace op;
namespace l0op {
OP_TYPE_REGISTER(ApplyTopKTopPCustom);
const aclTensor* ApplyTopKTopPCustom(
const aclTensor* sortedValue, const aclTensor* sortedIndices, const aclTensor* p, const aclTensor* k,
aclOpExecutor* executor)
{
L0_DFX(ApplyTopKTopPCustom, sortedValue, sortedIndices, p, k);
auto output = executor->AllocTensor(sortedValue->GetViewShape(), sortedValue->GetDataType());
if (p == nullptr) {
p = executor->AllocTensor(sortedValue->GetDataType(), Format::FORMAT_ND, Format::FORMAT_ND);
}
if (k == nullptr) {
k = executor->AllocTensor(DataType::DT_INT32, Format::FORMAT_ND, Format::FORMAT_ND);
}
ADD_TO_LAUNCHER_LIST_AICORE(ApplyTopKTopPCustom, OP_INPUT(sortedValue, sortedIndices, p, k), OP_OUTPUT(output));
return output;
}
} // namespace l0op

View File

@@ -0,0 +1,24 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file apply_top_k_top_p_custom.h
* \brief
*/
#ifndef OP_API_INC_LEVEL0_OP_APPLY_TOP_K_TOP_P_CUSTOM_OP_H_
#define OP_API_INC_LEVEL0_OP_APPLY_TOP_K_TOP_P_CUSTOM_OP_H_
#include "opdev/op_executor.h"
namespace l0op {
const aclTensor* ApplyTopKTopPCustom(const aclTensor* sortedValue, const aclTensor* sortedIndices,
const aclTensor* p, const aclTensor* k, aclOpExecutor* executor);
}
#endif // OP_API_INC_LEVEL0_OP_APPLY_TOP_K_TOP_P_CUSTOM_OP_H_

View File

@@ -0,0 +1,100 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file apply_top_k_top_p_custom_def.cpp
* \brief
*/
#include "register/op_def_registry.h"
namespace ops {
class ApplyTopKTopPCustom : public OpDef {
public:
explicit ApplyTopKTopPCustom(const char *name) : OpDef(name)
{
this->Input("sorted_value")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("sorted_indices")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("p")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("k")
.ParamType(OPTIONAL)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
OpAICoreConfig aicore_config;
aicore_config.DynamicCompileStaticFlag(true)
.DynamicFormatFlag(false)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
OpAICoreConfig config_kirin = GetKirinCoreConfig();
this->AICore().AddConfig("kirinx90", config_kirin);
}
private:
OpAICoreConfig GetKirinCoreConfig() const
{
OpAICoreConfig config_kirin;
config_kirin.DynamicCompileStaticFlag(true)
.DynamicFormatFlag(true)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true)
.NeedCheckSupportFlag(false)
.PrecisionReduceFlag(true);
config_kirin.Input("sorted_value")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
config_kirin.Input("sorted_indices")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
config_kirin.Input("p")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
config_kirin.Input("k")
.ParamType(OPTIONAL)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
config_kirin.Output("out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
return config_kirin;
}
};
OP_ADD(ApplyTopKTopPCustom);
} // namespace ops

View File

@@ -0,0 +1,314 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file apply_top_k_top_p_custom_tiling.cpp
* \brief
*/
#include <iostream>
#include <map>
#include "error_log.h"
#include "tiling/platform/platform_ascendc.h"
#include "platform/platform_infos_def.h"
#include "register/op_def_registry.h"
#include "register/op_impl_registry.h"
#include "tiling/tiling_api.h"
#include "apply_top_k_top_p_custom_tiling.h"
namespace {
constexpr uint32_t SYS_RESERVED_UB = uint32_t(16 * 1024);
constexpr uint32_t SELECT_RESERVED_UB = uint32_t(8 * 1024);
constexpr uint32_t DIM_ONE = 1;
constexpr uint32_t DIM_TWO = 2;
constexpr int32_t SORTED_VALUE_INPUT_INDEX = 0;
constexpr int32_t SORTED_INDICES_INPUT_INDEX = 1;
constexpr int32_t P_INPUT_INDEX = 2;
constexpr int32_t K_INPUT_INDEX = 3;
constexpr uint32_t DIM_INDEX0 = 0;
constexpr uint32_t FLOAT_BYTES = 4;
static std::map<ge::DataType, uint32_t> DTYPE_MAP = {{ge::DT_BF16, 2}, {ge::DT_FLOAT16, 1}, {ge::DT_FLOAT, 0}};
static std::map<ge::DataType, uint32_t> DATATYPE_LEN_MAP = {
{ge::DT_FLOAT16, 2}, {ge::DT_BF16, 2}, {ge::DT_FLOAT, 4}};
const static uint32_t SYS_WORKSPACESIZE = uint32_t(16 * 1024 * 1024);
constexpr uint32_t DATA_PER_BLOCK_B32 = 8;
constexpr uint32_t BYTES_B32 = 4;
constexpr uint32_t BLOCK_BYTES = 32;
constexpr uint32_t K_VALUE_MAX = 1024;
constexpr uint32_t ONLY_TOP_P_KEY = 2;
constexpr uint32_t ONLY_TOP_K_KEY = 1;
constexpr uint32_t BATCH_MODE = 1;
} // namespace
namespace optiling {
class ApplyTopKTopPCustomTiling {
public:
explicit ApplyTopKTopPCustomTiling(gert::TilingContext* context) : tilingcontext(context){};
ge::graphStatus Init();
ge::graphStatus RunKernelTiling();
private:
ApplyTopKTopPCustomTilingData tilingData;
gert::TilingContext* tilingcontext = nullptr;
ge::graphStatus CheckShape();
void SetTilingKey();
void GetUsedCore();
void CalDataPerCore();
void FillTilingData();
void PrintTilingData();
template <typename T1>
inline auto CeilAlign(T1 a, T1 b) const -> T1
{
return b == 0 ? a : (a + b - 1) / b * b;
}
template <typename T1>
inline auto FloorAlign(T1 a, T1 b) const -> T1
{
return b == 0 ? a : a / b * b;
}
const char *opName_ = nullptr;
uint32_t coreNum_ = 0;
uint32_t calUbSize_ = 0;
uint32_t batchSize_ = 0;
uint32_t vocabSize_ = 0;
uint32_t tilingKey_ = 0;
uint32_t usedCoreNum_ = 0;
uint32_t batchPerCore_ = 1;
uint32_t tailBatch_ = 0;
uint32_t dataNumInit_ = 0;
uint32_t dataNumInitAligned_ = 0;
uint32_t ubFactorElement_ = 0;
uint32_t ubFactorElementAligned_ = 0;
uint32_t tailUbFactorElement_ = 0;
uint32_t tailUbFactorElementAligned_ = 0;
uint32_t iterateTimes_ = 0;
uint32_t onlyTopK_ = 0;
uint32_t onlyTopP_ = 0;
uint64_t platformUbSize_ = 0;
};
ge::graphStatus ApplyTopKTopPCustomTiling::CheckShape() {
auto sortedValueShapePtr = tilingcontext->GetInputShape(SORTED_VALUE_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(tilingcontext, sortedValueShapePtr);
auto sortedValueShape = sortedValueShapePtr->GetStorageShape();
if (sortedValueShape.GetDimNum() != DIM_TWO) {
OP_LOGE(opName_, "the dimNum of sorted_value should be 2, but got %u.", sortedValueShape.GetDimNum());
return ge::GRAPH_FAILED;
}
auto sortedIndicesShapePtr = tilingcontext->GetInputShape(SORTED_INDICES_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(tilingcontext, sortedIndicesShapePtr);
auto sortedIndicesShape = sortedIndicesShapePtr->GetStorageShape();
if (sortedIndicesShape.GetDimNum() != DIM_TWO) {
OP_LOGE(opName_, "the dimNum of sorted_indices should be 2, but got %u.", sortedIndicesShape.GetDimNum());
return ge::GRAPH_FAILED;
}
batchSize_ = sortedValueShape.GetDim(DIM_INDEX0);
vocabSize_ = sortedValueShape.GetDim(DIM_ONE);
if (sortedIndicesShape.GetDim(DIM_INDEX0) != batchSize_ || sortedIndicesShape.GetDim(DIM_ONE) != vocabSize_) {
OP_LOGE(opName_, "the shape of sorted_indices should be equal to sorted_value.");
return ge::GRAPH_FAILED;
}
auto pShapePtr = tilingcontext->GetOptionalInputShape(P_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(tilingcontext, pShapePtr);
auto pShape = pShapePtr->GetStorageShape();
auto pDimNum = pShape.GetDimNum();
if (pDimNum != DIM_ONE && pDimNum != 0) {
OP_LOGE(opName_, "the dimNum of p should be 1 or 0, but got %u.", pDimNum);
return ge::GRAPH_FAILED;
}
if (pDimNum != 0 && batchSize_ != pShape.GetDim(DIM_INDEX0)) {
OP_LOGE(opName_, "p.shape[0] should be equal to logits.shape[0].");
return ge::GRAPH_FAILED;
}
auto kShapePtr = tilingcontext->GetOptionalInputShape(K_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(tilingcontext, kShapePtr);
auto kShape = kShapePtr->GetStorageShape();
auto kDimNum = kShape.GetDimNum();
if (kDimNum != DIM_ONE && kDimNum != 0) {
OP_LOGE(opName_, "the dimNum of k should be 1 or 0, but got %u.", kShape.GetDimNum());
return ge::GRAPH_FAILED;
}
if (kDimNum != 0 && batchSize_ != kShape.GetDim(DIM_INDEX0)) {
OP_LOGE(opName_, "k.shape[0] should be equal to logits.shape[0].");
return ge::GRAPH_FAILED;
}
if (kDimNum == 0 && pDimNum == 0) {
OP_LOGE(opName_, "the dimNum of q and k should be 0 at the same time.");
return ge::GRAPH_FAILED;
}
onlyTopK_ = (kDimNum != 0 && pDimNum == 0) ? ONLY_TOP_K_KEY : 0;
onlyTopP_ = (pDimNum != 0 && kDimNum == 0) ? ONLY_TOP_P_KEY : 0;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus ApplyTopKTopPCustomTiling::Init() {
opName_ = tilingcontext->GetNodeName();
OP_LOGD(opName_, "TilingForApplyTopKTopPCustom init.");
auto platformInfo = platform_ascendc::PlatformAscendC(tilingcontext->GetPlatformInfo());
coreNum_ = platformInfo.GetCoreNumAiv();
platformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, platformUbSize_);
OP_LOGD(opName_, "platformUbSize: %lu.", platformUbSize_);
uint32_t avaliableUb = static_cast<uint32_t>(platformUbSize_) - SYS_RESERVED_UB - SELECT_RESERVED_UB;
calUbSize_ = FloorAlign(avaliableUb, BLOCK_BYTES);
if (CheckShape() == ge::GRAPH_FAILED) {
OP_LOGE(opName_, "check shape failed.");
return ge::GRAPH_FAILED;
}
uint32_t tempValue = 1;
while (tempValue < vocabSize_) {
tempValue <<= 1;
iterateTimes_++;
} // ceil(log2(vocabSize_))
return ge::GRAPH_SUCCESS;
}
void ApplyTopKTopPCustomTiling::SetTilingKey() {
tilingKey_ += onlyTopK_;
tilingKey_ += onlyTopP_;
tilingcontext->SetTilingKey(tilingKey_);
if (tilingKey_ == ONLY_TOP_P_KEY){
tilingcontext->SetScheduleMode(BATCH_MODE);
}
}
void ApplyTopKTopPCustomTiling::GetUsedCore()
{
if (coreNum_ > 0) {
batchPerCore_ = coreNum_ == uint32_t(0) ? batchSize_ : batchSize_ / coreNum_;
tailBatch_ = batchSize_ % coreNum_;
usedCoreNum_ = coreNum_;
}
}
void ApplyTopKTopPCustomTiling::CalDataPerCore()
{
uint32_t inputDataTypeByte = DATATYPE_LEN_MAP[tilingcontext->GetInputDesc(SORTED_VALUE_INPUT_INDEX)->GetDataType()];
uint32_t dataPerBlock = BLOCK_BYTES / inputDataTypeByte;
dataNumInit_ = vocabSize_ < K_VALUE_MAX ? vocabSize_ : K_VALUE_MAX;
dataNumInitAligned_ = vocabSize_ < K_VALUE_MAX ? vocabSize_ : K_VALUE_MAX;
ubFactorElement_ = vocabSize_ < K_VALUE_MAX ? vocabSize_ : K_VALUE_MAX;
ubFactorElementAligned_ = CeilAlign(ubFactorElement_, dataPerBlock);
tailUbFactorElement_ = vocabSize_ % ubFactorElement_;
tailUbFactorElement_ = tailUbFactorElement_ == uint32_t(0) ? ubFactorElement_ : tailUbFactorElement_;
tailUbFactorElementAligned_ = CeilAlign(tailUbFactorElement_, dataPerBlock);
uint32_t sortedValueBytes = ubFactorElementAligned_ * inputDataTypeByte + K_VALUE_MAX * inputDataTypeByte;
uint32_t sortedIndicesBytes = ubFactorElementAligned_ * BYTES_B32 + K_VALUE_MAX * BYTES_B32;
uint32_t pBytes = dataPerBlock * inputDataTypeByte;
uint32_t kBytes = DATA_PER_BLOCK_B32 * BYTES_B32;
uint32_t outTensorBytes = ubFactorElementAligned_ * inputDataTypeByte;
calUbSize_ = calUbSize_ - sortedValueBytes - sortedIndicesBytes - pBytes - kBytes - outTensorBytes;
if (onlyTopP_ > 0) {
calUbSize_ = static_cast<uint32_t>(platformUbSize_);
}
}
void ApplyTopKTopPCustomTiling::FillTilingData()
{
tilingData.set_batchSize(batchSize_);
tilingData.set_vocabSize(vocabSize_);
tilingData.set_batchPerCore(batchPerCore_);
tilingData.set_tailBatch(tailBatch_);
tilingData.set_blockNum(usedCoreNum_);
tilingData.set_dataNumInit(dataNumInit_);
tilingData.set_dataNumInitAligned(dataNumInitAligned_);
tilingData.set_ubFactorElement(ubFactorElement_);
tilingData.set_ubFactorElementAligned(ubFactorElementAligned_);
tilingData.set_tailUbFactorElement(tailUbFactorElement_);
tilingData.set_tailUbFactorElementAligned(tailUbFactorElementAligned_);
tilingData.set_calUbSize(calUbSize_);
tilingData.set_iterateTimes(iterateTimes_);
}
void ApplyTopKTopPCustomTiling::PrintTilingData()
{
OP_LOGD(opName_, "batchSize: %u.", tilingData.get_batchSize());
OP_LOGD(opName_, "vocabSize: %u.", tilingData.get_vocabSize());
OP_LOGD(opName_, "batchPerCore: %u.", tilingData.get_batchPerCore());
OP_LOGD(opName_, "tailBatch: %u.", tilingData.get_tailBatch());
OP_LOGD(opName_, "usedCoreNum: %u.", tilingData.get_blockNum());
OP_LOGD(opName_, "dataNumInit_: %u.", tilingData.get_dataNumInit());
OP_LOGD(opName_, "dataNumInitAligned_: %u.", tilingData.get_dataNumInitAligned());
OP_LOGD(opName_, "ubFactorElement: %u.", tilingData.get_ubFactorElement());
OP_LOGD(opName_, "ubFactorElementAligned: %u.", tilingData.get_ubFactorElementAligned());
OP_LOGD(opName_, "tailUbFactorElement: %u.", tilingData.get_tailUbFactorElement());
OP_LOGD(opName_, "tailUbFactorElementAligned: %u.", tilingData.get_tailUbFactorElementAligned());
OP_LOGD(opName_, "calUbSize: %u.", tilingData.get_calUbSize());
OP_LOGD(opName_, "iterateTimes: %u.", tilingData.get_iterateTimes());
}
ge::graphStatus ApplyTopKTopPCustomTiling::RunKernelTiling()
{
OP_LOGD(opName_, "TilingForApplyTopKTopPCustom start.");
SetTilingKey();
GetUsedCore();
CalDataPerCore();
FillTilingData();
PrintTilingData();
OP_LOGD(opName_, "tilingKey: %u.", tilingKey_);
uint32_t syncWorkspaceSize = SYS_WORKSPACESIZE;
size_t* currentWorkspace = tilingcontext->GetWorkspaceSizes(1);
currentWorkspace[0] = onlyTopP_ > 0 ? syncWorkspaceSize + batchSize_ * vocabSize_ * FLOAT_BYTES : syncWorkspaceSize;
tilingData.SaveToBuffer(tilingcontext->GetRawTilingData()->GetData(),
tilingcontext->GetRawTilingData()->GetCapacity());
tilingcontext->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
tilingcontext->SetBlockDim(usedCoreNum_);
OP_LOGD(opName_, "TilingForApplyTopKTopPCustom end.");
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingForApplyTopKTopPCustom(gert::TilingContext* context)
{
ApplyTopKTopPCustomTiling tilingObject(context);
auto ret = tilingObject.Init();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(context->GetNodeName(), "tiling Init failed.");
return ge::GRAPH_FAILED;
}
ret = tilingObject.RunKernelTiling();
OP_LOGD(context->GetNodeName(), "TilingForApplyTopKTopPCustom end.");
return ret;
}
static ge::graphStatus TilingPrepareForApplyTopKTopPCustom(gert::TilingParseContext* context)
{
OP_LOGD(context->GetNodeName(), "TilingPrepareForApplyTopKTopPCustom start");
auto compileInfo = context->GetCompiledInfo<TilingForApplyTopKTopPCustomCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
auto platformInfo = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfo);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
compileInfo->totalCoreNum = ascendcPlatform.GetCoreNumAiv();
uint64_t ubSizePlatForm;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
compileInfo->ubSizePlatForm = static_cast<int64_t>(ubSizePlatForm);
OP_CHECK_IF(compileInfo->ubSizePlatForm <= 0,
OP_LOGE(context->GetNodeName(), "Failed to get ub size"),
return ge::GRAPH_FAILED);
OP_LOGD(context->GetNodeName(), "ub_size_platform is %lu", compileInfo->ubSizePlatForm);
uint64_t totalUbSize = 0;
platformInfo->GetLocalMemSize(fe::LocalMemType::UB, totalUbSize);
OP_LOGD(context->GetNodeName(), "total ub size is %lu", totalUbSize);
OP_LOGD(context->GetNodeName(), "TilingPrepareForApplyTopKTopPCustom end");
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(ApplyTopKTopPCustom)
.Tiling(TilingForApplyTopKTopPCustom)
.TilingParse<TilingForApplyTopKTopPCustomCompileInfo>(TilingPrepareForApplyTopKTopPCustom);
} // namespace optiling

View File

@@ -0,0 +1,51 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file apply_top_k_top_p_custom_tiling.h
* \brief
* ATTENTION: MAKE SURE 'BEGIN_TILING_DATA_DEF' STAY IN THE SAME LINE (28) USING BLANK LINES.
*
*
*
*
*
*/
#ifndef __APPLY_TOP_K_TOP_P_CUSTOM_TILINGDATA_H__
#define __APPLY_TOP_K_TOP_P_CUSTOM_TILINGDATA_H__
#include "register/tilingdata_base.h"
namespace optiling {
BEGIN_TILING_DATA_DEF(ApplyTopKTopPCustomTilingData)
TILING_DATA_FIELD_DEF(uint32_t, batchSize);
TILING_DATA_FIELD_DEF(uint32_t, vocabSize);
TILING_DATA_FIELD_DEF(uint32_t, batchPerCore);
TILING_DATA_FIELD_DEF(uint32_t, tailBatch);
TILING_DATA_FIELD_DEF(uint32_t, blockNum);
TILING_DATA_FIELD_DEF(uint32_t, dataNumInit);
TILING_DATA_FIELD_DEF(uint32_t, dataNumInitAligned);
TILING_DATA_FIELD_DEF(uint32_t, ubFactorElement);
TILING_DATA_FIELD_DEF(uint32_t, ubFactorElementAligned);
TILING_DATA_FIELD_DEF(uint32_t, tailUbFactorElement);
TILING_DATA_FIELD_DEF(uint32_t, tailUbFactorElementAligned);
TILING_DATA_FIELD_DEF(uint32_t, calUbSize);
TILING_DATA_FIELD_DEF(uint32_t, iterateTimes);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(ApplyTopKTopPCustom, ApplyTopKTopPCustomTilingData)
struct TilingForApplyTopKTopPCustomCompileInfo {
uint32_t totalCoreNum = 0;
uint64_t ubSizePlatForm = 0;
};
} // namespace optiling
#endif // __APPLY_TOP_K_TOP_P_CUSTOM_TILINGDATA_H__

View File

@@ -0,0 +1,71 @@
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#include <string>
#include "toolchain/slog.h"
#define OP_LOGI(opname, ...)
#define OP_LOGW(opname, ...) \
do { \
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
do { \
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE(opname, ...) \
do { \
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGD(opname, ...)
namespace optiling {
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
do { \
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
} while (0)
#define OP_CHECK_IF(cond, log_func, expr) \
do { \
if (cond) { \
log_func; \
expr; \
} \
} while (0)
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
do { \
if ((ptr) == nullptr) { \
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
return ge::GRAPH_FAILED; \
} \
} while (0)
} // namespace optiling
template <typename T>
T CeilAlign(T a, T b)
{
return (a + b - 1) / b * b;
}
template <typename T>
T CeilDiv(T a, T b)
{
if (b == 0) {
return a;
}
return (a + b - 1) / b;
}
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_

View File

@@ -0,0 +1,26 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file sort.h
* \brief
*/
#ifndef PTA_NPU_OP_API_INC_LEVEL0_OP_SORT_OP_H_
#define PTA_NPU_OP_API_INC_LEVEL0_OP_SORT_OP_H_
#include "opdev/op_executor.h"
#include "opdev/fast_vector.h"
namespace l0op {
const std::tuple<aclTensor*, aclTensor*> Sort(const aclTensor* self, int64_t dim, bool descending, bool stable,
op::DataType indicesType, aclOpExecutor* executor);
}
#endif // PTA_NPU_OP_API_INC_LEVEL0_OP_SORT_OP_H_

View File

@@ -0,0 +1,42 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file apply_top_k_top_p_custom.cpp
* \brief
*/
#include "apply_top_k_top_p_custom.h"
#include "apply_top_p_custom.h"
using namespace AscendC;
using namespace ApplyTopKTopPCustomOp;
using namespace ApplyTopPCustomOp;
extern "C" __global__ __aicore__ void apply_top_k_top_p_custom(GM_ADDR sorted_value, GM_ADDR sorted_indices,
GM_ADDR p, GM_ADDR k, GM_ADDR out, GM_ADDR workSpace, GM_ADDR tiling) {
TPipe pipe;
GET_TILING_DATA(tilingData, tiling);
if (TILING_KEY_IS(0)) {
ApplyTopKTopPCustomOp::ApplyTopKTopPCustom<DTYPE_OUT, float, DTYPE_OUT> op;
op.InitTilingData(tilingData, sorted_value, sorted_indices, p, k, out);
op.InitBuffer(&pipe);
op.Process();
} else if (TILING_KEY_IS(1)) {
ApplyTopKTopPCustomOp::ApplyTopKTopPCustom<DTYPE_OUT, float, DTYPE_OUT> op;
op.InitTilingData(tilingData, sorted_value, sorted_indices, p, k, out);
op.InitBuffer(&pipe);
op.ProcessTopK();
} else if (TILING_KEY_IS(2)) {
ApplyTopPCustomOp::ApplyTopPCustom<DTYPE_OUT, float, DTYPE_OUT> op;
op.InitTilingData(tilingData, sorted_value, sorted_indices, p, k, out, workSpace);
op.InitBuffer(&pipe);
op.ProcessTopP();
}
}

View File

@@ -0,0 +1,719 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file apply_top_k_top_p_custom.h
* \brief
*/
#ifndef APPLY_TOP_K_TOP_P_CUSTOM_H_KERNEL
#define APPLY_TOP_K_TOP_P_CUSTOM_H_KERNEL
#include "kernel_operator.h"
using namespace AscendC;
namespace ApplyTopKTopPCustomOp {
constexpr uint32_t BUFFER_NUM = 1;
constexpr uint16_t FLOAT16_NEG_INF = 0xFC00; // -inf 64512
constexpr uint16_t BF16_NEG_INF = 0xFF80; // -inf 65408
constexpr int32_t FLOAT32_NEG_INF = 0xFF800000; // -inf -2139095040
constexpr uint32_t BLOCK_BYTES = 32;
constexpr uint32_t DATA_PER_BLOCK_B32 = 8;
constexpr uint32_t DATA_PER_REPEAT_B32 = 64;
constexpr uint32_t K_MAX = 1024;
constexpr uint64_t MASK_64 = 64;
constexpr CumSumConfig CUMSUM_CONFIG{true, true, false};
template <typename inputT, typename calT, typename outputT>
class ApplyTopKTopPCustom {
public:
__aicore__ inline ApplyTopKTopPCustom(){};
__aicore__ inline void InitTilingData(
const ApplyTopKTopPCustomTilingData &__restrict tilingData, GM_ADDR sorted_value, GM_ADDR sorted_indices,
GM_ADDR p, GM_ADDR k, GM_ADDR out);
__aicore__ inline void InitBuffer(TPipe *inputPipe);
__aicore__ inline void Process();
__aicore__ inline void ProcessTopK();
private:
__aicore__ inline void InitCopyIn(uint32_t loopBatch, int64_t currentGmIdx);
__aicore__ inline void InitProcess(uint32_t loopBatch);
__aicore__ inline void ProcessKLtKMax(uint32_t loopBatch);
__aicore__ inline void ScatterCumtomImpl(uint32_t loopBatch, uint32_t loopProbNum, uint32_t offset);
__aicore__ inline void ProcessRemain(uint32_t loopBatch);
__aicore__ inline void GetKthResult(uint32_t loopBatch, uint32_t offset, uint8_t repeatTimes);
__aicore__ inline void GetFirstKLoop(uint32_t loopBatch, int32_t &firstKLoop);
__aicore__ inline void ScatterFromFirstKLoop(uint32_t loopBatch, int32_t firstKLoop, float &cumsumData);
__aicore__ inline void ReduceSumWithAddsAndExpImpl(uint32_t offset, uint32_t loopDataNum);
__aicore__ inline void CumSumWithAddsAndExpImpl(
uint32_t offset, uint32_t loopDataNum, uint32_t cumsumInner, float cumsumData);
// topk func
__aicore__ inline void InitProcessTopK(uint32_t loopBatch);
__aicore__ inline void ProcessKLtKMaxTopK(uint32_t loopBatch);
__aicore__ inline void ProcessRemainTopK(uint32_t loopBatch);
__aicore__ inline void GetFirstKLoopTopK(uint32_t loopBatch, int32_t &firstKLoop);
__aicore__ inline void ScatterFromFirstKLoopTopK(uint32_t loopBatch, int32_t firstKLoop);
__aicore__ inline void ScatterCumtomImplTopK(uint32_t loopBatch, uint32_t loopProbNum, uint32_t offset);
__aicore__ inline void SToMTE3Sync() {
event_t eventIDSToMTE3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_MTE3));
SetFlag<HardEvent::S_MTE3>(eventIDSToMTE3);
WaitFlag<HardEvent::S_MTE3>(eventIDSToMTE3);
}
__aicore__ inline void MTE3ToSSync() {
event_t eventIDMTE3ToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_S));
SetFlag<HardEvent::MTE3_S>(eventIDMTE3ToS);
WaitFlag<HardEvent::MTE3_S>(eventIDMTE3ToS);
}
__aicore__ inline void VToSSync() {
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
}
__aicore__ inline void MTE2ToVSync() {
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
}
__aicore__ inline void MTE2ToSSync() {
event_t eventIdMte2ToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_S));
SetFlag<HardEvent::MTE2_S>(eventIdMte2ToS);
WaitFlag<HardEvent::MTE2_S>(eventIdMte2ToS);
}
__aicore__ inline void MTE3ToVSync() {
event_t eventIdMte3ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
SetFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
WaitFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
}
__aicore__ inline void SToMTE2Sync()
{
event_t eventIDSToMTE2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_MTE2));
SetFlag<HardEvent::S_MTE2>(eventIDSToMTE2);
WaitFlag<HardEvent::S_MTE2>(eventIDSToMTE2);
}
__aicore__ inline void VToMTE3Sync() {
event_t eventIDVToMTE3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
SetFlag<HardEvent::V_MTE3>(eventIDVToMTE3);
WaitFlag<HardEvent::V_MTE3>(eventIDVToMTE3);
}
private:
TPipe *pipe_;
// create queues for input, in this case depth is equal to buffer num
TQue<QuePosition::VECIN, BUFFER_NUM> sortedValueInQueue_;
TQue<QuePosition::VECIN, BUFFER_NUM> sortedIndicesInQueue_;
TQue<QuePosition::VECIN, BUFFER_NUM> pInQueue_;
TQue<QuePosition::VECIN, BUFFER_NUM> kInQueue_;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueue_;
TBuf<TPosition::VECCALC> calBuf_;
// tilingData
uint32_t batchSize_ = 0;
uint32_t vocabSize_ = 0;
uint32_t batchPerCore_ = 0;
uint32_t tailBatch_ = 0;
uint32_t blockNum_ = 0;
uint32_t dataNumInit_ = 0;
uint32_t dataNumInitAligned_ = 0;
uint32_t ubFactorElement_ = 0;
uint32_t ubFactorElementAligned_ = 0;
uint32_t tailUbFactorElement_ = 0;
uint32_t tailUbFactorElementAligned_ = 0;
uint32_t calUbSize_ = 0;
uint32_t blockIdx_ = 0;
uint32_t loopBatch_ = 0;
uint32_t batchOffset_ = 0;
uint32_t bufOffsetLoop = 0;
uint32_t loopInner_ = 0;
uint32_t loopInnerOnlyP_ = 0;
int64_t baseGmIdx_ = 0;
GlobalTensor<inputT> mGmSortedValue_;
GlobalTensor<int32_t> mGmSortedIndices_;
GlobalTensor<inputT> mGmP_;
GlobalTensor<int32_t> mGmK_;
GlobalTensor<outputT> mGmOut_;
LocalTensor<int32_t> kLocal;
LocalTensor<inputT> pLocal;
LocalTensor<outputT> outTensor;
LocalTensor<inputT> sortedValueLocal;
LocalTensor<int32_t> sortedIndicesLocal;
LocalTensor<float> sortedValueLocalFp32;
LocalTensor<float> negInfLocal;
LocalTensor<float> calLocalFp32;
LocalTensor<float> kthValueLocal;
LocalTensor<float> tmpLocal;
LocalTensor<float> cumSumRes;
LocalTensor<float> cumSumTmp;
LocalTensor<float> reduceLocal;
LocalTensor<float> softMaxRes;
LocalTensor<inputT> scatterTensor;
LocalTensor<uint8_t> sharedTmpBuffer;
float kthValue = 0;
float pValue = 0;
float maxValue = 0;
float reduceSumValueInvert = 0;
float reduceSumValue = 0;
inputT kthTopKValue = 0;
BinaryRepeatParams repeatParams = {1, 0, 1, 8, 0, 8};
DataCopyExtParams scatterCopyParams{1, (uint32_t)(sizeof(outputT)), 0, 0, 0};
};
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::InitTilingData(
const ApplyTopKTopPCustomTilingData &__restrict tilingData, GM_ADDR sorted_value, GM_ADDR sorted_indices,
GM_ADDR p, GM_ADDR k, GM_ADDR out) {
batchSize_ = tilingData.batchSize;
vocabSize_ = tilingData.vocabSize;
batchPerCore_ = tilingData.batchPerCore;
tailBatch_ = tilingData.tailBatch;
blockNum_ = tilingData.blockNum;
dataNumInit_ = tilingData.dataNumInit;
dataNumInitAligned_ = AscendC::AlignUp(dataNumInit_, DATA_PER_BLOCK_B32);
ubFactorElement_ = tilingData.ubFactorElement;
ubFactorElementAligned_ = tilingData.ubFactorElementAligned;
tailUbFactorElement_ = tilingData.tailUbFactorElement;
tailUbFactorElementAligned_ = tilingData.tailUbFactorElementAligned;
calUbSize_ = tilingData.calUbSize;
blockIdx_ = GetBlockIdx();
if (blockIdx_ < tailBatch_)
{
loopBatch_ = batchPerCore_ + 1;
batchOffset_ = blockIdx_ * loopBatch_;
}
else
{
loopBatch_ = batchPerCore_;
batchOffset_ = blockIdx_ * batchPerCore_ + tailBatch_;
}
loopInner_ = (vocabSize_ - dataNumInit_ + ubFactorElementAligned_ - 1) / ubFactorElementAligned_;
loopInnerOnlyP_ = (vocabSize_ + ubFactorElementAligned_ - 1) / ubFactorElementAligned_;
mGmSortedValue_.SetGlobalBuffer(reinterpret_cast<__gm__ inputT *>(sorted_value));
mGmSortedIndices_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sorted_indices));
mGmP_.SetGlobalBuffer(reinterpret_cast<__gm__ inputT *>(p));
mGmK_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(k));
mGmOut_.SetGlobalBuffer(reinterpret_cast<__gm__ outputT *>(out));
}
// init used buffer
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::InitBuffer(TPipe *inputPipe) {
pipe_ = inputPipe;
pipe_->InitBuffer(sortedValueInQueue_, BUFFER_NUM, sizeof(inputT) * (ubFactorElementAligned_ + K_MAX));
pipe_->InitBuffer(sortedIndicesInQueue_, BUFFER_NUM, sizeof(int32_t) * (ubFactorElementAligned_ + K_MAX));
pipe_->InitBuffer(pInQueue_, BUFFER_NUM, BLOCK_BYTES);
pipe_->InitBuffer(kInQueue_, BUFFER_NUM, BLOCK_BYTES);
pipe_->InitBuffer(outQueue_, BUFFER_NUM, sizeof(outputT) * ubFactorElementAligned_);
pipe_->InitBuffer(calBuf_, calUbSize_);
if constexpr (!IsSameType<inputT, float>::value) {
sortedValueLocalFp32 = calBuf_.GetWithOffset<float>(ubFactorElementAligned_ + K_MAX, bufOffsetLoop);
bufOffsetLoop = bufOffsetLoop + (ubFactorElementAligned_ + K_MAX) * sizeof(float);
}
kthValueLocal = calBuf_.GetWithOffset<float>(DATA_PER_BLOCK_B32, bufOffsetLoop);
bufOffsetLoop = bufOffsetLoop + BLOCK_BYTES;
negInfLocal = calBuf_.GetWithOffset<float>(DATA_PER_BLOCK_B32, bufOffsetLoop);
bufOffsetLoop = bufOffsetLoop + BLOCK_BYTES;
tmpLocal = calBuf_.GetWithOffset<float>(ubFactorElementAligned_, bufOffsetLoop);
bufOffsetLoop = bufOffsetLoop + ubFactorElementAligned_ * sizeof(float);
cumSumRes = calBuf_.GetWithOffset<float>(ubFactorElementAligned_, bufOffsetLoop);
bufOffsetLoop = bufOffsetLoop + ubFactorElementAligned_ * sizeof(float);
cumSumTmp = calBuf_.GetWithOffset<float>(ubFactorElementAligned_, bufOffsetLoop);
bufOffsetLoop = bufOffsetLoop + ubFactorElementAligned_ * sizeof(float);
reduceLocal = calBuf_.GetWithOffset<float>(ubFactorElementAligned_ * BLOCK_BYTES, bufOffsetLoop);
bufOffsetLoop = bufOffsetLoop + ubFactorElementAligned_ * BLOCK_BYTES * sizeof(float);
softMaxRes = tmpLocal.template ReinterpretCast<float>();
scatterTensor = reduceLocal.template ReinterpretCast<inputT>();
sharedTmpBuffer = reduceLocal.template ReinterpretCast<uint8_t>();
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::Process() {
kLocal = kInQueue_.AllocTensor<int32_t>();
pLocal = pInQueue_.AllocTensor<inputT>();
outTensor = outQueue_.AllocTensor<outputT>();
sortedValueLocal = sortedValueInQueue_.AllocTensor<inputT>();
sortedIndicesLocal = sortedIndicesInQueue_.AllocTensor<int32_t>();
Duplicate(negInfLocal.template ReinterpretCast<int32_t>(), FLOAT32_NEG_INF, DATA_PER_BLOCK_B32);
if constexpr (IsSameType<inputT, float>::value) {
calLocalFp32 = sortedValueLocal;
Duplicate(outTensor.template ReinterpretCast<int32_t>(), FLOAT32_NEG_INF, ubFactorElementAligned_);
} else if constexpr (IsSameType<inputT, half>::value) {
calLocalFp32 = sortedValueLocalFp32;
Duplicate(outTensor.template ReinterpretCast<uint16_t>(), FLOAT16_NEG_INF, ubFactorElementAligned_);
} else {
calLocalFp32 = sortedValueLocalFp32;
Duplicate(outTensor.template ReinterpretCast<uint16_t>(), BF16_NEG_INF, ubFactorElementAligned_);
}
VToMTE3Sync();
for (uint32_t loopBatch = 0; loopBatch < loopBatch_; loopBatch++) {
baseGmIdx_ = batchOffset_ * vocabSize_ + loopBatch * vocabSize_;
InitProcess(loopBatch);
if (calLocalFp32.GetValue(ubFactorElementAligned_) < kthValue) {
ProcessKLtKMax(loopBatch);
} else {
ProcessRemain(loopBatch);
}
}
kInQueue_.FreeTensor(kLocal);
pInQueue_.FreeTensor(pLocal);
sortedValueInQueue_.FreeTensor(sortedValueLocal);
sortedIndicesInQueue_.FreeTensor(sortedIndicesLocal);
outQueue_.FreeTensor(outTensor);
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::InitCopyIn(uint32_t loopBatch,
int64_t currentGmIdx) {
DataCopyPad(mGmOut_[currentGmIdx], outTensor, {1, (uint32_t)(dataNumInit_ * sizeof(outputT)), 0, 0, 0});
DataCopyPad(sortedValueLocal[ubFactorElementAligned_], mGmSortedValue_[currentGmIdx],
{1, static_cast<uint32_t>(dataNumInit_ * sizeof(inputT)), 0, 0, 0},
{false, 0, 0, 0});
DataCopyPad(pLocal, mGmP_[batchOffset_ + loopBatch], {1, static_cast<uint32_t>(sizeof(inputT)), 0, 0, 0},
{false, 0, 0, 0});
if constexpr (!IsSameType<inputT, float>::value) {
MTE2ToVSync();
Cast(sortedValueLocalFp32[ubFactorElementAligned_], sortedValueLocal[ubFactorElementAligned_],
RoundMode::CAST_NONE, dataNumInit_);
Cast(tmpLocal, pLocal, RoundMode::CAST_NONE, DATA_PER_BLOCK_B32);
}
DataCopyPad(sortedIndicesLocal[ubFactorElementAligned_], mGmSortedIndices_[currentGmIdx],
{1, static_cast<uint32_t>(dataNumInit_ * sizeof(int32_t)), 0, 0, 0},
{false, 0, 0, 0});
DataCopyPad(kLocal, mGmK_[batchOffset_ + loopBatch], {1, static_cast<uint32_t>(sizeof(int32_t)), 0, 0, 0},
{false, 0, 0, 0});
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::GetKthResult(uint32_t loopBatch,
uint32_t offset, uint8_t repeatTimes){
Compare(tmpLocal.template ReinterpretCast<uint8_t>(), kthValueLocal, calLocalFp32[offset],
CMPMODE::GT, MASK_64, repeatTimes, repeatParams);
PipeBarrier<PIPE_V>();
Select(calLocalFp32[offset], tmpLocal.template ReinterpretCast<uint8_t>(),
negInfLocal, calLocalFp32[offset], SELMODE::VSEL_TENSOR_TENSOR_MODE, MASK_64,
repeatTimes, repeatParams);
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ReduceSumWithAddsAndExpImpl(uint32_t offset,
uint32_t loopDataNum) {
Adds(softMaxRes, calLocalFp32[offset], maxValue, loopDataNum);
PipeBarrier<PIPE_V>();
Exp(softMaxRes, softMaxRes, loopDataNum);
PipeBarrier<PIPE_V>();
ReduceSum(reduceLocal, softMaxRes, reduceLocal, loopDataNum);
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::InitProcess(uint32_t loopBatch) {
int64_t initGmIdx = baseGmIdx_ + vocabSize_ - dataNumInit_;
InitCopyIn(loopBatch, initGmIdx);
MTE2ToSSync();
int32_t kValue = kLocal.GetValue(0);
if constexpr (IsSameType<inputT, float>::value) {
pValue = float(1.0) - pLocal.GetValue(0);
} else {
pValue = float(1.0) - tmpLocal.GetValue(0);
}
maxValue = -calLocalFp32[ubFactorElementAligned_].GetValue(dataNumInit_ - 1);
if constexpr (IsSameType<inputT, float>::value) {
kthValue = mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0);
} else if constexpr (IsSameType<inputT, half>::value) {
kthValue = static_cast<float>(mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0));
} else {
kthValue = ToFloat(mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0));
}
Duplicate(kthValueLocal, kthValue, 8);
PipeBarrier<PIPE_V>();
uint8_t repeatTimes = (dataNumInit_ + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
GetKthResult(loopBatch, ubFactorElementAligned_, repeatTimes);
PipeBarrier<PIPE_V>();
DataCopyExtParams copyParams{1, (uint32_t)(ubFactorElementAligned_ * sizeof(outputT)), 0, 0, 0};
ReduceSumWithAddsAndExpImpl(ubFactorElementAligned_, dataNumInit_);
VToSSync();
reduceSumValue = reduceLocal.GetValue(0);
reduceSumValueInvert = 1 / reduceSumValue;
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ProcessKLtKMax(uint32_t loopBatch) {
DataCopyExtParams copyParams{1, (uint32_t)(ubFactorElementAligned_ * sizeof(outputT)), 0, 0, 0};
for (int32_t loopInner = 0; loopInner < loopInner_; loopInner++) {
int64_t currentGmIdxInner = baseGmIdx_ + loopInner * ubFactorElementAligned_;
if (loopInner == loopInner_ - 1) {
DataCopyPad(mGmOut_[currentGmIdxInner], outTensor,
{1, (uint32_t)(tailUbFactorElement_ * sizeof(outputT)), 0, 0, 0});
} else {
DataCopyPad(mGmOut_[currentGmIdxInner], outTensor, copyParams);
}
}
Muls(softMaxRes, softMaxRes, reduceSumValueInvert, dataNumInit_);
PipeBarrier<PIPE_V>();
const CumSumInfo cumSumInfo{1, dataNumInitAligned_};
CumSum<float, CUMSUM_CONFIG>(cumSumRes, cumSumTmp, softMaxRes, sharedTmpBuffer, cumSumInfo);
VToSSync();
int32_t loopProb = dataNumInit_ - 1;
scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb));
SToMTE3Sync();
int32_t gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb);
PipeBarrier<PIPE_MTE3>();
DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex], scatterTensor.template ReinterpretCast<outputT>(), scatterCopyParams);
MTE3ToSSync();
loopProb = loopProb - 1;
for (; loopProb >= 0; loopProb--) {
float cumsumData = cumSumRes.GetValue(loopProb);
if (cumsumData <= pValue) {
break;
}
scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb));
gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb);
SToMTE3Sync();
DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex],
scatterTensor.template ReinterpretCast<outputT>(), scatterCopyParams);
MTE3ToSSync();
}
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ScatterCumtomImpl(uint32_t loopBatch,
uint32_t loopProbNum, uint32_t offset) {
for (int32_t loopProb = 0; loopProb < static_cast<int32_t>(loopProbNum); loopProb++) {
float cumsumDataTmp = cumSumRes.GetValue(loopProb);
if (cumsumDataTmp <= pValue) {
continue;
}
scatterTensor.SetValue(0, sortedValueLocal[offset].GetValue(loopProb));
int32_t gmIndex = sortedIndicesLocal[offset].GetValue(loopProb);
SToMTE3Sync();
DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex], scatterTensor.template ReinterpretCast<outputT>(),
{1, (uint32_t)(1 * sizeof(outputT)), 0, 0, 0});
MTE3ToSSync();
}
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::GetFirstKLoop(uint32_t loopBatch,
int32_t &firstKLoop) {
uint8_t repeatTimes = (dataNumInit_ + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
uint32_t loopDataNum = ubFactorElementAligned_;
for (int32_t loopInner = 0; loopInner < loopInner_; loopInner++) {
int64_t currentGmIdx = baseGmIdx_ + loopInner * ubFactorElementAligned_;
if (loopInner == (loopInner_ - 1)) {
repeatTimes = ((tailUbFactorElement_) + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
loopDataNum = tailUbFactorElement_;
}
DataCopyPad(mGmOut_[currentGmIdx], outTensor, {1, (uint32_t)(loopDataNum * sizeof(outputT)), 0, 0, 0});
DataCopyPad(sortedValueLocal.template ReinterpretCast<inputT>(), mGmSortedValue_[currentGmIdx],
{1, static_cast<uint32_t>(loopDataNum * sizeof(inputT)), 0, 0, 0},
{false, 0, 0, 0});
if constexpr (!IsSameType<inputT, float>::value) {
MTE2ToVSync();
Cast(sortedValueLocalFp32, sortedValueLocal, RoundMode::CAST_NONE, loopDataNum);
VToSSync();
} else {
MTE2ToSSync();
}
if (calLocalFp32.GetValue(loopDataNum - 1) < kthValue) {
firstKLoop += 1;
continue;
}
GetKthResult(loopBatch, 0, repeatTimes);
PipeBarrier<PIPE_V>();
ReduceSumWithAddsAndExpImpl(0, loopDataNum);
VToSSync();
reduceSumValue += reduceLocal.GetValue(0);
}
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::CumSumWithAddsAndExpImpl(uint32_t offset,
uint32_t loopDataNum, uint32_t cumsumInner, float cumsumData) {
Adds(softMaxRes, calLocalFp32[offset], maxValue, loopDataNum);
PipeBarrier<PIPE_V>();
Exp(softMaxRes, softMaxRes, loopDataNum);
PipeBarrier<PIPE_V>();
Muls(softMaxRes, softMaxRes, reduceSumValueInvert, loopDataNum);
PipeBarrier<PIPE_V>();
const CumSumInfo cumSumInfo{1, cumsumInner};
CumSum<float, CUMSUM_CONFIG>(cumSumRes, cumSumTmp, softMaxRes, sharedTmpBuffer, cumSumInfo);
PipeBarrier<PIPE_V>();
Adds(cumSumRes, cumSumRes, cumsumData, loopDataNum);
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ProcessRemain(uint32_t loopBatch) {
int32_t firstKLoop = 0;
GetFirstKLoop(loopBatch, firstKLoop);
reduceSumValueInvert = 1 / reduceSumValue;
float cumsumData = 0;
ScatterFromFirstKLoop(loopBatch, firstKLoop, cumsumData);
uint32_t loopProb = dataNumInit_ - 1;
scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb));
int32_t gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb);
SToMTE3Sync();
DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex],
scatterTensor.template ReinterpretCast<outputT>(), scatterCopyParams);
MTE3ToVSync();
CumSumWithAddsAndExpImpl(ubFactorElementAligned_, dataNumInit_, dataNumInitAligned_, cumsumData);
VToSSync();
ScatterCumtomImpl(loopBatch, dataNumInit_ - 1, ubFactorElementAligned_);
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ScatterFromFirstKLoop(uint32_t loopBatch,
int32_t firstKLoop, float &cumsumData) {
uint32_t loopDataNum = ubFactorElementAligned_;
uint32_t cumsumInner = ubFactorElementAligned_;
uint8_t repeatTimes = ((ubFactorElementAligned_) + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
for (int32_t loopInner = firstKLoop; loopInner < loopInner_; loopInner++) {
int64_t currentGmIdx = baseGmIdx_ + loopInner * ubFactorElementAligned_;
if (loopInner == (loopInner_ - 1)) {
repeatTimes = (tailUbFactorElement_ + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
loopDataNum = tailUbFactorElement_;
cumsumInner = tailUbFactorElementAligned_;
}
DataCopyPad(sortedValueLocal.template ReinterpretCast<inputT>(), mGmSortedValue_[currentGmIdx],
{1, static_cast<uint32_t>(loopDataNum * sizeof(inputT)), 0, 0, 0},
{false, 0, 0, 0});
DataCopyPad(sortedIndicesLocal, mGmSortedIndices_[currentGmIdx],
{1, static_cast<uint32_t>(loopDataNum * sizeof(int32_t)), 0, 0, 0},
{false, 0, 0, 0});
if constexpr (!IsSameType<inputT, float>::value) {
MTE2ToVSync();
Cast(sortedValueLocalFp32, sortedValueLocal, RoundMode::CAST_NONE, loopDataNum);
PipeBarrier<PIPE_V>();
} else {
MTE2ToVSync();
}
GetKthResult(loopBatch, 0, repeatTimes);
PipeBarrier<PIPE_V>();
CumSumWithAddsAndExpImpl(0, loopDataNum, cumsumInner, cumsumData);
VToSSync();
float cumsumDataTmp = cumSumRes.GetValue(loopDataNum - 1);
cumsumData = cumsumDataTmp;
if (cumsumDataTmp <= pValue) {
continue;
}
ScatterCumtomImpl(loopBatch, loopDataNum, 0);
}
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ProcessTopK() {
kLocal = kInQueue_.AllocTensor<int32_t>();
outTensor = outQueue_.AllocTensor<outputT>();
sortedValueLocal = sortedValueInQueue_.AllocTensor<inputT>();
sortedIndicesLocal = sortedIndicesInQueue_.AllocTensor<int32_t>();
Duplicate(negInfLocal.template ReinterpretCast<int32_t>(), FLOAT32_NEG_INF, DATA_PER_BLOCK_B32);
if constexpr (IsSameType<inputT, float>::value) {
calLocalFp32 = sortedValueLocal;
Duplicate(outTensor.template ReinterpretCast<int32_t>(), FLOAT32_NEG_INF, ubFactorElementAligned_);
} else if constexpr (IsSameType<inputT, half>::value) {
calLocalFp32 = sortedValueLocalFp32;
Duplicate(outTensor.template ReinterpretCast<uint16_t>(), FLOAT16_NEG_INF, ubFactorElementAligned_);
} else {
calLocalFp32 = sortedValueLocalFp32;
Duplicate(outTensor.template ReinterpretCast<uint16_t>(), BF16_NEG_INF, ubFactorElementAligned_);
}
VToMTE3Sync();
for (uint32_t loopBatch = 0; loopBatch < loopBatch_; loopBatch++) {
baseGmIdx_ = batchOffset_ * vocabSize_ + loopBatch * vocabSize_;
InitProcessTopK(loopBatch);
/* The difference lies in that for the max branch, some data is less than the kthvalue,
so part of the data can be filtered out in advance;
while for the remain branch, all data must undergo the topk calculation.*/
if (calLocalFp32.GetValue(ubFactorElementAligned_) < kthValue) {
ProcessKLtKMaxTopK(loopBatch);
} else {
ProcessRemainTopK(loopBatch);
}
}
kInQueue_.FreeTensor(kLocal);
sortedValueInQueue_.FreeTensor(sortedValueLocal);
sortedIndicesInQueue_.FreeTensor(sortedIndicesLocal);
outQueue_.FreeTensor(outTensor);
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ProcessRemainTopK(uint32_t loopBatch) {
int32_t firstKLoop = 0;
GetFirstKLoopTopK(loopBatch, firstKLoop);
// Start the scatter calculation from the first loop in the row where the value is ≥ kthValue.
ScatterFromFirstKLoopTopK(loopBatch, firstKLoop);
/* Perform scatter calculation on the maximum number of ubFactorElementAligned_,
which does not overlap with the previous ones.*/
uint32_t loopProb = dataNumInit_ - 1;
scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb));
SToMTE3Sync();
int32_t gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb);
DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex],
scatterTensor.template ReinterpretCast<outputT>(), scatterCopyParams);
MTE3ToSSync();
ScatterCumtomImplTopK(loopBatch, dataNumInit_ - 1, ubFactorElementAligned_);
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::GetFirstKLoopTopK(uint32_t loopBatch,
int32_t &firstKLoop) {
uint8_t repeatTimes = (dataNumInit_ + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
uint32_t loopDataNum = ubFactorElementAligned_;
for (int32_t loopInner = 0; loopInner < loopInner_; loopInner++) {
int64_t currentGmIdx = baseGmIdx_ + loopInner * ubFactorElementAligned_;
if (loopInner == (loopInner_ - 1)) {
repeatTimes = ((tailUbFactorElement_) + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
loopDataNum = tailUbFactorElement_;
}
DataCopyPad(mGmOut_[currentGmIdx], outTensor, {1, (uint32_t)(loopDataNum * sizeof(outputT)), 0, 0, 0});
DataCopyPad(sortedValueLocal.template ReinterpretCast<inputT>(), mGmSortedValue_[currentGmIdx],
{1, static_cast<uint32_t>(loopDataNum * sizeof(inputT)), 0, 0, 0},
{false, 0, 0, 0});
MTE2ToSSync();
float rightVlaue = 0;
// Make a judgment on the rightmost value of each loop to filter the data.
if constexpr (IsSameType<inputT, bfloat16_t>::value) {
rightVlaue = ToFloat(sortedValueLocal.GetValue(loopDataNum - 1));
} else {
rightVlaue = static_cast<float>(sortedValueLocal.GetValue(loopDataNum - 1));
}
SToMTE2Sync();
if (rightVlaue < kthValue) {
firstKLoop += 1;
continue;
}
}
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ScatterFromFirstKLoopTopK(uint32_t loopBatch,
int32_t firstKLoop) {
uint32_t loopDataNum = ubFactorElementAligned_;
uint32_t cumsumInner = ubFactorElementAligned_;
uint8_t repeatTimes = ((ubFactorElementAligned_) + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
for (int32_t loopInner = firstKLoop; loopInner < loopInner_; loopInner++) {
int64_t currentGmIdx = baseGmIdx_ + loopInner * ubFactorElementAligned_;
if (loopInner == (loopInner_ - 1)) {
repeatTimes = (tailUbFactorElement_ + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
loopDataNum = tailUbFactorElement_;
cumsumInner = tailUbFactorElementAligned_;
}
DataCopyPad(sortedValueLocal.template ReinterpretCast<inputT>(), mGmSortedValue_[currentGmIdx],
{1, static_cast<uint32_t>(loopDataNum * sizeof(inputT)), 0, 0, 0},
{false, 0, 0, 0});
if constexpr (!IsSameType<inputT, float>::value) {
MTE2ToVSync();
Cast(sortedValueLocalFp32, sortedValueLocal, RoundMode::CAST_NONE, loopDataNum);
VToSSync();
}
DataCopyPad(sortedIndicesLocal, mGmSortedIndices_[currentGmIdx],
{1, static_cast<uint32_t>(loopDataNum * sizeof(int32_t)), 0, 0, 0},
{false, 0, 0, 0});
MTE2ToSSync();
ScatterCumtomImplTopK(loopBatch, loopDataNum, 0);
}
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ScatterCumtomImplTopK(uint32_t loopBatch,
uint32_t loopProbNum, uint32_t offset) {
// Reverse traversal, returning early to improve performance.
for (int32_t loopProb = static_cast<int32_t>(loopProbNum) - 1; loopProb >= 0; loopProb--) {
float curValue = calLocalFp32[offset].GetValue(loopProb);
if (curValue < kthValue) {
break;
}
scatterTensor.SetValue(0, sortedValueLocal[offset].GetValue(loopProb));
int32_t gmIndex = sortedIndicesLocal[offset].GetValue(loopProb);
SToMTE3Sync();
DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex], scatterTensor.template ReinterpretCast<outputT>(),
{1, (uint32_t)(1 * sizeof(outputT)), 0, 0, 0});
MTE3ToSSync();
}
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::InitProcessTopK(uint32_t loopBatch) {
int64_t initGmIdx = baseGmIdx_ + vocabSize_ - dataNumInit_;
DataCopyPad(mGmOut_[initGmIdx], outTensor, {1, (uint32_t)(dataNumInit_ * sizeof(outputT)), 0, 0, 0});
DataCopyPad(sortedValueLocal[ubFactorElementAligned_], mGmSortedValue_[initGmIdx],
{1, static_cast<uint32_t>(dataNumInit_ * sizeof(inputT)), 0, 0, 0},
{false, 0, 0, 0});
if constexpr (!IsSameType<inputT, float>::value) {
MTE2ToVSync();
Cast(sortedValueLocalFp32[ubFactorElementAligned_], sortedValueLocal[ubFactorElementAligned_],
RoundMode::CAST_NONE, dataNumInit_);
}
DataCopyPad(sortedIndicesLocal[ubFactorElementAligned_], mGmSortedIndices_[initGmIdx],
{1, static_cast<uint32_t>(dataNumInit_ * sizeof(int32_t)), 0, 0, 0},
{false, 0, 0, 0});
DataCopyPad(kLocal, mGmK_[batchOffset_ + loopBatch], {1, static_cast<uint32_t>(sizeof(int32_t)), 0, 0, 0},
{false, 0, 0, 0});
MTE2ToSSync();
int32_t kValue = mGmK_.GetValue(batchOffset_ + loopBatch);
maxValue = -calLocalFp32[ubFactorElementAligned_].GetValue(dataNumInit_ - 1);
if constexpr (IsSameType<inputT, float>::value) {
kthValue = mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0);
} else if constexpr (IsSameType<inputT, half>::value) {
kthValue = static_cast<float>(mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0));
} else {
kthValue = ToFloat(mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0));
}
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ProcessKLtKMaxTopK(uint32_t loopBatch) {
DataCopyExtParams copyParams{1, (uint32_t)(ubFactorElementAligned_ * sizeof(outputT)), 0, 0, 0};
// Move out -infinity to fill GM
for (int32_t loopInner = 0; loopInner < loopInner_; loopInner++) {
int64_t currentGmIdxInner = baseGmIdx_ + loopInner * ubFactorElementAligned_;
if (loopInner == loopInner_ - 1) {
DataCopyPad(mGmOut_[currentGmIdxInner], outTensor,
{1, (uint32_t)(tailUbFactorElement_ * sizeof(outputT)), 0, 0, 0});
} else {
DataCopyPad(mGmOut_[currentGmIdxInner], outTensor, copyParams);
}
}
// Scatter calculation
int32_t loopProb = dataNumInit_ - 1;
scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb));
int32_t gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb);
SToMTE3Sync();
PipeBarrier<PIPE_MTE3>();
DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex], scatterTensor.template ReinterpretCast<outputT>(), scatterCopyParams);
loopProb = loopProb - 1;
for (; loopProb >= 0; loopProb--) {
float curValue = calLocalFp32[ubFactorElementAligned_].GetValue(loopProb);
if (curValue < kthValue) {
break;
}
MTE3ToSSync();
scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb));
gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb);
SToMTE3Sync();
DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex],
scatterTensor.template ReinterpretCast<outputT>(), scatterCopyParams);
}
}
} // namespace
#endif // APPLY_TOP_K_TOP_P_CUSTOM_H_KERNEL

View File

@@ -0,0 +1,468 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file apply_top_p_custom.h
* \brief
*/
#ifndef APPLY_TOP_P_CUSTOM_H_KERNEL
#define APPLY_TOP_P_CUSTOM_H_KERNEL
#include "kernel_operator.h"
using namespace AscendC;
namespace ApplyTopPCustomOp {
constexpr uint16_t FLOAT16_NEG_INF = 0xFC00; // -inf 64512
constexpr uint16_t BF16_NEG_INF = 0xFF80; // -inf 65408
constexpr int32_t FLOAT32_NEG_INF = 0xFF800000; // -inf -2139095040
constexpr uint32_t BLOCK_BYTES = 32;
constexpr uint32_t DATA_PER_BLOCK_B32 = 8;
constexpr uint32_t DATA_PER_REPEAT_B32 = 64;
constexpr uint32_t SCATTER_PART_LENGTH = 1024;
constexpr uint32_t RESERVED_UB = 1024;
constexpr uint32_t FLOAT_BYTES = 4;
constexpr uint32_t SOFTMAX_UB_NUM = 2;
template <typename inputT, typename calT, typename outputT>
class ApplyTopPCustom {
public:
__aicore__ inline ApplyTopPCustom(){};
__aicore__ inline void InitTilingData(
const ApplyTopKTopPCustomTilingData &__restrict tilingData, GM_ADDR sorted_value, GM_ADDR sorted_indices, GM_ADDR p, GM_ADDR k, GM_ADDR out, GM_ADDR workspace);
__aicore__ inline void InitBuffer(TPipe *inputPipe);
__aicore__ inline void ProcessTopP();
private:
__aicore__ inline void ReduceSumWithAddsAndExpImpl(uint32_t loopDataNum);
// topp func
__aicore__ inline void ProcessPreSingleBatch(uint32_t loopBatch);
__aicore__ inline void GetSoftmaxSum(uint32_t loopBatch);
__aicore__ inline void CumsumKoggleStone(uint32_t loopBatch);
__aicore__ inline void GetPValue(uint32_t batchOffset);
__aicore__ inline void CumsumParamCompute(uint32_t iterateTime);
__aicore__ inline void GetMaxValue(int64_t baseGmIdx);
__aicore__ inline void GetSoftMaxRes(uint32_t loopBatch);
__aicore__ inline void ScatterSingleTask(uint32_t taskIndex);
__aicore__ inline void SToMTE3Sync() {
event_t eventIDSToMTE3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_MTE3));
SetFlag<HardEvent::S_MTE3>(eventIDSToMTE3);
WaitFlag<HardEvent::S_MTE3>(eventIDSToMTE3);
}
__aicore__ inline void VToMTE3Sync() {
event_t eventIDVToMTE3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
SetFlag<HardEvent::V_MTE3>(eventIDVToMTE3);
WaitFlag<HardEvent::V_MTE3>(eventIDVToMTE3);
}
__aicore__ inline void VToMTE2Sync() {
event_t eventIDVToMTE2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
SetFlag<HardEvent::V_MTE2>(eventIDVToMTE2);
WaitFlag<HardEvent::V_MTE2>(eventIDVToMTE2);
}
__aicore__ inline void MTE3ToSSync() {
event_t eventIDMTE3ToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_S));
SetFlag<HardEvent::MTE3_S>(eventIDMTE3ToS);
WaitFlag<HardEvent::MTE3_S>(eventIDMTE3ToS);
}
__aicore__ inline void VToSSync() {
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
}
__aicore__ inline void SToVSync() {
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
}
__aicore__ inline void MTE2ToVSync() {
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
}
__aicore__ inline void MTE2ToSSync() {
event_t eventIdMte2ToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_S));
SetFlag<HardEvent::MTE2_S>(eventIdMte2ToS);
WaitFlag<HardEvent::MTE2_S>(eventIdMte2ToS);
}
__aicore__ inline void MTE3ToMTE2Sync() {
event_t eventIdMte3ToMte2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));
SetFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
WaitFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
}
__aicore__ inline void MTE3ToVSync() {
event_t eventIdMte3ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
SetFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
WaitFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
}
__aicore__ inline uint32_t CeilDiv(uint32_t x, uint32_t y) {
return y == 0 ? x : (x + y - 1) / y;
}
private:
TPipe *pipe_;
TBuf<TPosition::VECCALC> calBuf_;
// tilingData
uint32_t batchSize_ = 0;
uint32_t vocabSize_ = 0;
uint32_t batchPerCore_ = 0;
uint32_t tailBatch_ = 0;
uint32_t blockNum_ = 0;
uint32_t dataNumInitAligned_ = 0;
uint32_t calUbSize_ = 0;
uint32_t blockIdx_ = 0;
uint32_t loopBatch_ = 0;
uint32_t batchOffset_ = 0;
uint32_t bufOffsetLoop = 0;
int64_t baseGmIdx_ = 0;
// topp scalar
uint32_t maxSoftmaxLength = 1;
uint32_t softmaxLength = 1;
uint32_t lineSfLoopTimes = 1;
uint32_t softmaxLengthTail = 1;
uint32_t scatterLength = 1;
uint32_t singleCoreB = 0;
uint32_t singleCoreBTail = 0;
uint32_t vCnt = 0;
uint32_t bCnt = 0;
uint32_t singleCoreV = 1;
uint32_t singleCoreVTail = 1;
uint32_t iterateTimes = 1;
GlobalTensor<inputT> mGmSortedValue_;
GlobalTensor<int32_t> mGmSortedIndices_;
GlobalTensor<inputT> mGmP_;
GlobalTensor<int32_t> mGmK_;
GlobalTensor<outputT> mGmOut_;
GlobalTensor<float> softMaxGm;
LocalTensor<uint8_t> totalUb;
// softmax tensor
LocalTensor<float> softMaxLocalFp32;
LocalTensor<inputT> softMaxLocal;
LocalTensor<float> softMaxResLocal;
LocalTensor<float> reduceLocal;
LocalTensor<inputT> outInfLocal;
// cumsum tensor
LocalTensor<float> cumSumInput1Local;
LocalTensor<float> cumSumInput2Local;
// scatter tensor
LocalTensor<inputT> sortedValueLocal;
LocalTensor<int32_t> sortedIndicesLocal;
LocalTensor<float> sortedValueLocalFp32;
LocalTensor<outputT> scatterLocal;
LocalTensor<float> cumsumLocal;
float pValue = 0;
float maxValue = 0;
float reduceSumValueInvert = 0;
float reduceSumValue = 0;
BinaryRepeatParams repeatParams = {1, 0, 1, 8, 0, 8};
DataCopyExtParams scatterCopyParams{1, (uint32_t)(sizeof(outputT)), 0, 0, 0};
};
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::InitTilingData(
const ApplyTopKTopPCustomTilingData &__restrict tilingData, GM_ADDR sorted_value, GM_ADDR sorted_indices,
GM_ADDR p, GM_ADDR k, GM_ADDR out, GM_ADDR workspace) {
batchSize_ = tilingData.batchSize;
vocabSize_ = tilingData.vocabSize;
batchPerCore_ = tilingData.batchPerCore;
tailBatch_ = tilingData.tailBatch;
blockNum_ = tilingData.blockNum;
calUbSize_ = tilingData.calUbSize;
iterateTimes = tilingData.iterateTimes;
blockIdx_ = GetBlockIdx();
if (blockIdx_ < tailBatch_) {
loopBatch_ = batchPerCore_ + 1;
batchOffset_ = blockIdx_ * loopBatch_;
} else {
loopBatch_ = batchPerCore_;
batchOffset_ = blockIdx_ * batchPerCore_ + tailBatch_;
}
maxSoftmaxLength = (calUbSize_ - RESERVED_UB) / SOFTMAX_UB_NUM / FLOAT_BYTES;
softmaxLength = maxSoftmaxLength < vocabSize_ ? maxSoftmaxLength : vocabSize_;
lineSfLoopTimes = (vocabSize_ + softmaxLength - 1) / softmaxLength;
softmaxLengthTail = vocabSize_ - (lineSfLoopTimes - 1) * softmaxLength;
scatterLength = (calUbSize_ - RESERVED_UB - BLOCK_BYTES) / (SOFTMAX_UB_NUM * FLOAT_BYTES + sizeof(inputT)) /
SCATTER_PART_LENGTH * SCATTER_PART_LENGTH;
singleCoreB = CeilDiv(batchSize_, blockNum_);
vCnt = batchSize_ < blockNum_ ? blockNum_ / batchSize_ : 1;
bCnt = batchSize_;
singleCoreB = 1;
singleCoreBTail = 1;
singleCoreV = vocabSize_ / vCnt;
singleCoreVTail = vocabSize_ - vCnt * singleCoreV;
mGmSortedValue_.SetGlobalBuffer(reinterpret_cast<__gm__ inputT *>(sorted_value));
mGmSortedIndices_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sorted_indices));
mGmP_.SetGlobalBuffer(reinterpret_cast<__gm__ inputT *>(p));
mGmK_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(k));
mGmOut_.SetGlobalBuffer(reinterpret_cast<__gm__ outputT *>(out));
softMaxGm.SetGlobalBuffer((__gm__ float*)workspace, batchSize_ * vocabSize_);
}
// init used buffer
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::InitBuffer(TPipe *inputPipe) {
pipe_ = inputPipe;
pipe_->InitBuffer(calBuf_, calUbSize_);
totalUb = calBuf_.Get<uint8_t>();
// softmax ub
uint32_t softmaxLengthAligned = CeilDiv(softmaxLength, BLOCK_BYTES / sizeof(inputT)) * BLOCK_BYTES / sizeof(inputT);
softMaxLocalFp32 = totalUb.ReinterpretCast<float>();
softMaxLocal = totalUb[softmaxLengthAligned * sizeof(inputT)].ReinterpretCast<inputT>();
softMaxResLocal = totalUb[softmaxLengthAligned * sizeof(float)].ReinterpretCast<float>();
reduceLocal = totalUb[softmaxLengthAligned * sizeof(float) * 2].ReinterpretCast<float>(); // 32 bytes
outInfLocal = totalUb.ReinterpretCast<inputT>(); // Take softmax ub
// cumsum ub
cumSumInput1Local = totalUb.ReinterpretCast<float>(); // Take softmax local
cumSumInput2Local = totalUb[softmaxLengthAligned * sizeof(float)].ReinterpretCast<float>(); // Take softmax res ub
// scatter ub
sortedValueLocal = totalUb[0].ReinterpretCast<inputT>();
sortedIndicesLocal = totalUb[scatterLength * sizeof(inputT)].ReinterpretCast<int32_t>();
cumsumLocal = totalUb[scatterLength * (FLOAT_BYTES + sizeof(inputT))].ReinterpretCast<float>();
scatterLocal = totalUb[calUbSize_ - RESERVED_UB + BLOCK_BYTES].ReinterpretCast<outputT>(); // 32 bytes
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::GetMaxValue(int64_t baseGmIdx) {
int64_t initGmIdx = baseGmIdx + vocabSize_ - 1;
if constexpr (IsSameType<inputT, float>::value) {
maxValue = -mGmSortedValue_[initGmIdx].GetValue(0);
} else if constexpr (IsSameType<inputT, half>::value) {
maxValue = -static_cast<float>(mGmSortedValue_[initGmIdx].GetValue(0));
} else {
maxValue = -ToFloat(mGmSortedValue_[initGmIdx].GetValue(0));
}
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::GetPValue(uint32_t batchOffset) {
if constexpr (IsSameType<inputT, float>::value) {
pValue = float(1.0) - mGmP_[batchOffset].GetValue(0);
} else if constexpr (IsSameType<inputT, half>::value) {
pValue = float(1.0) - static_cast<float>(mGmP_[batchOffset].GetValue(0));
} else {
pValue = float(1.0) - ToFloat(mGmP_[batchOffset].GetValue(0));
}
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::ProcessPreSingleBatch(uint32_t loopBatch) {
reduceSumValue = 0;
GetSoftmaxSum(loopBatch);
GetSoftMaxRes(loopBatch);
CumsumKoggleStone(loopBatch);
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::ProcessTopP() {
for (uint32_t loopBatch = 0; loopBatch < loopBatch_; loopBatch++) {
baseGmIdx_ = batchOffset_ * vocabSize_ + loopBatch * vocabSize_;
GetMaxValue(baseGmIdx_); // Get max value in softmax.
ProcessPreSingleBatch(loopBatch); // Softmax and cumsum.
}
SyncAll();
for (uint32_t taskIndex = 0; taskIndex < bCnt * vCnt; taskIndex++) {
ScatterSingleTask(taskIndex);
}
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::ScatterSingleTask(uint32_t taskIndex) {
if (GetBlockIdx() == taskIndex % blockNum_) {
uint32_t bCntIndex = taskIndex / vCnt;
uint32_t vCntIndex = taskIndex % vCnt;
uint32_t vCurSingleCore = vCntIndex < singleCoreVTail ? (singleCoreV + 1) : singleCoreV;
uint32_t copyTimes = CeilDiv(vCurSingleCore, scatterLength);
uint32_t copyLength = scatterLength;
uint32_t copyLengthTail = vCurSingleCore - (copyTimes - 1) * scatterLength;
GetPValue(bCntIndex); // Get maxPValue.
for (uint32_t cpIndex = 0; cpIndex < copyTimes; cpIndex++) {
uint32_t curCopyLength = cpIndex == (copyTimes - 1) ? copyLengthTail : copyLength;
int64_t gmOffset = vCntIndex < singleCoreVTail ?
bCntIndex * vocabSize_ + vCntIndex * (singleCoreV + 1) + cpIndex * copyLength :
bCntIndex * vocabSize_ + vCntIndex * singleCoreV + singleCoreVTail + cpIndex * copyLength;
DataCopyPad(cumsumLocal, softMaxGm[gmOffset],
{1, static_cast<uint32_t>(curCopyLength * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
DataCopyPad(sortedIndicesLocal, mGmSortedIndices_[gmOffset],
{1, static_cast<uint32_t>(curCopyLength * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
DataCopyPad(sortedValueLocal, mGmSortedValue_[gmOffset],
{1, static_cast<uint32_t>(curCopyLength * sizeof(inputT)), 0, 0, 0}, {false, 0, 0, 0});
MTE2ToSSync();
if (cumsumLocal.GetValue(curCopyLength - 1) <= pValue) {
continue;
}
uint32_t scatterLoop = CeilDiv(curCopyLength, SCATTER_PART_LENGTH);
uint32_t scatterNumsTail = curCopyLength - (scatterLoop - 1) * SCATTER_PART_LENGTH;
for (uint32_t scatterLoopIndex = 0; scatterLoopIndex < scatterLoop; scatterLoopIndex++) {
uint32_t curScatterNums = scatterLoopIndex == (scatterLoop - 1) ? scatterNumsTail : SCATTER_PART_LENGTH;
if (cumsumLocal.GetValue(scatterLoopIndex * SCATTER_PART_LENGTH + curScatterNums - 1) <= pValue) {
continue;
}
for (uint32_t scatterIndex = 0; scatterIndex < curScatterNums; scatterIndex++) {
int64_t scatterOffset = scatterLoopIndex * SCATTER_PART_LENGTH + scatterIndex;
if (cumsumLocal.GetValue(scatterOffset) <= pValue) {
continue;
}
scatterLocal.SetValue(0, sortedValueLocal.GetValue(scatterOffset));
int32_t lineIndex = sortedIndicesLocal.GetValue(scatterOffset);
SToMTE3Sync();
DataCopyPad(mGmOut_[bCntIndex * vocabSize_ + lineIndex], scatterLocal.template ReinterpretCast<outputT>(),
{1, (uint32_t)(1 * sizeof(outputT)), 0, 0, 0});
MTE3ToSSync();
}
}
}
}
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::GetSoftMaxRes(uint32_t loopBatch) {
uint32_t loopDataNum = softmaxLength;
for (int32_t loopInner = 0; loopInner < lineSfLoopTimes; loopInner++) {
int64_t currentGmIdx = baseGmIdx_ + loopInner * softmaxLength;
if (loopInner == (lineSfLoopTimes - 1)) {
loopDataNum = softmaxLengthTail;
}
if constexpr (!IsSameType<inputT, float>::value) {
DataCopyPad(softMaxLocal, mGmSortedValue_[currentGmIdx],
{1, static_cast<uint32_t>(loopDataNum * sizeof(inputT)), 0, 0, 0},
{false, 0, 0, 0});
MTE2ToVSync();
Cast(softMaxLocalFp32, softMaxLocal, RoundMode::CAST_NONE, loopDataNum);
PipeBarrier<PIPE_V>();
} else {
DataCopyPad(softMaxLocalFp32, mGmSortedValue_[currentGmIdx],
{1, static_cast<uint32_t>(loopDataNum * sizeof(float)), 0, 0, 0},
{false, 0, 0, 0});
MTE2ToVSync();
}
Adds(softMaxResLocal, softMaxLocalFp32, maxValue, loopDataNum);
VToMTE2Sync();
PipeBarrier<PIPE_V>();
Exp(softMaxResLocal, softMaxResLocal, loopDataNum);
PipeBarrier<PIPE_V>();
Muls(softMaxResLocal, softMaxResLocal, reduceSumValueInvert, loopDataNum);
VToMTE3Sync();
DataCopyPad(softMaxGm[currentGmIdx], softMaxResLocal,
{1, static_cast<uint32_t>(loopDataNum * sizeof(float)), 0, 0, 0});
MTE3ToMTE2Sync();
}
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::CumsumKoggleStone(uint32_t loopBatch) {
uint32_t loopDataNum = softmaxLength;
for (uint32_t iterateTime = 0; iterateTime < iterateTimes; iterateTime++) {
int64_t iteratOffset = 1;
for (uint32_t powerIdx = 0; powerIdx < iterateTime; powerIdx++) {
iteratOffset = iteratOffset * 2;
}
uint32_t addLength = vocabSize_ - iteratOffset;
uint32_t innerLoopNum = addLength / softmaxLength;
uint32_t dataTail = addLength - innerLoopNum * softmaxLength;
loopDataNum = softmaxLength;
for (uint32_t innerLoopIdx = 0; innerLoopIdx < innerLoopNum; innerLoopIdx++) {
// Copy data from right
int64_t loopInnerOffset = dataTail + (innerLoopNum - 1 - innerLoopIdx) * softmaxLength;
DataCopyPad(cumSumInput1Local, softMaxGm[baseGmIdx_ + loopInnerOffset],
{1, static_cast<uint32_t>(loopDataNum * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
DataCopyPad(cumSumInput2Local, softMaxGm[baseGmIdx_ + loopInnerOffset + iteratOffset],
{1, static_cast<uint32_t>(loopDataNum * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
MTE2ToVSync();
Add(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, loopDataNum);
VToMTE3Sync();
DataCopyPad(softMaxGm[baseGmIdx_ + loopInnerOffset + iteratOffset], cumSumInput1Local,
{1, static_cast<uint32_t>(loopDataNum * sizeof(float)), 0, 0, 0});
MTE3ToMTE2Sync();
}
if (dataTail > 0) {
loopDataNum = dataTail;
DataCopyPad(cumSumInput1Local, softMaxGm[baseGmIdx_],
{1, static_cast<uint32_t>(loopDataNum * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
DataCopyPad(cumSumInput2Local, softMaxGm[baseGmIdx_ + iteratOffset],
{1, static_cast<uint32_t>(loopDataNum * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
MTE2ToVSync();
Add(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, loopDataNum);
VToMTE3Sync();
DataCopyPad(softMaxGm[baseGmIdx_ + iteratOffset], cumSumInput1Local,
{1, static_cast<uint32_t>(loopDataNum * sizeof(float)), 0, 0, 0});
MTE3ToMTE2Sync();
}
}
MTE3ToVSync();
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::ReduceSumWithAddsAndExpImpl(
uint32_t loopDataNum) {
Adds(softMaxResLocal, softMaxLocalFp32, maxValue, loopDataNum);
PipeBarrier<PIPE_V>();
Exp(softMaxResLocal, softMaxResLocal, loopDataNum);
PipeBarrier<PIPE_V>();
ReduceSum(reduceLocal, softMaxResLocal, reduceLocal, loopDataNum);
}
template <typename inputT, typename calT, typename outputT>
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::GetSoftmaxSum(uint32_t loopBatch) {
uint32_t loopDataNum = softmaxLength;
for (int32_t loopInner = 0; loopInner < lineSfLoopTimes; loopInner++) {
int64_t currentGmIdx = baseGmIdx_ + loopInner * softmaxLength;
if (loopInner == (lineSfLoopTimes - 1)) {
loopDataNum = softmaxLengthTail;
}
if constexpr (IsSameType<inputT, float>::value) {
Duplicate(outInfLocal.template ReinterpretCast<int32_t>(), FLOAT32_NEG_INF, loopDataNum);
} else if constexpr (IsSameType<inputT, half>::value) {
Duplicate(outInfLocal.template ReinterpretCast<uint16_t>(), FLOAT16_NEG_INF, loopDataNum);
} else {
Duplicate(outInfLocal.template ReinterpretCast<uint16_t>(), BF16_NEG_INF, loopDataNum);
}
VToMTE3Sync();
DataCopyPad(mGmOut_[currentGmIdx], outInfLocal,
{1, static_cast<uint32_t>(loopDataNum * sizeof(inputT)), 0, 0, 0});
MTE3ToMTE2Sync();
if constexpr (!IsSameType<inputT, float>::value) {
DataCopyPad(softMaxLocal, mGmSortedValue_[currentGmIdx],
{1, static_cast<uint32_t>(loopDataNum * sizeof(inputT)), 0, 0, 0},
{false, 0, 0, 0});
MTE2ToVSync();
Cast(softMaxLocalFp32, softMaxLocal, RoundMode::CAST_NONE, loopDataNum);
PipeBarrier<PIPE_V>();
} else {
DataCopyPad(softMaxLocalFp32, mGmSortedValue_[currentGmIdx],
{1, static_cast<uint32_t>(loopDataNum * sizeof(inputT)), 0, 0, 0},
{false, 0, 0, 0});
MTE2ToVSync();
}
ReduceSumWithAddsAndExpImpl(loopDataNum);
VToSSync();
// Sum up to obtain the sum of exp reduce for the first x loops in the row.
reduceSumValue += reduceLocal.GetValue(0);
SToVSync();
}
reduceSumValueInvert = 1 / reduceSumValue;
SToVSync();
}
} // namespace
#endif // APPLY_TOP_P_CUSTOM_H_KERNEL

View File

@@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;"
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;"
SOC_ARG="ascend910b"
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
# ASCEND910C (A3) series
@@ -80,6 +80,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
"moe_init_routing_custom"
"moe_gating_top_k"
"add_rms_norm_bias"
"apply_top_k_top_p_custom"
)
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
SOC_ARG="ascend910_93"

View File

@@ -1234,6 +1234,26 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_
return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale);
}
at::Tensor npu_apply_top_k_top_p(
const at::Tensor& logits,
const c10::optional<at::Tensor>& p,
const c10::optional<at::Tensor>& k)
{
TORCH_CHECK(p.has_value() || k.has_value(),
"apply_top_k_top_p: p and k cannot be None at the same time.");
at::Tensor out = at::empty_like(logits);
EXEC_NPU_CMD(
aclnnApplyTopKTopPCustom,
logits,
p,
k,
out);
return out;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> moe_gating_top_k(
const at::Tensor& x,
int64_t k,
@@ -1495,4 +1515,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
"-> (Tensor y ,Tensor rstd, Tensor x)"
);
ops.impl("npu_add_rms_norm_bias", torch::kPrivateUse1, &vllm_ascend::npu_add_rms_norm_bias);
ops.def("npu_apply_top_k_top_p(Tensor logits, Tensor? p=None, Tensor? k=None) -> Tensor");
ops.impl("npu_apply_top_k_top_p", torch::kPrivateUse1, &vllm_ascend::npu_apply_top_k_top_p);
}