[Model] GLM5 adaptation (#6642)

### What this PR does / why we need it?
GLM5 adaptation
1. use torch_npu.npu_lightning_indexer for GLM5
2. forbid eagle proposer when fullgraph mode is enabled because of bugs
3. add quatization config for GLM5
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM main:
978a37c823

---------

Signed-off-by: yydyzr <liuyuncong1@huawei.com>
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
yydyzr
2026-02-11 22:22:22 +08:00
committed by GitHub
parent 140fcaffc3
commit ff3a50d011
17 changed files with 77 additions and 34 deletions

View File

@@ -0,0 +1,135 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_common.h
* \brief
*/
#ifndef LIGHTNING_INDEXER_COMMON_H
#define LIGHTNING_INDEXER_COMMON_H
namespace LICommon {
enum class LI_LAYOUT {
BSND = 0,
TND = 1,
PA_BSND = 2
};
template <typename Q_T, typename K_T, typename OUT_T, const bool PAGE_ATTENTION = false,
LI_LAYOUT LAYOUT_T = LI_LAYOUT::BSND, LI_LAYOUT K_LAYOUT_T = LI_LAYOUT::PA_BSND, typename... Args>
struct LIType {
using queryType = Q_T;
using keyType = K_T;
using outputType = OUT_T;
static constexpr bool pageAttention = PAGE_ATTENTION;
static constexpr LI_LAYOUT layout = LAYOUT_T;
static constexpr LI_LAYOUT keyLayout = K_LAYOUT_T;
};
struct RunInfo {
uint32_t loop;
uint32_t bN2Idx;
uint32_t bIdx;
uint32_t n2Idx = 0;
uint32_t gS1Idx;
uint32_t s2Idx;
uint32_t actS1Size = 1;
uint32_t actS2Size = 1;
uint32_t actMBaseSize;
uint32_t actualSingleProcessSInnerSize;
uint32_t actualSingleProcessSInnerSizeAlign;
uint64_t tensorQueryOffset;
uint64_t tensorKeyOffset;
uint64_t tensorWeightsOffset;
uint64_t indiceOutOffset;
bool isFirstS2InnerLoop;
bool isLastS2InnerLoop;
bool isAllLoopEnd = false;
};
struct ConstInfo {
static constexpr uint32_t FIA_SYNC_MODE2 = 2;
static constexpr uint32_t BUFFER_SIZE_BYTE_32B = 32;
static constexpr uint32_t BUFFER_SIZE_BYTE_64B = 64;
static constexpr uint32_t BUFFER_SIZE_BYTE_256B = 256;
static constexpr uint32_t BUFFER_SIZE_BYTE_512B = 512;
static constexpr uint32_t BUFFER_SIZE_BYTE_1K = 1024;
static constexpr uint32_t BUFFER_SIZE_BYTE_2K = 2048;
static constexpr uint32_t BUFFER_SIZE_BYTE_4K = 4096;
static constexpr uint32_t BUFFER_SIZE_BYTE_8K = 8192;
static constexpr uint32_t BUFFER_SIZE_BYTE_16K = 16384;
static constexpr uint32_t BUFFER_SIZE_BYTE_32K = 32768;
static constexpr int INVALID_IDX = -1;
uint32_t syncC1V1 = 0U;
uint32_t syncV1C1 = 0U;
uint32_t mBaseSize = 1ULL;
uint32_t s1BaseSize = 1ULL;
uint32_t s2BaseSize = 1ULL;
uint64_t batchSize = 0ULL;
uint64_t gSize = 0ULL;
uint64_t qHeadNum = 0ULL;
uint64_t kHeadNum;
uint64_t headDim;
uint64_t sparseCount;
uint64_t kSeqSize = 0ULL;
uint64_t qSeqSize = 1ULL;
uint32_t kCacheBlockSize = 0;
uint32_t maxBlockNumPerBatch = 0;
LI_LAYOUT outputLayout;
bool attenMaskFlag = false;
uint32_t actualLenQDims = 0U;
uint32_t actualLenDims = 0U;
bool isAccumSeqS1 = false;
bool isAccumSeqS2 = false;
};
struct SplitCoreInfo {
uint32_t s2Start = 0U;
uint32_t s2End = 0U;
uint32_t bN2Start = 0U;
uint32_t bN2End = 0U;
uint32_t gS1Start = 0U;
uint32_t gS1End = 0U;
bool isLD = false;
};
template <typename T>
__aicore__ inline T Align(T num, T rnd)
{
return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd) * (rnd)));
}
template <typename T1, typename T2>
__aicore__ inline T1 Min(T1 a, T2 b)
{
return (a > b) ? (b) : (a);
}
template <typename T1, typename T2>
__aicore__ inline T1 Max(T1 a, T2 b)
{
return (a > b) ? (a) : (b);
}
template <typename T>
__aicore__ inline T CeilDiv(T num, T rnd)
{
return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd)));
}
} // namespace LICommon
#endif // LIGHTNING_INDEXER_COMMON_H

View File

