add mla_preprocess kernel (#3226)
### What this PR does / why we need it? - Adds the `mla_preprocess` custom kernel to provide an optimized pre-processing operator for Multi-head Latent Attention (MLA) on Ascend NPUs. - Wires the new kernel into the C++ extension pipeline so vLLM can invoke it directly, cutting Python-side tensor shuffling and memory copies that previously bottlenecked MLA compilation paths. ### Does this PR introduce any user-facing change? - No. The change only introduces a low-level kernel; public APIs and inference behavior remain unchanged. ### How was this patch tested? - Dedicated Ascend kernels are not covered by our CI yet, so no extra automated tests were added. Future MLA-focused regression runs will cover this path. - vLLM version: v0.11.0 Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
@@ -81,6 +81,41 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_
|
||||
return y_out;
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
|
||||
const at::Tensor &hiddenState,
|
||||
const at::Tensor &gamma0,
|
||||
const at::Tensor &beta0,
|
||||
const at::Tensor &wdqkv,
|
||||
const at::Tensor &descale0,
|
||||
const at::Tensor &gamma1,
|
||||
const at::Tensor &beta1,
|
||||
const at::Tensor &wuq,
|
||||
const at::Tensor &descale1,
|
||||
const at::Tensor &gamma2,
|
||||
const at::Tensor &cos,
|
||||
const at::Tensor &sin,
|
||||
const at::Tensor &wuk,
|
||||
const at::Tensor &kv_cache,
|
||||
const at::Tensor &kv_cache_rope,
|
||||
const at::Tensor &slotmapping,
|
||||
const at::Tensor &quant_scale0,
|
||||
const at::Tensor &quant_offset0,
|
||||
const at::Tensor &bias0,
|
||||
const at::Tensor &quant_scale1,
|
||||
const at::Tensor &quant_offset1,
|
||||
const at::Tensor &bias1,
|
||||
const c10::optional<at::Tensor> &ctkv_scale,
|
||||
const c10::optional<at::Tensor> &q_nope_scale,
|
||||
c10::optional<c10::string_view> cache_mode,
|
||||
c10::optional<c10::string_view> quant_mode,
|
||||
at::Tensor &q_out0,
|
||||
at::Tensor &kv_cache_out0,
|
||||
at::Tensor &q_out1,
|
||||
at::Tensor &kv_cache_out1)
|
||||
{
|
||||
return {q_out0, kv_cache_out0, q_out1, kv_cache_out1};
|
||||
}
|
||||
|
||||
|
||||
} // namespace meta
|
||||
} // namespace vllm_ascend
|
||||
@@ -97,6 +132,7 @@ namespace {
|
||||
ops.impl("bgmv_expand", &vllm_ascend::meta::bgmv_expand_meta);
|
||||
// Sgmv expand
|
||||
ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta);
|
||||
|
||||
// MLA preprocess
|
||||
ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user