[Kernel] add custom op MatmulAllreduceAddRmsnorm (#4606)

What this PR does / why we need it?
Optimization of the fused operator for Qwen3 32B: Matmul, AllReduce,
Add, and RMSNorm

Does this PR introduce _any_ user-facing change?
No

How was this patch tested?

vLLM version: v0.11.2
vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

Signed-off-by: tongrunze <t00574058@china.huawei.com>
Co-authored-by: tongrunze <t00574058@china.huawei.com>
This commit is contained in:
Trunrain
2025-12-10 09:05:33 +08:00
committed by GitHub
parent f404c9af7f
commit ba9cda9dfd
16 changed files with 2854 additions and 1 deletions

View File

@@ -0,0 +1,50 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "lib/matmul_intf.h"
#include <kernel_operator.h>
#include "matmul_allreduce_add_rmsnorm_aic_kernel.h"
#include "matmul_allreduce_add_rmsnorm_aiv_kernel.h"
extern "C" __global__ __aicore__ void matmul_allreduce_add_rmsnorm(
GM_ADDR x1, GM_ADDR x2, GM_ADDR residual,
GM_ADDR gamma, GM_ADDR y, GM_ADDR add_out, GM_ADDR workspace, GM_ADDR tiling)
{
REGISTER_TILING_DEFAULT(MatmulAllreduceAddRmsnormTilingData);
GET_TILING_DATA(tiling_data, tiling);
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2);
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
auto tilingData = (__gm__ MatmulAllreduceAddRmsnormTilingData*)tiling;
__gm__ void* mc2InitTiling = (__gm__ void*)(&(tilingData->mc2InitTiling));
__gm__ void* mc2CcTiling = (__gm__ void*)(&(tilingData->mc2CcTiling));
auto contextGM0 = AscendC::GetHcclContext<AscendC::HCCL_GROUP_ID_0>();
if ASCEND_IS_AIC {
MatmulAllreduceAddRmsnormAicKernel<DTYPE_X1, DTYPE_Y> op;
op.Init(x1, x2, residual, gamma, y, workspace, &tiling_data, hccl_);
op.Process();
return;
}
if ASCEND_IS_AIV {
MatmulAllreduceAddRmsnormAivKernel<DTYPE_X1, DTYPE_Y> op;
op.Init(x1, x2, residual, gamma, y, add_out, workspace, &tiling_data, hccl_);
op.Process(&tiling_data);
return;
}
}

View File