@@ -0,0 +1,623 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_kernel.h
* \brief
*/
#ifndef LIGHTNING_INDEXER_KERNEL_H
#define LIGHTNING_INDEXER_KERNEL_H
#include "kernel_operator.h"
#include "kernel_operator_list_tensor_intf.h"
#include "kernel_tiling/kernel_tiling.h"
#include "lib/matmul_intf.h"
#include "lib/matrix/matmul/tiling.h"
#include "lightning_indexer_common.h"
#include "lightning_indexer_service_vector.h"
#include "lightning_indexer_service_cube.h"
namespace LIKernel {
using namespace LICommon;
using namespace LIServiceVec;
using namespace matmul;
using AscendC::CacheMode;
using AscendC::CrossCoreSetFlag;
using AscendC::CrossCoreWaitFlag;
struct TempLoopInfo {
uint32_t bN2Idx = 0;
uint32_t bIdx = 0U;
uint32_t n2Idx = 0U;
uint32_t gS1Idx = 0U;
uint32_t gS1LoopEnd = 0U;
uint32_t s2LoopEnd = 0U;
uint32_t actS1Size = 1ULL;
uint32_t actS2Size = 0ULL;
bool curActSeqLenIsZero = false;
bool needDealActS1LessThanS1 = false;
uint32_t actMBaseSize = 0U;
uint32_t mBasicSizeTail = 0U;
uint32_t s2BasicSizeTail = 0U;
};
template <typename LIT>
class LIPreload {
public:
__aicore__ inline LIPreload(){};
__aicore__ inline void Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths,
__gm__ uint8_t *blockTable, __gm__ uint8_t *sparseIndices, __gm__ uint8_t *workspace,
const LITilingData *__restrict tiling, TPipe *tPipe);
__aicore__ inline void Process();
using Q_T = typename LIT::queryType;
using K_T = typename LIT::keyType;
using OUT_T = typename LIT::outputType;
static constexpr bool PAGE_ATTENTION = LIT::pageAttention;
static constexpr LI_LAYOUT LAYOUT_T = LIT::layout;
static constexpr LI_LAYOUT K_LAYOUT_T = LIT::keyLayout;
using MM1_OUT_T = float;
LIMatmul<LIT> matmulService;
LIVector<LIT> vectorService;
static constexpr uint32_t SYNC_C1_V1_FLAG = 4;
static constexpr uint32_t SYNC_V1_C1_FLAG = 5;
static constexpr uint32_t M_BASE_SIZE = 512;
static constexpr uint32_t S2_BASE_SIZE = 512;
static constexpr uint32_t HEAD_DIM = 128;
static constexpr uint32_t K_HEAD_NUM = 1;
static constexpr uint32_t GM_ALIGN_BYTES = 512;
static constexpr int64_t LD_PREFETCH_LEN = 2;
// for workspace double
static constexpr uint32_t WS_DOBULE = 2;
protected:
TPipe *pipe = nullptr;
// offset
uint64_t queryCoreOffset = 0ULL;
uint64_t keyCoreOffset = 0ULL;
uint64_t weightsCoreOffset = 0ULL;
uint64_t indiceOutCoreOffset = 0ULL;
GlobalTensor<Q_T> queryGm;
GlobalTensor<K_T> keyGm;
GlobalTensor<K_T> weightsGm;
GlobalTensor<int32_t> indiceOutGm;
GlobalTensor<int32_t> blockTableGm;
GlobalTensor<uint32_t> actualSeqLengthsGmQ;
GlobalTensor<uint32_t> actualSeqLengthsGm;
// workspace
GlobalTensor<MM1_OUT_T> mm1ResGm;
GlobalTensor<float> vec1ResGm;
GlobalTensor<int64_t> vec1ParamGm;
// aic、aiv kernel info
uint32_t tmpBlockIdx = 0U;
uint32_t aiCoreIdx = 0U;
uint32_t usedCoreNum = 0U;
LICommon::ConstInfo constInfo{};
TempLoopInfo tempLoopInfo{};
LICommon::SplitCoreInfo splitCoreInfo{};
// ================================Init functions==================================
__aicore__ inline void InitTilingData(const LITilingData *__restrict tilingData);
__aicore__ inline void InitBuffers();
__aicore__ inline void InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths);
// ================================Split Core================================
__aicore__ inline void SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, LICommon::SplitCoreInfo &info);
__aicore__ inline uint32_t GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size, uint32_t actS2Size);
__aicore__ inline uint32_t GetTotalBaseBlockNum();
// ================================Process functions================================
__aicore__ inline void ProcessMain();
__aicore__ inline void ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx, LICommon::RunInfo &runInfo);
__aicore__ inline void ProcessDecode();
__aicore__ inline void ProcessInvalid();
// ================================Params Calc=====================================
__aicore__ inline void CalcGS1LoopParams(uint32_t bN2Idx);
__aicore__ inline void GetBN2Idx(uint32_t bN2Idx);
__aicore__ inline uint32_t GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq,
GlobalTensor<uint32_t> &actualSeqLengthsGm, uint32_t defaultSeqLen);
__aicore__ inline void GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size);
__aicore__ inline void CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx);
__aicore__ inline void CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LICommon::RunInfo &runInfo);
__aicore__ inline void DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start);
};
template <typename LIT>
__aicore__ inline void LIPreload<LIT>::InitTilingData(const LITilingData *__restrict tilingData)
{
usedCoreNum = tilingData->usedCoreNum;
constInfo.batchSize = tilingData->bSize;
constInfo.qHeadNum = constInfo.gSize = tilingData->gSize;
constInfo.kSeqSize = tilingData->s2Size;
constInfo.qSeqSize = tilingData->s1Size;
constInfo.attenMaskFlag = (tilingData->sparseMode == 3);
constInfo.kCacheBlockSize = tilingData->blockSize;
constInfo.maxBlockNumPerBatch = tilingData->maxBlockNumPerBatch;
constInfo.sparseCount = tilingData->sparseCount;
constInfo.outputLayout = LAYOUT_T;
if (LAYOUT_T == LI_LAYOUT::TND) {
constInfo.isAccumSeqS1 = true;
}
if (K_LAYOUT_T == LI_LAYOUT::TND) {
constInfo.isAccumSeqS2 = true;
}
constInfo.kHeadNum = K_HEAD_NUM;
constInfo.headDim = HEAD_DIM;
constInfo.mBaseSize = M_BASE_SIZE;
constInfo.s2BaseSize = S2_BASE_SIZE;
constInfo.s1BaseSize = (constInfo.mBaseSize + constInfo.gSize - 1) / constInfo.gSize;
}
template <typename LIT>
__aicore__ inline void LIPreload<LIT>::InitBuffers()
{
if ASCEND_IS_AIV {
vectorService.InitBuffers(pipe);
} else {
matmulService.InitBuffers(pipe);
}
}
template <typename LIT>
__aicore__ inline void LIPreload<LIT>::InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ,
__gm__ uint8_t *actualSeqLengths)
{
if (actualSeqLengthsQ == nullptr) {
constInfo.actualLenQDims = 0;
} else {
constInfo.actualLenQDims = constInfo.batchSize;
actualSeqLengthsGmQ.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengthsQ, constInfo.actualLenQDims);
}
if (actualSeqLengths == nullptr) {
constInfo.actualLenDims = 0;
} else {
constInfo.actualLenDims = constInfo.batchSize;
actualSeqLengthsGm.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengths, constInfo.actualLenDims);
}
}
template <typename LIT>
__aicore__ inline uint32_t LIPreload<LIT>::GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq,
GlobalTensor<uint32_t> &actualSeqLengthsGm,
uint32_t defaultSeqLen)
{
if (actualLenDims == 0) {
return defaultSeqLen;
} else if (isAccumSeq && bIdx > 0) {
return actualSeqLengthsGm.GetValue(bIdx) - actualSeqLengthsGm.GetValue(bIdx - 1);
} else {
return actualSeqLengthsGm.GetValue(bIdx);
}
}
template <typename LIT>
__aicore__ inline void LIPreload<LIT>::GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size)
{
actS1Size = GetActualSeqLen(bIdx, constInfo.actualLenQDims, constInfo.isAccumSeqS1, actualSeqLengthsGmQ,
constInfo.qSeqSize);
actS2Size =
GetActualSeqLen(bIdx, constInfo.actualLenDims, constInfo.isAccumSeqS2, actualSeqLengthsGm, constInfo.kSeqSize);
}
template <typename LIT>
__aicore__ inline uint32_t LIPreload<LIT>::GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size,
uint32_t actS2Size)
{
if (actS2Size == 0) {
return 0;
}
uint32_t s1Offset = constInfo.s1BaseSize * s1gIdx;
int32_t validS2LenBase = static_cast<int32_t>(actS2Size) - static_cast<int32_t>(actS1Size);
int32_t validS2Len = s1Offset + validS2LenBase + constInfo.s1BaseSize;
validS2Len = Min(validS2Len, static_cast<int32_t>(actS2Size));
validS2Len = Max(validS2Len, 1);
return (validS2Len + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
}
template <typename LIT>
__aicore__ inline uint32_t LIPreload<LIT>::GetTotalBaseBlockNum()
{
uint32_t totalBlockNum = 0;
uint32_t actS1Size, actS2Size;
uint32_t s1GBaseNum, s2BaseNum;
for (uint32_t bIdx = 0; bIdx < constInfo.batchSize; bIdx++) {
GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size);
s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize);
if (!constInfo.attenMaskFlag) {
s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize);
totalBlockNum += s1GBaseNum * s2BaseNum * constInfo.kHeadNum;
continue;
}
for (uint32_t s1gIdx = 0; s1gIdx < s1GBaseNum; s1gIdx++) {
s2BaseNum = GetS2BaseBlockNumOnMask(s1gIdx, actS1Size, actS2Size);
totalBlockNum += s2BaseNum * constInfo.kHeadNum;
}
}
return totalBlockNum;
}
template <typename LIT>
__aicore__ void inline LIPreload<LIT>::SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, LICommon::SplitCoreInfo &info)
{
uint32_t totalBlockNum = GetTotalBaseBlockNum();
uint32_t minBlockPerCore = totalBlockNum / coreNum;
uint32_t deal1MoreBlockCoreNum = totalBlockNum % coreNum;
uint32_t coreIdx = 0;
uint32_t lastGS1RemainBlockCnt = 0;
uint32_t coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore;
coreNum = minBlockPerCore == 0 ? deal1MoreBlockCoreNum : coreNum;
bool findLastCoreEnd = true;
uint32_t actS1Size, actS2Size;
uint32_t s1GBaseNum, s2BaseNum;
for (uint32_t bN2Idx = 0; bN2Idx < constInfo.batchSize * constInfo.kHeadNum; bN2Idx++) {
uint32_t bIdx = bN2Idx / constInfo.kHeadNum;
if (bN2Idx % constInfo.kHeadNum == 0) {
GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size);
s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize);
s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize);
}
if constexpr (LAYOUT_T == LI_LAYOUT::BSND) {
if (findLastCoreEnd && (s1GBaseNum == 0U || s2BaseNum == 0U)) {
info.bN2Start = bN2Idx;
info.gS1Start = 0;
info.s2Start = 0;
findLastCoreEnd = false;
}
}
for (uint32_t gS1Idx = 0; gS1Idx < s1GBaseNum; gS1Idx++) {
if (constInfo.attenMaskFlag) {
s2BaseNum = GetS2BaseBlockNumOnMask(gS1Idx, actS1Size, actS2Size);
}
if (findLastCoreEnd && s2BaseNum == 0U) {
info.bN2Start = bN2Idx;
info.gS1Start = gS1Idx;
info.s2Start = 0;
findLastCoreEnd = false;
}
for (uint32_t s2Idx = 0; s2Idx < s2BaseNum;) {
if (findLastCoreEnd) {
info.bN2Start = bN2Idx;
info.gS1Start = gS1Idx;
info.s2Start = s2Idx;
findLastCoreEnd = false;
}
uint32_t s2RemainBaseNum = s2BaseNum - s2Idx;
if (lastGS1RemainBlockCnt + s2RemainBaseNum >= coreDealBlockCnt) {
info.bN2End = bN2Idx;
info.gS1End = gS1Idx;
info.s2End = s2Idx + coreDealBlockCnt - lastGS1RemainBlockCnt - 1;
if (coreIdx == curCoreIdx) {
if (s2Idx == 0 && info.s2End + 1 < s2BaseNum) {
info.isLD = true;
}
if (coreIdx == coreNum - 1 && info.bN2End != constInfo.batchSize -1) {
info.bN2End = constInfo.batchSize -1;
info.gS1End = 0;
info.s2End = 0;
}
return;
}
coreIdx++;
findLastCoreEnd = true;
s2Idx = info.s2End + 1;
lastGS1RemainBlockCnt = 0;
coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore;
} else {
lastGS1RemainBlockCnt += s2RemainBaseNum;
break;
}
}
}
}
}
template <typename LIT>
__aicore__ inline void LIPreload<LIT>::DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start)
{
if ASCEND_IS_AIV {
if (constInfo.outputLayout == LI_LAYOUT::TND) {
uint32_t tSize = actualSeqLengthsGmQ.GetValue(constInfo.batchSize - 1);
uint32_t tBase = bIdx == 0 ? 0 : actualSeqLengthsGmQ.GetValue(bIdx - 1);
uint32_t s1Count = tempLoopInfo.actS1Size;
for (uint32_t s1Idx = s1Start; s1Idx < s1Count; s1Idx++) {
uint64_t indiceOutOffset =
(tBase + s1Idx) * constInfo.kHeadNum * constInfo.sparseCount +
n2Idx * constInfo.sparseCount;
vectorService.CleanInvalidOutput(indiceOutOffset);
}
} else if (constInfo.outputLayout == LI_LAYOUT::BSND) {
for (uint32_t s1Idx = s1Start; s1Idx < constInfo.qSeqSize; s1Idx++) {
// B,S1,N2,K
uint64_t indiceOutOffset = bIdx * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount +
s1Idx * constInfo.kHeadNum * constInfo.sparseCount +
n2Idx * constInfo.sparseCount;
vectorService.CleanInvalidOutput(indiceOutOffset);
}
}
}
}
template <typename LIT>
__aicore__ inline void LIPreload<LIT>::Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths,
__gm__ uint8_t *blockTable, __gm__ uint8_t *sparseIndices,
__gm__ uint8_t *workspace, const LITilingData *__restrict tiling,
TPipe *tPipe)
{
if ASCEND_IS_AIV {
tmpBlockIdx = GetBlockIdx(); // vec:0-47
aiCoreIdx = tmpBlockIdx / 2;
} else {
tmpBlockIdx = GetBlockIdx(); // cube:0-23
aiCoreIdx = tmpBlockIdx;
}
InitTilingData(tiling);
InitActualSeqLen(actualSeqLengthsQ, actualSeqLengths);
SplitCore(aiCoreIdx, usedCoreNum, splitCoreInfo);
pipe = tPipe;
uint64_t offset = 0;
uint64_t singleCoreMm1ResSize = WS_DOBULE * constInfo.mBaseSize * constInfo.s2BaseSize * sizeof(MM1_OUT_T);
mm1ResGm.SetGlobalBuffer((__gm__ MM1_OUT_T *)(workspace + offset + aiCoreIdx * singleCoreMm1ResSize));
offset += GetBlockNum() * singleCoreMm1ResSize;
vec1ResGm.SetGlobalBuffer((__gm__ float *)(workspace + offset));
offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * WS_DOBULE * BASE_TOPK * sizeof(float);
vec1ParamGm.SetGlobalBuffer((__gm__ int64_t *)(workspace + offset));
offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * LD_PARAM_NUM * sizeof(int64_t);
if ASCEND_IS_AIV {
vectorService.InitParams(constInfo, tiling);
indiceOutGm.SetGlobalBuffer((__gm__ int32_t *)sparseIndices);
weightsGm.SetGlobalBuffer((__gm__ K_T *)weights);
vectorService.InitVec1GlobalTensor(mm1ResGm, vec1ResGm, vec1ParamGm, weightsGm, indiceOutGm);
} else {
matmulService.InitParams(constInfo);
queryGm.SetGlobalBuffer((__gm__ Q_T *)query);
if constexpr (PAGE_ATTENTION) {
blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable);
}
keyGm.SetGlobalBuffer((__gm__ K_T *)key);
matmulService.InitMm1GlobalTensor(blockTableGm, keyGm, queryGm, mm1ResGm);
}
InitBuffers();
}
template <typename LIT>
__aicore__ inline void LIPreload<LIT>::GetBN2Idx(uint32_t bN2Idx)
{
tempLoopInfo.bN2Idx = bN2Idx;
tempLoopInfo.bIdx = bN2Idx / constInfo.kHeadNum;
tempLoopInfo.n2Idx = bN2Idx % constInfo.kHeadNum;
}
template <typename LIT>
__aicore__ inline void LIPreload<LIT>::CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx)
{
tempLoopInfo.gS1Idx = gS1LoopIdx;
tempLoopInfo.actMBaseSize = constInfo.mBaseSize;
uint32_t remainedGS1Size = tempLoopInfo.actS1Size * constInfo.gSize - tempLoopInfo.gS1Idx * constInfo.mBaseSize;
if (remainedGS1Size <= constInfo.mBaseSize && remainedGS1Size > 0) {
tempLoopInfo.actMBaseSize = tempLoopInfo.mBasicSizeTail;
}
bool isEnd = (bN2LoopIdx == splitCoreInfo.bN2End) && (gS1LoopIdx == splitCoreInfo.gS1End);
uint32_t s2BlockNum;
if (constInfo.attenMaskFlag) {
s2BlockNum = GetS2BaseBlockNumOnMask(gS1LoopIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size);
} else {
s2BlockNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
}
tempLoopInfo.s2LoopEnd = isEnd ? splitCoreInfo.s2End : s2BlockNum - 1;
}
template <typename LIT>
__aicore__ inline void LIPreload<LIT>::CalcGS1LoopParams(uint32_t bN2LoopIdx)
{
GetBN2Idx(bN2LoopIdx);
GetS1S2ActualSeqLen(tempLoopInfo.bIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size);
if ((tempLoopInfo.actS2Size == 0) || (tempLoopInfo.actS1Size == 0)) {
tempLoopInfo.curActSeqLenIsZero = true;
return;
}
tempLoopInfo.curActSeqLenIsZero = false;
tempLoopInfo.s2BasicSizeTail = tempLoopInfo.actS2Size % constInfo.s2BaseSize;
tempLoopInfo.s2BasicSizeTail =
(tempLoopInfo.s2BasicSizeTail == 0) ? constInfo.s2BaseSize : tempLoopInfo.s2BasicSizeTail;
tempLoopInfo.mBasicSizeTail = (tempLoopInfo.actS1Size * constInfo.gSize) % constInfo.mBaseSize;
tempLoopInfo.mBasicSizeTail =
(tempLoopInfo.mBasicSizeTail == 0) ? constInfo.mBaseSize : tempLoopInfo.mBasicSizeTail;
uint32_t gS1SplitNum = (tempLoopInfo.actS1Size * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize;
tempLoopInfo.gS1LoopEnd = (bN2LoopIdx == splitCoreInfo.bN2End) ? splitCoreInfo.gS1End : gS1SplitNum - 1;
if constexpr (LAYOUT_T == LI_LAYOUT::BSND) {
if (tempLoopInfo.gS1LoopEnd == gS1SplitNum - 1 && constInfo.qSeqSize > tempLoopInfo.actS1Size) {
tempLoopInfo.needDealActS1LessThanS1 = true;
}
}
}
template <typename LIT>
__aicore__ inline void LIPreload<LIT>::CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LICommon::RunInfo &runInfo)
{
runInfo.loop = loop;
runInfo.bIdx = tempLoopInfo.bIdx;
runInfo.gS1Idx = tempLoopInfo.gS1Idx;
runInfo.s2Idx = s2LoopIdx;
runInfo.bN2Idx = tempLoopInfo.bN2Idx;
runInfo.actS1Size = tempLoopInfo.actS1Size;
runInfo.actS2Size = tempLoopInfo.actS2Size;
runInfo.actMBaseSize = tempLoopInfo.actMBaseSize;
runInfo.actualSingleProcessSInnerSize = constInfo.s2BaseSize;
uint32_t s2SplitNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
if (runInfo.s2Idx == s2SplitNum - 1) {
runInfo.actualSingleProcessSInnerSize = tempLoopInfo.s2BasicSizeTail;
}
runInfo.actualSingleProcessSInnerSizeAlign =
LICommon::Align((uint32_t)runInfo.actualSingleProcessSInnerSize, LICommon::ConstInfo::BUFFER_SIZE_BYTE_32B);
runInfo.isFirstS2InnerLoop = s2LoopIdx == splitCoreInfo.s2Start;
runInfo.isLastS2InnerLoop = s2LoopIdx == tempLoopInfo.s2LoopEnd;
runInfo.isAllLoopEnd = (runInfo.bN2Idx == splitCoreInfo.bN2End) && (runInfo.gS1Idx == splitCoreInfo.gS1End) &&
(runInfo.s2Idx == splitCoreInfo.s2End);
if (runInfo.isFirstS2InnerLoop) {
uint64_t actualSeqQPrefixSum;
uint64_t actualSeqKPrefixSum;
if constexpr (LAYOUT_T == LI_LAYOUT::TND) {
actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGmQ.GetValue(runInfo.bIdx - 1);
actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGm.GetValue(runInfo.bIdx - 1);
} else { // BSND
actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.qSeqSize;
actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.kSeqSize;
}
uint64_t tndBIdxOffset = actualSeqQPrefixSum * constInfo.qHeadNum * constInfo.headDim;
uint64_t tndKeyBIdxOffset = actualSeqKPrefixSum * constInfo.kHeadNum * constInfo.headDim;
// B,S1,N1(N2,G),D
queryCoreOffset = tndBIdxOffset + runInfo.gS1Idx * constInfo.mBaseSize * constInfo.headDim;
keyCoreOffset = tndKeyBIdxOffset + runInfo.n2Idx * constInfo.headDim;
// B,S1,N1(N2,G)/T,N1(N2,G)
weightsCoreOffset = actualSeqQPrefixSum * constInfo.qHeadNum + runInfo.n2Idx * constInfo.gSize;
// B,S1,N2,k/T,N2,k
indiceOutCoreOffset = actualSeqQPrefixSum * constInfo.kHeadNum * constInfo.sparseCount +
runInfo.n2Idx * constInfo.sparseCount;
}
runInfo.tensorQueryOffset = queryCoreOffset;
runInfo.tensorKeyOffset = keyCoreOffset + runInfo.s2Idx * constInfo.s2BaseSize * constInfo.kHeadNum
* constInfo.headDim;
runInfo.tensorWeightsOffset = weightsCoreOffset;
runInfo.indiceOutOffset = indiceOutCoreOffset;
}
template <typename LIT>
__aicore__ inline void LIPreload<LIT>::Process()
{
if (usedCoreNum == 0) {
ProcessInvalid();
return;
}
ProcessMain();
ProcessDecode();
}
template <typename LIT>
__aicore__ inline void LIPreload<LIT>::ProcessInvalid()
{
if ASCEND_IS_AIV {
uint32_t aivCoreNum = GetBlockNum() * 2; // 2 means c:v = 1:2
uint64_t totalOutputSize =
constInfo.batchSize * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount;
uint64_t singleCoreSize =
LICommon::Align((totalOutputSize + aivCoreNum - 1) / aivCoreNum, GM_ALIGN_BYTES / sizeof(OUT_T));
uint64_t baseSize = tmpBlockIdx * singleCoreSize;
if (baseSize < totalOutputSize) {
uint64_t dealSize =
(baseSize + singleCoreSize > totalOutputSize) ? singleCoreSize : totalOutputSize - baseSize;
GlobalTensor<OUT_T> output = indiceOutGm[baseSize];
AscendC::InitGlobalMemory(output, dealSize, constInfo.INVALID_IDX);
}
}
}
template <typename LIT>
__aicore__ inline void LIPreload<LIT>::ProcessMain()
{
if (aiCoreIdx >= usedCoreNum) {
return;
}
if ASCEND_IS_AIV {
vectorService.AllocEventID();
CrossCoreSetFlag<LICommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE2>(constInfo.syncV1C1);
CrossCoreSetFlag<LICommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE2>(constInfo.syncV1C1);
} else {
matmulService.AllocEventID();
}
LICommon::RunInfo runInfo;
uint32_t gloop = 0;
for (uint32_t bN2LoopIdx = splitCoreInfo.bN2Start; bN2LoopIdx <= splitCoreInfo.bN2End; bN2LoopIdx++) {
CalcGS1LoopParams(bN2LoopIdx);
if (tempLoopInfo.curActSeqLenIsZero) {
DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, 0U);
continue;
}
for (uint32_t gS1LoopIdx = splitCoreInfo.gS1Start; gS1LoopIdx <= tempLoopInfo.gS1LoopEnd; gS1LoopIdx++) {
CalcS2LoopParams(bN2LoopIdx, gS1LoopIdx);
for (int s2LoopIdx = splitCoreInfo.s2Start; s2LoopIdx <= tempLoopInfo.s2LoopEnd; s2LoopIdx++) {
ProcessBaseBlock(gloop, s2LoopIdx, runInfo);
++gloop;
}
splitCoreInfo.s2Start = 0;
}
if (tempLoopInfo.needDealActS1LessThanS1) {
DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, tempLoopInfo.actS1Size);
}
splitCoreInfo.gS1Start = 0;
}
if ASCEND_IS_AIV {
vectorService.FreeEventID();
} else {
matmulService.FreeEventID();
CrossCoreWaitFlag(constInfo.syncV1C1);
CrossCoreWaitFlag(constInfo.syncV1C1);
}
}
template <typename LIT>
__aicore__ inline void LIPreload<LIT>::ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx, LICommon::RunInfo &runInfo)
{
CalcRunInfo(loop, s2LoopIdx, runInfo);
if ASCEND_IS_AIC {
CrossCoreWaitFlag(constInfo.syncV1C1);
matmulService.ComputeMm1(runInfo);
CrossCoreSetFlag<LICommon::ConstInfo::FIA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V1);
} else {
CrossCoreWaitFlag(constInfo.syncC1V1);
vectorService.ProcessVec(runInfo);
CrossCoreSetFlag<LICommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE2>(constInfo.syncV1C1);
}
}
template <typename LIT>
__aicore__ inline void LIPreload<LIT>::ProcessDecode()
{
if ASCEND_IS_AIV {
vectorService.InitLDBuffers(pipe);
ICachePreLoad(LD_PREFETCH_LEN);
SyncAll();
if (splitCoreInfo.isLD) {
vectorService.ProcessLD();
}
}
}
} // namespace LIKernel
#endif // LIGHTNING_INDEXER_KERNEL_H

