[Cherry-pick]bmm_transpose to v011dev (#3995)

### What this PR does / why we need it?
Add a custom op to acclerater the deepseek model. The fusion ops combine
the bmm and transpose together, which is applied to mla module.
Cherry-pick from this commtid c68ddc11ce

### Does this PR introduce _any_ user-facing change?
No

---------

Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
Wang Yixuan
2025-12-08 19:22:14 +08:00
committed by GitHub
parent 6391f0625f
commit d412565ec9
15 changed files with 1736 additions and 13 deletions

View File

@@ -0,0 +1,123 @@
#include <iostream>
#include <string>
#include "acl/acl.h"
#include "kernel_tiling/kernel_tiling.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/tiling_data.h"
#include "common_tiling.h"
namespace bmm_trans {
using namespace pp_matmul;
std::unordered_map<c10::string_view, uint16_t> quantModeMap = {
{"per_channel_symm", 0},
{"per_channel_asymm", 1},
{"per_token_symm", 2},
};
std::unordered_map<c10::string_view, uint16_t> formatModeMap = {
{"ND", 0},
{"NZ", 1},
};
std::unordered_map<c10::ScalarType, TensorDType> atType2tensorDType = {
{at::ScalarType::BFloat16, TensorDType::TENSOR_DTYPE_BF16},
{at::ScalarType::Half, TensorDType::TENSOR_DTYPE_FLOAT16}};
// batch size -> memory index
constexpr uint32_t MAX_CAPTURE_NUM = 1024;
template <typename MapType>
inline int GetModeVal(const MapType &mode_map, c10::optional<c10::string_view> mode_opt, c10::string_view default_mode,
const char *mode_name)
{
std::string modeStr(mode_name);
c10::string_view mode_str = mode_opt.value_or(default_mode);
auto it = mode_map.find(mode_str);
// if input mode is unsupported, use default value
TORCH_CHECK(it != mode_map.end(), modeStr, c10::str(": Unsupported mode value ", mode_str));
return it->second;
}
std::tuple<at::Tensor, uint32_t> batch_matmul_transpose_tiling(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
c10::optional<c10::string_view> format_mode,
c10::optional<c10::string_view> quant_mode)
{
auto tensorAShape = tensor_a.sizes();
auto tensorBShape = tensor_b.sizes();
auto tensorCShape = tensor_c.sizes();
uint32_t n;
uint32_t block_dim;
//auto &platform = PlatformInfo::Instance();
HardwareInfo hwInfo;
std::map<c10::ScalarType, float> dTypeMap = {{at::ScalarType::Half, 2.0}, {at::ScalarType::BFloat16, 2.0}};
at::ScalarType aType = tensor_a.scalar_type();
at::ScalarType bType = tensor_b.scalar_type();
at::ScalarType cType = tensor_c.scalar_type();
TORCH_CHECK(aType == bType && bType == cType, "tensor type is not the same");
TORCH_CHECK((aType == at::ScalarType::BFloat16) || (aType == at::ScalarType::Half),
"tensor type only support half or bf16");
TensorFormat formatMode = static_cast<TensorFormat>(GetModeVal(formatModeMap, format_mode, "ND", "format_mode"));
MatMul::QuantMode quantMode =
static_cast<MatMul::QuantMode>(GetModeVal(quantModeMap, quant_mode, "per_channel_symm", "quant_mode"));
TORCH_CHECK(tensorAShape.size() == 3, "batch size is not same between srcTensor and dstTensor");
if (formatMode == TensorFormat::TENSOR_FORMAT_ND) {
TORCH_CHECK(tensorBShape.size() == 3, "tensor shape should be dim3 in ND format");
TORCH_CHECK(tensorAShape[2] == tensorBShape[1], "tensor shape is wrong");
n = tensorBShape[2];
} else {
TORCH_CHECK(tensorBShape.size() == 4, "tensor shape should be dim4 in nz format");
TORCH_CHECK(tensorAShape[2] == tensorBShape[2], "tensor shape is wrong");
n = tensorBShape[1] * tensorBShape[3];
}
TORCH_CHECK(tensorAShape[1] == tensorBShape[0], "tensor shape is wrong");
OpShape opShape = {.batchSize = static_cast<uint32_t>(tensorAShape[1]),
.m = static_cast<uint32_t>(tensorAShape[0]),
.k = static_cast<uint32_t>(tensorAShape[2]),
.n = n};
pp_matmul::PpMatmulTilingData matmulTilingData = {
.opShape = opShape,
};
auto dType = atType2tensorDType[aType];
MatMulInfo mmInfo = {.batchSize = opShape.batchSize,
.m = opShape.m,
.k = opShape.k,
.n = opShape.n,
.dtypeA = dType,
.dtypeB = dType,
.dtypeC = dType,
.formatB = formatMode,
.mmType = MatMul::MatMulType::MATMUL_EIN_SUM,
.inDtype = dTypeMap[aType],
.outDtype = dTypeMap[cType],
.quantMode = quantMode};
GetPpMatmulTiling(mmInfo, hwInfo, block_dim, matmulTilingData);
host_utils::PpMatmulTilingCheck(matmulTilingData);
// tiling
int32_t batchIdx = opShape.m - 1;
uint32_t tilingSize = sizeof(pp_matmul::PpMatmulTilingData);
static auto global_tiling_data = at::empty(
{tilingSize * MAX_CAPTURE_NUM}, at::TensorOptions().dtype(at::kByte).device(tensor_a.options().device()));
if (batchIdx >= 0 && batchIdx < MAX_CAPTURE_NUM) {
aclrtMemcpy(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, &matmulTilingData,
tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
} else {
// Handle the case where batchIdx is out of range
TORCH_CHECK(false, "batchIdx is out of range: ", batchIdx);
}
at::Tensor tiling_tensor =
at::from_blob(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, at::kByte);
return std::make_tuple(tiling_tensor, block_dim);
}
}

View File

@@ -0,0 +1,57 @@
// Licensed under the BSD 3-Clause License (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef UTILS_COMMON_H
#define UTILS_COMMON_H
namespace host_utils {
constexpr uint32_t BLK_SIZE_ALIN_FOR_INT64 = 4;
constexpr uint32_t BLK_SIZE_ALIN_FOR_INT32 = 8;
inline uint64_t alinInt64Count(uint64_t count)
{
return (count + BLK_SIZE_ALIN_FOR_INT64 - 1) / BLK_SIZE_ALIN_FOR_INT64 * BLK_SIZE_ALIN_FOR_INT64;
}
inline uint64_t alinInt32Count(uint64_t count)
{
return (count + BLK_SIZE_ALIN_FOR_INT32 - 1) / BLK_SIZE_ALIN_FOR_INT32 * BLK_SIZE_ALIN_FOR_INT32;
}
template <typename T>
inline T CeilDiv(const T dividend, const T divisor)
{
if (divisor == 0) {
return UINT32_MAX;
}
return (dividend + divisor - 1) / divisor;
}
template <typename T>
inline T RoundUp(const T val, const T align = 16)
{
if (align == 0 || val + align - 1 < val) {
return 0;
}
return (val + align - 1) / align * align;
}
template <typename T>
inline T RoundDown(const T val, const T align = 16)
{
if (align == 0) {
return 0;
}
return val / align * align;
}
} // namespace host_utils
#endif // UTILS_COMMON_H

View File

@@ -0,0 +1,239 @@
/*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef COMMMON_TILING_H
#define COMMMON_TILING_H
#include <iostream>
#include <cmath>
#include "common.h"
#include "tiling/platform/platform_ascendc.h"
namespace host_utils {
constexpr uint32_t FP16_SIZE = 2;
constexpr uint32_t FP32_SIZE = 4;
constexpr uint32_t BLOCK_SIZE = 16;
constexpr uint32_t BLOCK_SIZE_INT8_K = 32;
constexpr uint32_t BASE_BLOCK_STEP = 2;
constexpr uint32_t AXES_ALIGN_SIZE = 512;
constexpr uint32_t AXES_ALIGN_SIZE_INT8 = 256;
constexpr uint32_t ND_SHAPE_SIZE = 2;
constexpr uint32_t NZ_SHAPE_SIZE = 4;
constexpr uint32_t CUBE_BLOCK_SIZE = 256;
constexpr uint32_t CUBE_BLOCK_SIZE_INT8 = 512;
constexpr uint32_t L1AB_PINGPONG_BUFFER_LEN = 262144;
constexpr uint32_t L0AB_PINGPONG_BUFFER_LEN_INT8 = 131072 * 2; // 256 KB
constexpr uint32_t L0AB_PINGPONG_BUFFER_LEN_FP16 = 131072; // 128 KB
constexpr uint32_t L1AB_PINGPONG_BUFFER_LEN_INT8_SPARSE = 160 * 1024;
constexpr uint32_t UB_LIMIT_SIZE_910A = 128 * 1024;
enum class PlatformType { ASCEND_310P, ASCEND_910A, ASCEND_910B, ASCEND_910C, PLATFORM_INVALID };
struct PlatformInfo {
public:
static const PlatformInfo &Instance()
{
static PlatformInfo platformInfo;
return platformInfo;
}
PlatformType socType;
uint32_t coreNum;
uint32_t coreNumAic;
uint32_t coreNumAiv;
uint64_t ubSize;
uint64_t l1Size;
uint64_t l2Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
private:
PlatformInfo()
{
auto ascendcPlatform = platform_ascendc::PlatformAscendCManager::GetInstance();
// TODO Hard coding set to 910_93xx, parse using aclrtGetSocName is better
socType = PlatformType::ASCEND_910C;
coreNum = ascendcPlatform->GetCoreNum();
coreNumAic = ascendcPlatform->GetCoreNumAic();
coreNumAiv = ascendcPlatform->GetCoreNumAiv();
ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L1, l1Size);
ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L2, l2Size);
ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L0_A, l0aSize);
ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L0_B, l0bSize);
ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, l0cSize);
}
PlatformInfo(const PlatformInfo &) = delete;
PlatformInfo &operator=(const PlatformInfo &) = delete;
PlatformInfo(PlatformInfo &&) = delete;
PlatformInfo &operator=(PlatformInfo &&) = delete;
};
inline __attribute__((always_inline)) uint32_t GetN0TilingLimit(bool compressFlag, uint32_t tilingN,
const PlatformType &platformType)
{
if (compressFlag) {
return std::min(tilingN * BLOCK_SIZE, AXES_ALIGN_SIZE_INT8);
} else {
return (platformType == PlatformType::ASCEND_310P || platformType == PlatformType::ASCEND_910A)
? AXES_ALIGN_SIZE
: AXES_ALIGN_SIZE_INT8;
}
}
template <typename OpShareType>
inline __attribute__((always_inline)) uint32_t GetN0TilingInit(const OpShareType &opShape, bool compressFlag,
uint32_t tilingN)
{
const uint32_t rnd = 16;
return compressFlag
? ((tilingN * BLOCK_SIZE > opShape.n) ? RoundUp<uint32_t>(opShape.n, rnd) : tilingN * BLOCK_SIZE)
: BLOCK_SIZE;
}
template <bool PRI_FLAG>
inline __attribute__((always_inline)) bool IsExceedTilingLimit(uint32_t axes0, uint32_t priAxes0,
uint32_t n0TilingLimit, PlatformType platformType,
uint32_t basicBlockSize)
{
return (PRI_FLAG && axes0 > n0TilingLimit) || (!PRI_FLAG && priAxes0 > n0TilingLimit) ||
(platformType == PlatformType::ASCEND_910A && basicBlockSize > UB_LIMIT_SIZE_910A);
}
template <bool PRI_FLAG, typename OpShareType>
inline __attribute__((always_inline)) void SetOpShapeAxesInfo(OpShareType &opShape, uint32_t priAxes0, uint32_t axes0)
{
opShape.m0 = PRI_FLAG ? priAxes0 : axes0;
opShape.n0 = PRI_FLAG ? axes0 : priAxes0;
}
template <typename HardwareType, typename OpShapeType>
inline __attribute__((always_inline)) float CostFunc(const HardwareType &hwInfor, OpShapeType &shape)
{
float aCoef = 1;
float bCoef = 1;
float bwCoef = static_cast<float>(hwInfor.l2BandWidth) / static_cast<float>(hwInfor.hbmBandWidth);
uint32_t mLoop = CeilDiv(shape.m, shape.m0);
uint32_t nLoop = CeilDiv(shape.n, shape.n0);
if (mLoop == 0 || nLoop == 0) {
return 1;
}
uint32_t coreNeed = shape.batchSize * mLoop * nLoop;
uint32_t blockDim = std::min(coreNeed, hwInfor.coreNum);
uint32_t mOnce = blockDim < nLoop ? shape.m0 : blockDim / nLoop * shape.m0;
uint32_t nOnce = blockDim < nLoop ? hwInfor.coreNum * shape.n0 : shape.n;
if (mOnce * shape.k * FP16_SIZE > hwInfor.l2Size) {
aCoef = bwCoef;
}
if (nOnce * shape.k * FP16_SIZE > hwInfor.l2Size) {
bCoef = bwCoef;
}
return 1 / (aCoef * static_cast<float>(shape.n0)) + 1 / (bCoef * static_cast<float>(shape.m0));
}
template <bool PRI_FLAG, typename OpShareType, typename TilingType, typename HardwareType, typename MatMulInfoType>
void TilingFunc(OpShareType &opShape, TilingType &tilingParam, const HardwareType &hwInfor,
const MatMulInfoType &mmInfo, bool compressFlag = false, const uint32_t tilingN = 1)
{
float costMin = 1;
const float CONST_2 = 2.0;
const uint32_t ROUND_CONST_16 = 16;
uint32_t roundBase = static_cast<uint32_t>(
pow(2, ceil(log(CeilDiv(PRI_FLAG ? opShape.n : opShape.m, ROUND_CONST_16)))) * ROUND_CONST_16);
uint32_t priAxes = RoundUp<uint32_t>(PRI_FLAG ? opShape.m : opShape.n, ROUND_CONST_16);
uint32_t axes = RoundUp<uint32_t>(PRI_FLAG ? opShape.n : opShape.m, roundBase);
float axes0Max = static_cast<float>(AXES_ALIGN_SIZE) / mmInfo.inDtype;
auto platformType = PlatformInfo::Instance().socType;
if (mmInfo.isInt8 && (platformType == PlatformType::ASCEND_310P || platformType == PlatformType::ASCEND_910A)) {
axes0Max /= CONST_2;
}
uint32_t n0TilingInit = GetN0TilingInit(opShape, compressFlag, tilingN);
uint32_t n0TilingLimit = GetN0TilingLimit(compressFlag, tilingN, platformType);
uint32_t priAxes0Init = PRI_FLAG ? BLOCK_SIZE : n0TilingInit;
uint32_t axes0Init = PRI_FLAG ? n0TilingInit : BLOCK_SIZE;
for (uint32_t priAxes0 = priAxes0Init; priAxes0 <= priAxes && priAxes0 <= axes0Max; priAxes0 *= BASE_BLOCK_STEP) {
for (uint32_t axes0 = axes0Init; axes0 <= axes && axes0 <= axes0Max; axes0 *= BASE_BLOCK_STEP) {
uint32_t basicBlockSize = priAxes0 * axes0 * FP32_SIZE;
if (basicBlockSize > hwInfor.l0cSize) {
continue;
}
if (mmInfo.isInt8 &&
IsExceedTilingLimit<PRI_FLAG>(axes0, priAxes0, n0TilingLimit, platformType, basicBlockSize)) {
continue;
}
SetOpShapeAxesInfo<PRI_FLAG>(opShape, priAxes0, axes0);
float cost = CostFunc<HardwareType, OpShareType>(hwInfor, opShape);
if (cost >= costMin) {
continue;
}
costMin = cost;
if constexpr (std::is_same<TilingType, pp_matmul::PpMatmulTilingData>::value) {
tilingParam.SetBaseOp(hwInfor.coreNum, opShape.m0, opShape.n0, mmInfo);
} else {
tilingParam.SetBaseOp(hwInfor.coreNum, opShape.m0, opShape.n0);
}
}
}
}
template <typename PpTilingDataType>
uint32_t Swizzl(PpTilingDataType &tilingData)
{
uint32_t swizzlDirect = 0;
uint32_t swizzlCount = 1;
float m0 = tilingData.opShape.m0;
float n0 = tilingData.opShape.n0;
float m = tilingData.opShape.m;
float k = tilingData.opShape.k;
float n = tilingData.opShape.n;
float mincost = m * k + k * n;
for (uint32_t i = 1; i <= tilingData.blockDim; ++i) {
int c = static_cast<int32_t>((tilingData.blockDim + i - 1) / i);
float cost;
// B0 + A < A0 + B
if (i * n0 + m < m0 * c + n) {
swizzlDirect = 1; // Nz
cost = n0 * i + m0 * c;
if (cost <= mincost) {
mincost = cost;
swizzlCount = i;
}
} else {
swizzlDirect = 0; // Zn
cost = m0 * i + n0 * c;
if (cost < mincost) {
mincost = cost;
swizzlCount = i;
}
}
}
tilingData.swizzlDirect = swizzlDirect;
tilingData.swizzlCount = swizzlCount;
return swizzlDirect;
}
template <typename PpTilingDataType>
inline __attribute__((always_inline)) void PpMatmulTilingCheck(const PpTilingDataType &tilingData)
{
TORCH_CHECK(tilingData.opShape.m0 > 0, "m0 is invalid");
TORCH_CHECK(tilingData.opShape.k0 > 0, "k0 is invalid");
TORCH_CHECK(tilingData.opShape.n0 > 0, "n0 is invalid");
TORCH_CHECK(tilingData.mLoop > 0, "mLoop is invalid");
TORCH_CHECK(tilingData.kLoop > 0, "kLoop is invalid");
TORCH_CHECK(tilingData.nLoop > 0, "nLoop is invalid");
TORCH_CHECK(tilingData.blockDim > 0, "nLoop is invalid");
}
} // namespace host_utils
#endif

View File

@@ -0,0 +1,155 @@
#include <map>
#include "tiling_data.h"
#include "common.h"
#include "common_tiling.h"
namespace pp_matmul {
constexpr uint32_t L1_DESCALE_BUFFER_LEN_MAX = 6144;
constexpr uint32_t CONST_3 = 3;
constexpr uint32_t CONST_4 = 4;
constexpr uint32_t CONST_16 = 16;
constexpr uint32_t CONST_32 = 32;
constexpr uint32_t CONST_256 = 256;
constexpr uint32_t CONST_512 = 512;
const std::map<TensorDType, uint32_t> G_DTYPE_MAP = {{TensorDType::TENSOR_DTYPE_FLOAT16, 1u},
{TensorDType::TENSOR_DTYPE_BF16, 2u}};
const std::map<TensorFormat, uint32_t> G_FORMAT_MAP = {{TensorFormat::TENSOR_FORMAT_ND, 0u},
{TensorFormat::TENSOR_FORMAT_NZ, 1u}};
using MmType = MatMul::MatMulType;
using QmType = MatMul::QuantMode;
using namespace host_utils;
bool IsI8Bf16Kernel(const MatMulInfo &mmInfo)
{
bool isI8Bf16 = mmInfo.isInt8 && mmInfo.dtypeC == TensorDType::TENSOR_DTYPE_BF16;
bool isI8Fp16 = mmInfo.isInt8 && mmInfo.dtypeC == TensorDType::TENSOR_DTYPE_FLOAT16 &&
mmInfo.quantMode == QmType::PER_TOKEN_SYMM;
return isI8Bf16 || isI8Fp16;
}
HardwareInfo::HardwareInfo()
{
auto &platform = PlatformInfo::Instance();
coreNum = platform.coreNumAic;
l2Size = platform.l2Size;
l1Size = platform.l1Size;
l0aSize = platform.l0aSize;
l0bSize = platform.l0bSize;
l0cSize = platform.l0cSize;
hbmBandWidth = 1;
l2BandWidth = 5; // 5x faster than hbm.
}
void PpMatmulTilingData::SetBaseShape(uint32_t batchSize, uint32_t m, uint32_t k, uint32_t n)
{
opShape.batchSize = batchSize;
opShape.m = m;
opShape.k = k;
opShape.n = n;
}
void PpMatmulTilingData::SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase, const MatMulInfo &mmInfo)
{
opShape.m0 = mBase;
opShape.n0 = nBase;
mLoop = CeilDiv(opShape.m, opShape.m0);
nLoop = CeilDiv(opShape.n, opShape.n0);
coreLoop = opShape.batchSize * mLoop * nLoop;
if (mLoop == 1 && mmInfo.transB && coreLoop % coreNum < coreNum / CONST_4 * CONST_3) {
mBase = RoundUp<uint32_t>(opShape.m, CONST_16);
opShape.m0 = mBase;
uint32_t maxN0 = PlatformInfo::Instance().l0cSize / (mBase * sizeof(float));
if (mmInfo.isInt8 || mmInfo.mmType == MmType::MATMUL_WITH_BIAS) {
maxN0 = maxN0 < CONST_256 ? maxN0 : CONST_256;
}
uint32_t x = CeilDiv(opShape.n, coreNum);
uint32_t y = CeilDiv(x, maxN0);
nBase = RoundUp<uint32_t>(CeilDiv(x, y), CONST_16);
uint32_t rqdL0CSize = mBase * nBase * sizeof(float);
if (rqdL0CSize < PlatformInfo::Instance().l0cSize &&
(mBase + nBase) * CONST_256 * sizeof(uint16_t) < L1AB_PINGPONG_BUFFER_LEN) {
opShape.n0 = nBase;
nLoop = CeilDiv(opShape.n, opShape.n0);
coreLoop = opShape.batchSize * nLoop;
}
}
blockDim = std::min(coreLoop, coreNum);
}
// transA transB quantMode [dtype] format
void PpMatmulTilingData::SetTilingKey(const MatMulInfo &mmInfo, uint32_t swizzleDirect, uint32_t enSplitK)
{
if (mmInfo.mmType == MmType::MATMUL_ACCUM_ATOMIC || mmInfo.mmType == MmType::MATMUL_WITH_BIAS ||
mmInfo.mmType == MmType::MATMUL_EIN_SUM || mmInfo.mmType == MmType::MATMUL_DEQUANT || IsI8Bf16Kernel(mmInfo)) {
// SwizzleDir[1] TransA[1] TransB[1] DtypeA[3] DtypeB[3] DtypeC[3] FormatA[1] FormatB[1] FormatC[1] WithBias[1]
tilingKey = swizzleDirect;
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.transA);
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.transB);
tilingKey = (tilingKey << 3) + G_DTYPE_MAP.at(mmInfo.dtypeA); // 3bit for dtypeA.
tilingKey = (tilingKey << 3) + G_DTYPE_MAP.at(mmInfo.dtypeB); // 3bit for dtypeB.
tilingKey = (tilingKey << 3) + G_DTYPE_MAP.at(mmInfo.dtypeC); // 3bit for dtypeC.
tilingKey = (tilingKey << 1) + G_FORMAT_MAP.at(mmInfo.formatA);
tilingKey = (tilingKey << 1) + G_FORMAT_MAP.at(mmInfo.formatB);
tilingKey = (tilingKey << 1) + G_FORMAT_MAP.at(mmInfo.formatC);
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.biasFlag);
} else {
tilingKey = swizzleDirect;
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.transA);
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.transB);
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.isInt8);
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.biasFlag);
tilingKey = (tilingKey << 1) + enSplitK;
}
}
uint32_t PpMatmulTilingData::End(const MatMulInfo &mmInfo)
{
uint32_t cubeBlockSize = mmInfo.isInt8 ? CUBE_BLOCK_SIZE_INT8 : CUBE_BLOCK_SIZE;
uint32_t kBlockSize = mmInfo.isInt8 ? BLOCK_SIZE_INT8_K : BLOCK_SIZE;
uint32_t scaleBlockSize = mmInfo.isInt8 ? L1_DESCALE_BUFFER_LEN_MAX : 0;
uint32_t shapeSum = opShape.m0 + opShape.n0;
if (mmInfo.isInt8 && (mmInfo.transA || !mmInfo.transB)) {
shapeSum = RoundUp<uint32_t>(opShape.m0, CONST_32) + RoundUp<uint32_t>(opShape.n0, CONST_32);
}
uint32_t k0Max = shapeSum == 0
? L1AB_PINGPONG_BUFFER_LEN
: static_cast<uint32_t>(static_cast<float>(L1AB_PINGPONG_BUFFER_LEN - scaleBlockSize) /
(shapeSum * mmInfo.inDtype));
if (mmInfo.mmType == MatMul::MatMulType::MATMUL_WITH_BIAS) {
uint32_t l1AbSize = L1AB_PINGPONG_BUFFER_LEN - opShape.n0 * sizeof(float);
k0Max = l1AbSize / (shapeSum * mmInfo.inDtype);
}
opShape.k0 =
k0Max < cubeBlockSize ? RoundDown<uint32_t>(k0Max, kBlockSize) : RoundDown<uint32_t>(k0Max, cubeBlockSize);
if (opShape.k0 > CONST_512) {
opShape.k0 = RoundDown<uint32_t>(opShape.k0, CONST_512);
}
kLoop = CeilDiv(opShape.k, opShape.k0);
return blockDim;
}
void GetPpMatmulTiling(const MatMulInfo &mmInfo, const HardwareInfo &hwInfo, uint32_t &blockDim,
PpMatmulTilingData &tilingData)
{
OpShape opShape;
opShape.batchSize = mmInfo.batchSize;
opShape.m = mmInfo.m;
opShape.n = mmInfo.n;
opShape.k = mmInfo.k;
tilingData.opShape = opShape;
tilingData.quantMode = static_cast<uint32_t>(mmInfo.quantMode);
tilingData.SetTilingKey(mmInfo, 0, 0); // init tilingkey with transA transB.
if (opShape.m < opShape.n) {
TilingFunc<false, OpShape, PpMatmulTilingData, HardwareInfo, MatMulInfo>(opShape, tilingData, hwInfo, mmInfo);
} else {
TilingFunc<true, OpShape, PpMatmulTilingData, HardwareInfo, MatMulInfo>(opShape, tilingData, hwInfo, mmInfo);
}
uint32_t direct = Swizzl<PpMatmulTilingData>(tilingData);
blockDim = tilingData.End(mmInfo);
tilingData.SetTilingKey(mmInfo, direct, 0);
}
} // namespace pp_matmul

View File

@@ -0,0 +1,90 @@
#ifndef PP_MATMUL_TILING_DATA
#define PP_MATMUL_TILING_DATA
#include <cstdint>
namespace pp_matmul {
struct MatMul {
enum class MatMulType : uint32_t {
MATMUL_DEFAULT = 0, // C = op(A) * op(B)
MATMUL_DEQUANT, //
MATMUL_ACCUM_ATOMIC, // C += op(A) * op(B)
MATMUL_WITH_BIAS, // C = op(A) * op(B) + Bias, where Bias is a vector.
MATMUL_EIN_SUM
};
enum class QuantMode : uint32_t { PER_CHANNEL_SYMM = 0, PER_CHANNEL_ASYMM, PER_TOKEN_SYMM };
};
enum class TensorDType : uint32_t { TENSOR_DTYPE_FLOAT16 = 0, TENSOR_DTYPE_BF16 };
enum class TensorFormat : uint32_t { TENSOR_FORMAT_ND = 0, TENSOR_FORMAT_NZ };
struct MatMulInfo {
uint32_t batchSize{0};
uint32_t m{0}; // actual input m
uint32_t k{0}; // actual input k
uint32_t n{0}; // actual input n
TensorDType dtypeA{TensorDType::TENSOR_DTYPE_FLOAT16};
TensorDType dtypeB{TensorDType::TENSOR_DTYPE_FLOAT16};
TensorDType dtypeC{TensorDType::TENSOR_DTYPE_FLOAT16};
TensorFormat formatA{TensorFormat::TENSOR_FORMAT_ND};
TensorFormat formatB{TensorFormat::TENSOR_FORMAT_ND};
TensorFormat formatC{TensorFormat::TENSOR_FORMAT_ND};
MatMul::MatMulType mmType{MatMul::MatMulType::MATMUL_DEFAULT};
bool transA{0}; // false: 0, true: 1
bool transB{0}; // false: 0, true: 1
bool biasFlag{0}; // false: 0, true: 1
bool isInt8{0}; // false: 0, true: 1
float inDtype{0};
float outDtype{0};
MatMul::QuantMode quantMode{MatMul::QuantMode::PER_CHANNEL_SYMM};
};
struct OpShape {
uint32_t batchSize{0};
uint32_t m{0};
uint32_t k{0};
uint32_t n{0};
uint32_t m0{0};
uint32_t k0{0};
uint32_t n0{0};
};
struct HardwareInfo {
uint32_t coreNum{0};
uint32_t l2Size{0};
uint32_t l1Size{0};
uint32_t l0aSize{0};
uint32_t l0bSize{0};
uint32_t l0cSize{0};
uint32_t hbmBandWidth{0};
uint32_t l2BandWidth{0};
HardwareInfo();
};
#pragma pack(push, 1)
struct PpMatmulTilingData {
OpShape opShape{};
uint32_t mLoop{1};
uint32_t kLoop{1};
uint32_t nLoop{1};
uint32_t coreLoop{1};
uint32_t swizzlCount{1};
uint32_t tilingKey{0};
uint32_t blockDim{1};
uint32_t swizzlDirect{0};
uint32_t splitk{0};
uint32_t enShuffleK{0};
uint32_t quantMode{0};
void SetBaseShape(uint32_t batchSize, uint32_t m, uint32_t k, uint32_t n);
void SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase, const MatMulInfo &mmInfo);
void SetTilingKey(const MatMulInfo &mmInfo, uint32_t swizzleDirect, uint32_t enSplitK);
uint32_t End(const MatMulInfo &mmInfo);
};
#pragma pack(pop)
void GetPpMatmulTiling(const MatMulInfo &mmInfo, const HardwareInfo &hwInfo, uint32_t &blockDim,
PpMatmulTilingData &tilingData);
} // namespace pp_matmul
#endif