@@ -0,0 +1,359 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MATMUL_ALLREDUCE_ADD_RMSNORM_AIC_KERNEL_H
#define MATMUL_ALLREDUCE_ADD_RMSNORM_AIC_KERNEL_H
#define ASCENDC_CUBE_ONLY
#include "catlass/catlass.hpp"
#include "catlass/arch/arch.hpp"
#include "catlass/gemm/block/block_mmad.hpp"
#include "catlass/gemm/block/block_swizzle.hpp"
#include "catlass/gemm/dispatch_policy.hpp"
#include "catlass/gemm/kernel/basic_matmul.hpp"
#include "catlass/gemm/gemm_type.hpp"
#include "catlass/layout/layout.hpp"
#include "matmul_allreduce_add_rmsnorm_utils.h"
#include "matmul_allreduce_add_rmsnorm_tiling.h"
constexpr int32_t SCALE_L1_SIZE_A = 256 * 8;
constexpr int32_t SCALE_L1_SIZE_B = 128 * 1024;
constexpr int32_t CUBE_MATRIX_SIZE_B16 = 256; // 16 * 16
constexpr int32_t CUBE_MATRIX_SIZE_B8 = 16 * 32; // 16 * 32
constexpr int32_t SCALE_L1_SIZE = 256 * 8; // 2 KB
constexpr int32_t BLOCK_SIZE_16 = 16;
constexpr int32_t BLOCK_SIZE_32 = 32;
constexpr int32_t DOUBLE_BUFFER_SIZE = 2;
constexpr uint32_t MM_L1_TILE_SHAPE_M = 128;
constexpr uint32_t MM_L1_TILE_SHAPE_N = 256;
constexpr uint32_t MM_L1_TILE_SHAPE_K = 256;
constexpr uint32_t MM_L0_TILE_SHAPE_M = MM_L1_TILE_SHAPE_M;
constexpr uint32_t MM_L0_TILE_SHAPE_N = MM_L1_TILE_SHAPE_N;
constexpr uint32_t MM_L0_TILE_SHAPE_K = 64;
using namespace Catlass;
template <typename T_INPUT>
struct GetAccumType {
using T = float;
};
__aicore__ inline bool IsQuant(const QuantGranularity &granularity)
{
return (granularity > QuantGranularity::QUANT_GRANULARITY_UNDEFINED) &&
(granularity < QuantGranularity::QUANT_GRANULARITY_MAX);
}
template <typename MmadDtype, typename OutDtype>
class MatmulAllreduceAddRmsnormAicKernel {
using T_ACCUM = typename GetAccumType<MmadDtype>::T;
public:
int PIPE_DEPTH = 2;
Arch::Resource<Arch::AtlasA2> resource;
__aicore__ inline MatmulAllreduceAddRmsnormAicKernel<MmadDtype, OutDtype>() { }
__aicore__ inline void Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR residual, GM_ADDR gamma, GM_ADDR y,
GM_ADDR workspace, const MatmulAllreduceAddRmsnormTilingData* tilingData,
Hccl<HCCL_SERVER_TYPE_AICPU> &hccl_)
{
this->hccl_ = hccl_;
this->gm_c = reinterpret_cast<__gm__ OutDtype *>(y);
this->gm_dequant_scale = nullptr;
this->has_offset = false;
auto ppTilingData = &tilingData->matmulAllreduceAddRmsnormInfo.ppTilingData;
auto commTilingData = &tilingData->matmulAllreduceAddRmsnormInfo.commTilingData;
auto quantInfo = &tilingData->matmulAllreduceAddRmsnormInfo.quantInfo;
this->batch_size = ppTilingData->opShape.batchSize;
this->m = ppTilingData->opShape.m;
this->k = ppTilingData->opShape.k;
this->n = ppTilingData->opShape.n;
this->weight_nz = false;
this->is_int8 = false;
this->cube_matrix_size = this->is_int8 ? CUBE_MATRIX_SIZE_B8 : CUBE_MATRIX_SIZE_B16;
this->m_align = Block512B<MmadDtype>::AlignUp(m);
this->k_align = Block512B<MmadDtype>::AlignUp(k);
this->n_align = Block512B<MmadDtype>::AlignUp(n);
this->m0 = ppTilingData->m0;
this->k0 = ppTilingData->k0;
this->n0 = ppTilingData->n0;
int32_t tiling_key = ppTilingData->tilingKey;
this->trans_a = ppTilingData->isTransA;
this->trans_b = ppTilingData->isTransB;
int32_t aligned_a;
int32_t aligned_b;
this->dequant_granularity = quantInfo->dequantGranularity;
AlignJudge(this->trans_a, this->trans_b, this->m, this->k, this->n,
this->m_align, this->k_align, this->n_align, aligned_a, aligned_b);
this->aligned_a = aligned_a;
this->aligned_b = aligned_b;
if (weight_nz) {
this->k_align16 = Block32B<MmadDtype>::AlignUp(k);
this->n_align16 = Block32B<MmadDtype>::AlignUp(n);
}
bool has_a_align = IsQuant(quantInfo->quantGranularity) || aligned_a;
bool has_b_align = IsQuant(this->dequant_granularity) && !this->is_int8 || aligned_b;
bool has_accum = IsQuant(this->dequant_granularity) &&
this->is_int8 && std::is_same<OutDtype, bfloat16_t>::value;
bool has_format_dequant_offset =
(this->dequant_granularity == QuantGranularity::PER_TENSOR) && this->is_int8 && this->has_offset;
auto workspace_info = GetWorkspaceInfo(workspace, this->batch_size, this->m, this->k, this->n,
this->m_align, this->k_align, this->n_align, this->trans_a, this->trans_b,
sizeof(MmadDtype), has_a_align, has_b_align, has_accum, has_format_dequant_offset);
this->gm_a_src = reinterpret_cast<__gm__ MmadDtype *>(x1);
this->gm_b_src = reinterpret_cast<__gm__ MmadDtype *>(x2);
this->gm_format_dequant_offset = reinterpret_cast<__gm__ int32_t *>(has_format_dequant_offset ?
workspace_info.gm_dequant_param : nullptr);
this->gm_workspace_src = workspace;
this->block_size = BLOCK_SIZE_32 / sizeof(MmadDtype);
int32_t a_l1_size = this->m0 * this->k0 * sizeof(MmadDtype);
int32_t a_l1_size_round = AscendC::DivCeil(a_l1_size, 512) * 512;
int32_t b_l1_size = this->n0 * this->k0 * sizeof(MmadDtype);
int32_t b_l1_size_round = AscendC::DivCeil(b_l1_size, 512) * 512;
this->l1_base_a = reinterpret_cast<__cbuf__ MmadDtype *>((uintptr_t)(this->is_int8 ? SCALE_L1_SIZE : 0));
this->l1_base_b =
reinterpret_cast<__cbuf__ MmadDtype *>(a_l1_size_round * (this->is_int8 ? DOUBLE_BUFFER_SIZE : 1) +
(uintptr_t) this->l1_base_a);
this->core_num = get_block_num();
this->core_idx = get_block_idx();
this->m_loop = ppTilingData->mLoop;
this->k_loop = ppTilingData->kLoop;
this->n_loop = ppTilingData->nLoop;
this->core_loop = ppTilingData->coreLoop;
this->swizzl_count = ppTilingData->swizzlCount;
this->swizzl_direct = ppTilingData->swizzlDirect;
this->is_91093 = commTilingData->is91093;
this->ping_flag = 1;
this->rank = hccl_.GetRankId();
this->rank_size = hccl_.GetRankDim();
this->withSerialMode = commTilingData->withSerialMode;
this->gm_peer_mem = (__gm__ OutDtype *)hccl_.GetWindowsInAddr(this->rank);
}
__aicore__ inline void MoveL0CToGM(__gm__ OutDtype *gm_dst, int64_t offset_c,
int32_t m_actual, int32_t n_actual, int32_t src_stride, int32_t dst_stride) {
if constexpr (std::is_same<OutDtype, __bf16>::value) {
copy_matrix_cc_to_gm(
gm_dst + offset_c,
l0c_buf,
0,
n_actual,
m_actual,
dst_stride,
src_stride,
0,
F322BF16,
0,
false,
true
);
} else {
copy_matrix_cc_to_gm(
gm_dst + offset_c,
l0c_buf,
0,
n_actual,
m_actual,
dst_stride,
src_stride,
0,
F322F16,
0,
false,
true
);
}
SetFlag<HardEvent::FIX_M>(EVENT_ID0);
}
__aicore__ inline void InitFlags()
{
WaitEvent(AIC_WAIT_AIV_FINISH_ALIGN_FLAG_ID);
}
__aicore__ inline void Endflags()
{
}
__aicore__ inline void Process()
{
// AIC matmul func, waits for AIV to complete [AllReduce & Add & RMSNorm].
InitFlags();
uint32_t m = this->m;
uint32_t k = this->k;
uint32_t n = this->n;
gmB.SetGlobalBuffer(gm_b_src, k * n);
using LayoutA = layout::RowMajor;
using LayoutB = layout::ColumnMajor;
using LayoutC = layout::RowMajor;
LayoutB layoutB {(layout::ColumnMajor::Index)k, (layout::ColumnMajor::Index)n};
using L1TileShape = GemmShape<MM_L1_TILE_SHAPE_M, MM_L1_TILE_SHAPE_N, MM_L1_TILE_SHAPE_K>;
using L0TileShape = GemmShape<MM_L0_TILE_SHAPE_M, MM_L0_TILE_SHAPE_N, MM_L0_TILE_SHAPE_K>;
using AType = Gemm::GemmType<MmadDtype, LayoutA>;
using BType = Gemm::GemmType<MmadDtype, LayoutB>;
using CType = AType;
constexpr bool ENABLE_UNIT_FLAG = true;
using MmadDispatchPolicy = Gemm::MmadAtlasA2Pingpong<ENABLE_UNIT_FLAG>;
using BlockMmad = Gemm::Block::BlockMmad<MmadDispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
GemmCoord blockShape = L1TileShape::ToCoord();
BlockMmad blockMmad(resource);
int mPerSplit = this->m0 * this->swizzl_count;
int mAvg = mPerSplit;
int splitM = AscendC::DivCeil(m, mPerSplit);
int flag_idx = 0;
icache_preload(8); // 8 corresponding to 16k
for (int splitIndex = 0; splitIndex < splitM; ++splitIndex) {
uint32_t mStart = splitIndex * mAvg;
uint32_t mActual = mAvg > (m - mStart) ? m - mStart:mAvg;
flag_idx = splitIndex % PIPE_DEPTH;
if (splitIndex >= PIPE_DEPTH) {
WaitEvent(flag_idx);
}
__gm__ MmadDtype *gm_a_src_tmp = reinterpret_cast<__gm__ MmadDtype *>(gm_a_src) + mStart * k;
__gm__ MmadDtype *gm_c_src_tmp = reinterpret_cast<__gm__ MmadDtype *>(gm_peer_mem) + mStart * n;
gmA.SetGlobalBuffer(gm_a_src_tmp, mActual*k);
gmC.SetGlobalBuffer(gm_c_src_tmp, mActual*n);
GemmCoord splitShape{mActual, n, k};
using BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle<3, 1>; // SwizzleOffset=3
BlockScheduler splitScheduler(splitShape, blockShape.GetCoordMN());
uint32_t coreLoops = splitScheduler.GetCoreLoops();
LayoutA layoutA{mActual, k};
LayoutC layoutC{mActual, n};
for (uint32_t loopIdx = core_idx; loopIdx < coreLoops; loopIdx += core_num) {
GemmCoord blockCoord = splitScheduler.GetBlockCoord(loopIdx);
GemmCoord actualBlockShape = splitScheduler.GetActualBlockShape(blockCoord);
GemmCoord offsetCoord = blockCoord * blockShape;
MatrixCoord offsetA = offsetCoord.GetCoordMK();
MatrixCoord offsetB = offsetCoord.GetCoordKN();
MatrixCoord offsetC = offsetCoord.GetCoordMN();
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
int64_t gmOffsetB = layoutB.GetOffset(offsetB);
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
blockMmad (gmA[gmOffsetA], layoutA, gmB[gmOffsetB], layoutB, gmC[gmOffsetC], layoutC, actualBlockShape);
}
FFTSCrossCoreSync<PIPE_FIX>(FFTS_SYNC_AICORE_GROUP_MODE, flag_idx);
}
Endflags();
PipeBarrier<PIPE_ALL>();
}
private:
AscendC::GlobalTensor<MmadDtype> gmA;
AscendC::GlobalTensor<MmadDtype> gmB;
AscendC::GlobalTensor<MmadDtype> gmC;
__gm__ MmadDtype *gm_a_src{nullptr};
__gm__ MmadDtype *gm_b_src{nullptr};
__gm__ OutDtype *gm_c{nullptr};
__gm__ OutDtype *gm_peer_mem{nullptr};
__gm__ int64_t *gm_dequant_scale{nullptr};
__gm__ int32_t *gm_format_dequant_offset{nullptr};
__gm__ int32_t *gm_accum{nullptr};
__gm__ uint8_t *gm_workspace_src;
__cbuf__ MmadDtype *l1_base_a = reinterpret_cast<__cbuf__ MmadDtype *>((uintptr_t) SCALE_L1_SIZE_A);
__cbuf__ MmadDtype *l1_base_b = reinterpret_cast<__cbuf__ MmadDtype *>((uintptr_t) SCALE_L1_SIZE_B);
__ca__ MmadDtype *l0a_base = reinterpret_cast<__ca__ MmadDtype *>((uintptr_t) 0);
__cb__ MmadDtype *l0b_base = reinterpret_cast<__cb__ MmadDtype *>((uintptr_t) 0);
__cc__ T_ACCUM *l0c_buf = reinterpret_cast<__cc__ T_ACCUM *>((uintptr_t) 0);
__cbuf__ int64_t *scale_l1 = reinterpret_cast<__cbuf__ int64_t *>((uintptr_t) 0);
__fbuf__ int64_t *scale_FB = (__fbuf__ int64_t *)(0);
__cbuf__ int32_t *bias_l1 = reinterpret_cast<__cbuf__ int32_t *>((uintptr_t)0);
uint16_t bias_bt = 0;
bool has_offset{false};
int32_t core_num;
int32_t batch_size;
int32_t m;
int32_t k;
int32_t n;
int32_t m_align;
int32_t k_align;
int32_t n_align;
int32_t k_align16;
int32_t n_align16;
int32_t m0;
int32_t k0;
int32_t n0;
int32_t m_loop;
int32_t n_loop;
int32_t k_loop;
int32_t core_loop;
int32_t core_idx;
int32_t ping_flag;
int32_t block_size;
int32_t cube_matrix_size;
int32_t aligned_a;
int32_t aligned_b;
int32_t swizzl_count;
int32_t swizzl_direct;
int32_t rank;
int32_t rank_size;
int32_t withSerialMode;
int32_t ag_dim;
int32_t rs_dim;
bool inner_dim_is_Ag{false};
int32_t ag_rank_idx;
int32_t rs_rank_idx;
bool weight_nz{false};
bool is_91093{false};
QuantGranularity dequant_granularity;
bool is_int8;
bool trans_a;
bool trans_b;
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
};
#endif // MATMUL_ALLREDUCE_ADD_RMSNORM_AIC_KERNEL_H

