remove redundant params in mla_preprocess kernel (#3530)
### What this PR does / why we need it? This pull request removes the redundant parameters `gamma1` and `beta1` (also named `gamma0`/`beta0` in some places) from the `mla_preprocess` kernel and its calling hierarchy. The changes are consistent across C++ kernel code, bindings, and Python call sites. The parameters were unused in the lower-level functions, so their removal is a good cleanup. ### Does this PR introduce _any_ user-facing change? The python interface of the kernel is affected, and the params of `gamma0` and `beta0` are not needed. ### How was this patch tested? The unit-test of the kernel is adapted accordingly. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: mojave2 <chenchen145@huawei.com>
This commit is contained in:
@@ -108,7 +108,7 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
|
||||
const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv,
|
||||
const at::Tensor &hiddenState, const at::Tensor &wdqkv,
|
||||
const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
|
||||
const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
|
||||
const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
|
||||
@@ -135,8 +135,6 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preproces
|
||||
);
|
||||
|
||||
void *hidden_state_ptr = hiddenState.data_ptr();
|
||||
void *gamma0_ptr = gamma0.data_ptr();
|
||||
void *beta0_ptr = beta0.data_ptr();
|
||||
void *quant_scale0_ptr = quant_scale0.data_ptr();
|
||||
void *quant_offset0_ptr = quant_offset0.data_ptr();
|
||||
void *wdqkv_ptr = wdqkv.data_ptr();
|
||||
@@ -168,12 +166,12 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preproces
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("mla_preprocess");
|
||||
|
||||
cmd.SetCustomHandler([stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
|
||||
cmd.SetCustomHandler([stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
|
||||
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr,
|
||||
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
|
||||
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
|
||||
tiling_ptr, block_dim]() -> int {
|
||||
mla_preprocess_impl(stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
|
||||
mla_preprocess_impl(stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
|
||||
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr,
|
||||
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
|
||||
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
|
||||
@@ -502,7 +500,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand);
|
||||
|
||||
ops.def(
|
||||
"mla_preprocess(Tensor hiddenState, Tensor gamma0, Tensor beta0, Tensor wdqkv,"
|
||||
"mla_preprocess(Tensor hiddenState, Tensor wdqkv,"
|
||||
" Tensor descale0, Tensor gamma1, Tensor beta1, Tensor wuq, Tensor descale1,"
|
||||
" Tensor gamma2, Tensor cos, Tensor sin, Tensor wuk, Tensor kv_cache,"
|
||||
" Tensor kv_cache_rope, Tensor slotmapping, Tensor quant_scale0,"
|
||||
|
||||
Reference in New Issue
Block a user