[feat] parameterize hardcoded MLA dimensions to support GLM5-W8A8 (#6902)
Derive MLA dimension constants (q_lora_rank, qk_nope_head_dim, etc.)
from tensor shapes at runtime instead of hardcoding DeepSeek V3 values.
This enables the mla_preprocess fused op to work with both DeepSeek V3
and GLM5 models without Python API changes.
- Add 9 dimension fields to MlaTilingData with DeepSeek V3 defaults
- Add OpParam fields and dynamize all host-side tiling functions
- Derive dimensions from wuk, gamma1, kv_cache_rope tensor shapes
- Replace 310+ hardcoded constants across 4 kernel .hpp files
- Remove unused MMSIZE1/MMSIZE2 constants
### What this PR does / why we need it?
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: liuchenbing <chenliumail@163.com>
Co-authored-by: liuchenbing <chenliumail@163.com>
This commit is contained in:
@@ -129,6 +129,11 @@ struct OpParam {
|
||||
QuantMode quantMode;
|
||||
caffe2::TypeMeta inDtype;
|
||||
bool enableInnerOut;
|
||||
// MLA dimensions derived from tensor shapes
|
||||
uint32_t qLoraRank;
|
||||
uint32_t qkNopeHeadDim;
|
||||
uint32_t qkRopeHeadDim;
|
||||
uint32_t kvLoraRank;
|
||||
};
|
||||
|
||||
class PpMatmulTilingApi
|
||||
@@ -397,7 +402,7 @@ void MlaPreprocessTiling::RmsNormQuantTiling()
|
||||
tilingData->rmsNumRow1 = opParam.N;
|
||||
tilingData->rmsQuantMin1 = -CONST_128;
|
||||
tilingData->rmsNumCore2 = platformInfo.coreNumAiv;
|
||||
tilingData->rmsNumCol2 = HIDDEN_STRATE_MM;
|
||||
tilingData->rmsNumCol2 = opParam.qLoraRank + opParam.kvLoraRank + opParam.qkRopeHeadDim;
|
||||
tilingData->rmsNumRow2 = opParam.N;
|
||||
tilingData->rmsQuantMin2 = -CONST_128;
|
||||
}
|
||||
@@ -405,10 +410,10 @@ void MlaPreprocessTiling::RmsNormQuantTiling()
|
||||
void MlaPreprocessTiling::RopeConcatTiling()
|
||||
{
|
||||
uint32_t ntokens = opParam.N;
|
||||
uint32_t hiddenSizeQ = HEADDIM * opParam.headNum;
|
||||
uint32_t headDim = HEADDIM;
|
||||
uint32_t hiddenSizeQ = opParam.qkRopeHeadDim * opParam.headNum;
|
||||
uint32_t headDim = opParam.qkRopeHeadDim;
|
||||
uint32_t headNumQ = hiddenSizeQ / headDim;
|
||||
uint32_t concatSize = CONCAT_SIZE;
|
||||
uint32_t concatSize = opParam.kvLoraRank;
|
||||
uint32_t maxCore = platformInfo.coreNumAiv;
|
||||
uint32_t maxUbSize = platformInfo.ubSize;
|
||||
|
||||
@@ -458,7 +463,7 @@ void MlaPreprocessTiling::EinSumQuantTiling()
|
||||
// input shape
|
||||
uint32_t esqBatch = opParam.N; // tokenNum
|
||||
uint32_t esqHeadNum = opParam.headNum; // headNum
|
||||
uint32_t esqColNum = AXES_ALIGN_SIZE; // 512
|
||||
uint32_t esqColNum = opParam.kvLoraRank; // kv_lora_rank
|
||||
|
||||
// split core
|
||||
uint32_t esqFrontCore = esqBatch % aivCore;
|
||||
@@ -508,14 +513,16 @@ void MlaPreprocessTiling::EinSumQuantTiling()
|
||||
|
||||
void MlaPreprocessTiling::SetMlapoWorkSpace()
|
||||
{
|
||||
uint32_t hiddenStrideRope = opParam.qkNopeHeadDim + opParam.qkRopeHeadDim;
|
||||
uint32_t hiddenStrateMm = opParam.qLoraRank + opParam.kvLoraRank + opParam.qkRopeHeadDim;
|
||||
uint64_t s1wsFactor =
|
||||
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(opParam.hiddenStateDim * sizeof(int8_t),
|
||||
opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t))
|
||||
opParam.headNum * opParam.kvLoraRank * sizeof(uint16_t))
|
||||
: opParam.hiddenStateDim * sizeof(int8_t));
|
||||
uint64_t workSizeS1 = s1wsFactor;
|
||||
uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t);
|
||||
uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t);
|
||||
uint64_t workSizeS4 = std::max(opParam.headNum * HIDDEN_STRATE_ROPE, HIDDEN_STRATE_MM) * sizeof(uint32_t);
|
||||
uint64_t workSizeS2 = opParam.headNum * hiddenStrideRope * sizeof(uint16_t);
|
||||
uint64_t workSizeS3 = hiddenStrateMm * sizeof(uint16_t);
|
||||
uint64_t workSizeS4 = std::max(opParam.headNum * hiddenStrideRope, hiddenStrateMm) * sizeof(uint32_t);
|
||||
|
||||
uint64_t maxWorkspaceSize = workSizeS1;
|
||||
maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS2);
|
||||
@@ -564,11 +571,17 @@ void MlaPreprocessTiling::Init()
|
||||
deqOnTheFly = true;
|
||||
}
|
||||
|
||||
uint32_t mm1N = opParam.qLoraRank + opParam.kvLoraRank + opParam.qkRopeHeadDim;
|
||||
uint32_t mm2K = opParam.qLoraRank;
|
||||
uint32_t mm2N = opParam.headNum * (opParam.qkNopeHeadDim + opParam.qkRopeHeadDim);
|
||||
uint32_t mm3K = opParam.qkNopeHeadDim;
|
||||
uint32_t mm3N = opParam.kvLoraRank;
|
||||
|
||||
PpMatmulTilingApi mm1TilingApi(platformInfo,
|
||||
1, // numBatch
|
||||
opParam.N, // m
|
||||
opParam.hiddenStateDim, // k
|
||||
HIDDEN_STRATE_MM, // n
|
||||
mm1N, // n
|
||||
false, // transA
|
||||
true, // transB
|
||||
enDequant, // enDequant
|
||||
@@ -576,21 +589,21 @@ void MlaPreprocessTiling::Init()
|
||||
mm1TilingApi.GetTilingData(tilingData->mm1);
|
||||
|
||||
PpMatmulTilingApi mm2TilingApi(platformInfo,
|
||||
1, // numBatch
|
||||
opParam.N, // m
|
||||
HIDDEN_STRATE_RMS, // k
|
||||
opParam.headNum * HIDDEN_STRATE_ROPE, // n
|
||||
false, // transA
|
||||
true, // transB
|
||||
enDequant, // enDequant
|
||||
deqOnTheFly); // in bf16.cce?
|
||||
1, // numBatch
|
||||
opParam.N, // m
|
||||
mm2K, // k
|
||||
mm2N, // n
|
||||
false, // transA
|
||||
true, // transB
|
||||
enDequant, // enDequant
|
||||
deqOnTheFly); // in bf16.cce?
|
||||
mm2TilingApi.GetTilingData(tilingData->mm2);
|
||||
|
||||
PpMatmulTilingApi mm3TilingApi(platformInfo,
|
||||
opParam.headNum, // numBatch
|
||||
opParam.N, // m
|
||||
CONST_128, // k
|
||||
CONCAT_SIZE, // n
|
||||
mm3K, // k
|
||||
mm3N, // n
|
||||
false, // transA
|
||||
false, // transB
|
||||
false, // enDequant
|
||||
@@ -604,6 +617,18 @@ void MlaPreprocessTiling::Init()
|
||||
SetMlapoWorkSpace();
|
||||
SetTilingKey();
|
||||
|
||||
// Populate model-specific MLA dimension fields
|
||||
tilingData->mm1OutSize = opParam.qLoraRank + opParam.kvLoraRank + opParam.qkRopeHeadDim;
|
||||
tilingData->splitSizeOne = opParam.kvLoraRank + opParam.qkRopeHeadDim;
|
||||
tilingData->splitSizeTwo = opParam.qLoraRank;
|
||||
tilingData->splitRmsNormSizeOne = opParam.kvLoraRank;
|
||||
tilingData->splitRmsNormSizeTwo = opParam.qkRopeHeadDim;
|
||||
tilingData->ropeSplitSizeOne = opParam.qkRopeHeadDim;
|
||||
tilingData->ropeSplitSizeTwo = opParam.qkNopeHeadDim;
|
||||
tilingData->hiddenStrideRope = opParam.qkNopeHeadDim + opParam.qkRopeHeadDim;
|
||||
tilingData->qkNopeHeadDim = opParam.qkNopeHeadDim;
|
||||
tilingData->avgFactor = 1.0f / static_cast<float>(opParam.qLoraRank);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -631,6 +656,8 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
|
||||
const at::Tensor &hiddenState,
|
||||
const at::Tensor &wdqkv,
|
||||
const at::Tensor &wuk,
|
||||
const at::Tensor &gamma1,
|
||||
const at::Tensor &kv_cache_rope,
|
||||
c10::optional<c10::string_view> cache_mode,
|
||||
c10::optional<c10::string_view> quant_mode,
|
||||
bool enable_inner_out
|
||||
@@ -656,6 +683,12 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
|
||||
int32_t headNum = wuk.sizes()[0];
|
||||
uint32_t hiddenStateDim = hiddenState.sizes().back();
|
||||
|
||||
// Derive MLA dimensions from tensor shapes
|
||||
uint32_t qkNopeHeadDim = wuk.sizes()[1];
|
||||
uint32_t kvLoraRank = wuk.sizes()[2];
|
||||
uint32_t qLoraRank = gamma1.sizes()[0];
|
||||
uint32_t qkRopeHeadDim = kv_cache_rope.sizes().back();
|
||||
|
||||
OpParam opParam;
|
||||
opParam.hiddenStateDim = hiddenStateDim;
|
||||
opParam.N = N;
|
||||
@@ -664,6 +697,10 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
|
||||
opParam.quantMode = static_cast<QuantMode>(quantMode);
|
||||
opParam.inDtype = hiddenState.options().dtype();
|
||||
opParam.enableInnerOut = enable_inner_out;
|
||||
opParam.qLoraRank = qLoraRank;
|
||||
opParam.qkNopeHeadDim = qkNopeHeadDim;
|
||||
opParam.qkRopeHeadDim = qkRopeHeadDim;
|
||||
opParam.kvLoraRank = kvLoraRank;
|
||||
if (wdqkv.options().dtype() == at::kBFloat16 || wdqkv.options().dtype() == at::kHalf) {
|
||||
opParam.isWeightQuantized = 0;
|
||||
} else {
|
||||
|
||||
@@ -95,6 +95,18 @@ struct MlaTilingData {
|
||||
uint32_t hiddenStateDim{7168};
|
||||
|
||||
uint32_t isWeightQuantized{1};
|
||||
|
||||
// Model-specific MLA dimensions (derived from tensor shapes)
|
||||
uint32_t mm1OutSize{2112}; // q_lora_rank + kv_lora_rank + qk_rope_head_dim
|
||||
uint32_t splitSizeOne{576}; // kv_lora_rank + qk_rope_head_dim
|
||||
uint32_t splitSizeTwo{1536}; // q_lora_rank
|
||||
uint32_t splitRmsNormSizeOne{512}; // kv_lora_rank
|
||||
uint32_t splitRmsNormSizeTwo{64}; // qk_rope_head_dim
|
||||
uint32_t ropeSplitSizeOne{64}; // qk_rope_head_dim
|
||||
uint32_t ropeSplitSizeTwo{128}; // qk_nope_head_dim
|
||||
uint32_t hiddenStrideRope{192}; // qk_nope_head_dim + qk_rope_head_dim
|
||||
uint32_t qkNopeHeadDim{128}; // for RoPE offset calc
|
||||
float avgFactor{0.000651041666f}; // 1/splitSizeTwo (1/qLoraRank), for RmsNorm avg
|
||||
};
|
||||
|
||||
#endif // MLAPREPROCESS_TILING_H
|
||||
|
||||
Reference in New Issue
Block a user