Files
xc-llm-ascend/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h
liuchen2026fly 542258ac9d [feat] parameterize hardcoded MLA dimensions to support GLM5-W8A8 (#6902)
Derive MLA dimension constants (q_lora_rank, qk_nope_head_dim, etc.)
from tensor shapes at runtime instead of hardcoding DeepSeek V3 values.
This enables the mla_preprocess fused op to work with both DeepSeek V3
and GLM5 models without Python API changes.

- Add 9 dimension fields to MlaTilingData with DeepSeek V3 defaults
- Add OpParam fields and dynamize all host-side tiling functions
- Derive dimensions from wuk, gamma1, kv_cache_rope tensor shapes
- Replace 310+ hardcoded constants across 4 kernel .hpp files
- Remove unused MMSIZE1/MMSIZE2 constants

### What this PR does / why we need it?

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: liuchenbing <chenliumail@163.com>
Co-authored-by: liuchenbing <chenliumail@163.com>
2026-03-09 20:17:21 +08:00

113 lines
3.5 KiB
C++

// Adapted from
// https://gitee.com/ascend/ascend-transformer-boost
//
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
// This file is a part of the CANN Open Software.
// Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.
//
#ifndef MLAPREPROCESS_TILING_H
#define MLAPREPROCESS_TILING_H
#include <cstdint>
struct PpMatmulTilingData {
uint32_t numBatch{0};
uint32_t m{0};
uint32_t k{0};
uint32_t n{0};
uint32_t m0{0};
uint32_t k0{0};
uint32_t n0{0};
uint32_t mLoop{0};
uint32_t kLoop{0};
uint32_t nLoop{0};
uint32_t coreLoop{0};
uint32_t swizzleCount{0};
uint32_t swizzleDirect{0};
uint32_t enShuffleK{0};
uint32_t blockDim{0};
uint32_t enLoadAllAmat{0};
uint32_t b0matPingPongBufferLen{0};
};
struct MlaTilingData {
uint32_t tilingKey{0};
uint64_t userWorkspaceSize{0};
uint64_t s1Offset{0};
uint64_t s2Offset{0};
uint64_t s3Offset{0};
uint64_t s4Offset{0};
uint64_t s5Offset{0};
uint32_t numCore{0};
uint32_t n{0};
uint32_t perTaskNum{0};
uint32_t resTaskNum{0};
PpMatmulTilingData mm1;
PpMatmulTilingData mm2;
PpMatmulTilingData mm3;
// rms1
uint32_t rmsNumCore1{0};
uint32_t rmsNumCol1{0};
uint32_t rmsNumRow1{0};
uint32_t rmsQuantMin1{0};
// rms2
uint32_t rmsNumCore2{0};
uint32_t rmsNumCol2{0};
uint32_t rmsNumRow2{0};
uint32_t rmsQuantMin2{0};
uint32_t hiddenSizeQ{0};
uint32_t headNumQ{0};
uint32_t headDim{0};
uint32_t concatSize{0};
uint32_t rotaryCoeff{0};
uint32_t ntokens{0};
uint32_t realCore{0};
uint32_t nlCoreRun{0};
uint32_t lCoreRun{0};
uint32_t maxNPerLoopForUb{0};
uint32_t preCoreLoopTime{0};
uint32_t preCoreLoopNLast{0};
uint32_t lastCoreLoopTime{0};
uint32_t lastCoreLoopNLast{0};
// EinSumQuant
uint32_t esqFrontCore{0};
uint32_t esqTailCore{0};
uint32_t esqFrontCoreBatch{0};
uint32_t esqTailCoreBatch{0};
uint32_t esqHeadNum{0};
uint32_t esqColNum{0};
uint32_t esqUbHeadLoop{0};
uint32_t esqHeadPerLoop{0};
uint32_t esqHeadTail{0};
uint32_t esqColLoop{0};
uint32_t esqColTail{0};
// hidden state dimension
uint32_t hiddenStateDim{7168};
uint32_t isWeightQuantized{1};
// Model-specific MLA dimensions (derived from tensor shapes)
uint32_t mm1OutSize{2112}; // q_lora_rank + kv_lora_rank + qk_rope_head_dim
uint32_t splitSizeOne{576}; // kv_lora_rank + qk_rope_head_dim
uint32_t splitSizeTwo{1536}; // q_lora_rank
uint32_t splitRmsNormSizeOne{512}; // kv_lora_rank
uint32_t splitRmsNormSizeTwo{64}; // qk_rope_head_dim
uint32_t ropeSplitSizeOne{64}; // qk_rope_head_dim
uint32_t ropeSplitSizeTwo{128}; // qk_nope_head_dim
uint32_t hiddenStrideRope{192}; // qk_nope_head_dim + qk_rope_head_dim
uint32_t qkNopeHeadDim{128}; // for RoPE offset calc
float avgFactor{0.000651041666f}; // 1/splitSizeTwo (1/qLoraRank), for RmsNorm avg
};
#endif // MLAPREPROCESS_TILING_H