View File

@@ -0,0 +1,415 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_service_cube.h
* \brief use 5 buffer for matmul l1, better pipeline
*/
#ifndef LIGHTNING_INDEXER_SERVICE_CUBE_H
#define LIGHTNING_INDEXER_SERVICE_CUBE_H
#include "kernel_operator.h"
#include "kernel_operator_list_tensor_intf.h"
#include "kernel_tiling/kernel_tiling.h"
#include "lib/matmul_intf.h"
#include "lib/matrix/matmul/tiling.h"
#include "lightning_indexer_common.h"
namespace LIKernel {
using namespace LICommon;
template <typename LIT>
class LIMatmul {
public:
using Q_T = typename LIT::queryType;
using K_T = typename LIT::keyType;
__aicore__ inline LIMatmul(){};
__aicore__ inline void InitBuffers(TPipe *pipe);
__aicore__ inline void InitMm1GlobalTensor(const GlobalTensor<int32_t> &blkTableGm, const GlobalTensor<K_T> &keyGm,
const GlobalTensor<Q_T> &queryGm, const GlobalTensor<float> &mm1ResGm);
__aicore__ inline void InitParams(const ConstInfo &constInfo);
__aicore__ inline void AllocEventID();
__aicore__ inline void FreeEventID();
__aicore__ inline void ComputeMm1(const LICommon::RunInfo &runInfo);
static constexpr IsResetLoad3dConfig LOAD3DV2_CONFIG = {true, true}; // isSetFMatrix isSetPadding;
static constexpr uint64_t KEY_BUF_NUM = 3;
static constexpr uint64_t QUERY_BUF_NUM = 2;
static constexpr uint64_t L0_BUF_NUM = 2;
static constexpr uint32_t KEY_MTE1_MTE2_EVENT = EVENT_ID2;
static constexpr uint32_t QUERY_MTE1_MTE2_EVENT = EVENT_ID5; // KEY_MTE1_MTE2_EVENT + KEY_BUF_NUM;
static constexpr uint32_t M_MTE1_EVENT = EVENT_ID3;
static constexpr uint32_t MTE2_MTE1_EVENT = EVENT_ID2;
static constexpr uint32_t MTE1_M_EVENT = EVENT_ID2;
static constexpr uint64_t M_BASIC_BLOCK = 256;
static constexpr uint64_t D_BASIC_BLOCK = 128;
static constexpr uint64_t S2_BASIC_BLOCK = 256;
static constexpr uint64_t M_BASIC_BLOCK_L0 = 128;
static constexpr uint64_t D_BASIC_BLOCK_L0 = 128;
static constexpr uint64_t S2_BASIC_BLOCK_L0 = 128;
static constexpr uint64_t QUERY_BUFFER_OFFSET = M_BASIC_BLOCK * D_BASIC_BLOCK;
static constexpr uint64_t KEY_BUFFER_OFFSET = S2_BASIC_BLOCK * D_BASIC_BLOCK;
static constexpr uint64_t L0AB_BUFFER_OFFSET = M_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0;
static constexpr uint64_t L0C_BUFFER_OFFSET = M_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0;
protected:
__aicore__ inline void Fixp(uint64_t s1gGmOffset, uint64_t s2GmOffset, uint64_t s1gL0RealSize,
uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo);
__aicore__ inline void ComuteL0c(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo);
__aicore__ inline void LoadKeyToL0b(uint64_t s2L0Offset, uint64_t s2L1RealSize, uint64_t s2L0RealSize,
const LICommon::RunInfo &runInfo);
__aicore__ inline void LoadQueryToL0a(uint64_t s1gL1Offset, uint64_t s1gL0Offset, uint64_t s1gL1RealSize,
uint64_t s1gL0RealSize, const LICommon::RunInfo &runInfo);
__aicore__ inline void QueryNd2Nz(uint64_t s1gL1RealSize, uint64_t s1gL1Offset, const LICommon::RunInfo &runInfo);
__aicore__ inline void KeyNd2Nz(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LICommon::RunInfo &runInfo);
__aicore__ inline void KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LICommon::RunInfo &runInfo);
GlobalTensor<int32_t> blkTableGm_;
GlobalTensor<K_T> keyGm_;
GlobalTensor<Q_T> queryGm_;
GlobalTensor<float> mm1ResGm_;
TBuf<TPosition::A1> bufQL1_;
LocalTensor<Q_T> queryL1_;
TBuf<TPosition::B1> bufKeyL1_;
LocalTensor<K_T> keyL1_;
TBuf<TPosition::A2> bufQL0_;
LocalTensor<Q_T> queryL0_;
TBuf<TPosition::B2> bufKeyL0_;
LocalTensor<K_T> keyL0_;
TBuf<TPosition::CO1> bufL0C_;
LocalTensor<float> cL0_;
uint64_t keyL1BufIdx_ = 0;
uint64_t queryL1Mte2BufIdx_ = 0;
uint64_t queryL1Mte1BufIdx_ = 0;
uint64_t l0BufIdx_ = 0;
ConstInfo constInfo_;
private:
static constexpr bool PAGE_ATTENTION = LIT::pageAttention;
};
template <typename LIT>
__aicore__ inline void LIMatmul<LIT>::InitParams(const ConstInfo &constInfo)
{
constInfo_ = constInfo;
}
template <typename LIT>
__aicore__ inline void LIMatmul<LIT>::InitBuffers(TPipe *pipe)
{
pipe->InitBuffer(bufQL1_, QUERY_BUF_NUM * M_BASIC_BLOCK * D_BASIC_BLOCK * sizeof(Q_T));
queryL1_ = bufQL1_.Get<Q_T>();
pipe->InitBuffer(bufKeyL1_, KEY_BUF_NUM * S2_BASIC_BLOCK * D_BASIC_BLOCK * sizeof(K_T));
keyL1_ = bufKeyL1_.Get<K_T>();
pipe->InitBuffer(bufQL0_, L0_BUF_NUM * M_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 * sizeof(Q_T));
queryL0_ = bufQL0_.Get<Q_T>();
pipe->InitBuffer(bufKeyL0_, L0_BUF_NUM * D_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0 * sizeof(K_T));
keyL0_ = bufKeyL0_.Get<K_T>();
pipe->InitBuffer(bufL0C_, L0_BUF_NUM * M_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0 * sizeof(float));
cL0_ = bufL0C_.Get<float>();
}
template <typename LIT>
__aicore__ inline void
LIMatmul<LIT>::InitMm1GlobalTensor(const GlobalTensor<int32_t> &blkTableGm, const GlobalTensor<K_T> &keyGm,
const GlobalTensor<Q_T> &queryGm, const GlobalTensor<float> &mm1ResGm)
{
blkTableGm_ = blkTableGm;
keyGm_ = keyGm;
queryGm_ = queryGm;
mm1ResGm_ = mm1ResGm;
}
template <typename LIT>
__aicore__ inline void LIMatmul<LIT>::ComputeMm1(const LICommon::RunInfo &runInfo)
{
uint64_t s2GmBaseOffset = runInfo.s2Idx * constInfo_.s2BaseSize;
uint64_t s1gProcessSize = runInfo.actMBaseSize;
uint64_t s2ProcessSize = runInfo.actualSingleProcessSInnerSize;
for (uint64_t s2GmOffset = 0; s2GmOffset < s2ProcessSize; s2GmOffset += S2_BASIC_BLOCK) {
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % KEY_BUF_NUM);
uint64_t s2L1RealSize =
s2GmOffset + S2_BASIC_BLOCK > s2ProcessSize ? s2ProcessSize - s2GmOffset : S2_BASIC_BLOCK;
if (PAGE_ATTENTION) {
KeyNd2NzForPA(s2L1RealSize, s2GmBaseOffset + s2GmOffset, runInfo);
}else {
KeyNd2Nz(s2L1RealSize, s2GmOffset, runInfo);
}
SetFlag<HardEvent::MTE2_MTE1>(MTE2_MTE1_EVENT);
WaitFlag<HardEvent::MTE2_MTE1>(MTE2_MTE1_EVENT);
for (uint64_t s1gGmOffset = 0; s1gGmOffset < s1gProcessSize; s1gGmOffset += M_BASIC_BLOCK) {
uint64_t s1gL1RealSize =
s1gGmOffset + M_BASIC_BLOCK > s1gProcessSize ? s1gProcessSize - s1gGmOffset : M_BASIC_BLOCK;
if (runInfo.isFirstS2InnerLoop && s2GmOffset == 0) {
queryL1Mte2BufIdx_++;
queryL1Mte1BufIdx_ = queryL1Mte2BufIdx_;
WaitFlag<HardEvent::MTE1_MTE2>(QUERY_MTE1_MTE2_EVENT + queryL1Mte2BufIdx_ % QUERY_BUF_NUM);
QueryNd2Nz(s1gL1RealSize, s1gGmOffset, runInfo);
SetFlag<HardEvent::MTE2_MTE1>(MTE2_MTE1_EVENT);
WaitFlag<HardEvent::MTE2_MTE1>(MTE2_MTE1_EVENT);
} else {
queryL1Mte1BufIdx_ =
queryL1Mte2BufIdx_ - (CeilDiv(s1gProcessSize, M_BASIC_BLOCK) - 1 - (s1gGmOffset > 0));
}
for (uint64_t s2L1Offset = 0; s2L1Offset < s2L1RealSize; s2L1Offset += S2_BASIC_BLOCK_L0) {
uint64_t s2L0RealSize =
s2L1Offset + S2_BASIC_BLOCK_L0 > s2L1RealSize ? s2L1RealSize - s2L1Offset : S2_BASIC_BLOCK_L0;
for (uint64_t s1gL1Offset = 0; s1gL1Offset < s1gL1RealSize; s1gL1Offset += M_BASIC_BLOCK_L0) {
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + l0BufIdx_ % L0_BUF_NUM);
uint64_t s1gL0RealSize =
s1gL1Offset + M_BASIC_BLOCK_L0 > s1gL1RealSize ? s1gL1RealSize - s1gL1Offset : M_BASIC_BLOCK_L0;
LoadQueryToL0a(s1gGmOffset, s1gL1Offset, s1gL1RealSize, s1gL0RealSize, runInfo);
LoadKeyToL0b(s2L1Offset, s2L1RealSize, s2L0RealSize, runInfo);
SetFlag<HardEvent::MTE1_M>(MTE1_M_EVENT);
WaitFlag<HardEvent::MTE1_M>(MTE1_M_EVENT);
ComuteL0c(s1gL0RealSize, s2L0RealSize, runInfo);
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + l0BufIdx_ % L0_BUF_NUM);
Fixp(s1gGmOffset + s1gL1Offset, s2GmOffset + s2L1Offset, s1gL0RealSize, s2L0RealSize, runInfo);
l0BufIdx_++;
}
}
if (s2GmOffset + S2_BASIC_BLOCK >= s2ProcessSize && runInfo.isLastS2InnerLoop) {
SetFlag<HardEvent::MTE1_MTE2>(QUERY_MTE1_MTE2_EVENT + queryL1Mte1BufIdx_ % QUERY_BUF_NUM);
}
}
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % KEY_BUF_NUM);
keyL1BufIdx_++;
}
}
template <typename LIT>
__aicore__ inline void LIMatmul<LIT>::KeyNd2Nz(uint64_t s2L1RealSize, uint64_t s2GmOffset,
const LICommon::RunInfo &runInfo)
{
uint64_t s2L1Offset = 0;
while (s2L1Offset < s2L1RealSize) {
uint64_t keyGmOffset = runInfo.tensorKeyOffset + (s2GmOffset + s2L1Offset) * constInfo_.headDim;
uint64_t s2Mte2Size = (s2L1RealSize <= S2_BASIC_BLOCK_L0 || s2L1Offset >= S2_BASIC_BLOCK_L0) ?
s2L1RealSize - s2L1Offset :
S2_BASIC_BLOCK_L0 - s2L1Offset;
Nd2NzParams nd2nzPara;
nd2nzPara.ndNum = 1;
nd2nzPara.nValue = s2Mte2Size; // 行数
nd2nzPara.dValue = constInfo_.headDim;
nd2nzPara.srcDValue = constInfo_.headDim;
nd2nzPara.dstNzC0Stride = s2L1Offset >= S2_BASIC_BLOCK_L0 ?
CeilAlign(s2L1RealSize - S2_BASIC_BLOCK_L0, (uint64_t)BLOCK_CUBE) :
(s2L1RealSize > S2_BASIC_BLOCK_L0 ?
S2_BASIC_BLOCK_L0 :
CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE));
nd2nzPara.dstNzNStride = 1;
nd2nzPara.srcNdMatrixStride = 0;
nd2nzPara.dstNzMatrixStride = 0;
DataCopy(keyL1_[(keyL1BufIdx_ % KEY_BUF_NUM) * KEY_BUFFER_OFFSET +
(s2L1Offset >= S2_BASIC_BLOCK_L0 ?
S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 + (s2L1Offset - S2_BASIC_BLOCK_L0) * BLOCK_CUBE :
s2L1Offset * BLOCK_CUBE)],
keyGm_[keyGmOffset], nd2nzPara);
s2L1Offset += s2Mte2Size;
}
}
// blkNum, blkSize, N2, D
template <typename LIT>
__aicore__ inline void LIMatmul<LIT>::KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset,
const LICommon::RunInfo &runInfo)
{
uint64_t s2L1Offset = 0;
while (s2L1Offset < s2L1RealSize) {
uint64_t s2BlkId = (s2L1Offset + s2GmOffset) / constInfo_.kCacheBlockSize;
uint64_t s2BlkOffset = (s2L1Offset + s2GmOffset) % constInfo_.kCacheBlockSize;
uint64_t keyGmOffset = blkTableGm_.GetValue(runInfo.bIdx * constInfo_.maxBlockNumPerBatch + s2BlkId) *
constInfo_.kCacheBlockSize * constInfo_.kHeadNum * constInfo_.headDim +
s2BlkOffset * constInfo_.headDim;
uint64_t s2Mte2Size = (s2L1RealSize <= S2_BASIC_BLOCK_L0 || s2L1Offset >= S2_BASIC_BLOCK_L0) ?
s2L1RealSize - s2L1Offset :
S2_BASIC_BLOCK_L0 - s2L1Offset;
s2Mte2Size = s2BlkOffset + s2Mte2Size >= constInfo_.kCacheBlockSize ? constInfo_.kCacheBlockSize - s2BlkOffset :
s2Mte2Size;
Nd2NzParams nd2nzPara;
nd2nzPara.ndNum = 1;
nd2nzPara.nValue = s2Mte2Size;
nd2nzPara.dValue = constInfo_.headDim;
nd2nzPara.srcDValue = constInfo_.headDim;
nd2nzPara.dstNzC0Stride = s2L1Offset >= S2_BASIC_BLOCK_L0 ?
CeilAlign(s2L1RealSize - S2_BASIC_BLOCK_L0, (uint64_t)BLOCK_CUBE) :
(s2L1RealSize > S2_BASIC_BLOCK_L0 ?
S2_BASIC_BLOCK_L0 :
CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE));
nd2nzPara.dstNzNStride = 1;
nd2nzPara.srcNdMatrixStride = 0;
nd2nzPara.dstNzMatrixStride = 0;
DataCopy(keyL1_[(keyL1BufIdx_ % KEY_BUF_NUM) * KEY_BUFFER_OFFSET +
(s2L1Offset >= S2_BASIC_BLOCK_L0 ?
S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 + (s2L1Offset - S2_BASIC_BLOCK_L0) * BLOCK_CUBE :
s2L1Offset * BLOCK_CUBE)],
keyGm_[keyGmOffset], nd2nzPara);
s2L1Offset += s2Mte2Size;
}
}
// batch, s1, n2, g, d
template <typename LIT>
__aicore__ inline void LIMatmul<LIT>::QueryNd2Nz(uint64_t s1gL1RealSize, uint64_t s1gGmOffset,
const LICommon::RunInfo &runInfo)
{
Nd2NzParams nd2nzPara;
nd2nzPara.ndNum = 1;
nd2nzPara.nValue = s1gL1RealSize;
nd2nzPara.dValue = constInfo_.headDim;
nd2nzPara.srcDValue = constInfo_.headDim;
nd2nzPara.dstNzC0Stride = CeilAlign(s1gL1RealSize, (uint64_t)BLOCK_CUBE);
nd2nzPara.dstNzNStride = 1;
nd2nzPara.srcNdMatrixStride = 0;
nd2nzPara.dstNzMatrixStride = 0;
DataCopy(queryL1_[(queryL1Mte2BufIdx_ % QUERY_BUF_NUM) * QUERY_BUFFER_OFFSET],
queryGm_[runInfo.tensorQueryOffset + s1gGmOffset * constInfo_.headDim], nd2nzPara);
}
template <typename LIT>
__aicore__ inline void LIMatmul<LIT>::LoadQueryToL0a(uint64_t s1gGmOffset, uint64_t s1gL1Offset, uint64_t s1gL1RealSize,
uint64_t s1gL0RealSize, const LICommon::RunInfo &runInfo)
{
LoadData3DParamsV2<Q_T> loadData3DParams;
// SetFmatrixParams
loadData3DParams.l1H = CeilDiv(s1gL1RealSize, BLOCK_CUBE); // Hin=M1=8
loadData3DParams.l1W = BLOCK_CUBE; // Win=M0
loadData3DParams.channelSize = constInfo_.headDim; // Cin=K
loadData3DParams.padList[0] = 0;
loadData3DParams.padList[1] = 0;
loadData3DParams.padList[2] = 0;
loadData3DParams.padList[3] = 255;
// SetLoadToA0Params
loadData3DParams.mExtension = CeilAlign(s1gL0RealSize, BLOCK_CUBE);
loadData3DParams.kExtension = constInfo_.headDim;
loadData3DParams.mStartPt = s1gL1Offset;
loadData3DParams.kStartPt = 0;
loadData3DParams.strideW = 1;
loadData3DParams.strideH = 1;
loadData3DParams.filterW = 1;
loadData3DParams.filterSizeW = (1 >> 8) & 255;
loadData3DParams.filterH = 1;
loadData3DParams.filterSizeH = (1 >> 8) & 255;
loadData3DParams.dilationFilterW = 1;
loadData3DParams.dilationFilterH = 1;
loadData3DParams.enTranspose = 0;
loadData3DParams.fMatrixCtrl = 0;
LoadData<Q_T, LOAD3DV2_CONFIG>(queryL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET],
queryL1_[(queryL1Mte1BufIdx_ % QUERY_BUF_NUM) * QUERY_BUFFER_OFFSET],
loadData3DParams);
}
template <typename LIT>
__aicore__ inline void LIMatmul<LIT>::LoadKeyToL0b(uint64_t s2L1Offset, uint64_t s2L1RealSize, uint64_t s2L0RealSize,
const LICommon::RunInfo &runInfo)
{
uint64_t keyL1Offset = s2L1Offset >= S2_BASIC_BLOCK_L0 ? S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 : 0;
LoadData2DParams loadData2DParams;
loadData2DParams.startIndex = 0;
loadData2DParams.repeatTimes = CeilDiv(s2L0RealSize, BLOCK_CUBE) * CeilDiv(constInfo_.headDim, BLOCK_CUBE);
loadData2DParams.srcStride = 1;
loadData2DParams.dstGap = 0;
loadData2DParams.ifTranspose = false;
LoadData(keyL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET],
keyL1_[(keyL1BufIdx_ % KEY_BUF_NUM) * KEY_BUFFER_OFFSET + keyL1Offset], loadData2DParams);
}
template <typename LIT>
__aicore__ inline void LIMatmul<LIT>::ComuteL0c(uint64_t s1gL0RealSize, uint64_t s2L0RealSize,
const LICommon::RunInfo &runInfo)
{
MmadParams mmadParams;
mmadParams.m = CeilAlign(s1gL0RealSize, BLOCK_CUBE);
mmadParams.n = s2L0RealSize;
mmadParams.k = constInfo_.headDim;
mmadParams.cmatrixInitVal = true;
mmadParams.cmatrixSource = false;
mmadParams.unitFlag = 0b11;
Mmad(cL0_[(l0BufIdx_ % L0_BUF_NUM) * L0C_BUFFER_OFFSET], queryL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET],
keyL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET], mmadParams);
if ((mmadParams.m / 16) * (mmadParams.n / 16) < 10) {
PipeBarrier<PIPE_M>();
}
}
template <typename LIT>
__aicore__ inline void LIMatmul<LIT>::Fixp(uint64_t s1gGmOffset, uint64_t s2GmOffset, uint64_t s1gL0RealSize,
uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo)
{
AscendC::DataCopyCO12DstParams intriParams;
intriParams.mSize = CeilAlign(s1gL0RealSize, BLOCK_CUBE);
intriParams.nSize = s2L0RealSize;
intriParams.dstStride = runInfo.actualSingleProcessSInnerSizeAlign;
intriParams.srcStride = CeilAlign(s1gL0RealSize, BLOCK_CUBE);
// set mode according to dtype
intriParams.quantPre = QuantMode_t::NoQuant;
intriParams.nz2ndEn = true;
intriParams.unitFlag = 0b11; // 3 unitflag
intriParams.reluPre = 1;
AscendC::SetFixpipeNz2ndFlag(1, 1, 1);
AscendC::DataCopy(mm1ResGm_[(runInfo.loop % 2) * constInfo_.mBaseSize * constInfo_.s2BaseSize +
s1gGmOffset * intriParams.dstStride + s2GmOffset],
cL0_[(l0BufIdx_ % L0_BUF_NUM) * L0C_BUFFER_OFFSET], intriParams);
}
template <typename LIT>
__aicore__ inline void LIMatmul<LIT>::AllocEventID()
{
SetMMLayoutTransform(true);
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 0);
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 1);
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 2);
SetFlag<HardEvent::MTE1_MTE2>(QUERY_MTE1_MTE2_EVENT + 0);
SetFlag<HardEvent::MTE1_MTE2>(QUERY_MTE1_MTE2_EVENT + 1);
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 0);
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 1);
}
template <typename LIT>
__aicore__ inline void LIMatmul<LIT>::FreeEventID()
{
SetMMLayoutTransform(false);
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 0);
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 1);
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 2);
WaitFlag<HardEvent::MTE1_MTE2>(QUERY_MTE1_MTE2_EVENT + 0);
WaitFlag<HardEvent::MTE1_MTE2>(QUERY_MTE1_MTE2_EVENT + 1);
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 0);
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 1);
}
} // namespace LIKernel
#endif

