[kernel] add AscendC op: lightning_indexer and sparse_flash_attention (#4625)

### What this PR does / why we need it?
Provide high-performance AscendC operators lightning_indexer and
sparse_flash_attention to boost the execution performance of the
DeepSeek v3.2 model. Meanwhile, adapt the two AscendC operators to
vllm-ascend framework.

### Does this PR introduce _any_ user-facing change?
No (only underlying operator optimizations, with no user-facing changes)

### How was this patch tested?

- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

Signed-off-by: MingYang119 <songmingyang@huawei.com>
This commit is contained in:
Song Mingyang
2025-12-03 09:53:10 +08:00
committed by GitHub
parent 7f2673ea2d
commit 18b90b501d
28 changed files with 9772 additions and 19 deletions

View File

@@ -0,0 +1,53 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file sparse_flash_attention.cpp
* \brief
*/
#include "kernel_operator.h"
#include "sparse_flash_attention_template_tiling_key.h"
#include "sparse_flash_attention_kernel_mla.h"
using namespace AscendC;
#define SFA_OP_IMPL(templateClass, tilingdataClass, ...) \
do { \
templateClass<SFAType<__VA_ARGS__>> op; \
GET_TILING_DATA_WITH_STRUCT(tilingdataClass, tiling_data_in, tiling); \
const tilingdataClass *__restrict tiling_data = &tiling_data_in; \
op.Init(query, key, value, sparseIndices, actualSeqLengthsQuery, actualSeqLengthsKV, \
blocktable, queryRope, keyRope, attentionOut, user, tiling_data, tiling, &tPipe); \
op.Process(); \
} while (0)
template<int FLASH_DECODE, int LAYOUT_T, int KV_LAYOUT_T, int TEMPLATE_MODE>
__global__ __aicore__ void
sparse_flash_attention(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *value,
__gm__ uint8_t *sparseIndices, __gm__ uint8_t *blocktable,
__gm__ uint8_t *actualSeqLengthsQuery, __gm__ uint8_t *actualSeqLengthsKV,
__gm__ uint8_t* queryRope, __gm__ uint8_t* keyRope,
__gm__ uint8_t *attentionOut, __gm__ uint8_t *workspace, __gm__ uint8_t *tiling)
{
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2);
TPipe tPipe;
__gm__ uint8_t *user = GetUserWorkspace(workspace);
if constexpr (ORIG_DTYPE_QUERY == DT_FLOAT16 && ORIG_DTYPE_KEY == DT_FLOAT16 &&
ORIG_DTYPE_ATTENTION_OUT == DT_FLOAT16) {
SFA_OP_IMPL(SparseFlashAttentionMla, SparseFlashAttentionTilingDataMla, half, half, half,
FLASH_DECODE, static_cast<SFA_LAYOUT>(LAYOUT_T), static_cast<SFA_LAYOUT>(KV_LAYOUT_T), TEMPLATE_MODE);
} else { // bf16
SFA_OP_IMPL(SparseFlashAttentionMla, SparseFlashAttentionTilingDataMla, bfloat16_t, bfloat16_t, bfloat16_t,
FLASH_DECODE, static_cast<SFA_LAYOUT>(LAYOUT_T), static_cast<SFA_LAYOUT>(KV_LAYOUT_T), TEMPLATE_MODE);
}
}

View File

@@ -0,0 +1,192 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file sparse_flash_attention_common.h
* \brief
*/
#ifndef SPARSE_FLASH_ATTENTION_COMMON_H
#define SPARSE_FLASH_ATTENTION_COMMON_H
#include "kernel_operator.h"
#include "lib/matmul_intf.h"
#include "lib/matrix/matmul/tiling.h"
using namespace AscendC;
constexpr SoftmaxConfig SFA_SOFTMAX_FLASHV2_CFG_WITHOUT_BRC = {false, 0, 0, SoftmaxMode::SOFTMAX_OUTPUT_WITHOUT_BRC};
enum class SFA_LAYOUT
{
BSND = 0,
TND = 1,
PA_BSND = 2,
};
template <typename Q_T, typename KV_T, typename OUT_T, const bool FLASH_DECODE = false,
SFA_LAYOUT LAYOUT_T = SFA_LAYOUT::BSND, SFA_LAYOUT KV_LAYOUT_T = SFA_LAYOUT::BSND,
const int TEMPLATE_MODE = C_TEMPLATE, typename... Args>
struct SFAType {
using queryType = Q_T;
using kvType = KV_T;
using outputType = OUT_T;
static constexpr bool flashDecode = FLASH_DECODE;
static constexpr SFA_LAYOUT layout = LAYOUT_T;
static constexpr SFA_LAYOUT kvLayout = KV_LAYOUT_T;
static constexpr int templateMode = TEMPLATE_MODE;
static constexpr bool pageAttention = (KV_LAYOUT_T == SFA_LAYOUT::PA_BSND);
};
// ================================Util functions==================================
template <typename T> __aicore__ inline T SFAAlign(T num, T rnd)
{
return (((rnd) == 0) ? 0 : (((num) + (rnd) - 1) / (rnd) * (rnd)));
}
template <typename T1, typename T2> __aicore__ inline T1 Min(T1 a, T2 b)
{
return (a > b) ? (b) : (a);
}
template <typename T> __aicore__ inline size_t BlockAlign(size_t s)
{
if constexpr (IsSameType<T, int4b_t>::value) {
return (s + 63) / 64 * 64;
}
size_t n = (32 / sizeof(T));
return (s + n - 1) / n * n;
}
struct RunInfo {
uint32_t loop;
uint32_t bIdx;
uint32_t gIdx;
uint32_t s1Idx;
uint32_t s2Idx;
uint32_t bn2IdxInCurCore;
uint32_t curSInnerLoopTimes;
uint64_t tndBIdxOffsetForQ;
uint64_t tndBIdxOffsetForKV;
uint64_t tensorAOffset;
uint64_t tensorBOffset;
uint64_t tensorARopeOffset;
uint64_t tensorBRopeOffset;
uint64_t attenOutOffset;
uint64_t attenMaskOffset;
uint64_t topKBaseOffset;
uint32_t actualSingleProcessSInnerSize;
uint32_t actualSingleProcessSInnerSizeAlign;
bool isFirstSInnerLoop;
bool isChangeBatch;
uint32_t s2BatchOffset;
uint32_t gSize;
uint32_t s1Size;
uint32_t s2Size;
uint32_t mSize;
uint32_t mSizeV;
uint32_t mSizeVStart;
uint32_t tndIsS2SplitCore;
uint32_t tndCoreStartKVSplitPos;
bool isBmm2Output;
bool isValid = false;
static constexpr uint32_t n2Idx = 0;
uint64_t actS1Size = 1;
uint64_t curActualSeqLenOri = 0ULL;
uint32_t gS1Idx;
uint64_t actS2Size = 1;
uint32_t actMBaseSize;
bool isLastS2Loop;
int32_t nextTokensPerBatch = 0;
int64_t threshold;
uint32_t curTopKIdx = 0;
uint64_t curOffsetInSparseBlock = 0;
};
struct ConstInfo {
static constexpr uint32_t SFA_SYNC_MODE2 = 2;
static constexpr uint32_t BUFFER_SIZE_BYTE_32B = 32;
static constexpr uint32_t BUFFER_SIZE_BYTE_64B = 64;
static constexpr uint32_t BUFFER_SIZE_BYTE_256B = 256;
static constexpr uint32_t BUFFER_SIZE_BYTE_512B = 512;
static constexpr uint32_t BUFFER_SIZE_BYTE_1K = 1024;
static constexpr uint32_t BUFFER_SIZE_BYTE_2K = 2048;
static constexpr uint32_t BUFFER_SIZE_BYTE_4K = 4096;
static constexpr uint32_t BUFFER_SIZE_BYTE_8K = 8192;
static constexpr uint32_t BUFFER_SIZE_BYTE_16K = 16384;
static constexpr uint32_t BUFFER_SIZE_BYTE_32K = 32768;
static constexpr float FLOAT_ZERO = 0;
static constexpr float FLOAT_MAX = 3.402823466e+38F;
uint32_t preLoadNum = 0U;
uint32_t nBufferMBaseSize = 0U;
uint32_t syncV1NupdateC2 = 0U;
uint32_t syncV0C1 = 0U;
uint32_t syncC1V1 = 0U;
uint32_t syncV1C2 = 0U;
uint32_t syncC2V2 = 0U;
uint32_t syncC2V1 = 0U;
uint32_t mmResUbSize = 0U;
uint32_t vec1ResUbSize = 0U;
uint32_t bmm2ResUbSize = 0U;
uint64_t batchSize = 0ULL;
uint64_t gSize = 0ULL;
uint64_t qHeadNum = 0ULL;
uint64_t kvHeadNum;
uint64_t headDim;
uint64_t headDimRope;
uint64_t kvSeqSize = 0ULL;
uint64_t qSeqSize = 1ULL;
int64_t kvCacheBlockSize = 0;
uint32_t maxBlockNumPerBatch = 0;
uint32_t splitKVNum = 0U;
SFA_LAYOUT outputLayout;
uint32_t sparseMode = 0;
bool needInit = false;
// FlashDecoding
uint32_t actualCombineLoopSize = 0U;
uint64_t combineLseOffset = 0ULL;
uint64_t combineAccumOutOffset = 0ULL;
uint32_t actualLenDimsQ = 0U;
uint32_t actualLenDimsKV = 0U;
// TND
uint32_t s2Start = 0U;
uint32_t s2End = 0U;
uint32_t bN2Start = 0U;
uint32_t bN2End = 0U;
uint32_t gS1Start = 0U;
uint32_t gS1End = 0U;
uint32_t tndFDCoreArrLen = 0U;
uint32_t coreStartKVSplitPos = 0U;
uint32_t mBaseSize = 1ULL;
uint32_t s2BaseSize = 1ULL;
// sparse attr
int64_t sparseBlockSize = 0;
uint32_t sparseBlockCount = 0;
};
struct MSplitInfo {
uint32_t nBufferIdx = 0U;
uint32_t nBufferStartM = 0U;
uint32_t nBufferDealM = 0U;
uint32_t vecStartM = 0U;
uint32_t vecDealM = 0U;
};
#endif // SPARSE_FLASH_ATTENTION_COMMON_H

