[Cherry-pick]bmm_transpose to v011dev (#3995)
### What this PR does / why we need it?
Add a custom op to acclerater the deepseek model. The fusion ops combine
the bmm and transpose together, which is applied to mla module.
Cherry-pick from this commtid c68ddc11ce
### Does this PR introduce _any_ user-facing change?
No
---------
Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
155
csrc/batch_matmul_transpose/op_host/tiling/tiling_data.cpp
Normal file
155
csrc/batch_matmul_transpose/op_host/tiling/tiling_data.cpp
Normal file
@@ -0,0 +1,155 @@
|
||||
#include <map>
|
||||
#include "tiling_data.h"
|
||||
#include "common.h"
|
||||
#include "common_tiling.h"
|
||||
|
||||
namespace pp_matmul {
|
||||
|
||||
constexpr uint32_t L1_DESCALE_BUFFER_LEN_MAX = 6144;
|
||||
constexpr uint32_t CONST_3 = 3;
|
||||
constexpr uint32_t CONST_4 = 4;
|
||||
constexpr uint32_t CONST_16 = 16;
|
||||
constexpr uint32_t CONST_32 = 32;
|
||||
constexpr uint32_t CONST_256 = 256;
|
||||
constexpr uint32_t CONST_512 = 512;
|
||||
|
||||
const std::map<TensorDType, uint32_t> G_DTYPE_MAP = {{TensorDType::TENSOR_DTYPE_FLOAT16, 1u},
|
||||
{TensorDType::TENSOR_DTYPE_BF16, 2u}};
|
||||
const std::map<TensorFormat, uint32_t> G_FORMAT_MAP = {{TensorFormat::TENSOR_FORMAT_ND, 0u},
|
||||
{TensorFormat::TENSOR_FORMAT_NZ, 1u}};
|
||||
using MmType = MatMul::MatMulType;
|
||||
using QmType = MatMul::QuantMode;
|
||||
using namespace host_utils;
|
||||
|
||||
bool IsI8Bf16Kernel(const MatMulInfo &mmInfo)
|
||||
{
|
||||
bool isI8Bf16 = mmInfo.isInt8 && mmInfo.dtypeC == TensorDType::TENSOR_DTYPE_BF16;
|
||||
bool isI8Fp16 = mmInfo.isInt8 && mmInfo.dtypeC == TensorDType::TENSOR_DTYPE_FLOAT16 &&
|
||||
mmInfo.quantMode == QmType::PER_TOKEN_SYMM;
|
||||
return isI8Bf16 || isI8Fp16;
|
||||
}
|
||||
|
||||
HardwareInfo::HardwareInfo()
|
||||
{
|
||||
auto &platform = PlatformInfo::Instance();
|
||||
coreNum = platform.coreNumAic;
|
||||
l2Size = platform.l2Size;
|
||||
l1Size = platform.l1Size;
|
||||
l0aSize = platform.l0aSize;
|
||||
l0bSize = platform.l0bSize;
|
||||
l0cSize = platform.l0cSize;
|
||||
hbmBandWidth = 1;
|
||||
l2BandWidth = 5; // 5x faster than hbm.
|
||||
}
|
||||
|
||||
void PpMatmulTilingData::SetBaseShape(uint32_t batchSize, uint32_t m, uint32_t k, uint32_t n)
|
||||
{
|
||||
opShape.batchSize = batchSize;
|
||||
opShape.m = m;
|
||||
opShape.k = k;
|
||||
opShape.n = n;
|
||||
}
|
||||
|
||||
void PpMatmulTilingData::SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase, const MatMulInfo &mmInfo)
|
||||
{
|
||||
opShape.m0 = mBase;
|
||||
opShape.n0 = nBase;
|
||||
mLoop = CeilDiv(opShape.m, opShape.m0);
|
||||
nLoop = CeilDiv(opShape.n, opShape.n0);
|
||||
coreLoop = opShape.batchSize * mLoop * nLoop;
|
||||
|
||||
if (mLoop == 1 && mmInfo.transB && coreLoop % coreNum < coreNum / CONST_4 * CONST_3) {
|
||||
mBase = RoundUp<uint32_t>(opShape.m, CONST_16);
|
||||
opShape.m0 = mBase;
|
||||
uint32_t maxN0 = PlatformInfo::Instance().l0cSize / (mBase * sizeof(float));
|
||||
if (mmInfo.isInt8 || mmInfo.mmType == MmType::MATMUL_WITH_BIAS) {
|
||||
maxN0 = maxN0 < CONST_256 ? maxN0 : CONST_256;
|
||||
}
|
||||
uint32_t x = CeilDiv(opShape.n, coreNum);
|
||||
uint32_t y = CeilDiv(x, maxN0);
|
||||
nBase = RoundUp<uint32_t>(CeilDiv(x, y), CONST_16);
|
||||
uint32_t rqdL0CSize = mBase * nBase * sizeof(float);
|
||||
if (rqdL0CSize < PlatformInfo::Instance().l0cSize &&
|
||||
(mBase + nBase) * CONST_256 * sizeof(uint16_t) < L1AB_PINGPONG_BUFFER_LEN) {
|
||||
opShape.n0 = nBase;
|
||||
nLoop = CeilDiv(opShape.n, opShape.n0);
|
||||
coreLoop = opShape.batchSize * nLoop;
|
||||
}
|
||||
}
|
||||
blockDim = std::min(coreLoop, coreNum);
|
||||
}
|
||||
|
||||
// transA transB quantMode [dtype] format
|
||||
void PpMatmulTilingData::SetTilingKey(const MatMulInfo &mmInfo, uint32_t swizzleDirect, uint32_t enSplitK)
|
||||
{
|
||||
if (mmInfo.mmType == MmType::MATMUL_ACCUM_ATOMIC || mmInfo.mmType == MmType::MATMUL_WITH_BIAS ||
|
||||
mmInfo.mmType == MmType::MATMUL_EIN_SUM || mmInfo.mmType == MmType::MATMUL_DEQUANT || IsI8Bf16Kernel(mmInfo)) {
|
||||
// SwizzleDir[1] TransA[1] TransB[1] DtypeA[3] DtypeB[3] DtypeC[3] FormatA[1] FormatB[1] FormatC[1] WithBias[1]
|
||||
tilingKey = swizzleDirect;
|
||||
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.transA);
|
||||
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.transB);
|
||||
tilingKey = (tilingKey << 3) + G_DTYPE_MAP.at(mmInfo.dtypeA); // 3bit for dtypeA.
|
||||
tilingKey = (tilingKey << 3) + G_DTYPE_MAP.at(mmInfo.dtypeB); // 3bit for dtypeB.
|
||||
tilingKey = (tilingKey << 3) + G_DTYPE_MAP.at(mmInfo.dtypeC); // 3bit for dtypeC.
|
||||
tilingKey = (tilingKey << 1) + G_FORMAT_MAP.at(mmInfo.formatA);
|
||||
tilingKey = (tilingKey << 1) + G_FORMAT_MAP.at(mmInfo.formatB);
|
||||
tilingKey = (tilingKey << 1) + G_FORMAT_MAP.at(mmInfo.formatC);
|
||||
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.biasFlag);
|
||||
} else {
|
||||
tilingKey = swizzleDirect;
|
||||
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.transA);
|
||||
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.transB);
|
||||
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.isInt8);
|
||||
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.biasFlag);
|
||||
tilingKey = (tilingKey << 1) + enSplitK;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t PpMatmulTilingData::End(const MatMulInfo &mmInfo)
|
||||
{
|
||||
uint32_t cubeBlockSize = mmInfo.isInt8 ? CUBE_BLOCK_SIZE_INT8 : CUBE_BLOCK_SIZE;
|
||||
uint32_t kBlockSize = mmInfo.isInt8 ? BLOCK_SIZE_INT8_K : BLOCK_SIZE;
|
||||
uint32_t scaleBlockSize = mmInfo.isInt8 ? L1_DESCALE_BUFFER_LEN_MAX : 0;
|
||||
uint32_t shapeSum = opShape.m0 + opShape.n0;
|
||||
if (mmInfo.isInt8 && (mmInfo.transA || !mmInfo.transB)) {
|
||||
shapeSum = RoundUp<uint32_t>(opShape.m0, CONST_32) + RoundUp<uint32_t>(opShape.n0, CONST_32);
|
||||
}
|
||||
uint32_t k0Max = shapeSum == 0
|
||||
? L1AB_PINGPONG_BUFFER_LEN
|
||||
: static_cast<uint32_t>(static_cast<float>(L1AB_PINGPONG_BUFFER_LEN - scaleBlockSize) /
|
||||
(shapeSum * mmInfo.inDtype));
|
||||
if (mmInfo.mmType == MatMul::MatMulType::MATMUL_WITH_BIAS) {
|
||||
uint32_t l1AbSize = L1AB_PINGPONG_BUFFER_LEN - opShape.n0 * sizeof(float);
|
||||
k0Max = l1AbSize / (shapeSum * mmInfo.inDtype);
|
||||
}
|
||||
|
||||
opShape.k0 =
|
||||
k0Max < cubeBlockSize ? RoundDown<uint32_t>(k0Max, kBlockSize) : RoundDown<uint32_t>(k0Max, cubeBlockSize);
|
||||
if (opShape.k0 > CONST_512) {
|
||||
opShape.k0 = RoundDown<uint32_t>(opShape.k0, CONST_512);
|
||||
}
|
||||
kLoop = CeilDiv(opShape.k, opShape.k0);
|
||||
return blockDim;
|
||||
}
|
||||
|
||||
void GetPpMatmulTiling(const MatMulInfo &mmInfo, const HardwareInfo &hwInfo, uint32_t &blockDim,
|
||||
PpMatmulTilingData &tilingData)
|
||||
{
|
||||
OpShape opShape;
|
||||
opShape.batchSize = mmInfo.batchSize;
|
||||
opShape.m = mmInfo.m;
|
||||
opShape.n = mmInfo.n;
|
||||
opShape.k = mmInfo.k;
|
||||
tilingData.opShape = opShape;
|
||||
tilingData.quantMode = static_cast<uint32_t>(mmInfo.quantMode);
|
||||
tilingData.SetTilingKey(mmInfo, 0, 0); // init tilingkey with transA transB.
|
||||
if (opShape.m < opShape.n) {
|
||||
TilingFunc<false, OpShape, PpMatmulTilingData, HardwareInfo, MatMulInfo>(opShape, tilingData, hwInfo, mmInfo);
|
||||
} else {
|
||||
TilingFunc<true, OpShape, PpMatmulTilingData, HardwareInfo, MatMulInfo>(opShape, tilingData, hwInfo, mmInfo);
|
||||
}
|
||||
uint32_t direct = Swizzl<PpMatmulTilingData>(tilingData);
|
||||
blockDim = tilingData.End(mmInfo);
|
||||
tilingData.SetTilingKey(mmInfo, direct, 0);
|
||||
}
|
||||
} // namespace pp_matmul
|
||||
90
csrc/batch_matmul_transpose/op_host/tiling/tiling_data.h
Normal file
90
csrc/batch_matmul_transpose/op_host/tiling/tiling_data.h
Normal file
@@ -0,0 +1,90 @@
|
||||
#ifndef PP_MATMUL_TILING_DATA
|
||||
#define PP_MATMUL_TILING_DATA
|
||||
#include <cstdint>
|
||||
|
||||
namespace pp_matmul {
|
||||
struct MatMul {
|
||||
enum class MatMulType : uint32_t {
|
||||
MATMUL_DEFAULT = 0, // C = op(A) * op(B)
|
||||
MATMUL_DEQUANT, //
|
||||
MATMUL_ACCUM_ATOMIC, // C += op(A) * op(B)
|
||||
MATMUL_WITH_BIAS, // C = op(A) * op(B) + Bias, where Bias is a vector.
|
||||
MATMUL_EIN_SUM
|
||||
};
|
||||
enum class QuantMode : uint32_t { PER_CHANNEL_SYMM = 0, PER_CHANNEL_ASYMM, PER_TOKEN_SYMM };
|
||||
};
|
||||
|
||||
enum class TensorDType : uint32_t { TENSOR_DTYPE_FLOAT16 = 0, TENSOR_DTYPE_BF16 };
|
||||
|
||||
enum class TensorFormat : uint32_t { TENSOR_FORMAT_ND = 0, TENSOR_FORMAT_NZ };
|
||||
|
||||
struct MatMulInfo {
|
||||
uint32_t batchSize{0};
|
||||
uint32_t m{0}; // actual input m
|
||||
uint32_t k{0}; // actual input k
|
||||
uint32_t n{0}; // actual input n
|
||||
TensorDType dtypeA{TensorDType::TENSOR_DTYPE_FLOAT16};
|
||||
TensorDType dtypeB{TensorDType::TENSOR_DTYPE_FLOAT16};
|
||||
TensorDType dtypeC{TensorDType::TENSOR_DTYPE_FLOAT16};
|
||||
TensorFormat formatA{TensorFormat::TENSOR_FORMAT_ND};
|
||||
TensorFormat formatB{TensorFormat::TENSOR_FORMAT_ND};
|
||||
TensorFormat formatC{TensorFormat::TENSOR_FORMAT_ND};
|
||||
MatMul::MatMulType mmType{MatMul::MatMulType::MATMUL_DEFAULT};
|
||||
bool transA{0}; // false: 0, true: 1
|
||||
bool transB{0}; // false: 0, true: 1
|
||||
bool biasFlag{0}; // false: 0, true: 1
|
||||
bool isInt8{0}; // false: 0, true: 1
|
||||
float inDtype{0};
|
||||
float outDtype{0};
|
||||
MatMul::QuantMode quantMode{MatMul::QuantMode::PER_CHANNEL_SYMM};
|
||||
};
|
||||
|
||||
struct OpShape {
|
||||
uint32_t batchSize{0};
|
||||
uint32_t m{0};
|
||||
uint32_t k{0};
|
||||
uint32_t n{0};
|
||||
uint32_t m0{0};
|
||||
uint32_t k0{0};
|
||||
uint32_t n0{0};
|
||||
};
|
||||
|
||||
struct HardwareInfo {
|
||||
uint32_t coreNum{0};
|
||||
uint32_t l2Size{0};
|
||||
uint32_t l1Size{0};
|
||||
uint32_t l0aSize{0};
|
||||
uint32_t l0bSize{0};
|
||||
uint32_t l0cSize{0};
|
||||
uint32_t hbmBandWidth{0};
|
||||
uint32_t l2BandWidth{0};
|
||||
|
||||
HardwareInfo();
|
||||
};
|
||||
|
||||
#pragma pack(push, 1)
|
||||
struct PpMatmulTilingData {
|
||||
OpShape opShape{};
|
||||
uint32_t mLoop{1};
|
||||
uint32_t kLoop{1};
|
||||
uint32_t nLoop{1};
|
||||
uint32_t coreLoop{1};
|
||||
uint32_t swizzlCount{1};
|
||||
uint32_t tilingKey{0};
|
||||
uint32_t blockDim{1};
|
||||
uint32_t swizzlDirect{0};
|
||||
uint32_t splitk{0};
|
||||
uint32_t enShuffleK{0};
|
||||
uint32_t quantMode{0};
|
||||
|
||||
void SetBaseShape(uint32_t batchSize, uint32_t m, uint32_t k, uint32_t n);
|
||||
void SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase, const MatMulInfo &mmInfo);
|
||||
void SetTilingKey(const MatMulInfo &mmInfo, uint32_t swizzleDirect, uint32_t enSplitK);
|
||||
uint32_t End(const MatMulInfo &mmInfo);
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
void GetPpMatmulTiling(const MatMulInfo &mmInfo, const HardwareInfo &hwInfo, uint32_t &blockDim,
|
||||
PpMatmulTilingData &tilingData);
|
||||
} // namespace pp_matmul
|
||||
#endif
|
||||
Reference in New Issue
Block a user