Files
xc-llm-ascend/csrc/lightning_indexer/op_kernel/lightning_indexer_vector.h
Song Mingyang 18b90b501d [kernel] add AscendC op: lightning_indexer and sparse_flash_attention (#4625)
### What this PR does / why we need it?
Provide high-performance AscendC operators lightning_indexer and
sparse_flash_attention to boost the execution performance of the
DeepSeek v3.2 model. Meanwhile, adapt the two AscendC operators to
vllm-ascend framework.

### Does this PR introduce _any_ user-facing change?
No (only underlying operator optimizations, with no user-facing changes)

### How was this patch tested?

- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

Signed-off-by: MingYang119 <songmingyang@huawei.com>
2025-12-03 09:53:10 +08:00

335 lines
13 KiB
C++

/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_vector.h
* \brief
*/
#ifndef LIGHTNING_INDEXER_VECTOR_H
#define LIGHTNING_INDEXER_VECTOR_H
#include "lightning_indexer_vector.h"
#include "kernel_operator.h"
namespace LIServiceVec {
using namespace AscendC;
constexpr int32_t NEG_INF = 0xFF800000;
constexpr int32_t INVALID_INDEX = -1;
constexpr uint8_t VEC_REPEAT_MAX = 255;
constexpr uint8_t B32_VEC_ELM_NUM = 64;
constexpr uint8_t B32_BLOCK_ALIGN_NUM = 8;
constexpr uint8_t B32_VEC_REPEAT_STRIDE = 8;
constexpr uint64_t VEC_REPEAT_BYTES = 256;
constexpr int32_t CONST_TWO = 2;
constexpr int64_t VALUE_AND_INDEX_NUM = 2;
constexpr int64_t BLOCK_BYTES = 32;
constexpr int64_t MRG_QUE_0 = 0;
constexpr int64_t MRG_QUE_1 = 1;
constexpr int64_t MRG_QUE_2 = 2;
constexpr int64_t MRG_QUE_3 = 3;
constexpr int64_t MRG_BLOCK_2 = 2;
constexpr int64_t MRG_BLOCK_3 = 3;
constexpr int64_t MRG_BLOCK_4 = 4;
template <typename T>
__aicore__ inline void CopyIn(LocalTensor<float> &mmOutUb, LocalTensor<T> &weightsUb, GlobalTensor<float> &mMoutGm,
GlobalTensor<T> &weightScaleGm, int64_t MMout_gmoffset, int64_t weights_gmoffset,
int64_t groupInner, int64_t s2Inner, int64_t mmUbStride)
{
AscendC::DataCopyPadExtParams<float> padParams{false, 0, 0, 0};
AscendC::DataCopyExtParams dataCopymMoutParams;
dataCopymMoutParams.blockCount = groupInner;
dataCopymMoutParams.blockLen = s2Inner * sizeof(float);
dataCopymMoutParams.srcStride = 0;
dataCopymMoutParams.dstStride = mmUbStride;
dataCopymMoutParams.rsv = 0;
AscendC::DataCopyPad(mmOutUb, mMoutGm[MMout_gmoffset], dataCopymMoutParams, padParams);
AscendC::DataCopyPadExtParams<T> padTParams{false, 0, 0, 0};
AscendC::DataCopyExtParams dataCopyweightParams;
dataCopyweightParams.blockCount = 1;
dataCopyweightParams.blockLen = groupInner * sizeof(T);
dataCopyweightParams.srcStride = 0;
dataCopyweightParams.dstStride = 0;
dataCopyweightParams.rsv = 0;
AscendC::DataCopyPad(weightsUb, weightScaleGm[weights_gmoffset], dataCopyweightParams, padTParams);
}
template <typename T>
__aicore__ inline void CopyOut(const GlobalTensor<T> &dstGm, const LocalTensor<T> &srcUb, int64_t copyCount)
{
AscendC::DataCopyParams dataCopyOutyParams;
dataCopyOutyParams.blockCount = 1;
dataCopyOutyParams.blockLen = copyCount * sizeof(T);
dataCopyOutyParams.srcStride = 0;
dataCopyOutyParams.dstStride = 0;
AscendC::DataCopyPad(dstGm, srcUb, dataCopyOutyParams);
}
template <typename T>
__aicore__ inline void DoScale(const LocalTensor<float> &reduceCacheBuf, LocalTensor<float> &mmOutUb,
LocalTensor<float> &weightsUb, LocalTensor<T> &weightsTUb, LocalTensor<float> &tmpBuff,
int64_t groupInner, int64_t s2Inner, int32_t outerGidx)
{
// cast bfloat16_t to float
if constexpr (!IsSameType<T, float>::value) {
AscendC::Cast(weightsUb, weightsTUb, RoundMode::CAST_NONE, groupInner);
AscendC::PipeBarrier<PIPE_V>();
}
// weight broadcast: [groupInner, 1] -> [groupInner, 8]
AscendC::Brcb(tmpBuff, weightsUb, LICommon::CeilDiv(groupInner, static_cast<int64_t>(B32_BLOCK_ALIGN_NUM)),
{1, B32_VEC_REPEAT_STRIDE});
AscendC::PipeBarrier<PIPE_V>();
// do scale: [groupInner, 8] * [groupInner, s2Inner]
uint64_t countPerRepeat = VEC_REPEAT_BYTES / sizeof(float);
uint64_t repeatTimes = s2Inner / countPerRepeat;
for (int32_t i = 0; i < groupInner; i++) {
if (outerGidx == 0) {
AscendC::Mul(reduceCacheBuf[i * s2Inner], mmOutUb[i * s2Inner], tmpBuff[i * B32_BLOCK_ALIGN_NUM],
countPerRepeat, repeatTimes, {1, 1, 0, B32_VEC_REPEAT_STRIDE, B32_VEC_REPEAT_STRIDE, 0});
} else {
AscendC::Mul(mmOutUb[i * s2Inner], mmOutUb[i * s2Inner], tmpBuff[i * B32_BLOCK_ALIGN_NUM], countPerRepeat,
repeatTimes, {1, 1, 0, B32_VEC_REPEAT_STRIDE, B32_VEC_REPEAT_STRIDE, 0});
}
}
if (outerGidx != 0) {
AscendC::PipeBarrier<PIPE_V>();
AscendC::Add(reduceCacheBuf, mmOutUb, reduceCacheBuf, groupInner * s2Inner);
}
AscendC::PipeBarrier<PIPE_V>();
}
__aicore__ inline uint64_t FindNearestPower2(uint64_t value)
{
if (value <= CONST_TWO) {
return value;
} else {
const uint64_t pow = 63 - clz(value);
return (1 << pow);
}
}
__aicore__ inline void DoReduce(const LocalTensor<float> &srcTensor, LocalTensor<float> &dstTensor, int32_t rNum,
int32_t aNum)
{
if (rNum == 1) {
AscendC::Adds<float>(dstTensor, srcTensor, 0, aNum);
AscendC::PipeBarrier<PIPE_V>();
return;
}
uint32_t dichotomizeAddPow = FindNearestPower2(rNum);
uint32_t dichotomizeAddDiffSize = rNum - dichotomizeAddPow;
if (dichotomizeAddDiffSize != 0) {
AscendC::Add(srcTensor, srcTensor, srcTensor[dichotomizeAddPow * aNum], dichotomizeAddDiffSize * aNum);
AscendC::PipeBarrier<PIPE_V>();
}
int32_t nowRows = dichotomizeAddPow;
while (nowRows > CONST_TWO) {
nowRows = nowRows / CONST_TWO;
AscendC::Add(srcTensor, srcTensor, srcTensor[nowRows * aNum], nowRows * aNum);
AscendC::PipeBarrier<PIPE_V>();
}
AscendC::Add(dstTensor, srcTensor, srcTensor[aNum], aNum);
AscendC::PipeBarrier<PIPE_V>();
}
__aicore__ inline void InitSortOutBuf(const LocalTensor<float> &src, int64_t eleNum)
{
uint64_t mask1[2] = {0x5555555555555555, 0};
uint64_t mask0[2] = {0xaaaaaaaaaaaaaaaa, 0};
int64_t repeatNum = eleNum / B32_VEC_ELM_NUM;
int64_t forLoop = repeatNum / VEC_REPEAT_MAX;
int64_t forRemain = repeatNum % VEC_REPEAT_MAX;
for (int i = 0; i < forLoop; i++) {
AscendC::Duplicate(src.template ReinterpretCast<int32_t>(), NEG_INF, mask1, VEC_REPEAT_MAX, 1,
B32_VEC_REPEAT_STRIDE);
AscendC::Duplicate(src.template ReinterpretCast<int32_t>(), INVALID_INDEX, mask0, VEC_REPEAT_MAX, 1,
B32_VEC_REPEAT_STRIDE);
}
if (forRemain > 0) {
AscendC::Duplicate(src.template ReinterpretCast<int32_t>()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], NEG_INF,
mask1, forRemain, 1, B32_VEC_REPEAT_STRIDE);
AscendC::Duplicate(src.template ReinterpretCast<int32_t>()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM],
INVALID_INDEX, mask0, forRemain, 1, B32_VEC_REPEAT_STRIDE);
}
AscendC::PipeBarrier<PIPE_V>();
}
__aicore__ inline void SortAll(LocalTensor<float> &src, LocalTensor<float> &tmp, int64_t logitsNum)
{
int64_t sort32Repeats = logitsNum / BLOCK_BYTES;
AscendC::Sort32(tmp, src, src[logitsNum].ReinterpretCast<uint32_t>(), sort32Repeats);
AscendC::PipeBarrier<PIPE_V>();
int64_t mrgGroups = sort32Repeats;
int64_t mrgElements = BLOCK_BYTES;
int64_t i = 0;
AscendC::LocalTensor<float> srcTensor;
AscendC::LocalTensor<float> dstTensor;
while (true) {
if (i % CONST_TWO == 0) {
srcTensor = tmp;
dstTensor = src;
} else {
srcTensor = src;
dstTensor = tmp;
}
AscendC::MrgSort4Info params;
params.elementLengths[0] = mrgElements;
params.elementLengths[MRG_QUE_1] = mrgElements;
params.elementLengths[MRG_QUE_2] = mrgElements;
params.elementLengths[MRG_QUE_3] = mrgElements;
params.ifExhaustedSuspension = false;
params.validBit = 0b1111;
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = srcTensor[0];
srcList.src2 = srcTensor[MRG_QUE_1 * VALUE_AND_INDEX_NUM * mrgElements];
srcList.src3 = srcTensor[MRG_QUE_2 * VALUE_AND_INDEX_NUM * mrgElements];
srcList.src4 = srcTensor[MRG_QUE_3 * VALUE_AND_INDEX_NUM * mrgElements];
if (mrgGroups <= MRG_BLOCK_4) {
params.repeatTimes = 1;
if (mrgGroups == 1) {
break;
} else if (mrgGroups == MRG_BLOCK_2) {
params.validBit = 0b0011;
} else if (mrgGroups == MRG_BLOCK_3) {
params.validBit = 0b0111;
} else if (mrgGroups == MRG_BLOCK_4) {
params.validBit = 0b1111;
}
AscendC::MrgSort<float>(dstTensor, srcList, params);
i += 1;
break;
} else {
params.repeatTimes = mrgGroups / MRG_BLOCK_4;
AscendC::MrgSort<float>(dstTensor, srcList, params);
i += 1;
mrgElements = mrgElements * MRG_BLOCK_4;
mrgGroups = mrgGroups / MRG_BLOCK_4;
}
AscendC::PipeBarrier<PIPE_V>();
}
if (i % CONST_TWO == 0) {
AscendC::DataCopy(src, tmp, logitsNum * VALUE_AND_INDEX_NUM);
AscendC::PipeBarrier<PIPE_V>();
}
}
__aicore__ inline void SortAll(LocalTensor<float> &dst, LocalTensor<float> &srcValue, LocalTensor<uint32_t> &srcIndex,
LocalTensor<float> &tmpTensor, int64_t logitsNum)
{
int64_t sort32Repeats = logitsNum / BLOCK_BYTES;
AscendC::Sort<float, true>(dst, srcValue, srcIndex, tmpTensor, sort32Repeats);
AscendC::PipeBarrier<PIPE_V>();
}
__aicore__ inline void MergeSort(const LocalTensor<float> &mrgDst, int32_t mrgDstNum, LocalTensor<float> &mrgSrc,
int32_t mrgSrcNum, LocalTensor<float> &tmpTensor)
{
AscendC::MrgSort4Info params;
params.elementLengths[0] = mrgDstNum;
params.elementLengths[1] = mrgSrcNum;
params.ifExhaustedSuspension = false;
params.validBit = 0b0011;
params.repeatTimes = 1;
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = mrgDst;
srcList.src2 = mrgSrc;
AscendC::MrgSort<float>(tmpTensor, srcList, params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::DataCopy(mrgDst, tmpTensor, mrgDstNum * VALUE_AND_INDEX_NUM);
AscendC::PipeBarrier<PIPE_V>();
}
__aicore__ inline void MrgBasicBlock(const LocalTensor<float> &dst, const LocalTensor<float> &src, int64_t blockNum,
int64_t basicBlockSize)
{
AscendC::MrgSort4Info params;
params.elementLengths[MRG_QUE_0] = basicBlockSize;
params.elementLengths[MRG_QUE_1] = basicBlockSize;
params.elementLengths[MRG_QUE_2] = basicBlockSize;
params.elementLengths[MRG_QUE_3] = basicBlockSize;
params.ifExhaustedSuspension = false;
if (blockNum == MRG_BLOCK_2) {
params.validBit = 0b0011;
} else if (blockNum == MRG_BLOCK_3) {
params.validBit = 0b0111;
} else if (blockNum == MRG_BLOCK_4) {
params.validBit = 0b1111;
} else {
AscendC::DataCopy(dst, src, basicBlockSize * VALUE_AND_INDEX_NUM);
return;
}
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = src[0];
srcList.src2 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_1];
srcList.src3 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_2];
srcList.src4 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_3];
AscendC::MrgSort<float>(dst, srcList, params);
}
template <bool needMrg = true>
__aicore__ inline void SparseTopK(const LocalTensor<float> &dst, const LocalTensor<float> &needsMerging,
const LocalTensor<float> &tmp, int64_t topk, int64_t mergSize)
{
if (!needMrg) {
AscendC::DataCopy(dst, needsMerging, mergSize * VALUE_AND_INDEX_NUM);
return;
}
AscendC::MrgSort4Info params;
params.elementLengths[0] = topk;
params.elementLengths[1] = mergSize;
params.ifExhaustedSuspension = (topk == mergSize);
params.validBit = 0b0011;
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = dst;
srcList.src2 = needsMerging;
AscendC::MrgSort<float>(tmp, srcList, params);
AscendC::DataCopy(dst, tmp, topk * VALUE_AND_INDEX_NUM);
}
__aicore__ inline void ExtractIndex(const LocalTensor<uint32_t> &idxULocal, const LocalTensor<uint32_t> &sortLocal,
int64_t extractNum)
{
AscendC::GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = Ceil(extractNum * sizeof(float) * VALUE_AND_INDEX_NUM, VEC_REPEAT_BYTES);
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride = B32_VEC_REPEAT_STRIDE;
gatherMaskParams.src1RepeatStride = 0;
uint64_t rsvdCnt = 0;
uint8_t src1Pattern = 2;
AscendC::GatherMask(idxULocal, sortLocal, src1Pattern, false, static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
AscendC::PipeBarrier<PIPE_V>();
}
template <HardEvent event>
__aicore__ inline void SetWaitFlag(HardEvent evt)
{
event_t eventId = static_cast<event_t>(GetTPipePtr()->FetchEventID(evt));
AscendC::SetFlag<event>(eventId);
AscendC::WaitFlag<event>(eventId);
}
} // namespace LIServiceVec
#endif // LIGHTNING_INDEXER_VECTOR_H