diff --git a/csrc/mla_preprocess/op_host/mla_preprocess.h b/csrc/mla_preprocess/op_host/mla_preprocess.h index 66ad8c33..f554869b 100644 --- a/csrc/mla_preprocess/op_host/mla_preprocess.h +++ b/csrc/mla_preprocess/op_host/mla_preprocess.h @@ -127,6 +127,7 @@ struct OpParam { int32_t cacheMode; QuantMode quantMode; caffe2::TypeMeta inDtype; + bool enableInnerOut; }; class PpMatmulTilingApi @@ -540,7 +541,8 @@ void MlaPreprocessTiling::SetMlapoWorkSpace() void MlaPreprocessTiling::SetTilingKey() { - uint64_t tilingKey = (static_cast(opParam.inDtype == at::kBFloat16)) << 8; + uint64_t tilingKey = (static_cast(opParam.enableInnerOut)) << 9; + tilingKey |= (static_cast(opParam.inDtype == at::kBFloat16)) << 8; tilingKey |= static_cast(opParam.cacheMode); tilingKey |= (static_cast(opParam.quantMode) << 3); @@ -619,21 +621,12 @@ inline int get_op_mode(const MapType &mode_map, c10::optional return it->second; } -// std::tuple mla_preprocess( -// const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv, -// const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq, -// const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin, -// const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping, -// const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0, -// const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1, -// const c10::optional &ctkv_scale, const c10::optional &q_nope_scale, -// c10::optional cache_mode, c10::optional quant_mode, at::Tensor &q_out0, -// at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1) std::tuple mla_preprocess_tiling( const at::Tensor &hiddenState, const at::Tensor &wuk, c10::optional cache_mode, - c10::optional quant_mode + c10::optional quant_mode, + bool enable_inner_out ) { auto cacheMode = get_op_mode(cache_mode_map, cache_mode, "krope_ctkv", "cache_mode"); @@ -661,6 +654,7 @@ std::tuple mla_preprocess_tiling( opParam.cacheMode = static_cast(cacheMode); opParam.quantMode = static_cast(quantMode); opParam.inDtype = hiddenState.options().dtype(); + opParam.enableInnerOut = enable_inner_out; MlaTilingData tilingData; MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData); diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess.h b/csrc/mla_preprocess/op_kernel/mla_preprocess.h index 35254112..fe7a125c 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess.h +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess.h @@ -103,6 +103,9 @@ constexpr uint32_t KEY_FP16_CACHEMODE_1_QUANTMODE_0 = 1; constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0 = 256; constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0 = 257; constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0 = 259; +constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER = 256 + 512; +constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER = 257 + 512; +constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER = 259 + 512; enum class QuantMode : int32_t { PER_TENSOR_ASYMM_QUANT = 0, diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp index 7ea231c7..5f367a92 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp @@ -15,6 +15,7 @@ #include "mla_preprocess_mix_fp16.hpp" #include "mla_preprocess_mix_bf16.hpp" +#include "mla_preprocess_mix_bf16_qdown.hpp" #include "../op_host/tiling/mla_preprocess_tiling.h" @@ -23,7 +24,7 @@ extern "C" __global__ __aicore__ void mla_preprocess( GM_ADDR bias1, GM_ADDR gamma2, GM_ADDR beta2, GM_ADDR quantScale2, GM_ADDR quantOffset2, GM_ADDR gamma3, GM_ADDR sin1, GM_ADDR cos1, GM_ADDR sin2, GM_ADDR cos2, GM_ADDR keycache, GM_ADDR slotMapping, GM_ADDR wuq, GM_ADDR bias2, GM_ADDR wuk, GM_ADDR descale1, GM_ADDR descale2, GM_ADDR ctkvScale, GM_ADDR qnopeScale, GM_ADDR q, - GM_ADDR keycacheOut, GM_ADDR q2, GM_ADDR keycacheOut2, GM_ADDR workspace, GM_ADDR tiling) + GM_ADDR keycacheOut, GM_ADDR q2, GM_ADDR keycacheOut2, GM_ADDR innerOut, GM_ADDR workspace, GM_ADDR tiling) { #if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220) PRELOAD(2); @@ -218,6 +219,54 @@ extern "C" __global__ __aicore__ void mla_preprocess( } break; } + case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER: { + MLAPO_BF16_INNER::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, + QuantMode::PER_TENSOR_ASYMM_QUANT> + opBf16Cm0Qm0Inner(mlaTilingData, tiling); + opBf16Cm0Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, + quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, + bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, + s1, s2, s3, s4, s5, innerOut); + if ASCEND_IS_AIC { + opBf16Cm0Qm0Inner.ProcessCube(); + } + if ASCEND_IS_AIV { + opBf16Cm0Qm0Inner.ProcessVector(); + } + break; + } + case KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER: { + MLAPO_BF16_INNER::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, + QuantMode::PER_TENSOR_ASYMM_QUANT> + opBf16Cm1Qm0Inner(mlaTilingData, tiling); + opBf16Cm1Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, + quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, + bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, + s1, s2, s3, s4, s5, innerOut); + if ASCEND_IS_AIC { + opBf16Cm1Qm0Inner.ProcessCube(); + } + if ASCEND_IS_AIV { + opBf16Cm1Qm0Inner.ProcessVector(); + } + break; + } + case KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER: { + MLAPO_BF16_INNER::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, + QuantMode::PER_TENSOR_ASYMM_QUANT> + opBf16Cm3Qm0Inner(mlaTilingData, tiling); + opBf16Cm3Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, + quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, + bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, + s1, s2, s3, s4, s5, innerOut); + if ASCEND_IS_AIC { + opBf16Cm3Qm0Inner.ProcessCube(); + } + if ASCEND_IS_AIV { + opBf16Cm3Qm0Inner.ProcessVector(); + } + break; + } default: { break; } @@ -256,6 +305,7 @@ extern void mla_preprocess_impl( void* keycache_out, void* q2, void* keycache_out2, + void* inner_out, void* workspace, void* tiling, const uint32_t block_dim) @@ -288,6 +338,7 @@ extern void mla_preprocess_impl( keycache_out, q2, keycache_out2, + inner_out, workspace, tiling); } diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16_qdown.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16_qdown.hpp new file mode 100644 index 00000000..aa990cbb --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16_qdown.hpp @@ -0,0 +1,2936 @@ +// Adapted from +// https://gitee.com/ascend/ascend-transformer-boost +// +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +// This file is a part of the CANN Open Software. +// Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// + +#include "kernel/common.h" +#include "kernel/iterator.h" +#include "kernel/mem.h" +#include "kernel/mma.h" +#include "kernel/utils.h" +#include "kernel/simd.h" +#include "kernel/kernel_utils.h" + +#include "lib/matmul_intf.h" + +#include "mla_preprocess.h" +#include "../op_host/tiling/mla_preprocess_tiling.h" + +namespace MLAPO_BF16_INNER { +template +class RopeFp16 +{ +public: + __aicore__ inline RopeFp16() : blockIdx_(AscendC::GetBlockIdx()) {} + + __aicore__ inline void RopeInit(GM_ADDR qGm, AscendC::GlobalTensor &cosGm, + AscendC::GlobalTensor &sinGm, + AscendC::GlobalTensor &outRopeConcatGm, + AscendC::GlobalTensor &outRopeConcatGm2, MlaTilingData &ropeConcatParams) + { + qGm_.SetGlobalBuffer(reinterpret_cast<__gm__ QkDtype *>(qGm)); + this->cosGm_ = cosGm; + this->sinGm_ = sinGm; + this->outRopeConcatGm_ = outRopeConcatGm; + this->outRopeConcatGm2_ = outRopeConcatGm2; + + headDim = ropeConcatParams.headDim; + headNumQ = ropeConcatParams.headNumQ; + rotaryCoeff = ropeConcatParams.rotaryCoeff; + ntokens = ropeConcatParams.ntokens; + realCore = ropeConcatParams.realCore; + nlCoreRun = ropeConcatParams.nlCoreRun; + lCoreRun = ropeConcatParams.lCoreRun; + maxNPerLoopForUb = ropeConcatParams.maxNPerLoopForUb; + preCoreLoopTime = ropeConcatParams.preCoreLoopTime; + preCoreLoopNLast = ropeConcatParams.preCoreLoopNLast; + lastCoreLoopTime = ropeConcatParams.lastCoreLoopTime; + lastCoreLoopNLast = ropeConcatParams.lastCoreLoopNLast; + concatSize = ropeConcatParams.concatSize; + blockIdx_ = (blockIdx_ / 2) * 2 + static_cast(GetSubBlockidx()); + loopTime = (blockIdx_ == realCore - 1) ? lastCoreLoopTime : preCoreLoopTime; + lastLoopN = (blockIdx_ == realCore - 1) ? lastCoreLoopNLast : preCoreLoopNLast; + this->repeatSize_ = 64; // 128 = 256B / sizeof(fp32) + this->rotateStride_ = this->headDim / this->rotaryCoeff; + headBlockLen = static_cast(this->headDim / ELE_NUM_FP16); + headBlockLenFP32 = static_cast(this->headDim / ELE_NUM_FP32); + rotaryLen = static_cast(this->rotateStride_ / ELE_NUM_FP32); + concatBlockLen = static_cast(this->concatSize / ELE_NUM_FP16); + outLineOffset = this->headDim + this->concatSize; + uint32_t dataNum = this->headDim * this->maxNPerLoopForUb; + dataSizeFp16 = dataNum * sizeof(QkDtype); + dataSizeFp32 = dataNum * sizeof(float); + uint32_t concatDataSize = this->concatSize * sizeof(QkDtype) * this->maxNPerLoopForUb; + } + + __aicore__ inline void Process() + { + if (blockIdx_ >= realCore) { + return; + } + uint64_t startCoreLineIndex = this->blockIdx_ * this->nlCoreRun; + // [maxNPerLoopForUb,head_dim] 的 neg + AscendC::LocalTensor negLocal = + buf.GetBuffer(dataSizeFp32 * 4 + dataSizeFp16 * 3); + ExpandNeg(negLocal, this->maxNPerLoopForUb); + + SET_FLAG(MTE3, MTE2, EVENT_ID1); + for (uint32_t zz = 0; zz < this->loopTime; ++zz) { + uint16_t loopN = (zz == this->loopTime - 1) ? this->lastLoopN : this->maxNPerLoopForUb; + uint64_t startHead = startCoreLineIndex + zz * this->maxNPerLoopForUb; + uint64_t endHead = startHead + loopN; + + // move in Q + AscendC::LocalTensor inputQ = buf.GetBuffer(0); + AscendC::LocalTensor inputQCastFP32 = buf.GetBuffer(dataSizeFp16); + AscendC::LocalTensor reverseQ = + buf.GetBuffer(dataSizeFp32 + dataSizeFp16); + uint64_t qOffset = startHead * 192 + 128; + CopyQGenReverseQ(inputQ, inputQCastFP32, reverseQ, qOffset, loopN); + + // move in cos/sin + AscendC::LocalTensor inputCos = + buf.GetBuffer(dataSizeFp32 * 2 + dataSizeFp16); + AscendC::LocalTensor inputSin = + buf.GetBuffer(dataSizeFp32 * 2 + dataSizeFp16 * 2); + uint64_t startSinCosHeadIndex = startHead; + uint64_t headRemain = startHead % this->headNumQ; + uint64_t localStartAddr = 0; + if (headRemain != 0) { + uint64_t preProcessHeadNum = this->headNumQ - headRemain; + uint64_t needToProcesHead = preProcessHeadNum > loopN ? loopN : preProcessHeadNum; + CopyCosSin(inputCos, inputSin, localStartAddr, (startSinCosHeadIndex / this->headNumQ) * this->headDim, + needToProcesHead); + startSinCosHeadIndex += needToProcesHead; + localStartAddr += needToProcesHead * this->headDim; + } + + if (startSinCosHeadIndex < endHead) { + uint64_t startSinCosIndex = startSinCosHeadIndex / this->headNumQ; + uint64_t endSinCosIndex = (endHead + this->headNumQ - 1) / this->headNumQ; + for (uint32_t index = startSinCosIndex; index < endSinCosIndex; ++index) { + uint32_t repeatNum = + index == endSinCosIndex - 1 ? endHead - index * this->headNumQ : this->headNumQ; + CopyCosSin(inputCos, inputSin, localStartAddr, index * this->headDim, repeatNum); + localStartAddr += this->headDim * this->headNumQ; + } + } + AscendC::LocalTensor inputCosCastFP32 = + buf.GetBuffer(dataSizeFp32 * 2 + dataSizeFp16 * 3); + AscendC::LocalTensor inputSinCastFP32 = + buf.GetBuffer(dataSizeFp32 * 3 + dataSizeFp16 * 3); + AscendC::Cast(inputCosCastFP32, inputCos, AscendC::RoundMode::CAST_NONE, loopN * this->headDim); + AscendC::Cast(inputSinCastFP32, inputSin, AscendC::RoundMode::CAST_NONE, loopN * this->headDim); + AscendC::PipeBarrier(); + + uint32_t repeatTime = this->headDim * loopN; + AscendC::Mul(inputQCastFP32, inputCosCastFP32, inputQCastFP32, repeatTime); + AscendC::Mul(reverseQ, negLocal, reverseQ, repeatTime); + AscendC::PipeBarrier(); + + AscendC::Mul(reverseQ, inputSinCastFP32, reverseQ, repeatTime); + AscendC::PipeBarrier(); + + AscendC::Add(inputQCastFP32, reverseQ, inputQCastFP32, repeatTime); + AscendC::PipeBarrier(); + + AscendC::Cast(inputQ, inputQCastFP32, AscendC::RoundMode::CAST_RINT, loopN * this->headDim); + AscendC::PipeBarrier(); + uint64_t outQOffset = startHead * outLineOffset + this->concatSize; + uint64_t outQOffset2 = startHead * this->headDim; + SET_FLAG(V, MTE3, EVENT_ID1); + WAIT_FLAG(V, MTE3, EVENT_ID1); + if constexpr (CacheMode == CACHE_MODE_KVCACHE) { + AscendC::DataCopy(this->outRopeConcatGm_[outQOffset], inputQ, {loopN, headBlockLen, 0, concatBlockLen}); + } else { + AscendC::DataCopy(this->outRopeConcatGm2_[outQOffset2], inputQ, loopN * this->headDim); + } + SET_FLAG(MTE3, MTE2, EVENT_ID1); + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + } + // tensor -1 -1 -1 1 1 1 + template + __aicore__ inline void ExpandNeg(const AscendC::LocalTensor &tempBuf, uint32_t headNumTemp) + { + for (uint32_t i = 0; i < this->rotateStride_; ++i) { + tempBuf.SetValue(i, (BUF_TYPE)-1); + tempBuf.SetValue(i + this->rotateStride_, (BUF_TYPE)1); + } + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + AscendC::Copy(tempBuf[this->headDim], tempBuf, this->headDim, headNumTemp - 1, {1, 1, headBlockLenFP32, 0}); + AscendC::PipeBarrier(); + } + + template + __aicore__ inline void CopyQGenReverseQ(const AscendC::LocalTensor &tempBufQ, + const AscendC::LocalTensor &tempBufQCast, + const AscendC::LocalTensor &tempBufRverseQ, uint64_t qOffset, + uint16_t loopN) + { + // move in Q + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, 128 / 16, 0}); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + // cast fp32 + AscendC::Cast(tempBufQCast, tempBufQ, AscendC::RoundMode::CAST_NONE, loopN * this->headDim); + AscendC::PipeBarrier(); + // move out reverseQ + AscendC::DataCopy(tempBufRverseQ, tempBufQCast[this->rotateStride_], {loopN, rotaryLen, rotaryLen, rotaryLen}); + AscendC::DataCopy(tempBufRverseQ[this->rotateStride_], tempBufQCast, {loopN, rotaryLen, rotaryLen, rotaryLen}); + AscendC::PipeBarrier(); + } + + template + __aicore__ inline void CopyCosSin(const AscendC::LocalTensor &tempBufCos, + const AscendC::LocalTensor &tempBufSin, uint64_t localStartAddr, + uint64_t gmStartAddr, uint64_t repeatNum) + { + AscendC::DataCopy(tempBufCos[localStartAddr], this->cosGm_[gmStartAddr], {1, headBlockLen, 0, 0}); + AscendC::DataCopy(tempBufSin[localStartAddr], this->sinGm_[gmStartAddr], {1, headBlockLen, 0, 0}); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + AscendC::Copy(tempBufCos[localStartAddr + this->headDim], tempBufCos[localStartAddr], this->headDim, + repeatNum - 1, {1, 1, headBlockLen, 0}); + AscendC::Copy(tempBufSin[localStartAddr + this->headDim], tempBufSin[localStartAddr], this->headDim, + repeatNum - 1, {1, 1, headBlockLen, 0}); + AscendC::PipeBarrier(); + } + +private: + AsdopsBuffer buf; + + AscendC::GlobalTensor qGm_; + AscendC::GlobalTensor cosGm_; + AscendC::GlobalTensor sinGm_; + AscendC::GlobalTensor outRopeConcatGm_; + AscendC::GlobalTensor outRopeConcatGm2_; + + uint32_t repeatSize_{0}; + uint32_t rotateStride_{0}; // this->headDim / rope conf + uint32_t headDim; + uint32_t headNumQ; + uint32_t rotaryCoeff; + uint32_t ntokens; + uint32_t realCore; + uint32_t nlCoreRun; + uint32_t lCoreRun; + uint32_t maxNPerLoopForUb; + uint32_t preCoreLoopTime; + uint32_t preCoreLoopNLast; + uint32_t lastCoreLoopTime; + uint32_t lastCoreLoopNLast; + uint32_t concatSize; + uint32_t blockIdx_; + uint32_t loopTime{0}; + uint32_t lastLoopN{0}; + + uint32_t dataSizeFp32; + uint32_t dataSizeFp16; + uint16_t headBlockLen{0}; + uint16_t headBlockLenFP32{0}; + uint16_t rotaryLen{0}; + uint16_t concatBlockLen{0}; + uint64_t outLineOffset{0}; +}; + +__aicore__ inline void ReduceSumCustom(const AscendC::LocalTensor &dst_local, + const AscendC::LocalTensor &src_local, + const AscendC::LocalTensor &work_local, int32_t count) +{ +#ifdef __DAV_C220_VEC__ + uint64_t mask = NUM_PER_REP_FP32; + int32_t repeatTimes = count / NUM_PER_REP_FP32; + int32_t tailCount = count % NUM_PER_REP_FP32; + int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32; + AscendC::BinaryRepeatParams repeatParams; + repeatParams.src0RepStride = AscendC::ONE_REPEAT_BYTE_SIZE / AscendC::ONE_BLK_SIZE; + repeatParams.src0BlkStride = 1; + repeatParams.src1RepStride = 0; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = 0; + repeatParams.dstBlkStride = 1; + Duplicate(work_local, ZERO, NUM_PER_REP_FP32); + AscendC::PipeBarrier(); + if (likely(repeatTimes > 0)) { + Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams); + AscendC::PipeBarrier(); + } + if (unlikely(tailCount != 0)) { + Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams); + AscendC::PipeBarrier(); + } + AscendC::AscendCUtils::SetMask(NUM_PER_REP_FP32); + cadd_v(dst_local, // dst + work_local, // src + 1, // repeat + 0, // dstRepeatStride + 1, // srcBlockStride + 0); // srcRepeatStride + AscendC::PipeBarrier(); +#endif +} + +template +class Quant +{ +public: + __aicore__ inline Quant() {} + + __aicore__ inline void Init(AscendC::GlobalTensor &quantScaleGmTensor, + AscendC::GlobalTensor &quantOffsetGmTensor, GM_ADDR perTokenDescaleGm, + GM_ADDR perChannelDescaleGm, GM_ADDR gmInput, GM_ADDR gmOutput, uint32_t stride, + uint32_t num_col, uint64_t gm_offset, uint64_t gm_out_offset, uint32_t row_work_, + const MlaTilingData &mlaParams_) + { + this->quantScaleGmTensor = quantScaleGmTensor; + this->quantOffsetGmTensor = quantOffsetGmTensor; + this->perTokenDescaleGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(perTokenDescaleGm)); + this->perChannelDescaleGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(perChannelDescaleGm)); + if constexpr (!NEED_DEQUANT) { + inputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(gmInput)); + } else { + mmGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmInput)); + } + outputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(gmOutput)); + + num_col_ = num_col; + quantMin_ = -128; + this->num_row_ = mlaParams_.n; + this->row_work = row_work; + this->row_work_ = row_work_; + gm_offset_ = gm_offset; + gm_out_offset_ = gm_out_offset; + num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + num_col_align_f16 = (num_col_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + num_col_align_f32 = (num_col_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + input_stride_ = stride; + + num_col_align_withStride_int8 = + (num_col_ - input_stride_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + num_col_align_withStride_fp16 = + (num_col_ - input_stride_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + num_col_align_withStride_fp32 = + (num_col_ - input_stride_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + } + + __aicore__ inline void Launch(const AscendC::LocalTensor &dstTensor, + const AscendC::LocalTensor &srcTensor, + const AscendC::LocalTensor &quantScaleTensor, + const AscendC::LocalTensor &quantOffsetTensor, + const AscendC::LocalTensor &res1Tensor, + const AscendC::LocalTensor &res3Tensor) + { + this->dstTensor = dstTensor; + this->srcTensor = srcTensor; + this->fp32_xy = res1Tensor; + this->buf = res3Tensor; + + AscendC::LocalTensor g = buf[OFFSET_GAMMA * num_col_align_withStride_fp32]; // 0 + AscendC::LocalTensor sqx = buf[OFFSET_SQX * num_col_align_withStride_fp32]; // 1 + AscendC::LocalTensor work = buf[OFFSET_SUM * num_col_align_withStride_fp32]; // 2 + AscendC::LocalTensor abs = buf[OFFSET_ABS * num_col_align_withStride_fp32]; // 3 + AscendC::LocalTensor sum = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32]; // 4 + AscendC::LocalTensor max = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 8]; // 5 + AscendC::LocalTensor perTokenDescaleTensor = + buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; // 6 + + SET_FLAG(MTE2, V, EVENT_ID1); + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + AscendC::DataCopy(quantScaleTensor, quantScaleGmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); + AscendC::DataCopy(quantOffsetTensor, quantOffsetGmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); + } + + if constexpr (NEED_DEQUANT) { + mmTensor = buf.ReinterpretCast()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; + deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE]; + perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2]; + AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); + } + + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + if (std::is_same::value) { + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + Cast(g, quantScaleTensor, AscendC::RoundMode::CAST_NONE, 1); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + input_scale_ = 1 / (float)(g.GetValue(0)); + input_offset_ = (float)(quantOffsetTensor.GetValue(0)); + } else { + SET_FLAG(MTE2, S, EVENT_ID0); + WAIT_FLAG(MTE2, S, EVENT_ID0); + input_scale_ = 1 / (float)(quantScaleTensor.GetValue(0)); + input_offset_ = (float)(quantOffsetTensor.GetValue(0)); + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + WAIT_FLAG(MTE2, V, EVENT_ID1); + uint64_t pid = 0; + SET_FLAG(MTE3, MTE2, EVENT_ID0); + while (pid < row_work_) { + uint64_t offset = pid * num_col_; + uint64_t outOffset = pid * (num_col_ - input_stride_); + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + if constexpr (!NEED_DEQUANT) { + AscendC::DataCopy(srcTensor, inputGmTensor[gm_offset_ + offset], + AscendC::DataCopyParams(1, num_col_ / BLOCK_SIZE_16, 0, 0)); + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + } else { + /* Dequant start */ + AscendC::DataCopy(mmTensor, mmGmTensor[gm_offset_ + offset], + AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); // 2112 + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + AscendC::Cast(mmTensor.ReinterpretCast(), mmTensor, AscendC::RoundMode::CAST_NONE, num_col_); + AscendC::PipeBarrier(); + AscendC::Mul(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), deScaleTensor, + num_col_); + SET_FLAG(V, MTE2, EVENT_ID0); + WAIT_FLAG(V, MTE2, EVENT_ID0); + gm_to_ub_align(perTokenDescaleTensor, perTokenDescaleGmTensor[pid], + 0, // sid + 1, // nBurst + sizeof(float), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, S, EVENT_ID0); + WAIT_FLAG(MTE2, S, EVENT_ID0); + float perTokenDescale = perTokenDescaleTensor.GetValue(0); + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + AscendC::Muls(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), perTokenDescale, + num_col_); + AscendC::PipeBarrier(); + AscendC::Cast(srcTensor, mmTensor.ReinterpretCast(), AscendC::RoundMode::CAST_RINT, num_col_); + AscendC::PipeBarrier(); + } + + Cast(fp32_xy, srcTensor[input_stride_], AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64, + num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); + AscendC::PipeBarrier(); + + /* Quant start */ + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + Muls(fp32_xy, fp32_xy, input_scale_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + Adds(fp32_xy, fp32_xy, input_offset_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + } else if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + Abs(abs, fp32_xy, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + ReduceMax(max, abs, work, num_col_ - input_stride_); + AscendC::PipeBarrier(); + float scaleOut = max.GetValue(0) / 127; + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + Muls(fp32_xy, fp32_xy, (float)(1 / scaleOut), REPEAT_TIME_64, + num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + perTokenDescaleTensor.SetValue(0, scaleOut); + SET_FLAG(S, MTE3, EVENT_ID0); + WAIT_FLAG(S, MTE3, EVENT_ID0); + if constexpr (!NEED_DEQUANT) { + ub_to_gm_align(perTokenDescaleGmTensor[pid], perTokenDescaleTensor, 0, + 1, // nBurst + 1 * sizeof(float), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0 // dstGap + ); + } else { + ub_to_gm_align(perTokenDescaleGmTensor[num_row_ + pid], + perTokenDescaleTensor, 0, + 1, // nBurst + 1 * sizeof(float), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0 // dstGap + ); + } + SET_FLAG(MTE3, V, EVENT_ID0); + WAIT_FLAG(MTE3, V, EVENT_ID0); + } + + AscendC::LocalTensor tmpfp16 = + buf.ReinterpretCast()[OFFSET_SUM * num_col_align_withStride_fp32 * 2]; + CastFrom32To16(tmpfp16, fp32_xy, num_col_align_withStride_fp32); + AscendC::PipeBarrier(); + CastFromF16ToI8(dstTensor, tmpfp16, quantMin_, num_col_align_withStride_fp16); + AscendC::PipeBarrier(); + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + AscendC::DataCopy(outputGmTensor[gm_out_offset_ + outOffset], dstTensor, + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / 32, 0, 0)); + SET_FLAG(MTE3, MTE2, EVENT_ID0); + ++pid; + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + } + +private: + AscendC::LocalTensor dstTensor; + AscendC::LocalTensor srcTensor; + AscendC::LocalTensor fp32_xy; + AscendC::LocalTensor buf; + AscendC::LocalTensor mmTensor; + AscendC::LocalTensor deScaleTensor; + + AscendC::GlobalTensor quantScaleGmTensor; + AscendC::GlobalTensor quantOffsetGmTensor; + AscendC::GlobalTensor inputGmTensor; + AscendC::GlobalTensor outputGmTensor; + AscendC::GlobalTensor perTokenDescaleGmTensor; + AscendC::GlobalTensor perChannelDescaleGmTensor; + AscendC::GlobalTensor mmGmTensor; + + uint32_t num_col_{0}; // input columns + uint32_t num_row_{0}; // input rows + uint32_t row_work_{0}; // rows need process + uint32_t row_work{0}; // rows need process + uint32_t row_step_{0}; // rows move in once + uint32_t row_tail_{0}; // rows move in last time + uint64_t gm_offset_{0}; // GM data offset + uint64_t gm_out_offset_{0}; // GM data offset + float avg_factor_{1.0}; // 1/num_col_ + float input_scale_{1.0}; + float input_offset_{0}; + int32_t input_stride_{0}; + float epsilon_{1e-12f}; + uint32_t num_col_align_int8{0}; + uint32_t num_col_align_f16{0}; + uint32_t num_col_align_f32{0}; + uint32_t num_col_align_f32_long{0}; + uint32_t num_col_align_withStride_int8{0}; + uint32_t num_col_align_withStride_fp16{0}; + uint32_t num_col_align_withStride_fp32{0}; + uint32_t num_col_temp; + half quantMin_{-128}; + uint32_t num_slice_{0}; + uint32_t tail_size_{0}; + uint32_t tail_copy_{0}; +}; + +template +class RmsNormQuant +{ +public: + __aicore__ inline RmsNormQuant() {} + + __aicore__ inline void Init(AscendC::GlobalTensor &gammaGmTensor, AscendC::GlobalTensor &betaGmTensor, + AscendC::GlobalTensor &quantScaleGmTensor, + AscendC::GlobalTensor &quantOffsetGmTensor, GM_ADDR perTokenDescaleGm, + GM_ADDR perChannelDescaleGm, GM_ADDR gmInput, GM_ADDR gmOutput, uint32_t stride, + uint32_t num_col, float avg_factor, uint64_t gm_offset, uint64_t gm_out_offset, + uint32_t row_work_, const MlaTilingData &mlaParams_, AscendC::GlobalTensor innerGmTensor) + { + this->gammaGmTensor = gammaGmTensor; + this->betaGmTensor = betaGmTensor; + this->quantScaleGmTensor = quantScaleGmTensor; + this->quantOffsetGmTensor = quantOffsetGmTensor; + this->perTokenDescaleGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(perTokenDescaleGm)); + this->perChannelDescaleGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(perChannelDescaleGm)); + if constexpr (!NEED_DEQUANT) { + inputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(gmInput)); + } else { + mmGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmInput)); + } + outputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(gmOutput)); + this->innerGmTensor = innerGmTensor; + + num_col_ = num_col; + avg_factor_ = avg_factor; + epsilon_ = 1e-6; + quantMin_ = -128; + this->num_row_ = mlaParams_.n; + this->row_work = row_work; + this->row_work_ = row_work_; + gm_offset_ = gm_offset; + gm_out_offset_ = gm_out_offset; + num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + num_col_align_f16 = (num_col_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + num_col_align_f32 = (num_col_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + input_stride_ = stride; + + num_col_align_withStride_int8 = + (num_col_ - input_stride_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + num_col_align_withStride_fp16 = + (num_col_ - input_stride_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + num_col_align_withStride_fp32 = + (num_col_ - input_stride_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + } + + __aicore__ inline void Launch(const AscendC::LocalTensor &dstTensor, + const AscendC::LocalTensor &srcTensor, const AscendC::LocalTensor &gammaTensor, + const AscendC::LocalTensor &betaTensor, + const AscendC::LocalTensor &quantScaleTensor, + const AscendC::LocalTensor &quantOffsetTensor, + const AscendC::LocalTensor &res1Tensor, + const AscendC::LocalTensor &res3Tensor) + { + this->dstTensor = dstTensor; + this->srcTensor = srcTensor; + this->gammaTensor = gammaTensor; + this->betaTensor = betaTensor; + this->fp32_xy = res1Tensor; + this->buf = res3Tensor; + + AscendC::LocalTensor g = buf[OFFSET_GAMMA * num_col_align_withStride_fp32]; // 0 + AscendC::LocalTensor sqx = buf[OFFSET_SQX * num_col_align_withStride_fp32]; // 1 + AscendC::LocalTensor work = buf[OFFSET_SUM * num_col_align_withStride_fp32]; // 2 + AscendC::LocalTensor abs = buf[OFFSET_ABS * num_col_align_withStride_fp32]; // 3 + AscendC::LocalTensor inner = buf[OFFSET_ABS * num_col_align_withStride_fp32].ReinterpretCast(); + AscendC::LocalTensor sum = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32]; // 4 + AscendC::LocalTensor max = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 8]; // 5 + AscendC::LocalTensor perTokenDescaleTensor = + buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; // 6 + + AscendC::DataCopy(gammaTensor, gammaGmTensor, + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / BLOCK_SIZE_16, 0, 0)); + AscendC::DataCopy(betaTensor, betaGmTensor, + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / BLOCK_SIZE_16, 0, 0)); + SET_FLAG(MTE2, V, EVENT_ID1); + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + AscendC::DataCopy(quantScaleTensor, quantScaleGmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); + AscendC::DataCopy(quantOffsetTensor, quantOffsetGmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); + } + + if constexpr (NEED_DEQUANT) { + mmTensor = buf.ReinterpretCast()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; + deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE]; + perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2]; + AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); + } + + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + if (std::is_same::value) { + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + Cast(g, quantScaleTensor, AscendC::RoundMode::CAST_NONE, 1); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + input_scale_ = 1 / (float)(g.GetValue(0)); + input_offset_ = (float)(quantOffsetTensor.GetValue(0)); + } else { + SET_FLAG(MTE2, S, EVENT_ID0); + WAIT_FLAG(MTE2, S, EVENT_ID0); + input_scale_ = 1 / (float)(quantScaleTensor.GetValue(0)); + input_offset_ = (float)(quantOffsetTensor.GetValue(0)); + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + WAIT_FLAG(MTE2, V, EVENT_ID1); + Cast(buf[OFFSET_GAMMA * num_col_align_withStride_fp32], gammaTensor, AscendC::RoundMode::CAST_NONE, + REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); + AscendC::PipeBarrier(); + uint64_t pid = 0; + SET_FLAG(MTE3, MTE2, EVENT_ID0); + while (pid < row_work_) { + uint64_t offset = pid * num_col_; + uint64_t outOffset = pid * (num_col_ - input_stride_); + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + if constexpr (!NEED_DEQUANT) { + AscendC::DataCopy(srcTensor, inputGmTensor[gm_offset_ + offset], + AscendC::DataCopyParams(1, num_col_ / BLOCK_SIZE_16, 0, 0)); + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + } else { + /* Dequant start */ + AscendC::DataCopy(mmTensor, mmGmTensor[gm_offset_ + offset], + AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); // 2112 + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + AscendC::Cast(mmTensor.ReinterpretCast(), mmTensor, AscendC::RoundMode::CAST_NONE, num_col_); + AscendC::PipeBarrier(); + AscendC::Mul(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), deScaleTensor, + num_col_); + SET_FLAG(V, MTE2, EVENT_ID0); + WAIT_FLAG(V, MTE2, EVENT_ID0); + gm_to_ub_align(perTokenDescaleTensor, perTokenDescaleGmTensor[pid], + 0, // sid + 1, // nBurst + sizeof(float), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, S, EVENT_ID0); + WAIT_FLAG(MTE2, S, EVENT_ID0); + float perTokenDescale = perTokenDescaleTensor.GetValue(0); + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + AscendC::Muls(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), perTokenDescale, + num_col_); + AscendC::PipeBarrier(); + AscendC::Cast(srcTensor, mmTensor.ReinterpretCast(), AscendC::RoundMode::CAST_RINT, num_col_); + AscendC::PipeBarrier(); + } + + Cast(fp32_xy, srcTensor[input_stride_], AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64, + num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); + AscendC::PipeBarrier(); + Mul(sqx, fp32_xy, fp32_xy, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE, + AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + Muls(sqx, sqx, avg_factor_, num_col_ - input_stride_); + AscendC::PipeBarrier(); + ReduceSumCustom(sum, sqx, work, num_col_ - input_stride_); + AscendC::PipeBarrier(); + Adds(sum, sum, epsilon_, 1); + AscendC::PipeBarrier(); + Sqrt(sum, sum, 1); + SET_FLAG(V, S, EVENT_ID0); + WAIT_FLAG(V, S, EVENT_ID0); + float factor = 1 / sum.GetValue(0); + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + Muls(fp32_xy, fp32_xy, factor, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + Mul(fp32_xy, fp32_xy, g, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE, + AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + if constexpr (WITH_BETA) { + AscendC::LocalTensor b = this->betaTensor; + Cast(work, b, AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64, + num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); + AscendC::PipeBarrier(); + Add(fp32_xy, fp32_xy, work, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE, + AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + } + + // q_down + // cast + + AscendC::Cast(inner, fp32_xy, AscendC::RoundMode::CAST_RINT, num_col_align_withStride_fp32); + + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + + // copy out; ************ + AscendC::DataCopy(innerGmTensor[gm_out_offset_ + outOffset], inner.ReinterpretCast(), + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / BLOCK_SIZE_16, 0, 0)); + SET_FLAG(MTE3, V, EVENT_ID0); + WAIT_FLAG(MTE3, V, EVENT_ID0); + AscendC::Duplicate(inner, (__bf16) 0, num_col_align_withStride_fp32);// qdown清空 + AscendC::PipeBarrier(); + + /* Quant start */ + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + Muls(fp32_xy, fp32_xy, input_scale_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + Adds(fp32_xy, fp32_xy, input_offset_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + } else if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + Abs(abs, fp32_xy, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + ReduceMax(max, abs, work, num_col_ - input_stride_); + AscendC::PipeBarrier(); + float scaleOut = max.GetValue(0) / 127; + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + Muls(fp32_xy, fp32_xy, (float)(1 / scaleOut), REPEAT_TIME_64, + num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + perTokenDescaleTensor.SetValue(0, scaleOut); + SET_FLAG(S, MTE3, EVENT_ID0); + WAIT_FLAG(S, MTE3, EVENT_ID0); + if constexpr (!NEED_DEQUANT) { + ub_to_gm_align(perTokenDescaleGmTensor[pid], perTokenDescaleTensor, 0, + 1, // nBurst + 1 * sizeof(float), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0 // dstGap + ); + } else { + ub_to_gm_align(perTokenDescaleGmTensor[num_row_ + pid], + perTokenDescaleTensor, 0, + 1, // nBurst + 1 * sizeof(float), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0 // dstGap + ); + } + SET_FLAG(MTE3, V, EVENT_ID0); + WAIT_FLAG(MTE3, V, EVENT_ID0); + } + + AscendC::LocalTensor tmpfp16 = + buf.ReinterpretCast()[OFFSET_SUM * num_col_align_withStride_fp32 * 2]; + CastFrom32To16(tmpfp16, fp32_xy, num_col_align_withStride_fp32); + AscendC::PipeBarrier(); + CastFromF16ToI8(dstTensor, tmpfp16, quantMin_, num_col_align_withStride_fp16); + AscendC::PipeBarrier(); + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + AscendC::DataCopy(outputGmTensor[gm_out_offset_ + outOffset], dstTensor, + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / 32, 0, 0)); + SET_FLAG(MTE3, MTE2, EVENT_ID0); + ++pid; + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + } + +private: + AscendC::LocalTensor dstTensor; + AscendC::LocalTensor srcTensor; + AscendC::LocalTensor gammaTensor; + AscendC::LocalTensor betaTensor; + AscendC::LocalTensor fp32_xy; + AscendC::LocalTensor buf; + AscendC::LocalTensor mmTensor; + AscendC::LocalTensor deScaleTensor; + + AscendC::GlobalTensor gammaGmTensor; + AscendC::GlobalTensor betaGmTensor; + AscendC::GlobalTensor quantScaleGmTensor; + AscendC::GlobalTensor quantOffsetGmTensor; + AscendC::GlobalTensor inputGmTensor; + AscendC::GlobalTensor outputGmTensor; + AscendC::GlobalTensor perTokenDescaleGmTensor; + AscendC::GlobalTensor perChannelDescaleGmTensor; + AscendC::GlobalTensor mmGmTensor; + AscendC::GlobalTensor innerGmTensor; + + uint32_t num_col_{0}; + uint32_t num_row_{0}; + uint32_t row_work_{0}; + uint32_t row_work{0}; + uint32_t row_step_{0}; + uint32_t row_tail_{0}; + uint64_t gm_offset_{0}; + uint64_t gm_out_offset_{0}; + float avg_factor_{1.0}; + float input_scale_{1.0}; + float input_offset_{0}; + int32_t input_stride_{0}; + float epsilon_{1e-12f}; + uint32_t num_col_align_int8{0}; + uint32_t num_col_align_f16{0}; + uint32_t num_col_align_f32{0}; + uint32_t num_col_align_f32_long{0}; + uint32_t num_col_align_withStride_int8{0}; + uint32_t num_col_align_withStride_fp16{0}; + uint32_t num_col_align_withStride_fp32{0}; + uint32_t num_col_temp; + half quantMin_{-128}; + uint32_t num_slice_{0}; + uint32_t tail_size_{0}; + uint32_t tail_copy_{0}; +}; + +template +class EinSumQuant +{ +public: + __aicore__ explicit EinSumQuant() {} + + __aicore__ __force_inline__ void Init(GM_ADDR einSumOutGm, GM_ADDR scaleGm, GM_ADDR quantOutGm, + const MlaTilingData &tilingData) + { + einSumOutGm_.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(einSumOutGm)); + scaleGm_.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(scaleGm)); + quantOutGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOutGm)); + + headNum = tilingData.esqHeadNum; + colNum = tilingData.esqColNum; + ubHeadLoop = tilingData.esqUbHeadLoop; + headPerLoop = tilingData.esqHeadPerLoop; + headTail = tilingData.esqHeadTail; + colLoop = tilingData.esqColLoop; + colTail = tilingData.esqColTail; + + currentIdx = (AscendC::GetBlockIdx() / 2) * 2 + GetSubBlockidx(); + if (currentIdx < tilingData.esqFrontCore) { + batchNum = tilingData.esqFrontCoreBatch; + currentCoreStartOffset = currentIdx * tilingData.esqFrontCoreBatch * headNum * colNum; + } else { + batchNum = tilingData.esqTailCoreBatch; + currentCoreStartOffset = (tilingData.esqFrontCore * tilingData.esqFrontCoreBatch + + (currentIdx - tilingData.esqFrontCore) * tilingData.esqTailCoreBatch) * + headNum * colNum; + } + calcRepeatStride = static_cast(colNum / ELE_NUM_FP32); + padLen = RoundUp(headNum, ELE_NUM_FP16); + calcLength = headPerLoop * colNum; + + // calc tensors' data size(bytes) and block + scaleBrcbFp32DataSize = padLen * ELE_NUM_FP32 * sizeof(float); + inputDataSize = calcLength * sizeof(InDtype); + inputDataBlock = calcLength * sizeof(InDtype) / BLOCK_SIZE_32; + inputFp32DataSize = calcLength * sizeof(float); + int8OutDataBlcok = calcLength / BLOCK_SIZE_32; + headTailDataBlock = headTail * colNum * sizeof(InDtype) / BLOCK_SIZE_32; + int8TailOutDataBlock = headTail * colNum / BLOCK_SIZE_32; + if (padLen > headNum) { + scaleCopyParams = AscendC::DataCopyExtParams(1, static_cast(headNum * sizeof(InDtype)), 0, 0, 0); + scalePadParams = AscendC::DataCopyPadExtParams(true, 0, static_cast(padLen - headNum), 0); + } + } + + __aicore__ __force_inline__ void Process() + { + if (batchNum == 0) { + return; + } + // init local tensor + scaleBrcbFp32_ = buf.GetBuffer(0); + inputTensor_ = buf.GetBuffer(scaleBrcbFp32DataSize); + inputFp32_ = + buf.GetBuffer(scaleBrcbFp32DataSize + inputDataSize * ROPE_CONCAT_NUM_BUFFER); + int8OutTensor_ = buf.GetBuffer( + scaleBrcbFp32DataSize + (inputDataSize + inputFp32DataSize) * ROPE_CONCAT_NUM_BUFFER); + + // scale copy in, cast, brcb[H, 1] --> [H, 8], use input ub space + if (headNum == padLen) { + AscendC::DataCopy(inputTensor_, scaleGm_, headNum); + } else { + AscendC::DataCopyPad(inputTensor_, scaleGm_, scaleCopyParams, scalePadParams); + } + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + AscendC::Cast(inputFp32_, inputTensor_, AscendC::RoundMode::CAST_NONE, padLen); + AscendC::PipeBarrier(); + AscendC::Brcb(scaleBrcbFp32_, inputFp32_, padLen / ELE_NUM_FP32, {1, 8}); + AscendC::PipeBarrier(); + + uint8_t pingFlag = 0; + // batch Loop + SET_FLAG(V, MTE2, EVENT_ID0); // input copy in wait vector release ub + SET_FLAG(V, MTE2, EVENT_ID1); + SET_FLAG(MTE3, V, EVENT_ID0); // quant calc wait last result copyout + SET_FLAG(MTE3, V, EVENT_ID1); + for (uint32_t batchIdx = 0; batchIdx < batchNum; batchIdx++) { + batchOffset = batchIdx * headNum * colNum; + // ub Loop + for (uint32_t ubLoopIdx = 0; ubLoopIdx < ubHeadLoop; ubLoopIdx++) { + scaleBrcbOffset = ubLoopIdx * headPerLoop * ELE_NUM_FP32; + inputLoopOffset = ubLoopIdx * headPerLoop * colNum; + calcStartOffset = currentCoreStartOffset + batchOffset + inputLoopOffset; + calcTmpOffset = pingFlag * calcLength; + + // input CopyIn and Cast + WAIT_FLAG(V, MTE2, pingFlag); + AscendC::DataCopy(inputTensor_[calcTmpOffset], einSumOutGm_[calcStartOffset], + {1, inputDataBlock, 0, 0}); + SET_FLAG(MTE2, V, pingFlag); + WAIT_FLAG(MTE2, V, pingFlag); + AscendC::Cast(inputFp32_[calcTmpOffset], inputTensor_[calcTmpOffset], AscendC::RoundMode::CAST_NONE, + calcLength); + AscendC::PipeBarrier(); + SET_FLAG(V, MTE2, pingFlag); + // quant calc + for (uint32_t colIdx = 0; colIdx < colLoop; colIdx++) { + colOffset = colIdx * CONST_64; + AscendC::Mul(inputFp32_[calcTmpOffset + colOffset], inputFp32_[calcTmpOffset + colOffset], + scaleBrcbFp32_[scaleBrcbOffset], CONST_64, headPerLoop, + {1, 1, 0, calcRepeatStride, calcRepeatStride, 1}); + } + AscendC::PipeBarrier(); + // quant fp32 --> fp16 --> int8 + CastFrom32To16(inputFp32_[calcTmpOffset].template ReinterpretCast(), inputFp32_[calcTmpOffset], + calcLength); + AscendC::PipeBarrier(); + WAIT_FLAG(MTE3, V, pingFlag); // wait last result copy out + CastFromF16ToI8(int8OutTensor_[calcTmpOffset], + inputFp32_[calcTmpOffset].template ReinterpretCast(), quantMin_, calcLength); + AscendC::PipeBarrier(); + SET_FLAG(V, MTE3, pingFlag); + WAIT_FLAG(V, MTE3, pingFlag); + // int8 CopyOut + AscendC::DataCopy(quantOutGm_[calcStartOffset], int8OutTensor_[calcTmpOffset], + {1, int8OutDataBlcok, 0, 0}); + SET_FLAG(MTE3, V, pingFlag); + pingFlag = 1 - pingFlag; + } + + // deal with head tail + if (headTail > 0) { + scaleBrcbOffset = ubHeadLoop * headPerLoop * ELE_NUM_FP32; + inputLoopOffset = ubHeadLoop * headPerLoop * colNum; + calcStartOffset = currentCoreStartOffset + batchOffset + inputLoopOffset; + calcTmpOffset = pingFlag * calcLength; + + // input CopyIn and Cast + WAIT_FLAG(V, MTE2, pingFlag); + AscendC::DataCopy(inputTensor_[calcTmpOffset], einSumOutGm_[calcStartOffset], + {1, headTailDataBlock, 0, 0}); + SET_FLAG(MTE2, V, pingFlag); + WAIT_FLAG(MTE2, V, pingFlag); + AscendC::Cast(inputFp32_[calcTmpOffset], inputTensor_[calcTmpOffset], AscendC::RoundMode::CAST_NONE, + headTail * colNum); + AscendC::PipeBarrier(); + SET_FLAG(V, MTE2, pingFlag); + // quant calc + for (uint32_t colIdx = 0; colIdx < colLoop; colIdx++) { + colOffset = colIdx * CONST_64; + AscendC::Mul(inputFp32_[calcTmpOffset + colOffset], inputFp32_[calcTmpOffset + colOffset], + scaleBrcbFp32_[scaleBrcbOffset], CONST_64, headTail, + {1, 1, 0, calcRepeatStride, calcRepeatStride, 1}); + } + AscendC::PipeBarrier(); + // quant fp32 --> fp16 --> int8 + CastFrom32To16(inputFp32_[calcTmpOffset].template ReinterpretCast(), inputFp32_[calcTmpOffset], + headTail * colNum); + AscendC::PipeBarrier(); + WAIT_FLAG(MTE3, V, pingFlag); // wait last result copy out + CastFromF16ToI8(int8OutTensor_[calcTmpOffset], + inputFp32_[calcTmpOffset].template ReinterpretCast(), quantMin_, + headTail * colNum); + AscendC::PipeBarrier(); + SET_FLAG(V, MTE3, pingFlag); + WAIT_FLAG(V, MTE3, pingFlag); + // int8 CopyOut + AscendC::DataCopy(quantOutGm_[calcStartOffset], int8OutTensor_[calcTmpOffset], + {1, int8TailOutDataBlock, 0, 0}); + SET_FLAG(MTE3, V, pingFlag); + pingFlag = 1 - pingFlag; + } + } + WAIT_FLAG(V, MTE2, EVENT_ID0); + WAIT_FLAG(V, MTE2, EVENT_ID1); + WAIT_FLAG(MTE3, V, EVENT_ID0); + WAIT_FLAG(MTE3, V, EVENT_ID1); + } + +private: + AsdopsBuffer buf; + + AscendC::GlobalTensor einSumOutGm_; + AscendC::GlobalTensor scaleGm_; + AscendC::GlobalTensor quantOutGm_; + + AscendC::LocalTensor scaleBrcbFp32_; + AscendC::LocalTensor inputTensor_; + AscendC::LocalTensor inputFp32_; + AscendC::LocalTensor int8OutTensor_; + + AscendC::DataCopyExtParams scaleCopyParams; + AscendC::DataCopyPadExtParams scalePadParams; + + // data processed by a single core[batchNum, headNum, colNum] + uint32_t batchNum; // The number of batches per kernel processed + uint32_t headNum; + uint32_t colNum; // Number of columns per row + // ub loop + uint32_t ubHeadLoop; // The number of times the UB loops through the head. + uint32_t headPerLoop; // The number of heads processed per UB cycle + uint32_t headTail; // The number of heads last processed + // col loop + uint32_t colLoop; // The number of calculations in the column direction cycle. + uint32_t colTail; // The number of cols last processed + + uint32_t currentIdx; + uint64_t currentCoreStartOffset; + uint32_t inputDataSize; // The size of each carry,bytes + uint32_t inputFp32DataSize; + uint32_t scaleBrcbFp32DataSize; + uint16_t inputDataBlock; // The number of blocks brought in per move,bytes + uint16_t int8OutDataBlcok; + uint16_t headTailDataBlock; + uint16_t int8TailOutDataBlock; + // gm offset + uint64_t inputLoopOffset{0}; + uint64_t batchOffset{0}; + uint64_t calcStartOffset{0}; + // double buffer tmp tensor length + uint32_t scaleBrcbOffset{0}; + uint32_t calcLength{0}; + uint32_t calcTmpOffset{0}; + + half quantMin_{-128}; + uint32_t colOffset{0}; + uint32_t padLen; + uint8_t calcRepeatStride; +}; + +#ifdef __DAV_C220_CUBE__ +struct MatCoord { + uint64_t m{0}; + uint64_t k{0}; + uint64_t n{0}; +}; + +template +class PpMatmulEinSum +{ + using AccumDtype = float; + + template + using CopyGmToCbuf = gm_to_l1; + using LoadCbufToCa = l1_to_l0_a; + using LoadCbufToCb = l1_to_l0_b; + using Mad = mmad; + using CopyCcToGm = l0c_to_gm; + + static constexpr uint32_t L0_PINGPONG_BUFFER_LEN = 16384; + static constexpr uint32_t L1_PINGPONG_BUFFER_LEN = 131072; + static constexpr uint32_t CONST_16 = 16; + static constexpr uint32_t CONST_256 = 256; + +public: + __aicore__ explicit PpMatmulEinSum(){}; + + __aicore__ __force_inline__ void Init(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, const MlaTilingData &mlaParams) + { +#ifdef __DAV_C220_CUBE__ + batch_size = mlaParams.mm3.numBatch; + m = mlaParams.mm3.m; + k = mlaParams.mm3.k; + n = mlaParams.mm3.n; + m0 = mlaParams.mm3.m0; + k0 = mlaParams.mm3.k0; + n0 = mlaParams.mm3.n0; + tdim.m = mlaParams.mm3.mLoop; + tdim.k = mlaParams.mm3.kLoop; + tdim.n = mlaParams.mm3.nLoop; + core_loop = mlaParams.mm3.coreLoop; + swizzle_cnt = mlaParams.mm3.swizzleCount; + num_core = mlaParams.mm3.blockDim; + core_idx = AscendC::GetBlockIdx(); + ping_flag = 1; + + gm_a.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmA)); + gm_b.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmB)); + gm_c.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(gmC)); + + AsdopsBuffer buf; + l1_base_a = buf.GetBuffer(0); + l1_base_b = buf.GetBuffer(RoundUp(m0 * k0 * sizeof(InDtype))); + l0a_base = buf.GetBuffer(0); + l0b_base = buf.GetBuffer(0); +#endif + return; + } + + __aicore__ __force_inline__ void Process() + { +#ifdef __DAV_C220_CUBE__ + if (block_idx >= num_core) { + WaitFlagDev(AIC_MM3_START); + return; + } + using LocalTensor = AscendC::LocalTensor; + + SET_FLAG(MTE1, MTE2, EVENT_ID0); + SET_FLAG(MTE1, MTE2, EVENT_ID1); + SET_FLAG(MTE1, MTE2, EVENT_ID2); + SET_FLAG(MTE1, MTE2, EVENT_ID3); + SET_FLAG(FIX, M, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID1); + + for (uint64_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += num_core) { + uint64_t batch_idx = loop_idx / tdim.n / tdim.m; + MatCoord tidx{0}; + GetBaseBlockIdx(loop_idx, tidx); + uint64_t offset_c = tidx.m * m0 * batch_size * (n + splitGapC) + batch_idx * (n + splitGapC) + tidx.n * n0; + uint64_t m_actual = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0; + uint64_t n_actual = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0; + uint64_t m_round = RoundUp(m_actual); + uint64_t n_round = RoundUp(n_actual); + uint64_t mn_max = m_round > n_round ? m_round : n_round; + uint64_t k_part_len = L0_PINGPONG_BUFFER_LEN / mn_max / CONST_16 * CONST_16; + uint64_t shuffle_k = en_shuffle_k ? (core_idx % tdim.k) : 0; + + uint64_t k_actual = (shuffle_k == tdim.k - 1) ? k - shuffle_k * k0 : k0; + uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16; + + LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + event_t event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (loop_idx == core_idx) { + WaitFlagDev(AIC_MM3_START); + + // Copy A from gm to l1 buffer + uint64_t offset_a = GetOffsetA(batch_idx, tidx.m, shuffle_k); + WAIT_FLAG(MTE1, MTE2, event_id); + CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual, k_round); + SET_FLAG(MTE2, MTE1, event_id); + + // Copy B from gm to l1 buffer + uint64_t offset_b = GetOffsetB(batch_idx, shuffle_k, tidx.n); + WAIT_FLAG(MTE1, MTE2, event_id + 2); + CopyTileB(l1_buf_b, gm_b[offset_b], k_actual, k_round, n_actual, n_round); + SET_FLAG(MTE2, MTE1, event_id + 2); + } + + for (tidx.k = 0; tidx.k < tdim.k; ++tidx.k) { + shuffle_k = en_shuffle_k ? (tidx.k + core_idx) % tdim.k : tidx.k; + uint64_t k_actual = (shuffle_k == (tdim.k - 1)) ? (k - shuffle_k * k0) : k0; + uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16; + fdim.k = (k_actual + k_part_len - 1) / k_part_len; + + LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (tidx.k < tdim.k - 1) { + uint64_t shuffle_k_next = en_shuffle_k ? (core_idx + tidx.k + 1) % tdim.k : (tidx.k + 1); + uint64_t offset_a_next = GetOffsetA(batch_idx, tidx.m, shuffle_k_next); + uint64_t offset_b_next = GetOffsetB(batch_idx, shuffle_k_next, tidx.n); + + uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0; + uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + + LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + // Preload A from gm to l1 buffer. + WAIT_FLAG(MTE1, MTE2, event_id_next); + CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual, m_round, k_actual_next, k_round_next); + SET_FLAG(MTE2, MTE1, event_id_next); + + // Preload B from gm to l1 buffer. + WAIT_FLAG(MTE1, MTE2, event_id_next + 2); + CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual, n_round); + SET_FLAG(MTE2, MTE1, event_id_next + 2); + } + + if (tidx.k == tdim.k - 1 && loop_idx + num_core < core_loop) { + uint64_t b_idx_next = (loop_idx + num_core) / tdim.n / tdim.m; + MatCoord tidx{0}; + GetBaseBlockIdx(loop_idx + num_core, tidx); + uint64_t shuffle_k_next = en_shuffle_k ? (core_idx % tdim.k) : 0; + uint64_t m_actual_next = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0; + uint64_t n_actual_next = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0; + uint64_t m_round_next = (m_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + uint64_t n_round_next = (n_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0; + uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + uint64_t offset_a_next = GetOffsetA(b_idx_next, tidx.m, shuffle_k_next); + uint64_t offset_b_next = GetOffsetB(b_idx_next, shuffle_k_next, tidx.n); + + LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + // Preload A from gm to l1 buffer. + WAIT_FLAG(MTE1, MTE2, event_id_next); + CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual_next, m_round_next, k_actual_next, + k_round_next); + SET_FLAG(MTE2, MTE1, event_id_next); + + // Preload B from gm to l1 buffer. + WAIT_FLAG(MTE1, MTE2, event_id_next + 2); + CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual_next, + n_round_next); + SET_FLAG(MTE2, MTE1, event_id_next + 2); + } + + MatCoord fidx{0}; + for (fidx.k = 0; fidx.k < fdim.k; ++fidx.k) { + uint32_t k0_round = (fidx.k < fdim.k - 1) ? k_part_len : k_round - fidx.k * k_part_len; + uint32_t k0_actual = (fidx.k < fdim.k - 1) ? k_part_len : k_actual - fidx.k * k_part_len; + + auto mte1_mad_ping_flag = 1 - fidx.k % 2; + auto mte1_mad_event_id = mte1_mad_ping_flag ? EVENT_ID0 : EVENT_ID1; + LocalTensor l0a_buf = l0a_base[(fidx.k & 0b1) * L0_PINGPONG_BUFFER_LEN]; + LocalTensor l0b_buf = l0b_base[(fidx.k & 0b1) * L0_PINGPONG_BUFFER_LEN]; + + // *** load matrix A from L1 to L0A + if (fidx.k == 0) { + WAIT_FLAG(MTE2, MTE1, event_id); + } + WAIT_FLAG(M, MTE1, mte1_mad_event_id); + if ((m == 1) || (m_actual == 1)) { + l1_to_l0_a( + l0a_buf, // dst + l1_buf_a[fidx.k * k_part_len], // src + 0, // mTileCeil + CeilDiv(k0_round), // kPartCeil + 0, // mSrcStride + 1, // kSrcStride + 0, // mDstStride + 0); // kDstStride + } else { + LoadCbufToCa(l0a_buf, // l0Tensor + l1_buf_a[fidx.k * k_part_len * m_round], // l1Tensor + m_round, // mTileCeil + k0_round, // kPartCeil + 1, // mSrcStride + m_round / CONST_16, // kSrcStride + k0_round / CONST_16, // mDstStride + 1); // kDstStride + } + if (fidx.k == fdim.k - 1) { + SET_FLAG(MTE1, MTE2, event_id); + } + + // *** load matrix B from L1 to L0B + if (fidx.k == 0) { + WAIT_FLAG(MTE2, MTE1, event_id + 2); + } + if constexpr (transB) { + LoadCbufToCb(l0b_buf, // l0Tensor + l1_buf_b[fidx.k * k_part_len * n_round], // l1Tensor + n_round, // nTileCeil + k0_round, // kPartCeil + 1, // nSrcStride + n_round / CONST_16, // kSrcStride + 1, // nDstStride + k0_round / CONST_16); // kDstStride + } else { + LoadCbufToCb(l0b_buf, // l0Tensor + l1_buf_b[fidx.k * k_part_len * CONST_16], // l1Tensor + n_round, // nTileCeil + k0_round, // kPartCeil + k_round / CONST_16, // nSrcStride + 1, // kSrcStride + 1, // nDstStride + n_round / CONST_16); // kDstStride + } + if (fidx.k == fdim.k - 1) { + SET_FLAG(MTE1, MTE2, event_id + 2); + } + + SET_FLAG(MTE1, M, mte1_mad_event_id); + WAIT_FLAG(MTE1, M, mte1_mad_event_id); + + bool init_c = (tidx.k == 0 && fidx.k == 0); + if (init_c) { + WAIT_FLAG(FIX, M, EVENT_ID0); + } + + Mad(l0c_buf, // c + l0a_buf, // a + l0b_buf, // b + m_actual, // mTileActual + n_actual, // nTileActual + k0_actual, // kTileActual + init_c); // initC + + PIPE_BARRIER(M); + SET_FLAG(M, MTE1, mte1_mad_event_id); + } + + ping_flag = 1 - ping_flag; + } + + SET_FLAG(M, FIX, EVENT_ID0); + WAIT_FLAG(M, FIX, EVENT_ID0); + + // copy from L0C to gm + CopyCcToGm(gm_c[offset_c], // dst + l0c_buf, // src + m_actual, // mTileActual + n_actual, // nTileActual + m_round, // mTileCeil + (n + splitGapC) * batch_size); // nActual + SET_FLAG(FIX, M, EVENT_ID0); + } + + WAIT_FLAG(M, MTE1, EVENT_ID0); + WAIT_FLAG(M, MTE1, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID0); + WAIT_FLAG(MTE1, MTE2, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID2); + WAIT_FLAG(MTE1, MTE2, EVENT_ID3); + WAIT_FLAG(FIX, M, EVENT_ID0); +#endif + } + +private: + __aicore__ __force_inline__ void GetBaseBlockIdx(uint64_t index, MatCoord &tidx) + { + uint64_t in_batch_idx = index % (tdim.m * tdim.n); + if constexpr (swizzleDirect == 0) { // Zn + uint64_t tile_block_loop = (tdim.m + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.n); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.n); + + uint64_t n_row = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_row = tdim.m - swizzle_cnt * tile_block_idx; + } + tidx.m = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row; + tidx.n = in_tile_block_idx / n_row; + if (tile_block_idx % 2 != 0) { + tidx.n = tdim.n - tidx.n - 1; + } + } else if constexpr (swizzleDirect == 1) { // Nz + uint64_t tile_block_loop = (tdim.n + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.m); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.m); + + uint64_t n_col = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_col = tdim.n - swizzle_cnt * tile_block_idx; + } + tidx.m = in_tile_block_idx / n_col; + tidx.n = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col; + if (tile_block_idx % 2 != 0) { + tidx.m = tdim.m - tidx.m - 1; + } + } + return; + } + + __aicore__ __force_inline__ uint64_t GetOffsetA(const uint64_t bIdx, const uint64_t mIdx, const uint64_t kIdx) + { + return mIdx * m0 * batch_size * (k + splitGapA) + bIdx * (k + splitGapA) + kIdx * k0; + } + + __aicore__ __force_inline__ uint64_t GetOffsetB(const uint64_t bIdx, const uint64_t kIdx, const uint64_t nIdx) + { + if constexpr (formatB == DataFormat::ND) { + if constexpr (transB) { + return bIdx * k * n + nIdx * n0 * k + kIdx * k0; + } else { + return bIdx * k * n + kIdx * k0 * n + nIdx * n0; + } + } else { + if constexpr (transB) { + return bIdx * RoundUp(n) * RoundUp(k) + kIdx * k0 * RoundUp(n) + + nIdx * n0 * CONST_16; + } else { + return bIdx * RoundUp(k) * RoundUp(n) + nIdx * n0 * RoundUp(k) + + kIdx * k0 * CONST_16; + } + } + } + + __aicore__ __force_inline__ void CopyTileA(AscendC::LocalTensor &dstTensor, + const AscendC::GlobalTensor &srcTensor, const uint64_t m_actual, + const uint64_t m_round, const uint64_t k_actual, const uint64_t k_round) + { + if ((m == 1) || (m_actual == 1)) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + 1, // nTileActual + CONST_16, // nTileCeil + 1, // nVal + k_actual, // kTileActual + k_round, // kTileCeil + k); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + m_actual, // nTileActual + m_round, // nTileCeil + m, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + (k + splitGapA) * batch_size); // dVal + } + } + + __aicore__ __force_inline__ void CopyTileB(AscendC::LocalTensor &dstTensor, + const AscendC::GlobalTensor &srcTensor, const uint64_t k_actual, + const uint64_t k_round, const uint64_t n_actual, const uint64_t n_round) + { + if constexpr (formatB == DataFormat::ND) { + if constexpr (transB) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + n_actual, // nTileActual + n_round, // nTileCeil + n, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + k); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + k_actual, // nTileActual + k_round, // nTileCeil + k, // nVal + n_actual, // dTileActual + n_round, // dTileCeil + n); // dVal + } + } else { + if constexpr (transB) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + n_actual, // nTileActual + n_round, // nTileCeil + RoundUp(n), // nVal + k_actual, // dTileActual + k_round, // dTileCeil + RoundUp(k)); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + k_actual, // nTileActual + k_round, // nTileCeil + RoundUp(k), // nVal + n_actual, // dTileActual + n_round, // dTileCeil + RoundUp(n)); // dVal + } + } + } + +private: + AscendC::GlobalTensor gm_a; + AscendC::GlobalTensor gm_b; + AscendC::GlobalTensor gm_c; + AscendC::LocalTensor l1_base_a; + AscendC::LocalTensor l1_base_b; + AscendC::LocalTensor l0a_base; + AscendC::LocalTensor l0b_base; + AscendC::LocalTensor l0c_buf; + + uint32_t num_core{0}; + uint32_t batch_size{0}; + uint32_t m{0}; + uint32_t k{0}; + uint32_t n{0}; + uint32_t m0{0}; + uint32_t k0{0}; + uint32_t n0{0}; + MatCoord tdim{0}; + MatCoord fdim{0}; + uint32_t core_loop{0}; + uint32_t swizzle_cnt{1}; + uint32_t core_idx{0}; + uint32_t en_shuffle_k{0}; + uint32_t ping_flag{0}; +}; + +template +class PpMatmulW8a8Aic +{ + using InDtype = int8_t; + using OutDtype = int32_t; + using AccumDtype = int32_t; + + template + using CopyGmToCbuf = gm_to_l1; + using LoadCbufToCa = l1_to_l0_a; + using LoadCbufToCb = l1_to_l0_b; + using Mmad = mmad; + using CopyCcToGm = l0c_to_gm; + + static constexpr uint64_t L0_PINGPONG_BUFFER_LEN = 32768; + static constexpr uint64_t L1_PINGPONG_BUFFER_LEN = 262144; + static constexpr uint64_t BLOCK_SIZE_16 = 16; + static constexpr uint64_t BLOCK_SIZE_32 = 32; + static constexpr uint64_t CUBE_MATRIX_SIZE_512 = 512; + static constexpr uint64_t CONST_4 = 4; + static constexpr uint64_t CONST_8 = 8; + static constexpr uint64_t CONST_32 = 32; + static constexpr uint64_t CONST_64 = 64; + static constexpr uint64_t CONST_128 = 128; + +public: + __aicore__ PpMatmulW8a8Aic() {}; + + __aicore__ __force_inline__ void Init(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, PpMatmulTilingData &tilingdata, + uint32_t mode) + { + gm_a.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmA)); + gm_b.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmB)); + gm_c.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(gmC)); + + batch_size = tilingdata.numBatch; + m = tilingdata.m; + k = tilingdata.k; + n = tilingdata.n; + m0 = tilingdata.m0; + k0 = tilingdata.k0; + n0 = tilingdata.n0; + m_loop = tilingdata.mLoop; + k_loop = tilingdata.kLoop; + n_loop = tilingdata.nLoop; + core_loop = tilingdata.coreLoop; + swizzle_cnt = tilingdata.swizzleCount; + en_shuffle_k = tilingdata.enShuffleK; + core_num = tilingdata.blockDim; + load_all_Amat_flag = tilingdata.enLoadAllAmat; + b0mat_pingpong_buffer_len = tilingdata.b0matPingPongBufferLen; + + core_idx = AscendC::GetBlockIdx(); + ping_flag = 1; + MM1_MM2_mode = mode; // MM1 or MM2 + + InitBuffer(); + return; + } + + __aicore__ __force_inline__ uint64_t GetOffsetA(const uint64_t batchIdx, const uint64_t mIdx, uint64_t kIdx) + { + return batchIdx * m * k + mIdx * m0 * k + kIdx * k0; + } + + __aicore__ __force_inline__ uint64_t GetOffsetB(const uint64_t batchIdx, const uint64_t kIdx, uint64_t nIdx) + { + if constexpr (formatB == DataFormat::ND) { + return batchIdx * k * n + nIdx * n0 * k + kIdx * k0; + } else { + return batchIdx * RoundUp<16>(n) * RoundUp<32>(k) + kIdx * k0 * RoundUp<16>(n) + nIdx * n0 * CONST_32; + } + } + + __aicore__ __force_inline__ void CopyTileA(AscendC::LocalTensor &dstTensor, + const AscendC::GlobalTensor &srcTensor, const uint64_t m_actual, + const uint64_t m_round, const uint64_t k_actual, const uint64_t k_round) + { + if ((m == 1) || (m_actual == 1)) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + 1, BLOCK_SIZE_16, 1, k_actual, k_round, k); + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + m_actual, // nTileActual + m_round, // nTileCeil + n, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + k); // dVal + } + } + + __aicore__ __force_inline__ void CopyTileB(const AscendC::LocalTensor &dstTensor, + const AscendC::GlobalTensor &srcTensor, const uint64_t k_actual, + const uint64_t k_round, const uint64_t n_actual, const uint64_t n_round) + { + if constexpr (formatB == DataFormat::ND) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + n_actual, // nTileActual + n_round, // nTileCeil + n, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + k); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + n_actual, // nTileActual + n_round, // nTileCeil + RoundUp<16>(n), // nVal + k_actual, // dTileActual + k_round, // dTileCeil + RoundUp<32>(k)); // dVal + } + } + + __aicore__ __force_inline__ void PreloadWeight() + { + if (core_idx < core_num) { + uint64_t m_idx = 0; + uint64_t n_idx = 0; + GetBaseBlockIdx(core_idx, m_idx, n_idx); + uint64_t shuffle_k = en_shuffle_k ? core_idx % k_loop : 0; + uint64_t offset_b = GetOffsetB(0, shuffle_k, n_idx); + uint64_t k_actual = (shuffle_k == k_loop - 1) ? k - shuffle_k * k0 : k0; + uint64_t k_round = RoundUp(k_actual); + uint64_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + uint64_t n_round = RoundUp(n_actual); + CopyTileB(l1_base_b, gm_b[offset_b], k_actual, k_round, n_actual, n_round); + } + if (core_idx < core_num && k_loop > 1) { + uint64_t m_idx = 0; + uint64_t n_idx = 0; + GetBaseBlockIdx(core_idx, m_idx, n_idx); + uint64_t shuffle_k = en_shuffle_k ? (core_idx + 1) % k_loop : 1; + uint64_t offset_b = GetOffsetB(0, shuffle_k, n_idx); + uint64_t k_actual = (shuffle_k == k_loop - 1) ? k - shuffle_k * k0 : k0; + uint64_t k_round = RoundUp(k_actual); + uint64_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + uint64_t n_round = RoundUp(n_actual); + CopyTileB(l1_base_b[b0mat_pingpong_buffer_len], gm_b[offset_b], k_actual, k_round, n_actual, n_round); + } + } + + __aicore__ __force_inline__ void Process(); + +private: + __aicore__ __force_inline__ void InitBuffer() + { + AsdopsBuffer buf; + l1_base_a = buf.template GetBuffer(0); + + // try load all A matrix + uint32_t a_l1_size = RoundUp(m) * RoundUp(k); + if (!load_all_Amat_flag) { + a_l1_size = RoundUp(m0 * k0); + } + + l1_base_b = l1_base_a[a_l1_size]; + l0a_base = buf.template GetBuffer(0); + l0b_base = buf.template GetBuffer(0); + l0c_buf = buf.template GetBuffer(0); + } + + __aicore__ __force_inline__ void GetBaseBlockIdx(uint64_t index, uint64_t &m_idx, uint64_t &n_idx) + { + uint64_t in_batch_idx = index % (m_loop * n_loop); + if constexpr (swizzleDir == 0) { // Zn + uint64_t tile_block_loop = (m_loop + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * n_loop); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * n_loop); + + uint64_t n_row = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_row = m_loop - swizzle_cnt * tile_block_idx; + } + m_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row; + n_idx = in_tile_block_idx / n_row; + if ((tile_block_idx & 0b1) != 0) { + n_idx = n_loop - n_idx - 1; + } + } else { // Nz + uint64_t tile_block_loop = (n_loop + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * m_loop); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * m_loop); + + uint64_t n_col = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_col = n_loop - swizzle_cnt * tile_block_idx; + } + m_idx = in_tile_block_idx / n_col; + n_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col; + if ((tile_block_idx & 0b1) != 0) { + m_idx = m_loop - m_idx - 1; + } + } + return; + } + +private: + AscendC::GlobalTensor gm_a; + AscendC::GlobalTensor gm_b; + AscendC::GlobalTensor gm_c; + + AscendC::LocalTensor l1_base_a; + AscendC::LocalTensor l1_base_b; + AscendC::LocalTensor l0a_base; + AscendC::LocalTensor l0b_base; + AscendC::LocalTensor l0c_buf; + + uint64_t bias_bt{0}; + uint32_t core_num{0}; + uint32_t batch_size{0}; + uint32_t m{0}; + uint32_t k{0}; + uint32_t n{0}; + uint32_t m0{0}; + uint32_t k0{0}; + uint32_t n0{0}; + uint32_t m_loop{0}; + uint32_t n_loop{0}; + uint32_t k_loop{0}; + uint32_t core_loop{0}; + uint32_t core_idx{0}; + uint32_t ping_flag{0}; + uint32_t swizzle_cnt{1}; + uint32_t en_shuffle_k{0}; + uint32_t MM1_MM2_mode{0}; + uint64_t b0mat_pingpong_buffer_len{0}; + bool load_all_Amat_flag{false}; +}; + +template +__aicore__ __force_inline__ void PpMatmulW8a8Aic::Process() +{ + using LocalTensor = AscendC::LocalTensor; + if (core_idx >= core_num) { + if (MM1_MM2_mode == 0) { + WaitFlagDev(AIC_MM1_START); + } else if (MM1_MM2_mode == 1) { + WaitFlagDev(AIC_MM2_START); + } + return; + } + SET_FLAG(MTE1, MTE2, EVENT_ID0); + SET_FLAG(MTE1, MTE2, EVENT_ID1); + SET_FLAG(MTE1, MTE2, EVENT_ID2); + SET_FLAG(MTE1, MTE2, EVENT_ID3); + SET_FLAG(M, MTE1, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID1); + SET_FLAG(FIX, M, EVENT_ID0); + SET_FLAG(FIX, MTE2, EVENT_ID0); + SET_FLAG(MTE1, MTE2, EVENT_ID7); + for (uint64_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += core_num) { + uint64_t batch_idx = loop_idx / n_loop / m_loop; + uint64_t m_idx = 0; + uint64_t n_idx = 0; + GetBaseBlockIdx(loop_idx, m_idx, n_idx); + uint64_t offset_a; + uint64_t offset_b; + uint64_t offset_bias; + uint64_t offset_a_next; + uint64_t offset_b_next; + uint64_t offset_c = batch_idx * m * n + m_idx * m0 * n + n_idx * n0; + uint64_t m_actual = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0; + uint64_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + uint64_t m_round = 0; + uint64_t n_round = 0; + uint64_t shuffle_k = en_shuffle_k ? core_idx % k_loop : 0; + uint64_t m_round_16 = RoundUp(m_actual); + uint64_t m_round_32 = RoundUp(m_actual); + m_round = m_round_16; + n_round = RoundUp(n_actual); + + uint64_t mn_max = m_round > n_round ? m_round : n_round; + uint64_t k_part_len = 0; + k_part_len = L0_PINGPONG_BUFFER_LEN / mn_max / BLOCK_SIZE_32 * BLOCK_SIZE_32; + + offset_b = GetOffsetB(batch_idx, shuffle_k, n_idx); + offset_bias = batch_idx * n + n_idx * n0; + + uint64_t k_actual = (shuffle_k == k_loop - 1) ? k - shuffle_k * k0 : k0; + uint64_t k_round = RoundUp(k_actual); + + auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + // Wait after Scalar + if (loop_idx == core_idx) { + if (MM1_MM2_mode == 0) { + WaitFlagDev(AIC_MM1_START); + } else if (MM1_MM2_mode == 1) { + WaitFlagDev(AIC_MM2_START); + } + } + + WAIT_FLAG(MTE1, MTE2, event_id); + LocalTensor l1_buf_a = + load_all_Amat_flag ? l1_base_a : (ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]); + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[b0mat_pingpong_buffer_len]; + if (load_all_Amat_flag) { + if (loop_idx == core_idx) { + offset_a = GetOffsetA(batch_idx, m_idx, 0); + uint64_t k_actual_first = k; + uint64_t k_round_first = RoundUp(k_actual_first); + CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual_first, k_round_first); + } + } else { + offset_a = GetOffsetA(batch_idx, m_idx, shuffle_k); + CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual, k_round); + } + SET_FLAG(MTE2, MTE1, event_id); + + WAIT_FLAG(MTE1, MTE2, event_id + CONST_2); + // The first weight matrix block is loaded in advance. + if (loop_idx != core_idx) { + CopyTileB(l1_buf_b, gm_b[offset_b], k_actual, k_round, n_actual, n_round); + } + SET_FLAG(MTE2, MTE1, event_id + CONST_2); + + for (uint64_t k_idx = 0; k_idx < k_loop; k_idx++) { + shuffle_k = en_shuffle_k ? (k_idx + core_idx) % k_loop : k_idx; + uint32_t k_actual = (shuffle_k == (k_loop - 1)) ? (k - shuffle_k * k0) : k0; + uint32_t k_round = RoundUp(k_actual); + uint32_t k_part_loop = (k_actual + k_part_len - 1) / k_part_len; + + // --------- load whole A in l1a addr change ------------- + LocalTensor l1_buf_a = load_all_Amat_flag ? (l1_base_a[k_idx * m0 * k0 * sizeof(int8_t)]) + : (ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]); + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[b0mat_pingpong_buffer_len]; + auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (k_idx < k_loop - 1) { + uint64_t shuffle_k_next = en_shuffle_k ? (core_idx + k_idx + 1) % k_loop : k_idx + 1; + + offset_b_next = GetOffsetB(batch_idx, shuffle_k_next, n_idx); + uint32_t k_actual_next = (shuffle_k_next == (k_loop - 1)) ? (k - shuffle_k_next * k0) : k0; + uint32_t k_round_next = RoundUp(k_actual_next); + + LocalTensor l1_buf_a_next = + load_all_Amat_flag ? l1_base_a : ((1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]); + LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[b0mat_pingpong_buffer_len]; + auto event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + WAIT_FLAG(MTE1, MTE2, event_id_next); + if (!load_all_Amat_flag) { + offset_a_next = GetOffsetA(batch_idx, m_idx, shuffle_k_next); + CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual, m_round, k_actual_next, k_round_next); + } + SET_FLAG(MTE2, MTE1, event_id_next); + + WAIT_FLAG(MTE1, MTE2, event_id_next + CONST_2); + // The second weight matrix is preloaded. + if (loop_idx != core_idx || k_idx != 0) { + CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual, n_round); + } + SET_FLAG(MTE2, MTE1, event_id_next + CONST_2); + } + + for (int k_part_idx = 0; k_part_idx < k_part_loop; k_part_idx++) { + uint32_t k0_round = (k_part_idx < k_part_loop - 1) ? k_part_len : k_round - k_part_idx * k_part_len; + uint32_t k0_actual = (k_part_idx < k_part_loop - 1) ? k_part_len : k_actual - k_part_idx * k_part_len; + + auto mte1_mad_ping_flag = 1 - k_part_idx % 2; + auto mte1_mad_event_id = mte1_mad_ping_flag ? EVENT_ID0 : EVENT_ID1; + AscendC::LocalTensor l0a_buf = l0a_base[(k_part_idx % 2) * L0_PINGPONG_BUFFER_LEN]; + AscendC::LocalTensor l0b_buf = l0b_base[(k_part_idx % 2) * L0_PINGPONG_BUFFER_LEN]; + + // *** load matrix A from L1 to L0A + if (k_part_idx == 0) { + WAIT_FLAG(MTE2, MTE1, event_id); + } + WAIT_FLAG(M, MTE1, mte1_mad_event_id); + if ((m == 1) || (m_actual == 1)) { + l1_to_l0_a( + l0a_buf, l1_buf_a[k_part_idx * k_part_len], + 0, // mTileCeil + CeilDiv(k0_round), // kPartCeil + 0, // mSrcStride + 1, // kSrcStride + 0, // mDstStride + 0); // kDstStride + } else { + LoadCbufToCa(l0a_buf, // l0Tensor + l1_buf_a[k_part_idx * k_part_len * m_round], // l1Tensor + m_round, // mTileCeil + k0_round, // kPartCeil + 1, // mSrcStride + m_round / BLOCK_SIZE_16, // kSrcStride + k0_round / BLOCK_SIZE_32, // mDstStride + 1); // kDstStride + } + if (k_part_idx == k_part_loop - 1) { + SET_FLAG(MTE1, MTE2, event_id); + } + + // *** load matrix B from L1 to L0B + if (k_part_idx == 0) { + WAIT_FLAG(MTE2, MTE1, event_id + CONST_2); + } + LoadCbufToCb(l0b_buf, // l0Tensor + l1_buf_b[k_part_idx * k_part_len * n_round], // l1Tensor + n_round, // nTileCeil + k0_round, // kPartCeil + 1, // nSrcStride + n_round / BLOCK_SIZE_16, // kSrcStride + 1, // nDstStride + k0_round / BLOCK_SIZE_32); // kDstStride + if (k_part_idx == k_part_loop - 1) { + SET_FLAG(MTE1, MTE2, event_id + CONST_2); + } + + SET_FLAG(MTE1, M, mte1_mad_event_id); + WAIT_FLAG(MTE1, M, mte1_mad_event_id); + + bool init_c = (k_idx == 0 && k_part_idx == 0); + if (init_c) { + WAIT_FLAG(FIX, M, EVENT_ID0); + } + Mmad(l0c_buf, l0a_buf, l0b_buf, + m_actual, // m + n_actual, // n + k0_actual, // k + init_c); // cmatrixInitVal + PIPE_BARRIER(M); + SET_FLAG(M, MTE1, mte1_mad_event_id); + } + + ping_flag = 1 - ping_flag; + } + SET_FLAG(M, FIX, EVENT_ID0); + WAIT_FLAG(M, FIX, EVENT_ID0); + // copy from L0C to gm + CopyCcToGm(gm_c[offset_c], // dst + l0c_buf, // src + m_actual, // MSize + n_actual, // NSize + m_round_16, // srcStride + n); // dstStride_dst_D + SET_FLAG(FIX, M, EVENT_ID0); + if constexpr (!withSyncAll) { + FftsCrossCoreSync(MMAIC); + if ((loop_idx / core_num + 1) % MAX_HW_SYNC_COUNTER == 0) { + WaitFlagDev(MMAIV); + } + } + } + + WAIT_FLAG(MTE1, MTE2, EVENT_ID0); + WAIT_FLAG(MTE1, MTE2, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID2); + WAIT_FLAG(MTE1, MTE2, EVENT_ID3); + WAIT_FLAG(M, MTE1, EVENT_ID0); + WAIT_FLAG(M, MTE1, EVENT_ID1); + WAIT_FLAG(FIX, M, EVENT_ID0); + WAIT_FLAG(FIX, MTE2, EVENT_ID0); + WAIT_FLAG(MTE1, MTE2, EVENT_ID7); +} + +#endif + +#if defined(__DAV_C220_VEC__) + +template +class PpMatmulW8a8Aiv +{ + using InDtype = int32_t; + using ScaleDtype = float; + using BiasDtype = int32_t; + +public: + __aicore__ PpMatmulW8a8Aiv() {}; + + __aicore__ __force_inline__ void Init(GM_ADDR gmInput, GM_ADDR gmOutput, GM_ADDR gmDescale, GM_ADDR gmPerTensorBias, + GM_ADDR gmPertokenDescale, const PpMatmulTilingData &gmTilingData) + { + gmInput_.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmInput)); + gmOutput_.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(gmOutput)); + gmPerTensorScale_.SetGlobalBuffer(reinterpret_cast<__gm__ ScaleDtype *>(gmDescale)); + gmPerTensorBias_.SetGlobalBuffer(reinterpret_cast<__gm__ BiasDtype *>(gmPerTensorBias)); + gmPerTokenScale_.SetGlobalBuffer(reinterpret_cast<__gm__ ScaleDtype *>(gmPertokenDescale)); + + batch_size = gmTilingData.numBatch; + m = gmTilingData.m; + k = gmTilingData.k; + n = gmTilingData.n; + m0 = gmTilingData.m0; + k0 = gmTilingData.k0; + n0 = gmTilingData.n0; + m_loop = gmTilingData.mLoop; + k_loop = gmTilingData.kLoop; + n_loop = gmTilingData.nLoop; + core_loop = gmTilingData.coreLoop; + swizzle_cnt = gmTilingData.swizzleCount; + swizzlDirect = gmTilingData.swizzleDirect; + en_shuffle_k = gmTilingData.enShuffleK; + + AsdopsBuffer buf; + ubInput_ = buf.GetBuffer(0); + ubTempFp32_ = buf.GetBuffer(94 * 1024); + ubOutput_ = buf.GetBuffer(0); + ubPerTensorScale_ = buf.GetBuffer(188 * 1024); + block_size = BLOCK_SIZE_32; + core_num = AscendC::GetBlockNum(); + core_idx = AscendC::GetBlockIdx() / 2; + ping_flag = 1; + } + + __aicore__ __force_inline__ void GetBlockIdx(uint32_t index, uint32_t &m_idx, uint32_t &n_idx) + { + uint32_t in_batch_idx = index % (m_loop * n_loop); + if (swizzlDirect == 0) { // Zn + uint32_t tile_block_loop = (m_loop + swizzle_cnt - 1) / swizzle_cnt; + uint32_t tile_block_idx = in_batch_idx / (swizzle_cnt * n_loop); + uint32_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * n_loop); + + uint32_t n_row = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_row = m_loop - swizzle_cnt * tile_block_idx; + } + m_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row; + n_idx = in_tile_block_idx / n_row; + if (tile_block_idx % 2 != 0) { + n_idx = n_loop - n_idx - 1; + } + } else { // Nz + uint32_t tile_block_loop = (n_loop + swizzle_cnt - 1) / swizzle_cnt; + uint32_t tile_block_idx = in_batch_idx / (swizzle_cnt * m_loop); + uint32_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * m_loop); + + uint32_t n_col = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_col = n_loop - swizzle_cnt * tile_block_idx; + } + m_idx = in_tile_block_idx / n_col; + n_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col; + if (tile_block_idx % 2 != 0) { + m_idx = m_loop - m_idx - 1; + } + } + } + + __aicore__ __force_inline__ void Process(); + +private: + AscendC::GlobalTensor gmPerTensorScale_; + AscendC::GlobalTensor gmPerTensorBias_; + AscendC::GlobalTensor gmPerTokenScale_; + AscendC::GlobalTensor gmInput_; + AscendC::GlobalTensor gmOutput_; + + AscendC::LocalTensor ubInput_; + AscendC::LocalTensor ubTempFp32_; + AscendC::LocalTensor ubOutput_; + AscendC::LocalTensor ubPerTensorScale_; + + uint32_t core_num{0}; + uint32_t batch_size{0}; + uint32_t m{0}; + uint32_t k{0}; + uint32_t n{0}; + uint32_t m0{0}; + uint32_t k0{0}; + uint32_t n0{0}; + uint32_t m_loop{0}; + uint32_t n_loop{0}; + uint32_t k_loop{0}; + uint32_t core_loop{0}; + uint32_t core_idx{0}; + uint32_t ping_flag{0}; + uint32_t block_size{0}; + uint32_t cube_matrix_size{0}; + uint32_t swizzle_cnt{1}; + uint32_t en_shuffle_k{0}; + uint32_t swizzlDirect{0}; + uint64_t L1_PINGPONG_BUFFER_LEN{0}; + uint32_t L0AB_PINGPONG_BUFFER_LEN{0}; +}; + +template +__aicore__ __force_inline__ void PpMatmulW8a8Aiv::Process() +{ + uint32_t m_idx = 0; + uint32_t n_idx = 0; + SET_FLAG(V, MTE2, EVENT_ID0); + SET_FLAG(MTE3, V, EVENT_ID0); + SET_FLAG(MTE3, MTE2, EVENT_ID0); + for (uint32_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += core_num) { + GetBlockIdx(loop_idx, m_idx, n_idx); + uint64_t batch_idx = loop_idx / n_loop / m_loop; + uint64_t offsetC = batch_idx * m * n + m_idx * m0 * n + n_idx * n0; + uint32_t m_actual = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0; + uint32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + uint32_t m_round = RoundUp(m_actual); + uint32_t n_round = RoundUp(n_actual); + uint32_t n_round_16 = RoundUp(n_actual); + uint32_t m_actual_per_vec = m_actual / AscendC::GetTaskRation(); + uint32_t m_offset = m + m_idx * m0; + if (GetSubBlockidx() != 0) { + offsetC += m_actual_per_vec * n; + m_offset += m_actual_per_vec; + m_actual_per_vec = m_actual - m_actual_per_vec; + } + + if constexpr (!withSyncAll) { + if (m_actual_per_vec == 0) { + WaitFlagDev(MMAIC); + if ((loop_idx / core_num + 1) % MAX_HW_SYNC_COUNTER == 1) { + FftsCrossCoreSync(MMAIV); + } + continue; + } + } + + uint64_t offsetScale = batch_idx * n + n_idx * n0; + bool aligned_s32 = ((n & 0b111) == 0); // 32B aligned + bool aligned_f16 = ((n & 0b1111) == 0); // 32B aligned + WAIT_FLAG(V, MTE2, EVENT_ID0); + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + if (aligned_s32) { + gm_to_ub(ubPerTensorScale_.ReinterpretCast(), + gmPerTensorBias_[offsetScale], + 0, // sid + 1, // nBurst + n_round * sizeof(BiasDtype) / BLOCK_SIZE_32, // lenBurst + 0, // srcStride + 0); // dstStride + } else { + gm_to_ub_align(ubPerTensorScale_.ReinterpretCast(), + gmPerTensorBias_[offsetScale], + 0, // sid + 1, // nBurst + n_actual * sizeof(float), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0); // dstGap + } + } else { + if (aligned_s32) { + gm_to_ub(ubPerTensorScale_, gmPerTensorScale_[offsetScale], + 0, // sid + 1, // nBurst + n_round * 4 / BLOCK_SIZE_32, // lenBurst + 0, // srcStride + 0); // dstStride + } else { + gm_to_ub_align(ubPerTensorScale_, gmPerTensorScale_[offsetScale], + 0, // sid + 1, // nBurst + n_actual * sizeof(float), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0); // dstGap + } + } + + if constexpr (!withSyncAll) { + WaitFlagDev(MMAIC); + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + if (aligned_s32) { + gm_to_ub(ubInput_, gmInput_[offsetC], + 0, // sid + m_actual_per_vec, // nBurst + n_round / 8, // lenBurst + (n - n_round) / 8, // srcStride + 0 // dstStride + ); + } else { + gm_to_ub_align(ubInput_, gmInput_[offsetC], + 0, // sid + m_actual_per_vec, // nBurst + n_actual * sizeof(int32_t), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + (n - n_actual) * sizeof(int32_t), // srcGap + 0 // dstGap + ); + } + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + + WAIT_FLAG(MTE3, V, EVENT_ID0); + uint32_t nRepeatCnt = CeilDiv(n_actual); + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + AscendC::SetMaskCount(); + AscendC::SetVectorMask(n_round); + for (uint32_t i = 0; i < m_actual_per_vec; ++i) { + // add_v(ubInput_[i * n_round], + // ubInput_[i * n_round], + // ubPerTensorScale_.ReinterpretCast(), + // (uint8_t)(nRepeatCnt), // repeat + // (uint8_t)1, // dstBlockStride + // (uint8_t)1, // src0BlockStride + // (uint8_t)1, // src1BlockStride + // (uint8_t)8, // dstRepeatStride + // (uint8_t)8, // src0RepeatStride + // (uint8_t)8 // src1RepeatStride + // ); + AscendC::Add(ubInput_[i * n_round], ubInput_[i * n_round], + ubPerTensorScale_.ReinterpretCast(), + AscendC::MASK_PLACEHOLDER, 1, + AscendC::BinaryRepeatParams((uint8_t)1, (uint8_t)1, (uint8_t)1, + (uint8_t)8, (uint8_t)8, (uint8_t)8)); + } + AscendC::ResetMask(); + SetMasknorm(); + + SET_FLAG(V, MTE2, EVENT_ID0); + WAIT_FLAG(V, MTE2, EVENT_ID0); + if (aligned_s32) { + gm_to_ub(ubPerTensorScale_, gmPerTensorScale_[offsetScale], + 0, // sid + 1, // nBurst + n_round * sizeof(ScaleDtype) / BLOCK_SIZE_32, // lenBurst + 0, // srcStride + 0 // dstStride + ); + } else { + gm_to_ub_align(ubPerTensorScale_, gmPerTensorScale_[offsetScale], + 0, // sid + 1, // nBurst + n_actual * sizeof(ScaleDtype), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0 // dstGap + ); + } + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + } + + // CASTF32 * f32 tf16 + constexpr uint32_t maxRepeat = 255; + constexpr uint32_t perRepeatNum = maxRepeat * 64; + uint32_t loopCnt = (m_actual_per_vec * n_actual + perRepeatNum - 1) / perRepeatNum; + for (uint32_t i = 0; i < loopCnt; i++) { + conv_v(ubInput_.ReinterpretCast()[perRepeatNum * i], + ubInput_[perRepeatNum * i], + (uint8_t)maxRepeat, // repeat + (uint16_t)1, // dstBlockStride + (uint16_t)1, // srcBlockStride + (uint16_t)8, // dstRepeatStride + (uint16_t)8 // srcRepeatStride + ); + } + AscendC::PipeBarrier(); + + for (uint32_t i = 0; i < m_actual_per_vec; ++i) { + mul_v(ubTempFp32_[i * n_round], + ubInput_.ReinterpretCast()[i * n_round], + ubPerTensorScale_.ReinterpretCast(), + (uint8_t)(nRepeatCnt), // repeat + (uint8_t)1, // dstBlockStride + (uint8_t)1, // src0BlockStride + (uint8_t)1, // src1BlockStride + (uint8_t)8, // dstRepeatStride + (uint8_t)8, // src0RepeatStride + (uint8_t)8 // src1RepeatStride + ); + if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + AscendC::PipeBarrier(); + float perTokenDescale = gmPerTokenScale_.GetValue(m_offset + i); + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + AscendC::Muls(ubTempFp32_[i * n_round], ubTempFp32_[i * n_round], perTokenDescale, n_round); + } + AscendC::PipeBarrier(); + } + SET_FLAG(V, MTE2, EVENT_ID0); + AscendC::PipeBarrier(); + if (n_actual % 16 > 8) { + for (uint32_t i = 0; i < loopCnt; i++) { + if constexpr (std::is_same_v) { + convr_v(ubOutput_[perRepeatNum * i], + ubTempFp32_[perRepeatNum * i], + (uint8_t)maxRepeat, // repeat + (uint16_t)1, // dstBlockStride + (uint16_t)1, // srcBlockStride + (uint16_t)4, // dstRepeatStride + (uint16_t)8); // srcRepeatStride + } else { + conv_v(ubOutput_[perRepeatNum * i], + ubTempFp32_[perRepeatNum * i], + (uint8_t)maxRepeat, // repeat + (uint16_t)1, // dstBlockStride + (uint16_t)1, // srcBlockStride + (uint16_t)4, // dstRepeatStride + (uint16_t)8); // srcRepeatStride + } + } + } else { + for (uint32_t i = 0; i < m_actual_per_vec; i++) { + if constexpr (std::is_same_v) { + convr_v(ubOutput_[n_round_16 * i], ubTempFp32_[n_round * i], + (uint8_t)nRepeatCnt, // repeat + (uint16_t)1, // dstBlockStride + (uint16_t)1, // srcBlockStride + (uint16_t)4, // dstRepeatStride + (uint16_t)8); // srcRepeatStride + } else { + conv_v(ubOutput_[n_round_16 * i], ubTempFp32_[n_round * i], + (uint8_t)nRepeatCnt, // repeat + (uint16_t)1, // dstBlockStride + (uint16_t)1, // srcBlockStride + (uint16_t)4, // dstRepeatStride + (uint16_t)8); // srcRepeatStride + } + } + } + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + if (aligned_f16) { + ub_to_gm(gmOutput_[offsetC], ubOutput_, 0, + m_actual_per_vec, // nBurst + n_round / 16, // lenBurst + 0, // srcStride + (n - n_round) / 16 // dstStride + ); + } else { + ub_to_gm_align(gmOutput_[offsetC], ubOutput_, 0, + m_actual_per_vec, // nBurst + n_actual * sizeof(OutDtype), // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + (n - n_actual) * sizeof(OutDtype) // dstGap + ); + } + SET_FLAG(MTE3, V, EVENT_ID0); + SET_FLAG(MTE3, MTE2, EVENT_ID0); + if constexpr (!withSyncAll) { + if ((loop_idx / core_num + 1) % MAX_HW_SYNC_COUNTER == 1) { + FftsCrossCoreSync(MMAIV); + } + } + } + WAIT_FLAG(V, MTE2, EVENT_ID0); + WAIT_FLAG(MTE3, V, EVENT_ID0); + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); +} +#endif + +template +class MLAOperation +{ + static constexpr bool mm1WithSyncAll = (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT); + static constexpr uint64_t splitGapC = CACHE_MODE == CACHE_MODE_KVCACHE ? CONST_64 : CONST_0; + using Q_OUT_DTYPE = typename std::conditional_t; + using K_NOPE_DTYPE = typename std::conditional_t; + +public: + __aicore__ inline MLAOperation(const MlaTilingData &mlaParams_, GM_ADDR tilingGm) + { + blockIdx = AscendC::GetBlockIdx(); +#ifdef __DAV_C220_VEC__ + sub_block_idx = static_cast(GetSubBlockidx()); +#endif + vectorBlockIdx = (blockIdx / 2) * 2 + sub_block_idx; + this->n = mlaParams_.n; + this->num_core_ = mlaParams_.rmsNumCore1; + this->num_col_1 = mlaParams_.rmsNumCol1; + this->num_col_2 = mlaParams_.rmsNumCol2; + this->num_row = mlaParams_.n; + this->epsilon_ = 1e-6; + this->mlaParams = mlaParams_; + } + + __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm, + GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm, + GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm, + GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm, + GM_ADDR slotMappingGm, GM_ADDR wuqGm, GM_ADDR bias2Gm, GM_ADDR wukGm, + GM_ADDR descale1Gm, GM_ADDR descale2Gm, GM_ADDR gmCtkvScale, GM_ADDR gmQnopeScale, + GM_ADDR qGm, GM_ADDR keycacheOutGm, GM_ADDR qGm2, GM_ADDR keycacheOutGm2, GM_ADDR s1Gm, + GM_ADDR s2Gm, GM_ADDR s3Gm, GM_ADDR s4Gm, GM_ADDR s5Gm, GM_ADDR innerGm) + { + quantScale3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmCtkvScale)); + gamma3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gamma3Gm)); + sin1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(sin1Gm)); + cos1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(cos1Gm)); + keycacheGmTensor1.SetGlobalBuffer(reinterpret_cast<__gm__ K_NOPE_DTYPE *>(keycacheOutGm)); + keycacheGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(keycacheOutGm2)); + slotMappingGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(slotMappingGm)); + descale1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(descale1Gm)); + s2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(s2Gm)); + s3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(s3Gm)); + s5GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(s5Gm)); + innerGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ bfloat16_t *>(innerGm)); + +#ifdef __DAV_C220_CUBE__ + mm_w8a8_aic_1.Init(s1Gm, wdqkvGm, s2Gm, mlaParams.mm1, 0); + mm_w8a8_aic_1.PreloadWeight(); + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + mm_w8a8_aic_2.Init(s1Gm, wuqGm, s2Gm, mlaParams.mm2, 1); + } else { + // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT + mm_w8a8_aic_2.Init(s1Gm, wuqGm, s3Gm, mlaParams.mm2, 1); + } + if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { + mm_ein_sum.Init(s4Gm, wukGm, s1Gm, mlaParams); + } else { + mm_ein_sum.Init(s4Gm, wukGm, qGm, mlaParams); + } +#endif + + hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(hiddenStateGm)); + quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(quantScale1Gm)); + quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm)); + wdqkvGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wdqkvGm)); + gamma2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gamma2Gm)); + quantScale2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(quantScale2Gm)); + quantOffset2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset2Gm)); + sin2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(sin2Gm)); + cos2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(cos2Gm)); + wuqGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wuqGm)); + wukGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(wukGm)); + descale2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(descale2Gm)); + s1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(s1Gm)); + s4GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(s4Gm)); + qGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ Q_OUT_DTYPE *>(qGm)); + qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(qGm2)); + bias1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias1Gm)); + bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm)); + beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta2Gm)); + +#ifdef __DAV_C220_VEC__ + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + mm_w8a8_aiv_1.Init(s2Gm, s3Gm, descale1Gm, bias1Gm, s5Gm, mlaParams.mm1); + mm_w8a8_aiv_2.Init(s2Gm, s4Gm, descale2Gm, bias2Gm, s5Gm, mlaParams.mm2); + } else { + // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT + mm_w8a8_aiv_2.Init(s3Gm, s4Gm, descale2Gm, bias2Gm, s5Gm, mlaParams.mm2); + } + row_work = (num_row + num_core_ - 1) / num_core_; + row_work_ = 0; + uint32_t need_core = (num_row + row_work - 1) / row_work; + if (vectorBlockIdx < need_core - 1) { + row_work_ = row_work; + } else if (vectorBlockIdx == need_core - 1) { + row_work_ = num_row - (need_core - 1) * row_work; + } else { + row_work_ = 0; + } + this->splitN = mlaParams.perTaskNum; + Quant1.Init(quantScale1GmTensor, quantOffset1GmTensor, s5Gm + row_work * vectorBlockIdx * sizeof(float), + descale1Gm, hiddenStateGm, s1Gm, 0, num_col_1, + vectorBlockIdx * static_cast(row_work) * num_col_1, + vectorBlockIdx * static_cast(row_work) * num_col_1, row_work_, mlaParams); + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor, + s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s3Gm, s1Gm, SPLIT_SIZE_ONE, + num_col_2, 0.000651041666, vectorBlockIdx * static_cast(row_work) * num_col_2, + vectorBlockIdx * static_cast(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams, innerGmTensor); + } else { + // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT + rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor, + s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s2Gm, s1Gm, SPLIT_SIZE_ONE, + num_col_2, 0.000651041666, vectorBlockIdx * static_cast(row_work) * num_col_2, + vectorBlockIdx * static_cast(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams, innerGmTensor); + } + ropeFp16.RopeInit(s4Gm, cos2GmTensor, sin2GmTensor, qGmTensor, qGmTensor2, mlaParams); + einSumQuant.Init(s1Gm, gmQnopeScale, qGm, mlaParams); +#endif + } + + __aicore__ inline void ProcessCube(); + + __aicore__ inline void ProcessVector(); + +private: + constexpr static uint32_t C0_SIZE = 16; + constexpr static uint32_t I8_C0_SIZE = 32; + + template + __aicore__ inline void RmsNormAndRopeConvergence1( + const AscendC::LocalTensor &srcTensor, const AscendC::LocalTensor &gammaTensor, + const AscendC::LocalTensor &sinTensor, const AscendC::LocalTensor &cosTensor, + const AscendC::LocalTensor &slotMappingTensor, const uint32_t sN, + const AscendC::LocalTensor &rmsNormTensor, const AscendC::LocalTensor &gammaFp32, + const AscendC::LocalTensor &ropeKTensor, const AscendC::LocalTensor &ropeKRevertTensor, + const AscendC::LocalTensor &calTensor, const AscendC::LocalTensor &outTmpTensor, + AscendC::LocalTensor &tmpfp16, AscendC::LocalTensor &int8OutTensor, float quantScale3) + { + int64_t slotMapGmOffset = vectorBlockIdx * row_work; + AscendC::DataCopy(gammaTensor, gamma3GmTensor, SPLIT_RMSNRORM_SIZE_ONE); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::DataCopyPad(slotMappingTensor, slotMappingGmTensor[slotMapGmOffset], + AscendC::DataCopyExtParams(1, sN * sizeof(int32_t), 0, 0, 0), + AscendC::DataCopyPadExtParams(false, 0, 8 - sN % 8, 0)); + if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + mmTensor = calTensor.ReinterpretCast()[SPLIT_SIZE_ONE]; + deScaleTensor = calTensor.ReinterpretCast()[SPLIT_SIZE_ONE * 2]; + AscendC::DataCopy(deScaleTensor, descale1gmTensor, AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0)); + } + SET_FLAG(MTE2, V, EVENT_ID2); + WAIT_FLAG(MTE2, V, EVENT_ID2); + SET_FLAG(MTE2, S, EVENT_ID2); + WAIT_FLAG(MTE2, S, EVENT_ID2); + for (uint64_t loop = 0; loop < sN; ++loop) { + uint64_t offset = vectorBlockIdx * static_cast(row_work) * num_col_2 + loop * MM1_OUT_SIZE; + int64_t slotValue = static_cast(slotMappingTensor.GetValue(loop)); + if (slotValue == -1) { + continue; + } + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + AscendC::DataCopy(srcTensor, s3GmTensor[offset], + AscendC::DataCopyParams(1, MM1_OUT_SIZE / BLOCK_SIZE_16, 0, 0)); + } else { + // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT + AscendC::DataCopy(mmTensor, s2GmTensor[offset], AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0)); + } + AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], + SPLIT_RMSNRORM_SIZE_TWO); + AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], + SPLIT_RMSNRORM_SIZE_TWO); + SET_FLAG(MTE2, V, EVENT_ID0); + // ND + uint64_t cacheStart = static_cast(slotValue) * static_cast(SPLIT_SIZE_ONE); + uint64_t cacheStart1 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_ONE); + uint64_t cacheStart2 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_TWO); + // NZ + uint32_t outer_idx = slotValue / 128; + uint32_t inner_idx = slotValue % 128; + + SET_FLAG(S, MTE3, EVENT_ID0); + /* RmsNorm start */ + WAIT_FLAG(MTE2, V, EVENT_ID0); + if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + /* DeQuant */ + AscendC::Cast(mmTensor.ReinterpretCast(), mmTensor, AscendC::RoundMode::CAST_NONE, + SPLIT_SIZE_ONE); + AscendC::PipeBarrier(); + AscendC::Mul(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), deScaleTensor, + SPLIT_SIZE_ONE); + AscendC::PipeBarrier(); + float perTokenDescale = s5GmTensor.GetValue(row_work * vectorBlockIdx + loop); + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + AscendC::Muls(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), perTokenDescale, + SPLIT_SIZE_ONE); + AscendC::PipeBarrier(); + AscendC::Cast(srcTensor, mmTensor.ReinterpretCast(), AscendC::RoundMode::CAST_RINT, + SPLIT_SIZE_ONE); + AscendC::PipeBarrier(); + } + Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + Mul(calTensor, rmsNormTensor, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + ReduceSumCustom(calTensor[SPLIT_RMSNRORM_SIZE_ONE], calTensor, calTensor[SPLIT_RMSNRORM_SIZE_ONE * 2], + SPLIT_RMSNRORM_SIZE_ONE); + SET_FLAG(V, S, EVENT_ID1); + WAIT_FLAG(V, S, EVENT_ID1); + float rms = sqrt(calTensor.GetValue(SPLIT_RMSNRORM_SIZE_ONE) / SPLIT_RMSNRORM_SIZE_ONE + epsilon_); + SET_FLAG(S, V, EVENT_ID1); + WAIT_FLAG(S, V, EVENT_ID1); + AscendC::PipeBarrier(); + Duplicate(calTensor, rms, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + Div(calTensor, rmsNormTensor, calTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + Mul(rmsNormTensor, gammaFp32, calTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { + // quant + Muls(rmsNormTensor, rmsNormTensor, quantScale3, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + CastFrom32To16(tmpfp16, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + CastFromF16ToI8(int8OutTensor, tmpfp16, -128, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + } else { + AscendC::PipeBarrier(); + if (std::is_same::value) { + Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_ONE); + } else { + Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + } + } + /* RmsNorm end */ + /* Rope K start */ + uint64_t revertOffset = SPLIT_RMSNRORM_SIZE_TWO / 2; + Cast(ropeKTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, + SPLIT_RMSNRORM_SIZE_TWO); + Cast(ropeKRevertTensor[revertOffset], srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, + revertOffset); + Cast(ropeKRevertTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE + revertOffset], AscendC::RoundMode::CAST_NONE, + revertOffset); + Duplicate(calTensor, static_cast(-1), revertOffset); + Duplicate(calTensor[revertOffset], static_cast(1), revertOffset); + AscendC::PipeBarrier(); + Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO], cosTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO); + Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], sinTensor, AscendC::RoundMode::CAST_NONE, + SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + Mul(ropeKTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO], ropeKTensor, SPLIT_RMSNRORM_SIZE_TWO); + Mul(ropeKRevertTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + if (std::is_same::value) { + Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT, + SPLIT_RMSNRORM_SIZE_TWO); + } else { + Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE, + SPLIT_RMSNRORM_SIZE_TWO); + } + AscendC::PipeBarrier(); + /* Rope K end */ + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(S, MTE3, EVENT_ID0); + if constexpr (CACHE_MODE == CACHE_MODE_KVCACHE) { + DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, SPLIT_SIZE_ONE); + } else if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { + uint64_t cacheSatartI8Nz1 = outer_idx * 128 * 512 + inner_idx * I8_C0_SIZE; + uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE; + // nope:int8 nz + AscendC::DataCopyExtParams outExt; + outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / I8_C0_SIZE; + outExt.blockLen = I8_C0_SIZE * sizeof(int8_t); + outExt.srcStride = 0; + outExt.dstStride = (128 * I8_C0_SIZE - I8_C0_SIZE) * sizeof(int8_t); + DataCopyPad(keycacheGmTensor1[cacheSatartI8Nz1], int8OutTensor, outExt); + // rope:T1 nz + outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE; + outExt.blockLen = C0_SIZE * sizeof(T1); + outExt.srcStride = 0; + outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); + DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt); + } else if constexpr (CACHE_MODE == CACHE_MODE_NZCACHE) { + uint64_t cacheSatartNz1 = outer_idx * 128 * 512 + inner_idx * C0_SIZE; + uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE; + // nope:T1 nz + AscendC::DataCopyExtParams outExt; + outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / C0_SIZE; + outExt.blockLen = C0_SIZE * sizeof(T1); + outExt.srcStride = 0; + outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); + DataCopyPad(keycacheGmTensor1[cacheSatartNz1], outTmpTensor, outExt); + // rope:T1 nz + outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE; + outExt.blockLen = C0_SIZE * sizeof(T1); + outExt.srcStride = 0; + outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); + DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt); + } else { + // keycache1 + DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, SPLIT_RMSNRORM_SIZE_ONE); + // keycache2 + DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], + SPLIT_RMSNRORM_SIZE_TWO); + } + SET_FLAG(MTE3, MTE2, EVENT_ID1); + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + } + } + +private: + uint32_t n; + uint32_t splitN; + uint32_t rotaryCoeff; + uint32_t blockIdx; + uint32_t sub_block_idx; + uint32_t vectorBlockIdx; + uint32_t blockOffset; + uint32_t perTaskNum; + uint32_t resTaskNum; + MlaTilingData mlaParams; + + uint32_t num_core_; + uint32_t num_col_1; + uint32_t num_col_2; + float epsilon_; + uint32_t num_row; + uint32_t quantMin_; + uint32_t row_work; + uint32_t row_work_; + + AsdopsBuffer buf; + AscendC::LocalTensor mmTensor; + AscendC::LocalTensor deScaleTensor; + + AscendC::GlobalTensor hiddenStateGmTensor; + + AscendC::GlobalTensor quantScale1GmTensor; + AscendC::GlobalTensor quantOffset1GmTensor; + + AscendC::GlobalTensor wdqkvGmTensor; + AscendC::GlobalTensor gamma2GmTensor; + AscendC::GlobalTensor quantScale2GmTensor; + AscendC::GlobalTensor quantScale3GmTensor; + AscendC::GlobalTensor quantOffset2GmTensor; + AscendC::GlobalTensor gamma3GmTensor; + AscendC::GlobalTensor sin1GmTensor; + AscendC::GlobalTensor cos1GmTensor; + AscendC::GlobalTensor sin2GmTensor; + AscendC::GlobalTensor cos2GmTensor; + AscendC::GlobalTensor keycacheGmTensor1; + AscendC::GlobalTensor keycacheGmTensor2; + AscendC::GlobalTensor slotMappingGmTensor; + AscendC::GlobalTensor wuqGmTensor; + AscendC::GlobalTensor wukGmTensor; + + // cachemode2-->int8; else bf16 + AscendC::GlobalTensor qGmTensor; + AscendC::GlobalTensor qGmTensor2; + AscendC::GlobalTensor s1GmTensor; + AscendC::GlobalTensor s2GmTensor; + AscendC::GlobalTensor s3GmTensor; + AscendC::GlobalTensor s4GmTensor; + AscendC::GlobalTensor s5GmTensor; + AscendC::GlobalTensor innerGmTensor; + AscendC::GlobalTensor descale1gmTensor; + AscendC::GlobalTensor descale2gmTensor; + AscendC::GlobalTensor beta2GmTensor; + + AscendC::GlobalTensor bias1gmTensor; + AscendC::GlobalTensor bias2gmTensor; + +#ifdef __DAV_C220_CUBE__ + PpMatmulW8a8Aic mm_w8a8_aic_1; + PpMatmulW8a8Aic mm_w8a8_aic_2; + PpMatmulEinSum mm_ein_sum; +#endif + +#ifdef __DAV_C220_VEC__ + PpMatmulW8a8Aiv mm_w8a8_aiv_1; + PpMatmulW8a8Aiv mm_w8a8_aiv_2; + + Quant Quant1; + + RmsNormQuant rmsNormQuant2; + RopeFp16 ropeFp16; + EinSumQuant einSumQuant; +#endif +}; + +template +__aicore__ inline void +MLAOperation::ProcessCube() +{ +#ifdef __DAV_C220_CUBE__ + mm_w8a8_aic_1.Process(); + if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + FftsCrossCoreSync(MMAIC); + WaitFlagDev(MMAIC); + FftsCrossCoreSync(MMAIV); + } + + mm_w8a8_aic_2.PreloadWeight(); + mm_w8a8_aic_2.Process(); + mm_ein_sum.Process(); + if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { + FftsCrossCoreSync(EINSUMOUT); + WaitFlagDev(EINSUMOUT); + FftsCrossCoreSync(EINSUMQUANT); + } +#endif +} + +template +__aicore__ inline void +MLAOperation::ProcessVector() +{ +#ifdef __DAV_C220_VEC__ + if (row_work_ != 0) { + uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + AscendC::LocalTensor input_tensor = buf.GetBuffer(0); + AscendC::LocalTensor scale_tensor = + buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); + AscendC::LocalTensor offset_tensor = buf.GetBuffer( + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32); + AscendC::LocalTensor res1_tensor = + buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64); + AscendC::LocalTensor res3_tensor = buf.GetBuffer( + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4); + AscendC::LocalTensor output_tensor = buf.GetBuffer( + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 + + BUF_FACTOR * num_col_align_f32 * 4 + 64); + Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); + } + FftsCrossCoreSync(QUANT1); + WaitFlagDev(QUANT1); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(AIC_MM1_START); + if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { + mm_w8a8_aiv_1.Process(); + FftsCrossCoreSync(RMSNORMQUANT2); + WaitFlagDev(RMSNORMQUANT2); + } else { // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT + WaitFlagDev(MMAIV); + } + if (row_work_ != 0) { + uint32_t num_col_align_int8 = (num_col_2 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + uint32_t num_col_align_f16 = (num_col_2 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + uint32_t num_col_align_f32 = (num_col_2 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + AscendC::LocalTensor input_tensor = buf.GetBuffer(0); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor beta_tensor = + buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2); + AscendC::LocalTensor scale_tensor = + buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2); + AscendC::LocalTensor offset_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 32); + AscendC::LocalTensor res1_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64); + AscendC::LocalTensor res3_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4); + AscendC::LocalTensor output_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4 + + BUF_FACTOR * num_col_align_f32 * 4 + 64 + MM1_OUT_SIZE * 4 * 2 + 32); + rmsNormQuant2.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, + res1_tensor, res3_tensor); + } + FftsCrossCoreSync(MM2); + WaitFlagDev(MM2); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(AIC_MM2_START); + + if (row_work_ != 0) { + AscendC::LocalTensor input_tensor = buf.GetBuffer(0); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor sin_tensor = + buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2); + AscendC::LocalTensor cos_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 2); + AscendC::LocalTensor slotMapping_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4); + int32_t rms3_ub_offset = + MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + 4096 * 32; + AscendC::LocalTensor tmp32_tensor = buf.GetBuffer(rms3_ub_offset); + + int32_t out_ub_offset = MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + + 4096 * 32 + SPLIT_RMSNRORM_SIZE_ONE * 3 * 4 + SPLIT_RMSNRORM_SIZE_TWO * 2 * 4 + + MM1_OUT_SIZE * 4 * 2 + 32; + AscendC::LocalTensor temp_tensor = buf.GetBuffer(out_ub_offset); + + AscendC::LocalTensor tmpfp16; + AscendC::LocalTensor int8OutTensor; + float scale3 = 0; + if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { + // quantScale3 + AscendC::LocalTensor quantScaleTensor = + buf.GetBuffer(rms3_ub_offset); + AscendC::LocalTensor floatQuantScaleTensor = + buf.GetBuffer(rms3_ub_offset + 32); + // int8out + tmpfp16 = buf.GetBuffer(rms3_ub_offset + + SPLIT_RMSNRORM_SIZE_ONE * sizeof(float) * 2); + int8OutTensor = buf.GetBuffer(out_ub_offset); + AscendC::DataCopy(quantScaleTensor, quantScale3GmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + Cast(floatQuantScaleTensor, quantScaleTensor, AscendC::RoundMode::CAST_NONE, 1); + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + scale3 = 1 / (float)(floatQuantScaleTensor.GetValue(0)); + } + + RmsNormAndRopeConvergence1( + input_tensor, // n * 576 + gamma_tensor, // gamma + sin_tensor, // sin + cos_tensor, // cons + slotMapping_tensor, // slotMapping + row_work_, tmp32_tensor, tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE], + tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE], + tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO], + tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO + + SPLIT_RMSNRORM_SIZE_TWO], + temp_tensor, tmpfp16, int8OutTensor, scale3); + } + mm_w8a8_aiv_2.Process(); + FftsCrossCoreSync(MM2OUT); + WaitFlagDev(MM2OUT); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(AIC_MM3_START); + ropeFp16.Process(); + + if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { + WaitFlagDev(EINSUMQUANT); + einSumQuant.Process(); + } +#endif +} + +} // namespace MLAPO_BF16 diff --git a/csrc/ops.h b/csrc/ops.h index 2401792d..4edf0a8f 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -154,6 +154,7 @@ namespace vllm_ascend { void* keycache_out, void* q2, void* keycache_out2, + void* inner_out, void* workspace, void* tiling, const uint32_t block_dim diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 351c6745..ec762b08 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -173,7 +173,7 @@ std::tuple rotary_embedding(at::Tensor &positions, at::T return {query_dst, key_dst}; } -std::tuple mla_preprocess( +std::tuple mla_preprocess( const at::Tensor &hiddenState, const at::Tensor &wdqkv, const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq, const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin, @@ -181,8 +181,8 @@ std::tuple mla_preproces const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0, const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1, const c10::optional &ctkv_scale, const c10::optional &q_nope_scale, - c10::optional cache_mode, c10::optional quant_mode, at::Tensor &q_out0, - at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1) + c10::optional cache_mode, c10::optional quant_mode, c10::optional enable_inner_out, at::Tensor &q_out0, + at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1, at::Tensor &inner_out) { at::Tensor CtkvScale = ctkv_scale.has_value() @@ -192,12 +192,17 @@ std::tuple mla_preproces q_nope_scale.has_value() ? q_nope_scale.value() : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + bool enableInnerOut = + enable_inner_out.has_value() + ? enable_inner_out.value() + : false; auto [workspace_tensor, tiling, block_dim] = mlapo::mla_preprocess_tiling( hiddenState, wuk, cache_mode, - quant_mode + quant_mode, + enableInnerOut ); void *hidden_state_ptr = hiddenState.data_ptr(); @@ -225,6 +230,7 @@ std::tuple mla_preproces void *kv_cache_out0_ptr = kv_cache_out0.data_ptr(); void *q_out1_ptr = q_out1.data_ptr(); void *kv_cache_out1_ptr = kv_cache_out1.data_ptr(); + void *inner_out_ptr = inner_out.data_ptr(); void *workspace_ptr = workspace_tensor.data_ptr(); void *tiling_ptr = tiling.data_ptr(); @@ -235,17 +241,17 @@ std::tuple mla_preproces cmd.SetCustomHandler([stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr, - qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr, + qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, inner_out_ptr, workspace_ptr, tiling_ptr, block_dim]() -> int { mla_preprocess_impl(stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr, kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr, - qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr, + qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, inner_out_ptr, workspace_ptr, tiling_ptr, block_dim); return 0; }); cmd.Run(); - return std::forward_as_tuple(q_out0, kv_cache_out0, q_out1, kv_cache_out1); + return std::forward_as_tuple(q_out0, kv_cache_out0, q_out1, kv_cache_out1, inner_out); } std::tuple get_masked_input_and_mask( @@ -792,9 +798,9 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) " Tensor kv_cache_rope, Tensor slotmapping, Tensor quant_scale0," " Tensor quant_offset0, Tensor bias0, Tensor quant_scale1, Tensor quant_offset1," " Tensor bias1, Tensor? ctkv_scale, Tensor? q_nope_scale, str? cache_mode," - " str? quant_mode, Tensor! q_out0, Tensor! kv_cache_out0, Tensor! q_out1," - " Tensor! kv_cache_out1) -> (Tensor q_out0, Tensor kv_cache_out0," - " Tensor q_out1, Tensor kv_cache_out1)" + " str? quant_mode, bool? enable_inner_out, Tensor! q_out0, Tensor! kv_cache_out0, Tensor! q_out1," + " Tensor! kv_cache_out1, Tensor! inner_out) -> (Tensor q_out0, Tensor kv_cache_out0," + " Tensor q_out1, Tensor kv_cache_out1, Tensor inner_out)" ); ops.impl("mla_preprocess", torch::kPrivateUse1, &vllm_ascend::mla_preprocess); diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index 4a09c8df..50e9ed63 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -81,7 +81,7 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_ return y_out; } -std::tuple mla_preprocess( +std::tuple mla_preprocess( const at::Tensor &hiddenState, const at::Tensor &wdqkv, const at::Tensor &descale0, @@ -106,12 +106,15 @@ std::tuple mla_preproces const c10::optional &q_nope_scale, c10::optional cache_mode, c10::optional quant_mode, + c10::optional enable_inner_out, at::Tensor &q_out0, at::Tensor &kv_cache_out0, at::Tensor &q_out1, - at::Tensor &kv_cache_out1) + at::Tensor &kv_cache_out1, + at::Tensor &inner_out + ) { - return {q_out0, kv_cache_out0, q_out1, kv_cache_out1}; + return {q_out0, kv_cache_out0, q_out1, kv_cache_out1, inner_out}; } std::tuple grouped_matmul_swiglu_quant( diff --git a/tests/e2e/nightly/ops/test_mla_preprocess_qdown.py b/tests/e2e/nightly/ops/test_mla_preprocess_qdown.py new file mode 100644 index 00000000..9eb7e1ca --- /dev/null +++ b/tests/e2e/nightly/ops/test_mla_preprocess_qdown.py @@ -0,0 +1,116 @@ +import gc + +import torch +import torch_npu + +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + + +@torch.inference_mode() +def test_mla_preprocess_kernel(): + token_num = 1 + head_num = 2 + N_7168 = 7168 + block_num = 1 + block_size = 128 + dtype = torch.bfloat16 + + hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu() + quant_scale0 = torch.randn((1, ), dtype=dtype).npu() + quant_offset0 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu() + + wdqkv = torch.randint(0, 7, (1, 224, 2112, 32), dtype=torch.int8).npu() + wdqkv = torch_npu.npu_format_cast(wdqkv.contiguous(), 29) + + de_scale0 = torch.rand((2112, ), dtype=torch.float).npu() + bias0 = torch.randint(0, 7, (2112, ), dtype=torch.int32).npu() + gamma1 = torch.randn((1536), dtype=dtype).npu() + beta1 = torch.randn((1536), dtype=dtype).npu() + quant_scale1 = torch.randn((1, ), dtype=dtype).npu() + quant_offset1 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu() + + wuq = torch.randint(0, 7, (1, 48, head_num * 192, 32), + dtype=torch.int8).npu() + wuq = torch_npu.npu_format_cast(wuq.contiguous(), 29) + + de_scale1 = torch.rand((head_num * 192, ), dtype=torch.float).npu() + bias1 = torch.randint(0, 7, (head_num * 192, ), dtype=torch.int32).npu() + + gamma2 = torch.randn((512), dtype=dtype).npu() + + cos = torch.randn((token_num, 64), dtype=dtype).npu() + sin = torch.randn((token_num, 64), dtype=dtype).npu() + + wuk = torch.randn((head_num, 128, 512), dtype=dtype).npu() + wuk = torch_npu.npu_format_cast(wuk, 29) + kv_cache = torch.randint(0, + 7, + (block_num, head_num * 512 // 32, block_size, 32), + dtype=dtype).npu() + kv_cache_rope = torch.randn( + (block_num, head_num * 64 // 16, block_size, 16), dtype=dtype).npu() + + slotmapping = torch.randint(0, 7, (token_num, ), dtype=torch.int32).npu() + + ctkv_scale = torch.randn((1, ), dtype=dtype).npu() + qnope_scale = torch.randn((head_num), dtype=dtype).npu() + + q_nope_out = torch.empty( + (hidden_states.shape[0], wuk.shape[0], kv_cache.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + q_rope_out = torch.empty( + (hidden_states.shape[0], wuk.shape[0], kv_cache_rope.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + q_down = torch.empty( + (hidden_states.shape[0], 1536), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + q_nope_old = q_nope_out.clone() + q_rope_old = q_rope_out.clone() + q_down_old = q_down.clone() + + torch.ops._C_ascend.mla_preprocess(hidden_states, + wdqkv, + de_scale0, + gamma1, + beta1, + wuq, + de_scale1, + gamma2, + cos, + sin, + wuk, + kv_cache, + kv_cache_rope, + slotmapping, + quant_scale0=quant_scale0, + quant_offset0=quant_offset0, + bias0=bias0, + quant_scale1=quant_scale1, + quant_offset1=quant_offset1, + bias1=bias1, + ctkv_scale=ctkv_scale, + q_nope_scale=qnope_scale, + cache_mode="krope_ctkv", + quant_mode="per_tensor_quant_asymm", + enable_inner_out=True, + q_out0=q_nope_out, + kv_cache_out0=kv_cache, + q_out1=q_rope_out, + kv_cache_out1=kv_cache_rope, + inner_out=q_down) + assert not torch.equal(q_nope_out, q_nope_old) + assert not torch.equal(q_rope_out, q_rope_old) + assert not torch.equal(q_down, q_down_old) + + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats()