diff --git a/csrc/apply_top_k_top_p_custom/op_host/CMakeLists.txt b/csrc/apply_top_k_top_p_custom/op_host/CMakeLists.txt new file mode 100644 index 00000000..634f6a07 --- /dev/null +++ b/csrc/apply_top_k_top_p_custom/op_host/CMakeLists.txt @@ -0,0 +1,41 @@ +add_ops_compile_options( + OP_NAME ApplyTopKTopPCustom + OPTIONS --cce-auto-sync=on + -Wno-deprecated-declarations + -Werror +) + +target_sources(op_host_aclnnExc PRIVATE + apply_top_k_top_p_custom_def.cpp +) + +target_sources(opapi PRIVATE + apply_top_k_top_p_custom.cpp + aclnn_apply_top_k_top_p_custom.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(aclnn_ops_train PRIVATE + apply_top_k_top_p_custom.cpp + aclnn_apply_top_k_top_p_custom.cpp + ) + + target_sources(aclnn_ops_infer PRIVATE + apply_top_k_top_p_custom.cpp + aclnn_apply_top_k_top_p_custom.cpp + ) +endif () + +target_sources(optiling PRIVATE + apply_top_k_top_p_custom_tiling.cpp +) + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) + +file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_apply_top_k_top_p_custom.h") + +install(FILES ${_GMM_Aclnn_header} + DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL +) \ No newline at end of file diff --git a/csrc/apply_top_k_top_p_custom/op_host/aclnn_apply_top_k_top_p_custom.cpp b/csrc/apply_top_k_top_p_custom/op_host/aclnn_apply_top_k_top_p_custom.cpp new file mode 100644 index 00000000..d9683524 --- /dev/null +++ b/csrc/apply_top_k_top_p_custom/op_host/aclnn_apply_top_k_top_p_custom.cpp @@ -0,0 +1,213 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file aclnn_apply_top_k_top_p_custom.cpp + * \brief + */ +#include "aclnn_apply_top_k_top_p_custom.h" +#include "apply_top_k_top_p_custom.h" +#include "sort.h" +#include "aclnn_kernels/contiguous.h" +#include "aclnn_kernels/common/op_error_check.h" +#include "aclnn/aclnn_base.h" +#include "opdev/common_types.h" +#include "opdev/data_type_utils.h" +#include "opdev/make_op_executor.h" +#include "opdev/format_utils.h" +#include "opdev/op_dfx.h" +#include "opdev/op_executor.h" +#include "opdev/op_log.h" +#include "opdev/tensor_view_utils.h" +#include "opdev/shape_utils.h" +#include "opdev/platform.h" + +using namespace op; +#ifdef __cplusplus +extern "C" { +#endif +namespace { +static const int64_t EXPECTED_DIM_ONE = 1; +static const int64_t EXPECTED_DIM_TWO = 2; +static constexpr size_t DIM_ONE = 1; + +// 根据API定义,需要列出所能支持的所有dtype +static const std::initializer_list DTYPE_SUPPORT_LIST = { + op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16, op::DataType::DT_BF16}; + +static const std::initializer_list INT_DTYPE_SUPPORT_LIST = { + op::DataType::DT_INT32}; + +static bool CheckNotNull(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out) +{ + OP_CHECK_NULL(logits, return false); + if (p == nullptr && k == nullptr) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The inputs, p and k, should not be nullptr at the same time."); + } + OP_CHECK_NULL(out, return false); + return true; +} + +static bool CheckDtypeValid(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out) +{ + // 检查数据类型是否在支持列表内 + OP_CHECK_DTYPE_NOT_SUPPORT(logits, DTYPE_SUPPORT_LIST, return false); + if (p != nullptr) { + OP_CHECK_DTYPE_NOT_SUPPORT(p, DTYPE_SUPPORT_LIST, return false); + } + if (k != nullptr) { + OP_CHECK_DTYPE_NOT_SUPPORT(k, INT_DTYPE_SUPPORT_LIST, return false); + } + OP_CHECK_DTYPE_NOT_SUPPORT(out, DTYPE_SUPPORT_LIST, return false); + + // 检查数据类型是否相同 + if (p != nullptr) { + OP_CHECK_DTYPE_NOT_MATCH(p, logits->GetDataType(), return false); + } + OP_CHECK_DTYPE_NOT_MATCH(out, logits->GetDataType(), return false); + return true; +} + +static bool CheckShapeValid(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out) +{ + OP_CHECK_WRONG_DIMENSION(logits, EXPECTED_DIM_TWO, return false); + OP_CHECK_SHAPE_NOT_EQUAL(out, logits, return false); + if (p != nullptr) { + OP_CHECK_WRONG_DIMENSION(p, EXPECTED_DIM_ONE, return false); + } + if (k != nullptr) { + OP_CHECK_WRONG_DIMENSION(k, EXPECTED_DIM_ONE, return false); + } + if (p != nullptr && p->GetViewShape().GetDim(0) != logits->GetViewShape().GetDim(0)) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "expected p.size(0) is equal to logits.size(0), but got %ld.", + p->GetViewShape().GetDim(0)); + return false; + } + if (k != nullptr && k->GetViewShape().GetDim(0) != logits->GetViewShape().GetDim(0)) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "expected k.size(0) is equal to logits.size(0), but got %ld.", + k->GetViewShape().GetDim(0)); + return false; + } + return true; +} + +static bool CheckFormatValid(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out) +{ + if (logits->GetStorageFormat() != Format::FORMAT_ND) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "logits format only support ND"); + return false; + } + if (p != nullptr && p->GetStorageFormat() != Format::FORMAT_ND) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "p format only support ND"); + return false; + } + if (k != nullptr && k->GetStorageFormat() != Format::FORMAT_ND) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "k format only support ND"); + return false; + } + if (out->GetStorageFormat() != Format::FORMAT_ND) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "out format only support ND"); + return false; + } + return true; +} + +static aclnnStatus CheckParams(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out) +{ + // 错误码等DFX方案细化后刷新,错误日志在check接口内打印 + // 1. 检查参数是否为空指针 + CHECK_RET(CheckNotNull(logits, p, k, out), ACLNN_ERR_PARAM_NULLPTR); + + // 2. 检查输入的数据类型是否在API支持的数据类型范围之内,需要根据api定义校验 + CHECK_RET(CheckDtypeValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID); + + // 3. 检查shape是否满足约束 + CHECK_RET(CheckShapeValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID); + + // 4. 检查format是否满足约束 + CHECK_RET(CheckFormatValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID); + + return ACLNN_SUCCESS; +} +} // namespace + +aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize( + const aclTensor* logits, const aclTensor* p, const aclTensor* k, aclTensor* out, uint64_t* workspaceSize, + aclOpExecutor** executor) +{ + OP_CHECK_COMM_INPUT(workspaceSize, executor); + L2_DFX_PHASE_1(aclnnApplyTopKTopPCustom, DFX_IN(logits, p, k), DFX_OUT(out)); + // 固定写法,创建OpExecutor + auto uniqueExecutor = CREATE_EXECUTOR(); + CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR); + + // 固定写法,参数检查 + auto ret = CheckParams(logits, p, k, out); + CHECK_RET(ret == ACLNN_SUCCESS, ret); + bool pIsEmpty = false; + bool kIsEmpty = false; + if (p != nullptr) { + pIsEmpty = p->IsEmpty(); + } + if (k != nullptr) { + kIsEmpty = k->IsEmpty(); + } + if (logits->IsEmpty() || pIsEmpty || kIsEmpty) { + // 根据实际支持情况补充 + *workspaceSize = 0; + uniqueExecutor.ReleaseTo(executor); + return ACLNN_SUCCESS; + } + // 固定写法,将输入selfRef转换成连续的tensor + auto logitsContiguous = l0op::Contiguous(logits, uniqueExecutor.get()); + CHECK_RET(logitsContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR); + const aclTensor* pContiguous = nullptr; + const aclTensor* kContiguous = nullptr; + if (p != nullptr) { + pContiguous = l0op::Contiguous(p, uniqueExecutor.get()); + CHECK_RET(pContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR); + } + if (k != nullptr) { + kContiguous = l0op::Contiguous(k, uniqueExecutor.get()); + CHECK_RET(kContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR); + } + bool isLastDimSizeOne = logits->GetViewShape()[DIM_ONE] == 1; + auto viewCopyResult = logitsContiguous; + if (isLastDimSizeOne) { + viewCopyResult = l0op::ViewCopy(logitsContiguous, out, uniqueExecutor.get()); + } else { + auto sortResult = l0op::Sort(logitsContiguous, -1, false, true, op::DataType::DT_INT32, uniqueExecutor.get()); + const aclTensor* sortedValue = std::get<0>(sortResult); + CHECK_RET(sortedValue != nullptr, ACLNN_ERR_INNER_NULLPTR); + const aclTensor* sortedIndices = std::get<1>(sortResult); + CHECK_RET(sortedIndices != nullptr, ACLNN_ERR_INNER_NULLPTR); + auto res = l0op::ApplyTopKTopPCustom(sortedValue, sortedIndices, pContiguous, kContiguous, uniqueExecutor.get()); + CHECK_RET(res != nullptr, ACLNN_ERR_INNER_NULLPTR); + // 固定写法,将计算结果拷贝到输出out上,out可能是非连续的tensor + viewCopyResult = l0op::ViewCopy(res, out, uniqueExecutor.get()); + } + CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR); + // 固定写法,获取计算过程中需要使用的workspace大小 + *workspaceSize = uniqueExecutor->GetWorkspaceSize(); + uniqueExecutor.ReleaseTo(executor); // 需要把 uniqueExecutor持有executor转移给executor + return ACLNN_SUCCESS; +} + +aclnnStatus aclnnApplyTopKTopPCustom(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream) +{ + L2_DFX_PHASE_2(aclnnApplyTopKTopPCustom); + // 固定写法,调用框架能力,完成计算 + return CommonOpExecutorRun(workspace, workspaceSize, executor, stream); +} + +#ifdef __cplusplus +} +#endif + diff --git a/csrc/apply_top_k_top_p_custom/op_host/aclnn_apply_top_k_top_p_custom.h b/csrc/apply_top_k_top_p_custom/op_host/aclnn_apply_top_k_top_p_custom.h new file mode 100644 index 00000000..95c7a1b9 --- /dev/null +++ b/csrc/apply_top_k_top_p_custom/op_host/aclnn_apply_top_k_top_p_custom.h @@ -0,0 +1,54 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file aclnn_apply_top_k_top_p_custom.h + * \brief + */ +#ifndef OP_API_INC_APPLY_TOP_K_TOP_P_CUSTOM_H_ +#define OP_API_INC_APPLY_TOP_K_TOP_P_CUSTOM_H_ + +#include "aclnn/aclnn_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief aclnnApplyTopKTopPCustom的第一段接口,根据具体的计算流程,计算workspace大小。 + * @domain aclnn_ops_infer + * @param [in] logits: npu device侧的aclTensor,数据类型支持FLOAT、FLOAT16、BFLOAT16,支持非连续的Tensor,数据格式支持ND。 + * @param [in] p: npu device侧的aclTensor,数据类型支持FLOAT、FLOAT16、BFLOAT16,支持非连续的Tensor,数据格式支持ND。 + * @param [in] k: npu device侧的aclTensor,数据类型支持INT32,支持非连续的Tensor,数据格式支持ND。 + * @param [in] out: npu device侧的aclTensor,数据类型支持FLOAT、FLOAT16、BFLOAT16,支持非连续的Tensor,数据格式支持ND。 + * @param [out] workspaceSize: 返回用户需要在npu device侧申请的workspace大小。 + * @param [out] executor: 返回op执行器,包含算子计算流程。 + * @return aclnnStatus: 返回状态码。 + */ +aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize(const aclTensor* logits, const aclTensor* p, + const aclTensor* k, aclTensor* out, uint64_t* workspaceSize, + aclOpExecutor** executor); + +/** + * @brief aclnnApplyTopKTopPCustom的第二段接口,用于执行计算。 + * @param [in] workspace: 在npu device侧申请的workspace内存起址。 + * @param [in] workspaceSize: 在npu device侧申请的workspace大小,由第一段接口aaclnnApplyTopKTopPCustomGetWorkspaceSize获取。 + * @param [in] stream: acl stream流。 + * @param [in] executor: op执行器,包含了算子计算流程。 + * @return aclnnStatus: 返回状态码。 + */ +aclnnStatus aclnnApplyTopKTopPCustom(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, + aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif // OP_API_INC_APPLY_TOP_K_TOP_P_CUSTOM_H_ diff --git a/csrc/apply_top_k_top_p_custom/op_host/apply_top_k_top_p_custom.cpp b/csrc/apply_top_k_top_p_custom/op_host/apply_top_k_top_p_custom.cpp new file mode 100644 index 00000000..4fd7d3b7 --- /dev/null +++ b/csrc/apply_top_k_top_p_custom/op_host/apply_top_k_top_p_custom.cpp @@ -0,0 +1,46 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file apply_top_k_top_p_custom.cpp + * \brief + */ +#include "apply_top_k_top_p_custom.h" +#include "opdev/data_type_utils.h" +#include "opdev/format_utils.h" +#include "opdev/make_op_executor.h" +#include "opdev/op_def.h" +#include "opdev/op_dfx.h" +#include "opdev/op_executor.h" +#include "opdev/op_log.h" +#include "opdev/shape_utils.h" +using namespace op; + +namespace l0op { + +OP_TYPE_REGISTER(ApplyTopKTopPCustom); + +const aclTensor* ApplyTopKTopPCustom( + const aclTensor* sortedValue, const aclTensor* sortedIndices, const aclTensor* p, const aclTensor* k, + aclOpExecutor* executor) +{ + L0_DFX(ApplyTopKTopPCustom, sortedValue, sortedIndices, p, k); + auto output = executor->AllocTensor(sortedValue->GetViewShape(), sortedValue->GetDataType()); + if (p == nullptr) { + p = executor->AllocTensor(sortedValue->GetDataType(), Format::FORMAT_ND, Format::FORMAT_ND); + } + if (k == nullptr) { + k = executor->AllocTensor(DataType::DT_INT32, Format::FORMAT_ND, Format::FORMAT_ND); + } + ADD_TO_LAUNCHER_LIST_AICORE(ApplyTopKTopPCustom, OP_INPUT(sortedValue, sortedIndices, p, k), OP_OUTPUT(output)); + + return output; +} +} // namespace l0op diff --git a/csrc/apply_top_k_top_p_custom/op_host/apply_top_k_top_p_custom.h b/csrc/apply_top_k_top_p_custom/op_host/apply_top_k_top_p_custom.h new file mode 100644 index 00000000..87fd50b5 --- /dev/null +++ b/csrc/apply_top_k_top_p_custom/op_host/apply_top_k_top_p_custom.h @@ -0,0 +1,24 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file apply_top_k_top_p_custom.h + * \brief + */ +#ifndef OP_API_INC_LEVEL0_OP_APPLY_TOP_K_TOP_P_CUSTOM_OP_H_ +#define OP_API_INC_LEVEL0_OP_APPLY_TOP_K_TOP_P_CUSTOM_OP_H_ + +#include "opdev/op_executor.h" + +namespace l0op { +const aclTensor* ApplyTopKTopPCustom(const aclTensor* sortedValue, const aclTensor* sortedIndices, + const aclTensor* p, const aclTensor* k, aclOpExecutor* executor); +} +#endif // OP_API_INC_LEVEL0_OP_APPLY_TOP_K_TOP_P_CUSTOM_OP_H_ \ No newline at end of file diff --git a/csrc/apply_top_k_top_p_custom/op_host/apply_top_k_top_p_custom_def.cpp b/csrc/apply_top_k_top_p_custom/op_host/apply_top_k_top_p_custom_def.cpp new file mode 100644 index 00000000..319953b3 --- /dev/null +++ b/csrc/apply_top_k_top_p_custom/op_host/apply_top_k_top_p_custom_def.cpp @@ -0,0 +1,100 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file apply_top_k_top_p_custom_def.cpp + * \brief + */ +#include "register/op_def_registry.h" + +namespace ops { +class ApplyTopKTopPCustom : public OpDef { +public: + explicit ApplyTopKTopPCustom(const char *name) : OpDef(name) + { + this->Input("sorted_value") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("sorted_indices") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("p") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("k") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("out") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(false) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true); + this->AICore().AddConfig("ascend910b"); + this->AICore().AddConfig("ascend910_93"); + + OpAICoreConfig config_kirin = GetKirinCoreConfig(); + this->AICore().AddConfig("kirinx90", config_kirin); + } + +private: + OpAICoreConfig GetKirinCoreConfig() const + { + OpAICoreConfig config_kirin; + config_kirin.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true); + config_kirin.Input("sorted_value") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + config_kirin.Input("sorted_indices") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + config_kirin.Input("p") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + config_kirin.Input("k") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + config_kirin.Output("out") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + return config_kirin; + } +}; + +OP_ADD(ApplyTopKTopPCustom); +} // namespace ops \ No newline at end of file diff --git a/csrc/apply_top_k_top_p_custom/op_host/apply_top_k_top_p_custom_tiling.cpp b/csrc/apply_top_k_top_p_custom/op_host/apply_top_k_top_p_custom_tiling.cpp new file mode 100644 index 00000000..03282f3d --- /dev/null +++ b/csrc/apply_top_k_top_p_custom/op_host/apply_top_k_top_p_custom_tiling.cpp @@ -0,0 +1,314 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file apply_top_k_top_p_custom_tiling.cpp + * \brief + */ + +#include +#include +#include "error_log.h" +#include "tiling/platform/platform_ascendc.h" +#include "platform/platform_infos_def.h" +#include "register/op_def_registry.h" +#include "register/op_impl_registry.h" +#include "tiling/tiling_api.h" +#include "apply_top_k_top_p_custom_tiling.h" + +namespace { + constexpr uint32_t SYS_RESERVED_UB = uint32_t(16 * 1024); + constexpr uint32_t SELECT_RESERVED_UB = uint32_t(8 * 1024); + constexpr uint32_t DIM_ONE = 1; + constexpr uint32_t DIM_TWO = 2; + constexpr int32_t SORTED_VALUE_INPUT_INDEX = 0; + constexpr int32_t SORTED_INDICES_INPUT_INDEX = 1; + constexpr int32_t P_INPUT_INDEX = 2; + constexpr int32_t K_INPUT_INDEX = 3; + constexpr uint32_t DIM_INDEX0 = 0; + constexpr uint32_t FLOAT_BYTES = 4; + static std::map DTYPE_MAP = {{ge::DT_BF16, 2}, {ge::DT_FLOAT16, 1}, {ge::DT_FLOAT, 0}}; + static std::map DATATYPE_LEN_MAP = { + {ge::DT_FLOAT16, 2}, {ge::DT_BF16, 2}, {ge::DT_FLOAT, 4}}; + const static uint32_t SYS_WORKSPACESIZE = uint32_t(16 * 1024 * 1024); + + constexpr uint32_t DATA_PER_BLOCK_B32 = 8; + constexpr uint32_t BYTES_B32 = 4; + constexpr uint32_t BLOCK_BYTES = 32; + constexpr uint32_t K_VALUE_MAX = 1024; + constexpr uint32_t ONLY_TOP_P_KEY = 2; + constexpr uint32_t ONLY_TOP_K_KEY = 1; + constexpr uint32_t BATCH_MODE = 1; +} // namespace + +namespace optiling { +class ApplyTopKTopPCustomTiling { +public: + explicit ApplyTopKTopPCustomTiling(gert::TilingContext* context) : tilingcontext(context){}; + ge::graphStatus Init(); + ge::graphStatus RunKernelTiling(); +private: + ApplyTopKTopPCustomTilingData tilingData; + gert::TilingContext* tilingcontext = nullptr; + ge::graphStatus CheckShape(); + void SetTilingKey(); + void GetUsedCore(); + void CalDataPerCore(); + void FillTilingData(); + void PrintTilingData(); + template + inline auto CeilAlign(T1 a, T1 b) const -> T1 + { + return b == 0 ? a : (a + b - 1) / b * b; + } + template + inline auto FloorAlign(T1 a, T1 b) const -> T1 + { + return b == 0 ? a : a / b * b; + } + + const char *opName_ = nullptr; + uint32_t coreNum_ = 0; + uint32_t calUbSize_ = 0; + uint32_t batchSize_ = 0; + uint32_t vocabSize_ = 0; + uint32_t tilingKey_ = 0; + uint32_t usedCoreNum_ = 0; + uint32_t batchPerCore_ = 1; + uint32_t tailBatch_ = 0; + uint32_t dataNumInit_ = 0; + uint32_t dataNumInitAligned_ = 0; + uint32_t ubFactorElement_ = 0; + uint32_t ubFactorElementAligned_ = 0; + uint32_t tailUbFactorElement_ = 0; + uint32_t tailUbFactorElementAligned_ = 0; + uint32_t iterateTimes_ = 0; + uint32_t onlyTopK_ = 0; + uint32_t onlyTopP_ = 0; + uint64_t platformUbSize_ = 0; +}; + +ge::graphStatus ApplyTopKTopPCustomTiling::CheckShape() { + auto sortedValueShapePtr = tilingcontext->GetInputShape(SORTED_VALUE_INPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(tilingcontext, sortedValueShapePtr); + auto sortedValueShape = sortedValueShapePtr->GetStorageShape(); + if (sortedValueShape.GetDimNum() != DIM_TWO) { + OP_LOGE(opName_, "the dimNum of sorted_value should be 2, but got %u.", sortedValueShape.GetDimNum()); + return ge::GRAPH_FAILED; + } + auto sortedIndicesShapePtr = tilingcontext->GetInputShape(SORTED_INDICES_INPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(tilingcontext, sortedIndicesShapePtr); + auto sortedIndicesShape = sortedIndicesShapePtr->GetStorageShape(); + if (sortedIndicesShape.GetDimNum() != DIM_TWO) { + OP_LOGE(opName_, "the dimNum of sorted_indices should be 2, but got %u.", sortedIndicesShape.GetDimNum()); + return ge::GRAPH_FAILED; + } + batchSize_ = sortedValueShape.GetDim(DIM_INDEX0); + vocabSize_ = sortedValueShape.GetDim(DIM_ONE); + if (sortedIndicesShape.GetDim(DIM_INDEX0) != batchSize_ || sortedIndicesShape.GetDim(DIM_ONE) != vocabSize_) { + OP_LOGE(opName_, "the shape of sorted_indices should be equal to sorted_value."); + return ge::GRAPH_FAILED; + } + + auto pShapePtr = tilingcontext->GetOptionalInputShape(P_INPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(tilingcontext, pShapePtr); + auto pShape = pShapePtr->GetStorageShape(); + auto pDimNum = pShape.GetDimNum(); + if (pDimNum != DIM_ONE && pDimNum != 0) { + OP_LOGE(opName_, "the dimNum of p should be 1 or 0, but got %u.", pDimNum); + return ge::GRAPH_FAILED; + } + if (pDimNum != 0 && batchSize_ != pShape.GetDim(DIM_INDEX0)) { + OP_LOGE(opName_, "p.shape[0] should be equal to logits.shape[0]."); + return ge::GRAPH_FAILED; + } + + auto kShapePtr = tilingcontext->GetOptionalInputShape(K_INPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(tilingcontext, kShapePtr); + auto kShape = kShapePtr->GetStorageShape(); + auto kDimNum = kShape.GetDimNum(); + if (kDimNum != DIM_ONE && kDimNum != 0) { + OP_LOGE(opName_, "the dimNum of k should be 1 or 0, but got %u.", kShape.GetDimNum()); + return ge::GRAPH_FAILED; + } + if (kDimNum != 0 && batchSize_ != kShape.GetDim(DIM_INDEX0)) { + OP_LOGE(opName_, "k.shape[0] should be equal to logits.shape[0]."); + return ge::GRAPH_FAILED; + } + if (kDimNum == 0 && pDimNum == 0) { + OP_LOGE(opName_, "the dimNum of q and k should be 0 at the same time."); + return ge::GRAPH_FAILED; + } + onlyTopK_ = (kDimNum != 0 && pDimNum == 0) ? ONLY_TOP_K_KEY : 0; + onlyTopP_ = (pDimNum != 0 && kDimNum == 0) ? ONLY_TOP_P_KEY : 0; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus ApplyTopKTopPCustomTiling::Init() { + opName_ = tilingcontext->GetNodeName(); + OP_LOGD(opName_, "TilingForApplyTopKTopPCustom init."); + auto platformInfo = platform_ascendc::PlatformAscendC(tilingcontext->GetPlatformInfo()); + coreNum_ = platformInfo.GetCoreNumAiv(); + platformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, platformUbSize_); + OP_LOGD(opName_, "platformUbSize: %lu.", platformUbSize_); + uint32_t avaliableUb = static_cast(platformUbSize_) - SYS_RESERVED_UB - SELECT_RESERVED_UB; + calUbSize_ = FloorAlign(avaliableUb, BLOCK_BYTES); + if (CheckShape() == ge::GRAPH_FAILED) { + OP_LOGE(opName_, "check shape failed."); + return ge::GRAPH_FAILED; + } + uint32_t tempValue = 1; + while (tempValue < vocabSize_) { + tempValue <<= 1; + iterateTimes_++; + } // ceil(log2(vocabSize_)) + return ge::GRAPH_SUCCESS; +} + +void ApplyTopKTopPCustomTiling::SetTilingKey() { + tilingKey_ += onlyTopK_; + tilingKey_ += onlyTopP_; + tilingcontext->SetTilingKey(tilingKey_); + if (tilingKey_ == ONLY_TOP_P_KEY){ + tilingcontext->SetScheduleMode(BATCH_MODE); + } +} + +void ApplyTopKTopPCustomTiling::GetUsedCore() +{ + if (coreNum_ > 0) { + batchPerCore_ = coreNum_ == uint32_t(0) ? batchSize_ : batchSize_ / coreNum_; + tailBatch_ = batchSize_ % coreNum_; + usedCoreNum_ = coreNum_; + } +} + +void ApplyTopKTopPCustomTiling::CalDataPerCore() +{ + uint32_t inputDataTypeByte = DATATYPE_LEN_MAP[tilingcontext->GetInputDesc(SORTED_VALUE_INPUT_INDEX)->GetDataType()]; + uint32_t dataPerBlock = BLOCK_BYTES / inputDataTypeByte; + dataNumInit_ = vocabSize_ < K_VALUE_MAX ? vocabSize_ : K_VALUE_MAX; + dataNumInitAligned_ = vocabSize_ < K_VALUE_MAX ? vocabSize_ : K_VALUE_MAX; + ubFactorElement_ = vocabSize_ < K_VALUE_MAX ? vocabSize_ : K_VALUE_MAX; + ubFactorElementAligned_ = CeilAlign(ubFactorElement_, dataPerBlock); + tailUbFactorElement_ = vocabSize_ % ubFactorElement_; + tailUbFactorElement_ = tailUbFactorElement_ == uint32_t(0) ? ubFactorElement_ : tailUbFactorElement_; + tailUbFactorElementAligned_ = CeilAlign(tailUbFactorElement_, dataPerBlock); + + uint32_t sortedValueBytes = ubFactorElementAligned_ * inputDataTypeByte + K_VALUE_MAX * inputDataTypeByte; + uint32_t sortedIndicesBytes = ubFactorElementAligned_ * BYTES_B32 + K_VALUE_MAX * BYTES_B32; + uint32_t pBytes = dataPerBlock * inputDataTypeByte; + uint32_t kBytes = DATA_PER_BLOCK_B32 * BYTES_B32; + uint32_t outTensorBytes = ubFactorElementAligned_ * inputDataTypeByte; + + calUbSize_ = calUbSize_ - sortedValueBytes - sortedIndicesBytes - pBytes - kBytes - outTensorBytes; + if (onlyTopP_ > 0) { + calUbSize_ = static_cast(platformUbSize_); + } +} + +void ApplyTopKTopPCustomTiling::FillTilingData() +{ + tilingData.set_batchSize(batchSize_); + tilingData.set_vocabSize(vocabSize_); + tilingData.set_batchPerCore(batchPerCore_); + tilingData.set_tailBatch(tailBatch_); + tilingData.set_blockNum(usedCoreNum_); + tilingData.set_dataNumInit(dataNumInit_); + tilingData.set_dataNumInitAligned(dataNumInitAligned_); + tilingData.set_ubFactorElement(ubFactorElement_); + tilingData.set_ubFactorElementAligned(ubFactorElementAligned_); + tilingData.set_tailUbFactorElement(tailUbFactorElement_); + tilingData.set_tailUbFactorElementAligned(tailUbFactorElementAligned_); + tilingData.set_calUbSize(calUbSize_); + tilingData.set_iterateTimes(iterateTimes_); +} + +void ApplyTopKTopPCustomTiling::PrintTilingData() +{ + OP_LOGD(opName_, "batchSize: %u.", tilingData.get_batchSize()); + OP_LOGD(opName_, "vocabSize: %u.", tilingData.get_vocabSize()); + OP_LOGD(opName_, "batchPerCore: %u.", tilingData.get_batchPerCore()); + OP_LOGD(opName_, "tailBatch: %u.", tilingData.get_tailBatch()); + OP_LOGD(opName_, "usedCoreNum: %u.", tilingData.get_blockNum()); + OP_LOGD(opName_, "dataNumInit_: %u.", tilingData.get_dataNumInit()); + OP_LOGD(opName_, "dataNumInitAligned_: %u.", tilingData.get_dataNumInitAligned()); + OP_LOGD(opName_, "ubFactorElement: %u.", tilingData.get_ubFactorElement()); + OP_LOGD(opName_, "ubFactorElementAligned: %u.", tilingData.get_ubFactorElementAligned()); + OP_LOGD(opName_, "tailUbFactorElement: %u.", tilingData.get_tailUbFactorElement()); + OP_LOGD(opName_, "tailUbFactorElementAligned: %u.", tilingData.get_tailUbFactorElementAligned()); + OP_LOGD(opName_, "calUbSize: %u.", tilingData.get_calUbSize()); + OP_LOGD(opName_, "iterateTimes: %u.", tilingData.get_iterateTimes()); +} + +ge::graphStatus ApplyTopKTopPCustomTiling::RunKernelTiling() +{ + OP_LOGD(opName_, "TilingForApplyTopKTopPCustom start."); + + SetTilingKey(); + GetUsedCore(); + CalDataPerCore(); + FillTilingData(); + PrintTilingData(); + + OP_LOGD(opName_, "tilingKey: %u.", tilingKey_); + uint32_t syncWorkspaceSize = SYS_WORKSPACESIZE; + size_t* currentWorkspace = tilingcontext->GetWorkspaceSizes(1); + currentWorkspace[0] = onlyTopP_ > 0 ? syncWorkspaceSize + batchSize_ * vocabSize_ * FLOAT_BYTES : syncWorkspaceSize; + + tilingData.SaveToBuffer(tilingcontext->GetRawTilingData()->GetData(), + tilingcontext->GetRawTilingData()->GetCapacity()); + tilingcontext->GetRawTilingData()->SetDataSize(tilingData.GetDataSize()); + tilingcontext->SetBlockDim(usedCoreNum_); + + OP_LOGD(opName_, "TilingForApplyTopKTopPCustom end."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus TilingForApplyTopKTopPCustom(gert::TilingContext* context) +{ + ApplyTopKTopPCustomTiling tilingObject(context); + auto ret = tilingObject.Init(); + if (ret != ge::GRAPH_SUCCESS) { + OP_LOGE(context->GetNodeName(), "tiling Init failed."); + return ge::GRAPH_FAILED; + } + ret = tilingObject.RunKernelTiling(); + OP_LOGD(context->GetNodeName(), "TilingForApplyTopKTopPCustom end."); + return ret; +} + +static ge::graphStatus TilingPrepareForApplyTopKTopPCustom(gert::TilingParseContext* context) +{ + OP_LOGD(context->GetNodeName(), "TilingPrepareForApplyTopKTopPCustom start"); + auto compileInfo = context->GetCompiledInfo(); + OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo); + auto platformInfo = context->GetPlatformInfo(); + OP_CHECK_NULL_WITH_CONTEXT(context, platformInfo); + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo); + compileInfo->totalCoreNum = ascendcPlatform.GetCoreNumAiv(); + uint64_t ubSizePlatForm; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm); + compileInfo->ubSizePlatForm = static_cast(ubSizePlatForm); + OP_CHECK_IF(compileInfo->ubSizePlatForm <= 0, + OP_LOGE(context->GetNodeName(), "Failed to get ub size"), + return ge::GRAPH_FAILED); + OP_LOGD(context->GetNodeName(), "ub_size_platform is %lu", compileInfo->ubSizePlatForm); + uint64_t totalUbSize = 0; + platformInfo->GetLocalMemSize(fe::LocalMemType::UB, totalUbSize); + OP_LOGD(context->GetNodeName(), "total ub size is %lu", totalUbSize); + OP_LOGD(context->GetNodeName(), "TilingPrepareForApplyTopKTopPCustom end"); + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(ApplyTopKTopPCustom) + .Tiling(TilingForApplyTopKTopPCustom) + .TilingParse(TilingPrepareForApplyTopKTopPCustom); +} // namespace optiling \ No newline at end of file diff --git a/csrc/apply_top_k_top_p_custom/op_host/apply_top_k_top_p_custom_tiling.h b/csrc/apply_top_k_top_p_custom/op_host/apply_top_k_top_p_custom_tiling.h new file mode 100644 index 00000000..baf9e9ba --- /dev/null +++ b/csrc/apply_top_k_top_p_custom/op_host/apply_top_k_top_p_custom_tiling.h @@ -0,0 +1,51 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file apply_top_k_top_p_custom_tiling.h + * \brief + * ATTENTION: MAKE SURE 'BEGIN_TILING_DATA_DEF' STAY IN THE SAME LINE (28) USING BLANK LINES. + * + * + * + * + * + */ +#ifndef __APPLY_TOP_K_TOP_P_CUSTOM_TILINGDATA_H__ +#define __APPLY_TOP_K_TOP_P_CUSTOM_TILINGDATA_H__ + +#include "register/tilingdata_base.h" + +namespace optiling { + +BEGIN_TILING_DATA_DEF(ApplyTopKTopPCustomTilingData) + TILING_DATA_FIELD_DEF(uint32_t, batchSize); + TILING_DATA_FIELD_DEF(uint32_t, vocabSize); + TILING_DATA_FIELD_DEF(uint32_t, batchPerCore); + TILING_DATA_FIELD_DEF(uint32_t, tailBatch); + TILING_DATA_FIELD_DEF(uint32_t, blockNum); + TILING_DATA_FIELD_DEF(uint32_t, dataNumInit); + TILING_DATA_FIELD_DEF(uint32_t, dataNumInitAligned); + TILING_DATA_FIELD_DEF(uint32_t, ubFactorElement); + TILING_DATA_FIELD_DEF(uint32_t, ubFactorElementAligned); + TILING_DATA_FIELD_DEF(uint32_t, tailUbFactorElement); + TILING_DATA_FIELD_DEF(uint32_t, tailUbFactorElementAligned); + TILING_DATA_FIELD_DEF(uint32_t, calUbSize); + TILING_DATA_FIELD_DEF(uint32_t, iterateTimes); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(ApplyTopKTopPCustom, ApplyTopKTopPCustomTilingData) + +struct TilingForApplyTopKTopPCustomCompileInfo { + uint32_t totalCoreNum = 0; + uint64_t ubSizePlatForm = 0; +}; + +} // namespace optiling +#endif // __APPLY_TOP_K_TOP_P_CUSTOM_TILINGDATA_H__ \ No newline at end of file diff --git a/csrc/apply_top_k_top_p_custom/op_host/error_log.h b/csrc/apply_top_k_top_p_custom/op_host/error_log.h new file mode 100644 index 00000000..6cbaee24 --- /dev/null +++ b/csrc/apply_top_k_top_p_custom/op_host/error_log.h @@ -0,0 +1,71 @@ +#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ +#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ + +#include +#include "toolchain/slog.h" + +#define OP_LOGI(opname, ...) +#define OP_LOGW(opname, ...) \ + do { \ + printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGE_WITHOUT_REPORT(opname, ...) \ + do { \ + printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGE(opname, ...) \ + do { \ + printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGD(opname, ...) + +namespace optiling { + +#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \ + do { \ + OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \ + } while (0) + + +#define OP_CHECK_IF(cond, log_func, expr) \ + do { \ + if (cond) { \ + log_func; \ + expr; \ + } \ + } while (0) + + + +#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \ + do { \ + if ((ptr) == nullptr) { \ + OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \ + return ge::GRAPH_FAILED; \ + } \ + } while (0) + +} // namespace optiling + +template +T CeilAlign(T a, T b) +{ + return (a + b - 1) / b * b; +} + +template +T CeilDiv(T a, T b) +{ + if (b == 0) { + return a; + } + return (a + b - 1) / b; +} + +#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ \ No newline at end of file diff --git a/csrc/apply_top_k_top_p_custom/op_host/sort.h b/csrc/apply_top_k_top_p_custom/op_host/sort.h new file mode 100644 index 00000000..6ca937fc --- /dev/null +++ b/csrc/apply_top_k_top_p_custom/op_host/sort.h @@ -0,0 +1,26 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sort.h + * \brief + */ +#ifndef PTA_NPU_OP_API_INC_LEVEL0_OP_SORT_OP_H_ +#define PTA_NPU_OP_API_INC_LEVEL0_OP_SORT_OP_H_ + +#include "opdev/op_executor.h" +#include "opdev/fast_vector.h" + +namespace l0op { +const std::tuple Sort(const aclTensor* self, int64_t dim, bool descending, bool stable, + op::DataType indicesType, aclOpExecutor* executor); +} + +#endif // PTA_NPU_OP_API_INC_LEVEL0_OP_SORT_OP_H_ diff --git a/csrc/apply_top_k_top_p_custom/op_kernel/apply_top_k_top_p_custom.cpp b/csrc/apply_top_k_top_p_custom/op_kernel/apply_top_k_top_p_custom.cpp new file mode 100644 index 00000000..a98f6cec --- /dev/null +++ b/csrc/apply_top_k_top_p_custom/op_kernel/apply_top_k_top_p_custom.cpp @@ -0,0 +1,42 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file apply_top_k_top_p_custom.cpp + * \brief + */ + +#include "apply_top_k_top_p_custom.h" +#include "apply_top_p_custom.h" +using namespace AscendC; +using namespace ApplyTopKTopPCustomOp; +using namespace ApplyTopPCustomOp; + +extern "C" __global__ __aicore__ void apply_top_k_top_p_custom(GM_ADDR sorted_value, GM_ADDR sorted_indices, + GM_ADDR p, GM_ADDR k, GM_ADDR out, GM_ADDR workSpace, GM_ADDR tiling) { + TPipe pipe; + GET_TILING_DATA(tilingData, tiling); + if (TILING_KEY_IS(0)) { + ApplyTopKTopPCustomOp::ApplyTopKTopPCustom op; + op.InitTilingData(tilingData, sorted_value, sorted_indices, p, k, out); + op.InitBuffer(&pipe); + op.Process(); + } else if (TILING_KEY_IS(1)) { + ApplyTopKTopPCustomOp::ApplyTopKTopPCustom op; + op.InitTilingData(tilingData, sorted_value, sorted_indices, p, k, out); + op.InitBuffer(&pipe); + op.ProcessTopK(); + } else if (TILING_KEY_IS(2)) { + ApplyTopPCustomOp::ApplyTopPCustom op; + op.InitTilingData(tilingData, sorted_value, sorted_indices, p, k, out, workSpace); + op.InitBuffer(&pipe); + op.ProcessTopP(); + } +} \ No newline at end of file diff --git a/csrc/apply_top_k_top_p_custom/op_kernel/apply_top_k_top_p_custom.h b/csrc/apply_top_k_top_p_custom/op_kernel/apply_top_k_top_p_custom.h new file mode 100644 index 00000000..e911399e --- /dev/null +++ b/csrc/apply_top_k_top_p_custom/op_kernel/apply_top_k_top_p_custom.h @@ -0,0 +1,719 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file apply_top_k_top_p_custom.h + * \brief + */ +#ifndef APPLY_TOP_K_TOP_P_CUSTOM_H_KERNEL +#define APPLY_TOP_K_TOP_P_CUSTOM_H_KERNEL + +#include "kernel_operator.h" + +using namespace AscendC; +namespace ApplyTopKTopPCustomOp { +constexpr uint32_t BUFFER_NUM = 1; +constexpr uint16_t FLOAT16_NEG_INF = 0xFC00; // -inf 64512 +constexpr uint16_t BF16_NEG_INF = 0xFF80; // -inf 65408 +constexpr int32_t FLOAT32_NEG_INF = 0xFF800000; // -inf -2139095040 + +constexpr uint32_t BLOCK_BYTES = 32; +constexpr uint32_t DATA_PER_BLOCK_B32 = 8; +constexpr uint32_t DATA_PER_REPEAT_B32 = 64; +constexpr uint32_t K_MAX = 1024; +constexpr uint64_t MASK_64 = 64; +constexpr CumSumConfig CUMSUM_CONFIG{true, true, false}; + +template +class ApplyTopKTopPCustom { +public: + __aicore__ inline ApplyTopKTopPCustom(){}; + __aicore__ inline void InitTilingData( + const ApplyTopKTopPCustomTilingData &__restrict tilingData, GM_ADDR sorted_value, GM_ADDR sorted_indices, + GM_ADDR p, GM_ADDR k, GM_ADDR out); + __aicore__ inline void InitBuffer(TPipe *inputPipe); + __aicore__ inline void Process(); + __aicore__ inline void ProcessTopK(); +private: + __aicore__ inline void InitCopyIn(uint32_t loopBatch, int64_t currentGmIdx); + __aicore__ inline void InitProcess(uint32_t loopBatch); + __aicore__ inline void ProcessKLtKMax(uint32_t loopBatch); + __aicore__ inline void ScatterCumtomImpl(uint32_t loopBatch, uint32_t loopProbNum, uint32_t offset); + __aicore__ inline void ProcessRemain(uint32_t loopBatch); + __aicore__ inline void GetKthResult(uint32_t loopBatch, uint32_t offset, uint8_t repeatTimes); + __aicore__ inline void GetFirstKLoop(uint32_t loopBatch, int32_t &firstKLoop); + __aicore__ inline void ScatterFromFirstKLoop(uint32_t loopBatch, int32_t firstKLoop, float &cumsumData); + __aicore__ inline void ReduceSumWithAddsAndExpImpl(uint32_t offset, uint32_t loopDataNum); + __aicore__ inline void CumSumWithAddsAndExpImpl( + uint32_t offset, uint32_t loopDataNum, uint32_t cumsumInner, float cumsumData); + // topk func + __aicore__ inline void InitProcessTopK(uint32_t loopBatch); + __aicore__ inline void ProcessKLtKMaxTopK(uint32_t loopBatch); + __aicore__ inline void ProcessRemainTopK(uint32_t loopBatch); + __aicore__ inline void GetFirstKLoopTopK(uint32_t loopBatch, int32_t &firstKLoop); + __aicore__ inline void ScatterFromFirstKLoopTopK(uint32_t loopBatch, int32_t firstKLoop); + __aicore__ inline void ScatterCumtomImplTopK(uint32_t loopBatch, uint32_t loopProbNum, uint32_t offset); + __aicore__ inline void SToMTE3Sync() { + event_t eventIDSToMTE3 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_MTE3)); + SetFlag(eventIDSToMTE3); + WaitFlag(eventIDSToMTE3); + } + __aicore__ inline void MTE3ToSSync() { + event_t eventIDMTE3ToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE3_S)); + SetFlag(eventIDMTE3ToS); + WaitFlag(eventIDMTE3ToS); + } + __aicore__ inline void VToSSync() { + event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + SetFlag(eventIdVToS); + WaitFlag(eventIdVToS); + } + __aicore__ inline void MTE2ToVSync() { + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + } + __aicore__ inline void MTE2ToSSync() { + event_t eventIdMte2ToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_S)); + SetFlag(eventIdMte2ToS); + WaitFlag(eventIdMte2ToS); + } + __aicore__ inline void MTE3ToVSync() { + event_t eventIdMte3ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V)); + SetFlag(eventIdMte3ToV); + WaitFlag(eventIdMte3ToV); + } + __aicore__ inline void SToMTE2Sync() + { + event_t eventIDSToMTE2 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_MTE2)); + SetFlag(eventIDSToMTE2); + WaitFlag(eventIDSToMTE2); + } + __aicore__ inline void VToMTE3Sync() { + event_t eventIDVToMTE3 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3)); + SetFlag(eventIDVToMTE3); + WaitFlag(eventIDVToMTE3); + } +private: + TPipe *pipe_; + // create queues for input, in this case depth is equal to buffer num + TQue sortedValueInQueue_; + TQue sortedIndicesInQueue_; + TQue pInQueue_; + TQue kInQueue_; + TQue outQueue_; + TBuf calBuf_; + + // tilingData + uint32_t batchSize_ = 0; + uint32_t vocabSize_ = 0; + uint32_t batchPerCore_ = 0; + uint32_t tailBatch_ = 0; + uint32_t blockNum_ = 0; + uint32_t dataNumInit_ = 0; + uint32_t dataNumInitAligned_ = 0; + uint32_t ubFactorElement_ = 0; + uint32_t ubFactorElementAligned_ = 0; + uint32_t tailUbFactorElement_ = 0; + uint32_t tailUbFactorElementAligned_ = 0; + uint32_t calUbSize_ = 0; + + uint32_t blockIdx_ = 0; + uint32_t loopBatch_ = 0; + uint32_t batchOffset_ = 0; + uint32_t bufOffsetLoop = 0; + uint32_t loopInner_ = 0; + uint32_t loopInnerOnlyP_ = 0; + int64_t baseGmIdx_ = 0; + + GlobalTensor mGmSortedValue_; + GlobalTensor mGmSortedIndices_; + GlobalTensor mGmP_; + GlobalTensor mGmK_; + GlobalTensor mGmOut_; + + LocalTensor kLocal; + LocalTensor pLocal; + LocalTensor outTensor; + LocalTensor sortedValueLocal; + LocalTensor sortedIndicesLocal; + + LocalTensor sortedValueLocalFp32; + LocalTensor negInfLocal; + + LocalTensor calLocalFp32; + LocalTensor kthValueLocal; + LocalTensor tmpLocal; + LocalTensor cumSumRes; + LocalTensor cumSumTmp; + LocalTensor reduceLocal; + LocalTensor softMaxRes; + LocalTensor scatterTensor; + LocalTensor sharedTmpBuffer; + + float kthValue = 0; + float pValue = 0; + float maxValue = 0; + float reduceSumValueInvert = 0; + float reduceSumValue = 0; + inputT kthTopKValue = 0; + BinaryRepeatParams repeatParams = {1, 0, 1, 8, 0, 8}; + DataCopyExtParams scatterCopyParams{1, (uint32_t)(sizeof(outputT)), 0, 0, 0}; +}; + +template +__aicore__ inline void ApplyTopKTopPCustom::InitTilingData( + const ApplyTopKTopPCustomTilingData &__restrict tilingData, GM_ADDR sorted_value, GM_ADDR sorted_indices, + GM_ADDR p, GM_ADDR k, GM_ADDR out) { + batchSize_ = tilingData.batchSize; + vocabSize_ = tilingData.vocabSize; + batchPerCore_ = tilingData.batchPerCore; + tailBatch_ = tilingData.tailBatch; + blockNum_ = tilingData.blockNum; + dataNumInit_ = tilingData.dataNumInit; + dataNumInitAligned_ = AscendC::AlignUp(dataNumInit_, DATA_PER_BLOCK_B32); + ubFactorElement_ = tilingData.ubFactorElement; + ubFactorElementAligned_ = tilingData.ubFactorElementAligned; + tailUbFactorElement_ = tilingData.tailUbFactorElement; + tailUbFactorElementAligned_ = tilingData.tailUbFactorElementAligned; + calUbSize_ = tilingData.calUbSize; + blockIdx_ = GetBlockIdx(); + + if (blockIdx_ < tailBatch_) + { + loopBatch_ = batchPerCore_ + 1; + batchOffset_ = blockIdx_ * loopBatch_; + } + else + { + loopBatch_ = batchPerCore_; + batchOffset_ = blockIdx_ * batchPerCore_ + tailBatch_; + } + loopInner_ = (vocabSize_ - dataNumInit_ + ubFactorElementAligned_ - 1) / ubFactorElementAligned_; + loopInnerOnlyP_ = (vocabSize_ + ubFactorElementAligned_ - 1) / ubFactorElementAligned_; + mGmSortedValue_.SetGlobalBuffer(reinterpret_cast<__gm__ inputT *>(sorted_value)); + mGmSortedIndices_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sorted_indices)); + mGmP_.SetGlobalBuffer(reinterpret_cast<__gm__ inputT *>(p)); + mGmK_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(k)); + mGmOut_.SetGlobalBuffer(reinterpret_cast<__gm__ outputT *>(out)); +} + +// init used buffer +template +__aicore__ inline void ApplyTopKTopPCustom::InitBuffer(TPipe *inputPipe) { + pipe_ = inputPipe; + pipe_->InitBuffer(sortedValueInQueue_, BUFFER_NUM, sizeof(inputT) * (ubFactorElementAligned_ + K_MAX)); + pipe_->InitBuffer(sortedIndicesInQueue_, BUFFER_NUM, sizeof(int32_t) * (ubFactorElementAligned_ + K_MAX)); + pipe_->InitBuffer(pInQueue_, BUFFER_NUM, BLOCK_BYTES); + pipe_->InitBuffer(kInQueue_, BUFFER_NUM, BLOCK_BYTES); + pipe_->InitBuffer(outQueue_, BUFFER_NUM, sizeof(outputT) * ubFactorElementAligned_); + pipe_->InitBuffer(calBuf_, calUbSize_); + if constexpr (!IsSameType::value) { + sortedValueLocalFp32 = calBuf_.GetWithOffset(ubFactorElementAligned_ + K_MAX, bufOffsetLoop); + bufOffsetLoop = bufOffsetLoop + (ubFactorElementAligned_ + K_MAX) * sizeof(float); + } + kthValueLocal = calBuf_.GetWithOffset(DATA_PER_BLOCK_B32, bufOffsetLoop); + bufOffsetLoop = bufOffsetLoop + BLOCK_BYTES; + + negInfLocal = calBuf_.GetWithOffset(DATA_PER_BLOCK_B32, bufOffsetLoop); + bufOffsetLoop = bufOffsetLoop + BLOCK_BYTES; + + tmpLocal = calBuf_.GetWithOffset(ubFactorElementAligned_, bufOffsetLoop); + bufOffsetLoop = bufOffsetLoop + ubFactorElementAligned_ * sizeof(float); + cumSumRes = calBuf_.GetWithOffset(ubFactorElementAligned_, bufOffsetLoop); + bufOffsetLoop = bufOffsetLoop + ubFactorElementAligned_ * sizeof(float); + cumSumTmp = calBuf_.GetWithOffset(ubFactorElementAligned_, bufOffsetLoop); + bufOffsetLoop = bufOffsetLoop + ubFactorElementAligned_ * sizeof(float); + reduceLocal = calBuf_.GetWithOffset(ubFactorElementAligned_ * BLOCK_BYTES, bufOffsetLoop); + bufOffsetLoop = bufOffsetLoop + ubFactorElementAligned_ * BLOCK_BYTES * sizeof(float); + + softMaxRes = tmpLocal.template ReinterpretCast(); + scatterTensor = reduceLocal.template ReinterpretCast(); + sharedTmpBuffer = reduceLocal.template ReinterpretCast(); +} + +template +__aicore__ inline void ApplyTopKTopPCustom::Process() { + kLocal = kInQueue_.AllocTensor(); + pLocal = pInQueue_.AllocTensor(); + outTensor = outQueue_.AllocTensor(); + sortedValueLocal = sortedValueInQueue_.AllocTensor(); + sortedIndicesLocal = sortedIndicesInQueue_.AllocTensor(); + Duplicate(negInfLocal.template ReinterpretCast(), FLOAT32_NEG_INF, DATA_PER_BLOCK_B32); + if constexpr (IsSameType::value) { + calLocalFp32 = sortedValueLocal; + Duplicate(outTensor.template ReinterpretCast(), FLOAT32_NEG_INF, ubFactorElementAligned_); + } else if constexpr (IsSameType::value) { + calLocalFp32 = sortedValueLocalFp32; + Duplicate(outTensor.template ReinterpretCast(), FLOAT16_NEG_INF, ubFactorElementAligned_); + } else { + calLocalFp32 = sortedValueLocalFp32; + Duplicate(outTensor.template ReinterpretCast(), BF16_NEG_INF, ubFactorElementAligned_); + } + VToMTE3Sync(); + for (uint32_t loopBatch = 0; loopBatch < loopBatch_; loopBatch++) { + baseGmIdx_ = batchOffset_ * vocabSize_ + loopBatch * vocabSize_; + InitProcess(loopBatch); + if (calLocalFp32.GetValue(ubFactorElementAligned_) < kthValue) { + ProcessKLtKMax(loopBatch); + } else { + ProcessRemain(loopBatch); + } + } + kInQueue_.FreeTensor(kLocal); + pInQueue_.FreeTensor(pLocal); + sortedValueInQueue_.FreeTensor(sortedValueLocal); + sortedIndicesInQueue_.FreeTensor(sortedIndicesLocal); + outQueue_.FreeTensor(outTensor); +} + +template +__aicore__ inline void ApplyTopKTopPCustom::InitCopyIn(uint32_t loopBatch, + int64_t currentGmIdx) { + DataCopyPad(mGmOut_[currentGmIdx], outTensor, {1, (uint32_t)(dataNumInit_ * sizeof(outputT)), 0, 0, 0}); + DataCopyPad(sortedValueLocal[ubFactorElementAligned_], mGmSortedValue_[currentGmIdx], + {1, static_cast(dataNumInit_ * sizeof(inputT)), 0, 0, 0}, + {false, 0, 0, 0}); + DataCopyPad(pLocal, mGmP_[batchOffset_ + loopBatch], {1, static_cast(sizeof(inputT)), 0, 0, 0}, + {false, 0, 0, 0}); + if constexpr (!IsSameType::value) { + MTE2ToVSync(); + Cast(sortedValueLocalFp32[ubFactorElementAligned_], sortedValueLocal[ubFactorElementAligned_], + RoundMode::CAST_NONE, dataNumInit_); + Cast(tmpLocal, pLocal, RoundMode::CAST_NONE, DATA_PER_BLOCK_B32); + } + DataCopyPad(sortedIndicesLocal[ubFactorElementAligned_], mGmSortedIndices_[currentGmIdx], + {1, static_cast(dataNumInit_ * sizeof(int32_t)), 0, 0, 0}, + {false, 0, 0, 0}); + DataCopyPad(kLocal, mGmK_[batchOffset_ + loopBatch], {1, static_cast(sizeof(int32_t)), 0, 0, 0}, + {false, 0, 0, 0}); +} + +template +__aicore__ inline void ApplyTopKTopPCustom::GetKthResult(uint32_t loopBatch, + uint32_t offset, uint8_t repeatTimes){ + Compare(tmpLocal.template ReinterpretCast(), kthValueLocal, calLocalFp32[offset], + CMPMODE::GT, MASK_64, repeatTimes, repeatParams); + PipeBarrier(); + Select(calLocalFp32[offset], tmpLocal.template ReinterpretCast(), + negInfLocal, calLocalFp32[offset], SELMODE::VSEL_TENSOR_TENSOR_MODE, MASK_64, + repeatTimes, repeatParams); +} + +template +__aicore__ inline void ApplyTopKTopPCustom::ReduceSumWithAddsAndExpImpl(uint32_t offset, + uint32_t loopDataNum) { + Adds(softMaxRes, calLocalFp32[offset], maxValue, loopDataNum); + PipeBarrier(); + Exp(softMaxRes, softMaxRes, loopDataNum); + PipeBarrier(); + ReduceSum(reduceLocal, softMaxRes, reduceLocal, loopDataNum); +} + +template +__aicore__ inline void ApplyTopKTopPCustom::InitProcess(uint32_t loopBatch) { + int64_t initGmIdx = baseGmIdx_ + vocabSize_ - dataNumInit_; + InitCopyIn(loopBatch, initGmIdx); + MTE2ToSSync(); + + int32_t kValue = kLocal.GetValue(0); + if constexpr (IsSameType::value) { + pValue = float(1.0) - pLocal.GetValue(0); + } else { + pValue = float(1.0) - tmpLocal.GetValue(0); + } + maxValue = -calLocalFp32[ubFactorElementAligned_].GetValue(dataNumInit_ - 1); + if constexpr (IsSameType::value) { + kthValue = mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0); + } else if constexpr (IsSameType::value) { + kthValue = static_cast(mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0)); + } else { + kthValue = ToFloat(mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0)); + } + + Duplicate(kthValueLocal, kthValue, 8); + PipeBarrier(); + + uint8_t repeatTimes = (dataNumInit_ + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32; + GetKthResult(loopBatch, ubFactorElementAligned_, repeatTimes); + PipeBarrier(); + DataCopyExtParams copyParams{1, (uint32_t)(ubFactorElementAligned_ * sizeof(outputT)), 0, 0, 0}; + + ReduceSumWithAddsAndExpImpl(ubFactorElementAligned_, dataNumInit_); + VToSSync(); + reduceSumValue = reduceLocal.GetValue(0); + reduceSumValueInvert = 1 / reduceSumValue; +} + +template +__aicore__ inline void ApplyTopKTopPCustom::ProcessKLtKMax(uint32_t loopBatch) { + DataCopyExtParams copyParams{1, (uint32_t)(ubFactorElementAligned_ * sizeof(outputT)), 0, 0, 0}; + for (int32_t loopInner = 0; loopInner < loopInner_; loopInner++) { + int64_t currentGmIdxInner = baseGmIdx_ + loopInner * ubFactorElementAligned_; + if (loopInner == loopInner_ - 1) { + DataCopyPad(mGmOut_[currentGmIdxInner], outTensor, + {1, (uint32_t)(tailUbFactorElement_ * sizeof(outputT)), 0, 0, 0}); + } else { + DataCopyPad(mGmOut_[currentGmIdxInner], outTensor, copyParams); + } + } + Muls(softMaxRes, softMaxRes, reduceSumValueInvert, dataNumInit_); + PipeBarrier(); + const CumSumInfo cumSumInfo{1, dataNumInitAligned_}; + CumSum(cumSumRes, cumSumTmp, softMaxRes, sharedTmpBuffer, cumSumInfo); + VToSSync(); + int32_t loopProb = dataNumInit_ - 1; + scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb)); + SToMTE3Sync(); + int32_t gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb); + PipeBarrier(); + DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex], scatterTensor.template ReinterpretCast(), scatterCopyParams); + MTE3ToSSync(); + loopProb = loopProb - 1; + for (; loopProb >= 0; loopProb--) { + float cumsumData = cumSumRes.GetValue(loopProb); + if (cumsumData <= pValue) { + break; + } + scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb)); + gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb); + SToMTE3Sync(); + DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex], + scatterTensor.template ReinterpretCast(), scatterCopyParams); + MTE3ToSSync(); + } +} + +template +__aicore__ inline void ApplyTopKTopPCustom::ScatterCumtomImpl(uint32_t loopBatch, + uint32_t loopProbNum, uint32_t offset) { + for (int32_t loopProb = 0; loopProb < static_cast(loopProbNum); loopProb++) { + float cumsumDataTmp = cumSumRes.GetValue(loopProb); + if (cumsumDataTmp <= pValue) { + continue; + } + scatterTensor.SetValue(0, sortedValueLocal[offset].GetValue(loopProb)); + int32_t gmIndex = sortedIndicesLocal[offset].GetValue(loopProb); + SToMTE3Sync(); + DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex], scatterTensor.template ReinterpretCast(), + {1, (uint32_t)(1 * sizeof(outputT)), 0, 0, 0}); + MTE3ToSSync(); + } +} + +template +__aicore__ inline void ApplyTopKTopPCustom::GetFirstKLoop(uint32_t loopBatch, + int32_t &firstKLoop) { + uint8_t repeatTimes = (dataNumInit_ + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32; + uint32_t loopDataNum = ubFactorElementAligned_; + for (int32_t loopInner = 0; loopInner < loopInner_; loopInner++) { + int64_t currentGmIdx = baseGmIdx_ + loopInner * ubFactorElementAligned_; + if (loopInner == (loopInner_ - 1)) { + repeatTimes = ((tailUbFactorElement_) + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32; + loopDataNum = tailUbFactorElement_; + } + DataCopyPad(mGmOut_[currentGmIdx], outTensor, {1, (uint32_t)(loopDataNum * sizeof(outputT)), 0, 0, 0}); + DataCopyPad(sortedValueLocal.template ReinterpretCast(), mGmSortedValue_[currentGmIdx], + {1, static_cast(loopDataNum * sizeof(inputT)), 0, 0, 0}, + {false, 0, 0, 0}); + if constexpr (!IsSameType::value) { + MTE2ToVSync(); + Cast(sortedValueLocalFp32, sortedValueLocal, RoundMode::CAST_NONE, loopDataNum); + VToSSync(); + } else { + MTE2ToSSync(); + } + if (calLocalFp32.GetValue(loopDataNum - 1) < kthValue) { + firstKLoop += 1; + continue; + } + + GetKthResult(loopBatch, 0, repeatTimes); + PipeBarrier(); + + ReduceSumWithAddsAndExpImpl(0, loopDataNum); + VToSSync(); + reduceSumValue += reduceLocal.GetValue(0); + } +} + +template +__aicore__ inline void ApplyTopKTopPCustom::CumSumWithAddsAndExpImpl(uint32_t offset, + uint32_t loopDataNum, uint32_t cumsumInner, float cumsumData) { + Adds(softMaxRes, calLocalFp32[offset], maxValue, loopDataNum); + PipeBarrier(); + Exp(softMaxRes, softMaxRes, loopDataNum); + PipeBarrier(); + Muls(softMaxRes, softMaxRes, reduceSumValueInvert, loopDataNum); + PipeBarrier(); + const CumSumInfo cumSumInfo{1, cumsumInner}; + CumSum(cumSumRes, cumSumTmp, softMaxRes, sharedTmpBuffer, cumSumInfo); + PipeBarrier(); + Adds(cumSumRes, cumSumRes, cumsumData, loopDataNum); +} + +template +__aicore__ inline void ApplyTopKTopPCustom::ProcessRemain(uint32_t loopBatch) { + int32_t firstKLoop = 0; + GetFirstKLoop(loopBatch, firstKLoop); + reduceSumValueInvert = 1 / reduceSumValue; + float cumsumData = 0; + ScatterFromFirstKLoop(loopBatch, firstKLoop, cumsumData); + uint32_t loopProb = dataNumInit_ - 1; + scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb)); + int32_t gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb); + SToMTE3Sync(); + DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex], + scatterTensor.template ReinterpretCast(), scatterCopyParams); + MTE3ToVSync(); + CumSumWithAddsAndExpImpl(ubFactorElementAligned_, dataNumInit_, dataNumInitAligned_, cumsumData); + VToSSync(); + ScatterCumtomImpl(loopBatch, dataNumInit_ - 1, ubFactorElementAligned_); +} + +template +__aicore__ inline void ApplyTopKTopPCustom::ScatterFromFirstKLoop(uint32_t loopBatch, + int32_t firstKLoop, float &cumsumData) { + uint32_t loopDataNum = ubFactorElementAligned_; + uint32_t cumsumInner = ubFactorElementAligned_; + uint8_t repeatTimes = ((ubFactorElementAligned_) + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32; + for (int32_t loopInner = firstKLoop; loopInner < loopInner_; loopInner++) { + int64_t currentGmIdx = baseGmIdx_ + loopInner * ubFactorElementAligned_; + if (loopInner == (loopInner_ - 1)) { + repeatTimes = (tailUbFactorElement_ + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32; + loopDataNum = tailUbFactorElement_; + cumsumInner = tailUbFactorElementAligned_; + } + DataCopyPad(sortedValueLocal.template ReinterpretCast(), mGmSortedValue_[currentGmIdx], + {1, static_cast(loopDataNum * sizeof(inputT)), 0, 0, 0}, + {false, 0, 0, 0}); + DataCopyPad(sortedIndicesLocal, mGmSortedIndices_[currentGmIdx], + {1, static_cast(loopDataNum * sizeof(int32_t)), 0, 0, 0}, + {false, 0, 0, 0}); + if constexpr (!IsSameType::value) { + MTE2ToVSync(); + Cast(sortedValueLocalFp32, sortedValueLocal, RoundMode::CAST_NONE, loopDataNum); + PipeBarrier(); + } else { + MTE2ToVSync(); + } + GetKthResult(loopBatch, 0, repeatTimes); + PipeBarrier(); + CumSumWithAddsAndExpImpl(0, loopDataNum, cumsumInner, cumsumData); + VToSSync(); + float cumsumDataTmp = cumSumRes.GetValue(loopDataNum - 1); + cumsumData = cumsumDataTmp; + if (cumsumDataTmp <= pValue) { + continue; + } + ScatterCumtomImpl(loopBatch, loopDataNum, 0); + } +} + +template +__aicore__ inline void ApplyTopKTopPCustom::ProcessTopK() { + kLocal = kInQueue_.AllocTensor(); + outTensor = outQueue_.AllocTensor(); + sortedValueLocal = sortedValueInQueue_.AllocTensor(); + sortedIndicesLocal = sortedIndicesInQueue_.AllocTensor(); + Duplicate(negInfLocal.template ReinterpretCast(), FLOAT32_NEG_INF, DATA_PER_BLOCK_B32); + if constexpr (IsSameType::value) { + calLocalFp32 = sortedValueLocal; + Duplicate(outTensor.template ReinterpretCast(), FLOAT32_NEG_INF, ubFactorElementAligned_); + } else if constexpr (IsSameType::value) { + calLocalFp32 = sortedValueLocalFp32; + Duplicate(outTensor.template ReinterpretCast(), FLOAT16_NEG_INF, ubFactorElementAligned_); + } else { + calLocalFp32 = sortedValueLocalFp32; + Duplicate(outTensor.template ReinterpretCast(), BF16_NEG_INF, ubFactorElementAligned_); + } + VToMTE3Sync(); + for (uint32_t loopBatch = 0; loopBatch < loopBatch_; loopBatch++) { + baseGmIdx_ = batchOffset_ * vocabSize_ + loopBatch * vocabSize_; + InitProcessTopK(loopBatch); + /* The difference lies in that for the max branch, some data is less than the kthvalue, + so part of the data can be filtered out in advance; + while for the remain branch, all data must undergo the topk calculation.*/ + if (calLocalFp32.GetValue(ubFactorElementAligned_) < kthValue) { + ProcessKLtKMaxTopK(loopBatch); + } else { + ProcessRemainTopK(loopBatch); + } + } + kInQueue_.FreeTensor(kLocal); + sortedValueInQueue_.FreeTensor(sortedValueLocal); + sortedIndicesInQueue_.FreeTensor(sortedIndicesLocal); + outQueue_.FreeTensor(outTensor); +} + +template +__aicore__ inline void ApplyTopKTopPCustom::ProcessRemainTopK(uint32_t loopBatch) { + int32_t firstKLoop = 0; + GetFirstKLoopTopK(loopBatch, firstKLoop); + // Start the scatter calculation from the first loop in the row where the value is ≥ kthValue. + ScatterFromFirstKLoopTopK(loopBatch, firstKLoop); + /* Perform scatter calculation on the maximum number of ubFactorElementAligned_, + which does not overlap with the previous ones.*/ + uint32_t loopProb = dataNumInit_ - 1; + scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb)); + SToMTE3Sync(); + int32_t gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb); + DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex], + scatterTensor.template ReinterpretCast(), scatterCopyParams); + MTE3ToSSync(); + ScatterCumtomImplTopK(loopBatch, dataNumInit_ - 1, ubFactorElementAligned_); +} + +template +__aicore__ inline void ApplyTopKTopPCustom::GetFirstKLoopTopK(uint32_t loopBatch, + int32_t &firstKLoop) { + uint8_t repeatTimes = (dataNumInit_ + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32; + uint32_t loopDataNum = ubFactorElementAligned_; + for (int32_t loopInner = 0; loopInner < loopInner_; loopInner++) { + int64_t currentGmIdx = baseGmIdx_ + loopInner * ubFactorElementAligned_; + if (loopInner == (loopInner_ - 1)) { + repeatTimes = ((tailUbFactorElement_) + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32; + loopDataNum = tailUbFactorElement_; + } + DataCopyPad(mGmOut_[currentGmIdx], outTensor, {1, (uint32_t)(loopDataNum * sizeof(outputT)), 0, 0, 0}); + DataCopyPad(sortedValueLocal.template ReinterpretCast(), mGmSortedValue_[currentGmIdx], + {1, static_cast(loopDataNum * sizeof(inputT)), 0, 0, 0}, + {false, 0, 0, 0}); + MTE2ToSSync(); + float rightVlaue = 0; + // Make a judgment on the rightmost value of each loop to filter the data. + if constexpr (IsSameType::value) { + rightVlaue = ToFloat(sortedValueLocal.GetValue(loopDataNum - 1)); + } else { + rightVlaue = static_cast(sortedValueLocal.GetValue(loopDataNum - 1)); + } + SToMTE2Sync(); + if (rightVlaue < kthValue) { + firstKLoop += 1; + continue; + } + } +} + +template +__aicore__ inline void ApplyTopKTopPCustom::ScatterFromFirstKLoopTopK(uint32_t loopBatch, + int32_t firstKLoop) { + uint32_t loopDataNum = ubFactorElementAligned_; + uint32_t cumsumInner = ubFactorElementAligned_; + uint8_t repeatTimes = ((ubFactorElementAligned_) + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32; + for (int32_t loopInner = firstKLoop; loopInner < loopInner_; loopInner++) { + int64_t currentGmIdx = baseGmIdx_ + loopInner * ubFactorElementAligned_; + if (loopInner == (loopInner_ - 1)) { + repeatTimes = (tailUbFactorElement_ + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32; + loopDataNum = tailUbFactorElement_; + cumsumInner = tailUbFactorElementAligned_; + } + DataCopyPad(sortedValueLocal.template ReinterpretCast(), mGmSortedValue_[currentGmIdx], + {1, static_cast(loopDataNum * sizeof(inputT)), 0, 0, 0}, + {false, 0, 0, 0}); + if constexpr (!IsSameType::value) { + MTE2ToVSync(); + Cast(sortedValueLocalFp32, sortedValueLocal, RoundMode::CAST_NONE, loopDataNum); + VToSSync(); + } + DataCopyPad(sortedIndicesLocal, mGmSortedIndices_[currentGmIdx], + {1, static_cast(loopDataNum * sizeof(int32_t)), 0, 0, 0}, + {false, 0, 0, 0}); + MTE2ToSSync(); + ScatterCumtomImplTopK(loopBatch, loopDataNum, 0); + } +} + +template +__aicore__ inline void ApplyTopKTopPCustom::ScatterCumtomImplTopK(uint32_t loopBatch, + uint32_t loopProbNum, uint32_t offset) { + // Reverse traversal, returning early to improve performance. + for (int32_t loopProb = static_cast(loopProbNum) - 1; loopProb >= 0; loopProb--) { + float curValue = calLocalFp32[offset].GetValue(loopProb); + if (curValue < kthValue) { + break; + } + scatterTensor.SetValue(0, sortedValueLocal[offset].GetValue(loopProb)); + int32_t gmIndex = sortedIndicesLocal[offset].GetValue(loopProb); + SToMTE3Sync(); + DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex], scatterTensor.template ReinterpretCast(), + {1, (uint32_t)(1 * sizeof(outputT)), 0, 0, 0}); + MTE3ToSSync(); + } +} + +template +__aicore__ inline void ApplyTopKTopPCustom::InitProcessTopK(uint32_t loopBatch) { + int64_t initGmIdx = baseGmIdx_ + vocabSize_ - dataNumInit_; + DataCopyPad(mGmOut_[initGmIdx], outTensor, {1, (uint32_t)(dataNumInit_ * sizeof(outputT)), 0, 0, 0}); + DataCopyPad(sortedValueLocal[ubFactorElementAligned_], mGmSortedValue_[initGmIdx], + {1, static_cast(dataNumInit_ * sizeof(inputT)), 0, 0, 0}, + {false, 0, 0, 0}); + if constexpr (!IsSameType::value) { + MTE2ToVSync(); + Cast(sortedValueLocalFp32[ubFactorElementAligned_], sortedValueLocal[ubFactorElementAligned_], + RoundMode::CAST_NONE, dataNumInit_); + } + DataCopyPad(sortedIndicesLocal[ubFactorElementAligned_], mGmSortedIndices_[initGmIdx], + {1, static_cast(dataNumInit_ * sizeof(int32_t)), 0, 0, 0}, + {false, 0, 0, 0}); + DataCopyPad(kLocal, mGmK_[batchOffset_ + loopBatch], {1, static_cast(sizeof(int32_t)), 0, 0, 0}, + {false, 0, 0, 0}); + MTE2ToSSync(); + int32_t kValue = mGmK_.GetValue(batchOffset_ + loopBatch); + maxValue = -calLocalFp32[ubFactorElementAligned_].GetValue(dataNumInit_ - 1); + if constexpr (IsSameType::value) { + kthValue = mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0); + } else if constexpr (IsSameType::value) { + kthValue = static_cast(mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0)); + } else { + kthValue = ToFloat(mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0)); + } +} + +template +__aicore__ inline void ApplyTopKTopPCustom::ProcessKLtKMaxTopK(uint32_t loopBatch) { + DataCopyExtParams copyParams{1, (uint32_t)(ubFactorElementAligned_ * sizeof(outputT)), 0, 0, 0}; + // Move out -infinity to fill GM + for (int32_t loopInner = 0; loopInner < loopInner_; loopInner++) { + int64_t currentGmIdxInner = baseGmIdx_ + loopInner * ubFactorElementAligned_; + if (loopInner == loopInner_ - 1) { + DataCopyPad(mGmOut_[currentGmIdxInner], outTensor, + {1, (uint32_t)(tailUbFactorElement_ * sizeof(outputT)), 0, 0, 0}); + } else { + DataCopyPad(mGmOut_[currentGmIdxInner], outTensor, copyParams); + } + } + // Scatter calculation + int32_t loopProb = dataNumInit_ - 1; + scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb)); + int32_t gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb); + SToMTE3Sync(); + PipeBarrier(); + DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex], scatterTensor.template ReinterpretCast(), scatterCopyParams); + loopProb = loopProb - 1; + + for (; loopProb >= 0; loopProb--) { + float curValue = calLocalFp32[ubFactorElementAligned_].GetValue(loopProb); + if (curValue < kthValue) { + break; + } + MTE3ToSSync(); + scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb)); + gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb); + SToMTE3Sync(); + DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex], + scatterTensor.template ReinterpretCast(), scatterCopyParams); + } +} + +} // namespace + +#endif // APPLY_TOP_K_TOP_P_CUSTOM_H_KERNEL \ No newline at end of file diff --git a/csrc/apply_top_k_top_p_custom/op_kernel/apply_top_p_custom.h b/csrc/apply_top_k_top_p_custom/op_kernel/apply_top_p_custom.h new file mode 100644 index 00000000..f5861987 --- /dev/null +++ b/csrc/apply_top_k_top_p_custom/op_kernel/apply_top_p_custom.h @@ -0,0 +1,468 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file apply_top_p_custom.h + * \brief + */ +#ifndef APPLY_TOP_P_CUSTOM_H_KERNEL +#define APPLY_TOP_P_CUSTOM_H_KERNEL + +#include "kernel_operator.h" + +using namespace AscendC; +namespace ApplyTopPCustomOp { +constexpr uint16_t FLOAT16_NEG_INF = 0xFC00; // -inf 64512 +constexpr uint16_t BF16_NEG_INF = 0xFF80; // -inf 65408 +constexpr int32_t FLOAT32_NEG_INF = 0xFF800000; // -inf -2139095040 + +constexpr uint32_t BLOCK_BYTES = 32; +constexpr uint32_t DATA_PER_BLOCK_B32 = 8; +constexpr uint32_t DATA_PER_REPEAT_B32 = 64; +constexpr uint32_t SCATTER_PART_LENGTH = 1024; +constexpr uint32_t RESERVED_UB = 1024; +constexpr uint32_t FLOAT_BYTES = 4; +constexpr uint32_t SOFTMAX_UB_NUM = 2; + +template +class ApplyTopPCustom { +public: + __aicore__ inline ApplyTopPCustom(){}; + __aicore__ inline void InitTilingData( + const ApplyTopKTopPCustomTilingData &__restrict tilingData, GM_ADDR sorted_value, GM_ADDR sorted_indices, GM_ADDR p, GM_ADDR k, GM_ADDR out, GM_ADDR workspace); + __aicore__ inline void InitBuffer(TPipe *inputPipe); + __aicore__ inline void ProcessTopP(); +private: + __aicore__ inline void ReduceSumWithAddsAndExpImpl(uint32_t loopDataNum); + // topp func + __aicore__ inline void ProcessPreSingleBatch(uint32_t loopBatch); + __aicore__ inline void GetSoftmaxSum(uint32_t loopBatch); + __aicore__ inline void CumsumKoggleStone(uint32_t loopBatch); + __aicore__ inline void GetPValue(uint32_t batchOffset); + __aicore__ inline void CumsumParamCompute(uint32_t iterateTime); + __aicore__ inline void GetMaxValue(int64_t baseGmIdx); + __aicore__ inline void GetSoftMaxRes(uint32_t loopBatch); + __aicore__ inline void ScatterSingleTask(uint32_t taskIndex); + + __aicore__ inline void SToMTE3Sync() { + event_t eventIDSToMTE3 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_MTE3)); + SetFlag(eventIDSToMTE3); + WaitFlag(eventIDSToMTE3); + } + __aicore__ inline void VToMTE3Sync() { + event_t eventIDVToMTE3 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3)); + SetFlag(eventIDVToMTE3); + WaitFlag(eventIDVToMTE3); + } + __aicore__ inline void VToMTE2Sync() { + event_t eventIDVToMTE2 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2)); + SetFlag(eventIDVToMTE2); + WaitFlag(eventIDVToMTE2); + } + __aicore__ inline void MTE3ToSSync() { + event_t eventIDMTE3ToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE3_S)); + SetFlag(eventIDMTE3ToS); + WaitFlag(eventIDMTE3ToS); + } + __aicore__ inline void VToSSync() { + event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + SetFlag(eventIdVToS); + WaitFlag(eventIdVToS); + } + __aicore__ inline void SToVSync() { + event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + SetFlag(eventIdSToV); + WaitFlag(eventIdSToV); + } + __aicore__ inline void MTE2ToVSync() { + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + } + __aicore__ inline void MTE2ToSSync() { + event_t eventIdMte2ToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_S)); + SetFlag(eventIdMte2ToS); + WaitFlag(eventIdMte2ToS); + } + + __aicore__ inline void MTE3ToMTE2Sync() { + event_t eventIdMte3ToMte2 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2)); + SetFlag(eventIdMte3ToMte2); + WaitFlag(eventIdMte3ToMte2); + } + + __aicore__ inline void MTE3ToVSync() { + event_t eventIdMte3ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V)); + SetFlag(eventIdMte3ToV); + WaitFlag(eventIdMte3ToV); + } + + __aicore__ inline uint32_t CeilDiv(uint32_t x, uint32_t y) { + return y == 0 ? x : (x + y - 1) / y; + } +private: + TPipe *pipe_; + TBuf calBuf_; + + // tilingData + uint32_t batchSize_ = 0; + uint32_t vocabSize_ = 0; + uint32_t batchPerCore_ = 0; + uint32_t tailBatch_ = 0; + uint32_t blockNum_ = 0; + uint32_t dataNumInitAligned_ = 0; + uint32_t calUbSize_ = 0; + uint32_t blockIdx_ = 0; + uint32_t loopBatch_ = 0; + uint32_t batchOffset_ = 0; + uint32_t bufOffsetLoop = 0; + int64_t baseGmIdx_ = 0; + + // topp scalar + uint32_t maxSoftmaxLength = 1; + uint32_t softmaxLength = 1; + uint32_t lineSfLoopTimes = 1; + uint32_t softmaxLengthTail = 1; + uint32_t scatterLength = 1; + + uint32_t singleCoreB = 0; + uint32_t singleCoreBTail = 0; + uint32_t vCnt = 0; + uint32_t bCnt = 0; + uint32_t singleCoreV = 1; + uint32_t singleCoreVTail = 1; + uint32_t iterateTimes = 1; + + GlobalTensor mGmSortedValue_; + GlobalTensor mGmSortedIndices_; + GlobalTensor mGmP_; + GlobalTensor mGmK_; + GlobalTensor mGmOut_; + GlobalTensor softMaxGm; + + LocalTensor totalUb; + + // softmax tensor + LocalTensor softMaxLocalFp32; + LocalTensor softMaxLocal; + LocalTensor softMaxResLocal; + LocalTensor reduceLocal; + LocalTensor outInfLocal; + + // cumsum tensor + LocalTensor cumSumInput1Local; + LocalTensor cumSumInput2Local; + + // scatter tensor + LocalTensor sortedValueLocal; + LocalTensor sortedIndicesLocal; + LocalTensor sortedValueLocalFp32; + LocalTensor scatterLocal; + LocalTensor cumsumLocal; + + float pValue = 0; + float maxValue = 0; + float reduceSumValueInvert = 0; + float reduceSumValue = 0; + BinaryRepeatParams repeatParams = {1, 0, 1, 8, 0, 8}; + DataCopyExtParams scatterCopyParams{1, (uint32_t)(sizeof(outputT)), 0, 0, 0}; +}; + +template +__aicore__ inline void ApplyTopPCustom::InitTilingData( + const ApplyTopKTopPCustomTilingData &__restrict tilingData, GM_ADDR sorted_value, GM_ADDR sorted_indices, + GM_ADDR p, GM_ADDR k, GM_ADDR out, GM_ADDR workspace) { + batchSize_ = tilingData.batchSize; + vocabSize_ = tilingData.vocabSize; + batchPerCore_ = tilingData.batchPerCore; + tailBatch_ = tilingData.tailBatch; + blockNum_ = tilingData.blockNum; + calUbSize_ = tilingData.calUbSize; + iterateTimes = tilingData.iterateTimes; + blockIdx_ = GetBlockIdx(); + if (blockIdx_ < tailBatch_) { + loopBatch_ = batchPerCore_ + 1; + batchOffset_ = blockIdx_ * loopBatch_; + } else { + loopBatch_ = batchPerCore_; + batchOffset_ = blockIdx_ * batchPerCore_ + tailBatch_; + } + maxSoftmaxLength = (calUbSize_ - RESERVED_UB) / SOFTMAX_UB_NUM / FLOAT_BYTES; + softmaxLength = maxSoftmaxLength < vocabSize_ ? maxSoftmaxLength : vocabSize_; + + lineSfLoopTimes = (vocabSize_ + softmaxLength - 1) / softmaxLength; + softmaxLengthTail = vocabSize_ - (lineSfLoopTimes - 1) * softmaxLength; + scatterLength = (calUbSize_ - RESERVED_UB - BLOCK_BYTES) / (SOFTMAX_UB_NUM * FLOAT_BYTES + sizeof(inputT)) / + SCATTER_PART_LENGTH * SCATTER_PART_LENGTH; + singleCoreB = CeilDiv(batchSize_, blockNum_); + vCnt = batchSize_ < blockNum_ ? blockNum_ / batchSize_ : 1; + bCnt = batchSize_; + singleCoreB = 1; + singleCoreBTail = 1; + singleCoreV = vocabSize_ / vCnt; + singleCoreVTail = vocabSize_ - vCnt * singleCoreV; + mGmSortedValue_.SetGlobalBuffer(reinterpret_cast<__gm__ inputT *>(sorted_value)); + mGmSortedIndices_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sorted_indices)); + mGmP_.SetGlobalBuffer(reinterpret_cast<__gm__ inputT *>(p)); + mGmK_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(k)); + mGmOut_.SetGlobalBuffer(reinterpret_cast<__gm__ outputT *>(out)); + softMaxGm.SetGlobalBuffer((__gm__ float*)workspace, batchSize_ * vocabSize_); +} + +// init used buffer +template +__aicore__ inline void ApplyTopPCustom::InitBuffer(TPipe *inputPipe) { + pipe_ = inputPipe; + pipe_->InitBuffer(calBuf_, calUbSize_); + totalUb = calBuf_.Get(); + // softmax ub + uint32_t softmaxLengthAligned = CeilDiv(softmaxLength, BLOCK_BYTES / sizeof(inputT)) * BLOCK_BYTES / sizeof(inputT); + softMaxLocalFp32 = totalUb.ReinterpretCast(); + softMaxLocal = totalUb[softmaxLengthAligned * sizeof(inputT)].ReinterpretCast(); + softMaxResLocal = totalUb[softmaxLengthAligned * sizeof(float)].ReinterpretCast(); + reduceLocal = totalUb[softmaxLengthAligned * sizeof(float) * 2].ReinterpretCast(); // 32 bytes + outInfLocal = totalUb.ReinterpretCast(); // Take softmax ub + + // cumsum ub + cumSumInput1Local = totalUb.ReinterpretCast(); // Take softmax local + cumSumInput2Local = totalUb[softmaxLengthAligned * sizeof(float)].ReinterpretCast(); // Take softmax res ub + + // scatter ub + sortedValueLocal = totalUb[0].ReinterpretCast(); + sortedIndicesLocal = totalUb[scatterLength * sizeof(inputT)].ReinterpretCast(); + cumsumLocal = totalUb[scatterLength * (FLOAT_BYTES + sizeof(inputT))].ReinterpretCast(); + scatterLocal = totalUb[calUbSize_ - RESERVED_UB + BLOCK_BYTES].ReinterpretCast(); // 32 bytes +} + +template +__aicore__ inline void ApplyTopPCustom::GetMaxValue(int64_t baseGmIdx) { + int64_t initGmIdx = baseGmIdx + vocabSize_ - 1; + if constexpr (IsSameType::value) { + maxValue = -mGmSortedValue_[initGmIdx].GetValue(0); + } else if constexpr (IsSameType::value) { + maxValue = -static_cast(mGmSortedValue_[initGmIdx].GetValue(0)); + } else { + maxValue = -ToFloat(mGmSortedValue_[initGmIdx].GetValue(0)); + } +} + +template +__aicore__ inline void ApplyTopPCustom::GetPValue(uint32_t batchOffset) { + if constexpr (IsSameType::value) { + pValue = float(1.0) - mGmP_[batchOffset].GetValue(0); + } else if constexpr (IsSameType::value) { + pValue = float(1.0) - static_cast(mGmP_[batchOffset].GetValue(0)); + } else { + pValue = float(1.0) - ToFloat(mGmP_[batchOffset].GetValue(0)); + } +} + +template +__aicore__ inline void ApplyTopPCustom::ProcessPreSingleBatch(uint32_t loopBatch) { + reduceSumValue = 0; + GetSoftmaxSum(loopBatch); + GetSoftMaxRes(loopBatch); + CumsumKoggleStone(loopBatch); +} + +template +__aicore__ inline void ApplyTopPCustom::ProcessTopP() { + for (uint32_t loopBatch = 0; loopBatch < loopBatch_; loopBatch++) { + baseGmIdx_ = batchOffset_ * vocabSize_ + loopBatch * vocabSize_; + GetMaxValue(baseGmIdx_); // Get max value in softmax. + ProcessPreSingleBatch(loopBatch); // Softmax and cumsum. + } + SyncAll(); + for (uint32_t taskIndex = 0; taskIndex < bCnt * vCnt; taskIndex++) { + ScatterSingleTask(taskIndex); + } +} + +template +__aicore__ inline void ApplyTopPCustom::ScatterSingleTask(uint32_t taskIndex) { + if (GetBlockIdx() == taskIndex % blockNum_) { + uint32_t bCntIndex = taskIndex / vCnt; + uint32_t vCntIndex = taskIndex % vCnt; + uint32_t vCurSingleCore = vCntIndex < singleCoreVTail ? (singleCoreV + 1) : singleCoreV; + uint32_t copyTimes = CeilDiv(vCurSingleCore, scatterLength); + uint32_t copyLength = scatterLength; + uint32_t copyLengthTail = vCurSingleCore - (copyTimes - 1) * scatterLength; + GetPValue(bCntIndex); // Get maxPValue. + for (uint32_t cpIndex = 0; cpIndex < copyTimes; cpIndex++) { + uint32_t curCopyLength = cpIndex == (copyTimes - 1) ? copyLengthTail : copyLength; + int64_t gmOffset = vCntIndex < singleCoreVTail ? + bCntIndex * vocabSize_ + vCntIndex * (singleCoreV + 1) + cpIndex * copyLength : + bCntIndex * vocabSize_ + vCntIndex * singleCoreV + singleCoreVTail + cpIndex * copyLength; + DataCopyPad(cumsumLocal, softMaxGm[gmOffset], + {1, static_cast(curCopyLength * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + DataCopyPad(sortedIndicesLocal, mGmSortedIndices_[gmOffset], + {1, static_cast(curCopyLength * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + DataCopyPad(sortedValueLocal, mGmSortedValue_[gmOffset], + {1, static_cast(curCopyLength * sizeof(inputT)), 0, 0, 0}, {false, 0, 0, 0}); + MTE2ToSSync(); + + if (cumsumLocal.GetValue(curCopyLength - 1) <= pValue) { + continue; + } + uint32_t scatterLoop = CeilDiv(curCopyLength, SCATTER_PART_LENGTH); + uint32_t scatterNumsTail = curCopyLength - (scatterLoop - 1) * SCATTER_PART_LENGTH; + for (uint32_t scatterLoopIndex = 0; scatterLoopIndex < scatterLoop; scatterLoopIndex++) { + uint32_t curScatterNums = scatterLoopIndex == (scatterLoop - 1) ? scatterNumsTail : SCATTER_PART_LENGTH; + if (cumsumLocal.GetValue(scatterLoopIndex * SCATTER_PART_LENGTH + curScatterNums - 1) <= pValue) { + continue; + } + for (uint32_t scatterIndex = 0; scatterIndex < curScatterNums; scatterIndex++) { + int64_t scatterOffset = scatterLoopIndex * SCATTER_PART_LENGTH + scatterIndex; + if (cumsumLocal.GetValue(scatterOffset) <= pValue) { + continue; + } + scatterLocal.SetValue(0, sortedValueLocal.GetValue(scatterOffset)); + int32_t lineIndex = sortedIndicesLocal.GetValue(scatterOffset); + SToMTE3Sync(); + DataCopyPad(mGmOut_[bCntIndex * vocabSize_ + lineIndex], scatterLocal.template ReinterpretCast(), + {1, (uint32_t)(1 * sizeof(outputT)), 0, 0, 0}); + MTE3ToSSync(); + } + } + } + } +} +template +__aicore__ inline void ApplyTopPCustom::GetSoftMaxRes(uint32_t loopBatch) { + uint32_t loopDataNum = softmaxLength; + for (int32_t loopInner = 0; loopInner < lineSfLoopTimes; loopInner++) { + int64_t currentGmIdx = baseGmIdx_ + loopInner * softmaxLength; + if (loopInner == (lineSfLoopTimes - 1)) { + loopDataNum = softmaxLengthTail; + } + if constexpr (!IsSameType::value) { + DataCopyPad(softMaxLocal, mGmSortedValue_[currentGmIdx], + {1, static_cast(loopDataNum * sizeof(inputT)), 0, 0, 0}, + {false, 0, 0, 0}); + MTE2ToVSync(); + Cast(softMaxLocalFp32, softMaxLocal, RoundMode::CAST_NONE, loopDataNum); + PipeBarrier(); + } else { + DataCopyPad(softMaxLocalFp32, mGmSortedValue_[currentGmIdx], + {1, static_cast(loopDataNum * sizeof(float)), 0, 0, 0}, + {false, 0, 0, 0}); + MTE2ToVSync(); + } + Adds(softMaxResLocal, softMaxLocalFp32, maxValue, loopDataNum); + VToMTE2Sync(); + PipeBarrier(); + Exp(softMaxResLocal, softMaxResLocal, loopDataNum); + PipeBarrier(); + Muls(softMaxResLocal, softMaxResLocal, reduceSumValueInvert, loopDataNum); + VToMTE3Sync(); + DataCopyPad(softMaxGm[currentGmIdx], softMaxResLocal, + {1, static_cast(loopDataNum * sizeof(float)), 0, 0, 0}); + MTE3ToMTE2Sync(); + } +} + +template +__aicore__ inline void ApplyTopPCustom::CumsumKoggleStone(uint32_t loopBatch) { + uint32_t loopDataNum = softmaxLength; + for (uint32_t iterateTime = 0; iterateTime < iterateTimes; iterateTime++) { + int64_t iteratOffset = 1; + for (uint32_t powerIdx = 0; powerIdx < iterateTime; powerIdx++) { + iteratOffset = iteratOffset * 2; + } + uint32_t addLength = vocabSize_ - iteratOffset; + uint32_t innerLoopNum = addLength / softmaxLength; + uint32_t dataTail = addLength - innerLoopNum * softmaxLength; + loopDataNum = softmaxLength; + for (uint32_t innerLoopIdx = 0; innerLoopIdx < innerLoopNum; innerLoopIdx++) { + // Copy data from right + int64_t loopInnerOffset = dataTail + (innerLoopNum - 1 - innerLoopIdx) * softmaxLength; + DataCopyPad(cumSumInput1Local, softMaxGm[baseGmIdx_ + loopInnerOffset], + {1, static_cast(loopDataNum * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + DataCopyPad(cumSumInput2Local, softMaxGm[baseGmIdx_ + loopInnerOffset + iteratOffset], + {1, static_cast(loopDataNum * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + MTE2ToVSync(); + Add(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, loopDataNum); + VToMTE3Sync(); + DataCopyPad(softMaxGm[baseGmIdx_ + loopInnerOffset + iteratOffset], cumSumInput1Local, + {1, static_cast(loopDataNum * sizeof(float)), 0, 0, 0}); + MTE3ToMTE2Sync(); + } + if (dataTail > 0) { + loopDataNum = dataTail; + DataCopyPad(cumSumInput1Local, softMaxGm[baseGmIdx_], + {1, static_cast(loopDataNum * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + DataCopyPad(cumSumInput2Local, softMaxGm[baseGmIdx_ + iteratOffset], + {1, static_cast(loopDataNum * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + MTE2ToVSync(); + Add(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, loopDataNum); + VToMTE3Sync(); + DataCopyPad(softMaxGm[baseGmIdx_ + iteratOffset], cumSumInput1Local, + {1, static_cast(loopDataNum * sizeof(float)), 0, 0, 0}); + MTE3ToMTE2Sync(); + } + } + MTE3ToVSync(); +} + +template +__aicore__ inline void ApplyTopPCustom::ReduceSumWithAddsAndExpImpl( + uint32_t loopDataNum) { + Adds(softMaxResLocal, softMaxLocalFp32, maxValue, loopDataNum); + PipeBarrier(); + Exp(softMaxResLocal, softMaxResLocal, loopDataNum); + PipeBarrier(); + ReduceSum(reduceLocal, softMaxResLocal, reduceLocal, loopDataNum); +} + +template +__aicore__ inline void ApplyTopPCustom::GetSoftmaxSum(uint32_t loopBatch) { + uint32_t loopDataNum = softmaxLength; + for (int32_t loopInner = 0; loopInner < lineSfLoopTimes; loopInner++) { + int64_t currentGmIdx = baseGmIdx_ + loopInner * softmaxLength; + if (loopInner == (lineSfLoopTimes - 1)) { + loopDataNum = softmaxLengthTail; + } + if constexpr (IsSameType::value) { + Duplicate(outInfLocal.template ReinterpretCast(), FLOAT32_NEG_INF, loopDataNum); + } else if constexpr (IsSameType::value) { + Duplicate(outInfLocal.template ReinterpretCast(), FLOAT16_NEG_INF, loopDataNum); + } else { + Duplicate(outInfLocal.template ReinterpretCast(), BF16_NEG_INF, loopDataNum); + } + VToMTE3Sync(); + DataCopyPad(mGmOut_[currentGmIdx], outInfLocal, + {1, static_cast(loopDataNum * sizeof(inputT)), 0, 0, 0}); + MTE3ToMTE2Sync(); + if constexpr (!IsSameType::value) { + DataCopyPad(softMaxLocal, mGmSortedValue_[currentGmIdx], + {1, static_cast(loopDataNum * sizeof(inputT)), 0, 0, 0}, + {false, 0, 0, 0}); + MTE2ToVSync(); + Cast(softMaxLocalFp32, softMaxLocal, RoundMode::CAST_NONE, loopDataNum); + PipeBarrier(); + } else { + DataCopyPad(softMaxLocalFp32, mGmSortedValue_[currentGmIdx], + {1, static_cast(loopDataNum * sizeof(inputT)), 0, 0, 0}, + {false, 0, 0, 0}); + MTE2ToVSync(); + } + + ReduceSumWithAddsAndExpImpl(loopDataNum); + VToSSync(); + // Sum up to obtain the sum of exp reduce for the first x loops in the row. + reduceSumValue += reduceLocal.GetValue(0); + SToVSync(); + } + reduceSumValueInvert = 1 / reduceSumValue; + SToVSync(); +} +} // namespace + +#endif // APPLY_TOP_P_CUSTOM_H_KERNEL \ No newline at end of file diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index c6cf1bb8..9ed497f0 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd) export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH} - CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;" + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series @@ -80,6 +80,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then "moe_init_routing_custom" "moe_gating_top_k" "add_rms_norm_bias" + "apply_top_k_top_p_custom" ) CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}") SOC_ARG="ascend910_93" diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 80751bae..b08bb92f 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -1234,6 +1234,26 @@ std::tuple npu_moe_init_routing_ return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale); } +at::Tensor npu_apply_top_k_top_p( + const at::Tensor& logits, + const c10::optional& p, + const c10::optional& k) +{ + TORCH_CHECK(p.has_value() || k.has_value(), + "apply_top_k_top_p: p and k cannot be None at the same time."); + + at::Tensor out = at::empty_like(logits); + + EXEC_NPU_CMD( + aclnnApplyTopKTopPCustom, + logits, + p, + k, + out); + + return out; +} + std::tuple moe_gating_top_k( const at::Tensor& x, int64_t k, @@ -1495,4 +1515,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) "-> (Tensor y ,Tensor rstd, Tensor x)" ); ops.impl("npu_add_rms_norm_bias", torch::kPrivateUse1, &vllm_ascend::npu_add_rms_norm_bias); + + ops.def("npu_apply_top_k_top_p(Tensor logits, Tensor? p=None, Tensor? k=None) -> Tensor"); + ops.impl("npu_apply_top_k_top_p", torch::kPrivateUse1, &vllm_ascend::npu_apply_top_k_top_p); } diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py index 9363aa49..62410cdb 100644 --- a/vllm_ascend/sample/sampler.py +++ b/vllm_ascend/sample/sampler.py @@ -3,7 +3,7 @@ from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler from vllm.v1.sample.sampler import Sampler from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import global_stream, npu_stream_switch +from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, global_stream, npu_stream_switch DEFAULT_LOGPROBS_MODE = "raw_logprobs" @@ -90,7 +90,7 @@ class AscendTopKTopPSampler(TopKTopPSampler): return random_sample(probs, generators), logits_to_return -def apply_top_k_top_p( +def _apply_top_k_top_p_pytorch( logits: torch.Tensor, k: torch.Tensor, p: torch.Tensor, @@ -124,3 +124,15 @@ def apply_top_k_top_p( logits.masked_fill_(elements_to_discard, -float("inf")) return logits + + +def _apply_top_k_top_p_ascendc( + logits: torch.Tensor, + k: torch.Tensor, + p: torch.Tensor, +) -> torch.Tensor: + if p is None and k is None: + return logits + return torch.ops._C_ascend.npu_apply_top_k_top_p(logits, k=k, p=p) + +apply_top_k_top_p = _apply_top_k_top_p_ascendc if get_ascend_device_type() in [AscendDeviceType.A2, AscendDeviceType.A3] else _apply_top_k_top_p_pytorch \ No newline at end of file