[feat] mlapo add bf16 no_quant support (#4852)
### What this PR does / why we need it?
This PR adds mlapo operation support for bf16 no_quant mode.
### Does this PR introduce _any_ user-facing change?
This PR makes quant related parameters optional.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: chenjunyi <isjunyi.chen@gmail.com>
This commit is contained in:
@@ -2034,6 +2034,7 @@ public:
|
||||
this->num_row = mlaParams_.n;
|
||||
this->epsilon_ = 1e-6;
|
||||
this->mlaParams = mlaParams_;
|
||||
this->hiddenStateDim = mlaParams_.hiddenStateDim;
|
||||
}
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
|
||||
@@ -2294,6 +2295,7 @@ private:
|
||||
uint32_t blockOffset;
|
||||
uint32_t perTaskNum;
|
||||
uint32_t resTaskNum;
|
||||
uint32_t hiddenStateDim;
|
||||
MlaTilingData mlaParams;
|
||||
|
||||
// rmsnormQuant
|
||||
@@ -2389,21 +2391,19 @@ __aicore__ inline void MLAOperation<cacheMode, weightFormat1, weightFormat2, wei
|
||||
uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
|
||||
uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
|
||||
uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
|
||||
const uint32_t gamma_offset = hiddenStateDim * 2;
|
||||
const uint32_t beta_offset = gamma_offset + hiddenStateDim * 2;
|
||||
const uint32_t scale_offset = beta_offset + hiddenStateDim * 2;
|
||||
AscendC::LocalTensor<half> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(0);
|
||||
AscendC::LocalTensor<half> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2);
|
||||
AscendC::LocalTensor<half> beta_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
|
||||
AscendC::LocalTensor<half> scale_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
|
||||
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
|
||||
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32);
|
||||
AscendC::LocalTensor<float> res1_tensor =
|
||||
buf.GetBuffer<BufferType::ASCEND_UB, float>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64);
|
||||
AscendC::LocalTensor<half> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(gamma_offset);
|
||||
AscendC::LocalTensor<half> beta_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(beta_offset);
|
||||
AscendC::LocalTensor<half> scale_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(scale_offset);
|
||||
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(scale_offset + 32);
|
||||
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(scale_offset + 64);
|
||||
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
|
||||
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4);
|
||||
scale_offset + 64 + num_col_align_f32 * 4);
|
||||
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
|
||||
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 +
|
||||
BUF_FACTOR * num_col_align_f32 * 4 + 32);
|
||||
scale_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 32);
|
||||
Quant1.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor,
|
||||
res3_tensor);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user