mlapo add qdown output (#4707)

### What this PR does / why we need it?
This PR adds mlapo operation support qdown of output.
### Does this PR introduce _any_ user-facing change?
mlapo operation add enable_inner_out of input
### How was this patch tested?
CI passed with new added/existing test.


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: h1074112368 <h1074112368@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
h1074112368
2025-12-06 11:18:53 +08:00
committed by GitHub
parent 8378f56f53
commit 74033999ed
8 changed files with 3136 additions and 26 deletions

View File

@@ -127,6 +127,7 @@ struct OpParam {
int32_t cacheMode;
QuantMode quantMode;
caffe2::TypeMeta inDtype;
bool enableInnerOut;
};
class PpMatmulTilingApi
@@ -540,7 +541,8 @@ void MlaPreprocessTiling::SetMlapoWorkSpace()
void MlaPreprocessTiling::SetTilingKey()
{
uint64_t tilingKey = (static_cast<uint64_t>(opParam.inDtype == at::kBFloat16)) << 8;
uint64_t tilingKey = (static_cast<uint64_t>(opParam.enableInnerOut)) << 9;
tilingKey |= (static_cast<uint64_t>(opParam.inDtype == at::kBFloat16)) << 8;
tilingKey |= static_cast<uint64_t>(opParam.cacheMode);
tilingKey |= (static_cast<uint64_t>(opParam.quantMode) << 3);
@@ -619,21 +621,12 @@ inline int get_op_mode(const MapType &mode_map, c10::optional<c10::string_view>
return it->second;
}
// std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
// const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv,
// const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
// const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
// const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
// const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0,
// const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1,
// const c10::optional<at::Tensor> &ctkv_scale, const c10::optional<at::Tensor> &q_nope_scale,
// c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, at::Tensor &q_out0,
// at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1)
std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
const at::Tensor &hiddenState,
const at::Tensor &wuk,
c10::optional<c10::string_view> cache_mode,
c10::optional<c10::string_view> quant_mode
c10::optional<c10::string_view> quant_mode,
bool enable_inner_out
)
{
auto cacheMode = get_op_mode(cache_mode_map, cache_mode, "krope_ctkv", "cache_mode");
@@ -661,6 +654,7 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
opParam.cacheMode = static_cast<int32_t>(cacheMode);
opParam.quantMode = static_cast<QuantMode>(quantMode);
opParam.inDtype = hiddenState.options().dtype();
opParam.enableInnerOut = enable_inner_out;
MlaTilingData tilingData;
MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData);

View File

@@ -103,6 +103,9 @@ constexpr uint32_t KEY_FP16_CACHEMODE_1_QUANTMODE_0 = 1;
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0 = 256;
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0 = 257;
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0 = 259;
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER = 256 + 512;
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER = 257 + 512;
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER = 259 + 512;
enum class QuantMode : int32_t {
PER_TENSOR_ASYMM_QUANT = 0,

View File

@@ -15,6 +15,7 @@
#include "mla_preprocess_mix_fp16.hpp"
#include "mla_preprocess_mix_bf16.hpp"
#include "mla_preprocess_mix_bf16_qdown.hpp"
#include "../op_host/tiling/mla_preprocess_tiling.h"
@@ -23,7 +24,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
GM_ADDR bias1, GM_ADDR gamma2, GM_ADDR beta2, GM_ADDR quantScale2, GM_ADDR quantOffset2, GM_ADDR gamma3,
GM_ADDR sin1, GM_ADDR cos1, GM_ADDR sin2, GM_ADDR cos2, GM_ADDR keycache, GM_ADDR slotMapping, GM_ADDR wuq,
GM_ADDR bias2, GM_ADDR wuk, GM_ADDR descale1, GM_ADDR descale2, GM_ADDR ctkvScale, GM_ADDR qnopeScale, GM_ADDR q,
GM_ADDR keycacheOut, GM_ADDR q2, GM_ADDR keycacheOut2, GM_ADDR workspace, GM_ADDR tiling)
GM_ADDR keycacheOut, GM_ADDR q2, GM_ADDR keycacheOut2, GM_ADDR innerOut, GM_ADDR workspace, GM_ADDR tiling)
{
#if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220)
PRELOAD(2);
@@ -218,6 +219,54 @@ extern "C" __global__ __aicore__ void mla_preprocess(
}
break;
}
case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER: {
MLAPO_BF16_INNER::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm0Qm0Inner(mlaTilingData, tiling);
opBf16Cm0Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5, innerOut);
if ASCEND_IS_AIC {
opBf16Cm0Qm0Inner.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm0Qm0Inner.ProcessVector();
}
break;
}
case KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER: {
MLAPO_BF16_INNER::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm1Qm0Inner(mlaTilingData, tiling);
opBf16Cm1Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5, innerOut);
if ASCEND_IS_AIC {
opBf16Cm1Qm0Inner.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm1Qm0Inner.ProcessVector();
}
break;
}
case KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER: {
MLAPO_BF16_INNER::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm3Qm0Inner(mlaTilingData, tiling);
opBf16Cm3Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5, innerOut);
if ASCEND_IS_AIC {
opBf16Cm3Qm0Inner.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm3Qm0Inner.ProcessVector();
}
break;
}
default: {
break;
}
@@ -256,6 +305,7 @@ extern void mla_preprocess_impl(
void* keycache_out,
void* q2,
void* keycache_out2,
void* inner_out,
void* workspace,
void* tiling,
const uint32_t block_dim)
@@ -288,6 +338,7 @@ extern void mla_preprocess_impl(
keycache_out,
q2,
keycache_out2,
inner_out,
workspace,
tiling);
}

File diff suppressed because it is too large Load Diff