[feat] parameterize hardcoded MLA dimensions to support GLM5-W8A8 (#6902)

Derive MLA dimension constants (q_lora_rank, qk_nope_head_dim, etc.)
from tensor shapes at runtime instead of hardcoding DeepSeek V3 values.
This enables the mla_preprocess fused op to work with both DeepSeek V3
and GLM5 models without Python API changes.

- Add 9 dimension fields to MlaTilingData with DeepSeek V3 defaults
- Add OpParam fields and dynamize all host-side tiling functions
- Derive dimensions from wuk, gamma1, kv_cache_rope tensor shapes
- Replace 310+ hardcoded constants across 4 kernel .hpp files
- Remove unused MMSIZE1/MMSIZE2 constants

### What this PR does / why we need it?

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: liuchenbing <chenliumail@163.com>
Co-authored-by: liuchenbing <chenliumail@163.com>
This commit is contained in:
liuchen2026fly
2026-03-09 20:17:21 +08:00
committed by GitHub
parent 13adcbe44b
commit 542258ac9d
9 changed files with 508 additions and 342 deletions

View File

@@ -43,6 +43,8 @@ public:
headDim = ropeConcatParams.headDim;
headNumQ = ropeConcatParams.headNumQ;
this->hiddenStrideRope_ = ropeConcatParams.hiddenStrideRope;
this->qkNopeHeadDim_ = ropeConcatParams.qkNopeHeadDim;
rotaryCoeff = ropeConcatParams.rotaryCoeff;
ntokens = ropeConcatParams.ntokens;
realCore = ropeConcatParams.realCore;
@@ -92,7 +94,7 @@ public:
AscendC::LocalTensor<float> inputQCastFP32 = buf.GetBuffer<BufferType::ASCEND_UB, float>(dataSizeFp16);
AscendC::LocalTensor<float> reverseQ =
buf.GetBuffer<BufferType::ASCEND_UB, float>(dataSizeFp32 + dataSizeFp16);
uint64_t qOffset = startHead * 192 + 128;
uint64_t qOffset = startHead * hiddenStrideRope_ + qkNopeHeadDim_;
CopyQGenReverseQ(inputQ, inputQCastFP32, reverseQ, qOffset, loopN);
// move in cos/sin
@@ -178,7 +180,7 @@ public:
{
// move in Q
WAIT_FLAG(MTE3, MTE2, EVENT_ID1);
AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, 128 / 16, 0});
AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, static_cast<uint16_t>(qkNopeHeadDim_ / 16), 0});
SET_FLAG(MTE2, V, EVENT_ID1);
WAIT_FLAG(MTE2, V, EVENT_ID1);
// cast fp32
@@ -215,6 +217,8 @@ private:
AscendC::GlobalTensor<QOutDtype> outRopeConcatGm_;
AscendC::GlobalTensor<QkDtype> outRopeConcatGm2_;
uint32_t hiddenStrideRope_{0};
uint32_t qkNopeHeadDim_{0};
uint32_t repeatSize_{0};
uint32_t rotateStride_{0}; // this->headDim / rope conf
uint32_t headDim;
@@ -309,6 +313,7 @@ public:
this->num_row_ = mlaParams_.n;
this->row_work = row_work;
this->row_work_ = row_work_;
this->mm1OutSize_ = mlaParams_.mm1OutSize;
gm_offset_ = gm_offset;
gm_out_offset_ = gm_out_offset;
num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
@@ -353,8 +358,8 @@ public:
if constexpr (NEED_DEQUANT) {
mmTensor = buf.ReinterpretCast<int32_t>()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16];
deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE];
perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2];
deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_];
perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_ * 2];
AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0));
}
@@ -528,6 +533,7 @@ private:
uint32_t num_col_align_withStride_fp32{0};
uint32_t num_col_temp;
half quantMin_{-128};
uint32_t mm1OutSize_{0};
uint32_t num_slice_{0};
uint32_t tail_size_{0};
uint32_t tail_copy_{0};
@@ -567,6 +573,7 @@ public:
this->num_row_ = mlaParams_.n;
this->row_work = row_work;
this->row_work_ = row_work_;
this->mm1OutSize_ = mlaParams_.mm1OutSize;
gm_offset_ = gm_offset;
gm_out_offset_ = gm_out_offset;
num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
@@ -618,8 +625,8 @@ public:
if constexpr (NEED_DEQUANT) {
mmTensor = buf.ReinterpretCast<int32_t>()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16];
deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE];
perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2];
deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_];
perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_ * 2];
AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0));
}
@@ -834,6 +841,7 @@ private:
uint32_t num_col_align_withStride_fp32{0};
uint32_t num_col_temp;
half quantMin_{-128};
uint32_t mm1OutSize_{0};
uint32_t num_slice_{0};
uint32_t tail_size_{0};
uint32_t tail_copy_{0};
@@ -2387,6 +2395,15 @@ public:
this->epsilon_ = 1e-6;
this->mlaParams = mlaParams_;
this->hiddenStateDim = mlaParams_.hiddenStateDim;
this->mm1OutSize_ = mlaParams_.mm1OutSize;
this->splitSizeOne_ = mlaParams_.splitSizeOne;
this->splitSizeTwo_ = mlaParams_.splitSizeTwo;
this->splitRmsNormSizeOne_ = mlaParams_.splitRmsNormSizeOne;
this->splitRmsNormSizeTwo_ = mlaParams_.splitRmsNormSizeTwo;
this->ropeSplitSizeOne_ = mlaParams_.ropeSplitSizeOne;
this->ropeSplitSizeTwo_ = mlaParams_.ropeSplitSizeTwo;
this->hiddenStrideRope_ = mlaParams_.hiddenStrideRope;
this->qkNopeHeadDim_ = mlaParams_.qkNopeHeadDim;
}
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
@@ -2471,15 +2488,15 @@ public:
vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_1, row_work_, mlaParams);
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor,
s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s3Gm, s1Gm, SPLIT_SIZE_ONE,
num_col_2, 0.000651041666, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
vectorBlockIdx * static_cast<uint64_t>(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams);
s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s3Gm, s1Gm, splitSizeOne_,
num_col_2, mlaParams.avgFactor, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
vectorBlockIdx * static_cast<uint64_t>(row_work) * splitSizeTwo_, row_work_, mlaParams);
} else {
// quantMode == QuantMode::PER_TOKEN_SYMM_QUANT
rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor,
s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s2Gm, s1Gm, SPLIT_SIZE_ONE,
num_col_2, 0.000651041666, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
vectorBlockIdx * static_cast<uint64_t>(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams);
s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s2Gm, s1Gm, splitSizeOne_,
num_col_2, mlaParams.avgFactor, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
vectorBlockIdx * static_cast<uint64_t>(row_work) * splitSizeTwo_, row_work_, mlaParams);
}
ropeFp16.RopeInit(s4Gm, cos2GmTensor, sin2GmTensor, qGmTensor, qGmTensor2, mlaParams);
einSumQuant.Init(s1Gm, gmQnopeScale, qGm, mlaParams);
@@ -2491,6 +2508,17 @@ public:
__aicore__ inline void ProcessVector();
private:
// Model-specific MLA dimensions from tiling data
uint32_t mm1OutSize_;
uint32_t splitSizeOne_;
uint32_t splitSizeTwo_;
uint32_t splitRmsNormSizeOne_;
uint32_t splitRmsNormSizeTwo_;
uint32_t ropeSplitSizeOne_;
uint32_t ropeSplitSizeTwo_;
uint32_t hiddenStrideRope_;
uint32_t qkNopeHeadDim_;
constexpr static uint32_t C0_SIZE = 16;
constexpr static uint32_t I8_C0_SIZE = 32;
@@ -2505,44 +2533,44 @@ private:
AscendC::LocalTensor<half> &tmpfp16, AscendC::LocalTensor<int8_t> &int8OutTensor, float quantScale3)
{
int64_t slotMapGmOffset = vectorBlockIdx * row_work;
AscendC::DataCopy(gammaTensor, gamma3GmTensor, SPLIT_RMSNRORM_SIZE_ONE);
AscendC::DataCopy(gammaTensor, gamma3GmTensor, splitRmsNormSizeOne_);
SET_FLAG(MTE2, V, EVENT_ID1);
WAIT_FLAG(MTE2, V, EVENT_ID1);
Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_);
AscendC::DataCopyPad(slotMappingTensor, slotMappingGmTensor[slotMapGmOffset],
AscendC::DataCopyExtParams(1, sN * sizeof(int32_t), 0, 0, 0),
AscendC::DataCopyPadExtParams<int32_t>(false, 0, 8 - sN % 8, 0));
if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
mmTensor = calTensor.ReinterpretCast<int32_t>()[SPLIT_SIZE_ONE];
deScaleTensor = calTensor.ReinterpretCast<float>()[SPLIT_SIZE_ONE * 2];
AscendC::DataCopy(deScaleTensor, descale1gmTensor, AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0));
mmTensor = calTensor.ReinterpretCast<int32_t>()[splitSizeOne_];
deScaleTensor = calTensor.ReinterpretCast<float>()[splitSizeOne_ * 2];
AscendC::DataCopy(deScaleTensor, descale1gmTensor, AscendC::DataCopyParams(1, splitSizeOne_ / 8, 0, 0));
}
SET_FLAG(MTE2, V, EVENT_ID2);
WAIT_FLAG(MTE2, V, EVENT_ID2);
SET_FLAG(MTE2, S, EVENT_ID2);
WAIT_FLAG(MTE2, S, EVENT_ID2);
for (uint64_t loop = 0; loop < sN; ++loop) {
uint64_t offset = vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2 + loop * MM1_OUT_SIZE;
uint64_t offset = vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2 + loop * mm1OutSize_;
int64_t slotValue = static_cast<int64_t>(slotMappingTensor.GetValue(loop));
if (slotValue == -1) {
continue;
}
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
AscendC::DataCopy(srcTensor, s3GmTensor[offset],
AscendC::DataCopyParams(1, MM1_OUT_SIZE / BLOCK_SIZE_16, 0, 0));
AscendC::DataCopyParams(1, mm1OutSize_ / BLOCK_SIZE_16, 0, 0));
} else {
// quantMode == QuantMode::PER_TOKEN_SYMM_QUANT
AscendC::DataCopy(mmTensor, s2GmTensor[offset], AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0));
AscendC::DataCopy(mmTensor, s2GmTensor[offset], AscendC::DataCopyParams(1, splitSizeOne_ / 8, 0, 0));
}
AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO],
SPLIT_RMSNRORM_SIZE_TWO);
AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO],
SPLIT_RMSNRORM_SIZE_TWO);
AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_],
splitRmsNormSizeTwo_);
AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_],
splitRmsNormSizeTwo_);
SET_FLAG(MTE2, V, EVENT_ID0);
// ND
uint64_t cacheStart = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_SIZE_ONE);
uint64_t cacheStart1 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_RMSNRORM_SIZE_ONE);
uint64_t cacheStart2 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_RMSNRORM_SIZE_TWO);
uint64_t cacheStart = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(splitSizeOne_);
uint64_t cacheStart1 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(splitRmsNormSizeOne_);
uint64_t cacheStart2 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(splitRmsNormSizeTwo_);
// NZ
uint32_t outer_idx = slotValue / 128;
uint32_t inner_idx = slotValue % 128;
@@ -2553,84 +2581,84 @@ private:
if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
/* DeQuant */
AscendC::Cast(mmTensor.ReinterpretCast<float>(), mmTensor, AscendC::RoundMode::CAST_NONE,
SPLIT_SIZE_ONE);
splitSizeOne_);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mul(mmTensor.ReinterpretCast<float>(), mmTensor.ReinterpretCast<float>(), deScaleTensor,
SPLIT_SIZE_ONE);
splitSizeOne_);
AscendC::PipeBarrier<PIPE_V>();
float perTokenDescale = s5GmTensor.GetValue(row_work * vectorBlockIdx + loop);
SET_FLAG(S, V, EVENT_ID0);
WAIT_FLAG(S, V, EVENT_ID0);
AscendC::Muls(mmTensor.ReinterpretCast<float>(), mmTensor.ReinterpretCast<float>(), perTokenDescale,
SPLIT_SIZE_ONE);
splitSizeOne_);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(srcTensor, mmTensor.ReinterpretCast<float>(), AscendC::RoundMode::CAST_RINT,
SPLIT_SIZE_ONE);
splitSizeOne_);
AscendC::PipeBarrier<PIPE_V>();
}
Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_);
AscendC::PipeBarrier<PIPE_V>();
Mul(calTensor, rmsNormTensor, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE);
Mul(calTensor, rmsNormTensor, rmsNormTensor, splitRmsNormSizeOne_);
AscendC::PipeBarrier<PIPE_V>();
ReduceSumCustom(calTensor[SPLIT_RMSNRORM_SIZE_ONE], calTensor, calTensor[SPLIT_RMSNRORM_SIZE_ONE * 2],
SPLIT_RMSNRORM_SIZE_ONE);
ReduceSumCustom(calTensor[splitRmsNormSizeOne_], calTensor, calTensor[splitRmsNormSizeOne_ * 2],
splitRmsNormSizeOne_);
SET_FLAG(V, S, EVENT_ID1);
WAIT_FLAG(V, S, EVENT_ID1);
float rms = sqrt(calTensor.GetValue(SPLIT_RMSNRORM_SIZE_ONE) / SPLIT_RMSNRORM_SIZE_ONE + epsilon_);
float rms = sqrt(calTensor.GetValue(splitRmsNormSizeOne_) / splitRmsNormSizeOne_ + epsilon_);
SET_FLAG(S, V, EVENT_ID1);
WAIT_FLAG(S, V, EVENT_ID1);
AscendC::PipeBarrier<PIPE_V>();
Duplicate(calTensor, rms, SPLIT_RMSNRORM_SIZE_ONE);
Duplicate(calTensor, rms, splitRmsNormSizeOne_);
AscendC::PipeBarrier<PIPE_V>();
Div(calTensor, rmsNormTensor, calTensor, SPLIT_RMSNRORM_SIZE_ONE);
Div(calTensor, rmsNormTensor, calTensor, splitRmsNormSizeOne_);
AscendC::PipeBarrier<PIPE_V>();
Mul(rmsNormTensor, gammaFp32, calTensor, SPLIT_RMSNRORM_SIZE_ONE);
Mul(rmsNormTensor, gammaFp32, calTensor, splitRmsNormSizeOne_);
AscendC::PipeBarrier<PIPE_V>();
if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) {
// quant
Muls(rmsNormTensor, rmsNormTensor, quantScale3, SPLIT_RMSNRORM_SIZE_ONE);
Muls(rmsNormTensor, rmsNormTensor, quantScale3, splitRmsNormSizeOne_);
AscendC::PipeBarrier<PIPE_V>();
CastFrom32To16(tmpfp16, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE);
CastFrom32To16(tmpfp16, rmsNormTensor, splitRmsNormSizeOne_);
AscendC::PipeBarrier<PIPE_V>();
CastFromF16ToI8(int8OutTensor, tmpfp16, -128, SPLIT_RMSNRORM_SIZE_ONE);
CastFromF16ToI8(int8OutTensor, tmpfp16, -128, splitRmsNormSizeOne_);
AscendC::PipeBarrier<PIPE_V>();
} else {
AscendC::PipeBarrier<PIPE_V>();
if (std::is_same<T1, __bf16>::value) {
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_ONE);
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, splitRmsNormSizeOne_);
} else {
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_);
}
}
/* RmsNorm end */
/* Rope K start */
uint64_t revertOffset = SPLIT_RMSNRORM_SIZE_TWO / 2;
Cast(ropeKTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE,
SPLIT_RMSNRORM_SIZE_TWO);
Cast(ropeKRevertTensor[revertOffset], srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE,
uint64_t revertOffset = splitRmsNormSizeTwo_ / 2;
Cast(ropeKTensor, srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE,
splitRmsNormSizeTwo_);
Cast(ropeKRevertTensor[revertOffset], srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE,
revertOffset);
Cast(ropeKRevertTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE + revertOffset], AscendC::RoundMode::CAST_NONE,
Cast(ropeKRevertTensor, srcTensor[splitRmsNormSizeOne_ + revertOffset], AscendC::RoundMode::CAST_NONE,
revertOffset);
Duplicate(calTensor, static_cast<float>(-1), revertOffset);
Duplicate(calTensor[revertOffset], static_cast<float>(1), revertOffset);
AscendC::PipeBarrier<PIPE_V>();
Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO], cosTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO);
Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], sinTensor, AscendC::RoundMode::CAST_NONE,
SPLIT_RMSNRORM_SIZE_TWO);
Cast(calTensor[splitRmsNormSizeTwo_], cosTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeTwo_);
Cast(calTensor[splitRmsNormSizeTwo_ * 2], sinTensor, AscendC::RoundMode::CAST_NONE,
splitRmsNormSizeTwo_);
AscendC::PipeBarrier<PIPE_V>();
Mul(ropeKTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO], ropeKTensor, SPLIT_RMSNRORM_SIZE_TWO);
Mul(ropeKRevertTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
Mul(ropeKTensor, calTensor[splitRmsNormSizeTwo_], ropeKTensor, splitRmsNormSizeTwo_);
Mul(ropeKRevertTensor, calTensor[splitRmsNormSizeTwo_ * 2], ropeKRevertTensor, splitRmsNormSizeTwo_);
AscendC::PipeBarrier<PIPE_V>();
Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, splitRmsNormSizeTwo_);
AscendC::PipeBarrier<PIPE_V>();
Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, splitRmsNormSizeTwo_);
AscendC::PipeBarrier<PIPE_V>();
if (std::is_same<T1, __bf16>::value) {
Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT,
SPLIT_RMSNRORM_SIZE_TWO);
Cast(outTmpTensor[splitRmsNormSizeOne_], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT,
splitRmsNormSizeTwo_);
} else {
Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE,
SPLIT_RMSNRORM_SIZE_TWO);
Cast(outTmpTensor[splitRmsNormSizeOne_], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE,
splitRmsNormSizeTwo_);
}
AscendC::PipeBarrier<PIPE_V>();
/* Rope K end */
@@ -2638,45 +2666,45 @@ private:
WAIT_FLAG(V, MTE3, EVENT_ID0);
WAIT_FLAG(S, MTE3, EVENT_ID0);
if constexpr (CACHE_MODE == CACHE_MODE_KVCACHE) {
DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, SPLIT_SIZE_ONE);
DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, splitSizeOne_);
} else if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) {
uint64_t cacheSatartI8Nz1 = outer_idx * 128 * 512 + inner_idx * I8_C0_SIZE;
uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE;
// nope:int8 nz
AscendC::DataCopyExtParams outExt;
outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / I8_C0_SIZE;
outExt.blockCount = splitRmsNormSizeOne_ / I8_C0_SIZE;
outExt.blockLen = I8_C0_SIZE * sizeof(int8_t);
outExt.srcStride = 0;
outExt.dstStride = (128 * I8_C0_SIZE - I8_C0_SIZE) * sizeof(int8_t);
DataCopyPad(keycacheGmTensor1[cacheSatartI8Nz1], int8OutTensor, outExt);
// rope:T1 nz
outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE;
outExt.blockCount = splitRmsNormSizeTwo_ / C0_SIZE;
outExt.blockLen = C0_SIZE * sizeof(T1);
outExt.srcStride = 0;
outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1);
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt);
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[splitRmsNormSizeOne_], outExt);
} else if constexpr (CACHE_MODE == CACHE_MODE_NZCACHE) {
uint64_t cacheSatartNz1 = outer_idx * 128 * 512 + inner_idx * C0_SIZE;
uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE;
// nope:T1 nz
AscendC::DataCopyExtParams outExt;
outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / C0_SIZE;
outExt.blockCount = splitRmsNormSizeOne_ / C0_SIZE;
outExt.blockLen = C0_SIZE * sizeof(T1);
outExt.srcStride = 0;
outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1);
DataCopyPad(keycacheGmTensor1[cacheSatartNz1], outTmpTensor, outExt);
// rope:T1 nz
outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE;
outExt.blockCount = splitRmsNormSizeTwo_ / C0_SIZE;
outExt.blockLen = C0_SIZE * sizeof(T1);
outExt.srcStride = 0;
outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1);
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt);
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[splitRmsNormSizeOne_], outExt);
} else {
// keycache1
DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, SPLIT_RMSNRORM_SIZE_ONE);
DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, splitRmsNormSizeOne_);
// keycache2
DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE],
SPLIT_RMSNRORM_SIZE_TWO);
DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[splitRmsNormSizeOne_],
splitRmsNormSizeTwo_);
}
SET_FLAG(MTE3, MTE2, EVENT_ID1);
WAIT_FLAG(MTE3, MTE2, EVENT_ID1);
@@ -2823,20 +2851,20 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
uint32_t num_col_align_f16 = (num_col_2 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
uint32_t num_col_align_f32 = (num_col_2 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2);
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2);
AscendC::LocalTensor<InDtype> beta_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2);
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2 + splitSizeTwo_ * 2);
AscendC::LocalTensor<InDtype> scale_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2);
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 32);
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 32);
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64);
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64);
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4);
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4);
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4 +
BUF_FACTOR * num_col_align_f32 * 4 + 64 + MM1_OUT_SIZE * 4 * 2 + 32);
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4 +
BUF_FACTOR * num_col_align_f32 * 4 + 64 + mm1OutSize_ * 4 * 2 + 32);
rmsNormQuant2.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor,
res1_tensor, res3_tensor);
}
@@ -2846,20 +2874,20 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
if (row_work_ != 0) {
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2);
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2);
AscendC::LocalTensor<InDtype> sin_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2);
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2);
AscendC::LocalTensor<InDtype> cos_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 2);
mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 2);
AscendC::LocalTensor<int32_t> slotMapping_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int32_t>(
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4);
mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4);
int32_t rms3_ub_offset =
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + 4096 * 32;
mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 + 4096 * 32;
AscendC::LocalTensor<float> tmp32_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(rms3_ub_offset);
int32_t out_ub_offset = MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 +
4096 * 32 + SPLIT_RMSNRORM_SIZE_ONE * 3 * 4 + SPLIT_RMSNRORM_SIZE_TWO * 2 * 4 +
MM1_OUT_SIZE * 4 * 2 + 32;
int32_t out_ub_offset = mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 +
4096 * 32 + splitRmsNormSizeOne_ * 3 * 4 + splitRmsNormSizeTwo_ * 2 * 4 +
mm1OutSize_ * 4 * 2 + 32;
AscendC::LocalTensor<InDtype> temp_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(out_ub_offset);
AscendC::LocalTensor<half> tmpfp16;
@@ -2873,7 +2901,7 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
buf.GetBuffer<BufferType::ASCEND_UB, float>(rms3_ub_offset + 32);
// int8out
tmpfp16 = buf.GetBuffer<BufferType::ASCEND_UB, half>(rms3_ub_offset +
SPLIT_RMSNRORM_SIZE_ONE * sizeof(float) * 2);
splitRmsNormSizeOne_ * sizeof(float) * 2);
int8OutTensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(out_ub_offset);
AscendC::DataCopy(quantScaleTensor, quantScale3GmTensor, AscendC::DataCopyParams(1, 1, 0, 0));
SET_FLAG(MTE2, V, EVENT_ID1);
@@ -2890,11 +2918,11 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
sin_tensor, // sin
cos_tensor, // cons
slotMapping_tensor, // slotMapping
row_work_, tmp32_tensor, tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE],
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE],
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO],
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO +
SPLIT_RMSNRORM_SIZE_TWO],
row_work_, tmp32_tensor, tmp32_tensor[splitRmsNormSizeOne_],
tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_],
tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_ + splitRmsNormSizeTwo_],
tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_ + splitRmsNormSizeTwo_ +
splitRmsNormSizeTwo_],
temp_tensor, tmpfp16, int8OutTensor, scale3);
}
mm_w8a8_aiv_2.Process();