[feat] mlapo add bf16 no_quant support (#4852)

### What this PR does / why we need it?
This PR adds mlapo operation support for bf16 no_quant mode.

### Does this PR introduce _any_ user-facing change?
This PR makes quant related parameters optional. 
### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: chenjunyi <isjunyi.chen@gmail.com>
This commit is contained in:
chenjunyi
2025-12-11 11:06:56 +08:00
committed by GitHub
parent c95c271538
commit c12eb22cbe
12 changed files with 1510 additions and 81 deletions

View File

@@ -49,7 +49,6 @@ constexpr uint8_t CACHE_MODE_INT8_NZCACHE = 2; // high performance KV NZ format
constexpr uint8_t CACHE_MODE_NZCACHE = 3;
// pp matmul
constexpr uint32_t HIDDTEN_STATE = 7168;
constexpr uint32_t FLOAT_BLOCK_SIZE = 64;
constexpr uint32_t HALF_BLOCK_SIZE = 64;
constexpr uint32_t HALF_VECTOR_SIZE = 64;
@@ -103,6 +102,7 @@ constexpr uint32_t KEY_FP16_CACHEMODE_1_QUANTMODE_0 = 1;
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0 = 256;
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0 = 257;
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0 = 259;
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_3 = 281;
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER = 256 + 512;
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER = 257 + 512;
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER = 259 + 512;

View File

@@ -16,6 +16,7 @@
#include "mla_preprocess_mix_fp16.hpp"
#include "mla_preprocess_mix_bf16.hpp"
#include "mla_preprocess_mix_bf16_qdown.hpp"
#include "mla_preprocess_mix_bf16_nq.hpp"
#include "../op_host/tiling/mla_preprocess_tiling.h"
@@ -42,6 +43,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
mlaTilingData.tilingKey = tilingData->tilingKey;
mlaTilingData.n = tilingData->n;
mlaTilingData.hiddenStateDim = tilingData->hiddenStateDim;
mlaTilingData.mm1.numBatch = tilingData->mm1.numBatch;
mlaTilingData.mm1.m = tilingData->mm1.m;
@@ -173,12 +175,12 @@ extern "C" __global__ __aicore__ void mla_preprocess(
}
case KEY_BF16_CACHEMODE_0_QUANTMODE_0: {
MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm0Qm0(mlaTilingData, tiling);
opBf16Cm0Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5);
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5);
if ASCEND_IS_AIC {
opBf16Cm0Qm0.ProcessCube();
}
@@ -189,12 +191,12 @@ extern "C" __global__ __aicore__ void mla_preprocess(
}
case KEY_BF16_CACHEMODE_1_QUANTMODE_0: {
MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm1Qm0(mlaTilingData, tiling);
opBf16Cm1Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5);
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5);
if ASCEND_IS_AIC {
opBf16Cm1Qm0.ProcessCube();
}
@@ -219,6 +221,21 @@ extern "C" __global__ __aicore__ void mla_preprocess(
}
break;
}
case KEY_BF16_CACHEMODE_1_QUANTMODE_3: {
MLAPO_BF16_NQ::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND>
opBf16Cm1Qm0(mlaTilingData, tiling);
opBf16Cm1Qm0.Init(hiddenState, wdqkv, gamma2, beta2,
gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
wuk, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3);
if ASCEND_IS_AIC {
opBf16Cm1Qm0.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm1Qm0.ProcessVector();
}
break;
}
case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER: {
MLAPO_BF16_INNER::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>

View File

@@ -2386,6 +2386,7 @@ public:
this->num_row = mlaParams_.n;
this->epsilon_ = 1e-6;
this->mlaParams = mlaParams_;
this->hiddenStateDim = mlaParams_.hiddenStateDim;
}
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
@@ -2692,6 +2693,7 @@ private:
uint32_t blockOffset;
uint32_t perTaskNum;
uint32_t resTaskNum;
uint32_t hiddenStateDim;
MlaTilingData mlaParams;
uint32_t num_core_;
@@ -2795,18 +2797,15 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
const uint32_t base_offset = hiddenStateDim * 6;
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
AscendC::LocalTensor<InDtype> scale_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32);
AscendC::LocalTensor<float> res1_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, float>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64);
AscendC::LocalTensor<InDtype> scale_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(base_offset);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(base_offset + 32);
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(base_offset + 64);
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4);
base_offset + 64 + num_col_align_f32 * 4);
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 +
BUF_FACTOR * num_col_align_f32 * 4 + 64);
base_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64);
Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor);
}
FftsCrossCoreSync<PIPE_MTE3, 0>(QUANT1);

File diff suppressed because it is too large Load Diff

View File

@@ -2406,6 +2406,7 @@ public:
this->num_row = mlaParams_.n;
this->epsilon_ = 1e-6;
this->mlaParams = mlaParams_;
this->hiddenStateDim = mlaParams_.hiddenStateDim;
}
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
@@ -2713,6 +2714,7 @@ private:
uint32_t blockOffset;
uint32_t perTaskNum;
uint32_t resTaskNum;
uint32_t hiddenStateDim;
MlaTilingData mlaParams;
uint32_t num_core_;
@@ -2817,18 +2819,15 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
const uint32_t base_offset = hiddenStateDim * 6;
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
AscendC::LocalTensor<InDtype> scale_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32);
AscendC::LocalTensor<float> res1_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, float>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64);
AscendC::LocalTensor<InDtype> scale_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(base_offset);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(base_offset + 32);
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(base_offset + 64);
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4);
base_offset + 64 + num_col_align_f32 * 4);
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 +
BUF_FACTOR * num_col_align_f32 * 4 + 64);
base_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64);
Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor);
}
FftsCrossCoreSync<PIPE_MTE3, 0>(QUANT1);

View File

@@ -2034,6 +2034,7 @@ public:
this->num_row = mlaParams_.n;
this->epsilon_ = 1e-6;
this->mlaParams = mlaParams_;
this->hiddenStateDim = mlaParams_.hiddenStateDim;
}
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
@@ -2294,6 +2295,7 @@ private:
uint32_t blockOffset;
uint32_t perTaskNum;
uint32_t resTaskNum;
uint32_t hiddenStateDim;
MlaTilingData mlaParams;
// rmsnormQuant
@@ -2389,21 +2391,19 @@ __aicore__ inline void MLAOperation<cacheMode, weightFormat1, weightFormat2, wei
uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
const uint32_t gamma_offset = hiddenStateDim * 2;
const uint32_t beta_offset = gamma_offset + hiddenStateDim * 2;
const uint32_t scale_offset = beta_offset + hiddenStateDim * 2;
AscendC::LocalTensor<half> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(0);
AscendC::LocalTensor<half> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2);
AscendC::LocalTensor<half> beta_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
AscendC::LocalTensor<half> scale_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32);
AscendC::LocalTensor<float> res1_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, float>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64);
AscendC::LocalTensor<half> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(gamma_offset);
AscendC::LocalTensor<half> beta_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(beta_offset);
AscendC::LocalTensor<half> scale_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(scale_offset);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(scale_offset + 32);
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(scale_offset + 64);
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4);
scale_offset + 64 + num_col_align_f32 * 4);
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 +
BUF_FACTOR * num_col_align_f32 * 4 + 32);
scale_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 32);
Quant1.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor,
res3_tensor);
}