// Adapted from // https://gitee.com/ascend/ascend-transformer-boost // // Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. // This file is a part of the CANN Open Software. // Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). // Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. // #include "kernel/common.h" #include "kernel/iterator.h" #include "kernel/mem.h" #include "kernel/mma.h" #include "kernel/utils.h" #include "kernel/simd.h" #include "kernel/kernel_utils.h" #include "lib/matmul_intf.h" #include "mla_preprocess.h" #include "../op_host/tiling/mla_preprocess_tiling.h" namespace MLAPO_BF16 { template class RopeFp16 { public: __aicore__ inline RopeFp16() : blockIdx_(AscendC::GetBlockIdx()) {} __aicore__ inline void RopeInit(GM_ADDR qGm, AscendC::GlobalTensor &cosGm, AscendC::GlobalTensor &sinGm, AscendC::GlobalTensor &outRopeConcatGm, AscendC::GlobalTensor &outRopeConcatGm2, MlaTilingData &ropeConcatParams) { qGm_.SetGlobalBuffer(reinterpret_cast<__gm__ QkDtype *>(qGm)); this->cosGm_ = cosGm; this->sinGm_ = sinGm; this->outRopeConcatGm_ = outRopeConcatGm; this->outRopeConcatGm2_ = outRopeConcatGm2; headDim = ropeConcatParams.headDim; headNumQ = ropeConcatParams.headNumQ; rotaryCoeff = ropeConcatParams.rotaryCoeff; ntokens = ropeConcatParams.ntokens; realCore = ropeConcatParams.realCore; nlCoreRun = ropeConcatParams.nlCoreRun; lCoreRun = ropeConcatParams.lCoreRun; maxNPerLoopForUb = ropeConcatParams.maxNPerLoopForUb; preCoreLoopTime = ropeConcatParams.preCoreLoopTime; preCoreLoopNLast = ropeConcatParams.preCoreLoopNLast; lastCoreLoopTime = ropeConcatParams.lastCoreLoopTime; lastCoreLoopNLast = ropeConcatParams.lastCoreLoopNLast; concatSize = ropeConcatParams.concatSize; blockIdx_ = (blockIdx_ / 2) * 2 + static_cast(GetSubBlockidx()); loopTime = (blockIdx_ == realCore - 1) ? lastCoreLoopTime : preCoreLoopTime; lastLoopN = (blockIdx_ == realCore - 1) ? lastCoreLoopNLast : preCoreLoopNLast; this->repeatSize_ = 64; // 128 = 256B / sizeof(fp32) this->rotateStride_ = this->headDim / this->rotaryCoeff; headBlockLen = static_cast(this->headDim / ELE_NUM_FP16); headBlockLenFP32 = static_cast(this->headDim / ELE_NUM_FP32); rotaryLen = static_cast(this->rotateStride_ / ELE_NUM_FP32); concatBlockLen = static_cast(this->concatSize / ELE_NUM_FP16); outLineOffset = this->headDim + this->concatSize; uint32_t dataNum = this->headDim * this->maxNPerLoopForUb; dataSizeFp16 = dataNum * sizeof(QkDtype); dataSizeFp32 = dataNum * sizeof(float); uint32_t concatDataSize = this->concatSize * sizeof(QkDtype) * this->maxNPerLoopForUb; } __aicore__ inline void Process() { if (blockIdx_ >= realCore) { return; } uint64_t startCoreLineIndex = this->blockIdx_ * this->nlCoreRun; // [maxNPerLoopForUb,head_dim] 的 neg AscendC::LocalTensor negLocal = buf.GetBuffer(dataSizeFp32 * 4 + dataSizeFp16 * 3); ExpandNeg(negLocal, this->maxNPerLoopForUb); SET_FLAG(MTE3, MTE2, EVENT_ID1); for (uint32_t zz = 0; zz < this->loopTime; ++zz) { uint16_t loopN = (zz == this->loopTime - 1) ? this->lastLoopN : this->maxNPerLoopForUb; uint64_t startHead = startCoreLineIndex + zz * this->maxNPerLoopForUb; uint64_t endHead = startHead + loopN; // move in Q AscendC::LocalTensor inputQ = buf.GetBuffer(0); AscendC::LocalTensor inputQCastFP32 = buf.GetBuffer(dataSizeFp16); AscendC::LocalTensor reverseQ = buf.GetBuffer(dataSizeFp32 + dataSizeFp16); uint64_t qOffset = startHead * 192 + 128; CopyQGenReverseQ(inputQ, inputQCastFP32, reverseQ, qOffset, loopN); // move in cos/sin AscendC::LocalTensor inputCos = buf.GetBuffer(dataSizeFp32 * 2 + dataSizeFp16); AscendC::LocalTensor inputSin = buf.GetBuffer(dataSizeFp32 * 2 + dataSizeFp16 * 2); uint64_t startSinCosHeadIndex = startHead; uint64_t headRemain = startHead % this->headNumQ; uint64_t localStartAddr = 0; if (headRemain != 0) { uint64_t preProcessHeadNum = this->headNumQ - headRemain; uint64_t needToProcesHead = preProcessHeadNum > loopN ? loopN : preProcessHeadNum; CopyCosSin(inputCos, inputSin, localStartAddr, (startSinCosHeadIndex / this->headNumQ) * this->headDim, needToProcesHead); startSinCosHeadIndex += needToProcesHead; localStartAddr += needToProcesHead * this->headDim; } if (startSinCosHeadIndex < endHead) { uint64_t startSinCosIndex = startSinCosHeadIndex / this->headNumQ; uint64_t endSinCosIndex = (endHead + this->headNumQ - 1) / this->headNumQ; for (uint32_t index = startSinCosIndex; index < endSinCosIndex; ++index) { uint32_t repeatNum = index == endSinCosIndex - 1 ? endHead - index * this->headNumQ : this->headNumQ; CopyCosSin(inputCos, inputSin, localStartAddr, index * this->headDim, repeatNum); localStartAddr += this->headDim * this->headNumQ; } } AscendC::LocalTensor inputCosCastFP32 = buf.GetBuffer(dataSizeFp32 * 2 + dataSizeFp16 * 3); AscendC::LocalTensor inputSinCastFP32 = buf.GetBuffer(dataSizeFp32 * 3 + dataSizeFp16 * 3); AscendC::Cast(inputCosCastFP32, inputCos, AscendC::RoundMode::CAST_NONE, loopN * this->headDim); AscendC::Cast(inputSinCastFP32, inputSin, AscendC::RoundMode::CAST_NONE, loopN * this->headDim); AscendC::PipeBarrier(); uint32_t repeatTime = this->headDim * loopN; AscendC::Mul(inputQCastFP32, inputCosCastFP32, inputQCastFP32, repeatTime); AscendC::Mul(reverseQ, negLocal, reverseQ, repeatTime); AscendC::PipeBarrier(); AscendC::Mul(reverseQ, inputSinCastFP32, reverseQ, repeatTime); AscendC::PipeBarrier(); AscendC::Add(inputQCastFP32, reverseQ, inputQCastFP32, repeatTime); AscendC::PipeBarrier(); AscendC::Cast(inputQ, inputQCastFP32, AscendC::RoundMode::CAST_RINT, loopN * this->headDim); AscendC::PipeBarrier(); uint64_t outQOffset = startHead * outLineOffset + this->concatSize; uint64_t outQOffset2 = startHead * this->headDim; SET_FLAG(V, MTE3, EVENT_ID1); WAIT_FLAG(V, MTE3, EVENT_ID1); if constexpr (CacheMode == CACHE_MODE_KVCACHE) { AscendC::DataCopy(this->outRopeConcatGm_[outQOffset], inputQ, {loopN, headBlockLen, 0, concatBlockLen}); } else { AscendC::DataCopy(this->outRopeConcatGm2_[outQOffset2], inputQ, loopN * this->headDim); } SET_FLAG(MTE3, MTE2, EVENT_ID1); } WAIT_FLAG(MTE3, MTE2, EVENT_ID1); } // tensor -1 -1 -1 1 1 1 template __aicore__ inline void ExpandNeg(const AscendC::LocalTensor &tempBuf, uint32_t headNumTemp) { for (uint32_t i = 0; i < this->rotateStride_; ++i) { tempBuf.SetValue(i, (BUF_TYPE)-1); tempBuf.SetValue(i + this->rotateStride_, (BUF_TYPE)1); } AscendC::SetFlag(EVENT_ID1); AscendC::WaitFlag(EVENT_ID1); AscendC::Copy(tempBuf[this->headDim], tempBuf, this->headDim, headNumTemp - 1, {1, 1, headBlockLenFP32, 0}); AscendC::PipeBarrier(); } template __aicore__ inline void CopyQGenReverseQ(const AscendC::LocalTensor &tempBufQ, const AscendC::LocalTensor &tempBufQCast, const AscendC::LocalTensor &tempBufRverseQ, uint64_t qOffset, uint16_t loopN) { // move in Q WAIT_FLAG(MTE3, MTE2, EVENT_ID1); AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, 128 / 16, 0}); SET_FLAG(MTE2, V, EVENT_ID1); WAIT_FLAG(MTE2, V, EVENT_ID1); // cast fp32 AscendC::Cast(tempBufQCast, tempBufQ, AscendC::RoundMode::CAST_NONE, loopN * this->headDim); AscendC::PipeBarrier(); // move out reverseQ AscendC::DataCopy(tempBufRverseQ, tempBufQCast[this->rotateStride_], {loopN, rotaryLen, rotaryLen, rotaryLen}); AscendC::DataCopy(tempBufRverseQ[this->rotateStride_], tempBufQCast, {loopN, rotaryLen, rotaryLen, rotaryLen}); AscendC::PipeBarrier(); } template __aicore__ inline void CopyCosSin(const AscendC::LocalTensor &tempBufCos, const AscendC::LocalTensor &tempBufSin, uint64_t localStartAddr, uint64_t gmStartAddr, uint64_t repeatNum) { AscendC::DataCopy(tempBufCos[localStartAddr], this->cosGm_[gmStartAddr], {1, headBlockLen, 0, 0}); AscendC::DataCopy(tempBufSin[localStartAddr], this->sinGm_[gmStartAddr], {1, headBlockLen, 0, 0}); SET_FLAG(MTE2, V, EVENT_ID1); WAIT_FLAG(MTE2, V, EVENT_ID1); AscendC::Copy(tempBufCos[localStartAddr + this->headDim], tempBufCos[localStartAddr], this->headDim, repeatNum - 1, {1, 1, headBlockLen, 0}); AscendC::Copy(tempBufSin[localStartAddr + this->headDim], tempBufSin[localStartAddr], this->headDim, repeatNum - 1, {1, 1, headBlockLen, 0}); AscendC::PipeBarrier(); } private: AsdopsBuffer buf; AscendC::GlobalTensor qGm_; AscendC::GlobalTensor cosGm_; AscendC::GlobalTensor sinGm_; AscendC::GlobalTensor outRopeConcatGm_; AscendC::GlobalTensor outRopeConcatGm2_; uint32_t repeatSize_{0}; uint32_t rotateStride_{0}; // this->headDim / rope conf uint32_t headDim; uint32_t headNumQ; uint32_t rotaryCoeff; uint32_t ntokens; uint32_t realCore; uint32_t nlCoreRun; uint32_t lCoreRun; uint32_t maxNPerLoopForUb; uint32_t preCoreLoopTime; uint32_t preCoreLoopNLast; uint32_t lastCoreLoopTime; uint32_t lastCoreLoopNLast; uint32_t concatSize; uint32_t blockIdx_; uint32_t loopTime{0}; uint32_t lastLoopN{0}; uint32_t dataSizeFp32; uint32_t dataSizeFp16; uint16_t headBlockLen{0}; uint16_t headBlockLenFP32{0}; uint16_t rotaryLen{0}; uint16_t concatBlockLen{0}; uint64_t outLineOffset{0}; }; __aicore__ inline void ReduceSumCustom(const AscendC::LocalTensor &dst_local, const AscendC::LocalTensor &src_local, const AscendC::LocalTensor &work_local, int32_t count) { #ifdef __DAV_C220_VEC__ uint64_t mask = NUM_PER_REP_FP32; int32_t repeatTimes = count / NUM_PER_REP_FP32; int32_t tailCount = count % NUM_PER_REP_FP32; int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32; AscendC::BinaryRepeatParams repeatParams; repeatParams.src0RepStride = AscendC::ONE_REPEAT_BYTE_SIZE / AscendC::ONE_BLK_SIZE; repeatParams.src0BlkStride = 1; repeatParams.src1RepStride = 0; repeatParams.src1BlkStride = 1; repeatParams.dstRepStride = 0; repeatParams.dstBlkStride = 1; Duplicate(work_local, ZERO, NUM_PER_REP_FP32); AscendC::PipeBarrier(); if (likely(repeatTimes > 0)) { Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams); AscendC::PipeBarrier(); } if (unlikely(tailCount != 0)) { Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams); AscendC::PipeBarrier(); } AscendC::AscendCUtils::SetMask(NUM_PER_REP_FP32); cadd_v(dst_local, // dst work_local, // src 1, // repeat 0, // dstRepeatStride 1, // srcBlockStride 0); // srcRepeatStride AscendC::PipeBarrier(); #endif } template class Quant { public: __aicore__ inline Quant() {} __aicore__ inline void Init(AscendC::GlobalTensor &quantScaleGmTensor, AscendC::GlobalTensor &quantOffsetGmTensor, GM_ADDR perTokenDescaleGm, GM_ADDR perChannelDescaleGm, GM_ADDR gmInput, GM_ADDR gmOutput, uint32_t stride, uint32_t num_col, uint64_t gm_offset, uint64_t gm_out_offset, uint32_t row_work_, const MlaTilingData &mlaParams_) { this->quantScaleGmTensor = quantScaleGmTensor; this->quantOffsetGmTensor = quantOffsetGmTensor; this->perTokenDescaleGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(perTokenDescaleGm)); this->perChannelDescaleGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(perChannelDescaleGm)); if constexpr (!NEED_DEQUANT) { inputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(gmInput)); } else { mmGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmInput)); } outputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(gmOutput)); num_col_ = num_col; quantMin_ = -128; this->num_row_ = mlaParams_.n; this->row_work = row_work; this->row_work_ = row_work_; gm_offset_ = gm_offset; gm_out_offset_ = gm_out_offset; num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; num_col_align_f16 = (num_col_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; num_col_align_f32 = (num_col_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; input_stride_ = stride; num_col_align_withStride_int8 = (num_col_ - input_stride_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; num_col_align_withStride_fp16 = (num_col_ - input_stride_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; num_col_align_withStride_fp32 = (num_col_ - input_stride_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; } __aicore__ inline void Launch(const AscendC::LocalTensor &dstTensor, const AscendC::LocalTensor &srcTensor, const AscendC::LocalTensor &quantScaleTensor, const AscendC::LocalTensor &quantOffsetTensor, const AscendC::LocalTensor &res1Tensor, const AscendC::LocalTensor &res3Tensor) { this->dstTensor = dstTensor; this->srcTensor = srcTensor; this->fp32_xy = res1Tensor; this->buf = res3Tensor; AscendC::LocalTensor g = buf[OFFSET_GAMMA * num_col_align_withStride_fp32]; // 0 AscendC::LocalTensor sqx = buf[OFFSET_SQX * num_col_align_withStride_fp32]; // 1 AscendC::LocalTensor work = buf[OFFSET_SUM * num_col_align_withStride_fp32]; // 2 AscendC::LocalTensor abs = buf[OFFSET_ABS * num_col_align_withStride_fp32]; // 3 AscendC::LocalTensor sum = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32]; // 4 AscendC::LocalTensor max = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 8]; // 5 AscendC::LocalTensor perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; // 6 SET_FLAG(MTE2, V, EVENT_ID1); if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { AscendC::DataCopy(quantScaleTensor, quantScaleGmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); AscendC::DataCopy(quantOffsetTensor, quantOffsetGmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); } if constexpr (NEED_DEQUANT) { mmTensor = buf.ReinterpretCast()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE]; perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2]; AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); } if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { if (std::is_same::value) { SET_FLAG(MTE2, V, EVENT_ID0); WAIT_FLAG(MTE2, V, EVENT_ID0); Cast(g, quantScaleTensor, AscendC::RoundMode::CAST_NONE, 1); AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); input_scale_ = 1 / (float)(g.GetValue(0)); input_offset_ = (float)(quantOffsetTensor.GetValue(0)); } else { SET_FLAG(MTE2, S, EVENT_ID0); WAIT_FLAG(MTE2, S, EVENT_ID0); input_scale_ = 1 / (float)(quantScaleTensor.GetValue(0)); input_offset_ = (float)(quantOffsetTensor.GetValue(0)); } AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); } WAIT_FLAG(MTE2, V, EVENT_ID1); uint64_t pid = 0; SET_FLAG(MTE3, MTE2, EVENT_ID0); while (pid < row_work_) { uint64_t offset = pid * num_col_; uint64_t outOffset = pid * (num_col_ - input_stride_); WAIT_FLAG(MTE3, MTE2, EVENT_ID0); if constexpr (!NEED_DEQUANT) { AscendC::DataCopy(srcTensor, inputGmTensor[gm_offset_ + offset], AscendC::DataCopyParams(1, num_col_ / BLOCK_SIZE_16, 0, 0)); SET_FLAG(MTE2, V, EVENT_ID0); WAIT_FLAG(MTE2, V, EVENT_ID0); } else { /* Dequant start */ AscendC::DataCopy(mmTensor, mmGmTensor[gm_offset_ + offset], AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); // 2112 SET_FLAG(MTE2, V, EVENT_ID0); WAIT_FLAG(MTE2, V, EVENT_ID0); AscendC::Cast(mmTensor.ReinterpretCast(), mmTensor, AscendC::RoundMode::CAST_NONE, num_col_); AscendC::PipeBarrier(); AscendC::Mul(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), deScaleTensor, num_col_); SET_FLAG(V, MTE2, EVENT_ID0); WAIT_FLAG(V, MTE2, EVENT_ID0); gm_to_ub_align(perTokenDescaleTensor, perTokenDescaleGmTensor[pid], 0, // sid 1, // nBurst sizeof(float), // lenBurst 0, // leftPaddingNum 0, // rightPaddingNum 0, // srcGap 0 // dstGap ); SET_FLAG(MTE2, S, EVENT_ID0); WAIT_FLAG(MTE2, S, EVENT_ID0); float perTokenDescale = perTokenDescaleTensor.GetValue(0); SET_FLAG(S, V, EVENT_ID0); WAIT_FLAG(S, V, EVENT_ID0); AscendC::Muls(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), perTokenDescale, num_col_); AscendC::PipeBarrier(); AscendC::Cast(srcTensor, mmTensor.ReinterpretCast(), AscendC::RoundMode::CAST_RINT, num_col_); AscendC::PipeBarrier(); } Cast(fp32_xy, srcTensor[input_stride_], AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); AscendC::PipeBarrier(); /* Quant start */ if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { Muls(fp32_xy, fp32_xy, input_scale_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); AscendC::PipeBarrier(); Adds(fp32_xy, fp32_xy, input_offset_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); AscendC::PipeBarrier(); } else if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { Abs(abs, fp32_xy, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); AscendC::PipeBarrier(); ReduceMax(max, abs, work, num_col_ - input_stride_); AscendC::PipeBarrier(); float scaleOut = max.GetValue(0) / 127; SET_FLAG(S, V, EVENT_ID0); WAIT_FLAG(S, V, EVENT_ID0); Muls(fp32_xy, fp32_xy, (float)(1 / scaleOut), REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); AscendC::PipeBarrier(); perTokenDescaleTensor.SetValue(0, scaleOut); SET_FLAG(S, MTE3, EVENT_ID0); WAIT_FLAG(S, MTE3, EVENT_ID0); if constexpr (!NEED_DEQUANT) { ub_to_gm_align(perTokenDescaleGmTensor[pid], perTokenDescaleTensor, 0, 1, // nBurst 1 * sizeof(float), // lenBurst 0, // leftPaddingNum 0, // rightPaddingNum 0, // srcGap 0 // dstGap ); } else { ub_to_gm_align(perTokenDescaleGmTensor[num_row_ + pid], perTokenDescaleTensor, 0, 1, // nBurst 1 * sizeof(float), // lenBurst 0, // leftPaddingNum 0, // rightPaddingNum 0, // srcGap 0 // dstGap ); } SET_FLAG(MTE3, V, EVENT_ID0); WAIT_FLAG(MTE3, V, EVENT_ID0); } AscendC::LocalTensor tmpfp16 = buf.ReinterpretCast()[OFFSET_SUM * num_col_align_withStride_fp32 * 2]; CastFrom32To16(tmpfp16, fp32_xy, num_col_align_withStride_fp32); AscendC::PipeBarrier(); CastFromF16ToI8(dstTensor, tmpfp16, quantMin_, num_col_align_withStride_fp16); AscendC::PipeBarrier(); SET_FLAG(V, MTE3, EVENT_ID0); WAIT_FLAG(V, MTE3, EVENT_ID0); AscendC::DataCopy(outputGmTensor[gm_out_offset_ + outOffset], dstTensor, AscendC::DataCopyParams(1, (num_col_ - input_stride_) / 32, 0, 0)); SET_FLAG(MTE3, MTE2, EVENT_ID0); ++pid; } WAIT_FLAG(MTE3, MTE2, EVENT_ID0); } private: AscendC::LocalTensor dstTensor; AscendC::LocalTensor srcTensor; AscendC::LocalTensor fp32_xy; AscendC::LocalTensor buf; AscendC::LocalTensor mmTensor; AscendC::LocalTensor deScaleTensor; AscendC::GlobalTensor quantScaleGmTensor; AscendC::GlobalTensor quantOffsetGmTensor; AscendC::GlobalTensor inputGmTensor; AscendC::GlobalTensor outputGmTensor; AscendC::GlobalTensor perTokenDescaleGmTensor; AscendC::GlobalTensor perChannelDescaleGmTensor; AscendC::GlobalTensor mmGmTensor; uint32_t num_col_{0}; // input columns uint32_t num_row_{0}; // input rows uint32_t row_work_{0}; // rows need process uint32_t row_work{0}; // rows need process uint32_t row_step_{0}; // rows move in once uint32_t row_tail_{0}; // rows move in last time uint64_t gm_offset_{0}; // GM data offset uint64_t gm_out_offset_{0}; // GM data offset float avg_factor_{1.0}; // 1/num_col_ float input_scale_{1.0}; float input_offset_{0}; int32_t input_stride_{0}; float epsilon_{1e-12f}; uint32_t num_col_align_int8{0}; uint32_t num_col_align_f16{0}; uint32_t num_col_align_f32{0}; uint32_t num_col_align_f32_long{0}; uint32_t num_col_align_withStride_int8{0}; uint32_t num_col_align_withStride_fp16{0}; uint32_t num_col_align_withStride_fp32{0}; uint32_t num_col_temp; half quantMin_{-128}; uint32_t num_slice_{0}; uint32_t tail_size_{0}; uint32_t tail_copy_{0}; }; template class RmsNormQuant { public: __aicore__ inline RmsNormQuant() {} __aicore__ inline void Init(AscendC::GlobalTensor &gammaGmTensor, AscendC::GlobalTensor &betaGmTensor, AscendC::GlobalTensor &quantScaleGmTensor, AscendC::GlobalTensor &quantOffsetGmTensor, GM_ADDR perTokenDescaleGm, GM_ADDR perChannelDescaleGm, GM_ADDR gmInput, GM_ADDR gmOutput, uint32_t stride, uint32_t num_col, float avg_factor, uint64_t gm_offset, uint64_t gm_out_offset, uint32_t row_work_, const MlaTilingData &mlaParams_) { this->gammaGmTensor = gammaGmTensor; this->betaGmTensor = betaGmTensor; this->quantScaleGmTensor = quantScaleGmTensor; this->quantOffsetGmTensor = quantOffsetGmTensor; this->perTokenDescaleGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(perTokenDescaleGm)); this->perChannelDescaleGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(perChannelDescaleGm)); if constexpr (!NEED_DEQUANT) { inputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(gmInput)); } else { mmGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmInput)); } outputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(gmOutput)); num_col_ = num_col; avg_factor_ = avg_factor; epsilon_ = 1e-6; quantMin_ = -128; this->num_row_ = mlaParams_.n; this->row_work = row_work; this->row_work_ = row_work_; gm_offset_ = gm_offset; gm_out_offset_ = gm_out_offset; num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; num_col_align_f16 = (num_col_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; num_col_align_f32 = (num_col_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; input_stride_ = stride; num_col_align_withStride_int8 = (num_col_ - input_stride_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; num_col_align_withStride_fp16 = (num_col_ - input_stride_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; num_col_align_withStride_fp32 = (num_col_ - input_stride_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; } __aicore__ inline void Launch(const AscendC::LocalTensor &dstTensor, const AscendC::LocalTensor &srcTensor, const AscendC::LocalTensor &gammaTensor, const AscendC::LocalTensor &betaTensor, const AscendC::LocalTensor &quantScaleTensor, const AscendC::LocalTensor &quantOffsetTensor, const AscendC::LocalTensor &res1Tensor, const AscendC::LocalTensor &res3Tensor) { this->dstTensor = dstTensor; this->srcTensor = srcTensor; this->gammaTensor = gammaTensor; this->betaTensor = betaTensor; this->fp32_xy = res1Tensor; this->buf = res3Tensor; AscendC::LocalTensor g = buf[OFFSET_GAMMA * num_col_align_withStride_fp32]; // 0 AscendC::LocalTensor sqx = buf[OFFSET_SQX * num_col_align_withStride_fp32]; // 1 AscendC::LocalTensor work = buf[OFFSET_SUM * num_col_align_withStride_fp32]; // 2 AscendC::LocalTensor abs = buf[OFFSET_ABS * num_col_align_withStride_fp32]; // 3 AscendC::LocalTensor sum = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32]; // 4 AscendC::LocalTensor max = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 8]; // 5 AscendC::LocalTensor perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; // 6 AscendC::DataCopy(gammaTensor, gammaGmTensor, AscendC::DataCopyParams(1, (num_col_ - input_stride_) / BLOCK_SIZE_16, 0, 0)); AscendC::DataCopy(betaTensor, betaGmTensor, AscendC::DataCopyParams(1, (num_col_ - input_stride_) / BLOCK_SIZE_16, 0, 0)); SET_FLAG(MTE2, V, EVENT_ID1); if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { AscendC::DataCopy(quantScaleTensor, quantScaleGmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); AscendC::DataCopy(quantOffsetTensor, quantOffsetGmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); } if constexpr (NEED_DEQUANT) { mmTensor = buf.ReinterpretCast()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE]; perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2]; AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); } if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { if (std::is_same::value) { SET_FLAG(MTE2, V, EVENT_ID0); WAIT_FLAG(MTE2, V, EVENT_ID0); Cast(g, quantScaleTensor, AscendC::RoundMode::CAST_NONE, 1); AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); input_scale_ = 1 / (float)(g.GetValue(0)); input_offset_ = (float)(quantOffsetTensor.GetValue(0)); } else { SET_FLAG(MTE2, S, EVENT_ID0); WAIT_FLAG(MTE2, S, EVENT_ID0); input_scale_ = 1 / (float)(quantScaleTensor.GetValue(0)); input_offset_ = (float)(quantOffsetTensor.GetValue(0)); } AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); } WAIT_FLAG(MTE2, V, EVENT_ID1); Cast(buf[OFFSET_GAMMA * num_col_align_withStride_fp32], gammaTensor, AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); AscendC::PipeBarrier(); uint64_t pid = 0; SET_FLAG(MTE3, MTE2, EVENT_ID0); while (pid < row_work_) { uint64_t offset = pid * num_col_; uint64_t outOffset = pid * (num_col_ - input_stride_); WAIT_FLAG(MTE3, MTE2, EVENT_ID0); if constexpr (!NEED_DEQUANT) { AscendC::DataCopy(srcTensor, inputGmTensor[gm_offset_ + offset], AscendC::DataCopyParams(1, num_col_ / BLOCK_SIZE_16, 0, 0)); SET_FLAG(MTE2, V, EVENT_ID0); WAIT_FLAG(MTE2, V, EVENT_ID0); } else { /* Dequant start */ AscendC::DataCopy(mmTensor, mmGmTensor[gm_offset_ + offset], AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); // 2112 SET_FLAG(MTE2, V, EVENT_ID0); WAIT_FLAG(MTE2, V, EVENT_ID0); AscendC::Cast(mmTensor.ReinterpretCast(), mmTensor, AscendC::RoundMode::CAST_NONE, num_col_); AscendC::PipeBarrier(); AscendC::Mul(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), deScaleTensor, num_col_); SET_FLAG(V, MTE2, EVENT_ID0); WAIT_FLAG(V, MTE2, EVENT_ID0); gm_to_ub_align(perTokenDescaleTensor, perTokenDescaleGmTensor[pid], 0, // sid 1, // nBurst sizeof(float), // lenBurst 0, // leftPaddingNum 0, // rightPaddingNum 0, // srcGap 0 // dstGap ); SET_FLAG(MTE2, S, EVENT_ID0); WAIT_FLAG(MTE2, S, EVENT_ID0); float perTokenDescale = perTokenDescaleTensor.GetValue(0); SET_FLAG(S, V, EVENT_ID0); WAIT_FLAG(S, V, EVENT_ID0); AscendC::Muls(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), perTokenDescale, num_col_); AscendC::PipeBarrier(); AscendC::Cast(srcTensor, mmTensor.ReinterpretCast(), AscendC::RoundMode::CAST_RINT, num_col_); AscendC::PipeBarrier(); } Cast(fp32_xy, srcTensor[input_stride_], AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); AscendC::PipeBarrier(); Mul(sqx, fp32_xy, fp32_xy, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); AscendC::PipeBarrier(); Muls(sqx, sqx, avg_factor_, num_col_ - input_stride_); AscendC::PipeBarrier(); ReduceSumCustom(sum, sqx, work, num_col_ - input_stride_); AscendC::PipeBarrier(); Adds(sum, sum, epsilon_, 1); AscendC::PipeBarrier(); Sqrt(sum, sum, 1); SET_FLAG(V, S, EVENT_ID0); WAIT_FLAG(V, S, EVENT_ID0); float factor = 1 / sum.GetValue(0); SET_FLAG(S, V, EVENT_ID0); WAIT_FLAG(S, V, EVENT_ID0); Muls(fp32_xy, fp32_xy, factor, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); AscendC::PipeBarrier(); Mul(fp32_xy, fp32_xy, g, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); AscendC::PipeBarrier(); if constexpr (WITH_BETA) { AscendC::LocalTensor b = this->betaTensor; Cast(work, b, AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); AscendC::PipeBarrier(); Add(fp32_xy, fp32_xy, work, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); AscendC::PipeBarrier(); } /* Quant start */ if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { Muls(fp32_xy, fp32_xy, input_scale_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); AscendC::PipeBarrier(); Adds(fp32_xy, fp32_xy, input_offset_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); AscendC::PipeBarrier(); } else if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { Abs(abs, fp32_xy, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); AscendC::PipeBarrier(); ReduceMax(max, abs, work, num_col_ - input_stride_); AscendC::PipeBarrier(); float scaleOut = max.GetValue(0) / 127; SET_FLAG(S, V, EVENT_ID0); WAIT_FLAG(S, V, EVENT_ID0); Muls(fp32_xy, fp32_xy, (float)(1 / scaleOut), REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); AscendC::PipeBarrier(); perTokenDescaleTensor.SetValue(0, scaleOut); SET_FLAG(S, MTE3, EVENT_ID0); WAIT_FLAG(S, MTE3, EVENT_ID0); if constexpr (!NEED_DEQUANT) { ub_to_gm_align(perTokenDescaleGmTensor[pid], perTokenDescaleTensor, 0, 1, // nBurst 1 * sizeof(float), // lenBurst 0, // leftPaddingNum 0, // rightPaddingNum 0, // srcGap 0 // dstGap ); } else { ub_to_gm_align(perTokenDescaleGmTensor[num_row_ + pid], perTokenDescaleTensor, 0, 1, // nBurst 1 * sizeof(float), // lenBurst 0, // leftPaddingNum 0, // rightPaddingNum 0, // srcGap 0 // dstGap ); } SET_FLAG(MTE3, V, EVENT_ID0); WAIT_FLAG(MTE3, V, EVENT_ID0); } AscendC::LocalTensor tmpfp16 = buf.ReinterpretCast()[OFFSET_SUM * num_col_align_withStride_fp32 * 2]; CastFrom32To16(tmpfp16, fp32_xy, num_col_align_withStride_fp32); AscendC::PipeBarrier(); CastFromF16ToI8(dstTensor, tmpfp16, quantMin_, num_col_align_withStride_fp16); AscendC::PipeBarrier(); SET_FLAG(V, MTE3, EVENT_ID0); WAIT_FLAG(V, MTE3, EVENT_ID0); AscendC::DataCopy(outputGmTensor[gm_out_offset_ + outOffset], dstTensor, AscendC::DataCopyParams(1, (num_col_ - input_stride_) / 32, 0, 0)); SET_FLAG(MTE3, MTE2, EVENT_ID0); ++pid; } WAIT_FLAG(MTE3, MTE2, EVENT_ID0); } private: AscendC::LocalTensor dstTensor; AscendC::LocalTensor srcTensor; AscendC::LocalTensor gammaTensor; AscendC::LocalTensor betaTensor; AscendC::LocalTensor fp32_xy; AscendC::LocalTensor buf; AscendC::LocalTensor mmTensor; AscendC::LocalTensor deScaleTensor; AscendC::GlobalTensor gammaGmTensor; AscendC::GlobalTensor betaGmTensor; AscendC::GlobalTensor quantScaleGmTensor; AscendC::GlobalTensor quantOffsetGmTensor; AscendC::GlobalTensor inputGmTensor; AscendC::GlobalTensor outputGmTensor; AscendC::GlobalTensor perTokenDescaleGmTensor; AscendC::GlobalTensor perChannelDescaleGmTensor; AscendC::GlobalTensor mmGmTensor; uint32_t num_col_{0}; uint32_t num_row_{0}; uint32_t row_work_{0}; uint32_t row_work{0}; uint32_t row_step_{0}; uint32_t row_tail_{0}; uint64_t gm_offset_{0}; uint64_t gm_out_offset_{0}; float avg_factor_{1.0}; float input_scale_{1.0}; float input_offset_{0}; int32_t input_stride_{0}; float epsilon_{1e-12f}; uint32_t num_col_align_int8{0}; uint32_t num_col_align_f16{0}; uint32_t num_col_align_f32{0}; uint32_t num_col_align_f32_long{0}; uint32_t num_col_align_withStride_int8{0}; uint32_t num_col_align_withStride_fp16{0}; uint32_t num_col_align_withStride_fp32{0}; uint32_t num_col_temp; half quantMin_{-128}; uint32_t num_slice_{0}; uint32_t tail_size_{0}; uint32_t tail_copy_{0}; }; template class EinSumQuant { public: __aicore__ explicit EinSumQuant() {} __aicore__ __force_inline__ void Init(GM_ADDR einSumOutGm, GM_ADDR scaleGm, GM_ADDR quantOutGm, const MlaTilingData &tilingData) { einSumOutGm_.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(einSumOutGm)); scaleGm_.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(scaleGm)); quantOutGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOutGm)); headNum = tilingData.esqHeadNum; colNum = tilingData.esqColNum; ubHeadLoop = tilingData.esqUbHeadLoop; headPerLoop = tilingData.esqHeadPerLoop; headTail = tilingData.esqHeadTail; colLoop = tilingData.esqColLoop; colTail = tilingData.esqColTail; currentIdx = (AscendC::GetBlockIdx() / 2) * 2 + GetSubBlockidx(); if (currentIdx < tilingData.esqFrontCore) { batchNum = tilingData.esqFrontCoreBatch; currentCoreStartOffset = currentIdx * tilingData.esqFrontCoreBatch * headNum * colNum; } else { batchNum = tilingData.esqTailCoreBatch; currentCoreStartOffset = (tilingData.esqFrontCore * tilingData.esqFrontCoreBatch + (currentIdx - tilingData.esqFrontCore) * tilingData.esqTailCoreBatch) * headNum * colNum; } calcRepeatStride = static_cast(colNum / ELE_NUM_FP32); padLen = RoundUp(headNum, ELE_NUM_FP16); calcLength = headPerLoop * colNum; // calc tensors' data size(bytes) and block scaleBrcbFp32DataSize = padLen * ELE_NUM_FP32 * sizeof(float); inputDataSize = calcLength * sizeof(InDtype); inputDataBlock = calcLength * sizeof(InDtype) / BLOCK_SIZE_32; inputFp32DataSize = calcLength * sizeof(float); int8OutDataBlcok = calcLength / BLOCK_SIZE_32; headTailDataBlock = headTail * colNum * sizeof(InDtype) / BLOCK_SIZE_32; int8TailOutDataBlock = headTail * colNum / BLOCK_SIZE_32; if (padLen > headNum) { scaleCopyParams = AscendC::DataCopyExtParams(1, static_cast(headNum * sizeof(InDtype)), 0, 0, 0); scalePadParams = AscendC::DataCopyPadExtParams(true, 0, static_cast(padLen - headNum), 0); } } __aicore__ __force_inline__ void Process() { if (batchNum == 0) { return; } // init local tensor scaleBrcbFp32_ = buf.GetBuffer(0); inputTensor_ = buf.GetBuffer(scaleBrcbFp32DataSize); inputFp32_ = buf.GetBuffer(scaleBrcbFp32DataSize + inputDataSize * ROPE_CONCAT_NUM_BUFFER); int8OutTensor_ = buf.GetBuffer( scaleBrcbFp32DataSize + (inputDataSize + inputFp32DataSize) * ROPE_CONCAT_NUM_BUFFER); // scale copy in, cast, brcb[H, 1] --> [H, 8], use input ub space if (headNum == padLen) { AscendC::DataCopy(inputTensor_, scaleGm_, headNum); } else { AscendC::DataCopyPad(inputTensor_, scaleGm_, scaleCopyParams, scalePadParams); } SET_FLAG(MTE2, V, EVENT_ID0); WAIT_FLAG(MTE2, V, EVENT_ID0); AscendC::Cast(inputFp32_, inputTensor_, AscendC::RoundMode::CAST_NONE, padLen); AscendC::PipeBarrier(); AscendC::Brcb(scaleBrcbFp32_, inputFp32_, padLen / ELE_NUM_FP32, {1, 8}); AscendC::PipeBarrier(); uint8_t pingFlag = 0; // batch Loop SET_FLAG(V, MTE2, EVENT_ID0); // input copy in wait vector release ub SET_FLAG(V, MTE2, EVENT_ID1); SET_FLAG(MTE3, V, EVENT_ID0); // quant calc wait last result copyout SET_FLAG(MTE3, V, EVENT_ID1); for (uint32_t batchIdx = 0; batchIdx < batchNum; batchIdx++) { batchOffset = batchIdx * headNum * colNum; // ub Loop for (uint32_t ubLoopIdx = 0; ubLoopIdx < ubHeadLoop; ubLoopIdx++) { scaleBrcbOffset = ubLoopIdx * headPerLoop * ELE_NUM_FP32; inputLoopOffset = ubLoopIdx * headPerLoop * colNum; calcStartOffset = currentCoreStartOffset + batchOffset + inputLoopOffset; calcTmpOffset = pingFlag * calcLength; // input CopyIn and Cast WAIT_FLAG(V, MTE2, pingFlag); AscendC::DataCopy(inputTensor_[calcTmpOffset], einSumOutGm_[calcStartOffset], {1, inputDataBlock, 0, 0}); SET_FLAG(MTE2, V, pingFlag); WAIT_FLAG(MTE2, V, pingFlag); AscendC::Cast(inputFp32_[calcTmpOffset], inputTensor_[calcTmpOffset], AscendC::RoundMode::CAST_NONE, calcLength); AscendC::PipeBarrier(); SET_FLAG(V, MTE2, pingFlag); // quant calc for (uint32_t colIdx = 0; colIdx < colLoop; colIdx++) { colOffset = colIdx * CONST_64; AscendC::Mul(inputFp32_[calcTmpOffset + colOffset], inputFp32_[calcTmpOffset + colOffset], scaleBrcbFp32_[scaleBrcbOffset], CONST_64, headPerLoop, {1, 1, 0, calcRepeatStride, calcRepeatStride, 1}); } AscendC::PipeBarrier(); // quant fp32 --> fp16 --> int8 CastFrom32To16(inputFp32_[calcTmpOffset].template ReinterpretCast(), inputFp32_[calcTmpOffset], calcLength); AscendC::PipeBarrier(); WAIT_FLAG(MTE3, V, pingFlag); // wait last result copy out CastFromF16ToI8(int8OutTensor_[calcTmpOffset], inputFp32_[calcTmpOffset].template ReinterpretCast(), quantMin_, calcLength); AscendC::PipeBarrier(); SET_FLAG(V, MTE3, pingFlag); WAIT_FLAG(V, MTE3, pingFlag); // int8 CopyOut AscendC::DataCopy(quantOutGm_[calcStartOffset], int8OutTensor_[calcTmpOffset], {1, int8OutDataBlcok, 0, 0}); SET_FLAG(MTE3, V, pingFlag); pingFlag = 1 - pingFlag; } // deal with head tail if (headTail > 0) { scaleBrcbOffset = ubHeadLoop * headPerLoop * ELE_NUM_FP32; inputLoopOffset = ubHeadLoop * headPerLoop * colNum; calcStartOffset = currentCoreStartOffset + batchOffset + inputLoopOffset; calcTmpOffset = pingFlag * calcLength; // input CopyIn and Cast WAIT_FLAG(V, MTE2, pingFlag); AscendC::DataCopy(inputTensor_[calcTmpOffset], einSumOutGm_[calcStartOffset], {1, headTailDataBlock, 0, 0}); SET_FLAG(MTE2, V, pingFlag); WAIT_FLAG(MTE2, V, pingFlag); AscendC::Cast(inputFp32_[calcTmpOffset], inputTensor_[calcTmpOffset], AscendC::RoundMode::CAST_NONE, headTail * colNum); AscendC::PipeBarrier(); SET_FLAG(V, MTE2, pingFlag); // quant calc for (uint32_t colIdx = 0; colIdx < colLoop; colIdx++) { colOffset = colIdx * CONST_64; AscendC::Mul(inputFp32_[calcTmpOffset + colOffset], inputFp32_[calcTmpOffset + colOffset], scaleBrcbFp32_[scaleBrcbOffset], CONST_64, headTail, {1, 1, 0, calcRepeatStride, calcRepeatStride, 1}); } AscendC::PipeBarrier(); // quant fp32 --> fp16 --> int8 CastFrom32To16(inputFp32_[calcTmpOffset].template ReinterpretCast(), inputFp32_[calcTmpOffset], headTail * colNum); AscendC::PipeBarrier(); WAIT_FLAG(MTE3, V, pingFlag); // wait last result copy out CastFromF16ToI8(int8OutTensor_[calcTmpOffset], inputFp32_[calcTmpOffset].template ReinterpretCast(), quantMin_, headTail * colNum); AscendC::PipeBarrier(); SET_FLAG(V, MTE3, pingFlag); WAIT_FLAG(V, MTE3, pingFlag); // int8 CopyOut AscendC::DataCopy(quantOutGm_[calcStartOffset], int8OutTensor_[calcTmpOffset], {1, int8TailOutDataBlock, 0, 0}); SET_FLAG(MTE3, V, pingFlag); pingFlag = 1 - pingFlag; } } WAIT_FLAG(V, MTE2, EVENT_ID0); WAIT_FLAG(V, MTE2, EVENT_ID1); WAIT_FLAG(MTE3, V, EVENT_ID0); WAIT_FLAG(MTE3, V, EVENT_ID1); } private: AsdopsBuffer buf; AscendC::GlobalTensor einSumOutGm_; AscendC::GlobalTensor scaleGm_; AscendC::GlobalTensor quantOutGm_; AscendC::LocalTensor scaleBrcbFp32_; AscendC::LocalTensor inputTensor_; AscendC::LocalTensor inputFp32_; AscendC::LocalTensor int8OutTensor_; AscendC::DataCopyExtParams scaleCopyParams; AscendC::DataCopyPadExtParams scalePadParams; // data processed by a single core[batchNum, headNum, colNum] uint32_t batchNum; // The number of batches per kernel processed uint32_t headNum; uint32_t colNum; // Number of columns per row // ub loop uint32_t ubHeadLoop; // The number of times the UB loops through the head. uint32_t headPerLoop; // The number of heads processed per UB cycle uint32_t headTail; // The number of heads last processed // col loop uint32_t colLoop; // The number of calculations in the column direction cycle. uint32_t colTail; // The number of cols last processed uint32_t currentIdx; uint64_t currentCoreStartOffset; uint32_t inputDataSize; // The size of each carry,bytes uint32_t inputFp32DataSize; uint32_t scaleBrcbFp32DataSize; uint16_t inputDataBlock; // The number of blocks brought in per move,bytes uint16_t int8OutDataBlcok; uint16_t headTailDataBlock; uint16_t int8TailOutDataBlock; // gm offset uint64_t inputLoopOffset{0}; uint64_t batchOffset{0}; uint64_t calcStartOffset{0}; // double buffer tmp tensor length uint32_t scaleBrcbOffset{0}; uint32_t calcLength{0}; uint32_t calcTmpOffset{0}; half quantMin_{-128}; uint32_t colOffset{0}; uint32_t padLen; uint8_t calcRepeatStride; }; #ifdef __DAV_C220_CUBE__ struct MatCoord { uint64_t m{0}; uint64_t k{0}; uint64_t n{0}; }; template class PpMatmulEinSum { using AccumDtype = float; template using CopyGmToCbuf = gm_to_l1; using LoadCbufToCa = l1_to_l0_a; using LoadCbufToCb = l1_to_l0_b; using Mad = mmad; using CopyCcToGm = l0c_to_gm; static constexpr uint32_t L0_PINGPONG_BUFFER_LEN = 16384; static constexpr uint32_t L1_PINGPONG_BUFFER_LEN = 131072; static constexpr uint32_t CONST_16 = 16; static constexpr uint32_t CONST_256 = 256; public: __aicore__ explicit PpMatmulEinSum(){}; __aicore__ __force_inline__ void Init(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, const MlaTilingData &mlaParams) { #ifdef __DAV_C220_CUBE__ batch_size = mlaParams.mm3.numBatch; m = mlaParams.mm3.m; k = mlaParams.mm3.k; n = mlaParams.mm3.n; m0 = mlaParams.mm3.m0; k0 = mlaParams.mm3.k0; n0 = mlaParams.mm3.n0; tdim.m = mlaParams.mm3.mLoop; tdim.k = mlaParams.mm3.kLoop; tdim.n = mlaParams.mm3.nLoop; core_loop = mlaParams.mm3.coreLoop; swizzle_cnt = mlaParams.mm3.swizzleCount; num_core = mlaParams.mm3.blockDim; core_idx = AscendC::GetBlockIdx(); ping_flag = 1; gm_a.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmA)); gm_b.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmB)); gm_c.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(gmC)); AsdopsBuffer buf; l1_base_a = buf.GetBuffer(0); l1_base_b = buf.GetBuffer(RoundUp(m0 * k0 * sizeof(InDtype))); l0a_base = buf.GetBuffer(0); l0b_base = buf.GetBuffer(0); #endif return; } __aicore__ __force_inline__ void Process() { #ifdef __DAV_C220_CUBE__ if (block_idx >= num_core) { WaitFlagDev(AIC_MM3_START); return; } using LocalTensor = AscendC::LocalTensor; SET_FLAG(MTE1, MTE2, EVENT_ID0); SET_FLAG(MTE1, MTE2, EVENT_ID1); SET_FLAG(MTE1, MTE2, EVENT_ID2); SET_FLAG(MTE1, MTE2, EVENT_ID3); SET_FLAG(FIX, M, EVENT_ID0); SET_FLAG(M, MTE1, EVENT_ID0); SET_FLAG(M, MTE1, EVENT_ID1); for (uint64_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += num_core) { uint64_t batch_idx = loop_idx / tdim.n / tdim.m; MatCoord tidx{0}; GetBaseBlockIdx(loop_idx, tidx); uint64_t offset_c = tidx.m * m0 * batch_size * (n + splitGapC) + batch_idx * (n + splitGapC) + tidx.n * n0; uint64_t m_actual = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0; uint64_t n_actual = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0; uint64_t m_round = RoundUp(m_actual); uint64_t n_round = RoundUp(n_actual); uint64_t mn_max = m_round > n_round ? m_round : n_round; uint64_t k_part_len = L0_PINGPONG_BUFFER_LEN / mn_max / CONST_16 * CONST_16; uint64_t shuffle_k = en_shuffle_k ? (core_idx % tdim.k) : 0; uint64_t k_actual = (shuffle_k == tdim.k - 1) ? k - shuffle_k * k0 : k0; uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16; LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; event_t event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; if (loop_idx == core_idx) { WaitFlagDev(AIC_MM3_START); // Copy A from gm to l1 buffer uint64_t offset_a = GetOffsetA(batch_idx, tidx.m, shuffle_k); WAIT_FLAG(MTE1, MTE2, event_id); CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual, k_round); SET_FLAG(MTE2, MTE1, event_id); // Copy B from gm to l1 buffer uint64_t offset_b = GetOffsetB(batch_idx, shuffle_k, tidx.n); WAIT_FLAG(MTE1, MTE2, event_id + 2); CopyTileB(l1_buf_b, gm_b[offset_b], k_actual, k_round, n_actual, n_round); SET_FLAG(MTE2, MTE1, event_id + 2); } for (tidx.k = 0; tidx.k < tdim.k; ++tidx.k) { shuffle_k = en_shuffle_k ? (tidx.k + core_idx) % tdim.k : tidx.k; uint64_t k_actual = (shuffle_k == (tdim.k - 1)) ? (k - shuffle_k * k0) : k0; uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16; fdim.k = (k_actual + k_part_len - 1) / k_part_len; LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; if (tidx.k < tdim.k - 1) { uint64_t shuffle_k_next = en_shuffle_k ? (core_idx + tidx.k + 1) % tdim.k : (tidx.k + 1); uint64_t offset_a_next = GetOffsetA(batch_idx, tidx.m, shuffle_k_next); uint64_t offset_b_next = GetOffsetB(batch_idx, shuffle_k_next, tidx.n); uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0; uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; // Preload A from gm to l1 buffer. WAIT_FLAG(MTE1, MTE2, event_id_next); CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual, m_round, k_actual_next, k_round_next); SET_FLAG(MTE2, MTE1, event_id_next); // Preload B from gm to l1 buffer. WAIT_FLAG(MTE1, MTE2, event_id_next + 2); CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual, n_round); SET_FLAG(MTE2, MTE1, event_id_next + 2); } if (tidx.k == tdim.k - 1 && loop_idx + num_core < core_loop) { uint64_t b_idx_next = (loop_idx + num_core) / tdim.n / tdim.m; MatCoord tidx{0}; GetBaseBlockIdx(loop_idx + num_core, tidx); uint64_t shuffle_k_next = en_shuffle_k ? (core_idx % tdim.k) : 0; uint64_t m_actual_next = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0; uint64_t n_actual_next = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0; uint64_t m_round_next = (m_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; uint64_t n_round_next = (n_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0; uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; uint64_t offset_a_next = GetOffsetA(b_idx_next, tidx.m, shuffle_k_next); uint64_t offset_b_next = GetOffsetB(b_idx_next, shuffle_k_next, tidx.n); LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; // Preload A from gm to l1 buffer. WAIT_FLAG(MTE1, MTE2, event_id_next); CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual_next, m_round_next, k_actual_next, k_round_next); SET_FLAG(MTE2, MTE1, event_id_next); // Preload B from gm to l1 buffer. WAIT_FLAG(MTE1, MTE2, event_id_next + 2); CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual_next, n_round_next); SET_FLAG(MTE2, MTE1, event_id_next + 2); } MatCoord fidx{0}; for (fidx.k = 0; fidx.k < fdim.k; ++fidx.k) { uint32_t k0_round = (fidx.k < fdim.k - 1) ? k_part_len : k_round - fidx.k * k_part_len; uint32_t k0_actual = (fidx.k < fdim.k - 1) ? k_part_len : k_actual - fidx.k * k_part_len; auto mte1_mad_ping_flag = 1 - fidx.k % 2; auto mte1_mad_event_id = mte1_mad_ping_flag ? EVENT_ID0 : EVENT_ID1; LocalTensor l0a_buf = l0a_base[(fidx.k & 0b1) * L0_PINGPONG_BUFFER_LEN]; LocalTensor l0b_buf = l0b_base[(fidx.k & 0b1) * L0_PINGPONG_BUFFER_LEN]; // *** load matrix A from L1 to L0A if (fidx.k == 0) { WAIT_FLAG(MTE2, MTE1, event_id); } WAIT_FLAG(M, MTE1, mte1_mad_event_id); if ((m == 1) || (m_actual == 1)) { l1_to_l0_a( l0a_buf, // dst l1_buf_a[fidx.k * k_part_len], // src 0, // mTileCeil CeilDiv(k0_round), // kPartCeil 0, // mSrcStride 1, // kSrcStride 0, // mDstStride 0); // kDstStride } else { LoadCbufToCa(l0a_buf, // l0Tensor l1_buf_a[fidx.k * k_part_len * m_round], // l1Tensor m_round, // mTileCeil k0_round, // kPartCeil 1, // mSrcStride m_round / CONST_16, // kSrcStride k0_round / CONST_16, // mDstStride 1); // kDstStride } if (fidx.k == fdim.k - 1) { SET_FLAG(MTE1, MTE2, event_id); } // *** load matrix B from L1 to L0B if (fidx.k == 0) { WAIT_FLAG(MTE2, MTE1, event_id + 2); } if constexpr (transB) { LoadCbufToCb(l0b_buf, // l0Tensor l1_buf_b[fidx.k * k_part_len * n_round], // l1Tensor n_round, // nTileCeil k0_round, // kPartCeil 1, // nSrcStride n_round / CONST_16, // kSrcStride 1, // nDstStride k0_round / CONST_16); // kDstStride } else { LoadCbufToCb(l0b_buf, // l0Tensor l1_buf_b[fidx.k * k_part_len * CONST_16], // l1Tensor n_round, // nTileCeil k0_round, // kPartCeil k_round / CONST_16, // nSrcStride 1, // kSrcStride 1, // nDstStride n_round / CONST_16); // kDstStride } if (fidx.k == fdim.k - 1) { SET_FLAG(MTE1, MTE2, event_id + 2); } SET_FLAG(MTE1, M, mte1_mad_event_id); WAIT_FLAG(MTE1, M, mte1_mad_event_id); bool init_c = (tidx.k == 0 && fidx.k == 0); if (init_c) { WAIT_FLAG(FIX, M, EVENT_ID0); } Mad(l0c_buf, // c l0a_buf, // a l0b_buf, // b m_actual, // mTileActual n_actual, // nTileActual k0_actual, // kTileActual init_c); // initC PIPE_BARRIER(M); SET_FLAG(M, MTE1, mte1_mad_event_id); } ping_flag = 1 - ping_flag; } SET_FLAG(M, FIX, EVENT_ID0); WAIT_FLAG(M, FIX, EVENT_ID0); // copy from L0C to gm CopyCcToGm(gm_c[offset_c], // dst l0c_buf, // src m_actual, // mTileActual n_actual, // nTileActual m_round, // mTileCeil (n + splitGapC) * batch_size); // nActual SET_FLAG(FIX, M, EVENT_ID0); } WAIT_FLAG(M, MTE1, EVENT_ID0); WAIT_FLAG(M, MTE1, EVENT_ID1); WAIT_FLAG(MTE1, MTE2, EVENT_ID0); WAIT_FLAG(MTE1, MTE2, EVENT_ID1); WAIT_FLAG(MTE1, MTE2, EVENT_ID2); WAIT_FLAG(MTE1, MTE2, EVENT_ID3); WAIT_FLAG(FIX, M, EVENT_ID0); #endif } private: __aicore__ __force_inline__ void GetBaseBlockIdx(uint64_t index, MatCoord &tidx) { uint64_t in_batch_idx = index % (tdim.m * tdim.n); if constexpr (swizzleDirect == 0) { // Zn uint64_t tile_block_loop = (tdim.m + swizzle_cnt - 1) / swizzle_cnt; uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.n); uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.n); uint64_t n_row = swizzle_cnt; if (tile_block_idx == tile_block_loop - 1) { n_row = tdim.m - swizzle_cnt * tile_block_idx; } tidx.m = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row; tidx.n = in_tile_block_idx / n_row; if (tile_block_idx % 2 != 0) { tidx.n = tdim.n - tidx.n - 1; } } else if constexpr (swizzleDirect == 1) { // Nz uint64_t tile_block_loop = (tdim.n + swizzle_cnt - 1) / swizzle_cnt; uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.m); uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.m); uint64_t n_col = swizzle_cnt; if (tile_block_idx == tile_block_loop - 1) { n_col = tdim.n - swizzle_cnt * tile_block_idx; } tidx.m = in_tile_block_idx / n_col; tidx.n = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col; if (tile_block_idx % 2 != 0) { tidx.m = tdim.m - tidx.m - 1; } } return; } __aicore__ __force_inline__ uint64_t GetOffsetA(const uint64_t bIdx, const uint64_t mIdx, const uint64_t kIdx) { return mIdx * m0 * batch_size * (k + splitGapA) + bIdx * (k + splitGapA) + kIdx * k0; } __aicore__ __force_inline__ uint64_t GetOffsetB(const uint64_t bIdx, const uint64_t kIdx, const uint64_t nIdx) { if constexpr (formatB == DataFormat::ND) { if constexpr (transB) { return bIdx * k * n + nIdx * n0 * k + kIdx * k0; } else { return bIdx * k * n + kIdx * k0 * n + nIdx * n0; } } else { if constexpr (transB) { return bIdx * RoundUp(n) * RoundUp(k) + kIdx * k0 * RoundUp(n) + nIdx * n0 * CONST_16; } else { return bIdx * RoundUp(k) * RoundUp(n) + nIdx * n0 * RoundUp(k) + kIdx * k0 * CONST_16; } } } __aicore__ __force_inline__ void CopyTileA(AscendC::LocalTensor &dstTensor, const AscendC::GlobalTensor &srcTensor, const uint64_t m_actual, const uint64_t m_round, const uint64_t k_actual, const uint64_t k_round) { if ((m == 1) || (m_actual == 1)) { CopyGmToCbuf(dstTensor, // dst srcTensor, // src 1, // nTileActual CONST_16, // nTileCeil 1, // nVal k_actual, // kTileActual k_round, // kTileCeil k); // dVal } else { CopyGmToCbuf(dstTensor, // dst srcTensor, // src m_actual, // nTileActual m_round, // nTileCeil m, // nVal k_actual, // dTileActual k_round, // dTileCeil (k + splitGapA) * batch_size); // dVal } } __aicore__ __force_inline__ void CopyTileB(AscendC::LocalTensor &dstTensor, const AscendC::GlobalTensor &srcTensor, const uint64_t k_actual, const uint64_t k_round, const uint64_t n_actual, const uint64_t n_round) { if constexpr (formatB == DataFormat::ND) { if constexpr (transB) { CopyGmToCbuf(dstTensor, // dst srcTensor, // src n_actual, // nTileActual n_round, // nTileCeil n, // nVal k_actual, // dTileActual k_round, // dTileCeil k); // dVal } else { CopyGmToCbuf(dstTensor, // dst srcTensor, // src k_actual, // nTileActual k_round, // nTileCeil k, // nVal n_actual, // dTileActual n_round, // dTileCeil n); // dVal } } else { if constexpr (transB) { CopyGmToCbuf(dstTensor, // dst srcTensor, // src n_actual, // nTileActual n_round, // nTileCeil RoundUp(n), // nVal k_actual, // dTileActual k_round, // dTileCeil RoundUp(k)); // dVal } else { CopyGmToCbuf(dstTensor, // dst srcTensor, // src k_actual, // nTileActual k_round, // nTileCeil RoundUp(k), // nVal n_actual, // dTileActual n_round, // dTileCeil RoundUp(n)); // dVal } } } private: AscendC::GlobalTensor gm_a; AscendC::GlobalTensor gm_b; AscendC::GlobalTensor gm_c; AscendC::LocalTensor l1_base_a; AscendC::LocalTensor l1_base_b; AscendC::LocalTensor l0a_base; AscendC::LocalTensor l0b_base; AscendC::LocalTensor l0c_buf; uint32_t num_core{0}; uint32_t batch_size{0}; uint32_t m{0}; uint32_t k{0}; uint32_t n{0}; uint32_t m0{0}; uint32_t k0{0}; uint32_t n0{0}; MatCoord tdim{0}; MatCoord fdim{0}; uint32_t core_loop{0}; uint32_t swizzle_cnt{1}; uint32_t core_idx{0}; uint32_t en_shuffle_k{0}; uint32_t ping_flag{0}; }; template class PpMatmulW8a8Aic { using InDtype = int8_t; using OutDtype = int32_t; using AccumDtype = int32_t; template using CopyGmToCbuf = gm_to_l1; using LoadCbufToCa = l1_to_l0_a; using LoadCbufToCb = l1_to_l0_b; using Mmad = mmad; using CopyCcToGm = l0c_to_gm; static constexpr uint64_t L0_PINGPONG_BUFFER_LEN = 32768; static constexpr uint64_t L1_PINGPONG_BUFFER_LEN = 262144; static constexpr uint64_t BLOCK_SIZE_16 = 16; static constexpr uint64_t BLOCK_SIZE_32 = 32; static constexpr uint64_t CUBE_MATRIX_SIZE_512 = 512; static constexpr uint64_t CONST_4 = 4; static constexpr uint64_t CONST_8 = 8; static constexpr uint64_t CONST_32 = 32; static constexpr uint64_t CONST_64 = 64; static constexpr uint64_t CONST_128 = 128; public: __aicore__ PpMatmulW8a8Aic() {}; __aicore__ __force_inline__ void Init(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, PpMatmulTilingData &tilingdata, uint32_t mode) { gm_a.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmA)); gm_b.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmB)); gm_c.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(gmC)); batch_size = tilingdata.numBatch; m = tilingdata.m; k = tilingdata.k; n = tilingdata.n; m0 = tilingdata.m0; k0 = tilingdata.k0; n0 = tilingdata.n0; m_loop = tilingdata.mLoop; k_loop = tilingdata.kLoop; n_loop = tilingdata.nLoop; core_loop = tilingdata.coreLoop; swizzle_cnt = tilingdata.swizzleCount; en_shuffle_k = tilingdata.enShuffleK; core_num = tilingdata.blockDim; load_all_Amat_flag = tilingdata.enLoadAllAmat; b0mat_pingpong_buffer_len = tilingdata.b0matPingPongBufferLen; core_idx = AscendC::GetBlockIdx(); ping_flag = 1; MM1_MM2_mode = mode; // MM1 or MM2 InitBuffer(); return; } __aicore__ __force_inline__ uint64_t GetOffsetA(const uint64_t batchIdx, const uint64_t mIdx, uint64_t kIdx) { return batchIdx * m * k + mIdx * m0 * k + kIdx * k0; } __aicore__ __force_inline__ uint64_t GetOffsetB(const uint64_t batchIdx, const uint64_t kIdx, uint64_t nIdx) { if constexpr (formatB == DataFormat::ND) { return batchIdx * k * n + nIdx * n0 * k + kIdx * k0; } else { return batchIdx * RoundUp<16>(n) * RoundUp<32>(k) + kIdx * k0 * RoundUp<16>(n) + nIdx * n0 * CONST_32; } } __aicore__ __force_inline__ void CopyTileA(AscendC::LocalTensor &dstTensor, const AscendC::GlobalTensor &srcTensor, const uint64_t m_actual, const uint64_t m_round, const uint64_t k_actual, const uint64_t k_round) { if ((m == 1) || (m_actual == 1)) { CopyGmToCbuf(dstTensor, // dst srcTensor, // src 1, BLOCK_SIZE_16, 1, k_actual, k_round, k); } else { CopyGmToCbuf(dstTensor, // dst srcTensor, // src m_actual, // nTileActual m_round, // nTileCeil n, // nVal k_actual, // dTileActual k_round, // dTileCeil k); // dVal } } __aicore__ __force_inline__ void CopyTileB(const AscendC::LocalTensor &dstTensor, const AscendC::GlobalTensor &srcTensor, const uint64_t k_actual, const uint64_t k_round, const uint64_t n_actual, const uint64_t n_round) { if constexpr (formatB == DataFormat::ND) { CopyGmToCbuf(dstTensor, // dst srcTensor, // src n_actual, // nTileActual n_round, // nTileCeil n, // nVal k_actual, // dTileActual k_round, // dTileCeil k); // dVal } else { CopyGmToCbuf(dstTensor, // dst srcTensor, // src n_actual, // nTileActual n_round, // nTileCeil RoundUp<16>(n), // nVal k_actual, // dTileActual k_round, // dTileCeil RoundUp<32>(k)); // dVal } } __aicore__ __force_inline__ void PreloadWeight() { if (core_idx < core_num) { uint64_t m_idx = 0; uint64_t n_idx = 0; GetBaseBlockIdx(core_idx, m_idx, n_idx); uint64_t shuffle_k = en_shuffle_k ? core_idx % k_loop : 0; uint64_t offset_b = GetOffsetB(0, shuffle_k, n_idx); uint64_t k_actual = (shuffle_k == k_loop - 1) ? k - shuffle_k * k0 : k0; uint64_t k_round = RoundUp(k_actual); uint64_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; uint64_t n_round = RoundUp(n_actual); CopyTileB(l1_base_b, gm_b[offset_b], k_actual, k_round, n_actual, n_round); } if (core_idx < core_num && k_loop > 1) { uint64_t m_idx = 0; uint64_t n_idx = 0; GetBaseBlockIdx(core_idx, m_idx, n_idx); uint64_t shuffle_k = en_shuffle_k ? (core_idx + 1) % k_loop : 1; uint64_t offset_b = GetOffsetB(0, shuffle_k, n_idx); uint64_t k_actual = (shuffle_k == k_loop - 1) ? k - shuffle_k * k0 : k0; uint64_t k_round = RoundUp(k_actual); uint64_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; uint64_t n_round = RoundUp(n_actual); CopyTileB(l1_base_b[b0mat_pingpong_buffer_len], gm_b[offset_b], k_actual, k_round, n_actual, n_round); } } __aicore__ __force_inline__ void Process(); private: __aicore__ __force_inline__ void InitBuffer() { AsdopsBuffer buf; l1_base_a = buf.template GetBuffer(0); // try load all A matrix uint32_t a_l1_size = RoundUp(m) * RoundUp(k); if (!load_all_Amat_flag) { a_l1_size = RoundUp(m0 * k0); } l1_base_b = l1_base_a[a_l1_size]; l0a_base = buf.template GetBuffer(0); l0b_base = buf.template GetBuffer(0); l0c_buf = buf.template GetBuffer(0); } __aicore__ __force_inline__ void GetBaseBlockIdx(uint64_t index, uint64_t &m_idx, uint64_t &n_idx) { uint64_t in_batch_idx = index % (m_loop * n_loop); if constexpr (swizzleDir == 0) { // Zn uint64_t tile_block_loop = (m_loop + swizzle_cnt - 1) / swizzle_cnt; uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * n_loop); uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * n_loop); uint64_t n_row = swizzle_cnt; if (tile_block_idx == tile_block_loop - 1) { n_row = m_loop - swizzle_cnt * tile_block_idx; } m_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row; n_idx = in_tile_block_idx / n_row; if ((tile_block_idx & 0b1) != 0) { n_idx = n_loop - n_idx - 1; } } else { // Nz uint64_t tile_block_loop = (n_loop + swizzle_cnt - 1) / swizzle_cnt; uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * m_loop); uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * m_loop); uint64_t n_col = swizzle_cnt; if (tile_block_idx == tile_block_loop - 1) { n_col = n_loop - swizzle_cnt * tile_block_idx; } m_idx = in_tile_block_idx / n_col; n_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col; if ((tile_block_idx & 0b1) != 0) { m_idx = m_loop - m_idx - 1; } } return; } private: AscendC::GlobalTensor gm_a; AscendC::GlobalTensor gm_b; AscendC::GlobalTensor gm_c; AscendC::LocalTensor l1_base_a; AscendC::LocalTensor l1_base_b; AscendC::LocalTensor l0a_base; AscendC::LocalTensor l0b_base; AscendC::LocalTensor l0c_buf; uint64_t bias_bt{0}; uint32_t core_num{0}; uint32_t batch_size{0}; uint32_t m{0}; uint32_t k{0}; uint32_t n{0}; uint32_t m0{0}; uint32_t k0{0}; uint32_t n0{0}; uint32_t m_loop{0}; uint32_t n_loop{0}; uint32_t k_loop{0}; uint32_t core_loop{0}; uint32_t core_idx{0}; uint32_t ping_flag{0}; uint32_t swizzle_cnt{1}; uint32_t en_shuffle_k{0}; uint32_t MM1_MM2_mode{0}; uint64_t b0mat_pingpong_buffer_len{0}; bool load_all_Amat_flag{false}; }; template __aicore__ __force_inline__ void PpMatmulW8a8Aic::Process() { using LocalTensor = AscendC::LocalTensor; if (core_idx >= core_num) { if (MM1_MM2_mode == 0) { WaitFlagDev(AIC_MM1_START); } else if (MM1_MM2_mode == 1) { WaitFlagDev(AIC_MM2_START); } return; } SET_FLAG(MTE1, MTE2, EVENT_ID0); SET_FLAG(MTE1, MTE2, EVENT_ID1); SET_FLAG(MTE1, MTE2, EVENT_ID2); SET_FLAG(MTE1, MTE2, EVENT_ID3); SET_FLAG(M, MTE1, EVENT_ID0); SET_FLAG(M, MTE1, EVENT_ID1); SET_FLAG(FIX, M, EVENT_ID0); SET_FLAG(FIX, MTE2, EVENT_ID0); SET_FLAG(MTE1, MTE2, EVENT_ID7); for (uint64_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += core_num) { uint64_t batch_idx = loop_idx / n_loop / m_loop; uint64_t m_idx = 0; uint64_t n_idx = 0; GetBaseBlockIdx(loop_idx, m_idx, n_idx); uint64_t offset_a; uint64_t offset_b; uint64_t offset_bias; uint64_t offset_a_next; uint64_t offset_b_next; uint64_t offset_c = batch_idx * m * n + m_idx * m0 * n + n_idx * n0; uint64_t m_actual = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0; uint64_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; uint64_t m_round = 0; uint64_t n_round = 0; uint64_t shuffle_k = en_shuffle_k ? core_idx % k_loop : 0; uint64_t m_round_16 = RoundUp(m_actual); uint64_t m_round_32 = RoundUp(m_actual); m_round = m_round_16; n_round = RoundUp(n_actual); uint64_t mn_max = m_round > n_round ? m_round : n_round; uint64_t k_part_len = 0; k_part_len = L0_PINGPONG_BUFFER_LEN / mn_max / BLOCK_SIZE_32 * BLOCK_SIZE_32; offset_b = GetOffsetB(batch_idx, shuffle_k, n_idx); offset_bias = batch_idx * n + n_idx * n0; uint64_t k_actual = (shuffle_k == k_loop - 1) ? k - shuffle_k * k0 : k0; uint64_t k_round = RoundUp(k_actual); auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; // Wait after Scalar if (loop_idx == core_idx) { if (MM1_MM2_mode == 0) { WaitFlagDev(AIC_MM1_START); } else if (MM1_MM2_mode == 1) { WaitFlagDev(AIC_MM2_START); } } WAIT_FLAG(MTE1, MTE2, event_id); LocalTensor l1_buf_a = load_all_Amat_flag ? l1_base_a : (ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]); LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[b0mat_pingpong_buffer_len]; if (load_all_Amat_flag) { if (loop_idx == core_idx) { offset_a = GetOffsetA(batch_idx, m_idx, 0); uint64_t k_actual_first = k; uint64_t k_round_first = RoundUp(k_actual_first); CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual_first, k_round_first); } } else { offset_a = GetOffsetA(batch_idx, m_idx, shuffle_k); CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual, k_round); } SET_FLAG(MTE2, MTE1, event_id); WAIT_FLAG(MTE1, MTE2, event_id + CONST_2); // The first weight matrix block is loaded in advance. if (loop_idx != core_idx) { CopyTileB(l1_buf_b, gm_b[offset_b], k_actual, k_round, n_actual, n_round); } SET_FLAG(MTE2, MTE1, event_id + CONST_2); for (uint64_t k_idx = 0; k_idx < k_loop; k_idx++) { shuffle_k = en_shuffle_k ? (k_idx + core_idx) % k_loop : k_idx; uint32_t k_actual = (shuffle_k == (k_loop - 1)) ? (k - shuffle_k * k0) : k0; uint32_t k_round = RoundUp(k_actual); uint32_t k_part_loop = (k_actual + k_part_len - 1) / k_part_len; // --------- load whole A in l1a addr change ------------- LocalTensor l1_buf_a = load_all_Amat_flag ? (l1_base_a[k_idx * m0 * k0 * sizeof(int8_t)]) : (ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]); LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[b0mat_pingpong_buffer_len]; auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; if (k_idx < k_loop - 1) { uint64_t shuffle_k_next = en_shuffle_k ? (core_idx + k_idx + 1) % k_loop : k_idx + 1; offset_b_next = GetOffsetB(batch_idx, shuffle_k_next, n_idx); uint32_t k_actual_next = (shuffle_k_next == (k_loop - 1)) ? (k - shuffle_k_next * k0) : k0; uint32_t k_round_next = RoundUp(k_actual_next); LocalTensor l1_buf_a_next = load_all_Amat_flag ? l1_base_a : ((1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]); LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[b0mat_pingpong_buffer_len]; auto event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; WAIT_FLAG(MTE1, MTE2, event_id_next); if (!load_all_Amat_flag) { offset_a_next = GetOffsetA(batch_idx, m_idx, shuffle_k_next); CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual, m_round, k_actual_next, k_round_next); } SET_FLAG(MTE2, MTE1, event_id_next); WAIT_FLAG(MTE1, MTE2, event_id_next + CONST_2); // The second weight matrix is preloaded. if (loop_idx != core_idx || k_idx != 0) { CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual, n_round); } SET_FLAG(MTE2, MTE1, event_id_next + CONST_2); } for (int k_part_idx = 0; k_part_idx < k_part_loop; k_part_idx++) { uint32_t k0_round = (k_part_idx < k_part_loop - 1) ? k_part_len : k_round - k_part_idx * k_part_len; uint32_t k0_actual = (k_part_idx < k_part_loop - 1) ? k_part_len : k_actual - k_part_idx * k_part_len; auto mte1_mad_ping_flag = 1 - k_part_idx % 2; auto mte1_mad_event_id = mte1_mad_ping_flag ? EVENT_ID0 : EVENT_ID1; AscendC::LocalTensor l0a_buf = l0a_base[(k_part_idx % 2) * L0_PINGPONG_BUFFER_LEN]; AscendC::LocalTensor l0b_buf = l0b_base[(k_part_idx % 2) * L0_PINGPONG_BUFFER_LEN]; // *** load matrix A from L1 to L0A if (k_part_idx == 0) { WAIT_FLAG(MTE2, MTE1, event_id); } WAIT_FLAG(M, MTE1, mte1_mad_event_id); if ((m == 1) || (m_actual == 1)) { l1_to_l0_a( l0a_buf, l1_buf_a[k_part_idx * k_part_len], 0, // mTileCeil CeilDiv(k0_round), // kPartCeil 0, // mSrcStride 1, // kSrcStride 0, // mDstStride 0); // kDstStride } else { LoadCbufToCa(l0a_buf, // l0Tensor l1_buf_a[k_part_idx * k_part_len * m_round], // l1Tensor m_round, // mTileCeil k0_round, // kPartCeil 1, // mSrcStride m_round / BLOCK_SIZE_16, // kSrcStride k0_round / BLOCK_SIZE_32, // mDstStride 1); // kDstStride } if (k_part_idx == k_part_loop - 1) { SET_FLAG(MTE1, MTE2, event_id); } // *** load matrix B from L1 to L0B if (k_part_idx == 0) { WAIT_FLAG(MTE2, MTE1, event_id + CONST_2); } LoadCbufToCb(l0b_buf, // l0Tensor l1_buf_b[k_part_idx * k_part_len * n_round], // l1Tensor n_round, // nTileCeil k0_round, // kPartCeil 1, // nSrcStride n_round / BLOCK_SIZE_16, // kSrcStride 1, // nDstStride k0_round / BLOCK_SIZE_32); // kDstStride if (k_part_idx == k_part_loop - 1) { SET_FLAG(MTE1, MTE2, event_id + CONST_2); } SET_FLAG(MTE1, M, mte1_mad_event_id); WAIT_FLAG(MTE1, M, mte1_mad_event_id); bool init_c = (k_idx == 0 && k_part_idx == 0); if (init_c) { WAIT_FLAG(FIX, M, EVENT_ID0); } Mmad(l0c_buf, l0a_buf, l0b_buf, m_actual, // m n_actual, // n k0_actual, // k init_c); // cmatrixInitVal PIPE_BARRIER(M); SET_FLAG(M, MTE1, mte1_mad_event_id); } ping_flag = 1 - ping_flag; } SET_FLAG(M, FIX, EVENT_ID0); WAIT_FLAG(M, FIX, EVENT_ID0); // copy from L0C to gm CopyCcToGm(gm_c[offset_c], // dst l0c_buf, // src m_actual, // MSize n_actual, // NSize m_round_16, // srcStride n); // dstStride_dst_D SET_FLAG(FIX, M, EVENT_ID0); if constexpr (!withSyncAll) { FftsCrossCoreSync(MMAIC); if ((loop_idx / core_num + 1) % MAX_HW_SYNC_COUNTER == 0) { WaitFlagDev(MMAIV); } } } WAIT_FLAG(MTE1, MTE2, EVENT_ID0); WAIT_FLAG(MTE1, MTE2, EVENT_ID1); WAIT_FLAG(MTE1, MTE2, EVENT_ID2); WAIT_FLAG(MTE1, MTE2, EVENT_ID3); WAIT_FLAG(M, MTE1, EVENT_ID0); WAIT_FLAG(M, MTE1, EVENT_ID1); WAIT_FLAG(FIX, M, EVENT_ID0); WAIT_FLAG(FIX, MTE2, EVENT_ID0); WAIT_FLAG(MTE1, MTE2, EVENT_ID7); } #endif #if defined(__DAV_C220_VEC__) template class PpMatmulW8a8Aiv { using InDtype = int32_t; using ScaleDtype = float; using BiasDtype = int32_t; public: __aicore__ PpMatmulW8a8Aiv() {}; __aicore__ __force_inline__ void Init(GM_ADDR gmInput, GM_ADDR gmOutput, GM_ADDR gmDescale, GM_ADDR gmPerTensorBias, GM_ADDR gmPertokenDescale, const PpMatmulTilingData &gmTilingData) { gmInput_.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmInput)); gmOutput_.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(gmOutput)); gmPerTensorScale_.SetGlobalBuffer(reinterpret_cast<__gm__ ScaleDtype *>(gmDescale)); gmPerTensorBias_.SetGlobalBuffer(reinterpret_cast<__gm__ BiasDtype *>(gmPerTensorBias)); gmPerTokenScale_.SetGlobalBuffer(reinterpret_cast<__gm__ ScaleDtype *>(gmPertokenDescale)); batch_size = gmTilingData.numBatch; m = gmTilingData.m; k = gmTilingData.k; n = gmTilingData.n; m0 = gmTilingData.m0; k0 = gmTilingData.k0; n0 = gmTilingData.n0; m_loop = gmTilingData.mLoop; k_loop = gmTilingData.kLoop; n_loop = gmTilingData.nLoop; core_loop = gmTilingData.coreLoop; swizzle_cnt = gmTilingData.swizzleCount; swizzlDirect = gmTilingData.swizzleDirect; en_shuffle_k = gmTilingData.enShuffleK; AsdopsBuffer buf; ubInput_ = buf.GetBuffer(0); ubTempFp32_ = buf.GetBuffer(94 * 1024); ubOutput_ = buf.GetBuffer(0); ubPerTensorScale_ = buf.GetBuffer(188 * 1024); block_size = BLOCK_SIZE_32; core_num = AscendC::GetBlockNum(); core_idx = AscendC::GetBlockIdx() / 2; ping_flag = 1; } __aicore__ __force_inline__ void GetBlockIdx(uint32_t index, uint32_t &m_idx, uint32_t &n_idx) { uint32_t in_batch_idx = index % (m_loop * n_loop); if (swizzlDirect == 0) { // Zn uint32_t tile_block_loop = (m_loop + swizzle_cnt - 1) / swizzle_cnt; uint32_t tile_block_idx = in_batch_idx / (swizzle_cnt * n_loop); uint32_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * n_loop); uint32_t n_row = swizzle_cnt; if (tile_block_idx == tile_block_loop - 1) { n_row = m_loop - swizzle_cnt * tile_block_idx; } m_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row; n_idx = in_tile_block_idx / n_row; if (tile_block_idx % 2 != 0) { n_idx = n_loop - n_idx - 1; } } else { // Nz uint32_t tile_block_loop = (n_loop + swizzle_cnt - 1) / swizzle_cnt; uint32_t tile_block_idx = in_batch_idx / (swizzle_cnt * m_loop); uint32_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * m_loop); uint32_t n_col = swizzle_cnt; if (tile_block_idx == tile_block_loop - 1) { n_col = n_loop - swizzle_cnt * tile_block_idx; } m_idx = in_tile_block_idx / n_col; n_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col; if (tile_block_idx % 2 != 0) { m_idx = m_loop - m_idx - 1; } } } __aicore__ __force_inline__ void Process(); private: AscendC::GlobalTensor gmPerTensorScale_; AscendC::GlobalTensor gmPerTensorBias_; AscendC::GlobalTensor gmPerTokenScale_; AscendC::GlobalTensor gmInput_; AscendC::GlobalTensor gmOutput_; AscendC::LocalTensor ubInput_; AscendC::LocalTensor ubTempFp32_; AscendC::LocalTensor ubOutput_; AscendC::LocalTensor ubPerTensorScale_; uint32_t core_num{0}; uint32_t batch_size{0}; uint32_t m{0}; uint32_t k{0}; uint32_t n{0}; uint32_t m0{0}; uint32_t k0{0}; uint32_t n0{0}; uint32_t m_loop{0}; uint32_t n_loop{0}; uint32_t k_loop{0}; uint32_t core_loop{0}; uint32_t core_idx{0}; uint32_t ping_flag{0}; uint32_t block_size{0}; uint32_t cube_matrix_size{0}; uint32_t swizzle_cnt{1}; uint32_t en_shuffle_k{0}; uint32_t swizzlDirect{0}; uint64_t L1_PINGPONG_BUFFER_LEN{0}; uint32_t L0AB_PINGPONG_BUFFER_LEN{0}; }; template __aicore__ __force_inline__ void PpMatmulW8a8Aiv::Process() { uint32_t m_idx = 0; uint32_t n_idx = 0; SET_FLAG(V, MTE2, EVENT_ID0); SET_FLAG(MTE3, V, EVENT_ID0); SET_FLAG(MTE3, MTE2, EVENT_ID0); for (uint32_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += core_num) { GetBlockIdx(loop_idx, m_idx, n_idx); uint64_t batch_idx = loop_idx / n_loop / m_loop; uint64_t offsetC = batch_idx * m * n + m_idx * m0 * n + n_idx * n0; uint32_t m_actual = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0; uint32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; uint32_t m_round = RoundUp(m_actual); uint32_t n_round = RoundUp(n_actual); uint32_t n_round_16 = RoundUp(n_actual); uint32_t m_actual_per_vec = m_actual / AscendC::GetTaskRation(); uint32_t m_offset = m + m_idx * m0; if (GetSubBlockidx() != 0) { offsetC += m_actual_per_vec * n; m_offset += m_actual_per_vec; m_actual_per_vec = m_actual - m_actual_per_vec; } if constexpr (!withSyncAll) { if (m_actual_per_vec == 0) { WaitFlagDev(MMAIC); if ((loop_idx / core_num + 1) % MAX_HW_SYNC_COUNTER == 1) { FftsCrossCoreSync(MMAIV); } continue; } } uint64_t offsetScale = batch_idx * n + n_idx * n0; bool aligned_s32 = ((n & 0b111) == 0); // 32B aligned bool aligned_f16 = ((n & 0b1111) == 0); // 32B aligned WAIT_FLAG(V, MTE2, EVENT_ID0); if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { if (aligned_s32) { gm_to_ub(ubPerTensorScale_.ReinterpretCast(), gmPerTensorBias_[offsetScale], 0, // sid 1, // nBurst n_round * sizeof(BiasDtype) / BLOCK_SIZE_32, // lenBurst 0, // srcStride 0); // dstStride } else { gm_to_ub_align(ubPerTensorScale_.ReinterpretCast(), gmPerTensorBias_[offsetScale], 0, // sid 1, // nBurst n_actual * sizeof(float), // lenBurst 0, // leftPaddingNum 0, // rightPaddingNum 0, // srcGap 0); // dstGap } } else { if (aligned_s32) { gm_to_ub(ubPerTensorScale_, gmPerTensorScale_[offsetScale], 0, // sid 1, // nBurst n_round * 4 / BLOCK_SIZE_32, // lenBurst 0, // srcStride 0); // dstStride } else { gm_to_ub_align(ubPerTensorScale_, gmPerTensorScale_[offsetScale], 0, // sid 1, // nBurst n_actual * sizeof(float), // lenBurst 0, // leftPaddingNum 0, // rightPaddingNum 0, // srcGap 0); // dstGap } } if constexpr (!withSyncAll) { WaitFlagDev(MMAIC); } WAIT_FLAG(MTE3, MTE2, EVENT_ID0); if (aligned_s32) { gm_to_ub(ubInput_, gmInput_[offsetC], 0, // sid m_actual_per_vec, // nBurst n_round / 8, // lenBurst (n - n_round) / 8, // srcStride 0 // dstStride ); } else { gm_to_ub_align(ubInput_, gmInput_[offsetC], 0, // sid m_actual_per_vec, // nBurst n_actual * sizeof(int32_t), // lenBurst 0, // leftPaddingNum 0, // rightPaddingNum (n - n_actual) * sizeof(int32_t), // srcGap 0 // dstGap ); } SET_FLAG(MTE2, V, EVENT_ID0); WAIT_FLAG(MTE2, V, EVENT_ID0); WAIT_FLAG(MTE3, V, EVENT_ID0); uint32_t nRepeatCnt = CeilDiv(n_actual); if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { AscendC::SetMaskCount(); AscendC::SetVectorMask(n_round); for (uint32_t i = 0; i < m_actual_per_vec; ++i) { // add_v(ubInput_[i * n_round], // ubInput_[i * n_round], // ubPerTensorScale_.ReinterpretCast(), // (uint8_t)(nRepeatCnt), // repeat // (uint8_t)1, // dstBlockStride // (uint8_t)1, // src0BlockStride // (uint8_t)1, // src1BlockStride // (uint8_t)8, // dstRepeatStride // (uint8_t)8, // src0RepeatStride // (uint8_t)8 // src1RepeatStride // ); AscendC::Add(ubInput_[i * n_round], ubInput_[i * n_round], ubPerTensorScale_.ReinterpretCast(), AscendC::MASK_PLACEHOLDER, 1, AscendC::BinaryRepeatParams((uint8_t)1, (uint8_t)1, (uint8_t)1, (uint8_t)8, (uint8_t)8, (uint8_t)8)); } AscendC::ResetMask(); SetMasknorm(); SET_FLAG(V, MTE2, EVENT_ID0); WAIT_FLAG(V, MTE2, EVENT_ID0); if (aligned_s32) { gm_to_ub(ubPerTensorScale_, gmPerTensorScale_[offsetScale], 0, // sid 1, // nBurst n_round * sizeof(ScaleDtype) / BLOCK_SIZE_32, // lenBurst 0, // srcStride 0 // dstStride ); } else { gm_to_ub_align(ubPerTensorScale_, gmPerTensorScale_[offsetScale], 0, // sid 1, // nBurst n_actual * sizeof(ScaleDtype), // lenBurst 0, // leftPaddingNum 0, // rightPaddingNum 0, // srcGap 0 // dstGap ); } SET_FLAG(MTE2, V, EVENT_ID0); WAIT_FLAG(MTE2, V, EVENT_ID0); } // CASTF32 * f32 tf16 constexpr uint32_t maxRepeat = 255; constexpr uint32_t perRepeatNum = maxRepeat * 64; uint32_t loopCnt = (m_actual_per_vec * n_actual + perRepeatNum - 1) / perRepeatNum; for (uint32_t i = 0; i < loopCnt; i++) { conv_v(ubInput_.ReinterpretCast()[perRepeatNum * i], ubInput_[perRepeatNum * i], (uint8_t)maxRepeat, // repeat (uint16_t)1, // dstBlockStride (uint16_t)1, // srcBlockStride (uint16_t)8, // dstRepeatStride (uint16_t)8 // srcRepeatStride ); } AscendC::PipeBarrier(); for (uint32_t i = 0; i < m_actual_per_vec; ++i) { mul_v(ubTempFp32_[i * n_round], ubInput_.ReinterpretCast()[i * n_round], ubPerTensorScale_.ReinterpretCast(), (uint8_t)(nRepeatCnt), // repeat (uint8_t)1, // dstBlockStride (uint8_t)1, // src0BlockStride (uint8_t)1, // src1BlockStride (uint8_t)8, // dstRepeatStride (uint8_t)8, // src0RepeatStride (uint8_t)8 // src1RepeatStride ); if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { AscendC::PipeBarrier(); float perTokenDescale = gmPerTokenScale_.GetValue(m_offset + i); SET_FLAG(S, V, EVENT_ID0); WAIT_FLAG(S, V, EVENT_ID0); AscendC::Muls(ubTempFp32_[i * n_round], ubTempFp32_[i * n_round], perTokenDescale, n_round); } AscendC::PipeBarrier(); } SET_FLAG(V, MTE2, EVENT_ID0); AscendC::PipeBarrier(); if (n_actual % 16 > 8) { for (uint32_t i = 0; i < loopCnt; i++) { if constexpr (std::is_same_v) { convr_v(ubOutput_[perRepeatNum * i], ubTempFp32_[perRepeatNum * i], (uint8_t)maxRepeat, // repeat (uint16_t)1, // dstBlockStride (uint16_t)1, // srcBlockStride (uint16_t)4, // dstRepeatStride (uint16_t)8); // srcRepeatStride } else { conv_v(ubOutput_[perRepeatNum * i], ubTempFp32_[perRepeatNum * i], (uint8_t)maxRepeat, // repeat (uint16_t)1, // dstBlockStride (uint16_t)1, // srcBlockStride (uint16_t)4, // dstRepeatStride (uint16_t)8); // srcRepeatStride } } } else { for (uint32_t i = 0; i < m_actual_per_vec; i++) { if constexpr (std::is_same_v) { convr_v(ubOutput_[n_round_16 * i], ubTempFp32_[n_round * i], (uint8_t)nRepeatCnt, // repeat (uint16_t)1, // dstBlockStride (uint16_t)1, // srcBlockStride (uint16_t)4, // dstRepeatStride (uint16_t)8); // srcRepeatStride } else { conv_v(ubOutput_[n_round_16 * i], ubTempFp32_[n_round * i], (uint8_t)nRepeatCnt, // repeat (uint16_t)1, // dstBlockStride (uint16_t)1, // srcBlockStride (uint16_t)4, // dstRepeatStride (uint16_t)8); // srcRepeatStride } } } SET_FLAG(V, MTE3, EVENT_ID0); WAIT_FLAG(V, MTE3, EVENT_ID0); if (aligned_f16) { ub_to_gm(gmOutput_[offsetC], ubOutput_, 0, m_actual_per_vec, // nBurst n_round / 16, // lenBurst 0, // srcStride (n - n_round) / 16 // dstStride ); } else { ub_to_gm_align(gmOutput_[offsetC], ubOutput_, 0, m_actual_per_vec, // nBurst n_actual * sizeof(OutDtype), // lenBurst 0, // leftPaddingNum 0, // rightPaddingNum 0, // srcGap (n - n_actual) * sizeof(OutDtype) // dstGap ); } SET_FLAG(MTE3, V, EVENT_ID0); SET_FLAG(MTE3, MTE2, EVENT_ID0); if constexpr (!withSyncAll) { if ((loop_idx / core_num + 1) % MAX_HW_SYNC_COUNTER == 1) { FftsCrossCoreSync(MMAIV); } } } WAIT_FLAG(V, MTE2, EVENT_ID0); WAIT_FLAG(MTE3, V, EVENT_ID0); WAIT_FLAG(MTE3, MTE2, EVENT_ID0); } #endif template class MLAOperation { static constexpr bool mm1WithSyncAll = (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT); static constexpr uint64_t splitGapC = CACHE_MODE == CACHE_MODE_KVCACHE ? CONST_64 : CONST_0; using Q_OUT_DTYPE = typename std::conditional_t; using K_NOPE_DTYPE = typename std::conditional_t; public: __aicore__ inline MLAOperation(const MlaTilingData &mlaParams_, GM_ADDR tilingGm) { blockIdx = AscendC::GetBlockIdx(); #ifdef __DAV_C220_VEC__ sub_block_idx = static_cast(GetSubBlockidx()); #endif vectorBlockIdx = (blockIdx / 2) * 2 + sub_block_idx; this->n = mlaParams_.n; this->num_core_ = mlaParams_.rmsNumCore1; this->num_col_1 = mlaParams_.rmsNumCol1; this->num_col_2 = mlaParams_.rmsNumCol2; this->num_row = mlaParams_.n; this->epsilon_ = 1e-6; this->mlaParams = mlaParams_; } __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm, GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm, GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm, GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm, GM_ADDR slotMappingGm, GM_ADDR wuqGm, GM_ADDR bias2Gm, GM_ADDR wukGm, GM_ADDR descale1Gm, GM_ADDR descale2Gm, GM_ADDR gmCtkvScale, GM_ADDR gmQnopeScale, GM_ADDR qGm, GM_ADDR keycacheOutGm, GM_ADDR qGm2, GM_ADDR keycacheOutGm2, GM_ADDR s1Gm, GM_ADDR s2Gm, GM_ADDR s3Gm, GM_ADDR s4Gm, GM_ADDR s5Gm) { quantScale3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmCtkvScale)); gamma3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gamma3Gm)); sin1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(sin1Gm)); cos1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(cos1Gm)); keycacheGmTensor1.SetGlobalBuffer(reinterpret_cast<__gm__ K_NOPE_DTYPE *>(keycacheOutGm)); keycacheGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(keycacheOutGm2)); slotMappingGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(slotMappingGm)); descale1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(descale1Gm)); s2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(s2Gm)); s3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(s3Gm)); s5GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(s5Gm)); #ifdef __DAV_C220_CUBE__ mm_w8a8_aic_1.Init(s1Gm, wdqkvGm, s2Gm, mlaParams.mm1, 0); mm_w8a8_aic_1.PreloadWeight(); if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { mm_w8a8_aic_2.Init(s1Gm, wuqGm, s2Gm, mlaParams.mm2, 1); } else { // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT mm_w8a8_aic_2.Init(s1Gm, wuqGm, s3Gm, mlaParams.mm2, 1); } if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { mm_ein_sum.Init(s4Gm, wukGm, s1Gm, mlaParams); } else { mm_ein_sum.Init(s4Gm, wukGm, qGm, mlaParams); } #endif hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(hiddenStateGm)); quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(quantScale1Gm)); quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm)); wdqkvGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wdqkvGm)); gamma2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gamma2Gm)); quantScale2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(quantScale2Gm)); quantOffset2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset2Gm)); sin2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(sin2Gm)); cos2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(cos2Gm)); wuqGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wuqGm)); wukGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(wukGm)); descale2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(descale2Gm)); s1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(s1Gm)); s4GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(s4Gm)); qGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ Q_OUT_DTYPE *>(qGm)); qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(qGm2)); bias1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias1Gm)); bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm)); beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta2Gm)); #ifdef __DAV_C220_VEC__ if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { mm_w8a8_aiv_1.Init(s2Gm, s3Gm, descale1Gm, bias1Gm, s5Gm, mlaParams.mm1); mm_w8a8_aiv_2.Init(s2Gm, s4Gm, descale2Gm, bias2Gm, s5Gm, mlaParams.mm2); } else { // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT mm_w8a8_aiv_2.Init(s3Gm, s4Gm, descale2Gm, bias2Gm, s5Gm, mlaParams.mm2); } row_work = (num_row + num_core_ - 1) / num_core_; row_work_ = 0; uint32_t need_core = (num_row + row_work - 1) / row_work; if (vectorBlockIdx < need_core - 1) { row_work_ = row_work; } else if (vectorBlockIdx == need_core - 1) { row_work_ = num_row - (need_core - 1) * row_work; } else { row_work_ = 0; } this->splitN = mlaParams.perTaskNum; Quant1.Init(quantScale1GmTensor, quantOffset1GmTensor, s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, hiddenStateGm, s1Gm, 0, num_col_1, vectorBlockIdx * static_cast(row_work) * num_col_1, vectorBlockIdx * static_cast(row_work) * num_col_1, row_work_, mlaParams); if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor, s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s3Gm, s1Gm, SPLIT_SIZE_ONE, num_col_2, 0.000651041666, vectorBlockIdx * static_cast(row_work) * num_col_2, vectorBlockIdx * static_cast(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams); } else { // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor, s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s2Gm, s1Gm, SPLIT_SIZE_ONE, num_col_2, 0.000651041666, vectorBlockIdx * static_cast(row_work) * num_col_2, vectorBlockIdx * static_cast(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams); } ropeFp16.RopeInit(s4Gm, cos2GmTensor, sin2GmTensor, qGmTensor, qGmTensor2, mlaParams); einSumQuant.Init(s1Gm, gmQnopeScale, qGm, mlaParams); #endif } __aicore__ inline void ProcessCube(); __aicore__ inline void ProcessVector(); private: constexpr static uint32_t C0_SIZE = 16; constexpr static uint32_t I8_C0_SIZE = 32; template __aicore__ inline void RmsNormAndRopeConvergence1( const AscendC::LocalTensor &srcTensor, const AscendC::LocalTensor &gammaTensor, const AscendC::LocalTensor &sinTensor, const AscendC::LocalTensor &cosTensor, const AscendC::LocalTensor &slotMappingTensor, const uint32_t sN, const AscendC::LocalTensor &rmsNormTensor, const AscendC::LocalTensor &gammaFp32, const AscendC::LocalTensor &ropeKTensor, const AscendC::LocalTensor &ropeKRevertTensor, const AscendC::LocalTensor &calTensor, const AscendC::LocalTensor &outTmpTensor, AscendC::LocalTensor &tmpfp16, AscendC::LocalTensor &int8OutTensor, float quantScale3) { int64_t slotMapGmOffset = vectorBlockIdx * row_work; AscendC::DataCopy(gammaTensor, gamma3GmTensor, SPLIT_RMSNRORM_SIZE_ONE); SET_FLAG(MTE2, V, EVENT_ID1); WAIT_FLAG(MTE2, V, EVENT_ID1); Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); AscendC::DataCopyPad(slotMappingTensor, slotMappingGmTensor[slotMapGmOffset], AscendC::DataCopyExtParams(1, sN * sizeof(int32_t), 0, 0, 0), AscendC::DataCopyPadExtParams(false, 0, 8 - sN % 8, 0)); if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { mmTensor = calTensor.ReinterpretCast()[SPLIT_SIZE_ONE]; deScaleTensor = calTensor.ReinterpretCast()[SPLIT_SIZE_ONE * 2]; AscendC::DataCopy(deScaleTensor, descale1gmTensor, AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0)); } SET_FLAG(MTE2, V, EVENT_ID2); WAIT_FLAG(MTE2, V, EVENT_ID2); SET_FLAG(MTE2, S, EVENT_ID2); WAIT_FLAG(MTE2, S, EVENT_ID2); for (uint64_t loop = 0; loop < sN; ++loop) { uint64_t offset = vectorBlockIdx * static_cast(row_work) * num_col_2 + loop * MM1_OUT_SIZE; int64_t slotValue = static_cast(slotMappingTensor.GetValue(loop)); if (slotValue == -1) { continue; } if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { AscendC::DataCopy(srcTensor, s3GmTensor[offset], AscendC::DataCopyParams(1, MM1_OUT_SIZE / BLOCK_SIZE_16, 0, 0)); } else { // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT AscendC::DataCopy(mmTensor, s2GmTensor[offset], AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0)); } AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], SPLIT_RMSNRORM_SIZE_TWO); AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], SPLIT_RMSNRORM_SIZE_TWO); SET_FLAG(MTE2, V, EVENT_ID0); // ND uint64_t cacheStart = static_cast(slotValue) * static_cast(SPLIT_SIZE_ONE); uint64_t cacheStart1 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_ONE); uint64_t cacheStart2 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_TWO); // NZ uint32_t outer_idx = slotValue / 128; uint32_t inner_idx = slotValue % 128; SET_FLAG(S, MTE3, EVENT_ID0); /* RmsNorm start */ WAIT_FLAG(MTE2, V, EVENT_ID0); if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { /* DeQuant */ AscendC::Cast(mmTensor.ReinterpretCast(), mmTensor, AscendC::RoundMode::CAST_NONE, SPLIT_SIZE_ONE); AscendC::PipeBarrier(); AscendC::Mul(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), deScaleTensor, SPLIT_SIZE_ONE); AscendC::PipeBarrier(); float perTokenDescale = s5GmTensor.GetValue(row_work * vectorBlockIdx + loop); SET_FLAG(S, V, EVENT_ID0); WAIT_FLAG(S, V, EVENT_ID0); AscendC::Muls(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), perTokenDescale, SPLIT_SIZE_ONE); AscendC::PipeBarrier(); AscendC::Cast(srcTensor, mmTensor.ReinterpretCast(), AscendC::RoundMode::CAST_RINT, SPLIT_SIZE_ONE); AscendC::PipeBarrier(); } Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); AscendC::PipeBarrier(); Mul(calTensor, rmsNormTensor, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE); AscendC::PipeBarrier(); ReduceSumCustom(calTensor[SPLIT_RMSNRORM_SIZE_ONE], calTensor, calTensor[SPLIT_RMSNRORM_SIZE_ONE * 2], SPLIT_RMSNRORM_SIZE_ONE); SET_FLAG(V, S, EVENT_ID1); WAIT_FLAG(V, S, EVENT_ID1); float rms = sqrt(calTensor.GetValue(SPLIT_RMSNRORM_SIZE_ONE) / SPLIT_RMSNRORM_SIZE_ONE + epsilon_); SET_FLAG(S, V, EVENT_ID1); WAIT_FLAG(S, V, EVENT_ID1); AscendC::PipeBarrier(); Duplicate(calTensor, rms, SPLIT_RMSNRORM_SIZE_ONE); AscendC::PipeBarrier(); Div(calTensor, rmsNormTensor, calTensor, SPLIT_RMSNRORM_SIZE_ONE); AscendC::PipeBarrier(); Mul(rmsNormTensor, gammaFp32, calTensor, SPLIT_RMSNRORM_SIZE_ONE); AscendC::PipeBarrier(); if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { // quant Muls(rmsNormTensor, rmsNormTensor, quantScale3, SPLIT_RMSNRORM_SIZE_ONE); AscendC::PipeBarrier(); CastFrom32To16(tmpfp16, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE); AscendC::PipeBarrier(); CastFromF16ToI8(int8OutTensor, tmpfp16, -128, SPLIT_RMSNRORM_SIZE_ONE); AscendC::PipeBarrier(); } else { AscendC::PipeBarrier(); if (std::is_same::value) { Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_ONE); } else { Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); } } /* RmsNorm end */ /* Rope K start */ uint64_t revertOffset = SPLIT_RMSNRORM_SIZE_TWO / 2; Cast(ropeKTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO); Cast(ropeKRevertTensor[revertOffset], srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, revertOffset); Cast(ropeKRevertTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE + revertOffset], AscendC::RoundMode::CAST_NONE, revertOffset); Duplicate(calTensor, static_cast(-1), revertOffset); Duplicate(calTensor[revertOffset], static_cast(1), revertOffset); AscendC::PipeBarrier(); Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO], cosTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO); Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], sinTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO); AscendC::PipeBarrier(); Mul(ropeKTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO], ropeKTensor, SPLIT_RMSNRORM_SIZE_TWO); Mul(ropeKRevertTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); AscendC::PipeBarrier(); Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); AscendC::PipeBarrier(); Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); AscendC::PipeBarrier(); if (std::is_same::value) { Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_TWO); } else { Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO); } AscendC::PipeBarrier(); /* Rope K end */ SET_FLAG(V, MTE3, EVENT_ID0); WAIT_FLAG(V, MTE3, EVENT_ID0); WAIT_FLAG(S, MTE3, EVENT_ID0); if constexpr (CACHE_MODE == CACHE_MODE_KVCACHE) { DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, SPLIT_SIZE_ONE); } else if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { uint64_t cacheSatartI8Nz1 = outer_idx * 128 * 512 + inner_idx * I8_C0_SIZE; uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE; // nope:int8 nz AscendC::DataCopyExtParams outExt; outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / I8_C0_SIZE; outExt.blockLen = I8_C0_SIZE * sizeof(int8_t); outExt.srcStride = 0; outExt.dstStride = (128 * I8_C0_SIZE - I8_C0_SIZE) * sizeof(int8_t); DataCopyPad(keycacheGmTensor1[cacheSatartI8Nz1], int8OutTensor, outExt); // rope:T1 nz outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE; outExt.blockLen = C0_SIZE * sizeof(T1); outExt.srcStride = 0; outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt); } else if constexpr (CACHE_MODE == CACHE_MODE_NZCACHE) { uint64_t cacheSatartNz1 = outer_idx * 128 * 512 + inner_idx * C0_SIZE; uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE; // nope:T1 nz AscendC::DataCopyExtParams outExt; outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / C0_SIZE; outExt.blockLen = C0_SIZE * sizeof(T1); outExt.srcStride = 0; outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); DataCopyPad(keycacheGmTensor1[cacheSatartNz1], outTmpTensor, outExt); // rope:T1 nz outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE; outExt.blockLen = C0_SIZE * sizeof(T1); outExt.srcStride = 0; outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt); } else { // keycache1 DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, SPLIT_RMSNRORM_SIZE_ONE); // keycache2 DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], SPLIT_RMSNRORM_SIZE_TWO); } SET_FLAG(MTE3, MTE2, EVENT_ID1); WAIT_FLAG(MTE3, MTE2, EVENT_ID1); } } private: uint32_t n; uint32_t splitN; uint32_t rotaryCoeff; uint32_t blockIdx; uint32_t sub_block_idx; uint32_t vectorBlockIdx; uint32_t blockOffset; uint32_t perTaskNum; uint32_t resTaskNum; MlaTilingData mlaParams; uint32_t num_core_; uint32_t num_col_1; uint32_t num_col_2; float epsilon_; uint32_t num_row; uint32_t quantMin_; uint32_t row_work; uint32_t row_work_; AsdopsBuffer buf; AscendC::LocalTensor mmTensor; AscendC::LocalTensor deScaleTensor; AscendC::GlobalTensor hiddenStateGmTensor; AscendC::GlobalTensor quantScale1GmTensor; AscendC::GlobalTensor quantOffset1GmTensor; AscendC::GlobalTensor wdqkvGmTensor; AscendC::GlobalTensor gamma2GmTensor; AscendC::GlobalTensor quantScale2GmTensor; AscendC::GlobalTensor quantScale3GmTensor; AscendC::GlobalTensor quantOffset2GmTensor; AscendC::GlobalTensor gamma3GmTensor; AscendC::GlobalTensor sin1GmTensor; AscendC::GlobalTensor cos1GmTensor; AscendC::GlobalTensor sin2GmTensor; AscendC::GlobalTensor cos2GmTensor; AscendC::GlobalTensor keycacheGmTensor1; AscendC::GlobalTensor keycacheGmTensor2; AscendC::GlobalTensor slotMappingGmTensor; AscendC::GlobalTensor wuqGmTensor; AscendC::GlobalTensor wukGmTensor; // cachemode2-->int8; else bf16 AscendC::GlobalTensor qGmTensor; AscendC::GlobalTensor qGmTensor2; AscendC::GlobalTensor s1GmTensor; AscendC::GlobalTensor s2GmTensor; AscendC::GlobalTensor s3GmTensor; AscendC::GlobalTensor s4GmTensor; AscendC::GlobalTensor s5GmTensor; AscendC::GlobalTensor descale1gmTensor; AscendC::GlobalTensor descale2gmTensor; AscendC::GlobalTensor beta2GmTensor; AscendC::GlobalTensor bias1gmTensor; AscendC::GlobalTensor bias2gmTensor; #ifdef __DAV_C220_CUBE__ PpMatmulW8a8Aic mm_w8a8_aic_1; PpMatmulW8a8Aic mm_w8a8_aic_2; PpMatmulEinSum mm_ein_sum; #endif #ifdef __DAV_C220_VEC__ PpMatmulW8a8Aiv mm_w8a8_aiv_1; PpMatmulW8a8Aiv mm_w8a8_aiv_2; Quant Quant1; RmsNormQuant rmsNormQuant2; RopeFp16 ropeFp16; EinSumQuant einSumQuant; #endif }; template __aicore__ inline void MLAOperation::ProcessCube() { #ifdef __DAV_C220_CUBE__ mm_w8a8_aic_1.Process(); if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { FftsCrossCoreSync(MMAIC); WaitFlagDev(MMAIC); FftsCrossCoreSync(MMAIV); } mm_w8a8_aic_2.PreloadWeight(); mm_w8a8_aic_2.Process(); mm_ein_sum.Process(); if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { FftsCrossCoreSync(EINSUMOUT); WaitFlagDev(EINSUMOUT); FftsCrossCoreSync(EINSUMQUANT); } #endif } template __aicore__ inline void MLAOperation::ProcessVector() { #ifdef __DAV_C220_VEC__ if (row_work_ != 0) { uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; AscendC::LocalTensor input_tensor = buf.GetBuffer(0); AscendC::LocalTensor scale_tensor = buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); AscendC::LocalTensor offset_tensor = buf.GetBuffer( HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32); AscendC::LocalTensor res1_tensor = buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64); AscendC::LocalTensor res3_tensor = buf.GetBuffer( HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64); Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); } FftsCrossCoreSync(QUANT1); WaitFlagDev(QUANT1); AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(AIC_MM1_START); if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { mm_w8a8_aiv_1.Process(); FftsCrossCoreSync(RMSNORMQUANT2); WaitFlagDev(RMSNORMQUANT2); } else { // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT WaitFlagDev(MMAIV); } if (row_work_ != 0) { uint32_t num_col_align_int8 = (num_col_2 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; uint32_t num_col_align_f16 = (num_col_2 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; uint32_t num_col_align_f32 = (num_col_2 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; AscendC::LocalTensor input_tensor = buf.GetBuffer(0); AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); AscendC::LocalTensor beta_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2); AscendC::LocalTensor scale_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2); AscendC::LocalTensor offset_tensor = buf.GetBuffer( MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 32); AscendC::LocalTensor res1_tensor = buf.GetBuffer( MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64); AscendC::LocalTensor res3_tensor = buf.GetBuffer( MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64 + MM1_OUT_SIZE * 4 * 2 + 32); rmsNormQuant2.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); } FftsCrossCoreSync(MM2); WaitFlagDev(MM2); AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(AIC_MM2_START); if (row_work_ != 0) { AscendC::LocalTensor input_tensor = buf.GetBuffer(0); AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); AscendC::LocalTensor sin_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2); AscendC::LocalTensor cos_tensor = buf.GetBuffer( MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 2); AscendC::LocalTensor slotMapping_tensor = buf.GetBuffer( MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4); int32_t rms3_ub_offset = MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + 4096 * 32; AscendC::LocalTensor tmp32_tensor = buf.GetBuffer(rms3_ub_offset); int32_t out_ub_offset = MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + 4096 * 32 + SPLIT_RMSNRORM_SIZE_ONE * 3 * 4 + SPLIT_RMSNRORM_SIZE_TWO * 2 * 4 + MM1_OUT_SIZE * 4 * 2 + 32; AscendC::LocalTensor temp_tensor = buf.GetBuffer(out_ub_offset); AscendC::LocalTensor tmpfp16; AscendC::LocalTensor int8OutTensor; float scale3 = 0; if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { // quantScale3 AscendC::LocalTensor quantScaleTensor = buf.GetBuffer(rms3_ub_offset); AscendC::LocalTensor floatQuantScaleTensor = buf.GetBuffer(rms3_ub_offset + 32); // int8out tmpfp16 = buf.GetBuffer(rms3_ub_offset + SPLIT_RMSNRORM_SIZE_ONE * sizeof(float) * 2); int8OutTensor = buf.GetBuffer(out_ub_offset); AscendC::DataCopy(quantScaleTensor, quantScale3GmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); SET_FLAG(MTE2, V, EVENT_ID1); WAIT_FLAG(MTE2, V, EVENT_ID1); Cast(floatQuantScaleTensor, quantScaleTensor, AscendC::RoundMode::CAST_NONE, 1); AscendC::SetFlag(EVENT_ID1); AscendC::WaitFlag(EVENT_ID1); scale3 = 1 / (float)(floatQuantScaleTensor.GetValue(0)); } RmsNormAndRopeConvergence1( input_tensor, // n * 576 gamma_tensor, // gamma sin_tensor, // sin cos_tensor, // cons slotMapping_tensor, // slotMapping row_work_, tmp32_tensor, tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE], tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE], tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO], tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO + SPLIT_RMSNRORM_SIZE_TWO], temp_tensor, tmpfp16, int8OutTensor, scale3); } mm_w8a8_aiv_2.Process(); FftsCrossCoreSync(MM2OUT); WaitFlagDev(MM2OUT); AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(AIC_MM3_START); ropeFp16.Process(); if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { WaitFlagDev(EINSUMQUANT); einSumQuant.Process(); } #endif } } // namespace MLAPO_BF16