[Perf][1/N] w8a8c8 support in dsv3.2/glm5 (#7029)
### What this PR does / why we need it?
This PR supports W8A8C8 in dsv3.2/glm5 with lightning_indexer_quant ops
in pd-mix stage mainly.
Because the code for the current PD-disaggregated scenario is still
under refactoring and cleanup, this PR prioritizes ensuring the C8
functionality in the pd-mix scenario.
The next steps are planned in two parts:
① Once the optimized scatter operator is updated, we will replace the
original operator to improve the performance of storing k_scale.
② Once the code logic for the PD-disaggregated scenario becomes stable,
we will carry out more comprehensive validation and make appropriate
adaptations.
③ Because enabling C8 currently introduces several new operators whose
performance still needs improvement, performance may regress in some
scenarios. Therefore, only after all the operators are fully ready can
we ensure that this feature does not cause any performance degradation.
At that point, we will enable this feature by default and remove the
switch in `additional_config`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -0,0 +1,50 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
#include "lightning_indexer_quant_kernel.h"
|
||||
#include "lightning_indexer_quant_template_tiling_key.h"
|
||||
|
||||
using namespace LIQKernel;
|
||||
|
||||
#define INVOKE_LI_NO_KFC_OP_IMPL(templateClass, ...) \
|
||||
do { \
|
||||
templateClass<LIQType<__VA_ARGS__>> op; \
|
||||
GET_TILING_DATA_WITH_STRUCT(LIQTilingData, tiling_data_in, tiling); \
|
||||
const LIQTilingData *__restrict tiling_data = &tiling_data_in; \
|
||||
op.Init(query, key, weights, queryScale, keyScale, actualSeqLengthsQ, actualSeqLengthsK, blocktable, \
|
||||
sparseIndices, user, tiling_data, &tPipe); \
|
||||
op.Process(); \
|
||||
} while (0)
|
||||
|
||||
template <int DT_Q, int DT_K, int DT_OUT, int PAGE_ATTENTION, int Q_LAYOUT_T, int K_LAYOUT_T>
|
||||
__global__ __aicore__ void lightning_indexer_quant(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
|
||||
__gm__ uint8_t *queryScale, __gm__ uint8_t *keyScale,
|
||||
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengthsK,
|
||||
__gm__ uint8_t *blocktable, __gm__ uint8_t *sparseIndices,
|
||||
__gm__ uint8_t *workspace, __gm__ uint8_t *tiling)
|
||||
{
|
||||
#if (__CCE_AICORE__ == 310) || (defined __DAV_310R6__) || (__CCE_AICORE__ == 200)
|
||||
|
||||
#else
|
||||
TPipe tPipe;
|
||||
__gm__ uint8_t *user = GetUserWorkspace(workspace);
|
||||
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2);
|
||||
|
||||
INVOKE_LI_NO_KFC_OP_IMPL(LIQPreload, int8_t, int8_t, int32_t,
|
||||
PAGE_ATTENTION, LI_LAYOUT(Q_LAYOUT_T), LI_LAYOUT(K_LAYOUT_T));
|
||||
#endif
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_common.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef LIGHTNING_INDEXER_QUANT_COMMON_H
|
||||
#define LIGHTNING_INDEXER_QUANT_COMMON_H
|
||||
|
||||
namespace LIQCommon {
|
||||
|
||||
// 与tiling的layout保持一致
|
||||
enum class LI_LAYOUT : uint32_t {
|
||||
BSND = 0,
|
||||
TND = 1,
|
||||
PA_BSND = 2
|
||||
};
|
||||
|
||||
template <typename Q_T, typename K_T, typename OUT_T, const bool PAGE_ATTENTION = false,
|
||||
LI_LAYOUT Q_LAYOUT_T = LI_LAYOUT::BSND, LI_LAYOUT K_LAYOUT_T = LI_LAYOUT::PA_BSND, typename... Args>
|
||||
struct LIQType {
|
||||
using queryType = Q_T;
|
||||
using keyType = K_T;
|
||||
using outputType = OUT_T;
|
||||
static constexpr bool pageAttention = PAGE_ATTENTION;
|
||||
static constexpr LI_LAYOUT layout = Q_LAYOUT_T;
|
||||
static constexpr LI_LAYOUT keyLayout = K_LAYOUT_T;
|
||||
};
|
||||
|
||||
struct RunInfo {
|
||||
uint32_t loop;
|
||||
uint32_t bN2Idx;
|
||||
uint32_t bIdx;
|
||||
uint32_t n2Idx = 0;
|
||||
uint32_t gS1Idx;
|
||||
uint32_t s2Idx;
|
||||
|
||||
uint32_t actS1Size = 1;
|
||||
uint32_t actS2Size = 1;
|
||||
uint32_t actMBaseSize;
|
||||
uint32_t actualSingleProcessSInnerSize;
|
||||
uint32_t actualSingleProcessSInnerSizeAlign;
|
||||
|
||||
uint64_t tensorQueryOffset;
|
||||
uint64_t tensorKeyOffset;
|
||||
uint64_t tensorKeyScaleOffset;
|
||||
uint64_t tensorWeightsOffset;
|
||||
uint64_t indiceOutOffset;
|
||||
|
||||
bool isFirstS2InnerLoop;
|
||||
bool isLastS2InnerLoop;
|
||||
bool isAllLoopEnd = false;
|
||||
bool isValid = false;
|
||||
};
|
||||
|
||||
struct ConstInfo {
|
||||
// CUBE与VEC核间同步的模式
|
||||
static constexpr uint32_t FIA_SYNC_MODE2 = 2;
|
||||
// BUFFER的字节数
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_32B = 32;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_64B = 64;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_256B = 256;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_512B = 512;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_1K = 1024;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_2K = 2048;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_4K = 4096;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_8K = 8192;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_16K = 16384;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_32K = 32768;
|
||||
// 无效索引
|
||||
static constexpr int INVALID_IDX = -1;
|
||||
|
||||
// CUBE和VEC的核间同步EventID
|
||||
uint32_t syncC1V1 = 0U;
|
||||
uint32_t syncC1V0 = 2U;
|
||||
uint32_t syncV1C1 = 0U;
|
||||
uint32_t syncV0C1 = 1U;
|
||||
|
||||
// 基本块大小
|
||||
uint32_t mBaseSize = 1ULL;
|
||||
uint32_t s1BaseSize = 1ULL;
|
||||
uint32_t s2BaseSize = 1ULL;
|
||||
|
||||
uint64_t batchSize = 0ULL;
|
||||
uint64_t gSize = 0ULL;
|
||||
uint64_t qHeadNum = 0ULL;
|
||||
uint64_t kHeadNum;
|
||||
uint64_t headDim;
|
||||
uint64_t sparseCount; // topK选取大小
|
||||
uint64_t kSeqSize = 0ULL; // kv最大S长度
|
||||
uint64_t qSeqSize = 1ULL; // q最大S长度
|
||||
uint32_t kCacheBlockSize = 0; // PA场景的block size
|
||||
uint32_t maxBlockNumPerBatch = 0; // PA场景的最大单batch block number
|
||||
LI_LAYOUT outputLayout; // 输出的格式
|
||||
bool attenMaskFlag = false;
|
||||
|
||||
uint32_t actualLenQDims = 0U; // query的actualSeqLength 的维度
|
||||
uint32_t actualLenDims = 0U; // KV 的actualSeqLength 的维度
|
||||
bool isAccumSeqS1 = false; // 是否累加模式
|
||||
bool isAccumSeqS2 = false; // 是否累加模式
|
||||
};
|
||||
|
||||
struct SplitCoreInfo {
|
||||
uint32_t s2Start = 0U; // S2的起始位置
|
||||
uint32_t s2End = 0U; // S2循环index上限
|
||||
uint32_t bN2Start = 0U;
|
||||
uint32_t bN2End = 0U;
|
||||
uint32_t gS1Start = 0U;
|
||||
uint32_t gS1End = 0U;
|
||||
bool isLD = false; // 当前核是否需要进行Decode归约任务
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T Align(T num, T rnd)
|
||||
{
|
||||
return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd) * (rnd)));
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__aicore__ inline T1 Min(T1 a, T2 b)
|
||||
{
|
||||
return (a > b) ? (b) : (a);
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__aicore__ inline T1 Max(T1 a, T2 b)
|
||||
{
|
||||
return (a > b) ? (a) : (b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T CeilDiv(T num, T rnd)
|
||||
{
|
||||
return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd)));
|
||||
}
|
||||
} // namespace LIQCommon
|
||||
|
||||
#endif // LIGHTNING_INDEXER_QUANT_COMMON_H
|
||||
@@ -0,0 +1,714 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_kernel.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef LIGHTNING_INDEXER_QUANT_KERNEL_H
|
||||
#define LIGHTNING_INDEXER_QUANT_KERNEL_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_operator_list_tensor_intf.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
#include "lib/matrix/matmul/tiling.h"
|
||||
#include "lightning_indexer_quant_common.h"
|
||||
#include "lightning_indexer_quant_service_vector.h"
|
||||
#include "lightning_indexer_quant_service_cube.h"
|
||||
|
||||
namespace LIQKernel {
|
||||
using namespace LIQCommon;
|
||||
using namespace LIQServiceVec;
|
||||
using namespace matmul;
|
||||
using AscendC::CacheMode;
|
||||
using AscendC::CrossCoreSetFlag;
|
||||
using AscendC::CrossCoreWaitFlag;
|
||||
|
||||
// 由于S2循环前,RunInfo还没有赋值,使用TempLoopInfo临时存放B、N、S1轴相关的信息;同时减少重复计算
|
||||
struct TempLoopInfo {
|
||||
uint32_t bN2Idx = 0;
|
||||
uint32_t bIdx = 0U;
|
||||
uint32_t n2Idx = 0U;
|
||||
uint32_t gS1Idx = 0U;
|
||||
uint32_t gS1LoopEnd = 0U; // gS1方向循环的结束Idx
|
||||
uint32_t s2LoopEnd = 0U; // S2方向循环的结束Idx
|
||||
uint32_t actS1Size = 1ULL; // 当前Batch循环处理的S1轴的实际大小
|
||||
uint32_t actS2Size = 0ULL;
|
||||
bool curActSeqLenIsZero = false;
|
||||
bool needDealActS1LessThanS1 = false; // S1的实际长度小于shape的S1长度时,是否需要清理输出
|
||||
uint32_t actMBaseSize = 0U; // m轴(gS1)方向实际大小
|
||||
uint32_t mBasicSizeTail = 0U; // gS1方向循环的尾基本块大小
|
||||
uint32_t s2BasicSizeTail = 0U; // S2方向循环的尾基本块大小
|
||||
};
|
||||
|
||||
template <typename LIQT>
|
||||
class LIQPreload {
|
||||
public:
|
||||
__aicore__ inline LIQPreload(){};
|
||||
__aicore__ inline void Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
|
||||
__gm__ uint8_t *queryScale, __gm__ uint8_t *keyScale, __gm__ uint8_t *actualSeqLengthsQ,
|
||||
__gm__ uint8_t *actualSeqLengthsK, __gm__ uint8_t *blockTable,
|
||||
__gm__ uint8_t *sparseIndices, __gm__ uint8_t *workspace,
|
||||
const LIQTilingData *__restrict tiling, TPipe *tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
// =================================类型定义区=================================
|
||||
using Q_T = typename LIQT::queryType;
|
||||
using K_T = typename LIQT::keyType;
|
||||
using OUT_T = typename LIQT::outputType;
|
||||
static constexpr bool PAGE_ATTENTION = LIQT::pageAttention;
|
||||
static constexpr LI_LAYOUT Q_LAYOUT_T = LIQT::layout;
|
||||
static constexpr LI_LAYOUT K_LAYOUT_T = LIQT::keyLayout;
|
||||
|
||||
using MM1_OUT_T = float;
|
||||
|
||||
LIQMatmul<LIQT> matmulService;
|
||||
LIQVector<LIQT> vectorService;
|
||||
|
||||
// =================================常量区=================================
|
||||
static constexpr uint32_t SYNC_C1_V1_FLAG = 4;
|
||||
static constexpr uint32_t SYNC_V1_C1_FLAG = 5;
|
||||
|
||||
static constexpr uint32_t M_BASE_SIZE = 256;
|
||||
static constexpr uint32_t S2_BASE_SIZE = 2048;
|
||||
static constexpr uint32_t HEAD_DIM = 128;
|
||||
static constexpr uint32_t K_HEAD_NUM = 1;
|
||||
static constexpr uint32_t GM_ALIGN_BYTES = 512;
|
||||
static constexpr uint32_t LI_QUANT_PRELOAD_TASK_CACHE_SIZE = 2;
|
||||
|
||||
static constexpr int64_t LD_PREFETCH_LEN = 2;
|
||||
// for workspace double
|
||||
static constexpr uint32_t WS_DOBULE = 2;
|
||||
|
||||
protected:
|
||||
TPipe *pipe = nullptr;
|
||||
|
||||
// offset
|
||||
uint64_t queryCoreOffset = 0ULL;
|
||||
uint64_t keyCoreOffset = 0ULL;
|
||||
uint64_t keyScaleCoreOffset = 0ULL;
|
||||
uint64_t weightsCoreOffset = 0ULL;
|
||||
uint64_t indiceOutCoreOffset = 0ULL;
|
||||
|
||||
// ================================Global Buffer区=================================
|
||||
GlobalTensor<Q_T> queryGm;
|
||||
GlobalTensor<K_T> keyGm;
|
||||
GlobalTensor<half> weightsGm;
|
||||
|
||||
GlobalTensor<int32_t> indiceOutGm;
|
||||
GlobalTensor<int32_t> blockTableGm;
|
||||
|
||||
GlobalTensor<uint32_t> actualSeqLengthsGmQ;
|
||||
GlobalTensor<uint32_t> actualSeqLengthsGm;
|
||||
|
||||
// ================================类成员变量====================================
|
||||
// aic、aiv核信息
|
||||
uint32_t tmpBlockIdx = 0U;
|
||||
uint32_t aiCoreIdx = 0U;
|
||||
uint32_t usedCoreNum = 0U;
|
||||
|
||||
LIQCommon::ConstInfo constInfo{};
|
||||
TempLoopInfo tempLoopInfo{};
|
||||
LIQCommon::SplitCoreInfo splitCoreInfo{};
|
||||
|
||||
// ================================Init functions==================================
|
||||
__aicore__ inline void InitTilingData(const LIQTilingData *__restrict tilingData);
|
||||
__aicore__ inline void InitBuffers();
|
||||
__aicore__ inline void InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengthsK);
|
||||
// ================================Split Core================================
|
||||
__aicore__ inline void SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, LIQCommon::SplitCoreInfo &info);
|
||||
__aicore__ inline uint32_t GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size, uint32_t actS2Size);
|
||||
__aicore__ inline uint32_t GetTotalBaseBlockNum();
|
||||
// ================================Process functions================================
|
||||
__aicore__ inline void ProcessMain();
|
||||
__aicore__ inline void ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx,
|
||||
LIQCommon::RunInfo runInfo[LI_QUANT_PRELOAD_TASK_CACHE_SIZE]);
|
||||
__aicore__ inline void ProcessDecode();
|
||||
__aicore__ inline void ProcessInvalid();
|
||||
// ================================Params Calc=====================================
|
||||
__aicore__ inline void CalcGS1LoopParams(uint32_t bN2Idx);
|
||||
__aicore__ inline void GetBN2Idx(uint32_t bN2Idx);
|
||||
__aicore__ inline uint32_t GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq,
|
||||
GlobalTensor<uint32_t> &actualSeqLengthsGm, uint32_t defaultSeqLen);
|
||||
__aicore__ inline void GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size);
|
||||
__aicore__ inline void CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx);
|
||||
__aicore__ inline void CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LIQCommon::RunInfo &runInfo);
|
||||
__aicore__ inline void DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start);
|
||||
};
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::InitTilingData(const LIQTilingData *__restrict tilingData)
|
||||
{
|
||||
usedCoreNum = tilingData->usedCoreNum;
|
||||
constInfo.batchSize = tilingData->bSize;
|
||||
constInfo.qHeadNum = constInfo.gSize = tilingData->gSize;
|
||||
constInfo.kSeqSize = tilingData->s2Size;
|
||||
constInfo.qSeqSize = tilingData->s1Size;
|
||||
constInfo.attenMaskFlag = (tilingData->sparseMode == 3);
|
||||
constInfo.kCacheBlockSize = tilingData->blockSize;
|
||||
constInfo.maxBlockNumPerBatch = tilingData->maxBlockNumPerBatch;
|
||||
constInfo.sparseCount = tilingData->sparseCount;
|
||||
constInfo.outputLayout = Q_LAYOUT_T; // 输出和输入形状一致
|
||||
if (Q_LAYOUT_T == LI_LAYOUT::TND) {
|
||||
constInfo.isAccumSeqS1 = true;
|
||||
}
|
||||
if (K_LAYOUT_T == LI_LAYOUT::TND) {
|
||||
constInfo.isAccumSeqS2 = true;
|
||||
}
|
||||
|
||||
constInfo.kHeadNum = K_HEAD_NUM;
|
||||
constInfo.headDim = HEAD_DIM;
|
||||
|
||||
constInfo.mBaseSize = M_BASE_SIZE;
|
||||
constInfo.s2BaseSize = S2_BASE_SIZE;
|
||||
constInfo.s1BaseSize = (constInfo.mBaseSize + constInfo.gSize - 1) / constInfo.gSize;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::InitBuffers()
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.InitBuffers(pipe);
|
||||
} else {
|
||||
matmulService.InitBuffers(pipe);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ,
|
||||
__gm__ uint8_t *actualSeqLengthsK)
|
||||
{
|
||||
if (actualSeqLengthsQ == nullptr) {
|
||||
constInfo.actualLenQDims = 0;
|
||||
} else {
|
||||
constInfo.actualLenQDims = constInfo.batchSize;
|
||||
actualSeqLengthsGmQ.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengthsQ, constInfo.actualLenQDims);
|
||||
}
|
||||
if (actualSeqLengthsK == nullptr) {
|
||||
constInfo.actualLenDims = 0;
|
||||
} else {
|
||||
constInfo.actualLenDims = constInfo.batchSize;
|
||||
actualSeqLengthsGm.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengthsK, constInfo.actualLenDims);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline uint32_t LIQPreload<LIQT>::GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq,
|
||||
GlobalTensor<uint32_t> &actualSeqLengthsGm,
|
||||
uint32_t defaultSeqLen)
|
||||
{
|
||||
if (actualLenDims == 0) {
|
||||
return defaultSeqLen;
|
||||
} else if (isAccumSeq && bIdx > 0) {
|
||||
return actualSeqLengthsGm.GetValue(bIdx) - actualSeqLengthsGm.GetValue(bIdx - 1);
|
||||
} else {
|
||||
return actualSeqLengthsGm.GetValue(bIdx);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size)
|
||||
{
|
||||
actS1Size = GetActualSeqLen(bIdx, constInfo.actualLenQDims, constInfo.isAccumSeqS1, actualSeqLengthsGmQ,
|
||||
constInfo.qSeqSize);
|
||||
actS2Size =
|
||||
GetActualSeqLen(bIdx, constInfo.actualLenDims, constInfo.isAccumSeqS2, actualSeqLengthsGm, constInfo.kSeqSize);
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline uint32_t LIQPreload<LIQT>::GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size,
|
||||
uint32_t actS2Size)
|
||||
{
|
||||
if (actS2Size == 0) {
|
||||
return 0;
|
||||
}
|
||||
uint32_t s1Offset = constInfo.s1BaseSize * s1gIdx;
|
||||
int32_t validS2LenBase = static_cast<int32_t>(actS2Size) - static_cast<int32_t>(actS1Size);
|
||||
int32_t validS2Len = s1Offset + validS2LenBase + constInfo.s1BaseSize;
|
||||
validS2Len = Min(validS2Len, static_cast<int32_t>(actS2Size));
|
||||
validS2Len = Max(validS2Len, 1);
|
||||
return (validS2Len + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline uint32_t LIQPreload<LIQT>::GetTotalBaseBlockNum()
|
||||
{
|
||||
uint32_t totalBlockNum = 0;
|
||||
uint32_t actS1Size, actS2Size;
|
||||
uint32_t s1GBaseNum, s2BaseNum;
|
||||
for (uint32_t bIdx = 0; bIdx < constInfo.batchSize; bIdx++) {
|
||||
GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size);
|
||||
s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize);
|
||||
if (!constInfo.attenMaskFlag) {
|
||||
s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize);
|
||||
totalBlockNum += s1GBaseNum * s2BaseNum * constInfo.kHeadNum;
|
||||
continue;
|
||||
}
|
||||
for (uint32_t s1gIdx = 0; s1gIdx < s1GBaseNum; s1gIdx++) {
|
||||
s2BaseNum = GetS2BaseBlockNumOnMask(s1gIdx, actS1Size, actS2Size);
|
||||
totalBlockNum += s2BaseNum * constInfo.kHeadNum;
|
||||
}
|
||||
}
|
||||
return totalBlockNum;
|
||||
}
|
||||
|
||||
// 多核版本,双闭区间。基本原则:计算每个核最少处理的块数, 剩余的部分前面的核每个核多处理一块
|
||||
template <typename LIQT>
|
||||
__aicore__ void inline LIQPreload<LIQT>::SplitCore(uint32_t curCoreIdx, uint32_t &coreNum,
|
||||
LIQCommon::SplitCoreInfo &info)
|
||||
{
|
||||
uint32_t totalBlockNum = GetTotalBaseBlockNum();
|
||||
uint32_t minBlockPerCore = totalBlockNum / coreNum;
|
||||
uint32_t deal1MoreBlockCoreNum = totalBlockNum % coreNum;
|
||||
uint32_t coreIdx = 0;
|
||||
uint32_t lastGS1RemainBlockCnt = 0;
|
||||
uint32_t coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore;
|
||||
coreNum = minBlockPerCore == 0 ? deal1MoreBlockCoreNum : coreNum;
|
||||
|
||||
bool findLastCoreEnd = true;
|
||||
uint32_t actS1Size, actS2Size;
|
||||
uint32_t s1GBaseNum, s2BaseNum;
|
||||
for (uint32_t bN2Idx = 0; bN2Idx < constInfo.batchSize * constInfo.kHeadNum; bN2Idx++) {
|
||||
uint32_t bIdx = bN2Idx / constInfo.kHeadNum;
|
||||
if (bN2Idx % constInfo.kHeadNum == 0) {
|
||||
GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size);
|
||||
s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize);
|
||||
s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize);
|
||||
}
|
||||
if constexpr (Q_LAYOUT_T == LI_LAYOUT::BSND) {
|
||||
if (findLastCoreEnd && (s1GBaseNum == 0U || s2BaseNum == 0U)) {
|
||||
info.bN2Start = bN2Idx;
|
||||
info.gS1Start = 0;
|
||||
info.s2Start = 0;
|
||||
findLastCoreEnd = false;
|
||||
}
|
||||
}
|
||||
for (uint32_t gS1Idx = 0; gS1Idx < s1GBaseNum; gS1Idx++) {
|
||||
if (constInfo.attenMaskFlag) {
|
||||
s2BaseNum = GetS2BaseBlockNumOnMask(gS1Idx, actS1Size, actS2Size);
|
||||
}
|
||||
if (findLastCoreEnd && s2BaseNum == 0U) {
|
||||
info.bN2Start = bN2Idx;
|
||||
info.gS1Start = gS1Idx;
|
||||
info.s2Start = 0;
|
||||
findLastCoreEnd = false;
|
||||
}
|
||||
for (uint32_t s2Idx = 0; s2Idx < s2BaseNum;) {
|
||||
if (findLastCoreEnd) {
|
||||
info.bN2Start = bN2Idx;
|
||||
info.gS1Start = gS1Idx;
|
||||
info.s2Start = s2Idx;
|
||||
findLastCoreEnd = false;
|
||||
}
|
||||
uint32_t s2RemainBaseNum = s2BaseNum - s2Idx;
|
||||
if (lastGS1RemainBlockCnt + s2RemainBaseNum >= coreDealBlockCnt) {
|
||||
info.bN2End = bN2Idx;
|
||||
info.gS1End = gS1Idx;
|
||||
info.s2End = s2Idx + coreDealBlockCnt - lastGS1RemainBlockCnt - 1;
|
||||
|
||||
if (coreIdx == curCoreIdx) {
|
||||
// S2被切N核,那么只有第一个核需要处理LD,其他核不用
|
||||
if (s2Idx == 0 && info.s2End + 1 < s2BaseNum) {
|
||||
info.isLD = true;
|
||||
}
|
||||
// 最后一个核处理的不是最后一个Batch,表明后面的Batch为空块(S2=0), 调整终点坐标以便清理输出
|
||||
if (coreIdx == coreNum - 1 && info.bN2End != constInfo.batchSize - 1) {
|
||||
info.bN2End = constInfo.batchSize - 1;
|
||||
info.gS1End = 0;
|
||||
info.s2End = 0;
|
||||
}
|
||||
return;
|
||||
}
|
||||
coreIdx++;
|
||||
findLastCoreEnd = true;
|
||||
s2Idx = info.s2End + 1;
|
||||
lastGS1RemainBlockCnt = 0;
|
||||
coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore;
|
||||
} else {
|
||||
lastGS1RemainBlockCnt += s2RemainBaseNum;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start)
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
if (constInfo.outputLayout == LI_LAYOUT::TND) {
|
||||
uint32_t tSize = actualSeqLengthsGmQ.GetValue(constInfo.batchSize - 1);
|
||||
uint32_t tBase = bIdx == 0 ? 0 : actualSeqLengthsGmQ.GetValue(bIdx - 1);
|
||||
uint32_t s1Count = tempLoopInfo.actS1Size;
|
||||
|
||||
for (uint32_t s1Idx = s1Start; s1Idx < s1Count; s1Idx++) {
|
||||
uint64_t indiceOutOffset =
|
||||
(tBase + s1Idx) * constInfo.kHeadNum * constInfo.sparseCount + // T轴、s1轴偏移
|
||||
n2Idx * constInfo.sparseCount; // N2轴偏移
|
||||
vectorService.CleanInvalidOutput(indiceOutOffset);
|
||||
}
|
||||
} else if (constInfo.outputLayout == LI_LAYOUT::BSND) {
|
||||
for (uint32_t s1Idx = s1Start; s1Idx < constInfo.qSeqSize; s1Idx++) {
|
||||
// B,S1,N2,K
|
||||
uint64_t indiceOutOffset = bIdx * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount +
|
||||
s1Idx * constInfo.kHeadNum * constInfo.sparseCount + // B轴、S1轴偏移
|
||||
n2Idx * constInfo.sparseCount; // N2轴偏移
|
||||
vectorService.CleanInvalidOutput(indiceOutOffset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
|
||||
__gm__ uint8_t *queryScale, __gm__ uint8_t *keyScale,
|
||||
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengthsK,
|
||||
__gm__ uint8_t *blockTable, __gm__ uint8_t *sparseIndices,
|
||||
__gm__ uint8_t *workspace, const LIQTilingData *__restrict tiling,
|
||||
TPipe *tPipe)
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
tmpBlockIdx = GetBlockIdx(); // vec:0-47
|
||||
aiCoreIdx = tmpBlockIdx / 2;
|
||||
} else {
|
||||
tmpBlockIdx = GetBlockIdx(); // cube:0-23
|
||||
aiCoreIdx = tmpBlockIdx;
|
||||
}
|
||||
|
||||
InitTilingData(tiling);
|
||||
InitActualSeqLen(actualSeqLengthsQ, actualSeqLengthsK);
|
||||
|
||||
// 计算分核
|
||||
SplitCore(aiCoreIdx, usedCoreNum, splitCoreInfo);
|
||||
|
||||
pipe = tPipe;
|
||||
// workspace 内存排布
|
||||
// |mm1ResGm(存S)|vec1ResGm(存LD中间结果)|vec1ParamGm(存LD参数)
|
||||
// |Core0_mm1ResDB0-Core0_mm1ResDB1-Core1_mm1ResDB0....Core23_mm1ResDB0-Core23_mm1ResDB1|Core0_vec1Res...
|
||||
uint64_t offset = 0;
|
||||
|
||||
// mm1开DoubleBuffer
|
||||
GlobalTensor<MM1_OUT_T> mm1ResGm; // 存放S
|
||||
uint64_t singleCoreMm1ResSize = WS_DOBULE * constInfo.s1BaseSize * constInfo.s2BaseSize * sizeof(MM1_OUT_T);
|
||||
mm1ResGm.SetGlobalBuffer((__gm__ MM1_OUT_T *)(workspace + aiCoreIdx * singleCoreMm1ResSize));
|
||||
offset += GetBlockNum() * singleCoreMm1ResSize;
|
||||
|
||||
// ld流程需要ws大小: [aicnum, 2, CeilDiv(constInfo.mBaseSize, constInfo.gSize), topkOut_*2]
|
||||
// (aic, 8, 2, 2, 2048)
|
||||
// (aic, s1_cube, 头尾, idx/value, K)
|
||||
GlobalTensor<float> vec1ResGm; // 存放TopK计算中间结果
|
||||
vec1ResGm.SetGlobalBuffer((__gm__ float *)(workspace + offset));
|
||||
offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * WS_DOBULE * BASE_TOPK * sizeof(float);
|
||||
|
||||
// (aic, 8, 2, 16)
|
||||
// (aic, s1_cube, 头尾,16ele)
|
||||
GlobalTensor<int64_t> vec1ParamGm; // 存放LD参数信息
|
||||
vec1ParamGm.SetGlobalBuffer((__gm__ int64_t *)(workspace + offset));
|
||||
offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * LD_PARAM_NUM * sizeof(int64_t);
|
||||
|
||||
GlobalTensor<half> weightWorkspaceGm; // v1阶段处理w*scale后的结果
|
||||
uint64_t weightMemSize = BLOCK_CUBE * constInfo.mBaseSize * WS_DOBULE * sizeof(half);
|
||||
weightWorkspaceGm.SetGlobalBuffer((__gm__ half *)(workspace + offset + aiCoreIdx * weightMemSize));
|
||||
offset += GetBlockNum() * weightMemSize;
|
||||
|
||||
GlobalTensor<half> qScaleGm;
|
||||
GlobalTensor<half> kScaleGm;
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.InitParams(constInfo, tiling);
|
||||
indiceOutGm.SetGlobalBuffer((__gm__ int32_t *)sparseIndices);
|
||||
weightsGm.SetGlobalBuffer((__gm__ half *)weights);
|
||||
qScaleGm.SetGlobalBuffer((__gm__ half *)queryScale);
|
||||
kScaleGm.SetGlobalBuffer((__gm__ half *)keyScale);
|
||||
blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable);
|
||||
vectorService.InitVecInputTensor(weightsGm, qScaleGm, kScaleGm, indiceOutGm, blockTableGm);
|
||||
vectorService.InitVecWorkspaceTensor(weightWorkspaceGm, mm1ResGm, vec1ResGm, vec1ParamGm);
|
||||
} else {
|
||||
matmulService.InitParams(constInfo);
|
||||
queryGm.SetGlobalBuffer((__gm__ Q_T *)query);
|
||||
if constexpr (PAGE_ATTENTION) {
|
||||
blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable);
|
||||
}
|
||||
keyGm.SetGlobalBuffer((__gm__ K_T *)key);
|
||||
matmulService.InitMm1GlobalTensor(blockTableGm, keyGm, queryGm, mm1ResGm, weightWorkspaceGm);
|
||||
}
|
||||
InitBuffers();
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::GetBN2Idx(uint32_t bN2Idx)
|
||||
{
|
||||
tempLoopInfo.bN2Idx = bN2Idx;
|
||||
tempLoopInfo.bIdx = bN2Idx / constInfo.kHeadNum;
|
||||
tempLoopInfo.n2Idx = bN2Idx % constInfo.kHeadNum;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx)
|
||||
{
|
||||
tempLoopInfo.gS1Idx = gS1LoopIdx;
|
||||
tempLoopInfo.actMBaseSize = constInfo.mBaseSize;
|
||||
uint32_t remainedGS1Size = tempLoopInfo.actS1Size * constInfo.gSize - tempLoopInfo.gS1Idx * constInfo.mBaseSize;
|
||||
if (remainedGS1Size <= constInfo.mBaseSize && remainedGS1Size > 0) {
|
||||
tempLoopInfo.actMBaseSize = tempLoopInfo.mBasicSizeTail;
|
||||
}
|
||||
|
||||
bool isEnd = (bN2LoopIdx == splitCoreInfo.bN2End) && (gS1LoopIdx == splitCoreInfo.gS1End);
|
||||
uint32_t s2BlockNum;
|
||||
if (constInfo.attenMaskFlag) {
|
||||
s2BlockNum = GetS2BaseBlockNumOnMask(gS1LoopIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size);
|
||||
} else {
|
||||
s2BlockNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
|
||||
}
|
||||
tempLoopInfo.s2LoopEnd = isEnd ? splitCoreInfo.s2End : s2BlockNum - 1;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::CalcGS1LoopParams(uint32_t bN2LoopIdx)
|
||||
{
|
||||
GetBN2Idx(bN2LoopIdx);
|
||||
GetS1S2ActualSeqLen(tempLoopInfo.bIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size);
|
||||
if ((tempLoopInfo.actS2Size == 0) || (tempLoopInfo.actS1Size == 0)) {
|
||||
tempLoopInfo.curActSeqLenIsZero = true;
|
||||
return;
|
||||
}
|
||||
tempLoopInfo.curActSeqLenIsZero = false;
|
||||
tempLoopInfo.s2BasicSizeTail = tempLoopInfo.actS2Size % constInfo.s2BaseSize;
|
||||
tempLoopInfo.s2BasicSizeTail =
|
||||
(tempLoopInfo.s2BasicSizeTail == 0) ? constInfo.s2BaseSize : tempLoopInfo.s2BasicSizeTail;
|
||||
tempLoopInfo.mBasicSizeTail = (tempLoopInfo.actS1Size * constInfo.gSize) % constInfo.mBaseSize;
|
||||
tempLoopInfo.mBasicSizeTail =
|
||||
(tempLoopInfo.mBasicSizeTail == 0) ? constInfo.mBaseSize : tempLoopInfo.mBasicSizeTail;
|
||||
|
||||
uint32_t gS1SplitNum = (tempLoopInfo.actS1Size * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize;
|
||||
tempLoopInfo.gS1LoopEnd = (bN2LoopIdx == splitCoreInfo.bN2End) ? splitCoreInfo.gS1End : gS1SplitNum - 1;
|
||||
if constexpr (Q_LAYOUT_T == LI_LAYOUT::BSND) {
|
||||
if (tempLoopInfo.gS1LoopEnd == gS1SplitNum - 1 && constInfo.qSeqSize > tempLoopInfo.actS1Size) {
|
||||
tempLoopInfo.needDealActS1LessThanS1 = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
runInfo.loop = loop;
|
||||
runInfo.bIdx = tempLoopInfo.bIdx;
|
||||
runInfo.gS1Idx = tempLoopInfo.gS1Idx;
|
||||
runInfo.s2Idx = s2LoopIdx;
|
||||
runInfo.bN2Idx = tempLoopInfo.bN2Idx;
|
||||
runInfo.isValid = s2LoopIdx <= tempLoopInfo.s2LoopEnd;
|
||||
|
||||
if (!runInfo.isValid) {
|
||||
return; // 需要验证, v1 时候需要runInfo
|
||||
}
|
||||
|
||||
runInfo.actS1Size = tempLoopInfo.actS1Size;
|
||||
runInfo.actS2Size = tempLoopInfo.actS2Size;
|
||||
// 计算实际基本块size
|
||||
runInfo.actMBaseSize = tempLoopInfo.actMBaseSize;
|
||||
runInfo.actualSingleProcessSInnerSize = constInfo.s2BaseSize;
|
||||
uint32_t s2SplitNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
|
||||
if (runInfo.s2Idx == s2SplitNum - 1) {
|
||||
runInfo.actualSingleProcessSInnerSize = tempLoopInfo.s2BasicSizeTail;
|
||||
}
|
||||
runInfo.actualSingleProcessSInnerSizeAlign =
|
||||
LIQCommon::Align((uint32_t)runInfo.actualSingleProcessSInnerSize, LIQCommon::ConstInfo::BUFFER_SIZE_BYTE_32B);
|
||||
|
||||
runInfo.isFirstS2InnerLoop = s2LoopIdx == splitCoreInfo.s2Start;
|
||||
runInfo.isLastS2InnerLoop = s2LoopIdx == tempLoopInfo.s2LoopEnd;
|
||||
runInfo.isAllLoopEnd = (runInfo.bN2Idx == splitCoreInfo.bN2End) && (runInfo.gS1Idx == splitCoreInfo.gS1End) &&
|
||||
(runInfo.s2Idx == splitCoreInfo.s2End);
|
||||
|
||||
if (runInfo.isFirstS2InnerLoop) {
|
||||
uint64_t actualSeqQPrefixSum;
|
||||
if constexpr (Q_LAYOUT_T == LI_LAYOUT::TND) {
|
||||
actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGmQ.GetValue(runInfo.bIdx - 1);
|
||||
} else { // BSND
|
||||
actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.qSeqSize;
|
||||
}
|
||||
uint64_t tndBIdxOffset = actualSeqQPrefixSum * constInfo.qHeadNum * constInfo.headDim;
|
||||
// B,S1,N1(N2,G),D
|
||||
queryCoreOffset = tndBIdxOffset + runInfo.gS1Idx * constInfo.mBaseSize * constInfo.headDim;
|
||||
// B,S1,N1(N2,G)/T,N1(N2,G)
|
||||
weightsCoreOffset = actualSeqQPrefixSum * constInfo.qHeadNum + runInfo.n2Idx * constInfo.gSize;
|
||||
// B,S1,N2,k/T,N2,k
|
||||
indiceOutCoreOffset =
|
||||
actualSeqQPrefixSum * constInfo.kHeadNum * constInfo.sparseCount + runInfo.n2Idx * constInfo.sparseCount;
|
||||
}
|
||||
uint64_t actualSeqKPrefixSum;
|
||||
if constexpr (K_LAYOUT_T == LI_LAYOUT::TND) { // T N2 D
|
||||
actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGm.GetValue(runInfo.bIdx - 1);
|
||||
} else {
|
||||
actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.kSeqSize;
|
||||
}
|
||||
uint64_t tndBIdxOffsetForK = actualSeqKPrefixSum * constInfo.kHeadNum * constInfo.headDim;
|
||||
keyCoreOffset = tndBIdxOffsetForK + runInfo.s2Idx * constInfo.s2BaseSize * constInfo.kHeadNum * constInfo.headDim;
|
||||
keyScaleCoreOffset = (actualSeqKPrefixSum + runInfo.s2Idx * constInfo.s2BaseSize) * constInfo.kHeadNum;
|
||||
runInfo.tensorQueryOffset = queryCoreOffset;
|
||||
runInfo.tensorKeyOffset = keyCoreOffset;
|
||||
runInfo.tensorKeyScaleOffset = keyScaleCoreOffset;
|
||||
runInfo.tensorWeightsOffset = weightsCoreOffset;
|
||||
runInfo.indiceOutOffset = indiceOutCoreOffset;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::Process()
|
||||
{
|
||||
if (usedCoreNum == 0) {
|
||||
// 没有计算任务,直接清理输出
|
||||
ProcessInvalid();
|
||||
return;
|
||||
}
|
||||
|
||||
ProcessMain();
|
||||
|
||||
ProcessDecode();
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::ProcessInvalid()
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
uint32_t aivCoreNum = GetBlockNum() * 2; // 2 means c:v = 1:2
|
||||
uint64_t totalOutputSize =
|
||||
constInfo.batchSize * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount;
|
||||
uint64_t singleCoreSize =
|
||||
LIQCommon::Align((totalOutputSize + aivCoreNum - 1) / aivCoreNum, GM_ALIGN_BYTES / sizeof(OUT_T));
|
||||
uint64_t baseSize = tmpBlockIdx * singleCoreSize;
|
||||
if (baseSize < totalOutputSize) {
|
||||
uint64_t dealSize =
|
||||
(baseSize + singleCoreSize <= totalOutputSize) ? singleCoreSize : totalOutputSize - baseSize;
|
||||
GlobalTensor<OUT_T> output = indiceOutGm[baseSize];
|
||||
AscendC::InitGlobalMemory(output, dealSize, constInfo.INVALID_IDX);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::ProcessMain()
|
||||
{
|
||||
if (aiCoreIdx >= usedCoreNum) {
|
||||
// 无任务核直接返回
|
||||
return;
|
||||
}
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.AllocEventID();
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE2>(constInfo.syncV1C1);
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE2>(constInfo.syncV1C1);
|
||||
} else {
|
||||
matmulService.AllocEventID();
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V0);
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V0);
|
||||
}
|
||||
|
||||
LIQCommon::RunInfo runInfo[LI_QUANT_PRELOAD_TASK_CACHE_SIZE];
|
||||
|
||||
uint32_t gloop = 0;
|
||||
for (uint32_t bN2LoopIdx = splitCoreInfo.bN2Start; bN2LoopIdx <= splitCoreInfo.bN2End; bN2LoopIdx++) {
|
||||
CalcGS1LoopParams(bN2LoopIdx);
|
||||
if (tempLoopInfo.curActSeqLenIsZero) {
|
||||
DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, 0U);
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
if (bN2LoopIdx == splitCoreInfo.bN2End && gloop > 0) {
|
||||
CrossCoreWaitFlag(constInfo.syncC1V1);
|
||||
vectorService.ProcessVec1(runInfo[1 - gloop % LI_QUANT_PRELOAD_TASK_CACHE_SIZE]);
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE3>(
|
||||
constInfo.syncV1C1); // 反向同步 1
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
for (uint32_t gS1LoopIdx = splitCoreInfo.gS1Start; gS1LoopIdx <= tempLoopInfo.gS1LoopEnd; gS1LoopIdx++) {
|
||||
CalcS2LoopParams(bN2LoopIdx, gS1LoopIdx);
|
||||
bool isEnd = (bN2LoopIdx == splitCoreInfo.bN2End) && (gS1LoopIdx == splitCoreInfo.gS1End);
|
||||
uint32_t extraLoop = isEnd ? LI_QUANT_PRELOAD_TASK_CACHE_SIZE - 1 : 0;
|
||||
for (int s2LoopIdx = splitCoreInfo.s2Start; s2LoopIdx <= (tempLoopInfo.s2LoopEnd + extraLoop);
|
||||
s2LoopIdx++) {
|
||||
ProcessBaseBlock(gloop, s2LoopIdx, runInfo);
|
||||
++gloop;
|
||||
}
|
||||
splitCoreInfo.s2Start = 0;
|
||||
}
|
||||
if (tempLoopInfo.needDealActS1LessThanS1) {
|
||||
DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, tempLoopInfo.actS1Size);
|
||||
}
|
||||
splitCoreInfo.gS1Start = 0;
|
||||
}
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.FreeEventID();
|
||||
CrossCoreWaitFlag(constInfo.syncC1V0);
|
||||
CrossCoreWaitFlag(constInfo.syncC1V0);
|
||||
} else {
|
||||
matmulService.FreeEventID();
|
||||
CrossCoreWaitFlag(constInfo.syncV1C1);
|
||||
CrossCoreWaitFlag(constInfo.syncV1C1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx,
|
||||
LIQCommon::RunInfo runInfo[LI_QUANT_PRELOAD_TASK_CACHE_SIZE])
|
||||
{
|
||||
int32_t curTaskId = loop % LI_QUANT_PRELOAD_TASK_CACHE_SIZE;
|
||||
LIQCommon::RunInfo &curRunInfo = runInfo[curTaskId];
|
||||
LIQCommon::RunInfo &lastRunInfo = runInfo[1 - curTaskId];
|
||||
|
||||
CalcRunInfo(loop, s2LoopIdx, curRunInfo);
|
||||
|
||||
if (curRunInfo.isValid) {
|
||||
if ASCEND_IS_AIC {
|
||||
if (curRunInfo.isFirstS2InnerLoop) {
|
||||
CrossCoreWaitFlag(constInfo.syncV0C1);
|
||||
}
|
||||
CrossCoreWaitFlag(constInfo.syncV1C1); // 反向同步 1
|
||||
matmulService.ComputeMm1(curRunInfo);
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V1);
|
||||
if (curRunInfo.isLastS2InnerLoop) {
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V0); // 反向同步 0
|
||||
}
|
||||
} else {
|
||||
if (curRunInfo.isFirstS2InnerLoop) {
|
||||
CrossCoreWaitFlag(constInfo.syncC1V0); // 反向同步 0
|
||||
vectorService.ProcessVec0(curRunInfo);
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE3>(constInfo.syncV0C1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (lastRunInfo.isValid) {
|
||||
if ASCEND_IS_AIV {
|
||||
CrossCoreWaitFlag(constInfo.syncC1V1);
|
||||
vectorService.ProcessVec1(lastRunInfo);
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE3>(constInfo.syncV1C1); // 反向同步 1
|
||||
}
|
||||
lastRunInfo.isValid = false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::ProcessDecode()
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.InitLDBuffers(pipe);
|
||||
ICachePreLoad(LD_PREFETCH_LEN);
|
||||
SyncAll();
|
||||
if (splitCoreInfo.isLD) {
|
||||
vectorService.ProcessLD();
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace LIQKernel
|
||||
#endif // LIGHTNING_INDEXER_QUANT_KERNEL_H
|
||||
@@ -0,0 +1,613 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_service_cube.h
|
||||
* \brief use 5 buffer for matmul l1, better pipeline
|
||||
*/
|
||||
#ifndef LIGHTNING_INDEXER_QUANT_SERVICE_CUBE_H
|
||||
#define LIGHTNING_INDEXER_QUANT_SERVICE_CUBE_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_operator_list_tensor_intf.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
#include "lib/matrix/matmul/tiling.h"
|
||||
#include "lightning_indexer_quant_common.h"
|
||||
|
||||
namespace LIQKernel {
|
||||
using namespace LIQCommon;
|
||||
struct MmInfo {
|
||||
int64_t s2L0LoopId;
|
||||
int64_t s1gL0LoopId;
|
||||
int64_t s2L0RealSize;
|
||||
int64_t s2GmOffset;
|
||||
};
|
||||
|
||||
template <typename LIQT>
|
||||
class LIQMatmul {
|
||||
public:
|
||||
using Q_T = typename LIQT::queryType;
|
||||
using K_T = typename LIQT::keyType;
|
||||
|
||||
__aicore__ inline LIQMatmul(){};
|
||||
__aicore__ inline void InitBuffers(TPipe *pipe);
|
||||
__aicore__ inline void InitMm1GlobalTensor(const GlobalTensor<int32_t> &blkTableGm, const GlobalTensor<K_T> &keyGm,
|
||||
const GlobalTensor<Q_T> &queryGm, const GlobalTensor<float> &mm1ResGm,
|
||||
const GlobalTensor<half> &weightWorkspaceGm);
|
||||
__aicore__ inline void InitParams(const ConstInfo &constInfo);
|
||||
__aicore__ inline void AllocEventID();
|
||||
__aicore__ inline void FreeEventID();
|
||||
__aicore__ inline void ComputeMm1(const LIQCommon::RunInfo &runInfo);
|
||||
|
||||
static constexpr IsResetLoad3dConfig LOAD3DV2_CONFIG = {true, true}; // isSetFMatrix isSetPadding;
|
||||
static constexpr uint64_t DOUBLE_BUF_NUM = 2;
|
||||
static constexpr uint64_t L0AB_BUF_NUM = 4;
|
||||
|
||||
static constexpr uint32_t KEY_MTE1_MTE2_EVENT = EVENT_ID2;
|
||||
static constexpr uint32_t QW_MTE1_MTE2_EVENT = EVENT_ID5; // KEY_MTE1_MTE2_EVENT + DOUBLE_BUF_NUM;
|
||||
static constexpr uint32_t M_MTE1_EVENT = EVENT_ID3;
|
||||
static constexpr uint32_t M_FIX_EVENT = EVENT_ID0;
|
||||
static constexpr uint32_t FIX_M_EVENT = EVENT_ID2;
|
||||
static constexpr uint32_t FIX_MTE1_EVENT = EVENT_ID4;
|
||||
|
||||
static constexpr uint64_t S8_BLOCK_CUBE = 32;
|
||||
|
||||
static constexpr uint32_t MTE2_MTE1_EVENT = EVENT_ID2;
|
||||
static constexpr uint32_t MTE1_M_EVENT = EVENT_ID2;
|
||||
|
||||
static constexpr uint64_t D_BASIC_BLOCK = 128;
|
||||
static constexpr uint64_t S1G_BASIC_BLOCK_L1 = 256;
|
||||
|
||||
static constexpr uint64_t S1G_BASIC_BLOCK_L0 = 128;
|
||||
static constexpr uint64_t S2_BASIC_BLOCK_L0 = 128;
|
||||
|
||||
static constexpr uint64_t QUERY_BUFFER_OFFSET = S1G_BASIC_BLOCK_L1 * D_BASIC_BLOCK;
|
||||
static constexpr uint64_t SL1_BUFFER_OFFSET = S1G_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0;
|
||||
static constexpr uint64_t KEY_BUFFER_OFFSET = S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK;
|
||||
static constexpr uint64_t WEIGHT_BUFFER_OFFSET = S1G_BASIC_BLOCK_L1 * BLOCK_CUBE;
|
||||
static constexpr uint64_t L0AB_BUFFER_OFFSET_S8_16K = 16 * 1024;
|
||||
static constexpr uint64_t L0AB_BUFFER_OFFSET_FP16_16K = 16 * 512;
|
||||
static constexpr uint64_t L0C_BUFFER_OFFSET = 64 * 256;
|
||||
|
||||
private:
|
||||
__aicore__ inline void WeightDmaCopy(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo);
|
||||
__aicore__ inline void LoadKeyToL0b(uint64_t s2L0RealSize);
|
||||
__aicore__ inline void LoadQueryToL0a(uint64_t s1gL1Offset, uint64_t s1gL1RealSize, uint64_t s1gL0RealSize);
|
||||
__aicore__ inline void QueryNd2Nz(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo);
|
||||
__aicore__ inline void KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LIQCommon::RunInfo &runInfo);
|
||||
__aicore__ inline void KeyNd2Nz(uint64_t s2L1RealSize, const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo);
|
||||
__aicore__ inline void FixpSToL1(uint64_t s1gL0RealSize, uint64_t s2L0RealSize);
|
||||
__aicore__ inline void LoadSToL0b(uint64_t s1gL1RealSize, uint64_t s2L0RealSize, uint64_t sL1BufIdx,
|
||||
int64_t mStartPt);
|
||||
__aicore__ inline void LoadWeightToL0a(uint64_t s1gL1Offset);
|
||||
__aicore__ inline void ComputeWs(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, int64_t s1gOffset);
|
||||
__aicore__ inline void FixpResToGm(uint64_t s1L0RealCount, uint64_t s2L0RealSize, uint64_t s1GmOffset,
|
||||
uint64_t s2GmOffset, const LIQCommon::RunInfo &runInfo);
|
||||
__aicore__ inline void ComputeQk(uint64_t s1gL0RealSize, uint64_t s2L0RealSize);
|
||||
__aicore__ inline void ProcessWs(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t sL1BufIdx,
|
||||
const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo);
|
||||
__aicore__ inline void ProcessQk(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t s1L0LoopCnt,
|
||||
const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo);
|
||||
__aicore__ inline void CalcMmInfo(MmInfo &mmInfo, uint64_t loopIdx, uint64_t s1L0LoopCnt, const MmInfo &lastMmInfo,
|
||||
const LIQCommon::RunInfo &runInfo);
|
||||
static constexpr LI_LAYOUT Q_LAYOUT_T = LIQT::layout;
|
||||
static constexpr LI_LAYOUT K_LAYOUT_T = LIQT::keyLayout;
|
||||
GlobalTensor<int32_t> blkTableGm_;
|
||||
GlobalTensor<K_T> keyGm_;
|
||||
GlobalTensor<Q_T> queryGm_;
|
||||
GlobalTensor<half> weightGm_;
|
||||
GlobalTensor<float> mm1ResGm_;
|
||||
|
||||
TBuf<TPosition::A1> bufQL1_;
|
||||
LocalTensor<Q_T> queryL1_;
|
||||
TBuf<TPosition::B1> bufKeyL1_;
|
||||
LocalTensor<K_T> keyL1_;
|
||||
TBuf<TPosition::A1> bufWeightL1_;
|
||||
LocalTensor<half> weightL1_;
|
||||
TBuf<TPosition::B1> bufSL1_;
|
||||
LocalTensor<half> sL1_;
|
||||
|
||||
TBuf<TPosition::A2> bufL0A_;
|
||||
LocalTensor<Q_T> l0a_;
|
||||
TBuf<TPosition::B2> bufL0B_;
|
||||
LocalTensor<K_T> l0b_;
|
||||
|
||||
TBuf<TPosition::CO1> bufL0C_;
|
||||
LocalTensor<int32_t> cL0_;
|
||||
|
||||
uint64_t keyL1BufIdx_ = 0;
|
||||
uint64_t qwL1Mte2BufIdx_ = 0;
|
||||
uint64_t sL1BufIdx_ = 0;
|
||||
uint64_t l0BufIdx_ = 0;
|
||||
uint64_t l0cBufIdx_ = 0;
|
||||
|
||||
ConstInfo constInfo_;
|
||||
};
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::InitParams(const ConstInfo &constInfo)
|
||||
{
|
||||
constInfo_ = constInfo;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::InitBuffers(TPipe *pipe)
|
||||
{
|
||||
pipe->InitBuffer(bufQL1_, DOUBLE_BUF_NUM * S1G_BASIC_BLOCK_L1 * D_BASIC_BLOCK * sizeof(Q_T));
|
||||
queryL1_ = bufQL1_.Get<Q_T>();
|
||||
pipe->InitBuffer(bufKeyL1_, DOUBLE_BUF_NUM * S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK * sizeof(K_T));
|
||||
keyL1_ = bufKeyL1_.Get<K_T>();
|
||||
|
||||
pipe->InitBuffer(bufWeightL1_, DOUBLE_BUF_NUM * S1G_BASIC_BLOCK_L1 * BLOCK_CUBE * sizeof(half));
|
||||
weightL1_ = bufWeightL1_.Get<half>();
|
||||
pipe->InitBuffer(bufSL1_, DOUBLE_BUF_NUM * S2_BASIC_BLOCK_L0 * S1G_BASIC_BLOCK_L0 * sizeof(half));
|
||||
sL1_ = bufSL1_.Get<half>();
|
||||
|
||||
pipe->InitBuffer(bufL0A_, 64 * 1024);
|
||||
l0a_ = bufL0A_.Get<Q_T>();
|
||||
pipe->InitBuffer(bufL0B_, 64 * 1024);
|
||||
l0b_ = bufL0B_.Get<K_T>();
|
||||
|
||||
pipe->InitBuffer(bufL0C_, 128 * 1024);
|
||||
cL0_ = bufL0C_.Get<int32_t>();
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::InitMm1GlobalTensor(const GlobalTensor<int32_t> &blkTableGm,
|
||||
const GlobalTensor<K_T> &keyGm,
|
||||
const GlobalTensor<Q_T> &queryGm,
|
||||
const GlobalTensor<float> &mm1ResGm,
|
||||
const GlobalTensor<half> &weightWorkspaceGm)
|
||||
{
|
||||
blkTableGm_ = blkTableGm;
|
||||
keyGm_ = keyGm;
|
||||
queryGm_ = queryGm;
|
||||
mm1ResGm_ = mm1ResGm;
|
||||
weightGm_ = weightWorkspaceGm;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::ProcessWs(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t sL1BufIdx,
|
||||
const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
WaitFlag<HardEvent::FIX_M>(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM);
|
||||
for (int64_t s1gOffset = 0; s1gOffset < s1gL0RealSize; s1gOffset += constInfo_.gSize) {
|
||||
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM);
|
||||
LoadSToL0b(s1gL0RealSize, mmInfo.s2L0RealSize, sL1BufIdx, s1gOffset);
|
||||
LoadWeightToL0a(s1gOffset + s1gL1Offset);
|
||||
|
||||
ComputeWs(s1gL0RealSize, mmInfo.s2L0RealSize, s1gOffset);
|
||||
|
||||
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM);
|
||||
l0BufIdx_++;
|
||||
}
|
||||
|
||||
FixpResToGm(s1gL0RealSize / constInfo_.gSize, mmInfo.s2L0RealSize, s1gL1Offset / constInfo_.gSize,
|
||||
mmInfo.s2L0LoopId * S2_BASIC_BLOCK_L0, runInfo);
|
||||
SetFlag<HardEvent::FIX_M>(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM);
|
||||
l0cBufIdx_++;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::ProcessQk(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t s1L0LoopCnt,
|
||||
const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
if (mmInfo.s1gL0LoopId == 0) {
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % DOUBLE_BUF_NUM);
|
||||
if constexpr (K_LAYOUT_T == LI_LAYOUT::PA_BSND) {
|
||||
KeyNd2NzForPA(mmInfo.s2L0RealSize, runInfo.s2Idx * constInfo_.s2BaseSize + mmInfo.s2GmOffset, runInfo);
|
||||
} else {
|
||||
KeyNd2Nz(mmInfo.s2L0RealSize, mmInfo, runInfo);
|
||||
}
|
||||
|
||||
SetFlag<HardEvent::MTE2_MTE1>(MTE2_MTE1_EVENT);
|
||||
WaitFlag<HardEvent::MTE2_MTE1>(MTE2_MTE1_EVENT);
|
||||
}
|
||||
|
||||
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM);
|
||||
LoadQueryToL0a(s1gL1Offset, runInfo.actMBaseSize, s1gL0RealSize);
|
||||
LoadKeyToL0b(mmInfo.s2L0RealSize);
|
||||
|
||||
if (mmInfo.s1gL0LoopId + 1 >= s1L0LoopCnt) {
|
||||
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % DOUBLE_BUF_NUM);
|
||||
keyL1BufIdx_++;
|
||||
}
|
||||
|
||||
WaitFlag<HardEvent::FIX_M>(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM);
|
||||
ComputeQk(s1gL0RealSize, mmInfo.s2L0RealSize);
|
||||
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM);
|
||||
|
||||
FixpSToL1(s1gL0RealSize, mmInfo.s2L0RealSize);
|
||||
SetFlag<HardEvent::FIX_M>(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM);
|
||||
l0BufIdx_++;
|
||||
l0cBufIdx_++;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::CalcMmInfo(MmInfo &mmInfo, uint64_t loopIdx, uint64_t s1L0LoopCnt,
|
||||
const MmInfo &lastMmInfo, const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
mmInfo.s2L0LoopId = loopIdx / s1L0LoopCnt;
|
||||
mmInfo.s1gL0LoopId = loopIdx % s1L0LoopCnt;
|
||||
|
||||
if (mmInfo.s1gL0LoopId == 0) {
|
||||
mmInfo.s2GmOffset = mmInfo.s2L0LoopId * S2_BASIC_BLOCK_L0;
|
||||
mmInfo.s2L0RealSize = mmInfo.s2GmOffset + S2_BASIC_BLOCK_L0 > runInfo.actualSingleProcessSInnerSize
|
||||
? runInfo.actualSingleProcessSInnerSize - mmInfo.s2GmOffset
|
||||
: S2_BASIC_BLOCK_L0;
|
||||
} else {
|
||||
mmInfo.s2L0RealSize = lastMmInfo.s2L0RealSize;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::ComputeMm1(const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
if (runInfo.isFirstS2InnerLoop) {
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM);
|
||||
QueryNd2Nz(runInfo.actMBaseSize, runInfo); // 256 * 128 // L1BasicBlock
|
||||
WeightDmaCopy(runInfo.actMBaseSize, runInfo);
|
||||
}
|
||||
int64_t loopIdx = 0;
|
||||
int64_t s2L0LoopCnt = CeilDiv(runInfo.actualSingleProcessSInnerSize, S2_BASIC_BLOCK_L0); // 2048取128
|
||||
int64_t s1L0LoopCnt = CeilDiv(runInfo.actMBaseSize, S1G_BASIC_BLOCK_L0); // 256取128
|
||||
int64_t s1gL1Offset[2] = {0, static_cast<int64_t>(S1G_BASIC_BLOCK_L0)};
|
||||
int64_t s1gL0RealSize[2] = {s1L0LoopCnt > 1 ? static_cast<int64_t>(S1G_BASIC_BLOCK_L0) : runInfo.actMBaseSize,
|
||||
runInfo.actMBaseSize - s1gL1Offset[1]};
|
||||
MmInfo mmInfo[2];
|
||||
CalcMmInfo(mmInfo[loopIdx & 1], loopIdx, s1L0LoopCnt, mmInfo[(loopIdx + 1) & 1], runInfo);
|
||||
|
||||
ProcessQk(s1gL0RealSize[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt],
|
||||
s1gL1Offset[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt], s1L0LoopCnt, mmInfo[loopIdx & 1],
|
||||
runInfo);
|
||||
|
||||
SetFlag<HardEvent::FIX_MTE1>(FIX_MTE1_EVENT + sL1BufIdx_ % DOUBLE_BUF_NUM);
|
||||
sL1BufIdx_++;
|
||||
loopIdx++;
|
||||
|
||||
while (loopIdx < s2L0LoopCnt * s1L0LoopCnt) {
|
||||
CalcMmInfo(mmInfo[loopIdx & 1], loopIdx, s1L0LoopCnt, mmInfo[(loopIdx + 1) & 1], runInfo);
|
||||
|
||||
ProcessQk(s1gL0RealSize[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt],
|
||||
s1gL1Offset[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt], s1L0LoopCnt, mmInfo[loopIdx & 1],
|
||||
runInfo);
|
||||
|
||||
SetFlag<HardEvent::FIX_MTE1>(FIX_MTE1_EVENT + sL1BufIdx_ % DOUBLE_BUF_NUM);
|
||||
sL1BufIdx_++;
|
||||
|
||||
WaitFlag<HardEvent::FIX_MTE1>(FIX_MTE1_EVENT + sL1BufIdx_ % DOUBLE_BUF_NUM);
|
||||
|
||||
ProcessWs(s1gL0RealSize[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt],
|
||||
s1gL1Offset[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt], sL1BufIdx_,
|
||||
mmInfo[(loopIdx + 1) & 1], runInfo);
|
||||
loopIdx++;
|
||||
}
|
||||
|
||||
WaitFlag<HardEvent::FIX_MTE1>(FIX_MTE1_EVENT + (sL1BufIdx_ + 1) % DOUBLE_BUF_NUM);
|
||||
|
||||
ProcessWs(s1gL0RealSize[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt],
|
||||
s1gL1Offset[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt], sL1BufIdx_ - 1,
|
||||
mmInfo[(loopIdx + 1) & 1], runInfo);
|
||||
|
||||
if (runInfo.isLastS2InnerLoop) {
|
||||
SetFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM);
|
||||
qwL1Mte2BufIdx_++;
|
||||
}
|
||||
}
|
||||
|
||||
// blkNum, blkSize, N2, D
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset,
|
||||
const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
uint64_t s2L1Offset = 0;
|
||||
while (s2L1Offset < s2L1RealSize) {
|
||||
uint64_t s2BlkId = (s2L1Offset + s2GmOffset) / constInfo_.kCacheBlockSize;
|
||||
uint64_t s2BlkOffset = (s2L1Offset + s2GmOffset) % constInfo_.kCacheBlockSize;
|
||||
uint64_t keyGmOffset = blkTableGm_.GetValue(runInfo.bIdx * constInfo_.maxBlockNumPerBatch + s2BlkId) *
|
||||
constInfo_.kCacheBlockSize * constInfo_.kHeadNum * constInfo_.headDim +
|
||||
s2BlkOffset * constInfo_.headDim;
|
||||
uint64_t s2Mte2Size = s2L1RealSize - s2L1Offset;
|
||||
s2Mte2Size = s2BlkOffset + s2Mte2Size >= constInfo_.kCacheBlockSize ? constInfo_.kCacheBlockSize - s2BlkOffset
|
||||
: s2Mte2Size;
|
||||
Nd2NzParams nd2nzPara;
|
||||
nd2nzPara.ndNum = 1;
|
||||
nd2nzPara.nValue = s2Mte2Size; // 行数
|
||||
nd2nzPara.dValue = constInfo_.headDim;
|
||||
nd2nzPara.srcDValue = constInfo_.headDim;
|
||||
nd2nzPara.dstNzC0Stride = CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE); // 对齐到16 单位block
|
||||
nd2nzPara.dstNzNStride = 1;
|
||||
nd2nzPara.srcNdMatrixStride = 0;
|
||||
nd2nzPara.dstNzMatrixStride = 0;
|
||||
DataCopy(keyL1_[(keyL1BufIdx_ % DOUBLE_BUF_NUM) * KEY_BUFFER_OFFSET + s2L1Offset * S8_BLOCK_CUBE],
|
||||
keyGm_[keyGmOffset], nd2nzPara);
|
||||
|
||||
s2L1Offset += s2Mte2Size;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::KeyNd2Nz(uint64_t s2L1RealSize, const MmInfo &mmInfo,
|
||||
const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
uint64_t dStride = constInfo_.headDim;
|
||||
if constexpr (K_LAYOUT_T == LI_LAYOUT::BSND || K_LAYOUT_T == LI_LAYOUT::TND) {
|
||||
dStride = constInfo_.headDim * constInfo_.kHeadNum; // constInfo_.kHeadNum
|
||||
}
|
||||
Nd2NzParams nd2nzPara;
|
||||
nd2nzPara.ndNum = 1;
|
||||
nd2nzPara.nValue = s2L1RealSize; // 行数
|
||||
nd2nzPara.dValue = constInfo_.headDim;
|
||||
nd2nzPara.srcDValue = dStride;
|
||||
nd2nzPara.dstNzC0Stride = CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE); // 对齐到16 单位block
|
||||
nd2nzPara.dstNzNStride = 1;
|
||||
nd2nzPara.srcNdMatrixStride = 0;
|
||||
nd2nzPara.dstNzMatrixStride = 0;
|
||||
// 默认一块buf最多放两份
|
||||
DataCopy(keyL1_[(keyL1BufIdx_ % DOUBLE_BUF_NUM) * KEY_BUFFER_OFFSET],
|
||||
keyGm_[runInfo.tensorKeyOffset + mmInfo.s2GmOffset * constInfo_.headDim], nd2nzPara);
|
||||
}
|
||||
|
||||
// batch, s1, g, 1
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::WeightDmaCopy(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
DataCopyParams copyInParams;
|
||||
copyInParams.blockCount = 1;
|
||||
copyInParams.blockLen = s1gL1RealSize;
|
||||
copyInParams.srcStride = 0;
|
||||
copyInParams.dstStride = 0;
|
||||
DataCopy(weightL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * WEIGHT_BUFFER_OFFSET],
|
||||
weightGm_[runInfo.loop % DOUBLE_BUF_NUM * BLOCK_CUBE * constInfo_.mBaseSize], copyInParams);
|
||||
}
|
||||
|
||||
// batch, s1, n2, g, d
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::QueryNd2Nz(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
Nd2NzParams nd2nzPara;
|
||||
nd2nzPara.ndNum = 1;
|
||||
nd2nzPara.nValue = s1gL1RealSize; // 行数
|
||||
nd2nzPara.dValue = constInfo_.headDim;
|
||||
nd2nzPara.srcDValue = constInfo_.headDim;
|
||||
nd2nzPara.dstNzC0Stride = CeilAlign(s1gL1RealSize, (uint64_t)BLOCK_CUBE); // 对齐到16 单位block
|
||||
nd2nzPara.dstNzNStride = 1;
|
||||
nd2nzPara.srcNdMatrixStride = 0;
|
||||
nd2nzPara.dstNzMatrixStride = 0;
|
||||
// 默认一块buf最多放两份
|
||||
DataCopy(queryL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * QUERY_BUFFER_OFFSET], queryGm_[runInfo.tensorQueryOffset],
|
||||
nd2nzPara);
|
||||
}
|
||||
|
||||
// s1g, d
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::LoadQueryToL0a(uint64_t s1gL1Offset, uint64_t s1gL1RealSize,
|
||||
uint64_t s1gL0RealSize)
|
||||
{
|
||||
LoadData3DParamsV2<Q_T> loadData3DParams;
|
||||
// SetFmatrixParams
|
||||
loadData3DParams.l1H = CeilDiv(s1gL1RealSize, BLOCK_CUBE); // Hin=M1=8
|
||||
loadData3DParams.l1W = BLOCK_CUBE; // Win=M0
|
||||
loadData3DParams.channelSize = constInfo_.headDim; // Cin=K
|
||||
|
||||
loadData3DParams.padList[0] = 0;
|
||||
loadData3DParams.padList[1] = 0;
|
||||
loadData3DParams.padList[2] = 0;
|
||||
loadData3DParams.padList[3] = 255; // 尾部数据不影响滑窗的结果
|
||||
|
||||
// SetLoadToA0Params
|
||||
loadData3DParams.mExtension = CeilAlign(s1gL0RealSize, BLOCK_CUBE); // M height维度目的
|
||||
loadData3DParams.kExtension = constInfo_.headDim; // K width维度目的
|
||||
loadData3DParams.mStartPt = s1gL1Offset;
|
||||
loadData3DParams.kStartPt = 0;
|
||||
loadData3DParams.strideW = 1;
|
||||
loadData3DParams.strideH = 1;
|
||||
loadData3DParams.filterW = 1;
|
||||
loadData3DParams.filterSizeW = (1 >> 8) & 255;
|
||||
loadData3DParams.filterH = 1;
|
||||
loadData3DParams.filterSizeH = (1 >> 8) & 255;
|
||||
loadData3DParams.dilationFilterW = 1;
|
||||
loadData3DParams.dilationFilterH = 1;
|
||||
loadData3DParams.enTranspose = 0;
|
||||
loadData3DParams.fMatrixCtrl = 0;
|
||||
|
||||
LoadData<Q_T, LOAD3DV2_CONFIG>(l0a_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K],
|
||||
queryL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * QUERY_BUFFER_OFFSET],
|
||||
loadData3DParams);
|
||||
}
|
||||
|
||||
// s1, g, s2 --> 2 * 64* 128
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::LoadSToL0b(uint64_t s1gL1RealSize, uint64_t s2L0RealSize, uint64_t sL1BufIdx,
|
||||
int64_t mStartPt)
|
||||
{
|
||||
LoadData3DParamsV2<half> loadData3DParams;
|
||||
// SetFmatrixParams
|
||||
loadData3DParams.l1H = S1G_BASIC_BLOCK_L0 / BLOCK_CUBE; // Hin=M1=8
|
||||
loadData3DParams.l1W = BLOCK_CUBE; // Win=M0
|
||||
loadData3DParams.channelSize = CeilAlign(s2L0RealSize, BLOCK_CUBE); // Cin=K
|
||||
|
||||
loadData3DParams.padList[0] = 0;
|
||||
loadData3DParams.padList[1] = 0;
|
||||
loadData3DParams.padList[2] = 0;
|
||||
loadData3DParams.padList[3] = 255; // 尾部数据不影响滑窗的结果
|
||||
|
||||
// SetLoadToA0Params
|
||||
loadData3DParams.mExtension = constInfo_.gSize; // M height维度目的
|
||||
loadData3DParams.kExtension = CeilAlign(s2L0RealSize, BLOCK_CUBE); // K width维度目的
|
||||
loadData3DParams.kStartPt = 0;
|
||||
loadData3DParams.strideW = 1;
|
||||
loadData3DParams.strideH = 1;
|
||||
loadData3DParams.filterW = 1;
|
||||
loadData3DParams.filterSizeW = (1 >> 8) & 255;
|
||||
loadData3DParams.filterH = 1;
|
||||
loadData3DParams.filterSizeH = (1 >> 8) & 255;
|
||||
loadData3DParams.dilationFilterW = 1;
|
||||
loadData3DParams.dilationFilterH = 1;
|
||||
loadData3DParams.enTranspose = 1;
|
||||
loadData3DParams.fMatrixCtrl = 0;
|
||||
|
||||
loadData3DParams.mStartPt = mStartPt;
|
||||
LoadData<half, LOAD3DV2_CONFIG>(
|
||||
l0b_.template ReinterpretCast<half>()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K],
|
||||
sL1_[(sL1BufIdx % DOUBLE_BUF_NUM) * SL1_BUFFER_OFFSET], loadData3DParams);
|
||||
}
|
||||
|
||||
// s1,g,1(16), 2,64,16
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::LoadWeightToL0a(uint64_t s1gL1Offset)
|
||||
{
|
||||
LoadData2DParams loadData2DParams;
|
||||
loadData2DParams.startIndex = 0;
|
||||
loadData2DParams.repeatTimes = CeilDiv(constInfo_.gSize, BLOCK_CUBE);
|
||||
loadData2DParams.srcStride = 1;
|
||||
loadData2DParams.dstGap = 0;
|
||||
loadData2DParams.ifTranspose = true;
|
||||
LoadData(l0a_.template ReinterpretCast<half>()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K],
|
||||
weightL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * WEIGHT_BUFFER_OFFSET + s1gL1Offset* BLOCK_CUBE],
|
||||
loadData2DParams);
|
||||
}
|
||||
|
||||
// s2, d -> 128,128
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::LoadKeyToL0b(uint64_t s2L0RealSize)
|
||||
{
|
||||
LoadData2DParams loadData2DParams;
|
||||
loadData2DParams.startIndex = 0;
|
||||
loadData2DParams.repeatTimes = CeilDiv(s2L0RealSize, BLOCK_CUBE) * CeilDiv(constInfo_.headDim, S8_BLOCK_CUBE);
|
||||
loadData2DParams.srcStride = 1;
|
||||
loadData2DParams.dstGap = 0;
|
||||
loadData2DParams.ifTranspose = false;
|
||||
LoadData(l0b_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K],
|
||||
keyL1_[(keyL1BufIdx_ % DOUBLE_BUF_NUM) * KEY_BUFFER_OFFSET], loadData2DParams);
|
||||
}
|
||||
|
||||
// A: s1,g,1(16) B: s1,g,s2 C: s1, 1(16), s2
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::ComputeWs(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, int64_t s1gOffset)
|
||||
{
|
||||
SetFlag<HardEvent::MTE1_M>(MTE1_M_EVENT);
|
||||
WaitFlag<HardEvent::MTE1_M>(MTE1_M_EVENT);
|
||||
MmadParams mmadParams;
|
||||
mmadParams.m = BLOCK_CUBE;
|
||||
mmadParams.n = s2L0RealSize;
|
||||
mmadParams.k = constInfo_.gSize;
|
||||
mmadParams.cmatrixInitVal = true;
|
||||
mmadParams.cmatrixSource = false;
|
||||
Mmad(cL0_.template ReinterpretCast<float>()[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET +
|
||||
s1gOffset * S2_BASIC_BLOCK_L0],
|
||||
l0a_.template ReinterpretCast<half>()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K],
|
||||
l0b_.template ReinterpretCast<half>()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K],
|
||||
mmadParams);
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::ComputeQk(uint64_t s1gL0RealSize, uint64_t s2L0RealSize)
|
||||
{
|
||||
SetFlag<HardEvent::MTE1_M>(MTE1_M_EVENT);
|
||||
WaitFlag<HardEvent::MTE1_M>(MTE1_M_EVENT);
|
||||
|
||||
MmadParams mmadParams;
|
||||
mmadParams.m = CeilAlign(s1gL0RealSize, BLOCK_CUBE);
|
||||
mmadParams.n = s2L0RealSize;
|
||||
mmadParams.k = constInfo_.headDim;
|
||||
mmadParams.cmatrixInitVal = true;
|
||||
mmadParams.cmatrixSource = false;
|
||||
Mmad(cL0_[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET],
|
||||
l0a_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K],
|
||||
l0b_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K], mmadParams);
|
||||
if ((mmadParams.m / 16) * (mmadParams.n / 16) < 10) {
|
||||
PipeBarrier<PIPE_M>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::FixpSToL1(uint64_t s1gL0RealSize, uint64_t s2L0RealSize)
|
||||
{
|
||||
SetFlag<HardEvent::M_FIX>(M_FIX_EVENT);
|
||||
WaitFlag<HardEvent::M_FIX>(M_FIX_EVENT);
|
||||
DataCopyCO12DstParams params;
|
||||
params.mSize = CeilAlign(s1gL0RealSize, BLOCK_CUBE);
|
||||
params.nSize = CeilAlign(s2L0RealSize, BLOCK_CUBE);
|
||||
params.dstStride = S1G_BASIC_BLOCK_L0;
|
||||
params.srcStride = params.mSize;
|
||||
params.quantPre = QuantMode_t::DEQF16;
|
||||
params.reluPre = 1;
|
||||
params.channelSplit = 0;
|
||||
params.nz2ndEn = 0;
|
||||
SetFixpipePreQuantFlag(0x3a800000);
|
||||
DataCopy(sL1_[(sL1BufIdx_ % DOUBLE_BUF_NUM) * SL1_BUFFER_OFFSET],
|
||||
cL0_[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET], params);
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::FixpResToGm(uint64_t s1L0RealCount, uint64_t s2L0RealSize, uint64_t s1GmOffset,
|
||||
uint64_t s2GmOffset, const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
SetFlag<HardEvent::M_FIX>(M_FIX_EVENT);
|
||||
WaitFlag<HardEvent::M_FIX>(M_FIX_EVENT);
|
||||
|
||||
AscendC::DataCopyCO12DstParams intriParams;
|
||||
intriParams.mSize = 1;
|
||||
intriParams.nSize = s2L0RealSize;
|
||||
intriParams.dstStride = constInfo_.s2BaseSize;
|
||||
intriParams.srcStride = 16;
|
||||
// set mode according to dtype
|
||||
intriParams.quantPre = QuantMode_t::NoQuant;
|
||||
intriParams.nz2ndEn = true;
|
||||
intriParams.reluPre = 0;
|
||||
AscendC::SetFixpipeNz2ndFlag(s1L0RealCount, CeilDiv(constInfo_.gSize, BLOCK_CUBE) * S2_BASIC_BLOCK_L0 / BLOCK_CUBE,
|
||||
2048);
|
||||
AscendC::DataCopy(mm1ResGm_[(runInfo.loop % 2) * constInfo_.mBaseSize / constInfo_.gSize * constInfo_.s2BaseSize +
|
||||
s1GmOffset * intriParams.dstStride + s2GmOffset],
|
||||
cL0_.template ReinterpretCast<float>()[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET],
|
||||
intriParams);
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::AllocEventID()
|
||||
{
|
||||
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 0);
|
||||
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 1);
|
||||
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 2);
|
||||
|
||||
SetFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + 0);
|
||||
SetFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + 1);
|
||||
|
||||
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 0);
|
||||
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 1);
|
||||
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 2);
|
||||
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 3);
|
||||
|
||||
SetFlag<HardEvent::FIX_M>(FIX_M_EVENT + 0);
|
||||
SetFlag<HardEvent::FIX_M>(FIX_M_EVENT + 1);
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::FreeEventID()
|
||||
{
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 0);
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 1);
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 2);
|
||||
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + 0);
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + 1);
|
||||
|
||||
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 0);
|
||||
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 1);
|
||||
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 2);
|
||||
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 3);
|
||||
|
||||
WaitFlag<HardEvent::FIX_M>(FIX_M_EVENT + 0);
|
||||
WaitFlag<HardEvent::FIX_M>(FIX_M_EVENT + 1);
|
||||
}
|
||||
} // namespace LIQKernel
|
||||
#endif
|
||||
@@ -0,0 +1,665 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_service_vector.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef LIGHTNING_INDEXER_QUANT_SERVICE_VECTOR_H
|
||||
#define LIGHTNING_INDEXER_QUANT_SERVICE_VECTOR_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_operator_list_tensor_intf.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
#include "lib/matrix/matmul/tiling.h"
|
||||
#include "lightning_indexer_quant_common.h"
|
||||
#include "lightning_indexer_quant_vector.h"
|
||||
|
||||
namespace LIQKernel {
|
||||
using namespace LIQCommon;
|
||||
using namespace LIQServiceVec;
|
||||
constexpr uint32_t BASE_TOPK = 2048;
|
||||
constexpr uint32_t BASE_TOPK_VALUE_IDX_SIZE = 4096;
|
||||
constexpr uint32_t LD_PARAM_NUM = 16;
|
||||
|
||||
template <typename LIQT>
|
||||
class LIQVector {
|
||||
public:
|
||||
// =================================类型定义区=================================
|
||||
static constexpr LI_LAYOUT Q_LAYOUT_T = LIQT::layout;
|
||||
static constexpr LI_LAYOUT K_LAYOUT_T = LIQT::keyLayout;
|
||||
static constexpr bool PAGE_ATTENTION = LIQT::pageAttention;
|
||||
// MM输出数据类型, 当前只支持float
|
||||
using MM1_OUT_T = float;
|
||||
|
||||
__aicore__ inline LIQVector(){};
|
||||
__aicore__ inline void ProcessVec0(const LIQCommon::RunInfo &info);
|
||||
__aicore__ inline void ProcessVec1(const LIQCommon::RunInfo &info);
|
||||
__aicore__ inline void ProcessLD();
|
||||
__aicore__ inline void InitBuffers(TPipe *pipe);
|
||||
__aicore__ inline void InitParams(const struct LIQCommon::ConstInfo &constInfo,
|
||||
const LIQTilingData *__restrict tilingData);
|
||||
__aicore__ inline void InitVecWorkspaceTensor(GlobalTensor<half> vec0OutGm, GlobalTensor<MM1_OUT_T> mm1ResGm,
|
||||
GlobalTensor<float> vec1ResGm, GlobalTensor<int64_t> vec1ParamGm);
|
||||
__aicore__ inline void InitVecInputTensor(GlobalTensor<half> weightsGm, GlobalTensor<half> qScaleGm,
|
||||
GlobalTensor<half> kScaleGm, GlobalTensor<int32_t> indiceOutGm,
|
||||
GlobalTensor<int32_t> blockTableGm);
|
||||
__aicore__ inline void CleanInvalidOutput(int64_t invalidS1offset);
|
||||
__aicore__ inline void AllocEventID();
|
||||
__aicore__ inline void FreeEventID();
|
||||
__aicore__ inline void InitLDBuffers(TPipe *pipe);
|
||||
|
||||
protected:
|
||||
GlobalTensor<MM1_OUT_T> mm1ResGm;
|
||||
GlobalTensor<float> vec1ResGm;
|
||||
GlobalTensor<int64_t> vec1ParamGm;
|
||||
GlobalTensor<half> weightsGm;
|
||||
GlobalTensor<half> qScaleGm;
|
||||
GlobalTensor<half> kScaleGm;
|
||||
GlobalTensor<half> vec0OutGm;
|
||||
GlobalTensor<int32_t> indiceOutGm;
|
||||
GlobalTensor<int32_t> blockTableGm;
|
||||
// =================================常量区=================================
|
||||
|
||||
private:
|
||||
__aicore__ inline void GetKeyScale(const LIQCommon::RunInfo &runInfo, const LocalTensor<half> &resUb,
|
||||
int64_t batchId, int64_t startS2, int64_t getLen);
|
||||
// ================================Local Buffer区====================================
|
||||
// queue
|
||||
TQue<QuePosition::VECIN, 1> inQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> outQueue_;
|
||||
|
||||
// tmp buff for vector
|
||||
TBuf<TPosition::VECCALC> sortOutBuf_;
|
||||
TBuf<TPosition::VECCALC> indexBuf_;
|
||||
TBuf<TPosition::VECCALC> paramBuf_;
|
||||
TBuf<TPosition::VECCALC> tmpBuf_;
|
||||
|
||||
// tmp buff for LD
|
||||
TBuf<> ldToBeMrgBuf_;
|
||||
TBuf<> ldTmpBuf_;
|
||||
TBuf<> ldOutValueBuf_;
|
||||
TBuf<> ldOutIdxBuf_;
|
||||
|
||||
LocalTensor<int32_t> globalTopkIndice_;
|
||||
LocalTensor<float> globalTopkUb_;
|
||||
|
||||
int32_t blockId_ = -1;
|
||||
// para for vector
|
||||
int32_t groupInner_ = 0;
|
||||
int32_t globalTopkNum_ = 0;
|
||||
int64_t blockS2StartIdx_ = 0;
|
||||
int32_t gSize_ = 0;
|
||||
int32_t kSeqSize_ = 0;
|
||||
int32_t kHeadNum_ = 0;
|
||||
int32_t qHeadNum_ = 0;
|
||||
int32_t s1BaseSize_ = 0;
|
||||
int32_t s2BaseSize_ = 0;
|
||||
int32_t kCacheBlockSize_ = 0;
|
||||
int32_t maxBlockNumPerBatch_ = 0;
|
||||
|
||||
// para for LD
|
||||
uint32_t mrgListNum_ = 4;
|
||||
uint32_t paramNum_ = 16;
|
||||
|
||||
struct LIQCommon::ConstInfo constInfo_;
|
||||
};
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::GetKeyScale(const LIQCommon::RunInfo &runInfo, const LocalTensor<half> &resUb,
|
||||
int64_t batchId, int64_t startS2, int64_t getLen)
|
||||
{
|
||||
// startS2一定能整除kCacheBlockSize_
|
||||
AscendC::DataCopyPadExtParams<half> padParams{false, 0, 0, 0};
|
||||
AscendC::DataCopyExtParams copyInParams;
|
||||
if constexpr (PAGE_ATTENTION) {
|
||||
int32_t startBlockTableIdx = startS2 / kCacheBlockSize_;
|
||||
int32_t startBlockTableOffset = startS2 % kCacheBlockSize_;
|
||||
int32_t blockTableBatchOffset = batchId * maxBlockNumPerBatch_;
|
||||
copyInParams.blockCount = 1;
|
||||
copyInParams.srcStride = 0;
|
||||
copyInParams.dstStride = 0;
|
||||
copyInParams.rsv = 0;
|
||||
int32_t resUbBaseOffset = 0;
|
||||
if (startBlockTableOffset > 0) {
|
||||
int32_t firstPartLen =
|
||||
kCacheBlockSize_ - startBlockTableOffset > getLen ? getLen : kCacheBlockSize_ - startBlockTableOffset;
|
||||
copyInParams.blockLen = firstPartLen * sizeof(half);
|
||||
int32_t blockId = blockTableGm.GetValue(blockTableBatchOffset + startBlockTableIdx);
|
||||
SetWaitFlag<HardEvent::S_MTE2>(HardEvent::S_MTE2);
|
||||
AscendC::DataCopyPad(resUb, kScaleGm[blockId * kCacheBlockSize_ + startBlockTableOffset],
|
||||
copyInParams, padParams);
|
||||
startBlockTableIdx++;
|
||||
getLen = getLen - firstPartLen;
|
||||
resUbBaseOffset = firstPartLen;
|
||||
}
|
||||
int32_t getLoopNum = CeilDiv(getLen, kCacheBlockSize_);
|
||||
copyInParams.blockLen = kCacheBlockSize_ * sizeof(half);
|
||||
for (int32_t i = 0; i < getLoopNum; i++) {
|
||||
if (i == getLoopNum - 1) {
|
||||
copyInParams.blockLen = (getLen - i * kCacheBlockSize_) * sizeof(half);
|
||||
}
|
||||
int32_t blockId = blockTableGm.GetValue(blockTableBatchOffset + startBlockTableIdx + i);
|
||||
SetWaitFlag<HardEvent::S_MTE2>(HardEvent::S_MTE2);
|
||||
AscendC::DataCopyPad(resUb[resUbBaseOffset + i * kCacheBlockSize_], kScaleGm[blockId * kCacheBlockSize_],
|
||||
copyInParams, padParams);
|
||||
}
|
||||
} else {
|
||||
copyInParams.blockCount = 1;
|
||||
copyInParams.blockLen = getLen * sizeof(half);
|
||||
copyInParams.srcStride = 0;
|
||||
copyInParams.dstStride = 0;
|
||||
copyInParams.rsv = 0;
|
||||
AscendC::DataCopyPad(resUb, kScaleGm[runInfo.tensorKeyScaleOffset], copyInParams, padParams);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::InitBuffers(TPipe *pipe)
|
||||
{
|
||||
pipe->InitBuffer(paramBuf_, LD_PARAM_NUM * sizeof(int64_t)); // 1 KB
|
||||
pipe->InitBuffer(inQueue_, 2, s2BaseSize_ * sizeof(float) * 2); // 32KB
|
||||
pipe->InitBuffer(outQueue_, 1, BASE_TOPK * sizeof(float)); // 8 KB
|
||||
pipe->InitBuffer(indexBuf_, s2BaseSize_ * sizeof(int32_t)); // 8 KB
|
||||
pipe->InitBuffer(tmpBuf_, 64 * 1024); // 64KB
|
||||
pipe->InitBuffer(sortOutBuf_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK_VALUE_IDX_SIZE * sizeof(float)); // 32KB
|
||||
|
||||
globalTopkIndice_ = indexBuf_.Get<int32_t>();
|
||||
globalTopkUb_ = sortOutBuf_.Get<float>();
|
||||
globalTopkNum_ = 0;
|
||||
|
||||
// 基本块执行前初始化UB和GM
|
||||
// step1. 初始化一个有序索引 0 - s2BaseSize_
|
||||
ArithProgression<int32_t>(globalTopkIndice_, 0, 1, s2BaseSize_);
|
||||
// step2. globalTopkUb_ [CeilDiv(s1BaseSize_, 2), BASE_TOPK, 2] -inf,-1
|
||||
InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK_VALUE_IDX_SIZE);
|
||||
|
||||
// step3. 初始化vec1ParamGm,是否进行LD的标志位设为-1(needFd=-1)
|
||||
// vec1ResIn32Gm = [aic, 2, s1BaseSize_, 16] int32
|
||||
// ws清零 [needFd, s2AcSeq, s2Start, s2End, isS2End, bn2idx, s1Idx, ......]
|
||||
LocalTensor<float> tmpfBuff = outQueue_.AllocTensor<float>();
|
||||
Duplicate(tmpfBuff.template ReinterpretCast<int32_t>(), -1, 2 * (s1BaseSize_ / 2) * paramNum_ * 2);
|
||||
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
|
||||
int64_t wsInfoOffset = (blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ + // 2个AIV共同地址偏移
|
||||
(blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_; // 每个AIV的地址偏移,S1方向
|
||||
DataCopyPad(vec1ParamGm[wsInfoOffset], tmpfBuff.template ReinterpretCast<int64_t>(),
|
||||
{1, static_cast<uint16_t>((s1BaseSize_ / 2) * 2 * paramNum_ * sizeof(int64_t)), 0, 0});
|
||||
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
|
||||
outQueue_.FreeTensor(tmpfBuff);
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::InitLDBuffers(TPipe *pipe)
|
||||
{
|
||||
pipe->Reset();
|
||||
pipe->InitBuffer(ldToBeMrgBuf_, BASE_TOPK_VALUE_IDX_SIZE * mrgListNum_ * sizeof(float));
|
||||
pipe->InitBuffer(ldTmpBuf_, BASE_TOPK_VALUE_IDX_SIZE * mrgListNum_ * sizeof(float));
|
||||
pipe->InitBuffer(ldOutValueBuf_, BASE_TOPK * sizeof(float));
|
||||
pipe->InitBuffer(ldOutIdxBuf_, BASE_TOPK * sizeof(int32_t));
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::InitParams(const struct LIQCommon::ConstInfo &constInfo,
|
||||
const LIQTilingData *__restrict tilingData)
|
||||
{
|
||||
this->constInfo_ = constInfo;
|
||||
blockS2StartIdx_ = 0;
|
||||
gSize_ = constInfo.gSize;
|
||||
kSeqSize_ = constInfo.kSeqSize;
|
||||
// define N2 para
|
||||
kHeadNum_ = constInfo.kHeadNum;
|
||||
qHeadNum_ = constInfo.qHeadNum;
|
||||
// define MMBase para
|
||||
s1BaseSize_ = constInfo.s1BaseSize; // 4
|
||||
s2BaseSize_ = constInfo.s2BaseSize; // 2048
|
||||
kCacheBlockSize_ = constInfo.kCacheBlockSize;
|
||||
maxBlockNumPerBatch_ = constInfo.maxBlockNumPerBatch;
|
||||
blockId_ = GetBlockIdx();
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::InitVecInputTensor(GlobalTensor<half> weightsGm, GlobalTensor<half> qScaleGm,
|
||||
GlobalTensor<half> kScaleGm,
|
||||
GlobalTensor<int32_t> indiceOutGm,
|
||||
GlobalTensor<int32_t> blockTableGm)
|
||||
{
|
||||
this->weightsGm = weightsGm;
|
||||
this->qScaleGm = qScaleGm;
|
||||
this->kScaleGm = kScaleGm;
|
||||
this->indiceOutGm = indiceOutGm;
|
||||
this->blockTableGm = blockTableGm;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::InitVecWorkspaceTensor(GlobalTensor<half> vec0OutGm,
|
||||
GlobalTensor<MM1_OUT_T> mm1ResGm,
|
||||
GlobalTensor<float> vec1ResGm,
|
||||
GlobalTensor<int64_t> vec1ParamGm)
|
||||
{
|
||||
this->mm1ResGm = mm1ResGm;
|
||||
this->vec1ResGm = vec1ResGm;
|
||||
this->vec0OutGm = vec0OutGm;
|
||||
this->vec1ParamGm = vec1ParamGm;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::AllocEventID()
|
||||
{
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::FreeEventID()
|
||||
{
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::CleanInvalidOutput(int64_t invalidS1offset)
|
||||
{
|
||||
// init -1 and copy to output
|
||||
LocalTensor<float> valueULocal = outQueue_.AllocTensor<float>();
|
||||
LocalTensor<int32_t> idxULocal1 = valueULocal.template ReinterpretCast<int32_t>();
|
||||
Duplicate(idxULocal1, constInfo_.INVALID_IDX, constInfo_.sparseCount);
|
||||
outQueue_.EnQue<float>(valueULocal);
|
||||
valueULocal = outQueue_.DeQue<float>();
|
||||
LIQServiceVec::CopyOut(indiceOutGm[invalidS1offset], idxULocal1, constInfo_.sparseCount);
|
||||
outQueue_.FreeTensor(valueULocal);
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::ProcessVec0(const LIQCommon::RunInfo &info)
|
||||
{
|
||||
// 只需要一个v核做
|
||||
if (blockId_ % 2 != 0) {
|
||||
return;
|
||||
}
|
||||
int32_t cuBaseS1Idx = info.gS1Idx * s1BaseSize_;
|
||||
// 计算输出w基地址偏移 偶数循环 -> 0 + aic_offset 奇数循环 -> 4*64 + aic_offset
|
||||
int64_t vec0OutGmOffset = (info.loop % 2) * ((s1BaseSize_ * gSize_ * BLOCK_CUBE));
|
||||
// 计算输入weight的地址偏移,qScale的地址偏移与weight相同
|
||||
int64_t weightGmOffset = info.tensorWeightsOffset + cuBaseS1Idx * qHeadNum_;
|
||||
// 当前需要计算的S1行数,处理尾块场景
|
||||
int32_t cuS1ProcNum = cuBaseS1Idx + s1BaseSize_ > info.actS1Size ? info.actS1Size % s1BaseSize_ : s1BaseSize_;
|
||||
int32_t cuProcEleNum = cuS1ProcNum * gSize_;
|
||||
|
||||
LocalTensor<half> inWeightsUb = inQueue_.AllocTensor<half>();
|
||||
LocalTensor<half> inQScaleUb = inWeightsUb[cuProcEleNum];
|
||||
AscendC::DataCopyPadExtParams<half> padParams{false, 0, 0, 0};
|
||||
AscendC::DataCopyExtParams copyInParams;
|
||||
copyInParams.blockCount = 1;
|
||||
copyInParams.blockLen = cuProcEleNum * sizeof(half);
|
||||
copyInParams.srcStride = 0;
|
||||
copyInParams.dstStride = 0;
|
||||
copyInParams.rsv = 0;
|
||||
AscendC::DataCopyPad(inWeightsUb, weightsGm[weightGmOffset], copyInParams, padParams);
|
||||
AscendC::DataCopyPad(inQScaleUb, qScaleGm[weightGmOffset], copyInParams, padParams);
|
||||
|
||||
inQueue_.EnQue<half>(inWeightsUb);
|
||||
inWeightsUb = inQueue_.DeQue<half>();
|
||||
AscendC::Mul(inWeightsUb, inWeightsUb, inQScaleUb, cuProcEleNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<half> resUb = outQueue_.AllocTensor<half>();
|
||||
AscendC::Brcb(resUb, inWeightsUb, static_cast<uint8_t>(cuProcEleNum / 8), {1, 8});
|
||||
inQueue_.FreeTensor(inWeightsUb);
|
||||
|
||||
outQueue_.EnQue<half>(resUb);
|
||||
resUb = outQueue_.DeQue<half>();
|
||||
AscendC::DataCopyParams copyOutParams;
|
||||
copyOutParams.blockCount = 1;
|
||||
copyOutParams.blockLen = cuProcEleNum * BLOCK_CUBE * sizeof(half);
|
||||
copyOutParams.srcStride = 0;
|
||||
copyOutParams.dstStride = 0;
|
||||
AscendC::DataCopyPad(vec0OutGm[vec0OutGmOffset], resUb, copyOutParams);
|
||||
outQueue_.FreeTensor(resUb);
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::ProcessVec1(const LIQCommon::RunInfo &info)
|
||||
{
|
||||
int32_t cuBaseS1Idx = info.gS1Idx * s1BaseSize_;
|
||||
int32_t cuBaseS2Idx = info.s2Idx * s2BaseSize_;
|
||||
|
||||
// 计算基本块基地址偏移 偶数循环 -> 0 + aic_offset 奇数循环 -> 4*2048 + aic_offset
|
||||
int64_t mmGmOffset = (info.loop % 2) * (s1BaseSize_ * s2BaseSize_);
|
||||
|
||||
// cuS1BeginIdxPerAiv: 每个AIV的S1起始偏移
|
||||
int32_t cuS1BeginIdxPerAiv = cuBaseS1Idx;
|
||||
int32_t cuS1ProcNum =
|
||||
cuS1BeginIdxPerAiv + s1BaseSize_ > info.actS1Size ? info.actS1Size % s1BaseSize_ : s1BaseSize_;
|
||||
// cuS1ProcNumPerAiv: 每个AIv的S1计算量
|
||||
int32_t cuS1ProcNumPerAiv = blockId_ % 2 == 0 ? CeilDiv(cuS1ProcNum, 2) : (cuS1ProcNum / 2);
|
||||
cuS1BeginIdxPerAiv += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2);
|
||||
// 基本块基地址偏移奇数核加一个S1地址偏移
|
||||
mmGmOffset += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2) * s2BaseSize_;
|
||||
// 非首个基本块, M(S1)轴发生切换需要初始化
|
||||
if (info.loop != 0 && info.s2Idx == 0) {
|
||||
// globalTopkUb_ value,index=-inf,-1
|
||||
InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK_VALUE_IDX_SIZE);
|
||||
blockS2StartIdx_ = 0;
|
||||
} else if (info.loop == 0) {
|
||||
blockS2StartIdx_ = info.s2Idx;
|
||||
}
|
||||
// cuRealAcSeq: 当前基本块S1对应的AcSeq
|
||||
int32_t cuRealAcSeq = info.actS2Size;
|
||||
if (constInfo_.attenMaskFlag) {
|
||||
// attenMask true场景
|
||||
cuRealAcSeq = info.actS2Size - (info.actS1Size - cuS1BeginIdxPerAiv);
|
||||
}
|
||||
|
||||
// LD输出S1方向偏移,保证2个Vector输出的内容连续
|
||||
uint32_t ldS1Offset = (blockId_ % 2 == 0) ? s1BaseSize_ / 2 - cuS1ProcNumPerAiv : 0;
|
||||
for (int innerS1Idx = 0; innerS1Idx < cuS1ProcNumPerAiv; innerS1Idx++) {
|
||||
if (constInfo_.attenMaskFlag) {
|
||||
cuRealAcSeq += 1;
|
||||
}
|
||||
int32_t cuS2Len = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq ? cuRealAcSeq - cuBaseS2Idx : s2BaseSize_;
|
||||
int32_t cuS1Idx = cuS1BeginIdxPerAiv + innerS1Idx;
|
||||
if (cuRealAcSeq > 0 && cuS2Len > 0) {
|
||||
int32_t cuS2LenVecAlign = CeilDiv(cuS2Len, s2BaseSize_) * s2BaseSize_;
|
||||
LocalTensor<float> mmInUb = inQueue_.AllocTensor<float>();
|
||||
LocalTensor<float> kScaleUb = mmInUb[cuS2LenVecAlign];
|
||||
LocalTensor<half> kScaleTUb = kScaleUb.template ReinterpretCast<half>()[cuS2LenVecAlign];
|
||||
AscendC::DataCopyPadExtParams<float> padParams{false, 0, 0, 0};
|
||||
AscendC::DataCopyPadExtParams<half> padTParams{false, 0, 0, 0};
|
||||
AscendC::DataCopyExtParams copyInParams;
|
||||
copyInParams.blockCount = 1;
|
||||
copyInParams.blockLen = cuS2Len * sizeof(float);
|
||||
copyInParams.srcStride = 0;
|
||||
copyInParams.dstStride = 0;
|
||||
copyInParams.rsv = 0;
|
||||
AscendC::DataCopyPad(mmInUb, mm1ResGm[mmGmOffset + innerS1Idx * s2BaseSize_], copyInParams, padParams);
|
||||
GetKeyScale(info, kScaleTUb, info.bIdx, cuBaseS2Idx, cuS2Len);
|
||||
inQueue_.EnQue<float>(mmInUb);
|
||||
mmInUb = inQueue_.DeQue<float>();
|
||||
AscendC::Cast(kScaleUb, kScaleTUb, RoundMode::CAST_NONE, cuS2Len);
|
||||
PipeBarrier<PIPE_V>();
|
||||
AscendC::Mul(mmInUb, mmInUb, kScaleUb, cuS2Len);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<float> sortBuff = tmpBuf_.Get<float>();
|
||||
LocalTensor<float> sortScoreUb = sortBuff;
|
||||
LocalTensor<float> sortIndiceUb = sortBuff[cuS2LenVecAlign];
|
||||
PipeBarrier<PIPE_V>();
|
||||
Duplicate(sortScoreUb.template ReinterpretCast<int32_t>(), LIQServiceVec::NEG_INF, cuS2LenVecAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(sortScoreUb, mmInUb, 0.0f, cuS2Len);
|
||||
PipeBarrier<PIPE_V>();
|
||||
inQueue_.FreeTensor(mmInUb);
|
||||
LocalTensor<int32_t> sortIndiceUbInt = sortIndiceUb.template ReinterpretCast<int32_t>();
|
||||
// 无效数据索引填充为-1
|
||||
if (cuS2LenVecAlign != cuS2Len) {
|
||||
Duplicate(sortIndiceUbInt, -1, cuS2LenVecAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
Adds(sortIndiceUbInt, globalTopkIndice_, static_cast<int32_t>(cuBaseS2Idx), cuS2Len);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<float> tmpSortBuf = sortBuff[2 * cuS2LenVecAlign];
|
||||
LIQServiceVec::SortAll(sortBuff, tmpSortBuf, cuS2LenVecAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LIQServiceVec::MergeSort(globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE], BASE_TOPK, sortBuff,
|
||||
cuS2LenVecAlign, tmpSortBuf);
|
||||
PipeBarrier<PIPE_V>();
|
||||
bool isS2End = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq;
|
||||
bool needCopyOutGm = blockS2StartIdx_ == 0 && isS2End;
|
||||
// 中间结果保存
|
||||
bool needCopyWsGm = info.isAllLoopEnd || isS2End;
|
||||
if (needCopyOutGm) {
|
||||
LocalTensor<uint32_t> idxULocal = outQueue_.AllocTensor<uint32_t>();
|
||||
ExtractIndex(idxULocal,
|
||||
globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE].template ReinterpretCast<uint32_t>(),
|
||||
BASE_TOPK);
|
||||
PipeBarrier<PIPE_V>();
|
||||
InitSortOutBuf(globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE], BASE_TOPK_VALUE_IDX_SIZE);
|
||||
outQueue_.EnQue<uint32_t>(idxULocal);
|
||||
idxULocal = outQueue_.DeQue<uint32_t>();
|
||||
LIQServiceVec::CopyOut(indiceOutGm[info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount],
|
||||
idxULocal.template ReinterpretCast<int32_t>(), constInfo_.sparseCount);
|
||||
outQueue_.FreeTensor(idxULocal);
|
||||
} else if (needCopyWsGm) {
|
||||
// vec1Res Gm = [aic, s1BaseSize_, 2, 2, topkOut_] float32
|
||||
// vec1Param Gm = [aic, s1BaseSize_, 2, 16] int64
|
||||
// 16 = [needFd, s2AcSeq, s2Start, s2End, isS2End, bn2idx, s1Idx, S1ProcNum, ......]
|
||||
|
||||
int64_t wsOffset =
|
||||
(blockId_ / 2) * s1BaseSize_ * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 2个AIV共同地址偏移
|
||||
(blockId_ % 2) * (s1BaseSize_ / 2) * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 每个AIV的地址偏移,S1方向
|
||||
(ldS1Offset + innerS1Idx) * 2 * BASE_TOPK_VALUE_IDX_SIZE;
|
||||
int64_t wsInfoOffset =
|
||||
(blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ + // 2个AIV共同地址偏移
|
||||
(blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_ + // 每个AIV的地址偏移,S1方向
|
||||
(ldS1Offset + innerS1Idx) * 2 * paramNum_;
|
||||
|
||||
LocalTensor<int64_t> tmpiBuff = paramBuf_.Get<int64_t>();
|
||||
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
|
||||
tmpiBuff.SetValue(0, static_cast<int64_t>(1));
|
||||
tmpiBuff.SetValue(1, static_cast<int64_t>(cuRealAcSeq));
|
||||
tmpiBuff.SetValue(2, static_cast<int64_t>(blockS2StartIdx_));
|
||||
tmpiBuff.SetValue(3, static_cast<int64_t>(cuBaseS2Idx + cuS2Len));
|
||||
tmpiBuff.SetValue(4, static_cast<int64_t>(isS2End));
|
||||
tmpiBuff.SetValue(5, static_cast<int64_t>(info.bN2Idx));
|
||||
tmpiBuff.SetValue(6, static_cast<int64_t>(cuS1Idx));
|
||||
tmpiBuff.SetValue(7, static_cast<int64_t>(cuS1ProcNum));
|
||||
tmpiBuff.SetValue(8, static_cast<int64_t>(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount));
|
||||
// 写入头尾判断
|
||||
// [head, tail]
|
||||
// head: 与前面规约,与前后规约
|
||||
// tail: 与后面规约
|
||||
bool isTailReduce = blockS2StartIdx_ == 0; // 一定是isLastTile
|
||||
// WS偏移规则 blockS2StartIdx_ != 0
|
||||
// 跟前面块做规约 写到0偏移 不用做计算 blockS2StartIdx_ == 0 and !isS2End
|
||||
// 跟后面块做规约 写到1偏移 需要 + s1BaseSize_, BASE_TOPK*2
|
||||
if (isTailReduce) { // S2不是最后结束的数据就需要往后做规约,放入第二块ws
|
||||
wsInfoOffset += paramNum_;
|
||||
wsOffset += BASE_TOPK_VALUE_IDX_SIZE;
|
||||
}
|
||||
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
|
||||
LIQServiceVec::CopyOut(vec1ParamGm[wsInfoOffset], tmpiBuff, 16);
|
||||
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
|
||||
LIQServiceVec::CopyOut(vec1ResGm[wsOffset], globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE],
|
||||
BASE_TOPK_VALUE_IDX_SIZE);
|
||||
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
|
||||
}
|
||||
} else if (cuRealAcSeq <= 0) {
|
||||
CleanInvalidOutput(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount);
|
||||
}
|
||||
}
|
||||
|
||||
// BNSD场景无效S1 输出-1
|
||||
if (Q_LAYOUT_T == LI_LAYOUT::BSND) {
|
||||
// 最后一个S1的基本块, 需要 >= info.actS1Size
|
||||
bool isS1LoopEnd = (cuBaseS1Idx + s1BaseSize_) >= info.actS1Size;
|
||||
int32_t invalidS1Num = constInfo_.qSeqSize - info.actS1Size;
|
||||
// blockS2StartIdx_ == 0 控制S2从开始的核去做冗余清理
|
||||
if (invalidS1Num > 0 && isS1LoopEnd && blockS2StartIdx_ == 0) {
|
||||
int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num, 2) : (invalidS1Num / 2);
|
||||
int32_t s1OffsetPerAiv = info.actS1Size + (blockId_ % 2) * CeilDiv(invalidS1Num, 2);
|
||||
for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) {
|
||||
CleanInvalidOutput(info.indiceOutOffset + (s1OffsetPerAiv + innerS1Idx) * constInfo_.sparseCount);
|
||||
}
|
||||
}
|
||||
|
||||
int32_t invalidS1Num2 = info.actS1Size - info.actS2Size;
|
||||
if (invalidS1Num2 > 0 && isS1LoopEnd && blockS2StartIdx_ == 0 && constInfo_.attenMaskFlag) {
|
||||
int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num2, 2) : (invalidS1Num2 / 2);
|
||||
int32_t s1OffsetPerAiv = (blockId_ % 2) * CeilDiv(invalidS1Num2, 2);
|
||||
for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) {
|
||||
CleanInvalidOutput((info.bN2Idx * constInfo_.qSeqSize + s1OffsetPerAiv + innerS1Idx) *
|
||||
constInfo_.sparseCount);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (info.isLastS2InnerLoop) {
|
||||
// S2最后一个Loop后, 下一个基本块初始从0开始
|
||||
blockS2StartIdx_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::ProcessLD()
|
||||
{
|
||||
int32_t curCubeId = blockId_ / 2;
|
||||
int32_t tmpCubeId = curCubeId;
|
||||
|
||||
int64_t s2ActSeq;
|
||||
int64_t s2Start;
|
||||
int64_t s2End;
|
||||
int64_t isS2End;
|
||||
int64_t bn2Idx;
|
||||
int64_t s1Idx;
|
||||
uint32_t acc_list_num = 0;
|
||||
int64_t bIdx = 0;
|
||||
int64_t needFd;
|
||||
int64_t wsOffset;
|
||||
int64_t wsInfoOffset = 0;
|
||||
int64_t nextneedFd;
|
||||
int64_t valueOffset = 0;
|
||||
int64_t outOffset = 0;
|
||||
|
||||
LocalTensor<float> curValueIdxUb = ldToBeMrgBuf_.Get<float>();
|
||||
LocalTensor<float> tmpUb = ldTmpBuf_.Get<float>();
|
||||
|
||||
// S2开头信息
|
||||
// 开始必然没有头规约,因此从尾规约开始处理,while循环读取下一个核的头规约
|
||||
// 存满4个list或者遇到S2结尾,则做merge,直到做完S2
|
||||
// 每个核都忽略自己的头规约,因为必然由前面的核做完
|
||||
uint32_t s1LdStartIdx = 0;
|
||||
uint32_t s1ProcNum = 0;
|
||||
uint64_t paramGmCoreOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_;
|
||||
for (uint32_t innerS1Idx = 0; innerS1Idx < s1BaseSize_; innerS1Idx++) {
|
||||
needFd = vec1ParamGm.GetValue(paramGmCoreOffset + innerS1Idx * 2 * paramNum_ + paramNum_);
|
||||
if (needFd == 1) {
|
||||
s1LdStartIdx = (s1ProcNum == 0) ? innerS1Idx : s1LdStartIdx;
|
||||
s1ProcNum++;
|
||||
}
|
||||
}
|
||||
|
||||
if (s1ProcNum == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// S1逐行计算
|
||||
uint32_t s1VecNum = CeilDiv(s1ProcNum, 2);
|
||||
if (blockId_ % 2 == 1) {
|
||||
s1LdStartIdx = s1LdStartIdx + s1VecNum;
|
||||
s1VecNum = s1ProcNum - s1VecNum;
|
||||
}
|
||||
for (uint32_t innerS1Idx = s1LdStartIdx; innerS1Idx < s1LdStartIdx + s1VecNum; innerS1Idx++) {
|
||||
// 重置偏移
|
||||
tmpCubeId = curCubeId;
|
||||
acc_list_num = 0;
|
||||
valueOffset = 0;
|
||||
|
||||
// 搬入数据
|
||||
wsOffset = tmpCubeId * s1BaseSize_ * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 2个AIV共同地址偏移
|
||||
innerS1Idx * 2 * BASE_TOPK_VALUE_IDX_SIZE + BASE_TOPK_VALUE_IDX_SIZE;
|
||||
SetWaitFlag<HardEvent::V_MTE2>(HardEvent::V_MTE2);
|
||||
SetWaitFlag<HardEvent::S_MTE2>(HardEvent::S_MTE2);
|
||||
DataCopyPad(curValueIdxUb, vec1ResGm[wsOffset],
|
||||
{1, static_cast<uint16_t>(BASE_TOPK_VALUE_IDX_SIZE * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0});
|
||||
acc_list_num++;
|
||||
valueOffset += BASE_TOPK_VALUE_IDX_SIZE;
|
||||
|
||||
// 获取下一个核规约信息
|
||||
tmpCubeId++;
|
||||
wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_;
|
||||
needFd = vec1ParamGm.GetValue(wsInfoOffset);
|
||||
isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4);
|
||||
s1Idx = vec1ParamGm.GetValue(wsInfoOffset + 6);
|
||||
outOffset = vec1ParamGm.GetValue(wsInfoOffset + 8);
|
||||
|
||||
while (needFd == 1) {
|
||||
// 搬入头规约数据
|
||||
wsOffset = tmpCubeId * s1BaseSize_ * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 2个AIV共同地址偏移
|
||||
innerS1Idx * 2 * BASE_TOPK_VALUE_IDX_SIZE;
|
||||
SetWaitFlag<HardEvent::V_MTE2>(HardEvent::V_MTE2);
|
||||
SetWaitFlag<HardEvent::S_MTE2>(HardEvent::S_MTE2);
|
||||
DataCopyPad(curValueIdxUb[valueOffset], vec1ResGm[wsOffset],
|
||||
{1, static_cast<uint16_t>(BASE_TOPK_VALUE_IDX_SIZE * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0});
|
||||
valueOffset += BASE_TOPK_VALUE_IDX_SIZE;
|
||||
acc_list_num++;
|
||||
|
||||
// 每满4个list,聚合 前2K为mrg结果
|
||||
if (acc_list_num == mrgListNum_) {
|
||||
// MrgSort 四条2048的队列,Mrg成一条
|
||||
AscendC::MrgSort4Info params;
|
||||
params.elementLengths[0] = BASE_TOPK;
|
||||
params.elementLengths[1] = BASE_TOPK;
|
||||
params.elementLengths[2] = BASE_TOPK;
|
||||
params.elementLengths[3] = BASE_TOPK;
|
||||
params.ifExhaustedSuspension = true;
|
||||
params.validBit = 0b1111;
|
||||
params.repeatTimes = 1;
|
||||
|
||||
AscendC::MrgSortSrcList<float> srcList;
|
||||
srcList.src1 = curValueIdxUb[0];
|
||||
srcList.src2 = curValueIdxUb[BASE_TOPK_VALUE_IDX_SIZE];
|
||||
srcList.src3 = curValueIdxUb[2 * BASE_TOPK_VALUE_IDX_SIZE];
|
||||
srcList.src4 = curValueIdxUb[3 * BASE_TOPK_VALUE_IDX_SIZE];
|
||||
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
|
||||
MrgSort(tmpUb, srcList, params);
|
||||
PipeBarrier<PIPE_V>();
|
||||
DataCopy(curValueIdxUb, tmpUb, BASE_TOPK_VALUE_IDX_SIZE);
|
||||
PipeBarrier<PIPE_V>();
|
||||
acc_list_num = 1;
|
||||
valueOffset = BASE_TOPK_VALUE_IDX_SIZE;
|
||||
}
|
||||
|
||||
// reduce到S2末尾,则跳出
|
||||
if (isS2End == 1) {
|
||||
break;
|
||||
}
|
||||
|
||||
tmpCubeId++;
|
||||
wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_;
|
||||
needFd = vec1ParamGm.GetValue(wsInfoOffset);
|
||||
isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4);
|
||||
}
|
||||
|
||||
// mrg不足4个list的数据
|
||||
if (acc_list_num != 1) {
|
||||
AscendC::MrgSort4Info params;
|
||||
params.elementLengths[0] = BASE_TOPK;
|
||||
params.elementLengths[1] = BASE_TOPK;
|
||||
params.elementLengths[2] = BASE_TOPK;
|
||||
params.elementLengths[3] = BASE_TOPK;
|
||||
params.ifExhaustedSuspension = true;
|
||||
if (acc_list_num == 2) {
|
||||
params.validBit = 0b0011;
|
||||
} else if (acc_list_num == 3) {
|
||||
params.validBit = 0b0111;
|
||||
}
|
||||
params.repeatTimes = 1;
|
||||
|
||||
AscendC::MrgSortSrcList<float> srcList;
|
||||
srcList.src1 = curValueIdxUb[0];
|
||||
srcList.src2 = curValueIdxUb[BASE_TOPK_VALUE_IDX_SIZE];
|
||||
srcList.src3 = curValueIdxUb[2 * BASE_TOPK_VALUE_IDX_SIZE];
|
||||
srcList.src4 = curValueIdxUb[3 * BASE_TOPK_VALUE_IDX_SIZE];
|
||||
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
|
||||
MrgSort(tmpUb, srcList, params);
|
||||
PipeBarrier<PIPE_V>();
|
||||
DataCopy(curValueIdxUb, tmpUb, BASE_TOPK_VALUE_IDX_SIZE);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
// 搬出
|
||||
LocalTensor<float> outValueUb = ldOutValueBuf_.Get<float>();
|
||||
LocalTensor<uint32_t> outIdxUb = ldOutIdxBuf_.Get<uint32_t>();
|
||||
Extract(outValueUb, outIdxUb, curValueIdxUb, (BASE_TOPK / 32));
|
||||
LocalTensor<int32_t> idxULocal1 = outIdxUb.template ReinterpretCast<int32_t>();
|
||||
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
|
||||
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
|
||||
DataCopyPad(indiceOutGm[outOffset], idxULocal1,
|
||||
{1, static_cast<uint16_t>(constInfo_.sparseCount * sizeof(int32_t)), 0, 0});
|
||||
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
|
||||
}
|
||||
}
|
||||
} // namespace LIQKernel
|
||||
#endif
|
||||
@@ -0,0 +1,53 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_template_tiling_key.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef TEMPLATE_TILING_KEY_LI_H_
|
||||
#define TEMPLATE_TILING_KEY_LI_H_
|
||||
|
||||
#include "ascendc/host_api/tiling/template_argument.h"
|
||||
|
||||
#define LI_TPL_FP16 1
|
||||
#define LI_TPL_IN8 2
|
||||
#define LI_TPL_INT32 3
|
||||
#define LI_TPL_BF16 27
|
||||
|
||||
#define LIQ_LAYOUT_BSND 0
|
||||
#define LIQ_LAYOUT_TND 1
|
||||
#define LIQ_LAYOUT_PA_BSND 2
|
||||
|
||||
#define ASCENDC_TPL_4_BW 4
|
||||
|
||||
// 模板参数支持的范围定义
|
||||
ASCENDC_TPL_ARGS_DECL(LightningIndexerQuant, // 算子OpType
|
||||
ASCENDC_TPL_DTYPE_DECL(DT_Q, LI_TPL_IN8), ASCENDC_TPL_DTYPE_DECL(DT_K, LI_TPL_IN8),
|
||||
ASCENDC_TPL_DTYPE_DECL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_DECL(PAGE_ATTENTION, 1, 0),
|
||||
ASCENDC_TPL_UINT_DECL(Q_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND,
|
||||
LIQ_LAYOUT_TND),
|
||||
ASCENDC_TPL_UINT_DECL(K_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST,
|
||||
LIQ_LAYOUT_PA_BSND, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND), );
|
||||
|
||||
// 支持的模板参数组合
|
||||
// 用于调用GET_TPL_TILING_KEY获取TilingKey时,接口内部校验TilingKey是否合法
|
||||
ASCENDC_TPL_SEL(
|
||||
ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_IN8), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_IN8),
|
||||
ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 1),
|
||||
ASCENDC_TPL_UINT_SEL(Q_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND),
|
||||
ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_PA_BSND), ),
|
||||
ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_IN8), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_IN8),
|
||||
ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 0),
|
||||
ASCENDC_TPL_UINT_SEL(Q_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND),
|
||||
ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND), ), );
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,193 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_vector.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef LIGHTNING_INDEXER_QUANT_VECTOR_H
|
||||
#define LIGHTNING_INDEXER_QUANT_VECTOR_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "lightning_indexer_quant_vector.h"
|
||||
|
||||
namespace LIQServiceVec {
|
||||
using namespace AscendC;
|
||||
|
||||
constexpr int32_t NEG_INF = 0xFF800000;
|
||||
constexpr int32_t INVALID_INDEX = -1;
|
||||
constexpr uint8_t VEC_REPEAT_MAX = 255;
|
||||
constexpr uint8_t B32_VEC_ELM_NUM = 64;
|
||||
constexpr uint8_t B32_BLOCK_ALIGN_NUM = 8;
|
||||
constexpr uint8_t B32_VEC_REPEAT_STRIDE = 8;
|
||||
constexpr uint64_t VEC_REPEAT_BYTES = 256;
|
||||
constexpr int32_t CONST_TWO = 2;
|
||||
constexpr int64_t VALUE_AND_INDEX_NUM = 2;
|
||||
constexpr int64_t BLOCK_BYTES = 32;
|
||||
constexpr int64_t MRG_QUE_0 = 0;
|
||||
constexpr int64_t MRG_QUE_1 = 1;
|
||||
constexpr int64_t MRG_QUE_2 = 2;
|
||||
constexpr int64_t MRG_QUE_3 = 3;
|
||||
constexpr int64_t MRG_BLOCK_2 = 2;
|
||||
constexpr int64_t MRG_BLOCK_3 = 3;
|
||||
constexpr int64_t MRG_BLOCK_4 = 4;
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void CopyOut(const GlobalTensor<T> &dstGm, const LocalTensor<T> &srcUb, int64_t copyCount)
|
||||
{
|
||||
AscendC::DataCopyParams dataCopyOutyParams;
|
||||
dataCopyOutyParams.blockCount = 1;
|
||||
dataCopyOutyParams.blockLen = copyCount * sizeof(T);
|
||||
dataCopyOutyParams.srcStride = 0;
|
||||
dataCopyOutyParams.dstStride = 0;
|
||||
AscendC::DataCopyPad(dstGm, srcUb, dataCopyOutyParams);
|
||||
}
|
||||
|
||||
/**
|
||||
src: 传入的初始化空间
|
||||
eleNum: 需要初始化的元素个数需为64整数倍,元素将被初始化为交错排布的-inf,-1
|
||||
*/
|
||||
__aicore__ inline void InitSortOutBuf(const LocalTensor<float> &src, int64_t eleNum)
|
||||
{
|
||||
uint64_t mask1[2] = {0x5555555555555555, 0};
|
||||
uint64_t mask0[2] = {0xaaaaaaaaaaaaaaaa, 0};
|
||||
int64_t repeatNum = eleNum / B32_VEC_ELM_NUM;
|
||||
int64_t forLoop = repeatNum / VEC_REPEAT_MAX;
|
||||
int64_t forRemain = repeatNum % VEC_REPEAT_MAX;
|
||||
for (int i = 0; i < forLoop; i++) {
|
||||
AscendC::Duplicate(src.template ReinterpretCast<int32_t>(), NEG_INF, mask1, VEC_REPEAT_MAX, 1,
|
||||
B32_VEC_REPEAT_STRIDE);
|
||||
AscendC::Duplicate(src.template ReinterpretCast<int32_t>(), INVALID_INDEX, mask0, VEC_REPEAT_MAX, 1,
|
||||
B32_VEC_REPEAT_STRIDE);
|
||||
}
|
||||
if (forRemain > 0) {
|
||||
AscendC::Duplicate(src.template ReinterpretCast<int32_t>()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], NEG_INF,
|
||||
mask1, forRemain, 1, B32_VEC_REPEAT_STRIDE);
|
||||
AscendC::Duplicate(src.template ReinterpretCast<int32_t>()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM],
|
||||
INVALID_INDEX, mask0, forRemain, 1, B32_VEC_REPEAT_STRIDE);
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
/**
|
||||
src: logits和索引,前logitsNum为logits,后logitsNum为索引
|
||||
tmp: 计算使用到的临时空间,大小与src一致
|
||||
logitsNum: 排序的元素个数, 暂只支持[128,256,384,512,1024,2048]
|
||||
*/
|
||||
__aicore__ inline void SortAll(LocalTensor<float> &src, LocalTensor<float> &tmp, int64_t logitsNum)
|
||||
{
|
||||
int64_t sort32Repeats = logitsNum / BLOCK_BYTES;
|
||||
AscendC::Sort32(tmp, src, src[logitsNum].ReinterpretCast<uint32_t>(), sort32Repeats);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
int64_t mrgGroups = sort32Repeats;
|
||||
int64_t mrgElements = BLOCK_BYTES;
|
||||
int64_t i = 0;
|
||||
AscendC::LocalTensor<float> srcTensor;
|
||||
AscendC::LocalTensor<float> dstTensor;
|
||||
while (true) {
|
||||
if (i % CONST_TWO == 0) {
|
||||
srcTensor = tmp;
|
||||
dstTensor = src;
|
||||
} else {
|
||||
srcTensor = src;
|
||||
dstTensor = tmp;
|
||||
}
|
||||
AscendC::MrgSort4Info params;
|
||||
params.elementLengths[0] = mrgElements;
|
||||
params.elementLengths[MRG_QUE_1] = mrgElements;
|
||||
params.elementLengths[MRG_QUE_2] = mrgElements;
|
||||
params.elementLengths[MRG_QUE_3] = mrgElements;
|
||||
params.ifExhaustedSuspension = false;
|
||||
params.validBit = 0b1111;
|
||||
|
||||
AscendC::MrgSortSrcList<float> srcList;
|
||||
srcList.src1 = srcTensor[0];
|
||||
srcList.src2 = srcTensor[MRG_QUE_1 * VALUE_AND_INDEX_NUM * mrgElements];
|
||||
srcList.src3 = srcTensor[MRG_QUE_2 * VALUE_AND_INDEX_NUM * mrgElements];
|
||||
srcList.src4 = srcTensor[MRG_QUE_3 * VALUE_AND_INDEX_NUM * mrgElements];
|
||||
if (mrgGroups <= MRG_BLOCK_4) {
|
||||
params.repeatTimes = 1;
|
||||
if (mrgGroups == 1) {
|
||||
break;
|
||||
} else if (mrgGroups == MRG_BLOCK_2) {
|
||||
params.validBit = 0b0011;
|
||||
} else if (mrgGroups == MRG_BLOCK_3) {
|
||||
params.validBit = 0b0111;
|
||||
} else if (mrgGroups == MRG_BLOCK_4) {
|
||||
params.validBit = 0b1111;
|
||||
}
|
||||
AscendC::MrgSort<float>(dstTensor, srcList, params);
|
||||
i += 1;
|
||||
break;
|
||||
} else {
|
||||
params.repeatTimes = mrgGroups / MRG_BLOCK_4;
|
||||
AscendC::MrgSort<float>(dstTensor, srcList, params);
|
||||
i += 1;
|
||||
mrgElements = mrgElements * MRG_BLOCK_4;
|
||||
mrgGroups = mrgGroups / MRG_BLOCK_4;
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
if (i % CONST_TWO == 0) {
|
||||
AscendC::DataCopy(src, tmp, logitsNum * VALUE_AND_INDEX_NUM);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
mrgDst: 合并进的Tensor
|
||||
mrgSrc: 待合并的Tensor
|
||||
tmpTensor:空间为mrgDst+mrgSrc
|
||||
*/
|
||||
__aicore__ inline void MergeSort(const LocalTensor<float> &mrgDst, int32_t mrgDstNum, LocalTensor<float> &mrgSrc,
|
||||
int32_t mrgSrcNum, LocalTensor<float> &tmpTensor)
|
||||
{
|
||||
AscendC::MrgSort4Info params;
|
||||
params.elementLengths[0] = mrgSrcNum;
|
||||
params.elementLengths[1] = mrgDstNum;
|
||||
params.ifExhaustedSuspension = false;
|
||||
params.validBit = 0b0011;
|
||||
params.repeatTimes = 1;
|
||||
|
||||
AscendC::MrgSortSrcList<float> srcList;
|
||||
srcList.src1 = mrgSrc;
|
||||
srcList.src2 = mrgDst;
|
||||
|
||||
AscendC::MrgSort<float>(tmpTensor, srcList, params);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::DataCopy(mrgDst, tmpTensor, mrgDstNum * VALUE_AND_INDEX_NUM);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void ExtractIndex(const LocalTensor<uint32_t> &idxULocal, const LocalTensor<uint32_t> &sortLocal,
|
||||
int64_t extractNum)
|
||||
{
|
||||
AscendC::GatherMaskParams gatherMaskParams;
|
||||
gatherMaskParams.repeatTimes = Ceil(extractNum * sizeof(float) * VALUE_AND_INDEX_NUM, VEC_REPEAT_BYTES);
|
||||
gatherMaskParams.src0BlockStride = 1;
|
||||
gatherMaskParams.src0RepeatStride = B32_VEC_REPEAT_STRIDE;
|
||||
gatherMaskParams.src1RepeatStride = 0;
|
||||
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
|
||||
uint8_t src1Pattern = 2; // 固定模式2,表示筛选出奇数索引的数
|
||||
AscendC::GatherMask(idxULocal, sortLocal, src1Pattern, false, static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
template <HardEvent event>
|
||||
__aicore__ inline void SetWaitFlag(HardEvent evt)
|
||||
{
|
||||
event_t eventId = static_cast<event_t>(GetTPipePtr()->FetchEventID(evt));
|
||||
AscendC::SetFlag<event>(eventId);
|
||||
AscendC::WaitFlag<event>(eventId);
|
||||
}
|
||||
|
||||
} // namespace LIQServiceVec
|
||||
#endif // LIGHTNING_INDEXER_QUANT_VECTOR_H
|
||||
Reference in New Issue
Block a user