From 542258ac9d9229aab4e8822de42443245a93f001 Mon Sep 17 00:00:00 2001 From: liuchen2026fly Date: Mon, 9 Mar 2026 20:17:21 +0800 Subject: [PATCH] [feat] parameterize hardcoded MLA dimensions to support GLM5-W8A8 (#6902) Derive MLA dimension constants (q_lora_rank, qk_nope_head_dim, etc.) from tensor shapes at runtime instead of hardcoding DeepSeek V3 values. This enables the mla_preprocess fused op to work with both DeepSeek V3 and GLM5 models without Python API changes. - Add 9 dimension fields to MlaTilingData with DeepSeek V3 defaults - Add OpParam fields and dynamize all host-side tiling functions - Derive dimensions from wuk, gamma1, kv_cache_rope tensor shapes - Replace 310+ hardcoded constants across 4 kernel .hpp files - Remove unused MMSIZE1/MMSIZE2 constants ### What this PR does / why we need it? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7 --------- Signed-off-by: liuchenbing Co-authored-by: liuchenbing --- .../mla_preprocess_torch_adpt.h | 2 + csrc/mla_preprocess/op_host/mla_preprocess.h | 77 +++++-- .../op_host/tiling/mla_preprocess_tiling.h | 12 + .../mla_preprocess/op_kernel/mla_preprocess.h | 3 - .../op_kernel/mla_preprocess_kernel.cpp | 12 + .../op_kernel/mla_preprocess_mix_bf16.hpp | 212 ++++++++++-------- .../op_kernel/mla_preprocess_mix_bf16_nq.hpp | 150 ++++++++----- .../mla_preprocess_mix_bf16_qdown.hpp | 212 ++++++++++-------- .../op_kernel/mla_preprocess_mix_fp16.hpp | 170 ++++++++------ 9 files changed, 508 insertions(+), 342 deletions(-) diff --git a/csrc/mla_preprocess/mla_preprocess_torch_adpt.h b/csrc/mla_preprocess/mla_preprocess_torch_adpt.h index 0bcd3ca8..7f70e55f 100644 --- a/csrc/mla_preprocess/mla_preprocess_torch_adpt.h +++ b/csrc/mla_preprocess/mla_preprocess_torch_adpt.h @@ -84,6 +84,8 @@ std::tuple hiddenState, wdqkv, wuk, + gamma1, + kv_cache_rope, cache_mode, quant_mode, enableInnerOut diff --git a/csrc/mla_preprocess/op_host/mla_preprocess.h b/csrc/mla_preprocess/op_host/mla_preprocess.h index ae4a1225..bff9274a 100644 --- a/csrc/mla_preprocess/op_host/mla_preprocess.h +++ b/csrc/mla_preprocess/op_host/mla_preprocess.h @@ -129,6 +129,11 @@ struct OpParam { QuantMode quantMode; caffe2::TypeMeta inDtype; bool enableInnerOut; + // MLA dimensions derived from tensor shapes + uint32_t qLoraRank; + uint32_t qkNopeHeadDim; + uint32_t qkRopeHeadDim; + uint32_t kvLoraRank; }; class PpMatmulTilingApi @@ -397,7 +402,7 @@ void MlaPreprocessTiling::RmsNormQuantTiling() tilingData->rmsNumRow1 = opParam.N; tilingData->rmsQuantMin1 = -CONST_128; tilingData->rmsNumCore2 = platformInfo.coreNumAiv; - tilingData->rmsNumCol2 = HIDDEN_STRATE_MM; + tilingData->rmsNumCol2 = opParam.qLoraRank + opParam.kvLoraRank + opParam.qkRopeHeadDim; tilingData->rmsNumRow2 = opParam.N; tilingData->rmsQuantMin2 = -CONST_128; } @@ -405,10 +410,10 @@ void MlaPreprocessTiling::RmsNormQuantTiling() void MlaPreprocessTiling::RopeConcatTiling() { uint32_t ntokens = opParam.N; - uint32_t hiddenSizeQ = HEADDIM * opParam.headNum; - uint32_t headDim = HEADDIM; + uint32_t hiddenSizeQ = opParam.qkRopeHeadDim * opParam.headNum; + uint32_t headDim = opParam.qkRopeHeadDim; uint32_t headNumQ = hiddenSizeQ / headDim; - uint32_t concatSize = CONCAT_SIZE; + uint32_t concatSize = opParam.kvLoraRank; uint32_t maxCore = platformInfo.coreNumAiv; uint32_t maxUbSize = platformInfo.ubSize; @@ -458,7 +463,7 @@ void MlaPreprocessTiling::EinSumQuantTiling() // input shape uint32_t esqBatch = opParam.N; // tokenNum uint32_t esqHeadNum = opParam.headNum; // headNum - uint32_t esqColNum = AXES_ALIGN_SIZE; // 512 + uint32_t esqColNum = opParam.kvLoraRank; // kv_lora_rank // split core uint32_t esqFrontCore = esqBatch % aivCore; @@ -508,14 +513,16 @@ void MlaPreprocessTiling::EinSumQuantTiling() void MlaPreprocessTiling::SetMlapoWorkSpace() { + uint32_t hiddenStrideRope = opParam.qkNopeHeadDim + opParam.qkRopeHeadDim; + uint32_t hiddenStrateMm = opParam.qLoraRank + opParam.kvLoraRank + opParam.qkRopeHeadDim; uint64_t s1wsFactor = static_cast(opParam.cacheMode == 2 ? std::max(opParam.hiddenStateDim * sizeof(int8_t), - opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t)) + opParam.headNum * opParam.kvLoraRank * sizeof(uint16_t)) : opParam.hiddenStateDim * sizeof(int8_t)); uint64_t workSizeS1 = s1wsFactor; - uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t); - uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t); - uint64_t workSizeS4 = std::max(opParam.headNum * HIDDEN_STRATE_ROPE, HIDDEN_STRATE_MM) * sizeof(uint32_t); + uint64_t workSizeS2 = opParam.headNum * hiddenStrideRope * sizeof(uint16_t); + uint64_t workSizeS3 = hiddenStrateMm * sizeof(uint16_t); + uint64_t workSizeS4 = std::max(opParam.headNum * hiddenStrideRope, hiddenStrateMm) * sizeof(uint32_t); uint64_t maxWorkspaceSize = workSizeS1; maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS2); @@ -564,11 +571,17 @@ void MlaPreprocessTiling::Init() deqOnTheFly = true; } + uint32_t mm1N = opParam.qLoraRank + opParam.kvLoraRank + opParam.qkRopeHeadDim; + uint32_t mm2K = opParam.qLoraRank; + uint32_t mm2N = opParam.headNum * (opParam.qkNopeHeadDim + opParam.qkRopeHeadDim); + uint32_t mm3K = opParam.qkNopeHeadDim; + uint32_t mm3N = opParam.kvLoraRank; + PpMatmulTilingApi mm1TilingApi(platformInfo, 1, // numBatch opParam.N, // m opParam.hiddenStateDim, // k - HIDDEN_STRATE_MM, // n + mm1N, // n false, // transA true, // transB enDequant, // enDequant @@ -576,21 +589,21 @@ void MlaPreprocessTiling::Init() mm1TilingApi.GetTilingData(tilingData->mm1); PpMatmulTilingApi mm2TilingApi(platformInfo, - 1, // numBatch - opParam.N, // m - HIDDEN_STRATE_RMS, // k - opParam.headNum * HIDDEN_STRATE_ROPE, // n - false, // transA - true, // transB - enDequant, // enDequant - deqOnTheFly); // in bf16.cce? + 1, // numBatch + opParam.N, // m + mm2K, // k + mm2N, // n + false, // transA + true, // transB + enDequant, // enDequant + deqOnTheFly); // in bf16.cce? mm2TilingApi.GetTilingData(tilingData->mm2); PpMatmulTilingApi mm3TilingApi(platformInfo, opParam.headNum, // numBatch opParam.N, // m - CONST_128, // k - CONCAT_SIZE, // n + mm3K, // k + mm3N, // n false, // transA false, // transB false, // enDequant @@ -604,6 +617,18 @@ void MlaPreprocessTiling::Init() SetMlapoWorkSpace(); SetTilingKey(); + // Populate model-specific MLA dimension fields + tilingData->mm1OutSize = opParam.qLoraRank + opParam.kvLoraRank + opParam.qkRopeHeadDim; + tilingData->splitSizeOne = opParam.kvLoraRank + opParam.qkRopeHeadDim; + tilingData->splitSizeTwo = opParam.qLoraRank; + tilingData->splitRmsNormSizeOne = opParam.kvLoraRank; + tilingData->splitRmsNormSizeTwo = opParam.qkRopeHeadDim; + tilingData->ropeSplitSizeOne = opParam.qkRopeHeadDim; + tilingData->ropeSplitSizeTwo = opParam.qkNopeHeadDim; + tilingData->hiddenStrideRope = opParam.qkNopeHeadDim + opParam.qkRopeHeadDim; + tilingData->qkNopeHeadDim = opParam.qkNopeHeadDim; + tilingData->avgFactor = 1.0f / static_cast(opParam.qLoraRank); + return; } @@ -631,6 +656,8 @@ std::tuple mla_preprocess_tiling( const at::Tensor &hiddenState, const at::Tensor &wdqkv, const at::Tensor &wuk, + const at::Tensor &gamma1, + const at::Tensor &kv_cache_rope, c10::optional cache_mode, c10::optional quant_mode, bool enable_inner_out @@ -656,6 +683,12 @@ std::tuple mla_preprocess_tiling( int32_t headNum = wuk.sizes()[0]; uint32_t hiddenStateDim = hiddenState.sizes().back(); + // Derive MLA dimensions from tensor shapes + uint32_t qkNopeHeadDim = wuk.sizes()[1]; + uint32_t kvLoraRank = wuk.sizes()[2]; + uint32_t qLoraRank = gamma1.sizes()[0]; + uint32_t qkRopeHeadDim = kv_cache_rope.sizes().back(); + OpParam opParam; opParam.hiddenStateDim = hiddenStateDim; opParam.N = N; @@ -664,6 +697,10 @@ std::tuple mla_preprocess_tiling( opParam.quantMode = static_cast(quantMode); opParam.inDtype = hiddenState.options().dtype(); opParam.enableInnerOut = enable_inner_out; + opParam.qLoraRank = qLoraRank; + opParam.qkNopeHeadDim = qkNopeHeadDim; + opParam.qkRopeHeadDim = qkRopeHeadDim; + opParam.kvLoraRank = kvLoraRank; if (wdqkv.options().dtype() == at::kBFloat16 || wdqkv.options().dtype() == at::kHalf) { opParam.isWeightQuantized = 0; } else { diff --git a/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h b/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h index 00be2898..fcae7960 100644 --- a/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h +++ b/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h @@ -95,6 +95,18 @@ struct MlaTilingData { uint32_t hiddenStateDim{7168}; uint32_t isWeightQuantized{1}; + + // Model-specific MLA dimensions (derived from tensor shapes) + uint32_t mm1OutSize{2112}; // q_lora_rank + kv_lora_rank + qk_rope_head_dim + uint32_t splitSizeOne{576}; // kv_lora_rank + qk_rope_head_dim + uint32_t splitSizeTwo{1536}; // q_lora_rank + uint32_t splitRmsNormSizeOne{512}; // kv_lora_rank + uint32_t splitRmsNormSizeTwo{64}; // qk_rope_head_dim + uint32_t ropeSplitSizeOne{64}; // qk_rope_head_dim + uint32_t ropeSplitSizeTwo{128}; // qk_nope_head_dim + uint32_t hiddenStrideRope{192}; // qk_nope_head_dim + qk_rope_head_dim + uint32_t qkNopeHeadDim{128}; // for RoPE offset calc + float avgFactor{0.000651041666f}; // 1/splitSizeTwo (1/qLoraRank), for RmsNorm avg }; #endif // MLAPREPROCESS_TILING_H diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess.h b/csrc/mla_preprocess/op_kernel/mla_preprocess.h index 6c894bd1..6917137b 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess.h +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess.h @@ -60,9 +60,6 @@ constexpr uint32_t SPLIT_RMSNRORM_SIZE_TWO = 64; constexpr uint32_t ROPE_SPLIT_SIZE_ONE = 64; constexpr uint32_t ROPE_SPLIT_SIZE_TWO = 128; -constexpr uint32_t MMSIZE1 = 128 * 192; // 24576 -constexpr uint32_t MMSIZE2 = 64 * 128; // 8192 - constexpr uint64_t L0_PINGPONG_BUFFER_LEN = 32768; // 32 KB constexpr uint64_t L1_PINGPONG_BUFFER_LEN = 262144; // 256 KB constexpr uint64_t BLOCK_SIZE_16 = 16; diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp index f5dbe210..d5c4eca6 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp @@ -136,6 +136,18 @@ extern "C" __global__ __aicore__ void mla_preprocess( mlaTilingData.s4Offset = tilingData->s4Offset; mlaTilingData.s5Offset = tilingData->s5Offset; + // Model-specific MLA dimensions + mlaTilingData.mm1OutSize = tilingData->mm1OutSize; + mlaTilingData.splitSizeOne = tilingData->splitSizeOne; + mlaTilingData.splitSizeTwo = tilingData->splitSizeTwo; + mlaTilingData.splitRmsNormSizeOne = tilingData->splitRmsNormSizeOne; + mlaTilingData.splitRmsNormSizeTwo = tilingData->splitRmsNormSizeTwo; + mlaTilingData.ropeSplitSizeOne = tilingData->ropeSplitSizeOne; + mlaTilingData.ropeSplitSizeTwo = tilingData->ropeSplitSizeTwo; + mlaTilingData.hiddenStrideRope = tilingData->hiddenStrideRope; + mlaTilingData.qkNopeHeadDim = tilingData->qkNopeHeadDim; + mlaTilingData.avgFactor = tilingData->avgFactor; + GM_ADDR s1 = workspace + static_cast(mlaTilingData.s1Offset); GM_ADDR s2 = workspace + static_cast(mlaTilingData.s2Offset); GM_ADDR s3 = workspace + static_cast(mlaTilingData.s3Offset); diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp index 5c641890..1e0220ef 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp @@ -43,6 +43,8 @@ public: headDim = ropeConcatParams.headDim; headNumQ = ropeConcatParams.headNumQ; + this->hiddenStrideRope_ = ropeConcatParams.hiddenStrideRope; + this->qkNopeHeadDim_ = ropeConcatParams.qkNopeHeadDim; rotaryCoeff = ropeConcatParams.rotaryCoeff; ntokens = ropeConcatParams.ntokens; realCore = ropeConcatParams.realCore; @@ -92,7 +94,7 @@ public: AscendC::LocalTensor inputQCastFP32 = buf.GetBuffer(dataSizeFp16); AscendC::LocalTensor reverseQ = buf.GetBuffer(dataSizeFp32 + dataSizeFp16); - uint64_t qOffset = startHead * 192 + 128; + uint64_t qOffset = startHead * hiddenStrideRope_ + qkNopeHeadDim_; CopyQGenReverseQ(inputQ, inputQCastFP32, reverseQ, qOffset, loopN); // move in cos/sin @@ -178,7 +180,7 @@ public: { // move in Q WAIT_FLAG(MTE3, MTE2, EVENT_ID1); - AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, 128 / 16, 0}); + AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, static_cast(qkNopeHeadDim_ / 16), 0}); SET_FLAG(MTE2, V, EVENT_ID1); WAIT_FLAG(MTE2, V, EVENT_ID1); // cast fp32 @@ -215,6 +217,8 @@ private: AscendC::GlobalTensor outRopeConcatGm_; AscendC::GlobalTensor outRopeConcatGm2_; + uint32_t hiddenStrideRope_{0}; + uint32_t qkNopeHeadDim_{0}; uint32_t repeatSize_{0}; uint32_t rotateStride_{0}; // this->headDim / rope conf uint32_t headDim; @@ -309,6 +313,7 @@ public: this->num_row_ = mlaParams_.n; this->row_work = row_work; this->row_work_ = row_work_; + this->mm1OutSize_ = mlaParams_.mm1OutSize; gm_offset_ = gm_offset; gm_out_offset_ = gm_out_offset; num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; @@ -353,8 +358,8 @@ public: if constexpr (NEED_DEQUANT) { mmTensor = buf.ReinterpretCast()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; - deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE]; - perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2]; + deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_]; + perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_ * 2]; AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); } @@ -528,6 +533,7 @@ private: uint32_t num_col_align_withStride_fp32{0}; uint32_t num_col_temp; half quantMin_{-128}; + uint32_t mm1OutSize_{0}; uint32_t num_slice_{0}; uint32_t tail_size_{0}; uint32_t tail_copy_{0}; @@ -567,6 +573,7 @@ public: this->num_row_ = mlaParams_.n; this->row_work = row_work; this->row_work_ = row_work_; + this->mm1OutSize_ = mlaParams_.mm1OutSize; gm_offset_ = gm_offset; gm_out_offset_ = gm_out_offset; num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; @@ -618,8 +625,8 @@ public: if constexpr (NEED_DEQUANT) { mmTensor = buf.ReinterpretCast()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; - deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE]; - perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2]; + deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_]; + perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_ * 2]; AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); } @@ -834,6 +841,7 @@ private: uint32_t num_col_align_withStride_fp32{0}; uint32_t num_col_temp; half quantMin_{-128}; + uint32_t mm1OutSize_{0}; uint32_t num_slice_{0}; uint32_t tail_size_{0}; uint32_t tail_copy_{0}; @@ -2387,6 +2395,15 @@ public: this->epsilon_ = 1e-6; this->mlaParams = mlaParams_; this->hiddenStateDim = mlaParams_.hiddenStateDim; + this->mm1OutSize_ = mlaParams_.mm1OutSize; + this->splitSizeOne_ = mlaParams_.splitSizeOne; + this->splitSizeTwo_ = mlaParams_.splitSizeTwo; + this->splitRmsNormSizeOne_ = mlaParams_.splitRmsNormSizeOne; + this->splitRmsNormSizeTwo_ = mlaParams_.splitRmsNormSizeTwo; + this->ropeSplitSizeOne_ = mlaParams_.ropeSplitSizeOne; + this->ropeSplitSizeTwo_ = mlaParams_.ropeSplitSizeTwo; + this->hiddenStrideRope_ = mlaParams_.hiddenStrideRope; + this->qkNopeHeadDim_ = mlaParams_.qkNopeHeadDim; } __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm, @@ -2471,15 +2488,15 @@ public: vectorBlockIdx * static_cast(row_work) * num_col_1, row_work_, mlaParams); if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor, - s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s3Gm, s1Gm, SPLIT_SIZE_ONE, - num_col_2, 0.000651041666, vectorBlockIdx * static_cast(row_work) * num_col_2, - vectorBlockIdx * static_cast(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams); + s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s3Gm, s1Gm, splitSizeOne_, + num_col_2, mlaParams.avgFactor, vectorBlockIdx * static_cast(row_work) * num_col_2, + vectorBlockIdx * static_cast(row_work) * splitSizeTwo_, row_work_, mlaParams); } else { // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor, - s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s2Gm, s1Gm, SPLIT_SIZE_ONE, - num_col_2, 0.000651041666, vectorBlockIdx * static_cast(row_work) * num_col_2, - vectorBlockIdx * static_cast(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams); + s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s2Gm, s1Gm, splitSizeOne_, + num_col_2, mlaParams.avgFactor, vectorBlockIdx * static_cast(row_work) * num_col_2, + vectorBlockIdx * static_cast(row_work) * splitSizeTwo_, row_work_, mlaParams); } ropeFp16.RopeInit(s4Gm, cos2GmTensor, sin2GmTensor, qGmTensor, qGmTensor2, mlaParams); einSumQuant.Init(s1Gm, gmQnopeScale, qGm, mlaParams); @@ -2491,6 +2508,17 @@ public: __aicore__ inline void ProcessVector(); private: + // Model-specific MLA dimensions from tiling data + uint32_t mm1OutSize_; + uint32_t splitSizeOne_; + uint32_t splitSizeTwo_; + uint32_t splitRmsNormSizeOne_; + uint32_t splitRmsNormSizeTwo_; + uint32_t ropeSplitSizeOne_; + uint32_t ropeSplitSizeTwo_; + uint32_t hiddenStrideRope_; + uint32_t qkNopeHeadDim_; + constexpr static uint32_t C0_SIZE = 16; constexpr static uint32_t I8_C0_SIZE = 32; @@ -2505,44 +2533,44 @@ private: AscendC::LocalTensor &tmpfp16, AscendC::LocalTensor &int8OutTensor, float quantScale3) { int64_t slotMapGmOffset = vectorBlockIdx * row_work; - AscendC::DataCopy(gammaTensor, gamma3GmTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::DataCopy(gammaTensor, gamma3GmTensor, splitRmsNormSizeOne_); SET_FLAG(MTE2, V, EVENT_ID1); WAIT_FLAG(MTE2, V, EVENT_ID1); - Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_); AscendC::DataCopyPad(slotMappingTensor, slotMappingGmTensor[slotMapGmOffset], AscendC::DataCopyExtParams(1, sN * sizeof(int32_t), 0, 0, 0), AscendC::DataCopyPadExtParams(false, 0, 8 - sN % 8, 0)); if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { - mmTensor = calTensor.ReinterpretCast()[SPLIT_SIZE_ONE]; - deScaleTensor = calTensor.ReinterpretCast()[SPLIT_SIZE_ONE * 2]; - AscendC::DataCopy(deScaleTensor, descale1gmTensor, AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0)); + mmTensor = calTensor.ReinterpretCast()[splitSizeOne_]; + deScaleTensor = calTensor.ReinterpretCast()[splitSizeOne_ * 2]; + AscendC::DataCopy(deScaleTensor, descale1gmTensor, AscendC::DataCopyParams(1, splitSizeOne_ / 8, 0, 0)); } SET_FLAG(MTE2, V, EVENT_ID2); WAIT_FLAG(MTE2, V, EVENT_ID2); SET_FLAG(MTE2, S, EVENT_ID2); WAIT_FLAG(MTE2, S, EVENT_ID2); for (uint64_t loop = 0; loop < sN; ++loop) { - uint64_t offset = vectorBlockIdx * static_cast(row_work) * num_col_2 + loop * MM1_OUT_SIZE; + uint64_t offset = vectorBlockIdx * static_cast(row_work) * num_col_2 + loop * mm1OutSize_; int64_t slotValue = static_cast(slotMappingTensor.GetValue(loop)); if (slotValue == -1) { continue; } if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { AscendC::DataCopy(srcTensor, s3GmTensor[offset], - AscendC::DataCopyParams(1, MM1_OUT_SIZE / BLOCK_SIZE_16, 0, 0)); + AscendC::DataCopyParams(1, mm1OutSize_ / BLOCK_SIZE_16, 0, 0)); } else { // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT - AscendC::DataCopy(mmTensor, s2GmTensor[offset], AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0)); + AscendC::DataCopy(mmTensor, s2GmTensor[offset], AscendC::DataCopyParams(1, splitSizeOne_ / 8, 0, 0)); } - AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], - SPLIT_RMSNRORM_SIZE_TWO); - AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], - SPLIT_RMSNRORM_SIZE_TWO); + AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_], + splitRmsNormSizeTwo_); + AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_], + splitRmsNormSizeTwo_); SET_FLAG(MTE2, V, EVENT_ID0); // ND - uint64_t cacheStart = static_cast(slotValue) * static_cast(SPLIT_SIZE_ONE); - uint64_t cacheStart1 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_ONE); - uint64_t cacheStart2 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_TWO); + uint64_t cacheStart = static_cast(slotValue) * static_cast(splitSizeOne_); + uint64_t cacheStart1 = static_cast(slotValue) * static_cast(splitRmsNormSizeOne_); + uint64_t cacheStart2 = static_cast(slotValue) * static_cast(splitRmsNormSizeTwo_); // NZ uint32_t outer_idx = slotValue / 128; uint32_t inner_idx = slotValue % 128; @@ -2553,84 +2581,84 @@ private: if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { /* DeQuant */ AscendC::Cast(mmTensor.ReinterpretCast(), mmTensor, AscendC::RoundMode::CAST_NONE, - SPLIT_SIZE_ONE); + splitSizeOne_); AscendC::PipeBarrier(); AscendC::Mul(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), deScaleTensor, - SPLIT_SIZE_ONE); + splitSizeOne_); AscendC::PipeBarrier(); float perTokenDescale = s5GmTensor.GetValue(row_work * vectorBlockIdx + loop); SET_FLAG(S, V, EVENT_ID0); WAIT_FLAG(S, V, EVENT_ID0); AscendC::Muls(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), perTokenDescale, - SPLIT_SIZE_ONE); + splitSizeOne_); AscendC::PipeBarrier(); AscendC::Cast(srcTensor, mmTensor.ReinterpretCast(), AscendC::RoundMode::CAST_RINT, - SPLIT_SIZE_ONE); + splitSizeOne_); AscendC::PipeBarrier(); } - Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - Mul(calTensor, rmsNormTensor, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE); + Mul(calTensor, rmsNormTensor, rmsNormTensor, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - ReduceSumCustom(calTensor[SPLIT_RMSNRORM_SIZE_ONE], calTensor, calTensor[SPLIT_RMSNRORM_SIZE_ONE * 2], - SPLIT_RMSNRORM_SIZE_ONE); + ReduceSumCustom(calTensor[splitRmsNormSizeOne_], calTensor, calTensor[splitRmsNormSizeOne_ * 2], + splitRmsNormSizeOne_); SET_FLAG(V, S, EVENT_ID1); WAIT_FLAG(V, S, EVENT_ID1); - float rms = sqrt(calTensor.GetValue(SPLIT_RMSNRORM_SIZE_ONE) / SPLIT_RMSNRORM_SIZE_ONE + epsilon_); + float rms = sqrt(calTensor.GetValue(splitRmsNormSizeOne_) / splitRmsNormSizeOne_ + epsilon_); SET_FLAG(S, V, EVENT_ID1); WAIT_FLAG(S, V, EVENT_ID1); AscendC::PipeBarrier(); - Duplicate(calTensor, rms, SPLIT_RMSNRORM_SIZE_ONE); + Duplicate(calTensor, rms, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - Div(calTensor, rmsNormTensor, calTensor, SPLIT_RMSNRORM_SIZE_ONE); + Div(calTensor, rmsNormTensor, calTensor, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - Mul(rmsNormTensor, gammaFp32, calTensor, SPLIT_RMSNRORM_SIZE_ONE); + Mul(rmsNormTensor, gammaFp32, calTensor, splitRmsNormSizeOne_); AscendC::PipeBarrier(); if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { // quant - Muls(rmsNormTensor, rmsNormTensor, quantScale3, SPLIT_RMSNRORM_SIZE_ONE); + Muls(rmsNormTensor, rmsNormTensor, quantScale3, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - CastFrom32To16(tmpfp16, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE); + CastFrom32To16(tmpfp16, rmsNormTensor, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - CastFromF16ToI8(int8OutTensor, tmpfp16, -128, SPLIT_RMSNRORM_SIZE_ONE); + CastFromF16ToI8(int8OutTensor, tmpfp16, -128, splitRmsNormSizeOne_); AscendC::PipeBarrier(); } else { AscendC::PipeBarrier(); if (std::is_same::value) { - Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_ONE); + Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, splitRmsNormSizeOne_); } else { - Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_); } } /* RmsNorm end */ /* Rope K start */ - uint64_t revertOffset = SPLIT_RMSNRORM_SIZE_TWO / 2; - Cast(ropeKTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, - SPLIT_RMSNRORM_SIZE_TWO); - Cast(ropeKRevertTensor[revertOffset], srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, + uint64_t revertOffset = splitRmsNormSizeTwo_ / 2; + Cast(ropeKTensor, srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE, + splitRmsNormSizeTwo_); + Cast(ropeKRevertTensor[revertOffset], srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE, revertOffset); - Cast(ropeKRevertTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE + revertOffset], AscendC::RoundMode::CAST_NONE, + Cast(ropeKRevertTensor, srcTensor[splitRmsNormSizeOne_ + revertOffset], AscendC::RoundMode::CAST_NONE, revertOffset); Duplicate(calTensor, static_cast(-1), revertOffset); Duplicate(calTensor[revertOffset], static_cast(1), revertOffset); AscendC::PipeBarrier(); - Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO], cosTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO); - Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], sinTensor, AscendC::RoundMode::CAST_NONE, - SPLIT_RMSNRORM_SIZE_TWO); + Cast(calTensor[splitRmsNormSizeTwo_], cosTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeTwo_); + Cast(calTensor[splitRmsNormSizeTwo_ * 2], sinTensor, AscendC::RoundMode::CAST_NONE, + splitRmsNormSizeTwo_); AscendC::PipeBarrier(); - Mul(ropeKTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO], ropeKTensor, SPLIT_RMSNRORM_SIZE_TWO); - Mul(ropeKRevertTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + Mul(ropeKTensor, calTensor[splitRmsNormSizeTwo_], ropeKTensor, splitRmsNormSizeTwo_); + Mul(ropeKRevertTensor, calTensor[splitRmsNormSizeTwo_ * 2], ropeKRevertTensor, splitRmsNormSizeTwo_); AscendC::PipeBarrier(); - Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, splitRmsNormSizeTwo_); AscendC::PipeBarrier(); - Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, splitRmsNormSizeTwo_); AscendC::PipeBarrier(); if (std::is_same::value) { - Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT, - SPLIT_RMSNRORM_SIZE_TWO); + Cast(outTmpTensor[splitRmsNormSizeOne_], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT, + splitRmsNormSizeTwo_); } else { - Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE, - SPLIT_RMSNRORM_SIZE_TWO); + Cast(outTmpTensor[splitRmsNormSizeOne_], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE, + splitRmsNormSizeTwo_); } AscendC::PipeBarrier(); /* Rope K end */ @@ -2638,45 +2666,45 @@ private: WAIT_FLAG(V, MTE3, EVENT_ID0); WAIT_FLAG(S, MTE3, EVENT_ID0); if constexpr (CACHE_MODE == CACHE_MODE_KVCACHE) { - DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, SPLIT_SIZE_ONE); + DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, splitSizeOne_); } else if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { uint64_t cacheSatartI8Nz1 = outer_idx * 128 * 512 + inner_idx * I8_C0_SIZE; uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE; // nope:int8 nz AscendC::DataCopyExtParams outExt; - outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / I8_C0_SIZE; + outExt.blockCount = splitRmsNormSizeOne_ / I8_C0_SIZE; outExt.blockLen = I8_C0_SIZE * sizeof(int8_t); outExt.srcStride = 0; outExt.dstStride = (128 * I8_C0_SIZE - I8_C0_SIZE) * sizeof(int8_t); DataCopyPad(keycacheGmTensor1[cacheSatartI8Nz1], int8OutTensor, outExt); // rope:T1 nz - outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE; + outExt.blockCount = splitRmsNormSizeTwo_ / C0_SIZE; outExt.blockLen = C0_SIZE * sizeof(T1); outExt.srcStride = 0; outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); - DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt); + DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[splitRmsNormSizeOne_], outExt); } else if constexpr (CACHE_MODE == CACHE_MODE_NZCACHE) { uint64_t cacheSatartNz1 = outer_idx * 128 * 512 + inner_idx * C0_SIZE; uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE; // nope:T1 nz AscendC::DataCopyExtParams outExt; - outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / C0_SIZE; + outExt.blockCount = splitRmsNormSizeOne_ / C0_SIZE; outExt.blockLen = C0_SIZE * sizeof(T1); outExt.srcStride = 0; outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); DataCopyPad(keycacheGmTensor1[cacheSatartNz1], outTmpTensor, outExt); // rope:T1 nz - outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE; + outExt.blockCount = splitRmsNormSizeTwo_ / C0_SIZE; outExt.blockLen = C0_SIZE * sizeof(T1); outExt.srcStride = 0; outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); - DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt); + DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[splitRmsNormSizeOne_], outExt); } else { // keycache1 - DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, SPLIT_RMSNRORM_SIZE_ONE); + DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, splitRmsNormSizeOne_); // keycache2 - DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], - SPLIT_RMSNRORM_SIZE_TWO); + DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[splitRmsNormSizeOne_], + splitRmsNormSizeTwo_); } SET_FLAG(MTE3, MTE2, EVENT_ID1); WAIT_FLAG(MTE3, MTE2, EVENT_ID1); @@ -2823,20 +2851,20 @@ MLAOperation input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(mm1OutSize_ * 2); AscendC::LocalTensor beta_tensor = - buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2); + buf.GetBuffer(mm1OutSize_ * 2 + splitSizeTwo_ * 2); AscendC::LocalTensor scale_tensor = - buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2); + buf.GetBuffer(mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2); AscendC::LocalTensor offset_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 32); + mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 32); AscendC::LocalTensor res1_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64); + mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64); AscendC::LocalTensor res3_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4); + mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4 + - BUF_FACTOR * num_col_align_f32 * 4 + 64 + MM1_OUT_SIZE * 4 * 2 + 32); + mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4 + + BUF_FACTOR * num_col_align_f32 * 4 + 64 + mm1OutSize_ * 4 * 2 + 32); rmsNormQuant2.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); } @@ -2846,20 +2874,20 @@ MLAOperation input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(mm1OutSize_ * 2); AscendC::LocalTensor sin_tensor = - buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2); + buf.GetBuffer(mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2); AscendC::LocalTensor cos_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 2); + mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 2); AscendC::LocalTensor slotMapping_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4); + mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4); int32_t rms3_ub_offset = - MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + 4096 * 32; + mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 + 4096 * 32; AscendC::LocalTensor tmp32_tensor = buf.GetBuffer(rms3_ub_offset); - int32_t out_ub_offset = MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + - 4096 * 32 + SPLIT_RMSNRORM_SIZE_ONE * 3 * 4 + SPLIT_RMSNRORM_SIZE_TWO * 2 * 4 + - MM1_OUT_SIZE * 4 * 2 + 32; + int32_t out_ub_offset = mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 + + 4096 * 32 + splitRmsNormSizeOne_ * 3 * 4 + splitRmsNormSizeTwo_ * 2 * 4 + + mm1OutSize_ * 4 * 2 + 32; AscendC::LocalTensor temp_tensor = buf.GetBuffer(out_ub_offset); AscendC::LocalTensor tmpfp16; @@ -2873,7 +2901,7 @@ MLAOperation(rms3_ub_offset + 32); // int8out tmpfp16 = buf.GetBuffer(rms3_ub_offset + - SPLIT_RMSNRORM_SIZE_ONE * sizeof(float) * 2); + splitRmsNormSizeOne_ * sizeof(float) * 2); int8OutTensor = buf.GetBuffer(out_ub_offset); AscendC::DataCopy(quantScaleTensor, quantScale3GmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); SET_FLAG(MTE2, V, EVENT_ID1); @@ -2890,11 +2918,11 @@ MLAOperation(GetSubBlockidx()); loopTime = (blockIdx_ == realCore - 1) ? lastCoreLoopTime : preCoreLoopTime; lastLoopN = (blockIdx_ == realCore - 1) ? lastCoreLoopNLast : preCoreLoopNLast; @@ -92,7 +94,7 @@ public: AscendC::LocalTensor inputQCastFP32 = buf.GetBuffer(dataSizeFp16); AscendC::LocalTensor reverseQ = buf.GetBuffer(dataSizeFp32 + dataSizeFp16); - uint64_t qOffset = startHead * 192 + 128; + uint64_t qOffset = startHead * hiddenStrideRope_ + qkNopeHeadDim_; CopyQGenReverseQ(inputQ, inputQCastFP32, reverseQ, qOffset, loopN); // move in cos/sin @@ -178,7 +180,7 @@ public: { // move in Q WAIT_FLAG(MTE3, MTE2, EVENT_ID1); - AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, 128 / 16, 0}); + AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, static_cast(qkNopeHeadDim_ / 16), 0}); SET_FLAG(MTE2, V, EVENT_ID1); WAIT_FLAG(MTE2, V, EVENT_ID1); // cast fp32 @@ -230,6 +232,8 @@ private: uint32_t lastCoreLoopTime; uint32_t lastCoreLoopNLast; uint32_t concatSize; + uint32_t hiddenStrideRope_; + uint32_t qkNopeHeadDim_; uint32_t blockIdx_; uint32_t loopTime{0}; uint32_t lastLoopN{0}; @@ -936,6 +940,16 @@ public: this->num_row = mlaParams_.n; this->epsilon_ = 1e-6; this->mlaParams = mlaParams_; + this->hiddenStateDim = mlaParams_.hiddenStateDim; + this->mm1OutSize_ = mlaParams_.mm1OutSize; + this->splitSizeOne_ = mlaParams_.splitSizeOne; + this->splitSizeTwo_ = mlaParams_.splitSizeTwo; + this->splitRmsNormSizeOne_ = mlaParams_.splitRmsNormSizeOne; + this->splitRmsNormSizeTwo_ = mlaParams_.splitRmsNormSizeTwo; + this->ropeSplitSizeOne_ = mlaParams_.ropeSplitSizeOne; + this->ropeSplitSizeTwo_ = mlaParams_.ropeSplitSizeTwo; + this->hiddenStrideRope_ = mlaParams_.hiddenStrideRope; + this->qkNopeHeadDim_ = mlaParams_.qkNopeHeadDim; } __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR wdqkvGm, GM_ADDR gamma2Gm, @@ -984,9 +998,9 @@ public: row_work_ = 0; } this->splitN = mlaParams.perTaskNum; - rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, s3Gm, s1Gm, SPLIT_SIZE_ONE, - num_col_2, 0.000651041666, vectorBlockIdx * static_cast(row_work) * num_col_2, - vectorBlockIdx * static_cast(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams); + rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, s3Gm, s1Gm, splitSizeOne_, + num_col_2, mlaParams.avgFactor, vectorBlockIdx * static_cast(row_work) * num_col_2, + vectorBlockIdx * static_cast(row_work) * splitSizeTwo_, row_work_, mlaParams); ropeFp16.RopeInit(s2Gm, cos2GmTensor, sin2GmTensor, qGmTensor, qGmTensor2, mlaParams); #endif } @@ -996,6 +1010,18 @@ public: __aicore__ inline void ProcessVector(); private: + // Model-specific MLA dimensions from tiling data + uint32_t hiddenStateDim; + uint32_t mm1OutSize_; + uint32_t splitSizeOne_; + uint32_t splitSizeTwo_; + uint32_t splitRmsNormSizeOne_; + uint32_t splitRmsNormSizeTwo_; + uint32_t ropeSplitSizeOne_; + uint32_t ropeSplitSizeTwo_; + uint32_t hiddenStrideRope_; + uint32_t qkNopeHeadDim_; + constexpr static uint32_t C0_SIZE = 16; constexpr static uint32_t I8_C0_SIZE = 32; @@ -1009,10 +1035,10 @@ private: const AscendC::LocalTensor &calTensor, const AscendC::LocalTensor &outTmpTensor) { int64_t slotMapGmOffset = vectorBlockIdx * row_work; - AscendC::DataCopy(gammaTensor, gamma3GmTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::DataCopy(gammaTensor, gamma3GmTensor, splitRmsNormSizeOne_); SET_FLAG(MTE2, V, EVENT_ID1); WAIT_FLAG(MTE2, V, EVENT_ID1); - Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_); AscendC::DataCopyPad(slotMappingTensor, slotMappingGmTensor[slotMapGmOffset], AscendC::DataCopyExtParams(1, sN * sizeof(int32_t), 0, 0, 0), AscendC::DataCopyPadExtParams(false, 0, 8 - sN % 8, 0)); @@ -1021,22 +1047,22 @@ private: SET_FLAG(MTE2, S, EVENT_ID2); WAIT_FLAG(MTE2, S, EVENT_ID2); for (uint64_t loop = 0; loop < sN; ++loop) { - uint64_t offset = vectorBlockIdx * static_cast(row_work) * num_col_2 + loop * MM1_OUT_SIZE; + uint64_t offset = vectorBlockIdx * static_cast(row_work) * num_col_2 + loop * mm1OutSize_; int64_t slotValue = static_cast(slotMappingTensor.GetValue(loop)); if (slotValue == -1) { continue; } AscendC::DataCopy(srcTensor, s3GmTensor[offset], - AscendC::DataCopyParams(1, MM1_OUT_SIZE / BLOCK_SIZE_16, 0, 0)); - AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], - SPLIT_RMSNRORM_SIZE_TWO); - AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], - SPLIT_RMSNRORM_SIZE_TWO); + AscendC::DataCopyParams(1, mm1OutSize_ / BLOCK_SIZE_16, 0, 0)); + AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_], + splitRmsNormSizeTwo_); + AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_], + splitRmsNormSizeTwo_); SET_FLAG(MTE2, V, EVENT_ID0); // ND - uint64_t cacheStart = static_cast(slotValue) * static_cast(SPLIT_SIZE_ONE); - uint64_t cacheStart1 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_ONE); - uint64_t cacheStart2 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_TWO); + uint64_t cacheStart = static_cast(slotValue) * static_cast(splitSizeOne_); + uint64_t cacheStart1 = static_cast(slotValue) * static_cast(splitRmsNormSizeOne_); + uint64_t cacheStart2 = static_cast(slotValue) * static_cast(splitRmsNormSizeTwo_); // NZ uint32_t outer_idx = slotValue / 128; uint32_t inner_idx = slotValue % 128; @@ -1044,63 +1070,63 @@ private: SET_FLAG(S, MTE3, EVENT_ID0); /* RmsNorm start */ WAIT_FLAG(MTE2, V, EVENT_ID0); - Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - Mul(calTensor, rmsNormTensor, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE); + Mul(calTensor, rmsNormTensor, rmsNormTensor, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - ReduceSumCustom(calTensor[SPLIT_RMSNRORM_SIZE_ONE], calTensor, calTensor[SPLIT_RMSNRORM_SIZE_ONE * 2], - SPLIT_RMSNRORM_SIZE_ONE); + ReduceSumCustom(calTensor[splitRmsNormSizeOne_], calTensor, calTensor[splitRmsNormSizeOne_ * 2], + splitRmsNormSizeOne_); SET_FLAG(V, S, EVENT_ID1); WAIT_FLAG(V, S, EVENT_ID1); - float rms = sqrt(calTensor.GetValue(SPLIT_RMSNRORM_SIZE_ONE) / SPLIT_RMSNRORM_SIZE_ONE + epsilon_); + float rms = sqrt(calTensor.GetValue(splitRmsNormSizeOne_) / splitRmsNormSizeOne_ + epsilon_); SET_FLAG(S, V, EVENT_ID1); WAIT_FLAG(S, V, EVENT_ID1); AscendC::PipeBarrier(); - Duplicate(calTensor, rms, SPLIT_RMSNRORM_SIZE_ONE); + Duplicate(calTensor, rms, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - Div(calTensor, rmsNormTensor, calTensor, SPLIT_RMSNRORM_SIZE_ONE); + Div(calTensor, rmsNormTensor, calTensor, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - Mul(rmsNormTensor, gammaFp32, calTensor, SPLIT_RMSNRORM_SIZE_ONE); + Mul(rmsNormTensor, gammaFp32, calTensor, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_ONE); + Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, splitRmsNormSizeOne_); /* RmsNorm end */ /* Rope K start */ - uint64_t revertOffset = SPLIT_RMSNRORM_SIZE_TWO / 2; - Cast(ropeKTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, - SPLIT_RMSNRORM_SIZE_TWO); - Cast(ropeKRevertTensor[revertOffset], srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, + uint64_t revertOffset = splitRmsNormSizeTwo_ / 2; + Cast(ropeKTensor, srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE, + splitRmsNormSizeTwo_); + Cast(ropeKRevertTensor[revertOffset], srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE, revertOffset); - Cast(ropeKRevertTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE + revertOffset], AscendC::RoundMode::CAST_NONE, + Cast(ropeKRevertTensor, srcTensor[splitRmsNormSizeOne_ + revertOffset], AscendC::RoundMode::CAST_NONE, revertOffset); Duplicate(calTensor, static_cast(-1), revertOffset); Duplicate(calTensor[revertOffset], static_cast(1), revertOffset); AscendC::PipeBarrier(); - Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO], cosTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO); - Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], sinTensor, AscendC::RoundMode::CAST_NONE, - SPLIT_RMSNRORM_SIZE_TWO); + Cast(calTensor[splitRmsNormSizeTwo_], cosTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeTwo_); + Cast(calTensor[splitRmsNormSizeTwo_ * 2], sinTensor, AscendC::RoundMode::CAST_NONE, + splitRmsNormSizeTwo_); AscendC::PipeBarrier(); - Mul(ropeKTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO], ropeKTensor, SPLIT_RMSNRORM_SIZE_TWO); - Mul(ropeKRevertTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + Mul(ropeKTensor, calTensor[splitRmsNormSizeTwo_], ropeKTensor, splitRmsNormSizeTwo_); + Mul(ropeKRevertTensor, calTensor[splitRmsNormSizeTwo_ * 2], ropeKRevertTensor, splitRmsNormSizeTwo_); AscendC::PipeBarrier(); - Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, splitRmsNormSizeTwo_); AscendC::PipeBarrier(); - Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, splitRmsNormSizeTwo_); AscendC::PipeBarrier(); - Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT, - SPLIT_RMSNRORM_SIZE_TWO); + Cast(outTmpTensor[splitRmsNormSizeOne_], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT, + splitRmsNormSizeTwo_); AscendC::PipeBarrier(); /* Rope K end */ SET_FLAG(V, MTE3, EVENT_ID0); WAIT_FLAG(V, MTE3, EVENT_ID0); WAIT_FLAG(S, MTE3, EVENT_ID0); if constexpr (CACHE_MODE == CACHE_MODE_KVCACHE) { - DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, SPLIT_SIZE_ONE); + DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, splitSizeOne_); } else { // keycache1 - DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, SPLIT_RMSNRORM_SIZE_ONE); + DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, splitRmsNormSizeOne_); // keycache2 - DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], - SPLIT_RMSNRORM_SIZE_TWO); + DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[splitRmsNormSizeOne_], + splitRmsNormSizeTwo_); } SET_FLAG(MTE3, MTE2, EVENT_ID1); WAIT_FLAG(MTE3, MTE2, EVENT_ID1); @@ -1196,16 +1222,16 @@ MLAOperation:: uint32_t num_col_align_f16 = (num_col_2 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; uint32_t num_col_align_f32 = (num_col_2 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; AscendC::LocalTensor input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(mm1OutSize_ * 2); AscendC::LocalTensor beta_tensor = - buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2); + buf.GetBuffer(mm1OutSize_ * 2 + splitSizeTwo_ * 2); AscendC::LocalTensor res1_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64); + mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64); AscendC::LocalTensor res3_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4); + mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4 + - BUF_FACTOR * num_col_align_f32 * 4 + 64 + MM1_OUT_SIZE * 4 * 2 + 32); + mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4 + + BUF_FACTOR * num_col_align_f32 * 4 + 64 + mm1OutSize_ * 4 * 2 + 32); rmsNormQuant2.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, res1_tensor, res3_tensor); } FftsCrossCoreSync(RMSNORMQUANT2); @@ -1214,20 +1240,20 @@ MLAOperation:: if (row_work_ != 0) { AscendC::LocalTensor input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(mm1OutSize_ * 2); AscendC::LocalTensor sin_tensor = - buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2); + buf.GetBuffer(mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2); AscendC::LocalTensor cos_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 2); + mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 2); AscendC::LocalTensor slotMapping_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4); + mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4); int32_t rms3_ub_offset = - MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + 4096 * 32; + mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 + 4096 * 32; AscendC::LocalTensor tmp32_tensor = buf.GetBuffer(rms3_ub_offset); - int32_t out_ub_offset = MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + - 4096 * 32 + SPLIT_RMSNRORM_SIZE_ONE * 3 * 4 + SPLIT_RMSNRORM_SIZE_TWO * 2 * 4 + - MM1_OUT_SIZE * 4 * 2 + 32; + int32_t out_ub_offset = mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 + + 4096 * 32 + splitRmsNormSizeOne_ * 3 * 4 + splitRmsNormSizeTwo_ * 2 * 4 + + mm1OutSize_ * 4 * 2 + 32; AscendC::LocalTensor temp_tensor = buf.GetBuffer(out_ub_offset); RmsNormAndRopeConvergence1( @@ -1236,11 +1262,11 @@ MLAOperation:: sin_tensor, // sin cos_tensor, // cons slotMapping_tensor, // slotMapping - row_work_, tmp32_tensor, tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE], - tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE], - tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO], - tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO + - SPLIT_RMSNRORM_SIZE_TWO], + row_work_, tmp32_tensor, tmp32_tensor[splitRmsNormSizeOne_], + tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_], + tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_ + splitRmsNormSizeTwo_], + tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_ + splitRmsNormSizeTwo_ + + splitRmsNormSizeTwo_], temp_tensor); } diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16_qdown.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16_qdown.hpp index 1df954e8..1aa7bf37 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16_qdown.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16_qdown.hpp @@ -54,6 +54,8 @@ public: lastCoreLoopTime = ropeConcatParams.lastCoreLoopTime; lastCoreLoopNLast = ropeConcatParams.lastCoreLoopNLast; concatSize = ropeConcatParams.concatSize; + hiddenStrideRope_ = ropeConcatParams.hiddenStrideRope; + qkNopeHeadDim_ = ropeConcatParams.qkNopeHeadDim; blockIdx_ = (blockIdx_ / 2) * 2 + static_cast(GetSubBlockidx()); loopTime = (blockIdx_ == realCore - 1) ? lastCoreLoopTime : preCoreLoopTime; lastLoopN = (blockIdx_ == realCore - 1) ? lastCoreLoopNLast : preCoreLoopNLast; @@ -92,7 +94,7 @@ public: AscendC::LocalTensor inputQCastFP32 = buf.GetBuffer(dataSizeFp16); AscendC::LocalTensor reverseQ = buf.GetBuffer(dataSizeFp32 + dataSizeFp16); - uint64_t qOffset = startHead * 192 + 128; + uint64_t qOffset = startHead * hiddenStrideRope_ + qkNopeHeadDim_; CopyQGenReverseQ(inputQ, inputQCastFP32, reverseQ, qOffset, loopN); // move in cos/sin @@ -178,7 +180,7 @@ public: { // move in Q WAIT_FLAG(MTE3, MTE2, EVENT_ID1); - AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, 128 / 16, 0}); + AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, static_cast(qkNopeHeadDim_ / 16), 0}); SET_FLAG(MTE2, V, EVENT_ID1); WAIT_FLAG(MTE2, V, EVENT_ID1); // cast fp32 @@ -230,6 +232,8 @@ private: uint32_t lastCoreLoopTime; uint32_t lastCoreLoopNLast; uint32_t concatSize; + uint32_t hiddenStrideRope_; + uint32_t qkNopeHeadDim_; uint32_t blockIdx_; uint32_t loopTime{0}; uint32_t lastLoopN{0}; @@ -309,6 +313,7 @@ public: this->num_row_ = mlaParams_.n; this->row_work = row_work; this->row_work_ = row_work_; + this->mm1OutSize_ = mlaParams_.mm1OutSize; gm_offset_ = gm_offset; gm_out_offset_ = gm_out_offset; num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; @@ -353,8 +358,8 @@ public: if constexpr (NEED_DEQUANT) { mmTensor = buf.ReinterpretCast()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; - deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE]; - perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2]; + deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_]; + perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_ * 2]; AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); } @@ -531,6 +536,7 @@ private: uint32_t num_slice_{0}; uint32_t tail_size_{0}; uint32_t tail_copy_{0}; + uint32_t mm1OutSize_{0}; }; template num_row_ = mlaParams_.n; this->row_work = row_work; this->row_work_ = row_work_; + this->mm1OutSize_ = mlaParams_.mm1OutSize; gm_offset_ = gm_offset; gm_out_offset_ = gm_out_offset; num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; @@ -620,8 +627,8 @@ public: if constexpr (NEED_DEQUANT) { mmTensor = buf.ReinterpretCast()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; - deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE]; - perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2]; + deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_]; + perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_ * 2]; AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); } @@ -857,6 +864,7 @@ private: uint32_t num_slice_{0}; uint32_t tail_size_{0}; uint32_t tail_copy_{0}; + uint32_t mm1OutSize_{0}; }; template @@ -2407,6 +2415,15 @@ public: this->epsilon_ = 1e-6; this->mlaParams = mlaParams_; this->hiddenStateDim = mlaParams_.hiddenStateDim; + this->mm1OutSize_ = mlaParams_.mm1OutSize; + this->splitSizeOne_ = mlaParams_.splitSizeOne; + this->splitSizeTwo_ = mlaParams_.splitSizeTwo; + this->splitRmsNormSizeOne_ = mlaParams_.splitRmsNormSizeOne; + this->splitRmsNormSizeTwo_ = mlaParams_.splitRmsNormSizeTwo; + this->ropeSplitSizeOne_ = mlaParams_.ropeSplitSizeOne; + this->ropeSplitSizeTwo_ = mlaParams_.ropeSplitSizeTwo; + this->hiddenStrideRope_ = mlaParams_.hiddenStrideRope; + this->qkNopeHeadDim_ = mlaParams_.qkNopeHeadDim; } __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm, @@ -2492,15 +2509,15 @@ public: vectorBlockIdx * static_cast(row_work) * num_col_1, row_work_, mlaParams); if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor, - s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s3Gm, s1Gm, SPLIT_SIZE_ONE, - num_col_2, 0.000651041666, vectorBlockIdx * static_cast(row_work) * num_col_2, - vectorBlockIdx * static_cast(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams, innerGmTensor); + s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s3Gm, s1Gm, splitSizeOne_, + num_col_2, mlaParams.avgFactor, vectorBlockIdx * static_cast(row_work) * num_col_2, + vectorBlockIdx * static_cast(row_work) * splitSizeTwo_, row_work_, mlaParams, innerGmTensor); } else { // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor, - s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s2Gm, s1Gm, SPLIT_SIZE_ONE, - num_col_2, 0.000651041666, vectorBlockIdx * static_cast(row_work) * num_col_2, - vectorBlockIdx * static_cast(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams, innerGmTensor); + s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s2Gm, s1Gm, splitSizeOne_, + num_col_2, mlaParams.avgFactor, vectorBlockIdx * static_cast(row_work) * num_col_2, + vectorBlockIdx * static_cast(row_work) * splitSizeTwo_, row_work_, mlaParams, innerGmTensor); } ropeFp16.RopeInit(s4Gm, cos2GmTensor, sin2GmTensor, qGmTensor, qGmTensor2, mlaParams); einSumQuant.Init(s1Gm, gmQnopeScale, qGm, mlaParams); @@ -2512,6 +2529,17 @@ public: __aicore__ inline void ProcessVector(); private: + // Model-specific MLA dimensions from tiling data + uint32_t mm1OutSize_; + uint32_t splitSizeOne_; + uint32_t splitSizeTwo_; + uint32_t splitRmsNormSizeOne_; + uint32_t splitRmsNormSizeTwo_; + uint32_t ropeSplitSizeOne_; + uint32_t ropeSplitSizeTwo_; + uint32_t hiddenStrideRope_; + uint32_t qkNopeHeadDim_; + constexpr static uint32_t C0_SIZE = 16; constexpr static uint32_t I8_C0_SIZE = 32; @@ -2526,44 +2554,44 @@ private: AscendC::LocalTensor &tmpfp16, AscendC::LocalTensor &int8OutTensor, float quantScale3) { int64_t slotMapGmOffset = vectorBlockIdx * row_work; - AscendC::DataCopy(gammaTensor, gamma3GmTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::DataCopy(gammaTensor, gamma3GmTensor, splitRmsNormSizeOne_); SET_FLAG(MTE2, V, EVENT_ID1); WAIT_FLAG(MTE2, V, EVENT_ID1); - Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_); AscendC::DataCopyPad(slotMappingTensor, slotMappingGmTensor[slotMapGmOffset], AscendC::DataCopyExtParams(1, sN * sizeof(int32_t), 0, 0, 0), AscendC::DataCopyPadExtParams(false, 0, 8 - sN % 8, 0)); if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { - mmTensor = calTensor.ReinterpretCast()[SPLIT_SIZE_ONE]; - deScaleTensor = calTensor.ReinterpretCast()[SPLIT_SIZE_ONE * 2]; - AscendC::DataCopy(deScaleTensor, descale1gmTensor, AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0)); + mmTensor = calTensor.ReinterpretCast()[splitSizeOne_]; + deScaleTensor = calTensor.ReinterpretCast()[splitSizeOne_ * 2]; + AscendC::DataCopy(deScaleTensor, descale1gmTensor, AscendC::DataCopyParams(1, splitSizeOne_ / 8, 0, 0)); } SET_FLAG(MTE2, V, EVENT_ID2); WAIT_FLAG(MTE2, V, EVENT_ID2); SET_FLAG(MTE2, S, EVENT_ID2); WAIT_FLAG(MTE2, S, EVENT_ID2); for (uint64_t loop = 0; loop < sN; ++loop) { - uint64_t offset = vectorBlockIdx * static_cast(row_work) * num_col_2 + loop * MM1_OUT_SIZE; + uint64_t offset = vectorBlockIdx * static_cast(row_work) * num_col_2 + loop * mm1OutSize_; int64_t slotValue = static_cast(slotMappingTensor.GetValue(loop)); if (slotValue == -1) { continue; } if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) { AscendC::DataCopy(srcTensor, s3GmTensor[offset], - AscendC::DataCopyParams(1, MM1_OUT_SIZE / BLOCK_SIZE_16, 0, 0)); + AscendC::DataCopyParams(1, mm1OutSize_ / BLOCK_SIZE_16, 0, 0)); } else { // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT - AscendC::DataCopy(mmTensor, s2GmTensor[offset], AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0)); + AscendC::DataCopy(mmTensor, s2GmTensor[offset], AscendC::DataCopyParams(1, splitSizeOne_ / 8, 0, 0)); } - AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], - SPLIT_RMSNRORM_SIZE_TWO); - AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], - SPLIT_RMSNRORM_SIZE_TWO); + AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_], + splitRmsNormSizeTwo_); + AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_], + splitRmsNormSizeTwo_); SET_FLAG(MTE2, V, EVENT_ID0); // ND - uint64_t cacheStart = static_cast(slotValue) * static_cast(SPLIT_SIZE_ONE); - uint64_t cacheStart1 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_ONE); - uint64_t cacheStart2 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_TWO); + uint64_t cacheStart = static_cast(slotValue) * static_cast(splitSizeOne_); + uint64_t cacheStart1 = static_cast(slotValue) * static_cast(splitRmsNormSizeOne_); + uint64_t cacheStart2 = static_cast(slotValue) * static_cast(splitRmsNormSizeTwo_); // NZ uint32_t outer_idx = slotValue / 128; uint32_t inner_idx = slotValue % 128; @@ -2574,84 +2602,84 @@ private: if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { /* DeQuant */ AscendC::Cast(mmTensor.ReinterpretCast(), mmTensor, AscendC::RoundMode::CAST_NONE, - SPLIT_SIZE_ONE); + splitSizeOne_); AscendC::PipeBarrier(); AscendC::Mul(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), deScaleTensor, - SPLIT_SIZE_ONE); + splitSizeOne_); AscendC::PipeBarrier(); float perTokenDescale = s5GmTensor.GetValue(row_work * vectorBlockIdx + loop); SET_FLAG(S, V, EVENT_ID0); WAIT_FLAG(S, V, EVENT_ID0); AscendC::Muls(mmTensor.ReinterpretCast(), mmTensor.ReinterpretCast(), perTokenDescale, - SPLIT_SIZE_ONE); + splitSizeOne_); AscendC::PipeBarrier(); AscendC::Cast(srcTensor, mmTensor.ReinterpretCast(), AscendC::RoundMode::CAST_RINT, - SPLIT_SIZE_ONE); + splitSizeOne_); AscendC::PipeBarrier(); } - Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - Mul(calTensor, rmsNormTensor, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE); + Mul(calTensor, rmsNormTensor, rmsNormTensor, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - ReduceSumCustom(calTensor[SPLIT_RMSNRORM_SIZE_ONE], calTensor, calTensor[SPLIT_RMSNRORM_SIZE_ONE * 2], - SPLIT_RMSNRORM_SIZE_ONE); + ReduceSumCustom(calTensor[splitRmsNormSizeOne_], calTensor, calTensor[splitRmsNormSizeOne_ * 2], + splitRmsNormSizeOne_); SET_FLAG(V, S, EVENT_ID1); WAIT_FLAG(V, S, EVENT_ID1); - float rms = sqrt(calTensor.GetValue(SPLIT_RMSNRORM_SIZE_ONE) / SPLIT_RMSNRORM_SIZE_ONE + epsilon_); + float rms = sqrt(calTensor.GetValue(splitRmsNormSizeOne_) / splitRmsNormSizeOne_ + epsilon_); SET_FLAG(S, V, EVENT_ID1); WAIT_FLAG(S, V, EVENT_ID1); AscendC::PipeBarrier(); - Duplicate(calTensor, rms, SPLIT_RMSNRORM_SIZE_ONE); + Duplicate(calTensor, rms, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - Div(calTensor, rmsNormTensor, calTensor, SPLIT_RMSNRORM_SIZE_ONE); + Div(calTensor, rmsNormTensor, calTensor, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - Mul(rmsNormTensor, gammaFp32, calTensor, SPLIT_RMSNRORM_SIZE_ONE); + Mul(rmsNormTensor, gammaFp32, calTensor, splitRmsNormSizeOne_); AscendC::PipeBarrier(); if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { // quant - Muls(rmsNormTensor, rmsNormTensor, quantScale3, SPLIT_RMSNRORM_SIZE_ONE); + Muls(rmsNormTensor, rmsNormTensor, quantScale3, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - CastFrom32To16(tmpfp16, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE); + CastFrom32To16(tmpfp16, rmsNormTensor, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - CastFromF16ToI8(int8OutTensor, tmpfp16, -128, SPLIT_RMSNRORM_SIZE_ONE); + CastFromF16ToI8(int8OutTensor, tmpfp16, -128, splitRmsNormSizeOne_); AscendC::PipeBarrier(); } else { AscendC::PipeBarrier(); if (std::is_same::value) { - Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_ONE); + Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, splitRmsNormSizeOne_); } else { - Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_); } } /* RmsNorm end */ /* Rope K start */ - uint64_t revertOffset = SPLIT_RMSNRORM_SIZE_TWO / 2; - Cast(ropeKTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, - SPLIT_RMSNRORM_SIZE_TWO); - Cast(ropeKRevertTensor[revertOffset], srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, + uint64_t revertOffset = splitRmsNormSizeTwo_ / 2; + Cast(ropeKTensor, srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE, + splitRmsNormSizeTwo_); + Cast(ropeKRevertTensor[revertOffset], srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE, revertOffset); - Cast(ropeKRevertTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE + revertOffset], AscendC::RoundMode::CAST_NONE, + Cast(ropeKRevertTensor, srcTensor[splitRmsNormSizeOne_ + revertOffset], AscendC::RoundMode::CAST_NONE, revertOffset); Duplicate(calTensor, static_cast(-1), revertOffset); Duplicate(calTensor[revertOffset], static_cast(1), revertOffset); AscendC::PipeBarrier(); - Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO], cosTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO); - Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], sinTensor, AscendC::RoundMode::CAST_NONE, - SPLIT_RMSNRORM_SIZE_TWO); + Cast(calTensor[splitRmsNormSizeTwo_], cosTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeTwo_); + Cast(calTensor[splitRmsNormSizeTwo_ * 2], sinTensor, AscendC::RoundMode::CAST_NONE, + splitRmsNormSizeTwo_); AscendC::PipeBarrier(); - Mul(ropeKTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO], ropeKTensor, SPLIT_RMSNRORM_SIZE_TWO); - Mul(ropeKRevertTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + Mul(ropeKTensor, calTensor[splitRmsNormSizeTwo_], ropeKTensor, splitRmsNormSizeTwo_); + Mul(ropeKRevertTensor, calTensor[splitRmsNormSizeTwo_ * 2], ropeKRevertTensor, splitRmsNormSizeTwo_); AscendC::PipeBarrier(); - Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, splitRmsNormSizeTwo_); AscendC::PipeBarrier(); - Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, splitRmsNormSizeTwo_); AscendC::PipeBarrier(); if (std::is_same::value) { - Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT, - SPLIT_RMSNRORM_SIZE_TWO); + Cast(outTmpTensor[splitRmsNormSizeOne_], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT, + splitRmsNormSizeTwo_); } else { - Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE, - SPLIT_RMSNRORM_SIZE_TWO); + Cast(outTmpTensor[splitRmsNormSizeOne_], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE, + splitRmsNormSizeTwo_); } AscendC::PipeBarrier(); /* Rope K end */ @@ -2659,45 +2687,45 @@ private: WAIT_FLAG(V, MTE3, EVENT_ID0); WAIT_FLAG(S, MTE3, EVENT_ID0); if constexpr (CACHE_MODE == CACHE_MODE_KVCACHE) { - DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, SPLIT_SIZE_ONE); + DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, splitSizeOne_); } else if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) { uint64_t cacheSatartI8Nz1 = outer_idx * 128 * 512 + inner_idx * I8_C0_SIZE; uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE; // nope:int8 nz AscendC::DataCopyExtParams outExt; - outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / I8_C0_SIZE; + outExt.blockCount = splitRmsNormSizeOne_ / I8_C0_SIZE; outExt.blockLen = I8_C0_SIZE * sizeof(int8_t); outExt.srcStride = 0; outExt.dstStride = (128 * I8_C0_SIZE - I8_C0_SIZE) * sizeof(int8_t); DataCopyPad(keycacheGmTensor1[cacheSatartI8Nz1], int8OutTensor, outExt); // rope:T1 nz - outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE; + outExt.blockCount = splitRmsNormSizeTwo_ / C0_SIZE; outExt.blockLen = C0_SIZE * sizeof(T1); outExt.srcStride = 0; outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); - DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt); + DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[splitRmsNormSizeOne_], outExt); } else if constexpr (CACHE_MODE == CACHE_MODE_NZCACHE) { uint64_t cacheSatartNz1 = outer_idx * 128 * 512 + inner_idx * C0_SIZE; uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE; // nope:T1 nz AscendC::DataCopyExtParams outExt; - outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / C0_SIZE; + outExt.blockCount = splitRmsNormSizeOne_ / C0_SIZE; outExt.blockLen = C0_SIZE * sizeof(T1); outExt.srcStride = 0; outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); DataCopyPad(keycacheGmTensor1[cacheSatartNz1], outTmpTensor, outExt); // rope:T1 nz - outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE; + outExt.blockCount = splitRmsNormSizeTwo_ / C0_SIZE; outExt.blockLen = C0_SIZE * sizeof(T1); outExt.srcStride = 0; outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); - DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt); + DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[splitRmsNormSizeOne_], outExt); } else { // keycache1 - DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, SPLIT_RMSNRORM_SIZE_ONE); + DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, splitRmsNormSizeOne_); // keycache2 - DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], - SPLIT_RMSNRORM_SIZE_TWO); + DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[splitRmsNormSizeOne_], + splitRmsNormSizeTwo_); } SET_FLAG(MTE3, MTE2, EVENT_ID1); WAIT_FLAG(MTE3, MTE2, EVENT_ID1); @@ -2845,20 +2873,20 @@ MLAOperation input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(mm1OutSize_ * 2); AscendC::LocalTensor beta_tensor = - buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2); + buf.GetBuffer(mm1OutSize_ * 2 + splitSizeTwo_ * 2); AscendC::LocalTensor scale_tensor = - buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2); + buf.GetBuffer(mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2); AscendC::LocalTensor offset_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 32); + mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 32); AscendC::LocalTensor res1_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64); + mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64); AscendC::LocalTensor res3_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4); + mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4 + - BUF_FACTOR * num_col_align_f32 * 4 + 64 + MM1_OUT_SIZE * 4 * 2 + 32); + mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4 + + BUF_FACTOR * num_col_align_f32 * 4 + 64 + mm1OutSize_ * 4 * 2 + 32); rmsNormQuant2.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); } @@ -2868,20 +2896,20 @@ MLAOperation input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(mm1OutSize_ * 2); AscendC::LocalTensor sin_tensor = - buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2); + buf.GetBuffer(mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2); AscendC::LocalTensor cos_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 2); + mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 2); AscendC::LocalTensor slotMapping_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4); + mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4); int32_t rms3_ub_offset = - MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + 4096 * 32; + mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 + 4096 * 32; AscendC::LocalTensor tmp32_tensor = buf.GetBuffer(rms3_ub_offset); - int32_t out_ub_offset = MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + - 4096 * 32 + SPLIT_RMSNRORM_SIZE_ONE * 3 * 4 + SPLIT_RMSNRORM_SIZE_TWO * 2 * 4 + - MM1_OUT_SIZE * 4 * 2 + 32; + int32_t out_ub_offset = mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 + + 4096 * 32 + splitRmsNormSizeOne_ * 3 * 4 + splitRmsNormSizeTwo_ * 2 * 4 + + mm1OutSize_ * 4 * 2 + 32; AscendC::LocalTensor temp_tensor = buf.GetBuffer(out_ub_offset); AscendC::LocalTensor tmpfp16; @@ -2895,7 +2923,7 @@ MLAOperation(rms3_ub_offset + 32); // int8out tmpfp16 = buf.GetBuffer(rms3_ub_offset + - SPLIT_RMSNRORM_SIZE_ONE * sizeof(float) * 2); + splitRmsNormSizeOne_ * sizeof(float) * 2); int8OutTensor = buf.GetBuffer(out_ub_offset); AscendC::DataCopy(quantScaleTensor, quantScale3GmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); SET_FLAG(MTE2, V, EVENT_ID1); @@ -2912,11 +2940,11 @@ MLAOperation(GetSubBlockidx()); loopTime = (blockIdx_ == realCore - 1) ? lastCoreLoopTime : preCoreLoopTime; lastLoopN = (blockIdx_ == realCore - 1) ? lastCoreLoopNLast : preCoreLoopNLast; @@ -92,7 +94,7 @@ public: AscendC::LocalTensor inputQCastFP32 = buf.GetBuffer(dataSizeFp16); AscendC::LocalTensor reverseQ = buf.GetBuffer(dataSizeFp32 + dataSizeFp16); - uint64_t qOffset = startHead * 192 + 128; + uint64_t qOffset = startHead * hiddenStrideRope_ + qkNopeHeadDim_; CopyQGenReverseQ(inputQ, inputQCastFP32, reverseQ, qOffset, loopN); // move in cos/sin @@ -184,7 +186,7 @@ public: WAIT_FLAG(S, MTE2, EVENT_ID1); WAIT_FLAG(MTE3, MTE2, EVENT_ID1); // move in Q - AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, 128 / 16, 0}); + AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, static_cast(qkNopeHeadDim_ / 16), 0}); SET_FLAG(MTE2, V, EVENT_ID1); WAIT_FLAG(MTE2, V, EVENT_ID1); // cast fp32 @@ -238,6 +240,8 @@ private: uint32_t lastCoreLoopTime; uint32_t lastCoreLoopNLast; uint32_t concatSize; + uint32_t hiddenStrideRope_; + uint32_t qkNopeHeadDim_; uint32_t blockIdx_; uint32_t loopTime{0}; // The number of current data rounds uint32_t lastLoopN{0}; // The number of lines currently processed by tails kernel @@ -2035,6 +2039,15 @@ public: this->epsilon_ = 1e-6; this->mlaParams = mlaParams_; this->hiddenStateDim = mlaParams_.hiddenStateDim; + this->mm1OutSize_ = mlaParams_.mm1OutSize; + this->splitSizeOne_ = mlaParams_.splitSizeOne; + this->splitSizeTwo_ = mlaParams_.splitSizeTwo; + this->splitRmsNormSizeOne_ = mlaParams_.splitRmsNormSizeOne; + this->splitRmsNormSizeTwo_ = mlaParams_.splitRmsNormSizeTwo; + this->ropeSplitSizeOne_ = mlaParams_.ropeSplitSizeOne; + this->ropeSplitSizeTwo_ = mlaParams_.ropeSplitSizeTwo; + this->hiddenStrideRope_ = mlaParams_.hiddenStrideRope; + this->qkNopeHeadDim_ = mlaParams_.qkNopeHeadDim; } __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm, @@ -2109,9 +2122,9 @@ public: vectorBlockIdx * static_cast(row_work) * num_col_1, row_work_, mlaParams); rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor, s3GmTensor, - s1GmTensor, SPLIT_SIZE_ONE, num_col_2, 0.000651041666, + s1GmTensor, splitSizeOne_, num_col_2, mlaParams.avgFactor, vectorBlockIdx * static_cast(row_work) * num_col_2, - vectorBlockIdx * static_cast(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams); + vectorBlockIdx * static_cast(row_work) * splitSizeTwo_, row_work_, mlaParams); ropeFp16.RopeInit(s2GmTensor, cos2GmTensor, sin2GmTensor, qGmTensor, qGmTensor2, mlaParams); einSumQuant.Init(s1Gm, gmQnopeScale, qGm, mlaParams); ubTensor = buf.GetBuffer(0); @@ -2125,6 +2138,17 @@ public: __aicore__ inline void ProcessVector(); private: + // Model-specific MLA dimensions from tiling data + uint32_t mm1OutSize_; + uint32_t splitSizeOne_; + uint32_t splitSizeTwo_; + uint32_t splitRmsNormSizeOne_; + uint32_t splitRmsNormSizeTwo_; + uint32_t ropeSplitSizeOne_; + uint32_t ropeSplitSizeTwo_; + uint32_t hiddenStrideRope_; + uint32_t qkNopeHeadDim_; + constexpr static uint32_t C0_SIZE = 16; constexpr static uint32_t I8_C0_SIZE = 32; @@ -2139,10 +2163,10 @@ private: AscendC::LocalTensor &tmpfp16, AscendC::LocalTensor &int8OutTensor, float quantScale3) { int64_t slotMapGmOffset = vectorBlockIdx * row_work; - AscendC::DataCopy(gammaTensor, gamma3GmTensor, SPLIT_RMSNRORM_SIZE_ONE); + AscendC::DataCopy(gammaTensor, gamma3GmTensor, splitRmsNormSizeOne_); SET_FLAG(MTE2, V, EVENT_ID1); WAIT_FLAG(MTE2, V, EVENT_ID1); - Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_); AscendC::DataCopyPad(slotMappingTensor, slotMappingGmTensor[slotMapGmOffset], AscendC::DataCopyExtParams(1, sN * sizeof(int32_t), 0, 0, 0), AscendC::DataCopyPadExtParams(false, 0, 8 - sN % 8, 0)); @@ -2151,134 +2175,134 @@ private: SET_FLAG(MTE2, S, EVENT_ID2); WAIT_FLAG(MTE2, S, EVENT_ID2); for (uint64_t loop = 0; loop < sN; ++loop) { - uint64_t offset = vectorBlockIdx * static_cast(row_work) * num_col_2 + loop * MM1_OUT_SIZE; + uint64_t offset = vectorBlockIdx * static_cast(row_work) * num_col_2 + loop * mm1OutSize_; int64_t slotValue = static_cast(slotMappingTensor.GetValue(loop)); if (slotValue == -1) { continue; } - AscendC::DataCopy(srcTensor, s3GmTensor[offset], SPLIT_SIZE_ONE); - AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], - SPLIT_RMSNRORM_SIZE_TWO); - AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO], - SPLIT_RMSNRORM_SIZE_TWO); + AscendC::DataCopy(srcTensor, s3GmTensor[offset], splitSizeOne_); + AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_], + splitRmsNormSizeTwo_); + AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_], + splitRmsNormSizeTwo_); SET_FLAG(MTE2, V, EVENT_ID0); // ND - uint64_t cacheStart = static_cast(slotValue) * static_cast(SPLIT_SIZE_ONE); - uint64_t cacheStart1 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_ONE); - uint64_t cacheStart2 = static_cast(slotValue) * static_cast(SPLIT_RMSNRORM_SIZE_TWO); + uint64_t cacheStart = static_cast(slotValue) * static_cast(splitSizeOne_); + uint64_t cacheStart1 = static_cast(slotValue) * static_cast(splitRmsNormSizeOne_); + uint64_t cacheStart2 = static_cast(slotValue) * static_cast(splitRmsNormSizeTwo_); // NZ uint32_t outer_idx = slotValue / 128; uint32_t inner_idx = slotValue % 128; SET_FLAG(S, MTE3, EVENT_ID0); /* RmsNorm start */ WAIT_FLAG(MTE2, V, EVENT_ID0); - Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - Mul(calTensor, rmsNormTensor, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE); + Mul(calTensor, rmsNormTensor, rmsNormTensor, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - ReduceSumCustom(calTensor[SPLIT_RMSNRORM_SIZE_ONE], calTensor, calTensor[SPLIT_RMSNRORM_SIZE_ONE * 2], - SPLIT_RMSNRORM_SIZE_ONE); + ReduceSumCustom(calTensor[splitRmsNormSizeOne_], calTensor, calTensor[splitRmsNormSizeOne_ * 2], + splitRmsNormSizeOne_); SET_FLAG(V, S, EVENT_ID1); WAIT_FLAG(V, S, EVENT_ID1); - float rms = sqrt(calTensor.GetValue(SPLIT_RMSNRORM_SIZE_ONE) / SPLIT_RMSNRORM_SIZE_ONE + epsilon_); + float rms = sqrt(calTensor.GetValue(splitRmsNormSizeOne_) / splitRmsNormSizeOne_ + epsilon_); SET_FLAG(S, V, EVENT_ID1); WAIT_FLAG(S, V, EVENT_ID1); AscendC::PipeBarrier(); - Duplicate(calTensor, rms, SPLIT_RMSNRORM_SIZE_ONE); + Duplicate(calTensor, rms, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - Div(calTensor, rmsNormTensor, calTensor, SPLIT_RMSNRORM_SIZE_ONE); + Div(calTensor, rmsNormTensor, calTensor, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - Mul(rmsNormTensor, gammaFp32, calTensor, SPLIT_RMSNRORM_SIZE_ONE); + Mul(rmsNormTensor, gammaFp32, calTensor, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_); AscendC::PipeBarrier(); if constexpr (cacheMode == CACHE_MODE_INT8_NZCACHE) { // quant - Muls(rmsNormTensor, rmsNormTensor, quantScale3, SPLIT_RMSNRORM_SIZE_ONE); + Muls(rmsNormTensor, rmsNormTensor, quantScale3, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - CastFrom32To16(tmpfp16, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE); + CastFrom32To16(tmpfp16, rmsNormTensor, splitRmsNormSizeOne_); AscendC::PipeBarrier(); - CastFromF16ToI8(int8OutTensor, tmpfp16, -128, SPLIT_RMSNRORM_SIZE_ONE); + CastFromF16ToI8(int8OutTensor, tmpfp16, -128, splitRmsNormSizeOne_); AscendC::PipeBarrier(); } else { AscendC::PipeBarrier(); if (std::is_same::value) { - Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_ONE); + Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, splitRmsNormSizeOne_); } else { - Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE); + Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_); } } /* RmsNorm end */ // /* Rope K start */ - uint64_t revertOffset = SPLIT_RMSNRORM_SIZE_TWO / 2; - Cast(ropeKTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, - SPLIT_RMSNRORM_SIZE_TWO); - Cast(ropeKRevertTensor[revertOffset], srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE, + uint64_t revertOffset = splitRmsNormSizeTwo_ / 2; + Cast(ropeKTensor, srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE, + splitRmsNormSizeTwo_); + Cast(ropeKRevertTensor[revertOffset], srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE, revertOffset); - Cast(ropeKRevertTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE + revertOffset], AscendC::RoundMode::CAST_NONE, + Cast(ropeKRevertTensor, srcTensor[splitRmsNormSizeOne_ + revertOffset], AscendC::RoundMode::CAST_NONE, revertOffset); Duplicate(calTensor, static_cast(-1), revertOffset); Duplicate(calTensor[revertOffset], static_cast(1), revertOffset); AscendC::PipeBarrier(); - Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO], cosTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO); - Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], sinTensor, AscendC::RoundMode::CAST_NONE, - SPLIT_RMSNRORM_SIZE_TWO); + Cast(calTensor[splitRmsNormSizeTwo_], cosTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeTwo_); + Cast(calTensor[splitRmsNormSizeTwo_ * 2], sinTensor, AscendC::RoundMode::CAST_NONE, + splitRmsNormSizeTwo_); AscendC::PipeBarrier(); - Mul(ropeKTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO], ropeKTensor, SPLIT_RMSNRORM_SIZE_TWO); - Mul(ropeKRevertTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + Mul(ropeKTensor, calTensor[splitRmsNormSizeTwo_], ropeKTensor, splitRmsNormSizeTwo_); + Mul(ropeKRevertTensor, calTensor[splitRmsNormSizeTwo_ * 2], ropeKRevertTensor, splitRmsNormSizeTwo_); AscendC::PipeBarrier(); - Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, splitRmsNormSizeTwo_); AscendC::PipeBarrier(); - Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO); + Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, splitRmsNormSizeTwo_); AscendC::PipeBarrier(); - Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE, - SPLIT_RMSNRORM_SIZE_TWO); + Cast(outTmpTensor[splitRmsNormSizeOne_], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE, + splitRmsNormSizeTwo_); /* Rope K end */ // reshapeAndcache SET_FLAG(V, MTE3, EVENT_ID0); WAIT_FLAG(V, MTE3, EVENT_ID0); WAIT_FLAG(S, MTE3, EVENT_ID0); if constexpr (cacheMode == CACHE_MODE_KVCACHE) { - DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, SPLIT_SIZE_ONE); + DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, splitSizeOne_); } else if constexpr (cacheMode == CACHE_MODE_INT8_NZCACHE) { // NZ int64_t cacheSatartI8Nz1 = outer_idx * 128 * 512 + inner_idx * I8_C0_SIZE; uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE; AscendC::DataCopyExtParams outExt; // nope:int8 nz - outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / I8_C0_SIZE; + outExt.blockCount = splitRmsNormSizeOne_ / I8_C0_SIZE; outExt.blockLen = I8_C0_SIZE * sizeof(int8_t); outExt.srcStride = 0; outExt.dstStride = (128 * I8_C0_SIZE - I8_C0_SIZE) * sizeof(int8_t); DataCopyPad(keycacheGmTensor1[cacheSatartI8Nz1], int8OutTensor, outExt); // rope:T1 nz - outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE; + outExt.blockCount = splitRmsNormSizeTwo_ / C0_SIZE; outExt.blockLen = C0_SIZE * sizeof(T1); outExt.srcStride = 0; outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); - DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt); + DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[splitRmsNormSizeOne_], outExt); } else if constexpr (cacheMode == CACHE_MODE_NZCACHE) { uint64_t cacheSatartNz1 = outer_idx * 128 * 512 + inner_idx * C0_SIZE; uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE; // nope:T1 nz AscendC::DataCopyExtParams outExt; - outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / C0_SIZE; + outExt.blockCount = splitRmsNormSizeOne_ / C0_SIZE; outExt.blockLen = C0_SIZE * sizeof(T1); outExt.srcStride = 0; outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); DataCopyPad(keycacheGmTensor1[cacheSatartNz1], outTmpTensor, outExt); // rope:T1 nz - outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE; + outExt.blockCount = splitRmsNormSizeTwo_ / C0_SIZE; outExt.blockLen = C0_SIZE * sizeof(T1); outExt.srcStride = 0; outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1); - DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt); + DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[splitRmsNormSizeOne_], outExt); } else { // keycache1 - DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, SPLIT_RMSNRORM_SIZE_ONE); + DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, splitRmsNormSizeOne_); // keycache2 - DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], - SPLIT_RMSNRORM_SIZE_TWO); + DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[splitRmsNormSizeOne_], + splitRmsNormSizeTwo_); } SET_FLAG(MTE3, MTE2, EVENT_ID1); WAIT_FLAG(MTE3, MTE2, EVENT_ID1); @@ -2417,19 +2441,19 @@ __aicore__ inline void MLAOperation input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(mm1OutSize_ * 2); AscendC::LocalTensor beta_tensor = - buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2); + buf.GetBuffer(mm1OutSize_ * 2 + splitSizeTwo_ * 2); AscendC::LocalTensor scale_tensor = - buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2); + buf.GetBuffer(mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2); AscendC::LocalTensor offset_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 32); + mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 32); AscendC::LocalTensor res1_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64); + mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64); AscendC::LocalTensor res3_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4); + mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4 + + mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 32); rmsNormQuant2.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); @@ -2440,19 +2464,19 @@ __aicore__ inline void MLAOperation input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(mm1OutSize_ * 2); AscendC::LocalTensor sin_tensor = - buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2); + buf.GetBuffer(mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2); AscendC::LocalTensor cos_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 2); + mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 2); AscendC::LocalTensor slotMapping_tensor = buf.GetBuffer( - MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4); + mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4); int32_t rms3_ub_offset = - MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + 4096 * 32; + mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 + 4096 * 32; AscendC::LocalTensor tmp32_tensor = buf.GetBuffer(rms3_ub_offset); - int32_t out_ub_offset = MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + - 4096 * 32 + SPLIT_RMSNRORM_SIZE_ONE * 3 * 4 + SPLIT_RMSNRORM_SIZE_TWO * 2 * 4; + int32_t out_ub_offset = mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 + + 4096 * 32 + splitRmsNormSizeOne_ * 3 * 4 + splitRmsNormSizeTwo_ * 2 * 4; AscendC::LocalTensor temp_tensor = buf.GetBuffer(out_ub_offset); AscendC::LocalTensor tmpfp16; @@ -2465,7 +2489,7 @@ __aicore__ inline void MLAOperation(rms3_ub_offset + 32); // int8out tmpfp16 = buf.GetBuffer(rms3_ub_offset + - SPLIT_RMSNRORM_SIZE_ONE * sizeof(float) * 2); + splitRmsNormSizeOne_ * sizeof(float) * 2); int8OutTensor = buf.GetBuffer(out_ub_offset); AscendC::DataCopy(quantScaleTensor, quantScale3GmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); SET_FLAG(MTE2, V, EVENT_ID1); @@ -2482,11 +2506,11 @@ __aicore__ inline void MLAOperation