View File

@@ -0,0 +1,702 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MATMUL_ALLREDUCE_ADD_RMSNORM_AIV_KERNEL_H
#define MATMUL_ALLREDUCE_ADD_RMSNORM_AIV_KERNEL_H
#include "kernel_operator.h"
#include "matmul_allreduce_add_rmsnorm_tiling.h"
#include "matmul_allreduce_add_rmsnorm_utils.h"
using namespace AscendC;
constexpr int32_t DIFUSION_ADD_LEN = 512;
constexpr int32_t TQUE_DEPTH = 1;
constexpr uint32_t TBUF_POOL_MAX_BUFID_SIZE = 8;
enum CrossRankSyncFlagEnum {
FLAG_ZERO_IDX,
FLAG_ONE_IDX,
FLAG_TWO_IDX,
FLAG_ADD_IDX,
FLAG_FOUR_IDX,
FLAG_GATHER_ADD_OUT_STEP1,
FLAG_GATHER_ADD_OUT_STEP2,
FLAG_NUM
};
constexpr int32_t FLAG_VALUE = 1;
constexpr int32_t NUM_PER_REP_FP32 = 64;
template <typename T>
__aicore__ void CopyUbufToGmAlignB16(__gm__ T *dst, __ubuf__ T *src, uint16_t nBurst, uint32_t lenBurst,
uint16_t srcSTride, uint16_t dstStride)
{
DataCopyExtParams dataCopyParams(nBurst,
lenBurst,
srcSTride,
dstStride,
0);
LocalTensor<uint8_t> ubTensor;
TBuffAddr ubAddr;
ubAddr.logicPos = static_cast<uint8_t>(TPosition::VECIN);
ubAddr.bufferAddr = reinterpret_cast<uint64_t>(src);
ubTensor.SetAddr(ubAddr);
GlobalTensor<uint8_t> gmTensor;
gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint8_t *>(dst));
DataCopyPad(gmTensor, ubTensor, dataCopyParams);
}
template <typename T>
__aicore__ void CopyGmToUbufAlignB16(__ubuf__ T *dst, __gm__ T *src, uint16_t nBurst, uint32_t lenBurst,
uint16_t srcSTride, uint16_t dstStride)
{
DataCopyExtParams dataCopyParams(nBurst,
lenBurst,
srcSTride,
dstStride,
0);
LocalTensor<uint8_t> ubTensor;
TBuffAddr ubAddr;
ubAddr.logicPos = static_cast<uint8_t>(TPosition::VECIN);
ubAddr.bufferAddr = reinterpret_cast<uint64_t>(dst);
ubTensor.SetAddr(ubAddr);
GlobalTensor<uint8_t> gmTensor;
gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint8_t *>(src));
DataCopyPadExtParams<uint8_t> padParams;
DataCopyPad(ubTensor, gmTensor, dataCopyParams, padParams);
}
template <typename MmadDtype, typename OutDtype>
class MatmulAllreduceAddRmsnormAivKernel {
public:
__aicore__ inline MatmulAllreduceAddRmsnormAivKernel<MmadDtype, OutDtype>() { }
__aicore__ inline void Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR residual, GM_ADDR gamma, GM_ADDR y, GM_ADDR add_out,
GM_ADDR workspace, const MatmulAllreduceAddRmsnormTilingData *tilingData,
Hccl<HCCL_SERVER_TYPE_AICPU> &hccl_)
{
this->hccl_ = hccl_;
is_deterministic = false;
auto ppTilingData = &tilingData->matmulAllreduceAddRmsnormInfo.ppTilingData;
auto commTilingData = &tilingData->matmulAllreduceAddRmsnormInfo.commTilingData;
auto quantInfo = &tilingData->matmulAllreduceAddRmsnormInfo.quantInfo;
gm_out = reinterpret_cast<__gm__ MmadDtype *>(y);
gm_add_input = reinterpret_cast<__gm__ MmadDtype *>(residual);
gm_add_output = reinterpret_cast<__gm__ MmadDtype *>(add_out);
gm_gamma = reinterpret_cast<__gm__ MmadDtype *>(gamma);
batch_size = ppTilingData->opShape.batchSize;
m = ppTilingData->opShape.m;
k = ppTilingData->opShape.k;
n = ppTilingData->opShape.n;
m0 = ppTilingData->m0;
k0 = ppTilingData->k0;
n0 = ppTilingData->n0;
m_loop = ppTilingData->mLoop;
k_loop = ppTilingData->kLoop;
n_loop = ppTilingData->nLoop;
core_loop = ppTilingData->coreLoop;
swizzl_count = ppTilingData->swizzlCount;
tiling_key = ppTilingData->tilingKey;
rank = hccl_.GetRankId();
rank_size = hccl_.GetRankDim();
max_ub_single_dma_size = commTilingData->ubMoveNum;
withSerialMode = false;
tag = commTilingData->tag;
comm_npu_split = commTilingData->commNpuSplit;
comm_data_split = commTilingData->commDataSplit;
comm_direct = commTilingData->commDirect;
is_91093 = false;
core_count = comm_npu_split * comm_data_split;
dequant_granularity = static_cast<QuantGranularity>(quantInfo->dequantGranularity);
dequant_group_size = quantInfo->dequantGroupSize;
quant_granularity = static_cast<QuantGranularity>(quantInfo->quantGranularity);
quant_group_size = quantInfo->quantGroupSize;
epsilon = tilingData->matmulAllreduceAddRmsnormInfo.rmsnormTilingData.epsilon;
is_gather_add_out = tilingData->matmulAllreduceAddRmsnormInfo.ppTilingData.isGatherAddOut;
swizzl_direct = (tiling_key & SWIZZL_MASK) ? true : false;
trans_a = ppTilingData->isTransA;
trans_b = ppTilingData->isTransB;
is_int8 = false;
ag_dim = 0;
rs_dim = 0;
inner_dim_is_Ag = false;
weight_nz = false;
max_ub_ping_pong_size = max_ub_single_dma_size / 2; // 2 - double buffer
core_idx = get_block_idx();
core_num = get_block_num();
aiv_idx = get_subblockid();
other_rank = (core_idx < rank_size) ? core_idx : -1;
// init ub usage
pipe.InitBuffer(ctrlBuf, AscendC::ONE_BLK_SIZE);
ub_ctrl_flag = reinterpret_cast<__ubuf__ int32_t *>(ctrlBuf.Get<int32_t>().GetPhyAddr());
pipe.InitBuffer(gammaBuf, n * sizeof(MmadDtype));
uint32_t step1_ub_usage = AscendC::AlignUp(
n * sizeof(MmadDtype) +
2 * (rank_size * DIFUSION_ADD_LEN * sizeof(MmadDtype)) +
n * sizeof(MmadDtype) +
n * sizeof(MmadDtype) +
n * sizeof(float) +
n * sizeof(float) +
n * sizeof(float),
AscendC::ONE_BLK_SIZE);
uint32_t step2_ub_usage = AscendC::AlignUp(
max_ub_ping_pong_size * sizeof(MmadDtype),
AscendC::ONE_BLK_SIZE) * 2;
uint32_t max_step_ub_usage = max(step1_ub_usage, step2_ub_usage);
pipe.InitBufPool(step1BufPool, max_step_ub_usage);
pipe.InitBufPool(step2BufPool, max_step_ub_usage, step1BufPool);
step1BufPool.InitBuffer(inQueueX, 1, n * sizeof(MmadDtype));
step1BufPool.InitBuffer(inQueueY, 2, rank_size * DIFUSION_ADD_LEN * sizeof(MmadDtype));
step1BufPool.InitBuffer(addOutQueue, 1, n * sizeof(MmadDtype));
step1BufPool.InitBuffer(outQueue, 1, n * sizeof(MmadDtype));
step1BufPool.InitBuffer(xFp32Buf, n * sizeof(float));
step1BufPool.InitBuffer(sqxBuf, n * sizeof(float));
step1BufPool.InitBuffer(reduceFp32Buf, n * sizeof(float));
step2BufPool.InitBuffer(allgatherBuf[0], max_ub_ping_pong_size * sizeof(MmadDtype));
step2BufPool.InitBuffer(allgatherBuf[1], max_ub_ping_pong_size * sizeof(MmadDtype));
CopyInGamma();
}
__aicore__ inline void Process(const MatmulAllreduceAddRmsnormTilingData *tilingData)
{
// AIV AllReduce & Add & RMSNorm func, waits for AIC to complete [Matmul].
FFTSCrossCoreSync<PIPE_MTE3>(FFTS_SYNC_AICORE_GROUP_MODE, AIC_WAIT_AIV_FINISH_ALIGN_FLAG_ID);
PipeBarrier<PIPE_ALL>();
ResetIpcFlags(FLAG_NUM);
CrossRankSyncEx(FLAG_NUM);
constexpr int32_t allreduce_used_core = 16;
int32_t one_comm_count = swizzl_count;
int32_t loop_num_per_comm = one_comm_count * n_loop;
int32_t comm_count = DivCeil(core_loop, loop_num_per_comm);
int32_t pipe_depth = is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT;
for (int cal_idx = 0; cal_idx < comm_count; ++cal_idx) {
uint64_t flag_idx = cal_idx % pipe_depth;
int32_t m_total = (cal_idx == comm_count - 1) ?
m - cal_idx * swizzl_count * m0 : swizzl_count * m0;
int32_t m_per_rank = DivCeil(m_total, rank_size);
int32_t loop_offset = cal_idx * swizzl_count * m0;
WaitEvent(flag_idx);
SetAndWaitAivSync(flag_idx, is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT);
CrossRankSyncV1(FLAG_ZERO_IDX, cal_idx + 1);
SetAndWaitAivSync(flag_idx, is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT);
if (aiv_idx == 0 && core_idx < allreduce_used_core) {
int32_t m_cur_rank = LimitRange(m_total - rank * m_per_rank, 0, m_per_rank);
int32_t m_per_core = DivCeil(m_cur_rank, allreduce_used_core);
int32_t m_cur_core = LimitRange(m_cur_rank - core_idx * m_per_core, 0, m_per_core);
int32_t core_offset_m = loop_offset + rank * m_per_rank + core_idx * m_per_core;
ParallelWithSplitStepOneAddNorm(core_offset_m * n, m_cur_core);
}
PipeBarrier<PIPE_ALL>();
SetAndWaitAivSync(flag_idx, is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT);
CrossRankSyncV1(FLAG_ADD_IDX, cal_idx + 1);
SetAndWaitAivSync(flag_idx, is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT);
{ // ParallelWithSplitStepTwo
int32_t used_core_per_rank = allreduce_used_core / rank_size;
int32_t sub_core_idx = core_idx % used_core_per_rank;
int32_t gather_rank_id = core_idx / used_core_per_rank;
int32_t m_in_rank = LimitRange(m_total - gather_rank_id * m_per_rank, 0, m_per_rank);
int32_t m_per_core = DivCeil(m_in_rank, used_core_per_rank);
int32_t m_cur_core = LimitRange(m_in_rank - sub_core_idx * m_per_core, 0, m_per_core);
int32_t core_offset_m = loop_offset + gather_rank_id * m_per_rank + sub_core_idx * m_per_core;
auto gm_share_buff = (__gm__ MmadDtype *)hccl_.GetWindowsInAddr(gather_rank_id);
bool filter_core_cond = aiv_idx == 0 && core_idx < allreduce_used_core && m_cur_core > 0;
if (filter_core_cond) {
ParallelAllGather(gm_out, gm_share_buff, core_offset_m * n, m_cur_core * n);
}
SetAndWaitAivSync(flag_idx);
CrossRankSyncV2(FLAG_TWO_IDX, cal_idx + 1);
SetAndWaitAivSync(flag_idx);
if (is_gather_add_out) {
if (filter_core_cond && gather_rank_id == rank) {
ParallelAllGather(gm_share_buff, gm_add_output, core_offset_m * n, m_cur_core * n);
}
SetAndWaitAivSync(flag_idx);
CrossRankSyncV2(FLAG_GATHER_ADD_OUT_STEP1, cal_idx + 1);
SetAndWaitAivSync(flag_idx);
if (filter_core_cond && gather_rank_id != rank) {
ParallelAllGather(gm_add_output, gm_share_buff, core_offset_m * n, m_cur_core * n);
}
SetAndWaitAivSync(flag_idx);
CrossRankSyncV2(FLAG_GATHER_ADD_OUT_STEP2, cal_idx + 1);
SetAndWaitAivSync(flag_idx);
}
}
if (cal_idx <= comm_count - pipe_depth) {
SetAicSync(flag_idx);
}
}
ResetIpcFlags(FLAG_NUM);
if (aiv_idx == 0 && core_idx < rank_size) {
__gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(other_rank);
CheckBuffFlag(ub_ctrl_flag, state_buff + FLAG_ZERO_IDX, 0);
}
}
private:
__aicore__ void SetBuffFlag(__ubuf__ int32_t *ub_ctrl_flag, __gm__ int32_t *buff, int32_t flag)
{
*ub_ctrl_flag = flag;
SetFlag<HardEvent::S_MTE3>(EVENT_ID2);
WaitFlag<HardEvent::S_MTE3>(EVENT_ID2);
CopyUbufToGmAlignB16(buff, ub_ctrl_flag, 1, sizeof(int32_t), 0, 0);
}
__aicore__ void SetBuffFlagByAdd(__ubuf__ int32_t *ub_ctrl_flag, __gm__ int32_t *buff, int32_t flag)
{
PipeBarrier<PIPE_ALL>();
*ub_ctrl_flag = flag;
PipeBarrier<PIPE_ALL>();
SetAtomicAdd<int32_t>();
PipeBarrier<PIPE_ALL>();
CopyUbufToGmAlignB16(buff, ub_ctrl_flag, 1, sizeof(int32_t), 0, 0);
PipeBarrier<PIPE_ALL>();
SetAtomicNone();
PipeBarrier<PIPE_ALL>();
}
__aicore__ void CheckBuffFlag(__ubuf__ int32_t *ub_ctrl_flag, __gm__ int32_t *buff, int32_t flag)
{
SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID1);
WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID1);
while (true) {
CopyGmToUbufAlignB16(ub_ctrl_flag, buff, 1, sizeof(int32_t), 0, 0);
SetFlag<HardEvent::MTE2_S>(EVENT_ID3);
WaitFlag<HardEvent::MTE2_S>(EVENT_ID3);
if (*ub_ctrl_flag == flag) {
break;
}
}
}
__aicore__ void SetAicSync(uint64_t flag_idx)
{
FFTSCrossCoreSync<PIPE_MTE3>(FFTS_SYNC_AICORE_GROUP_MODE, flag_idx);
}
__aicore__ void ResetIpcFlags(int32_t num_flags)
{
for (int32_t idx = 0; idx <= num_flags; ++idx) {
if (core_idx == 0 && aiv_idx == 0) {
__gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(rank);
SetBuffFlag(ub_ctrl_flag, state_buff + idx, 0);
}
}
}
__aicore__ void CrossRankSyncV1(int32_t flag_idx, int32_t flag_data)
{
if (aiv_idx == 0 && core_idx == rank) {
__gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(rank);
SetBuffFlagByAdd(ub_ctrl_flag, state_buff + flag_idx, FLAG_VALUE);
} else if (aiv_idx == 0 && core_idx < rank_size) {
__gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(core_idx);
CheckBuffFlag(ub_ctrl_flag, state_buff + flag_idx, FLAG_VALUE * flag_data);
}
}
__aicore__ void CrossRankSyncV2(int32_t flag_idx, int32_t flag_data)
{
if (aiv_idx == 0 && core_idx < rank_size) {
__gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(core_idx);
SetBuffFlagByAdd(ub_ctrl_flag, state_buff + flag_idx, FLAG_VALUE);
}
if (aiv_idx == 0 && core_idx == rank) {
__gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(rank);
CheckBuffFlag(ub_ctrl_flag, state_buff + flag_idx, FLAG_VALUE * rank_size * flag_data);
}
}
__aicore__ void SetAndWaitAivSync(uint64_t flag_idx, int32_t pipe_depth = 2)
{
FFTSCrossCoreSync<PIPE_MTE3>(0, flag_idx + pipe_depth);
WaitEvent(flag_idx + pipe_depth);
}
__aicore__ inline uint32_t GetGmU32(GM_ADDR gm_addr)
{
copy_gm_to_ubuf_align_b32(ub_ctrl_flag, gm_addr, 0, 1, sizeof(uint32_t), 0, 0, 0, 0);
PipeSync<HardEvent::MTE2_S>();
return *reinterpret_cast<__ubuf__ uint32_t *>(ub_ctrl_flag);
}
__aicore__ inline void SetGmU32(GM_ADDR gm_addr, uint32_t data)
{
*reinterpret_cast<__ubuf__ uint32_t *>(ub_ctrl_flag) = data;
PipeSync<HardEvent::S_MTE3>();
copy_ubuf_to_gm_align_b32(gm_addr, ub_ctrl_flag, 0, 1, sizeof(uint32_t), 0, 0, 0, 0);
}
__aicore__ inline void CrossRankSyncEx(uint32_t flag_idx)
{
AscendC::SyncAll<true>();
__asm__ __volatile__("");
if (aiv_idx == 0 && core_idx == 0) {
auto flag_addr = (GM_ADDR)hccl_.GetWindowsOutAddr(0) + flag_idx * AscendC::ONE_BLK_SIZE;
uint32_t old_flag_data = GetGmU32(flag_addr);
__asm__ __volatile__("");
SetAtomicAdd<int32_t>();
SetGmU32(flag_addr, 1);
PipeSync<HardEvent::MTE3_S>();
SetAtomicNone();
__asm__ __volatile__("");
uint32_t new_flag_data;
do {
new_flag_data = GetGmU32(flag_addr);
__asm__ __volatile__("");
} while (new_flag_data - old_flag_data < rank_size);
__asm__ __volatile__("");
SetAtomicAdd<int32_t>();
SetGmU32(flag_addr, 1);
PipeSync<HardEvent::MTE3_S>();
SetAtomicNone();
}
__asm__ __volatile__("");
AscendC::SyncAll<true>();
}
template <typename T>
__aicore__ inline T min(const T& a, const T& b) {
return (a < b) ? a : b;
}
template <typename T>
__aicore__ inline T max(const T& a, const T& b) {
return (a > b) ? a : b;
}
template <typename T>
__aicore__ inline T LimitRange(const T& val, const T& low, const T& high) {
return min(max(val, low), high);
}
template <AscendC::HardEvent EVENT>
__aicore__ inline void PipeSync()
{
AscendC::TEventID event_id = static_cast<event_t>(GetTPipePtr()->FetchEventID(EVENT));
AscendC::SetFlag<EVENT>(event_id);
AscendC::WaitFlag<EVENT>(event_id);
}
__aicore__ inline void CopyInGamma()
{
GlobalTensor<MmadDtype> gamma_global;
gamma_global.SetGlobalBuffer((__gm__ MmadDtype *)gm_gamma, n);
DataCopy(gammaBuf.Get<MmadDtype>(), gamma_global, n);
PipeSync<HardEvent::MTE2_V>();
}
__aicore__ void ParallelWithSplitStepOneAddNorm(uint32_t core_buf_offset, uint32_t m_cur_core)
{
if (m_cur_core <= 0) {
return;
}
auto buff = (__gm__ MmadDtype *)hccl_.GetWindowsInAddr(rank);
GlobalTensor<MmadDtype> x_global;
GlobalTensor<MmadDtype> y_global;
GlobalTensor<MmadDtype> out_global;
GlobalTensor<MmadDtype> add_out_global;
x_global.SetGlobalBuffer(buff + core_buf_offset);
out_global.SetGlobalBuffer(buff + core_buf_offset);
add_out_global.SetGlobalBuffer(gm_add_output + core_buf_offset);
uint32_t add_count = DivCeil(n, DIFUSION_ADD_LEN);
LocalTensor<MmadDtype> x_local;
LocalTensor<MmadDtype> y_local;
for (uint32_t i = 0; i < m_cur_core; i++) {
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
LocalTensor<float> sqx = sqxBuf.Get<float>();
x_local = inQueueX.AllocTensor<MmadDtype>();
for (uint32_t j = 0; j < add_count; j++) {
uint32_t add_offset = j * DIFUSION_ADD_LEN;
uint32_t add_len = min<uint32_t>(n - add_offset, DIFUSION_ADD_LEN);
DataCopy(x_local[add_offset], x_global[i * n + add_offset], add_len);
inQueueX.EnQue(x_local);
uint32_t iterate_end = (rank + 1) % rank_size;
y_local = inQueueY.AllocTensor<MmadDtype>();
for (uint32_t k = 0; k < rank_size; ++k) {
uint32_t iterate_idx = iterate_end + k;
if (iterate_idx >= rank_size) {
iterate_idx -= rank_size;
}
if (iterate_idx == rank) {
y_global.SetGlobalBuffer(gm_add_input + core_buf_offset);
} else {
auto other_buff = (__gm__ MmadDtype *)hccl_.GetWindowsInAddr(iterate_idx);
y_global.SetGlobalBuffer(other_buff + core_buf_offset);
}
DataCopy(y_local[k * add_len], y_global[i * n + add_offset], add_len);
}
inQueueY.EnQue(y_local);
x_local = inQueueX.DeQue<MmadDtype>();
y_local = inQueueY.DeQue<MmadDtype>();
Cast(x_fp32[add_offset], x_local[add_offset], RoundMode::CAST_NONE, add_len);
PipeBarrier<PIPE_V>();
for (uint32_t k = 0; k < rank_size; ++k) {
// use sqx as shared buf, required n >= add_len
Cast(sqx, y_local[k * add_len], RoundMode::CAST_NONE, add_len);
PipeBarrier<PIPE_V>();
Add(x_fp32[add_offset], x_fp32[add_offset], sqx, add_len);
PipeBarrier<PIPE_V>();
}
inQueueY.FreeTensor(y_local);
}
inQueueX.FreeTensor(x_local);
// copy add result out
LocalTensor<MmadDtype> add_out = addOutQueue.AllocTensor<MmadDtype>();
Cast(add_out, x_fp32, RoundMode::CAST_RINT, n);
addOutQueue.EnQue(add_out);
add_out = addOutQueue.DeQue<MmadDtype>();
DataCopy(add_out_global[i * n], add_out, n);
addOutQueue.FreeTensor(add_out);
LocalTensor<MmadDtype> gamma_local = gammaBuf.Get<MmadDtype>();
LocalTensor<MmadDtype> out_local = outQueue.AllocTensor<MmadDtype>();
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
// make sure precision is same in bf16 case
Cast(out_local, x_fp32, RoundMode::CAST_RINT, n);
PipeBarrier<PIPE_V>();
Cast(x_fp32, out_local, RoundMode::CAST_NONE, n);
PipeBarrier<PIPE_V>();
Mul(sqx, x_fp32, x_fp32, n);
PipeBarrier<PIPE_V>();
Muls(sqx, sqx, (float)1.0 / n, n);
PipeBarrier<PIPE_V>();
ReduceSum(sqx, sqx, reduce_buf_local, n);
PipeBarrier<PIPE_V>();
Adds(sqx, sqx, epsilon, 1);
PipeBarrier<PIPE_V>();
Sqrt(sqx, sqx, 1);
Duplicate(reduce_buf_local, (float)1.0, 1);
PipeBarrier<PIPE_V>();
Div(sqx, reduce_buf_local, sqx, 1);
PipeBarrier<PIPE_V>();
PipeSync<HardEvent::V_S>();
float rstd_value = sqx.GetValue(0);
PipeSync<HardEvent::S_V>();
PipeBarrier<PIPE_V>();
Muls(x_fp32, x_fp32, rstd_value, n);
PipeBarrier<PIPE_V>();
if constexpr (std::is_same<MmadDtype, half>::value) {
Cast(out_local, x_fp32, RoundMode::CAST_NONE, n);
PipeBarrier<PIPE_V>();
Mul(out_local, gamma_local, out_local, n);
PipeBarrier<PIPE_V>();
} else if constexpr (std::is_same<MmadDtype, bfloat16_t>::value) {
Cast(out_local, x_fp32, RoundMode::CAST_RINT, n);
PipeBarrier<PIPE_V>();
Cast(x_fp32, out_local, RoundMode::CAST_NONE, n);
PipeBarrier<PIPE_V>();
Cast(sqx, gamma_local, RoundMode::CAST_NONE, n);
PipeBarrier<PIPE_V>();
Mul(x_fp32, x_fp32, sqx, n);
PipeBarrier<PIPE_V>();
Cast(out_local, x_fp32, RoundMode::CAST_RINT, n);
PipeBarrier<PIPE_V>();
PipeSync<HardEvent::V_MTE2>();
}
outQueue.EnQue(out_local);
out_local = outQueue.DeQue<MmadDtype>();
DataCopy(out_global[i * n], out_local, n);
outQueue.FreeTensor(out_local);
}
}
__aicore__ void ParallelAllGather(__gm__ MmadDtype *gm_dst, __gm__ MmadDtype *gm_src,
uint32_t core_buf_offset, uint32_t data_len)
{
GlobalTensor<MmadDtype> src_global;
GlobalTensor<MmadDtype> dst_global;
src_global.SetGlobalBuffer(gm_src);
dst_global.SetGlobalBuffer(gm_dst);
constexpr uint32_t PIPELINE_COPY_NUM = sizeof(allgatherBuf) / sizeof(allgatherBuf[0]);
TEventID ev_mte3_mte2[PIPELINE_COPY_NUM];
TEventID ev_mte2_mte3[PIPELINE_COPY_NUM];
LocalTensor<MmadDtype> local_tensors[PIPELINE_COPY_NUM];
for (uint32_t i = 0; i < PIPELINE_COPY_NUM; i++) {
ev_mte3_mte2[i] = GetTPipePtr()->AllocEventID<HardEvent::MTE3_MTE2>();
ev_mte2_mte3[i] = GetTPipePtr()->AllocEventID<HardEvent::MTE2_MTE3>();
SetFlag<HardEvent::MTE3_MTE2>(ev_mte3_mte2[i]);
local_tensors[i] = allgatherBuf[i].Get<MmadDtype>();
}
uint32_t offset = core_buf_offset;
uint32_t copy_len = max_ub_ping_pong_size; // num of MmadDtype, not the byte length
uint32_t copy_count = DivCeil(data_len, copy_len);
uint32_t pipe_id = 0;
for (uint32_t i = 0; i < copy_count; i++) {
uint32_t actual_copy_len =
(i == copy_count - 1) ? (data_len - i * copy_len) : copy_len;
auto &local_tensor = local_tensors[pipe_id];
WaitFlag<HardEvent::MTE3_MTE2>(ev_mte3_mte2[pipe_id]);
DataCopy(local_tensor, src_global[offset], actual_copy_len);
SetFlag<HardEvent::MTE2_MTE3>(ev_mte2_mte3[pipe_id]);
WaitFlag<HardEvent::MTE2_MTE3>(ev_mte2_mte3[pipe_id]);
DataCopy(dst_global[offset], local_tensor, actual_copy_len);
SetFlag<HardEvent::MTE3_MTE2>(ev_mte3_mte2[pipe_id]);
offset += actual_copy_len;
pipe_id = (pipe_id + 1) % PIPELINE_COPY_NUM;
}
for (uint32_t i = 0; i < PIPELINE_COPY_NUM; i++) {
WaitFlag<HardEvent::MTE3_MTE2>(ev_mte3_mte2[i]);
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_MTE2>(ev_mte3_mte2[i]);
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_MTE3>(ev_mte2_mte3[i]);
}
PipeBarrier<PIPE_ALL>();
}
__gm__ MmadDtype *gm_out;
__gm__ MmadDtype *gm_add_input;
__gm__ MmadDtype *gm_add_output;
__gm__ MmadDtype *gm_gamma;
__ubuf__ int32_t *ub_ctrl_flag;
int32_t batch_size;
int32_t m;
int32_t k;
int32_t n;
int32_t m0;
int32_t k0;
int32_t n0;
int32_t m_loop;
int32_t n_loop;
int32_t k_loop;
int32_t core_loop;
int32_t core_idx;
int32_t rank;
int32_t rank_size;
int32_t tiling_key;
int32_t swizzl_count;
bool swizzl_direct;
bool trans_a;
bool trans_b;
bool is_int8;
bool is_91093;
bool is_gather_add_out;
int32_t aiv_idx;
int32_t other_rank;
int32_t core_num;
int32_t max_ub_single_dma_size;
int32_t max_ub_ping_pong_size;
int32_t gm_c_pingpong_size;
int32_t withSerialMode;
int32_t tag;
int32_t comm_npu_split;
int32_t comm_data_split;
int32_t comm_direct;
int32_t core_count;
bool is_deterministic;
QuantGranularity dequant_granularity;
int32_t dequant_group_size;
QuantGranularity quant_granularity;
int32_t quant_group_size;
WorkspaceInfo workspace_info;
int32_t ag_dim;
int32_t rs_dim;
bool inner_dim_is_Ag;
bool weight_nz{false};
float epsilon;
TPipe pipe;
AscendC::TBufPool<TPosition::VECCALC, TBUF_POOL_MAX_BUFID_SIZE> step1BufPool;
AscendC::TBufPool<TPosition::VECCALC, TBUF_POOL_MAX_BUFID_SIZE> step2BufPool;
AscendC::TQue<AscendC::QuePosition::VECIN, TQUE_DEPTH> inQueueX, inQueueY;
AscendC::TQue<AscendC::QuePosition::VECOUT, TQUE_DEPTH> outQueueZ;
AscendC::TQue<AscendC::QuePosition::VECOUT, TQUE_DEPTH> addOutQueue;
AscendC::TQue<AscendC::QuePosition::VECOUT, TQUE_DEPTH> outQueue;
AscendC::TBuf<TPosition::VECCALC> ctrlBuf;
AscendC::TBuf<TPosition::VECCALC> gammaBuf;
AscendC::TBuf<TPosition::VECCALC> xFp32Buf;
AscendC::TBuf<TPosition::VECCALC> sqxBuf;
AscendC::TBuf<TPosition::VECCALC> reduceFp32Buf;
AscendC::TBuf<TPosition::VECCALC> allgatherBuf[2];
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
};
#endif

