remove redundant params in mla_preprocess kernel (#3530)
### What this PR does / why we need it? This pull request removes the redundant parameters `gamma1` and `beta1` (also named `gamma0`/`beta0` in some places) from the `mla_preprocess` kernel and its calling hierarchy. The changes are consistent across C++ kernel code, bindings, and Python call sites. The parameters were unused in the lower-level functions, so their removal is a good cleanup. ### Does this PR introduce _any_ user-facing change? The python interface of the kernel is affected, and the params of `gamma0` and `beta0` are not needed. ### How was this patch tested? The unit-test of the kernel is adapted accordingly. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: mojave2 <chenchen145@huawei.com>
This commit is contained in:
@@ -294,8 +294,7 @@ class Quant
|
||||
public:
|
||||
__aicore__ inline Quant() {}
|
||||
|
||||
__aicore__ inline void Init(AscendC::GlobalTensor<T> gammaGmTensor, AscendC::GlobalTensor<T> betaGmTensor,
|
||||
AscendC::GlobalTensor<T> quantScaleGmTensor,
|
||||
__aicore__ inline void Init(AscendC::GlobalTensor<T> quantScaleGmTensor,
|
||||
AscendC::GlobalTensor<int8_t> quantOffsetGmTensor,
|
||||
AscendC::GlobalTensor<T> inputGmTensor, AscendC::GlobalTensor<int8_t> outputGmTensor,
|
||||
uint32_t stride, uint32_t num_col, float avg_factor, uint64_t gm_offset,
|
||||
@@ -2037,7 +2036,7 @@ public:
|
||||
this->mlaParams = mlaParams_;
|
||||
}
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm,
|
||||
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
|
||||
GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm,
|
||||
GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm,
|
||||
GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm,
|
||||
@@ -2057,7 +2056,6 @@ public:
|
||||
mm_w8a8_1.PreloadDoubleWeight();
|
||||
#endif
|
||||
hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(hiddenStateGm));
|
||||
gamma1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(gamma1Gm));
|
||||
quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(quantScale1Gm));
|
||||
quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm));
|
||||
|
||||
@@ -2081,7 +2079,6 @@ public:
|
||||
qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(qGm2));
|
||||
bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm));
|
||||
|
||||
beta1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(beta1Gm));
|
||||
beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(beta2Gm));
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
mm_w8a8_2.Init(s1GmTensor, wuqGmTensor, bias2gmTensor, descale2gmTensor, s2GmTensor, mlaParams, 1);
|
||||
@@ -2105,7 +2102,7 @@ public:
|
||||
row_work_ = 0;
|
||||
}
|
||||
this->splitN = mlaParams.perTaskNum;
|
||||
Quant1.Init(gamma1GmTensor, beta1GmTensor, quantScale1GmTensor, quantOffset1GmTensor, hiddenStateGmTensor,
|
||||
Quant1.Init(quantScale1GmTensor, quantOffset1GmTensor, hiddenStateGmTensor,
|
||||
s1GmTensor, 0, num_col_1, 0.0001395089285,
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_1,
|
||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_1, row_work_, mlaParams);
|
||||
@@ -2316,7 +2313,6 @@ private:
|
||||
|
||||
AscendC::GlobalTensor<half> hiddenStateGmTensor;
|
||||
|
||||
AscendC::GlobalTensor<half> gamma1GmTensor;
|
||||
AscendC::GlobalTensor<half> quantScale1GmTensor;
|
||||
AscendC::GlobalTensor<int8_t> quantOffset1GmTensor;
|
||||
|
||||
@@ -2343,7 +2339,6 @@ private:
|
||||
AscendC::GlobalTensor<half> s3GmTensor;
|
||||
AscendC::GlobalTensor<uint64_t> descale1gmTensor;
|
||||
AscendC::GlobalTensor<uint64_t> descale2gmTensor;
|
||||
AscendC::GlobalTensor<half> beta1GmTensor;
|
||||
AscendC::GlobalTensor<half> beta2GmTensor;
|
||||
|
||||
AscendC::GlobalTensor<int32_t> bias1gmTensor;
|
||||
|
||||
Reference in New Issue
Block a user