remove redundant params in mla_preprocess kernel (#3530)

### What this PR does / why we need it?

This pull request removes the redundant parameters `gamma1` and `beta1`
(also named `gamma0`/`beta0` in some places) from the `mla_preprocess`
kernel and its calling hierarchy. The changes are consistent across C++
kernel code, bindings, and Python call sites. The parameters were unused
in the lower-level functions, so their removal is a good cleanup.

### Does this PR introduce _any_ user-facing change?

The python interface of the kernel is affected, and the params of
`gamma0` and `beta0` are not needed.

### How was this patch tested?

The unit-test of the kernel is adapted accordingly.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: mojave2 <chenchen145@huawei.com>
This commit is contained in:
Chen Chen
2025-10-21 19:20:13 +08:00
committed by GitHub
parent 80b8df881f
commit 6b290acfe1
9 changed files with 34 additions and 50 deletions

View File

@@ -2388,7 +2388,7 @@ public:
this->mlaParams = mlaParams_;
}
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm,
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm,
GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm,
GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm,
@@ -2426,7 +2426,6 @@ public:
#endif
hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(hiddenStateGm));
gamma1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gamma1Gm));
quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(quantScale1Gm));
quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm));
wdqkvGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wdqkvGm));
@@ -2444,7 +2443,6 @@ public:
qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(qGm2));
bias1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias1Gm));
bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm));
beta1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta1Gm));
beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta2Gm));
#ifdef __DAV_C220_VEC__
@@ -2711,7 +2709,6 @@ private:
AscendC::GlobalTensor<InDtype> hiddenStateGmTensor;
AscendC::GlobalTensor<InDtype> gamma1GmTensor;
AscendC::GlobalTensor<InDtype> quantScale1GmTensor;
AscendC::GlobalTensor<int8_t> quantOffset1GmTensor;
@@ -2741,7 +2738,6 @@ private:
AscendC::GlobalTensor<float> s5GmTensor;
AscendC::GlobalTensor<float> descale1gmTensor;
AscendC::GlobalTensor<float> descale2gmTensor;
AscendC::GlobalTensor<InDtype> beta1GmTensor;
AscendC::GlobalTensor<InDtype> beta2GmTensor;
AscendC::GlobalTensor<int32_t> bias1gmTensor;