mlapo add qdown output (#4707)

### What this PR does / why we need it?
This PR adds mlapo operation support qdown of output.
### Does this PR introduce _any_ user-facing change?
mlapo operation add enable_inner_out of input
### How was this patch tested?
CI passed with new added/existing test.


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: h1074112368 <h1074112368@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
h1074112368
2025-12-06 11:18:53 +08:00
committed by GitHub
parent 8378f56f53
commit 74033999ed
8 changed files with 3136 additions and 26 deletions

View File

@@ -127,6 +127,7 @@ struct OpParam {
int32_t cacheMode;
QuantMode quantMode;
caffe2::TypeMeta inDtype;
bool enableInnerOut;
};
class PpMatmulTilingApi
@@ -540,7 +541,8 @@ void MlaPreprocessTiling::SetMlapoWorkSpace()
void MlaPreprocessTiling::SetTilingKey()
{
uint64_t tilingKey = (static_cast<uint64_t>(opParam.inDtype == at::kBFloat16)) << 8;
uint64_t tilingKey = (static_cast<uint64_t>(opParam.enableInnerOut)) << 9;
tilingKey |= (static_cast<uint64_t>(opParam.inDtype == at::kBFloat16)) << 8;
tilingKey |= static_cast<uint64_t>(opParam.cacheMode);
tilingKey |= (static_cast<uint64_t>(opParam.quantMode) << 3);
@@ -619,21 +621,12 @@ inline int get_op_mode(const MapType &mode_map, c10::optional<c10::string_view>
return it->second;
}
// std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
// const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv,
// const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
// const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
// const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
// const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0,
// const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1,
// const c10::optional<at::Tensor> &ctkv_scale, const c10::optional<at::Tensor> &q_nope_scale,
// c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, at::Tensor &q_out0,
// at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1)
std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
const at::Tensor &hiddenState,
const at::Tensor &wuk,
c10::optional<c10::string_view> cache_mode,
c10::optional<c10::string_view> quant_mode
c10::optional<c10::string_view> quant_mode,
bool enable_inner_out
)
{
auto cacheMode = get_op_mode(cache_mode_map, cache_mode, "krope_ctkv", "cache_mode");
@@ -661,6 +654,7 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
opParam.cacheMode = static_cast<int32_t>(cacheMode);
opParam.quantMode = static_cast<QuantMode>(quantMode);
opParam.inDtype = hiddenState.options().dtype();
opParam.enableInnerOut = enable_inner_out;
MlaTilingData tilingData;
MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData);

View File

@@ -103,6 +103,9 @@ constexpr uint32_t KEY_FP16_CACHEMODE_1_QUANTMODE_0 = 1;
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0 = 256;
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0 = 257;
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0 = 259;
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER = 256 + 512;
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER = 257 + 512;
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER = 259 + 512;
enum class QuantMode : int32_t {
PER_TENSOR_ASYMM_QUANT = 0,

View File

@@ -15,6 +15,7 @@
#include "mla_preprocess_mix_fp16.hpp"
#include "mla_preprocess_mix_bf16.hpp"
#include "mla_preprocess_mix_bf16_qdown.hpp"
#include "../op_host/tiling/mla_preprocess_tiling.h"
@@ -23,7 +24,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
GM_ADDR bias1, GM_ADDR gamma2, GM_ADDR beta2, GM_ADDR quantScale2, GM_ADDR quantOffset2, GM_ADDR gamma3,
GM_ADDR sin1, GM_ADDR cos1, GM_ADDR sin2, GM_ADDR cos2, GM_ADDR keycache, GM_ADDR slotMapping, GM_ADDR wuq,
GM_ADDR bias2, GM_ADDR wuk, GM_ADDR descale1, GM_ADDR descale2, GM_ADDR ctkvScale, GM_ADDR qnopeScale, GM_ADDR q,
GM_ADDR keycacheOut, GM_ADDR q2, GM_ADDR keycacheOut2, GM_ADDR workspace, GM_ADDR tiling)
GM_ADDR keycacheOut, GM_ADDR q2, GM_ADDR keycacheOut2, GM_ADDR innerOut, GM_ADDR workspace, GM_ADDR tiling)
{
#if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220)
PRELOAD(2);
@@ -218,6 +219,54 @@ extern "C" __global__ __aicore__ void mla_preprocess(
}
break;
}
case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER: {
MLAPO_BF16_INNER::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm0Qm0Inner(mlaTilingData, tiling);
opBf16Cm0Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5, innerOut);
if ASCEND_IS_AIC {
opBf16Cm0Qm0Inner.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm0Qm0Inner.ProcessVector();
}
break;
}
case KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER: {
MLAPO_BF16_INNER::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm1Qm0Inner(mlaTilingData, tiling);
opBf16Cm1Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5, innerOut);
if ASCEND_IS_AIC {
opBf16Cm1Qm0Inner.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm1Qm0Inner.ProcessVector();
}
break;
}
case KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER: {
MLAPO_BF16_INNER::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm3Qm0Inner(mlaTilingData, tiling);
opBf16Cm3Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5, innerOut);
if ASCEND_IS_AIC {
opBf16Cm3Qm0Inner.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm3Qm0Inner.ProcessVector();
}
break;
}
default: {
break;
}
@@ -256,6 +305,7 @@ extern void mla_preprocess_impl(
void* keycache_out,
void* q2,
void* keycache_out2,
void* inner_out,
void* workspace,
void* tiling,
const uint32_t block_dim)
@@ -288,6 +338,7 @@ extern void mla_preprocess_impl(
keycache_out,
q2,
keycache_out2,
inner_out,
workspace,
tiling);
}

