add mla_preprocess kernel (#3226)
### What this PR does / why we need it? - Adds the `mla_preprocess` custom kernel to provide an optimized pre-processing operator for Multi-head Latent Attention (MLA) on Ascend NPUs. - Wires the new kernel into the C++ extension pipeline so vLLM can invoke it directly, cutting Python-side tensor shuffling and memory copies that previously bottlenecked MLA compilation paths. ### Does this PR introduce any user-facing change? - No. The change only introduces a low-level kernel; public APIs and inference behavior remain unchanged. ### How was this patch tested? - Dedicated Ascend kernels are not covered by our CI yet, so no extra automated tests were added. Future MLA-focused regression runs will cover this path. - vLLM version: v0.11.0 Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
@@ -23,6 +23,7 @@
|
||||
#include "acl/acl.h"
|
||||
#include "ops.h"
|
||||
#include "utils.h"
|
||||
#include "mla_preprocess/op_host/mla_preprocess.h"
|
||||
|
||||
namespace vllm_ascend {
|
||||
|
||||
@@ -106,6 +107,83 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
|
||||
return {query_dst, key_dst};
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
|
||||
const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv,
|
||||
const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
|
||||
const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
|
||||
const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
|
||||
const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0,
|
||||
const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1,
|
||||
const c10::optional<at::Tensor> &ctkv_scale, const c10::optional<at::Tensor> &q_nope_scale,
|
||||
c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, at::Tensor &q_out0,
|
||||
at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1)
|
||||
{
|
||||
at::Tensor CtkvScale =
|
||||
ctkv_scale.has_value()
|
||||
? ctkv_scale.value()
|
||||
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
|
||||
at::Tensor QnopeScale =
|
||||
q_nope_scale.has_value()
|
||||
? q_nope_scale.value()
|
||||
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
|
||||
|
||||
auto [workspace_tensor, tiling, block_dim] = mlapo::mla_preprocess_tiling(
|
||||
hiddenState,
|
||||
wuk,
|
||||
cache_mode,
|
||||
quant_mode
|
||||
);
|
||||
|
||||
void *hidden_state_ptr = hiddenState.data_ptr();
|
||||
void *gamma0_ptr = gamma0.data_ptr();
|
||||
void *beta0_ptr = beta0.data_ptr();
|
||||
void *quant_scale0_ptr = quant_scale0.data_ptr();
|
||||
void *quant_offset0_ptr = quant_offset0.data_ptr();
|
||||
void *wdqkv_ptr = wdqkv.data_ptr();
|
||||
void *bias0_ptr = bias0.data_ptr();
|
||||
void *gamma1_ptr = gamma1.data_ptr();
|
||||
void *beta1_ptr = beta1.data_ptr();
|
||||
void *quant_scale1_ptr = quant_scale1.data_ptr();
|
||||
void *quant_offset1_ptr = quant_offset1.data_ptr();
|
||||
void *gamma2_ptr = gamma2.data_ptr();
|
||||
void *sin_ptr = sin.data_ptr();
|
||||
void *cos_ptr = cos.data_ptr();
|
||||
void *kv_cache_ptr = kv_cache.data_ptr();
|
||||
void *slotmapping_ptr = slotmapping.data_ptr();
|
||||
void *wuq_ptr = wuq.data_ptr();
|
||||
void *bias1_ptr = bias1.data_ptr();
|
||||
void *wuk_ptr = wuk.data_ptr();
|
||||
void *descale0_ptr = descale0.data_ptr();
|
||||
void *descale1_ptr = descale1.data_ptr();
|
||||
void *ctkv_scale_ptr = CtkvScale.data_ptr();
|
||||
void *qnope_scale_ptr = QnopeScale.data_ptr();
|
||||
void *q_out0_ptr = q_out0.data_ptr();
|
||||
void *kv_cache_out0_ptr = kv_cache_out0.data_ptr();
|
||||
void *q_out1_ptr = q_out1.data_ptr();
|
||||
void *kv_cache_out1_ptr = kv_cache_out1.data_ptr();
|
||||
void *workspace_ptr = workspace_tensor.data_ptr();
|
||||
void *tiling_ptr = tiling.data_ptr();
|
||||
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("mla_preprocess");
|
||||
|
||||
cmd.SetCustomHandler([stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
|
||||
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr,
|
||||
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
|
||||
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
|
||||
tiling_ptr, block_dim]() -> int {
|
||||
mla_preprocess_impl(stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
|
||||
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr,
|
||||
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
|
||||
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
|
||||
tiling_ptr, block_dim);
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
return std::forward_as_tuple(q_out0, kv_cache_out0, q_out1, kv_cache_out1);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
|
||||
at::Tensor &input,
|
||||
const int64_t org_vocab_start_index,
|
||||
@@ -422,4 +500,17 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
"sgmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y,"
|
||||
" int slice_offset, int slice_size) -> Tensor");
|
||||
ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand);
|
||||
|
||||
ops.def(
|
||||
"mla_preprocess(Tensor hiddenState, Tensor gamma0, Tensor beta0, Tensor wdqkv,"
|
||||
" Tensor descale0, Tensor gamma1, Tensor beta1, Tensor wuq, Tensor descale1,"
|
||||
" Tensor gamma2, Tensor cos, Tensor sin, Tensor wuk, Tensor kv_cache,"
|
||||
" Tensor kv_cache_rope, Tensor slotmapping, Tensor quant_scale0,"
|
||||
" Tensor quant_offset0, Tensor bias0, Tensor quant_scale1, Tensor quant_offset1,"
|
||||
" Tensor bias1, Tensor? ctkv_scale, Tensor? q_nope_scale, str? cache_mode,"
|
||||
" str? quant_mode, Tensor! q_out0, Tensor! kv_cache_out0, Tensor! q_out1,"
|
||||
" Tensor! kv_cache_out1) -> (Tensor q_out0, Tensor kv_cache_out0,"
|
||||
" Tensor q_out1, Tensor kv_cache_out1)"
|
||||
);
|
||||
ops.impl("mla_preprocess", torch::kPrivateUse1, &vllm_ascend::mla_preprocess);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user