[feature] add_rms_norm support bias (#5790)

### What this PR does / why we need it?
This PR is to replace addRmsNorm and Add With addRmsNormBias. This way
can lead to a more effecient result.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Full Test Pass

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

Signed-off-by: Chen_HaoWen <chenhaowen12@huawei.com>
Co-authored-by: Chen_HaoWen <chenhaowen12@huawei.com>
This commit is contained in:
yjmyl
2026-01-23 21:09:54 +08:00
committed by GitHub
parent 6c73b88dd6
commit e90b14140b
24 changed files with 3537 additions and 13 deletions

View File

@@ -0,0 +1,72 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file add_rms_norm_bias.cpp
* \brief
*/
#include "add_rms_norm_bias.h"
#include "add_rms_norm_bias_split_d.h"
#include "add_rms_norm_bias_merge_n.h"
#include "add_rms_norm_bias_multi_n.h"
#include "add_rms_norm_bias_single_n.h"
using namespace AscendC;
#define GENERAL_OP_IMPL(templateClass, ...) \
do { \
templateClass<__VA_ARGS__> op(&pipe); \
op.Init(x1, x2, gamma, beta, y, rstd, x, &tilingData); \
op.Process(); \
} while (0)
extern "C" __global__ __aicore__ void add_rms_norm_bias(
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, GM_ADDR workspace, GM_ADDR tiling)
{
TPipe pipe;
GET_TILING_DATA(tilingData, tiling);
if (TILING_KEY_IS(10)) {
GENERAL_OP_IMPL(KernelAddRmsNormBias, half);
} else if (TILING_KEY_IS(20)) {
GENERAL_OP_IMPL(KernelAddRmsNormBias, float);
} else if (TILING_KEY_IS(30)) {
#if !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
GENERAL_OP_IMPL(KernelAddRmsNormBias, bfloat16_t);
#endif
} else if (TILING_KEY_IS(11)) {
GENERAL_OP_IMPL(KernelAddRmsNormBiasSplitD, half);
} else if (TILING_KEY_IS(21)) {
GENERAL_OP_IMPL(KernelAddRmsNormBiasSplitD, float);
} else if (TILING_KEY_IS(31)) {
#if !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
GENERAL_OP_IMPL(KernelAddRmsNormBiasSplitD, bfloat16_t);
#endif
} else if (TILING_KEY_IS(12)) {
GENERAL_OP_IMPL(KernelAddRmsNormBiasMergeN, half);
} else if (TILING_KEY_IS(22)) {
GENERAL_OP_IMPL(KernelAddRmsNormBiasMergeN, float);
} else if (TILING_KEY_IS(32)) {
#if !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
GENERAL_OP_IMPL(KernelAddRmsNormBiasMergeN, bfloat16_t);
#endif
} else if (TILING_KEY_IS(13)) {
GENERAL_OP_IMPL(KernelAddRmsNormBiasSingleN, half);
} else if (TILING_KEY_IS(23)) {
GENERAL_OP_IMPL(KernelAddRmsNormBiasSingleN, float);
} else if (TILING_KEY_IS(33)) {
#if !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
GENERAL_OP_IMPL(KernelAddRmsNormBiasSingleN, bfloat16_t);
#endif
} else if (TILING_KEY_IS(14)) {
GENERAL_OP_IMPL(KernelAddRmsNormBiasMultiN, half);
} else if (TILING_KEY_IS(34)) {
GENERAL_OP_IMPL(KernelAddRmsNormBiasMultiN, bfloat16_t);
}
}

View File

@@ -0,0 +1,368 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file add_rms_norm_bias.h
* \brief add rms norm bias file
*/
#ifndef ADD_RMS_NORM_H_
#define ADD_RMS_NORM_H_
#include "./rms_norm_base.h"
using namespace AscendC;
using namespace RmsNorm;
template <typename T>
class KernelAddRmsNormBias {
public:
__aicore__ inline KernelAddRmsNormBias(TPipe* pipe)
{
Ppipe = pipe;
}
__aicore__ inline void Init(
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, const AddRMSNormBiasTilingData* tiling)
{
ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!");
this->numRow = tiling->num_row;
this->numCol = tiling->num_col;
this->blockFactor = tiling->block_factor;
this->rowFactor = tiling->row_factor;
this->ubFactor = tiling->ub_factor;
this->epsilon = tiling->epsilon;
this->avgFactor = (numCol != 0) ? (float)1.0 / numCol : 0;
this->nullptrBeta = tiling->nullptr_beta;
blockIdx_ = GetBlockIdx();
if (blockIdx_ < GetBlockNum() - 1) {
this->rowWork = blockFactor;
} else if (blockIdx_ == GetBlockNum() - 1) {
this->rowWork = numRow - (GetBlockNum() - 1) * blockFactor;
}
// get start index for current core, core parallel
x1Gm.SetGlobalBuffer((__gm__ T*)x1 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
x2Gm.SetGlobalBuffer((__gm__ T*)x2 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
gammaGm.SetGlobalBuffer((__gm__ T*)gamma, numCol);
if (!this->nullptrBeta) {
betaGm.SetGlobalBuffer((__gm__ T*)beta, numCol);
}
yGm.SetGlobalBuffer((__gm__ T*)y + blockIdx_ * blockFactor * numCol, rowWork * numCol);
rstdGm.SetGlobalBuffer((__gm__ float*)rstd + blockIdx_ * blockFactor, blockFactor);
xGm.SetGlobalBuffer((__gm__ T*)x + blockIdx_ * blockFactor * numCol, rowWork * numCol);
// pipe alloc memory to queue, the unit is Bytes
Ppipe->InitBuffer(inQueueX, BUFFER_NUM, ubFactor * sizeof(T));
Ppipe->InitBuffer(inQueueGamma, BUFFER_NUM, ubFactor * sizeof(T));
if (!this->nullptrBeta) {
Ppipe->InitBuffer(inQueueBeta, BUFFER_NUM, ubFactor * sizeof(T));
}
Ppipe->InitBuffer(outQueueY, BUFFER_NUM, ubFactor * sizeof(T));
Ppipe->InitBuffer(outQueueRstd, BUFFER_NUM, rowFactor * sizeof(float));
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
Ppipe->InitBuffer(xFp32Buf, ubFactor * sizeof(float));
}
Ppipe->InitBuffer(sqxBuf, ubFactor * sizeof(float));
Ppipe->InitBuffer(reduceFp32Buf, NUM_PER_REP_FP32 * sizeof(float));
}
__aicore__ inline void Process()
{
CopyInGammaBeta();
LocalTensor<T> gammaLocal = inQueueGamma.DeQue<T>();
LocalTensor<T> betaLocal;
if (!this->nullptrBeta) {
betaLocal = inQueueBeta.DeQue<T>();
}
uint32_t i_o_max = RmsNorm::CeilDiv(rowWork, rowFactor);
uint32_t row_tail = rowWork - (i_o_max - 1) * rowFactor;
for (uint32_t i_o = 0; i_o < i_o_max - 1; i_o++) {
SubProcess(i_o, rowFactor, gammaLocal, betaLocal);
}
SubProcess(i_o_max - 1, row_tail, gammaLocal, betaLocal);
inQueueGamma.FreeTensor(gammaLocal);
if (!this->nullptrBeta) {
inQueueBeta.FreeTensor(betaLocal);
}
}
__aicore__ inline void SubProcess(uint32_t i_o, uint32_t calc_row_num, LocalTensor<T>& gammaLocal, LocalTensor<T>& betaLocal)
{
LocalTensor<float> rstdLocal = outQueueRstd.AllocTensor<float>();
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
uint32_t gm_bias = (i_o * rowFactor + i_i) * numCol;
CopyIn(gm_bias);
Compute(i_i, gammaLocal, betaLocal, rstdLocal);
CopyOutY(gm_bias);
}
outQueueRstd.EnQue<float>(rstdLocal);
CopyOutRstd(i_o, calc_row_num);
}
private:
__aicore__ inline void CopyIn(uint32_t gm_bias)
{
LocalTensor<T> x1Local_in = inQueueX.AllocTensor<T>();
LocalTensor<T> x2Local = sqxBuf.Get<T>();
LocalTensor<T> xLocal = outQueueY.AllocTensor<T>();
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
x2Local = x2Local[ubFactor];
}
DataCopyCustom<T>(x1Local_in, x1Gm[gm_bias], numCol);
DataCopyCustom<T>(x2Local, x2Gm[gm_bias], numCol);
inQueueX.EnQue(x1Local_in);
auto x1Local = inQueueX.DeQue<T>();
if constexpr (is_same<T, half>::value) {
LocalTensor<float> x1_fp32 = xFp32Buf.Get<float>();
Add(xLocal, x1Local, x2Local, numCol);
PipeBarrier<PIPE_V>();
Cast(x1_fp32, xLocal, RoundMode::CAST_NONE, numCol);
PipeBarrier<PIPE_V>();
} else if constexpr (is_same<T, bfloat16_t>::value) {
LocalTensor<float> x1_fp32 = xFp32Buf.Get<float>();
LocalTensor<float> x2_fp32 = sqxBuf.Get<float>();
Cast(x1_fp32, x1Local, RoundMode::CAST_NONE, numCol);
Cast(x2_fp32, x2Local, RoundMode::CAST_NONE, numCol);
PipeBarrier<PIPE_V>();
Add(x1_fp32, x1_fp32, x2_fp32, numCol);
PipeBarrier<PIPE_V>();
Cast(xLocal, x1_fp32, RoundMode::CAST_RINT, numCol);
PipeBarrier<PIPE_V>();
} else {
Add(x1Local, x1Local, x2Local, numCol);
PipeBarrier<PIPE_V>();
Adds(xLocal, x1Local, (float)0, numCol);
}
inQueueX.FreeTensor(x1Local);
// CopyOut x1 + x2
outQueueY.EnQue(xLocal);
auto x_out = outQueueY.DeQue<T>();
DataCopyCustom<T>(xGm[gm_bias], x_out, numCol);
outQueueY.FreeTensor(x_out);
}
__aicore__ inline void CopyInGammaBeta()
{
LocalTensor<T> gammaLocal = inQueueGamma.AllocTensor<T>();
DataCopyCustom<T>(gammaLocal, gammaGm, numCol);
inQueueGamma.EnQue(gammaLocal);
if (!this->nullptrBeta) {
LocalTensor<T> betaLocal = inQueueBeta.AllocTensor<T>();
DataCopyCustom<T>(betaLocal, betaGm, numCol);
inQueueBeta.EnQue(betaLocal);
}
}
__aicore__ inline void Compute(uint32_t inner_progress, LocalTensor<float> gammaLocal, LocalTensor<float> betaLocal, LocalTensor<float> rstdLocal)
{
LocalTensor<float> xLocal = inQueueX.AllocTensor<float>();
LocalTensor<float> sqx = sqxBuf.Get<float>();
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
Mul(sqx, xLocal, xLocal, numCol);
PipeBarrier<PIPE_V>();
Muls(sqx, sqx, avgFactor, numCol);
PipeBarrier<PIPE_V>();
ReduceSumCustom(sqx, sqx, reduce_buf_local, numCol);
PipeBarrier<PIPE_V>();
Adds(sqx, sqx, epsilon, 1);
PipeBarrier<PIPE_V>();
Sqrt(sqx, sqx, 1);
Duplicate(reduce_buf_local, ONE, 1);
PipeBarrier<PIPE_V>();
Div(sqx, reduce_buf_local, sqx, 1);
PipeBarrier<PIPE_V>();
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(event_v_s);
WaitFlag<HardEvent::V_S>(event_v_s);
float rstdValue = sqx.GetValue(0);
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(event_s_v);
WaitFlag<HardEvent::S_V>(event_s_v);
rstdLocal.SetValue(inner_progress, rstdValue);
PipeBarrier<PIPE_V>();
LocalTensor<float> yLocal = outQueueY.AllocTensor<float>();
Muls(yLocal, xLocal, rstdValue, numCol);
inQueueX.FreeTensor(xLocal);
PipeBarrier<PIPE_V>();
Mul(yLocal, gammaLocal, yLocal, numCol);
if (!this->nullptrBeta) {
PipeBarrier<PIPE_V>();
Add(yLocal, betaLocal, yLocal, numCol);
}
PipeBarrier<PIPE_V>();
outQueueY.EnQue<float>(yLocal);
}
__aicore__ inline void Compute(
uint32_t inner_progress, LocalTensor<bfloat16_t> gammaLocal, LocalTensor<bfloat16_t> betaLocal, LocalTensor<float> rstdLocal)
{
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
LocalTensor<float> sqx = sqxBuf.Get<float>();
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
Mul(sqx, x_fp32, x_fp32, numCol);
PipeBarrier<PIPE_V>();
Muls(sqx, sqx, avgFactor, numCol);
PipeBarrier<PIPE_V>();
ReduceSumCustom(sqx, sqx, reduce_buf_local, numCol);
PipeBarrier<PIPE_V>();
Adds(sqx, sqx, epsilon, 1);
PipeBarrier<PIPE_V>();
Sqrt(sqx, sqx, 1);
Duplicate(reduce_buf_local, ONE, 1);
PipeBarrier<PIPE_V>();
Div(sqx, reduce_buf_local, sqx, 1);
PipeBarrier<PIPE_V>();
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(event_v_s);
WaitFlag<HardEvent::V_S>(event_v_s);
float rstdValue = sqx.GetValue(0);
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(event_s_v);
WaitFlag<HardEvent::S_V>(event_s_v);
rstdLocal.SetValue(inner_progress, rstdValue);
PipeBarrier<PIPE_V>();
Muls(x_fp32, x_fp32, rstdValue, numCol);
PipeBarrier<PIPE_V>();
LocalTensor<bfloat16_t> yLocal = outQueueY.AllocTensor<bfloat16_t>();
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, numCol);
PipeBarrier<PIPE_V>();
Cast(x_fp32, yLocal, RoundMode::CAST_NONE, numCol);
PipeBarrier<PIPE_V>();
Cast(sqx, gammaLocal, RoundMode::CAST_NONE, numCol); // gamma_fp32 reuse sqx
PipeBarrier<PIPE_V>();
Mul(x_fp32, x_fp32, sqx, numCol);
if (!this->nullptrBeta) {
PipeBarrier<PIPE_V>();
Cast(sqx, betaLocal, RoundMode::CAST_NONE, numCol);
PipeBarrier<PIPE_V>();
Add(x_fp32, x_fp32, sqx, numCol);
}
PipeBarrier<PIPE_V>();
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, numCol);
PipeBarrier<PIPE_V>();
event_t event_v_mte = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
SetFlag<HardEvent::V_MTE2>(event_v_mte);
WaitFlag<HardEvent::V_MTE2>(event_v_mte);
outQueueY.EnQue<bfloat16_t>(yLocal);
}
__aicore__ inline void Compute(uint32_t inner_progress, LocalTensor<half> gammaLocal, LocalTensor<half> betaLocal, LocalTensor<float> rstdLocal)
{
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
LocalTensor<float> sqx = sqxBuf.Get<float>();
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
Mul(sqx, x_fp32, x_fp32, numCol);
PipeBarrier<PIPE_V>();
Muls(sqx, sqx, avgFactor, numCol);
PipeBarrier<PIPE_V>();
ReduceSumCustom(sqx, sqx, reduce_buf_local, numCol);
PipeBarrier<PIPE_V>();
Adds(sqx, sqx, epsilon, 1);
PipeBarrier<PIPE_V>();
Sqrt(sqx, sqx, 1);
Duplicate(reduce_buf_local, ONE, 1);
PipeBarrier<PIPE_V>();
Div(sqx, reduce_buf_local, sqx, 1);
PipeBarrier<PIPE_V>();
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(event_v_s);
WaitFlag<HardEvent::V_S>(event_v_s);
float rstdValue = sqx.GetValue(0);
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(event_s_v);
WaitFlag<HardEvent::S_V>(event_s_v);
rstdLocal.SetValue(inner_progress, rstdValue);
PipeBarrier<PIPE_V>();
Muls(x_fp32, x_fp32, rstdValue, numCol);
PipeBarrier<PIPE_V>();
LocalTensor<half> yLocal = outQueueY.AllocTensor<half>();
Cast(yLocal, x_fp32, RoundMode::CAST_NONE, numCol);
event_t event_v_mte = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
SetFlag<HardEvent::V_MTE2>(event_v_mte);
WaitFlag<HardEvent::V_MTE2>(event_v_mte);
PipeBarrier<PIPE_V>();
Mul(yLocal, gammaLocal, yLocal, numCol);
if (!this->nullptrBeta) {
PipeBarrier<PIPE_V>();
Add(yLocal, betaLocal, yLocal, numCol);
}
PipeBarrier<PIPE_V>();
outQueueY.EnQue<half>(yLocal);
}
__aicore__ inline void CopyOutY(uint32_t progress)
{
LocalTensor<T> yLocal = outQueueY.DeQue<T>();
DataCopyCustom<T>(yGm[progress], yLocal, numCol);
outQueueY.FreeTensor(yLocal);
}
__aicore__ inline void CopyOutRstd(uint32_t outer_progress, uint32_t num)
{
LocalTensor<float> rstdLocal = outQueueRstd.DeQue<float>();
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
DataCopyCustom<float>(rstdGm[outer_progress * rowFactor], rstdLocal, num);
#endif
outQueueRstd.FreeTensor(rstdLocal);
}
private:
TPipe* Ppipe = nullptr;
// create queues for input, in this case depth is equal to buffer num
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueX;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueGamma;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueBeta;
// create queues for output, in this case depth is equal to buffer num
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueY;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueRstd;
TBuf<TPosition::VECCALC> xFp32Buf;
TBuf<TPosition::VECCALC> sqxBuf;
TBuf<TPosition::VECCALC> reduceFp32Buf;
GlobalTensor<T> x1Gm;
GlobalTensor<T> x2Gm;
GlobalTensor<T> gammaGm;
GlobalTensor<T> betaGm;
GlobalTensor<T> yGm;
GlobalTensor<float> rstdGm;
GlobalTensor<T> xGm;
uint32_t numRow;
uint32_t numCol;
uint32_t blockFactor; // number of calculations rows on each core
uint32_t rowFactor;
uint32_t ubFactor;
float epsilon;
float avgFactor;
int32_t blockIdx_;
uint32_t rowWork = 1;
uint32_t nullptrBeta = 0;
};
#endif // ADD_RMS_NORM_BIAS_H_

