[feat] mlapo add bf16 no_quant support (#4852)
### What this PR does / why we need it?
This PR adds mlapo operation support for bf16 no_quant mode.
### Does this PR introduce _any_ user-facing change?
This PR makes quant related parameters optional.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: chenjunyi <isjunyi.chen@gmail.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
#include "mla_preprocess_mix_fp16.hpp"
|
||||
#include "mla_preprocess_mix_bf16.hpp"
|
||||
#include "mla_preprocess_mix_bf16_qdown.hpp"
|
||||
#include "mla_preprocess_mix_bf16_nq.hpp"
|
||||
|
||||
#include "../op_host/tiling/mla_preprocess_tiling.h"
|
||||
|
||||
@@ -42,6 +43,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
|
||||
|
||||
mlaTilingData.tilingKey = tilingData->tilingKey;
|
||||
mlaTilingData.n = tilingData->n;
|
||||
mlaTilingData.hiddenStateDim = tilingData->hiddenStateDim;
|
||||
|
||||
mlaTilingData.mm1.numBatch = tilingData->mm1.numBatch;
|
||||
mlaTilingData.mm1.m = tilingData->mm1.m;
|
||||
@@ -173,12 +175,12 @@ extern "C" __global__ __aicore__ void mla_preprocess(
|
||||
}
|
||||
case KEY_BF16_CACHEMODE_0_QUANTMODE_0: {
|
||||
MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||
opBf16Cm0Qm0(mlaTilingData, tiling);
|
||||
opBf16Cm0Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3, s4, s5);
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3, s4, s5);
|
||||
if ASCEND_IS_AIC {
|
||||
opBf16Cm0Qm0.ProcessCube();
|
||||
}
|
||||
@@ -189,12 +191,12 @@ extern "C" __global__ __aicore__ void mla_preprocess(
|
||||
}
|
||||
case KEY_BF16_CACHEMODE_1_QUANTMODE_0: {
|
||||
MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||
opBf16Cm1Qm0(mlaTilingData, tiling);
|
||||
opBf16Cm1Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3, s4, s5);
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3, s4, s5);
|
||||
if ASCEND_IS_AIC {
|
||||
opBf16Cm1Qm0.ProcessCube();
|
||||
}
|
||||
@@ -219,6 +221,21 @@ extern "C" __global__ __aicore__ void mla_preprocess(
|
||||
}
|
||||
break;
|
||||
}
|
||||
case KEY_BF16_CACHEMODE_1_QUANTMODE_3: {
|
||||
MLAPO_BF16_NQ::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND>
|
||||
opBf16Cm1Qm0(mlaTilingData, tiling);
|
||||
opBf16Cm1Qm0.Init(hiddenState, wdqkv, gamma2, beta2,
|
||||
gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
wuk, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3);
|
||||
if ASCEND_IS_AIC {
|
||||
opBf16Cm1Qm0.ProcessCube();
|
||||
}
|
||||
if ASCEND_IS_AIV {
|
||||
opBf16Cm1Qm0.ProcessVector();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER: {
|
||||
MLAPO_BF16_INNER::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||
|
||||
Reference in New Issue
Block a user