View File

@@ -0,0 +1,101 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MATMUL_ALLREDUCE_ADD_RMSNORM_TILING_H
#define MATMUL_ALLREDUCE_ADD_RMSNORM_TILING_H
#include <cstdint>
#include "kernel_tiling/kernel_tiling.h"
enum QuantGranularity : int {
QUANT_GRANULARITY_UNDEFINED = -1,
PER_TENSOR = 0,
PER_CHANNEL = 1,
PER_GROUP = 2,
QUANT_GRANULARITY_MAX = 3,
};
struct Opshape {
int32_t batchSize = 1;
int32_t m = -1;
int32_t k = -1;
int32_t n = -1;
};
struct PPTilingData {
Opshape opShape = {};
int32_t m0 = 1;
int32_t k0 = 1;
int32_t n0 = 1;
int32_t mLoop = 1;
int32_t kLoop = 1;
int32_t nLoop = 1;
int32_t coreLoop = 1;
int32_t swizzlCount = 1;
int32_t swizzlDirect = 0;
uint32_t tilingKey = 0;
int32_t blockDim = 1;
int32_t splitK = 0;
bool weightNz = false;
bool isTransA = false;
bool isTransB = false;
bool isGatherAddOut = false;
};
struct CommTilingData {
int32_t rank = 1;
int32_t rankSize = 1;
int32_t pValue = 1;
int32_t ubMoveNum = 1;
int32_t write2OtherRank = 0;
int32_t withSerialMode = 0;
int32_t tag = 0;
int32_t commNpuSplit = 1;
int32_t commDataSplit = 1;
int32_t commDirect = 0;
int32_t lenPerLoop = 1;
int32_t is91093 = 0;
int32_t buffer_size = 0;
};
struct RmsNormTilingData {
RmsNormTiling tiling{};
uint32_t loopCount;
uint32_t calcBytes;
float epsilon{};
};
struct QuantInfo {
QuantGranularity dequantGranularity = QuantGranularity::QUANT_GRANULARITY_UNDEFINED;
int32_t dequantGroupSize = -1;
QuantGranularity quantGranularity = QuantGranularity::QUANT_GRANULARITY_UNDEFINED;
int32_t quantGroupSize = -1;
};
struct MatmulAllreduceAddRmsnormInfo {
PPTilingData ppTilingData{};
CommTilingData commTilingData{};
RmsNormTilingData rmsnormTilingData{};
QuantInfo quantInfo{};
};
struct MatmulAllreduceAddRmsnormTilingData {
Mc2InitTiling mc2InitTiling;
Mc2CcTiling mc2CcTiling;
MatmulAllreduceAddRmsnormInfo matmulAllreduceAddRmsnormInfo;
};
#endif // MATMUL_ALLREDUCE_ADD_RMSNORM_TILING_H

