[feat] parameterize hardcoded MLA dimensions to support GLM5-W8A8 (#6902)
Derive MLA dimension constants (q_lora_rank, qk_nope_head_dim, etc.)
from tensor shapes at runtime instead of hardcoding DeepSeek V3 values.
This enables the mla_preprocess fused op to work with both DeepSeek V3
and GLM5 models without Python API changes.
- Add 9 dimension fields to MlaTilingData with DeepSeek V3 defaults
- Add OpParam fields and dynamize all host-side tiling functions
- Derive dimensions from wuk, gamma1, kv_cache_rope tensor shapes
- Replace 310+ hardcoded constants across 4 kernel .hpp files
- Remove unused MMSIZE1/MMSIZE2 constants
### What this PR does / why we need it?
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: liuchenbing <chenliumail@163.com>
Co-authored-by: liuchenbing <chenliumail@163.com>
This commit is contained in:
@@ -84,6 +84,8 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &>
|
||||
hiddenState,
|
||||
wdqkv,
|
||||
wuk,
|
||||
gamma1,
|
||||
kv_cache_rope,
|
||||
cache_mode,
|
||||
quant_mode,
|
||||
enableInnerOut
|
||||
|
||||
@@ -129,6 +129,11 @@ struct OpParam {
|
||||
QuantMode quantMode;
|
||||
caffe2::TypeMeta inDtype;
|
||||
bool enableInnerOut;
|
||||
// MLA dimensions derived from tensor shapes
|
||||
uint32_t qLoraRank;
|
||||
uint32_t qkNopeHeadDim;
|
||||
uint32_t qkRopeHeadDim;
|
||||
uint32_t kvLoraRank;
|
||||
};
|
||||
|
||||
class PpMatmulTilingApi
|
||||
@@ -397,7 +402,7 @@ void MlaPreprocessTiling::RmsNormQuantTiling()
|
||||
tilingData->rmsNumRow1 = opParam.N;
|
||||
tilingData->rmsQuantMin1 = -CONST_128;
|
||||
tilingData->rmsNumCore2 = platformInfo.coreNumAiv;
|
||||
tilingData->rmsNumCol2 = HIDDEN_STRATE_MM;
|
||||
tilingData->rmsNumCol2 = opParam.qLoraRank + opParam.kvLoraRank + opParam.qkRopeHeadDim;
|
||||
tilingData->rmsNumRow2 = opParam.N;
|
||||
tilingData->rmsQuantMin2 = -CONST_128;
|
||||
}
|
||||
@@ -405,10 +410,10 @@ void MlaPreprocessTiling::RmsNormQuantTiling()
|
||||
void MlaPreprocessTiling::RopeConcatTiling()
|
||||
{
|
||||
uint32_t ntokens = opParam.N;
|
||||
uint32_t hiddenSizeQ = HEADDIM * opParam.headNum;
|
||||
uint32_t headDim = HEADDIM;
|
||||
uint32_t hiddenSizeQ = opParam.qkRopeHeadDim * opParam.headNum;
|
||||
uint32_t headDim = opParam.qkRopeHeadDim;
|
||||
uint32_t headNumQ = hiddenSizeQ / headDim;
|
||||
uint32_t concatSize = CONCAT_SIZE;
|
||||
uint32_t concatSize = opParam.kvLoraRank;
|
||||
uint32_t maxCore = platformInfo.coreNumAiv;
|
||||
uint32_t maxUbSize = platformInfo.ubSize;
|
||||
|
||||
@@ -458,7 +463,7 @@ void MlaPreprocessTiling::EinSumQuantTiling()
|
||||
// input shape
|
||||
uint32_t esqBatch = opParam.N; // tokenNum
|
||||
uint32_t esqHeadNum = opParam.headNum; // headNum
|
||||
uint32_t esqColNum = AXES_ALIGN_SIZE; // 512
|
||||
uint32_t esqColNum = opParam.kvLoraRank; // kv_lora_rank
|
||||
|
||||
// split core
|
||||
uint32_t esqFrontCore = esqBatch % aivCore;
|
||||
@@ -508,14 +513,16 @@ void MlaPreprocessTiling::EinSumQuantTiling()
|
||||
|
||||
void MlaPreprocessTiling::SetMlapoWorkSpace()
|
||||
{
|
||||
uint32_t hiddenStrideRope = opParam.qkNopeHeadDim + opParam.qkRopeHeadDim;
|
||||
uint32_t hiddenStrateMm = opParam.qLoraRank + opParam.kvLoraRank + opParam.qkRopeHeadDim;
|
||||
uint64_t s1wsFactor =
|
||||
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(opParam.hiddenStateDim * sizeof(int8_t),
|
||||
opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t))
|
||||
opParam.headNum * opParam.kvLoraRank * sizeof(uint16_t))
|
||||
: opParam.hiddenStateDim * sizeof(int8_t));
|
||||
uint64_t workSizeS1 = s1wsFactor;
|
||||
uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t);
|
||||
uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t);
|
||||
uint64_t workSizeS4 = std::max(opParam.headNum * HIDDEN_STRATE_ROPE, HIDDEN_STRATE_MM) * sizeof(uint32_t);
|
||||
uint64_t workSizeS2 = opParam.headNum * hiddenStrideRope * sizeof(uint16_t);
|
||||
uint64_t workSizeS3 = hiddenStrateMm * sizeof(uint16_t);
|
||||
uint64_t workSizeS4 = std::max(opParam.headNum * hiddenStrideRope, hiddenStrateMm) * sizeof(uint32_t);
|
||||
|
||||
uint64_t maxWorkspaceSize = workSizeS1;
|
||||
maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS2);
|
||||
@@ -564,11 +571,17 @@ void MlaPreprocessTiling::Init()
|
||||
deqOnTheFly = true;
|
||||
}
|
||||
|
||||
uint32_t mm1N = opParam.qLoraRank + opParam.kvLoraRank + opParam.qkRopeHeadDim;
|
||||
uint32_t mm2K = opParam.qLoraRank;
|
||||
uint32_t mm2N = opParam.headNum * (opParam.qkNopeHeadDim + opParam.qkRopeHeadDim);
|
||||
uint32_t mm3K = opParam.qkNopeHeadDim;
|
||||
uint32_t mm3N = opParam.kvLoraRank;
|
||||
|
||||
PpMatmulTilingApi mm1TilingApi(platformInfo,
|
||||
1, // numBatch
|
||||
opParam.N, // m
|
||||
opParam.hiddenStateDim, // k
|
||||
HIDDEN_STRATE_MM, // n
|
||||
mm1N, // n
|
||||
false, // transA
|
||||
true, // transB
|
||||
enDequant, // enDequant
|
||||
@@ -576,21 +589,21 @@ void MlaPreprocessTiling::Init()
|
||||
mm1TilingApi.GetTilingData(tilingData->mm1);
|
||||
|
||||
PpMatmulTilingApi mm2TilingApi(platformInfo,
|
||||
1, // numBatch
|
||||
opParam.N, // m
|
||||
HIDDEN_STRATE_RMS, // k
|
||||
opParam.headNum * HIDDEN_STRATE_ROPE, // n
|
||||
false, // transA
|
||||
true, // transB
|
||||
enDequant, // enDequant
|
||||
deqOnTheFly); // in bf16.cce?
|
||||
1, // numBatch
|
||||
opParam.N, // m
|
||||
mm2K, // k
|
||||
mm2N, // n
|
||||
false, // transA
|
||||
true, // transB
|
||||
enDequant, // enDequant
|
||||
deqOnTheFly); // in bf16.cce?
|
||||
mm2TilingApi.GetTilingData(tilingData->mm2);
|
||||
|
||||
PpMatmulTilingApi mm3TilingApi(platformInfo,
|
||||
opParam.headNum, // numBatch
|
||||
opParam.N, // m
|
||||
CONST_128, // k
|
||||
CONCAT_SIZE, // n
|
||||
mm3K, // k
|
||||
mm3N, // n
|
||||
false, // transA
|
||||
false, // transB
|
||||
false, // enDequant
|
||||
@@ -604,6 +617,18 @@ void MlaPreprocessTiling::Init()
|
||||
SetMlapoWorkSpace();
|
||||
SetTilingKey();
|
||||
|
||||
// Populate model-specific MLA dimension fields
|
||||
tilingData->mm1OutSize = opParam.qLoraRank + opParam.kvLoraRank + opParam.qkRopeHeadDim;
|
||||
tilingData->splitSizeOne = opParam.kvLoraRank + opParam.qkRopeHeadDim;
|
||||
tilingData->splitSizeTwo = opParam.qLoraRank;
|
||||
tilingData->splitRmsNormSizeOne = opParam.kvLoraRank;
|
||||
tilingData->splitRmsNormSizeTwo = opParam.qkRopeHeadDim;
|
||||
tilingData->ropeSplitSizeOne = opParam.qkRopeHeadDim;
|
||||
tilingData->ropeSplitSizeTwo = opParam.qkNopeHeadDim;
|
||||
tilingData->hiddenStrideRope = opParam.qkNopeHeadDim + opParam.qkRopeHeadDim;
|
||||
tilingData->qkNopeHeadDim = opParam.qkNopeHeadDim;
|
||||
tilingData->avgFactor = 1.0f / static_cast<float>(opParam.qLoraRank);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -631,6 +656,8 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
|
||||
const at::Tensor &hiddenState,
|
||||
const at::Tensor &wdqkv,
|
||||
const at::Tensor &wuk,
|
||||
const at::Tensor &gamma1,
|
||||
const at::Tensor &kv_cache_rope,
|
||||
c10::optional<c10::string_view> cache_mode,
|
||||
c10::optional<c10::string_view> quant_mode,
|
||||
bool enable_inner_out
|
||||
@@ -656,6 +683,12 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
|
||||
int32_t headNum = wuk.sizes()[0];
|
||||
uint32_t hiddenStateDim = hiddenState.sizes().back();
|
||||
|
||||
// Derive MLA dimensions from tensor shapes
|
||||
uint32_t qkNopeHeadDim = wuk.sizes()[1];
|
||||
uint32_t kvLoraRank = wuk.sizes()[2];
|
||||
uint32_t qLoraRank = gamma1.sizes()[0];
|
||||
uint32_t qkRopeHeadDim = kv_cache_rope.sizes().back();
|
||||
|
||||
OpParam opParam;
|
||||
opParam.hiddenStateDim = hiddenStateDim;
|
||||
opParam.N = N;
|
||||
@@ -664,6 +697,10 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
|
||||
opParam.quantMode = static_cast<QuantMode>(quantMode);
|
||||
opParam.inDtype = hiddenState.options().dtype();
|
||||
opParam.enableInnerOut = enable_inner_out;
|
||||
opParam.qLoraRank = qLoraRank;
|
||||
opParam.qkNopeHeadDim = qkNopeHeadDim;
|
||||
opParam.qkRopeHeadDim = qkRopeHeadDim;
|
||||
opParam.kvLoraRank = kvLoraRank;
|
||||
if (wdqkv.options().dtype() == at::kBFloat16 || wdqkv.options().dtype() == at::kHalf) {
|
||||
opParam.isWeightQuantized = 0;
|
||||
} else {
|
||||
|
||||
@@ -95,6 +95,18 @@ struct MlaTilingData {
|
||||
uint32_t hiddenStateDim{7168};
|
||||
|
||||
uint32_t isWeightQuantized{1};
|
||||
|
||||
// Model-specific MLA dimensions (derived from tensor shapes)
|
||||
uint32_t mm1OutSize{2112}; // q_lora_rank + kv_lora_rank + qk_rope_head_dim
|
||||
uint32_t splitSizeOne{576}; // kv_lora_rank + qk_rope_head_dim
|
||||
uint32_t splitSizeTwo{1536}; // q_lora_rank
|
||||
uint32_t splitRmsNormSizeOne{512}; // kv_lora_rank
|
||||
uint32_t splitRmsNormSizeTwo{64}; // qk_rope_head_dim
|
||||
uint32_t ropeSplitSizeOne{64}; // qk_rope_head_dim
|
||||
uint32_t ropeSplitSizeTwo{128}; // qk_nope_head_dim
|
||||
uint32_t hiddenStrideRope{192}; // qk_nope_head_dim + qk_rope_head_dim
|
||||
uint32_t qkNopeHeadDim{128}; // for RoPE offset calc
|
||||
float avgFactor{0.000651041666f}; // 1/splitSizeTwo (1/qLoraRank), for RmsNorm avg
|
||||
};
|
||||
|
||||
#endif // MLAPREPROCESS_TILING_H
|
||||
|
||||
@@ -60,9 +60,6 @@ constexpr uint32_t SPLIT_RMSNRORM_SIZE_TWO = 64;
|
||||
constexpr uint32_t ROPE_SPLIT_SIZE_ONE = 64;
|
||||
constexpr uint32_t ROPE_SPLIT_SIZE_TWO = 128;
|
||||
|
||||
constexpr uint32_t MMSIZE1 = 128 * 192; // 24576
|
||||
constexpr uint32_t MMSIZE2 = 64 * 128; // 8192
|
||||
|
||||
constexpr uint64_t L0_PINGPONG_BUFFER_LEN = 32768; // 32 KB
|
||||
constexpr uint64_t L1_PINGPONG_BUFFER_LEN = 262144; // 256 KB
|
||||
constexpr uint64_t BLOCK_SIZE_16 = 16;
|
||||
|
||||
@@ -136,6 +136,18 @@ extern "C" __global__ __aicore__ void mla_preprocess(
|
||||
mlaTilingData.s4Offset = tilingData->s4Offset;
|
||||
mlaTilingData.s5Offset = tilingData->s5Offset;
|
||||
|
||||
// Model-specific MLA dimensions
|
||||
mlaTilingData.mm1OutSize = tilingData->mm1OutSize;
|
||||
mlaTilingData.splitSizeOne = tilingData->splitSizeOne;
|
||||
mlaTilingData.splitSizeTwo = tilingData->splitSizeTwo;
|
||||
mlaTilingData.splitRmsNormSizeOne = tilingData->splitRmsNormSizeOne;
|
||||
mlaTilingData.splitRmsNormSizeTwo = tilingData->splitRmsNormSizeTwo;
|
||||
mlaTilingData.ropeSplitSizeOne = tilingData->ropeSplitSizeOne;
|
||||
mlaTilingData.ropeSplitSizeTwo = tilingData->ropeSplitSizeTwo;
|
||||
mlaTilingData.hiddenStrideRope = tilingData->hiddenStrideRope;
|
||||
mlaTilingData.qkNopeHeadDim = tilingData->qkNopeHeadDim;
|
||||
mlaTilingData.avgFactor = tilingData->avgFactor;
|
||||
|
||||
GM_ADDR s1 = workspace + static_cast<uint64_t>(mlaTilingData.s1Offset);
|
||||
GM_ADDR s2 = workspace + static_cast<uint64_t>(mlaTilingData.s2Offset);
|
||||
GM_ADDR s3 = workspace + static_cast<uint64_t>(mlaTilingData.s3Offset);
|
||||
|
||||
@@ -43,6 +43,8 @@ public:
|
||||
|
||||
headDim = ropeConcatParams.headDim;
|
||||
headNumQ = ropeConcatParams.headNumQ;
|
||||
this->hiddenStrideRope_ = ropeConcatParams.hiddenStrideRope;
|
||||
this->qkNopeHeadDim_ = ropeConcatParams.qkNopeHeadDim;
|
||||
rotaryCoeff = ropeConcatParams.rotaryCoeff;
|
||||
ntokens = ropeConcatParams.ntokens;
|
||||
realCore = ropeConcatParams.realCore;
|
||||
@@ -92,7 +94,7 @@ public:
|
||||
AscendC::LocalTensor<float> inputQCastFP32 = buf.GetBuffer<BufferType::ASCEND_UB, float>(dataSizeFp16);
|
||||
AscendC::LocalTensor<float> reverseQ =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, float>(dataSizeFp32 + dataSizeFp16);
|
||||
uint64_t qOffset = startHead * 192 + 128;
|
||||
uint64_t qOffset = startHead * hiddenStrideRope_ + qkNopeHeadDim_;
|
||||
CopyQGenReverseQ(inputQ, inputQCastFP32, reverseQ, qOffset, loopN);
|
||||
|
||||
// move in cos/sin
|
||||
@@ -178,7 +180,7 @@ public:
|
||||
{
|
||||
// move in Q
|
||||
WAIT_FLAG(MTE3, MTE2, EVENT_ID1);
|
||||
AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, 128 / 16, 0});
|
||||
AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, static_cast<uint16_t>(qkNopeHeadDim_ / 16), 0});
|
||||
SET_FLAG(MTE2, V, EVENT_ID1);
|
||||
WAIT_FLAG(MTE2, V, EVENT_ID1);
|
||||
// cast fp32
|
||||
@@ -215,6 +217,8 @@ private:
|
||||
AscendC::GlobalTensor<QOutDtype> outRopeConcatGm_;
|
||||
AscendC::GlobalTensor<QkDtype> outRopeConcatGm2_;
|
||||
|
||||
uint32_t hiddenStrideRope_{0};
|
||||
uint32_t qkNopeHeadDim_{0};
|
||||
uint32_t repeatSize_{0};
|
||||
uint32_t rotateStride_{0}; // this->headDim / rope conf
|
||||
uint32_t headDim;
|
||||
@@ -309,6 +313,7 @@ public:
|
||||
this->num_row_ = mlaParams_.n;
|
||||
this->row_work = row_work;
|
||||
this->row_work_ = row_work_;
|
||||
this->mm1OutSize_ = mlaParams_.mm1OutSize;
|
||||
gm_offset_ = gm_offset;
|
||||
gm_out_offset_ = gm_out_offset;
|
||||
num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
|
||||
@@ -353,8 +358,8 @@ public:
|
||||
|
||||
if constexpr (NEED_DEQUANT) {
|
||||
mmTensor = buf.ReinterpretCast<int32_t>()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16];
|
||||
deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE];
|
||||
perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2];
|
||||
deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_];
|
||||
perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_ * 2];
|
||||
AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0));
|
||||
}
|
||||
|
||||
@@ -528,6 +533,7 @@ private:
|
||||
uint32_t num_col_align_withStride_fp32{0};
|
||||
uint32_t num_col_temp;
|
||||
half quantMin_{-128};
|
||||
uint32_t mm1OutSize_{0};
|
||||
uint32_t num_slice_{0};
|
||||
uint32_t tail_size_{0};
|
||||
uint32_t tail_copy_{0};
|
||||
@@ -567,6 +573,7 @@ public:
|
||||
this->num_row_ = mlaParams_.n;
|
||||
this->row_work = row_work;
|
||||
this->row_work_ = row_work_;
|
||||
this->mm1OutSize_ = mlaParams_.mm1OutSize;
|
||||
gm_offset_ = gm_offset;
|
||||
gm_out_offset_ = gm_out_offset;
|
||||
num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
|
||||
@@ -618,8 +625,8 @@ public:
|
||||
|
||||
if constexpr (NEED_DEQUANT) {
|
||||
mmTensor = buf.ReinterpretCast<int32_t>()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16];
|
||||
deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE];
|
||||
perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2];
|
||||
deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_];
|
||||
perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_ * 2];
|
||||
AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0));
|
||||
}
|
||||
|
||||
@@ -834,6 +841,7 @@ private:
|
||||
uint32_t num_col_align_withStride_fp32{0};
|
||||
uint32_t num_col_temp;
|
||||
half quantMin_{-128};
|
||||
uint32_t mm1OutSize_{0};
|
||||
uint32_t num_slice_{0};
|
||||
uint32_t tail_size_{0};
|
||||
uint32_t tail_copy_{0};
|
||||
@@ -2387,6 +2395,15 @@ public:
|
||||
this->epsilon_ = 1e-6;
|
||||
this->mlaParams = mlaParams_;
|
||||
this->hiddenStateDim = mlaParams_.hiddenStateDim;
|
||||
this->mm1OutSize_ = mlaParams_.mm1OutSize;
|
||||
this->splitSizeOne_ = mlaParams_.splitSizeOne;
|
||||
this->splitSizeTwo_ = mlaParams_.splitSizeTwo;
|
||||
this->splitRmsNormSizeOne_ = mlaParams_.splitRmsNormSizeOne;
|
||||
this->splitRmsNormSizeTwo_ = mlaParams_.splitRmsNormSizeTwo;
|
||||
this->ropeSplitSizeOne_ = mlaParams_.ropeSplitSizeOne;
|
||||
this->ropeSplitSizeTwo_ = mlaParams_.ropeSplitSizeTwo;
|
||||
this->hiddenStrideRope_ = mlaParams_.hiddenStrideRope;
|
||||
this->qkNopeHeadDim_ = mlaParams_.qkNopeHeadDim;
|
||||
}
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
|
||||
@@ -2471,15 +2488,15 @@ public:
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_1, row_work_, mlaParams);
|
||||
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
|
||||
rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor,
|
||||
s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s3Gm, s1Gm, SPLIT_SIZE_ONE,
|
||||
num_col_2, 0.000651041666, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams);
|
||||
s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s3Gm, s1Gm, splitSizeOne_,
|
||||
num_col_2, mlaParams.avgFactor, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * splitSizeTwo_, row_work_, mlaParams);
|
||||
} else {
|
||||
// quantMode == QuantMode::PER_TOKEN_SYMM_QUANT
|
||||
rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor,
|
||||
s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s2Gm, s1Gm, SPLIT_SIZE_ONE,
|
||||
num_col_2, 0.000651041666, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams);
|
||||
s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s2Gm, s1Gm, splitSizeOne_,
|
||||
num_col_2, mlaParams.avgFactor, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * splitSizeTwo_, row_work_, mlaParams);
|
||||
}
|
||||
ropeFp16.RopeInit(s4Gm, cos2GmTensor, sin2GmTensor, qGmTensor, qGmTensor2, mlaParams);
|
||||
einSumQuant.Init(s1Gm, gmQnopeScale, qGm, mlaParams);
|
||||
@@ -2491,6 +2508,17 @@ public:
|
||||
__aicore__ inline void ProcessVector();
|
||||
|
||||
private:
|
||||
// Model-specific MLA dimensions from tiling data
|
||||
uint32_t mm1OutSize_;
|
||||
uint32_t splitSizeOne_;
|
||||
uint32_t splitSizeTwo_;
|
||||
uint32_t splitRmsNormSizeOne_;
|
||||
uint32_t splitRmsNormSizeTwo_;
|
||||
uint32_t ropeSplitSizeOne_;
|
||||
uint32_t ropeSplitSizeTwo_;
|
||||
uint32_t hiddenStrideRope_;
|
||||
uint32_t qkNopeHeadDim_;
|
||||
|
||||
constexpr static uint32_t C0_SIZE = 16;
|
||||
constexpr static uint32_t I8_C0_SIZE = 32;
|
||||
|
||||
@@ -2505,44 +2533,44 @@ private:
|
||||
AscendC::LocalTensor<half> &tmpfp16, AscendC::LocalTensor<int8_t> &int8OutTensor, float quantScale3)
|
||||
{
|
||||
int64_t slotMapGmOffset = vectorBlockIdx * row_work;
|
||||
AscendC::DataCopy(gammaTensor, gamma3GmTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
AscendC::DataCopy(gammaTensor, gamma3GmTensor, splitRmsNormSizeOne_);
|
||||
SET_FLAG(MTE2, V, EVENT_ID1);
|
||||
WAIT_FLAG(MTE2, V, EVENT_ID1);
|
||||
Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_);
|
||||
AscendC::DataCopyPad(slotMappingTensor, slotMappingGmTensor[slotMapGmOffset],
|
||||
AscendC::DataCopyExtParams(1, sN * sizeof(int32_t), 0, 0, 0),
|
||||
AscendC::DataCopyPadExtParams<int32_t>(false, 0, 8 - sN % 8, 0));
|
||||
if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
|
||||
mmTensor = calTensor.ReinterpretCast<int32_t>()[SPLIT_SIZE_ONE];
|
||||
deScaleTensor = calTensor.ReinterpretCast<float>()[SPLIT_SIZE_ONE * 2];
|
||||
AscendC::DataCopy(deScaleTensor, descale1gmTensor, AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0));
|
||||
mmTensor = calTensor.ReinterpretCast<int32_t>()[splitSizeOne_];
|
||||
deScaleTensor = calTensor.ReinterpretCast<float>()[splitSizeOne_ * 2];
|
||||
AscendC::DataCopy(deScaleTensor, descale1gmTensor, AscendC::DataCopyParams(1, splitSizeOne_ / 8, 0, 0));
|
||||
}
|
||||
SET_FLAG(MTE2, V, EVENT_ID2);
|
||||
WAIT_FLAG(MTE2, V, EVENT_ID2);
|
||||
SET_FLAG(MTE2, S, EVENT_ID2);
|
||||
WAIT_FLAG(MTE2, S, EVENT_ID2);
|
||||
for (uint64_t loop = 0; loop < sN; ++loop) {
|
||||
uint64_t offset = vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2 + loop * MM1_OUT_SIZE;
|
||||
uint64_t offset = vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2 + loop * mm1OutSize_;
|
||||
int64_t slotValue = static_cast<int64_t>(slotMappingTensor.GetValue(loop));
|
||||
if (slotValue == -1) {
|
||||
continue;
|
||||
}
|
||||
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
|
||||
AscendC::DataCopy(srcTensor, s3GmTensor[offset],
|
||||
AscendC::DataCopyParams(1, MM1_OUT_SIZE / BLOCK_SIZE_16, 0, 0));
|
||||
AscendC::DataCopyParams(1, mm1OutSize_ / BLOCK_SIZE_16, 0, 0));
|
||||
} else {
|
||||
// quantMode == QuantMode::PER_TOKEN_SYMM_QUANT
|
||||
AscendC::DataCopy(mmTensor, s2GmTensor[offset], AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0));
|
||||
AscendC::DataCopy(mmTensor, s2GmTensor[offset], AscendC::DataCopyParams(1, splitSizeOne_ / 8, 0, 0));
|
||||
}
|
||||
AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO],
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO],
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_],
|
||||
splitRmsNormSizeTwo_);
|
||||
AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_],
|
||||
splitRmsNormSizeTwo_);
|
||||
SET_FLAG(MTE2, V, EVENT_ID0);
|
||||
// ND
|
||||
uint64_t cacheStart = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_SIZE_ONE);
|
||||
uint64_t cacheStart1 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_RMSNRORM_SIZE_ONE);
|
||||
uint64_t cacheStart2 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_RMSNRORM_SIZE_TWO);
|
||||
uint64_t cacheStart = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(splitSizeOne_);
|
||||
uint64_t cacheStart1 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(splitRmsNormSizeOne_);
|
||||
uint64_t cacheStart2 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(splitRmsNormSizeTwo_);
|
||||
// NZ
|
||||
uint32_t outer_idx = slotValue / 128;
|
||||
uint32_t inner_idx = slotValue % 128;
|
||||
@@ -2553,84 +2581,84 @@ private:
|
||||
if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
|
||||
/* DeQuant */
|
||||
AscendC::Cast(mmTensor.ReinterpretCast<float>(), mmTensor, AscendC::RoundMode::CAST_NONE,
|
||||
SPLIT_SIZE_ONE);
|
||||
splitSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Mul(mmTensor.ReinterpretCast<float>(), mmTensor.ReinterpretCast<float>(), deScaleTensor,
|
||||
SPLIT_SIZE_ONE);
|
||||
splitSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
float perTokenDescale = s5GmTensor.GetValue(row_work * vectorBlockIdx + loop);
|
||||
SET_FLAG(S, V, EVENT_ID0);
|
||||
WAIT_FLAG(S, V, EVENT_ID0);
|
||||
AscendC::Muls(mmTensor.ReinterpretCast<float>(), mmTensor.ReinterpretCast<float>(), perTokenDescale,
|
||||
SPLIT_SIZE_ONE);
|
||||
splitSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Cast(srcTensor, mmTensor.ReinterpretCast<float>(), AscendC::RoundMode::CAST_RINT,
|
||||
SPLIT_SIZE_ONE);
|
||||
splitSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(calTensor, rmsNormTensor, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Mul(calTensor, rmsNormTensor, rmsNormTensor, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
ReduceSumCustom(calTensor[SPLIT_RMSNRORM_SIZE_ONE], calTensor, calTensor[SPLIT_RMSNRORM_SIZE_ONE * 2],
|
||||
SPLIT_RMSNRORM_SIZE_ONE);
|
||||
ReduceSumCustom(calTensor[splitRmsNormSizeOne_], calTensor, calTensor[splitRmsNormSizeOne_ * 2],
|
||||
splitRmsNormSizeOne_);
|
||||
SET_FLAG(V, S, EVENT_ID1);
|
||||
WAIT_FLAG(V, S, EVENT_ID1);
|
||||
float rms = sqrt(calTensor.GetValue(SPLIT_RMSNRORM_SIZE_ONE) / SPLIT_RMSNRORM_SIZE_ONE + epsilon_);
|
||||
float rms = sqrt(calTensor.GetValue(splitRmsNormSizeOne_) / splitRmsNormSizeOne_ + epsilon_);
|
||||
SET_FLAG(S, V, EVENT_ID1);
|
||||
WAIT_FLAG(S, V, EVENT_ID1);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Duplicate(calTensor, rms, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Duplicate(calTensor, rms, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Div(calTensor, rmsNormTensor, calTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Div(calTensor, rmsNormTensor, calTensor, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(rmsNormTensor, gammaFp32, calTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Mul(rmsNormTensor, gammaFp32, calTensor, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) {
|
||||
// quant
|
||||
Muls(rmsNormTensor, rmsNormTensor, quantScale3, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Muls(rmsNormTensor, rmsNormTensor, quantScale3, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFrom32To16(tmpfp16, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
CastFrom32To16(tmpfp16, rmsNormTensor, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFromF16ToI8(int8OutTensor, tmpfp16, -128, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
CastFromF16ToI8(int8OutTensor, tmpfp16, -128, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
} else {
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
if (std::is_same<T1, __bf16>::value) {
|
||||
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, splitRmsNormSizeOne_);
|
||||
} else {
|
||||
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_);
|
||||
}
|
||||
}
|
||||
/* RmsNorm end */
|
||||
/* Rope K start */
|
||||
uint64_t revertOffset = SPLIT_RMSNRORM_SIZE_TWO / 2;
|
||||
Cast(ropeKTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE,
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(ropeKRevertTensor[revertOffset], srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE,
|
||||
uint64_t revertOffset = splitRmsNormSizeTwo_ / 2;
|
||||
Cast(ropeKTensor, srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE,
|
||||
splitRmsNormSizeTwo_);
|
||||
Cast(ropeKRevertTensor[revertOffset], srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE,
|
||||
revertOffset);
|
||||
Cast(ropeKRevertTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE + revertOffset], AscendC::RoundMode::CAST_NONE,
|
||||
Cast(ropeKRevertTensor, srcTensor[splitRmsNormSizeOne_ + revertOffset], AscendC::RoundMode::CAST_NONE,
|
||||
revertOffset);
|
||||
Duplicate(calTensor, static_cast<float>(-1), revertOffset);
|
||||
Duplicate(calTensor[revertOffset], static_cast<float>(1), revertOffset);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO], cosTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], sinTensor, AscendC::RoundMode::CAST_NONE,
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(calTensor[splitRmsNormSizeTwo_], cosTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeTwo_);
|
||||
Cast(calTensor[splitRmsNormSizeTwo_ * 2], sinTensor, AscendC::RoundMode::CAST_NONE,
|
||||
splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(ropeKTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO], ropeKTensor, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Mul(ropeKRevertTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Mul(ropeKTensor, calTensor[splitRmsNormSizeTwo_], ropeKTensor, splitRmsNormSizeTwo_);
|
||||
Mul(ropeKRevertTensor, calTensor[splitRmsNormSizeTwo_ * 2], ropeKRevertTensor, splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
if (std::is_same<T1, __bf16>::value) {
|
||||
Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT,
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(outTmpTensor[splitRmsNormSizeOne_], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT,
|
||||
splitRmsNormSizeTwo_);
|
||||
} else {
|
||||
Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE,
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(outTmpTensor[splitRmsNormSizeOne_], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE,
|
||||
splitRmsNormSizeTwo_);
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
/* Rope K end */
|
||||
@@ -2638,45 +2666,45 @@ private:
|
||||
WAIT_FLAG(V, MTE3, EVENT_ID0);
|
||||
WAIT_FLAG(S, MTE3, EVENT_ID0);
|
||||
if constexpr (CACHE_MODE == CACHE_MODE_KVCACHE) {
|
||||
DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, SPLIT_SIZE_ONE);
|
||||
DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, splitSizeOne_);
|
||||
} else if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) {
|
||||
uint64_t cacheSatartI8Nz1 = outer_idx * 128 * 512 + inner_idx * I8_C0_SIZE;
|
||||
uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE;
|
||||
// nope:int8 nz
|
||||
AscendC::DataCopyExtParams outExt;
|
||||
outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / I8_C0_SIZE;
|
||||
outExt.blockCount = splitRmsNormSizeOne_ / I8_C0_SIZE;
|
||||
outExt.blockLen = I8_C0_SIZE * sizeof(int8_t);
|
||||
outExt.srcStride = 0;
|
||||
outExt.dstStride = (128 * I8_C0_SIZE - I8_C0_SIZE) * sizeof(int8_t);
|
||||
DataCopyPad(keycacheGmTensor1[cacheSatartI8Nz1], int8OutTensor, outExt);
|
||||
// rope:T1 nz
|
||||
outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE;
|
||||
outExt.blockCount = splitRmsNormSizeTwo_ / C0_SIZE;
|
||||
outExt.blockLen = C0_SIZE * sizeof(T1);
|
||||
outExt.srcStride = 0;
|
||||
outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1);
|
||||
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt);
|
||||
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[splitRmsNormSizeOne_], outExt);
|
||||
} else if constexpr (CACHE_MODE == CACHE_MODE_NZCACHE) {
|
||||
uint64_t cacheSatartNz1 = outer_idx * 128 * 512 + inner_idx * C0_SIZE;
|
||||
uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE;
|
||||
// nope:T1 nz
|
||||
AscendC::DataCopyExtParams outExt;
|
||||
outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / C0_SIZE;
|
||||
outExt.blockCount = splitRmsNormSizeOne_ / C0_SIZE;
|
||||
outExt.blockLen = C0_SIZE * sizeof(T1);
|
||||
outExt.srcStride = 0;
|
||||
outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1);
|
||||
DataCopyPad(keycacheGmTensor1[cacheSatartNz1], outTmpTensor, outExt);
|
||||
// rope:T1 nz
|
||||
outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE;
|
||||
outExt.blockCount = splitRmsNormSizeTwo_ / C0_SIZE;
|
||||
outExt.blockLen = C0_SIZE * sizeof(T1);
|
||||
outExt.srcStride = 0;
|
||||
outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1);
|
||||
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt);
|
||||
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[splitRmsNormSizeOne_], outExt);
|
||||
} else {
|
||||
// keycache1
|
||||
DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, splitRmsNormSizeOne_);
|
||||
// keycache2
|
||||
DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE],
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[splitRmsNormSizeOne_],
|
||||
splitRmsNormSizeTwo_);
|
||||
}
|
||||
SET_FLAG(MTE3, MTE2, EVENT_ID1);
|
||||
WAIT_FLAG(MTE3, MTE2, EVENT_ID1);
|
||||
@@ -2823,20 +2851,20 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
|
||||
uint32_t num_col_align_f16 = (num_col_2 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
|
||||
uint32_t num_col_align_f32 = (num_col_2 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
|
||||
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
|
||||
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2);
|
||||
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2);
|
||||
AscendC::LocalTensor<InDtype> beta_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2);
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2 + splitSizeTwo_ * 2);
|
||||
AscendC::LocalTensor<InDtype> scale_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2);
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2);
|
||||
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 32);
|
||||
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 32);
|
||||
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64);
|
||||
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64);
|
||||
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4);
|
||||
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4);
|
||||
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4 +
|
||||
BUF_FACTOR * num_col_align_f32 * 4 + 64 + MM1_OUT_SIZE * 4 * 2 + 32);
|
||||
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4 +
|
||||
BUF_FACTOR * num_col_align_f32 * 4 + 64 + mm1OutSize_ * 4 * 2 + 32);
|
||||
rmsNormQuant2.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor,
|
||||
res1_tensor, res3_tensor);
|
||||
}
|
||||
@@ -2846,20 +2874,20 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
|
||||
|
||||
if (row_work_ != 0) {
|
||||
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
|
||||
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2);
|
||||
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2);
|
||||
AscendC::LocalTensor<InDtype> sin_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2);
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2);
|
||||
AscendC::LocalTensor<InDtype> cos_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 2);
|
||||
mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 2);
|
||||
AscendC::LocalTensor<int32_t> slotMapping_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int32_t>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4);
|
||||
mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4);
|
||||
int32_t rms3_ub_offset =
|
||||
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + 4096 * 32;
|
||||
mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 + 4096 * 32;
|
||||
AscendC::LocalTensor<float> tmp32_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(rms3_ub_offset);
|
||||
|
||||
int32_t out_ub_offset = MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 +
|
||||
4096 * 32 + SPLIT_RMSNRORM_SIZE_ONE * 3 * 4 + SPLIT_RMSNRORM_SIZE_TWO * 2 * 4 +
|
||||
MM1_OUT_SIZE * 4 * 2 + 32;
|
||||
int32_t out_ub_offset = mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 +
|
||||
4096 * 32 + splitRmsNormSizeOne_ * 3 * 4 + splitRmsNormSizeTwo_ * 2 * 4 +
|
||||
mm1OutSize_ * 4 * 2 + 32;
|
||||
AscendC::LocalTensor<InDtype> temp_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(out_ub_offset);
|
||||
|
||||
AscendC::LocalTensor<half> tmpfp16;
|
||||
@@ -2873,7 +2901,7 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, float>(rms3_ub_offset + 32);
|
||||
// int8out
|
||||
tmpfp16 = buf.GetBuffer<BufferType::ASCEND_UB, half>(rms3_ub_offset +
|
||||
SPLIT_RMSNRORM_SIZE_ONE * sizeof(float) * 2);
|
||||
splitRmsNormSizeOne_ * sizeof(float) * 2);
|
||||
int8OutTensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(out_ub_offset);
|
||||
AscendC::DataCopy(quantScaleTensor, quantScale3GmTensor, AscendC::DataCopyParams(1, 1, 0, 0));
|
||||
SET_FLAG(MTE2, V, EVENT_ID1);
|
||||
@@ -2890,11 +2918,11 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
|
||||
sin_tensor, // sin
|
||||
cos_tensor, // cons
|
||||
slotMapping_tensor, // slotMapping
|
||||
row_work_, tmp32_tensor, tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE],
|
||||
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE],
|
||||
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO],
|
||||
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO +
|
||||
SPLIT_RMSNRORM_SIZE_TWO],
|
||||
row_work_, tmp32_tensor, tmp32_tensor[splitRmsNormSizeOne_],
|
||||
tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_],
|
||||
tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_ + splitRmsNormSizeTwo_],
|
||||
tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_ + splitRmsNormSizeTwo_ +
|
||||
splitRmsNormSizeTwo_],
|
||||
temp_tensor, tmpfp16, int8OutTensor, scale3);
|
||||
}
|
||||
mm_w8a8_aiv_2.Process();
|
||||
|
||||
@@ -54,6 +54,8 @@ public:
|
||||
lastCoreLoopTime = ropeConcatParams.lastCoreLoopTime;
|
||||
lastCoreLoopNLast = ropeConcatParams.lastCoreLoopNLast;
|
||||
concatSize = ropeConcatParams.concatSize;
|
||||
hiddenStrideRope_ = ropeConcatParams.hiddenStrideRope;
|
||||
qkNopeHeadDim_ = ropeConcatParams.qkNopeHeadDim;
|
||||
blockIdx_ = (blockIdx_ / 2) * 2 + static_cast<uint64_t>(GetSubBlockidx());
|
||||
loopTime = (blockIdx_ == realCore - 1) ? lastCoreLoopTime : preCoreLoopTime;
|
||||
lastLoopN = (blockIdx_ == realCore - 1) ? lastCoreLoopNLast : preCoreLoopNLast;
|
||||
@@ -92,7 +94,7 @@ public:
|
||||
AscendC::LocalTensor<float> inputQCastFP32 = buf.GetBuffer<BufferType::ASCEND_UB, float>(dataSizeFp16);
|
||||
AscendC::LocalTensor<float> reverseQ =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, float>(dataSizeFp32 + dataSizeFp16);
|
||||
uint64_t qOffset = startHead * 192 + 128;
|
||||
uint64_t qOffset = startHead * hiddenStrideRope_ + qkNopeHeadDim_;
|
||||
CopyQGenReverseQ(inputQ, inputQCastFP32, reverseQ, qOffset, loopN);
|
||||
|
||||
// move in cos/sin
|
||||
@@ -178,7 +180,7 @@ public:
|
||||
{
|
||||
// move in Q
|
||||
WAIT_FLAG(MTE3, MTE2, EVENT_ID1);
|
||||
AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, 128 / 16, 0});
|
||||
AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, static_cast<uint16_t>(qkNopeHeadDim_ / 16), 0});
|
||||
SET_FLAG(MTE2, V, EVENT_ID1);
|
||||
WAIT_FLAG(MTE2, V, EVENT_ID1);
|
||||
// cast fp32
|
||||
@@ -230,6 +232,8 @@ private:
|
||||
uint32_t lastCoreLoopTime;
|
||||
uint32_t lastCoreLoopNLast;
|
||||
uint32_t concatSize;
|
||||
uint32_t hiddenStrideRope_;
|
||||
uint32_t qkNopeHeadDim_;
|
||||
uint32_t blockIdx_;
|
||||
uint32_t loopTime{0};
|
||||
uint32_t lastLoopN{0};
|
||||
@@ -936,6 +940,16 @@ public:
|
||||
this->num_row = mlaParams_.n;
|
||||
this->epsilon_ = 1e-6;
|
||||
this->mlaParams = mlaParams_;
|
||||
this->hiddenStateDim = mlaParams_.hiddenStateDim;
|
||||
this->mm1OutSize_ = mlaParams_.mm1OutSize;
|
||||
this->splitSizeOne_ = mlaParams_.splitSizeOne;
|
||||
this->splitSizeTwo_ = mlaParams_.splitSizeTwo;
|
||||
this->splitRmsNormSizeOne_ = mlaParams_.splitRmsNormSizeOne;
|
||||
this->splitRmsNormSizeTwo_ = mlaParams_.splitRmsNormSizeTwo;
|
||||
this->ropeSplitSizeOne_ = mlaParams_.ropeSplitSizeOne;
|
||||
this->ropeSplitSizeTwo_ = mlaParams_.ropeSplitSizeTwo;
|
||||
this->hiddenStrideRope_ = mlaParams_.hiddenStrideRope;
|
||||
this->qkNopeHeadDim_ = mlaParams_.qkNopeHeadDim;
|
||||
}
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR wdqkvGm, GM_ADDR gamma2Gm,
|
||||
@@ -984,9 +998,9 @@ public:
|
||||
row_work_ = 0;
|
||||
}
|
||||
this->splitN = mlaParams.perTaskNum;
|
||||
rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, s3Gm, s1Gm, SPLIT_SIZE_ONE,
|
||||
num_col_2, 0.000651041666, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams);
|
||||
rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, s3Gm, s1Gm, splitSizeOne_,
|
||||
num_col_2, mlaParams.avgFactor, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * splitSizeTwo_, row_work_, mlaParams);
|
||||
ropeFp16.RopeInit(s2Gm, cos2GmTensor, sin2GmTensor, qGmTensor, qGmTensor2, mlaParams);
|
||||
#endif
|
||||
}
|
||||
@@ -996,6 +1010,18 @@ public:
|
||||
__aicore__ inline void ProcessVector();
|
||||
|
||||
private:
|
||||
// Model-specific MLA dimensions from tiling data
|
||||
uint32_t hiddenStateDim;
|
||||
uint32_t mm1OutSize_;
|
||||
uint32_t splitSizeOne_;
|
||||
uint32_t splitSizeTwo_;
|
||||
uint32_t splitRmsNormSizeOne_;
|
||||
uint32_t splitRmsNormSizeTwo_;
|
||||
uint32_t ropeSplitSizeOne_;
|
||||
uint32_t ropeSplitSizeTwo_;
|
||||
uint32_t hiddenStrideRope_;
|
||||
uint32_t qkNopeHeadDim_;
|
||||
|
||||
constexpr static uint32_t C0_SIZE = 16;
|
||||
constexpr static uint32_t I8_C0_SIZE = 32;
|
||||
|
||||
@@ -1009,10 +1035,10 @@ private:
|
||||
const AscendC::LocalTensor<float> &calTensor, const AscendC::LocalTensor<T1> &outTmpTensor)
|
||||
{
|
||||
int64_t slotMapGmOffset = vectorBlockIdx * row_work;
|
||||
AscendC::DataCopy(gammaTensor, gamma3GmTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
AscendC::DataCopy(gammaTensor, gamma3GmTensor, splitRmsNormSizeOne_);
|
||||
SET_FLAG(MTE2, V, EVENT_ID1);
|
||||
WAIT_FLAG(MTE2, V, EVENT_ID1);
|
||||
Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_);
|
||||
AscendC::DataCopyPad(slotMappingTensor, slotMappingGmTensor[slotMapGmOffset],
|
||||
AscendC::DataCopyExtParams(1, sN * sizeof(int32_t), 0, 0, 0),
|
||||
AscendC::DataCopyPadExtParams<int32_t>(false, 0, 8 - sN % 8, 0));
|
||||
@@ -1021,22 +1047,22 @@ private:
|
||||
SET_FLAG(MTE2, S, EVENT_ID2);
|
||||
WAIT_FLAG(MTE2, S, EVENT_ID2);
|
||||
for (uint64_t loop = 0; loop < sN; ++loop) {
|
||||
uint64_t offset = vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2 + loop * MM1_OUT_SIZE;
|
||||
uint64_t offset = vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2 + loop * mm1OutSize_;
|
||||
int64_t slotValue = static_cast<int64_t>(slotMappingTensor.GetValue(loop));
|
||||
if (slotValue == -1) {
|
||||
continue;
|
||||
}
|
||||
AscendC::DataCopy(srcTensor, s3GmTensor[offset],
|
||||
AscendC::DataCopyParams(1, MM1_OUT_SIZE / BLOCK_SIZE_16, 0, 0));
|
||||
AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO],
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO],
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
AscendC::DataCopyParams(1, mm1OutSize_ / BLOCK_SIZE_16, 0, 0));
|
||||
AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_],
|
||||
splitRmsNormSizeTwo_);
|
||||
AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_],
|
||||
splitRmsNormSizeTwo_);
|
||||
SET_FLAG(MTE2, V, EVENT_ID0);
|
||||
// ND
|
||||
uint64_t cacheStart = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_SIZE_ONE);
|
||||
uint64_t cacheStart1 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_RMSNRORM_SIZE_ONE);
|
||||
uint64_t cacheStart2 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_RMSNRORM_SIZE_TWO);
|
||||
uint64_t cacheStart = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(splitSizeOne_);
|
||||
uint64_t cacheStart1 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(splitRmsNormSizeOne_);
|
||||
uint64_t cacheStart2 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(splitRmsNormSizeTwo_);
|
||||
// NZ
|
||||
uint32_t outer_idx = slotValue / 128;
|
||||
uint32_t inner_idx = slotValue % 128;
|
||||
@@ -1044,63 +1070,63 @@ private:
|
||||
SET_FLAG(S, MTE3, EVENT_ID0);
|
||||
/* RmsNorm start */
|
||||
WAIT_FLAG(MTE2, V, EVENT_ID0);
|
||||
Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(calTensor, rmsNormTensor, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Mul(calTensor, rmsNormTensor, rmsNormTensor, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
ReduceSumCustom(calTensor[SPLIT_RMSNRORM_SIZE_ONE], calTensor, calTensor[SPLIT_RMSNRORM_SIZE_ONE * 2],
|
||||
SPLIT_RMSNRORM_SIZE_ONE);
|
||||
ReduceSumCustom(calTensor[splitRmsNormSizeOne_], calTensor, calTensor[splitRmsNormSizeOne_ * 2],
|
||||
splitRmsNormSizeOne_);
|
||||
SET_FLAG(V, S, EVENT_ID1);
|
||||
WAIT_FLAG(V, S, EVENT_ID1);
|
||||
float rms = sqrt(calTensor.GetValue(SPLIT_RMSNRORM_SIZE_ONE) / SPLIT_RMSNRORM_SIZE_ONE + epsilon_);
|
||||
float rms = sqrt(calTensor.GetValue(splitRmsNormSizeOne_) / splitRmsNormSizeOne_ + epsilon_);
|
||||
SET_FLAG(S, V, EVENT_ID1);
|
||||
WAIT_FLAG(S, V, EVENT_ID1);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Duplicate(calTensor, rms, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Duplicate(calTensor, rms, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Div(calTensor, rmsNormTensor, calTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Div(calTensor, rmsNormTensor, calTensor, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(rmsNormTensor, gammaFp32, calTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Mul(rmsNormTensor, gammaFp32, calTensor, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, splitRmsNormSizeOne_);
|
||||
/* RmsNorm end */
|
||||
/* Rope K start */
|
||||
uint64_t revertOffset = SPLIT_RMSNRORM_SIZE_TWO / 2;
|
||||
Cast(ropeKTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE,
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(ropeKRevertTensor[revertOffset], srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE,
|
||||
uint64_t revertOffset = splitRmsNormSizeTwo_ / 2;
|
||||
Cast(ropeKTensor, srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE,
|
||||
splitRmsNormSizeTwo_);
|
||||
Cast(ropeKRevertTensor[revertOffset], srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE,
|
||||
revertOffset);
|
||||
Cast(ropeKRevertTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE + revertOffset], AscendC::RoundMode::CAST_NONE,
|
||||
Cast(ropeKRevertTensor, srcTensor[splitRmsNormSizeOne_ + revertOffset], AscendC::RoundMode::CAST_NONE,
|
||||
revertOffset);
|
||||
Duplicate(calTensor, static_cast<float>(-1), revertOffset);
|
||||
Duplicate(calTensor[revertOffset], static_cast<float>(1), revertOffset);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO], cosTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], sinTensor, AscendC::RoundMode::CAST_NONE,
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(calTensor[splitRmsNormSizeTwo_], cosTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeTwo_);
|
||||
Cast(calTensor[splitRmsNormSizeTwo_ * 2], sinTensor, AscendC::RoundMode::CAST_NONE,
|
||||
splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(ropeKTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO], ropeKTensor, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Mul(ropeKRevertTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Mul(ropeKTensor, calTensor[splitRmsNormSizeTwo_], ropeKTensor, splitRmsNormSizeTwo_);
|
||||
Mul(ropeKRevertTensor, calTensor[splitRmsNormSizeTwo_ * 2], ropeKRevertTensor, splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT,
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(outTmpTensor[splitRmsNormSizeOne_], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT,
|
||||
splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
/* Rope K end */
|
||||
SET_FLAG(V, MTE3, EVENT_ID0);
|
||||
WAIT_FLAG(V, MTE3, EVENT_ID0);
|
||||
WAIT_FLAG(S, MTE3, EVENT_ID0);
|
||||
if constexpr (CACHE_MODE == CACHE_MODE_KVCACHE) {
|
||||
DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, SPLIT_SIZE_ONE);
|
||||
DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, splitSizeOne_);
|
||||
} else {
|
||||
// keycache1
|
||||
DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, splitRmsNormSizeOne_);
|
||||
// keycache2
|
||||
DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE],
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[splitRmsNormSizeOne_],
|
||||
splitRmsNormSizeTwo_);
|
||||
}
|
||||
SET_FLAG(MTE3, MTE2, EVENT_ID1);
|
||||
WAIT_FLAG(MTE3, MTE2, EVENT_ID1);
|
||||
@@ -1196,16 +1222,16 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3>::
|
||||
uint32_t num_col_align_f16 = (num_col_2 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
|
||||
uint32_t num_col_align_f32 = (num_col_2 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
|
||||
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
|
||||
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2);
|
||||
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2);
|
||||
AscendC::LocalTensor<InDtype> beta_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2);
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2 + splitSizeTwo_ * 2);
|
||||
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64);
|
||||
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64);
|
||||
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4);
|
||||
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4);
|
||||
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4 +
|
||||
BUF_FACTOR * num_col_align_f32 * 4 + 64 + MM1_OUT_SIZE * 4 * 2 + 32);
|
||||
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4 +
|
||||
BUF_FACTOR * num_col_align_f32 * 4 + 64 + mm1OutSize_ * 4 * 2 + 32);
|
||||
rmsNormQuant2.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, res1_tensor, res3_tensor);
|
||||
}
|
||||
FftsCrossCoreSync<PIPE_MTE3, 0>(RMSNORMQUANT2);
|
||||
@@ -1214,20 +1240,20 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3>::
|
||||
|
||||
if (row_work_ != 0) {
|
||||
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
|
||||
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2);
|
||||
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2);
|
||||
AscendC::LocalTensor<InDtype> sin_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2);
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2);
|
||||
AscendC::LocalTensor<InDtype> cos_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 2);
|
||||
mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 2);
|
||||
AscendC::LocalTensor<int32_t> slotMapping_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int32_t>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4);
|
||||
mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4);
|
||||
int32_t rms3_ub_offset =
|
||||
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + 4096 * 32;
|
||||
mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 + 4096 * 32;
|
||||
AscendC::LocalTensor<float> tmp32_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(rms3_ub_offset);
|
||||
|
||||
int32_t out_ub_offset = MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 +
|
||||
4096 * 32 + SPLIT_RMSNRORM_SIZE_ONE * 3 * 4 + SPLIT_RMSNRORM_SIZE_TWO * 2 * 4 +
|
||||
MM1_OUT_SIZE * 4 * 2 + 32;
|
||||
int32_t out_ub_offset = mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 +
|
||||
4096 * 32 + splitRmsNormSizeOne_ * 3 * 4 + splitRmsNormSizeTwo_ * 2 * 4 +
|
||||
mm1OutSize_ * 4 * 2 + 32;
|
||||
AscendC::LocalTensor<InDtype> temp_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(out_ub_offset);
|
||||
|
||||
RmsNormAndRopeConvergence1<InDtype>(
|
||||
@@ -1236,11 +1262,11 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3>::
|
||||
sin_tensor, // sin
|
||||
cos_tensor, // cons
|
||||
slotMapping_tensor, // slotMapping
|
||||
row_work_, tmp32_tensor, tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE],
|
||||
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE],
|
||||
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO],
|
||||
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO +
|
||||
SPLIT_RMSNRORM_SIZE_TWO],
|
||||
row_work_, tmp32_tensor, tmp32_tensor[splitRmsNormSizeOne_],
|
||||
tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_],
|
||||
tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_ + splitRmsNormSizeTwo_],
|
||||
tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_ + splitRmsNormSizeTwo_ +
|
||||
splitRmsNormSizeTwo_],
|
||||
temp_tensor);
|
||||
}
|
||||
|
||||
|
||||
@@ -54,6 +54,8 @@ public:
|
||||
lastCoreLoopTime = ropeConcatParams.lastCoreLoopTime;
|
||||
lastCoreLoopNLast = ropeConcatParams.lastCoreLoopNLast;
|
||||
concatSize = ropeConcatParams.concatSize;
|
||||
hiddenStrideRope_ = ropeConcatParams.hiddenStrideRope;
|
||||
qkNopeHeadDim_ = ropeConcatParams.qkNopeHeadDim;
|
||||
blockIdx_ = (blockIdx_ / 2) * 2 + static_cast<uint64_t>(GetSubBlockidx());
|
||||
loopTime = (blockIdx_ == realCore - 1) ? lastCoreLoopTime : preCoreLoopTime;
|
||||
lastLoopN = (blockIdx_ == realCore - 1) ? lastCoreLoopNLast : preCoreLoopNLast;
|
||||
@@ -92,7 +94,7 @@ public:
|
||||
AscendC::LocalTensor<float> inputQCastFP32 = buf.GetBuffer<BufferType::ASCEND_UB, float>(dataSizeFp16);
|
||||
AscendC::LocalTensor<float> reverseQ =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, float>(dataSizeFp32 + dataSizeFp16);
|
||||
uint64_t qOffset = startHead * 192 + 128;
|
||||
uint64_t qOffset = startHead * hiddenStrideRope_ + qkNopeHeadDim_;
|
||||
CopyQGenReverseQ(inputQ, inputQCastFP32, reverseQ, qOffset, loopN);
|
||||
|
||||
// move in cos/sin
|
||||
@@ -178,7 +180,7 @@ public:
|
||||
{
|
||||
// move in Q
|
||||
WAIT_FLAG(MTE3, MTE2, EVENT_ID1);
|
||||
AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, 128 / 16, 0});
|
||||
AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, static_cast<uint16_t>(qkNopeHeadDim_ / 16), 0});
|
||||
SET_FLAG(MTE2, V, EVENT_ID1);
|
||||
WAIT_FLAG(MTE2, V, EVENT_ID1);
|
||||
// cast fp32
|
||||
@@ -230,6 +232,8 @@ private:
|
||||
uint32_t lastCoreLoopTime;
|
||||
uint32_t lastCoreLoopNLast;
|
||||
uint32_t concatSize;
|
||||
uint32_t hiddenStrideRope_;
|
||||
uint32_t qkNopeHeadDim_;
|
||||
uint32_t blockIdx_;
|
||||
uint32_t loopTime{0};
|
||||
uint32_t lastLoopN{0};
|
||||
@@ -309,6 +313,7 @@ public:
|
||||
this->num_row_ = mlaParams_.n;
|
||||
this->row_work = row_work;
|
||||
this->row_work_ = row_work_;
|
||||
this->mm1OutSize_ = mlaParams_.mm1OutSize;
|
||||
gm_offset_ = gm_offset;
|
||||
gm_out_offset_ = gm_out_offset;
|
||||
num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
|
||||
@@ -353,8 +358,8 @@ public:
|
||||
|
||||
if constexpr (NEED_DEQUANT) {
|
||||
mmTensor = buf.ReinterpretCast<int32_t>()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16];
|
||||
deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE];
|
||||
perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2];
|
||||
deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_];
|
||||
perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_ * 2];
|
||||
AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0));
|
||||
}
|
||||
|
||||
@@ -531,6 +536,7 @@ private:
|
||||
uint32_t num_slice_{0};
|
||||
uint32_t tail_size_{0};
|
||||
uint32_t tail_copy_{0};
|
||||
uint32_t mm1OutSize_{0};
|
||||
};
|
||||
|
||||
template <typename T, bool WITH_BETA, bool FastComputeMode = false,
|
||||
@@ -568,6 +574,7 @@ public:
|
||||
this->num_row_ = mlaParams_.n;
|
||||
this->row_work = row_work;
|
||||
this->row_work_ = row_work_;
|
||||
this->mm1OutSize_ = mlaParams_.mm1OutSize;
|
||||
gm_offset_ = gm_offset;
|
||||
gm_out_offset_ = gm_out_offset;
|
||||
num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
|
||||
@@ -620,8 +627,8 @@ public:
|
||||
|
||||
if constexpr (NEED_DEQUANT) {
|
||||
mmTensor = buf.ReinterpretCast<int32_t>()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16];
|
||||
deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE];
|
||||
perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2];
|
||||
deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_];
|
||||
perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + mm1OutSize_ * 2];
|
||||
AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0));
|
||||
}
|
||||
|
||||
@@ -857,6 +864,7 @@ private:
|
||||
uint32_t num_slice_{0};
|
||||
uint32_t tail_size_{0};
|
||||
uint32_t tail_copy_{0};
|
||||
uint32_t mm1OutSize_{0};
|
||||
};
|
||||
|
||||
template <typename InDtype, typename ScaleDtype>
|
||||
@@ -2407,6 +2415,15 @@ public:
|
||||
this->epsilon_ = 1e-6;
|
||||
this->mlaParams = mlaParams_;
|
||||
this->hiddenStateDim = mlaParams_.hiddenStateDim;
|
||||
this->mm1OutSize_ = mlaParams_.mm1OutSize;
|
||||
this->splitSizeOne_ = mlaParams_.splitSizeOne;
|
||||
this->splitSizeTwo_ = mlaParams_.splitSizeTwo;
|
||||
this->splitRmsNormSizeOne_ = mlaParams_.splitRmsNormSizeOne;
|
||||
this->splitRmsNormSizeTwo_ = mlaParams_.splitRmsNormSizeTwo;
|
||||
this->ropeSplitSizeOne_ = mlaParams_.ropeSplitSizeOne;
|
||||
this->ropeSplitSizeTwo_ = mlaParams_.ropeSplitSizeTwo;
|
||||
this->hiddenStrideRope_ = mlaParams_.hiddenStrideRope;
|
||||
this->qkNopeHeadDim_ = mlaParams_.qkNopeHeadDim;
|
||||
}
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
|
||||
@@ -2492,15 +2509,15 @@ public:
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_1, row_work_, mlaParams);
|
||||
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
|
||||
rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor,
|
||||
s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s3Gm, s1Gm, SPLIT_SIZE_ONE,
|
||||
num_col_2, 0.000651041666, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams, innerGmTensor);
|
||||
s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s3Gm, s1Gm, splitSizeOne_,
|
||||
num_col_2, mlaParams.avgFactor, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * splitSizeTwo_, row_work_, mlaParams, innerGmTensor);
|
||||
} else {
|
||||
// quantMode == QuantMode::PER_TOKEN_SYMM_QUANT
|
||||
rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor,
|
||||
s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s2Gm, s1Gm, SPLIT_SIZE_ONE,
|
||||
num_col_2, 0.000651041666, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams, innerGmTensor);
|
||||
s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s2Gm, s1Gm, splitSizeOne_,
|
||||
num_col_2, mlaParams.avgFactor, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * splitSizeTwo_, row_work_, mlaParams, innerGmTensor);
|
||||
}
|
||||
ropeFp16.RopeInit(s4Gm, cos2GmTensor, sin2GmTensor, qGmTensor, qGmTensor2, mlaParams);
|
||||
einSumQuant.Init(s1Gm, gmQnopeScale, qGm, mlaParams);
|
||||
@@ -2512,6 +2529,17 @@ public:
|
||||
__aicore__ inline void ProcessVector();
|
||||
|
||||
private:
|
||||
// Model-specific MLA dimensions from tiling data
|
||||
uint32_t mm1OutSize_;
|
||||
uint32_t splitSizeOne_;
|
||||
uint32_t splitSizeTwo_;
|
||||
uint32_t splitRmsNormSizeOne_;
|
||||
uint32_t splitRmsNormSizeTwo_;
|
||||
uint32_t ropeSplitSizeOne_;
|
||||
uint32_t ropeSplitSizeTwo_;
|
||||
uint32_t hiddenStrideRope_;
|
||||
uint32_t qkNopeHeadDim_;
|
||||
|
||||
constexpr static uint32_t C0_SIZE = 16;
|
||||
constexpr static uint32_t I8_C0_SIZE = 32;
|
||||
|
||||
@@ -2526,44 +2554,44 @@ private:
|
||||
AscendC::LocalTensor<half> &tmpfp16, AscendC::LocalTensor<int8_t> &int8OutTensor, float quantScale3)
|
||||
{
|
||||
int64_t slotMapGmOffset = vectorBlockIdx * row_work;
|
||||
AscendC::DataCopy(gammaTensor, gamma3GmTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
AscendC::DataCopy(gammaTensor, gamma3GmTensor, splitRmsNormSizeOne_);
|
||||
SET_FLAG(MTE2, V, EVENT_ID1);
|
||||
WAIT_FLAG(MTE2, V, EVENT_ID1);
|
||||
Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_);
|
||||
AscendC::DataCopyPad(slotMappingTensor, slotMappingGmTensor[slotMapGmOffset],
|
||||
AscendC::DataCopyExtParams(1, sN * sizeof(int32_t), 0, 0, 0),
|
||||
AscendC::DataCopyPadExtParams<int32_t>(false, 0, 8 - sN % 8, 0));
|
||||
if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
|
||||
mmTensor = calTensor.ReinterpretCast<int32_t>()[SPLIT_SIZE_ONE];
|
||||
deScaleTensor = calTensor.ReinterpretCast<float>()[SPLIT_SIZE_ONE * 2];
|
||||
AscendC::DataCopy(deScaleTensor, descale1gmTensor, AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0));
|
||||
mmTensor = calTensor.ReinterpretCast<int32_t>()[splitSizeOne_];
|
||||
deScaleTensor = calTensor.ReinterpretCast<float>()[splitSizeOne_ * 2];
|
||||
AscendC::DataCopy(deScaleTensor, descale1gmTensor, AscendC::DataCopyParams(1, splitSizeOne_ / 8, 0, 0));
|
||||
}
|
||||
SET_FLAG(MTE2, V, EVENT_ID2);
|
||||
WAIT_FLAG(MTE2, V, EVENT_ID2);
|
||||
SET_FLAG(MTE2, S, EVENT_ID2);
|
||||
WAIT_FLAG(MTE2, S, EVENT_ID2);
|
||||
for (uint64_t loop = 0; loop < sN; ++loop) {
|
||||
uint64_t offset = vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2 + loop * MM1_OUT_SIZE;
|
||||
uint64_t offset = vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2 + loop * mm1OutSize_;
|
||||
int64_t slotValue = static_cast<int64_t>(slotMappingTensor.GetValue(loop));
|
||||
if (slotValue == -1) {
|
||||
continue;
|
||||
}
|
||||
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
|
||||
AscendC::DataCopy(srcTensor, s3GmTensor[offset],
|
||||
AscendC::DataCopyParams(1, MM1_OUT_SIZE / BLOCK_SIZE_16, 0, 0));
|
||||
AscendC::DataCopyParams(1, mm1OutSize_ / BLOCK_SIZE_16, 0, 0));
|
||||
} else {
|
||||
// quantMode == QuantMode::PER_TOKEN_SYMM_QUANT
|
||||
AscendC::DataCopy(mmTensor, s2GmTensor[offset], AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0));
|
||||
AscendC::DataCopy(mmTensor, s2GmTensor[offset], AscendC::DataCopyParams(1, splitSizeOne_ / 8, 0, 0));
|
||||
}
|
||||
AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO],
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO],
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_],
|
||||
splitRmsNormSizeTwo_);
|
||||
AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_],
|
||||
splitRmsNormSizeTwo_);
|
||||
SET_FLAG(MTE2, V, EVENT_ID0);
|
||||
// ND
|
||||
uint64_t cacheStart = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_SIZE_ONE);
|
||||
uint64_t cacheStart1 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_RMSNRORM_SIZE_ONE);
|
||||
uint64_t cacheStart2 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_RMSNRORM_SIZE_TWO);
|
||||
uint64_t cacheStart = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(splitSizeOne_);
|
||||
uint64_t cacheStart1 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(splitRmsNormSizeOne_);
|
||||
uint64_t cacheStart2 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(splitRmsNormSizeTwo_);
|
||||
// NZ
|
||||
uint32_t outer_idx = slotValue / 128;
|
||||
uint32_t inner_idx = slotValue % 128;
|
||||
@@ -2574,84 +2602,84 @@ private:
|
||||
if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
|
||||
/* DeQuant */
|
||||
AscendC::Cast(mmTensor.ReinterpretCast<float>(), mmTensor, AscendC::RoundMode::CAST_NONE,
|
||||
SPLIT_SIZE_ONE);
|
||||
splitSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Mul(mmTensor.ReinterpretCast<float>(), mmTensor.ReinterpretCast<float>(), deScaleTensor,
|
||||
SPLIT_SIZE_ONE);
|
||||
splitSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
float perTokenDescale = s5GmTensor.GetValue(row_work * vectorBlockIdx + loop);
|
||||
SET_FLAG(S, V, EVENT_ID0);
|
||||
WAIT_FLAG(S, V, EVENT_ID0);
|
||||
AscendC::Muls(mmTensor.ReinterpretCast<float>(), mmTensor.ReinterpretCast<float>(), perTokenDescale,
|
||||
SPLIT_SIZE_ONE);
|
||||
splitSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Cast(srcTensor, mmTensor.ReinterpretCast<float>(), AscendC::RoundMode::CAST_RINT,
|
||||
SPLIT_SIZE_ONE);
|
||||
splitSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(calTensor, rmsNormTensor, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Mul(calTensor, rmsNormTensor, rmsNormTensor, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
ReduceSumCustom(calTensor[SPLIT_RMSNRORM_SIZE_ONE], calTensor, calTensor[SPLIT_RMSNRORM_SIZE_ONE * 2],
|
||||
SPLIT_RMSNRORM_SIZE_ONE);
|
||||
ReduceSumCustom(calTensor[splitRmsNormSizeOne_], calTensor, calTensor[splitRmsNormSizeOne_ * 2],
|
||||
splitRmsNormSizeOne_);
|
||||
SET_FLAG(V, S, EVENT_ID1);
|
||||
WAIT_FLAG(V, S, EVENT_ID1);
|
||||
float rms = sqrt(calTensor.GetValue(SPLIT_RMSNRORM_SIZE_ONE) / SPLIT_RMSNRORM_SIZE_ONE + epsilon_);
|
||||
float rms = sqrt(calTensor.GetValue(splitRmsNormSizeOne_) / splitRmsNormSizeOne_ + epsilon_);
|
||||
SET_FLAG(S, V, EVENT_ID1);
|
||||
WAIT_FLAG(S, V, EVENT_ID1);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Duplicate(calTensor, rms, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Duplicate(calTensor, rms, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Div(calTensor, rmsNormTensor, calTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Div(calTensor, rmsNormTensor, calTensor, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(rmsNormTensor, gammaFp32, calTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Mul(rmsNormTensor, gammaFp32, calTensor, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) {
|
||||
// quant
|
||||
Muls(rmsNormTensor, rmsNormTensor, quantScale3, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Muls(rmsNormTensor, rmsNormTensor, quantScale3, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFrom32To16(tmpfp16, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
CastFrom32To16(tmpfp16, rmsNormTensor, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFromF16ToI8(int8OutTensor, tmpfp16, -128, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
CastFromF16ToI8(int8OutTensor, tmpfp16, -128, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
} else {
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
if (std::is_same<T1, __bf16>::value) {
|
||||
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, splitRmsNormSizeOne_);
|
||||
} else {
|
||||
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_);
|
||||
}
|
||||
}
|
||||
/* RmsNorm end */
|
||||
/* Rope K start */
|
||||
uint64_t revertOffset = SPLIT_RMSNRORM_SIZE_TWO / 2;
|
||||
Cast(ropeKTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE,
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(ropeKRevertTensor[revertOffset], srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE,
|
||||
uint64_t revertOffset = splitRmsNormSizeTwo_ / 2;
|
||||
Cast(ropeKTensor, srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE,
|
||||
splitRmsNormSizeTwo_);
|
||||
Cast(ropeKRevertTensor[revertOffset], srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE,
|
||||
revertOffset);
|
||||
Cast(ropeKRevertTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE + revertOffset], AscendC::RoundMode::CAST_NONE,
|
||||
Cast(ropeKRevertTensor, srcTensor[splitRmsNormSizeOne_ + revertOffset], AscendC::RoundMode::CAST_NONE,
|
||||
revertOffset);
|
||||
Duplicate(calTensor, static_cast<float>(-1), revertOffset);
|
||||
Duplicate(calTensor[revertOffset], static_cast<float>(1), revertOffset);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO], cosTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], sinTensor, AscendC::RoundMode::CAST_NONE,
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(calTensor[splitRmsNormSizeTwo_], cosTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeTwo_);
|
||||
Cast(calTensor[splitRmsNormSizeTwo_ * 2], sinTensor, AscendC::RoundMode::CAST_NONE,
|
||||
splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(ropeKTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO], ropeKTensor, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Mul(ropeKRevertTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Mul(ropeKTensor, calTensor[splitRmsNormSizeTwo_], ropeKTensor, splitRmsNormSizeTwo_);
|
||||
Mul(ropeKRevertTensor, calTensor[splitRmsNormSizeTwo_ * 2], ropeKRevertTensor, splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
if (std::is_same<T1, __bf16>::value) {
|
||||
Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT,
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(outTmpTensor[splitRmsNormSizeOne_], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT,
|
||||
splitRmsNormSizeTwo_);
|
||||
} else {
|
||||
Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE,
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(outTmpTensor[splitRmsNormSizeOne_], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE,
|
||||
splitRmsNormSizeTwo_);
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
/* Rope K end */
|
||||
@@ -2659,45 +2687,45 @@ private:
|
||||
WAIT_FLAG(V, MTE3, EVENT_ID0);
|
||||
WAIT_FLAG(S, MTE3, EVENT_ID0);
|
||||
if constexpr (CACHE_MODE == CACHE_MODE_KVCACHE) {
|
||||
DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, SPLIT_SIZE_ONE);
|
||||
DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, splitSizeOne_);
|
||||
} else if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) {
|
||||
uint64_t cacheSatartI8Nz1 = outer_idx * 128 * 512 + inner_idx * I8_C0_SIZE;
|
||||
uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE;
|
||||
// nope:int8 nz
|
||||
AscendC::DataCopyExtParams outExt;
|
||||
outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / I8_C0_SIZE;
|
||||
outExt.blockCount = splitRmsNormSizeOne_ / I8_C0_SIZE;
|
||||
outExt.blockLen = I8_C0_SIZE * sizeof(int8_t);
|
||||
outExt.srcStride = 0;
|
||||
outExt.dstStride = (128 * I8_C0_SIZE - I8_C0_SIZE) * sizeof(int8_t);
|
||||
DataCopyPad(keycacheGmTensor1[cacheSatartI8Nz1], int8OutTensor, outExt);
|
||||
// rope:T1 nz
|
||||
outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE;
|
||||
outExt.blockCount = splitRmsNormSizeTwo_ / C0_SIZE;
|
||||
outExt.blockLen = C0_SIZE * sizeof(T1);
|
||||
outExt.srcStride = 0;
|
||||
outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1);
|
||||
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt);
|
||||
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[splitRmsNormSizeOne_], outExt);
|
||||
} else if constexpr (CACHE_MODE == CACHE_MODE_NZCACHE) {
|
||||
uint64_t cacheSatartNz1 = outer_idx * 128 * 512 + inner_idx * C0_SIZE;
|
||||
uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE;
|
||||
// nope:T1 nz
|
||||
AscendC::DataCopyExtParams outExt;
|
||||
outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / C0_SIZE;
|
||||
outExt.blockCount = splitRmsNormSizeOne_ / C0_SIZE;
|
||||
outExt.blockLen = C0_SIZE * sizeof(T1);
|
||||
outExt.srcStride = 0;
|
||||
outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1);
|
||||
DataCopyPad(keycacheGmTensor1[cacheSatartNz1], outTmpTensor, outExt);
|
||||
// rope:T1 nz
|
||||
outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE;
|
||||
outExt.blockCount = splitRmsNormSizeTwo_ / C0_SIZE;
|
||||
outExt.blockLen = C0_SIZE * sizeof(T1);
|
||||
outExt.srcStride = 0;
|
||||
outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1);
|
||||
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt);
|
||||
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[splitRmsNormSizeOne_], outExt);
|
||||
} else {
|
||||
// keycache1
|
||||
DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, splitRmsNormSizeOne_);
|
||||
// keycache2
|
||||
DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE],
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[splitRmsNormSizeOne_],
|
||||
splitRmsNormSizeTwo_);
|
||||
}
|
||||
SET_FLAG(MTE3, MTE2, EVENT_ID1);
|
||||
WAIT_FLAG(MTE3, MTE2, EVENT_ID1);
|
||||
@@ -2845,20 +2873,20 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
|
||||
uint32_t num_col_align_f16 = (num_col_2 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
|
||||
uint32_t num_col_align_f32 = (num_col_2 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
|
||||
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
|
||||
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2);
|
||||
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2);
|
||||
AscendC::LocalTensor<InDtype> beta_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2);
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2 + splitSizeTwo_ * 2);
|
||||
AscendC::LocalTensor<InDtype> scale_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2);
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2);
|
||||
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 32);
|
||||
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 32);
|
||||
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64);
|
||||
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64);
|
||||
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4);
|
||||
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4);
|
||||
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4 +
|
||||
BUF_FACTOR * num_col_align_f32 * 4 + 64 + MM1_OUT_SIZE * 4 * 2 + 32);
|
||||
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4 +
|
||||
BUF_FACTOR * num_col_align_f32 * 4 + 64 + mm1OutSize_ * 4 * 2 + 32);
|
||||
rmsNormQuant2.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor,
|
||||
res1_tensor, res3_tensor);
|
||||
}
|
||||
@@ -2868,20 +2896,20 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
|
||||
|
||||
if (row_work_ != 0) {
|
||||
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
|
||||
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2);
|
||||
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2);
|
||||
AscendC::LocalTensor<InDtype> sin_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2);
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2);
|
||||
AscendC::LocalTensor<InDtype> cos_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 2);
|
||||
mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 2);
|
||||
AscendC::LocalTensor<int32_t> slotMapping_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int32_t>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4);
|
||||
mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4);
|
||||
int32_t rms3_ub_offset =
|
||||
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + 4096 * 32;
|
||||
mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 + 4096 * 32;
|
||||
AscendC::LocalTensor<float> tmp32_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(rms3_ub_offset);
|
||||
|
||||
int32_t out_ub_offset = MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 +
|
||||
4096 * 32 + SPLIT_RMSNRORM_SIZE_ONE * 3 * 4 + SPLIT_RMSNRORM_SIZE_TWO * 2 * 4 +
|
||||
MM1_OUT_SIZE * 4 * 2 + 32;
|
||||
int32_t out_ub_offset = mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 +
|
||||
4096 * 32 + splitRmsNormSizeOne_ * 3 * 4 + splitRmsNormSizeTwo_ * 2 * 4 +
|
||||
mm1OutSize_ * 4 * 2 + 32;
|
||||
AscendC::LocalTensor<InDtype> temp_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(out_ub_offset);
|
||||
|
||||
AscendC::LocalTensor<half> tmpfp16;
|
||||
@@ -2895,7 +2923,7 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, float>(rms3_ub_offset + 32);
|
||||
// int8out
|
||||
tmpfp16 = buf.GetBuffer<BufferType::ASCEND_UB, half>(rms3_ub_offset +
|
||||
SPLIT_RMSNRORM_SIZE_ONE * sizeof(float) * 2);
|
||||
splitRmsNormSizeOne_ * sizeof(float) * 2);
|
||||
int8OutTensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(out_ub_offset);
|
||||
AscendC::DataCopy(quantScaleTensor, quantScale3GmTensor, AscendC::DataCopyParams(1, 1, 0, 0));
|
||||
SET_FLAG(MTE2, V, EVENT_ID1);
|
||||
@@ -2912,11 +2940,11 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
|
||||
sin_tensor, // sin
|
||||
cos_tensor, // cons
|
||||
slotMapping_tensor, // slotMapping
|
||||
row_work_, tmp32_tensor, tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE],
|
||||
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE],
|
||||
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO],
|
||||
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO +
|
||||
SPLIT_RMSNRORM_SIZE_TWO],
|
||||
row_work_, tmp32_tensor, tmp32_tensor[splitRmsNormSizeOne_],
|
||||
tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_],
|
||||
tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_ + splitRmsNormSizeTwo_],
|
||||
tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_ + splitRmsNormSizeTwo_ +
|
||||
splitRmsNormSizeTwo_],
|
||||
temp_tensor, tmpfp16, int8OutTensor, scale3);
|
||||
}
|
||||
mm_w8a8_aiv_2.Process();
|
||||
|
||||
@@ -56,6 +56,8 @@ public:
|
||||
lastCoreLoopTime = ropeConcatParams.lastCoreLoopTime;
|
||||
lastCoreLoopNLast = ropeConcatParams.lastCoreLoopNLast;
|
||||
concatSize = ropeConcatParams.concatSize;
|
||||
hiddenStrideRope_ = ropeConcatParams.hiddenStrideRope;
|
||||
qkNopeHeadDim_ = ropeConcatParams.qkNopeHeadDim;
|
||||
blockIdx_ = (blockIdx_ / 2) * 2 + static_cast<uint64_t>(GetSubBlockidx());
|
||||
loopTime = (blockIdx_ == realCore - 1) ? lastCoreLoopTime : preCoreLoopTime;
|
||||
lastLoopN = (blockIdx_ == realCore - 1) ? lastCoreLoopNLast : preCoreLoopNLast;
|
||||
@@ -92,7 +94,7 @@ public:
|
||||
AscendC::LocalTensor<float> inputQCastFP32 = buf.GetBuffer<BufferType::ASCEND_UB, float>(dataSizeFp16);
|
||||
AscendC::LocalTensor<float> reverseQ =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, float>(dataSizeFp32 + dataSizeFp16);
|
||||
uint64_t qOffset = startHead * 192 + 128;
|
||||
uint64_t qOffset = startHead * hiddenStrideRope_ + qkNopeHeadDim_;
|
||||
CopyQGenReverseQ(inputQ, inputQCastFP32, reverseQ, qOffset, loopN);
|
||||
|
||||
// move in cos/sin
|
||||
@@ -184,7 +186,7 @@ public:
|
||||
WAIT_FLAG(S, MTE2, EVENT_ID1);
|
||||
WAIT_FLAG(MTE3, MTE2, EVENT_ID1);
|
||||
// move in Q
|
||||
AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, 128 / 16, 0});
|
||||
AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, static_cast<uint16_t>(qkNopeHeadDim_ / 16), 0});
|
||||
SET_FLAG(MTE2, V, EVENT_ID1);
|
||||
WAIT_FLAG(MTE2, V, EVENT_ID1);
|
||||
// cast fp32
|
||||
@@ -238,6 +240,8 @@ private:
|
||||
uint32_t lastCoreLoopTime;
|
||||
uint32_t lastCoreLoopNLast;
|
||||
uint32_t concatSize;
|
||||
uint32_t hiddenStrideRope_;
|
||||
uint32_t qkNopeHeadDim_;
|
||||
uint32_t blockIdx_;
|
||||
uint32_t loopTime{0}; // The number of current data rounds
|
||||
uint32_t lastLoopN{0}; // The number of lines currently processed by tails kernel
|
||||
@@ -2035,6 +2039,15 @@ public:
|
||||
this->epsilon_ = 1e-6;
|
||||
this->mlaParams = mlaParams_;
|
||||
this->hiddenStateDim = mlaParams_.hiddenStateDim;
|
||||
this->mm1OutSize_ = mlaParams_.mm1OutSize;
|
||||
this->splitSizeOne_ = mlaParams_.splitSizeOne;
|
||||
this->splitSizeTwo_ = mlaParams_.splitSizeTwo;
|
||||
this->splitRmsNormSizeOne_ = mlaParams_.splitRmsNormSizeOne;
|
||||
this->splitRmsNormSizeTwo_ = mlaParams_.splitRmsNormSizeTwo;
|
||||
this->ropeSplitSizeOne_ = mlaParams_.ropeSplitSizeOne;
|
||||
this->ropeSplitSizeTwo_ = mlaParams_.ropeSplitSizeTwo;
|
||||
this->hiddenStrideRope_ = mlaParams_.hiddenStrideRope;
|
||||
this->qkNopeHeadDim_ = mlaParams_.qkNopeHeadDim;
|
||||
}
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
|
||||
@@ -2109,9 +2122,9 @@ public:
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_1, row_work_, mlaParams);
|
||||
|
||||
rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor, s3GmTensor,
|
||||
s1GmTensor, SPLIT_SIZE_ONE, num_col_2, 0.000651041666,
|
||||
s1GmTensor, splitSizeOne_, num_col_2, mlaParams.avgFactor,
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams);
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * splitSizeTwo_, row_work_, mlaParams);
|
||||
ropeFp16.RopeInit(s2GmTensor, cos2GmTensor, sin2GmTensor, qGmTensor, qGmTensor2, mlaParams);
|
||||
einSumQuant.Init(s1Gm, gmQnopeScale, qGm, mlaParams);
|
||||
ubTensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(0);
|
||||
@@ -2125,6 +2138,17 @@ public:
|
||||
__aicore__ inline void ProcessVector();
|
||||
|
||||
private:
|
||||
// Model-specific MLA dimensions from tiling data
|
||||
uint32_t mm1OutSize_;
|
||||
uint32_t splitSizeOne_;
|
||||
uint32_t splitSizeTwo_;
|
||||
uint32_t splitRmsNormSizeOne_;
|
||||
uint32_t splitRmsNormSizeTwo_;
|
||||
uint32_t ropeSplitSizeOne_;
|
||||
uint32_t ropeSplitSizeTwo_;
|
||||
uint32_t hiddenStrideRope_;
|
||||
uint32_t qkNopeHeadDim_;
|
||||
|
||||
constexpr static uint32_t C0_SIZE = 16;
|
||||
constexpr static uint32_t I8_C0_SIZE = 32;
|
||||
|
||||
@@ -2139,10 +2163,10 @@ private:
|
||||
AscendC::LocalTensor<half> &tmpfp16, AscendC::LocalTensor<int8_t> &int8OutTensor, float quantScale3)
|
||||
{
|
||||
int64_t slotMapGmOffset = vectorBlockIdx * row_work;
|
||||
AscendC::DataCopy(gammaTensor, gamma3GmTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
AscendC::DataCopy(gammaTensor, gamma3GmTensor, splitRmsNormSizeOne_);
|
||||
SET_FLAG(MTE2, V, EVENT_ID1);
|
||||
WAIT_FLAG(MTE2, V, EVENT_ID1);
|
||||
Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_);
|
||||
AscendC::DataCopyPad(slotMappingTensor, slotMappingGmTensor[slotMapGmOffset],
|
||||
AscendC::DataCopyExtParams(1, sN * sizeof(int32_t), 0, 0, 0),
|
||||
AscendC::DataCopyPadExtParams<int32_t>(false, 0, 8 - sN % 8, 0));
|
||||
@@ -2151,134 +2175,134 @@ private:
|
||||
SET_FLAG(MTE2, S, EVENT_ID2);
|
||||
WAIT_FLAG(MTE2, S, EVENT_ID2);
|
||||
for (uint64_t loop = 0; loop < sN; ++loop) {
|
||||
uint64_t offset = vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2 + loop * MM1_OUT_SIZE;
|
||||
uint64_t offset = vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2 + loop * mm1OutSize_;
|
||||
int64_t slotValue = static_cast<int64_t>(slotMappingTensor.GetValue(loop));
|
||||
if (slotValue == -1) {
|
||||
continue;
|
||||
}
|
||||
AscendC::DataCopy(srcTensor, s3GmTensor[offset], SPLIT_SIZE_ONE);
|
||||
AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO],
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO],
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
AscendC::DataCopy(srcTensor, s3GmTensor[offset], splitSizeOne_);
|
||||
AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_],
|
||||
splitRmsNormSizeTwo_);
|
||||
AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * splitRmsNormSizeTwo_],
|
||||
splitRmsNormSizeTwo_);
|
||||
SET_FLAG(MTE2, V, EVENT_ID0);
|
||||
// ND
|
||||
uint64_t cacheStart = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_SIZE_ONE);
|
||||
uint64_t cacheStart1 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_RMSNRORM_SIZE_ONE);
|
||||
uint64_t cacheStart2 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_RMSNRORM_SIZE_TWO);
|
||||
uint64_t cacheStart = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(splitSizeOne_);
|
||||
uint64_t cacheStart1 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(splitRmsNormSizeOne_);
|
||||
uint64_t cacheStart2 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(splitRmsNormSizeTwo_);
|
||||
// NZ
|
||||
uint32_t outer_idx = slotValue / 128;
|
||||
uint32_t inner_idx = slotValue % 128;
|
||||
SET_FLAG(S, MTE3, EVENT_ID0);
|
||||
/* RmsNorm start */
|
||||
WAIT_FLAG(MTE2, V, EVENT_ID0);
|
||||
Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(calTensor, rmsNormTensor, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Mul(calTensor, rmsNormTensor, rmsNormTensor, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
ReduceSumCustom(calTensor[SPLIT_RMSNRORM_SIZE_ONE], calTensor, calTensor[SPLIT_RMSNRORM_SIZE_ONE * 2],
|
||||
SPLIT_RMSNRORM_SIZE_ONE);
|
||||
ReduceSumCustom(calTensor[splitRmsNormSizeOne_], calTensor, calTensor[splitRmsNormSizeOne_ * 2],
|
||||
splitRmsNormSizeOne_);
|
||||
SET_FLAG(V, S, EVENT_ID1);
|
||||
WAIT_FLAG(V, S, EVENT_ID1);
|
||||
float rms = sqrt(calTensor.GetValue(SPLIT_RMSNRORM_SIZE_ONE) / SPLIT_RMSNRORM_SIZE_ONE + epsilon_);
|
||||
float rms = sqrt(calTensor.GetValue(splitRmsNormSizeOne_) / splitRmsNormSizeOne_ + epsilon_);
|
||||
SET_FLAG(S, V, EVENT_ID1);
|
||||
WAIT_FLAG(S, V, EVENT_ID1);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Duplicate(calTensor, rms, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Duplicate(calTensor, rms, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Div(calTensor, rmsNormTensor, calTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Div(calTensor, rmsNormTensor, calTensor, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(rmsNormTensor, gammaFp32, calTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Mul(rmsNormTensor, gammaFp32, calTensor, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
if constexpr (cacheMode == CACHE_MODE_INT8_NZCACHE) {
|
||||
// quant
|
||||
Muls(rmsNormTensor, rmsNormTensor, quantScale3, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Muls(rmsNormTensor, rmsNormTensor, quantScale3, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFrom32To16(tmpfp16, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
CastFrom32To16(tmpfp16, rmsNormTensor, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFromF16ToI8(int8OutTensor, tmpfp16, -128, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
CastFromF16ToI8(int8OutTensor, tmpfp16, -128, splitRmsNormSizeOne_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
} else {
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
if (std::is_same<T1, __bf16>::value) {
|
||||
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, splitRmsNormSizeOne_);
|
||||
} else {
|
||||
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeOne_);
|
||||
}
|
||||
}
|
||||
/* RmsNorm end */
|
||||
// /* Rope K start */
|
||||
uint64_t revertOffset = SPLIT_RMSNRORM_SIZE_TWO / 2;
|
||||
Cast(ropeKTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE,
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(ropeKRevertTensor[revertOffset], srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE,
|
||||
uint64_t revertOffset = splitRmsNormSizeTwo_ / 2;
|
||||
Cast(ropeKTensor, srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE,
|
||||
splitRmsNormSizeTwo_);
|
||||
Cast(ropeKRevertTensor[revertOffset], srcTensor[splitRmsNormSizeOne_], AscendC::RoundMode::CAST_NONE,
|
||||
revertOffset);
|
||||
Cast(ropeKRevertTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE + revertOffset], AscendC::RoundMode::CAST_NONE,
|
||||
Cast(ropeKRevertTensor, srcTensor[splitRmsNormSizeOne_ + revertOffset], AscendC::RoundMode::CAST_NONE,
|
||||
revertOffset);
|
||||
Duplicate(calTensor, static_cast<float>(-1), revertOffset);
|
||||
Duplicate(calTensor[revertOffset], static_cast<float>(1), revertOffset);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO], cosTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], sinTensor, AscendC::RoundMode::CAST_NONE,
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(calTensor[splitRmsNormSizeTwo_], cosTensor, AscendC::RoundMode::CAST_NONE, splitRmsNormSizeTwo_);
|
||||
Cast(calTensor[splitRmsNormSizeTwo_ * 2], sinTensor, AscendC::RoundMode::CAST_NONE,
|
||||
splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(ropeKTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO], ropeKTensor, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Mul(ropeKRevertTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Mul(ropeKTensor, calTensor[splitRmsNormSizeTwo_], ropeKTensor, splitRmsNormSizeTwo_);
|
||||
Mul(ropeKRevertTensor, calTensor[splitRmsNormSizeTwo_ * 2], ropeKRevertTensor, splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, splitRmsNormSizeTwo_);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE,
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
Cast(outTmpTensor[splitRmsNormSizeOne_], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE,
|
||||
splitRmsNormSizeTwo_);
|
||||
/* Rope K end */
|
||||
// reshapeAndcache
|
||||
SET_FLAG(V, MTE3, EVENT_ID0);
|
||||
WAIT_FLAG(V, MTE3, EVENT_ID0);
|
||||
WAIT_FLAG(S, MTE3, EVENT_ID0);
|
||||
if constexpr (cacheMode == CACHE_MODE_KVCACHE) {
|
||||
DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, SPLIT_SIZE_ONE);
|
||||
DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, splitSizeOne_);
|
||||
} else if constexpr (cacheMode == CACHE_MODE_INT8_NZCACHE) {
|
||||
// NZ
|
||||
int64_t cacheSatartI8Nz1 = outer_idx * 128 * 512 + inner_idx * I8_C0_SIZE;
|
||||
uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE;
|
||||
AscendC::DataCopyExtParams outExt;
|
||||
// nope:int8 nz
|
||||
outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / I8_C0_SIZE;
|
||||
outExt.blockCount = splitRmsNormSizeOne_ / I8_C0_SIZE;
|
||||
outExt.blockLen = I8_C0_SIZE * sizeof(int8_t);
|
||||
outExt.srcStride = 0;
|
||||
outExt.dstStride = (128 * I8_C0_SIZE - I8_C0_SIZE) * sizeof(int8_t);
|
||||
DataCopyPad(keycacheGmTensor1[cacheSatartI8Nz1], int8OutTensor, outExt);
|
||||
// rope:T1 nz
|
||||
outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE;
|
||||
outExt.blockCount = splitRmsNormSizeTwo_ / C0_SIZE;
|
||||
outExt.blockLen = C0_SIZE * sizeof(T1);
|
||||
outExt.srcStride = 0;
|
||||
outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1);
|
||||
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt);
|
||||
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[splitRmsNormSizeOne_], outExt);
|
||||
} else if constexpr (cacheMode == CACHE_MODE_NZCACHE) {
|
||||
uint64_t cacheSatartNz1 = outer_idx * 128 * 512 + inner_idx * C0_SIZE;
|
||||
uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE;
|
||||
// nope:T1 nz
|
||||
AscendC::DataCopyExtParams outExt;
|
||||
outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / C0_SIZE;
|
||||
outExt.blockCount = splitRmsNormSizeOne_ / C0_SIZE;
|
||||
outExt.blockLen = C0_SIZE * sizeof(T1);
|
||||
outExt.srcStride = 0;
|
||||
outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1);
|
||||
DataCopyPad(keycacheGmTensor1[cacheSatartNz1], outTmpTensor, outExt);
|
||||
// rope:T1 nz
|
||||
outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE;
|
||||
outExt.blockCount = splitRmsNormSizeTwo_ / C0_SIZE;
|
||||
outExt.blockLen = C0_SIZE * sizeof(T1);
|
||||
outExt.srcStride = 0;
|
||||
outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1);
|
||||
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt);
|
||||
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[splitRmsNormSizeOne_], outExt);
|
||||
} else {
|
||||
// keycache1
|
||||
DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, SPLIT_RMSNRORM_SIZE_ONE);
|
||||
DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, splitRmsNormSizeOne_);
|
||||
// keycache2
|
||||
DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE],
|
||||
SPLIT_RMSNRORM_SIZE_TWO);
|
||||
DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[splitRmsNormSizeOne_],
|
||||
splitRmsNormSizeTwo_);
|
||||
}
|
||||
SET_FLAG(MTE3, MTE2, EVENT_ID1);
|
||||
WAIT_FLAG(MTE3, MTE2, EVENT_ID1);
|
||||
@@ -2417,19 +2441,19 @@ __aicore__ inline void MLAOperation<cacheMode, weightFormat1, weightFormat2, wei
|
||||
uint32_t num_col_align_f16 = (num_col_2 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
|
||||
uint32_t num_col_align_f32 = (num_col_2 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
|
||||
AscendC::LocalTensor<half> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(0);
|
||||
AscendC::LocalTensor<half> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(MM1_OUT_SIZE * 2);
|
||||
AscendC::LocalTensor<half> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(mm1OutSize_ * 2);
|
||||
AscendC::LocalTensor<half> beta_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, half>(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2);
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, half>(mm1OutSize_ * 2 + splitSizeTwo_ * 2);
|
||||
AscendC::LocalTensor<half> scale_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, half>(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2);
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, half>(mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2);
|
||||
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 32);
|
||||
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 32);
|
||||
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64);
|
||||
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64);
|
||||
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4);
|
||||
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4);
|
||||
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4 +
|
||||
mm1OutSize_ * 2 + splitSizeTwo_ * 2 + splitSizeTwo_ * 2 + 64 + num_col_align_f32 * 4 +
|
||||
BUF_FACTOR * num_col_align_f32 * 4 + 32);
|
||||
rmsNormQuant2.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor,
|
||||
res1_tensor, res3_tensor);
|
||||
@@ -2440,19 +2464,19 @@ __aicore__ inline void MLAOperation<cacheMode, weightFormat1, weightFormat2, wei
|
||||
|
||||
if (row_work_ != 0) {
|
||||
AscendC::LocalTensor<half> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(0);
|
||||
AscendC::LocalTensor<half> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(MM1_OUT_SIZE * 2);
|
||||
AscendC::LocalTensor<half> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(mm1OutSize_ * 2);
|
||||
AscendC::LocalTensor<half> sin_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, half>(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2);
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, half>(mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2);
|
||||
AscendC::LocalTensor<half> cos_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 2);
|
||||
mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 2);
|
||||
AscendC::LocalTensor<int32_t> slotMapping_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int32_t>(
|
||||
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4);
|
||||
mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4);
|
||||
int32_t rms3_ub_offset =
|
||||
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + 4096 * 32;
|
||||
mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 + 4096 * 32;
|
||||
AscendC::LocalTensor<float> tmp32_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(rms3_ub_offset);
|
||||
|
||||
int32_t out_ub_offset = MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 +
|
||||
4096 * 32 + SPLIT_RMSNRORM_SIZE_ONE * 3 * 4 + SPLIT_RMSNRORM_SIZE_TWO * 2 * 4;
|
||||
int32_t out_ub_offset = mm1OutSize_ * 2 + splitRmsNormSizeOne_ * 2 + splitRmsNormSizeTwo_ * 4 +
|
||||
4096 * 32 + splitRmsNormSizeOne_ * 3 * 4 + splitRmsNormSizeTwo_ * 2 * 4;
|
||||
AscendC::LocalTensor<half> temp_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(out_ub_offset);
|
||||
|
||||
AscendC::LocalTensor<half> tmpfp16;
|
||||
@@ -2465,7 +2489,7 @@ __aicore__ inline void MLAOperation<cacheMode, weightFormat1, weightFormat2, wei
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, float>(rms3_ub_offset + 32);
|
||||
// int8out
|
||||
tmpfp16 = buf.GetBuffer<BufferType::ASCEND_UB, half>(rms3_ub_offset +
|
||||
SPLIT_RMSNRORM_SIZE_ONE * sizeof(float) * 2);
|
||||
splitRmsNormSizeOne_ * sizeof(float) * 2);
|
||||
int8OutTensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(out_ub_offset);
|
||||
AscendC::DataCopy(quantScaleTensor, quantScale3GmTensor, AscendC::DataCopyParams(1, 1, 0, 0));
|
||||
SET_FLAG(MTE2, V, EVENT_ID1);
|
||||
@@ -2482,11 +2506,11 @@ __aicore__ inline void MLAOperation<cacheMode, weightFormat1, weightFormat2, wei
|
||||
sin_tensor, // sin
|
||||
cos_tensor, // cons
|
||||
slotMapping_tensor, // slotMapping
|
||||
row_work_, tmp32_tensor, tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE],
|
||||
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE],
|
||||
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO],
|
||||
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO +
|
||||
SPLIT_RMSNRORM_SIZE_TWO],
|
||||
row_work_, tmp32_tensor, tmp32_tensor[splitRmsNormSizeOne_],
|
||||
tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_],
|
||||
tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_ + splitRmsNormSizeTwo_],
|
||||
tmp32_tensor[splitRmsNormSizeOne_ + splitRmsNormSizeOne_ + splitRmsNormSizeTwo_ +
|
||||
splitRmsNormSizeTwo_],
|
||||
temp_tensor, tmpfp16, int8OutTensor, scale3);
|
||||
}
|
||||
WaitFlagDev(BMM3SPLIT);
|
||||
|
||||
Reference in New Issue
Block a user