/** * This program is free software, you can redistribute it and/or modify it. * Copyright (c) 2025 Huawei Technologies Co., Ltd. * This file is a part of the CANN Open Software. * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). * Please refer to the License for details. You may not use this file except in compliance with the License. * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. * See LICENSE in the root of the software repository for the full text of the License. */ /*! * \file lightning_indexer_vector.h * \brief */ #ifndef LIGHTNING_INDEXER_VECTOR_H #define LIGHTNING_INDEXER_VECTOR_H #include "lightning_indexer_vector.h" #include "kernel_operator.h" namespace LIServiceVec { using namespace AscendC; constexpr int32_t NEG_INF = 0xFF800000; constexpr int32_t INVALID_INDEX = -1; constexpr uint8_t VEC_REPEAT_MAX = 255; constexpr uint8_t B32_VEC_ELM_NUM = 64; constexpr uint8_t B32_BLOCK_ALIGN_NUM = 8; constexpr uint8_t B32_VEC_REPEAT_STRIDE = 8; constexpr uint64_t VEC_REPEAT_BYTES = 256; constexpr int32_t CONST_TWO = 2; constexpr int64_t VALUE_AND_INDEX_NUM = 2; constexpr int64_t BLOCK_BYTES = 32; constexpr int64_t MRG_QUE_0 = 0; constexpr int64_t MRG_QUE_1 = 1; constexpr int64_t MRG_QUE_2 = 2; constexpr int64_t MRG_QUE_3 = 3; constexpr int64_t MRG_BLOCK_2 = 2; constexpr int64_t MRG_BLOCK_3 = 3; constexpr int64_t MRG_BLOCK_4 = 4; template __aicore__ inline void CopyIn(LocalTensor &mmOutUb, LocalTensor &weightsUb, GlobalTensor &mMoutGm, GlobalTensor &weightScaleGm, int64_t MMout_gmoffset, int64_t weights_gmoffset, int64_t groupInner, int64_t s2Inner, int64_t mmUbStride) { AscendC::DataCopyPadExtParams padParams{false, 0, 0, 0}; AscendC::DataCopyExtParams dataCopymMoutParams; dataCopymMoutParams.blockCount = groupInner; dataCopymMoutParams.blockLen = s2Inner * sizeof(float); dataCopymMoutParams.srcStride = 0; dataCopymMoutParams.dstStride = mmUbStride; dataCopymMoutParams.rsv = 0; AscendC::DataCopyPad(mmOutUb, mMoutGm[MMout_gmoffset], dataCopymMoutParams, padParams); AscendC::DataCopyPadExtParams padTParams{false, 0, 0, 0}; AscendC::DataCopyExtParams dataCopyweightParams; dataCopyweightParams.blockCount = 1; dataCopyweightParams.blockLen = groupInner * sizeof(T); dataCopyweightParams.srcStride = 0; dataCopyweightParams.dstStride = 0; dataCopyweightParams.rsv = 0; AscendC::DataCopyPad(weightsUb, weightScaleGm[weights_gmoffset], dataCopyweightParams, padTParams); } template __aicore__ inline void CopyOut(const GlobalTensor &dstGm, const LocalTensor &srcUb, int64_t copyCount) { AscendC::DataCopyParams dataCopyOutyParams; dataCopyOutyParams.blockCount = 1; dataCopyOutyParams.blockLen = copyCount * sizeof(T); dataCopyOutyParams.srcStride = 0; dataCopyOutyParams.dstStride = 0; AscendC::DataCopyPad(dstGm, srcUb, dataCopyOutyParams); } template __aicore__ inline void DoScale(const LocalTensor &reduceCacheBuf, LocalTensor &mmOutUb, LocalTensor &weightsUb, LocalTensor &weightsTUb, LocalTensor &tmpBuff, int64_t groupInner, int64_t s2Inner, int32_t outerGidx) { // cast bfloat16_t to float if constexpr (!IsSameType::value) { AscendC::Cast(weightsUb, weightsTUb, RoundMode::CAST_NONE, groupInner); AscendC::PipeBarrier(); } // weight broadcast: [groupInner, 1] -> [groupInner, 8] AscendC::Brcb(tmpBuff, weightsUb, LICommon::CeilDiv(groupInner, static_cast(B32_BLOCK_ALIGN_NUM)), {1, B32_VEC_REPEAT_STRIDE}); AscendC::PipeBarrier(); // do scale: [groupInner, 8] * [groupInner, s2Inner] uint64_t countPerRepeat = VEC_REPEAT_BYTES / sizeof(float); uint64_t repeatTimes = s2Inner / countPerRepeat; for (int32_t i = 0; i < groupInner; i++) { if (outerGidx == 0) { AscendC::Mul(reduceCacheBuf[i * s2Inner], mmOutUb[i * s2Inner], tmpBuff[i * B32_BLOCK_ALIGN_NUM], countPerRepeat, repeatTimes, {1, 1, 0, B32_VEC_REPEAT_STRIDE, B32_VEC_REPEAT_STRIDE, 0}); } else { AscendC::Mul(mmOutUb[i * s2Inner], mmOutUb[i * s2Inner], tmpBuff[i * B32_BLOCK_ALIGN_NUM], countPerRepeat, repeatTimes, {1, 1, 0, B32_VEC_REPEAT_STRIDE, B32_VEC_REPEAT_STRIDE, 0}); } } if (outerGidx != 0) { AscendC::PipeBarrier(); AscendC::Add(reduceCacheBuf, mmOutUb, reduceCacheBuf, groupInner * s2Inner); } AscendC::PipeBarrier(); } __aicore__ inline uint64_t FindNearestPower2(uint64_t value) { if (value <= CONST_TWO) { return value; } else { const uint64_t pow = 63 - clz(value); return (1 << pow); } } __aicore__ inline void DoReduce(const LocalTensor &srcTensor, LocalTensor &dstTensor, int32_t rNum, int32_t aNum) { if (rNum == 1) { AscendC::Adds(dstTensor, srcTensor, 0, aNum); AscendC::PipeBarrier(); return; } uint32_t dichotomizeAddPow = FindNearestPower2(rNum); uint32_t dichotomizeAddDiffSize = rNum - dichotomizeAddPow; if (dichotomizeAddDiffSize != 0) { AscendC::Add(srcTensor, srcTensor, srcTensor[dichotomizeAddPow * aNum], dichotomizeAddDiffSize * aNum); AscendC::PipeBarrier(); } int32_t nowRows = dichotomizeAddPow; while (nowRows > CONST_TWO) { nowRows = nowRows / CONST_TWO; AscendC::Add(srcTensor, srcTensor, srcTensor[nowRows * aNum], nowRows * aNum); AscendC::PipeBarrier(); } AscendC::Add(dstTensor, srcTensor, srcTensor[aNum], aNum); AscendC::PipeBarrier(); } __aicore__ inline void InitSortOutBuf(const LocalTensor &src, int64_t eleNum) { uint64_t mask1[2] = {0x5555555555555555, 0}; uint64_t mask0[2] = {0xaaaaaaaaaaaaaaaa, 0}; int64_t repeatNum = eleNum / B32_VEC_ELM_NUM; int64_t forLoop = repeatNum / VEC_REPEAT_MAX; int64_t forRemain = repeatNum % VEC_REPEAT_MAX; for (int i = 0; i < forLoop; i++) { AscendC::Duplicate(src.template ReinterpretCast(), NEG_INF, mask1, VEC_REPEAT_MAX, 1, B32_VEC_REPEAT_STRIDE); AscendC::Duplicate(src.template ReinterpretCast(), INVALID_INDEX, mask0, VEC_REPEAT_MAX, 1, B32_VEC_REPEAT_STRIDE); } if (forRemain > 0) { AscendC::Duplicate(src.template ReinterpretCast()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], NEG_INF, mask1, forRemain, 1, B32_VEC_REPEAT_STRIDE); AscendC::Duplicate(src.template ReinterpretCast()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], INVALID_INDEX, mask0, forRemain, 1, B32_VEC_REPEAT_STRIDE); } AscendC::PipeBarrier(); } __aicore__ inline void SortAll(LocalTensor &src, LocalTensor &tmp, int64_t logitsNum) { int64_t sort32Repeats = logitsNum / BLOCK_BYTES; AscendC::Sort32(tmp, src, src[logitsNum].ReinterpretCast(), sort32Repeats); AscendC::PipeBarrier(); int64_t mrgGroups = sort32Repeats; int64_t mrgElements = BLOCK_BYTES; int64_t i = 0; AscendC::LocalTensor srcTensor; AscendC::LocalTensor dstTensor; while (true) { if (i % CONST_TWO == 0) { srcTensor = tmp; dstTensor = src; } else { srcTensor = src; dstTensor = tmp; } AscendC::MrgSort4Info params; params.elementLengths[0] = mrgElements; params.elementLengths[MRG_QUE_1] = mrgElements; params.elementLengths[MRG_QUE_2] = mrgElements; params.elementLengths[MRG_QUE_3] = mrgElements; params.ifExhaustedSuspension = false; params.validBit = 0b1111; AscendC::MrgSortSrcList srcList; srcList.src1 = srcTensor[0]; srcList.src2 = srcTensor[MRG_QUE_1 * VALUE_AND_INDEX_NUM * mrgElements]; srcList.src3 = srcTensor[MRG_QUE_2 * VALUE_AND_INDEX_NUM * mrgElements]; srcList.src4 = srcTensor[MRG_QUE_3 * VALUE_AND_INDEX_NUM * mrgElements]; if (mrgGroups <= MRG_BLOCK_4) { params.repeatTimes = 1; if (mrgGroups == 1) { break; } else if (mrgGroups == MRG_BLOCK_2) { params.validBit = 0b0011; } else if (mrgGroups == MRG_BLOCK_3) { params.validBit = 0b0111; } else if (mrgGroups == MRG_BLOCK_4) { params.validBit = 0b1111; } AscendC::MrgSort(dstTensor, srcList, params); i += 1; break; } else { params.repeatTimes = mrgGroups / MRG_BLOCK_4; AscendC::MrgSort(dstTensor, srcList, params); i += 1; mrgElements = mrgElements * MRG_BLOCK_4; mrgGroups = mrgGroups / MRG_BLOCK_4; } AscendC::PipeBarrier(); } if (i % CONST_TWO == 0) { AscendC::DataCopy(src, tmp, logitsNum * VALUE_AND_INDEX_NUM); AscendC::PipeBarrier(); } } __aicore__ inline void SortAll(LocalTensor &dst, LocalTensor &srcValue, LocalTensor &srcIndex, LocalTensor &tmpTensor, int64_t logitsNum) { int64_t sort32Repeats = logitsNum / BLOCK_BYTES; AscendC::Sort(dst, srcValue, srcIndex, tmpTensor, sort32Repeats); AscendC::PipeBarrier(); } __aicore__ inline void MergeSort(const LocalTensor &mrgDst, int32_t mrgDstNum, LocalTensor &mrgSrc, int32_t mrgSrcNum, LocalTensor &tmpTensor) { AscendC::MrgSort4Info params; params.elementLengths[0] = mrgDstNum; params.elementLengths[1] = mrgSrcNum; params.ifExhaustedSuspension = false; params.validBit = 0b0011; params.repeatTimes = 1; AscendC::MrgSortSrcList srcList; srcList.src1 = mrgDst; srcList.src2 = mrgSrc; AscendC::MrgSort(tmpTensor, srcList, params); AscendC::PipeBarrier(); AscendC::DataCopy(mrgDst, tmpTensor, mrgDstNum * VALUE_AND_INDEX_NUM); AscendC::PipeBarrier(); } __aicore__ inline void MrgBasicBlock(const LocalTensor &dst, const LocalTensor &src, int64_t blockNum, int64_t basicBlockSize) { AscendC::MrgSort4Info params; params.elementLengths[MRG_QUE_0] = basicBlockSize; params.elementLengths[MRG_QUE_1] = basicBlockSize; params.elementLengths[MRG_QUE_2] = basicBlockSize; params.elementLengths[MRG_QUE_3] = basicBlockSize; params.ifExhaustedSuspension = false; if (blockNum == MRG_BLOCK_2) { params.validBit = 0b0011; } else if (blockNum == MRG_BLOCK_3) { params.validBit = 0b0111; } else if (blockNum == MRG_BLOCK_4) { params.validBit = 0b1111; } else { AscendC::DataCopy(dst, src, basicBlockSize * VALUE_AND_INDEX_NUM); return; } AscendC::MrgSortSrcList srcList; srcList.src1 = src[0]; srcList.src2 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_1]; srcList.src3 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_2]; srcList.src4 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_3]; AscendC::MrgSort(dst, srcList, params); } template __aicore__ inline void SparseTopK(const LocalTensor &dst, const LocalTensor &needsMerging, const LocalTensor &tmp, int64_t topk, int64_t mergSize) { if (!needMrg) { AscendC::DataCopy(dst, needsMerging, mergSize * VALUE_AND_INDEX_NUM); return; } AscendC::MrgSort4Info params; params.elementLengths[0] = topk; params.elementLengths[1] = mergSize; params.ifExhaustedSuspension = (topk == mergSize); params.validBit = 0b0011; AscendC::MrgSortSrcList srcList; srcList.src1 = dst; srcList.src2 = needsMerging; AscendC::MrgSort(tmp, srcList, params); AscendC::DataCopy(dst, tmp, topk * VALUE_AND_INDEX_NUM); } __aicore__ inline void ExtractIndex(const LocalTensor &idxULocal, const LocalTensor &sortLocal, int64_t extractNum) { AscendC::GatherMaskParams gatherMaskParams; gatherMaskParams.repeatTimes = Ceil(extractNum * sizeof(float) * VALUE_AND_INDEX_NUM, VEC_REPEAT_BYTES); gatherMaskParams.src0BlockStride = 1; gatherMaskParams.src0RepeatStride = B32_VEC_REPEAT_STRIDE; gatherMaskParams.src1RepeatStride = 0; uint64_t rsvdCnt = 0; uint8_t src1Pattern = 2; AscendC::GatherMask(idxULocal, sortLocal, src1Pattern, false, static_cast(0), gatherMaskParams, rsvdCnt); AscendC::PipeBarrier(); } template __aicore__ inline void SetWaitFlag(HardEvent evt) { event_t eventId = static_cast(GetTPipePtr()->FetchEventID(evt)); AscendC::SetFlag(eventId); AscendC::WaitFlag(eventId); } } // namespace LIServiceVec #endif // LIGHTNING_INDEXER_VECTOR_H