### What this PR does / why we need it? Provide high-performance AscendC operators lightning_indexer and sparse_flash_attention to boost the execution performance of the DeepSeek v3.2 model. Meanwhile, adapt the two AscendC operators to vllm-ascend framework. ### Does this PR introduce _any_ user-facing change? No (only underlying operator optimizations, with no user-facing changes) ### How was this patch tested? - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: MingYang119 <songmingyang@huawei.com>
623 lines
26 KiB
C++
623 lines
26 KiB
C++
/**
|
|
* This program is free software, you can redistribute it and/or modify it.
|
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
* This file is a part of the CANN Open Software.
|
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
* See LICENSE in the root of the software repository for the full text of the License.
|
|
*/
|
|
|
|
/*!
|
|
* \file lightning_indexer_kernel.h
|
|
* \brief
|
|
*/
|
|
|
|
#ifndef LIGHTNING_INDEXER_KERNEL_H
|
|
#define LIGHTNING_INDEXER_KERNEL_H
|
|
|
|
#include "kernel_operator.h"
|
|
#include "kernel_operator_list_tensor_intf.h"
|
|
#include "kernel_tiling/kernel_tiling.h"
|
|
#include "lib/matmul_intf.h"
|
|
#include "lib/matrix/matmul/tiling.h"
|
|
#include "lightning_indexer_common.h"
|
|
#include "lightning_indexer_service_vector.h"
|
|
#include "lightning_indexer_service_cube.h"
|
|
|
|
namespace LIKernel {
|
|
using namespace LICommon;
|
|
using namespace LIServiceVec;
|
|
using namespace matmul;
|
|
using AscendC::CacheMode;
|
|
using AscendC::CrossCoreSetFlag;
|
|
using AscendC::CrossCoreWaitFlag;
|
|
|
|
struct TempLoopInfo {
|
|
uint32_t bN2Idx = 0;
|
|
uint32_t bIdx = 0U;
|
|
uint32_t n2Idx = 0U;
|
|
uint32_t gS1Idx = 0U;
|
|
uint32_t gS1LoopEnd = 0U;
|
|
uint32_t s2LoopEnd = 0U;
|
|
uint32_t actS1Size = 1ULL;
|
|
uint32_t actS2Size = 0ULL;
|
|
bool curActSeqLenIsZero = false;
|
|
bool needDealActS1LessThanS1 = false;
|
|
uint32_t actMBaseSize = 0U;
|
|
uint32_t mBasicSizeTail = 0U;
|
|
uint32_t s2BasicSizeTail = 0U;
|
|
};
|
|
|
|
template <typename LIT>
|
|
class LIPreload {
|
|
public:
|
|
__aicore__ inline LIPreload(){};
|
|
__aicore__ inline void Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
|
|
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths,
|
|
__gm__ uint8_t *blockTable, __gm__ uint8_t *sparseIndices, __gm__ uint8_t *workspace,
|
|
const LITilingData *__restrict tiling, TPipe *tPipe);
|
|
__aicore__ inline void Process();
|
|
|
|
using Q_T = typename LIT::queryType;
|
|
using K_T = typename LIT::keyType;
|
|
using OUT_T = typename LIT::outputType;
|
|
static constexpr bool PAGE_ATTENTION = LIT::pageAttention;
|
|
static constexpr LI_LAYOUT LAYOUT_T = LIT::layout;
|
|
static constexpr LI_LAYOUT K_LAYOUT_T = LIT::keyLayout;
|
|
|
|
using MM1_OUT_T = float;
|
|
|
|
LIMatmul<LIT> matmulService;
|
|
LIVector<LIT> vectorService;
|
|
|
|
static constexpr uint32_t SYNC_C1_V1_FLAG = 4;
|
|
static constexpr uint32_t SYNC_V1_C1_FLAG = 5;
|
|
|
|
static constexpr uint32_t M_BASE_SIZE = 512;
|
|
static constexpr uint32_t S2_BASE_SIZE = 512;
|
|
static constexpr uint32_t HEAD_DIM = 128;
|
|
static constexpr uint32_t K_HEAD_NUM = 1;
|
|
static constexpr uint32_t GM_ALIGN_BYTES = 512;
|
|
|
|
static constexpr int64_t LD_PREFETCH_LEN = 2;
|
|
// for workspace double
|
|
static constexpr uint32_t WS_DOBULE = 2;
|
|
|
|
protected:
|
|
TPipe *pipe = nullptr;
|
|
|
|
// offset
|
|
uint64_t queryCoreOffset = 0ULL;
|
|
uint64_t keyCoreOffset = 0ULL;
|
|
uint64_t weightsCoreOffset = 0ULL;
|
|
uint64_t indiceOutCoreOffset = 0ULL;
|
|
|
|
GlobalTensor<Q_T> queryGm;
|
|
GlobalTensor<K_T> keyGm;
|
|
GlobalTensor<K_T> weightsGm;
|
|
|
|
GlobalTensor<int32_t> indiceOutGm;
|
|
GlobalTensor<int32_t> blockTableGm;
|
|
|
|
GlobalTensor<uint32_t> actualSeqLengthsGmQ;
|
|
GlobalTensor<uint32_t> actualSeqLengthsGm;
|
|
// workspace
|
|
GlobalTensor<MM1_OUT_T> mm1ResGm;
|
|
GlobalTensor<float> vec1ResGm;
|
|
GlobalTensor<int64_t> vec1ParamGm;
|
|
|
|
// aic、aiv kernel info
|
|
uint32_t tmpBlockIdx = 0U;
|
|
uint32_t aiCoreIdx = 0U;
|
|
uint32_t usedCoreNum = 0U;
|
|
|
|
LICommon::ConstInfo constInfo{};
|
|
TempLoopInfo tempLoopInfo{};
|
|
LICommon::SplitCoreInfo splitCoreInfo{};
|
|
|
|
// ================================Init functions==================================
|
|
__aicore__ inline void InitTilingData(const LITilingData *__restrict tilingData);
|
|
__aicore__ inline void InitBuffers();
|
|
__aicore__ inline void InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths);
|
|
// ================================Split Core================================
|
|
__aicore__ inline void SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, LICommon::SplitCoreInfo &info);
|
|
__aicore__ inline uint32_t GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size, uint32_t actS2Size);
|
|
__aicore__ inline uint32_t GetTotalBaseBlockNum();
|
|
// ================================Process functions================================
|
|
__aicore__ inline void ProcessMain();
|
|
__aicore__ inline void ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx, LICommon::RunInfo &runInfo);
|
|
__aicore__ inline void ProcessDecode();
|
|
__aicore__ inline void ProcessInvalid();
|
|
// ================================Params Calc=====================================
|
|
__aicore__ inline void CalcGS1LoopParams(uint32_t bN2Idx);
|
|
__aicore__ inline void GetBN2Idx(uint32_t bN2Idx);
|
|
__aicore__ inline uint32_t GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq,
|
|
GlobalTensor<uint32_t> &actualSeqLengthsGm, uint32_t defaultSeqLen);
|
|
__aicore__ inline void GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size);
|
|
__aicore__ inline void CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx);
|
|
__aicore__ inline void CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LICommon::RunInfo &runInfo);
|
|
__aicore__ inline void DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start);
|
|
};
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline void LIPreload<LIT>::InitTilingData(const LITilingData *__restrict tilingData)
|
|
{
|
|
usedCoreNum = tilingData->usedCoreNum;
|
|
constInfo.batchSize = tilingData->bSize;
|
|
constInfo.qHeadNum = constInfo.gSize = tilingData->gSize;
|
|
constInfo.kSeqSize = tilingData->s2Size;
|
|
constInfo.qSeqSize = tilingData->s1Size;
|
|
constInfo.attenMaskFlag = (tilingData->sparseMode == 3);
|
|
constInfo.kCacheBlockSize = tilingData->blockSize;
|
|
constInfo.maxBlockNumPerBatch = tilingData->maxBlockNumPerBatch;
|
|
constInfo.sparseCount = tilingData->sparseCount;
|
|
constInfo.outputLayout = LAYOUT_T;
|
|
if (LAYOUT_T == LI_LAYOUT::TND) {
|
|
constInfo.isAccumSeqS1 = true;
|
|
}
|
|
if (K_LAYOUT_T == LI_LAYOUT::TND) {
|
|
constInfo.isAccumSeqS2 = true;
|
|
}
|
|
|
|
constInfo.kHeadNum = K_HEAD_NUM;
|
|
constInfo.headDim = HEAD_DIM;
|
|
|
|
constInfo.mBaseSize = M_BASE_SIZE;
|
|
constInfo.s2BaseSize = S2_BASE_SIZE;
|
|
constInfo.s1BaseSize = (constInfo.mBaseSize + constInfo.gSize - 1) / constInfo.gSize;
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline void LIPreload<LIT>::InitBuffers()
|
|
{
|
|
if ASCEND_IS_AIV {
|
|
vectorService.InitBuffers(pipe);
|
|
} else {
|
|
matmulService.InitBuffers(pipe);
|
|
}
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline void LIPreload<LIT>::InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ,
|
|
__gm__ uint8_t *actualSeqLengths)
|
|
{
|
|
if (actualSeqLengthsQ == nullptr) {
|
|
constInfo.actualLenQDims = 0;
|
|
} else {
|
|
constInfo.actualLenQDims = constInfo.batchSize;
|
|
actualSeqLengthsGmQ.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengthsQ, constInfo.actualLenQDims);
|
|
}
|
|
if (actualSeqLengths == nullptr) {
|
|
constInfo.actualLenDims = 0;
|
|
} else {
|
|
constInfo.actualLenDims = constInfo.batchSize;
|
|
actualSeqLengthsGm.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengths, constInfo.actualLenDims);
|
|
}
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline uint32_t LIPreload<LIT>::GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq,
|
|
GlobalTensor<uint32_t> &actualSeqLengthsGm,
|
|
uint32_t defaultSeqLen)
|
|
{
|
|
if (actualLenDims == 0) {
|
|
return defaultSeqLen;
|
|
} else if (isAccumSeq && bIdx > 0) {
|
|
return actualSeqLengthsGm.GetValue(bIdx) - actualSeqLengthsGm.GetValue(bIdx - 1);
|
|
} else {
|
|
return actualSeqLengthsGm.GetValue(bIdx);
|
|
}
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline void LIPreload<LIT>::GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size)
|
|
{
|
|
actS1Size = GetActualSeqLen(bIdx, constInfo.actualLenQDims, constInfo.isAccumSeqS1, actualSeqLengthsGmQ,
|
|
constInfo.qSeqSize);
|
|
actS2Size =
|
|
GetActualSeqLen(bIdx, constInfo.actualLenDims, constInfo.isAccumSeqS2, actualSeqLengthsGm, constInfo.kSeqSize);
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline uint32_t LIPreload<LIT>::GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size,
|
|
uint32_t actS2Size)
|
|
{
|
|
if (actS2Size == 0) {
|
|
return 0;
|
|
}
|
|
uint32_t s1Offset = constInfo.s1BaseSize * s1gIdx;
|
|
int32_t validS2LenBase = static_cast<int32_t>(actS2Size) - static_cast<int32_t>(actS1Size);
|
|
int32_t validS2Len = s1Offset + validS2LenBase + constInfo.s1BaseSize;
|
|
validS2Len = Min(validS2Len, static_cast<int32_t>(actS2Size));
|
|
validS2Len = Max(validS2Len, 1);
|
|
return (validS2Len + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline uint32_t LIPreload<LIT>::GetTotalBaseBlockNum()
|
|
{
|
|
uint32_t totalBlockNum = 0;
|
|
uint32_t actS1Size, actS2Size;
|
|
uint32_t s1GBaseNum, s2BaseNum;
|
|
for (uint32_t bIdx = 0; bIdx < constInfo.batchSize; bIdx++) {
|
|
GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size);
|
|
s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize);
|
|
if (!constInfo.attenMaskFlag) {
|
|
s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize);
|
|
totalBlockNum += s1GBaseNum * s2BaseNum * constInfo.kHeadNum;
|
|
continue;
|
|
}
|
|
for (uint32_t s1gIdx = 0; s1gIdx < s1GBaseNum; s1gIdx++) {
|
|
s2BaseNum = GetS2BaseBlockNumOnMask(s1gIdx, actS1Size, actS2Size);
|
|
totalBlockNum += s2BaseNum * constInfo.kHeadNum;
|
|
}
|
|
}
|
|
return totalBlockNum;
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ void inline LIPreload<LIT>::SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, LICommon::SplitCoreInfo &info)
|
|
{
|
|
uint32_t totalBlockNum = GetTotalBaseBlockNum();
|
|
uint32_t minBlockPerCore = totalBlockNum / coreNum;
|
|
uint32_t deal1MoreBlockCoreNum = totalBlockNum % coreNum;
|
|
uint32_t coreIdx = 0;
|
|
uint32_t lastGS1RemainBlockCnt = 0;
|
|
uint32_t coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore;
|
|
coreNum = minBlockPerCore == 0 ? deal1MoreBlockCoreNum : coreNum;
|
|
|
|
bool findLastCoreEnd = true;
|
|
uint32_t actS1Size, actS2Size;
|
|
uint32_t s1GBaseNum, s2BaseNum;
|
|
for (uint32_t bN2Idx = 0; bN2Idx < constInfo.batchSize * constInfo.kHeadNum; bN2Idx++) {
|
|
uint32_t bIdx = bN2Idx / constInfo.kHeadNum;
|
|
if (bN2Idx % constInfo.kHeadNum == 0) {
|
|
GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size);
|
|
s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize);
|
|
s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize);
|
|
}
|
|
if constexpr (LAYOUT_T == LI_LAYOUT::BSND) {
|
|
if (findLastCoreEnd && (s1GBaseNum == 0U || s2BaseNum == 0U)) {
|
|
info.bN2Start = bN2Idx;
|
|
info.gS1Start = 0;
|
|
info.s2Start = 0;
|
|
findLastCoreEnd = false;
|
|
}
|
|
}
|
|
for (uint32_t gS1Idx = 0; gS1Idx < s1GBaseNum; gS1Idx++) {
|
|
if (constInfo.attenMaskFlag) {
|
|
s2BaseNum = GetS2BaseBlockNumOnMask(gS1Idx, actS1Size, actS2Size);
|
|
}
|
|
if (findLastCoreEnd && s2BaseNum == 0U) {
|
|
info.bN2Start = bN2Idx;
|
|
info.gS1Start = gS1Idx;
|
|
info.s2Start = 0;
|
|
findLastCoreEnd = false;
|
|
}
|
|
for (uint32_t s2Idx = 0; s2Idx < s2BaseNum;) {
|
|
if (findLastCoreEnd) {
|
|
info.bN2Start = bN2Idx;
|
|
info.gS1Start = gS1Idx;
|
|
info.s2Start = s2Idx;
|
|
findLastCoreEnd = false;
|
|
}
|
|
uint32_t s2RemainBaseNum = s2BaseNum - s2Idx;
|
|
if (lastGS1RemainBlockCnt + s2RemainBaseNum >= coreDealBlockCnt) {
|
|
info.bN2End = bN2Idx;
|
|
info.gS1End = gS1Idx;
|
|
info.s2End = s2Idx + coreDealBlockCnt - lastGS1RemainBlockCnt - 1;
|
|
|
|
if (coreIdx == curCoreIdx) {
|
|
if (s2Idx == 0 && info.s2End + 1 < s2BaseNum) {
|
|
info.isLD = true;
|
|
}
|
|
if (coreIdx == coreNum - 1 && info.bN2End != constInfo.batchSize -1) {
|
|
info.bN2End = constInfo.batchSize -1;
|
|
info.gS1End = 0;
|
|
info.s2End = 0;
|
|
}
|
|
return;
|
|
}
|
|
coreIdx++;
|
|
findLastCoreEnd = true;
|
|
s2Idx = info.s2End + 1;
|
|
lastGS1RemainBlockCnt = 0;
|
|
coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore;
|
|
} else {
|
|
lastGS1RemainBlockCnt += s2RemainBaseNum;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline void LIPreload<LIT>::DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start)
|
|
{
|
|
if ASCEND_IS_AIV {
|
|
if (constInfo.outputLayout == LI_LAYOUT::TND) {
|
|
uint32_t tSize = actualSeqLengthsGmQ.GetValue(constInfo.batchSize - 1);
|
|
uint32_t tBase = bIdx == 0 ? 0 : actualSeqLengthsGmQ.GetValue(bIdx - 1);
|
|
uint32_t s1Count = tempLoopInfo.actS1Size;
|
|
|
|
for (uint32_t s1Idx = s1Start; s1Idx < s1Count; s1Idx++) {
|
|
uint64_t indiceOutOffset =
|
|
(tBase + s1Idx) * constInfo.kHeadNum * constInfo.sparseCount +
|
|
n2Idx * constInfo.sparseCount;
|
|
vectorService.CleanInvalidOutput(indiceOutOffset);
|
|
}
|
|
} else if (constInfo.outputLayout == LI_LAYOUT::BSND) {
|
|
for (uint32_t s1Idx = s1Start; s1Idx < constInfo.qSeqSize; s1Idx++) {
|
|
// B,S1,N2,K
|
|
uint64_t indiceOutOffset = bIdx * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount +
|
|
s1Idx * constInfo.kHeadNum * constInfo.sparseCount +
|
|
n2Idx * constInfo.sparseCount;
|
|
vectorService.CleanInvalidOutput(indiceOutOffset);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline void LIPreload<LIT>::Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
|
|
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths,
|
|
__gm__ uint8_t *blockTable, __gm__ uint8_t *sparseIndices,
|
|
__gm__ uint8_t *workspace, const LITilingData *__restrict tiling,
|
|
TPipe *tPipe)
|
|
{
|
|
if ASCEND_IS_AIV {
|
|
tmpBlockIdx = GetBlockIdx(); // vec:0-47
|
|
aiCoreIdx = tmpBlockIdx / 2;
|
|
} else {
|
|
tmpBlockIdx = GetBlockIdx(); // cube:0-23
|
|
aiCoreIdx = tmpBlockIdx;
|
|
}
|
|
|
|
InitTilingData(tiling);
|
|
InitActualSeqLen(actualSeqLengthsQ, actualSeqLengths);
|
|
|
|
SplitCore(aiCoreIdx, usedCoreNum, splitCoreInfo);
|
|
|
|
pipe = tPipe;
|
|
uint64_t offset = 0;
|
|
uint64_t singleCoreMm1ResSize = WS_DOBULE * constInfo.mBaseSize * constInfo.s2BaseSize * sizeof(MM1_OUT_T);
|
|
mm1ResGm.SetGlobalBuffer((__gm__ MM1_OUT_T *)(workspace + offset + aiCoreIdx * singleCoreMm1ResSize));
|
|
offset += GetBlockNum() * singleCoreMm1ResSize;
|
|
|
|
vec1ResGm.SetGlobalBuffer((__gm__ float *)(workspace + offset));
|
|
offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * WS_DOBULE * BASE_TOPK * sizeof(float);
|
|
|
|
vec1ParamGm.SetGlobalBuffer((__gm__ int64_t *)(workspace + offset));
|
|
offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * LD_PARAM_NUM * sizeof(int64_t);
|
|
|
|
if ASCEND_IS_AIV {
|
|
vectorService.InitParams(constInfo, tiling);
|
|
indiceOutGm.SetGlobalBuffer((__gm__ int32_t *)sparseIndices);
|
|
weightsGm.SetGlobalBuffer((__gm__ K_T *)weights);
|
|
vectorService.InitVec1GlobalTensor(mm1ResGm, vec1ResGm, vec1ParamGm, weightsGm, indiceOutGm);
|
|
} else {
|
|
matmulService.InitParams(constInfo);
|
|
queryGm.SetGlobalBuffer((__gm__ Q_T *)query);
|
|
if constexpr (PAGE_ATTENTION) {
|
|
blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable);
|
|
}
|
|
keyGm.SetGlobalBuffer((__gm__ K_T *)key);
|
|
matmulService.InitMm1GlobalTensor(blockTableGm, keyGm, queryGm, mm1ResGm);
|
|
}
|
|
InitBuffers();
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline void LIPreload<LIT>::GetBN2Idx(uint32_t bN2Idx)
|
|
{
|
|
tempLoopInfo.bN2Idx = bN2Idx;
|
|
tempLoopInfo.bIdx = bN2Idx / constInfo.kHeadNum;
|
|
tempLoopInfo.n2Idx = bN2Idx % constInfo.kHeadNum;
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline void LIPreload<LIT>::CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx)
|
|
{
|
|
tempLoopInfo.gS1Idx = gS1LoopIdx;
|
|
tempLoopInfo.actMBaseSize = constInfo.mBaseSize;
|
|
uint32_t remainedGS1Size = tempLoopInfo.actS1Size * constInfo.gSize - tempLoopInfo.gS1Idx * constInfo.mBaseSize;
|
|
if (remainedGS1Size <= constInfo.mBaseSize && remainedGS1Size > 0) {
|
|
tempLoopInfo.actMBaseSize = tempLoopInfo.mBasicSizeTail;
|
|
}
|
|
|
|
bool isEnd = (bN2LoopIdx == splitCoreInfo.bN2End) && (gS1LoopIdx == splitCoreInfo.gS1End);
|
|
uint32_t s2BlockNum;
|
|
if (constInfo.attenMaskFlag) {
|
|
s2BlockNum = GetS2BaseBlockNumOnMask(gS1LoopIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size);
|
|
} else {
|
|
s2BlockNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
|
|
}
|
|
tempLoopInfo.s2LoopEnd = isEnd ? splitCoreInfo.s2End : s2BlockNum - 1;
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline void LIPreload<LIT>::CalcGS1LoopParams(uint32_t bN2LoopIdx)
|
|
{
|
|
GetBN2Idx(bN2LoopIdx);
|
|
GetS1S2ActualSeqLen(tempLoopInfo.bIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size);
|
|
if ((tempLoopInfo.actS2Size == 0) || (tempLoopInfo.actS1Size == 0)) {
|
|
tempLoopInfo.curActSeqLenIsZero = true;
|
|
return;
|
|
}
|
|
tempLoopInfo.curActSeqLenIsZero = false;
|
|
tempLoopInfo.s2BasicSizeTail = tempLoopInfo.actS2Size % constInfo.s2BaseSize;
|
|
tempLoopInfo.s2BasicSizeTail =
|
|
(tempLoopInfo.s2BasicSizeTail == 0) ? constInfo.s2BaseSize : tempLoopInfo.s2BasicSizeTail;
|
|
tempLoopInfo.mBasicSizeTail = (tempLoopInfo.actS1Size * constInfo.gSize) % constInfo.mBaseSize;
|
|
tempLoopInfo.mBasicSizeTail =
|
|
(tempLoopInfo.mBasicSizeTail == 0) ? constInfo.mBaseSize : tempLoopInfo.mBasicSizeTail;
|
|
|
|
uint32_t gS1SplitNum = (tempLoopInfo.actS1Size * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize;
|
|
tempLoopInfo.gS1LoopEnd = (bN2LoopIdx == splitCoreInfo.bN2End) ? splitCoreInfo.gS1End : gS1SplitNum - 1;
|
|
if constexpr (LAYOUT_T == LI_LAYOUT::BSND) {
|
|
if (tempLoopInfo.gS1LoopEnd == gS1SplitNum - 1 && constInfo.qSeqSize > tempLoopInfo.actS1Size) {
|
|
tempLoopInfo.needDealActS1LessThanS1 = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline void LIPreload<LIT>::CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LICommon::RunInfo &runInfo)
|
|
{
|
|
runInfo.loop = loop;
|
|
runInfo.bIdx = tempLoopInfo.bIdx;
|
|
runInfo.gS1Idx = tempLoopInfo.gS1Idx;
|
|
runInfo.s2Idx = s2LoopIdx;
|
|
runInfo.bN2Idx = tempLoopInfo.bN2Idx;
|
|
|
|
runInfo.actS1Size = tempLoopInfo.actS1Size;
|
|
runInfo.actS2Size = tempLoopInfo.actS2Size;
|
|
runInfo.actMBaseSize = tempLoopInfo.actMBaseSize;
|
|
runInfo.actualSingleProcessSInnerSize = constInfo.s2BaseSize;
|
|
uint32_t s2SplitNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
|
|
if (runInfo.s2Idx == s2SplitNum - 1) {
|
|
runInfo.actualSingleProcessSInnerSize = tempLoopInfo.s2BasicSizeTail;
|
|
}
|
|
runInfo.actualSingleProcessSInnerSizeAlign =
|
|
LICommon::Align((uint32_t)runInfo.actualSingleProcessSInnerSize, LICommon::ConstInfo::BUFFER_SIZE_BYTE_32B);
|
|
|
|
runInfo.isFirstS2InnerLoop = s2LoopIdx == splitCoreInfo.s2Start;
|
|
runInfo.isLastS2InnerLoop = s2LoopIdx == tempLoopInfo.s2LoopEnd;
|
|
runInfo.isAllLoopEnd = (runInfo.bN2Idx == splitCoreInfo.bN2End) && (runInfo.gS1Idx == splitCoreInfo.gS1End) &&
|
|
(runInfo.s2Idx == splitCoreInfo.s2End);
|
|
|
|
if (runInfo.isFirstS2InnerLoop) {
|
|
uint64_t actualSeqQPrefixSum;
|
|
uint64_t actualSeqKPrefixSum;
|
|
if constexpr (LAYOUT_T == LI_LAYOUT::TND) {
|
|
actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGmQ.GetValue(runInfo.bIdx - 1);
|
|
actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGm.GetValue(runInfo.bIdx - 1);
|
|
} else { // BSND
|
|
actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.qSeqSize;
|
|
actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.kSeqSize;
|
|
}
|
|
uint64_t tndBIdxOffset = actualSeqQPrefixSum * constInfo.qHeadNum * constInfo.headDim;
|
|
uint64_t tndKeyBIdxOffset = actualSeqKPrefixSum * constInfo.kHeadNum * constInfo.headDim;
|
|
// B,S1,N1(N2,G),D
|
|
queryCoreOffset = tndBIdxOffset + runInfo.gS1Idx * constInfo.mBaseSize * constInfo.headDim;
|
|
keyCoreOffset = tndKeyBIdxOffset + runInfo.n2Idx * constInfo.headDim;
|
|
// B,S1,N1(N2,G)/T,N1(N2,G)
|
|
weightsCoreOffset = actualSeqQPrefixSum * constInfo.qHeadNum + runInfo.n2Idx * constInfo.gSize;
|
|
// B,S1,N2,k/T,N2,k
|
|
indiceOutCoreOffset = actualSeqQPrefixSum * constInfo.kHeadNum * constInfo.sparseCount +
|
|
runInfo.n2Idx * constInfo.sparseCount;
|
|
}
|
|
runInfo.tensorQueryOffset = queryCoreOffset;
|
|
runInfo.tensorKeyOffset = keyCoreOffset + runInfo.s2Idx * constInfo.s2BaseSize * constInfo.kHeadNum
|
|
* constInfo.headDim;
|
|
runInfo.tensorWeightsOffset = weightsCoreOffset;
|
|
runInfo.indiceOutOffset = indiceOutCoreOffset;
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline void LIPreload<LIT>::Process()
|
|
{
|
|
if (usedCoreNum == 0) {
|
|
ProcessInvalid();
|
|
return;
|
|
}
|
|
ProcessMain();
|
|
ProcessDecode();
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline void LIPreload<LIT>::ProcessInvalid()
|
|
{
|
|
if ASCEND_IS_AIV {
|
|
uint32_t aivCoreNum = GetBlockNum() * 2; // 2 means c:v = 1:2
|
|
uint64_t totalOutputSize =
|
|
constInfo.batchSize * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount;
|
|
uint64_t singleCoreSize =
|
|
LICommon::Align((totalOutputSize + aivCoreNum - 1) / aivCoreNum, GM_ALIGN_BYTES / sizeof(OUT_T));
|
|
uint64_t baseSize = tmpBlockIdx * singleCoreSize;
|
|
if (baseSize < totalOutputSize) {
|
|
uint64_t dealSize =
|
|
(baseSize + singleCoreSize > totalOutputSize) ? singleCoreSize : totalOutputSize - baseSize;
|
|
GlobalTensor<OUT_T> output = indiceOutGm[baseSize];
|
|
AscendC::InitGlobalMemory(output, dealSize, constInfo.INVALID_IDX);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline void LIPreload<LIT>::ProcessMain()
|
|
{
|
|
if (aiCoreIdx >= usedCoreNum) {
|
|
return;
|
|
}
|
|
|
|
if ASCEND_IS_AIV {
|
|
vectorService.AllocEventID();
|
|
CrossCoreSetFlag<LICommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE2>(constInfo.syncV1C1);
|
|
CrossCoreSetFlag<LICommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE2>(constInfo.syncV1C1);
|
|
} else {
|
|
matmulService.AllocEventID();
|
|
}
|
|
|
|
LICommon::RunInfo runInfo;
|
|
uint32_t gloop = 0;
|
|
for (uint32_t bN2LoopIdx = splitCoreInfo.bN2Start; bN2LoopIdx <= splitCoreInfo.bN2End; bN2LoopIdx++) {
|
|
CalcGS1LoopParams(bN2LoopIdx);
|
|
if (tempLoopInfo.curActSeqLenIsZero) {
|
|
DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, 0U);
|
|
continue;
|
|
}
|
|
for (uint32_t gS1LoopIdx = splitCoreInfo.gS1Start; gS1LoopIdx <= tempLoopInfo.gS1LoopEnd; gS1LoopIdx++) {
|
|
CalcS2LoopParams(bN2LoopIdx, gS1LoopIdx);
|
|
for (int s2LoopIdx = splitCoreInfo.s2Start; s2LoopIdx <= tempLoopInfo.s2LoopEnd; s2LoopIdx++) {
|
|
ProcessBaseBlock(gloop, s2LoopIdx, runInfo);
|
|
++gloop;
|
|
}
|
|
splitCoreInfo.s2Start = 0;
|
|
}
|
|
if (tempLoopInfo.needDealActS1LessThanS1) {
|
|
DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, tempLoopInfo.actS1Size);
|
|
}
|
|
splitCoreInfo.gS1Start = 0;
|
|
}
|
|
|
|
if ASCEND_IS_AIV {
|
|
vectorService.FreeEventID();
|
|
} else {
|
|
matmulService.FreeEventID();
|
|
CrossCoreWaitFlag(constInfo.syncV1C1);
|
|
CrossCoreWaitFlag(constInfo.syncV1C1);
|
|
}
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline void LIPreload<LIT>::ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx, LICommon::RunInfo &runInfo)
|
|
{
|
|
CalcRunInfo(loop, s2LoopIdx, runInfo);
|
|
if ASCEND_IS_AIC {
|
|
CrossCoreWaitFlag(constInfo.syncV1C1);
|
|
matmulService.ComputeMm1(runInfo);
|
|
CrossCoreSetFlag<LICommon::ConstInfo::FIA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V1);
|
|
} else {
|
|
CrossCoreWaitFlag(constInfo.syncC1V1);
|
|
vectorService.ProcessVec(runInfo);
|
|
CrossCoreSetFlag<LICommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE2>(constInfo.syncV1C1);
|
|
}
|
|
}
|
|
|
|
template <typename LIT>
|
|
__aicore__ inline void LIPreload<LIT>::ProcessDecode()
|
|
{
|
|
if ASCEND_IS_AIV {
|
|
vectorService.InitLDBuffers(pipe);
|
|
ICachePreLoad(LD_PREFETCH_LEN);
|
|
SyncAll();
|
|
if (splitCoreInfo.isLD) {
|
|
vectorService.ProcessLD();
|
|
}
|
|
}
|
|
}
|
|
} // namespace LIKernel
|
|
#endif // LIGHTNING_INDEXER_KERNEL_H
|