[Cherry-pick]bmm_transpose to v011dev (#3995)
### What this PR does / why we need it?
Add a custom op to acclerater the deepseek model. The fusion ops combine
the bmm and transpose together, which is applied to mla module.
Cherry-pick from this commtid c68ddc11ce
### Does this PR introduce _any_ user-facing change?
No
---------
Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
90
csrc/batch_matmul_transpose/op_host/tiling/tiling_data.h
Normal file
90
csrc/batch_matmul_transpose/op_host/tiling/tiling_data.h
Normal file
@@ -0,0 +1,90 @@
|
||||
#ifndef PP_MATMUL_TILING_DATA
|
||||
#define PP_MATMUL_TILING_DATA
|
||||
#include <cstdint>
|
||||
|
||||
namespace pp_matmul {
|
||||
struct MatMul {
|
||||
enum class MatMulType : uint32_t {
|
||||
MATMUL_DEFAULT = 0, // C = op(A) * op(B)
|
||||
MATMUL_DEQUANT, //
|
||||
MATMUL_ACCUM_ATOMIC, // C += op(A) * op(B)
|
||||
MATMUL_WITH_BIAS, // C = op(A) * op(B) + Bias, where Bias is a vector.
|
||||
MATMUL_EIN_SUM
|
||||
};
|
||||
enum class QuantMode : uint32_t { PER_CHANNEL_SYMM = 0, PER_CHANNEL_ASYMM, PER_TOKEN_SYMM };
|
||||
};
|
||||
|
||||
enum class TensorDType : uint32_t { TENSOR_DTYPE_FLOAT16 = 0, TENSOR_DTYPE_BF16 };
|
||||
|
||||
enum class TensorFormat : uint32_t { TENSOR_FORMAT_ND = 0, TENSOR_FORMAT_NZ };
|
||||
|
||||
struct MatMulInfo {
|
||||
uint32_t batchSize{0};
|
||||
uint32_t m{0}; // actual input m
|
||||
uint32_t k{0}; // actual input k
|
||||
uint32_t n{0}; // actual input n
|
||||
TensorDType dtypeA{TensorDType::TENSOR_DTYPE_FLOAT16};
|
||||
TensorDType dtypeB{TensorDType::TENSOR_DTYPE_FLOAT16};
|
||||
TensorDType dtypeC{TensorDType::TENSOR_DTYPE_FLOAT16};
|
||||
TensorFormat formatA{TensorFormat::TENSOR_FORMAT_ND};
|
||||
TensorFormat formatB{TensorFormat::TENSOR_FORMAT_ND};
|
||||
TensorFormat formatC{TensorFormat::TENSOR_FORMAT_ND};
|
||||
MatMul::MatMulType mmType{MatMul::MatMulType::MATMUL_DEFAULT};
|
||||
bool transA{0}; // false: 0, true: 1
|
||||
bool transB{0}; // false: 0, true: 1
|
||||
bool biasFlag{0}; // false: 0, true: 1
|
||||
bool isInt8{0}; // false: 0, true: 1
|
||||
float inDtype{0};
|
||||
float outDtype{0};
|
||||
MatMul::QuantMode quantMode{MatMul::QuantMode::PER_CHANNEL_SYMM};
|
||||
};
|
||||
|
||||
struct OpShape {
|
||||
uint32_t batchSize{0};
|
||||
uint32_t m{0};
|
||||
uint32_t k{0};
|
||||
uint32_t n{0};
|
||||
uint32_t m0{0};
|
||||
uint32_t k0{0};
|
||||
uint32_t n0{0};
|
||||
};
|
||||
|
||||
struct HardwareInfo {
|
||||
uint32_t coreNum{0};
|
||||
uint32_t l2Size{0};
|
||||
uint32_t l1Size{0};
|
||||
uint32_t l0aSize{0};
|
||||
uint32_t l0bSize{0};
|
||||
uint32_t l0cSize{0};
|
||||
uint32_t hbmBandWidth{0};
|
||||
uint32_t l2BandWidth{0};
|
||||
|
||||
HardwareInfo();
|
||||
};
|
||||
|
||||
#pragma pack(push, 1)
|
||||
struct PpMatmulTilingData {
|
||||
OpShape opShape{};
|
||||
uint32_t mLoop{1};
|
||||
uint32_t kLoop{1};
|
||||
uint32_t nLoop{1};
|
||||
uint32_t coreLoop{1};
|
||||
uint32_t swizzlCount{1};
|
||||
uint32_t tilingKey{0};
|
||||
uint32_t blockDim{1};
|
||||
uint32_t swizzlDirect{0};
|
||||
uint32_t splitk{0};
|
||||
uint32_t enShuffleK{0};
|
||||
uint32_t quantMode{0};
|
||||
|
||||
void SetBaseShape(uint32_t batchSize, uint32_t m, uint32_t k, uint32_t n);
|
||||
void SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase, const MatMulInfo &mmInfo);
|
||||
void SetTilingKey(const MatMulInfo &mmInfo, uint32_t swizzleDirect, uint32_t enSplitK);
|
||||
uint32_t End(const MatMulInfo &mmInfo);
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
void GetPpMatmulTiling(const MatMulInfo &mmInfo, const HardwareInfo &hwInfo, uint32_t &blockDim,
|
||||
PpMatmulTilingData &tilingData);
|
||||
} // namespace pp_matmul
|
||||
#endif
|
||||
Reference in New Issue
Block a user