diff --git a/csrc/mla_preprocess/op_host/mla_preprocess.h b/csrc/mla_preprocess/op_host/mla_preprocess.h index f554869b..ae4a1225 100644 --- a/csrc/mla_preprocess/op_host/mla_preprocess.h +++ b/csrc/mla_preprocess/op_host/mla_preprocess.h @@ -43,7 +43,6 @@ constexpr uint32_t L1_BIAS_SIZE = 2048; constexpr uint32_t L0C_SIZE = 128 * 1024; constexpr uint32_t CONCAT_SIZE = 512; -constexpr uint32_t HIDDEN_STRATE = 7168; constexpr uint32_t HIDDEN_STRATE_ROPE = 192; constexpr uint32_t HIDDEN_STRATE_MM = 2112; constexpr uint32_t HIDDEN_STRATE_RMS = 1536; @@ -122,6 +121,8 @@ struct PlatformInfo { }; struct OpParam { + uint32_t isWeightQuantized; + uint32_t hiddenStateDim; uint32_t N; uint32_t headNum; int32_t cacheMode; @@ -392,7 +393,7 @@ private: void MlaPreprocessTiling::RmsNormQuantTiling() { tilingData->rmsNumCore1 = platformInfo.coreNumAiv; - tilingData->rmsNumCol1 = HIDDEN_STRATE; + tilingData->rmsNumCol1 = opParam.hiddenStateDim; tilingData->rmsNumRow1 = opParam.N; tilingData->rmsQuantMin1 = -CONST_128; tilingData->rmsNumCore2 = platformInfo.coreNumAiv; @@ -508,9 +509,9 @@ void MlaPreprocessTiling::EinSumQuantTiling() void MlaPreprocessTiling::SetMlapoWorkSpace() { uint64_t s1wsFactor = - static_cast(opParam.cacheMode == 2 ? std::max(HIDDEN_STRATE * sizeof(int8_t), + static_cast(opParam.cacheMode == 2 ? std::max(opParam.hiddenStateDim * sizeof(int8_t), opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t)) - : HIDDEN_STRATE * sizeof(int8_t)); + : opParam.hiddenStateDim * sizeof(int8_t)); uint64_t workSizeS1 = s1wsFactor; uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t); uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t); @@ -525,7 +526,8 @@ void MlaPreprocessTiling::SetMlapoWorkSpace() uint64_t pertokenWorkspace = static_cast(opParam.N) * sizeof(float) * 2; uint64_t userWorkspaceSize; - if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + if (opParam.isWeightQuantized == 1 && + (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT)) { userWorkspaceSize = 4 * maxWorkspaceSize + pertokenWorkspace; } else { userWorkspaceSize = 3 * maxWorkspaceSize; @@ -554,21 +556,23 @@ void MlaPreprocessTiling::Init() { tilingData->numCore = platformInfo.coreNumAic; tilingData->n = opParam.N; - + tilingData->hiddenStateDim = opParam.hiddenStateDim; + tilingData->isWeightQuantized = opParam.isWeightQuantized; + bool enDequant = (opParam.isWeightQuantized == 1); bool deqOnTheFly = false; - if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + if (enDequant && (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT)) { deqOnTheFly = true; } PpMatmulTilingApi mm1TilingApi(platformInfo, - 1, // numBatch - opParam.N, // m - HIDDEN_STRATE, // k - HIDDEN_STRATE_MM, // n - false, // transA - true, // transB - true, // enDequant - deqOnTheFly); // in bf16.cce? + 1, // numBatch + opParam.N, // m + opParam.hiddenStateDim, // k + HIDDEN_STRATE_MM, // n + false, // transA + true, // transB + enDequant, // enDequant + deqOnTheFly); // in bf16.cce? mm1TilingApi.GetTilingData(tilingData->mm1); PpMatmulTilingApi mm2TilingApi(platformInfo, @@ -578,7 +582,7 @@ void MlaPreprocessTiling::Init() opParam.headNum * HIDDEN_STRATE_ROPE, // n false, // transA true, // transB - true, // enDequant + enDequant, // enDequant deqOnTheFly); // in bf16.cce? mm2TilingApi.GetTilingData(tilingData->mm2); @@ -609,6 +613,8 @@ std::unordered_map cache_mode_map = { std::unordered_map quant_mode_map = { {"per_tensor_quant_asymm", 0}, {"per_token_quant_symm", 1}, + {"per_token_quant_asymm", 2}, + {"no_quant", 3} }; template @@ -623,6 +629,7 @@ inline int get_op_mode(const MapType &mode_map, c10::optional std::tuple mla_preprocess_tiling( const at::Tensor &hiddenState, + const at::Tensor &wdqkv, const at::Tensor &wuk, c10::optional cache_mode, c10::optional quant_mode, @@ -647,14 +654,21 @@ std::tuple mla_preprocess_tiling( int32_t N = hiddenState.sizes()[0]; int32_t headNum = wuk.sizes()[0]; + uint32_t hiddenStateDim = hiddenState.sizes().back(); OpParam opParam; + opParam.hiddenStateDim = hiddenStateDim; opParam.N = N; opParam.headNum = headNum; opParam.cacheMode = static_cast(cacheMode); opParam.quantMode = static_cast(quantMode); opParam.inDtype = hiddenState.options().dtype(); opParam.enableInnerOut = enable_inner_out; + if (wdqkv.options().dtype() == at::kBFloat16 || wdqkv.options().dtype() == at::kHalf) { + opParam.isWeightQuantized = 0; + } else { + opParam.isWeightQuantized = 1; + } MlaTilingData tilingData; MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData); diff --git a/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h b/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h index aab1f3a7..00be2898 100644 --- a/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h +++ b/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h @@ -90,6 +90,11 @@ struct MlaTilingData { uint32_t esqHeadTail{0}; uint32_t esqColLoop{0}; uint32_t esqColTail{0}; + + // hidden state dimension + uint32_t hiddenStateDim{7168}; + + uint32_t isWeightQuantized{1}; }; #endif // MLAPREPROCESS_TILING_H diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess.h b/csrc/mla_preprocess/op_kernel/mla_preprocess.h index fe7a125c..6c894bd1 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess.h +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess.h @@ -49,7 +49,6 @@ constexpr uint8_t CACHE_MODE_INT8_NZCACHE = 2; // high performance KV NZ format constexpr uint8_t CACHE_MODE_NZCACHE = 3; // pp matmul -constexpr uint32_t HIDDTEN_STATE = 7168; constexpr uint32_t FLOAT_BLOCK_SIZE = 64; constexpr uint32_t HALF_BLOCK_SIZE = 64; constexpr uint32_t HALF_VECTOR_SIZE = 64; @@ -103,6 +102,7 @@ constexpr uint32_t KEY_FP16_CACHEMODE_1_QUANTMODE_0 = 1; constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0 = 256; constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0 = 257; constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0 = 259; +constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_3 = 281; constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER = 256 + 512; constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER = 257 + 512; constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER = 259 + 512; diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp index 5f367a92..f5dbe210 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp @@ -16,6 +16,7 @@ #include "mla_preprocess_mix_fp16.hpp" #include "mla_preprocess_mix_bf16.hpp" #include "mla_preprocess_mix_bf16_qdown.hpp" +#include "mla_preprocess_mix_bf16_nq.hpp" #include "../op_host/tiling/mla_preprocess_tiling.h" @@ -42,6 +43,7 @@ extern "C" __global__ __aicore__ void mla_preprocess( mlaTilingData.tilingKey = tilingData->tilingKey; mlaTilingData.n = tilingData->n; + mlaTilingData.hiddenStateDim = tilingData->hiddenStateDim; mlaTilingData.mm1.numBatch = tilingData->mm1.numBatch; mlaTilingData.mm1.m = tilingData->mm1.m; @@ -173,12 +175,12 @@ extern "C" __global__ __aicore__ void mla_preprocess( } case KEY_BF16_CACHEMODE_0_QUANTMODE_0: { MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, - QuantMode::PER_TENSOR_ASYMM_QUANT> + QuantMode::PER_TENSOR_ASYMM_QUANT> opBf16Cm0Qm0(mlaTilingData, tiling); opBf16Cm0Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, - quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, - bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, - s1, s2, s3, s4, s5); + quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, + bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, + s1, s2, s3, s4, s5); if ASCEND_IS_AIC { opBf16Cm0Qm0.ProcessCube(); } @@ -189,12 +191,12 @@ extern "C" __global__ __aicore__ void mla_preprocess( } case KEY_BF16_CACHEMODE_1_QUANTMODE_0: { MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, - QuantMode::PER_TENSOR_ASYMM_QUANT> + QuantMode::PER_TENSOR_ASYMM_QUANT> opBf16Cm1Qm0(mlaTilingData, tiling); opBf16Cm1Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, - quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, - bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, - s1, s2, s3, s4, s5); + quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, + bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, + s1, s2, s3, s4, s5); if ASCEND_IS_AIC { opBf16Cm1Qm0.ProcessCube(); } @@ -219,6 +221,21 @@ extern "C" __global__ __aicore__ void mla_preprocess( } break; } + case KEY_BF16_CACHEMODE_1_QUANTMODE_3: { + MLAPO_BF16_NQ::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND> + opBf16Cm1Qm0(mlaTilingData, tiling); + opBf16Cm1Qm0.Init(hiddenState, wdqkv, gamma2, beta2, + gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, + wuk, q, keycacheOut, q2, keycacheOut2, + s1, s2, s3); + if ASCEND_IS_AIC { + opBf16Cm1Qm0.ProcessCube(); + } + if ASCEND_IS_AIV { + opBf16Cm1Qm0.ProcessVector(); + } + break; + } case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER: { MLAPO_BF16_INNER::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, QuantMode::PER_TENSOR_ASYMM_QUANT> diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp index 43d9509e..5c641890 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp @@ -2386,6 +2386,7 @@ public: this->num_row = mlaParams_.n; this->epsilon_ = 1e-6; this->mlaParams = mlaParams_; + this->hiddenStateDim = mlaParams_.hiddenStateDim; } __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm, @@ -2692,6 +2693,7 @@ private: uint32_t blockOffset; uint32_t perTaskNum; uint32_t resTaskNum; + uint32_t hiddenStateDim; MlaTilingData mlaParams; uint32_t num_core_; @@ -2795,18 +2797,15 @@ MLAOperation input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor scale_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); - AscendC::LocalTensor offset_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32); - AscendC::LocalTensor res1_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64); + AscendC::LocalTensor scale_tensor = buf.GetBuffer(base_offset); + AscendC::LocalTensor offset_tensor = buf.GetBuffer(base_offset + 32); + AscendC::LocalTensor res1_tensor = buf.GetBuffer(base_offset + 64); AscendC::LocalTensor res3_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4); + base_offset + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 + - BUF_FACTOR * num_col_align_f32 * 4 + 64); + base_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64); Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); } FftsCrossCoreSync(QUANT1); diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16_nq.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16_nq.hpp new file mode 100644 index 00000000..ec42aebb --- /dev/null +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16_nq.hpp @@ -0,0 +1,1252 @@ +// Adapted from +// https://gitee.com/ascend/ascend-transformer-boost +// +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +// This file is a part of the CANN Open Software. +// Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// + +#include "kernel/common.h" +#include "kernel/iterator.h" +#include "kernel/mem.h" +#include "kernel/mma.h" +#include "kernel/utils.h" +#include "kernel/simd.h" +#include "kernel/kernel_utils.h" + +#include "lib/matmul_intf.h" + +#include "mla_preprocess.h" +#include "../op_host/tiling/mla_preprocess_tiling.h" + +namespace MLAPO_BF16_NQ { +template +class RopeFp16 +{ +public: + __aicore__ inline RopeFp16() : blockIdx_(AscendC::GetBlockIdx()) {} + + __aicore__ inline void RopeInit(GM_ADDR qGm, AscendC::GlobalTensor &cosGm, + AscendC::GlobalTensor &sinGm, + AscendC::GlobalTensor &outRopeConcatGm, + AscendC::GlobalTensor &outRopeConcatGm2, MlaTilingData &ropeConcatParams) + { + qGm_.SetGlobalBuffer(reinterpret_cast<__gm__ QkDtype *>(qGm)); + this->cosGm_ = cosGm; + this->sinGm_ = sinGm; + this->outRopeConcatGm_ = outRopeConcatGm; + this->outRopeConcatGm2_ = outRopeConcatGm2; + + headDim = ropeConcatParams.headDim; + headNumQ = ropeConcatParams.headNumQ; + rotaryCoeff = ropeConcatParams.rotaryCoeff; + ntokens = ropeConcatParams.ntokens; + realCore = ropeConcatParams.realCore; + nlCoreRun = ropeConcatParams.nlCoreRun; + lCoreRun = ropeConcatParams.lCoreRun; + maxNPerLoopForUb = ropeConcatParams.maxNPerLoopForUb; + preCoreLoopTime = ropeConcatParams.preCoreLoopTime; + preCoreLoopNLast = ropeConcatParams.preCoreLoopNLast; + lastCoreLoopTime = ropeConcatParams.lastCoreLoopTime; + lastCoreLoopNLast = ropeConcatParams.lastCoreLoopNLast; + concatSize = ropeConcatParams.concatSize; + blockIdx_ = (blockIdx_ / 2) * 2 + static_cast(GetSubBlockidx()); + loopTime = (blockIdx_ == realCore - 1) ? lastCoreLoopTime : preCoreLoopTime; + lastLoopN = (blockIdx_ == realCore - 1) ? lastCoreLoopNLast : preCoreLoopNLast; + this->repeatSize_ = 64; // 128 = 256B / sizeof(fp32) + this->rotateStride_ = this->headDim / this->rotaryCoeff; + headBlockLen = static_cast(this->headDim / ELE_NUM_FP16); + headBlockLenFP32 = static_cast(this->headDim / ELE_NUM_FP32); + rotaryLen = static_cast(this->rotateStride_ / ELE_NUM_FP32); + concatBlockLen = static_cast(this->concatSize / ELE_NUM_FP16); + outLineOffset = this->headDim + this->concatSize; + uint32_t dataNum = this->headDim * this->maxNPerLoopForUb; + dataSizeFp16 = dataNum * sizeof(QkDtype); + dataSizeFp32 = dataNum * sizeof(float); + uint32_t concatDataSize = this->concatSize * sizeof(QkDtype) * this->maxNPerLoopForUb; + } + + __aicore__ inline void Process() + { + if (blockIdx_ >= realCore) { + return; + } + uint64_t startCoreLineIndex = this->blockIdx_ * this->nlCoreRun; + // [maxNPerLoopForUb,head_dim] 的 neg + AscendC::LocalTensor negLocal = + buf.GetBuffer(dataSizeFp32 * 4 + dataSizeFp16 * 3); + ExpandNeg(negLocal, this->maxNPerLoopForUb); + + SET_FLAG(MTE3, MTE2, EVENT_ID1); + for (uint32_t zz = 0; zz < this->loopTime; ++zz) { + uint16_t loopN = (zz == this->loopTime - 1) ? this->lastLoopN : this->maxNPerLoopForUb; + uint64_t startHead = startCoreLineIndex + zz * this->maxNPerLoopForUb; + uint64_t endHead = startHead + loopN; + + // move in Q + AscendC::LocalTensor inputQ = buf.GetBuffer(0); + AscendC::LocalTensor inputQCastFP32 = buf.GetBuffer(dataSizeFp16); + AscendC::LocalTensor reverseQ = + buf.GetBuffer(dataSizeFp32 + dataSizeFp16); + uint64_t qOffset = startHead * 192 + 128; + CopyQGenReverseQ(inputQ, inputQCastFP32, reverseQ, qOffset, loopN); + + // move in cos/sin + AscendC::LocalTensor inputCos = + buf.GetBuffer(dataSizeFp32 * 2 + dataSizeFp16); + AscendC::LocalTensor inputSin = + buf.GetBuffer(dataSizeFp32 * 2 + dataSizeFp16 * 2); + uint64_t startSinCosHeadIndex = startHead; + uint64_t headRemain = startHead % this->headNumQ; + uint64_t localStartAddr = 0; + if (headRemain != 0) { + uint64_t preProcessHeadNum = this->headNumQ - headRemain; + uint64_t needToProcesHead = preProcessHeadNum > loopN ? loopN : preProcessHeadNum; + CopyCosSin(inputCos, inputSin, localStartAddr, (startSinCosHeadIndex / this->headNumQ) * this->headDim, + needToProcesHead); + startSinCosHeadIndex += needToProcesHead; + localStartAddr += needToProcesHead * this->headDim; + } + + if (startSinCosHeadIndex < endHead) { + uint64_t startSinCosIndex = startSinCosHeadIndex / this->headNumQ; + uint64_t endSinCosIndex = (endHead + this->headNumQ - 1) / this->headNumQ; + for (uint32_t index = startSinCosIndex; index < endSinCosIndex; ++index) { + uint32_t repeatNum = + index == endSinCosIndex - 1 ? endHead - index * this->headNumQ : this->headNumQ; + CopyCosSin(inputCos, inputSin, localStartAddr, index * this->headDim, repeatNum); + localStartAddr += this->headDim * this->headNumQ; + } + } + AscendC::LocalTensor inputCosCastFP32 = + buf.GetBuffer(dataSizeFp32 * 2 + dataSizeFp16 * 3); + AscendC::LocalTensor inputSinCastFP32 = + buf.GetBuffer(dataSizeFp32 * 3 + dataSizeFp16 * 3); + AscendC::Cast(inputCosCastFP32, inputCos, AscendC::RoundMode::CAST_NONE, loopN * this->headDim); + AscendC::Cast(inputSinCastFP32, inputSin, AscendC::RoundMode::CAST_NONE, loopN * this->headDim); + AscendC::PipeBarrier(); + + uint32_t repeatTime = this->headDim * loopN; + AscendC::Mul(inputQCastFP32, inputCosCastFP32, inputQCastFP32, repeatTime); + AscendC::Mul(reverseQ, negLocal, reverseQ, repeatTime); + AscendC::PipeBarrier(); + + AscendC::Mul(reverseQ, inputSinCastFP32, reverseQ, repeatTime); + AscendC::PipeBarrier(); + + AscendC::Add(inputQCastFP32, reverseQ, inputQCastFP32, repeatTime); + AscendC::PipeBarrier(); + + AscendC::Cast(inputQ, inputQCastFP32, AscendC::RoundMode::CAST_RINT, loopN * this->headDim); + AscendC::PipeBarrier(); + uint64_t outQOffset = startHead * outLineOffset + this->concatSize; + uint64_t outQOffset2 = startHead * this->headDim; + SET_FLAG(V, MTE3, EVENT_ID1); + WAIT_FLAG(V, MTE3, EVENT_ID1); + if constexpr (CacheMode == CACHE_MODE_KVCACHE) { + AscendC::DataCopy(this->outRopeConcatGm_[outQOffset], inputQ, {loopN, headBlockLen, 0, concatBlockLen}); + } else { + AscendC::DataCopy(this->outRopeConcatGm2_[outQOffset2], inputQ, loopN * this->headDim); + } + SET_FLAG(MTE3, MTE2, EVENT_ID1); + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + } + // tensor -1 -1 -1 1 1 1 + template + __aicore__ inline void ExpandNeg(const AscendC::LocalTensor &tempBuf, uint32_t headNumTemp) + { + for (uint32_t i = 0; i < this->rotateStride_; ++i) { + tempBuf.SetValue(i, (BUF_TYPE)-1); + tempBuf.SetValue(i + this->rotateStride_, (BUF_TYPE)1); + } + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + AscendC::Copy(tempBuf[this->headDim], tempBuf, this->headDim, headNumTemp - 1, {1, 1, headBlockLenFP32, 0}); + AscendC::PipeBarrier(); + } + + template + __aicore__ inline void CopyQGenReverseQ(const AscendC::LocalTensor &tempBufQ, + const AscendC::LocalTensor &tempBufQCast, + const AscendC::LocalTensor &tempBufRverseQ, uint64_t qOffset, + uint16_t loopN) + { + // move in Q + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, 128 / 16, 0}); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + // cast fp32 + AscendC::Cast(tempBufQCast, tempBufQ, AscendC::RoundMode::CAST_NONE, loopN * this->headDim); + AscendC::PipeBarrier(); + // move out reverseQ + AscendC::DataCopy(tempBufRverseQ, tempBufQCast[this->rotateStride_], {loopN, rotaryLen, rotaryLen, rotaryLen}); + AscendC::DataCopy(tempBufRverseQ[this->rotateStride_], tempBufQCast, {loopN, rotaryLen, rotaryLen, rotaryLen}); + AscendC::PipeBarrier(); + } + + template + __aicore__ inline void CopyCosSin(const AscendC::LocalTensor &tempBufCos, + const AscendC::LocalTensor &tempBufSin, uint64_t localStartAddr, + uint64_t gmStartAddr, uint64_t repeatNum) + { + AscendC::DataCopy(tempBufCos[localStartAddr], this->cosGm_[gmStartAddr], {1, headBlockLen, 0, 0}); + AscendC::DataCopy(tempBufSin[localStartAddr], this->sinGm_[gmStartAddr], {1, headBlockLen, 0, 0}); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + AscendC::Copy(tempBufCos[localStartAddr + this->headDim], tempBufCos[localStartAddr], this->headDim, + repeatNum - 1, {1, 1, headBlockLen, 0}); + AscendC::Copy(tempBufSin[localStartAddr + this->headDim], tempBufSin[localStartAddr], this->headDim, + repeatNum - 1, {1, 1, headBlockLen, 0}); + AscendC::PipeBarrier(); + } + +private: + AsdopsBuffer buf; + + AscendC::GlobalTensor qGm_; + AscendC::GlobalTensor cosGm_; + AscendC::GlobalTensor sinGm_; + AscendC::GlobalTensor outRopeConcatGm_; + AscendC::GlobalTensor outRopeConcatGm2_; + + uint32_t repeatSize_{0}; + uint32_t rotateStride_{0}; // this->headDim / rope conf + uint32_t headDim; + uint32_t headNumQ; + uint32_t rotaryCoeff; + uint32_t ntokens; + uint32_t realCore; + uint32_t nlCoreRun; + uint32_t lCoreRun; + uint32_t maxNPerLoopForUb; + uint32_t preCoreLoopTime; + uint32_t preCoreLoopNLast; + uint32_t lastCoreLoopTime; + uint32_t lastCoreLoopNLast; + uint32_t concatSize; + uint32_t blockIdx_; + uint32_t loopTime{0}; + uint32_t lastLoopN{0}; + + uint32_t dataSizeFp32; + uint32_t dataSizeFp16; + uint16_t headBlockLen{0}; + uint16_t headBlockLenFP32{0}; + uint16_t rotaryLen{0}; + uint16_t concatBlockLen{0}; + uint64_t outLineOffset{0}; +}; + +__aicore__ inline void ReduceSumCustom(const AscendC::LocalTensor &dst_local, + const AscendC::LocalTensor &src_local, + const AscendC::LocalTensor &work_local, int32_t count) +{ +#ifdef __DAV_C220_VEC__ + uint64_t mask = NUM_PER_REP_FP32; + int32_t repeatTimes = count / NUM_PER_REP_FP32; + int32_t tailCount = count % NUM_PER_REP_FP32; + int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32; + AscendC::BinaryRepeatParams repeatParams; + repeatParams.src0RepStride = AscendC::ONE_REPEAT_BYTE_SIZE / AscendC::ONE_BLK_SIZE; + repeatParams.src0BlkStride = 1; + repeatParams.src1RepStride = 0; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = 0; + repeatParams.dstBlkStride = 1; + Duplicate(work_local, ZERO, NUM_PER_REP_FP32); + AscendC::PipeBarrier(); + if (likely(repeatTimes > 0)) { + Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams); + AscendC::PipeBarrier(); + } + if (unlikely(tailCount != 0)) { + Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams); + AscendC::PipeBarrier(); + } + AscendC::AscendCUtils::SetMask(NUM_PER_REP_FP32); + cadd_v(dst_local, // dst + work_local, // src + 1, // repeat + 0, // dstRepeatStride + 1, // srcBlockStride + 0); // srcRepeatStride + AscendC::PipeBarrier(); +#endif +} + +template +class RmsNormQuant +{ +public: + __aicore__ inline RmsNormQuant() {} + + __aicore__ inline void Init(AscendC::GlobalTensor &gammaGmTensor, AscendC::GlobalTensor &betaGmTensor, + GM_ADDR gmInput, GM_ADDR gmOutput, uint32_t stride, + uint32_t num_col, float avg_factor, uint64_t gm_offset, uint64_t gm_out_offset, + uint32_t row_work_, const MlaTilingData &mlaParams_) + { + this->gammaGmTensor = gammaGmTensor; + this->betaGmTensor = betaGmTensor; + inputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(gmInput)); + outputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(gmOutput)); + + num_col_ = num_col; + avg_factor_ = avg_factor; + epsilon_ = 1e-6; + this->num_row_ = mlaParams_.n; + this->row_work = row_work; + this->row_work_ = row_work_; + gm_offset_ = gm_offset; + gm_out_offset_ = gm_out_offset; + num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + num_col_align_f16 = (num_col_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + num_col_align_f32 = (num_col_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + input_stride_ = stride; + + num_col_align_withStride_int8 = + (num_col_ - input_stride_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + num_col_align_withStride_fp16 = + (num_col_ - input_stride_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + num_col_align_withStride_fp32 = + (num_col_ - input_stride_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + } + + __aicore__ inline void Launch(const AscendC::LocalTensor &dstTensor, + const AscendC::LocalTensor &srcTensor, const AscendC::LocalTensor &gammaTensor, + const AscendC::LocalTensor &betaTensor, + const AscendC::LocalTensor &res1Tensor, + const AscendC::LocalTensor &res3Tensor) + { + this->dstTensor = dstTensor; + this->srcTensor = srcTensor; + this->gammaTensor = gammaTensor; + this->betaTensor = betaTensor; + this->fp32_xy = res1Tensor; + this->buf = res3Tensor; + + AscendC::LocalTensor g = buf[OFFSET_GAMMA * num_col_align_withStride_fp32]; // 0 + AscendC::LocalTensor sqx = buf[OFFSET_SQX * num_col_align_withStride_fp32]; // 1 + AscendC::LocalTensor work = buf[OFFSET_SUM * num_col_align_withStride_fp32]; // 2 + AscendC::LocalTensor sum = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32]; // 4 + + AscendC::DataCopy(gammaTensor, gammaGmTensor, + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / BLOCK_SIZE_16, 0, 0)); + AscendC::DataCopy(betaTensor, betaGmTensor, + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / BLOCK_SIZE_16, 0, 0)); + SET_FLAG(MTE2, V, EVENT_ID1); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + WAIT_FLAG(MTE2, V, EVENT_ID1); + Cast(buf[OFFSET_GAMMA * num_col_align_withStride_fp32], gammaTensor, AscendC::RoundMode::CAST_NONE, + REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); + AscendC::PipeBarrier(); + uint64_t pid = 0; + SET_FLAG(MTE3, MTE2, EVENT_ID0); + while (pid < row_work_) { + uint64_t offset = pid * num_col_; + uint64_t outOffset = pid * (num_col_ - input_stride_); + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + AscendC::DataCopy(srcTensor, inputGmTensor[gm_offset_ + offset], + AscendC::DataCopyParams(1, num_col_ / BLOCK_SIZE_16, 0, 0)); + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + + Cast(fp32_xy, srcTensor[input_stride_], AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64, + num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM}); + AscendC::PipeBarrier(); + Mul(sqx, fp32_xy, fp32_xy, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE, + AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + Muls(sqx, sqx, avg_factor_, num_col_ - input_stride_); + AscendC::PipeBarrier(); + ReduceSumCustom(sum, sqx, work, num_col_ - input_stride_); + AscendC::PipeBarrier(); + Adds(sum, sum, epsilon_, 1); + AscendC::PipeBarrier(); + Sqrt(sum, sum, 1); + SET_FLAG(V, S, EVENT_ID0); + WAIT_FLAG(V, S, EVENT_ID0); + float factor = 1 / sum.GetValue(0); + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + Muls(fp32_xy, fp32_xy, factor, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + Mul(fp32_xy, fp32_xy, g, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, + {1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE, + AscendC::DEFAULT_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + + AscendC::LocalTensor tmpfp16 = + buf.ReinterpretCast()[OFFSET_SUM * num_col_align_withStride_fp32 * 2]; + CastFrom32To16(tmpfp16, fp32_xy, num_col_align_withStride_fp32); + AscendC::PipeBarrier(); + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + AscendC::DataCopy(outputGmTensor[gm_out_offset_ + outOffset], tmpfp16, + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / 16, 0, 0)); + SET_FLAG(MTE3, MTE2, EVENT_ID0); + ++pid; + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + } + +private: + AscendC::LocalTensor dstTensor; + AscendC::LocalTensor srcTensor; + AscendC::LocalTensor gammaTensor; + AscendC::LocalTensor betaTensor; + AscendC::LocalTensor fp32_xy; + AscendC::LocalTensor buf; + + AscendC::GlobalTensor gammaGmTensor; + AscendC::GlobalTensor betaGmTensor; + AscendC::GlobalTensor inputGmTensor; + AscendC::GlobalTensor outputGmTensor; + + uint32_t num_col_{0}; + uint32_t num_row_{0}; + uint32_t row_work_{0}; + uint32_t row_work{0}; + uint32_t row_step_{0}; + uint32_t row_tail_{0}; + uint64_t gm_offset_{0}; + uint64_t gm_out_offset_{0}; + float avg_factor_{1.0}; + int32_t input_stride_{0}; + float epsilon_{1e-12f}; + uint32_t num_col_align_int8{0}; + uint32_t num_col_align_f16{0}; + uint32_t num_col_align_f32{0}; + uint32_t num_col_align_f32_long{0}; + uint32_t num_col_align_withStride_int8{0}; + uint32_t num_col_align_withStride_fp16{0}; + uint32_t num_col_align_withStride_fp32{0}; + uint32_t num_col_temp; + uint32_t num_slice_{0}; + uint32_t tail_size_{0}; + uint32_t tail_copy_{0}; +}; + +#ifdef __DAV_C220_CUBE__ +struct MatCoord { + uint64_t m{0}; + uint64_t k{0}; + uint64_t n{0}; +}; + +template +class PpMatmulEinSum +{ + using AccumDtype = float; + + template + using CopyGmToCbuf = gm_to_l1; + using LoadCbufToCa = l1_to_l0_a; + using LoadCbufToCb = l1_to_l0_b; + using Mad = mmad; + using CopyCcToGm = l0c_to_gm; + + static constexpr uint32_t L0_PINGPONG_BUFFER_LEN = 16384; + static constexpr uint32_t L1_PINGPONG_BUFFER_LEN = 131072; + static constexpr uint32_t CONST_16 = 16; + static constexpr uint32_t CONST_256 = 256; + +public: + __aicore__ explicit PpMatmulEinSum(){}; + + __aicore__ __force_inline__ void Init(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, PpMatmulTilingData &tilingdata) + { +#ifdef __DAV_C220_CUBE__ + batch_size = tilingdata.numBatch; + m = tilingdata.m; + k = tilingdata.k; + n = tilingdata.n; + m0 = tilingdata.m0; + k0 = tilingdata.k0; + n0 = tilingdata.n0; + tdim.m = tilingdata.mLoop; + tdim.k = tilingdata.kLoop; + tdim.n = tilingdata.nLoop; + core_loop = tilingdata.coreLoop; + swizzle_cnt = tilingdata.swizzleCount; + num_core = tilingdata.blockDim; + core_idx = AscendC::GetBlockIdx(); + ping_flag = 1; + + gm_a.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmA)); + gm_b.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmB)); + gm_c.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(gmC)); + + AsdopsBuffer buf; + l1_base_a = buf.GetBuffer(0); + l1_base_b = buf.GetBuffer(RoundUp(m0 * k0 * sizeof(InDtype))); + l0a_base = buf.GetBuffer(0); + l0b_base = buf.GetBuffer(0); +#endif + return; + } + + __aicore__ __force_inline__ void Process() + { +#ifdef __DAV_C220_CUBE__ + if (block_idx >= num_core) { + return; + } + using LocalTensor = AscendC::LocalTensor; + + SET_FLAG(MTE1, MTE2, EVENT_ID0); + SET_FLAG(MTE1, MTE2, EVENT_ID1); + SET_FLAG(MTE1, MTE2, EVENT_ID2); + SET_FLAG(MTE1, MTE2, EVENT_ID3); + SET_FLAG(FIX, M, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID1); + + for (uint64_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += num_core) { + uint64_t batch_idx = loop_idx / tdim.n / tdim.m; + MatCoord tidx{0}; + GetBaseBlockIdx(loop_idx, tidx); + uint64_t offset_c = 0; + if constexpr (ein) { + offset_c = tidx.m * m0 * batch_size * (n + splitGapC) + batch_idx * (n + splitGapC) + tidx.n * n0; + } else { + offset_c = batch_idx * m * n + tidx.m * m0 * n + tidx.n * n0; + } + uint64_t m_actual = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0; + uint64_t n_actual = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0; + uint64_t m_round = RoundUp(m_actual); + uint64_t n_round = RoundUp(n_actual); + uint64_t mn_max = m_round > n_round ? m_round : n_round; + uint64_t k_part_len = L0_PINGPONG_BUFFER_LEN / mn_max / CONST_16 * CONST_16; + uint64_t shuffle_k = en_shuffle_k ? (core_idx % tdim.k) : 0; + + uint64_t k_actual = (shuffle_k == tdim.k - 1) ? k - shuffle_k * k0 : k0; + uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16; + + LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + event_t event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (loop_idx == core_idx) { + // Copy A from gm to l1 buffer + uint64_t offset_a = GetOffsetA(batch_idx, tidx.m, shuffle_k); + WAIT_FLAG(MTE1, MTE2, event_id); + CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual, k_round); + SET_FLAG(MTE2, MTE1, event_id); + + // Copy B from gm to l1 buffer + uint64_t offset_b = GetOffsetB(batch_idx, shuffle_k, tidx.n); + WAIT_FLAG(MTE1, MTE2, event_id + 2); + CopyTileB(l1_buf_b, gm_b[offset_b], k_actual, k_round, n_actual, n_round); + SET_FLAG(MTE2, MTE1, event_id + 2); + } + + for (tidx.k = 0; tidx.k < tdim.k; ++tidx.k) { + shuffle_k = en_shuffle_k ? (tidx.k + core_idx) % tdim.k : tidx.k; + uint64_t k_actual = (shuffle_k == (tdim.k - 1)) ? (k - shuffle_k * k0) : k0; + uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16; + fdim.k = (k_actual + k_part_len - 1) / k_part_len; + + LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (tidx.k < tdim.k - 1) { + uint64_t shuffle_k_next = en_shuffle_k ? (core_idx + tidx.k + 1) % tdim.k : (tidx.k + 1); + uint64_t offset_a_next = GetOffsetA(batch_idx, tidx.m, shuffle_k_next); + uint64_t offset_b_next = GetOffsetB(batch_idx, shuffle_k_next, tidx.n); + + uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0; + uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + + LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + // Preload A from gm to l1 buffer. + WAIT_FLAG(MTE1, MTE2, event_id_next); + CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual, m_round, k_actual_next, k_round_next); + SET_FLAG(MTE2, MTE1, event_id_next); + + // Preload B from gm to l1 buffer. + WAIT_FLAG(MTE1, MTE2, event_id_next + 2); + CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual, n_round); + SET_FLAG(MTE2, MTE1, event_id_next + 2); + } + + if (tidx.k == tdim.k - 1 && loop_idx + num_core < core_loop) { + uint64_t b_idx_next = (loop_idx + num_core) / tdim.n / tdim.m; + MatCoord tidx{0}; + GetBaseBlockIdx(loop_idx + num_core, tidx); + uint64_t shuffle_k_next = en_shuffle_k ? (core_idx % tdim.k) : 0; + uint64_t m_actual_next = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0; + uint64_t n_actual_next = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0; + uint64_t m_round_next = (m_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + uint64_t n_round_next = (n_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0; + uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + uint64_t offset_a_next = GetOffsetA(b_idx_next, tidx.m, shuffle_k_next); + uint64_t offset_b_next = GetOffsetB(b_idx_next, shuffle_k_next, tidx.n); + + LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + // Preload A from gm to l1 buffer. + WAIT_FLAG(MTE1, MTE2, event_id_next); + CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual_next, m_round_next, k_actual_next, + k_round_next); + SET_FLAG(MTE2, MTE1, event_id_next); + + // Preload B from gm to l1 buffer. + WAIT_FLAG(MTE1, MTE2, event_id_next + 2); + CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual_next, + n_round_next); + SET_FLAG(MTE2, MTE1, event_id_next + 2); + } + + MatCoord fidx{0}; + for (fidx.k = 0; fidx.k < fdim.k; ++fidx.k) { + uint32_t k0_round = (fidx.k < fdim.k - 1) ? k_part_len : k_round - fidx.k * k_part_len; + uint32_t k0_actual = (fidx.k < fdim.k - 1) ? k_part_len : k_actual - fidx.k * k_part_len; + + auto mte1_mad_ping_flag = 1 - fidx.k % 2; + auto mte1_mad_event_id = mte1_mad_ping_flag ? EVENT_ID0 : EVENT_ID1; + LocalTensor l0a_buf = l0a_base[(fidx.k & 0b1) * L0_PINGPONG_BUFFER_LEN]; + LocalTensor l0b_buf = l0b_base[(fidx.k & 0b1) * L0_PINGPONG_BUFFER_LEN]; + + // *** load matrix A from L1 to L0A + if (fidx.k == 0) { + WAIT_FLAG(MTE2, MTE1, event_id); + } + WAIT_FLAG(M, MTE1, mte1_mad_event_id); + if ((m == 1) || (m_actual == 1)) { + l1_to_l0_a( + l0a_buf, // dst + l1_buf_a[fidx.k * k_part_len], // src + 0, // mTileCeil + CeilDiv(k0_round), // kPartCeil + 0, // mSrcStride + 1, // kSrcStride + 0, // mDstStride + 0); // kDstStride + } else { + LoadCbufToCa(l0a_buf, // l0Tensor + l1_buf_a[fidx.k * k_part_len * m_round], // l1Tensor + m_round, // mTileCeil + k0_round, // kPartCeil + 1, // mSrcStride + m_round / CONST_16, // kSrcStride + k0_round / CONST_16, // mDstStride + 1); // kDstStride + } + if (fidx.k == fdim.k - 1) { + SET_FLAG(MTE1, MTE2, event_id); + } + + // *** load matrix B from L1 to L0B + if (fidx.k == 0) { + WAIT_FLAG(MTE2, MTE1, event_id + 2); + } + if constexpr (transB) { + LoadCbufToCb(l0b_buf, // l0Tensor + l1_buf_b[fidx.k * k_part_len * n_round], // l1Tensor + n_round, // nTileCeil + k0_round, // kPartCeil + 1, // nSrcStride + n_round / CONST_16, // kSrcStride + 1, // nDstStride + k0_round / CONST_16); // kDstStride + } else { + LoadCbufToCb(l0b_buf, // l0Tensor + l1_buf_b[fidx.k * k_part_len * CONST_16], // l1Tensor + n_round, // nTileCeil + k0_round, // kPartCeil + k_round / CONST_16, // nSrcStride + 1, // kSrcStride + 1, // nDstStride + n_round / CONST_16); // kDstStride + } + if (fidx.k == fdim.k - 1) { + SET_FLAG(MTE1, MTE2, event_id + 2); + } + + SET_FLAG(MTE1, M, mte1_mad_event_id); + WAIT_FLAG(MTE1, M, mte1_mad_event_id); + + bool init_c = (tidx.k == 0 && fidx.k == 0); + if (init_c) { + WAIT_FLAG(FIX, M, EVENT_ID0); + } + + Mad(l0c_buf, // c + l0a_buf, // a + l0b_buf, // b + m_actual, // mTileActual + n_actual, // nTileActual + k0_actual, // kTileActual + init_c); // initC + + PIPE_BARRIER(M); + SET_FLAG(M, MTE1, mte1_mad_event_id); + } + + ping_flag = 1 - ping_flag; + } + + SET_FLAG(M, FIX, EVENT_ID0); + WAIT_FLAG(M, FIX, EVENT_ID0); + + // copy from L0C to gm + if constexpr (ein) { + CopyCcToGm(gm_c[offset_c], // dst + l0c_buf, // src + m_actual, // mTileActual + n_actual, // nTileActual + m_round, // mTileCeil + (n + splitGapC) * batch_size); // nActual + + } else { + CopyCcToGm(gm_c[offset_c], // dst + l0c_buf, // src + m_actual, // MSize + n_actual, // NSize + m_round, // srcStride + n); // dstStride_dst_D + } + SET_FLAG(FIX, M, EVENT_ID0); + } + + WAIT_FLAG(M, MTE1, EVENT_ID0); + WAIT_FLAG(M, MTE1, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID0); + WAIT_FLAG(MTE1, MTE2, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID2); + WAIT_FLAG(MTE1, MTE2, EVENT_ID3); + WAIT_FLAG(FIX, M, EVENT_ID0); +#endif + } + +private: + __aicore__ __force_inline__ void GetBaseBlockIdx(uint64_t index, MatCoord &tidx) + { + uint64_t in_batch_idx = index % (tdim.m * tdim.n); + if constexpr (swizzleDirect == 0) { // Zn + uint64_t tile_block_loop = (tdim.m + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.n); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.n); + + uint64_t n_row = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_row = tdim.m - swizzle_cnt * tile_block_idx; + } + tidx.m = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row; + tidx.n = in_tile_block_idx / n_row; + if (tile_block_idx % 2 != 0) { + tidx.n = tdim.n - tidx.n - 1; + } + } else if constexpr (swizzleDirect == 1) { // Nz + uint64_t tile_block_loop = (tdim.n + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.m); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.m); + + uint64_t n_col = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_col = tdim.n - swizzle_cnt * tile_block_idx; + } + tidx.m = in_tile_block_idx / n_col; + tidx.n = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col; + if (tile_block_idx % 2 != 0) { + tidx.m = tdim.m - tidx.m - 1; + } + } + return; + } + + __aicore__ __force_inline__ uint64_t GetOffsetA(const uint64_t bIdx, const uint64_t mIdx, const uint64_t kIdx) + { + if constexpr (ein) { + return mIdx * m0 * batch_size * (k + splitGapA) + bIdx * (k + splitGapA) + kIdx * k0; + } else { + return bIdx * m * k + mIdx * m0 * k + kIdx * k0; + } + } + + __aicore__ __force_inline__ uint64_t GetOffsetB(const uint64_t bIdx, const uint64_t kIdx, const uint64_t nIdx) + { + if constexpr (formatB == DataFormat::ND) { + if constexpr (transB) { + return bIdx * k * n + nIdx * n0 * k + kIdx * k0; + } else { + return bIdx * k * n + kIdx * k0 * n + nIdx * n0; + } + } else { + if constexpr (transB) { + return bIdx * RoundUp(n) * RoundUp(k) + kIdx * k0 * RoundUp(n) + + nIdx * n0 * CONST_16; + } else { + return bIdx * RoundUp(k) * RoundUp(n) + nIdx * n0 * RoundUp(k) + + kIdx * k0 * CONST_16; + } + } + } + + __aicore__ __force_inline__ void CopyTileA(AscendC::LocalTensor &dstTensor, + const AscendC::GlobalTensor &srcTensor, const uint64_t m_actual, + const uint64_t m_round, const uint64_t k_actual, const uint64_t k_round) + { + if ((m == 1) || (m_actual == 1)) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + 1, // nTileActual + CONST_16, // nTileCeil + 1, // nVal + k_actual, // kTileActual + k_round, // kTileCeil + k); // dVal + } else { + if constexpr (ein) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + m_actual, // nTileActual + m_round, // nTileCeil + m, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + (k + splitGapA) * batch_size); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + m_actual, // nTileActual + m_round, // nTileCeil + m, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + k); // dVal + } + } + } + + __aicore__ __force_inline__ void CopyTileB(AscendC::LocalTensor &dstTensor, + const AscendC::GlobalTensor &srcTensor, const uint64_t k_actual, + const uint64_t k_round, const uint64_t n_actual, const uint64_t n_round) + { + if constexpr (formatB == DataFormat::ND) { + if constexpr (transB) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + n_actual, // nTileActual + n_round, // nTileCeil + n, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + k); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + k_actual, // nTileActual + k_round, // nTileCeil + k, // nVal + n_actual, // dTileActual + n_round, // dTileCeil + n); // dVal + } + } else { + if constexpr (transB) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + n_actual, // nTileActual + n_round, // nTileCeil + RoundUp(n), // nVal + k_actual, // dTileActual + k_round, // dTileCeil + RoundUp(k)); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + k_actual, // nTileActual + k_round, // nTileCeil + RoundUp(k), // nVal + n_actual, // dTileActual + n_round, // dTileCeil + RoundUp(n)); // dVal + } + } + } + +private: + AscendC::GlobalTensor gm_a; + AscendC::GlobalTensor gm_b; + AscendC::GlobalTensor gm_c; + AscendC::LocalTensor l1_base_a; + AscendC::LocalTensor l1_base_b; + AscendC::LocalTensor l0a_base; + AscendC::LocalTensor l0b_base; + AscendC::LocalTensor l0c_buf; + + uint32_t num_core{0}; + uint32_t batch_size{0}; + uint32_t m{0}; + uint32_t k{0}; + uint32_t n{0}; + uint32_t m0{0}; + uint32_t k0{0}; + uint32_t n0{0}; + MatCoord tdim{0}; + MatCoord fdim{0}; + uint32_t core_loop{0}; + uint32_t swizzle_cnt{1}; + uint32_t core_idx{0}; + uint32_t en_shuffle_k{0}; + uint32_t ping_flag{0}; +}; +#endif + +template +class MLAOperation +{ + static constexpr uint64_t splitGapC = CACHE_MODE == CACHE_MODE_KVCACHE ? CONST_64 : CONST_0; + +public: + __aicore__ inline MLAOperation(const MlaTilingData &mlaParams_, GM_ADDR tilingGm) + { + blockIdx = AscendC::GetBlockIdx(); +#ifdef __DAV_C220_VEC__ + sub_block_idx = static_cast(GetSubBlockidx()); +#endif + vectorBlockIdx = (blockIdx / 2) * 2 + sub_block_idx; + this->n = mlaParams_.n; + this->num_core_ = mlaParams_.rmsNumCore1; + this->num_col_1 = mlaParams_.rmsNumCol1; + this->num_col_2 = mlaParams_.rmsNumCol2; + this->num_row = mlaParams_.n; + this->epsilon_ = 1e-6; + this->mlaParams = mlaParams_; + } + + __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR wdqkvGm, GM_ADDR gamma2Gm, + GM_ADDR beta2Gm, GM_ADDR gamma3Gm, + GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm, + GM_ADDR slotMappingGm, GM_ADDR wuqGm, GM_ADDR wukGm, + GM_ADDR qGm, GM_ADDR keycacheOutGm, GM_ADDR qGm2, GM_ADDR keycacheOutGm2, GM_ADDR s1Gm, + GM_ADDR s2Gm, GM_ADDR s3Gm) + { + gamma3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gamma3Gm)); + sin1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(sin1Gm)); + cos1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(cos1Gm)); + keycacheGmTensor1.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(keycacheOutGm)); + keycacheGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(keycacheOutGm2)); + slotMappingGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(slotMappingGm)); + s2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(s2Gm)); + s3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(s3Gm)); + +#ifdef __DAV_C220_CUBE__ + mm_1.Init(hiddenStateGm, wdqkvGm, s3Gm, mlaParams.mm1); + mm_2.Init(s1Gm, wuqGm, s2Gm, mlaParams.mm2); + mm_ein_sum.Init(s2Gm, wukGm, qGm, mlaParams.mm3); +#endif + + hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(hiddenStateGm)); + wdqkvGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(wdqkvGm)); + gamma2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gamma2Gm)); + sin2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(sin2Gm)); + cos2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(cos2Gm)); + wuqGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(wuqGm)); + wukGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(wukGm)); + s1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(s1Gm)); + qGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(qGm)); + qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(qGm2)); + beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta2Gm)); + +#ifdef __DAV_C220_VEC__ + row_work = (num_row + num_core_ - 1) / num_core_; + row_work_ = 0; + uint32_t need_core = (num_row + row_work - 1) / row_work; + if (vectorBlockIdx < need_core - 1) { + row_work_ = row_work; + } else if (vectorBlockIdx == need_core - 1) { + row_work_ = num_row - (need_core - 1) * row_work; + } else { + row_work_ = 0; + } + this->splitN = mlaParams.perTaskNum; + rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, s3Gm, s1Gm, SPLIT_SIZE_ONE, + num_col_2, 0.000651041666, vectorBlockIdx * static_cast(row_work) * num_col_2, + vectorBlockIdx * static_cast(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams); + ropeFp16.RopeInit(s2Gm, cos2GmTensor, sin2GmTensor, qGmTensor, qGmTensor2, mlaParams); +#endif + } + + __aicore__ inline void ProcessCube(); + + __aicore__ inline void ProcessVector(); + +private: + constexpr static uint32_t C0_SIZE = 16; + constexpr static uint32_t I8_C0_SIZE = 32; + + template + __aicore__ inline void RmsNormAndRopeConvergence1( + const AscendC::LocalTensor &srcTensor, const AscendC::LocalTensor &gammaTensor, + const AscendC::LocalTensor &sinTensor, const AscendC::LocalTensor &cosTensor, + const AscendC::LocalTensor &slotMappingTensor, const uint32_t sN, + const AscendC::LocalTensor &rmsNormTensor, const AscendC::LocalTensor &gammaFp32, + const AscendC::LocalTensor &ropeKTensor, const AscendC::LocalTensor &ropeKRevertTensor, + const AscendC::LocalTensor &calTensor, const AscendC::LocalTensor &outTmpTensor) + { + int64_t slotMapGmOffset = vectorBlockIdx * row_work; + AscendC::DataCopy(gammaTensor, gamma3GmTensor, SPLIT_RMSNRORM_SIZE_ONE); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::DataCopyPad(slotMappingTensor, slotMappingGmTensor[slotMapGmOffset], + AscendC::DataCopyExtParams(1, sN * sizeof(int32_t), 0, 0, 0), + AscendC::DataCopyPadExtParams(false, 0, 8 - sN % 8, 0)); + SET_FLAG(MTE2, V, EVENT_ID2); + WAIT_FLAG(MTE2, V, EVENT_ID2); + SET_FLAG(MTE2, S, EVENT_ID2); + WAIT_FLAG(MTE2, S, EVENT_ID2); + for (uint64_t loop = 0; loop < sN; ++loop) { + uint64_t offset = vectorBlockIdx * static_cast(row_work) * num_col_2 + loop * MM1_OUT_SIZE; + int64_t slotValue = static_cast(slotMappingTensor.GetValue(loop)); + if (slotValue == -1) { + continue; + } + AscendC::DataCopy(srcTensor, s3GmTensor[offset], + AscendC::DataCopyParams(1, MM1_OUT_SIZE / BLOCK_SIZE_16, 0, 0)); + AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], + SPLIT_RMSNRORM_SIZE_TWO); + AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], + SPLIT_RMSNRORM_SIZE_TWO); + SET_FLAG(MTE2, V, EVENT_ID0); + // ND + uint64_t cacheStart = static_cast(slotValue) * static_cast(SPLIT_SIZE_ONE); + uint64_t cacheStart1 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_ONE); + uint64_t cacheStart2 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_TWO); + // NZ + uint32_t outer_idx = slotValue / 128; + uint32_t inner_idx = slotValue % 128; + + SET_FLAG(S, MTE3, EVENT_ID0); + /* RmsNorm start */ + WAIT_FLAG(MTE2, V, EVENT_ID0); + Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + Mul(calTensor, rmsNormTensor, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + ReduceSumCustom(calTensor[SPLIT_RMSNRORM_SIZE_ONE], calTensor, calTensor[SPLIT_RMSNRORM_SIZE_ONE * 2], + SPLIT_RMSNRORM_SIZE_ONE); + SET_FLAG(V, S, EVENT_ID1); + WAIT_FLAG(V, S, EVENT_ID1); + float rms = sqrt(calTensor.GetValue(SPLIT_RMSNRORM_SIZE_ONE) / SPLIT_RMSNRORM_SIZE_ONE + epsilon_); + SET_FLAG(S, V, EVENT_ID1); + WAIT_FLAG(S, V, EVENT_ID1); + AscendC::PipeBarrier(); + Duplicate(calTensor, rms, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + Div(calTensor, rmsNormTensor, calTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + Mul(rmsNormTensor, gammaFp32, calTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::PipeBarrier(); + Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_ONE); + /* RmsNorm end */ + /* Rope K start */ + uint64_t revertOffset = SPLIT_RMSNRORM_SIZE_TWO / 2; + Cast(ropeKTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, + SPLIT_RMSNRORM_SIZE_TWO); + Cast(ropeKRevertTensor[revertOffset], srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, + revertOffset); + Cast(ropeKRevertTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE + revertOffset], AscendC::RoundMode::CAST_NONE, + revertOffset); + Duplicate(calTensor, static_cast(-1), revertOffset); + Duplicate(calTensor[revertOffset], static_cast(1), revertOffset); + AscendC::PipeBarrier(); + Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO], cosTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO); + Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], sinTensor, AscendC::RoundMode::CAST_NONE, + SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + Mul(ropeKTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO], ropeKTensor, SPLIT_RMSNRORM_SIZE_TWO); + Mul(ropeKRevertTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT, + SPLIT_RMSNRORM_SIZE_TWO); + AscendC::PipeBarrier(); + /* Rope K end */ + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(S, MTE3, EVENT_ID0); + if constexpr (CACHE_MODE == CACHE_MODE_KVCACHE) { + DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, SPLIT_SIZE_ONE); + } else { + // keycache1 + DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, SPLIT_RMSNRORM_SIZE_ONE); + // keycache2 + DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], + SPLIT_RMSNRORM_SIZE_TWO); + } + SET_FLAG(MTE3, MTE2, EVENT_ID1); + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + } + } + +private: + uint32_t n; + uint32_t splitN; + uint32_t rotaryCoeff; + uint32_t blockIdx; + uint32_t sub_block_idx; + uint32_t vectorBlockIdx; + uint32_t blockOffset; + uint32_t perTaskNum; + uint32_t resTaskNum; + MlaTilingData mlaParams; + + uint32_t num_core_; + uint32_t num_col_1; + uint32_t num_col_2; + float epsilon_; + uint32_t num_row; + uint32_t row_work; + uint32_t row_work_; + + AsdopsBuffer buf; + + AscendC::GlobalTensor hiddenStateGmTensor; + + AscendC::GlobalTensor wdqkvGmTensor; + AscendC::GlobalTensor gamma2GmTensor; + AscendC::GlobalTensor gamma3GmTensor; + AscendC::GlobalTensor sin1GmTensor; + AscendC::GlobalTensor cos1GmTensor; + AscendC::GlobalTensor sin2GmTensor; + AscendC::GlobalTensor cos2GmTensor; + AscendC::GlobalTensor keycacheGmTensor1; + AscendC::GlobalTensor keycacheGmTensor2; + AscendC::GlobalTensor slotMappingGmTensor; + AscendC::GlobalTensor wuqGmTensor; + AscendC::GlobalTensor wukGmTensor; + + AscendC::GlobalTensor qGmTensor; + AscendC::GlobalTensor qGmTensor2; + AscendC::GlobalTensor s1GmTensor; + AscendC::GlobalTensor s2GmTensor; + AscendC::GlobalTensor s3GmTensor; + AscendC::GlobalTensor beta2GmTensor; + +#ifdef __DAV_C220_CUBE__ + PpMatmulEinSum mm_1; + PpMatmulEinSum mm_2; + PpMatmulEinSum mm_ein_sum; +#endif + +#ifdef __DAV_C220_VEC__ + RmsNormQuant rmsNormQuant2; + RopeFp16 ropeFp16; +#endif +}; + +template +__aicore__ inline void +MLAOperation::ProcessCube() +{ +#ifdef __DAV_C220_CUBE__ + mm_1.Process(); + FftsCrossCoreSync(MM1); + WaitFlagDev(MM1); + FftsCrossCoreSync(QUANT1); + + WaitFlagDev(AIC_MM2_START); + mm_2.Process(); + FftsCrossCoreSync(MM2); + WaitFlagDev(MM2); + + FftsCrossCoreSync(EINSUMOUT); + mm_ein_sum.Process(); +#endif +} + +template +__aicore__ inline void +MLAOperation::ProcessVector() +{ +#ifdef __DAV_C220_VEC__ + WaitFlagDev(QUANT1); + if (row_work_ != 0) { + uint32_t num_col_align_int8 = (num_col_2 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; + uint32_t num_col_align_f16 = (num_col_2 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; + uint32_t num_col_align_f32 = (num_col_2 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; + AscendC::LocalTensor input_tensor = buf.GetBuffer(0); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor beta_tensor = + buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2); + AscendC::LocalTensor res1_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64); + AscendC::LocalTensor res3_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4); + AscendC::LocalTensor output_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4 + + BUF_FACTOR * num_col_align_f32 * 4 + 64 + MM1_OUT_SIZE * 4 * 2 + 32); + rmsNormQuant2.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, res1_tensor, res3_tensor); + } + FftsCrossCoreSync(RMSNORMQUANT2); + WaitFlagDev(RMSNORMQUANT2); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(AIC_MM2_START); + + if (row_work_ != 0) { + AscendC::LocalTensor input_tensor = buf.GetBuffer(0); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor sin_tensor = + buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2); + AscendC::LocalTensor cos_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 2); + AscendC::LocalTensor slotMapping_tensor = buf.GetBuffer( + MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4); + int32_t rms3_ub_offset = + MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + 4096 * 32; + AscendC::LocalTensor tmp32_tensor = buf.GetBuffer(rms3_ub_offset); + + int32_t out_ub_offset = MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + + 4096 * 32 + SPLIT_RMSNRORM_SIZE_ONE * 3 * 4 + SPLIT_RMSNRORM_SIZE_TWO * 2 * 4 + + MM1_OUT_SIZE * 4 * 2 + 32; + AscendC::LocalTensor temp_tensor = buf.GetBuffer(out_ub_offset); + + RmsNormAndRopeConvergence1( + input_tensor, // n * 576 + gamma_tensor, // gamma + sin_tensor, // sin + cos_tensor, // cons + slotMapping_tensor, // slotMapping + row_work_, tmp32_tensor, tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE], + tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE], + tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO], + tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO + + SPLIT_RMSNRORM_SIZE_TWO], + temp_tensor); + } + + WaitFlagDev(EINSUMOUT); + ropeFp16.Process(); +#endif +} + +} // namespace MLAPO_BF16_NQ diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16_qdown.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16_qdown.hpp index aa990cbb..1df954e8 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16_qdown.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16_qdown.hpp @@ -2406,6 +2406,7 @@ public: this->num_row = mlaParams_.n; this->epsilon_ = 1e-6; this->mlaParams = mlaParams_; + this->hiddenStateDim = mlaParams_.hiddenStateDim; } __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm, @@ -2713,6 +2714,7 @@ private: uint32_t blockOffset; uint32_t perTaskNum; uint32_t resTaskNum; + uint32_t hiddenStateDim; MlaTilingData mlaParams; uint32_t num_core_; @@ -2817,18 +2819,15 @@ MLAOperation input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor scale_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); - AscendC::LocalTensor offset_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32); - AscendC::LocalTensor res1_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64); + AscendC::LocalTensor scale_tensor = buf.GetBuffer(base_offset); + AscendC::LocalTensor offset_tensor = buf.GetBuffer(base_offset + 32); + AscendC::LocalTensor res1_tensor = buf.GetBuffer(base_offset + 64); AscendC::LocalTensor res3_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4); + base_offset + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 + - BUF_FACTOR * num_col_align_f32 * 4 + 64); + base_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64); Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); } FftsCrossCoreSync(QUANT1); diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp index 73cb04de..a26bdfea 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp @@ -2034,6 +2034,7 @@ public: this->num_row = mlaParams_.n; this->epsilon_ = 1e-6; this->mlaParams = mlaParams_; + this->hiddenStateDim = mlaParams_.hiddenStateDim; } __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm, @@ -2294,6 +2295,7 @@ private: uint32_t blockOffset; uint32_t perTaskNum; uint32_t resTaskNum; + uint32_t hiddenStateDim; MlaTilingData mlaParams; // rmsnormQuant @@ -2389,21 +2391,19 @@ __aicore__ inline void MLAOperation input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor gamma_tensor = buf.GetBuffer(HIDDTEN_STATE * 2); - AscendC::LocalTensor beta_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); - AscendC::LocalTensor scale_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); - AscendC::LocalTensor offset_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32); - AscendC::LocalTensor res1_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(gamma_offset); + AscendC::LocalTensor beta_tensor = buf.GetBuffer(beta_offset); + AscendC::LocalTensor scale_tensor = buf.GetBuffer(scale_offset); + AscendC::LocalTensor offset_tensor = buf.GetBuffer(scale_offset + 32); + AscendC::LocalTensor res1_tensor = buf.GetBuffer(scale_offset + 64); AscendC::LocalTensor res3_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4); + scale_offset + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 + - BUF_FACTOR * num_col_align_f32 * 4 + 32); + scale_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 32); Quant1.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); } diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 4465cd34..e0c976c4 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -176,15 +176,51 @@ std::tuple rotary_embedding(at::Tensor &positions, at::T std::tuple mla_preprocess( const at::Tensor &hiddenState, const at::Tensor &wdqkv, - const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq, - const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin, + const c10::optional &descale0, const at::Tensor &gamma1, const c10::optional &beta1, const at::Tensor &wuq, + const c10::optional &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin, const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping, - const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0, - const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1, + const c10::optional &quant_scale0, const c10::optional &quant_offset0, const c10::optional &bias0, + const c10::optional &quant_scale1, const c10::optional &quant_offset1, const c10::optional &bias1, const c10::optional &ctkv_scale, const c10::optional &q_nope_scale, c10::optional cache_mode, c10::optional quant_mode, c10::optional enable_inner_out, at::Tensor &q_out0, at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1, at::Tensor &inner_out) { + at::Tensor Descale0 = + descale0.has_value() + ? descale0.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor Descale1 = + descale1.has_value() + ? descale1.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor Beta1 = + beta1.has_value() + ? beta1.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor Quant_scale0 = + quant_scale0.has_value() + ? quant_scale0.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor Quant_scale1 = + quant_scale1.has_value() + ? quant_scale1.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor Quant_offset0 = + quant_offset0.has_value() + ? quant_offset0.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor Quant_offset1 = + quant_offset1.has_value() + ? quant_offset1.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor Bias0 = + bias0.has_value() + ? bias0.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); + at::Tensor Bias1 = + bias1.has_value() + ? bias1.value() + : at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device())); at::Tensor CtkvScale = ctkv_scale.has_value() ? ctkv_scale.value() @@ -200,6 +236,7 @@ std::tuple auto [workspace_tensor, tiling, block_dim] = mlapo::mla_preprocess_tiling( hiddenState, + wdqkv, wuk, cache_mode, quant_mode, @@ -207,24 +244,24 @@ std::tuple ); void *hidden_state_ptr = hiddenState.data_ptr(); - void *quant_scale0_ptr = quant_scale0.data_ptr(); - void *quant_offset0_ptr = quant_offset0.data_ptr(); + void *quant_scale0_ptr = Quant_scale0.data_ptr(); + void *quant_offset0_ptr = Quant_offset0.data_ptr(); void *wdqkv_ptr = wdqkv.data_ptr(); - void *bias0_ptr = bias0.data_ptr(); + void *bias0_ptr = Bias0.data_ptr(); void *gamma1_ptr = gamma1.data_ptr(); - void *beta1_ptr = beta1.data_ptr(); - void *quant_scale1_ptr = quant_scale1.data_ptr(); - void *quant_offset1_ptr = quant_offset1.data_ptr(); + void *beta1_ptr = Beta1.data_ptr(); + void *quant_scale1_ptr = Quant_scale1.data_ptr(); + void *quant_offset1_ptr = Quant_offset1.data_ptr(); void *gamma2_ptr = gamma2.data_ptr(); void *sin_ptr = sin.data_ptr(); void *cos_ptr = cos.data_ptr(); void *kv_cache_ptr = kv_cache.data_ptr(); void *slotmapping_ptr = slotmapping.data_ptr(); void *wuq_ptr = wuq.data_ptr(); - void *bias1_ptr = bias1.data_ptr(); + void *bias1_ptr = Bias1.data_ptr(); void *wuk_ptr = wuk.data_ptr(); - void *descale0_ptr = descale0.data_ptr(); - void *descale1_ptr = descale1.data_ptr(); + void *descale0_ptr = Descale0.data_ptr(); + void *descale1_ptr = Descale1.data_ptr(); void *ctkv_scale_ptr = CtkvScale.data_ptr(); void *qnope_scale_ptr = QnopeScale.data_ptr(); void *q_out0_ptr = q_out0.data_ptr(); @@ -1122,11 +1159,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) ops.def( "mla_preprocess(Tensor hiddenState, Tensor wdqkv," - " Tensor descale0, Tensor gamma1, Tensor beta1, Tensor wuq, Tensor descale1," + " Tensor? descale0, Tensor gamma1, Tensor? beta1, Tensor wuq, Tensor? descale1," " Tensor gamma2, Tensor cos, Tensor sin, Tensor wuk, Tensor kv_cache," - " Tensor kv_cache_rope, Tensor slotmapping, Tensor quant_scale0," - " Tensor quant_offset0, Tensor bias0, Tensor quant_scale1, Tensor quant_offset1," - " Tensor bias1, Tensor? ctkv_scale, Tensor? q_nope_scale, str? cache_mode," + " Tensor kv_cache_rope, Tensor slotmapping, Tensor? quant_scale0," + " Tensor? quant_offset0, Tensor? bias0, Tensor? quant_scale1, Tensor? quant_offset1," + " Tensor? bias1, Tensor? ctkv_scale, Tensor? q_nope_scale, str? cache_mode," " str? quant_mode, bool? enable_inner_out, Tensor! q_out0, Tensor! kv_cache_out0, Tensor! q_out1," " Tensor! kv_cache_out1, Tensor! inner_out) -> (Tensor q_out0, Tensor kv_cache_out0," " Tensor q_out1, Tensor kv_cache_out1, Tensor inner_out)" diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index 50b28e5d..3223801b 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -84,11 +84,11 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_ std::tuple mla_preprocess( const at::Tensor &hiddenState, const at::Tensor &wdqkv, - const at::Tensor &descale0, + const c10::optional &descale0, const at::Tensor &gamma1, - const at::Tensor &beta1, + const c10::optional &beta1, const at::Tensor &wuq, - const at::Tensor &descale1, + const c10::optional &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin, @@ -96,12 +96,12 @@ std::tuple const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping, - const at::Tensor &quant_scale0, - const at::Tensor &quant_offset0, - const at::Tensor &bias0, - const at::Tensor &quant_scale1, - const at::Tensor &quant_offset1, - const at::Tensor &bias1, + const c10::optional &quant_scale0, + const c10::optional &quant_offset0, + const c10::optional &bias0, + const c10::optional &quant_scale1, + const c10::optional &quant_offset1, + const c10::optional &bias1, const c10::optional &ctkv_scale, const c10::optional &q_nope_scale, c10::optional cache_mode, diff --git a/tests/e2e/nightly/ops/test_mla_preprocess.py b/tests/e2e/nightly/ops/test_mla_preprocess.py index e73310f4..99b383ba 100644 --- a/tests/e2e/nightly/ops/test_mla_preprocess.py +++ b/tests/e2e/nightly/ops/test_mla_preprocess.py @@ -67,6 +67,11 @@ def test_mla_preprocess_kernel(): dtype=hidden_states.dtype, device=hidden_states.device, ) + q_down = torch.empty( + (hidden_states.shape[0], 1536), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) q_nope_old = q_nope_out.clone() q_rope_old = q_rope_out.clone() @@ -95,10 +100,12 @@ def test_mla_preprocess_kernel(): q_nope_scale=qnope_scale, cache_mode="krope_ctkv", quant_mode="per_tensor_quant_asymm", + enable_inner_out=False, q_out0=q_nope_out, kv_cache_out0=kv_cache, q_out1=q_rope_out, kv_cache_out1=kv_cache_rope, + inner_out=q_down, ) assert not torch.equal(q_nope_out, q_nope_old) assert not torch.equal(q_rope_out, q_rope_old) diff --git a/tests/e2e/nightly/ops/test_mla_preprocess_nq.py b/tests/e2e/nightly/ops/test_mla_preprocess_nq.py new file mode 100644 index 00000000..b18c63f6 --- /dev/null +++ b/tests/e2e/nightly/ops/test_mla_preprocess_nq.py @@ -0,0 +1,99 @@ +import gc + +import torch +import torch_npu + +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + + +@torch.inference_mode() +def test_mla_preprocess_kernel(): + token_num = 1 + head_num = 2 + N_7168 = 7168 + block_num = 1 + block_size = 128 + dtype = torch.bfloat16 + + hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu() + + wdqkv = torch.randint(0, 7, (1, 448, 2112, 16), dtype=dtype).npu() + wdqkv = torch_npu.npu_format_cast(wdqkv.contiguous(), 29) + gamma1 = torch.randn((1536), dtype=dtype).npu() + + wuq = torch.randint(0, 7, (1, 96, head_num * 192, 16), dtype=dtype).npu() + wuq = torch_npu.npu_format_cast(wuq.contiguous(), 29) + gamma2 = torch.randn((512), dtype=dtype).npu() + + cos = torch.randn((token_num, 64), dtype=dtype).npu() + sin = torch.randn((token_num, 64), dtype=dtype).npu() + + wuk = torch.randn((head_num, 128, 512), dtype=dtype).npu() + # wuk = torch_npu.npu_format_cast(wuk, 29) + kv_cache = torch.randint(0, + 7, + (block_num, head_num * 512 // 32, block_size, 32), + dtype=dtype).npu() + kv_cache_rope = torch.randn( + (block_num, head_num * 64 // 16, block_size, 16), dtype=dtype).npu() + + slotmapping = torch.randint(0, 7, (token_num, ), dtype=torch.int32).npu() + + q_nope_out = torch.empty( + (hidden_states.shape[0], wuk.shape[0], kv_cache.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + q_rope_out = torch.empty( + (hidden_states.shape[0], wuk.shape[0], kv_cache_rope.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + q_down = torch.empty( + (hidden_states.shape[0], 1536), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + q_nope_old = q_nope_out.clone() + q_rope_old = q_rope_out.clone() + + torch.ops._C_ascend.mla_preprocess( + hidden_states, + wdqkv, + None, + gamma1, + None, + wuq, + None, + gamma2, + cos, + sin, + wuk, + kv_cache, + kv_cache_rope, + slotmapping, + None, + None, + None, + None, + None, + None, + None, + None, + cache_mode="krope_ctkv", + quant_mode="no_quant", + enable_inner_out=False, + q_out0=q_nope_out, + kv_cache_out0=kv_cache, + q_out1=q_rope_out, + kv_cache_out1=kv_cache_rope, + inner_out=q_down, + ) + assert not torch.equal(q_nope_out, q_nope_old) + assert not torch.equal(q_rope_out, q_rope_old) + + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats()