[feature] add_rms_norm support bias (#5790)
### What this PR does / why we need it?
This PR is to replace addRmsNorm and Add With addRmsNormBias. This way
can lead to a more effecient result.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Full Test Pass
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: Chen_HaoWen <chenhaowen12@huawei.com>
Co-authored-by: Chen_HaoWen <chenhaowen12@huawei.com>
This commit is contained in:
72
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias.cpp
Normal file
72
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias.cpp
Normal file
@@ -0,0 +1,72 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file add_rms_norm_bias.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "add_rms_norm_bias.h"
|
||||
#include "add_rms_norm_bias_split_d.h"
|
||||
#include "add_rms_norm_bias_merge_n.h"
|
||||
#include "add_rms_norm_bias_multi_n.h"
|
||||
#include "add_rms_norm_bias_single_n.h"
|
||||
|
||||
using namespace AscendC;
|
||||
|
||||
#define GENERAL_OP_IMPL(templateClass, ...) \
|
||||
do { \
|
||||
templateClass<__VA_ARGS__> op(&pipe); \
|
||||
op.Init(x1, x2, gamma, beta, y, rstd, x, &tilingData); \
|
||||
op.Process(); \
|
||||
} while (0)
|
||||
|
||||
extern "C" __global__ __aicore__ void add_rms_norm_bias(
|
||||
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, GM_ADDR workspace, GM_ADDR tiling)
|
||||
{
|
||||
TPipe pipe;
|
||||
GET_TILING_DATA(tilingData, tiling);
|
||||
if (TILING_KEY_IS(10)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBias, half);
|
||||
} else if (TILING_KEY_IS(20)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBias, float);
|
||||
} else if (TILING_KEY_IS(30)) {
|
||||
#if !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBias, bfloat16_t);
|
||||
#endif
|
||||
} else if (TILING_KEY_IS(11)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasSplitD, half);
|
||||
} else if (TILING_KEY_IS(21)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasSplitD, float);
|
||||
} else if (TILING_KEY_IS(31)) {
|
||||
#if !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasSplitD, bfloat16_t);
|
||||
#endif
|
||||
} else if (TILING_KEY_IS(12)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasMergeN, half);
|
||||
} else if (TILING_KEY_IS(22)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasMergeN, float);
|
||||
} else if (TILING_KEY_IS(32)) {
|
||||
#if !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasMergeN, bfloat16_t);
|
||||
#endif
|
||||
} else if (TILING_KEY_IS(13)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasSingleN, half);
|
||||
} else if (TILING_KEY_IS(23)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasSingleN, float);
|
||||
} else if (TILING_KEY_IS(33)) {
|
||||
#if !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasSingleN, bfloat16_t);
|
||||
#endif
|
||||
} else if (TILING_KEY_IS(14)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasMultiN, half);
|
||||
} else if (TILING_KEY_IS(34)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasMultiN, bfloat16_t);
|
||||
}
|
||||
}
|
||||
368
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias.h
Normal file
368
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias.h
Normal file
@@ -0,0 +1,368 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file add_rms_norm_bias.h
|
||||
* \brief add rms norm bias file
|
||||
*/
|
||||
#ifndef ADD_RMS_NORM_H_
|
||||
#define ADD_RMS_NORM_H_
|
||||
#include "./rms_norm_base.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace RmsNorm;
|
||||
|
||||
template <typename T>
|
||||
class KernelAddRmsNormBias {
|
||||
public:
|
||||
__aicore__ inline KernelAddRmsNormBias(TPipe* pipe)
|
||||
{
|
||||
Ppipe = pipe;
|
||||
}
|
||||
__aicore__ inline void Init(
|
||||
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, const AddRMSNormBiasTilingData* tiling)
|
||||
{
|
||||
ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!");
|
||||
this->numRow = tiling->num_row;
|
||||
this->numCol = tiling->num_col;
|
||||
this->blockFactor = tiling->block_factor;
|
||||
this->rowFactor = tiling->row_factor;
|
||||
this->ubFactor = tiling->ub_factor;
|
||||
this->epsilon = tiling->epsilon;
|
||||
this->avgFactor = (numCol != 0) ? (float)1.0 / numCol : 0;
|
||||
this->nullptrBeta = tiling->nullptr_beta;
|
||||
|
||||
blockIdx_ = GetBlockIdx();
|
||||
if (blockIdx_ < GetBlockNum() - 1) {
|
||||
this->rowWork = blockFactor;
|
||||
} else if (blockIdx_ == GetBlockNum() - 1) {
|
||||
this->rowWork = numRow - (GetBlockNum() - 1) * blockFactor;
|
||||
}
|
||||
// get start index for current core, core parallel
|
||||
x1Gm.SetGlobalBuffer((__gm__ T*)x1 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
x2Gm.SetGlobalBuffer((__gm__ T*)x2 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
gammaGm.SetGlobalBuffer((__gm__ T*)gamma, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
betaGm.SetGlobalBuffer((__gm__ T*)beta, numCol);
|
||||
}
|
||||
yGm.SetGlobalBuffer((__gm__ T*)y + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
rstdGm.SetGlobalBuffer((__gm__ float*)rstd + blockIdx_ * blockFactor, blockFactor);
|
||||
xGm.SetGlobalBuffer((__gm__ T*)x + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
|
||||
// pipe alloc memory to queue, the unit is Bytes
|
||||
Ppipe->InitBuffer(inQueueX, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
Ppipe->InitBuffer(inQueueGamma, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
if (!this->nullptrBeta) {
|
||||
Ppipe->InitBuffer(inQueueBeta, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
}
|
||||
Ppipe->InitBuffer(outQueueY, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
Ppipe->InitBuffer(outQueueRstd, BUFFER_NUM, rowFactor * sizeof(float));
|
||||
|
||||
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
|
||||
Ppipe->InitBuffer(xFp32Buf, ubFactor * sizeof(float));
|
||||
}
|
||||
Ppipe->InitBuffer(sqxBuf, ubFactor * sizeof(float));
|
||||
Ppipe->InitBuffer(reduceFp32Buf, NUM_PER_REP_FP32 * sizeof(float));
|
||||
}
|
||||
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
CopyInGammaBeta();
|
||||
LocalTensor<T> gammaLocal = inQueueGamma.DeQue<T>();
|
||||
LocalTensor<T> betaLocal;
|
||||
if (!this->nullptrBeta) {
|
||||
betaLocal = inQueueBeta.DeQue<T>();
|
||||
}
|
||||
uint32_t i_o_max = RmsNorm::CeilDiv(rowWork, rowFactor);
|
||||
uint32_t row_tail = rowWork - (i_o_max - 1) * rowFactor;
|
||||
|
||||
for (uint32_t i_o = 0; i_o < i_o_max - 1; i_o++) {
|
||||
SubProcess(i_o, rowFactor, gammaLocal, betaLocal);
|
||||
}
|
||||
SubProcess(i_o_max - 1, row_tail, gammaLocal, betaLocal);
|
||||
inQueueGamma.FreeTensor(gammaLocal);
|
||||
if (!this->nullptrBeta) {
|
||||
inQueueBeta.FreeTensor(betaLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void SubProcess(uint32_t i_o, uint32_t calc_row_num, LocalTensor<T>& gammaLocal, LocalTensor<T>& betaLocal)
|
||||
{
|
||||
LocalTensor<float> rstdLocal = outQueueRstd.AllocTensor<float>();
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
uint32_t gm_bias = (i_o * rowFactor + i_i) * numCol;
|
||||
CopyIn(gm_bias);
|
||||
Compute(i_i, gammaLocal, betaLocal, rstdLocal);
|
||||
CopyOutY(gm_bias);
|
||||
}
|
||||
outQueueRstd.EnQue<float>(rstdLocal);
|
||||
CopyOutRstd(i_o, calc_row_num);
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyIn(uint32_t gm_bias)
|
||||
{
|
||||
LocalTensor<T> x1Local_in = inQueueX.AllocTensor<T>();
|
||||
LocalTensor<T> x2Local = sqxBuf.Get<T>();
|
||||
LocalTensor<T> xLocal = outQueueY.AllocTensor<T>();
|
||||
|
||||
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
|
||||
x2Local = x2Local[ubFactor];
|
||||
}
|
||||
|
||||
DataCopyCustom<T>(x1Local_in, x1Gm[gm_bias], numCol);
|
||||
DataCopyCustom<T>(x2Local, x2Gm[gm_bias], numCol);
|
||||
inQueueX.EnQue(x1Local_in);
|
||||
auto x1Local = inQueueX.DeQue<T>();
|
||||
|
||||
if constexpr (is_same<T, half>::value) {
|
||||
LocalTensor<float> x1_fp32 = xFp32Buf.Get<float>();
|
||||
Add(xLocal, x1Local, x2Local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(x1_fp32, xLocal, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
} else if constexpr (is_same<T, bfloat16_t>::value) {
|
||||
LocalTensor<float> x1_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> x2_fp32 = sqxBuf.Get<float>();
|
||||
Cast(x1_fp32, x1Local, RoundMode::CAST_NONE, numCol);
|
||||
Cast(x2_fp32, x2Local, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(x1_fp32, x1_fp32, x2_fp32, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(xLocal, x1_fp32, RoundMode::CAST_RINT, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
} else {
|
||||
Add(x1Local, x1Local, x2Local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(xLocal, x1Local, (float)0, numCol);
|
||||
}
|
||||
inQueueX.FreeTensor(x1Local);
|
||||
|
||||
// CopyOut x1 + x2
|
||||
outQueueY.EnQue(xLocal);
|
||||
auto x_out = outQueueY.DeQue<T>();
|
||||
DataCopyCustom<T>(xGm[gm_bias], x_out, numCol);
|
||||
outQueueY.FreeTensor(x_out);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInGammaBeta()
|
||||
{
|
||||
LocalTensor<T> gammaLocal = inQueueGamma.AllocTensor<T>();
|
||||
DataCopyCustom<T>(gammaLocal, gammaGm, numCol);
|
||||
inQueueGamma.EnQue(gammaLocal);
|
||||
if (!this->nullptrBeta) {
|
||||
LocalTensor<T> betaLocal = inQueueBeta.AllocTensor<T>();
|
||||
DataCopyCustom<T>(betaLocal, betaGm, numCol);
|
||||
inQueueBeta.EnQue(betaLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void Compute(uint32_t inner_progress, LocalTensor<float> gammaLocal, LocalTensor<float> betaLocal, LocalTensor<float> rstdLocal)
|
||||
{
|
||||
LocalTensor<float> xLocal = inQueueX.AllocTensor<float>();
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
|
||||
Mul(sqx, xLocal, xLocal, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Muls(sqx, sqx, avgFactor, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
ReduceSumCustom(sqx, sqx, reduce_buf_local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(sqx, sqx, epsilon, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Sqrt(sqx, sqx, 1);
|
||||
Duplicate(reduce_buf_local, ONE, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(sqx, reduce_buf_local, sqx, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(event_v_s);
|
||||
WaitFlag<HardEvent::V_S>(event_v_s);
|
||||
float rstdValue = sqx.GetValue(0);
|
||||
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(event_s_v);
|
||||
WaitFlag<HardEvent::S_V>(event_s_v);
|
||||
rstdLocal.SetValue(inner_progress, rstdValue);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<float> yLocal = outQueueY.AllocTensor<float>();
|
||||
Muls(yLocal, xLocal, rstdValue, numCol);
|
||||
inQueueX.FreeTensor(xLocal);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(yLocal, gammaLocal, yLocal, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(yLocal, betaLocal, yLocal, numCol);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
outQueueY.EnQue<float>(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void Compute(
|
||||
uint32_t inner_progress, LocalTensor<bfloat16_t> gammaLocal, LocalTensor<bfloat16_t> betaLocal, LocalTensor<float> rstdLocal)
|
||||
{
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
|
||||
|
||||
Mul(sqx, x_fp32, x_fp32, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Muls(sqx, sqx, avgFactor, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSumCustom(sqx, sqx, reduce_buf_local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Adds(sqx, sqx, epsilon, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Sqrt(sqx, sqx, 1);
|
||||
Duplicate(reduce_buf_local, ONE, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(sqx, reduce_buf_local, sqx, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(event_v_s);
|
||||
WaitFlag<HardEvent::V_S>(event_v_s);
|
||||
float rstdValue = sqx.GetValue(0);
|
||||
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(event_s_v);
|
||||
WaitFlag<HardEvent::S_V>(event_s_v);
|
||||
rstdLocal.SetValue(inner_progress, rstdValue);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(x_fp32, x_fp32, rstdValue, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<bfloat16_t> yLocal = outQueueY.AllocTensor<bfloat16_t>();
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(x_fp32, yLocal, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(sqx, gammaLocal, RoundMode::CAST_NONE, numCol); // gamma_fp32 reuse sqx
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(x_fp32, x_fp32, sqx, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(sqx, betaLocal, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(x_fp32, x_fp32, sqx, numCol);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
event_t event_v_mte = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(event_v_mte);
|
||||
WaitFlag<HardEvent::V_MTE2>(event_v_mte);
|
||||
|
||||
outQueueY.EnQue<bfloat16_t>(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void Compute(uint32_t inner_progress, LocalTensor<half> gammaLocal, LocalTensor<half> betaLocal, LocalTensor<float> rstdLocal)
|
||||
{
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
|
||||
|
||||
Mul(sqx, x_fp32, x_fp32, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Muls(sqx, sqx, avgFactor, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
ReduceSumCustom(sqx, sqx, reduce_buf_local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Adds(sqx, sqx, epsilon, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Sqrt(sqx, sqx, 1);
|
||||
Duplicate(reduce_buf_local, ONE, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(sqx, reduce_buf_local, sqx, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(event_v_s);
|
||||
WaitFlag<HardEvent::V_S>(event_v_s);
|
||||
float rstdValue = sqx.GetValue(0);
|
||||
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(event_s_v);
|
||||
WaitFlag<HardEvent::S_V>(event_s_v);
|
||||
rstdLocal.SetValue(inner_progress, rstdValue);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(x_fp32, x_fp32, rstdValue, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<half> yLocal = outQueueY.AllocTensor<half>();
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_NONE, numCol);
|
||||
|
||||
event_t event_v_mte = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(event_v_mte);
|
||||
WaitFlag<HardEvent::V_MTE2>(event_v_mte);
|
||||
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(yLocal, gammaLocal, yLocal, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(yLocal, betaLocal, yLocal, numCol);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
outQueueY.EnQue<half>(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOutY(uint32_t progress)
|
||||
{
|
||||
LocalTensor<T> yLocal = outQueueY.DeQue<T>();
|
||||
DataCopyCustom<T>(yGm[progress], yLocal, numCol);
|
||||
outQueueY.FreeTensor(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOutRstd(uint32_t outer_progress, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> rstdLocal = outQueueRstd.DeQue<float>();
|
||||
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
DataCopyCustom<float>(rstdGm[outer_progress * rowFactor], rstdLocal, num);
|
||||
#endif
|
||||
outQueueRstd.FreeTensor(rstdLocal);
|
||||
}
|
||||
|
||||
private:
|
||||
TPipe* Ppipe = nullptr;
|
||||
// create queues for input, in this case depth is equal to buffer num
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueX;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueGamma;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueBeta;
|
||||
// create queues for output, in this case depth is equal to buffer num
|
||||
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueY;
|
||||
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueRstd;
|
||||
|
||||
TBuf<TPosition::VECCALC> xFp32Buf;
|
||||
TBuf<TPosition::VECCALC> sqxBuf;
|
||||
TBuf<TPosition::VECCALC> reduceFp32Buf;
|
||||
GlobalTensor<T> x1Gm;
|
||||
GlobalTensor<T> x2Gm;
|
||||
GlobalTensor<T> gammaGm;
|
||||
GlobalTensor<T> betaGm;
|
||||
GlobalTensor<T> yGm;
|
||||
GlobalTensor<float> rstdGm;
|
||||
GlobalTensor<T> xGm;
|
||||
|
||||
uint32_t numRow;
|
||||
uint32_t numCol;
|
||||
uint32_t blockFactor; // number of calculations rows on each core
|
||||
uint32_t rowFactor;
|
||||
uint32_t ubFactor;
|
||||
float epsilon;
|
||||
float avgFactor;
|
||||
int32_t blockIdx_;
|
||||
uint32_t rowWork = 1;
|
||||
uint32_t nullptrBeta = 0;
|
||||
};
|
||||
#endif // ADD_RMS_NORM_BIAS_H_
|
||||
471
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias_merge_n.h
Normal file
471
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias_merge_n.h
Normal file
@@ -0,0 +1,471 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file add_rms_norm_bias_merge_n.h
|
||||
* \brief add rms norm bias merge n file
|
||||
*/
|
||||
#ifndef ADD_RMS_NORM_BIAS_MERGE_N_H_
|
||||
#define ADD_RMS_NORM_BIAS_MERGE_N_H_
|
||||
#include "./rms_norm_base.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace RmsNorm;
|
||||
|
||||
template <typename T>
|
||||
class KernelAddRmsNormBiasMergeN {
|
||||
public:
|
||||
__aicore__ inline KernelAddRmsNormBiasMergeN(TPipe* pipe)
|
||||
{
|
||||
Ppipe = pipe;
|
||||
}
|
||||
__aicore__ inline void Init(
|
||||
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, const AddRMSNormBiasTilingData* tiling)
|
||||
{
|
||||
ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!");
|
||||
this->numRow = tiling->num_row;
|
||||
this->numCol = tiling->num_col;
|
||||
this->numColAlign = tiling->num_col_align;
|
||||
this->blockFactor = tiling->block_factor;
|
||||
this->rowFactor = tiling->row_factor;
|
||||
this->ubFactor = tiling->ub_factor;
|
||||
this->epsilon = tiling->epsilon;
|
||||
this->avgFactor = tiling->avg_factor;
|
||||
|
||||
blockIdx_ = GetBlockIdx();
|
||||
if (blockIdx_ < GetBlockNum() - 1) {
|
||||
this->rowWork = blockFactor;
|
||||
this->rowLoop = tiling->row_loop;
|
||||
this->rowTail = tiling->row_tail;
|
||||
} else if (blockIdx_ == GetBlockNum() - 1) {
|
||||
this->rowWork = tiling->last_block_factor;
|
||||
this->rowLoop = tiling->last_block_row_loop;
|
||||
this->rowTail = tiling->last_block_row_tail;
|
||||
}
|
||||
this->mulLoopFp32 = tiling->mul_loop_fp32;
|
||||
this->mulTailFp32 = tiling->mul_tail_fp32;
|
||||
this->dstRepStrideFp32 = tiling->dst_rep_stride_fp32;
|
||||
this->mulLoopFp16 = tiling->mul_loop_fp16;
|
||||
this->mulTailFp16 = tiling->mul_tail_fp16;
|
||||
this->dstRepStrideFp16 = tiling->dst_rep_stride_fp16;
|
||||
this->isPerformance = tiling->is_performance;
|
||||
this->nullptrBeta = tiling->nullptr_beta;
|
||||
// get start index for current core, core parallel
|
||||
x1Gm.SetGlobalBuffer((__gm__ T*)x1 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
x2Gm.SetGlobalBuffer((__gm__ T*)x2 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
gammaGm.SetGlobalBuffer((__gm__ T*)gamma, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
betaGm.SetGlobalBuffer((__gm__ T*)beta, numCol);
|
||||
}
|
||||
yGm.SetGlobalBuffer((__gm__ T*)y + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
rstdGm.SetGlobalBuffer((__gm__ float*)rstd + blockIdx_ * blockFactor, blockFactor);
|
||||
xGm.SetGlobalBuffer((__gm__ T*)x + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
|
||||
// pipe alloc memory to queue, the unit is Bytes
|
||||
Ppipe->InitBuffer(inQueueX, DOUBLE_BUFFER_NUM, ubFactor * sizeof(T));
|
||||
Ppipe->InitBuffer(inQueueGamma, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
if (!this->nullptrBeta) {
|
||||
Ppipe->InitBuffer(inQueueBeta, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
}
|
||||
Ppipe->InitBuffer(outQueueY, DOUBLE_BUFFER_NUM, ubFactor * sizeof(T));
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
Ppipe->InitBuffer(outQueueRstd, BUFFER_NUM, rowFactor * sizeof(float));
|
||||
#else
|
||||
Ppipe->InitBuffer(rstdBuf, rowFactor * sizeof(float));
|
||||
#endif
|
||||
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
|
||||
Ppipe->InitBuffer(xFp32Buf, ubFactor * sizeof(float));
|
||||
}
|
||||
Ppipe->InitBuffer(sqxBuf, ubFactor * sizeof(float));
|
||||
Ppipe->InitBuffer(tmpBuf, rowFactor * NUM_PER_REP_FP32 * sizeof(float));
|
||||
}
|
||||
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
CopyInGammaBeta();
|
||||
LocalTensor<T> gammaLocal = inQueueGamma.DeQue<T>();
|
||||
LocalTensor<T> betaLocal;
|
||||
if (!this->nullptrBeta) {
|
||||
betaLocal = inQueueBeta.DeQue<T>();
|
||||
}
|
||||
for (uint32_t i_o = 0; i_o < rowLoop - 1; i_o++) {
|
||||
MainCompute(i_o, rowFactor, gammaLocal, betaLocal);
|
||||
}
|
||||
MainCompute(rowLoop - 1, rowTail, gammaLocal, betaLocal);
|
||||
inQueueGamma.FreeTensor(gammaLocal);
|
||||
if (!this->nullptrBeta) {
|
||||
inQueueBeta.FreeTensor(betaLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MainCompute(uint32_t i_o, uint32_t calc_row_num, LocalTensor<T>& gammaLocal, LocalTensor<T>& betaLocal)
|
||||
{
|
||||
uint32_t gm_bias = i_o * rowFactor * numCol;
|
||||
uint32_t elementNum = calc_row_num * numColAlign;
|
||||
CopyInX(gm_bias, calc_row_num);
|
||||
LocalTensor<T> xLocal = ComputeX(elementNum);
|
||||
CopyOutX(gm_bias, calc_row_num);
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
LocalTensor<float> rstdLocal = outQueueRstd.AllocTensor<float>();
|
||||
ComputeRstd(xLocal, rstdLocal, calc_row_num, elementNum);
|
||||
outQueueRstd.EnQue<float>(rstdLocal);
|
||||
CopyOutRstd(i_o, calc_row_num);
|
||||
#else
|
||||
LocalTensor<float> rstdLocal = rstdBuf.Get<float>();
|
||||
ComputeRstd(xLocal, rstdLocal, calc_row_num, elementNum);
|
||||
#endif
|
||||
ComputeY(xLocal, gammaLocal, betaLocal, rstdLocal, calc_row_num, elementNum);
|
||||
CopyOutY(gm_bias, calc_row_num);
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyInX(uint32_t gm_bias, uint32_t calc_row_num)
|
||||
{
|
||||
LocalTensor<T> x1Local = inQueueX.AllocTensor<T>();
|
||||
if (isNumColAlign) {
|
||||
DataCopyCustom<T>(x1Local, x1Gm[gm_bias], calc_row_num * numCol);
|
||||
} else {
|
||||
DataCopyCustom<T>(x1Local, x1Gm[gm_bias], calc_row_num, numCol);
|
||||
}
|
||||
inQueueX.EnQue(x1Local);
|
||||
LocalTensor<T> x2Local = inQueueX.AllocTensor<T>();
|
||||
if (isNumColAlign) {
|
||||
DataCopyCustom<T>(x2Local, x2Gm[gm_bias], calc_row_num * numCol);
|
||||
} else {
|
||||
DataCopyCustom<T>(x2Local, x2Gm[gm_bias], calc_row_num, numCol);
|
||||
}
|
||||
inQueueX.EnQue(x2Local);
|
||||
}
|
||||
|
||||
__aicore__ inline LocalTensor<T> ComputeX(uint32_t elementNum)
|
||||
{
|
||||
LocalTensor<T> x1Local = inQueueX.DeQue<T>();
|
||||
LocalTensor<T> x2Local = inQueueX.DeQue<T>();
|
||||
LocalTensor<T> xLocal = outQueueY.AllocTensor<T>();
|
||||
if constexpr (!is_same<T, bfloat16_t>::value) {
|
||||
Add(xLocal, x1Local, x2Local, elementNum);
|
||||
} else {
|
||||
LocalTensor<float> x1Fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> x2Fp32 = sqxBuf.Get<float>();
|
||||
Cast(x1Fp32, x1Local, RoundMode::CAST_NONE, elementNum);
|
||||
Cast(x2Fp32, x2Local, RoundMode::CAST_NONE, elementNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(x1Fp32, x1Fp32, x2Fp32, elementNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(xLocal, x1Fp32, RoundMode::CAST_RINT, elementNum);
|
||||
}
|
||||
inQueueX.FreeTensor(x1Local);
|
||||
inQueueX.FreeTensor(x2Local);
|
||||
outQueueY.EnQue(xLocal);
|
||||
PipeBarrier<PIPE_V>();
|
||||
return xLocal;
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOutX(uint32_t gm_bias, uint32_t calc_row_num)
|
||||
{
|
||||
// CopyOut x1 + x2
|
||||
auto xOut = outQueueY.DeQue<T>();
|
||||
if (isNumColAlign) {
|
||||
DataCopyCustom<T>(xGm[gm_bias], xOut, calc_row_num * numCol);
|
||||
} else {
|
||||
DataCopyCustom<T>(xGm[gm_bias], xOut, calc_row_num, numCol);
|
||||
}
|
||||
outQueueY.FreeTensor(xOut);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInGammaBeta()
|
||||
{
|
||||
LocalTensor<T> gammaLocal = inQueueGamma.AllocTensor<T>();
|
||||
DataCopyCustom<T>(gammaLocal, gammaGm, numCol);
|
||||
inQueueGamma.EnQue(gammaLocal);
|
||||
if (!this->nullptrBeta) {
|
||||
LocalTensor<T> betaLocal = inQueueBeta.AllocTensor<T>();
|
||||
DataCopyCustom<T>(betaLocal, betaGm, numCol);
|
||||
inQueueBeta.EnQue(betaLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeRstd(LocalTensor<T> xLocal, LocalTensor<float> rstdLocal, uint32_t calc_row_num, uint32_t elementNum)
|
||||
{
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
LocalTensor<float> tmpLocal = tmpBuf.Get<float>();
|
||||
if constexpr (!is_same<T, float>::value) {
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
Cast(x_fp32, xLocal, RoundMode::CAST_NONE, elementNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(sqx, x_fp32, x_fp32, elementNum);
|
||||
} else {
|
||||
Mul(sqx, xLocal, xLocal, elementNum);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Muls(sqx, sqx, avgFactor, elementNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
ReduceSumMultiN(rstdLocal, sqx, tmpLocal, calc_row_num, numCol, numColAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(rstdLocal, rstdLocal, epsilon, calc_row_num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Sqrt(rstdLocal, rstdLocal, calc_row_num);
|
||||
Duplicate(tmpLocal, ONE, calc_row_num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Div(rstdLocal, tmpLocal, rstdLocal, calc_row_num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeY(
|
||||
LocalTensor<T> xLocal, LocalTensor<T> gammaLocal, LocalTensor<T> betaLocal, LocalTensor<float> rstdLocal, uint32_t calc_row_num, uint32_t elementNum)
|
||||
{
|
||||
LocalTensor<float> tmpLocal = tmpBuf.Get<float>();
|
||||
uint32_t splidRow = 240;
|
||||
uint32_t rowRepeatLoop1 = calc_row_num / splidRow;
|
||||
uint32_t rowRepeatTail1 = calc_row_num - rowRepeatLoop1 * splidRow;
|
||||
for(uint32_t r_i = 0; r_i < rowRepeatLoop1; r_i ++) {
|
||||
Brcb(tmpLocal[r_i * splidRow * MOV_8], rstdLocal[r_i * splidRow], splidRow, {1, 8});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
if(rowRepeatTail1 > 0) {
|
||||
Brcb(tmpLocal[rowRepeatLoop1 * splidRow * MOV_8], rstdLocal[rowRepeatLoop1 * splidRow], rowRepeatTail1, {1, 8});
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
LocalTensor<T> yLocal = outQueueY.AllocTensor<T>();
|
||||
if constexpr (!is_same<T, float>::value) {
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
repeatByRow<float>(x_fp32, x_fp32, tmpLocal, calc_row_num, ONE_UINT);
|
||||
if constexpr (is_same<T, half>::value) {
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_NONE, elementNum);
|
||||
} else {
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, elementNum);
|
||||
}
|
||||
} else {
|
||||
repeatByRow<float>(yLocal, xLocal, tmpLocal, calc_row_num, ONE_UINT);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
if constexpr (is_same<T, half>::value) {
|
||||
repeatByRow<half>(yLocal, yLocal, gammaLocal, calc_row_num, TWO_UINT);
|
||||
if (!this->nullptrBeta) {
|
||||
addRepeatByRow<half>(yLocal, yLocal, betaLocal, calc_row_num, TWO_UINT);
|
||||
}
|
||||
} else if constexpr (is_same<T, bfloat16_t>::value) {
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
Cast(x_fp32, yLocal, RoundMode::CAST_NONE, elementNum);
|
||||
Cast(sqx, gammaLocal, RoundMode::CAST_NONE, elementNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
repeatByRow<float>(x_fp32, x_fp32, sqx, calc_row_num, THREE_UINT);
|
||||
if (!this->nullptrBeta) {
|
||||
Cast(sqx, betaLocal, RoundMode::CAST_NONE, elementNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
addRepeatByRow<float>(x_fp32, x_fp32, sqx, calc_row_num, THREE_UINT);
|
||||
}
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, elementNum);
|
||||
} else {
|
||||
repeatByRow<float>(yLocal, yLocal, gammaLocal, calc_row_num, THREE_UINT);
|
||||
if (!this->nullptrBeta) {
|
||||
addRepeatByRow<float>(yLocal, yLocal, betaLocal, calc_row_num, THREE_UINT);
|
||||
}
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
outQueueY.EnQue<T>(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOutY(uint32_t progress, uint32_t calc_row_num)
|
||||
{
|
||||
LocalTensor<T> yLocal = outQueueY.DeQue<T>();
|
||||
if (isNumColAlign) {
|
||||
DataCopyCustom<T>(yGm[progress], yLocal, calc_row_num * numCol);
|
||||
} else {
|
||||
DataCopyCustom<T>(yGm[progress], yLocal, calc_row_num, numCol);
|
||||
}
|
||||
outQueueY.FreeTensor(yLocal);
|
||||
}
|
||||
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
__aicore__ inline void CopyOutRstd(uint32_t outer_progress, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> rstdLocal = outQueueRstd.DeQue<float>();
|
||||
DataCopyCustom<float>(rstdGm[outer_progress * rowFactor], rstdLocal, num);
|
||||
outQueueRstd.FreeTensor(rstdLocal);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename U>
|
||||
__aicore__ inline void repeatByRow(const LocalTensor<U>& dstLocal, const LocalTensor<U>& src1Local, const LocalTensor<U>& src2Local, uint32_t calc_row_num, uint32_t type)
|
||||
{
|
||||
// TWO_UINT=gammaFp16 ONE_UINT=rstd
|
||||
uint32_t strideParams[6] = {mulLoopFp32, mulTailFp32, 64, 1, dstRepStrideFp32, 0};
|
||||
if (type == TWO_UINT) {
|
||||
strideParams[0] = mulLoopFp16;
|
||||
strideParams[1] = mulTailFp16;
|
||||
strideParams[2] = 128;
|
||||
strideParams[4] = dstRepStrideFp16;
|
||||
} else if (type == ONE_UINT) {
|
||||
strideParams[3] = 0;
|
||||
strideParams[5] = 1;
|
||||
}
|
||||
uint32_t singlT = 255;
|
||||
uint32_t rowRepeatLoop = calc_row_num / singlT;
|
||||
uint32_t rowRepeatTail = calc_row_num - rowRepeatLoop * singlT;
|
||||
uint32_t offset2 = 0;
|
||||
for(uint32_t r_i = 0; r_i < rowRepeatLoop; r_i ++) {
|
||||
offset2 = type == 1 ? (r_i * singlT * MOV_8) : 0;
|
||||
mulRepeat<U>(dstLocal[r_i * singlT * numColAlign], src1Local[r_i * singlT * numColAlign], src2Local[offset2], singlT, strideParams);
|
||||
}
|
||||
if(rowRepeatTail > 0) {
|
||||
offset2 = type == 1 ? (rowRepeatLoop * singlT * MOV_8) : 0;
|
||||
uint32_t offset1 = rowRepeatLoop * singlT * numColAlign;
|
||||
mulRepeat<U>(dstLocal[offset1], src1Local[offset1], src2Local[offset2], rowRepeatTail, strideParams);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__aicore__ inline void mulRepeat(const LocalTensor<U>& dstLocal, const LocalTensor<U>& src1Local, const LocalTensor<U>& src2Local, uint32_t calcRowNum, uint32_t strideParams[6])
|
||||
{
|
||||
uint32_t mulLoop = strideParams[0];
|
||||
uint32_t mulTail = strideParams[1];
|
||||
uint32_t strideNum = strideParams[2];
|
||||
uint8_t src1BlkStride = static_cast<uint8_t>(strideParams[3]);
|
||||
uint8_t dstRepStride = static_cast<uint8_t>(strideParams[4]);
|
||||
uint8_t src1RepStride = static_cast<uint8_t>(strideParams[5]);
|
||||
if(src1BlkStride == 0) {
|
||||
for (uint32_t m_i = 0; m_i < mulLoop; m_i++) {
|
||||
Mul(dstLocal[m_i * strideNum], src1Local[m_i * strideNum], src2Local, strideNum, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
if(mulTail > 0) {
|
||||
Mul(dstLocal[mulLoop * strideNum], src1Local[mulLoop * strideNum], src2Local, mulTail, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
} else {
|
||||
for (uint32_t m_i = 0; m_i < mulLoop; m_i++) {
|
||||
Mul(dstLocal[m_i * strideNum], src1Local[m_i * strideNum], src2Local[m_i * strideNum], strideNum, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
if(mulTail > 0) {
|
||||
Mul(dstLocal[mulLoop * strideNum], src1Local[mulLoop * strideNum], src2Local[mulLoop * strideNum], mulTail, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__aicore__ inline void addRepeatByRow(const LocalTensor<U>& dstLocal, const LocalTensor<U>& src1Local, const LocalTensor<U>& src2Local, uint32_t calc_row_num, uint32_t type)
|
||||
{
|
||||
// TWO_UINT=gammaFp16 ONE_UINT=rstd
|
||||
uint32_t strideParams[6] = {mulLoopFp32, mulTailFp32, 64, 1, dstRepStrideFp32, 0};
|
||||
if (type == TWO_UINT) {
|
||||
strideParams[0] = mulLoopFp16;
|
||||
strideParams[1] = mulTailFp16;
|
||||
strideParams[2] = 128;
|
||||
strideParams[4] = dstRepStrideFp16;
|
||||
} else if (type == ONE_UINT) {
|
||||
strideParams[3] = 0;
|
||||
strideParams[5] = 1;
|
||||
}
|
||||
uint32_t singlT = 255;
|
||||
uint32_t rowRepeatLoop = calc_row_num / singlT;
|
||||
uint32_t rowRepeatTail = calc_row_num - rowRepeatLoop * singlT;
|
||||
uint32_t offset2 = 0;
|
||||
for(uint32_t r_i = 0; r_i < rowRepeatLoop; r_i ++) {
|
||||
offset2 = type == 1 ? (r_i * singlT * MOV_8) : 0;
|
||||
addRepeat<U>(dstLocal[r_i * singlT * numColAlign], src1Local[r_i * singlT * numColAlign], src2Local[offset2], singlT, strideParams);
|
||||
}
|
||||
if(rowRepeatTail > 0) {
|
||||
offset2 = type == 1 ? (rowRepeatLoop * singlT * MOV_8) : 0;
|
||||
uint32_t offset1 = rowRepeatLoop * singlT * numColAlign;
|
||||
addRepeat<U>(dstLocal[offset1], src1Local[offset1], src2Local[offset2], rowRepeatTail, strideParams);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__aicore__ inline void addRepeat(const LocalTensor<U>& dstLocal, const LocalTensor<U>& src1Local, const LocalTensor<U>& src2Local, uint32_t calcRowNum, uint32_t strideParams[6])
|
||||
{
|
||||
uint32_t addLoop = strideParams[0];
|
||||
uint32_t addTail = strideParams[1];
|
||||
uint32_t strideNum = strideParams[2];
|
||||
uint8_t src1BlkStride = static_cast<uint8_t>(strideParams[3]);
|
||||
uint8_t dstRepStride = static_cast<uint8_t>(strideParams[4]);
|
||||
uint8_t src1RepStride = static_cast<uint8_t>(strideParams[5]);
|
||||
if(src1BlkStride == 0) {
|
||||
for (uint32_t m_i = 0; m_i < addLoop; m_i++) {
|
||||
Add(dstLocal[m_i * strideNum], src1Local[m_i * strideNum], src2Local, strideNum, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
if(addTail > 0) {
|
||||
Add(dstLocal[addLoop * strideNum], src1Local[addLoop * strideNum], src2Local, addTail, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
} else {
|
||||
for (uint32_t m_i = 0; m_i < addLoop; m_i++) {
|
||||
Add(dstLocal[m_i * strideNum], src1Local[m_i * strideNum], src2Local[m_i * strideNum], strideNum, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
if(addTail > 0) {
|
||||
Add(dstLocal[addLoop * strideNum], src1Local[addLoop * strideNum], src2Local[addLoop * strideNum], addTail, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
TPipe* Ppipe = nullptr;
|
||||
// create queues for input, in this case depth is equal to buffer num
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueGamma;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueBeta;
|
||||
TQue<QuePosition::VECIN, DOUBLE_BUFFER_NUM> inQueueX;
|
||||
// create queues for output, in this case depth is equal to buffer num
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueRstd;
|
||||
#else
|
||||
TBuf<TPosition::VECCALC> rstdBuf;
|
||||
#endif
|
||||
TQue<QuePosition::VECOUT, DOUBLE_BUFFER_NUM> outQueueY;
|
||||
|
||||
TBuf<TPosition::VECCALC> xFp32Buf;
|
||||
TBuf<TPosition::VECCALC> sqxBuf;
|
||||
TBuf<TPosition::VECCALC> tmpBuf;
|
||||
GlobalTensor<T> x1Gm;
|
||||
GlobalTensor<T> x2Gm;
|
||||
GlobalTensor<T> gammaGm;
|
||||
GlobalTensor<T> betaGm;
|
||||
GlobalTensor<T> yGm;
|
||||
GlobalTensor<float> rstdGm;
|
||||
GlobalTensor<T> xGm;
|
||||
|
||||
uint32_t numRow;
|
||||
uint32_t numCol;
|
||||
uint32_t numColAlign;
|
||||
uint32_t blockFactor; // number of calculations rows on each core
|
||||
uint32_t rowFactor;
|
||||
uint32_t ubFactor;
|
||||
float epsilon;
|
||||
float avgFactor;
|
||||
int32_t blockIdx_;
|
||||
uint32_t rowWork = 1;
|
||||
#if (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
bool isNumColAlign = true;
|
||||
#else
|
||||
bool isNumColAlign = false;
|
||||
#endif
|
||||
uint8_t isPerformance = 0;
|
||||
uint32_t rowLoop = 1;
|
||||
uint32_t rowTail = 0;
|
||||
uint32_t mulLoopFp32;
|
||||
uint32_t mulTailFp32;
|
||||
uint8_t dstRepStrideFp32;
|
||||
uint32_t mulLoopFp16;
|
||||
uint32_t mulTailFp16;
|
||||
uint8_t dstRepStrideFp16;
|
||||
uint32_t nullptrBeta = 0;
|
||||
};
|
||||
#endif // _ADD_RMS_NORM_BIAS_MERGE_N_H_
|
||||
339
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias_multi_n.h
Normal file
339
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias_multi_n.h
Normal file
@@ -0,0 +1,339 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file add_rms_norm_bias_multi_n.h
|
||||
* \brief add rms norm bias multi n file
|
||||
*/
|
||||
#ifndef ADD_RMS_NORM_BIAS_MULTI_N_H_
|
||||
#define ADD_RMS_NORM_BIAS_MULTI_N_H_
|
||||
#include "./rms_norm_base.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace RmsNorm;
|
||||
|
||||
template <typename T>
|
||||
class KernelAddRmsNormBiasMultiN {
|
||||
public:
|
||||
__aicore__ inline KernelAddRmsNormBiasMultiN(TPipe* pipe)
|
||||
{
|
||||
Ppipe = pipe;
|
||||
}
|
||||
__aicore__ inline void Init(
|
||||
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, const AddRMSNormBiasTilingData* tiling)
|
||||
{
|
||||
ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!");
|
||||
this->numRow = tiling->num_row;
|
||||
this->numCol = tiling->num_col;
|
||||
this->numColAlign = tiling->num_col_align;
|
||||
this->blockFactor = tiling->block_factor;
|
||||
this->rowFactor = tiling->row_factor;
|
||||
this->ubFactor = tiling->ub_factor;
|
||||
this->epsilon = tiling->epsilon;
|
||||
this->avgFactor = tiling->avg_factor;
|
||||
this->nullptrBeta = tiling->nullptr_beta;
|
||||
|
||||
blockIdx_ = GetBlockIdx();
|
||||
if (blockIdx_ < GetBlockNum() - 1) {
|
||||
this->rowWork = blockFactor;
|
||||
this->rowLoop = tiling->row_loop;
|
||||
this->rowTail = tiling->row_tail;
|
||||
} else if (blockIdx_ == GetBlockNum() - 1) {
|
||||
this->rowWork = tiling->last_block_factor;
|
||||
this->rowLoop = tiling->last_block_row_loop;
|
||||
this->rowTail = tiling->last_block_row_tail;
|
||||
}
|
||||
// get start index for current core, core parallel
|
||||
x1Gm.SetGlobalBuffer((__gm__ T*)x1 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
x2Gm.SetGlobalBuffer((__gm__ T*)x2 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
gammaGm.SetGlobalBuffer((__gm__ T*)gamma, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
betaGm.SetGlobalBuffer((__gm__ T*)beta, numCol);
|
||||
}
|
||||
yGm.SetGlobalBuffer((__gm__ T*)y + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
rstdGm.SetGlobalBuffer((__gm__ float*)rstd + blockIdx_ * blockFactor, blockFactor);
|
||||
xGm.SetGlobalBuffer((__gm__ T*)x + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
|
||||
// pipe alloc memory to queue, the unit is Bytes
|
||||
Ppipe->InitBuffer(inQueueX, DOUBLE_BUFFER_NUM, ubFactor * sizeof(T));
|
||||
Ppipe->InitBuffer(inQueueGamma, BUFFER_NUM, numColAlign * sizeof(T));
|
||||
if (!this->nullptrBeta) {
|
||||
Ppipe->InitBuffer(inQueueBeta, BUFFER_NUM, numColAlign * sizeof(T));
|
||||
}
|
||||
Ppipe->InitBuffer(outQueueY, DOUBLE_BUFFER_NUM, ubFactor * sizeof(T));
|
||||
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
Ppipe->InitBuffer(outQueueRstd, BUFFER_NUM, rowFactor * NUM_PER_BLK_FP32 * sizeof(float));
|
||||
#else
|
||||
Ppipe->InitBuffer(rstdBuf, rowFactor * NUM_PER_BLK_FP32 * sizeof(float));
|
||||
#endif
|
||||
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
|
||||
Ppipe->InitBuffer(xFp32Buf, ubFactor * sizeof(float));
|
||||
}
|
||||
Ppipe->InitBuffer(sqxBuf, ubFactor * sizeof(float));
|
||||
Ppipe->InitBuffer(reduceFp32Buf, NUM_PER_REP_FP32 * sizeof(float));
|
||||
Ppipe->InitBuffer(offsetBuf, rowFactor * NUM_PER_BLK_FP32 * sizeof(uint32_t));
|
||||
}
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
CopyInGammaBeta();
|
||||
LocalTensor<T> betaLocal;
|
||||
if (!this->nullptrBeta) {
|
||||
betaLocal = inQueueBeta.DeQue<T>();
|
||||
}
|
||||
LocalTensor<T> gammaLocal = inQueueGamma.DeQue<T>();
|
||||
LocalTensor<uint32_t> offsetLocal = offsetBuf.Get<uint32_t>();
|
||||
for (uint32_t i = 0; i < rowFactor; i++) {
|
||||
Duplicate(offsetLocal[i * NUM_PER_BLK_FP32], i * ONE_BLK_SIZE, NUM_PER_BLK_FP32);
|
||||
}
|
||||
for (uint32_t i_o = 0; i_o < rowLoop - 1; i_o++) {
|
||||
SubProcessHalf(i_o, rowFactor, gammaLocal, betaLocal);
|
||||
}
|
||||
SubProcessHalf(rowLoop - 1, rowTail, gammaLocal, betaLocal);
|
||||
inQueueGamma.FreeTensor(gammaLocal);
|
||||
if (!this->nullptrBeta) {
|
||||
inQueueBeta.FreeTensor(betaLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void SubProcessHalf(uint32_t i_o, uint32_t calc_row_num, LocalTensor<T>& gammaLocal, LocalTensor<T>& betaLocal)
|
||||
{
|
||||
uint32_t gm_bias = i_o * rowFactor * numCol;
|
||||
CopyInX(gm_bias, calc_row_num);
|
||||
LocalTensor<T> xLocal = ComputeX(calc_row_num);
|
||||
CopyOutX(gm_bias, calc_row_num);
|
||||
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
LocalTensor<float> rstdLocal = outQueueRstd.AllocTensor<float>();
|
||||
ComputeRstd(xLocal, rstdLocal, calc_row_num);
|
||||
outQueueRstd.EnQue<float>(rstdLocal);
|
||||
CopyOutRstd(i_o * rowFactor, calc_row_num);
|
||||
#else
|
||||
LocalTensor<float> rstdLocal = rstdBuf.Get<float>();
|
||||
ComputeRstd(xLocal, rstdLocal, calc_row_num);
|
||||
#endif
|
||||
ComputeY(xLocal, gammaLocal, betaLocal, rstdLocal, calc_row_num);
|
||||
CopyOutY(gm_bias, calc_row_num);
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyInX(uint32_t gm_bias, uint32_t calc_row_num)
|
||||
{
|
||||
LocalTensor<T> x1Local = inQueueX.AllocTensor<T>();
|
||||
DataCopyCustom<T>(x1Local, x1Gm[gm_bias], calc_row_num * numCol);
|
||||
inQueueX.EnQue(x1Local);
|
||||
LocalTensor<T> x2Local = inQueueX.AllocTensor<T>();
|
||||
DataCopyCustom<T>(x2Local, x2Gm[gm_bias], calc_row_num * numCol);
|
||||
inQueueX.EnQue(x2Local);
|
||||
}
|
||||
|
||||
__aicore__ inline LocalTensor<T> ComputeX(uint32_t calc_row_num)
|
||||
{
|
||||
uint32_t calc_num = calc_row_num * numColAlign;
|
||||
LocalTensor<T> x1Local = inQueueX.DeQue<T>();
|
||||
LocalTensor<T> x2Local = inQueueX.DeQue<T>();
|
||||
LocalTensor<T> xLocal = outQueueY.AllocTensor<T>();
|
||||
if constexpr (!is_same<T, bfloat16_t>::value) {
|
||||
Add(xLocal, x1Local, x2Local, calc_num);
|
||||
} else {
|
||||
LocalTensor<float> x1Fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> x2Fp32 = sqxBuf.Get<float>();
|
||||
Cast(x1Fp32, x1Local, RoundMode::CAST_NONE, calc_num);
|
||||
Cast(x2Fp32, x2Local, RoundMode::CAST_NONE, calc_num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(x1Fp32, x1Fp32, x2Fp32, calc_num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(xLocal, x1Fp32, RoundMode::CAST_RINT, calc_num);
|
||||
}
|
||||
inQueueX.FreeTensor(x1Local);
|
||||
inQueueX.FreeTensor(x2Local);
|
||||
outQueueY.EnQue(xLocal);
|
||||
PipeBarrier<PIPE_V>();
|
||||
return xLocal;
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOutX(uint32_t gm_bias, uint32_t calc_row_num)
|
||||
{
|
||||
// CopyOut x1 + x2
|
||||
auto x_out = outQueueY.DeQue<T>();
|
||||
DataCopyCustom<T>(xGm[gm_bias], x_out, calc_row_num * numCol);
|
||||
outQueueY.FreeTensor(x_out);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInGammaBeta()
|
||||
{
|
||||
LocalTensor<T> gammaLocal = inQueueGamma.AllocTensor<T>();
|
||||
DataCopyCustom<T>(gammaLocal, gammaGm, numCol);
|
||||
inQueueGamma.EnQue(gammaLocal);
|
||||
if (!this->nullptrBeta) {
|
||||
LocalTensor<T> betaLocal = inQueueBeta.AllocTensor<T>();
|
||||
DataCopyCustom<T>(betaLocal, betaGm, numCol);
|
||||
inQueueBeta.EnQue(betaLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeRstd(LocalTensor<T> xLocal, LocalTensor<float> rstdLocal, uint32_t calc_row_num)
|
||||
{
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
|
||||
Cast(x_fp32, xLocal, RoundMode::CAST_NONE, calc_row_num * numColAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Mul(sqx, x_fp32, x_fp32, calc_row_num * numColAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Muls(sqx, sqx, avgFactor, calc_row_num * numColAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
ReduceSumCustom(rstdLocal[i_i * NUM_PER_BLK_FP32], sqx[i_i * numColAlign], reduce_buf_local, numCol);
|
||||
}
|
||||
Adds(rstdLocal, rstdLocal, epsilon, calc_row_num * NUM_PER_BLK_FP32);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Sqrt(rstdLocal, rstdLocal, calc_row_num * NUM_PER_BLK_FP32);
|
||||
Duplicate(reduce_buf_local, ONE, NUM_PER_BLK_FP32);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
int32_t repeatTimes = calc_row_num * NUM_PER_BLK_FP32 / NUM_PER_REP_FP32;
|
||||
int32_t tailCount = calc_row_num * NUM_PER_BLK_FP32 % NUM_PER_REP_FP32;
|
||||
int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
|
||||
|
||||
if (likely(repeatTimes > 0)) {
|
||||
Div(rstdLocal, reduce_buf_local, rstdLocal, NUM_PER_REP_FP32, repeatTimes, {1, 0, 1, DEFAULT_REPEAT_STRIDE, 0, DEFAULT_REPEAT_STRIDE});
|
||||
}
|
||||
if (unlikely(tailCount != 0)) {
|
||||
Div(rstdLocal[bodyCount], reduce_buf_local, rstdLocal[bodyCount], tailCount, 1, {1, 0, 1, DEFAULT_REPEAT_STRIDE, 0, DEFAULT_REPEAT_STRIDE});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeY(
|
||||
LocalTensor<T> xLocal, LocalTensor<T> gammaLocal, LocalTensor<T> betaLocal, LocalTensor<float> rstdLocal, uint32_t calc_row_num)
|
||||
{
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<uint32_t> offsetLocal = offsetBuf.Get<uint32_t>();
|
||||
Gather(rstdLocal, rstdLocal, offsetLocal, ZERO_UINT, calc_row_num * NUM_PER_BLK_FP32);
|
||||
PipeBarrier<PIPE_V>();
|
||||
int32_t repeatTimes = numCol / NUM_PER_REP_FP32;
|
||||
int32_t tailCount = numCol % NUM_PER_REP_FP32;
|
||||
int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
if (likely(repeatTimes > 0)) {
|
||||
Mul(x_fp32[i_i * numColAlign], x_fp32[i_i * numColAlign], rstdLocal[i_i * NUM_PER_BLK_FP32],
|
||||
NUM_PER_REP_FP32, repeatTimes, {1, 1, 0, DEFAULT_REPEAT_STRIDE, DEFAULT_REPEAT_STRIDE, 0});
|
||||
}
|
||||
if (unlikely(tailCount != 0)) {
|
||||
Mul(x_fp32[i_i * numColAlign + bodyCount], x_fp32[i_i * numColAlign + bodyCount],
|
||||
rstdLocal[i_i * NUM_PER_BLK_FP32], tailCount, 1,
|
||||
{1, 1, 0, DEFAULT_REPEAT_STRIDE, DEFAULT_REPEAT_STRIDE, 0});
|
||||
}
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<T> yLocal = outQueueY.AllocTensor<T>();
|
||||
if constexpr (is_same<T, half>::value) {
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_NONE, calc_row_num * numColAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
Mul(yLocal[i_i * numColAlign], gammaLocal, yLocal[i_i * numColAlign], numCol);
|
||||
}
|
||||
if (!this->nullptrBeta) {
|
||||
PipeBarrier<PIPE_V>();
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
Add(yLocal[i_i * numColAlign], betaLocal, yLocal[i_i * numColAlign], numCol);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, calc_row_num * numColAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<float> yfp32 = xFp32Buf.Get<float>();
|
||||
Cast(yfp32, yLocal, RoundMode::CAST_NONE, calc_row_num * numColAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<float> gammaFp32 = sqxBuf.Get<float>();
|
||||
Cast(gammaFp32, gammaLocal, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
Mul(yfp32[i_i * numColAlign], gammaFp32, yfp32[i_i * numColAlign], numCol);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
if (!this->nullptrBeta) {
|
||||
Cast(gammaFp32, betaLocal, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
Add(yfp32[i_i * numColAlign], gammaFp32, yfp32[i_i * numColAlign], numCol);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
Cast(yLocal, yfp32, RoundMode::CAST_RINT, calc_row_num * numColAlign);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
outQueueY.EnQue<T>(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOutY(uint32_t progress, uint32_t calc_row_num)
|
||||
{
|
||||
LocalTensor<T> yLocal = outQueueY.DeQue<T>();
|
||||
DataCopyCustom<T>(yGm[progress], yLocal, calc_row_num * numCol);
|
||||
outQueueY.FreeTensor(yLocal);
|
||||
}
|
||||
|
||||
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
__aicore__ inline void CopyOutRstd(uint32_t outer_progress, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> rstdLocal = outQueueRstd.DeQue<float>();
|
||||
DataCopyParams copyParams;
|
||||
copyParams.blockLen = sizeof(float);
|
||||
copyParams.blockCount = num;
|
||||
DataCopyPad(rstdGm[outer_progress], rstdLocal, copyParams);
|
||||
outQueueRstd.FreeTensor(rstdLocal);
|
||||
}
|
||||
#endif
|
||||
|
||||
private:
|
||||
TPipe* Ppipe = nullptr;
|
||||
// create queues for input, in this case depth is equal to buffer num
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueGamma;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueBeta;
|
||||
TQue<QuePosition::VECIN, DOUBLE_BUFFER_NUM> inQueueX;
|
||||
// create queues for output, in this case depth is equal to buffer num
|
||||
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueRstd;
|
||||
#else
|
||||
TBuf<TPosition::VECCALC> rstdBuf;
|
||||
#endif
|
||||
TQue<QuePosition::VECOUT, DOUBLE_BUFFER_NUM> outQueueY;
|
||||
|
||||
TBuf<TPosition::VECCALC> xFp32Buf;
|
||||
TBuf<TPosition::VECCALC> sqxBuf;
|
||||
TBuf<TPosition::VECCALC> reduceFp32Buf;
|
||||
TBuf<TPosition::VECCALC> offsetBuf;
|
||||
GlobalTensor<T> x1Gm;
|
||||
GlobalTensor<T> x2Gm;
|
||||
GlobalTensor<T> gammaGm;
|
||||
GlobalTensor<T> betaGm;
|
||||
GlobalTensor<T> yGm;
|
||||
GlobalTensor<float> rstdGm;
|
||||
GlobalTensor<T> xGm;
|
||||
|
||||
uint32_t numRow;
|
||||
uint32_t numCol;
|
||||
uint32_t blockFactor; // number of calculations rows on each core
|
||||
uint32_t rowFactor;
|
||||
uint32_t ubFactor;
|
||||
float epsilon;
|
||||
float avgFactor;
|
||||
uint32_t numColAlign;
|
||||
int32_t blockIdx_;
|
||||
uint32_t rowWork = 1;
|
||||
uint32_t rowLoop = 1;
|
||||
uint32_t rowTail = 0;
|
||||
uint32_t nullptrBeta = 0;
|
||||
};
|
||||
#endif // ADD_RMS_NORM__BIAS_MULTI_N_H_
|
||||
376
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias_single_n.h
Normal file
376
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias_single_n.h
Normal file
@@ -0,0 +1,376 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file add_rms_norm_bias_single_n.h
|
||||
* \brief add rms norm bias single n file
|
||||
*/
|
||||
#ifndef ADD_RMS_NORM_BIAS_SINGLE_N_H_
|
||||
#define ADD_RMS_NORM_BIAS_SINGLE_N_H_
|
||||
#include "./rms_norm_base.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace RmsNorm;
|
||||
|
||||
template <typename T>
|
||||
class KernelAddRmsNormBiasSingleN {
|
||||
static constexpr int32_t MAXBUFFER = 195584;
|
||||
public:
|
||||
__aicore__ inline KernelAddRmsNormBiasSingleN(TPipe* pipe)
|
||||
{
|
||||
Ppipe = pipe;
|
||||
}
|
||||
__aicore__ inline void Init(
|
||||
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, const AddRMSNormBiasTilingData* tiling)
|
||||
{
|
||||
ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!");
|
||||
|
||||
this->numCol = tiling->num_col;
|
||||
this->blockFactor = 1; // in this case, blockFactor = 1
|
||||
this->ubFactor = tiling->ub_factor;
|
||||
this->epsilon = tiling->epsilon;
|
||||
this->avgFactor = (numCol != 0) ? (float)1.0 / numCol : 0;
|
||||
this->nullptrBeta = tiling->nullptr_beta;
|
||||
|
||||
this->rowWork = 1;
|
||||
blockIdx_ = GetBlockIdx();
|
||||
// get start index for current core, core parallel
|
||||
x1Gm.SetGlobalBuffer((__gm__ T*)x1 + blockIdx_ * numCol, numCol);
|
||||
x2Gm.SetGlobalBuffer((__gm__ T*)x2 + blockIdx_ * numCol, numCol);
|
||||
gammaGm.SetGlobalBuffer((__gm__ T*)gamma, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
betaGm.SetGlobalBuffer((__gm__ T*)beta, numCol);
|
||||
}
|
||||
yGm.SetGlobalBuffer((__gm__ T*)y + blockIdx_ * numCol, numCol);
|
||||
rstdGm.SetGlobalBuffer((__gm__ float*)rstd + blockIdx_, 1);
|
||||
xGm.SetGlobalBuffer((__gm__ T*)x + blockIdx_ * numCol, numCol);
|
||||
|
||||
Ppipe->InitBuffer(unitBuf, MAXBUFFER); // (192 - 1) * 1024 byte
|
||||
}
|
||||
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
if constexpr (is_same<T, half>::value) {
|
||||
ProcessFp16();
|
||||
} else if constexpr (is_same<T, float>::value) {
|
||||
ProcessFp32();
|
||||
} else {
|
||||
ProcessBf16();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ inline void ProcessFp16()
|
||||
{
|
||||
LocalTensor<float> ubLocal = unitBuf.Get<float>();
|
||||
LocalTensor<T> xLocal = ubLocal.template ReinterpretCast<T>();
|
||||
LocalTensor<T> x1Local = xLocal[0];
|
||||
LocalTensor<T> x2Local = xLocal[ubFactor];
|
||||
LocalTensor<float> xFp32Local = ubLocal[ubFactor];
|
||||
LocalTensor<float> sqxLocal = ubLocal[ubFactor * 2];
|
||||
LocalTensor<float> tmpLocal = ubLocal[ubFactor * 3];
|
||||
|
||||
DataCopyCustom<T>(x1Local, x1Gm, numCol);
|
||||
event_t eventMTE2V1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V1);
|
||||
DataCopyCustom<T>(x2Local, x2Gm, numCol);
|
||||
event_t eventMTE2V2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V2);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V1);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2);
|
||||
Add(x1Local, x1Local, x2Local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
// copy gamma
|
||||
event_t eventVMTE2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(eventVMTE2);
|
||||
WaitFlag<HardEvent::V_MTE2>(eventVMTE2);
|
||||
|
||||
DataCopyCustom<T>(x2Local, gammaGm, numCol); // gammaLocal use x2Local
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V2);
|
||||
|
||||
// copy x out
|
||||
event_t eventVMTE3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
DataCopyCustom<T>(xGm, x1Local, numCol);
|
||||
event_t eventMTE3V = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
|
||||
SetFlag<HardEvent::MTE3_V>(eventMTE3V);
|
||||
|
||||
Cast(xFp32Local, x1Local, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(sqxLocal, xFp32Local, xFp32Local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(sqxLocal, sqxLocal, avgFactor, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSumCustom(sqxLocal, sqxLocal, tmpLocal, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(sqxLocal, sqxLocal, epsilon, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Sqrt(sqxLocal, sqxLocal, 1);
|
||||
Duplicate(tmpLocal, ONE, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(sqxLocal, tmpLocal, sqxLocal, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
// copyout rstd
|
||||
#if (defined(__CCE_AICORE__) && __CCE_AICORE__ == 220) || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
DataCopyCustom<float>(rstdGm, sqxLocal, 1);
|
||||
#endif
|
||||
event_t eventVS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventVS);
|
||||
WaitFlag<HardEvent::V_S>(eventVS);
|
||||
float rstdValue = sqxLocal.GetValue(0);
|
||||
event_t eventSV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventSV);
|
||||
WaitFlag<HardEvent::S_V>(eventSV);
|
||||
|
||||
Muls(xFp32Local, xFp32Local, rstdValue, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
WaitFlag<HardEvent::MTE3_V>(eventMTE3V);
|
||||
Cast(x1Local, xFp32Local, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2);
|
||||
Mul(x1Local, x1Local, x2Local, numCol);
|
||||
|
||||
if (!this->nullptrBeta) {
|
||||
event_t eventVMTE2Beta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
|
||||
WaitFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
|
||||
DataCopyCustom<T>(x2Local, betaGm, numCol);
|
||||
event_t eventMTE2XBeta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
|
||||
Add(x1Local, x1Local, x2Local, numCol);
|
||||
}
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
DataCopyCustom<T>(yGm, x1Local, numCol);
|
||||
}
|
||||
|
||||
__aicore__ inline void ProcessFp32()
|
||||
{
|
||||
LocalTensor<float> ubLocal = unitBuf.Get<float>();
|
||||
LocalTensor<T> x1Local = ubLocal[0];
|
||||
LocalTensor<T> x2Local = ubLocal[ubFactor];
|
||||
LocalTensor<float> sqxLocal = ubLocal[ubFactor * 2];
|
||||
LocalTensor<float> tmpLocal = ubLocal[ubFactor * 3];
|
||||
|
||||
DataCopyCustom<T>(x1Local, x1Gm, numCol);
|
||||
event_t eventMTE2V1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V1);
|
||||
DataCopyCustom<T>(x2Local, x2Gm, numCol);
|
||||
event_t eventMTE2V2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V2);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V1);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2);
|
||||
Add(x1Local, x1Local, x2Local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
// copy gamma
|
||||
event_t eventVMTE2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(eventVMTE2);
|
||||
WaitFlag<HardEvent::V_MTE2>(eventVMTE2);
|
||||
|
||||
DataCopyCustom<T>(x2Local, gammaGm, numCol); // gammaLocal use x2Local
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V2);
|
||||
|
||||
// copy x out
|
||||
event_t eventVMTE3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
DataCopyCustom<T>(xGm, x1Local, numCol);
|
||||
event_t eventMTE3V = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
|
||||
SetFlag<HardEvent::MTE3_V>(eventMTE3V);
|
||||
|
||||
Mul(sqxLocal, x1Local, x1Local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(sqxLocal, sqxLocal, avgFactor, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSumCustom(sqxLocal, sqxLocal, tmpLocal, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(sqxLocal, sqxLocal, epsilon, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Sqrt(sqxLocal, sqxLocal, 1);
|
||||
Duplicate(tmpLocal, ONE, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(sqxLocal, tmpLocal, sqxLocal, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
// copyout rstd
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
DataCopyCustom<float>(rstdGm, sqxLocal, 1);
|
||||
#endif
|
||||
event_t eventVS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventVS);
|
||||
WaitFlag<HardEvent::V_S>(eventVS);
|
||||
float rstdValue = sqxLocal.GetValue(0);
|
||||
event_t eventSV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventSV);
|
||||
WaitFlag<HardEvent::S_V>(eventSV);
|
||||
WaitFlag<HardEvent::MTE3_V>(eventMTE3V);
|
||||
Muls(x1Local, x1Local, rstdValue, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2);
|
||||
Mul(x1Local, x1Local, x2Local, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
event_t eventVMTE2Beta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
|
||||
WaitFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
|
||||
DataCopyCustom<T>(x2Local, betaGm, numCol);
|
||||
event_t eventMTE2XBeta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
|
||||
Add(x1Local, x1Local, x2Local, numCol);
|
||||
}
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
DataCopyCustom<T>(yGm, x1Local, numCol);
|
||||
}
|
||||
|
||||
__aicore__ inline void ProcessBf16()
|
||||
{
|
||||
LocalTensor<float> ubLocal = unitBuf.Get<float>();
|
||||
LocalTensor<T> xLocal = ubLocal.template ReinterpretCast<T>();
|
||||
LocalTensor<T> x1Local = xLocal[0];
|
||||
LocalTensor<T> x2Local = xLocal[ubFactor];
|
||||
LocalTensor<float> xFp32Local = ubLocal[ubFactor];
|
||||
LocalTensor<float> sqxLocal = ubLocal[ubFactor * 2];
|
||||
LocalTensor<float> tmpLocal = ubLocal[ubFactor * 3];
|
||||
|
||||
DataCopyCustom<T>(x1Local, x1Gm, numCol);
|
||||
event_t eventMTE2V1_BF16_0 = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>());
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V1_BF16_0);
|
||||
DataCopyCustom<T>(x2Local, x2Gm, numCol);
|
||||
event_t eventMTE2V2_BF16_0 = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>());
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V2_BF16_0);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V1_BF16_0);
|
||||
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(eventMTE2V1_BF16_0);
|
||||
Cast(xFp32Local, x1Local, RoundMode::CAST_NONE, numCol);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2_BF16_0);
|
||||
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(eventMTE2V2_BF16_0);
|
||||
Cast(sqxLocal, x2Local, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(xFp32Local, xFp32Local, sqxLocal, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(x1Local, xFp32Local, RoundMode::CAST_RINT, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
// copy gamma
|
||||
event_t eventVMTE2_BF16_0 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(eventVMTE2_BF16_0);
|
||||
WaitFlag<HardEvent::V_MTE2>(eventVMTE2_BF16_0);
|
||||
|
||||
DataCopyCustom<T>(x2Local, gammaGm, numCol); // gammaLocal use x2Local
|
||||
event_t eventMTE2V2_BF16_1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V2_BF16_1);
|
||||
|
||||
// copy x out
|
||||
event_t eventVMTE3_BF16_0 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_0);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_0);
|
||||
DataCopyCustom<T>(xGm, x1Local, numCol);
|
||||
event_t eventMTE3V_BF16_0 = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::MTE3_V>());
|
||||
SetFlag<HardEvent::MTE3_V>(eventMTE3V_BF16_0);
|
||||
|
||||
Cast(xFp32Local, x1Local, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(sqxLocal, xFp32Local, xFp32Local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(sqxLocal, sqxLocal, avgFactor, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSumCustom(sqxLocal, sqxLocal, tmpLocal, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(sqxLocal, sqxLocal, epsilon, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Sqrt(sqxLocal, sqxLocal, 1);
|
||||
Duplicate(tmpLocal, ONE, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(sqxLocal, tmpLocal, sqxLocal, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
event_t eventVS_BF16_0 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventVS_BF16_0);
|
||||
WaitFlag<HardEvent::V_S>(eventVS_BF16_0);
|
||||
float rstdValue = sqxLocal.GetValue(0);
|
||||
event_t eventSV_BF16_0 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventSV_BF16_0);
|
||||
WaitFlag<HardEvent::S_V>(eventSV_BF16_0);
|
||||
// copyout rstd
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
event_t eventVMTE3_BF16_1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_1);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_1);
|
||||
DataCopyCustom<float>(rstdGm, sqxLocal, 1);
|
||||
event_t eventMTE3V2_BF16_0 = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::MTE3_V>());
|
||||
SetFlag<HardEvent::MTE3_V>(eventMTE3V2_BF16_0);
|
||||
#endif
|
||||
|
||||
Muls(xFp32Local, xFp32Local, rstdValue, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
WaitFlag<HardEvent::MTE3_V>(eventMTE3V_BF16_0);
|
||||
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_V>(eventMTE3V_BF16_0);
|
||||
Cast(x1Local, xFp32Local, RoundMode::CAST_RINT, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(xFp32Local, x1Local, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2_BF16_1);
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
WaitFlag<HardEvent::MTE3_V>(eventMTE3V2_BF16_0);
|
||||
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_V>(eventMTE3V2_BF16_0);
|
||||
#endif
|
||||
Cast(sqxLocal, x2Local, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(xFp32Local, xFp32Local, sqxLocal, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
event_t eventVMTE2Beta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
|
||||
WaitFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
|
||||
DataCopyCustom<T>(x2Local, betaGm, numCol);
|
||||
event_t eventMTE2XBeta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
|
||||
Cast(sqxLocal, x2Local, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(xFp32Local, xFp32Local, sqxLocal, numCol);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(x1Local, xFp32Local, RoundMode::CAST_RINT, numCol);
|
||||
event_t eventVMTE3_BF16_2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_2);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_2);
|
||||
DataCopyCustom<T>(yGm, x1Local, numCol);
|
||||
}
|
||||
|
||||
private:
|
||||
TPipe* Ppipe = nullptr;
|
||||
|
||||
TBuf<TPosition::VECCALC> unitBuf;
|
||||
GlobalTensor<T> x1Gm;
|
||||
GlobalTensor<T> x2Gm;
|
||||
GlobalTensor<T> gammaGm;
|
||||
GlobalTensor<T> betaGm;
|
||||
GlobalTensor<T> yGm;
|
||||
GlobalTensor<float> rstdGm;
|
||||
GlobalTensor<T> xGm;
|
||||
|
||||
uint32_t numRow;
|
||||
uint32_t numCol;
|
||||
uint32_t blockFactor; // number of calculations rows on each core
|
||||
uint32_t ubFactor;
|
||||
float epsilon;
|
||||
float avgFactor;
|
||||
int32_t blockIdx_;
|
||||
uint32_t rowWork = 1;
|
||||
uint32_t nullptrBeta = 0;
|
||||
};
|
||||
#endif // _ADD_RMS_NORM_BIAS_SINGLE_N_H_
|
||||
395
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias_split_d.h
Normal file
395
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias_split_d.h
Normal file
@@ -0,0 +1,395 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file add_rms_norm_bias_split_d.h
|
||||
* \brief add rms norm bias split d file
|
||||
*/
|
||||
#ifndef ADD_RMS_NORM_BIAS_SPLIT_D_H_
|
||||
#define ADD_RMS_NORM_BIAS_SPLIT_D_H_
|
||||
#include "./rms_norm_base.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace RmsNorm;
|
||||
|
||||
template <typename T>
|
||||
class KernelAddRmsNormBiasSplitD {
|
||||
public:
|
||||
__aicore__ inline KernelAddRmsNormBiasSplitD(TPipe* pipe)
|
||||
{
|
||||
Ppipe = pipe;
|
||||
}
|
||||
__aicore__ inline void Init(
|
||||
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, const AddRMSNormBiasTilingData* tiling)
|
||||
{
|
||||
ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!");
|
||||
this->numRow = tiling->num_row;
|
||||
this->numCol = tiling->num_col;
|
||||
this->blockFactor = tiling->block_factor;
|
||||
this->rowFactor = tiling->row_factor;
|
||||
this->ubFactor = tiling->ub_factor;
|
||||
this->epsilon = tiling->epsilon;
|
||||
this->avgFactor = (numCol != 0) ? (float)1.0 / numCol : 0;
|
||||
this->nullptrBeta = tiling->nullptr_beta;
|
||||
|
||||
blockIdx_ = GetBlockIdx();
|
||||
if (blockIdx_ < GetBlockNum() - 1) {
|
||||
this->rowWork = blockFactor;
|
||||
} else if (blockIdx_ == GetBlockNum() - 1) {
|
||||
this->rowWork = numRow - (GetBlockNum() - 1) * blockFactor;
|
||||
} else {
|
||||
}
|
||||
// get start index for current core, core parallel
|
||||
x1Gm.SetGlobalBuffer((__gm__ T*)x1 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
x2Gm.SetGlobalBuffer((__gm__ T*)x2 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
gammaGm.SetGlobalBuffer((__gm__ T*)gamma, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
betaGm.SetGlobalBuffer((__gm__ T*)beta, numCol);
|
||||
}
|
||||
yGm.SetGlobalBuffer((__gm__ T*)y + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
rstdGm.SetGlobalBuffer((__gm__ float*)rstd + blockIdx_ * blockFactor, blockFactor);
|
||||
xGm.SetGlobalBuffer((__gm__ T*)x + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
|
||||
// pipe alloc memory to queue, the unit is Bytes.
|
||||
// We need 2 buffers here for both x1 and x2.
|
||||
Ppipe->InitBuffer(inQueueX, BUFFER_NUM, 2 * ubFactor * sizeof(T));
|
||||
Ppipe->InitBuffer(inQueueGamma, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
if (!this->nullptrBeta) {
|
||||
Ppipe->InitBuffer(inQueueBeta, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
}
|
||||
Ppipe->InitBuffer(outQueueY, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
Ppipe->InitBuffer(outQueueRstd, BUFFER_NUM, rowFactor * sizeof(float));
|
||||
|
||||
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
|
||||
Ppipe->InitBuffer(xFp32Buf, ubFactor * sizeof(float));
|
||||
}
|
||||
Ppipe->InitBuffer(sqxBuf, ubFactor * sizeof(float));
|
||||
Ppipe->InitBuffer(sumBuf, rowFactor * NUM_PER_BLK_FP32 * sizeof(float));
|
||||
Ppipe->InitBuffer(reduceFp32Buf, NUM_PER_REP_FP32 * sizeof(float));
|
||||
}
|
||||
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
uint32_t i_o_max = RmsNorm::CeilDiv(rowWork, rowFactor);
|
||||
uint32_t row_tail = rowWork - (i_o_max - 1) * rowFactor;
|
||||
uint32_t j_max = RmsNorm::CeilDiv(numCol, ubFactor);
|
||||
uint32_t col_tail = numCol - (j_max - 1) * ubFactor;
|
||||
for (uint32_t i_o = 0; i_o < i_o_max - 1; i_o++) {
|
||||
SubProcess(i_o, rowFactor, j_max, col_tail);
|
||||
}
|
||||
SubProcess(i_o_max - 1, row_tail, j_max, col_tail);
|
||||
}
|
||||
|
||||
__aicore__ inline void SubProcess(uint32_t i_o, uint32_t calc_row_num, uint32_t j_max, uint32_t col_tail)
|
||||
{
|
||||
LocalTensor<float> sumLocal = sumBuf.Get<float>();
|
||||
|
||||
LocalTensor<float> rstdLocal = outQueueRstd.AllocTensor<float>();
|
||||
Duplicate(rstdLocal, (float)0.0, calc_row_num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
for (uint32_t j = 0; j < j_max - 1; j++) {
|
||||
ComputeFormer(i_o, calc_row_num, j, rstdLocal, sumLocal, ubFactor);
|
||||
}
|
||||
// do tail
|
||||
ComputeFormer(i_o, calc_row_num, j_max - 1, rstdLocal, sumLocal, col_tail);
|
||||
ComputeRstd(rstdLocal, calc_row_num);
|
||||
|
||||
for (uint32_t j = 0; j < j_max - 1; j++) {
|
||||
ComputeLatter(i_o, calc_row_num, j, rstdLocal, ubFactor);
|
||||
}
|
||||
ComputeLatter(i_o, calc_row_num, j_max - 1, rstdLocal, col_tail);
|
||||
outQueueRstd.EnQue<float>(rstdLocal);
|
||||
CopyOutRstd(i_o, calc_row_num);
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyInAndAdd(uint32_t i_idx, uint32_t j_idx, uint32_t num)
|
||||
{
|
||||
LocalTensor<T> x1x2_in = inQueueX.AllocTensor<T>();
|
||||
LocalTensor<T> x1_in = x1x2_in[0];
|
||||
LocalTensor<T> x2_in = x1x2_in[ubFactor];
|
||||
DataCopyCustom<T>(x1_in, x1Gm[i_idx * numCol + j_idx * ubFactor], num);
|
||||
DataCopyCustom<T>(x2_in, x2Gm[i_idx * numCol + j_idx * ubFactor], num);
|
||||
inQueueX.EnQue(x1x2_in);
|
||||
LocalTensor<T> x1x2Local = inQueueX.DeQue<T>();
|
||||
|
||||
auto x1Local = x1x2Local[0];
|
||||
auto x2Local = x1x2Local[ubFactor];
|
||||
|
||||
LocalTensor<T> xLocal = outQueueY.AllocTensor<T>();
|
||||
|
||||
if constexpr (is_same<T, half>::value) {
|
||||
LocalTensor<float> x1_fp32 = xFp32Buf.Get<float>();
|
||||
|
||||
Add(xLocal, x1Local, x2Local, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(x1_fp32, xLocal, RoundMode::CAST_NONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
// x1+x2 saved in x1_fp32
|
||||
} else if constexpr (is_same<T, bfloat16_t>::value) {
|
||||
LocalTensor<float> x1_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> x2_fp32 = x1x2Local.template ReinterpretCast<float>();
|
||||
|
||||
Cast(x1_fp32, x1Local, RoundMode::CAST_NONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(x2_fp32, x2Local, RoundMode::CAST_NONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Add(x1_fp32, x1_fp32, x2_fp32, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(xLocal, x1_fp32, RoundMode::CAST_RINT, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
// x1+x2 saved in x1_fp32
|
||||
} else {
|
||||
Add(x1Local, x1Local, x2Local, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(xLocal, x1Local, (float)0.0, num);
|
||||
// x1+x2 saved in inQueueX
|
||||
}
|
||||
inQueueX.FreeTensor(x1x2Local);
|
||||
|
||||
// copy out to workspace && x_out
|
||||
outQueueY.EnQue(xLocal);
|
||||
auto x_out = outQueueY.DeQue<T>();
|
||||
DataCopyCustom<T>(xGm[i_idx * numCol + j_idx * ubFactor], x_out, num);
|
||||
outQueueY.FreeTensor(x_out);
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeFormer(
|
||||
uint32_t i_o_idx, uint32_t calc_row_num, uint32_t j_idx, LocalTensor<float>& rstdLocal,
|
||||
LocalTensor<float>& sumLocal, uint32_t num)
|
||||
{
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
CopyInAndAdd(i_o_idx * rowFactor + i_i, j_idx, num);
|
||||
ComputeSum(i_i, sumLocal, num);
|
||||
}
|
||||
BlockReduceSumFP32(sumLocal, sumLocal, calc_row_num * NUM_PER_BLK_FP32);
|
||||
Add(rstdLocal, rstdLocal, sumLocal, calc_row_num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeSum(uint32_t i_i_idx, LocalTensor<float>& sumLocal, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
|
||||
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(sqx, x_fp32, x_fp32, num);
|
||||
} else {
|
||||
LocalTensor<T> xLocal = inQueueX.AllocTensor<float>();
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(sqx, xLocal, xLocal, num);
|
||||
inQueueX.FreeTensor(xLocal);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(sqx, sqx, avgFactor, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
// 8 means 8 fp32 pre block
|
||||
ReduceSumFP32ToBlock(sumLocal[i_i_idx * 8], sqx, reduce_buf_local, num);
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeRstd(LocalTensor<float> rstdLocal, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
|
||||
Adds(rstdLocal, rstdLocal, epsilon, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Sqrt(rstdLocal, rstdLocal, num);
|
||||
Duplicate(reduce_buf_local, ONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(rstdLocal, reduce_buf_local, rstdLocal, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeLatter(
|
||||
uint32_t i_o_idx, uint32_t calc_row_num, uint32_t j_idx, LocalTensor<float>& rstdLocal, uint32_t num)
|
||||
{
|
||||
CopyInGammaBeta(j_idx, num);
|
||||
LocalTensor<T> gammaLocal = inQueueGamma.DeQue<T>();
|
||||
LocalTensor<T> betaLocal;
|
||||
if (!this->nullptrBeta) {
|
||||
betaLocal = inQueueBeta.DeQue<T>();
|
||||
}
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
CopyInX(i_o_idx * rowFactor + i_i, j_idx, num);
|
||||
ComputeY(i_i, gammaLocal, betaLocal, rstdLocal, num);
|
||||
CopyOutY(i_o_idx * rowFactor + i_i, j_idx, num);
|
||||
}
|
||||
inQueueGamma.FreeTensor(gammaLocal);
|
||||
if (!this->nullptrBeta) {
|
||||
inQueueBeta.FreeTensor(betaLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInGammaBeta(uint32_t j_idx, uint32_t num)
|
||||
{
|
||||
LocalTensor<T> gammaLocal = inQueueGamma.AllocTensor<T>();
|
||||
DataCopyCustom<T>(gammaLocal, gammaGm[j_idx * ubFactor], num);
|
||||
inQueueGamma.EnQue(gammaLocal);
|
||||
if (!this->nullptrBeta) {
|
||||
LocalTensor<T> betaLocal = inQueueBeta.AllocTensor<T>();
|
||||
DataCopyCustom<T>(betaLocal, betaGm[j_idx * ubFactor], num);
|
||||
inQueueBeta.EnQue(betaLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInX(uint32_t i_idx, uint32_t j_idx, uint32_t num)
|
||||
{
|
||||
LocalTensor<T> xLocal = inQueueX.AllocTensor<T>();
|
||||
DataCopyCustom<T>(xLocal, xGm[i_idx * numCol + j_idx * ubFactor], num);
|
||||
inQueueX.EnQue<T>(xLocal);
|
||||
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<T> xLocal = inQueueX.DeQue<T>();
|
||||
Cast(x_fp32, xLocal, RoundMode::CAST_NONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
inQueueX.FreeTensor(xLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeY(
|
||||
uint32_t i_i_idx, LocalTensor<half>& gammaLocal, LocalTensor<half>& betaLocal, LocalTensor<float>& rstdLocal, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(event_v_s);
|
||||
WaitFlag<HardEvent::V_S>(event_v_s);
|
||||
float rstdValue = rstdLocal.GetValue(i_i_idx);
|
||||
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(event_s_v);
|
||||
WaitFlag<HardEvent::S_V>(event_s_v);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(x_fp32, x_fp32, rstdValue, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<half> yLocal = outQueueY.AllocTensor<half>();
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_NONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(yLocal, gammaLocal, yLocal, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
if (!this->nullptrBeta) {
|
||||
Add(yLocal, betaLocal, yLocal, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
outQueueY.EnQue<half>(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeY(
|
||||
uint32_t i_i_idx, LocalTensor<float>& gammaLocal, LocalTensor<float>& betaLocal, LocalTensor<float>& rstdLocal, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> xLocal = inQueueX.DeQue<float>();
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(event_v_s);
|
||||
WaitFlag<HardEvent::V_S>(event_v_s);
|
||||
float rstdValue = rstdLocal.GetValue(i_i_idx);
|
||||
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(event_s_v);
|
||||
WaitFlag<HardEvent::S_V>(event_s_v);
|
||||
LocalTensor<float> yLocal = outQueueY.AllocTensor<float>();
|
||||
Muls(yLocal, xLocal, rstdValue, num);
|
||||
inQueueX.FreeTensor(xLocal);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(yLocal, gammaLocal, yLocal, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
if (!this->nullptrBeta) {
|
||||
Add(yLocal, betaLocal, yLocal, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
outQueueY.EnQue<float>(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeY(
|
||||
uint32_t i_i_idx, LocalTensor<bfloat16_t>& gammaLocal, LocalTensor<bfloat16_t>& betaLocal, LocalTensor<float>& rstdLocal, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(event_v_s);
|
||||
WaitFlag<HardEvent::V_S>(event_v_s);
|
||||
float rstdValue = rstdLocal.GetValue(i_i_idx);
|
||||
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(event_s_v);
|
||||
WaitFlag<HardEvent::S_V>(event_s_v);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(x_fp32, x_fp32, rstdValue, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<bfloat16_t> yLocal = outQueueY.AllocTensor<bfloat16_t>();
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(x_fp32, yLocal, RoundMode::CAST_NONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(sqx, gammaLocal, RoundMode::CAST_NONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(x_fp32, x_fp32, sqx, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
if (!this->nullptrBeta) {
|
||||
Cast(sqx, betaLocal, RoundMode::CAST_NONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(x_fp32, x_fp32, sqx, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
outQueueY.EnQue<bfloat16_t>(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOutY(uint32_t i_idx, uint32_t j_idx, uint32_t num)
|
||||
{
|
||||
LocalTensor<T> yLocal = outQueueY.DeQue<T>();
|
||||
DataCopyCustom<T>(yGm[i_idx * numCol + j_idx * ubFactor], yLocal, num);
|
||||
outQueueY.FreeTensor(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOutRstd(uint32_t i_o_idx, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> rstdLocal = outQueueRstd.DeQue<float>();
|
||||
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
DataCopyCustom<float>(rstdGm[i_o_idx * rowFactor], rstdLocal, num);
|
||||
#endif
|
||||
outQueueRstd.FreeTensor(rstdLocal);
|
||||
}
|
||||
|
||||
private:
|
||||
TPipe* Ppipe = nullptr;
|
||||
// create queues for input, in this case depth is equal to buffer num
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueX;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueGamma;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueBeta;
|
||||
// create queues for output, in this case depth is equal to buffer num
|
||||
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueY;
|
||||
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueRstd;
|
||||
TBuf<TPosition::VECCALC> xFp32Buf;
|
||||
TBuf<TPosition::VECCALC> sqxBuf;
|
||||
TBuf<TPosition::VECCALC> sumBuf;
|
||||
TBuf<TPosition::VECCALC> reduceFp32Buf;
|
||||
|
||||
GlobalTensor<T> x1Gm;
|
||||
GlobalTensor<T> x2Gm;
|
||||
GlobalTensor<T> gammaGm;
|
||||
GlobalTensor<T> betaGm;
|
||||
GlobalTensor<T> yGm;
|
||||
GlobalTensor<float> rstdGm;
|
||||
GlobalTensor<T> xGm;
|
||||
|
||||
uint32_t numRow;
|
||||
uint32_t numCol;
|
||||
uint32_t blockFactor; // number of calculations rows on each core
|
||||
uint32_t rowFactor;
|
||||
uint32_t ubFactor;
|
||||
float epsilon;
|
||||
float avgFactor;
|
||||
int32_t blockIdx_;
|
||||
uint32_t rowWork = 1;
|
||||
uint32_t nullptrBeta = 0;
|
||||
|
||||
int tempbufNum;
|
||||
};
|
||||
#endif // _ADD_RMS_NORM_BIAS_SPLIT_D_H_
|
||||
179
csrc/add_rms_norm_bias/op_kernel/reduce_common.h
Normal file
179
csrc/add_rms_norm_bias/op_kernel/reduce_common.h
Normal file
@@ -0,0 +1,179 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
/*!
|
||||
* \file reduce_common.h
|
||||
*/
|
||||
#ifndef REDUCE_COMMON_H_RMS_NORM
|
||||
#define REDUCE_COMMON_H_RMS_NORM
|
||||
#include "kernel_operator.h"
|
||||
using namespace AscendC;
|
||||
|
||||
constexpr uint32_t MAX_REP_NUM = 255;
|
||||
constexpr uint32_t ELEM_PER_REP_FP32 = 64;
|
||||
constexpr uint32_t ELEM_PER_BLK_FP32 = 8;
|
||||
constexpr float ZERO = 0;
|
||||
constexpr int32_t HALf_INTERVAL = 2;
|
||||
constexpr int32_t INDEX_TWO = 2;
|
||||
constexpr int32_t INDEX_FOUR = 4;
|
||||
constexpr int32_t INDEX_EIGHT = 8;
|
||||
constexpr int32_t INDEX_SIXTEEN = 16;
|
||||
|
||||
__aicore__ inline void ReduceSumForSmallReduceDimPreRepeat(
|
||||
const LocalTensor<float>& dstLocal, const LocalTensor<float>& srcLocal, const LocalTensor<float>& tmpLocal,
|
||||
const uint32_t elemNum, const uint32_t numLastDim, const uint32_t tailCount, const uint32_t repeat,
|
||||
const uint8_t repStride)
|
||||
{
|
||||
uint32_t elemIndex = 0;
|
||||
for (; elemIndex + ELEM_PER_REP_FP32 <= numLastDim; elemIndex += ELEM_PER_REP_FP32) {
|
||||
Add(tmpLocal, srcLocal[elemIndex], tmpLocal, elemNum, repeat,
|
||||
{1, 1, 1, ELEM_PER_BLK_FP32, repStride, ELEM_PER_BLK_FP32});
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
if (unlikely(tailCount != 0)) {
|
||||
Add(tmpLocal, srcLocal[elemIndex], tmpLocal, tailCount, repeat,
|
||||
{1, 1, 1, ELEM_PER_BLK_FP32, repStride, ELEM_PER_BLK_FP32});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
AscendCUtils::SetMask<float>(ELEM_PER_REP_FP32); // set mask = 64
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
|
||||
if ASCEND_IS_AIV {
|
||||
WholeReduceSum<float, false>(dstLocal, tmpLocal, elemNum, repeat, 1, 1, ELEM_PER_BLK_FP32);
|
||||
}
|
||||
#elif defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003
|
||||
WholeReduceSum(dstLocal, tmpLocal, elemNum, repeat, 1, 1, ELEM_PER_BLK_FP32);
|
||||
#else
|
||||
WholeReduceSum<float, false>(dstLocal, tmpLocal, elemNum, repeat, 1, 1, ELEM_PER_BLK_FP32);
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* reduce dim form (N, D) to (N, 1)
|
||||
* this reduce sum is for small reduce dim.
|
||||
*/
|
||||
__aicore__ inline void ReduceSumForSmallReduceDim(
|
||||
const LocalTensor<float>& dstLocal, const LocalTensor<float>& srcLocal, const LocalTensor<float>& tmpLocal,
|
||||
const uint32_t numLastDimAligned, const uint32_t numLastDim, const uint32_t tailCount, const uint32_t repeat,
|
||||
const uint8_t repStride)
|
||||
{
|
||||
uint32_t repeatTimes = repeat / MAX_REP_NUM;
|
||||
if (repeatTimes == 0) {
|
||||
ReduceSumForSmallReduceDimPreRepeat(
|
||||
dstLocal, srcLocal, tmpLocal, ELEM_PER_REP_FP32, numLastDim, tailCount, repeat, repStride);
|
||||
} else {
|
||||
uint32_t repTailNum = repeat % MAX_REP_NUM;
|
||||
uint32_t repIndex = 0;
|
||||
uint32_t repElem;
|
||||
for (; repIndex + MAX_REP_NUM <= repeat; repIndex += MAX_REP_NUM) {
|
||||
ReduceSumForSmallReduceDimPreRepeat(
|
||||
dstLocal[repIndex], srcLocal[repIndex * numLastDimAligned], tmpLocal[repIndex * ELEM_PER_REP_FP32],
|
||||
ELEM_PER_REP_FP32, numLastDim, tailCount, MAX_REP_NUM, repStride);
|
||||
}
|
||||
if (repTailNum != 0) {
|
||||
ReduceSumForSmallReduceDimPreRepeat(
|
||||
dstLocal[repIndex], srcLocal[repIndex * numLastDimAligned], tmpLocal[repIndex * ELEM_PER_REP_FP32],
|
||||
ELEM_PER_REP_FP32, numLastDim, tailCount, repTailNum, repStride);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* reduce dim form (N, D) to (N, 1)
|
||||
* this reduce sum is for small reduce dim, require D < 255 * 8.
|
||||
* size of tmpLocal: (N, 64)
|
||||
*/
|
||||
__aicore__ inline void ReduceSumMultiN(
|
||||
const LocalTensor<float>& dstLocal, const LocalTensor<float>& srcLocal, const LocalTensor<float>& tmpLocal,
|
||||
const uint32_t numRow, const uint32_t numCol, const uint32_t numColAlign)
|
||||
{
|
||||
const uint32_t tailCount = numCol % ELEM_PER_REP_FP32;
|
||||
const uint32_t repeat = numRow;
|
||||
const uint8_t repStride = numColAlign / ELEM_PER_BLK_FP32;
|
||||
Duplicate(tmpLocal, ZERO, numRow * ELEM_PER_REP_FP32);
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSumForSmallReduceDim(dstLocal, srcLocal, tmpLocal, numColAlign, numCol, tailCount, repeat, repStride);
|
||||
}
|
||||
|
||||
__aicore__ inline int32_t findPowerTwo(int32_t n)
|
||||
{
|
||||
// find max power of 2 no more than n (32 bit)
|
||||
n |= n >> 1; // Set the first digit of n's binary to 1
|
||||
n |= n >> INDEX_TWO;
|
||||
n |= n >> INDEX_FOUR;
|
||||
n |= n >> INDEX_EIGHT;
|
||||
n |= n >> INDEX_SIXTEEN;
|
||||
return (n + 1) >> 1;
|
||||
}
|
||||
|
||||
__aicore__ inline void ReduceSumHalfInterval(
|
||||
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, int32_t count)
|
||||
{
|
||||
if (likely(count > ELEM_PER_REP_FP32)) {
|
||||
int32_t bodyCount = findPowerTwo(count);
|
||||
int32_t tailCount = count - bodyCount;
|
||||
if (tailCount > 0) {
|
||||
Add(src_local, src_local, src_local[bodyCount], tailCount);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
while (bodyCount > ELEM_PER_REP_FP32) {
|
||||
bodyCount = bodyCount / HALf_INTERVAL;
|
||||
Add(src_local, src_local, src_local[bodyCount], bodyCount);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
AscendCUtils::SetMask<float>(ELEM_PER_REP_FP32);
|
||||
} else {
|
||||
AscendCUtils::SetMask<float>(count);
|
||||
}
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
|
||||
if (g_coreType == AIV) {
|
||||
WholeReduceSum<float, false>(dst_local, src_local, ELEM_PER_REP_FP32, 1, 0, 1, 0);
|
||||
}
|
||||
#elif defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003
|
||||
WholeReduceSum(dst_local, src_local, ELEM_PER_REP_FP32, 1, 1, 1, ELEM_PER_BLK_FP32);
|
||||
#else
|
||||
WholeReduceSum<float, false>(dst_local, src_local, ELEM_PER_REP_FP32, 1, 1, 1, DEFAULT_REPEAT_STRIDE);
|
||||
#endif
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline float ReduceSumHalfInterval(const LocalTensor<float>& src_local, int32_t count)
|
||||
{
|
||||
if (likely(count > ELEM_PER_REP_FP32)) {
|
||||
int32_t bodyCount = findPowerTwo(count);
|
||||
int32_t tailCount = count - bodyCount;
|
||||
if (tailCount > 0) {
|
||||
Add(src_local, src_local, src_local[bodyCount], tailCount);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
while (bodyCount > ELEM_PER_REP_FP32) {
|
||||
bodyCount = bodyCount / HALf_INTERVAL;
|
||||
Add(src_local, src_local, src_local[bodyCount], bodyCount);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
AscendCUtils::SetMask<float>(ELEM_PER_REP_FP32);
|
||||
} else {
|
||||
AscendCUtils::SetMask<float>(count);
|
||||
}
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
|
||||
if (g_coreType == AIV) {
|
||||
WholeReduceSum<float, false>(src_local, src_local, ELEM_PER_REP_FP32, 1, 0, 1, 0);
|
||||
}
|
||||
#elif defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003
|
||||
WholeReduceSum(src_local, src_local, ELEM_PER_REP_FP32, 1, 1, 1, ELEM_PER_BLK_FP32);
|
||||
#else
|
||||
WholeReduceSum<float, false>(src_local, src_local, ELEM_PER_REP_FP32, 1, 1, 1, DEFAULT_REPEAT_STRIDE);
|
||||
#endif
|
||||
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(event_v_s);
|
||||
WaitFlag<HardEvent::V_S>(event_v_s);
|
||||
return src_local.GetValue(0);
|
||||
}
|
||||
#endif // _REDUCE_COMMON_H_
|
||||
316
csrc/add_rms_norm_bias/op_kernel/rms_norm_base.h
Normal file
316
csrc/add_rms_norm_bias/op_kernel/rms_norm_base.h
Normal file
@@ -0,0 +1,316 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef RMS_NORM_BASE_H_
|
||||
#define RMS_NORM_BASE_H_
|
||||
#include "kernel_operator.h"
|
||||
#include "reduce_common.h"
|
||||
|
||||
namespace RmsNorm {
|
||||
using namespace AscendC;
|
||||
|
||||
|
||||
/**
|
||||
* Get the block size of unified buffer in bytes
|
||||
*/
|
||||
__aicore__ inline constexpr uint32_t GetUbBlockSize()
|
||||
{
|
||||
return 32U;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the size of vector registers in bytes
|
||||
*/
|
||||
__aicore__ inline constexpr uint32_t GetVRegSize()
|
||||
{
|
||||
#if __CCE_AICORE__ == 310
|
||||
return AscendC::VECTOR_REG_WIDTH;
|
||||
#else
|
||||
return 256U;
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ != 220 && __CCE_AICORE__ != 310
|
||||
#define bfloat16_t int16_t
|
||||
#endif
|
||||
constexpr int32_t BUFFER_NUM = 1; // tensor num for each queue
|
||||
constexpr int32_t DOUBLE_BUFFER_NUM = 2;
|
||||
constexpr int32_t UNROLL_NUM = 2;
|
||||
constexpr int32_t NUM_PER_REP_FP32 = 64; // ONE_REPEAT_BYTE_SIZE / sizeof(float);
|
||||
constexpr int32_t NUM_PER_BLK_FP32 = 8;
|
||||
constexpr int32_t FLOAT_BTYPE_SIZE = 4;
|
||||
constexpr int32_t NUM_PER_BLK_FP16 = 16;
|
||||
constexpr int32_t CONTINUE_STRIDE = 8;
|
||||
constexpr int32_t BLOCK_SIZE = 32;
|
||||
constexpr uint32_t ONCE_VECTOR_SIZE = 256;
|
||||
constexpr float MINUS_HALF = -0.5f;
|
||||
constexpr uint32_t ZERO_UINT = 0;
|
||||
constexpr uint32_t ONE_UINT = 1;
|
||||
constexpr uint32_t TWO_UINT = 2;
|
||||
constexpr uint32_t THREE_UINT = 3;
|
||||
constexpr float ONE = 1;
|
||||
constexpr int32_t SECOND_LOOP = 2;
|
||||
constexpr int32_t HALf_INTERVAL = 2;
|
||||
constexpr int32_t MAX_REAPEAT = 255;
|
||||
constexpr int32_t DIM_NUM = 2;
|
||||
constexpr int32_t NDDMA_DIM = 5;
|
||||
|
||||
constexpr uint32_t V_LENGTH = GetVRegSize() / sizeof(float);
|
||||
constexpr uint64_t ALIGN_512_FACTOR = 512;
|
||||
constexpr uint64_t ALIGN_32_FACTOR = 32;
|
||||
constexpr int32_t CONST_FACTOR_2 = 2;
|
||||
constexpr uint32_t SUM_COUNT = 2;
|
||||
constexpr int32_t MOV_2 = 2;
|
||||
constexpr int32_t MOV_4 = 4;
|
||||
constexpr int32_t MOV_8 = 8;
|
||||
constexpr int32_t MOV_16 = 16;
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T CeilDiv(T x, T y)
|
||||
{
|
||||
return y == 0 ? x : (x + y - 1) / y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T Min(T left, T right)
|
||||
{
|
||||
return (left < right ? left : right);
|
||||
}
|
||||
|
||||
template <typename Tp, Tp v>
|
||||
struct integral_constant {
|
||||
static constexpr Tp value = v;
|
||||
};
|
||||
using true_type = integral_constant<bool, true>;
|
||||
using false_type = integral_constant<bool, false>;
|
||||
template <typename, typename>
|
||||
struct is_same : public false_type {};
|
||||
template <typename Tp>
|
||||
struct is_same<Tp, Tp> : public true_type {};
|
||||
|
||||
template <typename T, typename T_GAMMA>
|
||||
class KernelRmsNormBase {
|
||||
#define IS_X_FP32 (is_same<T, float>::value)
|
||||
#define IS_GAMMA_FP32 (is_same<T_GAMMA, float>::value)
|
||||
#define IS_MIX_DTYPE ((!IS_X_FP32) && IS_GAMMA_FP32)
|
||||
};
|
||||
|
||||
__aicore__ inline int32_t findPowerTwo(int32_t n)
|
||||
{
|
||||
// find max power of 2 no more than n (32 bit)
|
||||
n |= n >> 1; // Set the first digit of n's binary to 1
|
||||
n |= n >> MOV_2;
|
||||
n |= n >> MOV_4;
|
||||
n |= n >> MOV_8;
|
||||
n |= n >> MOV_16;
|
||||
return (n + 1) >> 1;
|
||||
}
|
||||
|
||||
__aicore__ inline void ReduceSumHalfIntervalToRepeat(
|
||||
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, int32_t count, int32_t left)
|
||||
{
|
||||
// count need smaller than 255 repeat
|
||||
if (likely(count > NUM_PER_BLK_FP32)) {
|
||||
int32_t bodyCount = count - left;
|
||||
int32_t tailCount = left;
|
||||
if (tailCount > 0) {
|
||||
Add(src_local, src_local, src_local[bodyCount], tailCount);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
while (bodyCount > SECOND_LOOP * NUM_PER_BLK_FP32) {
|
||||
bodyCount = bodyCount / HALf_INTERVAL;
|
||||
Add(src_local, src_local, src_local[bodyCount], bodyCount);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
bodyCount = bodyCount / HALf_INTERVAL;
|
||||
Add(dst_local, src_local, src_local[bodyCount], bodyCount);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void ReduceSumFP32(
|
||||
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, const LocalTensor<float>& work_local,
|
||||
int32_t count)
|
||||
{
|
||||
// count need smaller than 255 repeat
|
||||
uint64_t mask = NUM_PER_REP_FP32;
|
||||
int32_t repeatTimes = count / NUM_PER_REP_FP32;
|
||||
int32_t tailCount = count % NUM_PER_REP_FP32;
|
||||
int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
|
||||
BinaryRepeatParams repeatParams;
|
||||
repeatParams.src0RepStride = ONE_REPEAT_BYTE_SIZE / ONE_BLK_SIZE;
|
||||
repeatParams.src0BlkStride = 1;
|
||||
repeatParams.src1RepStride = 0;
|
||||
repeatParams.src1BlkStride = 1;
|
||||
repeatParams.dstRepStride = 0;
|
||||
repeatParams.dstBlkStride = 1;
|
||||
Duplicate(work_local, ZERO, NUM_PER_REP_FP32);
|
||||
PipeBarrier<PIPE_V>();
|
||||
if (likely(repeatTimes > 0)) {
|
||||
Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
if (unlikely(tailCount != 0)) {
|
||||
Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
AscendCUtils::SetMask<float>(NUM_PER_REP_FP32);
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
|
||||
if (g_coreType == AIV) {
|
||||
WholeReduceSum<float, false>(dst_local, work_local, MASK_PLACEHOLDER, 1, 0, 1, 0);
|
||||
}
|
||||
#elif !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
WholeReduceSum<float, false>(dst_local, work_local, MASK_PLACEHOLDER, 1, 1, 1, DEFAULT_REPEAT_STRIDE);
|
||||
#endif
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void ReduceSumCustom(
|
||||
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, const LocalTensor<float>& work_local,
|
||||
int32_t count)
|
||||
{
|
||||
ReduceSumFP32(dst_local, src_local, work_local, count);
|
||||
}
|
||||
__aicore__ inline void ReduceSumFP32ToBlock(
|
||||
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, const LocalTensor<float>& work_local,
|
||||
int32_t count)
|
||||
{
|
||||
// count need smaller than 255 repeat
|
||||
uint64_t mask = NUM_PER_REP_FP32;
|
||||
int32_t repeatTimes = count / NUM_PER_REP_FP32;
|
||||
int32_t tailCount = count % NUM_PER_REP_FP32;
|
||||
int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
|
||||
BinaryRepeatParams repeatParams;
|
||||
repeatParams.src0RepStride = ONCE_VECTOR_SIZE / BLOCK_SIZE;
|
||||
repeatParams.src0BlkStride = 1;
|
||||
repeatParams.src1RepStride = 0;
|
||||
repeatParams.src1BlkStride = 1;
|
||||
repeatParams.dstRepStride = 0;
|
||||
repeatParams.dstBlkStride = 1;
|
||||
Duplicate(work_local, ZERO, NUM_PER_REP_FP32);
|
||||
PipeBarrier<PIPE_V>();
|
||||
if (likely(repeatTimes > 0)) {
|
||||
Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
if (unlikely(tailCount != 0)) {
|
||||
Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
BlockReduceSum(dst_local, work_local, 1, mask, 1, 1, DEFAULT_REPEAT_STRIDE);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void BlockReduceSumFP32(
|
||||
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, int32_t count)
|
||||
{
|
||||
// count need multiple of 8
|
||||
int32_t repeatTimes = count / NUM_PER_REP_FP32;
|
||||
int32_t tailCount = count % NUM_PER_REP_FP32;
|
||||
int32_t dstAddr = repeatTimes * 8;
|
||||
int32_t srcAddr = repeatTimes * NUM_PER_REP_FP32;
|
||||
if (likely(repeatTimes > 0)) {
|
||||
BlockReduceSum(dst_local, src_local, repeatTimes, NUM_PER_REP_FP32, 1, 1, DEFAULT_REPEAT_STRIDE);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
if (tailCount != 0) {
|
||||
BlockReduceSum(dst_local[dstAddr], src_local[srcAddr], 1, tailCount, 1, 1, DEFAULT_REPEAT_STRIDE);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename R>
|
||||
__aicore__ inline void DataCopyCustom(const U& dstTensor, const R& srcTensor, const uint32_t count)
|
||||
{
|
||||
#if (defined(__CCE_AICORE__) && __CCE_AICORE__ == 220) || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
DataCopyParams copyParams;
|
||||
copyParams.blockLen = count * sizeof(T);
|
||||
copyParams.blockCount = 1;
|
||||
if constexpr (is_same<U, AscendC::LocalTensor<T>>::value) {
|
||||
DataCopyPadParams padParams;
|
||||
DataCopyPad(dstTensor, srcTensor, copyParams, padParams);
|
||||
} else {
|
||||
DataCopyPad(dstTensor, srcTensor, copyParams);
|
||||
}
|
||||
#else
|
||||
// only support count greater than 32byte
|
||||
int32_t numPerBlock = ONE_BLK_SIZE / sizeof(T);
|
||||
if (count % numPerBlock == 0) {
|
||||
DataCopy(dstTensor, srcTensor, count);
|
||||
} else {
|
||||
if constexpr (is_same<U, AscendC::LocalTensor<T>>::value) {
|
||||
int32_t num = AlignUp(count, numPerBlock);
|
||||
DataCopy(dstTensor, srcTensor, num);
|
||||
} else {
|
||||
if (count < numPerBlock) {
|
||||
DataCopy(dstTensor, srcTensor, numPerBlock);
|
||||
} else {
|
||||
int32_t num = count / numPerBlock * numPerBlock;
|
||||
DataCopy(dstTensor, srcTensor, num);
|
||||
SetFlag<HardEvent::MTE3_S>(EVENT_ID0);
|
||||
WaitFlag<HardEvent::MTE3_S>(EVENT_ID0);
|
||||
for (int32_t i = 0; i < numPerBlock; i++) {
|
||||
T tensorValue = srcTensor.GetValue(count - numPerBlock + i);
|
||||
srcTensor.SetValue(i, tensorValue);
|
||||
}
|
||||
SetFlag<HardEvent::S_MTE3>(EVENT_ID0);
|
||||
WaitFlag<HardEvent::S_MTE3>(EVENT_ID0);
|
||||
DataCopy(dstTensor[count - numPerBlock], srcTensor, numPerBlock);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void DataCopyCustom(
|
||||
const LocalTensor<T>& dstTensor, const GlobalTensor<T>& srcTensor, const uint32_t numRow, const uint32_t numCol)
|
||||
{
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
|
||||
DataCopyParams copyParams;
|
||||
copyParams.blockLen = numCol * sizeof(T);
|
||||
copyParams.blockCount = numRow;
|
||||
DataCopyPadParams padParams;
|
||||
DataCopyPad(dstTensor, srcTensor, copyParams, padParams);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void DataCopyCustom(
|
||||
const GlobalTensor<T>& dstTensor, const LocalTensor<T>& srcTensor, const uint32_t numRow, const uint32_t numCol)
|
||||
{
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
|
||||
DataCopyParams copyParams;
|
||||
copyParams.blockLen = numCol * sizeof(T);
|
||||
copyParams.blockCount = numRow;
|
||||
DataCopyPad(dstTensor, srcTensor, copyParams);
|
||||
#endif
|
||||
}
|
||||
|
||||
__aicore__ inline void RoundFloat2Int8(LocalTensor<int8_t>& dstTensor, LocalTensor<float>& srcTensor, int32_t size)
|
||||
{
|
||||
Cast(srcTensor.ReinterpretCast<int32_t>(), srcTensor, RoundMode::CAST_RINT, size);
|
||||
PipeBarrier<PIPE_V>();
|
||||
SetDeqScale((half)1.000000e+00f);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(srcTensor.ReinterpretCast<half>(), srcTensor.ReinterpretCast<int32_t>(), RoundMode::CAST_NONE, size);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(dstTensor, srcTensor.ReinterpretCast<half>(), RoundMode::CAST_TRUNC, size);
|
||||
}
|
||||
|
||||
__aicore__ inline uint32_t ROUND_UP(uint32_t x, uint32_t block_number)
|
||||
{
|
||||
if (block_number > 0) {
|
||||
return (x + block_number - 1) / block_number * block_number;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
} // namespace RmsNorm
|
||||
#endif // RMS_NORM_BASE_H_
|
||||
Reference in New Issue
Block a user