/** * This program is free software, you can redistribute it and/or modify it. * Copyright (c) 2025 Huawei Technologies Co., Ltd. * This file is a part of the CANN Open Software. * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). * Please refer to the License for details. You may not use this file except in compliance with the License. * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. * See LICENSE in the root of the software repository for the full text of the License. */ /*! * \file lightning_indexer_service_cube.h * \brief use 5 buffer for matmul l1, better pipeline */ #ifndef LIGHTNING_INDEXER_SERVICE_CUBE_H #define LIGHTNING_INDEXER_SERVICE_CUBE_H #include "kernel_operator.h" #include "kernel_operator_list_tensor_intf.h" #include "kernel_tiling/kernel_tiling.h" #include "lib/matmul_intf.h" #include "lib/matrix/matmul/tiling.h" #include "lightning_indexer_common.h" namespace LIKernel { using namespace LICommon; template class LIMatmul { public: using Q_T = typename LIT::queryType; using K_T = typename LIT::keyType; __aicore__ inline LIMatmul(){}; __aicore__ inline void InitBuffers(TPipe *pipe); __aicore__ inline void InitMm1GlobalTensor(const GlobalTensor &blkTableGm, const GlobalTensor &keyGm, const GlobalTensor &queryGm, const GlobalTensor &mm1ResGm); __aicore__ inline void InitParams(const ConstInfo &constInfo); __aicore__ inline void AllocEventID(); __aicore__ inline void FreeEventID(); __aicore__ inline void ComputeMm1(const LICommon::RunInfo &runInfo); static constexpr IsResetLoad3dConfig LOAD3DV2_CONFIG = {true, true}; // isSetFMatrix isSetPadding; static constexpr uint64_t KEY_BUF_NUM = 3; static constexpr uint64_t QUERY_BUF_NUM = 2; static constexpr uint64_t L0_BUF_NUM = 2; static constexpr uint32_t KEY_MTE1_MTE2_EVENT = EVENT_ID2; static constexpr uint32_t QUERY_MTE1_MTE2_EVENT = EVENT_ID5; // KEY_MTE1_MTE2_EVENT + KEY_BUF_NUM; static constexpr uint32_t M_MTE1_EVENT = EVENT_ID3; static constexpr uint32_t MTE2_MTE1_EVENT = EVENT_ID2; static constexpr uint32_t MTE1_M_EVENT = EVENT_ID2; static constexpr uint64_t M_BASIC_BLOCK = 256; static constexpr uint64_t D_BASIC_BLOCK = 128; static constexpr uint64_t S2_BASIC_BLOCK = 256; static constexpr uint64_t M_BASIC_BLOCK_L0 = 128; static constexpr uint64_t D_BASIC_BLOCK_L0 = 128; static constexpr uint64_t S2_BASIC_BLOCK_L0 = 128; static constexpr uint64_t QUERY_BUFFER_OFFSET = M_BASIC_BLOCK * D_BASIC_BLOCK; static constexpr uint64_t KEY_BUFFER_OFFSET = S2_BASIC_BLOCK * D_BASIC_BLOCK; static constexpr uint64_t L0AB_BUFFER_OFFSET = M_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0; static constexpr uint64_t L0C_BUFFER_OFFSET = M_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0; protected: __aicore__ inline void Fixp(uint64_t s1gGmOffset, uint64_t s2GmOffset, uint64_t s1gL0RealSize, uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo); __aicore__ inline void ComuteL0c(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo); __aicore__ inline void LoadKeyToL0b(uint64_t s2L0Offset, uint64_t s2L1RealSize, uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo); __aicore__ inline void LoadQueryToL0a(uint64_t s1gL1Offset, uint64_t s1gL0Offset, uint64_t s1gL1RealSize, uint64_t s1gL0RealSize, const LICommon::RunInfo &runInfo); __aicore__ inline void QueryNd2Nz(uint64_t s1gL1RealSize, uint64_t s1gL1Offset, const LICommon::RunInfo &runInfo); __aicore__ inline void KeyNd2Nz(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LICommon::RunInfo &runInfo); __aicore__ inline void KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LICommon::RunInfo &runInfo); GlobalTensor blkTableGm_; GlobalTensor keyGm_; GlobalTensor queryGm_; GlobalTensor mm1ResGm_; TBuf bufQL1_; LocalTensor queryL1_; TBuf bufKeyL1_; LocalTensor keyL1_; TBuf bufQL0_; LocalTensor queryL0_; TBuf bufKeyL0_; LocalTensor keyL0_; TBuf bufL0C_; LocalTensor cL0_; uint64_t keyL1BufIdx_ = 0; uint64_t queryL1Mte2BufIdx_ = 0; uint64_t queryL1Mte1BufIdx_ = 0; uint64_t l0BufIdx_ = 0; ConstInfo constInfo_; private: static constexpr bool PAGE_ATTENTION = LIT::pageAttention; }; template __aicore__ inline void LIMatmul::InitParams(const ConstInfo &constInfo) { constInfo_ = constInfo; } template __aicore__ inline void LIMatmul::InitBuffers(TPipe *pipe) { pipe->InitBuffer(bufQL1_, QUERY_BUF_NUM * M_BASIC_BLOCK * D_BASIC_BLOCK * sizeof(Q_T)); queryL1_ = bufQL1_.Get(); pipe->InitBuffer(bufKeyL1_, KEY_BUF_NUM * S2_BASIC_BLOCK * D_BASIC_BLOCK * sizeof(K_T)); keyL1_ = bufKeyL1_.Get(); pipe->InitBuffer(bufQL0_, L0_BUF_NUM * M_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 * sizeof(Q_T)); queryL0_ = bufQL0_.Get(); pipe->InitBuffer(bufKeyL0_, L0_BUF_NUM * D_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0 * sizeof(K_T)); keyL0_ = bufKeyL0_.Get(); pipe->InitBuffer(bufL0C_, L0_BUF_NUM * M_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0 * sizeof(float)); cL0_ = bufL0C_.Get(); } template __aicore__ inline void LIMatmul::InitMm1GlobalTensor(const GlobalTensor &blkTableGm, const GlobalTensor &keyGm, const GlobalTensor &queryGm, const GlobalTensor &mm1ResGm) { blkTableGm_ = blkTableGm; keyGm_ = keyGm; queryGm_ = queryGm; mm1ResGm_ = mm1ResGm; } template __aicore__ inline void LIMatmul::ComputeMm1(const LICommon::RunInfo &runInfo) { uint64_t s2GmBaseOffset = runInfo.s2Idx * constInfo_.s2BaseSize; uint64_t s1gProcessSize = runInfo.actMBaseSize; uint64_t s2ProcessSize = runInfo.actualSingleProcessSInnerSize; for (uint64_t s2GmOffset = 0; s2GmOffset < s2ProcessSize; s2GmOffset += S2_BASIC_BLOCK) { WaitFlag(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % KEY_BUF_NUM); uint64_t s2L1RealSize = s2GmOffset + S2_BASIC_BLOCK > s2ProcessSize ? s2ProcessSize - s2GmOffset : S2_BASIC_BLOCK; if (PAGE_ATTENTION) { KeyNd2NzForPA(s2L1RealSize, s2GmBaseOffset + s2GmOffset, runInfo); }else { KeyNd2Nz(s2L1RealSize, s2GmOffset, runInfo); } SetFlag(MTE2_MTE1_EVENT); WaitFlag(MTE2_MTE1_EVENT); for (uint64_t s1gGmOffset = 0; s1gGmOffset < s1gProcessSize; s1gGmOffset += M_BASIC_BLOCK) { uint64_t s1gL1RealSize = s1gGmOffset + M_BASIC_BLOCK > s1gProcessSize ? s1gProcessSize - s1gGmOffset : M_BASIC_BLOCK; if (runInfo.isFirstS2InnerLoop && s2GmOffset == 0) { queryL1Mte2BufIdx_++; queryL1Mte1BufIdx_ = queryL1Mte2BufIdx_; WaitFlag(QUERY_MTE1_MTE2_EVENT + queryL1Mte2BufIdx_ % QUERY_BUF_NUM); QueryNd2Nz(s1gL1RealSize, s1gGmOffset, runInfo); SetFlag(MTE2_MTE1_EVENT); WaitFlag(MTE2_MTE1_EVENT); } else { queryL1Mte1BufIdx_ = queryL1Mte2BufIdx_ - (CeilDiv(s1gProcessSize, M_BASIC_BLOCK) - 1 - (s1gGmOffset > 0)); } for (uint64_t s2L1Offset = 0; s2L1Offset < s2L1RealSize; s2L1Offset += S2_BASIC_BLOCK_L0) { uint64_t s2L0RealSize = s2L1Offset + S2_BASIC_BLOCK_L0 > s2L1RealSize ? s2L1RealSize - s2L1Offset : S2_BASIC_BLOCK_L0; for (uint64_t s1gL1Offset = 0; s1gL1Offset < s1gL1RealSize; s1gL1Offset += M_BASIC_BLOCK_L0) { WaitFlag(M_MTE1_EVENT + l0BufIdx_ % L0_BUF_NUM); uint64_t s1gL0RealSize = s1gL1Offset + M_BASIC_BLOCK_L0 > s1gL1RealSize ? s1gL1RealSize - s1gL1Offset : M_BASIC_BLOCK_L0; LoadQueryToL0a(s1gGmOffset, s1gL1Offset, s1gL1RealSize, s1gL0RealSize, runInfo); LoadKeyToL0b(s2L1Offset, s2L1RealSize, s2L0RealSize, runInfo); SetFlag(MTE1_M_EVENT); WaitFlag(MTE1_M_EVENT); ComuteL0c(s1gL0RealSize, s2L0RealSize, runInfo); SetFlag(M_MTE1_EVENT + l0BufIdx_ % L0_BUF_NUM); Fixp(s1gGmOffset + s1gL1Offset, s2GmOffset + s2L1Offset, s1gL0RealSize, s2L0RealSize, runInfo); l0BufIdx_++; } } if (s2GmOffset + S2_BASIC_BLOCK >= s2ProcessSize && runInfo.isLastS2InnerLoop) { SetFlag(QUERY_MTE1_MTE2_EVENT + queryL1Mte1BufIdx_ % QUERY_BUF_NUM); } } SetFlag(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % KEY_BUF_NUM); keyL1BufIdx_++; } } template __aicore__ inline void LIMatmul::KeyNd2Nz(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LICommon::RunInfo &runInfo) { uint64_t s2L1Offset = 0; while (s2L1Offset < s2L1RealSize) { uint64_t keyGmOffset = runInfo.tensorKeyOffset + (s2GmOffset + s2L1Offset) * constInfo_.headDim; uint64_t s2Mte2Size = (s2L1RealSize <= S2_BASIC_BLOCK_L0 || s2L1Offset >= S2_BASIC_BLOCK_L0) ? s2L1RealSize - s2L1Offset : S2_BASIC_BLOCK_L0 - s2L1Offset; Nd2NzParams nd2nzPara; nd2nzPara.ndNum = 1; nd2nzPara.nValue = s2Mte2Size; // 行数 nd2nzPara.dValue = constInfo_.headDim; nd2nzPara.srcDValue = constInfo_.headDim; nd2nzPara.dstNzC0Stride = s2L1Offset >= S2_BASIC_BLOCK_L0 ? CeilAlign(s2L1RealSize - S2_BASIC_BLOCK_L0, (uint64_t)BLOCK_CUBE) : (s2L1RealSize > S2_BASIC_BLOCK_L0 ? S2_BASIC_BLOCK_L0 : CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE)); nd2nzPara.dstNzNStride = 1; nd2nzPara.srcNdMatrixStride = 0; nd2nzPara.dstNzMatrixStride = 0; DataCopy(keyL1_[(keyL1BufIdx_ % KEY_BUF_NUM) * KEY_BUFFER_OFFSET + (s2L1Offset >= S2_BASIC_BLOCK_L0 ? S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 + (s2L1Offset - S2_BASIC_BLOCK_L0) * BLOCK_CUBE : s2L1Offset * BLOCK_CUBE)], keyGm_[keyGmOffset], nd2nzPara); s2L1Offset += s2Mte2Size; } } // blkNum, blkSize, N2, D template __aicore__ inline void LIMatmul::KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LICommon::RunInfo &runInfo) { uint64_t s2L1Offset = 0; while (s2L1Offset < s2L1RealSize) { uint64_t s2BlkId = (s2L1Offset + s2GmOffset) / constInfo_.kCacheBlockSize; uint64_t s2BlkOffset = (s2L1Offset + s2GmOffset) % constInfo_.kCacheBlockSize; uint64_t keyGmOffset = blkTableGm_.GetValue(runInfo.bIdx * constInfo_.maxBlockNumPerBatch + s2BlkId) * constInfo_.kCacheBlockSize * constInfo_.kHeadNum * constInfo_.headDim + s2BlkOffset * constInfo_.headDim; uint64_t s2Mte2Size = (s2L1RealSize <= S2_BASIC_BLOCK_L0 || s2L1Offset >= S2_BASIC_BLOCK_L0) ? s2L1RealSize - s2L1Offset : S2_BASIC_BLOCK_L0 - s2L1Offset; s2Mte2Size = s2BlkOffset + s2Mte2Size >= constInfo_.kCacheBlockSize ? constInfo_.kCacheBlockSize - s2BlkOffset : s2Mte2Size; Nd2NzParams nd2nzPara; nd2nzPara.ndNum = 1; nd2nzPara.nValue = s2Mte2Size; nd2nzPara.dValue = constInfo_.headDim; nd2nzPara.srcDValue = constInfo_.headDim; nd2nzPara.dstNzC0Stride = s2L1Offset >= S2_BASIC_BLOCK_L0 ? CeilAlign(s2L1RealSize - S2_BASIC_BLOCK_L0, (uint64_t)BLOCK_CUBE) : (s2L1RealSize > S2_BASIC_BLOCK_L0 ? S2_BASIC_BLOCK_L0 : CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE)); nd2nzPara.dstNzNStride = 1; nd2nzPara.srcNdMatrixStride = 0; nd2nzPara.dstNzMatrixStride = 0; DataCopy(keyL1_[(keyL1BufIdx_ % KEY_BUF_NUM) * KEY_BUFFER_OFFSET + (s2L1Offset >= S2_BASIC_BLOCK_L0 ? S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 + (s2L1Offset - S2_BASIC_BLOCK_L0) * BLOCK_CUBE : s2L1Offset * BLOCK_CUBE)], keyGm_[keyGmOffset], nd2nzPara); s2L1Offset += s2Mte2Size; } } // batch, s1, n2, g, d template __aicore__ inline void LIMatmul::QueryNd2Nz(uint64_t s1gL1RealSize, uint64_t s1gGmOffset, const LICommon::RunInfo &runInfo) { Nd2NzParams nd2nzPara; nd2nzPara.ndNum = 1; nd2nzPara.nValue = s1gL1RealSize; nd2nzPara.dValue = constInfo_.headDim; nd2nzPara.srcDValue = constInfo_.headDim; nd2nzPara.dstNzC0Stride = CeilAlign(s1gL1RealSize, (uint64_t)BLOCK_CUBE); nd2nzPara.dstNzNStride = 1; nd2nzPara.srcNdMatrixStride = 0; nd2nzPara.dstNzMatrixStride = 0; DataCopy(queryL1_[(queryL1Mte2BufIdx_ % QUERY_BUF_NUM) * QUERY_BUFFER_OFFSET], queryGm_[runInfo.tensorQueryOffset + s1gGmOffset * constInfo_.headDim], nd2nzPara); } template __aicore__ inline void LIMatmul::LoadQueryToL0a(uint64_t s1gGmOffset, uint64_t s1gL1Offset, uint64_t s1gL1RealSize, uint64_t s1gL0RealSize, const LICommon::RunInfo &runInfo) { LoadData3DParamsV2 loadData3DParams; // SetFmatrixParams loadData3DParams.l1H = CeilDiv(s1gL1RealSize, BLOCK_CUBE); // Hin=M1=8 loadData3DParams.l1W = BLOCK_CUBE; // Win=M0 loadData3DParams.channelSize = constInfo_.headDim; // Cin=K loadData3DParams.padList[0] = 0; loadData3DParams.padList[1] = 0; loadData3DParams.padList[2] = 0; loadData3DParams.padList[3] = 255; // SetLoadToA0Params loadData3DParams.mExtension = CeilAlign(s1gL0RealSize, BLOCK_CUBE); loadData3DParams.kExtension = constInfo_.headDim; loadData3DParams.mStartPt = s1gL1Offset; loadData3DParams.kStartPt = 0; loadData3DParams.strideW = 1; loadData3DParams.strideH = 1; loadData3DParams.filterW = 1; loadData3DParams.filterSizeW = (1 >> 8) & 255; loadData3DParams.filterH = 1; loadData3DParams.filterSizeH = (1 >> 8) & 255; loadData3DParams.dilationFilterW = 1; loadData3DParams.dilationFilterH = 1; loadData3DParams.enTranspose = 0; loadData3DParams.fMatrixCtrl = 0; LoadData(queryL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET], queryL1_[(queryL1Mte1BufIdx_ % QUERY_BUF_NUM) * QUERY_BUFFER_OFFSET], loadData3DParams); } template __aicore__ inline void LIMatmul::LoadKeyToL0b(uint64_t s2L1Offset, uint64_t s2L1RealSize, uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo) { uint64_t keyL1Offset = s2L1Offset >= S2_BASIC_BLOCK_L0 ? S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 : 0; LoadData2DParams loadData2DParams; loadData2DParams.startIndex = 0; loadData2DParams.repeatTimes = CeilDiv(s2L0RealSize, BLOCK_CUBE) * CeilDiv(constInfo_.headDim, BLOCK_CUBE); loadData2DParams.srcStride = 1; loadData2DParams.dstGap = 0; loadData2DParams.ifTranspose = false; LoadData(keyL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET], keyL1_[(keyL1BufIdx_ % KEY_BUF_NUM) * KEY_BUFFER_OFFSET + keyL1Offset], loadData2DParams); } template __aicore__ inline void LIMatmul::ComuteL0c(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo) { MmadParams mmadParams; mmadParams.m = CeilAlign(s1gL0RealSize, BLOCK_CUBE); mmadParams.n = s2L0RealSize; mmadParams.k = constInfo_.headDim; mmadParams.cmatrixInitVal = true; mmadParams.cmatrixSource = false; mmadParams.unitFlag = 0b11; Mmad(cL0_[(l0BufIdx_ % L0_BUF_NUM) * L0C_BUFFER_OFFSET], queryL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET], keyL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET], mmadParams); if ((mmadParams.m / 16) * (mmadParams.n / 16) < 10) { PipeBarrier(); } } template __aicore__ inline void LIMatmul::Fixp(uint64_t s1gGmOffset, uint64_t s2GmOffset, uint64_t s1gL0RealSize, uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo) { AscendC::DataCopyCO12DstParams intriParams; intriParams.mSize = CeilAlign(s1gL0RealSize, BLOCK_CUBE); intriParams.nSize = s2L0RealSize; intriParams.dstStride = runInfo.actualSingleProcessSInnerSizeAlign; intriParams.srcStride = CeilAlign(s1gL0RealSize, BLOCK_CUBE); // set mode according to dtype intriParams.quantPre = QuantMode_t::NoQuant; intriParams.nz2ndEn = true; intriParams.unitFlag = 0b11; // 3 unitflag intriParams.reluPre = 1; AscendC::SetFixpipeNz2ndFlag(1, 1, 1); AscendC::DataCopy(mm1ResGm_[(runInfo.loop % 2) * constInfo_.mBaseSize * constInfo_.s2BaseSize + s1gGmOffset * intriParams.dstStride + s2GmOffset], cL0_[(l0BufIdx_ % L0_BUF_NUM) * L0C_BUFFER_OFFSET], intriParams); } template __aicore__ inline void LIMatmul::AllocEventID() { SetMMLayoutTransform(true); SetFlag(KEY_MTE1_MTE2_EVENT + 0); SetFlag(KEY_MTE1_MTE2_EVENT + 1); SetFlag(KEY_MTE1_MTE2_EVENT + 2); SetFlag(QUERY_MTE1_MTE2_EVENT + 0); SetFlag(QUERY_MTE1_MTE2_EVENT + 1); SetFlag(M_MTE1_EVENT + 0); SetFlag(M_MTE1_EVENT + 1); } template __aicore__ inline void LIMatmul::FreeEventID() { SetMMLayoutTransform(false); WaitFlag(KEY_MTE1_MTE2_EVENT + 0); WaitFlag(KEY_MTE1_MTE2_EVENT + 1); WaitFlag(KEY_MTE1_MTE2_EVENT + 2); WaitFlag(QUERY_MTE1_MTE2_EVENT + 0); WaitFlag(QUERY_MTE1_MTE2_EVENT + 1); WaitFlag(M_MTE1_EVENT + 0); WaitFlag(M_MTE1_EVENT + 1); } } // namespace LIKernel #endif