add mla_preprocess kernel (#3226)

### What this PR does / why we need it?

- Adds the `mla_preprocess` custom kernel to provide an optimized
pre-processing operator for Multi-head Latent Attention (MLA) on Ascend
NPUs.
- Wires the new kernel into the C++ extension pipeline so vLLM can
invoke it directly, cutting Python-side tensor shuffling and memory
copies that previously bottlenecked MLA compilation paths.

### Does this PR introduce any user-facing change?

- No. The change only introduces a low-level kernel; public APIs and
inference behavior remain unchanged.

### How was this patch tested?

- Dedicated Ascend kernels are not covered by our CI yet, so no extra
automated tests were added. Future MLA-focused regression runs will
cover this path.

- vLLM version: v0.11.0

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
Chen Chen
2025-10-12 07:39:45 +08:00
committed by GitHub
parent 1b1207e3c3
commit bcc313e8f2
32 changed files with 9158 additions and 3 deletions

View File

@@ -81,6 +81,41 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_
return y_out;
}
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
const at::Tensor &hiddenState,
const at::Tensor &gamma0,
const at::Tensor &beta0,
const at::Tensor &wdqkv,
const at::Tensor &descale0,
const at::Tensor &gamma1,
const at::Tensor &beta1,
const at::Tensor &wuq,
const at::Tensor &descale1,
const at::Tensor &gamma2,
const at::Tensor &cos,
const at::Tensor &sin,
const at::Tensor &wuk,
const at::Tensor &kv_cache,
const at::Tensor &kv_cache_rope,
const at::Tensor &slotmapping,
const at::Tensor &quant_scale0,
const at::Tensor &quant_offset0,
const at::Tensor &bias0,
const at::Tensor &quant_scale1,
const at::Tensor &quant_offset1,
const at::Tensor &bias1,
const c10::optional<at::Tensor> &ctkv_scale,
const c10::optional<at::Tensor> &q_nope_scale,
c10::optional<c10::string_view> cache_mode,
c10::optional<c10::string_view> quant_mode,
at::Tensor &q_out0,
at::Tensor &kv_cache_out0,
at::Tensor &q_out1,
at::Tensor &kv_cache_out1)
{
return {q_out0, kv_cache_out0, q_out1, kv_cache_out1};
}
} // namespace meta
} // namespace vllm_ascend
@@ -97,6 +132,7 @@ namespace {
ops.impl("bgmv_expand", &vllm_ascend::meta::bgmv_expand_meta);
// Sgmv expand
ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta);
// MLA preprocess
ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess);
}
}