remove redundant params in mla_preprocess kernel (#3530)

### What this PR does / why we need it?

This pull request removes the redundant parameters `gamma1` and `beta1`
(also named `gamma0`/`beta0` in some places) from the `mla_preprocess`
kernel and its calling hierarchy. The changes are consistent across C++
kernel code, bindings, and Python call sites. The parameters were unused
in the lower-level functions, so their removal is a good cleanup.

### Does this PR introduce _any_ user-facing change?

The python interface of the kernel is affected, and the params of
`gamma0` and `beta0` are not needed.

### How was this patch tested?

The unit-test of the kernel is adapted accordingly.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: mojave2 <chenchen145@huawei.com>
This commit is contained in:
Chen Chen
2025-10-21 19:20:13 +08:00
committed by GitHub
parent 80b8df881f
commit 6b290acfe1
9 changed files with 34 additions and 50 deletions

View File

@@ -19,7 +19,7 @@
#include "../op_host/tiling/mla_preprocess_tiling.h" #include "../op_host/tiling/mla_preprocess_tiling.h"
extern "C" __global__ __aicore__ void mla_preprocess( extern "C" __global__ __aicore__ void mla_preprocess(
GM_ADDR hiddenState, GM_ADDR gamma1, GM_ADDR beta1, GM_ADDR quantScale1, GM_ADDR quantOffset1, GM_ADDR wdqkv, GM_ADDR hiddenState, GM_ADDR quantScale1, GM_ADDR quantOffset1, GM_ADDR wdqkv,
GM_ADDR bias1, GM_ADDR gamma2, GM_ADDR beta2, GM_ADDR quantScale2, GM_ADDR quantOffset2, GM_ADDR gamma3, GM_ADDR bias1, GM_ADDR gamma2, GM_ADDR beta2, GM_ADDR quantScale2, GM_ADDR quantOffset2, GM_ADDR gamma3,
GM_ADDR sin1, GM_ADDR cos1, GM_ADDR sin2, GM_ADDR cos2, GM_ADDR keycache, GM_ADDR slotMapping, GM_ADDR wuq, GM_ADDR sin1, GM_ADDR cos1, GM_ADDR sin2, GM_ADDR cos2, GM_ADDR keycache, GM_ADDR slotMapping, GM_ADDR wuq,
GM_ADDR bias2, GM_ADDR wuk, GM_ADDR descale1, GM_ADDR descale2, GM_ADDR ctkvScale, GM_ADDR qnopeScale, GM_ADDR q, GM_ADDR bias2, GM_ADDR wuk, GM_ADDR descale1, GM_ADDR descale2, GM_ADDR ctkvScale, GM_ADDR qnopeScale, GM_ADDR q,
@@ -143,7 +143,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
case KEY_FP16_CACHEMODE_0_QUANTMODE_0: { case KEY_FP16_CACHEMODE_0_QUANTMODE_0: {
MLAPO_FP16::MLAOperation<CACHE_MODE_KVCACHE, DataFormat::NZ, DataFormat::NZ, DataFormat::ND> opFp16Cm0Qm0( MLAPO_FP16::MLAOperation<CACHE_MODE_KVCACHE, DataFormat::NZ, DataFormat::NZ, DataFormat::ND> opFp16Cm0Qm0(
mlaTilingData, tiling); mlaTilingData, tiling);
opFp16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, opFp16Cm0Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3); s1, s2, s3);
@@ -158,7 +158,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
case KEY_FP16_CACHEMODE_1_QUANTMODE_0: { case KEY_FP16_CACHEMODE_1_QUANTMODE_0: {
MLAPO_FP16::MLAOperation<CACHE_MODE_KROPE_CTKV, DataFormat::NZ, DataFormat::NZ, DataFormat::ND> MLAPO_FP16::MLAOperation<CACHE_MODE_KROPE_CTKV, DataFormat::NZ, DataFormat::NZ, DataFormat::ND>
opFp16Cm1Qm0(mlaTilingData, tiling); opFp16Cm1Qm0(mlaTilingData, tiling);
opFp16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, opFp16Cm1Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3); s1, s2, s3);
@@ -174,7 +174,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT> QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm0Qm0(mlaTilingData, tiling); opBf16Cm0Qm0(mlaTilingData, tiling);
opBf16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, opBf16Cm0Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5); s1, s2, s3, s4, s5);
@@ -190,7 +190,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT> QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm1Qm0(mlaTilingData, tiling); opBf16Cm1Qm0(mlaTilingData, tiling);
opBf16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, opBf16Cm1Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5); s1, s2, s3, s4, s5);
@@ -206,7 +206,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
MLAPO_BF16::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, MLAPO_BF16::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT> QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm3Qm0(mlaTilingData, tiling); opBf16Cm3Qm0(mlaTilingData, tiling);
opBf16Cm3Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, opBf16Cm3Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5); s1, s2, s3, s4, s5);
@@ -230,8 +230,6 @@ namespace vllm_ascend {
extern void mla_preprocess_impl( extern void mla_preprocess_impl(
void* stream, void* stream,
void* hidden_state, void* hidden_state,
void* gamma1,
void* beta1,
void* quant_scale1, void* quant_scale1,
void* quant_offset1, void* quant_offset1,
void* wdqkv, void* wdqkv,
@@ -264,8 +262,6 @@ extern void mla_preprocess_impl(
{ {
mla_preprocess<<<block_dim, nullptr, stream>>>( mla_preprocess<<<block_dim, nullptr, stream>>>(
hidden_state, hidden_state,
gamma1,
beta1,
quant_scale1, quant_scale1,
quant_offset1, quant_offset1,
wdqkv, wdqkv,

View File

@@ -2388,7 +2388,7 @@ public:
this->mlaParams = mlaParams_; this->mlaParams = mlaParams_;
} }
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm, __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm, GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm,
GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm, GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm,
GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm, GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm,
@@ -2426,7 +2426,6 @@ public:
#endif #endif
hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(hiddenStateGm)); hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(hiddenStateGm));
gamma1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gamma1Gm));
quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(quantScale1Gm)); quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(quantScale1Gm));
quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm)); quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm));
wdqkvGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wdqkvGm)); wdqkvGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wdqkvGm));
@@ -2444,7 +2443,6 @@ public:
qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(qGm2)); qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(qGm2));
bias1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias1Gm)); bias1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias1Gm));
bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm)); bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm));
beta1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta1Gm));
beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta2Gm)); beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta2Gm));
#ifdef __DAV_C220_VEC__ #ifdef __DAV_C220_VEC__
@@ -2711,7 +2709,6 @@ private:
AscendC::GlobalTensor<InDtype> hiddenStateGmTensor; AscendC::GlobalTensor<InDtype> hiddenStateGmTensor;
AscendC::GlobalTensor<InDtype> gamma1GmTensor;
AscendC::GlobalTensor<InDtype> quantScale1GmTensor; AscendC::GlobalTensor<InDtype> quantScale1GmTensor;
AscendC::GlobalTensor<int8_t> quantOffset1GmTensor; AscendC::GlobalTensor<int8_t> quantOffset1GmTensor;
@@ -2741,7 +2738,6 @@ private:
AscendC::GlobalTensor<float> s5GmTensor; AscendC::GlobalTensor<float> s5GmTensor;
AscendC::GlobalTensor<float> descale1gmTensor; AscendC::GlobalTensor<float> descale1gmTensor;
AscendC::GlobalTensor<float> descale2gmTensor; AscendC::GlobalTensor<float> descale2gmTensor;
AscendC::GlobalTensor<InDtype> beta1GmTensor;
AscendC::GlobalTensor<InDtype> beta2GmTensor; AscendC::GlobalTensor<InDtype> beta2GmTensor;
AscendC::GlobalTensor<int32_t> bias1gmTensor; AscendC::GlobalTensor<int32_t> bias1gmTensor;

