[feat] parameterize hardcoded MLA dimensions to support GLM5-W8A8 (#6902)

Derive MLA dimension constants (q_lora_rank, qk_nope_head_dim, etc.)
from tensor shapes at runtime instead of hardcoding DeepSeek V3 values.
This enables the mla_preprocess fused op to work with both DeepSeek V3
and GLM5 models without Python API changes.

- Add 9 dimension fields to MlaTilingData with DeepSeek V3 defaults
- Add OpParam fields and dynamize all host-side tiling functions
- Derive dimensions from wuk, gamma1, kv_cache_rope tensor shapes
- Replace 310+ hardcoded constants across 4 kernel .hpp files
- Remove unused MMSIZE1/MMSIZE2 constants

### What this PR does / why we need it?

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: liuchenbing <chenliumail@163.com>
Co-authored-by: liuchenbing <chenliumail@163.com>
This commit is contained in:
liuchen2026fly
2026-03-09 20:17:21 +08:00
committed by GitHub
parent 13adcbe44b
commit 542258ac9d
9 changed files with 508 additions and 342 deletions

View File

@@ -129,6 +129,11 @@ struct OpParam {
QuantMode quantMode;
caffe2::TypeMeta inDtype;
bool enableInnerOut;
// MLA dimensions derived from tensor shapes
uint32_t qLoraRank;
uint32_t qkNopeHeadDim;
uint32_t qkRopeHeadDim;
uint32_t kvLoraRank;
};
class PpMatmulTilingApi
@@ -397,7 +402,7 @@ void MlaPreprocessTiling::RmsNormQuantTiling()
tilingData->rmsNumRow1 = opParam.N;
tilingData->rmsQuantMin1 = -CONST_128;
tilingData->rmsNumCore2 = platformInfo.coreNumAiv;
tilingData->rmsNumCol2 = HIDDEN_STRATE_MM;
tilingData->rmsNumCol2 = opParam.qLoraRank + opParam.kvLoraRank + opParam.qkRopeHeadDim;
tilingData->rmsNumRow2 = opParam.N;
tilingData->rmsQuantMin2 = -CONST_128;
}
@@ -405,10 +410,10 @@ void MlaPreprocessTiling::RmsNormQuantTiling()
void MlaPreprocessTiling::RopeConcatTiling()
{
uint32_t ntokens = opParam.N;
uint32_t hiddenSizeQ = HEADDIM * opParam.headNum;
uint32_t headDim = HEADDIM;
uint32_t hiddenSizeQ = opParam.qkRopeHeadDim * opParam.headNum;
uint32_t headDim = opParam.qkRopeHeadDim;
uint32_t headNumQ = hiddenSizeQ / headDim;
uint32_t concatSize = CONCAT_SIZE;
uint32_t concatSize = opParam.kvLoraRank;
uint32_t maxCore = platformInfo.coreNumAiv;
uint32_t maxUbSize = platformInfo.ubSize;
@@ -458,7 +463,7 @@ void MlaPreprocessTiling::EinSumQuantTiling()
// input shape
uint32_t esqBatch = opParam.N; // tokenNum
uint32_t esqHeadNum = opParam.headNum; // headNum
uint32_t esqColNum = AXES_ALIGN_SIZE; // 512
uint32_t esqColNum = opParam.kvLoraRank; // kv_lora_rank
// split core
uint32_t esqFrontCore = esqBatch % aivCore;
@@ -508,14 +513,16 @@ void MlaPreprocessTiling::EinSumQuantTiling()
void MlaPreprocessTiling::SetMlapoWorkSpace()
{
uint32_t hiddenStrideRope = opParam.qkNopeHeadDim + opParam.qkRopeHeadDim;
uint32_t hiddenStrateMm = opParam.qLoraRank + opParam.kvLoraRank + opParam.qkRopeHeadDim;
uint64_t s1wsFactor =
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(opParam.hiddenStateDim * sizeof(int8_t),
opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t))
opParam.headNum * opParam.kvLoraRank * sizeof(uint16_t))
: opParam.hiddenStateDim * sizeof(int8_t));
uint64_t workSizeS1 = s1wsFactor;
uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t);
uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t);
uint64_t workSizeS4 = std::max(opParam.headNum * HIDDEN_STRATE_ROPE, HIDDEN_STRATE_MM) * sizeof(uint32_t);
uint64_t workSizeS2 = opParam.headNum * hiddenStrideRope * sizeof(uint16_t);
uint64_t workSizeS3 = hiddenStrateMm * sizeof(uint16_t);
uint64_t workSizeS4 = std::max(opParam.headNum * hiddenStrideRope, hiddenStrateMm) * sizeof(uint32_t);
uint64_t maxWorkspaceSize = workSizeS1;
maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS2);
@@ -564,11 +571,17 @@ void MlaPreprocessTiling::Init()
deqOnTheFly = true;
}
uint32_t mm1N = opParam.qLoraRank + opParam.kvLoraRank + opParam.qkRopeHeadDim;
uint32_t mm2K = opParam.qLoraRank;
uint32_t mm2N = opParam.headNum * (opParam.qkNopeHeadDim + opParam.qkRopeHeadDim);
uint32_t mm3K = opParam.qkNopeHeadDim;
uint32_t mm3N = opParam.kvLoraRank;
PpMatmulTilingApi mm1TilingApi(platformInfo,
1, // numBatch
opParam.N, // m
opParam.hiddenStateDim, // k
HIDDEN_STRATE_MM, // n
mm1N, // n
false, // transA
true, // transB
enDequant, // enDequant
@@ -576,21 +589,21 @@ void MlaPreprocessTiling::Init()
mm1TilingApi.GetTilingData(tilingData->mm1);
PpMatmulTilingApi mm2TilingApi(platformInfo,
1, // numBatch
opParam.N, // m
HIDDEN_STRATE_RMS, // k
opParam.headNum * HIDDEN_STRATE_ROPE, // n
false, // transA
true, // transB
enDequant, // enDequant
deqOnTheFly); // in bf16.cce?
1, // numBatch
opParam.N, // m
mm2K, // k
mm2N, // n
false, // transA
true, // transB
enDequant, // enDequant
deqOnTheFly); // in bf16.cce?
mm2TilingApi.GetTilingData(tilingData->mm2);
PpMatmulTilingApi mm3TilingApi(platformInfo,
opParam.headNum, // numBatch
opParam.N, // m
CONST_128, // k
CONCAT_SIZE, // n
mm3K, // k
mm3N, // n
false, // transA
false, // transB
false, // enDequant
@@ -604,6 +617,18 @@ void MlaPreprocessTiling::Init()
SetMlapoWorkSpace();
SetTilingKey();
// Populate model-specific MLA dimension fields
tilingData->mm1OutSize = opParam.qLoraRank + opParam.kvLoraRank + opParam.qkRopeHeadDim;
tilingData->splitSizeOne = opParam.kvLoraRank + opParam.qkRopeHeadDim;
tilingData->splitSizeTwo = opParam.qLoraRank;
tilingData->splitRmsNormSizeOne = opParam.kvLoraRank;
tilingData->splitRmsNormSizeTwo = opParam.qkRopeHeadDim;
tilingData->ropeSplitSizeOne = opParam.qkRopeHeadDim;
tilingData->ropeSplitSizeTwo = opParam.qkNopeHeadDim;
tilingData->hiddenStrideRope = opParam.qkNopeHeadDim + opParam.qkRopeHeadDim;
tilingData->qkNopeHeadDim = opParam.qkNopeHeadDim;
tilingData->avgFactor = 1.0f / static_cast<float>(opParam.qLoraRank);
return;
}
@@ -631,6 +656,8 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
const at::Tensor &hiddenState,
const at::Tensor &wdqkv,
const at::Tensor &wuk,
const at::Tensor &gamma1,
const at::Tensor &kv_cache_rope,
c10::optional<c10::string_view> cache_mode,
c10::optional<c10::string_view> quant_mode,
bool enable_inner_out
@@ -656,6 +683,12 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
int32_t headNum = wuk.sizes()[0];
uint32_t hiddenStateDim = hiddenState.sizes().back();
// Derive MLA dimensions from tensor shapes
uint32_t qkNopeHeadDim = wuk.sizes()[1];
uint32_t kvLoraRank = wuk.sizes()[2];
uint32_t qLoraRank = gamma1.sizes()[0];
uint32_t qkRopeHeadDim = kv_cache_rope.sizes().back();
OpParam opParam;
opParam.hiddenStateDim = hiddenStateDim;
opParam.N = N;
@@ -664,6 +697,10 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
opParam.quantMode = static_cast<QuantMode>(quantMode);
opParam.inDtype = hiddenState.options().dtype();
opParam.enableInnerOut = enable_inner_out;
opParam.qLoraRank = qLoraRank;
opParam.qkNopeHeadDim = qkNopeHeadDim;
opParam.qkRopeHeadDim = qkRopeHeadDim;
opParam.kvLoraRank = kvLoraRank;
if (wdqkv.options().dtype() == at::kBFloat16 || wdqkv.options().dtype() == at::kHalf) {
opParam.isWeightQuantized = 0;
} else {

View File

@@ -95,6 +95,18 @@ struct MlaTilingData {
uint32_t hiddenStateDim{7168};
uint32_t isWeightQuantized{1};
// Model-specific MLA dimensions (derived from tensor shapes)
uint32_t mm1OutSize{2112}; // q_lora_rank + kv_lora_rank + qk_rope_head_dim
uint32_t splitSizeOne{576}; // kv_lora_rank + qk_rope_head_dim
uint32_t splitSizeTwo{1536}; // q_lora_rank
uint32_t splitRmsNormSizeOne{512}; // kv_lora_rank
uint32_t splitRmsNormSizeTwo{64}; // qk_rope_head_dim
uint32_t ropeSplitSizeOne{64}; // qk_rope_head_dim
uint32_t ropeSplitSizeTwo{128}; // qk_nope_head_dim
uint32_t hiddenStrideRope{192}; // qk_nope_head_dim + qk_rope_head_dim
uint32_t qkNopeHeadDim{128}; // for RoPE offset calc
float avgFactor{0.000651041666f}; // 1/splitSizeTwo (1/qLoraRank), for RmsNorm avg
};
#endif // MLAPREPROCESS_TILING_H