### What this PR does / why we need it? Provide high-performance AscendC operators lightning_indexer and sparse_flash_attention to boost the execution performance of the DeepSeek v3.2 model. Meanwhile, adapt the two AscendC operators to vllm-ascend framework. ### Does this PR introduce _any_ user-facing change? No (only underlying operator optimizations, with no user-facing changes) ### How was this patch tested? - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: MingYang119 <songmingyang@huawei.com>
335 lines
13 KiB
C++
335 lines
13 KiB
C++
/**
|
|
* This program is free software, you can redistribute it and/or modify it.
|
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
* This file is a part of the CANN Open Software.
|
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
* See LICENSE in the root of the software repository for the full text of the License.
|
|
*/
|
|
|
|
/*!
|
|
* \file lightning_indexer_vector.h
|
|
* \brief
|
|
*/
|
|
#ifndef LIGHTNING_INDEXER_VECTOR_H
|
|
#define LIGHTNING_INDEXER_VECTOR_H
|
|
|
|
#include "lightning_indexer_vector.h"
|
|
#include "kernel_operator.h"
|
|
|
|
namespace LIServiceVec {
|
|
using namespace AscendC;
|
|
|
|
constexpr int32_t NEG_INF = 0xFF800000;
|
|
constexpr int32_t INVALID_INDEX = -1;
|
|
constexpr uint8_t VEC_REPEAT_MAX = 255;
|
|
constexpr uint8_t B32_VEC_ELM_NUM = 64;
|
|
constexpr uint8_t B32_BLOCK_ALIGN_NUM = 8;
|
|
constexpr uint8_t B32_VEC_REPEAT_STRIDE = 8;
|
|
constexpr uint64_t VEC_REPEAT_BYTES = 256;
|
|
constexpr int32_t CONST_TWO = 2;
|
|
constexpr int64_t VALUE_AND_INDEX_NUM = 2;
|
|
constexpr int64_t BLOCK_BYTES = 32;
|
|
constexpr int64_t MRG_QUE_0 = 0;
|
|
constexpr int64_t MRG_QUE_1 = 1;
|
|
constexpr int64_t MRG_QUE_2 = 2;
|
|
constexpr int64_t MRG_QUE_3 = 3;
|
|
constexpr int64_t MRG_BLOCK_2 = 2;
|
|
constexpr int64_t MRG_BLOCK_3 = 3;
|
|
constexpr int64_t MRG_BLOCK_4 = 4;
|
|
|
|
template <typename T>
|
|
__aicore__ inline void CopyIn(LocalTensor<float> &mmOutUb, LocalTensor<T> &weightsUb, GlobalTensor<float> &mMoutGm,
|
|
GlobalTensor<T> &weightScaleGm, int64_t MMout_gmoffset, int64_t weights_gmoffset,
|
|
int64_t groupInner, int64_t s2Inner, int64_t mmUbStride)
|
|
{
|
|
AscendC::DataCopyPadExtParams<float> padParams{false, 0, 0, 0};
|
|
AscendC::DataCopyExtParams dataCopymMoutParams;
|
|
dataCopymMoutParams.blockCount = groupInner;
|
|
dataCopymMoutParams.blockLen = s2Inner * sizeof(float);
|
|
dataCopymMoutParams.srcStride = 0;
|
|
dataCopymMoutParams.dstStride = mmUbStride;
|
|
dataCopymMoutParams.rsv = 0;
|
|
AscendC::DataCopyPad(mmOutUb, mMoutGm[MMout_gmoffset], dataCopymMoutParams, padParams);
|
|
|
|
AscendC::DataCopyPadExtParams<T> padTParams{false, 0, 0, 0};
|
|
AscendC::DataCopyExtParams dataCopyweightParams;
|
|
dataCopyweightParams.blockCount = 1;
|
|
dataCopyweightParams.blockLen = groupInner * sizeof(T);
|
|
dataCopyweightParams.srcStride = 0;
|
|
dataCopyweightParams.dstStride = 0;
|
|
dataCopyweightParams.rsv = 0;
|
|
AscendC::DataCopyPad(weightsUb, weightScaleGm[weights_gmoffset], dataCopyweightParams, padTParams);
|
|
}
|
|
|
|
|
|
template <typename T>
|
|
__aicore__ inline void CopyOut(const GlobalTensor<T> &dstGm, const LocalTensor<T> &srcUb, int64_t copyCount)
|
|
{
|
|
AscendC::DataCopyParams dataCopyOutyParams;
|
|
dataCopyOutyParams.blockCount = 1;
|
|
dataCopyOutyParams.blockLen = copyCount * sizeof(T);
|
|
dataCopyOutyParams.srcStride = 0;
|
|
dataCopyOutyParams.dstStride = 0;
|
|
AscendC::DataCopyPad(dstGm, srcUb, dataCopyOutyParams);
|
|
}
|
|
|
|
|
|
template <typename T>
|
|
__aicore__ inline void DoScale(const LocalTensor<float> &reduceCacheBuf, LocalTensor<float> &mmOutUb,
|
|
LocalTensor<float> &weightsUb, LocalTensor<T> &weightsTUb, LocalTensor<float> &tmpBuff,
|
|
int64_t groupInner, int64_t s2Inner, int32_t outerGidx)
|
|
{
|
|
// cast bfloat16_t to float
|
|
if constexpr (!IsSameType<T, float>::value) {
|
|
AscendC::Cast(weightsUb, weightsTUb, RoundMode::CAST_NONE, groupInner);
|
|
AscendC::PipeBarrier<PIPE_V>();
|
|
}
|
|
|
|
// weight broadcast: [groupInner, 1] -> [groupInner, 8]
|
|
AscendC::Brcb(tmpBuff, weightsUb, LICommon::CeilDiv(groupInner, static_cast<int64_t>(B32_BLOCK_ALIGN_NUM)),
|
|
{1, B32_VEC_REPEAT_STRIDE});
|
|
AscendC::PipeBarrier<PIPE_V>();
|
|
|
|
// do scale: [groupInner, 8] * [groupInner, s2Inner]
|
|
uint64_t countPerRepeat = VEC_REPEAT_BYTES / sizeof(float);
|
|
uint64_t repeatTimes = s2Inner / countPerRepeat;
|
|
for (int32_t i = 0; i < groupInner; i++) {
|
|
if (outerGidx == 0) {
|
|
AscendC::Mul(reduceCacheBuf[i * s2Inner], mmOutUb[i * s2Inner], tmpBuff[i * B32_BLOCK_ALIGN_NUM],
|
|
countPerRepeat, repeatTimes, {1, 1, 0, B32_VEC_REPEAT_STRIDE, B32_VEC_REPEAT_STRIDE, 0});
|
|
} else {
|
|
AscendC::Mul(mmOutUb[i * s2Inner], mmOutUb[i * s2Inner], tmpBuff[i * B32_BLOCK_ALIGN_NUM], countPerRepeat,
|
|
repeatTimes, {1, 1, 0, B32_VEC_REPEAT_STRIDE, B32_VEC_REPEAT_STRIDE, 0});
|
|
}
|
|
}
|
|
|
|
if (outerGidx != 0) {
|
|
AscendC::PipeBarrier<PIPE_V>();
|
|
AscendC::Add(reduceCacheBuf, mmOutUb, reduceCacheBuf, groupInner * s2Inner);
|
|
}
|
|
AscendC::PipeBarrier<PIPE_V>();
|
|
}
|
|
|
|
|
|
__aicore__ inline uint64_t FindNearestPower2(uint64_t value)
|
|
{
|
|
if (value <= CONST_TWO) {
|
|
return value;
|
|
} else {
|
|
const uint64_t pow = 63 - clz(value);
|
|
return (1 << pow);
|
|
}
|
|
}
|
|
|
|
|
|
__aicore__ inline void DoReduce(const LocalTensor<float> &srcTensor, LocalTensor<float> &dstTensor, int32_t rNum,
|
|
int32_t aNum)
|
|
{
|
|
if (rNum == 1) {
|
|
AscendC::Adds<float>(dstTensor, srcTensor, 0, aNum);
|
|
AscendC::PipeBarrier<PIPE_V>();
|
|
return;
|
|
}
|
|
|
|
uint32_t dichotomizeAddPow = FindNearestPower2(rNum);
|
|
uint32_t dichotomizeAddDiffSize = rNum - dichotomizeAddPow;
|
|
if (dichotomizeAddDiffSize != 0) {
|
|
AscendC::Add(srcTensor, srcTensor, srcTensor[dichotomizeAddPow * aNum], dichotomizeAddDiffSize * aNum);
|
|
AscendC::PipeBarrier<PIPE_V>();
|
|
}
|
|
int32_t nowRows = dichotomizeAddPow;
|
|
while (nowRows > CONST_TWO) {
|
|
nowRows = nowRows / CONST_TWO;
|
|
AscendC::Add(srcTensor, srcTensor, srcTensor[nowRows * aNum], nowRows * aNum);
|
|
AscendC::PipeBarrier<PIPE_V>();
|
|
}
|
|
AscendC::Add(dstTensor, srcTensor, srcTensor[aNum], aNum);
|
|
AscendC::PipeBarrier<PIPE_V>();
|
|
}
|
|
|
|
__aicore__ inline void InitSortOutBuf(const LocalTensor<float> &src, int64_t eleNum)
|
|
{
|
|
uint64_t mask1[2] = {0x5555555555555555, 0};
|
|
uint64_t mask0[2] = {0xaaaaaaaaaaaaaaaa, 0};
|
|
int64_t repeatNum = eleNum / B32_VEC_ELM_NUM;
|
|
int64_t forLoop = repeatNum / VEC_REPEAT_MAX;
|
|
int64_t forRemain = repeatNum % VEC_REPEAT_MAX;
|
|
for (int i = 0; i < forLoop; i++) {
|
|
AscendC::Duplicate(src.template ReinterpretCast<int32_t>(), NEG_INF, mask1, VEC_REPEAT_MAX, 1,
|
|
B32_VEC_REPEAT_STRIDE);
|
|
AscendC::Duplicate(src.template ReinterpretCast<int32_t>(), INVALID_INDEX, mask0, VEC_REPEAT_MAX, 1,
|
|
B32_VEC_REPEAT_STRIDE);
|
|
}
|
|
if (forRemain > 0) {
|
|
AscendC::Duplicate(src.template ReinterpretCast<int32_t>()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], NEG_INF,
|
|
mask1, forRemain, 1, B32_VEC_REPEAT_STRIDE);
|
|
AscendC::Duplicate(src.template ReinterpretCast<int32_t>()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM],
|
|
INVALID_INDEX, mask0, forRemain, 1, B32_VEC_REPEAT_STRIDE);
|
|
}
|
|
AscendC::PipeBarrier<PIPE_V>();
|
|
}
|
|
|
|
__aicore__ inline void SortAll(LocalTensor<float> &src, LocalTensor<float> &tmp, int64_t logitsNum)
|
|
{
|
|
int64_t sort32Repeats = logitsNum / BLOCK_BYTES;
|
|
AscendC::Sort32(tmp, src, src[logitsNum].ReinterpretCast<uint32_t>(), sort32Repeats);
|
|
AscendC::PipeBarrier<PIPE_V>();
|
|
|
|
int64_t mrgGroups = sort32Repeats;
|
|
int64_t mrgElements = BLOCK_BYTES;
|
|
int64_t i = 0;
|
|
AscendC::LocalTensor<float> srcTensor;
|
|
AscendC::LocalTensor<float> dstTensor;
|
|
while (true) {
|
|
if (i % CONST_TWO == 0) {
|
|
srcTensor = tmp;
|
|
dstTensor = src;
|
|
} else {
|
|
srcTensor = src;
|
|
dstTensor = tmp;
|
|
}
|
|
AscendC::MrgSort4Info params;
|
|
params.elementLengths[0] = mrgElements;
|
|
params.elementLengths[MRG_QUE_1] = mrgElements;
|
|
params.elementLengths[MRG_QUE_2] = mrgElements;
|
|
params.elementLengths[MRG_QUE_3] = mrgElements;
|
|
params.ifExhaustedSuspension = false;
|
|
params.validBit = 0b1111;
|
|
|
|
AscendC::MrgSortSrcList<float> srcList;
|
|
srcList.src1 = srcTensor[0];
|
|
srcList.src2 = srcTensor[MRG_QUE_1 * VALUE_AND_INDEX_NUM * mrgElements];
|
|
srcList.src3 = srcTensor[MRG_QUE_2 * VALUE_AND_INDEX_NUM * mrgElements];
|
|
srcList.src4 = srcTensor[MRG_QUE_3 * VALUE_AND_INDEX_NUM * mrgElements];
|
|
if (mrgGroups <= MRG_BLOCK_4) {
|
|
params.repeatTimes = 1;
|
|
if (mrgGroups == 1) {
|
|
break;
|
|
} else if (mrgGroups == MRG_BLOCK_2) {
|
|
params.validBit = 0b0011;
|
|
} else if (mrgGroups == MRG_BLOCK_3) {
|
|
params.validBit = 0b0111;
|
|
} else if (mrgGroups == MRG_BLOCK_4) {
|
|
params.validBit = 0b1111;
|
|
}
|
|
AscendC::MrgSort<float>(dstTensor, srcList, params);
|
|
i += 1;
|
|
break;
|
|
} else {
|
|
params.repeatTimes = mrgGroups / MRG_BLOCK_4;
|
|
AscendC::MrgSort<float>(dstTensor, srcList, params);
|
|
i += 1;
|
|
mrgElements = mrgElements * MRG_BLOCK_4;
|
|
mrgGroups = mrgGroups / MRG_BLOCK_4;
|
|
}
|
|
AscendC::PipeBarrier<PIPE_V>();
|
|
}
|
|
if (i % CONST_TWO == 0) {
|
|
AscendC::DataCopy(src, tmp, logitsNum * VALUE_AND_INDEX_NUM);
|
|
AscendC::PipeBarrier<PIPE_V>();
|
|
}
|
|
}
|
|
|
|
__aicore__ inline void SortAll(LocalTensor<float> &dst, LocalTensor<float> &srcValue, LocalTensor<uint32_t> &srcIndex,
|
|
LocalTensor<float> &tmpTensor, int64_t logitsNum)
|
|
{
|
|
int64_t sort32Repeats = logitsNum / BLOCK_BYTES;
|
|
AscendC::Sort<float, true>(dst, srcValue, srcIndex, tmpTensor, sort32Repeats);
|
|
AscendC::PipeBarrier<PIPE_V>();
|
|
}
|
|
|
|
__aicore__ inline void MergeSort(const LocalTensor<float> &mrgDst, int32_t mrgDstNum, LocalTensor<float> &mrgSrc,
|
|
int32_t mrgSrcNum, LocalTensor<float> &tmpTensor)
|
|
{
|
|
AscendC::MrgSort4Info params;
|
|
params.elementLengths[0] = mrgDstNum;
|
|
params.elementLengths[1] = mrgSrcNum;
|
|
params.ifExhaustedSuspension = false;
|
|
params.validBit = 0b0011;
|
|
params.repeatTimes = 1;
|
|
|
|
AscendC::MrgSortSrcList<float> srcList;
|
|
srcList.src1 = mrgDst;
|
|
srcList.src2 = mrgSrc;
|
|
|
|
AscendC::MrgSort<float>(tmpTensor, srcList, params);
|
|
AscendC::PipeBarrier<PIPE_V>();
|
|
AscendC::DataCopy(mrgDst, tmpTensor, mrgDstNum * VALUE_AND_INDEX_NUM);
|
|
AscendC::PipeBarrier<PIPE_V>();
|
|
}
|
|
|
|
__aicore__ inline void MrgBasicBlock(const LocalTensor<float> &dst, const LocalTensor<float> &src, int64_t blockNum,
|
|
int64_t basicBlockSize)
|
|
{
|
|
AscendC::MrgSort4Info params;
|
|
params.elementLengths[MRG_QUE_0] = basicBlockSize;
|
|
params.elementLengths[MRG_QUE_1] = basicBlockSize;
|
|
params.elementLengths[MRG_QUE_2] = basicBlockSize;
|
|
params.elementLengths[MRG_QUE_3] = basicBlockSize;
|
|
params.ifExhaustedSuspension = false;
|
|
if (blockNum == MRG_BLOCK_2) {
|
|
params.validBit = 0b0011;
|
|
} else if (blockNum == MRG_BLOCK_3) {
|
|
params.validBit = 0b0111;
|
|
} else if (blockNum == MRG_BLOCK_4) {
|
|
params.validBit = 0b1111;
|
|
} else {
|
|
AscendC::DataCopy(dst, src, basicBlockSize * VALUE_AND_INDEX_NUM);
|
|
return;
|
|
}
|
|
AscendC::MrgSortSrcList<float> srcList;
|
|
srcList.src1 = src[0];
|
|
srcList.src2 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_1];
|
|
srcList.src3 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_2];
|
|
srcList.src4 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_3];
|
|
AscendC::MrgSort<float>(dst, srcList, params);
|
|
}
|
|
|
|
template <bool needMrg = true>
|
|
__aicore__ inline void SparseTopK(const LocalTensor<float> &dst, const LocalTensor<float> &needsMerging,
|
|
const LocalTensor<float> &tmp, int64_t topk, int64_t mergSize)
|
|
{
|
|
if (!needMrg) {
|
|
AscendC::DataCopy(dst, needsMerging, mergSize * VALUE_AND_INDEX_NUM);
|
|
return;
|
|
}
|
|
AscendC::MrgSort4Info params;
|
|
params.elementLengths[0] = topk;
|
|
params.elementLengths[1] = mergSize;
|
|
params.ifExhaustedSuspension = (topk == mergSize);
|
|
params.validBit = 0b0011;
|
|
AscendC::MrgSortSrcList<float> srcList;
|
|
srcList.src1 = dst;
|
|
srcList.src2 = needsMerging;
|
|
AscendC::MrgSort<float>(tmp, srcList, params);
|
|
AscendC::DataCopy(dst, tmp, topk * VALUE_AND_INDEX_NUM);
|
|
}
|
|
|
|
|
|
__aicore__ inline void ExtractIndex(const LocalTensor<uint32_t> &idxULocal, const LocalTensor<uint32_t> &sortLocal,
|
|
int64_t extractNum)
|
|
{
|
|
AscendC::GatherMaskParams gatherMaskParams;
|
|
gatherMaskParams.repeatTimes = Ceil(extractNum * sizeof(float) * VALUE_AND_INDEX_NUM, VEC_REPEAT_BYTES);
|
|
gatherMaskParams.src0BlockStride = 1;
|
|
gatherMaskParams.src0RepeatStride = B32_VEC_REPEAT_STRIDE;
|
|
gatherMaskParams.src1RepeatStride = 0;
|
|
uint64_t rsvdCnt = 0;
|
|
uint8_t src1Pattern = 2;
|
|
AscendC::GatherMask(idxULocal, sortLocal, src1Pattern, false, static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
|
|
AscendC::PipeBarrier<PIPE_V>();
|
|
}
|
|
|
|
|
|
template <HardEvent event>
|
|
__aicore__ inline void SetWaitFlag(HardEvent evt)
|
|
{
|
|
event_t eventId = static_cast<event_t>(GetTPipePtr()->FetchEventID(evt));
|
|
AscendC::SetFlag<event>(eventId);
|
|
AscendC::WaitFlag<event>(eventId);
|
|
}
|
|
|
|
} // namespace LIServiceVec
|
|
#endif // LIGHTNING_INDEXER_VECTOR_H
|