File diff suppressed because it is too large Load Diff

View File

@@ -154,6 +154,7 @@ namespace vllm_ascend {
void* keycache_out,
void* q2,
void* keycache_out2,
void* inner_out,
void* workspace,
void* tiling,
const uint32_t block_dim

View File

@@ -173,7 +173,7 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
return {query_dst, key_dst};
}
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
const at::Tensor &hiddenState, const at::Tensor &wdqkv,
const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
@@ -181,8 +181,8 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preproces
const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0,
const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1,
const c10::optional<at::Tensor> &ctkv_scale, const c10::optional<at::Tensor> &q_nope_scale,
c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, at::Tensor &q_out0,
at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1)
c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, c10::optional<bool> enable_inner_out, at::Tensor &q_out0,
at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1, at::Tensor &inner_out)
{
at::Tensor CtkvScale =
ctkv_scale.has_value()
@@ -192,12 +192,17 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preproces
q_nope_scale.has_value()
? q_nope_scale.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
bool enableInnerOut =
enable_inner_out.has_value()
? enable_inner_out.value()
: false;
auto [workspace_tensor, tiling, block_dim] = mlapo::mla_preprocess_tiling(
hiddenState,
wuk,
cache_mode,
quant_mode
quant_mode,
enableInnerOut
);
void *hidden_state_ptr = hiddenState.data_ptr();
@@ -225,6 +230,7 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preproces
void *kv_cache_out0_ptr = kv_cache_out0.data_ptr();
void *q_out1_ptr = q_out1.data_ptr();
void *kv_cache_out1_ptr = kv_cache_out1.data_ptr();
void *inner_out_ptr = inner_out.data_ptr();
void *workspace_ptr = workspace_tensor.data_ptr();
void *tiling_ptr = tiling.data_ptr();
@@ -235,17 +241,17 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preproces
cmd.SetCustomHandler([stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr,
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, inner_out_ptr, workspace_ptr,
tiling_ptr, block_dim]() -> int {
mla_preprocess_impl(stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr,
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, inner_out_ptr, workspace_ptr,
tiling_ptr, block_dim);
return 0;
});
cmd.Run();
return std::forward_as_tuple(q_out0, kv_cache_out0, q_out1, kv_cache_out1);
return std::forward_as_tuple(q_out0, kv_cache_out0, q_out1, kv_cache_out1, inner_out);
}
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
@@ -792,9 +798,9 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
" Tensor kv_cache_rope, Tensor slotmapping, Tensor quant_scale0,"
" Tensor quant_offset0, Tensor bias0, Tensor quant_scale1, Tensor quant_offset1,"
" Tensor bias1, Tensor? ctkv_scale, Tensor? q_nope_scale, str? cache_mode,"
" str? quant_mode, Tensor! q_out0, Tensor! kv_cache_out0, Tensor! q_out1,"
" Tensor! kv_cache_out1) -> (Tensor q_out0, Tensor kv_cache_out0,"
" Tensor q_out1, Tensor kv_cache_out1)"
" str? quant_mode, bool? enable_inner_out, Tensor! q_out0, Tensor! kv_cache_out0, Tensor! q_out1,"
" Tensor! kv_cache_out1, Tensor! inner_out) -> (Tensor q_out0, Tensor kv_cache_out0,"
" Tensor q_out1, Tensor kv_cache_out1, Tensor inner_out)"
);
ops.impl("mla_preprocess", torch::kPrivateUse1, &vllm_ascend::mla_preprocess);

