[feat] mlapo add bf16 no_quant support (#4852)

### What this PR does / why we need it?
This PR adds mlapo operation support for bf16 no_quant mode.

### Does this PR introduce _any_ user-facing change?
This PR makes quant related parameters optional. 
### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: chenjunyi <isjunyi.chen@gmail.com>
This commit is contained in:
chenjunyi
2025-12-11 11:06:56 +08:00
committed by GitHub
parent c95c271538
commit c12eb22cbe
12 changed files with 1510 additions and 81 deletions

View File

@@ -176,15 +176,51 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
const at::Tensor &hiddenState, const at::Tensor &wdqkv,
const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
const c10::optional<at::Tensor> &descale0, const at::Tensor &gamma1, const c10::optional<at::Tensor> &beta1, const at::Tensor &wuq,
const c10::optional<at::Tensor> &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0,
const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1,
const c10::optional<at::Tensor> &quant_scale0, const c10::optional<at::Tensor> &quant_offset0, const c10::optional<at::Tensor> &bias0,
const c10::optional<at::Tensor> &quant_scale1, const c10::optional<at::Tensor> &quant_offset1, const c10::optional<at::Tensor> &bias1,
const c10::optional<at::Tensor> &ctkv_scale, const c10::optional<at::Tensor> &q_nope_scale,
c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, c10::optional<bool> enable_inner_out, at::Tensor &q_out0,
at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1, at::Tensor &inner_out)
{
at::Tensor Descale0 =
descale0.has_value()
? descale0.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor Descale1 =
descale1.has_value()
? descale1.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor Beta1 =
beta1.has_value()
? beta1.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor Quant_scale0 =
quant_scale0.has_value()
? quant_scale0.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor Quant_scale1 =
quant_scale1.has_value()
? quant_scale1.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor Quant_offset0 =
quant_offset0.has_value()
? quant_offset0.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor Quant_offset1 =
quant_offset1.has_value()
? quant_offset1.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor Bias0 =
bias0.has_value()
? bias0.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor Bias1 =
bias1.has_value()
? bias1.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor CtkvScale =
ctkv_scale.has_value()
? ctkv_scale.value()
@@ -200,6 +236,7 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &>
auto [workspace_tensor, tiling, block_dim] = mlapo::mla_preprocess_tiling(
hiddenState,
wdqkv,
wuk,
cache_mode,
quant_mode,
@@ -207,24 +244,24 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &>
);
void *hidden_state_ptr = hiddenState.data_ptr();
void *quant_scale0_ptr = quant_scale0.data_ptr();
void *quant_offset0_ptr = quant_offset0.data_ptr();
void *quant_scale0_ptr = Quant_scale0.data_ptr();
void *quant_offset0_ptr = Quant_offset0.data_ptr();
void *wdqkv_ptr = wdqkv.data_ptr();
void *bias0_ptr = bias0.data_ptr();
void *bias0_ptr = Bias0.data_ptr();
void *gamma1_ptr = gamma1.data_ptr();
void *beta1_ptr = beta1.data_ptr();
void *quant_scale1_ptr = quant_scale1.data_ptr();
void *quant_offset1_ptr = quant_offset1.data_ptr();
void *beta1_ptr = Beta1.data_ptr();
void *quant_scale1_ptr = Quant_scale1.data_ptr();
void *quant_offset1_ptr = Quant_offset1.data_ptr();
void *gamma2_ptr = gamma2.data_ptr();
void *sin_ptr = sin.data_ptr();
void *cos_ptr = cos.data_ptr();
void *kv_cache_ptr = kv_cache.data_ptr();
void *slotmapping_ptr = slotmapping.data_ptr();
void *wuq_ptr = wuq.data_ptr();
void *bias1_ptr = bias1.data_ptr();
void *bias1_ptr = Bias1.data_ptr();
void *wuk_ptr = wuk.data_ptr();
void *descale0_ptr = descale0.data_ptr();
void *descale1_ptr = descale1.data_ptr();
void *descale0_ptr = Descale0.data_ptr();
void *descale1_ptr = Descale1.data_ptr();
void *ctkv_scale_ptr = CtkvScale.data_ptr();
void *qnope_scale_ptr = QnopeScale.data_ptr();
void *q_out0_ptr = q_out0.data_ptr();
@@ -1122,11 +1159,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
ops.def(
"mla_preprocess(Tensor hiddenState, Tensor wdqkv,"
" Tensor descale0, Tensor gamma1, Tensor beta1, Tensor wuq, Tensor descale1,"
" Tensor? descale0, Tensor gamma1, Tensor? beta1, Tensor wuq, Tensor? descale1,"
" Tensor gamma2, Tensor cos, Tensor sin, Tensor wuk, Tensor kv_cache,"
" Tensor kv_cache_rope, Tensor slotmapping, Tensor quant_scale0,"
" Tensor quant_offset0, Tensor bias0, Tensor quant_scale1, Tensor quant_offset1,"
" Tensor bias1, Tensor? ctkv_scale, Tensor? q_nope_scale, str? cache_mode,"
" Tensor kv_cache_rope, Tensor slotmapping, Tensor? quant_scale0,"
" Tensor? quant_offset0, Tensor? bias0, Tensor? quant_scale1, Tensor? quant_offset1,"
" Tensor? bias1, Tensor? ctkv_scale, Tensor? q_nope_scale, str? cache_mode,"
" str? quant_mode, bool? enable_inner_out, Tensor! q_out0, Tensor! kv_cache_out0, Tensor! q_out1,"
" Tensor! kv_cache_out1, Tensor! inner_out) -> (Tensor q_out0, Tensor kv_cache_out0,"
" Tensor q_out1, Tensor kv_cache_out1, Tensor inner_out)"