remove redundant params in mla_preprocess kernel (#3530)

### What this PR does / why we need it?

This pull request removes the redundant parameters `gamma1` and `beta1`
(also named `gamma0`/`beta0` in some places) from the `mla_preprocess`
kernel and its calling hierarchy. The changes are consistent across C++
kernel code, bindings, and Python call sites. The parameters were unused
in the lower-level functions, so their removal is a good cleanup.

### Does this PR introduce _any_ user-facing change?

The python interface of the kernel is affected, and the params of
`gamma0` and `beta0` are not needed.

### How was this patch tested?

The unit-test of the kernel is adapted accordingly.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: mojave2 <chenchen145@huawei.com>
This commit is contained in:
Chen Chen
2025-10-21 19:20:13 +08:00
committed by GitHub
parent 80b8df881f
commit 6b290acfe1
9 changed files with 34 additions and 50 deletions

View File

@@ -18,8 +18,6 @@ def test_mla_preprocess_kernel():
dtype = torch.bfloat16
hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu()
gamma0 = torch.randn((N_7168), dtype=dtype).npu()
beta0 = torch.randn((N_7168), dtype=dtype).npu()
quant_scale0 = torch.randn((1, ), dtype=dtype).npu()
quant_offset0 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
@@ -74,8 +72,6 @@ def test_mla_preprocess_kernel():
torch.ops._C_ascend.mla_preprocess(
hidden_states,
gamma0,
beta0,
wdqkv,
de_scale0,
gamma1,

View File

@@ -1,5 +1,8 @@
from __future__ import annotations
import os
from unittest.mock import patch
import pytest
from vllm import SamplingParams
from vllm.config import CompilationConfig, CUDAGraphMode
@@ -108,3 +111,19 @@ def test_mtp2_correctness_full_graph(
model_name: str,
):
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"})
def test_mtp_correctness_piecewise_graph_with_mlapo_kernel(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 1)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"})
def test_mtp_correctness_full_graph_with_mlapo_kernel(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 1, CUDAGraphMode.FULL)