[feat] mlapo add bf16 no_quant support (#4852)

### What this PR does / why we need it?
This PR adds mlapo operation support for bf16 no_quant mode.

### Does this PR introduce _any_ user-facing change?
This PR makes quant related parameters optional. 
### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: chenjunyi <isjunyi.chen@gmail.com>
This commit is contained in:
chenjunyi
2025-12-11 11:06:56 +08:00
committed by GitHub
parent c95c271538
commit c12eb22cbe
12 changed files with 1510 additions and 81 deletions

View File

@@ -43,7 +43,6 @@ constexpr uint32_t L1_BIAS_SIZE = 2048;
constexpr uint32_t L0C_SIZE = 128 * 1024;
constexpr uint32_t CONCAT_SIZE = 512;
constexpr uint32_t HIDDEN_STRATE = 7168;
constexpr uint32_t HIDDEN_STRATE_ROPE = 192;
constexpr uint32_t HIDDEN_STRATE_MM = 2112;
constexpr uint32_t HIDDEN_STRATE_RMS = 1536;
@@ -122,6 +121,8 @@ struct PlatformInfo {
};
struct OpParam {
uint32_t isWeightQuantized;
uint32_t hiddenStateDim;
uint32_t N;
uint32_t headNum;
int32_t cacheMode;
@@ -392,7 +393,7 @@ private:
void MlaPreprocessTiling::RmsNormQuantTiling()
{
tilingData->rmsNumCore1 = platformInfo.coreNumAiv;
tilingData->rmsNumCol1 = HIDDEN_STRATE;
tilingData->rmsNumCol1 = opParam.hiddenStateDim;
tilingData->rmsNumRow1 = opParam.N;
tilingData->rmsQuantMin1 = -CONST_128;
tilingData->rmsNumCore2 = platformInfo.coreNumAiv;
@@ -508,9 +509,9 @@ void MlaPreprocessTiling::EinSumQuantTiling()
void MlaPreprocessTiling::SetMlapoWorkSpace()
{
uint64_t s1wsFactor =
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(HIDDEN_STRATE * sizeof(int8_t),
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(opParam.hiddenStateDim * sizeof(int8_t),
opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t))
: HIDDEN_STRATE * sizeof(int8_t));
: opParam.hiddenStateDim * sizeof(int8_t));
uint64_t workSizeS1 = s1wsFactor;
uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t);
uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t);
@@ -525,7 +526,8 @@ void MlaPreprocessTiling::SetMlapoWorkSpace()
uint64_t pertokenWorkspace = static_cast<uint64_t>(opParam.N) * sizeof(float) * 2;
uint64_t userWorkspaceSize;
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
if (opParam.isWeightQuantized == 1 &&
(opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT)) {
userWorkspaceSize = 4 * maxWorkspaceSize + pertokenWorkspace;
} else {
userWorkspaceSize = 3 * maxWorkspaceSize;
@@ -554,21 +556,23 @@ void MlaPreprocessTiling::Init()
{
tilingData->numCore = platformInfo.coreNumAic;
tilingData->n = opParam.N;
tilingData->hiddenStateDim = opParam.hiddenStateDim;
tilingData->isWeightQuantized = opParam.isWeightQuantized;
bool enDequant = (opParam.isWeightQuantized == 1);
bool deqOnTheFly = false;
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
if (enDequant && (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT)) {
deqOnTheFly = true;
}
PpMatmulTilingApi mm1TilingApi(platformInfo,
1, // numBatch
opParam.N, // m
HIDDEN_STRATE, // k
HIDDEN_STRATE_MM, // n
false, // transA
true, // transB
true, // enDequant
deqOnTheFly); // in bf16.cce?
1, // numBatch
opParam.N, // m
opParam.hiddenStateDim, // k
HIDDEN_STRATE_MM, // n
false, // transA
true, // transB
enDequant, // enDequant
deqOnTheFly); // in bf16.cce?
mm1TilingApi.GetTilingData(tilingData->mm1);
PpMatmulTilingApi mm2TilingApi(platformInfo,
@@ -578,7 +582,7 @@ void MlaPreprocessTiling::Init()
opParam.headNum * HIDDEN_STRATE_ROPE, // n
false, // transA
true, // transB
true, // enDequant
enDequant, // enDequant
deqOnTheFly); // in bf16.cce?
mm2TilingApi.GetTilingData(tilingData->mm2);
@@ -609,6 +613,8 @@ std::unordered_map<c10::string_view, uint16_t> cache_mode_map = {
std::unordered_map<c10::string_view, uint16_t> quant_mode_map = {
{"per_tensor_quant_asymm", 0},
{"per_token_quant_symm", 1},
{"per_token_quant_asymm", 2},
{"no_quant", 3}
};
template <typename MapType>
@@ -623,6 +629,7 @@ inline int get_op_mode(const MapType &mode_map, c10::optional<c10::string_view>
std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
const at::Tensor &hiddenState,
const at::Tensor &wdqkv,
const at::Tensor &wuk,
c10::optional<c10::string_view> cache_mode,
c10::optional<c10::string_view> quant_mode,
@@ -647,14 +654,21 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
int32_t N = hiddenState.sizes()[0];
int32_t headNum = wuk.sizes()[0];
uint32_t hiddenStateDim = hiddenState.sizes().back();
OpParam opParam;
opParam.hiddenStateDim = hiddenStateDim;
opParam.N = N;
opParam.headNum = headNum;
opParam.cacheMode = static_cast<int32_t>(cacheMode);
opParam.quantMode = static_cast<QuantMode>(quantMode);
opParam.inDtype = hiddenState.options().dtype();
opParam.enableInnerOut = enable_inner_out;
if (wdqkv.options().dtype() == at::kBFloat16 || wdqkv.options().dtype() == at::kHalf) {
opParam.isWeightQuantized = 0;
} else {
opParam.isWeightQuantized = 1;
}
MlaTilingData tilingData;
MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData);

View File

@@ -90,6 +90,11 @@ struct MlaTilingData {
uint32_t esqHeadTail{0};
uint32_t esqColLoop{0};
uint32_t esqColTail{0};
// hidden state dimension
uint32_t hiddenStateDim{7168};
uint32_t isWeightQuantized{1};
};
#endif // MLAPREPROCESS_TILING_H