View File

@@ -0,0 +1,471 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file add_rms_norm_bias_merge_n.h
* \brief add rms norm bias merge n file
*/
#ifndef ADD_RMS_NORM_BIAS_MERGE_N_H_
#define ADD_RMS_NORM_BIAS_MERGE_N_H_
#include "./rms_norm_base.h"
using namespace AscendC;
using namespace RmsNorm;
template <typename T>
class KernelAddRmsNormBiasMergeN {
public:
__aicore__ inline KernelAddRmsNormBiasMergeN(TPipe* pipe)
{
Ppipe = pipe;
}
__aicore__ inline void Init(
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, const AddRMSNormBiasTilingData* tiling)
{
ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!");
this->numRow = tiling->num_row;
this->numCol = tiling->num_col;
this->numColAlign = tiling->num_col_align;
this->blockFactor = tiling->block_factor;
this->rowFactor = tiling->row_factor;
this->ubFactor = tiling->ub_factor;
this->epsilon = tiling->epsilon;
this->avgFactor = tiling->avg_factor;
blockIdx_ = GetBlockIdx();
if (blockIdx_ < GetBlockNum() - 1) {
this->rowWork = blockFactor;
this->rowLoop = tiling->row_loop;
this->rowTail = tiling->row_tail;
} else if (blockIdx_ == GetBlockNum() - 1) {
this->rowWork = tiling->last_block_factor;
this->rowLoop = tiling->last_block_row_loop;
this->rowTail = tiling->last_block_row_tail;
}
this->mulLoopFp32 = tiling->mul_loop_fp32;
this->mulTailFp32 = tiling->mul_tail_fp32;
this->dstRepStrideFp32 = tiling->dst_rep_stride_fp32;
this->mulLoopFp16 = tiling->mul_loop_fp16;
this->mulTailFp16 = tiling->mul_tail_fp16;
this->dstRepStrideFp16 = tiling->dst_rep_stride_fp16;
this->isPerformance = tiling->is_performance;
this->nullptrBeta = tiling->nullptr_beta;
// get start index for current core, core parallel
x1Gm.SetGlobalBuffer((__gm__ T*)x1 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
x2Gm.SetGlobalBuffer((__gm__ T*)x2 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
gammaGm.SetGlobalBuffer((__gm__ T*)gamma, numCol);
if (!this->nullptrBeta) {
betaGm.SetGlobalBuffer((__gm__ T*)beta, numCol);
}
yGm.SetGlobalBuffer((__gm__ T*)y + blockIdx_ * blockFactor * numCol, rowWork * numCol);
rstdGm.SetGlobalBuffer((__gm__ float*)rstd + blockIdx_ * blockFactor, blockFactor);
xGm.SetGlobalBuffer((__gm__ T*)x + blockIdx_ * blockFactor * numCol, rowWork * numCol);
// pipe alloc memory to queue, the unit is Bytes
Ppipe->InitBuffer(inQueueX, DOUBLE_BUFFER_NUM, ubFactor * sizeof(T));
Ppipe->InitBuffer(inQueueGamma, BUFFER_NUM, ubFactor * sizeof(T));
if (!this->nullptrBeta) {
Ppipe->InitBuffer(inQueueBeta, BUFFER_NUM, ubFactor * sizeof(T));
}
Ppipe->InitBuffer(outQueueY, DOUBLE_BUFFER_NUM, ubFactor * sizeof(T));
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
Ppipe->InitBuffer(outQueueRstd, BUFFER_NUM, rowFactor * sizeof(float));
#else
Ppipe->InitBuffer(rstdBuf, rowFactor * sizeof(float));
#endif
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
Ppipe->InitBuffer(xFp32Buf, ubFactor * sizeof(float));
}
Ppipe->InitBuffer(sqxBuf, ubFactor * sizeof(float));
Ppipe->InitBuffer(tmpBuf, rowFactor * NUM_PER_REP_FP32 * sizeof(float));
}
__aicore__ inline void Process()
{
CopyInGammaBeta();
LocalTensor<T> gammaLocal = inQueueGamma.DeQue<T>();
LocalTensor<T> betaLocal;
if (!this->nullptrBeta) {
betaLocal = inQueueBeta.DeQue<T>();
}
for (uint32_t i_o = 0; i_o < rowLoop - 1; i_o++) {
MainCompute(i_o, rowFactor, gammaLocal, betaLocal);
}
MainCompute(rowLoop - 1, rowTail, gammaLocal, betaLocal);
inQueueGamma.FreeTensor(gammaLocal);
if (!this->nullptrBeta) {
inQueueBeta.FreeTensor(betaLocal);
}
}
__aicore__ inline void MainCompute(uint32_t i_o, uint32_t calc_row_num, LocalTensor<T>& gammaLocal, LocalTensor<T>& betaLocal)
{
uint32_t gm_bias = i_o * rowFactor * numCol;
uint32_t elementNum = calc_row_num * numColAlign;
CopyInX(gm_bias, calc_row_num);
LocalTensor<T> xLocal = ComputeX(elementNum);
CopyOutX(gm_bias, calc_row_num);
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
LocalTensor<float> rstdLocal = outQueueRstd.AllocTensor<float>();
ComputeRstd(xLocal, rstdLocal, calc_row_num, elementNum);
outQueueRstd.EnQue<float>(rstdLocal);
CopyOutRstd(i_o, calc_row_num);
#else
LocalTensor<float> rstdLocal = rstdBuf.Get<float>();
ComputeRstd(xLocal, rstdLocal, calc_row_num, elementNum);
#endif
ComputeY(xLocal, gammaLocal, betaLocal, rstdLocal, calc_row_num, elementNum);
CopyOutY(gm_bias, calc_row_num);
}
private:
__aicore__ inline void CopyInX(uint32_t gm_bias, uint32_t calc_row_num)
{
LocalTensor<T> x1Local = inQueueX.AllocTensor<T>();
if (isNumColAlign) {
DataCopyCustom<T>(x1Local, x1Gm[gm_bias], calc_row_num * numCol);
} else {
DataCopyCustom<T>(x1Local, x1Gm[gm_bias], calc_row_num, numCol);
}
inQueueX.EnQue(x1Local);
LocalTensor<T> x2Local = inQueueX.AllocTensor<T>();
if (isNumColAlign) {
DataCopyCustom<T>(x2Local, x2Gm[gm_bias], calc_row_num * numCol);
} else {
DataCopyCustom<T>(x2Local, x2Gm[gm_bias], calc_row_num, numCol);
}
inQueueX.EnQue(x2Local);
}
__aicore__ inline LocalTensor<T> ComputeX(uint32_t elementNum)
{
LocalTensor<T> x1Local = inQueueX.DeQue<T>();
LocalTensor<T> x2Local = inQueueX.DeQue<T>();
LocalTensor<T> xLocal = outQueueY.AllocTensor<T>();
if constexpr (!is_same<T, bfloat16_t>::value) {
Add(xLocal, x1Local, x2Local, elementNum);
} else {
LocalTensor<float> x1Fp32 = xFp32Buf.Get<float>();
LocalTensor<float> x2Fp32 = sqxBuf.Get<float>();
Cast(x1Fp32, x1Local, RoundMode::CAST_NONE, elementNum);
Cast(x2Fp32, x2Local, RoundMode::CAST_NONE, elementNum);
PipeBarrier<PIPE_V>();
Add(x1Fp32, x1Fp32, x2Fp32, elementNum);
PipeBarrier<PIPE_V>();
Cast(xLocal, x1Fp32, RoundMode::CAST_RINT, elementNum);
}
inQueueX.FreeTensor(x1Local);
inQueueX.FreeTensor(x2Local);
outQueueY.EnQue(xLocal);
PipeBarrier<PIPE_V>();
return xLocal;
}
__aicore__ inline void CopyOutX(uint32_t gm_bias, uint32_t calc_row_num)
{
// CopyOut x1 + x2
auto xOut = outQueueY.DeQue<T>();
if (isNumColAlign) {
DataCopyCustom<T>(xGm[gm_bias], xOut, calc_row_num * numCol);
} else {
DataCopyCustom<T>(xGm[gm_bias], xOut, calc_row_num, numCol);
}
outQueueY.FreeTensor(xOut);
}
__aicore__ inline void CopyInGammaBeta()
{
LocalTensor<T> gammaLocal = inQueueGamma.AllocTensor<T>();
DataCopyCustom<T>(gammaLocal, gammaGm, numCol);
inQueueGamma.EnQue(gammaLocal);
if (!this->nullptrBeta) {
LocalTensor<T> betaLocal = inQueueBeta.AllocTensor<T>();
DataCopyCustom<T>(betaLocal, betaGm, numCol);
inQueueBeta.EnQue(betaLocal);
}
}
__aicore__ inline void ComputeRstd(LocalTensor<T> xLocal, LocalTensor<float> rstdLocal, uint32_t calc_row_num, uint32_t elementNum)
{
LocalTensor<float> sqx = sqxBuf.Get<float>();
LocalTensor<float> tmpLocal = tmpBuf.Get<float>();
if constexpr (!is_same<T, float>::value) {
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
Cast(x_fp32, xLocal, RoundMode::CAST_NONE, elementNum);
PipeBarrier<PIPE_V>();
Mul(sqx, x_fp32, x_fp32, elementNum);
} else {
Mul(sqx, xLocal, xLocal, elementNum);
}
PipeBarrier<PIPE_V>();
Muls(sqx, sqx, avgFactor, elementNum);
PipeBarrier<PIPE_V>();
ReduceSumMultiN(rstdLocal, sqx, tmpLocal, calc_row_num, numCol, numColAlign);
PipeBarrier<PIPE_V>();
Adds(rstdLocal, rstdLocal, epsilon, calc_row_num);
PipeBarrier<PIPE_V>();
Sqrt(rstdLocal, rstdLocal, calc_row_num);
Duplicate(tmpLocal, ONE, calc_row_num);
PipeBarrier<PIPE_V>();
Div(rstdLocal, tmpLocal, rstdLocal, calc_row_num);
PipeBarrier<PIPE_V>();
}
__aicore__ inline void ComputeY(
LocalTensor<T> xLocal, LocalTensor<T> gammaLocal, LocalTensor<T> betaLocal, LocalTensor<float> rstdLocal, uint32_t calc_row_num, uint32_t elementNum)
{
LocalTensor<float> tmpLocal = tmpBuf.Get<float>();
uint32_t splidRow = 240;
uint32_t rowRepeatLoop1 = calc_row_num / splidRow;
uint32_t rowRepeatTail1 = calc_row_num - rowRepeatLoop1 * splidRow;
for(uint32_t r_i = 0; r_i < rowRepeatLoop1; r_i ++) {
Brcb(tmpLocal[r_i * splidRow * MOV_8], rstdLocal[r_i * splidRow], splidRow, {1, 8});
}
PipeBarrier<PIPE_V>();
if(rowRepeatTail1 > 0) {
Brcb(tmpLocal[rowRepeatLoop1 * splidRow * MOV_8], rstdLocal[rowRepeatLoop1 * splidRow], rowRepeatTail1, {1, 8});
PipeBarrier<PIPE_V>();
}
LocalTensor<T> yLocal = outQueueY.AllocTensor<T>();
if constexpr (!is_same<T, float>::value) {
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
repeatByRow<float>(x_fp32, x_fp32, tmpLocal, calc_row_num, ONE_UINT);
if constexpr (is_same<T, half>::value) {
Cast(yLocal, x_fp32, RoundMode::CAST_NONE, elementNum);
} else {
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, elementNum);
}
} else {
repeatByRow<float>(yLocal, xLocal, tmpLocal, calc_row_num, ONE_UINT);
}
PipeBarrier<PIPE_V>();
if constexpr (is_same<T, half>::value) {
repeatByRow<half>(yLocal, yLocal, gammaLocal, calc_row_num, TWO_UINT);
if (!this->nullptrBeta) {
addRepeatByRow<half>(yLocal, yLocal, betaLocal, calc_row_num, TWO_UINT);
}
} else if constexpr (is_same<T, bfloat16_t>::value) {
LocalTensor<float> sqx = sqxBuf.Get<float>();
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
Cast(x_fp32, yLocal, RoundMode::CAST_NONE, elementNum);
Cast(sqx, gammaLocal, RoundMode::CAST_NONE, elementNum);
PipeBarrier<PIPE_V>();
repeatByRow<float>(x_fp32, x_fp32, sqx, calc_row_num, THREE_UINT);
if (!this->nullptrBeta) {
Cast(sqx, betaLocal, RoundMode::CAST_NONE, elementNum);
PipeBarrier<PIPE_V>();
addRepeatByRow<float>(x_fp32, x_fp32, sqx, calc_row_num, THREE_UINT);
}
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, elementNum);
} else {
repeatByRow<float>(yLocal, yLocal, gammaLocal, calc_row_num, THREE_UINT);
if (!this->nullptrBeta) {
addRepeatByRow<float>(yLocal, yLocal, betaLocal, calc_row_num, THREE_UINT);
}
}
PipeBarrier<PIPE_V>();
outQueueY.EnQue<T>(yLocal);
}
__aicore__ inline void CopyOutY(uint32_t progress, uint32_t calc_row_num)
{
LocalTensor<T> yLocal = outQueueY.DeQue<T>();
if (isNumColAlign) {
DataCopyCustom<T>(yGm[progress], yLocal, calc_row_num * numCol);
} else {
DataCopyCustom<T>(yGm[progress], yLocal, calc_row_num, numCol);
}
outQueueY.FreeTensor(yLocal);
}
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
__aicore__ inline void CopyOutRstd(uint32_t outer_progress, uint32_t num)
{
LocalTensor<float> rstdLocal = outQueueRstd.DeQue<float>();
DataCopyCustom<float>(rstdGm[outer_progress * rowFactor], rstdLocal, num);
outQueueRstd.FreeTensor(rstdLocal);
}
#endif
template <typename U>
__aicore__ inline void repeatByRow(const LocalTensor<U>& dstLocal, const LocalTensor<U>& src1Local, const LocalTensor<U>& src2Local, uint32_t calc_row_num, uint32_t type)
{
// TWO_UINT=gammaFp16 ONE_UINT=rstd
uint32_t strideParams[6] = {mulLoopFp32, mulTailFp32, 64, 1, dstRepStrideFp32, 0};
if (type == TWO_UINT) {
strideParams[0] = mulLoopFp16;
strideParams[1] = mulTailFp16;
strideParams[2] = 128;
strideParams[4] = dstRepStrideFp16;
} else if (type == ONE_UINT) {
strideParams[3] = 0;
strideParams[5] = 1;
}
uint32_t singlT = 255;
uint32_t rowRepeatLoop = calc_row_num / singlT;
uint32_t rowRepeatTail = calc_row_num - rowRepeatLoop * singlT;
uint32_t offset2 = 0;
for(uint32_t r_i = 0; r_i < rowRepeatLoop; r_i ++) {
offset2 = type == 1 ? (r_i * singlT * MOV_8) : 0;
mulRepeat<U>(dstLocal[r_i * singlT * numColAlign], src1Local[r_i * singlT * numColAlign], src2Local[offset2], singlT, strideParams);
}
if(rowRepeatTail > 0) {
offset2 = type == 1 ? (rowRepeatLoop * singlT * MOV_8) : 0;
uint32_t offset1 = rowRepeatLoop * singlT * numColAlign;
mulRepeat<U>(dstLocal[offset1], src1Local[offset1], src2Local[offset2], rowRepeatTail, strideParams);
}
}
template <typename U>
__aicore__ inline void mulRepeat(const LocalTensor<U>& dstLocal, const LocalTensor<U>& src1Local, const LocalTensor<U>& src2Local, uint32_t calcRowNum, uint32_t strideParams[6])
{
uint32_t mulLoop = strideParams[0];
uint32_t mulTail = strideParams[1];
uint32_t strideNum = strideParams[2];
uint8_t src1BlkStride = static_cast<uint8_t>(strideParams[3]);
uint8_t dstRepStride = static_cast<uint8_t>(strideParams[4]);
uint8_t src1RepStride = static_cast<uint8_t>(strideParams[5]);
if(src1BlkStride == 0) {
for (uint32_t m_i = 0; m_i < mulLoop; m_i++) {
Mul(dstLocal[m_i * strideNum], src1Local[m_i * strideNum], src2Local, strideNum, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
}
PipeBarrier<PIPE_V>();
if(mulTail > 0) {
Mul(dstLocal[mulLoop * strideNum], src1Local[mulLoop * strideNum], src2Local, mulTail, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
}
PipeBarrier<PIPE_V>();
} else {
for (uint32_t m_i = 0; m_i < mulLoop; m_i++) {
Mul(dstLocal[m_i * strideNum], src1Local[m_i * strideNum], src2Local[m_i * strideNum], strideNum, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
}
PipeBarrier<PIPE_V>();
if(mulTail > 0) {
Mul(dstLocal[mulLoop * strideNum], src1Local[mulLoop * strideNum], src2Local[mulLoop * strideNum], mulTail, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
}
PipeBarrier<PIPE_V>();
}
}
template <typename U>
__aicore__ inline void addRepeatByRow(const LocalTensor<U>& dstLocal, const LocalTensor<U>& src1Local, const LocalTensor<U>& src2Local, uint32_t calc_row_num, uint32_t type)
{
// TWO_UINT=gammaFp16 ONE_UINT=rstd
uint32_t strideParams[6] = {mulLoopFp32, mulTailFp32, 64, 1, dstRepStrideFp32, 0};
if (type == TWO_UINT) {
strideParams[0] = mulLoopFp16;
strideParams[1] = mulTailFp16;
strideParams[2] = 128;
strideParams[4] = dstRepStrideFp16;
} else if (type == ONE_UINT) {
strideParams[3] = 0;
strideParams[5] = 1;
}
uint32_t singlT = 255;
uint32_t rowRepeatLoop = calc_row_num / singlT;
uint32_t rowRepeatTail = calc_row_num - rowRepeatLoop * singlT;
uint32_t offset2 = 0;
for(uint32_t r_i = 0; r_i < rowRepeatLoop; r_i ++) {
offset2 = type == 1 ? (r_i * singlT * MOV_8) : 0;
addRepeat<U>(dstLocal[r_i * singlT * numColAlign], src1Local[r_i * singlT * numColAlign], src2Local[offset2], singlT, strideParams);
}
if(rowRepeatTail > 0) {
offset2 = type == 1 ? (rowRepeatLoop * singlT * MOV_8) : 0;
uint32_t offset1 = rowRepeatLoop * singlT * numColAlign;
addRepeat<U>(dstLocal[offset1], src1Local[offset1], src2Local[offset2], rowRepeatTail, strideParams);
}
}
template <typename U>
__aicore__ inline void addRepeat(const LocalTensor<U>& dstLocal, const LocalTensor<U>& src1Local, const LocalTensor<U>& src2Local, uint32_t calcRowNum, uint32_t strideParams[6])
{
uint32_t addLoop = strideParams[0];
uint32_t addTail = strideParams[1];
uint32_t strideNum = strideParams[2];
uint8_t src1BlkStride = static_cast<uint8_t>(strideParams[3]);
uint8_t dstRepStride = static_cast<uint8_t>(strideParams[4]);
uint8_t src1RepStride = static_cast<uint8_t>(strideParams[5]);
if(src1BlkStride == 0) {
for (uint32_t m_i = 0; m_i < addLoop; m_i++) {
Add(dstLocal[m_i * strideNum], src1Local[m_i * strideNum], src2Local, strideNum, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
}
PipeBarrier<PIPE_V>();
if(addTail > 0) {
Add(dstLocal[addLoop * strideNum], src1Local[addLoop * strideNum], src2Local, addTail, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
}
PipeBarrier<PIPE_V>();
} else {
for (uint32_t m_i = 0; m_i < addLoop; m_i++) {
Add(dstLocal[m_i * strideNum], src1Local[m_i * strideNum], src2Local[m_i * strideNum], strideNum, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
}
PipeBarrier<PIPE_V>();
if(addTail > 0) {
Add(dstLocal[addLoop * strideNum], src1Local[addLoop * strideNum], src2Local[addLoop * strideNum], addTail, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
}
PipeBarrier<PIPE_V>();
}
}
private:
TPipe* Ppipe = nullptr;
// create queues for input, in this case depth is equal to buffer num
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueGamma;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueBeta;
TQue<QuePosition::VECIN, DOUBLE_BUFFER_NUM> inQueueX;
// create queues for output, in this case depth is equal to buffer num
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueRstd;
#else
TBuf<TPosition::VECCALC> rstdBuf;
#endif
TQue<QuePosition::VECOUT, DOUBLE_BUFFER_NUM> outQueueY;
TBuf<TPosition::VECCALC> xFp32Buf;
TBuf<TPosition::VECCALC> sqxBuf;
TBuf<TPosition::VECCALC> tmpBuf;
GlobalTensor<T> x1Gm;
GlobalTensor<T> x2Gm;
GlobalTensor<T> gammaGm;
GlobalTensor<T> betaGm;
GlobalTensor<T> yGm;
GlobalTensor<float> rstdGm;
GlobalTensor<T> xGm;
uint32_t numRow;
uint32_t numCol;
uint32_t numColAlign;
uint32_t blockFactor; // number of calculations rows on each core
uint32_t rowFactor;
uint32_t ubFactor;
float epsilon;
float avgFactor;
int32_t blockIdx_;
uint32_t rowWork = 1;
#if (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
bool isNumColAlign = true;
#else
bool isNumColAlign = false;
#endif
uint8_t isPerformance = 0;
uint32_t rowLoop = 1;
uint32_t rowTail = 0;
uint32_t mulLoopFp32;
uint32_t mulTailFp32;
uint8_t dstRepStrideFp32;
uint32_t mulLoopFp16;
uint32_t mulTailFp16;
uint8_t dstRepStrideFp16;
uint32_t nullptrBeta = 0;
};
#endif // _ADD_RMS_NORM_BIAS_MERGE_N_H_

View File

@@ -0,0 +1,339 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file add_rms_norm_bias_multi_n.h
* \brief add rms norm bias multi n file
*/
#ifndef ADD_RMS_NORM_BIAS_MULTI_N_H_
#define ADD_RMS_NORM_BIAS_MULTI_N_H_
#include "./rms_norm_base.h"
using namespace AscendC;
using namespace RmsNorm;
template <typename T>
class KernelAddRmsNormBiasMultiN {
public:
__aicore__ inline KernelAddRmsNormBiasMultiN(TPipe* pipe)
{
Ppipe = pipe;
}
__aicore__ inline void Init(
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, const AddRMSNormBiasTilingData* tiling)
{
ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!");
this->numRow = tiling->num_row;
this->numCol = tiling->num_col;
this->numColAlign = tiling->num_col_align;
this->blockFactor = tiling->block_factor;
this->rowFactor = tiling->row_factor;
this->ubFactor = tiling->ub_factor;
this->epsilon = tiling->epsilon;
this->avgFactor = tiling->avg_factor;
this->nullptrBeta = tiling->nullptr_beta;
blockIdx_ = GetBlockIdx();
if (blockIdx_ < GetBlockNum() - 1) {
this->rowWork = blockFactor;
this->rowLoop = tiling->row_loop;
this->rowTail = tiling->row_tail;
} else if (blockIdx_ == GetBlockNum() - 1) {
this->rowWork = tiling->last_block_factor;
this->rowLoop = tiling->last_block_row_loop;
this->rowTail = tiling->last_block_row_tail;
}
// get start index for current core, core parallel
x1Gm.SetGlobalBuffer((__gm__ T*)x1 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
x2Gm.SetGlobalBuffer((__gm__ T*)x2 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
gammaGm.SetGlobalBuffer((__gm__ T*)gamma, numCol);
if (!this->nullptrBeta) {
betaGm.SetGlobalBuffer((__gm__ T*)beta, numCol);
}
yGm.SetGlobalBuffer((__gm__ T*)y + blockIdx_ * blockFactor * numCol, rowWork * numCol);
rstdGm.SetGlobalBuffer((__gm__ float*)rstd + blockIdx_ * blockFactor, blockFactor);
xGm.SetGlobalBuffer((__gm__ T*)x + blockIdx_ * blockFactor * numCol, rowWork * numCol);
// pipe alloc memory to queue, the unit is Bytes
Ppipe->InitBuffer(inQueueX, DOUBLE_BUFFER_NUM, ubFactor * sizeof(T));
Ppipe->InitBuffer(inQueueGamma, BUFFER_NUM, numColAlign * sizeof(T));
if (!this->nullptrBeta) {
Ppipe->InitBuffer(inQueueBeta, BUFFER_NUM, numColAlign * sizeof(T));
}
Ppipe->InitBuffer(outQueueY, DOUBLE_BUFFER_NUM, ubFactor * sizeof(T));
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
Ppipe->InitBuffer(outQueueRstd, BUFFER_NUM, rowFactor * NUM_PER_BLK_FP32 * sizeof(float));
#else
Ppipe->InitBuffer(rstdBuf, rowFactor * NUM_PER_BLK_FP32 * sizeof(float));
#endif
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
Ppipe->InitBuffer(xFp32Buf, ubFactor * sizeof(float));
}
Ppipe->InitBuffer(sqxBuf, ubFactor * sizeof(float));
Ppipe->InitBuffer(reduceFp32Buf, NUM_PER_REP_FP32 * sizeof(float));
Ppipe->InitBuffer(offsetBuf, rowFactor * NUM_PER_BLK_FP32 * sizeof(uint32_t));
}
__aicore__ inline void Process()
{
CopyInGammaBeta();
LocalTensor<T> betaLocal;
if (!this->nullptrBeta) {
betaLocal = inQueueBeta.DeQue<T>();
}
LocalTensor<T> gammaLocal = inQueueGamma.DeQue<T>();
LocalTensor<uint32_t> offsetLocal = offsetBuf.Get<uint32_t>();
for (uint32_t i = 0; i < rowFactor; i++) {
Duplicate(offsetLocal[i * NUM_PER_BLK_FP32], i * ONE_BLK_SIZE, NUM_PER_BLK_FP32);
}
for (uint32_t i_o = 0; i_o < rowLoop - 1; i_o++) {
SubProcessHalf(i_o, rowFactor, gammaLocal, betaLocal);
}
SubProcessHalf(rowLoop - 1, rowTail, gammaLocal, betaLocal);
inQueueGamma.FreeTensor(gammaLocal);
if (!this->nullptrBeta) {
inQueueBeta.FreeTensor(betaLocal);
}
}
__aicore__ inline void SubProcessHalf(uint32_t i_o, uint32_t calc_row_num, LocalTensor<T>& gammaLocal, LocalTensor<T>& betaLocal)
{
uint32_t gm_bias = i_o * rowFactor * numCol;
CopyInX(gm_bias, calc_row_num);
LocalTensor<T> xLocal = ComputeX(calc_row_num);
CopyOutX(gm_bias, calc_row_num);
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
LocalTensor<float> rstdLocal = outQueueRstd.AllocTensor<float>();
ComputeRstd(xLocal, rstdLocal, calc_row_num);
outQueueRstd.EnQue<float>(rstdLocal);
CopyOutRstd(i_o * rowFactor, calc_row_num);
#else
LocalTensor<float> rstdLocal = rstdBuf.Get<float>();
ComputeRstd(xLocal, rstdLocal, calc_row_num);
#endif
ComputeY(xLocal, gammaLocal, betaLocal, rstdLocal, calc_row_num);
CopyOutY(gm_bias, calc_row_num);
}
private:
__aicore__ inline void CopyInX(uint32_t gm_bias, uint32_t calc_row_num)
{
LocalTensor<T> x1Local = inQueueX.AllocTensor<T>();
DataCopyCustom<T>(x1Local, x1Gm[gm_bias], calc_row_num * numCol);
inQueueX.EnQue(x1Local);
LocalTensor<T> x2Local = inQueueX.AllocTensor<T>();
DataCopyCustom<T>(x2Local, x2Gm[gm_bias], calc_row_num * numCol);
inQueueX.EnQue(x2Local);
}
__aicore__ inline LocalTensor<T> ComputeX(uint32_t calc_row_num)
{
uint32_t calc_num = calc_row_num * numColAlign;
LocalTensor<T> x1Local = inQueueX.DeQue<T>();
LocalTensor<T> x2Local = inQueueX.DeQue<T>();
LocalTensor<T> xLocal = outQueueY.AllocTensor<T>();
if constexpr (!is_same<T, bfloat16_t>::value) {
Add(xLocal, x1Local, x2Local, calc_num);
} else {
LocalTensor<float> x1Fp32 = xFp32Buf.Get<float>();
LocalTensor<float> x2Fp32 = sqxBuf.Get<float>();
Cast(x1Fp32, x1Local, RoundMode::CAST_NONE, calc_num);
Cast(x2Fp32, x2Local, RoundMode::CAST_NONE, calc_num);
PipeBarrier<PIPE_V>();
Add(x1Fp32, x1Fp32, x2Fp32, calc_num);
PipeBarrier<PIPE_V>();
Cast(xLocal, x1Fp32, RoundMode::CAST_RINT, calc_num);
}
inQueueX.FreeTensor(x1Local);
inQueueX.FreeTensor(x2Local);
outQueueY.EnQue(xLocal);
PipeBarrier<PIPE_V>();
return xLocal;
}
__aicore__ inline void CopyOutX(uint32_t gm_bias, uint32_t calc_row_num)
{
// CopyOut x1 + x2
auto x_out = outQueueY.DeQue<T>();
DataCopyCustom<T>(xGm[gm_bias], x_out, calc_row_num * numCol);
outQueueY.FreeTensor(x_out);
}
__aicore__ inline void CopyInGammaBeta()
{
LocalTensor<T> gammaLocal = inQueueGamma.AllocTensor<T>();
DataCopyCustom<T>(gammaLocal, gammaGm, numCol);
inQueueGamma.EnQue(gammaLocal);
if (!this->nullptrBeta) {
LocalTensor<T> betaLocal = inQueueBeta.AllocTensor<T>();
DataCopyCustom<T>(betaLocal, betaGm, numCol);
inQueueBeta.EnQue(betaLocal);
}
}
__aicore__ inline void ComputeRstd(LocalTensor<T> xLocal, LocalTensor<float> rstdLocal, uint32_t calc_row_num)
{
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
LocalTensor<float> sqx = sqxBuf.Get<float>();
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
Cast(x_fp32, xLocal, RoundMode::CAST_NONE, calc_row_num * numColAlign);
PipeBarrier<PIPE_V>();
Mul(sqx, x_fp32, x_fp32, calc_row_num * numColAlign);
PipeBarrier<PIPE_V>();
Muls(sqx, sqx, avgFactor, calc_row_num * numColAlign);
PipeBarrier<PIPE_V>();
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
ReduceSumCustom(rstdLocal[i_i * NUM_PER_BLK_FP32], sqx[i_i * numColAlign], reduce_buf_local, numCol);
}
Adds(rstdLocal, rstdLocal, epsilon, calc_row_num * NUM_PER_BLK_FP32);
PipeBarrier<PIPE_V>();
Sqrt(rstdLocal, rstdLocal, calc_row_num * NUM_PER_BLK_FP32);
Duplicate(reduce_buf_local, ONE, NUM_PER_BLK_FP32);
PipeBarrier<PIPE_V>();
int32_t repeatTimes = calc_row_num * NUM_PER_BLK_FP32 / NUM_PER_REP_FP32;
int32_t tailCount = calc_row_num * NUM_PER_BLK_FP32 % NUM_PER_REP_FP32;
int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
if (likely(repeatTimes > 0)) {
Div(rstdLocal, reduce_buf_local, rstdLocal, NUM_PER_REP_FP32, repeatTimes, {1, 0, 1, DEFAULT_REPEAT_STRIDE, 0, DEFAULT_REPEAT_STRIDE});
}
if (unlikely(tailCount != 0)) {
Div(rstdLocal[bodyCount], reduce_buf_local, rstdLocal[bodyCount], tailCount, 1, {1, 0, 1, DEFAULT_REPEAT_STRIDE, 0, DEFAULT_REPEAT_STRIDE});
}
PipeBarrier<PIPE_V>();
}
__aicore__ inline void ComputeY(
LocalTensor<T> xLocal, LocalTensor<T> gammaLocal, LocalTensor<T> betaLocal, LocalTensor<float> rstdLocal, uint32_t calc_row_num)
{
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
LocalTensor<uint32_t> offsetLocal = offsetBuf.Get<uint32_t>();
Gather(rstdLocal, rstdLocal, offsetLocal, ZERO_UINT, calc_row_num * NUM_PER_BLK_FP32);
PipeBarrier<PIPE_V>();
int32_t repeatTimes = numCol / NUM_PER_REP_FP32;
int32_t tailCount = numCol % NUM_PER_REP_FP32;
int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
if (likely(repeatTimes > 0)) {
Mul(x_fp32[i_i * numColAlign], x_fp32[i_i * numColAlign], rstdLocal[i_i * NUM_PER_BLK_FP32],
NUM_PER_REP_FP32, repeatTimes, {1, 1, 0, DEFAULT_REPEAT_STRIDE, DEFAULT_REPEAT_STRIDE, 0});
}
if (unlikely(tailCount != 0)) {
Mul(x_fp32[i_i * numColAlign + bodyCount], x_fp32[i_i * numColAlign + bodyCount],
rstdLocal[i_i * NUM_PER_BLK_FP32], tailCount, 1,
{1, 1, 0, DEFAULT_REPEAT_STRIDE, DEFAULT_REPEAT_STRIDE, 0});
}
}
PipeBarrier<PIPE_V>();
LocalTensor<T> yLocal = outQueueY.AllocTensor<T>();
if constexpr (is_same<T, half>::value) {
Cast(yLocal, x_fp32, RoundMode::CAST_NONE, calc_row_num * numColAlign);
PipeBarrier<PIPE_V>();
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
Mul(yLocal[i_i * numColAlign], gammaLocal, yLocal[i_i * numColAlign], numCol);
}
if (!this->nullptrBeta) {
PipeBarrier<PIPE_V>();
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
Add(yLocal[i_i * numColAlign], betaLocal, yLocal[i_i * numColAlign], numCol);
}
}
} else {
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, calc_row_num * numColAlign);
PipeBarrier<PIPE_V>();
LocalTensor<float> yfp32 = xFp32Buf.Get<float>();
Cast(yfp32, yLocal, RoundMode::CAST_NONE, calc_row_num * numColAlign);
PipeBarrier<PIPE_V>();
LocalTensor<float> gammaFp32 = sqxBuf.Get<float>();
Cast(gammaFp32, gammaLocal, RoundMode::CAST_NONE, numCol);
PipeBarrier<PIPE_V>();
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
Mul(yfp32[i_i * numColAlign], gammaFp32, yfp32[i_i * numColAlign], numCol);
}
PipeBarrier<PIPE_V>();
if (!this->nullptrBeta) {
Cast(gammaFp32, betaLocal, RoundMode::CAST_NONE, numCol);
PipeBarrier<PIPE_V>();
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
Add(yfp32[i_i * numColAlign], gammaFp32, yfp32[i_i * numColAlign], numCol);
}
PipeBarrier<PIPE_V>();
}
Cast(yLocal, yfp32, RoundMode::CAST_RINT, calc_row_num * numColAlign);
}
PipeBarrier<PIPE_V>();
outQueueY.EnQue<T>(yLocal);
}
__aicore__ inline void CopyOutY(uint32_t progress, uint32_t calc_row_num)
{
LocalTensor<T> yLocal = outQueueY.DeQue<T>();
DataCopyCustom<T>(yGm[progress], yLocal, calc_row_num * numCol);
outQueueY.FreeTensor(yLocal);
}
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
__aicore__ inline void CopyOutRstd(uint32_t outer_progress, uint32_t num)
{
LocalTensor<float> rstdLocal = outQueueRstd.DeQue<float>();
DataCopyParams copyParams;
copyParams.blockLen = sizeof(float);
copyParams.blockCount = num;
DataCopyPad(rstdGm[outer_progress], rstdLocal, copyParams);
outQueueRstd.FreeTensor(rstdLocal);
}
#endif
private:
TPipe* Ppipe = nullptr;
// create queues for input, in this case depth is equal to buffer num
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueGamma;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueBeta;
TQue<QuePosition::VECIN, DOUBLE_BUFFER_NUM> inQueueX;
// create queues for output, in this case depth is equal to buffer num
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueRstd;
#else
TBuf<TPosition::VECCALC> rstdBuf;
#endif
TQue<QuePosition::VECOUT, DOUBLE_BUFFER_NUM> outQueueY;
TBuf<TPosition::VECCALC> xFp32Buf;
TBuf<TPosition::VECCALC> sqxBuf;
TBuf<TPosition::VECCALC> reduceFp32Buf;
TBuf<TPosition::VECCALC> offsetBuf;
GlobalTensor<T> x1Gm;
GlobalTensor<T> x2Gm;
GlobalTensor<T> gammaGm;
GlobalTensor<T> betaGm;
GlobalTensor<T> yGm;
GlobalTensor<float> rstdGm;
GlobalTensor<T> xGm;
uint32_t numRow;
uint32_t numCol;
uint32_t blockFactor; // number of calculations rows on each core
uint32_t rowFactor;
uint32_t ubFactor;
float epsilon;
float avgFactor;
uint32_t numColAlign;
int32_t blockIdx_;
uint32_t rowWork = 1;
uint32_t rowLoop = 1;
uint32_t rowTail = 0;
uint32_t nullptrBeta = 0;
};
#endif // ADD_RMS_NORM__BIAS_MULTI_N_H_

View File

@@ -0,0 +1,376 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file add_rms_norm_bias_single_n.h
* \brief add rms norm bias single n file
*/
#ifndef ADD_RMS_NORM_BIAS_SINGLE_N_H_
#define ADD_RMS_NORM_BIAS_SINGLE_N_H_
#include "./rms_norm_base.h"
using namespace AscendC;
using namespace RmsNorm;
template <typename T>
class KernelAddRmsNormBiasSingleN {
static constexpr int32_t MAXBUFFER = 195584;
public:
__aicore__ inline KernelAddRmsNormBiasSingleN(TPipe* pipe)
{
Ppipe = pipe;
}
__aicore__ inline void Init(
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, const AddRMSNormBiasTilingData* tiling)
{
ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!");
this->numCol = tiling->num_col;
this->blockFactor = 1; // in this case, blockFactor = 1
this->ubFactor = tiling->ub_factor;
this->epsilon = tiling->epsilon;
this->avgFactor = (numCol != 0) ? (float)1.0 / numCol : 0;
this->nullptrBeta = tiling->nullptr_beta;
this->rowWork = 1;
blockIdx_ = GetBlockIdx();
// get start index for current core, core parallel
x1Gm.SetGlobalBuffer((__gm__ T*)x1 + blockIdx_ * numCol, numCol);
x2Gm.SetGlobalBuffer((__gm__ T*)x2 + blockIdx_ * numCol, numCol);
gammaGm.SetGlobalBuffer((__gm__ T*)gamma, numCol);
if (!this->nullptrBeta) {
betaGm.SetGlobalBuffer((__gm__ T*)beta, numCol);
}
yGm.SetGlobalBuffer((__gm__ T*)y + blockIdx_ * numCol, numCol);
rstdGm.SetGlobalBuffer((__gm__ float*)rstd + blockIdx_, 1);
xGm.SetGlobalBuffer((__gm__ T*)x + blockIdx_ * numCol, numCol);
Ppipe->InitBuffer(unitBuf, MAXBUFFER); // (192 - 1) * 1024 byte
}
__aicore__ inline void Process()
{
if constexpr (is_same<T, half>::value) {
ProcessFp16();
} else if constexpr (is_same<T, float>::value) {
ProcessFp32();
} else {
ProcessBf16();
}
}
private:
__aicore__ inline void ProcessFp16()
{
LocalTensor<float> ubLocal = unitBuf.Get<float>();
LocalTensor<T> xLocal = ubLocal.template ReinterpretCast<T>();
LocalTensor<T> x1Local = xLocal[0];
LocalTensor<T> x2Local = xLocal[ubFactor];
LocalTensor<float> xFp32Local = ubLocal[ubFactor];
LocalTensor<float> sqxLocal = ubLocal[ubFactor * 2];
LocalTensor<float> tmpLocal = ubLocal[ubFactor * 3];
DataCopyCustom<T>(x1Local, x1Gm, numCol);
event_t eventMTE2V1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventMTE2V1);
DataCopyCustom<T>(x2Local, x2Gm, numCol);
event_t eventMTE2V2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventMTE2V2);
WaitFlag<HardEvent::MTE2_V>(eventMTE2V1);
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2);
Add(x1Local, x1Local, x2Local, numCol);
PipeBarrier<PIPE_V>();
// copy gamma
event_t eventVMTE2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
SetFlag<HardEvent::V_MTE2>(eventVMTE2);
WaitFlag<HardEvent::V_MTE2>(eventVMTE2);
DataCopyCustom<T>(x2Local, gammaGm, numCol); // gammaLocal use x2Local
SetFlag<HardEvent::MTE2_V>(eventMTE2V2);
// copy x out
event_t eventVMTE3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
DataCopyCustom<T>(xGm, x1Local, numCol);
event_t eventMTE3V = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
SetFlag<HardEvent::MTE3_V>(eventMTE3V);
Cast(xFp32Local, x1Local, RoundMode::CAST_NONE, numCol);
PipeBarrier<PIPE_V>();
Mul(sqxLocal, xFp32Local, xFp32Local, numCol);
PipeBarrier<PIPE_V>();
Muls(sqxLocal, sqxLocal, avgFactor, numCol);
PipeBarrier<PIPE_V>();
ReduceSumCustom(sqxLocal, sqxLocal, tmpLocal, numCol);
PipeBarrier<PIPE_V>();
Adds(sqxLocal, sqxLocal, epsilon, 1);
PipeBarrier<PIPE_V>();
Sqrt(sqxLocal, sqxLocal, 1);
Duplicate(tmpLocal, ONE, 1);
PipeBarrier<PIPE_V>();
Div(sqxLocal, tmpLocal, sqxLocal, 1);
PipeBarrier<PIPE_V>();
// copyout rstd
#if (defined(__CCE_AICORE__) && __CCE_AICORE__ == 220) || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
DataCopyCustom<float>(rstdGm, sqxLocal, 1);
#endif
event_t eventVS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventVS);
WaitFlag<HardEvent::V_S>(eventVS);
float rstdValue = sqxLocal.GetValue(0);
event_t eventSV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventSV);
WaitFlag<HardEvent::S_V>(eventSV);
Muls(xFp32Local, xFp32Local, rstdValue, numCol);
PipeBarrier<PIPE_V>();
WaitFlag<HardEvent::MTE3_V>(eventMTE3V);
Cast(x1Local, xFp32Local, RoundMode::CAST_NONE, numCol);
PipeBarrier<PIPE_V>();
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2);
Mul(x1Local, x1Local, x2Local, numCol);
if (!this->nullptrBeta) {
event_t eventVMTE2Beta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
SetFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
WaitFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
DataCopyCustom<T>(x2Local, betaGm, numCol);
event_t eventMTE2XBeta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
WaitFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
Add(x1Local, x1Local, x2Local, numCol);
}
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
DataCopyCustom<T>(yGm, x1Local, numCol);
}
__aicore__ inline void ProcessFp32()
{
LocalTensor<float> ubLocal = unitBuf.Get<float>();
LocalTensor<T> x1Local = ubLocal[0];
LocalTensor<T> x2Local = ubLocal[ubFactor];
LocalTensor<float> sqxLocal = ubLocal[ubFactor * 2];
LocalTensor<float> tmpLocal = ubLocal[ubFactor * 3];
DataCopyCustom<T>(x1Local, x1Gm, numCol);
event_t eventMTE2V1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventMTE2V1);
DataCopyCustom<T>(x2Local, x2Gm, numCol);
event_t eventMTE2V2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventMTE2V2);
WaitFlag<HardEvent::MTE2_V>(eventMTE2V1);
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2);
Add(x1Local, x1Local, x2Local, numCol);
PipeBarrier<PIPE_V>();
// copy gamma
event_t eventVMTE2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
SetFlag<HardEvent::V_MTE2>(eventVMTE2);
WaitFlag<HardEvent::V_MTE2>(eventVMTE2);
DataCopyCustom<T>(x2Local, gammaGm, numCol); // gammaLocal use x2Local
SetFlag<HardEvent::MTE2_V>(eventMTE2V2);
// copy x out
event_t eventVMTE3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
DataCopyCustom<T>(xGm, x1Local, numCol);
event_t eventMTE3V = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
SetFlag<HardEvent::MTE3_V>(eventMTE3V);
Mul(sqxLocal, x1Local, x1Local, numCol);
PipeBarrier<PIPE_V>();
Muls(sqxLocal, sqxLocal, avgFactor, numCol);
PipeBarrier<PIPE_V>();
ReduceSumCustom(sqxLocal, sqxLocal, tmpLocal, numCol);
PipeBarrier<PIPE_V>();
Adds(sqxLocal, sqxLocal, epsilon, 1);
PipeBarrier<PIPE_V>();
Sqrt(sqxLocal, sqxLocal, 1);
Duplicate(tmpLocal, ONE, 1);
PipeBarrier<PIPE_V>();
Div(sqxLocal, tmpLocal, sqxLocal, 1);
PipeBarrier<PIPE_V>();
// copyout rstd
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
DataCopyCustom<float>(rstdGm, sqxLocal, 1);
#endif
event_t eventVS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventVS);
WaitFlag<HardEvent::V_S>(eventVS);
float rstdValue = sqxLocal.GetValue(0);
event_t eventSV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventSV);
WaitFlag<HardEvent::S_V>(eventSV);
WaitFlag<HardEvent::MTE3_V>(eventMTE3V);
Muls(x1Local, x1Local, rstdValue, numCol);
PipeBarrier<PIPE_V>();
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2);
Mul(x1Local, x1Local, x2Local, numCol);
if (!this->nullptrBeta) {
event_t eventVMTE2Beta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
SetFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
WaitFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
DataCopyCustom<T>(x2Local, betaGm, numCol);
event_t eventMTE2XBeta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
WaitFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
Add(x1Local, x1Local, x2Local, numCol);
}
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
DataCopyCustom<T>(yGm, x1Local, numCol);
}
__aicore__ inline void ProcessBf16()
{
LocalTensor<float> ubLocal = unitBuf.Get<float>();
LocalTensor<T> xLocal = ubLocal.template ReinterpretCast<T>();
LocalTensor<T> x1Local = xLocal[0];
LocalTensor<T> x2Local = xLocal[ubFactor];
LocalTensor<float> xFp32Local = ubLocal[ubFactor];
LocalTensor<float> sqxLocal = ubLocal[ubFactor * 2];
LocalTensor<float> tmpLocal = ubLocal[ubFactor * 3];
DataCopyCustom<T>(x1Local, x1Gm, numCol);
event_t eventMTE2V1_BF16_0 = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>());
SetFlag<HardEvent::MTE2_V>(eventMTE2V1_BF16_0);
DataCopyCustom<T>(x2Local, x2Gm, numCol);
event_t eventMTE2V2_BF16_0 = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>());
SetFlag<HardEvent::MTE2_V>(eventMTE2V2_BF16_0);
WaitFlag<HardEvent::MTE2_V>(eventMTE2V1_BF16_0);
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(eventMTE2V1_BF16_0);
Cast(xFp32Local, x1Local, RoundMode::CAST_NONE, numCol);
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2_BF16_0);
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(eventMTE2V2_BF16_0);
Cast(sqxLocal, x2Local, RoundMode::CAST_NONE, numCol);
PipeBarrier<PIPE_V>();
Add(xFp32Local, xFp32Local, sqxLocal, numCol);
PipeBarrier<PIPE_V>();
Cast(x1Local, xFp32Local, RoundMode::CAST_RINT, numCol);
PipeBarrier<PIPE_V>();
// copy gamma
event_t eventVMTE2_BF16_0 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
SetFlag<HardEvent::V_MTE2>(eventVMTE2_BF16_0);
WaitFlag<HardEvent::V_MTE2>(eventVMTE2_BF16_0);
DataCopyCustom<T>(x2Local, gammaGm, numCol); // gammaLocal use x2Local
event_t eventMTE2V2_BF16_1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventMTE2V2_BF16_1);
// copy x out
event_t eventVMTE3_BF16_0 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
SetFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_0);
WaitFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_0);
DataCopyCustom<T>(xGm, x1Local, numCol);
event_t eventMTE3V_BF16_0 = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::MTE3_V>());
SetFlag<HardEvent::MTE3_V>(eventMTE3V_BF16_0);
Cast(xFp32Local, x1Local, RoundMode::CAST_NONE, numCol);
PipeBarrier<PIPE_V>();
Mul(sqxLocal, xFp32Local, xFp32Local, numCol);
PipeBarrier<PIPE_V>();
Muls(sqxLocal, sqxLocal, avgFactor, numCol);
PipeBarrier<PIPE_V>();
ReduceSumCustom(sqxLocal, sqxLocal, tmpLocal, numCol);
PipeBarrier<PIPE_V>();
Adds(sqxLocal, sqxLocal, epsilon, 1);
PipeBarrier<PIPE_V>();
Sqrt(sqxLocal, sqxLocal, 1);
Duplicate(tmpLocal, ONE, 1);
PipeBarrier<PIPE_V>();
Div(sqxLocal, tmpLocal, sqxLocal, 1);
PipeBarrier<PIPE_V>();
event_t eventVS_BF16_0 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventVS_BF16_0);
WaitFlag<HardEvent::V_S>(eventVS_BF16_0);
float rstdValue = sqxLocal.GetValue(0);
event_t eventSV_BF16_0 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventSV_BF16_0);
WaitFlag<HardEvent::S_V>(eventSV_BF16_0);
// copyout rstd
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
event_t eventVMTE3_BF16_1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
SetFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_1);
WaitFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_1);
DataCopyCustom<float>(rstdGm, sqxLocal, 1);
event_t eventMTE3V2_BF16_0 = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::MTE3_V>());
SetFlag<HardEvent::MTE3_V>(eventMTE3V2_BF16_0);
#endif
Muls(xFp32Local, xFp32Local, rstdValue, numCol);
PipeBarrier<PIPE_V>();
WaitFlag<HardEvent::MTE3_V>(eventMTE3V_BF16_0);
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_V>(eventMTE3V_BF16_0);
Cast(x1Local, xFp32Local, RoundMode::CAST_RINT, numCol);
PipeBarrier<PIPE_V>();
Cast(xFp32Local, x1Local, RoundMode::CAST_NONE, numCol);
PipeBarrier<PIPE_V>();
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2_BF16_1);
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
WaitFlag<HardEvent::MTE3_V>(eventMTE3V2_BF16_0);
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_V>(eventMTE3V2_BF16_0);
#endif
Cast(sqxLocal, x2Local, RoundMode::CAST_NONE, numCol);
PipeBarrier<PIPE_V>();
Mul(xFp32Local, xFp32Local, sqxLocal, numCol);
if (!this->nullptrBeta) {
event_t eventVMTE2Beta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
SetFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
WaitFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
DataCopyCustom<T>(x2Local, betaGm, numCol);
event_t eventMTE2XBeta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
WaitFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
Cast(sqxLocal, x2Local, RoundMode::CAST_NONE, numCol);
PipeBarrier<PIPE_V>();
Add(xFp32Local, xFp32Local, sqxLocal, numCol);
}
PipeBarrier<PIPE_V>();
Cast(x1Local, xFp32Local, RoundMode::CAST_RINT, numCol);
event_t eventVMTE3_BF16_2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
SetFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_2);
WaitFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_2);
DataCopyCustom<T>(yGm, x1Local, numCol);
}
private:
TPipe* Ppipe = nullptr;
TBuf<TPosition::VECCALC> unitBuf;
GlobalTensor<T> x1Gm;
GlobalTensor<T> x2Gm;
GlobalTensor<T> gammaGm;
GlobalTensor<T> betaGm;
GlobalTensor<T> yGm;
GlobalTensor<float> rstdGm;
GlobalTensor<T> xGm;
uint32_t numRow;
uint32_t numCol;
uint32_t blockFactor; // number of calculations rows on each core
uint32_t ubFactor;
float epsilon;
float avgFactor;
int32_t blockIdx_;
uint32_t rowWork = 1;
uint32_t nullptrBeta = 0;
};
#endif // _ADD_RMS_NORM_BIAS_SINGLE_N_H_