View File

@@ -0,0 +1,414 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MATMUL_ALLREDUCE_ADD_RMSNORM_UTILS_H
#define MATMUL_ALLREDUCE_ADD_RMSNORM_UTILS_H
#include <type_traits>
#include "kernel_operator.h"
using namespace AscendC;
constexpr int64_t ND2NZ_STRIDE_LIMIT = 65536;
constexpr int32_t AIC_WAIT_AIV_FINISH_ALIGN_FLAG_ID = 12;
constexpr int32_t MAX_BLOCK_COUNT = 2;
constexpr int32_t BLOCK_COUNT_3 = 3;
constexpr int32_t BLOCK_COUNT_4 = 4;
constexpr int32_t TILE_BLOCK_MOD = 2;
constexpr int32_t BLOCK_SIZE_32B = 32;
constexpr int32_t BLOCK_SIZE_256B = 256;
constexpr int32_t BLOCK_SIZE_512B = 512;
constexpr int32_t FFTS_SYNC_INTERNEL_MODE = 0;
constexpr int32_t FFTS_SYNC_AICORE_GROUP_MODE = 2;
constexpr int32_t SWIZZL_MASK = 0b100000;
constexpr int32_t TRANS_A_MASK = 0b010000;
constexpr int32_t TRANS_B_MASK = 0b001000;
constexpr int32_t INT8_MASK = 0b000100;
constexpr int32_t BIAS_MASK = 0b000010;
template <typename T, size_t SIZE>
struct BaseBlock {
static_assert((SIZE & (SIZE - 1)) == 0, "Invalid block size");
static constexpr size_t size = SIZE / sizeof(T);
static __aicore__ inline size_t Count(size_t len)
{
return (len + size - 1) / size;
}
static __aicore__ inline bool IsAligned(size_t len)
{
return len % size == 0;
}
static __aicore__ inline size_t AlignUp(size_t len)
{
return (len + size - 1) & ~(size - 1);
}
static __aicore__ inline size_t AlignDown(size_t len)
{
return len & ~(size - 1);
}
};
template <typename T>
using Block32B = BaseBlock<T, BLOCK_SIZE_32B>;
template <typename T>
using Block256B = BaseBlock<T, BLOCK_SIZE_256B>;
template <typename T>
using Block512B = BaseBlock<T, BLOCK_SIZE_512B>;
struct WorkspaceInfo {
__gm__ uint8_t *gm_a_align{ nullptr };
__gm__ uint8_t *gm_b_align{ nullptr };
__gm__ uint8_t *gm_accum{ nullptr };
__gm__ uint8_t *gm_dequant_param{ nullptr };
};
template <typename T>
__aicore__ inline LocalTensor<T> CreateLocalTensor(__ubuf__ T *addr)
{
LocalTensor<T> tensor;
TBuffAddr taddr;
taddr.bufferAddr = reinterpret_cast<uint64_t>(addr);
tensor.SetAddr(taddr);
return tensor;
}
template <typename T>
__aicore__ inline LocalTensor<T> CreateLocalTensor(uint32_t buffer_offset)
{
LocalTensor<T> tensor;
tensor.address_.bufferAddr = buffer_offset;
return tensor;
}
template <typename T>
__aicore__ inline LocalTensor<T> CreateLocalTensor(uint32_t buffer_offset, uint8_t logic_pos)
{
LocalTensor<T> tensor;
tensor.address_.logicPos = logic_pos;
tensor.address_.bufferAddr = buffer_offset;
return tensor;
}
template<typename T>
struct IntrinsicCopyGmToL1Nd2Nz {
static __aicore__ inline void move(
__cbuf__ T *dst, __gm__ T *src,
uint8_t sid, uint16_t ndNum, uint16_t nValue, uint16_t dValue,
uint16_t srcNdMatrixStride, uint16_t srcDValue, uint16_t dstNzC0Stride,
uint16_t dstNzNStride, uint16_t dstNzMatrixStride) {
Nd2NzParams nd2nzParams(
ndNum, nValue, dValue,
srcNdMatrixStride, srcDValue, dstNzC0Stride,
dstNzNStride, dstNzMatrixStride
);
uint32_t dst_buffer_offset = reinterpret_cast<uint64_t>(dst);
uint8_t dst_logicpos = static_cast<uint8_t>(TPosition::C1);
LocalTensor<T> dstTensor;
dstTensor = CreateLocalTensor<T>(dst_buffer_offset, dst_logicpos);
GlobalTensor<T> srcTensor;
srcTensor.SetGlobalBuffer(src);
DataCopy(dstTensor, srcTensor, nd2nzParams);
}
};
template <typename T>
struct CopyGmToL1Nd2zN {
static __aicore__ inline void move(
__cbuf__ T *dst, __gm__ T *src,
uint16_t nValue, uint16_t dValue, uint32_t srcDValue, uint16_t dstNzC0Stride) {
constexpr int BLOCK_LEN = 32 / sizeof(T);
if (srcDValue < ND2NZ_STRIDE_LIMIT) {
IntrinsicCopyGmToL1Nd2Nz<T>::move(
dst,
src,
0,
1,
nValue,
dValue,
0,
srcDValue,
dstNzC0Stride,
1,
0
);
} else {
for (int i = 0; i < nValue; i++) {
IntrinsicCopyGmToL1Nd2Nz<T>::move(
dst + i * BLOCK_LEN,
src + i * srcDValue,
0,
1,
1,
dValue,
0,
0,
dstNzC0Stride,
0,
0
);
}
}
}
};
__aicore__ inline void AlignJudge(bool trans_a, bool trans_b, int32_t m, int32_t k, int32_t n, int32_t m_align,
int32_t k_align, int32_t n_align, int32_t &aligned_a, int32_t &aligned_b)
{
if (!trans_a) {
aligned_a = k != k_align;
} else {
aligned_a = (m != m_align && m != 1);
}
if (!trans_b) {
aligned_b = (n != n_align);
} else {
aligned_b = (k != k_align);
}
}
__aicore__ inline WorkspaceInfo GetWorkspaceInfo(__gm__ uint8_t *gm_workspace, int32_t batch_size, int32_t m,
int32_t k, int32_t n, int32_t m_align, int32_t k_align, int32_t n_align, bool trans_a, bool trans_b,
int32_t mmad_dsize, bool has_a_align, bool has_b_align, bool has_accum = false, bool has_dequant_param = false)
{
WorkspaceInfo workspace_info;
uint64_t workspace_offset = 0;
if (has_a_align) {
workspace_info.gm_a_align = gm_workspace + workspace_offset;
workspace_offset += static_cast<uint64_t>(batch_size) * (trans_a ? k * m_align : m * k_align) * mmad_dsize;
}
if (has_b_align) {
workspace_info.gm_b_align = gm_workspace + workspace_offset;
workspace_offset += static_cast<uint64_t>(batch_size) * (trans_b ? n * k_align : k * n_align) * mmad_dsize;
}
if (has_accum) {
workspace_info.gm_accum = gm_workspace + workspace_offset;
workspace_offset += static_cast<uint64_t>(batch_size) * m * n * sizeof(int32_t);
}
if (has_dequant_param) {
workspace_info.gm_dequant_param = gm_workspace + workspace_offset;
workspace_offset += n * sizeof(float32_t);
}
return workspace_info;
}
template<typename T>
__aicore__ inline void CopyCubfToBt(uint64_t dst, __cbuf__ T *src, uint16_t convControl, uint16_t nBurst,
uint16_t lenBurst, uint16_t sourceGap, uint16_t dstGap)
{
DataCopyParams intriParams(nBurst, lenBurst, sourceGap, dstGap);
uint32_t src_buffer_offset = reinterpret_cast<uint64_t>(src);
uint32_t dst_buffer_offset = reinterpret_cast<uint64_t>(dst);
uint8_t src_logicpos = static_cast<uint8_t>(TPosition::C1);
uint8_t dst_logicpos = static_cast<uint8_t>(TPosition::C2);
LocalTensor<T> srcTensor;
LocalTensor<T> dstTensor;
srcTensor = CreateLocalTensor<T>(src_buffer_offset, src_logicpos);
dstTensor = CreateLocalTensor<T>(dst_buffer_offset, dst_logicpos);
DataCopy(dstTensor, srcTensor, intriParams);
}
template<typename T>
__aicore__ inline void CopyGmToCbuf(__cbuf__ T *dst, __gm__ T *src, uint8_t sid, uint16_t nBurst,
uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride, pad_t padMode)
{
DataCopyParams intriParams(nBurst, lenBurst, srcStride, dstStride);
GlobalTensor<T> srcTensor;
srcTensor.SetGlobalBuffer(src);
uint32_t dst_buffer_offset = reinterpret_cast<uint64_t>(dst);
uint8_t logicpos = static_cast<uint8_t>(TPosition::C1);
LocalTensor<T> dstTensor;
dstTensor = CreateLocalTensor<T>(dst_buffer_offset, logicpos);
DataCopy(dstTensor, srcTensor, intriParams);
}
template<typename T>
__aicore__ inline void SetFpc(__fbuf__ T *src)
{
LocalTensor<T> tensor;
uint32_t src_buffer_offset = reinterpret_cast<uint64_t>(src);
tensor = CreateLocalTensor<T>(src_buffer_offset);
SetFixPipeConfig(tensor);
}
template<typename T>
__aicore__ inline void LoadCbufToCaTranspose(__ca__ T *dst, __cbuf__ T *src, uint16_t indexID, uint8_t repeat,
uint16_t srcStride, uint16_t dstStride, bool addrmode,
uint16_t dstFracStride)
{
LoadData2dTransposeParams params(
indexID,
repeat,
srcStride,
dstStride,
dstFracStride,
addrmode
);
uint32_t src_buffer_offset = reinterpret_cast<uint64_t>(src);
uint32_t dst_buffer_offset = reinterpret_cast<uint64_t>(dst);
uint8_t src_logicpos = static_cast<uint8_t>(TPosition::C1);
uint8_t dst_logicpos = static_cast<uint8_t>(TPosition::A2);
LocalTensor<T> srcTensor;
LocalTensor<T> dstTensor;
srcTensor = CreateLocalTensor<T>(src_buffer_offset, src_logicpos);
dstTensor = CreateLocalTensor<T>(dst_buffer_offset, dst_logicpos);
LoadDataWithTranspose(dstTensor, srcTensor, params);
}
template<typename T>
__aicore__ inline void LoadCbufToCbTranspose(__cb__ T *dst, __cbuf__ T *src, uint16_t indexID, uint8_t repeat,
uint16_t srcStride, uint16_t dstStride, bool addrmode,
uint16_t dstFracStride)
{
LoadData2dTransposeParams params(
indexID,
repeat,
srcStride,
dstStride,
dstFracStride,
addrmode
);
uint32_t src_buffer_offset = reinterpret_cast<uint64_t>(src);
uint32_t dst_buffer_offset = reinterpret_cast<uint64_t>(dst);
uint8_t src_logicpos = static_cast<uint8_t>(TPosition::C1);
uint8_t dst_logicpos = static_cast<uint8_t>(TPosition::B2);
LocalTensor<T> srcTensor;
LocalTensor<T> dstTensor;
srcTensor = CreateLocalTensor<T>(src_buffer_offset, src_logicpos);
dstTensor = CreateLocalTensor<T>(dst_buffer_offset, dst_logicpos);
LoadDataWithTranspose(dstTensor, srcTensor, params);
}
template <typename T>
__aicore__ inline void LoadCbufToCa(__ca__ T *dst, __cbuf__ T *src, uint16_t baseIdx, uint8_t repeat,
uint16_t srcStride, uint16_t dstStride, uint8_t sid, bool transpose,
uint8_t addr_cal_mode)
{
LoadData2dParams params(
baseIdx,
repeat,
srcStride,
sid,
dstStride,
transpose,
addr_cal_mode
);
uint32_t src_buffer_offset = reinterpret_cast<uint64_t>(src);
uint32_t dst_buffer_offset = reinterpret_cast<uint64_t>(dst);
uint8_t src_logicpos = static_cast<uint8_t>(TPosition::C1);
uint8_t dst_logicpos = static_cast<uint8_t>(TPosition::A2);
LocalTensor<T> srcTensor;
LocalTensor<T> dstTensor;
srcTensor = CreateLocalTensor<T>(src_buffer_offset, src_logicpos);
dstTensor = CreateLocalTensor<T>(dst_buffer_offset, dst_logicpos);
LoadData(dstTensor, srcTensor, params);
}
template<typename T>
__aicore__ inline void LoadCbufToCb(__cb__ T *dst, __cbuf__ T *src, uint16_t baseIdx, uint8_t repeat,
uint16_t srcStride, uint16_t dstStride, uint8_t sid, bool transpose,
uint8_t addr_cal_mode)
{
LoadData2dParams params(
baseIdx,
repeat,
srcStride,
sid,
dstStride,
transpose,
addr_cal_mode
);
uint32_t src_buffer_offset = reinterpret_cast<uint64_t>(src);
uint32_t dst_buffer_offset = reinterpret_cast<uint64_t>(dst);
uint8_t src_logicpos = static_cast<uint8_t>(TPosition::C1);
uint8_t dst_logicpos = static_cast<uint8_t>(TPosition::B2);
LocalTensor<T> srcTensor;
LocalTensor<T> dstTensor;
srcTensor = CreateLocalTensor<T>(src_buffer_offset, src_logicpos);
dstTensor = CreateLocalTensor<T>(dst_buffer_offset, dst_logicpos);
LoadData(dstTensor, srcTensor, params);
}
__aicore__ inline void GetBlockIdx(int32_t loop_idx, int32_t m_loop, int32_t n_loop, int32_t swizzl_direction,
int32_t swizzl_count, int64_t &m_idx, int64_t &n_idx)
{
uint32_t in_batch_idx = loop_idx % (m_loop * n_loop);
if (swizzl_direction == 0) {
uint32_t tile_block_loop = (m_loop + swizzl_count - 1) / swizzl_count;
uint32_t tile_block_idx = in_batch_idx / (swizzl_count * n_loop);
uint32_t in_tile_block_idx = in_batch_idx % (swizzl_count * n_loop);
uint32_t n_row = swizzl_count;
if (tile_block_idx == tile_block_loop - 1) {
n_row = m_loop - swizzl_count * tile_block_idx;
}
m_idx = tile_block_idx * swizzl_count + in_tile_block_idx % n_row;
n_idx = in_tile_block_idx / n_row;
if (tile_block_idx % TILE_BLOCK_MOD != 0) {
n_idx = n_loop - n_idx - 1;
}
} else if (swizzl_direction == 1) {
uint32_t tile_block_loop = (n_loop + swizzl_count - 1) / swizzl_count;
uint32_t tile_block_idx = in_batch_idx / (swizzl_count * m_loop);
uint32_t in_tile_block_idx = in_batch_idx % (swizzl_count * m_loop);
uint32_t n_col = swizzl_count;
if (tile_block_idx == tile_block_loop - 1) {
n_col = n_loop - swizzl_count * tile_block_idx;
}
m_idx = in_tile_block_idx / n_col;
n_idx = tile_block_idx * swizzl_count + in_tile_block_idx % n_col;
if (tile_block_idx % TILE_BLOCK_MOD != 0) {
m_idx = m_loop - m_idx - 1;
}
}
}
template <pipe_t pipe>
__aicore__ inline void FFTSCrossCoreSync(uint64_t mode, uint64_t flag_id)
{
uint64_t config = 1 | (mode << 4) | (flag_id << 8);
ffts_cross_core_sync(pipe, config);
}
template <typename T>
__aicore__ GlobalTensor<T> CreateGlobalTensor(__gm__ T *addr)
{
GlobalTensor<T> tensor;
tensor.SetGlobalBuffer(addr);
return tensor;
}
#endif // MATMUL_ALLREDUCE_ADD_RMSNORM_H