[feat] mlapo add bf16 no_quant support (#4852)

### What this PR does / why we need it?
This PR adds mlapo operation support for bf16 no_quant mode.

### Does this PR introduce _any_ user-facing change?
This PR makes quant related parameters optional. 
### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: chenjunyi <isjunyi.chen@gmail.com>
This commit is contained in:
chenjunyi
2025-12-11 11:06:56 +08:00
committed by GitHub
parent c95c271538
commit c12eb22cbe
12 changed files with 1510 additions and 81 deletions

View File

@@ -43,7 +43,6 @@ constexpr uint32_t L1_BIAS_SIZE = 2048;
constexpr uint32_t L0C_SIZE = 128 * 1024; constexpr uint32_t L0C_SIZE = 128 * 1024;
constexpr uint32_t CONCAT_SIZE = 512; constexpr uint32_t CONCAT_SIZE = 512;
constexpr uint32_t HIDDEN_STRATE = 7168;
constexpr uint32_t HIDDEN_STRATE_ROPE = 192; constexpr uint32_t HIDDEN_STRATE_ROPE = 192;
constexpr uint32_t HIDDEN_STRATE_MM = 2112; constexpr uint32_t HIDDEN_STRATE_MM = 2112;
constexpr uint32_t HIDDEN_STRATE_RMS = 1536; constexpr uint32_t HIDDEN_STRATE_RMS = 1536;
@@ -122,6 +121,8 @@ struct PlatformInfo {
}; };
struct OpParam { struct OpParam {
uint32_t isWeightQuantized;
uint32_t hiddenStateDim;
uint32_t N; uint32_t N;
uint32_t headNum; uint32_t headNum;
int32_t cacheMode; int32_t cacheMode;
@@ -392,7 +393,7 @@ private:
void MlaPreprocessTiling::RmsNormQuantTiling() void MlaPreprocessTiling::RmsNormQuantTiling()
{ {
tilingData->rmsNumCore1 = platformInfo.coreNumAiv; tilingData->rmsNumCore1 = platformInfo.coreNumAiv;
tilingData->rmsNumCol1 = HIDDEN_STRATE; tilingData->rmsNumCol1 = opParam.hiddenStateDim;
tilingData->rmsNumRow1 = opParam.N; tilingData->rmsNumRow1 = opParam.N;
tilingData->rmsQuantMin1 = -CONST_128; tilingData->rmsQuantMin1 = -CONST_128;
tilingData->rmsNumCore2 = platformInfo.coreNumAiv; tilingData->rmsNumCore2 = platformInfo.coreNumAiv;
@@ -508,9 +509,9 @@ void MlaPreprocessTiling::EinSumQuantTiling()
void MlaPreprocessTiling::SetMlapoWorkSpace() void MlaPreprocessTiling::SetMlapoWorkSpace()
{ {
uint64_t s1wsFactor = uint64_t s1wsFactor =
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(HIDDEN_STRATE * sizeof(int8_t), static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(opParam.hiddenStateDim * sizeof(int8_t),
opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t)) opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t))
: HIDDEN_STRATE * sizeof(int8_t)); : opParam.hiddenStateDim * sizeof(int8_t));
uint64_t workSizeS1 = s1wsFactor; uint64_t workSizeS1 = s1wsFactor;
uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t); uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t);
uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t); uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t);
@@ -525,7 +526,8 @@ void MlaPreprocessTiling::SetMlapoWorkSpace()
uint64_t pertokenWorkspace = static_cast<uint64_t>(opParam.N) * sizeof(float) * 2; uint64_t pertokenWorkspace = static_cast<uint64_t>(opParam.N) * sizeof(float) * 2;
uint64_t userWorkspaceSize; uint64_t userWorkspaceSize;
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { if (opParam.isWeightQuantized == 1 &&
(opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT)) {
userWorkspaceSize = 4 * maxWorkspaceSize + pertokenWorkspace; userWorkspaceSize = 4 * maxWorkspaceSize + pertokenWorkspace;
} else { } else {
userWorkspaceSize = 3 * maxWorkspaceSize; userWorkspaceSize = 3 * maxWorkspaceSize;
@@ -554,21 +556,23 @@ void MlaPreprocessTiling::Init()
{ {
tilingData->numCore = platformInfo.coreNumAic; tilingData->numCore = platformInfo.coreNumAic;
tilingData->n = opParam.N; tilingData->n = opParam.N;
tilingData->hiddenStateDim = opParam.hiddenStateDim;
tilingData->isWeightQuantized = opParam.isWeightQuantized;
bool enDequant = (opParam.isWeightQuantized == 1);
bool deqOnTheFly = false; bool deqOnTheFly = false;
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { if (enDequant && (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT)) {
deqOnTheFly = true; deqOnTheFly = true;
} }
PpMatmulTilingApi mm1TilingApi(platformInfo, PpMatmulTilingApi mm1TilingApi(platformInfo,
1, // numBatch 1, // numBatch
opParam.N, // m opParam.N, // m
HIDDEN_STRATE, // k opParam.hiddenStateDim, // k
HIDDEN_STRATE_MM, // n HIDDEN_STRATE_MM, // n
false, // transA false, // transA
true, // transB true, // transB
true, // enDequant enDequant, // enDequant
deqOnTheFly); // in bf16.cce? deqOnTheFly); // in bf16.cce?
mm1TilingApi.GetTilingData(tilingData->mm1); mm1TilingApi.GetTilingData(tilingData->mm1);
PpMatmulTilingApi mm2TilingApi(platformInfo, PpMatmulTilingApi mm2TilingApi(platformInfo,
@@ -578,7 +582,7 @@ void MlaPreprocessTiling::Init()
opParam.headNum * HIDDEN_STRATE_ROPE, // n opParam.headNum * HIDDEN_STRATE_ROPE, // n
false, // transA false, // transA
true, // transB true, // transB
true, // enDequant enDequant, // enDequant
deqOnTheFly); // in bf16.cce? deqOnTheFly); // in bf16.cce?
mm2TilingApi.GetTilingData(tilingData->mm2); mm2TilingApi.GetTilingData(tilingData->mm2);
@@ -609,6 +613,8 @@ std::unordered_map<c10::string_view, uint16_t> cache_mode_map = {
std::unordered_map<c10::string_view, uint16_t> quant_mode_map = { std::unordered_map<c10::string_view, uint16_t> quant_mode_map = {
{"per_tensor_quant_asymm", 0}, {"per_tensor_quant_asymm", 0},
{"per_token_quant_symm", 1}, {"per_token_quant_symm", 1},
{"per_token_quant_asymm", 2},
{"no_quant", 3}
}; };
template <typename MapType> template <typename MapType>
@@ -623,6 +629,7 @@ inline int get_op_mode(const MapType &mode_map, c10::optional<c10::string_view>
std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling( std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
const at::Tensor &hiddenState, const at::Tensor &hiddenState,
const at::Tensor &wdqkv,
const at::Tensor &wuk, const at::Tensor &wuk,
c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> cache_mode,
c10::optional<c10::string_view> quant_mode, c10::optional<c10::string_view> quant_mode,
@@ -647,14 +654,21 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
int32_t N = hiddenState.sizes()[0]; int32_t N = hiddenState.sizes()[0];
int32_t headNum = wuk.sizes()[0]; int32_t headNum = wuk.sizes()[0];
uint32_t hiddenStateDim = hiddenState.sizes().back();
OpParam opParam; OpParam opParam;
opParam.hiddenStateDim = hiddenStateDim;
opParam.N = N; opParam.N = N;
opParam.headNum = headNum; opParam.headNum = headNum;
opParam.cacheMode = static_cast<int32_t>(cacheMode); opParam.cacheMode = static_cast<int32_t>(cacheMode);
opParam.quantMode = static_cast<QuantMode>(quantMode); opParam.quantMode = static_cast<QuantMode>(quantMode);
opParam.inDtype = hiddenState.options().dtype(); opParam.inDtype = hiddenState.options().dtype();
opParam.enableInnerOut = enable_inner_out; opParam.enableInnerOut = enable_inner_out;
if (wdqkv.options().dtype() == at::kBFloat16 || wdqkv.options().dtype() == at::kHalf) {
opParam.isWeightQuantized = 0;
} else {
opParam.isWeightQuantized = 1;
}
MlaTilingData tilingData; MlaTilingData tilingData;
MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData); MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData);

View File

@@ -90,6 +90,11 @@ struct MlaTilingData {
uint32_t esqHeadTail{0}; uint32_t esqHeadTail{0};
uint32_t esqColLoop{0}; uint32_t esqColLoop{0};
uint32_t esqColTail{0}; uint32_t esqColTail{0};
// hidden state dimension
uint32_t hiddenStateDim{7168};
uint32_t isWeightQuantized{1};
}; };
#endif // MLAPREPROCESS_TILING_H #endif // MLAPREPROCESS_TILING_H

View File

@@ -49,7 +49,6 @@ constexpr uint8_t CACHE_MODE_INT8_NZCACHE = 2; // high performance KV NZ format
constexpr uint8_t CACHE_MODE_NZCACHE = 3; constexpr uint8_t CACHE_MODE_NZCACHE = 3;
// pp matmul // pp matmul
constexpr uint32_t HIDDTEN_STATE = 7168;
constexpr uint32_t FLOAT_BLOCK_SIZE = 64; constexpr uint32_t FLOAT_BLOCK_SIZE = 64;
constexpr uint32_t HALF_BLOCK_SIZE = 64; constexpr uint32_t HALF_BLOCK_SIZE = 64;
constexpr uint32_t HALF_VECTOR_SIZE = 64; constexpr uint32_t HALF_VECTOR_SIZE = 64;
@@ -103,6 +102,7 @@ constexpr uint32_t KEY_FP16_CACHEMODE_1_QUANTMODE_0 = 1;
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0 = 256; constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0 = 256;
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0 = 257; constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0 = 257;
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0 = 259; constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0 = 259;
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_3 = 281;
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER = 256 + 512; constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER = 256 + 512;
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER = 257 + 512; constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER = 257 + 512;
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER = 259 + 512; constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER = 259 + 512;

View File

@@ -16,6 +16,7 @@
#include "mla_preprocess_mix_fp16.hpp" #include "mla_preprocess_mix_fp16.hpp"
#include "mla_preprocess_mix_bf16.hpp" #include "mla_preprocess_mix_bf16.hpp"
#include "mla_preprocess_mix_bf16_qdown.hpp" #include "mla_preprocess_mix_bf16_qdown.hpp"
#include "mla_preprocess_mix_bf16_nq.hpp"
#include "../op_host/tiling/mla_preprocess_tiling.h" #include "../op_host/tiling/mla_preprocess_tiling.h"
@@ -42,6 +43,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
mlaTilingData.tilingKey = tilingData->tilingKey; mlaTilingData.tilingKey = tilingData->tilingKey;
mlaTilingData.n = tilingData->n; mlaTilingData.n = tilingData->n;
mlaTilingData.hiddenStateDim = tilingData->hiddenStateDim;
mlaTilingData.mm1.numBatch = tilingData->mm1.numBatch; mlaTilingData.mm1.numBatch = tilingData->mm1.numBatch;
mlaTilingData.mm1.m = tilingData->mm1.m; mlaTilingData.mm1.m = tilingData->mm1.m;
@@ -173,12 +175,12 @@ extern "C" __global__ __aicore__ void mla_preprocess(
} }
case KEY_BF16_CACHEMODE_0_QUANTMODE_0: { case KEY_BF16_CACHEMODE_0_QUANTMODE_0: {
MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT> QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm0Qm0(mlaTilingData, tiling); opBf16Cm0Qm0(mlaTilingData, tiling);
opBf16Cm0Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, opBf16Cm0Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5); s1, s2, s3, s4, s5);
if ASCEND_IS_AIC { if ASCEND_IS_AIC {
opBf16Cm0Qm0.ProcessCube(); opBf16Cm0Qm0.ProcessCube();
} }
@@ -189,12 +191,12 @@ extern "C" __global__ __aicore__ void mla_preprocess(
} }
case KEY_BF16_CACHEMODE_1_QUANTMODE_0: { case KEY_BF16_CACHEMODE_1_QUANTMODE_0: {
MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT> QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm1Qm0(mlaTilingData, tiling); opBf16Cm1Qm0(mlaTilingData, tiling);
opBf16Cm1Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, opBf16Cm1Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5); s1, s2, s3, s4, s5);
if ASCEND_IS_AIC { if ASCEND_IS_AIC {
opBf16Cm1Qm0.ProcessCube(); opBf16Cm1Qm0.ProcessCube();
} }
@@ -219,6 +221,21 @@ extern "C" __global__ __aicore__ void mla_preprocess(
} }
break; break;
} }
case KEY_BF16_CACHEMODE_1_QUANTMODE_3: {
MLAPO_BF16_NQ::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND>
opBf16Cm1Qm0(mlaTilingData, tiling);
opBf16Cm1Qm0.Init(hiddenState, wdqkv, gamma2, beta2,
gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
wuk, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3);
if ASCEND_IS_AIC {
opBf16Cm1Qm0.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm1Qm0.ProcessVector();
}
break;
}
case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER: { case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER: {
MLAPO_BF16_INNER::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, MLAPO_BF16_INNER::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT> QuantMode::PER_TENSOR_ASYMM_QUANT>

View File

@@ -2386,6 +2386,7 @@ public:
this->num_row = mlaParams_.n; this->num_row = mlaParams_.n;
this->epsilon_ = 1e-6; this->epsilon_ = 1e-6;
this->mlaParams = mlaParams_; this->mlaParams = mlaParams_;
this->hiddenStateDim = mlaParams_.hiddenStateDim;
} }
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm, __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
@@ -2692,6 +2693,7 @@ private:
uint32_t blockOffset; uint32_t blockOffset;
uint32_t perTaskNum; uint32_t perTaskNum;
uint32_t resTaskNum; uint32_t resTaskNum;
uint32_t hiddenStateDim;
MlaTilingData mlaParams; MlaTilingData mlaParams;
uint32_t num_core_; uint32_t num_core_;
@@ -2795,18 +2797,15 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
const uint32_t base_offset = hiddenStateDim * 6;
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0); AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
AscendC::LocalTensor<InDtype> scale_tensor = AscendC::LocalTensor<InDtype> scale_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(base_offset);
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(base_offset + 32);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>( AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(base_offset + 64);
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32);
AscendC::LocalTensor<float> res1_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, float>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64);
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>( AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4); base_offset + 64 + num_col_align_f32 * 4);
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>( AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 + base_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64);
BUF_FACTOR * num_col_align_f32 * 4 + 64);
Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor);
} }
FftsCrossCoreSync<PIPE_MTE3, 0>(QUANT1); FftsCrossCoreSync<PIPE_MTE3, 0>(QUANT1);

File diff suppressed because it is too large Load Diff

View File

@@ -2406,6 +2406,7 @@ public:
this->num_row = mlaParams_.n; this->num_row = mlaParams_.n;
this->epsilon_ = 1e-6; this->epsilon_ = 1e-6;
this->mlaParams = mlaParams_; this->mlaParams = mlaParams_;
this->hiddenStateDim = mlaParams_.hiddenStateDim;
} }
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm, __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
@@ -2713,6 +2714,7 @@ private:
uint32_t blockOffset; uint32_t blockOffset;
uint32_t perTaskNum; uint32_t perTaskNum;
uint32_t resTaskNum; uint32_t resTaskNum;
uint32_t hiddenStateDim;
MlaTilingData mlaParams; MlaTilingData mlaParams;
uint32_t num_core_; uint32_t num_core_;
@@ -2817,18 +2819,15 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
const uint32_t base_offset = hiddenStateDim * 6;
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0); AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
AscendC::LocalTensor<InDtype> scale_tensor = AscendC::LocalTensor<InDtype> scale_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(base_offset);
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(base_offset + 32);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>( AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(base_offset + 64);
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32);
AscendC::LocalTensor<float> res1_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, float>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64);
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>( AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4); base_offset + 64 + num_col_align_f32 * 4);
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>( AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 + base_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64);
BUF_FACTOR * num_col_align_f32 * 4 + 64);
Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor);
} }
FftsCrossCoreSync<PIPE_MTE3, 0>(QUANT1); FftsCrossCoreSync<PIPE_MTE3, 0>(QUANT1);