View File

@@ -0,0 +1,395 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file add_rms_norm_bias_split_d.h
* \brief add rms norm bias split d file
*/
#ifndef ADD_RMS_NORM_BIAS_SPLIT_D_H_
#define ADD_RMS_NORM_BIAS_SPLIT_D_H_
#include "./rms_norm_base.h"
using namespace AscendC;
using namespace RmsNorm;
template <typename T>
class KernelAddRmsNormBiasSplitD {
public:
__aicore__ inline KernelAddRmsNormBiasSplitD(TPipe* pipe)
{
Ppipe = pipe;
}
__aicore__ inline void Init(
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, const AddRMSNormBiasTilingData* tiling)
{
ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!");
this->numRow = tiling->num_row;
this->numCol = tiling->num_col;
this->blockFactor = tiling->block_factor;
this->rowFactor = tiling->row_factor;
this->ubFactor = tiling->ub_factor;
this->epsilon = tiling->epsilon;
this->avgFactor = (numCol != 0) ? (float)1.0 / numCol : 0;
this->nullptrBeta = tiling->nullptr_beta;
blockIdx_ = GetBlockIdx();
if (blockIdx_ < GetBlockNum() - 1) {
this->rowWork = blockFactor;
} else if (blockIdx_ == GetBlockNum() - 1) {
this->rowWork = numRow - (GetBlockNum() - 1) * blockFactor;
} else {
}
// get start index for current core, core parallel
x1Gm.SetGlobalBuffer((__gm__ T*)x1 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
x2Gm.SetGlobalBuffer((__gm__ T*)x2 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
gammaGm.SetGlobalBuffer((__gm__ T*)gamma, numCol);
if (!this->nullptrBeta) {
betaGm.SetGlobalBuffer((__gm__ T*)beta, numCol);
}
yGm.SetGlobalBuffer((__gm__ T*)y + blockIdx_ * blockFactor * numCol, rowWork * numCol);
rstdGm.SetGlobalBuffer((__gm__ float*)rstd + blockIdx_ * blockFactor, blockFactor);
xGm.SetGlobalBuffer((__gm__ T*)x + blockIdx_ * blockFactor * numCol, rowWork * numCol);
// pipe alloc memory to queue, the unit is Bytes.
// We need 2 buffers here for both x1 and x2.
Ppipe->InitBuffer(inQueueX, BUFFER_NUM, 2 * ubFactor * sizeof(T));
Ppipe->InitBuffer(inQueueGamma, BUFFER_NUM, ubFactor * sizeof(T));
if (!this->nullptrBeta) {
Ppipe->InitBuffer(inQueueBeta, BUFFER_NUM, ubFactor * sizeof(T));
}
Ppipe->InitBuffer(outQueueY, BUFFER_NUM, ubFactor * sizeof(T));
Ppipe->InitBuffer(outQueueRstd, BUFFER_NUM, rowFactor * sizeof(float));
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
Ppipe->InitBuffer(xFp32Buf, ubFactor * sizeof(float));
}
Ppipe->InitBuffer(sqxBuf, ubFactor * sizeof(float));
Ppipe->InitBuffer(sumBuf, rowFactor * NUM_PER_BLK_FP32 * sizeof(float));
Ppipe->InitBuffer(reduceFp32Buf, NUM_PER_REP_FP32 * sizeof(float));
}
__aicore__ inline void Process()
{
uint32_t i_o_max = RmsNorm::CeilDiv(rowWork, rowFactor);
uint32_t row_tail = rowWork - (i_o_max - 1) * rowFactor;
uint32_t j_max = RmsNorm::CeilDiv(numCol, ubFactor);
uint32_t col_tail = numCol - (j_max - 1) * ubFactor;
for (uint32_t i_o = 0; i_o < i_o_max - 1; i_o++) {
SubProcess(i_o, rowFactor, j_max, col_tail);
}
SubProcess(i_o_max - 1, row_tail, j_max, col_tail);
}
__aicore__ inline void SubProcess(uint32_t i_o, uint32_t calc_row_num, uint32_t j_max, uint32_t col_tail)
{
LocalTensor<float> sumLocal = sumBuf.Get<float>();
LocalTensor<float> rstdLocal = outQueueRstd.AllocTensor<float>();
Duplicate(rstdLocal, (float)0.0, calc_row_num);
PipeBarrier<PIPE_V>();
for (uint32_t j = 0; j < j_max - 1; j++) {
ComputeFormer(i_o, calc_row_num, j, rstdLocal, sumLocal, ubFactor);
}
// do tail
ComputeFormer(i_o, calc_row_num, j_max - 1, rstdLocal, sumLocal, col_tail);
ComputeRstd(rstdLocal, calc_row_num);
for (uint32_t j = 0; j < j_max - 1; j++) {
ComputeLatter(i_o, calc_row_num, j, rstdLocal, ubFactor);
}
ComputeLatter(i_o, calc_row_num, j_max - 1, rstdLocal, col_tail);
outQueueRstd.EnQue<float>(rstdLocal);
CopyOutRstd(i_o, calc_row_num);
}
private:
__aicore__ inline void CopyInAndAdd(uint32_t i_idx, uint32_t j_idx, uint32_t num)
{
LocalTensor<T> x1x2_in = inQueueX.AllocTensor<T>();
LocalTensor<T> x1_in = x1x2_in[0];
LocalTensor<T> x2_in = x1x2_in[ubFactor];
DataCopyCustom<T>(x1_in, x1Gm[i_idx * numCol + j_idx * ubFactor], num);
DataCopyCustom<T>(x2_in, x2Gm[i_idx * numCol + j_idx * ubFactor], num);
inQueueX.EnQue(x1x2_in);
LocalTensor<T> x1x2Local = inQueueX.DeQue<T>();
auto x1Local = x1x2Local[0];
auto x2Local = x1x2Local[ubFactor];
LocalTensor<T> xLocal = outQueueY.AllocTensor<T>();
if constexpr (is_same<T, half>::value) {
LocalTensor<float> x1_fp32 = xFp32Buf.Get<float>();
Add(xLocal, x1Local, x2Local, num);
PipeBarrier<PIPE_V>();
Cast(x1_fp32, xLocal, RoundMode::CAST_NONE, num);
PipeBarrier<PIPE_V>();
// x1+x2 saved in x1_fp32
} else if constexpr (is_same<T, bfloat16_t>::value) {
LocalTensor<float> x1_fp32 = xFp32Buf.Get<float>();
LocalTensor<float> x2_fp32 = x1x2Local.template ReinterpretCast<float>();
Cast(x1_fp32, x1Local, RoundMode::CAST_NONE, num);
PipeBarrier<PIPE_V>();
Cast(x2_fp32, x2Local, RoundMode::CAST_NONE, num);
PipeBarrier<PIPE_V>();
Add(x1_fp32, x1_fp32, x2_fp32, num);
PipeBarrier<PIPE_V>();
Cast(xLocal, x1_fp32, RoundMode::CAST_RINT, num);
PipeBarrier<PIPE_V>();
// x1+x2 saved in x1_fp32
} else {
Add(x1Local, x1Local, x2Local, num);
PipeBarrier<PIPE_V>();
Adds(xLocal, x1Local, (float)0.0, num);
// x1+x2 saved in inQueueX
}
inQueueX.FreeTensor(x1x2Local);
// copy out to workspace && x_out
outQueueY.EnQue(xLocal);
auto x_out = outQueueY.DeQue<T>();
DataCopyCustom<T>(xGm[i_idx * numCol + j_idx * ubFactor], x_out, num);
outQueueY.FreeTensor(x_out);
}
__aicore__ inline void ComputeFormer(
uint32_t i_o_idx, uint32_t calc_row_num, uint32_t j_idx, LocalTensor<float>& rstdLocal,
LocalTensor<float>& sumLocal, uint32_t num)
{
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
CopyInAndAdd(i_o_idx * rowFactor + i_i, j_idx, num);
ComputeSum(i_i, sumLocal, num);
}
BlockReduceSumFP32(sumLocal, sumLocal, calc_row_num * NUM_PER_BLK_FP32);
Add(rstdLocal, rstdLocal, sumLocal, calc_row_num);
PipeBarrier<PIPE_V>();
}
__aicore__ inline void ComputeSum(uint32_t i_i_idx, LocalTensor<float>& sumLocal, uint32_t num)
{
LocalTensor<float> sqx = sqxBuf.Get<float>();
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
PipeBarrier<PIPE_V>();
Mul(sqx, x_fp32, x_fp32, num);
} else {
LocalTensor<T> xLocal = inQueueX.AllocTensor<float>();
PipeBarrier<PIPE_V>();
Mul(sqx, xLocal, xLocal, num);
inQueueX.FreeTensor(xLocal);
}
PipeBarrier<PIPE_V>();
Muls(sqx, sqx, avgFactor, num);
PipeBarrier<PIPE_V>();
// 8 means 8 fp32 pre block
ReduceSumFP32ToBlock(sumLocal[i_i_idx * 8], sqx, reduce_buf_local, num);
}
__aicore__ inline void ComputeRstd(LocalTensor<float> rstdLocal, uint32_t num)
{
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
Adds(rstdLocal, rstdLocal, epsilon, num);
PipeBarrier<PIPE_V>();
Sqrt(rstdLocal, rstdLocal, num);
Duplicate(reduce_buf_local, ONE, num);
PipeBarrier<PIPE_V>();
Div(rstdLocal, reduce_buf_local, rstdLocal, num);
PipeBarrier<PIPE_V>();
}
__aicore__ inline void ComputeLatter(
uint32_t i_o_idx, uint32_t calc_row_num, uint32_t j_idx, LocalTensor<float>& rstdLocal, uint32_t num)
{
CopyInGammaBeta(j_idx, num);
LocalTensor<T> gammaLocal = inQueueGamma.DeQue<T>();
LocalTensor<T> betaLocal;
if (!this->nullptrBeta) {
betaLocal = inQueueBeta.DeQue<T>();
}
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
CopyInX(i_o_idx * rowFactor + i_i, j_idx, num);
ComputeY(i_i, gammaLocal, betaLocal, rstdLocal, num);
CopyOutY(i_o_idx * rowFactor + i_i, j_idx, num);
}
inQueueGamma.FreeTensor(gammaLocal);
if (!this->nullptrBeta) {
inQueueBeta.FreeTensor(betaLocal);
}
}
__aicore__ inline void CopyInGammaBeta(uint32_t j_idx, uint32_t num)
{
LocalTensor<T> gammaLocal = inQueueGamma.AllocTensor<T>();
DataCopyCustom<T>(gammaLocal, gammaGm[j_idx * ubFactor], num);
inQueueGamma.EnQue(gammaLocal);
if (!this->nullptrBeta) {
LocalTensor<T> betaLocal = inQueueBeta.AllocTensor<T>();
DataCopyCustom<T>(betaLocal, betaGm[j_idx * ubFactor], num);
inQueueBeta.EnQue(betaLocal);
}
}
__aicore__ inline void CopyInX(uint32_t i_idx, uint32_t j_idx, uint32_t num)
{
LocalTensor<T> xLocal = inQueueX.AllocTensor<T>();
DataCopyCustom<T>(xLocal, xGm[i_idx * numCol + j_idx * ubFactor], num);
inQueueX.EnQue<T>(xLocal);
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
LocalTensor<T> xLocal = inQueueX.DeQue<T>();
Cast(x_fp32, xLocal, RoundMode::CAST_NONE, num);
PipeBarrier<PIPE_V>();
inQueueX.FreeTensor(xLocal);
}
}
__aicore__ inline void ComputeY(
uint32_t i_i_idx, LocalTensor<half>& gammaLocal, LocalTensor<half>& betaLocal, LocalTensor<float>& rstdLocal, uint32_t num)
{
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
LocalTensor<float> sqx = sqxBuf.Get<float>();
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(event_v_s);
WaitFlag<HardEvent::V_S>(event_v_s);
float rstdValue = rstdLocal.GetValue(i_i_idx);
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(event_s_v);
WaitFlag<HardEvent::S_V>(event_s_v);
PipeBarrier<PIPE_V>();
Muls(x_fp32, x_fp32, rstdValue, num);
PipeBarrier<PIPE_V>();
LocalTensor<half> yLocal = outQueueY.AllocTensor<half>();
Cast(yLocal, x_fp32, RoundMode::CAST_NONE, num);
PipeBarrier<PIPE_V>();
Mul(yLocal, gammaLocal, yLocal, num);
PipeBarrier<PIPE_V>();
if (!this->nullptrBeta) {
Add(yLocal, betaLocal, yLocal, num);
PipeBarrier<PIPE_V>();
}
outQueueY.EnQue<half>(yLocal);
}
__aicore__ inline void ComputeY(
uint32_t i_i_idx, LocalTensor<float>& gammaLocal, LocalTensor<float>& betaLocal, LocalTensor<float>& rstdLocal, uint32_t num)
{
LocalTensor<float> xLocal = inQueueX.DeQue<float>();
LocalTensor<float> sqx = sqxBuf.Get<float>();
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(event_v_s);
WaitFlag<HardEvent::V_S>(event_v_s);
float rstdValue = rstdLocal.GetValue(i_i_idx);
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(event_s_v);
WaitFlag<HardEvent::S_V>(event_s_v);
LocalTensor<float> yLocal = outQueueY.AllocTensor<float>();
Muls(yLocal, xLocal, rstdValue, num);
inQueueX.FreeTensor(xLocal);
PipeBarrier<PIPE_V>();
Mul(yLocal, gammaLocal, yLocal, num);
PipeBarrier<PIPE_V>();
if (!this->nullptrBeta) {
Add(yLocal, betaLocal, yLocal, num);
PipeBarrier<PIPE_V>();
}
outQueueY.EnQue<float>(yLocal);
}
__aicore__ inline void ComputeY(
uint32_t i_i_idx, LocalTensor<bfloat16_t>& gammaLocal, LocalTensor<bfloat16_t>& betaLocal, LocalTensor<float>& rstdLocal, uint32_t num)
{
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
LocalTensor<float> sqx = sqxBuf.Get<float>();
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(event_v_s);
WaitFlag<HardEvent::V_S>(event_v_s);
float rstdValue = rstdLocal.GetValue(i_i_idx);
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(event_s_v);
WaitFlag<HardEvent::S_V>(event_s_v);
PipeBarrier<PIPE_V>();
Muls(x_fp32, x_fp32, rstdValue, num);
PipeBarrier<PIPE_V>();
LocalTensor<bfloat16_t> yLocal = outQueueY.AllocTensor<bfloat16_t>();
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, num);
PipeBarrier<PIPE_V>();
Cast(x_fp32, yLocal, RoundMode::CAST_NONE, num);
PipeBarrier<PIPE_V>();
Cast(sqx, gammaLocal, RoundMode::CAST_NONE, num);
PipeBarrier<PIPE_V>();
Mul(x_fp32, x_fp32, sqx, num);
PipeBarrier<PIPE_V>();
if (!this->nullptrBeta) {
Cast(sqx, betaLocal, RoundMode::CAST_NONE, num);
PipeBarrier<PIPE_V>();
Add(x_fp32, x_fp32, sqx, num);
PipeBarrier<PIPE_V>();
}
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, num);
PipeBarrier<PIPE_V>();
outQueueY.EnQue<bfloat16_t>(yLocal);
}
__aicore__ inline void CopyOutY(uint32_t i_idx, uint32_t j_idx, uint32_t num)
{
LocalTensor<T> yLocal = outQueueY.DeQue<T>();
DataCopyCustom<T>(yGm[i_idx * numCol + j_idx * ubFactor], yLocal, num);
outQueueY.FreeTensor(yLocal);
}
__aicore__ inline void CopyOutRstd(uint32_t i_o_idx, uint32_t num)
{
LocalTensor<float> rstdLocal = outQueueRstd.DeQue<float>();
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
DataCopyCustom<float>(rstdGm[i_o_idx * rowFactor], rstdLocal, num);
#endif
outQueueRstd.FreeTensor(rstdLocal);
}
private:
TPipe* Ppipe = nullptr;
// create queues for input, in this case depth is equal to buffer num
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueX;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueGamma;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueBeta;
// create queues for output, in this case depth is equal to buffer num
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueY;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueRstd;
TBuf<TPosition::VECCALC> xFp32Buf;
TBuf<TPosition::VECCALC> sqxBuf;
TBuf<TPosition::VECCALC> sumBuf;
TBuf<TPosition::VECCALC> reduceFp32Buf;
GlobalTensor<T> x1Gm;
GlobalTensor<T> x2Gm;
GlobalTensor<T> gammaGm;
GlobalTensor<T> betaGm;
GlobalTensor<T> yGm;
GlobalTensor<float> rstdGm;
GlobalTensor<T> xGm;
uint32_t numRow;
uint32_t numCol;
uint32_t blockFactor; // number of calculations rows on each core
uint32_t rowFactor;
uint32_t ubFactor;
float epsilon;
float avgFactor;
int32_t blockIdx_;
uint32_t rowWork = 1;
uint32_t nullptrBeta = 0;
int tempbufNum;
};
#endif // _ADD_RMS_NORM_BIAS_SPLIT_D_H_