View File

@@ -0,0 +1,559 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_service_vector.h
* \brief
*/
#ifndef LIGHTNING_INDEXER_SERVICE_VECTOR_H
#define LIGHTNING_INDEXER_SERVICE_VECTOR_H
#include "kernel_operator.h"
#include "kernel_operator_list_tensor_intf.h"
#include "kernel_tiling/kernel_tiling.h"
#include "lib/matmul_intf.h"
#include "lib/matrix/matmul/tiling.h"
#include "lightning_indexer_common.h"
#include "lightning_indexer_vector.h"
namespace LIKernel {
using namespace LICommon;
using namespace LIServiceVec;
constexpr uint32_t BASE_TOPK = 2048;
constexpr uint32_t LD_PARAM_NUM = 16;
template <typename LIT>
class LIVector {
public:
using K_T = typename LIT::keyType;
static constexpr LI_LAYOUT LAYOUT_T = LIT::layout;
using MM1_OUT_T = float;
__aicore__ inline LIVector(){};
__aicore__ inline void ProcessVec(const LICommon::RunInfo &info);
__aicore__ inline void ProcessLD();
__aicore__ inline void InitBuffers(TPipe *pipe);
__aicore__ inline void InitParams(const struct LICommon::ConstInfo &constInfo,
const LITilingData *__restrict tilingData);
__aicore__ inline void InitVec1GlobalTensor(GlobalTensor<MM1_OUT_T> mm1ResGm, GlobalTensor<float> vec1ResGm,
GlobalTensor<int64_t> vec1ParamGm, GlobalTensor<K_T> weightsGm,
GlobalTensor<int32_t> indiceOutGm);
__aicore__ inline void CleanInvalidOutput(int64_t invalidS1offset);
__aicore__ inline void AllocEventID();
__aicore__ inline void FreeEventID();
__aicore__ inline void InitLDBuffers(TPipe *pipe);
protected:
GlobalTensor<MM1_OUT_T> mm1ResGm;
GlobalTensor<float> vec1ResGm;
GlobalTensor<int64_t> vec1ParamGm;
GlobalTensor<K_T> weightsGm;
GlobalTensor<int32_t> indiceOutGm;
private:
// queue
TQue<QuePosition::VECIN, 1> inQueue_;
TQue<QuePosition::VECOUT, 1> outQueue_;
// tmp buff for vector
TBuf<TPosition::VECCALC> sortOutBuf_;
TBuf<TPosition::VECCALC> indexBuf_;
TBuf<TPosition::VECCALC> reduceOutBuf_;
TBuf<TPosition::VECCALC> brcBuf_;
TBuf<TPosition::VECCALC> paramBuf_;
// tmp buff for LD
TBuf<> ldToBeMrgBuf_;
TBuf<> ldTmpBuf_;
TBuf<> ldOutValueBuf_;
TBuf<> ldOutIdxBuf_;
LocalTensor<int32_t> globalTopkIndice_;
LocalTensor<float> globalTopkUb_;
LocalTensor<float> SortedBasicBlock_;
int32_t blockId_ = -1;
// para for vector
int32_t groupInner_ = 0;
int32_t globalTopkNum_ = 0;
int64_t blockS2StartIdx_ = 0;
int32_t gSize_ = 0;
int32_t kHeadNum_ = 0;
int32_t s1BaseSize_ = 0;
int32_t s2BaseSize_ = 0;
// para for LD
uint32_t mrgListNum_ = 4;
uint32_t paramNum_ = 16;
constexpr static uint32_t REDUCE_BANK_CONFLICT_OFFSETS = 256;
constexpr static uint32_t REDUCE_BANK_CONFLICT_NUM = REDUCE_BANK_CONFLICT_OFFSETS / sizeof(float);
struct LICommon::ConstInfo constInfo_;
};
template <typename LIT>
__aicore__ inline void LIVector<LIT>::InitBuffers(TPipe *pipe)
{
uint32_t outNeedBufSize = (BASE_TOPK * 2) * 2 * sizeof(float);
uint32_t reduceCacheSize = REDUCE_BANK_CONFLICT_OFFSETS + groupInner_ * s2BaseSize_ * sizeof(float);
outNeedBufSize = reduceCacheSize > outNeedBufSize ? reduceCacheSize : outNeedBufSize;
pipe->InitBuffer(inQueue_, 2,
groupInner_ * s2BaseSize_ * sizeof(float) + s2BaseSize_ * sizeof(float)); // 69KB mm_out_ub
pipe->InitBuffer(outQueue_, 1, outNeedBufSize); // 32KB extract
pipe->InitBuffer(sortOutBuf_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK * 2 * sizeof(float)); // 64KB
pipe->InitBuffer(indexBuf_, s2BaseSize_ * sizeof(int32_t)); // 2KB
pipe->InitBuffer(reduceOutBuf_, s2BaseSize_ * 2 * sizeof(float)); // 4KB
pipe->InitBuffer(brcBuf_, groupInner_ * 8 * sizeof(float));
pipe->InitBuffer(paramBuf_, LD_PARAM_NUM * sizeof(int64_t));
//
globalTopkIndice_ = indexBuf_.Get<int32_t>();
globalTopkUb_ = sortOutBuf_.Get<float>();
SortedBasicBlock_ = globalTopkUb_[BASE_TOPK * 2 * 2];
globalTopkNum_ = 0;
ArithProgression<int32_t>(globalTopkIndice_, 0, 1, s2BaseSize_);
InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK * 2);
LocalTensor<float> tmpfBuff = outQueue_.AllocTensor<float>();
Duplicate(tmpfBuff.template ReinterpretCast<int32_t>(), -1, 2 * (s1BaseSize_ / 2) * paramNum_ * 2);
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
int64_t wsInfoOffset = (blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ +
(blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_;
DataCopyPad(vec1ParamGm[wsInfoOffset], tmpfBuff.template ReinterpretCast<int64_t>(),
{1, static_cast<uint16_t>((s1BaseSize_ / 2) * 2 * paramNum_ * sizeof(int64_t)), 0, 0});
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
outQueue_.FreeTensor(tmpfBuff);
}
template <typename LIT>
__aicore__ inline void LIVector<LIT>::InitLDBuffers(TPipe *pipe)
{
pipe->Reset();
pipe->InitBuffer(ldToBeMrgBuf_, 2 * BASE_TOPK * mrgListNum_ * sizeof(float)); // 2value + index
pipe->InitBuffer(ldTmpBuf_, 2 * BASE_TOPK * mrgListNum_ * sizeof(float)); // 2value + index
pipe->InitBuffer(ldOutValueBuf_, BASE_TOPK * sizeof(float));
pipe->InitBuffer(ldOutIdxBuf_, BASE_TOPK * sizeof(int32_t));
}
template <typename LIT>
__aicore__ inline void LIVector<LIT>::InitParams(const struct LICommon::ConstInfo &constInfo,
const LITilingData *__restrict tilingData)
{
this->constInfo_ = constInfo;
blockS2StartIdx_ = 0;
gSize_ = constInfo.gSize;
// define N2 para
kHeadNum_ = constInfo.kHeadNum;
// define MMBase para
s1BaseSize_ = constInfo.s1BaseSize;
s2BaseSize_ = constInfo.s2BaseSize;
groupInner_ = 16;
blockId_ = GetBlockIdx();
}
template <typename LIT>
__aicore__ inline void
LIVector<LIT>::InitVec1GlobalTensor(GlobalTensor<MM1_OUT_T> mm1ResGm, GlobalTensor<float> vec1ResGm,
GlobalTensor<int64_t> vec1ParamGm, GlobalTensor<K_T> weightsGm,
GlobalTensor<int32_t> indiceOutGm)
{
this->mm1ResGm = mm1ResGm;
this->vec1ResGm = vec1ResGm;
this->vec1ParamGm = vec1ParamGm;
this->weightsGm = weightsGm;
this->indiceOutGm = indiceOutGm;
}
template <typename LIT>
__aicore__ inline void LIVector<LIT>::AllocEventID()
{
}
template <typename LIT>
__aicore__ inline void LIVector<LIT>::FreeEventID()
{
}
template <typename LIT>
__aicore__ inline void LIVector<LIT>::CleanInvalidOutput(int64_t invalidS1offset)
{
// init -1 and copy to output
LocalTensor<float> valueULocal = outQueue_.AllocTensor<float>();
LocalTensor<int32_t> idxULocal1 = valueULocal.template ReinterpretCast<int32_t>();
Duplicate(idxULocal1, constInfo_.INVALID_IDX, constInfo_.sparseCount);
outQueue_.EnQue<float>(valueULocal);
valueULocal = outQueue_.DeQue<float>();
LIServiceVec::CopyOut(indiceOutGm[invalidS1offset], idxULocal1, constInfo_.sparseCount);
outQueue_.FreeTensor(valueULocal);
}
template <typename LIT>
__aicore__ inline void LIVector<LIT>::ProcessVec(const LICommon::RunInfo &info)
{
int32_t cuBaseS1Idx = info.gS1Idx * s1BaseSize_;
int32_t cuBaseS2Idx = info.s2Idx * s2BaseSize_;
int64_t mmGmOffset = (info.loop % 2) * ((s1BaseSize_ * gSize_) * s2BaseSize_);
int64_t weightGmOffset = info.tensorWeightsOffset + cuBaseS1Idx * kHeadNum_ * gSize_;
PipeBarrier<PIPE_V>();
int32_t cuS1BeginIdxPerAiv = cuBaseS1Idx;
int32_t cuS1ProcNum =
cuS1BeginIdxPerAiv + s1BaseSize_ > info.actS1Size ? info.actS1Size % s1BaseSize_ : s1BaseSize_;
int32_t cuS1ProcNumPerAiv = blockId_ % 2 == 0 ? CeilDiv(cuS1ProcNum, 2) : (cuS1ProcNum / 2);
cuS1BeginIdxPerAiv += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2);
weightGmOffset += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2) * kHeadNum_ * gSize_;
mmGmOffset += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2) * gSize_ * info.actualSingleProcessSInnerSizeAlign;
// cut G
int32_t outerG = CeilDiv(gSize_, groupInner_);
if (info.loop != 0 && info.s2Idx == 0) {
// globalTopkUb_ value,index=-inf,-1
InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK * 2);
blockS2StartIdx_ = 0;
} else if (info.loop == 0) {
blockS2StartIdx_ = info.s2Idx;
}
int32_t cuRealAcSeq = info.actS2Size;
if (constInfo_.attenMaskFlag) {
cuRealAcSeq = info.actS2Size - (info.actS1Size - cuS1BeginIdxPerAiv);
}
LocalTensor<float> reduceOutBuff = reduceOutBuf_.Get<float>();
LocalTensor<float> brcBuf = brcBuf_.Get<float>();
uint32_t ldS1Offset = (blockId_ % 2 == 0) ? s1BaseSize_ / 2 - cuS1ProcNumPerAiv : 0;
for (int innerS1Idx = 0; innerS1Idx < cuS1ProcNumPerAiv; innerS1Idx++) {
if (constInfo_.attenMaskFlag) {
cuRealAcSeq += 1;
}
int32_t cuS2Len = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq ? cuRealAcSeq - cuBaseS2Idx : s2BaseSize_;
int32_t cuS1Idx = cuS1BeginIdxPerAiv + innerS1Idx;
if (cuRealAcSeq > 0 && cuS2Len > 0) {
int32_t cuS2LenVecAlign = CeilDiv(cuS2Len, s2BaseSize_) * s2BaseSize_;
int32_t mmUbStride = (cuS2LenVecAlign - info.actualSingleProcessSInnerSizeAlign) / B32_BLOCK_ALIGN_NUM;
LocalTensor<float> reduceOutInner = reduceOutBuff[s2BaseSize_];
PipeBarrier<PIPE_V>();
LocalTensor<float> reduceCacheBuf = outQueue_.AllocTensor<float>();
for (int outerGidx = 0; outerGidx < outerG; outerGidx++) {
int32_t procGnum = outerGidx != outerG - 1 ? groupInner_ : gSize_ - outerGidx * groupInner_;
LocalTensor<float> mmInUb = inQueue_.AllocTensor<float>();
LocalTensor<float> weightsInUb = mmInUb[procGnum * s2BaseSize_];
LocalTensor<K_T> weightsInTUb = weightsInUb.template ReinterpretCast<K_T>();
if constexpr (!IsSameType<K_T, float>::value) {
weightsInTUb = weightsInTUb[groupInner_];
}
LIServiceVec::CopyIn(mmInUb, weightsInTUb, mm1ResGm, weightsGm,
mmGmOffset + innerS1Idx * gSize_ * info.actualSingleProcessSInnerSizeAlign +
outerGidx * groupInner_ * info.actualSingleProcessSInnerSizeAlign,
weightGmOffset + innerS1Idx * gSize_ + outerGidx * groupInner_, procGnum,
info.actualSingleProcessSInnerSizeAlign, mmUbStride);
inQueue_.EnQue<float>(mmInUb);
mmInUb = inQueue_.DeQue<float>();
weightsInUb = mmInUb[procGnum * s2BaseSize_];
LIServiceVec::DoScale(reduceCacheBuf[REDUCE_BANK_CONFLICT_NUM], mmInUb, weightsInUb, weightsInTUb,
brcBuf, procGnum, s2BaseSize_, outerGidx);
// confused reduceOp in DoScale
// neednot use LIServiceVec::doReduce(mmInUb, reduceOutInner, procGnum, (s2BaseSize_+8));
inQueue_.FreeTensor(mmInUb);
}
int32_t gRedCnt = groupInner_ > gSize_ ? gSize_ : groupInner_;
bool isS2End = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq;
LIServiceVec::DoReduce(reduceCacheBuf[REDUCE_BANK_CONFLICT_NUM], reduceOutInner, gRedCnt, s2BaseSize_);
outQueue_.FreeTensor(reduceCacheBuf);
LocalTensor<float> sortScoreUb = reduceOutBuff;
LocalTensor<float> sortIndiceUb = reduceOutBuff[cuS2LenVecAlign];
PipeBarrier<PIPE_V>();
Duplicate(sortScoreUb.template ReinterpretCast<int32_t>(), LIServiceVec::NEG_INF, cuS2LenVecAlign);
PipeBarrier<PIPE_V>();
Adds(sortScoreUb, reduceOutInner, 0.0f, cuS2Len);
PipeBarrier<PIPE_V>();
LocalTensor<int32_t> sortIndiceUbInt = sortIndiceUb.template ReinterpretCast<int32_t>();
if (cuS2LenVecAlign != cuS2Len) {
Duplicate(sortIndiceUbInt, -1, cuS2LenVecAlign);
}
PipeBarrier<PIPE_V>();
Adds(sortIndiceUbInt, globalTopkIndice_, static_cast<int32_t>(cuBaseS2Idx), cuS2Len);
PipeBarrier<PIPE_V>();
LocalTensor<float> tmpSortBuf = outQueue_.AllocTensor<float>();
if (info.actS1Size > 4) {
LIServiceVec::SortAll(reduceOutBuff, tmpSortBuf,
cuS2LenVecAlign); // cuS2LenVecAlign <= s2BaseSize_, fill -inf
PipeBarrier<PIPE_V>();
LIServiceVec::MergeSort(globalTopkUb_[innerS1Idx * BASE_TOPK * 2], BASE_TOPK, reduceOutBuff,
cuS2LenVecAlign, tmpSortBuf);
} else {
int64_t globalTopkUbCacheIdx = (info.s2Idx - blockS2StartIdx_) % 4;
Sort<float, true>(
SortedBasicBlock_[innerS1Idx * BASE_TOPK * 2 + globalTopkUbCacheIdx * s2BaseSize_ * 2],
reduceOutBuff, sortIndiceUbInt.template ReinterpretCast<uint32_t>(), tmpSortBuf,
cuS2LenVecAlign / 32);
if (globalTopkUbCacheIdx == 3 || isS2End || info.isAllLoopEnd) {
LocalTensor<float> tt = SortedBasicBlock_[innerS1Idx * BASE_TOPK * 2];
if (info.s2Idx - blockS2StartIdx_ < 4) {
MrgBasicBlock(globalTopkUb_[innerS1Idx * BASE_TOPK * 2], tt,
static_cast<int64_t>(globalTopkUbCacheIdx + 1), s2BaseSize_);
} else {
if (globalTopkUbCacheIdx > 0) {
MrgBasicBlock(tmpSortBuf, tt, static_cast<int64_t>(globalTopkUbCacheIdx + 1), s2BaseSize_);
PipeBarrier<PIPE_V>();
DataCopy(SortedBasicBlock_[innerS1Idx * BASE_TOPK * 2], tmpSortBuf,
(globalTopkUbCacheIdx + 1) * s2BaseSize_ * 2);
}
PipeBarrier<PIPE_V>();
SparseTopK(globalTopkUb_[innerS1Idx * BASE_TOPK * 2],
SortedBasicBlock_[innerS1Idx * BASE_TOPK * 2], tmpSortBuf, BASE_TOPK,
s2BaseSize_ * (globalTopkUbCacheIdx + 1));
}
}
}
PipeBarrier<PIPE_V>();
outQueue_.FreeTensor(tmpSortBuf);
bool needCopyOutGm = blockS2StartIdx_ == 0 && isS2End;
bool needCopyWsGm = info.isAllLoopEnd || isS2End;
if (needCopyOutGm) {
LocalTensor<float> valueULocal = outQueue_.AllocTensor<float>();
LocalTensor<uint32_t> idxULocal = valueULocal.template ReinterpretCast<uint32_t>()[BASE_TOPK];
ExtractIndex(idxULocal, globalTopkUb_[innerS1Idx * BASE_TOPK * 2].template ReinterpretCast<uint32_t>(),
BASE_TOPK);
PipeBarrier<PIPE_V>();
InitSortOutBuf(globalTopkUb_[innerS1Idx * BASE_TOPK * 2], BASE_TOPK * 2);
outQueue_.EnQue<float>(valueULocal);
valueULocal = outQueue_.DeQue<float>();
LocalTensor<int32_t> idxULocal1 = valueULocal.template ReinterpretCast<int32_t>()[BASE_TOPK];
LIServiceVec::CopyOut(indiceOutGm[info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount],
idxULocal1, constInfo_.sparseCount);
outQueue_.FreeTensor(valueULocal);
} else if (needCopyWsGm) {
// vec1Res Gm = [aic, s1BaseSize_, 2, 2, topkOut_] float32
// vec1Param Gm = [aic, s1BaseSize_, 2, 16] int64
// 16 = [needFd, s2AcSeq, s2Start, s2End, isS2End, bn2idx, s1Idx, S1ProcNum, ......]
int64_t wsOffset = (blockId_ / 2) * s1BaseSize_ * 2 * 2 * BASE_TOPK +
(blockId_ % 2) * (s1BaseSize_ / 2) * 2 * 2 * BASE_TOPK +
(ldS1Offset + innerS1Idx) * 2 * 2 * BASE_TOPK;
int64_t wsInfoOffset = (blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ +
(blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_ +
(ldS1Offset + innerS1Idx) * 2 * paramNum_;
LocalTensor<int64_t> tmpiBuff = paramBuf_.Get<int64_t>();
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
tmpiBuff.SetValue(0, static_cast<int64_t>(1));
tmpiBuff.SetValue(1, static_cast<int64_t>(cuRealAcSeq));
tmpiBuff.SetValue(2, static_cast<int64_t>(blockS2StartIdx_));
tmpiBuff.SetValue(3, static_cast<int64_t>(cuBaseS2Idx + cuS2Len));
tmpiBuff.SetValue(4, static_cast<int64_t>(isS2End));
tmpiBuff.SetValue(5, static_cast<int64_t>(info.bN2Idx));
tmpiBuff.SetValue(6, static_cast<int64_t>(cuS1Idx));
tmpiBuff.SetValue(7, static_cast<int64_t>(cuS1ProcNum));
tmpiBuff.SetValue(8, static_cast<int64_t>(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount));
bool isTailReduce = blockS2StartIdx_ == 0;
if (isTailReduce) {
wsInfoOffset += paramNum_;
wsOffset += 2 * BASE_TOPK;
}
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
LIServiceVec::CopyOut(vec1ParamGm[wsInfoOffset], tmpiBuff, 16);
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
LIServiceVec::CopyOut(vec1ResGm[wsOffset], globalTopkUb_[innerS1Idx * BASE_TOPK * 2], 2 * BASE_TOPK);
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
}
} else if (cuRealAcSeq <= 0) {
CleanInvalidOutput(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount);
}
}
if (LAYOUT_T == LI_LAYOUT::BSND) {
bool isS1LoopEnd = (cuBaseS1Idx + s1BaseSize_) >= info.actS1Size;
int32_t invalidS1Num = constInfo_.qSeqSize - info.actS1Size;
if (invalidS1Num > 0 && isS1LoopEnd && blockS2StartIdx_ == 0) {
int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num, 2) : (invalidS1Num / 2);
int32_t s1OffsetPerAiv = info.actS1Size + (blockId_ % 2) * CeilDiv(invalidS1Num, 2);
for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) {
CleanInvalidOutput(info.indiceOutOffset + (s1OffsetPerAiv + innerS1Idx) * constInfo_.sparseCount);
}
}
int32_t invalidS1Num2 = info.actS1Size - info.actS2Size;
if (invalidS1Num2 > 0 && isS1LoopEnd && blockS2StartIdx_ == 0 && constInfo_.attenMaskFlag) {
int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num2, 2) : (invalidS1Num2 / 2);
int32_t s1OffsetPerAiv = (blockId_ % 2) * CeilDiv(invalidS1Num2, 2);
for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) {
CleanInvalidOutput((info.bN2Idx * constInfo_.qSeqSize + s1OffsetPerAiv + innerS1Idx) *
constInfo_.sparseCount);
}
}
}
if (info.isLastS2InnerLoop) {
blockS2StartIdx_ = 0;
}
}
template <typename LIT>
__aicore__ inline void LIVector<LIT>::ProcessLD()
{
int32_t curCubeId = blockId_ / 2;
int32_t tmpCubeId = curCubeId;
int64_t s2ActSeq;
int64_t s2Start;
int64_t s2End;
int64_t isS2End;
int64_t bn2Idx;
int64_t s1Idx;
uint32_t acc_list_num = 0;
int64_t bIdx = 0;
int64_t needFd;
int64_t wsOffset;
int64_t wsInfoOffset = 0;
int64_t nextneedFd;
int64_t valueOffset = 0;
int64_t outOffset = 0;
LocalTensor<float> curValueIdxUb = ldToBeMrgBuf_.Get<float>();
LocalTensor<float> tmpUb = ldTmpBuf_.Get<float>();
uint32_t s1LdStartIdx = 0;
uint32_t s1ProcNum = 0;
uint64_t paramGmCoreOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_;
for (uint32_t innerS1Idx = 0; innerS1Idx < s1BaseSize_; innerS1Idx++) {
needFd = vec1ParamGm.GetValue(paramGmCoreOffset + innerS1Idx * 2 * paramNum_ + paramNum_);
if (needFd == 1) {
s1LdStartIdx = (s1ProcNum == 0) ? innerS1Idx : s1LdStartIdx;
s1ProcNum++;
}
}
if (s1ProcNum == 0) {
return;
}
uint32_t s1VecNum = CeilDiv(s1ProcNum, 2);
if (blockId_ % 2 == 1) {
s1LdStartIdx = s1LdStartIdx + s1VecNum;
s1VecNum = s1ProcNum - s1VecNum;
}
for (uint32_t innerS1Idx = s1LdStartIdx; innerS1Idx < s1LdStartIdx + s1VecNum; innerS1Idx++) {
tmpCubeId = curCubeId;
acc_list_num = 0;
valueOffset = 0;
wsOffset = tmpCubeId * s1BaseSize_ * 2 * 2 * BASE_TOPK +
innerS1Idx * 2 * 2 * BASE_TOPK + 2 * BASE_TOPK;
SetWaitFlag<HardEvent::V_MTE2>(HardEvent::V_MTE2);
SetWaitFlag<HardEvent::S_MTE2>(HardEvent::S_MTE2);
DataCopyPad(curValueIdxUb, vec1ResGm[wsOffset],
{1, static_cast<uint16_t>(2 * BASE_TOPK * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0});
acc_list_num++;
valueOffset += 2 * BASE_TOPK;
tmpCubeId++;
wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_;
needFd = vec1ParamGm.GetValue(wsInfoOffset);
isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4);
s1Idx = vec1ParamGm.GetValue(wsInfoOffset + 6);
outOffset = vec1ParamGm.GetValue(wsInfoOffset + 8);
while (needFd == 1) {
wsOffset = tmpCubeId * s1BaseSize_ * 2 * 2 * BASE_TOPK +
innerS1Idx * 2 * 2 * BASE_TOPK;
SetWaitFlag<HardEvent::V_MTE2>(HardEvent::V_MTE2);
SetWaitFlag<HardEvent::S_MTE2>(HardEvent::S_MTE2);
DataCopyPad(curValueIdxUb[valueOffset], vec1ResGm[wsOffset],
{1, static_cast<uint16_t>(2 * BASE_TOPK * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0});
valueOffset += 2 * BASE_TOPK;
acc_list_num++;
if (acc_list_num == mrgListNum_) {
AscendC::MrgSort4Info params;
params.elementLengths[0] = BASE_TOPK;
params.elementLengths[1] = BASE_TOPK;
params.elementLengths[2] = BASE_TOPK;
params.elementLengths[3] = BASE_TOPK;
params.ifExhaustedSuspension = true;
params.validBit = 0b1111;
params.repeatTimes = 1;
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = curValueIdxUb[0];
srcList.src2 = curValueIdxUb[2 * BASE_TOPK];
srcList.src3 = curValueIdxUb[4 * BASE_TOPK];
srcList.src4 = curValueIdxUb[6 * BASE_TOPK];
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
MrgSort(tmpUb, srcList, params);
PipeBarrier<PIPE_V>();
DataCopy(curValueIdxUb, tmpUb, 2 * BASE_TOPK);
PipeBarrier<PIPE_V>();
acc_list_num = 1;
valueOffset = 2 * BASE_TOPK;
}
if (isS2End == 1) {
break;
}
tmpCubeId++;
wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_;
needFd = vec1ParamGm.GetValue(wsInfoOffset);
isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4);
}
if (acc_list_num != 1) {
AscendC::MrgSort4Info params;
params.elementLengths[0] = BASE_TOPK;
params.elementLengths[1] = BASE_TOPK;
params.elementLengths[2] = BASE_TOPK;
params.elementLengths[3] = BASE_TOPK;
params.ifExhaustedSuspension = true;
if (acc_list_num == 2) {
params.validBit = 0b0011;
} else if (acc_list_num == 3) {
params.validBit = 0b0111;
}
params.repeatTimes = 1;
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = curValueIdxUb[0];
srcList.src2 = curValueIdxUb[2 * BASE_TOPK];
srcList.src3 = curValueIdxUb[4 * BASE_TOPK];
srcList.src4 = curValueIdxUb[6 * BASE_TOPK];
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
MrgSort(tmpUb, srcList, params);
PipeBarrier<PIPE_V>();
DataCopy(curValueIdxUb, tmpUb, 2 * BASE_TOPK);
PipeBarrier<PIPE_V>();
}
LocalTensor<float> outValueUb = ldOutValueBuf_.Get<float>();
LocalTensor<uint32_t> outIdxUb = ldOutIdxBuf_.Get<uint32_t>();
Extract(outValueUb, outIdxUb, curValueIdxUb, (BASE_TOPK / 32));
LocalTensor<int32_t> idxULocal1 = outIdxUb.template ReinterpretCast<int32_t>();
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
DataCopyPad(indiceOutGm[outOffset], idxULocal1,
{1, static_cast<uint16_t>(constInfo_.sparseCount * sizeof(int32_t)), 0, 0});
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
}
}
} // namespace LIKernel
#endif