View File

@@ -81,7 +81,7 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_
return y_out;
}
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
const at::Tensor &hiddenState,
const at::Tensor &wdqkv,
const at::Tensor &descale0,
@@ -106,12 +106,15 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preproces
const c10::optional<at::Tensor> &q_nope_scale,
c10::optional<c10::string_view> cache_mode,
c10::optional<c10::string_view> quant_mode,
c10::optional<bool> enable_inner_out,
at::Tensor &q_out0,
at::Tensor &kv_cache_out0,
at::Tensor &q_out1,
at::Tensor &kv_cache_out1)
at::Tensor &kv_cache_out1,
at::Tensor &inner_out
)
{
return {q_out0, kv_cache_out0, q_out1, kv_cache_out1};
return {q_out0, kv_cache_out0, q_out1, kv_cache_out1, inner_out};
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant(

View File

@@ -0,0 +1,116 @@
import gc
import torch
import torch_npu
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
@torch.inference_mode()
def test_mla_preprocess_kernel():
token_num = 1
head_num = 2
N_7168 = 7168
block_num = 1
block_size = 128
dtype = torch.bfloat16
hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu()
quant_scale0 = torch.randn((1, ), dtype=dtype).npu()
quant_offset0 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
wdqkv = torch.randint(0, 7, (1, 224, 2112, 32), dtype=torch.int8).npu()
wdqkv = torch_npu.npu_format_cast(wdqkv.contiguous(), 29)
de_scale0 = torch.rand((2112, ), dtype=torch.float).npu()
bias0 = torch.randint(0, 7, (2112, ), dtype=torch.int32).npu()
gamma1 = torch.randn((1536), dtype=dtype).npu()
beta1 = torch.randn((1536), dtype=dtype).npu()
quant_scale1 = torch.randn((1, ), dtype=dtype).npu()
quant_offset1 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
wuq = torch.randint(0, 7, (1, 48, head_num * 192, 32),
dtype=torch.int8).npu()
wuq = torch_npu.npu_format_cast(wuq.contiguous(), 29)
de_scale1 = torch.rand((head_num * 192, ), dtype=torch.float).npu()
bias1 = torch.randint(0, 7, (head_num * 192, ), dtype=torch.int32).npu()
gamma2 = torch.randn((512), dtype=dtype).npu()
cos = torch.randn((token_num, 64), dtype=dtype).npu()
sin = torch.randn((token_num, 64), dtype=dtype).npu()
wuk = torch.randn((head_num, 128, 512), dtype=dtype).npu()
wuk = torch_npu.npu_format_cast(wuk, 29)
kv_cache = torch.randint(0,
7,
(block_num, head_num * 512 // 32, block_size, 32),
dtype=dtype).npu()
kv_cache_rope = torch.randn(
(block_num, head_num * 64 // 16, block_size, 16), dtype=dtype).npu()
slotmapping = torch.randint(0, 7, (token_num, ), dtype=torch.int32).npu()
ctkv_scale = torch.randn((1, ), dtype=dtype).npu()
qnope_scale = torch.randn((head_num), dtype=dtype).npu()
q_nope_out = torch.empty(
(hidden_states.shape[0], wuk.shape[0], kv_cache.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_rope_out = torch.empty(
(hidden_states.shape[0], wuk.shape[0], kv_cache_rope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_down = torch.empty(
(hidden_states.shape[0], 1536),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_nope_old = q_nope_out.clone()
q_rope_old = q_rope_out.clone()
q_down_old = q_down.clone()
torch.ops._C_ascend.mla_preprocess(hidden_states,
wdqkv,
de_scale0,
gamma1,
beta1,
wuq,
de_scale1,
gamma2,
cos,
sin,
wuk,
kv_cache,
kv_cache_rope,
slotmapping,
quant_scale0=quant_scale0,
quant_offset0=quant_offset0,
bias0=bias0,
quant_scale1=quant_scale1,
quant_offset1=quant_offset1,
bias1=bias1,
ctkv_scale=ctkv_scale,
q_nope_scale=qnope_scale,
cache_mode="krope_ctkv",
quant_mode="per_tensor_quant_asymm",
enable_inner_out=True,
q_out0=q_nope_out,
kv_cache_out0=kv_cache,
q_out1=q_rope_out,
kv_cache_out1=kv_cache_rope,
inner_out=q_down)
assert not torch.equal(q_nope_out, q_nope_old)
assert not torch.equal(q_rope_out, q_rope_old)
assert not torch.equal(q_down, q_down_old)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()