View File

@@ -0,0 +1,179 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file reduce_common.h
*/
#ifndef REDUCE_COMMON_H_RMS_NORM
#define REDUCE_COMMON_H_RMS_NORM
#include "kernel_operator.h"
using namespace AscendC;
constexpr uint32_t MAX_REP_NUM = 255;
constexpr uint32_t ELEM_PER_REP_FP32 = 64;
constexpr uint32_t ELEM_PER_BLK_FP32 = 8;
constexpr float ZERO = 0;
constexpr int32_t HALf_INTERVAL = 2;
constexpr int32_t INDEX_TWO = 2;
constexpr int32_t INDEX_FOUR = 4;
constexpr int32_t INDEX_EIGHT = 8;
constexpr int32_t INDEX_SIXTEEN = 16;
__aicore__ inline void ReduceSumForSmallReduceDimPreRepeat(
const LocalTensor<float>& dstLocal, const LocalTensor<float>& srcLocal, const LocalTensor<float>& tmpLocal,
const uint32_t elemNum, const uint32_t numLastDim, const uint32_t tailCount, const uint32_t repeat,
const uint8_t repStride)
{
uint32_t elemIndex = 0;
for (; elemIndex + ELEM_PER_REP_FP32 <= numLastDim; elemIndex += ELEM_PER_REP_FP32) {
Add(tmpLocal, srcLocal[elemIndex], tmpLocal, elemNum, repeat,
{1, 1, 1, ELEM_PER_BLK_FP32, repStride, ELEM_PER_BLK_FP32});
PipeBarrier<PIPE_V>();
}
if (unlikely(tailCount != 0)) {
Add(tmpLocal, srcLocal[elemIndex], tmpLocal, tailCount, repeat,
{1, 1, 1, ELEM_PER_BLK_FP32, repStride, ELEM_PER_BLK_FP32});
}
PipeBarrier<PIPE_V>();
AscendCUtils::SetMask<float>(ELEM_PER_REP_FP32); // set mask = 64
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
if ASCEND_IS_AIV {
WholeReduceSum<float, false>(dstLocal, tmpLocal, elemNum, repeat, 1, 1, ELEM_PER_BLK_FP32);
}
#elif defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003
WholeReduceSum(dstLocal, tmpLocal, elemNum, repeat, 1, 1, ELEM_PER_BLK_FP32);
#else
WholeReduceSum<float, false>(dstLocal, tmpLocal, elemNum, repeat, 1, 1, ELEM_PER_BLK_FP32);
#endif
}
/*
* reduce dim form (N, D) to (N, 1)
* this reduce sum is for small reduce dim.
*/
__aicore__ inline void ReduceSumForSmallReduceDim(
const LocalTensor<float>& dstLocal, const LocalTensor<float>& srcLocal, const LocalTensor<float>& tmpLocal,
const uint32_t numLastDimAligned, const uint32_t numLastDim, const uint32_t tailCount, const uint32_t repeat,
const uint8_t repStride)
{
uint32_t repeatTimes = repeat / MAX_REP_NUM;
if (repeatTimes == 0) {
ReduceSumForSmallReduceDimPreRepeat(
dstLocal, srcLocal, tmpLocal, ELEM_PER_REP_FP32, numLastDim, tailCount, repeat, repStride);
} else {
uint32_t repTailNum = repeat % MAX_REP_NUM;
uint32_t repIndex = 0;
uint32_t repElem;
for (; repIndex + MAX_REP_NUM <= repeat; repIndex += MAX_REP_NUM) {
ReduceSumForSmallReduceDimPreRepeat(
dstLocal[repIndex], srcLocal[repIndex * numLastDimAligned], tmpLocal[repIndex * ELEM_PER_REP_FP32],
ELEM_PER_REP_FP32, numLastDim, tailCount, MAX_REP_NUM, repStride);
}
if (repTailNum != 0) {
ReduceSumForSmallReduceDimPreRepeat(
dstLocal[repIndex], srcLocal[repIndex * numLastDimAligned], tmpLocal[repIndex * ELEM_PER_REP_FP32],
ELEM_PER_REP_FP32, numLastDim, tailCount, repTailNum, repStride);
}
}
}
/*
* reduce dim form (N, D) to (N, 1)
* this reduce sum is for small reduce dim, require D < 255 * 8.
* size of tmpLocal: (N, 64)
*/
__aicore__ inline void ReduceSumMultiN(
const LocalTensor<float>& dstLocal, const LocalTensor<float>& srcLocal, const LocalTensor<float>& tmpLocal,
const uint32_t numRow, const uint32_t numCol, const uint32_t numColAlign)
{
const uint32_t tailCount = numCol % ELEM_PER_REP_FP32;
const uint32_t repeat = numRow;
const uint8_t repStride = numColAlign / ELEM_PER_BLK_FP32;
Duplicate(tmpLocal, ZERO, numRow * ELEM_PER_REP_FP32);
PipeBarrier<PIPE_V>();
ReduceSumForSmallReduceDim(dstLocal, srcLocal, tmpLocal, numColAlign, numCol, tailCount, repeat, repStride);
}
__aicore__ inline int32_t findPowerTwo(int32_t n)
{
// find max power of 2 no more than n (32 bit)
n |= n >> 1; // Set the first digit of n's binary to 1
n |= n >> INDEX_TWO;
n |= n >> INDEX_FOUR;
n |= n >> INDEX_EIGHT;
n |= n >> INDEX_SIXTEEN;
return (n + 1) >> 1;
}
__aicore__ inline void ReduceSumHalfInterval(
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, int32_t count)
{
if (likely(count > ELEM_PER_REP_FP32)) {
int32_t bodyCount = findPowerTwo(count);
int32_t tailCount = count - bodyCount;
if (tailCount > 0) {
Add(src_local, src_local, src_local[bodyCount], tailCount);
PipeBarrier<PIPE_V>();
}
while (bodyCount > ELEM_PER_REP_FP32) {
bodyCount = bodyCount / HALf_INTERVAL;
Add(src_local, src_local, src_local[bodyCount], bodyCount);
PipeBarrier<PIPE_V>();
}
AscendCUtils::SetMask<float>(ELEM_PER_REP_FP32);
} else {
AscendCUtils::SetMask<float>(count);
}
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
if (g_coreType == AIV) {
WholeReduceSum<float, false>(dst_local, src_local, ELEM_PER_REP_FP32, 1, 0, 1, 0);
}
#elif defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003
WholeReduceSum(dst_local, src_local, ELEM_PER_REP_FP32, 1, 1, 1, ELEM_PER_BLK_FP32);
#else
WholeReduceSum<float, false>(dst_local, src_local, ELEM_PER_REP_FP32, 1, 1, 1, DEFAULT_REPEAT_STRIDE);
#endif
PipeBarrier<PIPE_V>();
}
__aicore__ inline float ReduceSumHalfInterval(const LocalTensor<float>& src_local, int32_t count)
{
if (likely(count > ELEM_PER_REP_FP32)) {
int32_t bodyCount = findPowerTwo(count);
int32_t tailCount = count - bodyCount;
if (tailCount > 0) {
Add(src_local, src_local, src_local[bodyCount], tailCount);
PipeBarrier<PIPE_V>();
}
while (bodyCount > ELEM_PER_REP_FP32) {
bodyCount = bodyCount / HALf_INTERVAL;
Add(src_local, src_local, src_local[bodyCount], bodyCount);
PipeBarrier<PIPE_V>();
}
AscendCUtils::SetMask<float>(ELEM_PER_REP_FP32);
} else {
AscendCUtils::SetMask<float>(count);
}
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
if (g_coreType == AIV) {
WholeReduceSum<float, false>(src_local, src_local, ELEM_PER_REP_FP32, 1, 0, 1, 0);
}
#elif defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003
WholeReduceSum(src_local, src_local, ELEM_PER_REP_FP32, 1, 1, 1, ELEM_PER_BLK_FP32);
#else
WholeReduceSum<float, false>(src_local, src_local, ELEM_PER_REP_FP32, 1, 1, 1, DEFAULT_REPEAT_STRIDE);
#endif
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(event_v_s);
WaitFlag<HardEvent::V_S>(event_v_s);
return src_local.GetValue(0);
}
#endif // _REDUCE_COMMON_H_