View File

@@ -2034,6 +2034,7 @@ public:
this->num_row = mlaParams_.n; this->num_row = mlaParams_.n;
this->epsilon_ = 1e-6; this->epsilon_ = 1e-6;
this->mlaParams = mlaParams_; this->mlaParams = mlaParams_;
this->hiddenStateDim = mlaParams_.hiddenStateDim;
} }
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm, __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
@@ -2294,6 +2295,7 @@ private:
uint32_t blockOffset; uint32_t blockOffset;
uint32_t perTaskNum; uint32_t perTaskNum;
uint32_t resTaskNum; uint32_t resTaskNum;
uint32_t hiddenStateDim;
MlaTilingData mlaParams; MlaTilingData mlaParams;
// rmsnormQuant // rmsnormQuant
@@ -2389,21 +2391,19 @@ __aicore__ inline void MLAOperation<cacheMode, weightFormat1, weightFormat2, wei
uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
const uint32_t gamma_offset = hiddenStateDim * 2;
const uint32_t beta_offset = gamma_offset + hiddenStateDim * 2;
const uint32_t scale_offset = beta_offset + hiddenStateDim * 2;
AscendC::LocalTensor<half> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(0); AscendC::LocalTensor<half> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(0);
AscendC::LocalTensor<half> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2); AscendC::LocalTensor<half> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(gamma_offset);
AscendC::LocalTensor<half> beta_tensor = AscendC::LocalTensor<half> beta_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(beta_offset);
buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); AscendC::LocalTensor<half> scale_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(scale_offset);
AscendC::LocalTensor<half> scale_tensor = AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(scale_offset + 32);
buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(scale_offset + 64);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32);
AscendC::LocalTensor<float> res1_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, float>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64);
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>( AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4); scale_offset + 64 + num_col_align_f32 * 4);
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>( AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 + scale_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 32);
BUF_FACTOR * num_col_align_f32 * 4 + 32);
Quant1.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor, Quant1.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor,
res3_tensor); res3_tensor);
} }

