[feat] parameterize hardcoded MLA dimensions to support GLM5-W8A8 (#6902)

Derive MLA dimension constants (q_lora_rank, qk_nope_head_dim, etc.)
from tensor shapes at runtime instead of hardcoding DeepSeek V3 values.
This enables the mla_preprocess fused op to work with both DeepSeek V3
and GLM5 models without Python API changes.

- Add 9 dimension fields to MlaTilingData with DeepSeek V3 defaults
- Add OpParam fields and dynamize all host-side tiling functions
- Derive dimensions from wuk, gamma1, kv_cache_rope tensor shapes
- Replace 310+ hardcoded constants across 4 kernel .hpp files
- Remove unused MMSIZE1/MMSIZE2 constants

### What this PR does / why we need it?

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: liuchenbing <chenliumail@163.com>
Co-authored-by: liuchenbing <chenliumail@163.com>
This commit is contained in:
liuchen2026fly
2026-03-09 20:17:21 +08:00
committed by GitHub
parent 13adcbe44b
commit 542258ac9d
9 changed files with 508 additions and 342 deletions

View File

@@ -95,6 +95,18 @@ struct MlaTilingData {
uint32_t hiddenStateDim{7168};
uint32_t isWeightQuantized{1};
// Model-specific MLA dimensions (derived from tensor shapes)
uint32_t mm1OutSize{2112}; // q_lora_rank + kv_lora_rank + qk_rope_head_dim
uint32_t splitSizeOne{576}; // kv_lora_rank + qk_rope_head_dim
uint32_t splitSizeTwo{1536}; // q_lora_rank
uint32_t splitRmsNormSizeOne{512}; // kv_lora_rank
uint32_t splitRmsNormSizeTwo{64}; // qk_rope_head_dim
uint32_t ropeSplitSizeOne{64}; // qk_rope_head_dim
uint32_t ropeSplitSizeTwo{128}; // qk_nope_head_dim
uint32_t hiddenStrideRope{192}; // qk_nope_head_dim + qk_rope_head_dim
uint32_t qkNopeHeadDim{128}; // for RoPE offset calc
float avgFactor{0.000651041666f}; // 1/splitSizeTwo (1/qLoraRank), for RmsNorm avg
};
#endif // MLAPREPROCESS_TILING_H