remove redundant params in mla_preprocess kernel (#3530)

### What this PR does / why we need it?

This pull request removes the redundant parameters `gamma1` and `beta1`
(also named `gamma0`/`beta0` in some places) from the `mla_preprocess`
kernel and its calling hierarchy. The changes are consistent across C++
kernel code, bindings, and Python call sites. The parameters were unused
in the lower-level functions, so their removal is a good cleanup.

### Does this PR introduce _any_ user-facing change?

The python interface of the kernel is affected, and the params of
`gamma0` and `beta0` are not needed.

### How was this patch tested?

The unit-test of the kernel is adapted accordingly.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: mojave2 <chenchen145@huawei.com>
This commit is contained in:
Chen Chen
2025-10-21 19:20:13 +08:00
committed by GitHub
parent 80b8df881f
commit 6b290acfe1
9 changed files with 34 additions and 50 deletions

View File

@@ -19,7 +19,7 @@
#include "../op_host/tiling/mla_preprocess_tiling.h"
extern "C" __global__ __aicore__ void mla_preprocess(
GM_ADDR hiddenState, GM_ADDR gamma1, GM_ADDR beta1, GM_ADDR quantScale1, GM_ADDR quantOffset1, GM_ADDR wdqkv,
GM_ADDR hiddenState, GM_ADDR quantScale1, GM_ADDR quantOffset1, GM_ADDR wdqkv,
GM_ADDR bias1, GM_ADDR gamma2, GM_ADDR beta2, GM_ADDR quantScale2, GM_ADDR quantOffset2, GM_ADDR gamma3,
GM_ADDR sin1, GM_ADDR cos1, GM_ADDR sin2, GM_ADDR cos2, GM_ADDR keycache, GM_ADDR slotMapping, GM_ADDR wuq,
GM_ADDR bias2, GM_ADDR wuk, GM_ADDR descale1, GM_ADDR descale2, GM_ADDR ctkvScale, GM_ADDR qnopeScale, GM_ADDR q,
@@ -143,7 +143,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
case KEY_FP16_CACHEMODE_0_QUANTMODE_0: {
MLAPO_FP16::MLAOperation<CACHE_MODE_KVCACHE, DataFormat::NZ, DataFormat::NZ, DataFormat::ND> opFp16Cm0Qm0(
mlaTilingData, tiling);
opFp16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
opFp16Cm0Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3);
@@ -158,7 +158,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
case KEY_FP16_CACHEMODE_1_QUANTMODE_0: {
MLAPO_FP16::MLAOperation<CACHE_MODE_KROPE_CTKV, DataFormat::NZ, DataFormat::NZ, DataFormat::ND>
opFp16Cm1Qm0(mlaTilingData, tiling);
opFp16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
opFp16Cm1Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3);
@@ -174,7 +174,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm0Qm0(mlaTilingData, tiling);
opBf16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
opBf16Cm0Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5);
@@ -190,7 +190,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm1Qm0(mlaTilingData, tiling);
opBf16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
opBf16Cm1Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5);
@@ -206,7 +206,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
MLAPO_BF16::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm3Qm0(mlaTilingData, tiling);
opBf16Cm3Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
opBf16Cm3Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5);
@@ -230,8 +230,6 @@ namespace vllm_ascend {
extern void mla_preprocess_impl(
void* stream,
void* hidden_state,
void* gamma1,
void* beta1,
void* quant_scale1,
void* quant_offset1,
void* wdqkv,
@@ -264,8 +262,6 @@ extern void mla_preprocess_impl(
{
mla_preprocess<<<block_dim, nullptr, stream>>>(
hidden_state,
gamma1,
beta1,
quant_scale1,
quant_offset1,
wdqkv,