View File

@@ -176,15 +176,51 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess( std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
const at::Tensor &hiddenState, const at::Tensor &wdqkv, const at::Tensor &hiddenState, const at::Tensor &wdqkv,
const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq, const c10::optional<at::Tensor> &descale0, const at::Tensor &gamma1, const c10::optional<at::Tensor> &beta1, const at::Tensor &wuq,
const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin, const c10::optional<at::Tensor> &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping, const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0, const c10::optional<at::Tensor> &quant_scale0, const c10::optional<at::Tensor> &quant_offset0, const c10::optional<at::Tensor> &bias0,
const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1, const c10::optional<at::Tensor> &quant_scale1, const c10::optional<at::Tensor> &quant_offset1, const c10::optional<at::Tensor> &bias1,
const c10::optional<at::Tensor> &ctkv_scale, const c10::optional<at::Tensor> &q_nope_scale, const c10::optional<at::Tensor> &ctkv_scale, const c10::optional<at::Tensor> &q_nope_scale,
c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, c10::optional<bool> enable_inner_out, at::Tensor &q_out0, c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, c10::optional<bool> enable_inner_out, at::Tensor &q_out0,
at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1, at::Tensor &inner_out) at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1, at::Tensor &inner_out)
{ {
at::Tensor Descale0 =
descale0.has_value()
? descale0.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor Descale1 =
descale1.has_value()
? descale1.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor Beta1 =
beta1.has_value()
? beta1.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor Quant_scale0 =
quant_scale0.has_value()
? quant_scale0.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor Quant_scale1 =
quant_scale1.has_value()
? quant_scale1.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor Quant_offset0 =
quant_offset0.has_value()
? quant_offset0.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor Quant_offset1 =
quant_offset1.has_value()
? quant_offset1.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor Bias0 =
bias0.has_value()
? bias0.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor Bias1 =
bias1.has_value()
? bias1.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor CtkvScale = at::Tensor CtkvScale =
ctkv_scale.has_value() ctkv_scale.has_value()
? ctkv_scale.value() ? ctkv_scale.value()
@@ -200,6 +236,7 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &>
auto [workspace_tensor, tiling, block_dim] = mlapo::mla_preprocess_tiling( auto [workspace_tensor, tiling, block_dim] = mlapo::mla_preprocess_tiling(
hiddenState, hiddenState,
wdqkv,
wuk, wuk,
cache_mode, cache_mode,
quant_mode, quant_mode,
@@ -207,24 +244,24 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &>
); );
void *hidden_state_ptr = hiddenState.data_ptr(); void *hidden_state_ptr = hiddenState.data_ptr();
void *quant_scale0_ptr = quant_scale0.data_ptr(); void *quant_scale0_ptr = Quant_scale0.data_ptr();
void *quant_offset0_ptr = quant_offset0.data_ptr(); void *quant_offset0_ptr = Quant_offset0.data_ptr();
void *wdqkv_ptr = wdqkv.data_ptr(); void *wdqkv_ptr = wdqkv.data_ptr();
void *bias0_ptr = bias0.data_ptr(); void *bias0_ptr = Bias0.data_ptr();
void *gamma1_ptr = gamma1.data_ptr(); void *gamma1_ptr = gamma1.data_ptr();
void *beta1_ptr = beta1.data_ptr(); void *beta1_ptr = Beta1.data_ptr();
void *quant_scale1_ptr = quant_scale1.data_ptr(); void *quant_scale1_ptr = Quant_scale1.data_ptr();
void *quant_offset1_ptr = quant_offset1.data_ptr(); void *quant_offset1_ptr = Quant_offset1.data_ptr();
void *gamma2_ptr = gamma2.data_ptr(); void *gamma2_ptr = gamma2.data_ptr();
void *sin_ptr = sin.data_ptr(); void *sin_ptr = sin.data_ptr();
void *cos_ptr = cos.data_ptr(); void *cos_ptr = cos.data_ptr();
void *kv_cache_ptr = kv_cache.data_ptr(); void *kv_cache_ptr = kv_cache.data_ptr();
void *slotmapping_ptr = slotmapping.data_ptr(); void *slotmapping_ptr = slotmapping.data_ptr();
void *wuq_ptr = wuq.data_ptr(); void *wuq_ptr = wuq.data_ptr();
void *bias1_ptr = bias1.data_ptr(); void *bias1_ptr = Bias1.data_ptr();
void *wuk_ptr = wuk.data_ptr(); void *wuk_ptr = wuk.data_ptr();
void *descale0_ptr = descale0.data_ptr(); void *descale0_ptr = Descale0.data_ptr();
void *descale1_ptr = descale1.data_ptr(); void *descale1_ptr = Descale1.data_ptr();
void *ctkv_scale_ptr = CtkvScale.data_ptr(); void *ctkv_scale_ptr = CtkvScale.data_ptr();
void *qnope_scale_ptr = QnopeScale.data_ptr(); void *qnope_scale_ptr = QnopeScale.data_ptr();
void *q_out0_ptr = q_out0.data_ptr(); void *q_out0_ptr = q_out0.data_ptr();
@@ -1122,11 +1159,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
ops.def( ops.def(
"mla_preprocess(Tensor hiddenState, Tensor wdqkv," "mla_preprocess(Tensor hiddenState, Tensor wdqkv,"
" Tensor descale0, Tensor gamma1, Tensor beta1, Tensor wuq, Tensor descale1," " Tensor? descale0, Tensor gamma1, Tensor? beta1, Tensor wuq, Tensor? descale1,"
" Tensor gamma2, Tensor cos, Tensor sin, Tensor wuk, Tensor kv_cache," " Tensor gamma2, Tensor cos, Tensor sin, Tensor wuk, Tensor kv_cache,"
" Tensor kv_cache_rope, Tensor slotmapping, Tensor quant_scale0," " Tensor kv_cache_rope, Tensor slotmapping, Tensor? quant_scale0,"
" Tensor quant_offset0, Tensor bias0, Tensor quant_scale1, Tensor quant_offset1," " Tensor? quant_offset0, Tensor? bias0, Tensor? quant_scale1, Tensor? quant_offset1,"
" Tensor bias1, Tensor? ctkv_scale, Tensor? q_nope_scale, str? cache_mode," " Tensor? bias1, Tensor? ctkv_scale, Tensor? q_nope_scale, str? cache_mode,"
" str? quant_mode, bool? enable_inner_out, Tensor! q_out0, Tensor! kv_cache_out0, Tensor! q_out1," " str? quant_mode, bool? enable_inner_out, Tensor! q_out0, Tensor! kv_cache_out0, Tensor! q_out1,"
" Tensor! kv_cache_out1, Tensor! inner_out) -> (Tensor q_out0, Tensor kv_cache_out0," " Tensor! kv_cache_out1, Tensor! inner_out) -> (Tensor q_out0, Tensor kv_cache_out0,"
" Tensor q_out1, Tensor kv_cache_out1, Tensor inner_out)" " Tensor q_out1, Tensor kv_cache_out1, Tensor inner_out)"

View File

@@ -84,11 +84,11 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess( std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
const at::Tensor &hiddenState, const at::Tensor &hiddenState,
const at::Tensor &wdqkv, const at::Tensor &wdqkv,
const at::Tensor &descale0, const c10::optional<at::Tensor> &descale0,
const at::Tensor &gamma1, const at::Tensor &gamma1,
const at::Tensor &beta1, const c10::optional<at::Tensor> &beta1,
const at::Tensor &wuq, const at::Tensor &wuq,
const at::Tensor &descale1, const c10::optional<at::Tensor> &descale1,
const at::Tensor &gamma2, const at::Tensor &gamma2,
const at::Tensor &cos, const at::Tensor &cos,
const at::Tensor &sin, const at::Tensor &sin,
@@ -96,12 +96,12 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &>
const at::Tensor &kv_cache, const at::Tensor &kv_cache,
const at::Tensor &kv_cache_rope, const at::Tensor &kv_cache_rope,
const at::Tensor &slotmapping, const at::Tensor &slotmapping,
const at::Tensor &quant_scale0, const c10::optional<at::Tensor> &quant_scale0,
const at::Tensor &quant_offset0, const c10::optional<at::Tensor> &quant_offset0,
const at::Tensor &bias0, const c10::optional<at::Tensor> &bias0,
const at::Tensor &quant_scale1, const c10::optional<at::Tensor> &quant_scale1,
const at::Tensor &quant_offset1, const c10::optional<at::Tensor> &quant_offset1,
const at::Tensor &bias1, const c10::optional<at::Tensor> &bias1,
const c10::optional<at::Tensor> &ctkv_scale, const c10::optional<at::Tensor> &ctkv_scale,
const c10::optional<at::Tensor> &q_nope_scale, const c10::optional<at::Tensor> &q_nope_scale,
c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> cache_mode,

View File

@@ -67,6 +67,11 @@ def test_mla_preprocess_kernel():
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device=hidden_states.device, device=hidden_states.device,
) )
q_down = torch.empty(
(hidden_states.shape[0], 1536),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_nope_old = q_nope_out.clone() q_nope_old = q_nope_out.clone()
q_rope_old = q_rope_out.clone() q_rope_old = q_rope_out.clone()
@@ -95,10 +100,12 @@ def test_mla_preprocess_kernel():
q_nope_scale=qnope_scale, q_nope_scale=qnope_scale,
cache_mode="krope_ctkv", cache_mode="krope_ctkv",
quant_mode="per_tensor_quant_asymm", quant_mode="per_tensor_quant_asymm",
enable_inner_out=False,
q_out0=q_nope_out, q_out0=q_nope_out,
kv_cache_out0=kv_cache, kv_cache_out0=kv_cache,
q_out1=q_rope_out, q_out1=q_rope_out,
kv_cache_out1=kv_cache_rope, kv_cache_out1=kv_cache_rope,
inner_out=q_down,
) )
assert not torch.equal(q_nope_out, q_nope_old) assert not torch.equal(q_nope_out, q_nope_old)
assert not torch.equal(q_rope_out, q_rope_old) assert not torch.equal(q_rope_out, q_rope_old)

View File

@@ -0,0 +1,99 @@
import gc
import torch
import torch_npu
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
@torch.inference_mode()
def test_mla_preprocess_kernel():
token_num = 1
head_num = 2
N_7168 = 7168
block_num = 1
block_size = 128
dtype = torch.bfloat16
hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu()
wdqkv = torch.randint(0, 7, (1, 448, 2112, 16), dtype=dtype).npu()
wdqkv = torch_npu.npu_format_cast(wdqkv.contiguous(), 29)
gamma1 = torch.randn((1536), dtype=dtype).npu()
wuq = torch.randint(0, 7, (1, 96, head_num * 192, 16), dtype=dtype).npu()
wuq = torch_npu.npu_format_cast(wuq.contiguous(), 29)
gamma2 = torch.randn((512), dtype=dtype).npu()
cos = torch.randn((token_num, 64), dtype=dtype).npu()
sin = torch.randn((token_num, 64), dtype=dtype).npu()
wuk = torch.randn((head_num, 128, 512), dtype=dtype).npu()
# wuk = torch_npu.npu_format_cast(wuk, 29)
kv_cache = torch.randint(0,
7,
(block_num, head_num * 512 // 32, block_size, 32),
dtype=dtype).npu()
kv_cache_rope = torch.randn(
(block_num, head_num * 64 // 16, block_size, 16), dtype=dtype).npu()
slotmapping = torch.randint(0, 7, (token_num, ), dtype=torch.int32).npu()
q_nope_out = torch.empty(
(hidden_states.shape[0], wuk.shape[0], kv_cache.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_rope_out = torch.empty(
(hidden_states.shape[0], wuk.shape[0], kv_cache_rope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_down = torch.empty(
(hidden_states.shape[0], 1536),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_nope_old = q_nope_out.clone()
q_rope_old = q_rope_out.clone()
torch.ops._C_ascend.mla_preprocess(
hidden_states,
wdqkv,
None,
gamma1,
None,
wuq,
None,
gamma2,
cos,
sin,
wuk,
kv_cache,
kv_cache_rope,
slotmapping,
None,
None,
None,
None,
None,
None,
None,
None,
cache_mode="krope_ctkv",
quant_mode="no_quant",
enable_inner_out=False,
q_out0=q_nope_out,
kv_cache_out0=kv_cache,
q_out1=q_rope_out,
kv_cache_out1=kv_cache_rope,
inner_out=q_down,
)
assert not torch.equal(q_nope_out, q_nope_old)
assert not torch.equal(q_rope_out, q_rope_old)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()