remove redundant params in mla_preprocess kernel (#3530)
### What this PR does / why we need it? This pull request removes the redundant parameters `gamma1` and `beta1` (also named `gamma0`/`beta0` in some places) from the `mla_preprocess` kernel and its calling hierarchy. The changes are consistent across C++ kernel code, bindings, and Python call sites. The parameters were unused in the lower-level functions, so their removal is a good cleanup. ### Does this PR introduce _any_ user-facing change? The python interface of the kernel is affected, and the params of `gamma0` and `beta0` are not needed. ### How was this patch tested? The unit-test of the kernel is adapted accordingly. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: mojave2 <chenchen145@huawei.com>
This commit is contained in:
@@ -19,7 +19,7 @@
|
|||||||
#include "../op_host/tiling/mla_preprocess_tiling.h"
|
#include "../op_host/tiling/mla_preprocess_tiling.h"
|
||||||
|
|
||||||
extern "C" __global__ __aicore__ void mla_preprocess(
|
extern "C" __global__ __aicore__ void mla_preprocess(
|
||||||
GM_ADDR hiddenState, GM_ADDR gamma1, GM_ADDR beta1, GM_ADDR quantScale1, GM_ADDR quantOffset1, GM_ADDR wdqkv,
|
GM_ADDR hiddenState, GM_ADDR quantScale1, GM_ADDR quantOffset1, GM_ADDR wdqkv,
|
||||||
GM_ADDR bias1, GM_ADDR gamma2, GM_ADDR beta2, GM_ADDR quantScale2, GM_ADDR quantOffset2, GM_ADDR gamma3,
|
GM_ADDR bias1, GM_ADDR gamma2, GM_ADDR beta2, GM_ADDR quantScale2, GM_ADDR quantOffset2, GM_ADDR gamma3,
|
||||||
GM_ADDR sin1, GM_ADDR cos1, GM_ADDR sin2, GM_ADDR cos2, GM_ADDR keycache, GM_ADDR slotMapping, GM_ADDR wuq,
|
GM_ADDR sin1, GM_ADDR cos1, GM_ADDR sin2, GM_ADDR cos2, GM_ADDR keycache, GM_ADDR slotMapping, GM_ADDR wuq,
|
||||||
GM_ADDR bias2, GM_ADDR wuk, GM_ADDR descale1, GM_ADDR descale2, GM_ADDR ctkvScale, GM_ADDR qnopeScale, GM_ADDR q,
|
GM_ADDR bias2, GM_ADDR wuk, GM_ADDR descale1, GM_ADDR descale2, GM_ADDR ctkvScale, GM_ADDR qnopeScale, GM_ADDR q,
|
||||||
@@ -143,7 +143,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
|
|||||||
case KEY_FP16_CACHEMODE_0_QUANTMODE_0: {
|
case KEY_FP16_CACHEMODE_0_QUANTMODE_0: {
|
||||||
MLAPO_FP16::MLAOperation<CACHE_MODE_KVCACHE, DataFormat::NZ, DataFormat::NZ, DataFormat::ND> opFp16Cm0Qm0(
|
MLAPO_FP16::MLAOperation<CACHE_MODE_KVCACHE, DataFormat::NZ, DataFormat::NZ, DataFormat::ND> opFp16Cm0Qm0(
|
||||||
mlaTilingData, tiling);
|
mlaTilingData, tiling);
|
||||||
opFp16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
opFp16Cm0Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||||
s1, s2, s3);
|
s1, s2, s3);
|
||||||
@@ -158,7 +158,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
|
|||||||
case KEY_FP16_CACHEMODE_1_QUANTMODE_0: {
|
case KEY_FP16_CACHEMODE_1_QUANTMODE_0: {
|
||||||
MLAPO_FP16::MLAOperation<CACHE_MODE_KROPE_CTKV, DataFormat::NZ, DataFormat::NZ, DataFormat::ND>
|
MLAPO_FP16::MLAOperation<CACHE_MODE_KROPE_CTKV, DataFormat::NZ, DataFormat::NZ, DataFormat::ND>
|
||||||
opFp16Cm1Qm0(mlaTilingData, tiling);
|
opFp16Cm1Qm0(mlaTilingData, tiling);
|
||||||
opFp16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
opFp16Cm1Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||||
s1, s2, s3);
|
s1, s2, s3);
|
||||||
@@ -174,7 +174,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
|
|||||||
MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
||||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||||
opBf16Cm0Qm0(mlaTilingData, tiling);
|
opBf16Cm0Qm0(mlaTilingData, tiling);
|
||||||
opBf16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
opBf16Cm0Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||||
s1, s2, s3, s4, s5);
|
s1, s2, s3, s4, s5);
|
||||||
@@ -190,7 +190,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
|
|||||||
MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
||||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||||
opBf16Cm1Qm0(mlaTilingData, tiling);
|
opBf16Cm1Qm0(mlaTilingData, tiling);
|
||||||
opBf16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
opBf16Cm1Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||||
s1, s2, s3, s4, s5);
|
s1, s2, s3, s4, s5);
|
||||||
@@ -206,7 +206,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
|
|||||||
MLAPO_BF16::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
MLAPO_BF16::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
||||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||||
opBf16Cm3Qm0(mlaTilingData, tiling);
|
opBf16Cm3Qm0(mlaTilingData, tiling);
|
||||||
opBf16Cm3Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
opBf16Cm3Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||||
s1, s2, s3, s4, s5);
|
s1, s2, s3, s4, s5);
|
||||||
@@ -230,8 +230,6 @@ namespace vllm_ascend {
|
|||||||
extern void mla_preprocess_impl(
|
extern void mla_preprocess_impl(
|
||||||
void* stream,
|
void* stream,
|
||||||
void* hidden_state,
|
void* hidden_state,
|
||||||
void* gamma1,
|
|
||||||
void* beta1,
|
|
||||||
void* quant_scale1,
|
void* quant_scale1,
|
||||||
void* quant_offset1,
|
void* quant_offset1,
|
||||||
void* wdqkv,
|
void* wdqkv,
|
||||||
@@ -264,8 +262,6 @@ extern void mla_preprocess_impl(
|
|||||||
{
|
{
|
||||||
mla_preprocess<<<block_dim, nullptr, stream>>>(
|
mla_preprocess<<<block_dim, nullptr, stream>>>(
|
||||||
hidden_state,
|
hidden_state,
|
||||||
gamma1,
|
|
||||||
beta1,
|
|
||||||
quant_scale1,
|
quant_scale1,
|
||||||
quant_offset1,
|
quant_offset1,
|
||||||
wdqkv,
|
wdqkv,
|
||||||
|
|||||||
@@ -2388,7 +2388,7 @@ public:
|
|||||||
this->mlaParams = mlaParams_;
|
this->mlaParams = mlaParams_;
|
||||||
}
|
}
|
||||||
|
|
||||||
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm,
|
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
|
||||||
GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm,
|
GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm,
|
||||||
GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm,
|
GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm,
|
||||||
GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm,
|
GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm,
|
||||||
@@ -2426,7 +2426,6 @@ public:
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(hiddenStateGm));
|
hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(hiddenStateGm));
|
||||||
gamma1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gamma1Gm));
|
|
||||||
quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(quantScale1Gm));
|
quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(quantScale1Gm));
|
||||||
quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm));
|
quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm));
|
||||||
wdqkvGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wdqkvGm));
|
wdqkvGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wdqkvGm));
|
||||||
@@ -2444,7 +2443,6 @@ public:
|
|||||||
qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(qGm2));
|
qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(qGm2));
|
||||||
bias1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias1Gm));
|
bias1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias1Gm));
|
||||||
bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm));
|
bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm));
|
||||||
beta1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta1Gm));
|
|
||||||
beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta2Gm));
|
beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta2Gm));
|
||||||
|
|
||||||
#ifdef __DAV_C220_VEC__
|
#ifdef __DAV_C220_VEC__
|
||||||
@@ -2711,7 +2709,6 @@ private:
|
|||||||
|
|
||||||
AscendC::GlobalTensor<InDtype> hiddenStateGmTensor;
|
AscendC::GlobalTensor<InDtype> hiddenStateGmTensor;
|
||||||
|
|
||||||
AscendC::GlobalTensor<InDtype> gamma1GmTensor;
|
|
||||||
AscendC::GlobalTensor<InDtype> quantScale1GmTensor;
|
AscendC::GlobalTensor<InDtype> quantScale1GmTensor;
|
||||||
AscendC::GlobalTensor<int8_t> quantOffset1GmTensor;
|
AscendC::GlobalTensor<int8_t> quantOffset1GmTensor;
|
||||||
|
|
||||||
@@ -2741,7 +2738,6 @@ private:
|
|||||||
AscendC::GlobalTensor<float> s5GmTensor;
|
AscendC::GlobalTensor<float> s5GmTensor;
|
||||||
AscendC::GlobalTensor<float> descale1gmTensor;
|
AscendC::GlobalTensor<float> descale1gmTensor;
|
||||||
AscendC::GlobalTensor<float> descale2gmTensor;
|
AscendC::GlobalTensor<float> descale2gmTensor;
|
||||||
AscendC::GlobalTensor<InDtype> beta1GmTensor;
|
|
||||||
AscendC::GlobalTensor<InDtype> beta2GmTensor;
|
AscendC::GlobalTensor<InDtype> beta2GmTensor;
|
||||||
|
|
||||||
AscendC::GlobalTensor<int32_t> bias1gmTensor;
|
AscendC::GlobalTensor<int32_t> bias1gmTensor;
|
||||||
|
|||||||
@@ -294,8 +294,7 @@ class Quant
|
|||||||
public:
|
public:
|
||||||
__aicore__ inline Quant() {}
|
__aicore__ inline Quant() {}
|
||||||
|
|
||||||
__aicore__ inline void Init(AscendC::GlobalTensor<T> gammaGmTensor, AscendC::GlobalTensor<T> betaGmTensor,
|
__aicore__ inline void Init(AscendC::GlobalTensor<T> quantScaleGmTensor,
|
||||||
AscendC::GlobalTensor<T> quantScaleGmTensor,
|
|
||||||
AscendC::GlobalTensor<int8_t> quantOffsetGmTensor,
|
AscendC::GlobalTensor<int8_t> quantOffsetGmTensor,
|
||||||
AscendC::GlobalTensor<T> inputGmTensor, AscendC::GlobalTensor<int8_t> outputGmTensor,
|
AscendC::GlobalTensor<T> inputGmTensor, AscendC::GlobalTensor<int8_t> outputGmTensor,
|
||||||
uint32_t stride, uint32_t num_col, float avg_factor, uint64_t gm_offset,
|
uint32_t stride, uint32_t num_col, float avg_factor, uint64_t gm_offset,
|
||||||
@@ -2037,7 +2036,7 @@ public:
|
|||||||
this->mlaParams = mlaParams_;
|
this->mlaParams = mlaParams_;
|
||||||
}
|
}
|
||||||
|
|
||||||
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm,
|
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
|
||||||
GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm,
|
GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm,
|
||||||
GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm,
|
GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm,
|
||||||
GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm,
|
GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm,
|
||||||
@@ -2057,7 +2056,6 @@ public:
|
|||||||
mm_w8a8_1.PreloadDoubleWeight();
|
mm_w8a8_1.PreloadDoubleWeight();
|
||||||
#endif
|
#endif
|
||||||
hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(hiddenStateGm));
|
hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(hiddenStateGm));
|
||||||
gamma1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(gamma1Gm));
|
|
||||||
quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(quantScale1Gm));
|
quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(quantScale1Gm));
|
||||||
quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm));
|
quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm));
|
||||||
|
|
||||||
@@ -2081,7 +2079,6 @@ public:
|
|||||||
qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(qGm2));
|
qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(qGm2));
|
||||||
bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm));
|
bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm));
|
||||||
|
|
||||||
beta1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(beta1Gm));
|
|
||||||
beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(beta2Gm));
|
beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(beta2Gm));
|
||||||
#ifdef __DAV_C220_CUBE__
|
#ifdef __DAV_C220_CUBE__
|
||||||
mm_w8a8_2.Init(s1GmTensor, wuqGmTensor, bias2gmTensor, descale2gmTensor, s2GmTensor, mlaParams, 1);
|
mm_w8a8_2.Init(s1GmTensor, wuqGmTensor, bias2gmTensor, descale2gmTensor, s2GmTensor, mlaParams, 1);
|
||||||
@@ -2105,7 +2102,7 @@ public:
|
|||||||
row_work_ = 0;
|
row_work_ = 0;
|
||||||
}
|
}
|
||||||
this->splitN = mlaParams.perTaskNum;
|
this->splitN = mlaParams.perTaskNum;
|
||||||
Quant1.Init(gamma1GmTensor, beta1GmTensor, quantScale1GmTensor, quantOffset1GmTensor, hiddenStateGmTensor,
|
Quant1.Init(quantScale1GmTensor, quantOffset1GmTensor, hiddenStateGmTensor,
|
||||||
s1GmTensor, 0, num_col_1, 0.0001395089285,
|
s1GmTensor, 0, num_col_1, 0.0001395089285,
|
||||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_1,
|
vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_1,
|
||||||
vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_1, row_work_, mlaParams);
|
vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_1, row_work_, mlaParams);
|
||||||
@@ -2316,7 +2313,6 @@ private:
|
|||||||
|
|
||||||
AscendC::GlobalTensor<half> hiddenStateGmTensor;
|
AscendC::GlobalTensor<half> hiddenStateGmTensor;
|
||||||
|
|
||||||
AscendC::GlobalTensor<half> gamma1GmTensor;
|
|
||||||
AscendC::GlobalTensor<half> quantScale1GmTensor;
|
AscendC::GlobalTensor<half> quantScale1GmTensor;
|
||||||
AscendC::GlobalTensor<int8_t> quantOffset1GmTensor;
|
AscendC::GlobalTensor<int8_t> quantOffset1GmTensor;
|
||||||
|
|
||||||
@@ -2343,7 +2339,6 @@ private:
|
|||||||
AscendC::GlobalTensor<half> s3GmTensor;
|
AscendC::GlobalTensor<half> s3GmTensor;
|
||||||
AscendC::GlobalTensor<uint64_t> descale1gmTensor;
|
AscendC::GlobalTensor<uint64_t> descale1gmTensor;
|
||||||
AscendC::GlobalTensor<uint64_t> descale2gmTensor;
|
AscendC::GlobalTensor<uint64_t> descale2gmTensor;
|
||||||
AscendC::GlobalTensor<half> beta1GmTensor;
|
|
||||||
AscendC::GlobalTensor<half> beta2GmTensor;
|
AscendC::GlobalTensor<half> beta2GmTensor;
|
||||||
|
|
||||||
AscendC::GlobalTensor<int32_t> bias1gmTensor;
|
AscendC::GlobalTensor<int32_t> bias1gmTensor;
|
||||||
|
|||||||
@@ -128,8 +128,6 @@ namespace vllm_ascend {
|
|||||||
extern void mla_preprocess_impl(
|
extern void mla_preprocess_impl(
|
||||||
void* stream,
|
void* stream,
|
||||||
void* hidden_state,
|
void* hidden_state,
|
||||||
void* gamma1,
|
|
||||||
void* beta1,
|
|
||||||
void* quant_scale1,
|
void* quant_scale1,
|
||||||
void* quant_offset1,
|
void* quant_offset1,
|
||||||
void* wdqkv,
|
void* wdqkv,
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
|
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
|
||||||
const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv,
|
const at::Tensor &hiddenState, const at::Tensor &wdqkv,
|
||||||
const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
|
const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
|
||||||
const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
|
const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
|
||||||
const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
|
const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
|
||||||
@@ -135,8 +135,6 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preproces
|
|||||||
);
|
);
|
||||||
|
|
||||||
void *hidden_state_ptr = hiddenState.data_ptr();
|
void *hidden_state_ptr = hiddenState.data_ptr();
|
||||||
void *gamma0_ptr = gamma0.data_ptr();
|
|
||||||
void *beta0_ptr = beta0.data_ptr();
|
|
||||||
void *quant_scale0_ptr = quant_scale0.data_ptr();
|
void *quant_scale0_ptr = quant_scale0.data_ptr();
|
||||||
void *quant_offset0_ptr = quant_offset0.data_ptr();
|
void *quant_offset0_ptr = quant_offset0.data_ptr();
|
||||||
void *wdqkv_ptr = wdqkv.data_ptr();
|
void *wdqkv_ptr = wdqkv.data_ptr();
|
||||||
@@ -168,12 +166,12 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preproces
|
|||||||
at_npu::native::OpCommand cmd;
|
at_npu::native::OpCommand cmd;
|
||||||
cmd.Name("mla_preprocess");
|
cmd.Name("mla_preprocess");
|
||||||
|
|
||||||
cmd.SetCustomHandler([stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
|
cmd.SetCustomHandler([stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
|
||||||
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr,
|
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr,
|
||||||
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
|
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
|
||||||
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
|
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
|
||||||
tiling_ptr, block_dim]() -> int {
|
tiling_ptr, block_dim]() -> int {
|
||||||
mla_preprocess_impl(stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
|
mla_preprocess_impl(stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
|
||||||
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr,
|
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr,
|
||||||
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
|
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
|
||||||
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
|
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
|
||||||
@@ -502,7 +500,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
|||||||
ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand);
|
ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand);
|
||||||
|
|
||||||
ops.def(
|
ops.def(
|
||||||
"mla_preprocess(Tensor hiddenState, Tensor gamma0, Tensor beta0, Tensor wdqkv,"
|
"mla_preprocess(Tensor hiddenState, Tensor wdqkv,"
|
||||||
" Tensor descale0, Tensor gamma1, Tensor beta1, Tensor wuq, Tensor descale1,"
|
" Tensor descale0, Tensor gamma1, Tensor beta1, Tensor wuq, Tensor descale1,"
|
||||||
" Tensor gamma2, Tensor cos, Tensor sin, Tensor wuk, Tensor kv_cache,"
|
" Tensor gamma2, Tensor cos, Tensor sin, Tensor wuk, Tensor kv_cache,"
|
||||||
" Tensor kv_cache_rope, Tensor slotmapping, Tensor quant_scale0,"
|
" Tensor kv_cache_rope, Tensor slotmapping, Tensor quant_scale0,"
|
||||||
|
|||||||
@@ -83,8 +83,6 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_
|
|||||||
|
|
||||||
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
|
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
|
||||||
const at::Tensor &hiddenState,
|
const at::Tensor &hiddenState,
|
||||||
const at::Tensor &gamma0,
|
|
||||||
const at::Tensor &beta0,
|
|
||||||
const at::Tensor &wdqkv,
|
const at::Tensor &wdqkv,
|
||||||
const at::Tensor &descale0,
|
const at::Tensor &descale0,
|
||||||
const at::Tensor &gamma1,
|
const at::Tensor &gamma1,
|
||||||
|
|||||||
@@ -18,8 +18,6 @@ def test_mla_preprocess_kernel():
|
|||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu()
|
hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu()
|
||||||
gamma0 = torch.randn((N_7168), dtype=dtype).npu()
|
|
||||||
beta0 = torch.randn((N_7168), dtype=dtype).npu()
|
|
||||||
quant_scale0 = torch.randn((1, ), dtype=dtype).npu()
|
quant_scale0 = torch.randn((1, ), dtype=dtype).npu()
|
||||||
quant_offset0 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
|
quant_offset0 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
|
||||||
|
|
||||||
@@ -74,8 +72,6 @@ def test_mla_preprocess_kernel():
|
|||||||
|
|
||||||
torch.ops._C_ascend.mla_preprocess(
|
torch.ops._C_ascend.mla_preprocess(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
gamma0,
|
|
||||||
beta0,
|
|
||||||
wdqkv,
|
wdqkv,
|
||||||
de_scale0,
|
de_scale0,
|
||||||
gamma1,
|
gamma1,
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.config import CompilationConfig, CUDAGraphMode
|
from vllm.config import CompilationConfig, CUDAGraphMode
|
||||||
@@ -108,3 +111,19 @@ def test_mtp2_correctness_full_graph(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
):
|
||||||
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
|
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"})
|
||||||
|
def test_mtp_correctness_piecewise_graph_with_mlapo_kernel(
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
mtp_correctness(sampling_config, model_name, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"})
|
||||||
|
def test_mtp_correctness_full_graph_with_mlapo_kernel(
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
mtp_correctness(sampling_config, model_name, 1, CUDAGraphMode.FULL)
|
||||||
|
|||||||
@@ -716,17 +716,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.qb_qt_bias = qb_qt_bias.reshape(
|
self.qb_qt_bias = qb_qt_bias.reshape(
|
||||||
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
||||||
|
|
||||||
device = self.q_proj.weight.device
|
device = self.q_a_proj.weight.device
|
||||||
self.gamma0 = torch.ones(
|
|
||||||
[self.fused_qkv_a_proj.weight.shape[-1]],
|
|
||||||
dtype=act_dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
self.beta0 = torch.zeros(
|
|
||||||
[self.fused_qkv_a_proj.weight.shape[-1]],
|
|
||||||
dtype=act_dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
self.gamma1 = self.q_a_layernorm.weight.data
|
self.gamma1 = self.q_a_layernorm.weight.data
|
||||||
self.beta1 = self.q_a_layernorm.bias.data
|
self.beta1 = self.q_a_layernorm.bias.data
|
||||||
self.gamma2 = self.kv_a_layernorm.weight.data
|
self.gamma2 = self.kv_a_layernorm.weight.data
|
||||||
@@ -1085,8 +1075,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
torch.ops._C_ascend.mla_preprocess(
|
torch.ops._C_ascend.mla_preprocess(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self.gamma0,
|
|
||||||
self.beta0,
|
|
||||||
self.wd_qkv,
|
self.wd_qkv,
|
||||||
self.deq_scale_qkv,
|
self.deq_scale_qkv,
|
||||||
self.gamma1,
|
self.gamma1,
|
||||||
|
|||||||
Reference in New Issue
Block a user