View File

@@ -0,0 +1,66 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_template_tiling_key.h
* \brief
*/
#ifndef TEMPLATE_TILING_KEY_LI_H_
#define TEMPLATE_TILING_KEY_LI_H_
#include "ascendc/host_api/tiling/template_argument.h"
#define LI_TPL_FP16 1
#define LI_TPL_INT32 3
#define LI_TPL_BF16 27
#define LI_LAYOUT_BSND 0
#define LI_LAYOUT_TND 1
#define LI_LAYOUT_PA_BSND 2
#define ASCENDC_TPL_4_BW 4
ASCENDC_TPL_ARGS_DECL(LightningIndexerVllm,
ASCENDC_TPL_DTYPE_DECL(DT_Q, LI_TPL_FP16, LI_TPL_BF16),
ASCENDC_TPL_DTYPE_DECL(DT_K, LI_TPL_FP16, LI_TPL_BF16),
ASCENDC_TPL_DTYPE_DECL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_DECL(PAGE_ATTENTION, 0, 1),
ASCENDC_TPL_UINT_DECL(LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND,
LI_LAYOUT_TND),
ASCENDC_TPL_UINT_DECL(K_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST,
LI_LAYOUT_PA_BSND, LI_LAYOUT_BSND, LI_LAYOUT_TND), );
ASCENDC_TPL_SEL(
ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_FP16), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_FP16),
ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32),
ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 1),
ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND),
ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_PA_BSND), ),
ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_BF16), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_BF16),
ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32),
ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 1),
ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND),
ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_PA_BSND), ),
ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_FP16), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_FP16),
ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32),
ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 0),
ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND),
ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST,
LI_LAYOUT_BSND, LI_LAYOUT_TND), ),
ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_BF16), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_BF16),
ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32),
ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 0),
ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND),
ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND), ), );
#endif

