[feat] parameterize hardcoded MLA dimensions to support GLM5-W8A8 (#6902)

Derive MLA dimension constants (q_lora_rank, qk_nope_head_dim, etc.)
from tensor shapes at runtime instead of hardcoding DeepSeek V3 values.
This enables the mla_preprocess fused op to work with both DeepSeek V3
and GLM5 models without Python API changes.

- Add 9 dimension fields to MlaTilingData with DeepSeek V3 defaults
- Add OpParam fields and dynamize all host-side tiling functions
- Derive dimensions from wuk, gamma1, kv_cache_rope tensor shapes
- Replace 310+ hardcoded constants across 4 kernel .hpp files
- Remove unused MMSIZE1/MMSIZE2 constants

### What this PR does / why we need it?

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: liuchenbing <chenliumail@163.com>
Co-authored-by: liuchenbing <chenliumail@163.com>
This commit is contained in:
liuchen2026fly
2026-03-09 20:17:21 +08:00
committed by GitHub
parent 13adcbe44b
commit 542258ac9d
9 changed files with 508 additions and 342 deletions

View File

@@ -136,6 +136,18 @@ extern "C" __global__ __aicore__ void mla_preprocess(
mlaTilingData.s4Offset = tilingData->s4Offset;
mlaTilingData.s5Offset = tilingData->s5Offset;
// Model-specific MLA dimensions
mlaTilingData.mm1OutSize = tilingData->mm1OutSize;
mlaTilingData.splitSizeOne = tilingData->splitSizeOne;
mlaTilingData.splitSizeTwo = tilingData->splitSizeTwo;
mlaTilingData.splitRmsNormSizeOne = tilingData->splitRmsNormSizeOne;
mlaTilingData.splitRmsNormSizeTwo = tilingData->splitRmsNormSizeTwo;
mlaTilingData.ropeSplitSizeOne = tilingData->ropeSplitSizeOne;
mlaTilingData.ropeSplitSizeTwo = tilingData->ropeSplitSizeTwo;
mlaTilingData.hiddenStrideRope = tilingData->hiddenStrideRope;
mlaTilingData.qkNopeHeadDim = tilingData->qkNopeHeadDim;
mlaTilingData.avgFactor = tilingData->avgFactor;
GM_ADDR s1 = workspace + static_cast<uint64_t>(mlaTilingData.s1Offset);
GM_ADDR s2 = workspace + static_cast<uint64_t>(mlaTilingData.s2Offset);
GM_ADDR s3 = workspace + static_cast<uint64_t>(mlaTilingData.s3Offset);