remove redundant params in mla_preprocess kernel (#3530)
### What this PR does / why we need it? This pull request removes the redundant parameters `gamma1` and `beta1` (also named `gamma0`/`beta0` in some places) from the `mla_preprocess` kernel and its calling hierarchy. The changes are consistent across C++ kernel code, bindings, and Python call sites. The parameters were unused in the lower-level functions, so their removal is a good cleanup. ### Does this PR introduce _any_ user-facing change? The python interface of the kernel is affected, and the params of `gamma0` and `beta0` are not needed. ### How was this patch tested? The unit-test of the kernel is adapted accordingly. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: mojave2 <chenchen145@huawei.com>
This commit is contained in:
@@ -18,8 +18,6 @@ def test_mla_preprocess_kernel():
|
||||
dtype = torch.bfloat16
|
||||
|
||||
hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu()
|
||||
gamma0 = torch.randn((N_7168), dtype=dtype).npu()
|
||||
beta0 = torch.randn((N_7168), dtype=dtype).npu()
|
||||
quant_scale0 = torch.randn((1, ), dtype=dtype).npu()
|
||||
quant_offset0 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
|
||||
|
||||
@@ -74,8 +72,6 @@ def test_mla_preprocess_kernel():
|
||||
|
||||
torch.ops._C_ascend.mla_preprocess(
|
||||
hidden_states,
|
||||
gamma0,
|
||||
beta0,
|
||||
wdqkv,
|
||||
de_scale0,
|
||||
gamma1,
|
||||
|
||||
Reference in New Issue
Block a user