View File

@@ -0,0 +1,335 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_vector.h
* \brief
*/
#ifndef LIGHTNING_INDEXER_VECTOR_H
#define LIGHTNING_INDEXER_VECTOR_H
#include "lightning_indexer_vector.h"
#include "kernel_operator.h"
namespace LIServiceVec {
using namespace AscendC;
constexpr int32_t NEG_INF = 0xFF800000;
constexpr int32_t INVALID_INDEX = -1;
constexpr uint8_t VEC_REPEAT_MAX = 255;
constexpr uint8_t B32_VEC_ELM_NUM = 64;
constexpr uint8_t B32_BLOCK_ALIGN_NUM = 8;
constexpr uint8_t B32_VEC_REPEAT_STRIDE = 8;
constexpr uint64_t VEC_REPEAT_BYTES = 256;
constexpr int32_t CONST_TWO = 2;
constexpr int64_t VALUE_AND_INDEX_NUM = 2;
constexpr int64_t BLOCK_BYTES = 32;
constexpr int64_t MRG_QUE_0 = 0;
constexpr int64_t MRG_QUE_1 = 1;
constexpr int64_t MRG_QUE_2 = 2;
constexpr int64_t MRG_QUE_3 = 3;
constexpr int64_t MRG_BLOCK_2 = 2;
constexpr int64_t MRG_BLOCK_3 = 3;
constexpr int64_t MRG_BLOCK_4 = 4;
template <typename T>
__aicore__ inline void CopyIn(LocalTensor<float> &mmOutUb, LocalTensor<T> &weightsUb, GlobalTensor<float> &mMoutGm,
GlobalTensor<T> &weightScaleGm, int64_t MMout_gmoffset, int64_t weights_gmoffset,
int64_t groupInner, int64_t s2Inner, int64_t mmUbStride)
{
AscendC::DataCopyPadExtParams<float> padParams{false, 0, 0, 0};
AscendC::DataCopyExtParams dataCopymMoutParams;
dataCopymMoutParams.blockCount = groupInner;
dataCopymMoutParams.blockLen = s2Inner * sizeof(float);
dataCopymMoutParams.srcStride = 0;
dataCopymMoutParams.dstStride = mmUbStride;
dataCopymMoutParams.rsv = 0;
AscendC::DataCopyPad(mmOutUb, mMoutGm[MMout_gmoffset], dataCopymMoutParams, padParams);
AscendC::DataCopyPadExtParams<T> padTParams{false, 0, 0, 0};
AscendC::DataCopyExtParams dataCopyweightParams;
dataCopyweightParams.blockCount = 1;
dataCopyweightParams.blockLen = groupInner * sizeof(T);
dataCopyweightParams.srcStride = 0;
dataCopyweightParams.dstStride = 0;
dataCopyweightParams.rsv = 0;
AscendC::DataCopyPad(weightsUb, weightScaleGm[weights_gmoffset], dataCopyweightParams, padTParams);
}
template <typename T>
__aicore__ inline void CopyOut(const GlobalTensor<T> &dstGm, const LocalTensor<T> &srcUb, int64_t copyCount)
{
AscendC::DataCopyParams dataCopyOutyParams;
dataCopyOutyParams.blockCount = 1;
dataCopyOutyParams.blockLen = copyCount * sizeof(T);
dataCopyOutyParams.srcStride = 0;
dataCopyOutyParams.dstStride = 0;
AscendC::DataCopyPad(dstGm, srcUb, dataCopyOutyParams);
}
template <typename T>
__aicore__ inline void DoScale(const LocalTensor<float> &reduceCacheBuf, LocalTensor<float> &mmOutUb,
LocalTensor<float> &weightsUb, LocalTensor<T> &weightsTUb, LocalTensor<float> &tmpBuff,
int64_t groupInner, int64_t s2Inner, int32_t outerGidx)
{
// cast bfloat16_t to float
if constexpr (!IsSameType<T, float>::value) {
AscendC::Cast(weightsUb, weightsTUb, RoundMode::CAST_NONE, groupInner);
AscendC::PipeBarrier<PIPE_V>();
}
// weight broadcast: [groupInner, 1] -> [groupInner, 8]
AscendC::Brcb(tmpBuff, weightsUb, LICommon::CeilDiv(groupInner, static_cast<int64_t>(B32_BLOCK_ALIGN_NUM)),
{1, B32_VEC_REPEAT_STRIDE});
AscendC::PipeBarrier<PIPE_V>();
// do scale: [groupInner, 8] * [groupInner, s2Inner]
uint64_t countPerRepeat = VEC_REPEAT_BYTES / sizeof(float);
uint64_t repeatTimes = s2Inner / countPerRepeat;
for (int32_t i = 0; i < groupInner; i++) {
if (outerGidx == 0) {
AscendC::Mul(reduceCacheBuf[i * s2Inner], mmOutUb[i * s2Inner], tmpBuff[i * B32_BLOCK_ALIGN_NUM],
countPerRepeat, repeatTimes, {1, 1, 0, B32_VEC_REPEAT_STRIDE, B32_VEC_REPEAT_STRIDE, 0});
} else {
AscendC::Mul(mmOutUb[i * s2Inner], mmOutUb[i * s2Inner], tmpBuff[i * B32_BLOCK_ALIGN_NUM], countPerRepeat,
repeatTimes, {1, 1, 0, B32_VEC_REPEAT_STRIDE, B32_VEC_REPEAT_STRIDE, 0});
}
}
if (outerGidx != 0) {
AscendC::PipeBarrier<PIPE_V>();
AscendC::Add(reduceCacheBuf, mmOutUb, reduceCacheBuf, groupInner * s2Inner);
}
AscendC::PipeBarrier<PIPE_V>();
}
__aicore__ inline uint64_t FindNearestPower2(uint64_t value)
{
if (value <= CONST_TWO) {
return value;
} else {
const uint64_t pow = 63 - clz(value);
return (1 << pow);
}
}
__aicore__ inline void DoReduce(const LocalTensor<float> &srcTensor, LocalTensor<float> &dstTensor, int32_t rNum,
int32_t aNum)
{
if (rNum == 1) {
AscendC::Adds<float>(dstTensor, srcTensor, 0, aNum);
AscendC::PipeBarrier<PIPE_V>();
return;
}
uint32_t dichotomizeAddPow = FindNearestPower2(rNum);
uint32_t dichotomizeAddDiffSize = rNum - dichotomizeAddPow;
if (dichotomizeAddDiffSize != 0) {
AscendC::Add(srcTensor, srcTensor, srcTensor[dichotomizeAddPow * aNum], dichotomizeAddDiffSize * aNum);
AscendC::PipeBarrier<PIPE_V>();
}
int32_t nowRows = dichotomizeAddPow;
while (nowRows > CONST_TWO) {
nowRows = nowRows / CONST_TWO;
AscendC::Add(srcTensor, srcTensor, srcTensor[nowRows * aNum], nowRows * aNum);
AscendC::PipeBarrier<PIPE_V>();
}
AscendC::Add(dstTensor, srcTensor, srcTensor[aNum], aNum);
AscendC::PipeBarrier<PIPE_V>();
}
__aicore__ inline void InitSortOutBuf(const LocalTensor<float> &src, int64_t eleNum)
{
uint64_t mask1[2] = {0x5555555555555555, 0};
uint64_t mask0[2] = {0xaaaaaaaaaaaaaaaa, 0};
int64_t repeatNum = eleNum / B32_VEC_ELM_NUM;
int64_t forLoop = repeatNum / VEC_REPEAT_MAX;
int64_t forRemain = repeatNum % VEC_REPEAT_MAX;
for (int i = 0; i < forLoop; i++) {
AscendC::Duplicate(src.template ReinterpretCast<int32_t>(), NEG_INF, mask1, VEC_REPEAT_MAX, 1,
B32_VEC_REPEAT_STRIDE);
AscendC::Duplicate(src.template ReinterpretCast<int32_t>(), INVALID_INDEX, mask0, VEC_REPEAT_MAX, 1,
B32_VEC_REPEAT_STRIDE);
}
if (forRemain > 0) {
AscendC::Duplicate(src.template ReinterpretCast<int32_t>()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], NEG_INF,
mask1, forRemain, 1, B32_VEC_REPEAT_STRIDE);
AscendC::Duplicate(src.template ReinterpretCast<int32_t>()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM],
INVALID_INDEX, mask0, forRemain, 1, B32_VEC_REPEAT_STRIDE);
}
AscendC::PipeBarrier<PIPE_V>();
}
__aicore__ inline void SortAll(LocalTensor<float> &src, LocalTensor<float> &tmp, int64_t logitsNum)
{
int64_t sort32Repeats = logitsNum / BLOCK_BYTES;
AscendC::Sort32(tmp, src, src[logitsNum].ReinterpretCast<uint32_t>(), sort32Repeats);
AscendC::PipeBarrier<PIPE_V>();
int64_t mrgGroups = sort32Repeats;
int64_t mrgElements = BLOCK_BYTES;
int64_t i = 0;
AscendC::LocalTensor<float> srcTensor;
AscendC::LocalTensor<float> dstTensor;
while (true) {
if (i % CONST_TWO == 0) {
srcTensor = tmp;
dstTensor = src;
} else {
srcTensor = src;
dstTensor = tmp;
}
AscendC::MrgSort4Info params;
params.elementLengths[0] = mrgElements;
params.elementLengths[MRG_QUE_1] = mrgElements;
params.elementLengths[MRG_QUE_2] = mrgElements;
params.elementLengths[MRG_QUE_3] = mrgElements;
params.ifExhaustedSuspension = false;
params.validBit = 0b1111;
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = srcTensor[0];
srcList.src2 = srcTensor[MRG_QUE_1 * VALUE_AND_INDEX_NUM * mrgElements];
srcList.src3 = srcTensor[MRG_QUE_2 * VALUE_AND_INDEX_NUM * mrgElements];
srcList.src4 = srcTensor[MRG_QUE_3 * VALUE_AND_INDEX_NUM * mrgElements];
if (mrgGroups <= MRG_BLOCK_4) {
params.repeatTimes = 1;
if (mrgGroups == 1) {
break;
} else if (mrgGroups == MRG_BLOCK_2) {
params.validBit = 0b0011;
} else if (mrgGroups == MRG_BLOCK_3) {
params.validBit = 0b0111;
} else if (mrgGroups == MRG_BLOCK_4) {
params.validBit = 0b1111;
}
AscendC::MrgSort<float>(dstTensor, srcList, params);
i += 1;
break;
} else {
params.repeatTimes = mrgGroups / MRG_BLOCK_4;
AscendC::MrgSort<float>(dstTensor, srcList, params);
i += 1;
mrgElements = mrgElements * MRG_BLOCK_4;
mrgGroups = mrgGroups / MRG_BLOCK_4;
}
AscendC::PipeBarrier<PIPE_V>();
}
if (i % CONST_TWO == 0) {
AscendC::DataCopy(src, tmp, logitsNum * VALUE_AND_INDEX_NUM);
AscendC::PipeBarrier<PIPE_V>();
}
}
__aicore__ inline void SortAll(LocalTensor<float> &dst, LocalTensor<float> &srcValue, LocalTensor<uint32_t> &srcIndex,
LocalTensor<float> &tmpTensor, int64_t logitsNum)
{
int64_t sort32Repeats = logitsNum / BLOCK_BYTES;
AscendC::Sort<float, true>(dst, srcValue, srcIndex, tmpTensor, sort32Repeats);
AscendC::PipeBarrier<PIPE_V>();
}
__aicore__ inline void MergeSort(const LocalTensor<float> &mrgDst, int32_t mrgDstNum, LocalTensor<float> &mrgSrc,
int32_t mrgSrcNum, LocalTensor<float> &tmpTensor)
{
AscendC::MrgSort4Info params;
params.elementLengths[0] = mrgDstNum;
params.elementLengths[1] = mrgSrcNum;
params.ifExhaustedSuspension = false;
params.validBit = 0b0011;
params.repeatTimes = 1;
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = mrgDst;
srcList.src2 = mrgSrc;
AscendC::MrgSort<float>(tmpTensor, srcList, params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::DataCopy(mrgDst, tmpTensor, mrgDstNum * VALUE_AND_INDEX_NUM);
AscendC::PipeBarrier<PIPE_V>();
}
__aicore__ inline void MrgBasicBlock(const LocalTensor<float> &dst, const LocalTensor<float> &src, int64_t blockNum,
int64_t basicBlockSize)
{
AscendC::MrgSort4Info params;
params.elementLengths[MRG_QUE_0] = basicBlockSize;
params.elementLengths[MRG_QUE_1] = basicBlockSize;
params.elementLengths[MRG_QUE_2] = basicBlockSize;
params.elementLengths[MRG_QUE_3] = basicBlockSize;
params.ifExhaustedSuspension = false;
if (blockNum == MRG_BLOCK_2) {
params.validBit = 0b0011;
} else if (blockNum == MRG_BLOCK_3) {
params.validBit = 0b0111;
} else if (blockNum == MRG_BLOCK_4) {
params.validBit = 0b1111;
} else {
AscendC::DataCopy(dst, src, basicBlockSize * VALUE_AND_INDEX_NUM);
return;
}
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = src[0];
srcList.src2 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_1];
srcList.src3 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_2];
srcList.src4 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_3];
AscendC::MrgSort<float>(dst, srcList, params);
}
template <bool needMrg = true>
__aicore__ inline void SparseTopK(const LocalTensor<float> &dst, const LocalTensor<float> &needsMerging,
const LocalTensor<float> &tmp, int64_t topk, int64_t mergSize)
{
if (!needMrg) {
AscendC::DataCopy(dst, needsMerging, mergSize * VALUE_AND_INDEX_NUM);
return;
}
AscendC::MrgSort4Info params;
params.elementLengths[0] = topk;
params.elementLengths[1] = mergSize;
params.ifExhaustedSuspension = (topk == mergSize);
params.validBit = 0b0011;
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = dst;
srcList.src2 = needsMerging;
AscendC::MrgSort<float>(tmp, srcList, params);
AscendC::DataCopy(dst, tmp, topk * VALUE_AND_INDEX_NUM);
}
__aicore__ inline void ExtractIndex(const LocalTensor<uint32_t> &idxULocal, const LocalTensor<uint32_t> &sortLocal,
int64_t extractNum)
{
AscendC::GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = Ceil(extractNum * sizeof(float) * VALUE_AND_INDEX_NUM, VEC_REPEAT_BYTES);
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride = B32_VEC_REPEAT_STRIDE;
gatherMaskParams.src1RepeatStride = 0;
uint64_t rsvdCnt = 0;
uint8_t src1Pattern = 2;
AscendC::GatherMask(idxULocal, sortLocal, src1Pattern, false, static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
AscendC::PipeBarrier<PIPE_V>();
}
template <HardEvent event>
__aicore__ inline void SetWaitFlag(HardEvent evt)
{
event_t eventId = static_cast<event_t>(GetTPipePtr()->FetchEventID(evt));
AscendC::SetFlag<event>(eventId);
AscendC::WaitFlag<event>(eventId);
}
} // namespace LIServiceVec
#endif // LIGHTNING_INDEXER_VECTOR_H

