[ops] support advanced apply_top_k_top_p without top_k constraint (#6098)
### What this PR does / why we need it?
Implement `apply_top_k_top_p` via ascendC to eliminate the constraint of
k [1,1024]. It enables high performance TopKTopP calculation and avoid
D2H synchronization introduced by k validation.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
E2E serving with `k=4096` and `p=0.95`
- vLLM version: v0.13.0
- vLLM main:
d68209402d
---------
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
41
csrc/apply_top_k_top_p_custom/op_host/CMakeLists.txt
Normal file
41
csrc/apply_top_k_top_p_custom/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,41 @@
|
||||
add_ops_compile_options(
|
||||
OP_NAME ApplyTopKTopPCustom
|
||||
OPTIONS --cce-auto-sync=on
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnnExc PRIVATE
|
||||
apply_top_k_top_p_custom_def.cpp
|
||||
)
|
||||
|
||||
target_sources(opapi PRIVATE
|
||||
apply_top_k_top_p_custom.cpp
|
||||
aclnn_apply_top_k_top_p_custom.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(aclnn_ops_train PRIVATE
|
||||
apply_top_k_top_p_custom.cpp
|
||||
aclnn_apply_top_k_top_p_custom.cpp
|
||||
)
|
||||
|
||||
target_sources(aclnn_ops_infer PRIVATE
|
||||
apply_top_k_top_p_custom.cpp
|
||||
aclnn_apply_top_k_top_p_custom.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
apply_top_k_top_p_custom_tiling.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_apply_top_k_top_p_custom.h")
|
||||
|
||||
install(FILES ${_GMM_Aclnn_header}
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||
)
|
||||
@@ -0,0 +1,213 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file aclnn_apply_top_k_top_p_custom.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "aclnn_apply_top_k_top_p_custom.h"
|
||||
#include "apply_top_k_top_p_custom.h"
|
||||
#include "sort.h"
|
||||
#include "aclnn_kernels/contiguous.h"
|
||||
#include "aclnn_kernels/common/op_error_check.h"
|
||||
#include "aclnn/aclnn_base.h"
|
||||
#include "opdev/common_types.h"
|
||||
#include "opdev/data_type_utils.h"
|
||||
#include "opdev/make_op_executor.h"
|
||||
#include "opdev/format_utils.h"
|
||||
#include "opdev/op_dfx.h"
|
||||
#include "opdev/op_executor.h"
|
||||
#include "opdev/op_log.h"
|
||||
#include "opdev/tensor_view_utils.h"
|
||||
#include "opdev/shape_utils.h"
|
||||
#include "opdev/platform.h"
|
||||
|
||||
using namespace op;
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
namespace {
|
||||
static const int64_t EXPECTED_DIM_ONE = 1;
|
||||
static const int64_t EXPECTED_DIM_TWO = 2;
|
||||
static constexpr size_t DIM_ONE = 1;
|
||||
|
||||
// 根据API定义,需要列出所能支持的所有dtype
|
||||
static const std::initializer_list<op::DataType> DTYPE_SUPPORT_LIST = {
|
||||
op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16, op::DataType::DT_BF16};
|
||||
|
||||
static const std::initializer_list<op::DataType> INT_DTYPE_SUPPORT_LIST = {
|
||||
op::DataType::DT_INT32};
|
||||
|
||||
static bool CheckNotNull(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
|
||||
{
|
||||
OP_CHECK_NULL(logits, return false);
|
||||
if (p == nullptr && k == nullptr) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The inputs, p and k, should not be nullptr at the same time.");
|
||||
}
|
||||
OP_CHECK_NULL(out, return false);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckDtypeValid(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
|
||||
{
|
||||
// 检查数据类型是否在支持列表内
|
||||
OP_CHECK_DTYPE_NOT_SUPPORT(logits, DTYPE_SUPPORT_LIST, return false);
|
||||
if (p != nullptr) {
|
||||
OP_CHECK_DTYPE_NOT_SUPPORT(p, DTYPE_SUPPORT_LIST, return false);
|
||||
}
|
||||
if (k != nullptr) {
|
||||
OP_CHECK_DTYPE_NOT_SUPPORT(k, INT_DTYPE_SUPPORT_LIST, return false);
|
||||
}
|
||||
OP_CHECK_DTYPE_NOT_SUPPORT(out, DTYPE_SUPPORT_LIST, return false);
|
||||
|
||||
// 检查数据类型是否相同
|
||||
if (p != nullptr) {
|
||||
OP_CHECK_DTYPE_NOT_MATCH(p, logits->GetDataType(), return false);
|
||||
}
|
||||
OP_CHECK_DTYPE_NOT_MATCH(out, logits->GetDataType(), return false);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckShapeValid(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
|
||||
{
|
||||
OP_CHECK_WRONG_DIMENSION(logits, EXPECTED_DIM_TWO, return false);
|
||||
OP_CHECK_SHAPE_NOT_EQUAL(out, logits, return false);
|
||||
if (p != nullptr) {
|
||||
OP_CHECK_WRONG_DIMENSION(p, EXPECTED_DIM_ONE, return false);
|
||||
}
|
||||
if (k != nullptr) {
|
||||
OP_CHECK_WRONG_DIMENSION(k, EXPECTED_DIM_ONE, return false);
|
||||
}
|
||||
if (p != nullptr && p->GetViewShape().GetDim(0) != logits->GetViewShape().GetDim(0)) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "expected p.size(0) is equal to logits.size(0), but got %ld.",
|
||||
p->GetViewShape().GetDim(0));
|
||||
return false;
|
||||
}
|
||||
if (k != nullptr && k->GetViewShape().GetDim(0) != logits->GetViewShape().GetDim(0)) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "expected k.size(0) is equal to logits.size(0), but got %ld.",
|
||||
k->GetViewShape().GetDim(0));
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckFormatValid(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
|
||||
{
|
||||
if (logits->GetStorageFormat() != Format::FORMAT_ND) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "logits format only support ND");
|
||||
return false;
|
||||
}
|
||||
if (p != nullptr && p->GetStorageFormat() != Format::FORMAT_ND) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "p format only support ND");
|
||||
return false;
|
||||
}
|
||||
if (k != nullptr && k->GetStorageFormat() != Format::FORMAT_ND) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "k format only support ND");
|
||||
return false;
|
||||
}
|
||||
if (out->GetStorageFormat() != Format::FORMAT_ND) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "out format only support ND");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static aclnnStatus CheckParams(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out)
|
||||
{
|
||||
// 错误码等DFX方案细化后刷新,错误日志在check接口内打印
|
||||
// 1. 检查参数是否为空指针
|
||||
CHECK_RET(CheckNotNull(logits, p, k, out), ACLNN_ERR_PARAM_NULLPTR);
|
||||
|
||||
// 2. 检查输入的数据类型是否在API支持的数据类型范围之内,需要根据api定义校验
|
||||
CHECK_RET(CheckDtypeValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID);
|
||||
|
||||
// 3. 检查shape是否满足约束
|
||||
CHECK_RET(CheckShapeValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID);
|
||||
|
||||
// 4. 检查format是否满足约束
|
||||
CHECK_RET(CheckFormatValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID);
|
||||
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize(
|
||||
const aclTensor* logits, const aclTensor* p, const aclTensor* k, aclTensor* out, uint64_t* workspaceSize,
|
||||
aclOpExecutor** executor)
|
||||
{
|
||||
OP_CHECK_COMM_INPUT(workspaceSize, executor);
|
||||
L2_DFX_PHASE_1(aclnnApplyTopKTopPCustom, DFX_IN(logits, p, k), DFX_OUT(out));
|
||||
// 固定写法,创建OpExecutor
|
||||
auto uniqueExecutor = CREATE_EXECUTOR();
|
||||
CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
|
||||
|
||||
// 固定写法,参数检查
|
||||
auto ret = CheckParams(logits, p, k, out);
|
||||
CHECK_RET(ret == ACLNN_SUCCESS, ret);
|
||||
bool pIsEmpty = false;
|
||||
bool kIsEmpty = false;
|
||||
if (p != nullptr) {
|
||||
pIsEmpty = p->IsEmpty();
|
||||
}
|
||||
if (k != nullptr) {
|
||||
kIsEmpty = k->IsEmpty();
|
||||
}
|
||||
if (logits->IsEmpty() || pIsEmpty || kIsEmpty) {
|
||||
// 根据实际支持情况补充
|
||||
*workspaceSize = 0;
|
||||
uniqueExecutor.ReleaseTo(executor);
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
// 固定写法,将输入selfRef转换成连续的tensor
|
||||
auto logitsContiguous = l0op::Contiguous(logits, uniqueExecutor.get());
|
||||
CHECK_RET(logitsContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
const aclTensor* pContiguous = nullptr;
|
||||
const aclTensor* kContiguous = nullptr;
|
||||
if (p != nullptr) {
|
||||
pContiguous = l0op::Contiguous(p, uniqueExecutor.get());
|
||||
CHECK_RET(pContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
}
|
||||
if (k != nullptr) {
|
||||
kContiguous = l0op::Contiguous(k, uniqueExecutor.get());
|
||||
CHECK_RET(kContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
}
|
||||
bool isLastDimSizeOne = logits->GetViewShape()[DIM_ONE] == 1;
|
||||
auto viewCopyResult = logitsContiguous;
|
||||
if (isLastDimSizeOne) {
|
||||
viewCopyResult = l0op::ViewCopy(logitsContiguous, out, uniqueExecutor.get());
|
||||
} else {
|
||||
auto sortResult = l0op::Sort(logitsContiguous, -1, false, true, op::DataType::DT_INT32, uniqueExecutor.get());
|
||||
const aclTensor* sortedValue = std::get<0>(sortResult);
|
||||
CHECK_RET(sortedValue != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
const aclTensor* sortedIndices = std::get<1>(sortResult);
|
||||
CHECK_RET(sortedIndices != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
auto res = l0op::ApplyTopKTopPCustom(sortedValue, sortedIndices, pContiguous, kContiguous, uniqueExecutor.get());
|
||||
CHECK_RET(res != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
// 固定写法,将计算结果拷贝到输出out上,out可能是非连续的tensor
|
||||
viewCopyResult = l0op::ViewCopy(res, out, uniqueExecutor.get());
|
||||
}
|
||||
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
// 固定写法,获取计算过程中需要使用的workspace大小
|
||||
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
|
||||
uniqueExecutor.ReleaseTo(executor); // 需要把 uniqueExecutor持有executor转移给executor
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
|
||||
aclnnStatus aclnnApplyTopKTopPCustom(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)
|
||||
{
|
||||
L2_DFX_PHASE_2(aclnnApplyTopKTopPCustom);
|
||||
// 固定写法,调用框架能力,完成计算
|
||||
return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file aclnn_apply_top_k_top_p_custom.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef OP_API_INC_APPLY_TOP_K_TOP_P_CUSTOM_H_
|
||||
#define OP_API_INC_APPLY_TOP_K_TOP_P_CUSTOM_H_
|
||||
|
||||
#include "aclnn/aclnn_base.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief aclnnApplyTopKTopPCustom的第一段接口,根据具体的计算流程,计算workspace大小。
|
||||
* @domain aclnn_ops_infer
|
||||
* @param [in] logits: npu device侧的aclTensor,数据类型支持FLOAT、FLOAT16、BFLOAT16,支持非连续的Tensor,数据格式支持ND。
|
||||
* @param [in] p: npu device侧的aclTensor,数据类型支持FLOAT、FLOAT16、BFLOAT16,支持非连续的Tensor,数据格式支持ND。
|
||||
* @param [in] k: npu device侧的aclTensor,数据类型支持INT32,支持非连续的Tensor,数据格式支持ND。
|
||||
* @param [in] out: npu device侧的aclTensor,数据类型支持FLOAT、FLOAT16、BFLOAT16,支持非连续的Tensor,数据格式支持ND。
|
||||
* @param [out] workspaceSize: 返回用户需要在npu device侧申请的workspace大小。
|
||||
* @param [out] executor: 返回op执行器,包含算子计算流程。
|
||||
* @return aclnnStatus: 返回状态码。
|
||||
*/
|
||||
aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize(const aclTensor* logits, const aclTensor* p,
|
||||
const aclTensor* k, aclTensor* out, uint64_t* workspaceSize,
|
||||
aclOpExecutor** executor);
|
||||
|
||||
/**
|
||||
* @brief aclnnApplyTopKTopPCustom的第二段接口,用于执行计算。
|
||||
* @param [in] workspace: 在npu device侧申请的workspace内存起址。
|
||||
* @param [in] workspaceSize: 在npu device侧申请的workspace大小,由第一段接口aaclnnApplyTopKTopPCustomGetWorkspaceSize获取。
|
||||
* @param [in] stream: acl stream流。
|
||||
* @param [in] executor: op执行器,包含了算子计算流程。
|
||||
* @return aclnnStatus: 返回状态码。
|
||||
*/
|
||||
aclnnStatus aclnnApplyTopKTopPCustom(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor,
|
||||
aclrtStream stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // OP_API_INC_APPLY_TOP_K_TOP_P_CUSTOM_H_
|
||||
@@ -0,0 +1,46 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file apply_top_k_top_p_custom.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "apply_top_k_top_p_custom.h"
|
||||
#include "opdev/data_type_utils.h"
|
||||
#include "opdev/format_utils.h"
|
||||
#include "opdev/make_op_executor.h"
|
||||
#include "opdev/op_def.h"
|
||||
#include "opdev/op_dfx.h"
|
||||
#include "opdev/op_executor.h"
|
||||
#include "opdev/op_log.h"
|
||||
#include "opdev/shape_utils.h"
|
||||
using namespace op;
|
||||
|
||||
namespace l0op {
|
||||
|
||||
OP_TYPE_REGISTER(ApplyTopKTopPCustom);
|
||||
|
||||
const aclTensor* ApplyTopKTopPCustom(
|
||||
const aclTensor* sortedValue, const aclTensor* sortedIndices, const aclTensor* p, const aclTensor* k,
|
||||
aclOpExecutor* executor)
|
||||
{
|
||||
L0_DFX(ApplyTopKTopPCustom, sortedValue, sortedIndices, p, k);
|
||||
auto output = executor->AllocTensor(sortedValue->GetViewShape(), sortedValue->GetDataType());
|
||||
if (p == nullptr) {
|
||||
p = executor->AllocTensor(sortedValue->GetDataType(), Format::FORMAT_ND, Format::FORMAT_ND);
|
||||
}
|
||||
if (k == nullptr) {
|
||||
k = executor->AllocTensor(DataType::DT_INT32, Format::FORMAT_ND, Format::FORMAT_ND);
|
||||
}
|
||||
ADD_TO_LAUNCHER_LIST_AICORE(ApplyTopKTopPCustom, OP_INPUT(sortedValue, sortedIndices, p, k), OP_OUTPUT(output));
|
||||
|
||||
return output;
|
||||
}
|
||||
} // namespace l0op
|
||||
@@ -0,0 +1,24 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file apply_top_k_top_p_custom.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef OP_API_INC_LEVEL0_OP_APPLY_TOP_K_TOP_P_CUSTOM_OP_H_
|
||||
#define OP_API_INC_LEVEL0_OP_APPLY_TOP_K_TOP_P_CUSTOM_OP_H_
|
||||
|
||||
#include "opdev/op_executor.h"
|
||||
|
||||
namespace l0op {
|
||||
const aclTensor* ApplyTopKTopPCustom(const aclTensor* sortedValue, const aclTensor* sortedIndices,
|
||||
const aclTensor* p, const aclTensor* k, aclOpExecutor* executor);
|
||||
}
|
||||
#endif // OP_API_INC_LEVEL0_OP_APPLY_TOP_K_TOP_P_CUSTOM_OP_H_
|
||||
@@ -0,0 +1,100 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file apply_top_k_top_p_custom_def.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class ApplyTopKTopPCustom : public OpDef {
|
||||
public:
|
||||
explicit ApplyTopKTopPCustom(const char *name) : OpDef(name)
|
||||
{
|
||||
this->Input("sorted_value")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("sorted_indices")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("p")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("k")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("out")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
|
||||
OpAICoreConfig aicore_config;
|
||||
aicore_config.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(false)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true);
|
||||
this->AICore().AddConfig("ascend910b");
|
||||
this->AICore().AddConfig("ascend910_93");
|
||||
|
||||
OpAICoreConfig config_kirin = GetKirinCoreConfig();
|
||||
this->AICore().AddConfig("kirinx90", config_kirin);
|
||||
}
|
||||
|
||||
private:
|
||||
OpAICoreConfig GetKirinCoreConfig() const
|
||||
{
|
||||
OpAICoreConfig config_kirin;
|
||||
config_kirin.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(true)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.NeedCheckSupportFlag(false)
|
||||
.PrecisionReduceFlag(true);
|
||||
config_kirin.Input("sorted_value")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
config_kirin.Input("sorted_indices")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
config_kirin.Input("p")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
config_kirin.Input("k")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
config_kirin.Output("out")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
return config_kirin;
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(ApplyTopKTopPCustom);
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,314 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file apply_top_k_top_p_custom_tiling.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include "error_log.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "platform/platform_infos_def.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "tiling/tiling_api.h"
|
||||
#include "apply_top_k_top_p_custom_tiling.h"
|
||||
|
||||
namespace {
|
||||
constexpr uint32_t SYS_RESERVED_UB = uint32_t(16 * 1024);
|
||||
constexpr uint32_t SELECT_RESERVED_UB = uint32_t(8 * 1024);
|
||||
constexpr uint32_t DIM_ONE = 1;
|
||||
constexpr uint32_t DIM_TWO = 2;
|
||||
constexpr int32_t SORTED_VALUE_INPUT_INDEX = 0;
|
||||
constexpr int32_t SORTED_INDICES_INPUT_INDEX = 1;
|
||||
constexpr int32_t P_INPUT_INDEX = 2;
|
||||
constexpr int32_t K_INPUT_INDEX = 3;
|
||||
constexpr uint32_t DIM_INDEX0 = 0;
|
||||
constexpr uint32_t FLOAT_BYTES = 4;
|
||||
static std::map<ge::DataType, uint32_t> DTYPE_MAP = {{ge::DT_BF16, 2}, {ge::DT_FLOAT16, 1}, {ge::DT_FLOAT, 0}};
|
||||
static std::map<ge::DataType, uint32_t> DATATYPE_LEN_MAP = {
|
||||
{ge::DT_FLOAT16, 2}, {ge::DT_BF16, 2}, {ge::DT_FLOAT, 4}};
|
||||
const static uint32_t SYS_WORKSPACESIZE = uint32_t(16 * 1024 * 1024);
|
||||
|
||||
constexpr uint32_t DATA_PER_BLOCK_B32 = 8;
|
||||
constexpr uint32_t BYTES_B32 = 4;
|
||||
constexpr uint32_t BLOCK_BYTES = 32;
|
||||
constexpr uint32_t K_VALUE_MAX = 1024;
|
||||
constexpr uint32_t ONLY_TOP_P_KEY = 2;
|
||||
constexpr uint32_t ONLY_TOP_K_KEY = 1;
|
||||
constexpr uint32_t BATCH_MODE = 1;
|
||||
} // namespace
|
||||
|
||||
namespace optiling {
|
||||
class ApplyTopKTopPCustomTiling {
|
||||
public:
|
||||
explicit ApplyTopKTopPCustomTiling(gert::TilingContext* context) : tilingcontext(context){};
|
||||
ge::graphStatus Init();
|
||||
ge::graphStatus RunKernelTiling();
|
||||
private:
|
||||
ApplyTopKTopPCustomTilingData tilingData;
|
||||
gert::TilingContext* tilingcontext = nullptr;
|
||||
ge::graphStatus CheckShape();
|
||||
void SetTilingKey();
|
||||
void GetUsedCore();
|
||||
void CalDataPerCore();
|
||||
void FillTilingData();
|
||||
void PrintTilingData();
|
||||
template <typename T1>
|
||||
inline auto CeilAlign(T1 a, T1 b) const -> T1
|
||||
{
|
||||
return b == 0 ? a : (a + b - 1) / b * b;
|
||||
}
|
||||
template <typename T1>
|
||||
inline auto FloorAlign(T1 a, T1 b) const -> T1
|
||||
{
|
||||
return b == 0 ? a : a / b * b;
|
||||
}
|
||||
|
||||
const char *opName_ = nullptr;
|
||||
uint32_t coreNum_ = 0;
|
||||
uint32_t calUbSize_ = 0;
|
||||
uint32_t batchSize_ = 0;
|
||||
uint32_t vocabSize_ = 0;
|
||||
uint32_t tilingKey_ = 0;
|
||||
uint32_t usedCoreNum_ = 0;
|
||||
uint32_t batchPerCore_ = 1;
|
||||
uint32_t tailBatch_ = 0;
|
||||
uint32_t dataNumInit_ = 0;
|
||||
uint32_t dataNumInitAligned_ = 0;
|
||||
uint32_t ubFactorElement_ = 0;
|
||||
uint32_t ubFactorElementAligned_ = 0;
|
||||
uint32_t tailUbFactorElement_ = 0;
|
||||
uint32_t tailUbFactorElementAligned_ = 0;
|
||||
uint32_t iterateTimes_ = 0;
|
||||
uint32_t onlyTopK_ = 0;
|
||||
uint32_t onlyTopP_ = 0;
|
||||
uint64_t platformUbSize_ = 0;
|
||||
};
|
||||
|
||||
ge::graphStatus ApplyTopKTopPCustomTiling::CheckShape() {
|
||||
auto sortedValueShapePtr = tilingcontext->GetInputShape(SORTED_VALUE_INPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(tilingcontext, sortedValueShapePtr);
|
||||
auto sortedValueShape = sortedValueShapePtr->GetStorageShape();
|
||||
if (sortedValueShape.GetDimNum() != DIM_TWO) {
|
||||
OP_LOGE(opName_, "the dimNum of sorted_value should be 2, but got %u.", sortedValueShape.GetDimNum());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
auto sortedIndicesShapePtr = tilingcontext->GetInputShape(SORTED_INDICES_INPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(tilingcontext, sortedIndicesShapePtr);
|
||||
auto sortedIndicesShape = sortedIndicesShapePtr->GetStorageShape();
|
||||
if (sortedIndicesShape.GetDimNum() != DIM_TWO) {
|
||||
OP_LOGE(opName_, "the dimNum of sorted_indices should be 2, but got %u.", sortedIndicesShape.GetDimNum());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
batchSize_ = sortedValueShape.GetDim(DIM_INDEX0);
|
||||
vocabSize_ = sortedValueShape.GetDim(DIM_ONE);
|
||||
if (sortedIndicesShape.GetDim(DIM_INDEX0) != batchSize_ || sortedIndicesShape.GetDim(DIM_ONE) != vocabSize_) {
|
||||
OP_LOGE(opName_, "the shape of sorted_indices should be equal to sorted_value.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
auto pShapePtr = tilingcontext->GetOptionalInputShape(P_INPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(tilingcontext, pShapePtr);
|
||||
auto pShape = pShapePtr->GetStorageShape();
|
||||
auto pDimNum = pShape.GetDimNum();
|
||||
if (pDimNum != DIM_ONE && pDimNum != 0) {
|
||||
OP_LOGE(opName_, "the dimNum of p should be 1 or 0, but got %u.", pDimNum);
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
if (pDimNum != 0 && batchSize_ != pShape.GetDim(DIM_INDEX0)) {
|
||||
OP_LOGE(opName_, "p.shape[0] should be equal to logits.shape[0].");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
auto kShapePtr = tilingcontext->GetOptionalInputShape(K_INPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(tilingcontext, kShapePtr);
|
||||
auto kShape = kShapePtr->GetStorageShape();
|
||||
auto kDimNum = kShape.GetDimNum();
|
||||
if (kDimNum != DIM_ONE && kDimNum != 0) {
|
||||
OP_LOGE(opName_, "the dimNum of k should be 1 or 0, but got %u.", kShape.GetDimNum());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
if (kDimNum != 0 && batchSize_ != kShape.GetDim(DIM_INDEX0)) {
|
||||
OP_LOGE(opName_, "k.shape[0] should be equal to logits.shape[0].");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
if (kDimNum == 0 && pDimNum == 0) {
|
||||
OP_LOGE(opName_, "the dimNum of q and k should be 0 at the same time.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
onlyTopK_ = (kDimNum != 0 && pDimNum == 0) ? ONLY_TOP_K_KEY : 0;
|
||||
onlyTopP_ = (pDimNum != 0 && kDimNum == 0) ? ONLY_TOP_P_KEY : 0;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus ApplyTopKTopPCustomTiling::Init() {
|
||||
opName_ = tilingcontext->GetNodeName();
|
||||
OP_LOGD(opName_, "TilingForApplyTopKTopPCustom init.");
|
||||
auto platformInfo = platform_ascendc::PlatformAscendC(tilingcontext->GetPlatformInfo());
|
||||
coreNum_ = platformInfo.GetCoreNumAiv();
|
||||
platformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, platformUbSize_);
|
||||
OP_LOGD(opName_, "platformUbSize: %lu.", platformUbSize_);
|
||||
uint32_t avaliableUb = static_cast<uint32_t>(platformUbSize_) - SYS_RESERVED_UB - SELECT_RESERVED_UB;
|
||||
calUbSize_ = FloorAlign(avaliableUb, BLOCK_BYTES);
|
||||
if (CheckShape() == ge::GRAPH_FAILED) {
|
||||
OP_LOGE(opName_, "check shape failed.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
uint32_t tempValue = 1;
|
||||
while (tempValue < vocabSize_) {
|
||||
tempValue <<= 1;
|
||||
iterateTimes_++;
|
||||
} // ceil(log2(vocabSize_))
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
void ApplyTopKTopPCustomTiling::SetTilingKey() {
|
||||
tilingKey_ += onlyTopK_;
|
||||
tilingKey_ += onlyTopP_;
|
||||
tilingcontext->SetTilingKey(tilingKey_);
|
||||
if (tilingKey_ == ONLY_TOP_P_KEY){
|
||||
tilingcontext->SetScheduleMode(BATCH_MODE);
|
||||
}
|
||||
}
|
||||
|
||||
void ApplyTopKTopPCustomTiling::GetUsedCore()
|
||||
{
|
||||
if (coreNum_ > 0) {
|
||||
batchPerCore_ = coreNum_ == uint32_t(0) ? batchSize_ : batchSize_ / coreNum_;
|
||||
tailBatch_ = batchSize_ % coreNum_;
|
||||
usedCoreNum_ = coreNum_;
|
||||
}
|
||||
}
|
||||
|
||||
void ApplyTopKTopPCustomTiling::CalDataPerCore()
|
||||
{
|
||||
uint32_t inputDataTypeByte = DATATYPE_LEN_MAP[tilingcontext->GetInputDesc(SORTED_VALUE_INPUT_INDEX)->GetDataType()];
|
||||
uint32_t dataPerBlock = BLOCK_BYTES / inputDataTypeByte;
|
||||
dataNumInit_ = vocabSize_ < K_VALUE_MAX ? vocabSize_ : K_VALUE_MAX;
|
||||
dataNumInitAligned_ = vocabSize_ < K_VALUE_MAX ? vocabSize_ : K_VALUE_MAX;
|
||||
ubFactorElement_ = vocabSize_ < K_VALUE_MAX ? vocabSize_ : K_VALUE_MAX;
|
||||
ubFactorElementAligned_ = CeilAlign(ubFactorElement_, dataPerBlock);
|
||||
tailUbFactorElement_ = vocabSize_ % ubFactorElement_;
|
||||
tailUbFactorElement_ = tailUbFactorElement_ == uint32_t(0) ? ubFactorElement_ : tailUbFactorElement_;
|
||||
tailUbFactorElementAligned_ = CeilAlign(tailUbFactorElement_, dataPerBlock);
|
||||
|
||||
uint32_t sortedValueBytes = ubFactorElementAligned_ * inputDataTypeByte + K_VALUE_MAX * inputDataTypeByte;
|
||||
uint32_t sortedIndicesBytes = ubFactorElementAligned_ * BYTES_B32 + K_VALUE_MAX * BYTES_B32;
|
||||
uint32_t pBytes = dataPerBlock * inputDataTypeByte;
|
||||
uint32_t kBytes = DATA_PER_BLOCK_B32 * BYTES_B32;
|
||||
uint32_t outTensorBytes = ubFactorElementAligned_ * inputDataTypeByte;
|
||||
|
||||
calUbSize_ = calUbSize_ - sortedValueBytes - sortedIndicesBytes - pBytes - kBytes - outTensorBytes;
|
||||
if (onlyTopP_ > 0) {
|
||||
calUbSize_ = static_cast<uint32_t>(platformUbSize_);
|
||||
}
|
||||
}
|
||||
|
||||
void ApplyTopKTopPCustomTiling::FillTilingData()
|
||||
{
|
||||
tilingData.set_batchSize(batchSize_);
|
||||
tilingData.set_vocabSize(vocabSize_);
|
||||
tilingData.set_batchPerCore(batchPerCore_);
|
||||
tilingData.set_tailBatch(tailBatch_);
|
||||
tilingData.set_blockNum(usedCoreNum_);
|
||||
tilingData.set_dataNumInit(dataNumInit_);
|
||||
tilingData.set_dataNumInitAligned(dataNumInitAligned_);
|
||||
tilingData.set_ubFactorElement(ubFactorElement_);
|
||||
tilingData.set_ubFactorElementAligned(ubFactorElementAligned_);
|
||||
tilingData.set_tailUbFactorElement(tailUbFactorElement_);
|
||||
tilingData.set_tailUbFactorElementAligned(tailUbFactorElementAligned_);
|
||||
tilingData.set_calUbSize(calUbSize_);
|
||||
tilingData.set_iterateTimes(iterateTimes_);
|
||||
}
|
||||
|
||||
void ApplyTopKTopPCustomTiling::PrintTilingData()
|
||||
{
|
||||
OP_LOGD(opName_, "batchSize: %u.", tilingData.get_batchSize());
|
||||
OP_LOGD(opName_, "vocabSize: %u.", tilingData.get_vocabSize());
|
||||
OP_LOGD(opName_, "batchPerCore: %u.", tilingData.get_batchPerCore());
|
||||
OP_LOGD(opName_, "tailBatch: %u.", tilingData.get_tailBatch());
|
||||
OP_LOGD(opName_, "usedCoreNum: %u.", tilingData.get_blockNum());
|
||||
OP_LOGD(opName_, "dataNumInit_: %u.", tilingData.get_dataNumInit());
|
||||
OP_LOGD(opName_, "dataNumInitAligned_: %u.", tilingData.get_dataNumInitAligned());
|
||||
OP_LOGD(opName_, "ubFactorElement: %u.", tilingData.get_ubFactorElement());
|
||||
OP_LOGD(opName_, "ubFactorElementAligned: %u.", tilingData.get_ubFactorElementAligned());
|
||||
OP_LOGD(opName_, "tailUbFactorElement: %u.", tilingData.get_tailUbFactorElement());
|
||||
OP_LOGD(opName_, "tailUbFactorElementAligned: %u.", tilingData.get_tailUbFactorElementAligned());
|
||||
OP_LOGD(opName_, "calUbSize: %u.", tilingData.get_calUbSize());
|
||||
OP_LOGD(opName_, "iterateTimes: %u.", tilingData.get_iterateTimes());
|
||||
}
|
||||
|
||||
ge::graphStatus ApplyTopKTopPCustomTiling::RunKernelTiling()
|
||||
{
|
||||
OP_LOGD(opName_, "TilingForApplyTopKTopPCustom start.");
|
||||
|
||||
SetTilingKey();
|
||||
GetUsedCore();
|
||||
CalDataPerCore();
|
||||
FillTilingData();
|
||||
PrintTilingData();
|
||||
|
||||
OP_LOGD(opName_, "tilingKey: %u.", tilingKey_);
|
||||
uint32_t syncWorkspaceSize = SYS_WORKSPACESIZE;
|
||||
size_t* currentWorkspace = tilingcontext->GetWorkspaceSizes(1);
|
||||
currentWorkspace[0] = onlyTopP_ > 0 ? syncWorkspaceSize + batchSize_ * vocabSize_ * FLOAT_BYTES : syncWorkspaceSize;
|
||||
|
||||
tilingData.SaveToBuffer(tilingcontext->GetRawTilingData()->GetData(),
|
||||
tilingcontext->GetRawTilingData()->GetCapacity());
|
||||
tilingcontext->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
|
||||
tilingcontext->SetBlockDim(usedCoreNum_);
|
||||
|
||||
OP_LOGD(opName_, "TilingForApplyTopKTopPCustom end.");
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus TilingForApplyTopKTopPCustom(gert::TilingContext* context)
|
||||
{
|
||||
ApplyTopKTopPCustomTiling tilingObject(context);
|
||||
auto ret = tilingObject.Init();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
OP_LOGE(context->GetNodeName(), "tiling Init failed.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
ret = tilingObject.RunKernelTiling();
|
||||
OP_LOGD(context->GetNodeName(), "TilingForApplyTopKTopPCustom end.");
|
||||
return ret;
|
||||
}
|
||||
|
||||
static ge::graphStatus TilingPrepareForApplyTopKTopPCustom(gert::TilingParseContext* context)
|
||||
{
|
||||
OP_LOGD(context->GetNodeName(), "TilingPrepareForApplyTopKTopPCustom start");
|
||||
auto compileInfo = context->GetCompiledInfo<TilingForApplyTopKTopPCustomCompileInfo>();
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
|
||||
auto platformInfo = context->GetPlatformInfo();
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfo);
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
|
||||
compileInfo->totalCoreNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint64_t ubSizePlatForm;
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
|
||||
compileInfo->ubSizePlatForm = static_cast<int64_t>(ubSizePlatForm);
|
||||
OP_CHECK_IF(compileInfo->ubSizePlatForm <= 0,
|
||||
OP_LOGE(context->GetNodeName(), "Failed to get ub size"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_LOGD(context->GetNodeName(), "ub_size_platform is %lu", compileInfo->ubSizePlatForm);
|
||||
uint64_t totalUbSize = 0;
|
||||
platformInfo->GetLocalMemSize(fe::LocalMemType::UB, totalUbSize);
|
||||
OP_LOGD(context->GetNodeName(), "total ub size is %lu", totalUbSize);
|
||||
OP_LOGD(context->GetNodeName(), "TilingPrepareForApplyTopKTopPCustom end");
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(ApplyTopKTopPCustom)
|
||||
.Tiling(TilingForApplyTopKTopPCustom)
|
||||
.TilingParse<TilingForApplyTopKTopPCustomCompileInfo>(TilingPrepareForApplyTopKTopPCustom);
|
||||
} // namespace optiling
|
||||
@@ -0,0 +1,51 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file apply_top_k_top_p_custom_tiling.h
|
||||
* \brief
|
||||
* ATTENTION: MAKE SURE 'BEGIN_TILING_DATA_DEF' STAY IN THE SAME LINE (28) USING BLANK LINES.
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
*/
|
||||
#ifndef __APPLY_TOP_K_TOP_P_CUSTOM_TILINGDATA_H__
|
||||
#define __APPLY_TOP_K_TOP_P_CUSTOM_TILINGDATA_H__
|
||||
|
||||
#include "register/tilingdata_base.h"
|
||||
|
||||
namespace optiling {
|
||||
|
||||
BEGIN_TILING_DATA_DEF(ApplyTopKTopPCustomTilingData)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, batchSize);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, vocabSize);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, batchPerCore);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, tailBatch);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, blockNum);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dataNumInit);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dataNumInitAligned);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, ubFactorElement);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, ubFactorElementAligned);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, tailUbFactorElement);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, tailUbFactorElementAligned);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, calUbSize);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, iterateTimes);
|
||||
END_TILING_DATA_DEF;
|
||||
REGISTER_TILING_DATA_CLASS(ApplyTopKTopPCustom, ApplyTopKTopPCustomTilingData)
|
||||
|
||||
struct TilingForApplyTopKTopPCustomCompileInfo {
|
||||
uint32_t totalCoreNum = 0;
|
||||
uint64_t ubSizePlatForm = 0;
|
||||
};
|
||||
|
||||
} // namespace optiling
|
||||
#endif // __APPLY_TOP_K_TOP_P_CUSTOM_TILINGDATA_H__
|
||||
71
csrc/apply_top_k_top_p_custom/op_host/error_log.h
Normal file
71
csrc/apply_top_k_top_p_custom/op_host/error_log.h
Normal file
@@ -0,0 +1,71 @@
|
||||
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
|
||||
#include <string>
|
||||
#include "toolchain/slog.h"
|
||||
|
||||
#define OP_LOGI(opname, ...)
|
||||
#define OP_LOGW(opname, ...) \
|
||||
do { \
|
||||
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
|
||||
do { \
|
||||
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGE(opname, ...) \
|
||||
do { \
|
||||
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGD(opname, ...)
|
||||
|
||||
namespace optiling {
|
||||
|
||||
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
|
||||
do { \
|
||||
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
|
||||
#define OP_CHECK_IF(cond, log_func, expr) \
|
||||
do { \
|
||||
if (cond) { \
|
||||
log_func; \
|
||||
expr; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
|
||||
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
|
||||
do { \
|
||||
if ((ptr) == nullptr) { \
|
||||
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
|
||||
return ge::GRAPH_FAILED; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
} // namespace optiling
|
||||
|
||||
template <typename T>
|
||||
T CeilAlign(T a, T b)
|
||||
{
|
||||
return (a + b - 1) / b * b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T CeilDiv(T a, T b)
|
||||
{
|
||||
if (b == 0) {
|
||||
return a;
|
||||
}
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
26
csrc/apply_top_k_top_p_custom/op_host/sort.h
Normal file
26
csrc/apply_top_k_top_p_custom/op_host/sort.h
Normal file
@@ -0,0 +1,26 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file sort.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef PTA_NPU_OP_API_INC_LEVEL0_OP_SORT_OP_H_
|
||||
#define PTA_NPU_OP_API_INC_LEVEL0_OP_SORT_OP_H_
|
||||
|
||||
#include "opdev/op_executor.h"
|
||||
#include "opdev/fast_vector.h"
|
||||
|
||||
namespace l0op {
|
||||
const std::tuple<aclTensor*, aclTensor*> Sort(const aclTensor* self, int64_t dim, bool descending, bool stable,
|
||||
op::DataType indicesType, aclOpExecutor* executor);
|
||||
}
|
||||
|
||||
#endif // PTA_NPU_OP_API_INC_LEVEL0_OP_SORT_OP_H_
|
||||
@@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file apply_top_k_top_p_custom.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "apply_top_k_top_p_custom.h"
|
||||
#include "apply_top_p_custom.h"
|
||||
using namespace AscendC;
|
||||
using namespace ApplyTopKTopPCustomOp;
|
||||
using namespace ApplyTopPCustomOp;
|
||||
|
||||
extern "C" __global__ __aicore__ void apply_top_k_top_p_custom(GM_ADDR sorted_value, GM_ADDR sorted_indices,
|
||||
GM_ADDR p, GM_ADDR k, GM_ADDR out, GM_ADDR workSpace, GM_ADDR tiling) {
|
||||
TPipe pipe;
|
||||
GET_TILING_DATA(tilingData, tiling);
|
||||
if (TILING_KEY_IS(0)) {
|
||||
ApplyTopKTopPCustomOp::ApplyTopKTopPCustom<DTYPE_OUT, float, DTYPE_OUT> op;
|
||||
op.InitTilingData(tilingData, sorted_value, sorted_indices, p, k, out);
|
||||
op.InitBuffer(&pipe);
|
||||
op.Process();
|
||||
} else if (TILING_KEY_IS(1)) {
|
||||
ApplyTopKTopPCustomOp::ApplyTopKTopPCustom<DTYPE_OUT, float, DTYPE_OUT> op;
|
||||
op.InitTilingData(tilingData, sorted_value, sorted_indices, p, k, out);
|
||||
op.InitBuffer(&pipe);
|
||||
op.ProcessTopK();
|
||||
} else if (TILING_KEY_IS(2)) {
|
||||
ApplyTopPCustomOp::ApplyTopPCustom<DTYPE_OUT, float, DTYPE_OUT> op;
|
||||
op.InitTilingData(tilingData, sorted_value, sorted_indices, p, k, out, workSpace);
|
||||
op.InitBuffer(&pipe);
|
||||
op.ProcessTopP();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,719 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file apply_top_k_top_p_custom.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef APPLY_TOP_K_TOP_P_CUSTOM_H_KERNEL
|
||||
#define APPLY_TOP_K_TOP_P_CUSTOM_H_KERNEL
|
||||
|
||||
#include "kernel_operator.h"
|
||||
|
||||
using namespace AscendC;
|
||||
namespace ApplyTopKTopPCustomOp {
|
||||
constexpr uint32_t BUFFER_NUM = 1;
|
||||
constexpr uint16_t FLOAT16_NEG_INF = 0xFC00; // -inf 64512
|
||||
constexpr uint16_t BF16_NEG_INF = 0xFF80; // -inf 65408
|
||||
constexpr int32_t FLOAT32_NEG_INF = 0xFF800000; // -inf -2139095040
|
||||
|
||||
constexpr uint32_t BLOCK_BYTES = 32;
|
||||
constexpr uint32_t DATA_PER_BLOCK_B32 = 8;
|
||||
constexpr uint32_t DATA_PER_REPEAT_B32 = 64;
|
||||
constexpr uint32_t K_MAX = 1024;
|
||||
constexpr uint64_t MASK_64 = 64;
|
||||
constexpr CumSumConfig CUMSUM_CONFIG{true, true, false};
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
class ApplyTopKTopPCustom {
|
||||
public:
|
||||
__aicore__ inline ApplyTopKTopPCustom(){};
|
||||
__aicore__ inline void InitTilingData(
|
||||
const ApplyTopKTopPCustomTilingData &__restrict tilingData, GM_ADDR sorted_value, GM_ADDR sorted_indices,
|
||||
GM_ADDR p, GM_ADDR k, GM_ADDR out);
|
||||
__aicore__ inline void InitBuffer(TPipe *inputPipe);
|
||||
__aicore__ inline void Process();
|
||||
__aicore__ inline void ProcessTopK();
|
||||
private:
|
||||
__aicore__ inline void InitCopyIn(uint32_t loopBatch, int64_t currentGmIdx);
|
||||
__aicore__ inline void InitProcess(uint32_t loopBatch);
|
||||
__aicore__ inline void ProcessKLtKMax(uint32_t loopBatch);
|
||||
__aicore__ inline void ScatterCumtomImpl(uint32_t loopBatch, uint32_t loopProbNum, uint32_t offset);
|
||||
__aicore__ inline void ProcessRemain(uint32_t loopBatch);
|
||||
__aicore__ inline void GetKthResult(uint32_t loopBatch, uint32_t offset, uint8_t repeatTimes);
|
||||
__aicore__ inline void GetFirstKLoop(uint32_t loopBatch, int32_t &firstKLoop);
|
||||
__aicore__ inline void ScatterFromFirstKLoop(uint32_t loopBatch, int32_t firstKLoop, float &cumsumData);
|
||||
__aicore__ inline void ReduceSumWithAddsAndExpImpl(uint32_t offset, uint32_t loopDataNum);
|
||||
__aicore__ inline void CumSumWithAddsAndExpImpl(
|
||||
uint32_t offset, uint32_t loopDataNum, uint32_t cumsumInner, float cumsumData);
|
||||
// topk func
|
||||
__aicore__ inline void InitProcessTopK(uint32_t loopBatch);
|
||||
__aicore__ inline void ProcessKLtKMaxTopK(uint32_t loopBatch);
|
||||
__aicore__ inline void ProcessRemainTopK(uint32_t loopBatch);
|
||||
__aicore__ inline void GetFirstKLoopTopK(uint32_t loopBatch, int32_t &firstKLoop);
|
||||
__aicore__ inline void ScatterFromFirstKLoopTopK(uint32_t loopBatch, int32_t firstKLoop);
|
||||
__aicore__ inline void ScatterCumtomImplTopK(uint32_t loopBatch, uint32_t loopProbNum, uint32_t offset);
|
||||
__aicore__ inline void SToMTE3Sync() {
|
||||
event_t eventIDSToMTE3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_MTE3));
|
||||
SetFlag<HardEvent::S_MTE3>(eventIDSToMTE3);
|
||||
WaitFlag<HardEvent::S_MTE3>(eventIDSToMTE3);
|
||||
}
|
||||
__aicore__ inline void MTE3ToSSync() {
|
||||
event_t eventIDMTE3ToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_S));
|
||||
SetFlag<HardEvent::MTE3_S>(eventIDMTE3ToS);
|
||||
WaitFlag<HardEvent::MTE3_S>(eventIDMTE3ToS);
|
||||
}
|
||||
__aicore__ inline void VToSSync() {
|
||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
||||
}
|
||||
__aicore__ inline void MTE2ToVSync() {
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
}
|
||||
__aicore__ inline void MTE2ToSSync() {
|
||||
event_t eventIdMte2ToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_S));
|
||||
SetFlag<HardEvent::MTE2_S>(eventIdMte2ToS);
|
||||
WaitFlag<HardEvent::MTE2_S>(eventIdMte2ToS);
|
||||
}
|
||||
__aicore__ inline void MTE3ToVSync() {
|
||||
event_t eventIdMte3ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
|
||||
SetFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
|
||||
WaitFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
|
||||
}
|
||||
__aicore__ inline void SToMTE2Sync()
|
||||
{
|
||||
event_t eventIDSToMTE2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_MTE2));
|
||||
SetFlag<HardEvent::S_MTE2>(eventIDSToMTE2);
|
||||
WaitFlag<HardEvent::S_MTE2>(eventIDSToMTE2);
|
||||
}
|
||||
__aicore__ inline void VToMTE3Sync() {
|
||||
event_t eventIDVToMTE3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
|
||||
SetFlag<HardEvent::V_MTE3>(eventIDVToMTE3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventIDVToMTE3);
|
||||
}
|
||||
private:
|
||||
TPipe *pipe_;
|
||||
// create queues for input, in this case depth is equal to buffer num
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> sortedValueInQueue_;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> sortedIndicesInQueue_;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> pInQueue_;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> kInQueue_;
|
||||
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueue_;
|
||||
TBuf<TPosition::VECCALC> calBuf_;
|
||||
|
||||
// tilingData
|
||||
uint32_t batchSize_ = 0;
|
||||
uint32_t vocabSize_ = 0;
|
||||
uint32_t batchPerCore_ = 0;
|
||||
uint32_t tailBatch_ = 0;
|
||||
uint32_t blockNum_ = 0;
|
||||
uint32_t dataNumInit_ = 0;
|
||||
uint32_t dataNumInitAligned_ = 0;
|
||||
uint32_t ubFactorElement_ = 0;
|
||||
uint32_t ubFactorElementAligned_ = 0;
|
||||
uint32_t tailUbFactorElement_ = 0;
|
||||
uint32_t tailUbFactorElementAligned_ = 0;
|
||||
uint32_t calUbSize_ = 0;
|
||||
|
||||
uint32_t blockIdx_ = 0;
|
||||
uint32_t loopBatch_ = 0;
|
||||
uint32_t batchOffset_ = 0;
|
||||
uint32_t bufOffsetLoop = 0;
|
||||
uint32_t loopInner_ = 0;
|
||||
uint32_t loopInnerOnlyP_ = 0;
|
||||
int64_t baseGmIdx_ = 0;
|
||||
|
||||
GlobalTensor<inputT> mGmSortedValue_;
|
||||
GlobalTensor<int32_t> mGmSortedIndices_;
|
||||
GlobalTensor<inputT> mGmP_;
|
||||
GlobalTensor<int32_t> mGmK_;
|
||||
GlobalTensor<outputT> mGmOut_;
|
||||
|
||||
LocalTensor<int32_t> kLocal;
|
||||
LocalTensor<inputT> pLocal;
|
||||
LocalTensor<outputT> outTensor;
|
||||
LocalTensor<inputT> sortedValueLocal;
|
||||
LocalTensor<int32_t> sortedIndicesLocal;
|
||||
|
||||
LocalTensor<float> sortedValueLocalFp32;
|
||||
LocalTensor<float> negInfLocal;
|
||||
|
||||
LocalTensor<float> calLocalFp32;
|
||||
LocalTensor<float> kthValueLocal;
|
||||
LocalTensor<float> tmpLocal;
|
||||
LocalTensor<float> cumSumRes;
|
||||
LocalTensor<float> cumSumTmp;
|
||||
LocalTensor<float> reduceLocal;
|
||||
LocalTensor<float> softMaxRes;
|
||||
LocalTensor<inputT> scatterTensor;
|
||||
LocalTensor<uint8_t> sharedTmpBuffer;
|
||||
|
||||
float kthValue = 0;
|
||||
float pValue = 0;
|
||||
float maxValue = 0;
|
||||
float reduceSumValueInvert = 0;
|
||||
float reduceSumValue = 0;
|
||||
inputT kthTopKValue = 0;
|
||||
BinaryRepeatParams repeatParams = {1, 0, 1, 8, 0, 8};
|
||||
DataCopyExtParams scatterCopyParams{1, (uint32_t)(sizeof(outputT)), 0, 0, 0};
|
||||
};
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::InitTilingData(
|
||||
const ApplyTopKTopPCustomTilingData &__restrict tilingData, GM_ADDR sorted_value, GM_ADDR sorted_indices,
|
||||
GM_ADDR p, GM_ADDR k, GM_ADDR out) {
|
||||
batchSize_ = tilingData.batchSize;
|
||||
vocabSize_ = tilingData.vocabSize;
|
||||
batchPerCore_ = tilingData.batchPerCore;
|
||||
tailBatch_ = tilingData.tailBatch;
|
||||
blockNum_ = tilingData.blockNum;
|
||||
dataNumInit_ = tilingData.dataNumInit;
|
||||
dataNumInitAligned_ = AscendC::AlignUp(dataNumInit_, DATA_PER_BLOCK_B32);
|
||||
ubFactorElement_ = tilingData.ubFactorElement;
|
||||
ubFactorElementAligned_ = tilingData.ubFactorElementAligned;
|
||||
tailUbFactorElement_ = tilingData.tailUbFactorElement;
|
||||
tailUbFactorElementAligned_ = tilingData.tailUbFactorElementAligned;
|
||||
calUbSize_ = tilingData.calUbSize;
|
||||
blockIdx_ = GetBlockIdx();
|
||||
|
||||
if (blockIdx_ < tailBatch_)
|
||||
{
|
||||
loopBatch_ = batchPerCore_ + 1;
|
||||
batchOffset_ = blockIdx_ * loopBatch_;
|
||||
}
|
||||
else
|
||||
{
|
||||
loopBatch_ = batchPerCore_;
|
||||
batchOffset_ = blockIdx_ * batchPerCore_ + tailBatch_;
|
||||
}
|
||||
loopInner_ = (vocabSize_ - dataNumInit_ + ubFactorElementAligned_ - 1) / ubFactorElementAligned_;
|
||||
loopInnerOnlyP_ = (vocabSize_ + ubFactorElementAligned_ - 1) / ubFactorElementAligned_;
|
||||
mGmSortedValue_.SetGlobalBuffer(reinterpret_cast<__gm__ inputT *>(sorted_value));
|
||||
mGmSortedIndices_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sorted_indices));
|
||||
mGmP_.SetGlobalBuffer(reinterpret_cast<__gm__ inputT *>(p));
|
||||
mGmK_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(k));
|
||||
mGmOut_.SetGlobalBuffer(reinterpret_cast<__gm__ outputT *>(out));
|
||||
}
|
||||
|
||||
// init used buffer
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::InitBuffer(TPipe *inputPipe) {
|
||||
pipe_ = inputPipe;
|
||||
pipe_->InitBuffer(sortedValueInQueue_, BUFFER_NUM, sizeof(inputT) * (ubFactorElementAligned_ + K_MAX));
|
||||
pipe_->InitBuffer(sortedIndicesInQueue_, BUFFER_NUM, sizeof(int32_t) * (ubFactorElementAligned_ + K_MAX));
|
||||
pipe_->InitBuffer(pInQueue_, BUFFER_NUM, BLOCK_BYTES);
|
||||
pipe_->InitBuffer(kInQueue_, BUFFER_NUM, BLOCK_BYTES);
|
||||
pipe_->InitBuffer(outQueue_, BUFFER_NUM, sizeof(outputT) * ubFactorElementAligned_);
|
||||
pipe_->InitBuffer(calBuf_, calUbSize_);
|
||||
if constexpr (!IsSameType<inputT, float>::value) {
|
||||
sortedValueLocalFp32 = calBuf_.GetWithOffset<float>(ubFactorElementAligned_ + K_MAX, bufOffsetLoop);
|
||||
bufOffsetLoop = bufOffsetLoop + (ubFactorElementAligned_ + K_MAX) * sizeof(float);
|
||||
}
|
||||
kthValueLocal = calBuf_.GetWithOffset<float>(DATA_PER_BLOCK_B32, bufOffsetLoop);
|
||||
bufOffsetLoop = bufOffsetLoop + BLOCK_BYTES;
|
||||
|
||||
negInfLocal = calBuf_.GetWithOffset<float>(DATA_PER_BLOCK_B32, bufOffsetLoop);
|
||||
bufOffsetLoop = bufOffsetLoop + BLOCK_BYTES;
|
||||
|
||||
tmpLocal = calBuf_.GetWithOffset<float>(ubFactorElementAligned_, bufOffsetLoop);
|
||||
bufOffsetLoop = bufOffsetLoop + ubFactorElementAligned_ * sizeof(float);
|
||||
cumSumRes = calBuf_.GetWithOffset<float>(ubFactorElementAligned_, bufOffsetLoop);
|
||||
bufOffsetLoop = bufOffsetLoop + ubFactorElementAligned_ * sizeof(float);
|
||||
cumSumTmp = calBuf_.GetWithOffset<float>(ubFactorElementAligned_, bufOffsetLoop);
|
||||
bufOffsetLoop = bufOffsetLoop + ubFactorElementAligned_ * sizeof(float);
|
||||
reduceLocal = calBuf_.GetWithOffset<float>(ubFactorElementAligned_ * BLOCK_BYTES, bufOffsetLoop);
|
||||
bufOffsetLoop = bufOffsetLoop + ubFactorElementAligned_ * BLOCK_BYTES * sizeof(float);
|
||||
|
||||
softMaxRes = tmpLocal.template ReinterpretCast<float>();
|
||||
scatterTensor = reduceLocal.template ReinterpretCast<inputT>();
|
||||
sharedTmpBuffer = reduceLocal.template ReinterpretCast<uint8_t>();
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::Process() {
|
||||
kLocal = kInQueue_.AllocTensor<int32_t>();
|
||||
pLocal = pInQueue_.AllocTensor<inputT>();
|
||||
outTensor = outQueue_.AllocTensor<outputT>();
|
||||
sortedValueLocal = sortedValueInQueue_.AllocTensor<inputT>();
|
||||
sortedIndicesLocal = sortedIndicesInQueue_.AllocTensor<int32_t>();
|
||||
Duplicate(negInfLocal.template ReinterpretCast<int32_t>(), FLOAT32_NEG_INF, DATA_PER_BLOCK_B32);
|
||||
if constexpr (IsSameType<inputT, float>::value) {
|
||||
calLocalFp32 = sortedValueLocal;
|
||||
Duplicate(outTensor.template ReinterpretCast<int32_t>(), FLOAT32_NEG_INF, ubFactorElementAligned_);
|
||||
} else if constexpr (IsSameType<inputT, half>::value) {
|
||||
calLocalFp32 = sortedValueLocalFp32;
|
||||
Duplicate(outTensor.template ReinterpretCast<uint16_t>(), FLOAT16_NEG_INF, ubFactorElementAligned_);
|
||||
} else {
|
||||
calLocalFp32 = sortedValueLocalFp32;
|
||||
Duplicate(outTensor.template ReinterpretCast<uint16_t>(), BF16_NEG_INF, ubFactorElementAligned_);
|
||||
}
|
||||
VToMTE3Sync();
|
||||
for (uint32_t loopBatch = 0; loopBatch < loopBatch_; loopBatch++) {
|
||||
baseGmIdx_ = batchOffset_ * vocabSize_ + loopBatch * vocabSize_;
|
||||
InitProcess(loopBatch);
|
||||
if (calLocalFp32.GetValue(ubFactorElementAligned_) < kthValue) {
|
||||
ProcessKLtKMax(loopBatch);
|
||||
} else {
|
||||
ProcessRemain(loopBatch);
|
||||
}
|
||||
}
|
||||
kInQueue_.FreeTensor(kLocal);
|
||||
pInQueue_.FreeTensor(pLocal);
|
||||
sortedValueInQueue_.FreeTensor(sortedValueLocal);
|
||||
sortedIndicesInQueue_.FreeTensor(sortedIndicesLocal);
|
||||
outQueue_.FreeTensor(outTensor);
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::InitCopyIn(uint32_t loopBatch,
|
||||
int64_t currentGmIdx) {
|
||||
DataCopyPad(mGmOut_[currentGmIdx], outTensor, {1, (uint32_t)(dataNumInit_ * sizeof(outputT)), 0, 0, 0});
|
||||
DataCopyPad(sortedValueLocal[ubFactorElementAligned_], mGmSortedValue_[currentGmIdx],
|
||||
{1, static_cast<uint32_t>(dataNumInit_ * sizeof(inputT)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
DataCopyPad(pLocal, mGmP_[batchOffset_ + loopBatch], {1, static_cast<uint32_t>(sizeof(inputT)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
if constexpr (!IsSameType<inputT, float>::value) {
|
||||
MTE2ToVSync();
|
||||
Cast(sortedValueLocalFp32[ubFactorElementAligned_], sortedValueLocal[ubFactorElementAligned_],
|
||||
RoundMode::CAST_NONE, dataNumInit_);
|
||||
Cast(tmpLocal, pLocal, RoundMode::CAST_NONE, DATA_PER_BLOCK_B32);
|
||||
}
|
||||
DataCopyPad(sortedIndicesLocal[ubFactorElementAligned_], mGmSortedIndices_[currentGmIdx],
|
||||
{1, static_cast<uint32_t>(dataNumInit_ * sizeof(int32_t)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
DataCopyPad(kLocal, mGmK_[batchOffset_ + loopBatch], {1, static_cast<uint32_t>(sizeof(int32_t)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::GetKthResult(uint32_t loopBatch,
|
||||
uint32_t offset, uint8_t repeatTimes){
|
||||
Compare(tmpLocal.template ReinterpretCast<uint8_t>(), kthValueLocal, calLocalFp32[offset],
|
||||
CMPMODE::GT, MASK_64, repeatTimes, repeatParams);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Select(calLocalFp32[offset], tmpLocal.template ReinterpretCast<uint8_t>(),
|
||||
negInfLocal, calLocalFp32[offset], SELMODE::VSEL_TENSOR_TENSOR_MODE, MASK_64,
|
||||
repeatTimes, repeatParams);
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ReduceSumWithAddsAndExpImpl(uint32_t offset,
|
||||
uint32_t loopDataNum) {
|
||||
Adds(softMaxRes, calLocalFp32[offset], maxValue, loopDataNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Exp(softMaxRes, softMaxRes, loopDataNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSum(reduceLocal, softMaxRes, reduceLocal, loopDataNum);
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::InitProcess(uint32_t loopBatch) {
|
||||
int64_t initGmIdx = baseGmIdx_ + vocabSize_ - dataNumInit_;
|
||||
InitCopyIn(loopBatch, initGmIdx);
|
||||
MTE2ToSSync();
|
||||
|
||||
int32_t kValue = kLocal.GetValue(0);
|
||||
if constexpr (IsSameType<inputT, float>::value) {
|
||||
pValue = float(1.0) - pLocal.GetValue(0);
|
||||
} else {
|
||||
pValue = float(1.0) - tmpLocal.GetValue(0);
|
||||
}
|
||||
maxValue = -calLocalFp32[ubFactorElementAligned_].GetValue(dataNumInit_ - 1);
|
||||
if constexpr (IsSameType<inputT, float>::value) {
|
||||
kthValue = mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0);
|
||||
} else if constexpr (IsSameType<inputT, half>::value) {
|
||||
kthValue = static_cast<float>(mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0));
|
||||
} else {
|
||||
kthValue = ToFloat(mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0));
|
||||
}
|
||||
|
||||
Duplicate(kthValueLocal, kthValue, 8);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
uint8_t repeatTimes = (dataNumInit_ + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
|
||||
GetKthResult(loopBatch, ubFactorElementAligned_, repeatTimes);
|
||||
PipeBarrier<PIPE_V>();
|
||||
DataCopyExtParams copyParams{1, (uint32_t)(ubFactorElementAligned_ * sizeof(outputT)), 0, 0, 0};
|
||||
|
||||
ReduceSumWithAddsAndExpImpl(ubFactorElementAligned_, dataNumInit_);
|
||||
VToSSync();
|
||||
reduceSumValue = reduceLocal.GetValue(0);
|
||||
reduceSumValueInvert = 1 / reduceSumValue;
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ProcessKLtKMax(uint32_t loopBatch) {
|
||||
DataCopyExtParams copyParams{1, (uint32_t)(ubFactorElementAligned_ * sizeof(outputT)), 0, 0, 0};
|
||||
for (int32_t loopInner = 0; loopInner < loopInner_; loopInner++) {
|
||||
int64_t currentGmIdxInner = baseGmIdx_ + loopInner * ubFactorElementAligned_;
|
||||
if (loopInner == loopInner_ - 1) {
|
||||
DataCopyPad(mGmOut_[currentGmIdxInner], outTensor,
|
||||
{1, (uint32_t)(tailUbFactorElement_ * sizeof(outputT)), 0, 0, 0});
|
||||
} else {
|
||||
DataCopyPad(mGmOut_[currentGmIdxInner], outTensor, copyParams);
|
||||
}
|
||||
}
|
||||
Muls(softMaxRes, softMaxRes, reduceSumValueInvert, dataNumInit_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
const CumSumInfo cumSumInfo{1, dataNumInitAligned_};
|
||||
CumSum<float, CUMSUM_CONFIG>(cumSumRes, cumSumTmp, softMaxRes, sharedTmpBuffer, cumSumInfo);
|
||||
VToSSync();
|
||||
int32_t loopProb = dataNumInit_ - 1;
|
||||
scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb));
|
||||
SToMTE3Sync();
|
||||
int32_t gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb);
|
||||
PipeBarrier<PIPE_MTE3>();
|
||||
DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex], scatterTensor.template ReinterpretCast<outputT>(), scatterCopyParams);
|
||||
MTE3ToSSync();
|
||||
loopProb = loopProb - 1;
|
||||
for (; loopProb >= 0; loopProb--) {
|
||||
float cumsumData = cumSumRes.GetValue(loopProb);
|
||||
if (cumsumData <= pValue) {
|
||||
break;
|
||||
}
|
||||
scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb));
|
||||
gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb);
|
||||
SToMTE3Sync();
|
||||
DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex],
|
||||
scatterTensor.template ReinterpretCast<outputT>(), scatterCopyParams);
|
||||
MTE3ToSSync();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ScatterCumtomImpl(uint32_t loopBatch,
|
||||
uint32_t loopProbNum, uint32_t offset) {
|
||||
for (int32_t loopProb = 0; loopProb < static_cast<int32_t>(loopProbNum); loopProb++) {
|
||||
float cumsumDataTmp = cumSumRes.GetValue(loopProb);
|
||||
if (cumsumDataTmp <= pValue) {
|
||||
continue;
|
||||
}
|
||||
scatterTensor.SetValue(0, sortedValueLocal[offset].GetValue(loopProb));
|
||||
int32_t gmIndex = sortedIndicesLocal[offset].GetValue(loopProb);
|
||||
SToMTE3Sync();
|
||||
DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex], scatterTensor.template ReinterpretCast<outputT>(),
|
||||
{1, (uint32_t)(1 * sizeof(outputT)), 0, 0, 0});
|
||||
MTE3ToSSync();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::GetFirstKLoop(uint32_t loopBatch,
|
||||
int32_t &firstKLoop) {
|
||||
uint8_t repeatTimes = (dataNumInit_ + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
|
||||
uint32_t loopDataNum = ubFactorElementAligned_;
|
||||
for (int32_t loopInner = 0; loopInner < loopInner_; loopInner++) {
|
||||
int64_t currentGmIdx = baseGmIdx_ + loopInner * ubFactorElementAligned_;
|
||||
if (loopInner == (loopInner_ - 1)) {
|
||||
repeatTimes = ((tailUbFactorElement_) + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
|
||||
loopDataNum = tailUbFactorElement_;
|
||||
}
|
||||
DataCopyPad(mGmOut_[currentGmIdx], outTensor, {1, (uint32_t)(loopDataNum * sizeof(outputT)), 0, 0, 0});
|
||||
DataCopyPad(sortedValueLocal.template ReinterpretCast<inputT>(), mGmSortedValue_[currentGmIdx],
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(inputT)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
if constexpr (!IsSameType<inputT, float>::value) {
|
||||
MTE2ToVSync();
|
||||
Cast(sortedValueLocalFp32, sortedValueLocal, RoundMode::CAST_NONE, loopDataNum);
|
||||
VToSSync();
|
||||
} else {
|
||||
MTE2ToSSync();
|
||||
}
|
||||
if (calLocalFp32.GetValue(loopDataNum - 1) < kthValue) {
|
||||
firstKLoop += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
GetKthResult(loopBatch, 0, repeatTimes);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
ReduceSumWithAddsAndExpImpl(0, loopDataNum);
|
||||
VToSSync();
|
||||
reduceSumValue += reduceLocal.GetValue(0);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::CumSumWithAddsAndExpImpl(uint32_t offset,
|
||||
uint32_t loopDataNum, uint32_t cumsumInner, float cumsumData) {
|
||||
Adds(softMaxRes, calLocalFp32[offset], maxValue, loopDataNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Exp(softMaxRes, softMaxRes, loopDataNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(softMaxRes, softMaxRes, reduceSumValueInvert, loopDataNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
const CumSumInfo cumSumInfo{1, cumsumInner};
|
||||
CumSum<float, CUMSUM_CONFIG>(cumSumRes, cumSumTmp, softMaxRes, sharedTmpBuffer, cumSumInfo);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(cumSumRes, cumSumRes, cumsumData, loopDataNum);
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ProcessRemain(uint32_t loopBatch) {
|
||||
int32_t firstKLoop = 0;
|
||||
GetFirstKLoop(loopBatch, firstKLoop);
|
||||
reduceSumValueInvert = 1 / reduceSumValue;
|
||||
float cumsumData = 0;
|
||||
ScatterFromFirstKLoop(loopBatch, firstKLoop, cumsumData);
|
||||
uint32_t loopProb = dataNumInit_ - 1;
|
||||
scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb));
|
||||
int32_t gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb);
|
||||
SToMTE3Sync();
|
||||
DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex],
|
||||
scatterTensor.template ReinterpretCast<outputT>(), scatterCopyParams);
|
||||
MTE3ToVSync();
|
||||
CumSumWithAddsAndExpImpl(ubFactorElementAligned_, dataNumInit_, dataNumInitAligned_, cumsumData);
|
||||
VToSSync();
|
||||
ScatterCumtomImpl(loopBatch, dataNumInit_ - 1, ubFactorElementAligned_);
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ScatterFromFirstKLoop(uint32_t loopBatch,
|
||||
int32_t firstKLoop, float &cumsumData) {
|
||||
uint32_t loopDataNum = ubFactorElementAligned_;
|
||||
uint32_t cumsumInner = ubFactorElementAligned_;
|
||||
uint8_t repeatTimes = ((ubFactorElementAligned_) + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
|
||||
for (int32_t loopInner = firstKLoop; loopInner < loopInner_; loopInner++) {
|
||||
int64_t currentGmIdx = baseGmIdx_ + loopInner * ubFactorElementAligned_;
|
||||
if (loopInner == (loopInner_ - 1)) {
|
||||
repeatTimes = (tailUbFactorElement_ + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
|
||||
loopDataNum = tailUbFactorElement_;
|
||||
cumsumInner = tailUbFactorElementAligned_;
|
||||
}
|
||||
DataCopyPad(sortedValueLocal.template ReinterpretCast<inputT>(), mGmSortedValue_[currentGmIdx],
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(inputT)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
DataCopyPad(sortedIndicesLocal, mGmSortedIndices_[currentGmIdx],
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(int32_t)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
if constexpr (!IsSameType<inputT, float>::value) {
|
||||
MTE2ToVSync();
|
||||
Cast(sortedValueLocalFp32, sortedValueLocal, RoundMode::CAST_NONE, loopDataNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
} else {
|
||||
MTE2ToVSync();
|
||||
}
|
||||
GetKthResult(loopBatch, 0, repeatTimes);
|
||||
PipeBarrier<PIPE_V>();
|
||||
CumSumWithAddsAndExpImpl(0, loopDataNum, cumsumInner, cumsumData);
|
||||
VToSSync();
|
||||
float cumsumDataTmp = cumSumRes.GetValue(loopDataNum - 1);
|
||||
cumsumData = cumsumDataTmp;
|
||||
if (cumsumDataTmp <= pValue) {
|
||||
continue;
|
||||
}
|
||||
ScatterCumtomImpl(loopBatch, loopDataNum, 0);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ProcessTopK() {
|
||||
kLocal = kInQueue_.AllocTensor<int32_t>();
|
||||
outTensor = outQueue_.AllocTensor<outputT>();
|
||||
sortedValueLocal = sortedValueInQueue_.AllocTensor<inputT>();
|
||||
sortedIndicesLocal = sortedIndicesInQueue_.AllocTensor<int32_t>();
|
||||
Duplicate(negInfLocal.template ReinterpretCast<int32_t>(), FLOAT32_NEG_INF, DATA_PER_BLOCK_B32);
|
||||
if constexpr (IsSameType<inputT, float>::value) {
|
||||
calLocalFp32 = sortedValueLocal;
|
||||
Duplicate(outTensor.template ReinterpretCast<int32_t>(), FLOAT32_NEG_INF, ubFactorElementAligned_);
|
||||
} else if constexpr (IsSameType<inputT, half>::value) {
|
||||
calLocalFp32 = sortedValueLocalFp32;
|
||||
Duplicate(outTensor.template ReinterpretCast<uint16_t>(), FLOAT16_NEG_INF, ubFactorElementAligned_);
|
||||
} else {
|
||||
calLocalFp32 = sortedValueLocalFp32;
|
||||
Duplicate(outTensor.template ReinterpretCast<uint16_t>(), BF16_NEG_INF, ubFactorElementAligned_);
|
||||
}
|
||||
VToMTE3Sync();
|
||||
for (uint32_t loopBatch = 0; loopBatch < loopBatch_; loopBatch++) {
|
||||
baseGmIdx_ = batchOffset_ * vocabSize_ + loopBatch * vocabSize_;
|
||||
InitProcessTopK(loopBatch);
|
||||
/* The difference lies in that for the max branch, some data is less than the kthvalue,
|
||||
so part of the data can be filtered out in advance;
|
||||
while for the remain branch, all data must undergo the topk calculation.*/
|
||||
if (calLocalFp32.GetValue(ubFactorElementAligned_) < kthValue) {
|
||||
ProcessKLtKMaxTopK(loopBatch);
|
||||
} else {
|
||||
ProcessRemainTopK(loopBatch);
|
||||
}
|
||||
}
|
||||
kInQueue_.FreeTensor(kLocal);
|
||||
sortedValueInQueue_.FreeTensor(sortedValueLocal);
|
||||
sortedIndicesInQueue_.FreeTensor(sortedIndicesLocal);
|
||||
outQueue_.FreeTensor(outTensor);
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ProcessRemainTopK(uint32_t loopBatch) {
|
||||
int32_t firstKLoop = 0;
|
||||
GetFirstKLoopTopK(loopBatch, firstKLoop);
|
||||
// Start the scatter calculation from the first loop in the row where the value is ≥ kthValue.
|
||||
ScatterFromFirstKLoopTopK(loopBatch, firstKLoop);
|
||||
/* Perform scatter calculation on the maximum number of ubFactorElementAligned_,
|
||||
which does not overlap with the previous ones.*/
|
||||
uint32_t loopProb = dataNumInit_ - 1;
|
||||
scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb));
|
||||
SToMTE3Sync();
|
||||
int32_t gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb);
|
||||
DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex],
|
||||
scatterTensor.template ReinterpretCast<outputT>(), scatterCopyParams);
|
||||
MTE3ToSSync();
|
||||
ScatterCumtomImplTopK(loopBatch, dataNumInit_ - 1, ubFactorElementAligned_);
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::GetFirstKLoopTopK(uint32_t loopBatch,
|
||||
int32_t &firstKLoop) {
|
||||
uint8_t repeatTimes = (dataNumInit_ + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
|
||||
uint32_t loopDataNum = ubFactorElementAligned_;
|
||||
for (int32_t loopInner = 0; loopInner < loopInner_; loopInner++) {
|
||||
int64_t currentGmIdx = baseGmIdx_ + loopInner * ubFactorElementAligned_;
|
||||
if (loopInner == (loopInner_ - 1)) {
|
||||
repeatTimes = ((tailUbFactorElement_) + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
|
||||
loopDataNum = tailUbFactorElement_;
|
||||
}
|
||||
DataCopyPad(mGmOut_[currentGmIdx], outTensor, {1, (uint32_t)(loopDataNum * sizeof(outputT)), 0, 0, 0});
|
||||
DataCopyPad(sortedValueLocal.template ReinterpretCast<inputT>(), mGmSortedValue_[currentGmIdx],
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(inputT)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
MTE2ToSSync();
|
||||
float rightVlaue = 0;
|
||||
// Make a judgment on the rightmost value of each loop to filter the data.
|
||||
if constexpr (IsSameType<inputT, bfloat16_t>::value) {
|
||||
rightVlaue = ToFloat(sortedValueLocal.GetValue(loopDataNum - 1));
|
||||
} else {
|
||||
rightVlaue = static_cast<float>(sortedValueLocal.GetValue(loopDataNum - 1));
|
||||
}
|
||||
SToMTE2Sync();
|
||||
if (rightVlaue < kthValue) {
|
||||
firstKLoop += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ScatterFromFirstKLoopTopK(uint32_t loopBatch,
|
||||
int32_t firstKLoop) {
|
||||
uint32_t loopDataNum = ubFactorElementAligned_;
|
||||
uint32_t cumsumInner = ubFactorElementAligned_;
|
||||
uint8_t repeatTimes = ((ubFactorElementAligned_) + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
|
||||
for (int32_t loopInner = firstKLoop; loopInner < loopInner_; loopInner++) {
|
||||
int64_t currentGmIdx = baseGmIdx_ + loopInner * ubFactorElementAligned_;
|
||||
if (loopInner == (loopInner_ - 1)) {
|
||||
repeatTimes = (tailUbFactorElement_ + DATA_PER_REPEAT_B32 - 1) / DATA_PER_REPEAT_B32;
|
||||
loopDataNum = tailUbFactorElement_;
|
||||
cumsumInner = tailUbFactorElementAligned_;
|
||||
}
|
||||
DataCopyPad(sortedValueLocal.template ReinterpretCast<inputT>(), mGmSortedValue_[currentGmIdx],
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(inputT)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
if constexpr (!IsSameType<inputT, float>::value) {
|
||||
MTE2ToVSync();
|
||||
Cast(sortedValueLocalFp32, sortedValueLocal, RoundMode::CAST_NONE, loopDataNum);
|
||||
VToSSync();
|
||||
}
|
||||
DataCopyPad(sortedIndicesLocal, mGmSortedIndices_[currentGmIdx],
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(int32_t)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
MTE2ToSSync();
|
||||
ScatterCumtomImplTopK(loopBatch, loopDataNum, 0);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ScatterCumtomImplTopK(uint32_t loopBatch,
|
||||
uint32_t loopProbNum, uint32_t offset) {
|
||||
// Reverse traversal, returning early to improve performance.
|
||||
for (int32_t loopProb = static_cast<int32_t>(loopProbNum) - 1; loopProb >= 0; loopProb--) {
|
||||
float curValue = calLocalFp32[offset].GetValue(loopProb);
|
||||
if (curValue < kthValue) {
|
||||
break;
|
||||
}
|
||||
scatterTensor.SetValue(0, sortedValueLocal[offset].GetValue(loopProb));
|
||||
int32_t gmIndex = sortedIndicesLocal[offset].GetValue(loopProb);
|
||||
SToMTE3Sync();
|
||||
DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex], scatterTensor.template ReinterpretCast<outputT>(),
|
||||
{1, (uint32_t)(1 * sizeof(outputT)), 0, 0, 0});
|
||||
MTE3ToSSync();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::InitProcessTopK(uint32_t loopBatch) {
|
||||
int64_t initGmIdx = baseGmIdx_ + vocabSize_ - dataNumInit_;
|
||||
DataCopyPad(mGmOut_[initGmIdx], outTensor, {1, (uint32_t)(dataNumInit_ * sizeof(outputT)), 0, 0, 0});
|
||||
DataCopyPad(sortedValueLocal[ubFactorElementAligned_], mGmSortedValue_[initGmIdx],
|
||||
{1, static_cast<uint32_t>(dataNumInit_ * sizeof(inputT)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
if constexpr (!IsSameType<inputT, float>::value) {
|
||||
MTE2ToVSync();
|
||||
Cast(sortedValueLocalFp32[ubFactorElementAligned_], sortedValueLocal[ubFactorElementAligned_],
|
||||
RoundMode::CAST_NONE, dataNumInit_);
|
||||
}
|
||||
DataCopyPad(sortedIndicesLocal[ubFactorElementAligned_], mGmSortedIndices_[initGmIdx],
|
||||
{1, static_cast<uint32_t>(dataNumInit_ * sizeof(int32_t)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
DataCopyPad(kLocal, mGmK_[batchOffset_ + loopBatch], {1, static_cast<uint32_t>(sizeof(int32_t)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
MTE2ToSSync();
|
||||
int32_t kValue = mGmK_.GetValue(batchOffset_ + loopBatch);
|
||||
maxValue = -calLocalFp32[ubFactorElementAligned_].GetValue(dataNumInit_ - 1);
|
||||
if constexpr (IsSameType<inputT, float>::value) {
|
||||
kthValue = mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0);
|
||||
} else if constexpr (IsSameType<inputT, half>::value) {
|
||||
kthValue = static_cast<float>(mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0));
|
||||
} else {
|
||||
kthValue = ToFloat(mGmSortedValue_[baseGmIdx_ + vocabSize_ - kValue].GetValue(0));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopKTopPCustom<inputT, calT, outputT>::ProcessKLtKMaxTopK(uint32_t loopBatch) {
|
||||
DataCopyExtParams copyParams{1, (uint32_t)(ubFactorElementAligned_ * sizeof(outputT)), 0, 0, 0};
|
||||
// Move out -infinity to fill GM
|
||||
for (int32_t loopInner = 0; loopInner < loopInner_; loopInner++) {
|
||||
int64_t currentGmIdxInner = baseGmIdx_ + loopInner * ubFactorElementAligned_;
|
||||
if (loopInner == loopInner_ - 1) {
|
||||
DataCopyPad(mGmOut_[currentGmIdxInner], outTensor,
|
||||
{1, (uint32_t)(tailUbFactorElement_ * sizeof(outputT)), 0, 0, 0});
|
||||
} else {
|
||||
DataCopyPad(mGmOut_[currentGmIdxInner], outTensor, copyParams);
|
||||
}
|
||||
}
|
||||
// Scatter calculation
|
||||
int32_t loopProb = dataNumInit_ - 1;
|
||||
scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb));
|
||||
int32_t gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb);
|
||||
SToMTE3Sync();
|
||||
PipeBarrier<PIPE_MTE3>();
|
||||
DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex], scatterTensor.template ReinterpretCast<outputT>(), scatterCopyParams);
|
||||
loopProb = loopProb - 1;
|
||||
|
||||
for (; loopProb >= 0; loopProb--) {
|
||||
float curValue = calLocalFp32[ubFactorElementAligned_].GetValue(loopProb);
|
||||
if (curValue < kthValue) {
|
||||
break;
|
||||
}
|
||||
MTE3ToSSync();
|
||||
scatterTensor.SetValue(0, sortedValueLocal[ubFactorElementAligned_].GetValue(loopProb));
|
||||
gmIndex = sortedIndicesLocal[ubFactorElementAligned_].GetValue(loopProb);
|
||||
SToMTE3Sync();
|
||||
DataCopyPad(mGmOut_[baseGmIdx_ + gmIndex],
|
||||
scatterTensor.template ReinterpretCast<outputT>(), scatterCopyParams);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif // APPLY_TOP_K_TOP_P_CUSTOM_H_KERNEL
|
||||
468
csrc/apply_top_k_top_p_custom/op_kernel/apply_top_p_custom.h
Normal file
468
csrc/apply_top_k_top_p_custom/op_kernel/apply_top_p_custom.h
Normal file
@@ -0,0 +1,468 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file apply_top_p_custom.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef APPLY_TOP_P_CUSTOM_H_KERNEL
|
||||
#define APPLY_TOP_P_CUSTOM_H_KERNEL
|
||||
|
||||
#include "kernel_operator.h"
|
||||
|
||||
using namespace AscendC;
|
||||
namespace ApplyTopPCustomOp {
|
||||
constexpr uint16_t FLOAT16_NEG_INF = 0xFC00; // -inf 64512
|
||||
constexpr uint16_t BF16_NEG_INF = 0xFF80; // -inf 65408
|
||||
constexpr int32_t FLOAT32_NEG_INF = 0xFF800000; // -inf -2139095040
|
||||
|
||||
constexpr uint32_t BLOCK_BYTES = 32;
|
||||
constexpr uint32_t DATA_PER_BLOCK_B32 = 8;
|
||||
constexpr uint32_t DATA_PER_REPEAT_B32 = 64;
|
||||
constexpr uint32_t SCATTER_PART_LENGTH = 1024;
|
||||
constexpr uint32_t RESERVED_UB = 1024;
|
||||
constexpr uint32_t FLOAT_BYTES = 4;
|
||||
constexpr uint32_t SOFTMAX_UB_NUM = 2;
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
class ApplyTopPCustom {
|
||||
public:
|
||||
__aicore__ inline ApplyTopPCustom(){};
|
||||
__aicore__ inline void InitTilingData(
|
||||
const ApplyTopKTopPCustomTilingData &__restrict tilingData, GM_ADDR sorted_value, GM_ADDR sorted_indices, GM_ADDR p, GM_ADDR k, GM_ADDR out, GM_ADDR workspace);
|
||||
__aicore__ inline void InitBuffer(TPipe *inputPipe);
|
||||
__aicore__ inline void ProcessTopP();
|
||||
private:
|
||||
__aicore__ inline void ReduceSumWithAddsAndExpImpl(uint32_t loopDataNum);
|
||||
// topp func
|
||||
__aicore__ inline void ProcessPreSingleBatch(uint32_t loopBatch);
|
||||
__aicore__ inline void GetSoftmaxSum(uint32_t loopBatch);
|
||||
__aicore__ inline void CumsumKoggleStone(uint32_t loopBatch);
|
||||
__aicore__ inline void GetPValue(uint32_t batchOffset);
|
||||
__aicore__ inline void CumsumParamCompute(uint32_t iterateTime);
|
||||
__aicore__ inline void GetMaxValue(int64_t baseGmIdx);
|
||||
__aicore__ inline void GetSoftMaxRes(uint32_t loopBatch);
|
||||
__aicore__ inline void ScatterSingleTask(uint32_t taskIndex);
|
||||
|
||||
__aicore__ inline void SToMTE3Sync() {
|
||||
event_t eventIDSToMTE3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_MTE3));
|
||||
SetFlag<HardEvent::S_MTE3>(eventIDSToMTE3);
|
||||
WaitFlag<HardEvent::S_MTE3>(eventIDSToMTE3);
|
||||
}
|
||||
__aicore__ inline void VToMTE3Sync() {
|
||||
event_t eventIDVToMTE3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
|
||||
SetFlag<HardEvent::V_MTE3>(eventIDVToMTE3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventIDVToMTE3);
|
||||
}
|
||||
__aicore__ inline void VToMTE2Sync() {
|
||||
event_t eventIDVToMTE2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(eventIDVToMTE2);
|
||||
WaitFlag<HardEvent::V_MTE2>(eventIDVToMTE2);
|
||||
}
|
||||
__aicore__ inline void MTE3ToSSync() {
|
||||
event_t eventIDMTE3ToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_S));
|
||||
SetFlag<HardEvent::MTE3_S>(eventIDMTE3ToS);
|
||||
WaitFlag<HardEvent::MTE3_S>(eventIDMTE3ToS);
|
||||
}
|
||||
__aicore__ inline void VToSSync() {
|
||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
||||
}
|
||||
__aicore__ inline void SToVSync() {
|
||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
||||
}
|
||||
__aicore__ inline void MTE2ToVSync() {
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
}
|
||||
__aicore__ inline void MTE2ToSSync() {
|
||||
event_t eventIdMte2ToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_S));
|
||||
SetFlag<HardEvent::MTE2_S>(eventIdMte2ToS);
|
||||
WaitFlag<HardEvent::MTE2_S>(eventIdMte2ToS);
|
||||
}
|
||||
|
||||
__aicore__ inline void MTE3ToMTE2Sync() {
|
||||
event_t eventIdMte3ToMte2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));
|
||||
SetFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
|
||||
WaitFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
|
||||
}
|
||||
|
||||
__aicore__ inline void MTE3ToVSync() {
|
||||
event_t eventIdMte3ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
|
||||
SetFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
|
||||
WaitFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
|
||||
}
|
||||
|
||||
__aicore__ inline uint32_t CeilDiv(uint32_t x, uint32_t y) {
|
||||
return y == 0 ? x : (x + y - 1) / y;
|
||||
}
|
||||
private:
|
||||
TPipe *pipe_;
|
||||
TBuf<TPosition::VECCALC> calBuf_;
|
||||
|
||||
// tilingData
|
||||
uint32_t batchSize_ = 0;
|
||||
uint32_t vocabSize_ = 0;
|
||||
uint32_t batchPerCore_ = 0;
|
||||
uint32_t tailBatch_ = 0;
|
||||
uint32_t blockNum_ = 0;
|
||||
uint32_t dataNumInitAligned_ = 0;
|
||||
uint32_t calUbSize_ = 0;
|
||||
uint32_t blockIdx_ = 0;
|
||||
uint32_t loopBatch_ = 0;
|
||||
uint32_t batchOffset_ = 0;
|
||||
uint32_t bufOffsetLoop = 0;
|
||||
int64_t baseGmIdx_ = 0;
|
||||
|
||||
// topp scalar
|
||||
uint32_t maxSoftmaxLength = 1;
|
||||
uint32_t softmaxLength = 1;
|
||||
uint32_t lineSfLoopTimes = 1;
|
||||
uint32_t softmaxLengthTail = 1;
|
||||
uint32_t scatterLength = 1;
|
||||
|
||||
uint32_t singleCoreB = 0;
|
||||
uint32_t singleCoreBTail = 0;
|
||||
uint32_t vCnt = 0;
|
||||
uint32_t bCnt = 0;
|
||||
uint32_t singleCoreV = 1;
|
||||
uint32_t singleCoreVTail = 1;
|
||||
uint32_t iterateTimes = 1;
|
||||
|
||||
GlobalTensor<inputT> mGmSortedValue_;
|
||||
GlobalTensor<int32_t> mGmSortedIndices_;
|
||||
GlobalTensor<inputT> mGmP_;
|
||||
GlobalTensor<int32_t> mGmK_;
|
||||
GlobalTensor<outputT> mGmOut_;
|
||||
GlobalTensor<float> softMaxGm;
|
||||
|
||||
LocalTensor<uint8_t> totalUb;
|
||||
|
||||
// softmax tensor
|
||||
LocalTensor<float> softMaxLocalFp32;
|
||||
LocalTensor<inputT> softMaxLocal;
|
||||
LocalTensor<float> softMaxResLocal;
|
||||
LocalTensor<float> reduceLocal;
|
||||
LocalTensor<inputT> outInfLocal;
|
||||
|
||||
// cumsum tensor
|
||||
LocalTensor<float> cumSumInput1Local;
|
||||
LocalTensor<float> cumSumInput2Local;
|
||||
|
||||
// scatter tensor
|
||||
LocalTensor<inputT> sortedValueLocal;
|
||||
LocalTensor<int32_t> sortedIndicesLocal;
|
||||
LocalTensor<float> sortedValueLocalFp32;
|
||||
LocalTensor<outputT> scatterLocal;
|
||||
LocalTensor<float> cumsumLocal;
|
||||
|
||||
float pValue = 0;
|
||||
float maxValue = 0;
|
||||
float reduceSumValueInvert = 0;
|
||||
float reduceSumValue = 0;
|
||||
BinaryRepeatParams repeatParams = {1, 0, 1, 8, 0, 8};
|
||||
DataCopyExtParams scatterCopyParams{1, (uint32_t)(sizeof(outputT)), 0, 0, 0};
|
||||
};
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::InitTilingData(
|
||||
const ApplyTopKTopPCustomTilingData &__restrict tilingData, GM_ADDR sorted_value, GM_ADDR sorted_indices,
|
||||
GM_ADDR p, GM_ADDR k, GM_ADDR out, GM_ADDR workspace) {
|
||||
batchSize_ = tilingData.batchSize;
|
||||
vocabSize_ = tilingData.vocabSize;
|
||||
batchPerCore_ = tilingData.batchPerCore;
|
||||
tailBatch_ = tilingData.tailBatch;
|
||||
blockNum_ = tilingData.blockNum;
|
||||
calUbSize_ = tilingData.calUbSize;
|
||||
iterateTimes = tilingData.iterateTimes;
|
||||
blockIdx_ = GetBlockIdx();
|
||||
if (blockIdx_ < tailBatch_) {
|
||||
loopBatch_ = batchPerCore_ + 1;
|
||||
batchOffset_ = blockIdx_ * loopBatch_;
|
||||
} else {
|
||||
loopBatch_ = batchPerCore_;
|
||||
batchOffset_ = blockIdx_ * batchPerCore_ + tailBatch_;
|
||||
}
|
||||
maxSoftmaxLength = (calUbSize_ - RESERVED_UB) / SOFTMAX_UB_NUM / FLOAT_BYTES;
|
||||
softmaxLength = maxSoftmaxLength < vocabSize_ ? maxSoftmaxLength : vocabSize_;
|
||||
|
||||
lineSfLoopTimes = (vocabSize_ + softmaxLength - 1) / softmaxLength;
|
||||
softmaxLengthTail = vocabSize_ - (lineSfLoopTimes - 1) * softmaxLength;
|
||||
scatterLength = (calUbSize_ - RESERVED_UB - BLOCK_BYTES) / (SOFTMAX_UB_NUM * FLOAT_BYTES + sizeof(inputT)) /
|
||||
SCATTER_PART_LENGTH * SCATTER_PART_LENGTH;
|
||||
singleCoreB = CeilDiv(batchSize_, blockNum_);
|
||||
vCnt = batchSize_ < blockNum_ ? blockNum_ / batchSize_ : 1;
|
||||
bCnt = batchSize_;
|
||||
singleCoreB = 1;
|
||||
singleCoreBTail = 1;
|
||||
singleCoreV = vocabSize_ / vCnt;
|
||||
singleCoreVTail = vocabSize_ - vCnt * singleCoreV;
|
||||
mGmSortedValue_.SetGlobalBuffer(reinterpret_cast<__gm__ inputT *>(sorted_value));
|
||||
mGmSortedIndices_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sorted_indices));
|
||||
mGmP_.SetGlobalBuffer(reinterpret_cast<__gm__ inputT *>(p));
|
||||
mGmK_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(k));
|
||||
mGmOut_.SetGlobalBuffer(reinterpret_cast<__gm__ outputT *>(out));
|
||||
softMaxGm.SetGlobalBuffer((__gm__ float*)workspace, batchSize_ * vocabSize_);
|
||||
}
|
||||
|
||||
// init used buffer
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::InitBuffer(TPipe *inputPipe) {
|
||||
pipe_ = inputPipe;
|
||||
pipe_->InitBuffer(calBuf_, calUbSize_);
|
||||
totalUb = calBuf_.Get<uint8_t>();
|
||||
// softmax ub
|
||||
uint32_t softmaxLengthAligned = CeilDiv(softmaxLength, BLOCK_BYTES / sizeof(inputT)) * BLOCK_BYTES / sizeof(inputT);
|
||||
softMaxLocalFp32 = totalUb.ReinterpretCast<float>();
|
||||
softMaxLocal = totalUb[softmaxLengthAligned * sizeof(inputT)].ReinterpretCast<inputT>();
|
||||
softMaxResLocal = totalUb[softmaxLengthAligned * sizeof(float)].ReinterpretCast<float>();
|
||||
reduceLocal = totalUb[softmaxLengthAligned * sizeof(float) * 2].ReinterpretCast<float>(); // 32 bytes
|
||||
outInfLocal = totalUb.ReinterpretCast<inputT>(); // Take softmax ub
|
||||
|
||||
// cumsum ub
|
||||
cumSumInput1Local = totalUb.ReinterpretCast<float>(); // Take softmax local
|
||||
cumSumInput2Local = totalUb[softmaxLengthAligned * sizeof(float)].ReinterpretCast<float>(); // Take softmax res ub
|
||||
|
||||
// scatter ub
|
||||
sortedValueLocal = totalUb[0].ReinterpretCast<inputT>();
|
||||
sortedIndicesLocal = totalUb[scatterLength * sizeof(inputT)].ReinterpretCast<int32_t>();
|
||||
cumsumLocal = totalUb[scatterLength * (FLOAT_BYTES + sizeof(inputT))].ReinterpretCast<float>();
|
||||
scatterLocal = totalUb[calUbSize_ - RESERVED_UB + BLOCK_BYTES].ReinterpretCast<outputT>(); // 32 bytes
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::GetMaxValue(int64_t baseGmIdx) {
|
||||
int64_t initGmIdx = baseGmIdx + vocabSize_ - 1;
|
||||
if constexpr (IsSameType<inputT, float>::value) {
|
||||
maxValue = -mGmSortedValue_[initGmIdx].GetValue(0);
|
||||
} else if constexpr (IsSameType<inputT, half>::value) {
|
||||
maxValue = -static_cast<float>(mGmSortedValue_[initGmIdx].GetValue(0));
|
||||
} else {
|
||||
maxValue = -ToFloat(mGmSortedValue_[initGmIdx].GetValue(0));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::GetPValue(uint32_t batchOffset) {
|
||||
if constexpr (IsSameType<inputT, float>::value) {
|
||||
pValue = float(1.0) - mGmP_[batchOffset].GetValue(0);
|
||||
} else if constexpr (IsSameType<inputT, half>::value) {
|
||||
pValue = float(1.0) - static_cast<float>(mGmP_[batchOffset].GetValue(0));
|
||||
} else {
|
||||
pValue = float(1.0) - ToFloat(mGmP_[batchOffset].GetValue(0));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::ProcessPreSingleBatch(uint32_t loopBatch) {
|
||||
reduceSumValue = 0;
|
||||
GetSoftmaxSum(loopBatch);
|
||||
GetSoftMaxRes(loopBatch);
|
||||
CumsumKoggleStone(loopBatch);
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::ProcessTopP() {
|
||||
for (uint32_t loopBatch = 0; loopBatch < loopBatch_; loopBatch++) {
|
||||
baseGmIdx_ = batchOffset_ * vocabSize_ + loopBatch * vocabSize_;
|
||||
GetMaxValue(baseGmIdx_); // Get max value in softmax.
|
||||
ProcessPreSingleBatch(loopBatch); // Softmax and cumsum.
|
||||
}
|
||||
SyncAll();
|
||||
for (uint32_t taskIndex = 0; taskIndex < bCnt * vCnt; taskIndex++) {
|
||||
ScatterSingleTask(taskIndex);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::ScatterSingleTask(uint32_t taskIndex) {
|
||||
if (GetBlockIdx() == taskIndex % blockNum_) {
|
||||
uint32_t bCntIndex = taskIndex / vCnt;
|
||||
uint32_t vCntIndex = taskIndex % vCnt;
|
||||
uint32_t vCurSingleCore = vCntIndex < singleCoreVTail ? (singleCoreV + 1) : singleCoreV;
|
||||
uint32_t copyTimes = CeilDiv(vCurSingleCore, scatterLength);
|
||||
uint32_t copyLength = scatterLength;
|
||||
uint32_t copyLengthTail = vCurSingleCore - (copyTimes - 1) * scatterLength;
|
||||
GetPValue(bCntIndex); // Get maxPValue.
|
||||
for (uint32_t cpIndex = 0; cpIndex < copyTimes; cpIndex++) {
|
||||
uint32_t curCopyLength = cpIndex == (copyTimes - 1) ? copyLengthTail : copyLength;
|
||||
int64_t gmOffset = vCntIndex < singleCoreVTail ?
|
||||
bCntIndex * vocabSize_ + vCntIndex * (singleCoreV + 1) + cpIndex * copyLength :
|
||||
bCntIndex * vocabSize_ + vCntIndex * singleCoreV + singleCoreVTail + cpIndex * copyLength;
|
||||
DataCopyPad(cumsumLocal, softMaxGm[gmOffset],
|
||||
{1, static_cast<uint32_t>(curCopyLength * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
|
||||
DataCopyPad(sortedIndicesLocal, mGmSortedIndices_[gmOffset],
|
||||
{1, static_cast<uint32_t>(curCopyLength * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
|
||||
DataCopyPad(sortedValueLocal, mGmSortedValue_[gmOffset],
|
||||
{1, static_cast<uint32_t>(curCopyLength * sizeof(inputT)), 0, 0, 0}, {false, 0, 0, 0});
|
||||
MTE2ToSSync();
|
||||
|
||||
if (cumsumLocal.GetValue(curCopyLength - 1) <= pValue) {
|
||||
continue;
|
||||
}
|
||||
uint32_t scatterLoop = CeilDiv(curCopyLength, SCATTER_PART_LENGTH);
|
||||
uint32_t scatterNumsTail = curCopyLength - (scatterLoop - 1) * SCATTER_PART_LENGTH;
|
||||
for (uint32_t scatterLoopIndex = 0; scatterLoopIndex < scatterLoop; scatterLoopIndex++) {
|
||||
uint32_t curScatterNums = scatterLoopIndex == (scatterLoop - 1) ? scatterNumsTail : SCATTER_PART_LENGTH;
|
||||
if (cumsumLocal.GetValue(scatterLoopIndex * SCATTER_PART_LENGTH + curScatterNums - 1) <= pValue) {
|
||||
continue;
|
||||
}
|
||||
for (uint32_t scatterIndex = 0; scatterIndex < curScatterNums; scatterIndex++) {
|
||||
int64_t scatterOffset = scatterLoopIndex * SCATTER_PART_LENGTH + scatterIndex;
|
||||
if (cumsumLocal.GetValue(scatterOffset) <= pValue) {
|
||||
continue;
|
||||
}
|
||||
scatterLocal.SetValue(0, sortedValueLocal.GetValue(scatterOffset));
|
||||
int32_t lineIndex = sortedIndicesLocal.GetValue(scatterOffset);
|
||||
SToMTE3Sync();
|
||||
DataCopyPad(mGmOut_[bCntIndex * vocabSize_ + lineIndex], scatterLocal.template ReinterpretCast<outputT>(),
|
||||
{1, (uint32_t)(1 * sizeof(outputT)), 0, 0, 0});
|
||||
MTE3ToSSync();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::GetSoftMaxRes(uint32_t loopBatch) {
|
||||
uint32_t loopDataNum = softmaxLength;
|
||||
for (int32_t loopInner = 0; loopInner < lineSfLoopTimes; loopInner++) {
|
||||
int64_t currentGmIdx = baseGmIdx_ + loopInner * softmaxLength;
|
||||
if (loopInner == (lineSfLoopTimes - 1)) {
|
||||
loopDataNum = softmaxLengthTail;
|
||||
}
|
||||
if constexpr (!IsSameType<inputT, float>::value) {
|
||||
DataCopyPad(softMaxLocal, mGmSortedValue_[currentGmIdx],
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(inputT)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
MTE2ToVSync();
|
||||
Cast(softMaxLocalFp32, softMaxLocal, RoundMode::CAST_NONE, loopDataNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
} else {
|
||||
DataCopyPad(softMaxLocalFp32, mGmSortedValue_[currentGmIdx],
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(float)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
MTE2ToVSync();
|
||||
}
|
||||
Adds(softMaxResLocal, softMaxLocalFp32, maxValue, loopDataNum);
|
||||
VToMTE2Sync();
|
||||
PipeBarrier<PIPE_V>();
|
||||
Exp(softMaxResLocal, softMaxResLocal, loopDataNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(softMaxResLocal, softMaxResLocal, reduceSumValueInvert, loopDataNum);
|
||||
VToMTE3Sync();
|
||||
DataCopyPad(softMaxGm[currentGmIdx], softMaxResLocal,
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(float)), 0, 0, 0});
|
||||
MTE3ToMTE2Sync();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::CumsumKoggleStone(uint32_t loopBatch) {
|
||||
uint32_t loopDataNum = softmaxLength;
|
||||
for (uint32_t iterateTime = 0; iterateTime < iterateTimes; iterateTime++) {
|
||||
int64_t iteratOffset = 1;
|
||||
for (uint32_t powerIdx = 0; powerIdx < iterateTime; powerIdx++) {
|
||||
iteratOffset = iteratOffset * 2;
|
||||
}
|
||||
uint32_t addLength = vocabSize_ - iteratOffset;
|
||||
uint32_t innerLoopNum = addLength / softmaxLength;
|
||||
uint32_t dataTail = addLength - innerLoopNum * softmaxLength;
|
||||
loopDataNum = softmaxLength;
|
||||
for (uint32_t innerLoopIdx = 0; innerLoopIdx < innerLoopNum; innerLoopIdx++) {
|
||||
// Copy data from right
|
||||
int64_t loopInnerOffset = dataTail + (innerLoopNum - 1 - innerLoopIdx) * softmaxLength;
|
||||
DataCopyPad(cumSumInput1Local, softMaxGm[baseGmIdx_ + loopInnerOffset],
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
|
||||
DataCopyPad(cumSumInput2Local, softMaxGm[baseGmIdx_ + loopInnerOffset + iteratOffset],
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
|
||||
MTE2ToVSync();
|
||||
Add(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, loopDataNum);
|
||||
VToMTE3Sync();
|
||||
DataCopyPad(softMaxGm[baseGmIdx_ + loopInnerOffset + iteratOffset], cumSumInput1Local,
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(float)), 0, 0, 0});
|
||||
MTE3ToMTE2Sync();
|
||||
}
|
||||
if (dataTail > 0) {
|
||||
loopDataNum = dataTail;
|
||||
DataCopyPad(cumSumInput1Local, softMaxGm[baseGmIdx_],
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
|
||||
DataCopyPad(cumSumInput2Local, softMaxGm[baseGmIdx_ + iteratOffset],
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
|
||||
MTE2ToVSync();
|
||||
Add(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, loopDataNum);
|
||||
VToMTE3Sync();
|
||||
DataCopyPad(softMaxGm[baseGmIdx_ + iteratOffset], cumSumInput1Local,
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(float)), 0, 0, 0});
|
||||
MTE3ToMTE2Sync();
|
||||
}
|
||||
}
|
||||
MTE3ToVSync();
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::ReduceSumWithAddsAndExpImpl(
|
||||
uint32_t loopDataNum) {
|
||||
Adds(softMaxResLocal, softMaxLocalFp32, maxValue, loopDataNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Exp(softMaxResLocal, softMaxResLocal, loopDataNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSum(reduceLocal, softMaxResLocal, reduceLocal, loopDataNum);
|
||||
}
|
||||
|
||||
template <typename inputT, typename calT, typename outputT>
|
||||
__aicore__ inline void ApplyTopPCustom<inputT, calT, outputT>::GetSoftmaxSum(uint32_t loopBatch) {
|
||||
uint32_t loopDataNum = softmaxLength;
|
||||
for (int32_t loopInner = 0; loopInner < lineSfLoopTimes; loopInner++) {
|
||||
int64_t currentGmIdx = baseGmIdx_ + loopInner * softmaxLength;
|
||||
if (loopInner == (lineSfLoopTimes - 1)) {
|
||||
loopDataNum = softmaxLengthTail;
|
||||
}
|
||||
if constexpr (IsSameType<inputT, float>::value) {
|
||||
Duplicate(outInfLocal.template ReinterpretCast<int32_t>(), FLOAT32_NEG_INF, loopDataNum);
|
||||
} else if constexpr (IsSameType<inputT, half>::value) {
|
||||
Duplicate(outInfLocal.template ReinterpretCast<uint16_t>(), FLOAT16_NEG_INF, loopDataNum);
|
||||
} else {
|
||||
Duplicate(outInfLocal.template ReinterpretCast<uint16_t>(), BF16_NEG_INF, loopDataNum);
|
||||
}
|
||||
VToMTE3Sync();
|
||||
DataCopyPad(mGmOut_[currentGmIdx], outInfLocal,
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(inputT)), 0, 0, 0});
|
||||
MTE3ToMTE2Sync();
|
||||
if constexpr (!IsSameType<inputT, float>::value) {
|
||||
DataCopyPad(softMaxLocal, mGmSortedValue_[currentGmIdx],
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(inputT)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
MTE2ToVSync();
|
||||
Cast(softMaxLocalFp32, softMaxLocal, RoundMode::CAST_NONE, loopDataNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
} else {
|
||||
DataCopyPad(softMaxLocalFp32, mGmSortedValue_[currentGmIdx],
|
||||
{1, static_cast<uint32_t>(loopDataNum * sizeof(inputT)), 0, 0, 0},
|
||||
{false, 0, 0, 0});
|
||||
MTE2ToVSync();
|
||||
}
|
||||
|
||||
ReduceSumWithAddsAndExpImpl(loopDataNum);
|
||||
VToSSync();
|
||||
// Sum up to obtain the sum of exp reduce for the first x loops in the row.
|
||||
reduceSumValue += reduceLocal.GetValue(0);
|
||||
SToVSync();
|
||||
}
|
||||
reduceSumValueInvert = 1 / reduceSumValue;
|
||||
SToVSync();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#endif // APPLY_TOP_P_CUSTOM_H_KERNEL
|
||||
@@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
|
||||
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
|
||||
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
|
||||
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;"
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;"
|
||||
SOC_ARG="ascend910b"
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
# ASCEND910C (A3) series
|
||||
@@ -80,6 +80,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
"moe_init_routing_custom"
|
||||
"moe_gating_top_k"
|
||||
"add_rms_norm_bias"
|
||||
"apply_top_k_top_p_custom"
|
||||
)
|
||||
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
|
||||
SOC_ARG="ascend910_93"
|
||||
|
||||
@@ -1234,6 +1234,26 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_
|
||||
return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale);
|
||||
}
|
||||
|
||||
at::Tensor npu_apply_top_k_top_p(
|
||||
const at::Tensor& logits,
|
||||
const c10::optional<at::Tensor>& p,
|
||||
const c10::optional<at::Tensor>& k)
|
||||
{
|
||||
TORCH_CHECK(p.has_value() || k.has_value(),
|
||||
"apply_top_k_top_p: p and k cannot be None at the same time.");
|
||||
|
||||
at::Tensor out = at::empty_like(logits);
|
||||
|
||||
EXEC_NPU_CMD(
|
||||
aclnnApplyTopKTopPCustom,
|
||||
logits,
|
||||
p,
|
||||
k,
|
||||
out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> moe_gating_top_k(
|
||||
const at::Tensor& x,
|
||||
int64_t k,
|
||||
@@ -1495,4 +1515,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
"-> (Tensor y ,Tensor rstd, Tensor x)"
|
||||
);
|
||||
ops.impl("npu_add_rms_norm_bias", torch::kPrivateUse1, &vllm_ascend::npu_add_rms_norm_bias);
|
||||
|
||||
ops.def("npu_apply_top_k_top_p(Tensor logits, Tensor? p=None, Tensor? k=None) -> Tensor");
|
||||
ops.impl("npu_apply_top_k_top_p", torch::kPrivateUse1, &vllm_ascend::npu_apply_top_k_top_p);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user