[feat] mlapo add bf16 no_quant support (#4852)
### What this PR does / why we need it?
This PR adds mlapo operation support for bf16 no_quant mode.
### Does this PR introduce _any_ user-facing change?
This PR makes quant related parameters optional.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: chenjunyi <isjunyi.chen@gmail.com>
This commit is contained in:
@@ -43,7 +43,6 @@ constexpr uint32_t L1_BIAS_SIZE = 2048;
|
||||
constexpr uint32_t L0C_SIZE = 128 * 1024;
|
||||
constexpr uint32_t CONCAT_SIZE = 512;
|
||||
|
||||
constexpr uint32_t HIDDEN_STRATE = 7168;
|
||||
constexpr uint32_t HIDDEN_STRATE_ROPE = 192;
|
||||
constexpr uint32_t HIDDEN_STRATE_MM = 2112;
|
||||
constexpr uint32_t HIDDEN_STRATE_RMS = 1536;
|
||||
@@ -122,6 +121,8 @@ struct PlatformInfo {
|
||||
};
|
||||
|
||||
struct OpParam {
|
||||
uint32_t isWeightQuantized;
|
||||
uint32_t hiddenStateDim;
|
||||
uint32_t N;
|
||||
uint32_t headNum;
|
||||
int32_t cacheMode;
|
||||
@@ -392,7 +393,7 @@ private:
|
||||
void MlaPreprocessTiling::RmsNormQuantTiling()
|
||||
{
|
||||
tilingData->rmsNumCore1 = platformInfo.coreNumAiv;
|
||||
tilingData->rmsNumCol1 = HIDDEN_STRATE;
|
||||
tilingData->rmsNumCol1 = opParam.hiddenStateDim;
|
||||
tilingData->rmsNumRow1 = opParam.N;
|
||||
tilingData->rmsQuantMin1 = -CONST_128;
|
||||
tilingData->rmsNumCore2 = platformInfo.coreNumAiv;
|
||||
@@ -508,9 +509,9 @@ void MlaPreprocessTiling::EinSumQuantTiling()
|
||||
void MlaPreprocessTiling::SetMlapoWorkSpace()
|
||||
{
|
||||
uint64_t s1wsFactor =
|
||||
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(HIDDEN_STRATE * sizeof(int8_t),
|
||||
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(opParam.hiddenStateDim * sizeof(int8_t),
|
||||
opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t))
|
||||
: HIDDEN_STRATE * sizeof(int8_t));
|
||||
: opParam.hiddenStateDim * sizeof(int8_t));
|
||||
uint64_t workSizeS1 = s1wsFactor;
|
||||
uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t);
|
||||
uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t);
|
||||
@@ -525,7 +526,8 @@ void MlaPreprocessTiling::SetMlapoWorkSpace()
|
||||
uint64_t pertokenWorkspace = static_cast<uint64_t>(opParam.N) * sizeof(float) * 2;
|
||||
|
||||
uint64_t userWorkspaceSize;
|
||||
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
|
||||
if (opParam.isWeightQuantized == 1 &&
|
||||
(opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT)) {
|
||||
userWorkspaceSize = 4 * maxWorkspaceSize + pertokenWorkspace;
|
||||
} else {
|
||||
userWorkspaceSize = 3 * maxWorkspaceSize;
|
||||
@@ -554,21 +556,23 @@ void MlaPreprocessTiling::Init()
|
||||
{
|
||||
tilingData->numCore = platformInfo.coreNumAic;
|
||||
tilingData->n = opParam.N;
|
||||
|
||||
tilingData->hiddenStateDim = opParam.hiddenStateDim;
|
||||
tilingData->isWeightQuantized = opParam.isWeightQuantized;
|
||||
bool enDequant = (opParam.isWeightQuantized == 1);
|
||||
bool deqOnTheFly = false;
|
||||
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
|
||||
if (enDequant && (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT)) {
|
||||
deqOnTheFly = true;
|
||||
}
|
||||
|
||||
PpMatmulTilingApi mm1TilingApi(platformInfo,
|
||||
1, // numBatch
|
||||
opParam.N, // m
|
||||
HIDDEN_STRATE, // k
|
||||
HIDDEN_STRATE_MM, // n
|
||||
false, // transA
|
||||
true, // transB
|
||||
true, // enDequant
|
||||
deqOnTheFly); // in bf16.cce?
|
||||
1, // numBatch
|
||||
opParam.N, // m
|
||||
opParam.hiddenStateDim, // k
|
||||
HIDDEN_STRATE_MM, // n
|
||||
false, // transA
|
||||
true, // transB
|
||||
enDequant, // enDequant
|
||||
deqOnTheFly); // in bf16.cce?
|
||||
mm1TilingApi.GetTilingData(tilingData->mm1);
|
||||
|
||||
PpMatmulTilingApi mm2TilingApi(platformInfo,
|
||||
@@ -578,7 +582,7 @@ void MlaPreprocessTiling::Init()
|
||||
opParam.headNum * HIDDEN_STRATE_ROPE, // n
|
||||
false, // transA
|
||||
true, // transB
|
||||
true, // enDequant
|
||||
enDequant, // enDequant
|
||||
deqOnTheFly); // in bf16.cce?
|
||||
mm2TilingApi.GetTilingData(tilingData->mm2);
|
||||
|
||||
@@ -609,6 +613,8 @@ std::unordered_map<c10::string_view, uint16_t> cache_mode_map = {
|
||||
std::unordered_map<c10::string_view, uint16_t> quant_mode_map = {
|
||||
{"per_tensor_quant_asymm", 0},
|
||||
{"per_token_quant_symm", 1},
|
||||
{"per_token_quant_asymm", 2},
|
||||
{"no_quant", 3}
|
||||
};
|
||||
|
||||
template <typename MapType>
|
||||
@@ -623,6 +629,7 @@ inline int get_op_mode(const MapType &mode_map, c10::optional<c10::string_view>
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
|
||||
const at::Tensor &hiddenState,
|
||||
const at::Tensor &wdqkv,
|
||||
const at::Tensor &wuk,
|
||||
c10::optional<c10::string_view> cache_mode,
|
||||
c10::optional<c10::string_view> quant_mode,
|
||||
@@ -647,14 +654,21 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
|
||||
|
||||
int32_t N = hiddenState.sizes()[0];
|
||||
int32_t headNum = wuk.sizes()[0];
|
||||
uint32_t hiddenStateDim = hiddenState.sizes().back();
|
||||
|
||||
OpParam opParam;
|
||||
opParam.hiddenStateDim = hiddenStateDim;
|
||||
opParam.N = N;
|
||||
opParam.headNum = headNum;
|
||||
opParam.cacheMode = static_cast<int32_t>(cacheMode);
|
||||
opParam.quantMode = static_cast<QuantMode>(quantMode);
|
||||
opParam.inDtype = hiddenState.options().dtype();
|
||||
opParam.enableInnerOut = enable_inner_out;
|
||||
if (wdqkv.options().dtype() == at::kBFloat16 || wdqkv.options().dtype() == at::kHalf) {
|
||||
opParam.isWeightQuantized = 0;
|
||||
} else {
|
||||
opParam.isWeightQuantized = 1;
|
||||
}
|
||||
|
||||
MlaTilingData tilingData;
|
||||
MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData);
|
||||
|
||||
@@ -90,6 +90,11 @@ struct MlaTilingData {
|
||||
uint32_t esqHeadTail{0};
|
||||
uint32_t esqColLoop{0};
|
||||
uint32_t esqColTail{0};
|
||||
|
||||
// hidden state dimension
|
||||
uint32_t hiddenStateDim{7168};
|
||||
|
||||
uint32_t isWeightQuantized{1};
|
||||
};
|
||||
|
||||
#endif // MLAPREPROCESS_TILING_H
|
||||
|
||||
@@ -49,7 +49,6 @@ constexpr uint8_t CACHE_MODE_INT8_NZCACHE = 2; // high performance KV NZ format
|
||||
constexpr uint8_t CACHE_MODE_NZCACHE = 3;
|
||||
|
||||
// pp matmul
|
||||
constexpr uint32_t HIDDTEN_STATE = 7168;
|
||||
constexpr uint32_t FLOAT_BLOCK_SIZE = 64;
|
||||
constexpr uint32_t HALF_BLOCK_SIZE = 64;
|
||||
constexpr uint32_t HALF_VECTOR_SIZE = 64;
|
||||
@@ -103,6 +102,7 @@ constexpr uint32_t KEY_FP16_CACHEMODE_1_QUANTMODE_0 = 1;
|
||||
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0 = 256;
|
||||
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0 = 257;
|
||||
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0 = 259;
|
||||
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_3 = 281;
|
||||
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER = 256 + 512;
|
||||
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER = 257 + 512;
|
||||
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER = 259 + 512;
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "mla_preprocess_mix_fp16.hpp"
|
||||
#include "mla_preprocess_mix_bf16.hpp"
|
||||
#include "mla_preprocess_mix_bf16_qdown.hpp"
|
||||
#include "mla_preprocess_mix_bf16_nq.hpp"
|
||||
|
||||
#include "../op_host/tiling/mla_preprocess_tiling.h"
|
||||
|
||||
@@ -42,6 +43,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
|
||||
|
||||
mlaTilingData.tilingKey = tilingData->tilingKey;
|
||||
mlaTilingData.n = tilingData->n;
|
||||
mlaTilingData.hiddenStateDim = tilingData->hiddenStateDim;
|
||||
|
||||
mlaTilingData.mm1.numBatch = tilingData->mm1.numBatch;
|
||||
mlaTilingData.mm1.m = tilingData->mm1.m;
|
||||
@@ -173,12 +175,12 @@ extern "C" __global__ __aicore__ void mla_preprocess(
|
||||
}
|
||||
case KEY_BF16_CACHEMODE_0_QUANTMODE_0: {
|
||||
MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||
opBf16Cm0Qm0(mlaTilingData, tiling);
|
||||
opBf16Cm0Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3, s4, s5);
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3, s4, s5);
|
||||
if ASCEND_IS_AIC {
|
||||
opBf16Cm0Qm0.ProcessCube();
|
||||
}
|
||||
@@ -189,12 +191,12 @@ extern "C" __global__ __aicore__ void mla_preprocess(
|
||||
}
|
||||
case KEY_BF16_CACHEMODE_1_QUANTMODE_0: {
|
||||
MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||
opBf16Cm1Qm0(mlaTilingData, tiling);
|
||||
opBf16Cm1Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3, s4, s5);
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3, s4, s5);
|
||||
if ASCEND_IS_AIC {
|
||||
opBf16Cm1Qm0.ProcessCube();
|
||||
}
|
||||
@@ -219,6 +221,21 @@ extern "C" __global__ __aicore__ void mla_preprocess(
|
||||
}
|
||||
break;
|
||||
}
|
||||
case KEY_BF16_CACHEMODE_1_QUANTMODE_3: {
|
||||
MLAPO_BF16_NQ::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND>
|
||||
opBf16Cm1Qm0(mlaTilingData, tiling);
|
||||
opBf16Cm1Qm0.Init(hiddenState, wdqkv, gamma2, beta2,
|
||||
gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
wuk, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3);
|
||||
if ASCEND_IS_AIC {
|
||||
opBf16Cm1Qm0.ProcessCube();
|
||||
}
|
||||
if ASCEND_IS_AIV {
|
||||
opBf16Cm1Qm0.ProcessVector();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER: {
|
||||
MLAPO_BF16_INNER::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||
|
||||
@@ -2386,6 +2386,7 @@ public:
|
||||
this->num_row = mlaParams_.n;
|
||||
this->epsilon_ = 1e-6;
|
||||
this->mlaParams = mlaParams_;
|
||||
this->hiddenStateDim = mlaParams_.hiddenStateDim;
|
||||
}
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
|
||||
@@ -2692,6 +2693,7 @@ private:
|
||||
uint32_t blockOffset;
|
||||
uint32_t perTaskNum;
|
||||
uint32_t resTaskNum;
|
||||
uint32_t hiddenStateDim;
|
||||
MlaTilingData mlaParams;
|
||||
|
||||
uint32_t num_core_;
|
||||
@@ -2795,18 +2797,15 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
|
||||
uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
|
||||
uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
|
||||
uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
|
||||
const uint32_t base_offset = hiddenStateDim * 6;
|
||||
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
|
||||
AscendC::LocalTensor<InDtype> scale_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
|
||||
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
|
||||
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32);
|
||||
AscendC::LocalTensor<float> res1_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, float>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64);
|
||||
AscendC::LocalTensor<InDtype> scale_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(base_offset);
|
||||
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(base_offset + 32);
|
||||
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(base_offset + 64);
|
||||
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
|
||||
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4);
|
||||
base_offset + 64 + num_col_align_f32 * 4);
|
||||
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
|
||||
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 +
|
||||
BUF_FACTOR * num_col_align_f32 * 4 + 64);
|
||||
base_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64);
|
||||
Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor);
|
||||
}
|
||||
FftsCrossCoreSync<PIPE_MTE3, 0>(QUANT1);
|
||||
|
||||
1252
csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16_nq.hpp
Normal file
1252
csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16_nq.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2406,6 +2406,7 @@ public:
|
||||
this->num_row = mlaParams_.n;
|
||||
this->epsilon_ = 1e-6;
|
||||
this->mlaParams = mlaParams_;
|
||||
this->hiddenStateDim = mlaParams_.hiddenStateDim;
|
||||
}
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
|
||||
@@ -2713,6 +2714,7 @@ private:
|
||||
uint32_t blockOffset;
|
||||
uint32_t perTaskNum;
|
||||
uint32_t resTaskNum;
|
||||
uint32_t hiddenStateDim;
|
||||
MlaTilingData mlaParams;
|
||||
|
||||
uint32_t num_core_;
|
||||
@@ -2817,18 +2819,15 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
|
||||
uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
|
||||
uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
|
||||
uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
|
||||
const uint32_t base_offset = hiddenStateDim * 6;
|
||||
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
|
||||
AscendC::LocalTensor<InDtype> scale_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
|
||||
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
|
||||
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32);
|
||||
AscendC::LocalTensor<float> res1_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, float>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64);
|
||||
AscendC::LocalTensor<InDtype> scale_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(base_offset);
|
||||
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(base_offset + 32);
|
||||
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(base_offset + 64);
|
||||
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
|
||||
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4);
|
||||
base_offset + 64 + num_col_align_f32 * 4);
|
||||
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
|
||||
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 +
|
||||
BUF_FACTOR * num_col_align_f32 * 4 + 64);
|
||||
base_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64);
|
||||
Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor);
|
||||
}
|
||||
FftsCrossCoreSync<PIPE_MTE3, 0>(QUANT1);
|
||||
|
||||
@@ -2034,6 +2034,7 @@ public:
|
||||
this->num_row = mlaParams_.n;
|
||||
this->epsilon_ = 1e-6;
|
||||
this->mlaParams = mlaParams_;
|
||||
this->hiddenStateDim = mlaParams_.hiddenStateDim;
|
||||
}
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
|
||||
@@ -2294,6 +2295,7 @@ private:
|
||||
uint32_t blockOffset;
|
||||
uint32_t perTaskNum;
|
||||
uint32_t resTaskNum;
|
||||
uint32_t hiddenStateDim;
|
||||
MlaTilingData mlaParams;
|
||||
|
||||
// rmsnormQuant
|
||||
@@ -2389,21 +2391,19 @@ __aicore__ inline void MLAOperation<cacheMode, weightFormat1, weightFormat2, wei
|
||||
uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
|
||||
uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
|
||||
uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
|
||||
const uint32_t gamma_offset = hiddenStateDim * 2;
|
||||
const uint32_t beta_offset = gamma_offset + hiddenStateDim * 2;
|
||||
const uint32_t scale_offset = beta_offset + hiddenStateDim * 2;
|
||||
AscendC::LocalTensor<half> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(0);
|
||||
AscendC::LocalTensor<half> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2);
|
||||
AscendC::LocalTensor<half> beta_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
|
||||
AscendC::LocalTensor<half> scale_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
|
||||
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
|
||||
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32);
|
||||
AscendC::LocalTensor<float> res1_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, float>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64);
|
||||
AscendC::LocalTensor<half> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(gamma_offset);
|
||||
AscendC::LocalTensor<half> beta_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(beta_offset);
|
||||
AscendC::LocalTensor<half> scale_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(scale_offset);
|
||||
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(scale_offset + 32);
|
||||
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(scale_offset + 64);
|
||||
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
|
||||
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4);
|
||||
scale_offset + 64 + num_col_align_f32 * 4);
|
||||
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
|
||||
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 +
|
||||
BUF_FACTOR * num_col_align_f32 * 4 + 32);
|
||||
scale_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 32);
|
||||
Quant1.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor,
|
||||
res3_tensor);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user