View File

@@ -0,0 +1,58 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer.cpp
* \brief
*/
#include "kernel_operator.h"
#include "lib/matmul_intf.h"
#include "lightning_indexer_template_tiling_key.h"
#include "lightning_indexer_kernel.h"
using namespace LIKernel;
#define INVOKE_LI_NO_KFC_OP_IMPL(templateClass, ...) \
do { \
templateClass<LIType<__VA_ARGS__>> op; \
LI_COPY_TILING_DATA(LITilingData, tiling); \
op.Init(query, key, weights, actualSeqLengthsQ, actualSeqLengths, blocktable, sparseIndices, user, \
tiling_data, &tPipe); \
op.Process(); \
} while (0)
#define LI_COPY_TILING_DATA(tilingDataStruct, tiling) \
GET_TILING_DATA_WITH_STRUCT(tilingDataStruct, tiling_data_in, tiling); \
const tilingDataStruct *__restrict tiling_data = &tiling_data_in;
template <int DT_Q, int DT_K, int DT_OUT, int PAGE_ATTENTION, int LAYOUT_T, int K_LAYOUT_T>
__global__ __aicore__ void lightning_indexer_vllm(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths,
__gm__ uint8_t *blocktable, __gm__ uint8_t *sparseIndices,
__gm__ uint8_t *workspace, __gm__ uint8_t *tiling)
{
#if (__CCE_AICORE__ == 310) || (defined __DAV_310R6__) || (__CCE_AICORE__ == 200)
#else
TPipe tPipe;
__gm__ uint8_t *user = GetUserWorkspace(workspace);
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2);
if constexpr (DT_Q == LI_TPL_FP16 && DT_K == LI_TPL_FP16 && DT_OUT == LI_TPL_INT32) {
INVOKE_LI_NO_KFC_OP_IMPL(LIPreload, half, half, int32_t, PAGE_ATTENTION,
LI_LAYOUT(LAYOUT_T), LI_LAYOUT(K_LAYOUT_T));
} else {
INVOKE_LI_NO_KFC_OP_IMPL(LIPreload, bfloat16_t, bfloat16_t, int32_t, PAGE_ATTENTION,
LI_LAYOUT(LAYOUT_T), LI_LAYOUT(K_LAYOUT_T));
}
#endif
}