View File

@@ -0,0 +1,316 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef RMS_NORM_BASE_H_
#define RMS_NORM_BASE_H_
#include "kernel_operator.h"
#include "reduce_common.h"
namespace RmsNorm {
using namespace AscendC;
/**
* Get the block size of unified buffer in bytes
*/
__aicore__ inline constexpr uint32_t GetUbBlockSize()
{
return 32U;
}
/**
* Get the size of vector registers in bytes
*/
__aicore__ inline constexpr uint32_t GetVRegSize()
{
#if __CCE_AICORE__ == 310
return AscendC::VECTOR_REG_WIDTH;
#else
return 256U;
#endif
}
#if defined(__CCE_AICORE__) && __CCE_AICORE__ != 220 && __CCE_AICORE__ != 310
#define bfloat16_t int16_t
#endif
constexpr int32_t BUFFER_NUM = 1; // tensor num for each queue
constexpr int32_t DOUBLE_BUFFER_NUM = 2;
constexpr int32_t UNROLL_NUM = 2;
constexpr int32_t NUM_PER_REP_FP32 = 64; // ONE_REPEAT_BYTE_SIZE / sizeof(float);
constexpr int32_t NUM_PER_BLK_FP32 = 8;
constexpr int32_t FLOAT_BTYPE_SIZE = 4;
constexpr int32_t NUM_PER_BLK_FP16 = 16;
constexpr int32_t CONTINUE_STRIDE = 8;
constexpr int32_t BLOCK_SIZE = 32;
constexpr uint32_t ONCE_VECTOR_SIZE = 256;
constexpr float MINUS_HALF = -0.5f;
constexpr uint32_t ZERO_UINT = 0;
constexpr uint32_t ONE_UINT = 1;
constexpr uint32_t TWO_UINT = 2;
constexpr uint32_t THREE_UINT = 3;
constexpr float ONE = 1;
constexpr int32_t SECOND_LOOP = 2;
constexpr int32_t HALf_INTERVAL = 2;
constexpr int32_t MAX_REAPEAT = 255;
constexpr int32_t DIM_NUM = 2;
constexpr int32_t NDDMA_DIM = 5;
constexpr uint32_t V_LENGTH = GetVRegSize() / sizeof(float);
constexpr uint64_t ALIGN_512_FACTOR = 512;
constexpr uint64_t ALIGN_32_FACTOR = 32;
constexpr int32_t CONST_FACTOR_2 = 2;
constexpr uint32_t SUM_COUNT = 2;
constexpr int32_t MOV_2 = 2;
constexpr int32_t MOV_4 = 4;
constexpr int32_t MOV_8 = 8;
constexpr int32_t MOV_16 = 16;
template <typename T>
__aicore__ inline T CeilDiv(T x, T y)
{
return y == 0 ? x : (x + y - 1) / y;
}
template <typename T>
__aicore__ inline T Min(T left, T right)
{
return (left < right ? left : right);
}
template <typename Tp, Tp v>
struct integral_constant {
static constexpr Tp value = v;
};
using true_type = integral_constant<bool, true>;
using false_type = integral_constant<bool, false>;
template <typename, typename>
struct is_same : public false_type {};
template <typename Tp>
struct is_same<Tp, Tp> : public true_type {};
template <typename T, typename T_GAMMA>
class KernelRmsNormBase {
#define IS_X_FP32 (is_same<T, float>::value)
#define IS_GAMMA_FP32 (is_same<T_GAMMA, float>::value)
#define IS_MIX_DTYPE ((!IS_X_FP32) && IS_GAMMA_FP32)
};
__aicore__ inline int32_t findPowerTwo(int32_t n)
{
// find max power of 2 no more than n (32 bit)
n |= n >> 1; // Set the first digit of n's binary to 1
n |= n >> MOV_2;
n |= n >> MOV_4;
n |= n >> MOV_8;
n |= n >> MOV_16;
return (n + 1) >> 1;
}
__aicore__ inline void ReduceSumHalfIntervalToRepeat(
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, int32_t count, int32_t left)
{
// count need smaller than 255 repeat
if (likely(count > NUM_PER_BLK_FP32)) {
int32_t bodyCount = count - left;
int32_t tailCount = left;
if (tailCount > 0) {
Add(src_local, src_local, src_local[bodyCount], tailCount);
PipeBarrier<PIPE_V>();
}
while (bodyCount > SECOND_LOOP * NUM_PER_BLK_FP32) {
bodyCount = bodyCount / HALf_INTERVAL;
Add(src_local, src_local, src_local[bodyCount], bodyCount);
PipeBarrier<PIPE_V>();
}
bodyCount = bodyCount / HALf_INTERVAL;
Add(dst_local, src_local, src_local[bodyCount], bodyCount);
PipeBarrier<PIPE_V>();
}
}
__aicore__ inline void ReduceSumFP32(
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, const LocalTensor<float>& work_local,
int32_t count)
{
// count need smaller than 255 repeat
uint64_t mask = NUM_PER_REP_FP32;
int32_t repeatTimes = count / NUM_PER_REP_FP32;
int32_t tailCount = count % NUM_PER_REP_FP32;
int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
BinaryRepeatParams repeatParams;
repeatParams.src0RepStride = ONE_REPEAT_BYTE_SIZE / ONE_BLK_SIZE;
repeatParams.src0BlkStride = 1;
repeatParams.src1RepStride = 0;
repeatParams.src1BlkStride = 1;
repeatParams.dstRepStride = 0;
repeatParams.dstBlkStride = 1;
Duplicate(work_local, ZERO, NUM_PER_REP_FP32);
PipeBarrier<PIPE_V>();
if (likely(repeatTimes > 0)) {
Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams);
PipeBarrier<PIPE_V>();
}
if (unlikely(tailCount != 0)) {
Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams);
PipeBarrier<PIPE_V>();
}
AscendCUtils::SetMask<float>(NUM_PER_REP_FP32);
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
if (g_coreType == AIV) {
WholeReduceSum<float, false>(dst_local, work_local, MASK_PLACEHOLDER, 1, 0, 1, 0);
}
#elif !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
WholeReduceSum<float, false>(dst_local, work_local, MASK_PLACEHOLDER, 1, 1, 1, DEFAULT_REPEAT_STRIDE);
#endif
PipeBarrier<PIPE_V>();
}
__aicore__ inline void ReduceSumCustom(
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, const LocalTensor<float>& work_local,
int32_t count)
{
ReduceSumFP32(dst_local, src_local, work_local, count);
}
__aicore__ inline void ReduceSumFP32ToBlock(
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, const LocalTensor<float>& work_local,
int32_t count)
{
// count need smaller than 255 repeat
uint64_t mask = NUM_PER_REP_FP32;
int32_t repeatTimes = count / NUM_PER_REP_FP32;
int32_t tailCount = count % NUM_PER_REP_FP32;
int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
BinaryRepeatParams repeatParams;
repeatParams.src0RepStride = ONCE_VECTOR_SIZE / BLOCK_SIZE;
repeatParams.src0BlkStride = 1;
repeatParams.src1RepStride = 0;
repeatParams.src1BlkStride = 1;
repeatParams.dstRepStride = 0;
repeatParams.dstBlkStride = 1;
Duplicate(work_local, ZERO, NUM_PER_REP_FP32);
PipeBarrier<PIPE_V>();
if (likely(repeatTimes > 0)) {
Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams);
PipeBarrier<PIPE_V>();
}
if (unlikely(tailCount != 0)) {
Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams);
PipeBarrier<PIPE_V>();
}
BlockReduceSum(dst_local, work_local, 1, mask, 1, 1, DEFAULT_REPEAT_STRIDE);
PipeBarrier<PIPE_V>();
}
__aicore__ inline void BlockReduceSumFP32(
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, int32_t count)
{
// count need multiple of 8
int32_t repeatTimes = count / NUM_PER_REP_FP32;
int32_t tailCount = count % NUM_PER_REP_FP32;
int32_t dstAddr = repeatTimes * 8;
int32_t srcAddr = repeatTimes * NUM_PER_REP_FP32;
if (likely(repeatTimes > 0)) {
BlockReduceSum(dst_local, src_local, repeatTimes, NUM_PER_REP_FP32, 1, 1, DEFAULT_REPEAT_STRIDE);
PipeBarrier<PIPE_V>();
}
if (tailCount != 0) {
BlockReduceSum(dst_local[dstAddr], src_local[srcAddr], 1, tailCount, 1, 1, DEFAULT_REPEAT_STRIDE);
PipeBarrier<PIPE_V>();
}
}
template <typename T, typename U, typename R>
__aicore__ inline void DataCopyCustom(const U& dstTensor, const R& srcTensor, const uint32_t count)
{
#if (defined(__CCE_AICORE__) && __CCE_AICORE__ == 220) || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
DataCopyParams copyParams;
copyParams.blockLen = count * sizeof(T);
copyParams.blockCount = 1;
if constexpr (is_same<U, AscendC::LocalTensor<T>>::value) {
DataCopyPadParams padParams;
DataCopyPad(dstTensor, srcTensor, copyParams, padParams);
} else {
DataCopyPad(dstTensor, srcTensor, copyParams);
}
#else
// only support count greater than 32byte
int32_t numPerBlock = ONE_BLK_SIZE / sizeof(T);
if (count % numPerBlock == 0) {
DataCopy(dstTensor, srcTensor, count);
} else {
if constexpr (is_same<U, AscendC::LocalTensor<T>>::value) {
int32_t num = AlignUp(count, numPerBlock);
DataCopy(dstTensor, srcTensor, num);
} else {
if (count < numPerBlock) {
DataCopy(dstTensor, srcTensor, numPerBlock);
} else {
int32_t num = count / numPerBlock * numPerBlock;
DataCopy(dstTensor, srcTensor, num);
SetFlag<HardEvent::MTE3_S>(EVENT_ID0);
WaitFlag<HardEvent::MTE3_S>(EVENT_ID0);
for (int32_t i = 0; i < numPerBlock; i++) {
T tensorValue = srcTensor.GetValue(count - numPerBlock + i);
srcTensor.SetValue(i, tensorValue);
}
SetFlag<HardEvent::S_MTE3>(EVENT_ID0);
WaitFlag<HardEvent::S_MTE3>(EVENT_ID0);
DataCopy(dstTensor[count - numPerBlock], srcTensor, numPerBlock);
}
}
}
#endif
}
template <typename T>
__aicore__ inline void DataCopyCustom(
const LocalTensor<T>& dstTensor, const GlobalTensor<T>& srcTensor, const uint32_t numRow, const uint32_t numCol)
{
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
DataCopyParams copyParams;
copyParams.blockLen = numCol * sizeof(T);
copyParams.blockCount = numRow;
DataCopyPadParams padParams;
DataCopyPad(dstTensor, srcTensor, copyParams, padParams);
#endif
}
template <typename T>
__aicore__ inline void DataCopyCustom(
const GlobalTensor<T>& dstTensor, const LocalTensor<T>& srcTensor, const uint32_t numRow, const uint32_t numCol)
{
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
DataCopyParams copyParams;
copyParams.blockLen = numCol * sizeof(T);
copyParams.blockCount = numRow;
DataCopyPad(dstTensor, srcTensor, copyParams);
#endif
}
__aicore__ inline void RoundFloat2Int8(LocalTensor<int8_t>& dstTensor, LocalTensor<float>& srcTensor, int32_t size)
{
Cast(srcTensor.ReinterpretCast<int32_t>(), srcTensor, RoundMode::CAST_RINT, size);
PipeBarrier<PIPE_V>();
SetDeqScale((half)1.000000e+00f);
PipeBarrier<PIPE_V>();
Cast(srcTensor.ReinterpretCast<half>(), srcTensor.ReinterpretCast<int32_t>(), RoundMode::CAST_NONE, size);
PipeBarrier<PIPE_V>();
Cast(dstTensor, srcTensor.ReinterpretCast<half>(), RoundMode::CAST_TRUNC, size);
}
__aicore__ inline uint32_t ROUND_UP(uint32_t x, uint32_t block_number)
{
if (block_number > 0) {
return (x + block_number - 1) / block_number * block_number;
}
return 0;
}
} // namespace RmsNorm
#endif // RMS_NORM_BASE_H_