[feat] parameterize hardcoded MLA dimensions to support GLM5-W8A8 (#6902)
Derive MLA dimension constants (q_lora_rank, qk_nope_head_dim, etc.)
from tensor shapes at runtime instead of hardcoding DeepSeek V3 values.
This enables the mla_preprocess fused op to work with both DeepSeek V3
and GLM5 models without Python API changes.
- Add 9 dimension fields to MlaTilingData with DeepSeek V3 defaults
- Add OpParam fields and dynamize all host-side tiling functions
- Derive dimensions from wuk, gamma1, kv_cache_rope tensor shapes
- Replace 310+ hardcoded constants across 4 kernel .hpp files
- Remove unused MMSIZE1/MMSIZE2 constants
### What this PR does / why we need it?
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: liuchenbing <chenliumail@163.com>
Co-authored-by: liuchenbing <chenliumail@163.com>
This commit is contained in:
@@ -136,6 +136,18 @@ extern "C" __global__ __aicore__ void mla_preprocess(
|
||||
mlaTilingData.s4Offset = tilingData->s4Offset;
|
||||
mlaTilingData.s5Offset = tilingData->s5Offset;
|
||||
|
||||
// Model-specific MLA dimensions
|
||||
mlaTilingData.mm1OutSize = tilingData->mm1OutSize;
|
||||
mlaTilingData.splitSizeOne = tilingData->splitSizeOne;
|
||||
mlaTilingData.splitSizeTwo = tilingData->splitSizeTwo;
|
||||
mlaTilingData.splitRmsNormSizeOne = tilingData->splitRmsNormSizeOne;
|
||||
mlaTilingData.splitRmsNormSizeTwo = tilingData->splitRmsNormSizeTwo;
|
||||
mlaTilingData.ropeSplitSizeOne = tilingData->ropeSplitSizeOne;
|
||||
mlaTilingData.ropeSplitSizeTwo = tilingData->ropeSplitSizeTwo;
|
||||
mlaTilingData.hiddenStrideRope = tilingData->hiddenStrideRope;
|
||||
mlaTilingData.qkNopeHeadDim = tilingData->qkNopeHeadDim;
|
||||
mlaTilingData.avgFactor = tilingData->avgFactor;
|
||||
|
||||
GM_ADDR s1 = workspace + static_cast<uint64_t>(mlaTilingData.s1Offset);
|
||||
GM_ADDR s2 = workspace + static_cast<uint64_t>(mlaTilingData.s2Offset);
|
||||
GM_ADDR s3 = workspace + static_cast<uint64_t>(mlaTilingData.s3Offset);
|
||||
|
||||
Reference in New Issue
Block a user