View File

@@ -294,8 +294,7 @@ class Quant
public: public:
__aicore__ inline Quant() {} __aicore__ inline Quant() {}
__aicore__ inline void Init(AscendC::GlobalTensor<T> gammaGmTensor, AscendC::GlobalTensor<T> betaGmTensor, __aicore__ inline void Init(AscendC::GlobalTensor<T> quantScaleGmTensor,
AscendC::GlobalTensor<T> quantScaleGmTensor,
AscendC::GlobalTensor<int8_t> quantOffsetGmTensor, AscendC::GlobalTensor<int8_t> quantOffsetGmTensor,
AscendC::GlobalTensor<T> inputGmTensor, AscendC::GlobalTensor<int8_t> outputGmTensor, AscendC::GlobalTensor<T> inputGmTensor, AscendC::GlobalTensor<int8_t> outputGmTensor,
uint32_t stride, uint32_t num_col, float avg_factor, uint64_t gm_offset, uint32_t stride, uint32_t num_col, float avg_factor, uint64_t gm_offset,
@@ -2037,7 +2036,7 @@ public:
this->mlaParams = mlaParams_; this->mlaParams = mlaParams_;
} }
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm, __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm, GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm,
GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm, GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm,
GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm, GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm,
@@ -2057,7 +2056,6 @@ public:
mm_w8a8_1.PreloadDoubleWeight(); mm_w8a8_1.PreloadDoubleWeight();
#endif #endif
hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(hiddenStateGm)); hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(hiddenStateGm));
gamma1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(gamma1Gm));
quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(quantScale1Gm)); quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(quantScale1Gm));
quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm)); quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm));
@@ -2081,7 +2079,6 @@ public:
qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(qGm2)); qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(qGm2));
bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm)); bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm));
beta1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(beta1Gm));
beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(beta2Gm)); beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(beta2Gm));
#ifdef __DAV_C220_CUBE__ #ifdef __DAV_C220_CUBE__
mm_w8a8_2.Init(s1GmTensor, wuqGmTensor, bias2gmTensor, descale2gmTensor, s2GmTensor, mlaParams, 1); mm_w8a8_2.Init(s1GmTensor, wuqGmTensor, bias2gmTensor, descale2gmTensor, s2GmTensor, mlaParams, 1);
@@ -2105,7 +2102,7 @@ public:
row_work_ = 0; row_work_ = 0;
} }
this->splitN = mlaParams.perTaskNum; this->splitN = mlaParams.perTaskNum;
Quant1.Init(gamma1GmTensor, beta1GmTensor, quantScale1GmTensor, quantOffset1GmTensor, hiddenStateGmTensor, Quant1.Init(quantScale1GmTensor, quantOffset1GmTensor, hiddenStateGmTensor,
s1GmTensor, 0, num_col_1, 0.0001395089285, s1GmTensor, 0, num_col_1, 0.0001395089285,
vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_1, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_1,
vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_1, row_work_, mlaParams); vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_1, row_work_, mlaParams);
@@ -2316,7 +2313,6 @@ private:
AscendC::GlobalTensor<half> hiddenStateGmTensor; AscendC::GlobalTensor<half> hiddenStateGmTensor;
AscendC::GlobalTensor<half> gamma1GmTensor;
AscendC::GlobalTensor<half> quantScale1GmTensor; AscendC::GlobalTensor<half> quantScale1GmTensor;
AscendC::GlobalTensor<int8_t> quantOffset1GmTensor; AscendC::GlobalTensor<int8_t> quantOffset1GmTensor;
@@ -2343,7 +2339,6 @@ private:
AscendC::GlobalTensor<half> s3GmTensor; AscendC::GlobalTensor<half> s3GmTensor;
AscendC::GlobalTensor<uint64_t> descale1gmTensor; AscendC::GlobalTensor<uint64_t> descale1gmTensor;
AscendC::GlobalTensor<uint64_t> descale2gmTensor; AscendC::GlobalTensor<uint64_t> descale2gmTensor;
AscendC::GlobalTensor<half> beta1GmTensor;
AscendC::GlobalTensor<half> beta2GmTensor; AscendC::GlobalTensor<half> beta2GmTensor;
AscendC::GlobalTensor<int32_t> bias1gmTensor; AscendC::GlobalTensor<int32_t> bias1gmTensor;