View File

@@ -0,0 +1,969 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file sparse_flash_attention_kernel_mla.h
* \brief
*/
#ifndef SPARSE_FLASH_ATTENTION_KERNEL_MLA_H
#define SPARSE_FLASH_ATTENTION_KERNEL_MLA_H
#include "kernel_operator.h"
#include "kernel_operator_list_tensor_intf.h"
#include "kernel_tiling/kernel_tiling.h"
#include "lib/matmul_intf.h"
#include "lib/matrix/matmul/tiling.h"
#include "sparse_flash_attention_common.h"
#include "sparse_flash_attention_service_cube_mla.h"
#include "sparse_flash_attention_service_vector_mla.h"
using namespace matmul;
using AscendC::CacheMode;
using AscendC::CrossCoreSetFlag;
using AscendC::CrossCoreWaitFlag;
struct TempLoopInfo {
uint32_t bn2IdxInCurCore = 0;
uint32_t bIdx = 0U;
uint32_t n2Idx = 0U;
uint64_t s2BasicSizeTail = 0U;
uint32_t s2LoopTimes = 0U;
uint64_t curActualSeqLen = 0ULL;
uint64_t curActualSeqLenOri = 0ULL;
bool curActSeqLenIsZero = false;
int32_t nextTokensPerBatch = 0;
uint64_t actS1Size = 1ULL;
uint32_t tndCoreStartKVSplitPos;
bool tndIsS2SplitCore;
uint32_t gS1Idx = 0U;
uint64_t mBasicSizeTail = 0U;
};
template <typename SFAT> class SparseFlashAttentionMla {
public:
using T = float;
using Q_T = typename SFAT::queryType;
using KV_T = typename SFAT::kvType;
using OUT_T = typename SFAT::outputType;
using Q_ROPE_T = Q_T;
using K_ROPE_T = KV_T;
using UPDATE_T = T;
using MM1_OUT_T = T;
using MM2_OUT_T = T;
__aicore__ inline SparseFlashAttentionMla(){};
__aicore__ inline void Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *value,
__gm__ uint8_t *sparseIndices, __gm__ uint8_t *actualSeqLengthsQ,
__gm__ uint8_t *actualSeqLengths, __gm__ uint8_t *blockTable,
__gm__ uint8_t *queryRope, __gm__ uint8_t *keyRope,
__gm__ uint8_t *attentionOut, __gm__ uint8_t *workspace,
const SparseFlashAttentionTilingDataMla *__restrict tiling,
__gm__ uint8_t *gmTiling, TPipe *tPipe);
__aicore__ inline void Process();
private:
static constexpr bool PAGE_ATTENTION = SFAT::pageAttention;
static constexpr int TEMPLATE_MODE = SFAT::templateMode;
static constexpr bool FLASH_DECODE = SFAT::flashDecode;
static constexpr SFA_LAYOUT LAYOUT_T = SFAT::layout;
static constexpr SFA_LAYOUT KV_LAYOUT_T = SFAT::kvLayout;
static constexpr uint32_t PRELOAD_NUM = 2;
static constexpr uint32_t N_BUFFER_M_BASIC_SIZE = 256;
static constexpr uint32_t SFA_PRELOAD_TASK_CACHE_SIZE = 3;
static constexpr uint32_t SYNC_V0_C1_FLAG = 6;
static constexpr uint32_t SYNC_C1_V1_FLAG = 7;
static constexpr uint32_t SYNC_V1_C2_FLAG = 8;
static constexpr uint32_t SYNC_C2_V2_FLAG = 9;
static constexpr uint32_t SYNC_C2_V1_FLAG = 4;
static constexpr uint32_t SYNC_V1_NUPDATE_C2_FLAG = 5;
static constexpr uint64_t SYNC_MM2RES_BUF1_FLAG = 10;
static constexpr uint64_t SYNC_MM2RES_BUF2_FLAG = 11;
static constexpr uint64_t SYNC_FDOUTPUT_BUF_FLAG = 12;
static constexpr uint32_t BLOCK_ELEMENT_NUM = SFAVectorService<SFAT>::BYTE_BLOCK / sizeof(T);
static constexpr uint64_t kvHeadNum = 1ULL;
static constexpr uint64_t headDim = 512ULL;
static constexpr uint64_t headDimAlign = 512ULL;
static constexpr uint64_t headDimRope = 64ULL;
static constexpr uint32_t msdIterNum = 2U;
static constexpr uint32_t dbWorkspaceRatio = PRELOAD_NUM;
const SparseFlashAttentionTilingDataMla *__restrict tilingData = nullptr;
TPipe *pipe = nullptr;
uint64_t mSizeVStart = 0ULL;
int64_t threshold = 0;
uint64_t topKBaseOffset = 0ULL;
uint64_t s2BatchBaseOffset = 0;
uint64_t tensorACoreOffset = 0ULL;
uint64_t tensorBCoreOffset = 0ULL;
uint64_t tensorARopeCoreOffset = 0ULL;
uint64_t tensorBRopeCoreOffset = 0ULL;
uint64_t tensorBOffset = 0ULL;
uint64_t attenOutOffset = 0ULL;
uint32_t tmpBlockIdx = 0U;
uint32_t aiCoreIdx = 0U;
uint32_t usedCoreNum = 0U;
__gm__ uint8_t *keyPtr = nullptr;
__gm__ uint8_t *valuePtr = nullptr;
ConstInfo constInfo{};
TempLoopInfo tempLoopInfo{};
SFAMatmulService<SFAT> matmulService;
SFAVectorService<SFAT> vectorService;
GlobalTensor<Q_T> queryGm;
GlobalTensor<KV_T> keyGm;
GlobalTensor<KV_T> valueGm;
GlobalTensor<Q_ROPE_T> qRopeGm;
GlobalTensor<K_ROPE_T> kRopeGm;
GlobalTensor<OUT_T> attentionOutGm;
GlobalTensor<int32_t> blockTableGm;
GlobalTensor<int32_t> topKGm;
GlobalTensor<int32_t> actualSeqLengthsQGm;
GlobalTensor<int32_t> actualSeqLengthsKVGm;
// workspace
GlobalTensor<MM1_OUT_T> mm1ResGm;
GlobalTensor<KV_T> vec1ResGm;
GlobalTensor<MM2_OUT_T> mm2ResGm;
GlobalTensor<KV_T> kvMergeGm_;
GlobalTensor<int32_t> kvValidSizeGm_;
GlobalTensor<int32_t> mm2ResInt32Gm;
GlobalTensor<UPDATE_T> vec2ResGm;
GlobalTensor<T> accumOutGm;
GlobalTensor<T> lseSumFdGm;
GlobalTensor<T> lseMaxFdGm;
// ================================Init functions===================================
__aicore__ inline void InitTilingData();
__aicore__ inline void InitCalcParamsEach();
__aicore__ inline void InitBuffers();
__aicore__ inline void InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths);
__aicore__ inline void InitOutputSingleCore();
// ================================Process functions================================
__aicore__ inline void ProcessBalance();
__aicore__ inline void PreloadPipeline(uint32_t loop, uint64_t s2Start, uint64_t s2LoopIdx,
RunInfo extraInfo[SFA_PRELOAD_TASK_CACHE_SIZE], uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock);
// ================================Offset Calc=====================================
__aicore__ inline void GetActualSeqLen(uint32_t bIdx, uint32_t s1Idx = 0);
__aicore__ inline void GetSparseActualSeqLen(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx);
__aicore__ inline void CalcSinnerTopKBegin(RunInfo &info, uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock);
__aicore__ inline void UpdateInnerLoopCond();
__aicore__ inline void DealActSeqLenIsZero(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx);
__aicore__ inline void CalcParams(uint32_t loop, uint64_t s2Start, uint32_t s2LoopIdx, RunInfo &info);
__aicore__ inline void GetAxisStartIdx(uint32_t bN2EndPrev, uint32_t gS1EndPrev, uint32_t s2EndPrev);
__aicore__ inline uint64_t GetBalanceActualSeqLengths(GlobalTensor<int32_t> &actualSeqLengths, uint32_t bIdx);
__aicore__ inline uint32_t GetActualSeqLenKV(uint32_t bIdx);
__aicore__ inline void GetBN2Idx(uint32_t bN2Idx, uint32_t &bIdx, uint32_t &n2Idx);
__aicore__ inline void UpdateInner(uint32_t &s2End, uint32_t &curS2End, uint32_t s1Idx, bool isEnd);
__aicore__ inline void GetPreNextTokensLeftUp();
// ================================Mm1==============================================
__aicore__ inline void ComputeMm1(const RunInfo &info);
// ================================Mm2==============================================
__aicore__ inline void ComputeMm2(const RunInfo &info);
__aicore__ inline void Bmm2DataCopyOut(uint64_t attenOutOffset, LocalTensor<OUT_T> &attenOutUb, uint32_t startRow,
uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount);
__aicore__ inline void InitAllZeroOutput(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx);
};
template <typename SFAT> __aicore__ inline void SparseFlashAttentionMla<SFAT>::InitTilingData()
{
usedCoreNum = tilingData->singleCoreParams.usedCoreNum;
constInfo.splitKVNum = tilingData->splitKVParams.s2;
constInfo.mmResUbSize = tilingData->singleCoreTensorSize.mmResUbSize;
constInfo.bmm2ResUbSize = tilingData->singleCoreTensorSize.bmm2ResUbSize;
constInfo.vec1ResUbSize = constInfo.mmResUbSize * msdIterNum;
constInfo.batchSize = tilingData->baseParams.batchSize;
constInfo.qHeadNum = constInfo.gSize = tilingData->baseParams.nNumOfQInOneGroup;
constInfo.kvSeqSize = tilingData->baseParams.seqSize;
constInfo.qSeqSize = tilingData->baseParams.qSeqSize;
constInfo.maxBlockNumPerBatch = tilingData->baseParams.maxBlockNumPerBatch;
constInfo.kvCacheBlockSize = tilingData->baseParams.blockSize;
constInfo.outputLayout = static_cast<SFA_LAYOUT>(tilingData->baseParams.outputLayout);
constInfo.mBaseSize = tilingData->innerSplitParams.mBaseSize;
constInfo.s2BaseSize = tilingData->innerSplitParams.s2BaseSize;
constInfo.kvHeadNum = kvHeadNum;
constInfo.headDim = headDim;
constInfo.headDimRope = headDimRope;
constInfo.sparseBlockSize = tilingData->baseParams.sparseBlockSize;
constInfo.sparseBlockCount = tilingData->baseParams.sparseBlockCount;
constInfo.sparseMode = tilingData->baseParams.sparseMode;
constInfo.preLoadNum = PRELOAD_NUM;
constInfo.nBufferMBaseSize = N_BUFFER_M_BASIC_SIZE;
constInfo.syncV0C1 = SYNC_V0_C1_FLAG;
constInfo.syncC1V1 = SYNC_C1_V1_FLAG;
constInfo.syncV1C2 = SYNC_V1_C2_FLAG;
constInfo.syncC2V2 = SYNC_C2_V2_FLAG;
constInfo.syncC2V1 = SYNC_C2_V1_FLAG;
constInfo.syncV1NupdateC2 = SYNC_V1_NUPDATE_C2_FLAG;
}
template <typename SFAT> __aicore__ inline void SparseFlashAttentionMla<SFAT>::InitBuffers()
{
if ASCEND_IS_AIV {
vectorService.InitBuffers(pipe);
} else {
matmulService.InitBuffers(pipe);
}
}
template <typename SFAT>
__aicore__ inline void
SparseFlashAttentionMla<SFAT>::InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ,
__gm__ uint8_t *actualSeqLengths)
{
constInfo.actualLenDimsQ = tilingData->baseParams.actualLenDimsQ;
constInfo.actualLenDimsKV = tilingData->baseParams.actualLenDimsKV;
if (constInfo.actualLenDimsKV != 0) {
actualSeqLengthsKVGm.SetGlobalBuffer((__gm__ int32_t *)actualSeqLengths, constInfo.actualLenDimsKV);
}
if (constInfo.actualLenDimsQ != 0) {
actualSeqLengthsQGm.SetGlobalBuffer((__gm__ int32_t *)actualSeqLengthsQ, constInfo.actualLenDimsQ);
}
}
template <typename SFAT>
__aicore__ inline void SparseFlashAttentionMla<SFAT>::InitAllZeroOutput(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx)
{
if (constInfo.outputLayout == SFA_LAYOUT::TND) {
uint32_t tBase = bIdx == 0 ? 0 : actualSeqLengthsQGm.GetValue(bIdx - 1);
uint32_t s1Count = tempLoopInfo.actS1Size;
uint64_t attenOutOffset = (tBase + s1Idx) * kvHeadNum * constInfo.gSize * headDim +
n2Idx * constInfo.gSize * headDim;
matmul::InitOutput<OUT_T>(attentionOutGm[attenOutOffset], constInfo.gSize * headDim, 0);
} else if (constInfo.outputLayout == SFA_LAYOUT::BSND) {
uint64_t attenOutOffset = bIdx * constInfo.qSeqSize * kvHeadNum * constInfo.gSize * headDim +
s1Idx * kvHeadNum * constInfo.gSize * headDim +
n2Idx * constInfo.gSize * headDim;
matmul::InitOutput<OUT_T>(attentionOutGm[attenOutOffset], constInfo.gSize * headDim, 0);
}
}
template <typename SFAT>
__aicore__ inline void SparseFlashAttentionMla<SFAT>::InitOutputSingleCore()
{
uint32_t coreNum = GetBlockNum();
if (coreNum != 0) {
uint64_t totalOutputSize = constInfo.batchSize * constInfo.qHeadNum * constInfo.qSeqSize * constInfo.headDim;
uint64_t singleCoreSize = (totalOutputSize + (2 * coreNum) - 1) / (2 * coreNum); // 2 means c:v = 1:2
uint64_t tailSize = totalOutputSize - tmpBlockIdx * singleCoreSize;
uint64_t singleInitOutputSize = tailSize < singleCoreSize ? tailSize : singleCoreSize;
if (singleInitOutputSize > 0) {
matmul::InitOutput<OUT_T>(attentionOutGm[tmpBlockIdx * singleCoreSize], singleInitOutputSize, 0);
}
SyncAll();
}
}
template <typename SFAT>
__aicore__ inline void SparseFlashAttentionMla<SFAT>::GetActualSeqLen(uint32_t bIdx, uint32_t s1Idx)
{
tempLoopInfo.curActualSeqLenOri = GetActualSeqLenKV(bIdx);
tempLoopInfo.actS1Size = GetBalanceActualSeqLengths(actualSeqLengthsQGm, bIdx);
}
template <typename SFAT>
__aicore__ inline void SparseFlashAttentionMla<SFAT>::GetSparseActualSeqLen(uint32_t bIdx, uint32_t s1Idx,
uint32_t n2Idx)
{
if (tempLoopInfo.nextTokensPerBatch < 0 && s1Idx < (-tempLoopInfo.nextTokensPerBatch)) {
tempLoopInfo.curActualSeqLen = 0;
return;
}
int64_t threshold = tempLoopInfo.curActualSeqLenOri;
if (constInfo.sparseMode == 3) {
threshold = static_cast<int64_t>(tempLoopInfo.nextTokensPerBatch) + s1Idx + 1;
}
tempLoopInfo.curActualSeqLen = (constInfo.sparseBlockCount * constInfo.sparseBlockSize > threshold) ?
threshold :
constInfo.sparseBlockCount * constInfo.sparseBlockSize;
}
template <typename SFAT>
__aicore__ inline uint32_t SparseFlashAttentionMla<SFAT>::GetActualSeqLenKV(uint32_t bIdx)
{
if constexpr (KV_LAYOUT_T == SFA_LAYOUT::TND) {
if (bIdx > 0) {
return actualSeqLengthsKVGm.GetValue(bIdx) - actualSeqLengthsKVGm.GetValue(bIdx - 1);
} else if (bIdx == 0) {
return actualSeqLengthsKVGm.GetValue(0);
} else {
return 0;
}
} else {
if (constInfo.actualLenDimsKV == 0) {
return constInfo.kvSeqSize;
} else if (constInfo.actualLenDimsKV == 1) {
return actualSeqLengthsKVGm.GetValue(0);
} else {
return actualSeqLengthsKVGm.GetValue(bIdx);
}
}
}
template <typename SFAT>
__aicore__ inline void SparseFlashAttentionMla<SFAT>::DealActSeqLenIsZero(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx)
{
if ASCEND_IS_AIV {
InitAllZeroOutput(bIdx, s1Idx, n2Idx);
}
}
template <typename SFAT>
__aicore__ inline void SparseFlashAttentionMla<SFAT>::GetPreNextTokensLeftUp()
{
if (constInfo.sparseMode == 3) {
tempLoopInfo.nextTokensPerBatch =
static_cast<int32_t>(tempLoopInfo.curActualSeqLenOri) - static_cast<int32_t>(tempLoopInfo.actS1Size);
}
}
template <typename SFAT> __aicore__ inline void SparseFlashAttentionMla<SFAT>::UpdateInnerLoopCond()
{
if ((tempLoopInfo.curActualSeqLen == 0) || (tempLoopInfo.actS1Size == 0)) {
tempLoopInfo.curActSeqLenIsZero = true;
return;
}
tempLoopInfo.curActSeqLenIsZero = false;
tempLoopInfo.mBasicSizeTail = (tempLoopInfo.actS1Size * constInfo.gSize) % constInfo.mBaseSize;
tempLoopInfo.mBasicSizeTail =
(tempLoopInfo.mBasicSizeTail == 0) ? constInfo.mBaseSize : tempLoopInfo.mBasicSizeTail;
tempLoopInfo.s2LoopTimes = 0;
}
template <typename SFAT>
__aicore__ inline void SparseFlashAttentionMla<SFAT>::UpdateInner(uint32_t &s2End, uint32_t &curS2End,
uint32_t s1Idx, bool isEnd)
{
uint32_t s1BaseSize = 1;
int64_t s1Offset = s1BaseSize * s1Idx;
int64_t s2LastToken = Min(s1Offset + tempLoopInfo.nextTokensPerBatch + s1BaseSize,tempLoopInfo.curActualSeqLenOri);
s2LastToken = Min(constInfo.sparseBlockSize * constInfo.sparseBlockCount, s2LastToken);
curS2End = (s2LastToken + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
tempLoopInfo.s2LoopTimes = isEnd ? constInfo.s2End + 1 : curS2End;
}
template <typename SFAT>
__aicore__ inline void SparseFlashAttentionMla<SFAT>::Init(__gm__ uint8_t *query,
__gm__ uint8_t *key, __gm__ uint8_t *value,
__gm__ uint8_t *sparseIndices, __gm__ uint8_t *actualSeqLengthsQ,
__gm__ uint8_t *actualSeqLengths, __gm__ uint8_t *blockTable,
__gm__ uint8_t *queryRope, __gm__ uint8_t *keyRope,
__gm__ uint8_t *attentionOut, __gm__ uint8_t *workspace,
const SparseFlashAttentionTilingDataMla *__restrict tiling,
__gm__ uint8_t *gmTiling, TPipe *tPipe)
{
if ASCEND_IS_AIV {
tmpBlockIdx = GetBlockIdx(); // vec:0-47
aiCoreIdx = tmpBlockIdx / 2;
} else {
tmpBlockIdx = GetBlockIdx(); // cube:0-23
aiCoreIdx = tmpBlockIdx;
}
// init tiling data
tilingData = tiling;
InitTilingData();
InitActualSeqLen(actualSeqLengthsQ, actualSeqLengths);
InitCalcParamsEach();
pipe = tPipe;
keyPtr = key;
valuePtr = value;
// init global buffer
queryGm.SetGlobalBuffer((__gm__ Q_T *)query);
keyGm.SetGlobalBuffer((__gm__ KV_T *)keyPtr);
valueGm.SetGlobalBuffer((__gm__ KV_T *)valuePtr);
qRopeGm.SetGlobalBuffer((__gm__ Q_ROPE_T *)queryRope);
kRopeGm.SetGlobalBuffer((__gm__ K_ROPE_T *)keyRope);
attentionOutGm.SetGlobalBuffer((__gm__ OUT_T *)attentionOut);
if ASCEND_IS_AIV {
if (constInfo.needInit && LAYOUT_T != SFA_LAYOUT::TND) {
InitOutputSingleCore();
}
}
if constexpr (PAGE_ATTENTION) {
blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable);
}
topKGm.SetGlobalBuffer((__gm__ int32_t *)sparseIndices);
uint64_t offset = 0;
mm1ResGm.SetGlobalBuffer(
(__gm__ MM1_OUT_T *)(workspace + offset +
aiCoreIdx * dbWorkspaceRatio * constInfo.mmResUbSize * sizeof(MM1_OUT_T)));
offset += GetBlockNum() * dbWorkspaceRatio * constInfo.mmResUbSize * sizeof(MM1_OUT_T);
vec1ResGm.SetGlobalBuffer(
(__gm__ KV_T *)(workspace + offset + aiCoreIdx * dbWorkspaceRatio * constInfo.mmResUbSize * sizeof(KV_T)));
offset += GetBlockNum() * dbWorkspaceRatio * constInfo.mmResUbSize * sizeof(KV_T);
mm2ResGm.SetGlobalBuffer(
(__gm__ MM2_OUT_T *)(workspace + offset +
aiCoreIdx * dbWorkspaceRatio * constInfo.bmm2ResUbSize * sizeof(MM2_OUT_T)));
offset += GetBlockNum() * dbWorkspaceRatio * constInfo.bmm2ResUbSize * sizeof(MM2_OUT_T);
mm2ResInt32Gm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(mm2ResGm.GetPhyAddr(0)));
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
// s2 d+rope bufNum
kvMergeGm_.SetGlobalBuffer((__gm__ KV_T *)(workspace + offset + aiCoreIdx * 512 * 576 * 4 * sizeof(KV_T)));
offset += GetBlockNum() * 512 * 576 * 4 * sizeof(KV_T);
kvValidSizeGm_.SetGlobalBuffer(
(__gm__ int32_t *)(workspace + offset + (aiCoreIdx * 2) * 128 * 4 * sizeof(int32_t)));
}
if constexpr (FLASH_DECODE) {
accumOutGm.SetGlobalBuffer((__gm__ float *)(workspace + offset));
offset = offset + tilingData->splitKVParams.accumOutSize * sizeof(float);
lseSumFdGm.SetGlobalBuffer((__gm__ float *)(workspace + offset));
lseMaxFdGm.SetGlobalBuffer((__gm__ float *)(workspace + offset) + tilingData->splitKVParams.logSumExpSize / 2);
offset = offset + tilingData->splitKVParams.logSumExpSize * sizeof(float);
}
if ASCEND_IS_AIV {
vectorService.InitParams(constInfo, tilingData);
vectorService.InitMm2ResInt32GmGlobalTensor(mm2ResInt32Gm);
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
vectorService.InitVec0GlobalTensor(kvValidSizeGm_, kvMergeGm_, kRopeGm, keyGm, blockTableGm);
}
vectorService.InitVec1GlobalTensor(mm1ResGm, vec1ResGm, actualSeqLengthsQGm,
actualSeqLengthsKVGm, lseMaxFdGm, lseSumFdGm, topKGm);
vectorService.InitVec2GlobalTensor(accumOutGm, vec2ResGm, mm2ResGm, attentionOutGm);
}
if ASCEND_IS_AIC {
matmulService.InitParams(constInfo);
matmulService.InitMm1GlobalTensor(queryGm, qRopeGm, keyGm, kRopeGm, mm1ResGm);
matmulService.InitMm2GlobalTensor(vec1ResGm, valueGm, mm2ResGm, attentionOutGm);
matmulService.InitPageAttentionInfo(kvMergeGm_, blockTableGm, topKGm,
constInfo.kvCacheBlockSize, constInfo.maxBlockNumPerBatch);
}
if (pipe != nullptr) {
InitBuffers();
}
}
template <typename SFAT> __aicore__ inline void SparseFlashAttentionMla<SFAT>::InitCalcParamsEach()
{
uint32_t totalBaseNum = 0;
uint32_t s1GBaseSize = constInfo.gSize;
uint32_t actBatchS2 = 1;
uint32_t coreNum = GetBlockNum();
uint32_t currCoreIdx = aiCoreIdx;
uint32_t actBatchS1 = 1;
for (uint32_t bIdx = 0; bIdx < constInfo.batchSize; bIdx++) {
uint32_t actBatchS1 = GetBalanceActualSeqLengths(actualSeqLengthsQGm, bIdx);
if (actBatchS1 < constInfo.qSeqSize) {
constInfo.needInit = true;
}
totalBaseNum += actBatchS1*actBatchS2 ;
}
uint32_t avgBaseNum = 1;
if (totalBaseNum > coreNum) {
avgBaseNum = (totalBaseNum + coreNum - 1) / coreNum;
}else {
usedCoreNum = totalBaseNum;
}
if(aiCoreIdx>=usedCoreNum){
return;
}
uint32_t accumBaseNum = 0;
uint32_t targetBaseNum = 0;
uint32_t lastValidBIdx = 0;
uint32_t lastValidactBatchS1=0;
bool setStart=false;
targetBaseNum = (currCoreIdx + 1) * avgBaseNum;
uint32_t targetStartBaseNum = targetBaseNum-avgBaseNum;
for (uint32_t bN2Idx = 0; bN2Idx < constInfo.batchSize * constInfo.kvHeadNum; bN2Idx++) {
uint32_t bIdx = bN2Idx / constInfo.kvHeadNum;
actBatchS1 = GetBalanceActualSeqLengths(actualSeqLengthsQGm, bIdx);
for (uint32_t s1GIdx = 0; s1GIdx < actBatchS1; s1GIdx++) {
accumBaseNum += 1;
if(!setStart && accumBaseNum >= targetStartBaseNum){
constInfo.bN2Start = bN2Idx;
constInfo.gS1Start = s1GIdx;
setStart=true;
}
if (accumBaseNum >= targetBaseNum) {
constInfo.bN2End = bN2Idx;
constInfo.gS1End = s1GIdx;
constInfo.s2End = 0;
constInfo.coreStartKVSplitPos = 0;
if (aiCoreIdx != 0) {
GetAxisStartIdx(constInfo.bN2Start, constInfo.gS1Start, 0);
}
return;
}
}
if ((actBatchS1 > 0) && (actBatchS2 > 0)) {
lastValidBIdx = bIdx;
lastValidactBatchS1 = actBatchS1;
}
}
if (!setStart){
constInfo.bN2Start = lastValidBIdx;
constInfo.gS1Start = lastValidactBatchS1-1;
}
if (accumBaseNum < targetBaseNum) {
constInfo.bN2End = lastValidBIdx;
constInfo.gS1End = lastValidactBatchS1-1;
constInfo.s2End = 0;
constInfo.coreStartKVSplitPos = 0;
if (aiCoreIdx != 0) {
GetAxisStartIdx(constInfo.bN2Start, constInfo.gS1Start, 0);
}
return;
}
}
template <typename SFAT>
__aicore__ inline void
SparseFlashAttentionMla<SFAT>::Bmm2DataCopyOut(uint64_t attenOutOffset, LocalTensor<OUT_T> &attenOutUb,
uint32_t startRow, uint32_t dealRowCount,
uint32_t columnCount, uint32_t actualColumnCount)
{
DataCopyExtParams dataCopyParams;
dataCopyParams.blockCount = dealRowCount;
dataCopyParams.blockLen = actualColumnCount * sizeof(OUT_T);
dataCopyParams.srcStride = (columnCount - actualColumnCount) / (SFAVectorService<SFAT>::BYTE_BLOCK / sizeof(OUT_T));
dataCopyParams.dstStride = 0;
DataCopyPad(attentionOutGm[attenOutOffset + (mSizeVStart + startRow) * actualColumnCount], attenOutUb,
dataCopyParams);
}
template <typename SFAT>
__aicore__ inline void SparseFlashAttentionMla<SFAT>::CalcParams(uint32_t loop, uint64_t s2Start,
uint32_t s2LoopIdx, RunInfo &info)
{
info.loop = loop;
info.bIdx = tempLoopInfo.bIdx;
info.gS1Idx = tempLoopInfo.gS1Idx;
info.s2Idx = s2LoopIdx;
info.curSInnerLoopTimes = tempLoopInfo.s2LoopTimes;
info.tndIsS2SplitCore = tempLoopInfo.tndIsS2SplitCore;
info.tndCoreStartKVSplitPos = tempLoopInfo.tndCoreStartKVSplitPos;
info.isBmm2Output = false;
info.actS1Size = tempLoopInfo.actS1Size;
info.actMBaseSize = constInfo.mBaseSize;
uint32_t remainedGS1Size = tempLoopInfo.actS1Size * constInfo.gSize - tempLoopInfo.gS1Idx;
if (remainedGS1Size <= constInfo.mBaseSize && remainedGS1Size > 0) {
info.actMBaseSize = tempLoopInfo.mBasicSizeTail;
}
info.isValid = s2LoopIdx < tempLoopInfo.s2LoopTimes;
if ASCEND_IS_AIV {
info.mSize = info.actMBaseSize;
info.mSizeV = (info.mSize <= 16) ? info.mSize : (((info.mSize + 15) / 16 + 1) / 2 * 16);
info.mSizeVStart = 0;
if (tmpBlockIdx % 2 == 1) {
info.mSizeVStart = info.mSizeV;
info.mSizeV = info.mSize - info.mSizeV;
}
}
info.isChangeBatch = false;
info.isFirstSInnerLoop = s2LoopIdx == s2Start;
if (info.isFirstSInnerLoop) {
tempLoopInfo.bn2IdxInCurCore++;
}
info.isLastS2Loop = s2LoopIdx == tempLoopInfo.s2LoopTimes - 1;
info.bn2IdxInCurCore = tempLoopInfo.bn2IdxInCurCore - 1;
uint64_t actualSeqQPrefixSum;
if constexpr (LAYOUT_T == SFA_LAYOUT::TND) {
actualSeqQPrefixSum = (info.bIdx <= 0) ? 0 : actualSeqLengthsQGm.GetValue(info.bIdx - 1);
} else {
actualSeqQPrefixSum = (info.bIdx <= 0) ? 0 : info.bIdx * constInfo.qSeqSize;
}
info.tndBIdxOffsetForQ = actualSeqQPrefixSum * constInfo.qHeadNum * headDim;
uint64_t actualSeqKVPrefixSum;
if constexpr (KV_LAYOUT_T == SFA_LAYOUT::TND) {
actualSeqKVPrefixSum = (info.bIdx <= 0) ? 0 : actualSeqLengthsKVGm.GetValue(info.bIdx - 1);
} else {
actualSeqKVPrefixSum = (info.bIdx <= 0) ? 0 : info.bIdx * constInfo.kvSeqSize;
}
info.tndBIdxOffsetForKV = actualSeqKVPrefixSum * constInfo.kvHeadNum * headDim;
if (info.isFirstSInnerLoop) {
uint64_t tndBIdxRopeOffsetForQ = actualSeqQPrefixSum * constInfo.qHeadNum * headDimRope;
tensorACoreOffset = info.tndBIdxOffsetForQ + info.gS1Idx * headDim;
tensorARopeCoreOffset = tndBIdxRopeOffsetForQ + info.gS1Idx * headDimRope;
uint64_t tndBIdxRopeOffsetForK = actualSeqKVPrefixSum * constInfo.kvHeadNum * headDimRope;
tensorBCoreOffset = info.tndBIdxOffsetForKV + info.n2Idx * headDim;
tensorBRopeCoreOffset = tndBIdxRopeOffsetForK + info.n2Idx * headDimRope;
if (constInfo.sparseMode == 3) {
threshold = static_cast<int64_t>(tempLoopInfo.nextTokensPerBatch) + info.gS1Idx / constInfo.gSize + 1;
} else {
threshold = tempLoopInfo.curActualSeqLenOri;
}
if constexpr(LAYOUT_T == SFA_LAYOUT::BSND) { // B,S1,N2 K
topKBaseOffset = info.bIdx * constInfo.qSeqSize * constInfo.kvHeadNum * constInfo.sparseBlockCount +
info.gS1Idx / constInfo.gSize * constInfo.kvHeadNum * constInfo.sparseBlockCount +
info.n2Idx * constInfo.sparseBlockCount;
} else if (LAYOUT_T == SFA_LAYOUT::TND) { // T N2 K
topKBaseOffset = info.tndBIdxOffsetForQ / constInfo.gSize / constInfo.headDim * constInfo.kvHeadNum *
constInfo.sparseBlockCount + info.n2Idx * constInfo.sparseBlockCount +
info.gS1Idx / constInfo.gSize * constInfo.kvHeadNum * constInfo.sparseBlockCount;
} else { // B N2 S1 K
topKBaseOffset = info.bIdx * constInfo.kvHeadNum * constInfo.qSeqSize * constInfo.sparseBlockCount +
info.n2Idx * constInfo.qSeqSize * constInfo.sparseBlockCount +
info.gS1Idx / constInfo.gSize * constInfo.sparseBlockCount;
}
}
info.topKBaseOffset = topKBaseOffset;
info.threshold = threshold;
info.tensorAOffset = tensorACoreOffset;
info.tensorARopeOffset = tensorARopeCoreOffset;
info.tensorBOffset = tensorBCoreOffset;
info.tensorBRopeOffset = tensorBRopeCoreOffset;
info.attenOutOffset = tensorACoreOffset;
uint64_t sInnerOffsetDataSize = info.s2Idx * constInfo.s2BaseSize;
info.s2BatchOffset = s2BatchBaseOffset + sInnerOffsetDataSize;
info.curActualSeqLenOri = tempLoopInfo.curActualSeqLenOri;
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
if (tempLoopInfo.curActualSeqLen > sInnerOffsetDataSize) {
info.actualSingleProcessSInnerSize = tempLoopInfo.curActualSeqLen - sInnerOffsetDataSize;
info.actualSingleProcessSInnerSize = info.actualSingleProcessSInnerSize > constInfo.s2BaseSize ?
constInfo.s2BaseSize : info.actualSingleProcessSInnerSize;
info.actualSingleProcessSInnerSize =
SFAAlign((int64_t)info.actualSingleProcessSInnerSize, (int64_t)constInfo.sparseBlockSize);
} else {
info.actualSingleProcessSInnerSize = 0;
}
info.actualSingleProcessSInnerSizeAlign =
SFAAlign((uint32_t)info.actualSingleProcessSInnerSize, (uint32_t)SFAVectorService<SFAT>::BYTE_BLOCK);
}
}
template <typename SFAT>
__aicore__ inline void SparseFlashAttentionMla<SFAT>::ComputeMm1(const RunInfo &info)
{
uint32_t nBufferLoopTimes = (info.actMBaseSize + constInfo.nBufferMBaseSize - 1) / constInfo.nBufferMBaseSize;
uint32_t nBufferTail = info.actMBaseSize - (nBufferLoopTimes - 1) * constInfo.nBufferMBaseSize;
for (uint32_t i = 0; i < nBufferLoopTimes; i++) {
MSplitInfo mSplitInfo;
mSplitInfo.nBufferStartM = i * constInfo.nBufferMBaseSize;
mSplitInfo.nBufferDealM = (i + 1 != nBufferLoopTimes) ? constInfo.nBufferMBaseSize : nBufferTail;
matmulService.ComputeMm1(info, mSplitInfo);
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V1);
}
}
template <typename SFAT>
__aicore__ inline void SparseFlashAttentionMla<SFAT>::ComputeMm2(const RunInfo &info)
{
uint32_t nBufferLoopTimes = (info.actMBaseSize + constInfo.nBufferMBaseSize - 1) / constInfo.nBufferMBaseSize;
uint32_t nBufferTail = info.actMBaseSize - (nBufferLoopTimes - 1) * constInfo.nBufferMBaseSize;
for (uint32_t i = 0; i < nBufferLoopTimes; i++) {
MSplitInfo mSplitInfo;
mSplitInfo.nBufferStartM = i * constInfo.nBufferMBaseSize;
mSplitInfo.nBufferDealM = (i + 1 != nBufferLoopTimes) ? constInfo.nBufferMBaseSize : nBufferTail;
CrossCoreWaitFlag(constInfo.syncV1C2);
matmulService.ComputeMm2(info, mSplitInfo);
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC2V2);
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC2V1);
}
}
template <typename SFAT> __aicore__ inline void SparseFlashAttentionMla<SFAT>::Process()
{
if (aiCoreIdx < usedCoreNum) {
if ASCEND_IS_AIV {
vectorService.AllocEventID();
vectorService.InitSoftmaxDefaultBuffer();
} else {
matmulService.AllocEventID();
}
ProcessBalance();
if ASCEND_IS_AIV {
vectorService.FreeEventID();
} else {
matmulService.FreeEventID();
}
}
}
template <typename SFAT>
__aicore__ inline void SparseFlashAttentionMla<SFAT>::GetBN2Idx(uint32_t bN2Idx, uint32_t &bIdx,
uint32_t &n2Idx)
{
bIdx = bN2Idx / kvHeadNum;
n2Idx = bN2Idx % kvHeadNum;
}
template <typename SFAT> __aicore__ inline void SparseFlashAttentionMla<SFAT>::ProcessBalance()
{
RunInfo extraInfo[SFA_PRELOAD_TASK_CACHE_SIZE];
uint32_t gloop = 0;
int gS1LoopEnd;
bool globalLoopStart = true;
if ASCEND_IS_AIC {
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC2V1);
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_MTE2>(3);
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_MTE2>(3);
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_MTE2>(3);
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_MTE2>(3);
}
}
for (uint32_t bN2LoopIdx = constInfo.bN2Start; bN2LoopIdx <= constInfo.bN2End; bN2LoopIdx++) {
GetBN2Idx(bN2LoopIdx, tempLoopInfo.bIdx, tempLoopInfo.n2Idx);
GetActualSeqLen(tempLoopInfo.bIdx);
GetPreNextTokensLeftUp();
if (tempLoopInfo.actS1Size == 0) {
continue;
}
int gS1SplitNum = (tempLoopInfo.actS1Size * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize;
gS1LoopEnd = (bN2LoopIdx == constInfo.bN2End) ? constInfo.gS1End : gS1SplitNum - 1;
for (uint32_t gS1LoopIdx = constInfo.gS1Start; gS1LoopIdx <= gS1LoopEnd; gS1LoopIdx++) {
tempLoopInfo.gS1Idx = gS1LoopIdx * constInfo.mBaseSize;
GetSparseActualSeqLen(tempLoopInfo.bIdx, gS1LoopIdx, tempLoopInfo.n2Idx);
UpdateInnerLoopCond();
if (tempLoopInfo.curActSeqLenIsZero) {
DealActSeqLenIsZero(tempLoopInfo.bIdx, gS1LoopIdx, tempLoopInfo.n2Idx);
}
int s2SplitNum =
(tempLoopInfo.curActualSeqLen + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
bool isEnd = (bN2LoopIdx == constInfo.bN2End) && (gS1LoopIdx == constInfo.gS1End);
tempLoopInfo.s2LoopTimes = s2SplitNum;
tempLoopInfo.tndIsS2SplitCore =
((constInfo.s2Start == 0) && (tempLoopInfo.s2LoopTimes == s2SplitNum)) ? false : true;
tempLoopInfo.tndCoreStartKVSplitPos = globalLoopStart ? constInfo.coreStartKVSplitPos : 0;
uint32_t extraLoop = isEnd ? 2 : 0;
uint32_t curTopKIdx = 0;
uint64_t curOffsetInSparseBlock = 0;
for (int s2LoopIdx = constInfo.s2Start; s2LoopIdx < (tempLoopInfo.s2LoopTimes + extraLoop); s2LoopIdx++) {
PreloadPipeline(gloop, constInfo.s2Start, s2LoopIdx, extraInfo, curTopKIdx, curOffsetInSparseBlock);
++gloop;
}
globalLoopStart = false;
constInfo.s2Start = 0;
}
constInfo.gS1Start = 0;
}
if ASCEND_IS_AIV {
CrossCoreWaitFlag(constInfo.syncC2V1);
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
CrossCoreWaitFlag(3);
CrossCoreWaitFlag(3);
CrossCoreWaitFlag(3);
CrossCoreWaitFlag(3);
}
}
}
template <typename SFAT>
__aicore__ inline void
SparseFlashAttentionMla<SFAT>::PreloadPipeline(uint32_t loop, uint64_t s2Start, uint64_t s2LoopIdx,
RunInfo extraInfo[SFA_PRELOAD_TASK_CACHE_SIZE], uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock)
{
RunInfo &extraInfo0 = extraInfo[loop % SFA_PRELOAD_TASK_CACHE_SIZE];
RunInfo &extraInfo2 = extraInfo[(loop + 2) % SFA_PRELOAD_TASK_CACHE_SIZE];
RunInfo &extraInfo1 = extraInfo[(loop + 1) % SFA_PRELOAD_TASK_CACHE_SIZE];
CalcParams(loop, s2Start, s2LoopIdx, extraInfo0);
CalcSinnerTopKBegin(extraInfo0, curTopKIdx, curOffsetInSparseBlock);
if (extraInfo0.isValid) {
if ASCEND_IS_AIC {
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
CrossCoreWaitFlag(constInfo.syncV0C1);
}
ComputeMm1(extraInfo0);
} else {
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
CrossCoreWaitFlag(3);
vectorService.MergeKv(extraInfo0);
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_MTE3>(constInfo.syncV0C1);
}
}
}
if (extraInfo2.isValid) {
if ASCEND_IS_AIV {
vectorService.ProcessVec1L(extraInfo2);
}
if ASCEND_IS_AIC {
ComputeMm2(extraInfo2);
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_MTE2>(3);
}
}
}
if (extraInfo1.isValid) {
if ASCEND_IS_AIV {
vectorService.ProcessVec2L(extraInfo1);
}
extraInfo1.isValid = false;
}
}
template <typename SFAT>
__aicore__ inline uint64_t
SparseFlashAttentionMla<SFAT>::GetBalanceActualSeqLengths(GlobalTensor<int32_t> &actualSeqLengths,
uint32_t bIdx)
{
if constexpr (LAYOUT_T == SFA_LAYOUT::TND) {
if (bIdx > 0) {
return actualSeqLengths.GetValue(bIdx) - actualSeqLengths.GetValue(bIdx - 1);
} else if (bIdx == 0) {
return actualSeqLengths.GetValue(0);
} else {
return 0;
}
} else {
if (constInfo.actualLenDimsQ == 0) {
return constInfo.qSeqSize;
} else if (constInfo.actualLenDimsQ == 1) {
return actualSeqLengths.GetValue(0);
} else {
return actualSeqLengths.GetValue(bIdx);
}
}
}
template <typename SFAT>
__aicore__ inline void SparseFlashAttentionMla<SFAT>::GetAxisStartIdx(uint32_t bN2EndPrev,
uint32_t s1GEndPrev,
uint32_t s2EndPrev)
{
uint32_t bEndPrev = bN2EndPrev / kvHeadNum;
uint32_t actualSeqQPrev = GetBalanceActualSeqLengths(actualSeqLengthsQGm, bEndPrev);
uint32_t s1GPrevBaseNum = (actualSeqQPrev * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize;
constInfo.bN2Start = bN2EndPrev;
constInfo.gS1Start = s1GEndPrev;
constInfo.s2Start = 0;
if (s1GEndPrev >= s1GPrevBaseNum - 1) {
constInfo.gS1Start = 0;
constInfo.bN2Start++;
} else {
constInfo.gS1Start++;
}
}
template <typename SFAT>
__aicore__ inline void SparseFlashAttentionMla<SFAT>::CalcSinnerTopKBegin(RunInfo &info, uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock)
{
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
return;
}
uint64_t thresholdSparseCount = (info.threshold + constInfo.sparseBlockSize - 1) / constInfo.sparseBlockSize;
uint64_t validCount = (constInfo.sparseBlockCount > thresholdSparseCount) ? thresholdSparseCount : constInfo.sparseBlockCount;
int32_t sparseIndices = topKGm.GetValue(info.topKBaseOffset + curTopKIdx);
if (sparseIndices == -1 || curTopKIdx == validCount) {
info.actualSingleProcessSInnerSize = 0;
info.actualSingleProcessSInnerSizeAlign = 0;
tempLoopInfo.s2BasicSizeTail = 0;
if (curTopKIdx == 0) {
DealActSeqLenIsZero(info.bIdx, info.gS1Idx / constInfo.gSize, tempLoopInfo.n2Idx);
}
return;
}
uint32_t sparseLen = 0;
uint64_t blockBegin = sparseIndices * constInfo.sparseBlockSize;
uint64_t blockEnd = (blockBegin + constInfo.sparseBlockSize > info.threshold) ? info.threshold : blockBegin + constInfo.sparseBlockSize;
int32_t blockLen = blockEnd - blockBegin;
sparseLen += (blockLen > static_cast<int32_t>(curOffsetInSparseBlock)) ? blockLen - curOffsetInSparseBlock : 0;
bool firstVaildFlag = false;
if (curTopKIdx > 0) {
info.curTopKIdx = curTopKIdx;
info.curOffsetInSparseBlock = curOffsetInSparseBlock;
} else if (curTopKIdx == 0 && sparseLen > 0) {
info.curTopKIdx = curTopKIdx;
info.curOffsetInSparseBlock = 0;
firstVaildFlag = true;
}
for (uint64_t topkIdx = curTopKIdx + 1; topkIdx < validCount; topkIdx++) {
int32_t sparseIndices = topKGm.GetValue(info.topKBaseOffset + topkIdx);
if (sparseIndices == -1) {
curTopKIdx = topkIdx;
curOffsetInSparseBlock = 0;
break;
}
uint64_t blockBegin = sparseIndices * constInfo.sparseBlockSize;
if (blockBegin >= info.threshold) {
continue;
}
if (firstVaildFlag == false && curTopKIdx == 0) {
info.curTopKIdx = topkIdx;
info.curOffsetInSparseBlock = 0;
firstVaildFlag = true;
}
uint64_t blockEnd = (blockBegin + constInfo.sparseBlockSize > info.threshold) ? info.threshold : blockBegin + constInfo.sparseBlockSize;
uint64_t blockLen = blockEnd - blockBegin;
sparseLen += blockLen;
if (sparseLen >= constInfo.s2BaseSize) {
curTopKIdx = topkIdx;
curOffsetInSparseBlock = blockLen - (sparseLen - constInfo.s2BaseSize);
sparseLen = constInfo.s2BaseSize;
break;
}
if (topkIdx == validCount - 1) {
curTopKIdx = validCount;
curOffsetInSparseBlock = 0;
}
}
info.actualSingleProcessSInnerSize = sparseLen;
info.actualSingleProcessSInnerSizeAlign = SFAAlign((uint32_t)info.actualSingleProcessSInnerSize, (uint32_t)SFAVectorService<SFAT>::BYTE_BLOCK);
tempLoopInfo.s2BasicSizeTail = (sparseLen == constInfo.s2BaseSize) ? 0 : sparseLen;
if (curTopKIdx == 0 && sparseLen == 0) {
DealActSeqLenIsZero(info.bIdx, info.gS1Idx / constInfo.gSize, tempLoopInfo.n2Idx);
}
}
#endif // SPARSE_FLASH_ATTENTION_KERNEL_MLA_H

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,54 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file sparse_flash_attention_template_tiling_key.h
* \brief
*/
#ifndef SPARSE_FLASH_ATTENTION_TEMPLATE_TILING_KEY_H
#define SPARSE_FLASH_ATTENTION_TEMPLATE_TILING_KEY_H
#include "ascendc/host_api/tiling/template_argument.h"
#define SFA_LAYOUT_BSND 0
#define SFA_LAYOUT_TND 1
#define SFA_LAYOUT_PA_BSND 2
#define ASCENDC_TPL_4_BW 4
#define C_TEMPLATE 0
#define V_TEMPLATE 1
ASCENDC_TPL_ARGS_DECL(SparseFlashAttention,
ASCENDC_TPL_BOOL_DECL(FLASH_DECODE, 0, 1),
ASCENDC_TPL_UINT_DECL(LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_TND),
ASCENDC_TPL_UINT_DECL(KV_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_TND,
SFA_LAYOUT_PA_BSND),
ASCENDC_TPL_UINT_DECL(TEMPLATE_MODE, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, C_TEMPLATE, V_TEMPLATE),
);
ASCENDC_TPL_SEL(
ASCENDC_TPL_ARGS_SEL(
ASCENDC_TPL_BOOL_SEL(FLASH_DECODE, 0),
ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_TND),
ASCENDC_TPL_UINT_SEL(KV_LAYOUT_T, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_PA_BSND, SFA_LAYOUT_TND),
ASCENDC_TPL_UINT_SEL(TEMPLATE_MODE, ASCENDC_TPL_UI_LIST, C_TEMPLATE),
),
ASCENDC_TPL_ARGS_SEL(
ASCENDC_TPL_BOOL_SEL(FLASH_DECODE, 0),
ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_TND),
ASCENDC_TPL_UINT_SEL(KV_LAYOUT_T, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_PA_BSND, SFA_LAYOUT_BSND, SFA_LAYOUT_TND),
ASCENDC_TPL_UINT_SEL(TEMPLATE_MODE, ASCENDC_TPL_UI_LIST, V_TEMPLATE), // V模板不支持非PA
),
);
#endif // TEMPLATE_TILING_KEY