View File

@@ -128,8 +128,6 @@ namespace vllm_ascend {
extern void mla_preprocess_impl( extern void mla_preprocess_impl(
void* stream, void* stream,
void* hidden_state, void* hidden_state,
void* gamma1,
void* beta1,
void* quant_scale1, void* quant_scale1,
void* quant_offset1, void* quant_offset1,
void* wdqkv, void* wdqkv,

View File

@@ -108,7 +108,7 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
} }
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess( std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv, const at::Tensor &hiddenState, const at::Tensor &wdqkv,
const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq, const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin, const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping, const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
@@ -135,8 +135,6 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preproces
); );
void *hidden_state_ptr = hiddenState.data_ptr(); void *hidden_state_ptr = hiddenState.data_ptr();
void *gamma0_ptr = gamma0.data_ptr();
void *beta0_ptr = beta0.data_ptr();
void *quant_scale0_ptr = quant_scale0.data_ptr(); void *quant_scale0_ptr = quant_scale0.data_ptr();
void *quant_offset0_ptr = quant_offset0.data_ptr(); void *quant_offset0_ptr = quant_offset0.data_ptr();
void *wdqkv_ptr = wdqkv.data_ptr(); void *wdqkv_ptr = wdqkv.data_ptr();
@@ -168,12 +166,12 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preproces
at_npu::native::OpCommand cmd; at_npu::native::OpCommand cmd;
cmd.Name("mla_preprocess"); cmd.Name("mla_preprocess");
cmd.SetCustomHandler([stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, cmd.SetCustomHandler([stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr,
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr, kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr, qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
tiling_ptr, block_dim]() -> int { tiling_ptr, block_dim]() -> int {
mla_preprocess_impl(stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, mla_preprocess_impl(stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr, gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr,
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr, kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr, qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
@@ -502,7 +500,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand); ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand);
ops.def( ops.def(
"mla_preprocess(Tensor hiddenState, Tensor gamma0, Tensor beta0, Tensor wdqkv," "mla_preprocess(Tensor hiddenState, Tensor wdqkv,"
" Tensor descale0, Tensor gamma1, Tensor beta1, Tensor wuq, Tensor descale1," " Tensor descale0, Tensor gamma1, Tensor beta1, Tensor wuq, Tensor descale1,"
" Tensor gamma2, Tensor cos, Tensor sin, Tensor wuk, Tensor kv_cache," " Tensor gamma2, Tensor cos, Tensor sin, Tensor wuk, Tensor kv_cache,"
" Tensor kv_cache_rope, Tensor slotmapping, Tensor quant_scale0," " Tensor kv_cache_rope, Tensor slotmapping, Tensor quant_scale0,"

View File

@@ -83,8 +83,6 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess( std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
const at::Tensor &hiddenState, const at::Tensor &hiddenState,
const at::Tensor &gamma0,
const at::Tensor &beta0,
const at::Tensor &wdqkv, const at::Tensor &wdqkv,
const at::Tensor &descale0, const at::Tensor &descale0,
const at::Tensor &gamma1, const at::Tensor &gamma1,

View File

@@ -18,8 +18,6 @@ def test_mla_preprocess_kernel():
dtype = torch.bfloat16 dtype = torch.bfloat16
hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu() hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu()
gamma0 = torch.randn((N_7168), dtype=dtype).npu()
beta0 = torch.randn((N_7168), dtype=dtype).npu()
quant_scale0 = torch.randn((1, ), dtype=dtype).npu() quant_scale0 = torch.randn((1, ), dtype=dtype).npu()
quant_offset0 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu() quant_offset0 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
@@ -74,8 +72,6 @@ def test_mla_preprocess_kernel():
torch.ops._C_ascend.mla_preprocess( torch.ops._C_ascend.mla_preprocess(
hidden_states, hidden_states,
gamma0,
beta0,
wdqkv, wdqkv,
de_scale0, de_scale0,
gamma1, gamma1,

View File

@@ -1,5 +1,8 @@
from __future__ import annotations from __future__ import annotations
import os
from unittest.mock import patch
import pytest import pytest
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import CompilationConfig, CUDAGraphMode from vllm.config import CompilationConfig, CUDAGraphMode
@@ -108,3 +111,19 @@ def test_mtp2_correctness_full_graph(
model_name: str, model_name: str,
): ):
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL) mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"})
def test_mtp_correctness_piecewise_graph_with_mlapo_kernel(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 1)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"})
def test_mtp_correctness_full_graph_with_mlapo_kernel(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 1, CUDAGraphMode.FULL)

View File

@@ -716,17 +716,7 @@ class AscendMLAImpl(MLAAttentionImpl):
self.qb_qt_bias = qb_qt_bias.reshape( self.qb_qt_bias = qb_qt_bias.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
device = self.q_proj.weight.device device = self.q_a_proj.weight.device
self.gamma0 = torch.ones(
[self.fused_qkv_a_proj.weight.shape[-1]],
dtype=act_dtype,
device=device,
)
self.beta0 = torch.zeros(
[self.fused_qkv_a_proj.weight.shape[-1]],
dtype=act_dtype,
device=device,
)
self.gamma1 = self.q_a_layernorm.weight.data self.gamma1 = self.q_a_layernorm.weight.data
self.beta1 = self.q_a_layernorm.bias.data self.beta1 = self.q_a_layernorm.bias.data
self.gamma2 = self.kv_a_layernorm.weight.data self.gamma2 = self.kv_a_layernorm.weight.data
@@ -1085,8 +1075,6 @@ class AscendMLAImpl(MLAAttentionImpl):
torch.ops._C_ascend.mla_preprocess( torch.ops._C_ascend.mla_preprocess(
hidden_states, hidden_states,
self.gamma0,
self.beta0,
self.wd_qkv, self.wd_qkv,
self.deq_scale_qkv, self.deq_scale_qkv,
self.gamma1, self.gamma1,