[Cherry-pick]bmm_transpose to v011dev (#3995)
### What this PR does / why we need it? Add a custom op to acclerater the deepseek model. The fusion ops combine the bmm and transpose together, which is applied to mla module. Cherry-pick from this commtid c68ddc11ce53334fc9a17bad58342148cbf14e86 ### Does this PR introduce _any_ user-facing change? No --------- Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
@@ -12,8 +12,8 @@ repos:
|
||||
- id: codespell
|
||||
args: [
|
||||
--toml, pyproject.toml,
|
||||
'--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/mla_preprocess/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
|
||||
'-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND'
|
||||
'--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
|
||||
'-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND,ND'
|
||||
]
|
||||
additional_dependencies:
|
||||
- tomli
|
||||
|
||||
@@ -55,15 +55,34 @@ include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
|
||||
file(GLOB KERNEL_FILES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/*.cpp)
|
||||
|
||||
ascendc_library(vllm_ascend_kernels SHARED
|
||||
set(VLLM_ASCEND_CUSTOM_OP
|
||||
${KERNEL_FILES}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp
|
||||
)
|
||||
|
||||
set(VLLM_ASCEND_CUSTOM_OP_EXCLUDE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp
|
||||
)
|
||||
|
||||
if(SOC_VERSION STREQUAL "ASCEND310P3")
|
||||
list(REMOVE_ITEM VLLM_ASCEND_CUSTOM_OP ${VLLM_ASCEND_CUSTOM_OP_EXCLUDE})
|
||||
endif()
|
||||
|
||||
ascendc_library(vllm_ascend_kernels SHARED
|
||||
${VLLM_ASCEND_CUSTOM_OP}
|
||||
)
|
||||
|
||||
message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}")
|
||||
|
||||
file(GLOB VLLM_ASCEND_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp)
|
||||
if(SOC_VERSION STREQUAL "ASCEND310P3")
|
||||
file(GLOB VLLM_ASCEND_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp)
|
||||
else()
|
||||
file(GLOB VLLM_ASCEND_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.cpp)
|
||||
endif()
|
||||
|
||||
include_directories(
|
||||
${pybind11_INCLUDE_DIRS}
|
||||
@@ -73,6 +92,7 @@ include_directories(
|
||||
${ASCEND_HOME_PATH}/include
|
||||
${ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform
|
||||
${ASCEND_HOME_PATH}/x86_64-linux/include/experiment/platform
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_host
|
||||
)
|
||||
|
||||
set(
|
||||
|
||||
123
csrc/batch_matmul_transpose/op_host/batch_matmul_transpose.h
Normal file
123
csrc/batch_matmul_transpose/op_host/batch_matmul_transpose.h
Normal file
@@ -0,0 +1,123 @@
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include "acl/acl.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "tiling/tiling_data.h"
|
||||
#include "common_tiling.h"
|
||||
|
||||
|
||||
namespace bmm_trans {
|
||||
using namespace pp_matmul;
|
||||
|
||||
std::unordered_map<c10::string_view, uint16_t> quantModeMap = {
|
||||
{"per_channel_symm", 0},
|
||||
{"per_channel_asymm", 1},
|
||||
{"per_token_symm", 2},
|
||||
};
|
||||
|
||||
std::unordered_map<c10::string_view, uint16_t> formatModeMap = {
|
||||
{"ND", 0},
|
||||
{"NZ", 1},
|
||||
};
|
||||
|
||||
std::unordered_map<c10::ScalarType, TensorDType> atType2tensorDType = {
|
||||
{at::ScalarType::BFloat16, TensorDType::TENSOR_DTYPE_BF16},
|
||||
{at::ScalarType::Half, TensorDType::TENSOR_DTYPE_FLOAT16}};
|
||||
|
||||
// batch size -> memory index
|
||||
constexpr uint32_t MAX_CAPTURE_NUM = 1024;
|
||||
|
||||
template <typename MapType>
|
||||
inline int GetModeVal(const MapType &mode_map, c10::optional<c10::string_view> mode_opt, c10::string_view default_mode,
|
||||
const char *mode_name)
|
||||
{
|
||||
std::string modeStr(mode_name);
|
||||
c10::string_view mode_str = mode_opt.value_or(default_mode);
|
||||
auto it = mode_map.find(mode_str);
|
||||
// if input mode is unsupported, use default value
|
||||
TORCH_CHECK(it != mode_map.end(), modeStr, c10::str(": Unsupported mode value ", mode_str));
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, uint32_t> batch_matmul_transpose_tiling(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
|
||||
c10::optional<c10::string_view> format_mode,
|
||||
c10::optional<c10::string_view> quant_mode)
|
||||
{
|
||||
auto tensorAShape = tensor_a.sizes();
|
||||
auto tensorBShape = tensor_b.sizes();
|
||||
auto tensorCShape = tensor_c.sizes();
|
||||
uint32_t n;
|
||||
uint32_t block_dim;
|
||||
|
||||
//auto &platform = PlatformInfo::Instance();
|
||||
HardwareInfo hwInfo;
|
||||
std::map<c10::ScalarType, float> dTypeMap = {{at::ScalarType::Half, 2.0}, {at::ScalarType::BFloat16, 2.0}};
|
||||
|
||||
at::ScalarType aType = tensor_a.scalar_type();
|
||||
at::ScalarType bType = tensor_b.scalar_type();
|
||||
at::ScalarType cType = tensor_c.scalar_type();
|
||||
TORCH_CHECK(aType == bType && bType == cType, "tensor type is not the same");
|
||||
TORCH_CHECK((aType == at::ScalarType::BFloat16) || (aType == at::ScalarType::Half),
|
||||
"tensor type only support half or bf16");
|
||||
|
||||
TensorFormat formatMode = static_cast<TensorFormat>(GetModeVal(formatModeMap, format_mode, "ND", "format_mode"));
|
||||
MatMul::QuantMode quantMode =
|
||||
static_cast<MatMul::QuantMode>(GetModeVal(quantModeMap, quant_mode, "per_channel_symm", "quant_mode"));
|
||||
|
||||
TORCH_CHECK(tensorAShape.size() == 3, "batch size is not same between srcTensor and dstTensor");
|
||||
if (formatMode == TensorFormat::TENSOR_FORMAT_ND) {
|
||||
TORCH_CHECK(tensorBShape.size() == 3, "tensor shape should be dim3 in ND format");
|
||||
TORCH_CHECK(tensorAShape[2] == tensorBShape[1], "tensor shape is wrong");
|
||||
n = tensorBShape[2];
|
||||
} else {
|
||||
TORCH_CHECK(tensorBShape.size() == 4, "tensor shape should be dim4 in nz format");
|
||||
TORCH_CHECK(tensorAShape[2] == tensorBShape[2], "tensor shape is wrong");
|
||||
n = tensorBShape[1] * tensorBShape[3];
|
||||
}
|
||||
TORCH_CHECK(tensorAShape[1] == tensorBShape[0], "tensor shape is wrong");
|
||||
|
||||
OpShape opShape = {.batchSize = static_cast<uint32_t>(tensorAShape[1]),
|
||||
.m = static_cast<uint32_t>(tensorAShape[0]),
|
||||
.k = static_cast<uint32_t>(tensorAShape[2]),
|
||||
.n = n};
|
||||
pp_matmul::PpMatmulTilingData matmulTilingData = {
|
||||
.opShape = opShape,
|
||||
};
|
||||
auto dType = atType2tensorDType[aType];
|
||||
MatMulInfo mmInfo = {.batchSize = opShape.batchSize,
|
||||
.m = opShape.m,
|
||||
.k = opShape.k,
|
||||
.n = opShape.n,
|
||||
.dtypeA = dType,
|
||||
.dtypeB = dType,
|
||||
.dtypeC = dType,
|
||||
.formatB = formatMode,
|
||||
.mmType = MatMul::MatMulType::MATMUL_EIN_SUM,
|
||||
.inDtype = dTypeMap[aType],
|
||||
.outDtype = dTypeMap[cType],
|
||||
.quantMode = quantMode};
|
||||
GetPpMatmulTiling(mmInfo, hwInfo, block_dim, matmulTilingData);
|
||||
host_utils::PpMatmulTilingCheck(matmulTilingData);
|
||||
|
||||
// tiling
|
||||
int32_t batchIdx = opShape.m - 1;
|
||||
uint32_t tilingSize = sizeof(pp_matmul::PpMatmulTilingData);
|
||||
static auto global_tiling_data = at::empty(
|
||||
{tilingSize * MAX_CAPTURE_NUM}, at::TensorOptions().dtype(at::kByte).device(tensor_a.options().device()));
|
||||
if (batchIdx >= 0 && batchIdx < MAX_CAPTURE_NUM) {
|
||||
aclrtMemcpy(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, &matmulTilingData,
|
||||
tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
} else {
|
||||
// Handle the case where batchIdx is out of range
|
||||
TORCH_CHECK(false, "batchIdx is out of range: ", batchIdx);
|
||||
}
|
||||
at::Tensor tiling_tensor =
|
||||
at::from_blob(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, at::kByte);
|
||||
|
||||
return std::make_tuple(tiling_tensor, block_dim);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
57
csrc/batch_matmul_transpose/op_host/common.h
Normal file
57
csrc/batch_matmul_transpose/op_host/common.h
Normal file
@@ -0,0 +1,57 @@
|
||||
|
||||
// Licensed under the BSD 3-Clause License (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef UTILS_COMMON_H
|
||||
#define UTILS_COMMON_H
|
||||
|
||||
namespace host_utils {
|
||||
|
||||
constexpr uint32_t BLK_SIZE_ALIN_FOR_INT64 = 4;
|
||||
constexpr uint32_t BLK_SIZE_ALIN_FOR_INT32 = 8;
|
||||
|
||||
inline uint64_t alinInt64Count(uint64_t count)
|
||||
{
|
||||
return (count + BLK_SIZE_ALIN_FOR_INT64 - 1) / BLK_SIZE_ALIN_FOR_INT64 * BLK_SIZE_ALIN_FOR_INT64;
|
||||
}
|
||||
|
||||
inline uint64_t alinInt32Count(uint64_t count)
|
||||
{
|
||||
return (count + BLK_SIZE_ALIN_FOR_INT32 - 1) / BLK_SIZE_ALIN_FOR_INT32 * BLK_SIZE_ALIN_FOR_INT32;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T CeilDiv(const T dividend, const T divisor)
|
||||
{
|
||||
if (divisor == 0) {
|
||||
return UINT32_MAX;
|
||||
}
|
||||
return (dividend + divisor - 1) / divisor;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T RoundUp(const T val, const T align = 16)
|
||||
{
|
||||
if (align == 0 || val + align - 1 < val) {
|
||||
return 0;
|
||||
}
|
||||
return (val + align - 1) / align * align;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T RoundDown(const T val, const T align = 16)
|
||||
{
|
||||
if (align == 0) {
|
||||
return 0;
|
||||
}
|
||||
return val / align * align;
|
||||
}
|
||||
} // namespace host_utils
|
||||
#endif // UTILS_COMMON_H
|
||||
239
csrc/batch_matmul_transpose/op_host/common_tiling.h
Normal file
239
csrc/batch_matmul_transpose/op_host/common_tiling.h
Normal file
@@ -0,0 +1,239 @@
|
||||
/*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef COMMMON_TILING_H
|
||||
#define COMMMON_TILING_H
|
||||
|
||||
#include <iostream>
|
||||
#include <cmath>
|
||||
#include "common.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
|
||||
namespace host_utils {
|
||||
|
||||
constexpr uint32_t FP16_SIZE = 2;
|
||||
constexpr uint32_t FP32_SIZE = 4;
|
||||
constexpr uint32_t BLOCK_SIZE = 16;
|
||||
constexpr uint32_t BLOCK_SIZE_INT8_K = 32;
|
||||
constexpr uint32_t BASE_BLOCK_STEP = 2;
|
||||
constexpr uint32_t AXES_ALIGN_SIZE = 512;
|
||||
constexpr uint32_t AXES_ALIGN_SIZE_INT8 = 256;
|
||||
constexpr uint32_t ND_SHAPE_SIZE = 2;
|
||||
constexpr uint32_t NZ_SHAPE_SIZE = 4;
|
||||
constexpr uint32_t CUBE_BLOCK_SIZE = 256;
|
||||
constexpr uint32_t CUBE_BLOCK_SIZE_INT8 = 512;
|
||||
constexpr uint32_t L1AB_PINGPONG_BUFFER_LEN = 262144;
|
||||
constexpr uint32_t L0AB_PINGPONG_BUFFER_LEN_INT8 = 131072 * 2; // 256 KB
|
||||
constexpr uint32_t L0AB_PINGPONG_BUFFER_LEN_FP16 = 131072; // 128 KB
|
||||
constexpr uint32_t L1AB_PINGPONG_BUFFER_LEN_INT8_SPARSE = 160 * 1024;
|
||||
constexpr uint32_t UB_LIMIT_SIZE_910A = 128 * 1024;
|
||||
|
||||
enum class PlatformType { ASCEND_310P, ASCEND_910A, ASCEND_910B, ASCEND_910C, PLATFORM_INVALID };
|
||||
|
||||
struct PlatformInfo {
|
||||
public:
|
||||
static const PlatformInfo &Instance()
|
||||
{
|
||||
static PlatformInfo platformInfo;
|
||||
return platformInfo;
|
||||
}
|
||||
|
||||
PlatformType socType;
|
||||
uint32_t coreNum;
|
||||
uint32_t coreNumAic;
|
||||
uint32_t coreNumAiv;
|
||||
uint64_t ubSize;
|
||||
uint64_t l1Size;
|
||||
uint64_t l2Size;
|
||||
uint64_t l0aSize;
|
||||
uint64_t l0bSize;
|
||||
uint64_t l0cSize;
|
||||
|
||||
private:
|
||||
PlatformInfo()
|
||||
{
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendCManager::GetInstance();
|
||||
// TODO Hard coding set to 910_93xx, parse using aclrtGetSocName is better
|
||||
socType = PlatformType::ASCEND_910C;
|
||||
coreNum = ascendcPlatform->GetCoreNum();
|
||||
coreNumAic = ascendcPlatform->GetCoreNumAic();
|
||||
coreNumAiv = ascendcPlatform->GetCoreNumAiv();
|
||||
ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
|
||||
ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L1, l1Size);
|
||||
ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L2, l2Size);
|
||||
ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L0_A, l0aSize);
|
||||
ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L0_B, l0bSize);
|
||||
ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, l0cSize);
|
||||
}
|
||||
|
||||
PlatformInfo(const PlatformInfo &) = delete;
|
||||
PlatformInfo &operator=(const PlatformInfo &) = delete;
|
||||
PlatformInfo(PlatformInfo &&) = delete;
|
||||
PlatformInfo &operator=(PlatformInfo &&) = delete;
|
||||
};
|
||||
|
||||
inline __attribute__((always_inline)) uint32_t GetN0TilingLimit(bool compressFlag, uint32_t tilingN,
|
||||
const PlatformType &platformType)
|
||||
{
|
||||
if (compressFlag) {
|
||||
return std::min(tilingN * BLOCK_SIZE, AXES_ALIGN_SIZE_INT8);
|
||||
} else {
|
||||
return (platformType == PlatformType::ASCEND_310P || platformType == PlatformType::ASCEND_910A)
|
||||
? AXES_ALIGN_SIZE
|
||||
: AXES_ALIGN_SIZE_INT8;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OpShareType>
|
||||
inline __attribute__((always_inline)) uint32_t GetN0TilingInit(const OpShareType &opShape, bool compressFlag,
|
||||
uint32_t tilingN)
|
||||
{
|
||||
const uint32_t rnd = 16;
|
||||
return compressFlag
|
||||
? ((tilingN * BLOCK_SIZE > opShape.n) ? RoundUp<uint32_t>(opShape.n, rnd) : tilingN * BLOCK_SIZE)
|
||||
: BLOCK_SIZE;
|
||||
}
|
||||
|
||||
template <bool PRI_FLAG>
|
||||
inline __attribute__((always_inline)) bool IsExceedTilingLimit(uint32_t axes0, uint32_t priAxes0,
|
||||
uint32_t n0TilingLimit, PlatformType platformType,
|
||||
uint32_t basicBlockSize)
|
||||
{
|
||||
return (PRI_FLAG && axes0 > n0TilingLimit) || (!PRI_FLAG && priAxes0 > n0TilingLimit) ||
|
||||
(platformType == PlatformType::ASCEND_910A && basicBlockSize > UB_LIMIT_SIZE_910A);
|
||||
}
|
||||
|
||||
template <bool PRI_FLAG, typename OpShareType>
|
||||
inline __attribute__((always_inline)) void SetOpShapeAxesInfo(OpShareType &opShape, uint32_t priAxes0, uint32_t axes0)
|
||||
{
|
||||
opShape.m0 = PRI_FLAG ? priAxes0 : axes0;
|
||||
opShape.n0 = PRI_FLAG ? axes0 : priAxes0;
|
||||
}
|
||||
|
||||
template <typename HardwareType, typename OpShapeType>
|
||||
inline __attribute__((always_inline)) float CostFunc(const HardwareType &hwInfor, OpShapeType &shape)
|
||||
{
|
||||
float aCoef = 1;
|
||||
float bCoef = 1;
|
||||
float bwCoef = static_cast<float>(hwInfor.l2BandWidth) / static_cast<float>(hwInfor.hbmBandWidth);
|
||||
uint32_t mLoop = CeilDiv(shape.m, shape.m0);
|
||||
uint32_t nLoop = CeilDiv(shape.n, shape.n0);
|
||||
if (mLoop == 0 || nLoop == 0) {
|
||||
return 1;
|
||||
}
|
||||
uint32_t coreNeed = shape.batchSize * mLoop * nLoop;
|
||||
uint32_t blockDim = std::min(coreNeed, hwInfor.coreNum);
|
||||
uint32_t mOnce = blockDim < nLoop ? shape.m0 : blockDim / nLoop * shape.m0;
|
||||
uint32_t nOnce = blockDim < nLoop ? hwInfor.coreNum * shape.n0 : shape.n;
|
||||
if (mOnce * shape.k * FP16_SIZE > hwInfor.l2Size) {
|
||||
aCoef = bwCoef;
|
||||
}
|
||||
if (nOnce * shape.k * FP16_SIZE > hwInfor.l2Size) {
|
||||
bCoef = bwCoef;
|
||||
}
|
||||
return 1 / (aCoef * static_cast<float>(shape.n0)) + 1 / (bCoef * static_cast<float>(shape.m0));
|
||||
}
|
||||
|
||||
template <bool PRI_FLAG, typename OpShareType, typename TilingType, typename HardwareType, typename MatMulInfoType>
|
||||
void TilingFunc(OpShareType &opShape, TilingType &tilingParam, const HardwareType &hwInfor,
|
||||
const MatMulInfoType &mmInfo, bool compressFlag = false, const uint32_t tilingN = 1)
|
||||
{
|
||||
float costMin = 1;
|
||||
const float CONST_2 = 2.0;
|
||||
const uint32_t ROUND_CONST_16 = 16;
|
||||
uint32_t roundBase = static_cast<uint32_t>(
|
||||
pow(2, ceil(log(CeilDiv(PRI_FLAG ? opShape.n : opShape.m, ROUND_CONST_16)))) * ROUND_CONST_16);
|
||||
uint32_t priAxes = RoundUp<uint32_t>(PRI_FLAG ? opShape.m : opShape.n, ROUND_CONST_16);
|
||||
uint32_t axes = RoundUp<uint32_t>(PRI_FLAG ? opShape.n : opShape.m, roundBase);
|
||||
float axes0Max = static_cast<float>(AXES_ALIGN_SIZE) / mmInfo.inDtype;
|
||||
auto platformType = PlatformInfo::Instance().socType;
|
||||
if (mmInfo.isInt8 && (platformType == PlatformType::ASCEND_310P || platformType == PlatformType::ASCEND_910A)) {
|
||||
axes0Max /= CONST_2;
|
||||
}
|
||||
|
||||
uint32_t n0TilingInit = GetN0TilingInit(opShape, compressFlag, tilingN);
|
||||
uint32_t n0TilingLimit = GetN0TilingLimit(compressFlag, tilingN, platformType);
|
||||
uint32_t priAxes0Init = PRI_FLAG ? BLOCK_SIZE : n0TilingInit;
|
||||
uint32_t axes0Init = PRI_FLAG ? n0TilingInit : BLOCK_SIZE;
|
||||
for (uint32_t priAxes0 = priAxes0Init; priAxes0 <= priAxes && priAxes0 <= axes0Max; priAxes0 *= BASE_BLOCK_STEP) {
|
||||
for (uint32_t axes0 = axes0Init; axes0 <= axes && axes0 <= axes0Max; axes0 *= BASE_BLOCK_STEP) {
|
||||
uint32_t basicBlockSize = priAxes0 * axes0 * FP32_SIZE;
|
||||
if (basicBlockSize > hwInfor.l0cSize) {
|
||||
continue;
|
||||
}
|
||||
if (mmInfo.isInt8 &&
|
||||
IsExceedTilingLimit<PRI_FLAG>(axes0, priAxes0, n0TilingLimit, platformType, basicBlockSize)) {
|
||||
continue;
|
||||
}
|
||||
SetOpShapeAxesInfo<PRI_FLAG>(opShape, priAxes0, axes0);
|
||||
float cost = CostFunc<HardwareType, OpShareType>(hwInfor, opShape);
|
||||
if (cost >= costMin) {
|
||||
continue;
|
||||
}
|
||||
costMin = cost;
|
||||
if constexpr (std::is_same<TilingType, pp_matmul::PpMatmulTilingData>::value) {
|
||||
tilingParam.SetBaseOp(hwInfor.coreNum, opShape.m0, opShape.n0, mmInfo);
|
||||
} else {
|
||||
tilingParam.SetBaseOp(hwInfor.coreNum, opShape.m0, opShape.n0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PpTilingDataType>
|
||||
uint32_t Swizzl(PpTilingDataType &tilingData)
|
||||
{
|
||||
uint32_t swizzlDirect = 0;
|
||||
uint32_t swizzlCount = 1;
|
||||
float m0 = tilingData.opShape.m0;
|
||||
float n0 = tilingData.opShape.n0;
|
||||
float m = tilingData.opShape.m;
|
||||
float k = tilingData.opShape.k;
|
||||
float n = tilingData.opShape.n;
|
||||
float mincost = m * k + k * n;
|
||||
|
||||
for (uint32_t i = 1; i <= tilingData.blockDim; ++i) {
|
||||
int c = static_cast<int32_t>((tilingData.blockDim + i - 1) / i);
|
||||
float cost;
|
||||
// B0 + A < A0 + B
|
||||
if (i * n0 + m < m0 * c + n) {
|
||||
swizzlDirect = 1; // Nz
|
||||
cost = n0 * i + m0 * c;
|
||||
if (cost <= mincost) {
|
||||
mincost = cost;
|
||||
swizzlCount = i;
|
||||
}
|
||||
} else {
|
||||
swizzlDirect = 0; // Zn
|
||||
cost = m0 * i + n0 * c;
|
||||
if (cost < mincost) {
|
||||
mincost = cost;
|
||||
swizzlCount = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
tilingData.swizzlDirect = swizzlDirect;
|
||||
tilingData.swizzlCount = swizzlCount;
|
||||
return swizzlDirect;
|
||||
}
|
||||
|
||||
template <typename PpTilingDataType>
|
||||
inline __attribute__((always_inline)) void PpMatmulTilingCheck(const PpTilingDataType &tilingData)
|
||||
{
|
||||
TORCH_CHECK(tilingData.opShape.m0 > 0, "m0 is invalid");
|
||||
TORCH_CHECK(tilingData.opShape.k0 > 0, "k0 is invalid");
|
||||
TORCH_CHECK(tilingData.opShape.n0 > 0, "n0 is invalid");
|
||||
TORCH_CHECK(tilingData.mLoop > 0, "mLoop is invalid");
|
||||
TORCH_CHECK(tilingData.kLoop > 0, "kLoop is invalid");
|
||||
TORCH_CHECK(tilingData.nLoop > 0, "nLoop is invalid");
|
||||
TORCH_CHECK(tilingData.blockDim > 0, "nLoop is invalid");
|
||||
}
|
||||
} // namespace host_utils
|
||||
#endif
|
||||
155
csrc/batch_matmul_transpose/op_host/tiling/tiling_data.cpp
Normal file
155
csrc/batch_matmul_transpose/op_host/tiling/tiling_data.cpp
Normal file
@@ -0,0 +1,155 @@
|
||||
#include <map>
|
||||
#include "tiling_data.h"
|
||||
#include "common.h"
|
||||
#include "common_tiling.h"
|
||||
|
||||
namespace pp_matmul {
|
||||
|
||||
constexpr uint32_t L1_DESCALE_BUFFER_LEN_MAX = 6144;
|
||||
constexpr uint32_t CONST_3 = 3;
|
||||
constexpr uint32_t CONST_4 = 4;
|
||||
constexpr uint32_t CONST_16 = 16;
|
||||
constexpr uint32_t CONST_32 = 32;
|
||||
constexpr uint32_t CONST_256 = 256;
|
||||
constexpr uint32_t CONST_512 = 512;
|
||||
|
||||
const std::map<TensorDType, uint32_t> G_DTYPE_MAP = {{TensorDType::TENSOR_DTYPE_FLOAT16, 1u},
|
||||
{TensorDType::TENSOR_DTYPE_BF16, 2u}};
|
||||
const std::map<TensorFormat, uint32_t> G_FORMAT_MAP = {{TensorFormat::TENSOR_FORMAT_ND, 0u},
|
||||
{TensorFormat::TENSOR_FORMAT_NZ, 1u}};
|
||||
using MmType = MatMul::MatMulType;
|
||||
using QmType = MatMul::QuantMode;
|
||||
using namespace host_utils;
|
||||
|
||||
bool IsI8Bf16Kernel(const MatMulInfo &mmInfo)
|
||||
{
|
||||
bool isI8Bf16 = mmInfo.isInt8 && mmInfo.dtypeC == TensorDType::TENSOR_DTYPE_BF16;
|
||||
bool isI8Fp16 = mmInfo.isInt8 && mmInfo.dtypeC == TensorDType::TENSOR_DTYPE_FLOAT16 &&
|
||||
mmInfo.quantMode == QmType::PER_TOKEN_SYMM;
|
||||
return isI8Bf16 || isI8Fp16;
|
||||
}
|
||||
|
||||
HardwareInfo::HardwareInfo()
|
||||
{
|
||||
auto &platform = PlatformInfo::Instance();
|
||||
coreNum = platform.coreNumAic;
|
||||
l2Size = platform.l2Size;
|
||||
l1Size = platform.l1Size;
|
||||
l0aSize = platform.l0aSize;
|
||||
l0bSize = platform.l0bSize;
|
||||
l0cSize = platform.l0cSize;
|
||||
hbmBandWidth = 1;
|
||||
l2BandWidth = 5; // 5x faster than hbm.
|
||||
}
|
||||
|
||||
void PpMatmulTilingData::SetBaseShape(uint32_t batchSize, uint32_t m, uint32_t k, uint32_t n)
|
||||
{
|
||||
opShape.batchSize = batchSize;
|
||||
opShape.m = m;
|
||||
opShape.k = k;
|
||||
opShape.n = n;
|
||||
}
|
||||
|
||||
void PpMatmulTilingData::SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase, const MatMulInfo &mmInfo)
|
||||
{
|
||||
opShape.m0 = mBase;
|
||||
opShape.n0 = nBase;
|
||||
mLoop = CeilDiv(opShape.m, opShape.m0);
|
||||
nLoop = CeilDiv(opShape.n, opShape.n0);
|
||||
coreLoop = opShape.batchSize * mLoop * nLoop;
|
||||
|
||||
if (mLoop == 1 && mmInfo.transB && coreLoop % coreNum < coreNum / CONST_4 * CONST_3) {
|
||||
mBase = RoundUp<uint32_t>(opShape.m, CONST_16);
|
||||
opShape.m0 = mBase;
|
||||
uint32_t maxN0 = PlatformInfo::Instance().l0cSize / (mBase * sizeof(float));
|
||||
if (mmInfo.isInt8 || mmInfo.mmType == MmType::MATMUL_WITH_BIAS) {
|
||||
maxN0 = maxN0 < CONST_256 ? maxN0 : CONST_256;
|
||||
}
|
||||
uint32_t x = CeilDiv(opShape.n, coreNum);
|
||||
uint32_t y = CeilDiv(x, maxN0);
|
||||
nBase = RoundUp<uint32_t>(CeilDiv(x, y), CONST_16);
|
||||
uint32_t rqdL0CSize = mBase * nBase * sizeof(float);
|
||||
if (rqdL0CSize < PlatformInfo::Instance().l0cSize &&
|
||||
(mBase + nBase) * CONST_256 * sizeof(uint16_t) < L1AB_PINGPONG_BUFFER_LEN) {
|
||||
opShape.n0 = nBase;
|
||||
nLoop = CeilDiv(opShape.n, opShape.n0);
|
||||
coreLoop = opShape.batchSize * nLoop;
|
||||
}
|
||||
}
|
||||
blockDim = std::min(coreLoop, coreNum);
|
||||
}
|
||||
|
||||
// transA transB quantMode [dtype] format
|
||||
void PpMatmulTilingData::SetTilingKey(const MatMulInfo &mmInfo, uint32_t swizzleDirect, uint32_t enSplitK)
|
||||
{
|
||||
if (mmInfo.mmType == MmType::MATMUL_ACCUM_ATOMIC || mmInfo.mmType == MmType::MATMUL_WITH_BIAS ||
|
||||
mmInfo.mmType == MmType::MATMUL_EIN_SUM || mmInfo.mmType == MmType::MATMUL_DEQUANT || IsI8Bf16Kernel(mmInfo)) {
|
||||
// SwizzleDir[1] TransA[1] TransB[1] DtypeA[3] DtypeB[3] DtypeC[3] FormatA[1] FormatB[1] FormatC[1] WithBias[1]
|
||||
tilingKey = swizzleDirect;
|
||||
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.transA);
|
||||
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.transB);
|
||||
tilingKey = (tilingKey << 3) + G_DTYPE_MAP.at(mmInfo.dtypeA); // 3bit for dtypeA.
|
||||
tilingKey = (tilingKey << 3) + G_DTYPE_MAP.at(mmInfo.dtypeB); // 3bit for dtypeB.
|
||||
tilingKey = (tilingKey << 3) + G_DTYPE_MAP.at(mmInfo.dtypeC); // 3bit for dtypeC.
|
||||
tilingKey = (tilingKey << 1) + G_FORMAT_MAP.at(mmInfo.formatA);
|
||||
tilingKey = (tilingKey << 1) + G_FORMAT_MAP.at(mmInfo.formatB);
|
||||
tilingKey = (tilingKey << 1) + G_FORMAT_MAP.at(mmInfo.formatC);
|
||||
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.biasFlag);
|
||||
} else {
|
||||
tilingKey = swizzleDirect;
|
||||
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.transA);
|
||||
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.transB);
|
||||
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.isInt8);
|
||||
tilingKey = (tilingKey << 1) + static_cast<uint32_t>(mmInfo.biasFlag);
|
||||
tilingKey = (tilingKey << 1) + enSplitK;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t PpMatmulTilingData::End(const MatMulInfo &mmInfo)
|
||||
{
|
||||
uint32_t cubeBlockSize = mmInfo.isInt8 ? CUBE_BLOCK_SIZE_INT8 : CUBE_BLOCK_SIZE;
|
||||
uint32_t kBlockSize = mmInfo.isInt8 ? BLOCK_SIZE_INT8_K : BLOCK_SIZE;
|
||||
uint32_t scaleBlockSize = mmInfo.isInt8 ? L1_DESCALE_BUFFER_LEN_MAX : 0;
|
||||
uint32_t shapeSum = opShape.m0 + opShape.n0;
|
||||
if (mmInfo.isInt8 && (mmInfo.transA || !mmInfo.transB)) {
|
||||
shapeSum = RoundUp<uint32_t>(opShape.m0, CONST_32) + RoundUp<uint32_t>(opShape.n0, CONST_32);
|
||||
}
|
||||
uint32_t k0Max = shapeSum == 0
|
||||
? L1AB_PINGPONG_BUFFER_LEN
|
||||
: static_cast<uint32_t>(static_cast<float>(L1AB_PINGPONG_BUFFER_LEN - scaleBlockSize) /
|
||||
(shapeSum * mmInfo.inDtype));
|
||||
if (mmInfo.mmType == MatMul::MatMulType::MATMUL_WITH_BIAS) {
|
||||
uint32_t l1AbSize = L1AB_PINGPONG_BUFFER_LEN - opShape.n0 * sizeof(float);
|
||||
k0Max = l1AbSize / (shapeSum * mmInfo.inDtype);
|
||||
}
|
||||
|
||||
opShape.k0 =
|
||||
k0Max < cubeBlockSize ? RoundDown<uint32_t>(k0Max, kBlockSize) : RoundDown<uint32_t>(k0Max, cubeBlockSize);
|
||||
if (opShape.k0 > CONST_512) {
|
||||
opShape.k0 = RoundDown<uint32_t>(opShape.k0, CONST_512);
|
||||
}
|
||||
kLoop = CeilDiv(opShape.k, opShape.k0);
|
||||
return blockDim;
|
||||
}
|
||||
|
||||
void GetPpMatmulTiling(const MatMulInfo &mmInfo, const HardwareInfo &hwInfo, uint32_t &blockDim,
|
||||
PpMatmulTilingData &tilingData)
|
||||
{
|
||||
OpShape opShape;
|
||||
opShape.batchSize = mmInfo.batchSize;
|
||||
opShape.m = mmInfo.m;
|
||||
opShape.n = mmInfo.n;
|
||||
opShape.k = mmInfo.k;
|
||||
tilingData.opShape = opShape;
|
||||
tilingData.quantMode = static_cast<uint32_t>(mmInfo.quantMode);
|
||||
tilingData.SetTilingKey(mmInfo, 0, 0); // init tilingkey with transA transB.
|
||||
if (opShape.m < opShape.n) {
|
||||
TilingFunc<false, OpShape, PpMatmulTilingData, HardwareInfo, MatMulInfo>(opShape, tilingData, hwInfo, mmInfo);
|
||||
} else {
|
||||
TilingFunc<true, OpShape, PpMatmulTilingData, HardwareInfo, MatMulInfo>(opShape, tilingData, hwInfo, mmInfo);
|
||||
}
|
||||
uint32_t direct = Swizzl<PpMatmulTilingData>(tilingData);
|
||||
blockDim = tilingData.End(mmInfo);
|
||||
tilingData.SetTilingKey(mmInfo, direct, 0);
|
||||
}
|
||||
} // namespace pp_matmul
|
||||
90
csrc/batch_matmul_transpose/op_host/tiling/tiling_data.h
Normal file
90
csrc/batch_matmul_transpose/op_host/tiling/tiling_data.h
Normal file
@@ -0,0 +1,90 @@
|
||||
#ifndef PP_MATMUL_TILING_DATA
|
||||
#define PP_MATMUL_TILING_DATA
|
||||
#include <cstdint>
|
||||
|
||||
namespace pp_matmul {
|
||||
struct MatMul {
|
||||
enum class MatMulType : uint32_t {
|
||||
MATMUL_DEFAULT = 0, // C = op(A) * op(B)
|
||||
MATMUL_DEQUANT, //
|
||||
MATMUL_ACCUM_ATOMIC, // C += op(A) * op(B)
|
||||
MATMUL_WITH_BIAS, // C = op(A) * op(B) + Bias, where Bias is a vector.
|
||||
MATMUL_EIN_SUM
|
||||
};
|
||||
enum class QuantMode : uint32_t { PER_CHANNEL_SYMM = 0, PER_CHANNEL_ASYMM, PER_TOKEN_SYMM };
|
||||
};
|
||||
|
||||
enum class TensorDType : uint32_t { TENSOR_DTYPE_FLOAT16 = 0, TENSOR_DTYPE_BF16 };
|
||||
|
||||
enum class TensorFormat : uint32_t { TENSOR_FORMAT_ND = 0, TENSOR_FORMAT_NZ };
|
||||
|
||||
struct MatMulInfo {
|
||||
uint32_t batchSize{0};
|
||||
uint32_t m{0}; // actual input m
|
||||
uint32_t k{0}; // actual input k
|
||||
uint32_t n{0}; // actual input n
|
||||
TensorDType dtypeA{TensorDType::TENSOR_DTYPE_FLOAT16};
|
||||
TensorDType dtypeB{TensorDType::TENSOR_DTYPE_FLOAT16};
|
||||
TensorDType dtypeC{TensorDType::TENSOR_DTYPE_FLOAT16};
|
||||
TensorFormat formatA{TensorFormat::TENSOR_FORMAT_ND};
|
||||
TensorFormat formatB{TensorFormat::TENSOR_FORMAT_ND};
|
||||
TensorFormat formatC{TensorFormat::TENSOR_FORMAT_ND};
|
||||
MatMul::MatMulType mmType{MatMul::MatMulType::MATMUL_DEFAULT};
|
||||
bool transA{0}; // false: 0, true: 1
|
||||
bool transB{0}; // false: 0, true: 1
|
||||
bool biasFlag{0}; // false: 0, true: 1
|
||||
bool isInt8{0}; // false: 0, true: 1
|
||||
float inDtype{0};
|
||||
float outDtype{0};
|
||||
MatMul::QuantMode quantMode{MatMul::QuantMode::PER_CHANNEL_SYMM};
|
||||
};
|
||||
|
||||
struct OpShape {
|
||||
uint32_t batchSize{0};
|
||||
uint32_t m{0};
|
||||
uint32_t k{0};
|
||||
uint32_t n{0};
|
||||
uint32_t m0{0};
|
||||
uint32_t k0{0};
|
||||
uint32_t n0{0};
|
||||
};
|
||||
|
||||
struct HardwareInfo {
|
||||
uint32_t coreNum{0};
|
||||
uint32_t l2Size{0};
|
||||
uint32_t l1Size{0};
|
||||
uint32_t l0aSize{0};
|
||||
uint32_t l0bSize{0};
|
||||
uint32_t l0cSize{0};
|
||||
uint32_t hbmBandWidth{0};
|
||||
uint32_t l2BandWidth{0};
|
||||
|
||||
HardwareInfo();
|
||||
};
|
||||
|
||||
#pragma pack(push, 1)
|
||||
struct PpMatmulTilingData {
|
||||
OpShape opShape{};
|
||||
uint32_t mLoop{1};
|
||||
uint32_t kLoop{1};
|
||||
uint32_t nLoop{1};
|
||||
uint32_t coreLoop{1};
|
||||
uint32_t swizzlCount{1};
|
||||
uint32_t tilingKey{0};
|
||||
uint32_t blockDim{1};
|
||||
uint32_t swizzlDirect{0};
|
||||
uint32_t splitk{0};
|
||||
uint32_t enShuffleK{0};
|
||||
uint32_t quantMode{0};
|
||||
|
||||
void SetBaseShape(uint32_t batchSize, uint32_t m, uint32_t k, uint32_t n);
|
||||
void SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase, const MatMulInfo &mmInfo);
|
||||
void SetTilingKey(const MatMulInfo &mmInfo, uint32_t swizzleDirect, uint32_t enSplitK);
|
||||
uint32_t End(const MatMulInfo &mmInfo);
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
void GetPpMatmulTiling(const MatMulInfo &mmInfo, const HardwareInfo &hwInfo, uint32_t &blockDim,
|
||||
PpMatmulTilingData &tilingData);
|
||||
} // namespace pp_matmul
|
||||
#endif
|
||||
@@ -0,0 +1,825 @@
|
||||
// Adapted from
|
||||
// https://gitee.com/ascend/ascend-transformer-boost
|
||||
//
|
||||
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
// This file is a part of the CANN Open Software.
|
||||
// Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
// Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
// See LICENSE in the root of the software repository for the full text of the License.
|
||||
//
|
||||
|
||||
#define __aicore__ [aicore]
|
||||
#include "kernel_operator.h"
|
||||
#include "../op_host/tiling/tiling_data.h"
|
||||
#include "../../mla_preprocess/op_kernel/kernel/common.h"
|
||||
#include "../../mla_preprocess/op_kernel/kernel/hardware.h"
|
||||
#include "../../mla_preprocess/op_kernel/kernel/mma.h"
|
||||
#include "../../mla_preprocess/op_kernel/kernel/utils.h"
|
||||
#include "../../mla_preprocess/op_kernel/kernel/iterator.h"
|
||||
#include "../../kernels/math_utils.h"
|
||||
|
||||
constexpr uint32_t L0_PINGPONG_BUFFER_LEN = 16384;
|
||||
constexpr uint32_t L1_PINGPONG_BUFFER_LEN = 131072;
|
||||
constexpr uint32_t CONST_16 = 16;
|
||||
constexpr uint32_t CONST_256 = 256;
|
||||
constexpr uint64_t ND2NZ_STRIDE_LIMIT = 65536;
|
||||
constexpr uint64_t BLOCK_SIZE_16 = 16;
|
||||
constexpr uint64_t CONST_16UL = 16;
|
||||
constexpr uint64_t CONST_256UL = 256;
|
||||
|
||||
struct MatCoord {
|
||||
uint64_t m{0};
|
||||
uint64_t k{0};
|
||||
uint64_t n{0};
|
||||
};
|
||||
|
||||
using namespace device_utils;
|
||||
|
||||
template <uint32_t SwizzleDirect, bool TA, bool TB, typename InDtype = half, typename OutDtype = half,
|
||||
DataFormat FormatB = DataFormat::ND>
|
||||
class PpMatmulEinSum
|
||||
{
|
||||
using LocalTensor = AscendC::LocalTensor<InDtype>;
|
||||
template <DataFormat srcFormat = DataFormat::ND, DataFormat dstFormat = DataFormat::ND>
|
||||
using CopyGmToCbuf = gm_to_l1<ArchType::ASCEND_V220, InDtype, srcFormat, dstFormat>;
|
||||
using LoadCbufToCa = l1_to_l0_a<ArchType::ASCEND_V220, InDtype, TA, DataFormat::ZN, DataFormat::ZZ>;
|
||||
using LoadCbufToCb = l1_to_l0_b<ArchType::ASCEND_V220, InDtype, TB, DataFormat::ZN, DataFormat::NZ>;
|
||||
using Mad = mmad<ArchType::ASCEND_V220, InDtype, InDtype, float, TA>;
|
||||
using CopyCcToGm = l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, OutDtype, float>;
|
||||
|
||||
public:
|
||||
__aicore__ explicit PpMatmulEinSum(){};
|
||||
|
||||
__aicore__ __force_inline__ void Init(__gm__ uint8_t *__restrict__ a, __gm__ uint8_t *__restrict__ b,
|
||||
__gm__ uint8_t *__restrict__ c, __gm__ uint8_t *__restrict__ tiling_data)
|
||||
{
|
||||
gm_a.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(a));
|
||||
gm_b.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(b));
|
||||
gm_c.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(c));
|
||||
auto gm_tiling_data = reinterpret_cast<__gm__ pp_matmul::PpMatmulTilingData *>(tiling_data);
|
||||
|
||||
batch_size = gm_tiling_data->opShape.batchSize;
|
||||
m = gm_tiling_data->opShape.m;
|
||||
k = gm_tiling_data->opShape.k;
|
||||
n = gm_tiling_data->opShape.n;
|
||||
m0 = gm_tiling_data->opShape.m0;
|
||||
k0 = gm_tiling_data->opShape.k0;
|
||||
n0 = gm_tiling_data->opShape.n0;
|
||||
tdim.m = gm_tiling_data->mLoop;
|
||||
tdim.k = gm_tiling_data->kLoop;
|
||||
tdim.n = gm_tiling_data->nLoop;
|
||||
core_loop = gm_tiling_data->coreLoop;
|
||||
swizzle_cnt = gm_tiling_data->swizzlCount;
|
||||
en_shuffle_k = gm_tiling_data->enShuffleK;
|
||||
|
||||
AsdopsBuffer<ArchType::ASCEND_V220> buf;
|
||||
l1_base_a = buf.template GetBuffer<BufferType::ASCEND_CB, InDtype>(0);
|
||||
l1_base_b = buf.template GetBuffer<BufferType::ASCEND_CB, InDtype>(
|
||||
RoundUp<uint64_t>(m0 * k0 * sizeof(InDtype), CONST_256UL));
|
||||
l0a_base = buf.template GetBuffer<BufferType::ASCEND_L0A, InDtype>(0);
|
||||
l0b_base = buf.template GetBuffer<BufferType::ASCEND_L0B, InDtype>(0);
|
||||
num_core = AscendC::GetBlockNum();
|
||||
core_idx = AscendC::GetBlockIdx();
|
||||
ping_flag = 1;
|
||||
}
|
||||
|
||||
__aicore__ __force_inline__ void GetBlockIdx(uint64_t index, MatCoord &tidx)
|
||||
{
|
||||
uint64_t in_batch_idx = index % (tdim.m * tdim.n);
|
||||
if constexpr (SwizzleDirect == 0) { // Zn
|
||||
uint64_t tile_block_loop = (tdim.m + swizzle_cnt - 1) / swizzle_cnt;
|
||||
uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.n);
|
||||
uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.n);
|
||||
|
||||
uint64_t n_row = swizzle_cnt;
|
||||
if (tile_block_idx == tile_block_loop - 1) {
|
||||
n_row = tdim.m - swizzle_cnt * tile_block_idx;
|
||||
}
|
||||
tidx.m = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row;
|
||||
tidx.n = in_tile_block_idx / n_row;
|
||||
if (tile_block_idx % 2 != 0) {
|
||||
tidx.n = tdim.n - tidx.n - 1;
|
||||
}
|
||||
} else if constexpr (SwizzleDirect == 1) { // Nz
|
||||
uint64_t tile_block_loop = (tdim.n + swizzle_cnt - 1) / swizzle_cnt;
|
||||
uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.m);
|
||||
uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.m);
|
||||
|
||||
uint64_t n_col = swizzle_cnt;
|
||||
if (tile_block_idx == tile_block_loop - 1) {
|
||||
n_col = tdim.n - swizzle_cnt * tile_block_idx;
|
||||
}
|
||||
tidx.m = in_tile_block_idx / n_col;
|
||||
tidx.n = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col;
|
||||
if (tile_block_idx % 2 != 0) {
|
||||
tidx.m = tdim.m - tidx.m - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ __force_inline__ void Process()
|
||||
{
|
||||
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0);
|
||||
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
|
||||
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2);
|
||||
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3);
|
||||
set_flag(PIPE_FIX, PIPE_M, EVENT_ID0);
|
||||
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0);
|
||||
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
|
||||
|
||||
for (uint64_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += num_core) {
|
||||
uint64_t batch_idx = loop_idx / tdim.n / tdim.m;
|
||||
MatCoord tidx{0};
|
||||
GetBlockIdx(loop_idx, tidx);
|
||||
uint64_t offset_a = 0, offset_b = 0, offset_a_next = 0, offset_b_next = 0;
|
||||
uint64_t offset_c = tidx.m * m0 * batch_size * n + batch_idx * n + tidx.n * n0;
|
||||
uint64_t m_actual = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0;
|
||||
uint64_t n_actual = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0;
|
||||
uint64_t m_round = RoundUp<uint64_t, CONST_16UL>(m_actual);
|
||||
uint64_t n_round = RoundUp<uint64_t, CONST_16UL>(n_actual);
|
||||
uint64_t mn_max = m_round > n_round ? m_round : n_round;
|
||||
uint64_t k_part_len = L0_PINGPONG_BUFFER_LEN / mn_max / CONST_16 * CONST_16;
|
||||
uint64_t shuffle_k = en_shuffle_k ? (core_idx % tdim.k) : 0;
|
||||
if (TA) {
|
||||
offset_a = shuffle_k * k0 * m * batch_size + batch_idx * m + tidx.m * m0;
|
||||
} else {
|
||||
offset_a = tidx.m * m0 * batch_size * k + batch_idx * k + shuffle_k * k0;
|
||||
}
|
||||
|
||||
if (TB) {
|
||||
if constexpr (FormatB != DataFormat::NZ) {
|
||||
offset_b = batch_idx * k * n + tidx.n * n0 * k + shuffle_k * k0;
|
||||
} else {
|
||||
offset_b = batch_idx * RoundUp<uint64_t, CONST_16UL>(k) * RoundUp<uint64_t, CONST_16UL>(n) +
|
||||
tidx.n * n0 * BLOCK_SIZE_16 + shuffle_k * k0 * RoundUp<uint64_t, CONST_16UL>(n);
|
||||
}
|
||||
} else {
|
||||
if constexpr (FormatB != DataFormat::NZ) {
|
||||
offset_b = batch_idx * k * n + shuffle_k * k0 * n + tidx.n * n0;
|
||||
} else {
|
||||
offset_b = batch_idx * RoundUp<uint64_t, CONST_16UL>(k) * RoundUp<uint64_t, CONST_16UL>(n) +
|
||||
shuffle_k * k0 * BLOCK_SIZE_16 + tidx.n * n0 * RoundUp<uint64_t, CONST_16UL>(k);
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t k_actual = (shuffle_k == tdim.k - 1) ? k - shuffle_k * k0 : k0;
|
||||
uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16;
|
||||
|
||||
LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN];
|
||||
LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN];
|
||||
LocalTensor l0a_buf = ping_flag ? l0a_base : l0a_base[L0_PINGPONG_BUFFER_LEN];
|
||||
LocalTensor l0b_buf = ping_flag ? l0b_base : l0b_base[L0_PINGPONG_BUFFER_LEN];
|
||||
event_t event_id = ping_flag ? EVENT_ID0 : EVENT_ID1;
|
||||
|
||||
if (loop_idx == core_idx) {
|
||||
WAIT_FLAG(MTE1, MTE2, event_id);
|
||||
// *** load matrix A to L1
|
||||
if ((m == 1) || (m_actual == 1 && !TA)) {
|
||||
CopyGmToCbuf<DataFormat::ND, DataFormat::ND>(l1_buf_a, // dst
|
||||
gm_a[offset_a], // src
|
||||
1, // nTileActual
|
||||
16, // nTileCeil
|
||||
1, // nVal
|
||||
k_actual, // kTileActual
|
||||
k_round, // kTileCeil
|
||||
k); // dVal
|
||||
} else {
|
||||
if (TA) {
|
||||
CopyGmToCbuf<DataFormat::ND, DataFormat::NZ>(l1_buf_a, // dst
|
||||
gm_a[offset_a], // src
|
||||
k_actual, // nTileActual
|
||||
k_round, // nTileCeil
|
||||
k, // nVal
|
||||
m_actual, // dTileActual
|
||||
m_round, // dTileCeil
|
||||
m * batch_size); // dVal
|
||||
} else {
|
||||
CopyGmToCbuf<DataFormat::ND, DataFormat::NZ>(l1_buf_a, // dst
|
||||
gm_a[offset_a], // src
|
||||
m_actual, // nTileActual
|
||||
m_round, // nTileCeil
|
||||
m, // nVal
|
||||
k_actual, // dTileActual
|
||||
k_round, // dTileCeil
|
||||
k * batch_size); // dVal
|
||||
}
|
||||
}
|
||||
SET_FLAG(MTE2, MTE1, event_id);
|
||||
// *** load matrix B to L1
|
||||
wait_flag(PIPE_MTE1, PIPE_MTE2, event_id + 2);
|
||||
if constexpr (FormatB != DataFormat::NZ) {
|
||||
if (TB) {
|
||||
CopyGmToCbuf<DataFormat::ND, DataFormat::NZ>(l1_buf_b, // dst
|
||||
gm_b[offset_b], // src
|
||||
n_actual, // nTileActual
|
||||
n_round, // nTileCeil
|
||||
n, // nVal
|
||||
k_actual, // dTileActual
|
||||
k_round, // dTileCeil
|
||||
k); // dVal
|
||||
} else {
|
||||
CopyGmToCbuf<DataFormat::ND, DataFormat::NZ>(l1_buf_b, // dst
|
||||
gm_b[offset_b], // src
|
||||
k_actual, // nTileActual
|
||||
k_round, // nTileCeil
|
||||
k, // nVal
|
||||
n_actual, // dTileActual
|
||||
n_round, // dTileCeil
|
||||
n); // dVal
|
||||
}
|
||||
} else {
|
||||
if (TB) {
|
||||
CopyGmToCbuf<DataFormat::NZ, DataFormat::NZ>(l1_buf_b, // dst
|
||||
gm_b[offset_b], // src
|
||||
n_actual, // nTileActual
|
||||
n_round, // nTileCeil
|
||||
RoundUp<uint64_t, CONST_16UL>(n), // nVal
|
||||
k_actual, // dTileActual
|
||||
k_round, // dTileCeil
|
||||
RoundUp<uint64_t, CONST_16UL>(k)); // dVal
|
||||
} else {
|
||||
CopyGmToCbuf<DataFormat::NZ, DataFormat::NZ>(l1_buf_b, // dst
|
||||
gm_b[offset_b], // src
|
||||
k_actual, // nTileActual
|
||||
k_round, // nTileCeil
|
||||
RoundUp<uint64_t, CONST_16UL>(k), // nVal
|
||||
n_actual, // dTileActual
|
||||
n_round, // dTileCeil
|
||||
RoundUp<uint64_t, CONST_16UL>(n)); // dVal
|
||||
}
|
||||
}
|
||||
SET_FLAG(MTE2, MTE1, event_id + 2);
|
||||
}
|
||||
|
||||
for (tidx.k = 0; tidx.k < tdim.k; ++tidx.k) {
|
||||
shuffle_k = en_shuffle_k ? (tidx.k + core_idx) % tdim.k : tidx.k;
|
||||
uint64_t k_actual = (shuffle_k == (tdim.k - 1)) ? (k - shuffle_k * k0) : k0;
|
||||
uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16;
|
||||
fdim.k = (k_actual + k_part_len - 1) / k_part_len;
|
||||
|
||||
LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN];
|
||||
LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN];
|
||||
auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1;
|
||||
|
||||
if (tidx.k < tdim.k - 1) {
|
||||
uint64_t shuffle_k_next = en_shuffle_k ? (core_idx + tidx.k + 1) % tdim.k : (tidx.k + 1);
|
||||
if (TA) {
|
||||
offset_a_next = shuffle_k_next * k0 * m * batch_size + batch_idx * m + tidx.m * m0;
|
||||
} else {
|
||||
offset_a_next = tidx.m * m0 * batch_size * k + batch_idx * k + shuffle_k_next * k0;
|
||||
}
|
||||
|
||||
if (TB) {
|
||||
if constexpr (FormatB != DataFormat::NZ) {
|
||||
offset_b_next = batch_idx * k * n + tidx.n * n0 * k + shuffle_k_next * k0;
|
||||
} else {
|
||||
offset_b_next =
|
||||
batch_idx * RoundUp<uint64_t, CONST_16UL>(k) * RoundUp<uint64_t, CONST_16UL>(n) +
|
||||
tidx.n * n0 * BLOCK_SIZE_16 + shuffle_k_next * k0 * RoundUp<uint64_t, CONST_16UL>(n);
|
||||
}
|
||||
} else {
|
||||
if constexpr (FormatB != DataFormat::NZ) {
|
||||
offset_b_next = batch_idx * k * n + shuffle_k_next * k0 * n + tidx.n * n0;
|
||||
} else {
|
||||
offset_b_next =
|
||||
batch_idx * RoundUp<uint64_t, CONST_16UL>(k) * RoundUp<uint64_t, CONST_16UL>(n) +
|
||||
shuffle_k_next * k0 * BLOCK_SIZE_16 + tidx.n * n0 * RoundUp<uint64_t, CONST_16UL>(k);
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0;
|
||||
uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16;
|
||||
|
||||
LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN];
|
||||
LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN];
|
||||
event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1;
|
||||
|
||||
WAIT_FLAG(MTE1, MTE2, event_id_next);
|
||||
// *** load matrix A to L1
|
||||
if ((m == 1) || (m_actual == 1 && !TA)) {
|
||||
CopyGmToCbuf<DataFormat::ND, DataFormat::ND>(l1_buf_a_next, // dst
|
||||
gm_a[offset_a_next], // src
|
||||
m_actual, // nTileActual
|
||||
m_round, // nTileCeil
|
||||
m, // nVal
|
||||
k_actual_next, // kTileActual
|
||||
k_round_next, // kTileCeil
|
||||
k); // dVal
|
||||
} else {
|
||||
if (TA) {
|
||||
CopyGmToCbuf<DataFormat::ND, DataFormat::NZ>(l1_buf_a_next, // dst
|
||||
gm_a[offset_a_next], // src
|
||||
k_actual_next, // nTileActual
|
||||
k_round_next, // nTileCeil
|
||||
k, // nVal
|
||||
m_actual, // dTileActual
|
||||
m_round, // dTileCeil
|
||||
m * batch_size); // dVal
|
||||
} else {
|
||||
CopyGmToCbuf<DataFormat::ND, DataFormat::NZ>(l1_buf_a_next, // dst
|
||||
gm_a[offset_a_next], // src
|
||||
m_actual, // nTileActual
|
||||
m_round, // nTileCeil
|
||||
m, // nVal
|
||||
k_actual_next, // dTileActual
|
||||
k_round_next, // dTileCeil
|
||||
k * batch_size); // dVal
|
||||
}
|
||||
}
|
||||
SET_FLAG(MTE2, MTE1, event_id_next);
|
||||
|
||||
// *** load matrix B to L1
|
||||
wait_flag(PIPE_MTE1, PIPE_MTE2, event_id_next + 2);
|
||||
if constexpr (FormatB != DataFormat::NZ) {
|
||||
if (TB) {
|
||||
CopyGmToCbuf<DataFormat::ND, DataFormat::NZ>(l1_buf_b_next, // dst
|
||||
gm_b[offset_b_next], // src
|
||||
n_actual, // nTileActual
|
||||
n_round, // nTileCeil
|
||||
n, // nVal
|
||||
k_actual_next, // dTileActual
|
||||
k_round_next, // dTileCeil
|
||||
k); // dVal
|
||||
} else {
|
||||
CopyGmToCbuf<DataFormat::ND, DataFormat::NZ>(l1_buf_b_next, // dst
|
||||
gm_b[offset_b_next], // src
|
||||
k_actual_next, // nTileActual
|
||||
k_round_next, // nTileCeil
|
||||
k, // nVal
|
||||
n_actual, // dTileActual
|
||||
n_round, // dTileCeil
|
||||
n); // dVal
|
||||
}
|
||||
} else {
|
||||
if (TB) {
|
||||
CopyGmToCbuf<DataFormat::NZ, DataFormat::NZ>(l1_buf_b_next, // dst
|
||||
gm_b[offset_b_next], // src
|
||||
n_actual, // nTileActual
|
||||
n_round, // nTileCeil
|
||||
RoundUp<uint64_t, CONST_16UL>(n), // nVal
|
||||
k_actual_next, // dTileActual
|
||||
k_round_next, // dTileCeil
|
||||
RoundUp<uint64_t, CONST_16UL>(k)); // dVal
|
||||
} else {
|
||||
CopyGmToCbuf<DataFormat::NZ, DataFormat::NZ>(l1_buf_b_next, // dst
|
||||
gm_b[offset_b_next], // src
|
||||
k_actual_next, // nTileActual
|
||||
k_round_next, // nTileCeil
|
||||
RoundUp<uint64_t, CONST_16UL>(k), // nVal
|
||||
n_actual, // dTileActual
|
||||
n_round, // dTileCeil
|
||||
RoundUp<uint64_t, CONST_16UL>(n)); // dVal
|
||||
}
|
||||
}
|
||||
SET_FLAG(MTE2, MTE1, event_id_next + 2);
|
||||
}
|
||||
|
||||
if (tidx.k == tdim.k - 1 && loop_idx + num_core < core_loop) {
|
||||
uint64_t b_idx_next = (loop_idx + num_core) / tdim.n / tdim.m;
|
||||
MatCoord tidx{0};
|
||||
GetBlockIdx(loop_idx + num_core, tidx);
|
||||
uint64_t shuffle_k_next = en_shuffle_k ? (core_idx % tdim.k) : 0;
|
||||
uint64_t m_actual_next = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0;
|
||||
uint64_t n_actual_next = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0;
|
||||
uint64_t m_round_next = (m_actual_next + CONST_16 - 1) / CONST_16 * CONST_16;
|
||||
uint64_t n_round_next = (n_actual_next + CONST_16 - 1) / CONST_16 * CONST_16;
|
||||
uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0;
|
||||
uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16;
|
||||
if (TA) {
|
||||
offset_a_next = shuffle_k_next * k0 * m * batch_size + b_idx_next * m + tidx.m * m0;
|
||||
} else {
|
||||
offset_a_next = tidx.m * m0 * batch_size * k + b_idx_next * k + shuffle_k_next * k0;
|
||||
}
|
||||
|
||||
if (TB) {
|
||||
if constexpr (FormatB != DataFormat::NZ) {
|
||||
offset_b_next = b_idx_next * k * n + tidx.n * n0 * k + shuffle_k_next * k0;
|
||||
} else {
|
||||
offset_b_next =
|
||||
b_idx_next * RoundUp<uint64_t, CONST_16UL>(k) * RoundUp<uint64_t, CONST_16UL>(n) +
|
||||
tidx.n * n0 * BLOCK_SIZE_16 + shuffle_k_next * k0 * RoundUp<uint64_t, CONST_16UL>(n);
|
||||
}
|
||||
} else {
|
||||
if constexpr (FormatB != DataFormat::NZ) {
|
||||
offset_b_next = b_idx_next * k * n + shuffle_k_next * k0 * n + tidx.n * n0;
|
||||
} else {
|
||||
offset_b_next =
|
||||
b_idx_next * RoundUp<uint64_t, CONST_16UL>(k) * RoundUp<uint64_t, CONST_16UL>(n) +
|
||||
shuffle_k_next * k0 * BLOCK_SIZE_16 + tidx.n * n0 * RoundUp<uint64_t, CONST_16UL>(k);
|
||||
}
|
||||
}
|
||||
|
||||
LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN];
|
||||
LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN];
|
||||
event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1;
|
||||
|
||||
WAIT_FLAG(MTE1, MTE2, event_id_next);
|
||||
// *** load matrix A to L1
|
||||
if (m == 1 || m_actual_next == 1 && !TA) {
|
||||
CopyGmToCbuf<DataFormat::ND, DataFormat::ND>(l1_buf_a_next, // dst
|
||||
gm_a[offset_a_next], // src
|
||||
m_actual_next, // nTileActual
|
||||
m_round_next, // nTileCeil
|
||||
m, // nVal
|
||||
k_actual_next, // kTileActual
|
||||
k_round_next, // kTileCeil
|
||||
k); // dVal
|
||||
} else {
|
||||
if (TA) {
|
||||
CopyGmToCbuf<DataFormat::ND, DataFormat::NZ>(l1_buf_a_next, // dst
|
||||
gm_a[offset_a_next], // src
|
||||
k_actual_next, // nTileActual
|
||||
k_round_next, // nTileCeil
|
||||
k, // nVal
|
||||
m_actual_next, // dTileActual
|
||||
m_round_next, // dTileCeil
|
||||
m * batch_size); // dVal
|
||||
} else {
|
||||
CopyGmToCbuf<DataFormat::ND, DataFormat::NZ>(l1_buf_a_next, // dst
|
||||
gm_a[offset_a_next], // src
|
||||
m_actual_next, // nTileActual
|
||||
m_round_next, // nTileCeil
|
||||
m, // nVal
|
||||
k_actual_next, // dTileActual
|
||||
k_round_next, // dTileCeil
|
||||
k * batch_size); // dVal
|
||||
}
|
||||
}
|
||||
SET_FLAG(MTE2, MTE1, event_id_next);
|
||||
|
||||
// *** load matrix B to L1
|
||||
wait_flag(PIPE_MTE1, PIPE_MTE2, event_id_next + 2);
|
||||
if constexpr (FormatB != DataFormat::NZ) {
|
||||
if (TB) {
|
||||
CopyGmToCbuf<DataFormat::ND, DataFormat::NZ>(l1_buf_b_next, // dst
|
||||
gm_b[offset_b_next], // src
|
||||
n_actual_next, // nTileActual
|
||||
n_round_next, // nTileCeil
|
||||
n, // nVal
|
||||
k_actual_next, // dTileActual
|
||||
k_round_next, // dTileCeil
|
||||
k); // dVal
|
||||
} else {
|
||||
CopyGmToCbuf<DataFormat::ND, DataFormat::NZ>(l1_buf_b_next, // dst
|
||||
gm_b[offset_b_next], // src
|
||||
k_actual_next, // nTileActual
|
||||
k_round_next, // nTileCeil
|
||||
k, // nVal
|
||||
n_actual_next, // dTileActual
|
||||
n_round_next, // dTileCeil
|
||||
n); // dVal
|
||||
}
|
||||
} else {
|
||||
if (TB) {
|
||||
CopyGmToCbuf<DataFormat::NZ, DataFormat::NZ>(l1_buf_b_next, // dst
|
||||
gm_b[offset_b_next], // src
|
||||
n_actual_next, // nTileActual
|
||||
n_round_next, // nTileCeil
|
||||
RoundUp<uint64_t, CONST_16UL>(n), // nVal
|
||||
k_actual_next, // dTileActual
|
||||
k_round_next, // dTileCeil
|
||||
RoundUp<uint64_t, CONST_16UL>(k)); // dVal
|
||||
} else {
|
||||
CopyGmToCbuf<DataFormat::NZ, DataFormat::NZ>(l1_buf_b_next, // dst
|
||||
gm_b[offset_b_next], // src
|
||||
k_actual_next, // nTileActual
|
||||
k_round_next, // nTileCeil
|
||||
RoundUp<uint64_t, CONST_16UL>(k), // nVal
|
||||
n_actual_next, // dTileActual
|
||||
n_round_next, // dTileCeil
|
||||
RoundUp<uint64_t, CONST_16UL>(n)); // dVal
|
||||
}
|
||||
}
|
||||
SET_FLAG(MTE2, MTE1, event_id_next + 2);
|
||||
}
|
||||
|
||||
MatCoord fidx{0};
|
||||
for (fidx.k = 0; fidx.k < fdim.k; ++fidx.k) {
|
||||
uint32_t k0_round = (fidx.k < fdim.k - 1) ? k_part_len : k_round - fidx.k * k_part_len;
|
||||
uint32_t k0_actual = (fidx.k < fdim.k - 1) ? k_part_len : k_actual - fidx.k * k_part_len;
|
||||
|
||||
auto mte1_mad_ping_flag = 1 - fidx.k % 2;
|
||||
auto mte1_mad_event_id = mte1_mad_ping_flag ? EVENT_ID0 : EVENT_ID1;
|
||||
auto l0a_buf = l0a_base[(fidx.k % 2) * L0_PINGPONG_BUFFER_LEN];
|
||||
auto l0b_buf = l0b_base[(fidx.k % 2) * L0_PINGPONG_BUFFER_LEN];
|
||||
|
||||
// *** load matrix A from L1 to L0A
|
||||
if (fidx.k == 0) {
|
||||
WAIT_FLAG(MTE2, MTE1, event_id);
|
||||
}
|
||||
WAIT_FLAG(M, MTE1, mte1_mad_event_id);
|
||||
if ((m == 1) || (m_actual == 1 && !TA)) {
|
||||
l1_to_l0_a<ArchType::ASCEND_V220, InDtype, false, DataFormat::VECTOR, DataFormat::VECTOR>(
|
||||
l0a_buf, // dst
|
||||
l1_buf_a[fidx.k * k_part_len], // src
|
||||
0, // mTileCeil
|
||||
CeilDiv<CONST_256>(k0_round), // kPartCeil
|
||||
0, // mSrcStride
|
||||
1, // kSrcStride
|
||||
0, // mDstStride
|
||||
0); // kDstStride
|
||||
} else {
|
||||
if (TA) {
|
||||
LoadCbufToCa(l0a_buf, // l0Tensor
|
||||
l1_buf_a[fidx.k * k_part_len * CONST_16], // l1Tensor
|
||||
m_round, // mTileCeil
|
||||
k0_round, // kPartCeil
|
||||
k_round / CONST_16, // mSrcStride
|
||||
1, // kSrcStride
|
||||
k0_round / CONST_16, // mDstStride
|
||||
1); // kDstStride
|
||||
} else {
|
||||
LoadCbufToCa(l0a_buf, // l0Tensor
|
||||
l1_buf_a[fidx.k * k_part_len * m_round], // l1Tensor
|
||||
m_round, // mTileCeil
|
||||
k0_round, // kPartCeil
|
||||
1, // mSrcStride
|
||||
m_round / CONST_16, // kSrcStride
|
||||
k0_round / CONST_16, // mDstStride
|
||||
1); // kDstStride
|
||||
}
|
||||
}
|
||||
if (fidx.k == fdim.k - 1) {
|
||||
SET_FLAG(MTE1, MTE2, event_id);
|
||||
}
|
||||
|
||||
// *** load matrix B from L1 to L0B
|
||||
if (fidx.k == 0) {
|
||||
WAIT_FLAG(MTE2, MTE1, event_id + 2);
|
||||
}
|
||||
if (TB) {
|
||||
LoadCbufToCb(l0b_buf, // l0Tensor
|
||||
l1_buf_b[fidx.k * k_part_len * n_round], // l1Tensor
|
||||
n_round, // nTileCeil
|
||||
k0_round, // kPartCeil
|
||||
1, // nSrcStride
|
||||
n_round / CONST_16, // kSrcStride
|
||||
1, // nDstStride
|
||||
k0_round / CONST_16); // kDstStride
|
||||
} else {
|
||||
LoadCbufToCb(l0b_buf, // l0Tensor
|
||||
l1_buf_b[fidx.k * k_part_len * CONST_16], // l1Tensor
|
||||
n_round, // nTileCeil
|
||||
k0_round, // kPartCeil
|
||||
k_round / CONST_16, // nSrcStride
|
||||
1, // kSrcStride
|
||||
1, // nDstStride
|
||||
n_round / CONST_16); // kDstStride
|
||||
}
|
||||
if (fidx.k == fdim.k - 1) {
|
||||
SET_FLAG(MTE1, MTE2, event_id + 2);
|
||||
}
|
||||
|
||||
SET_FLAG(MTE1, M, mte1_mad_event_id);
|
||||
WAIT_FLAG(MTE1, M, mte1_mad_event_id);
|
||||
|
||||
bool init_c = (tidx.k == 0 && fidx.k == 0);
|
||||
if (init_c) {
|
||||
WAIT_FLAG(FIX, M, EVENT_ID0);
|
||||
}
|
||||
|
||||
if (m != 1 && m_actual == 1 && TA) {
|
||||
Mad(l0c_buf, // c
|
||||
l0a_buf, // a
|
||||
l0b_buf, // b
|
||||
CONST_16, // mTileActual
|
||||
n_actual, // nTileActual
|
||||
k0_actual, // kTileActual
|
||||
init_c); // initC
|
||||
} else {
|
||||
Mad(l0c_buf, // c
|
||||
l0a_buf, // a
|
||||
l0b_buf, // b
|
||||
m_actual, // mTileActual
|
||||
n_actual, // nTileActual
|
||||
k0_actual, // kTileActual
|
||||
init_c); // initC
|
||||
}
|
||||
|
||||
PIPE_BARRIER(M);
|
||||
SET_FLAG(M, MTE1, mte1_mad_event_id);
|
||||
}
|
||||
|
||||
ping_flag = 1 - ping_flag;
|
||||
}
|
||||
|
||||
SET_FLAG(M, FIX, EVENT_ID0);
|
||||
WAIT_FLAG(M, FIX, EVENT_ID0);
|
||||
|
||||
// copy from L0C to gm
|
||||
CopyCcToGm(gm_c[offset_c], // dst
|
||||
l0c_buf, // src
|
||||
m_actual, // mTileActual
|
||||
n_actual, // nTileActual
|
||||
m_round, // mTileCeil
|
||||
n * batch_size); // nActual
|
||||
SET_FLAG(FIX, M, EVENT_ID0);
|
||||
}
|
||||
|
||||
WAIT_FLAG(M, MTE1, EVENT_ID0);
|
||||
WAIT_FLAG(M, MTE1, EVENT_ID1);
|
||||
WAIT_FLAG(MTE1, MTE2, EVENT_ID0);
|
||||
WAIT_FLAG(MTE1, MTE2, EVENT_ID1);
|
||||
WAIT_FLAG(MTE1, MTE2, EVENT_ID2);
|
||||
WAIT_FLAG(MTE1, MTE2, EVENT_ID3);
|
||||
WAIT_FLAG(FIX, M, EVENT_ID0);
|
||||
PIPE_BARRIER(ALL);
|
||||
}
|
||||
|
||||
private:
|
||||
AscendC::GlobalTensor<InDtype> gm_a;
|
||||
AscendC::GlobalTensor<InDtype> gm_b;
|
||||
AscendC::GlobalTensor<OutDtype> gm_c;
|
||||
AscendC::LocalTensor<InDtype> l1_base_a;
|
||||
AscendC::LocalTensor<InDtype> l1_base_b;
|
||||
AscendC::LocalTensor<InDtype> l0a_base;
|
||||
AscendC::LocalTensor<InDtype> l0b_base;
|
||||
AscendC::LocalTensor<float> l0c_buf;
|
||||
|
||||
uint32_t num_core{0};
|
||||
uint32_t batch_size{0};
|
||||
uint32_t m{0};
|
||||
uint32_t k{0};
|
||||
uint32_t n{0};
|
||||
uint32_t m0{0};
|
||||
uint32_t k0{0};
|
||||
uint32_t n0{0};
|
||||
MatCoord tdim{0};
|
||||
MatCoord fdim{0};
|
||||
uint32_t core_loop{0};
|
||||
uint32_t swizzle_cnt{1};
|
||||
uint32_t core_idx{0};
|
||||
uint32_t en_shuffle_k{0};
|
||||
uint32_t ping_flag{0};
|
||||
};
|
||||
|
||||
extern "C" __global__ __aicore__ void batch_matmul_transpose(GM_ADDR gm_a, GM_ADDR gm_b, GM_ADDR gm_c,
|
||||
GM_ADDR gm_tiling_data)
|
||||
{
|
||||
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIC_ONLY);
|
||||
PpMatmulEinSum<0, false, false, half, half, DataFormat::ND>
|
||||
einsum_0_n_fp16_nd; // swizzleDir[0] transA[0] transB[0] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0]
|
||||
// DataFormatB[0]
|
||||
PpMatmulEinSum<1, false, false, half, half, DataFormat::ND>
|
||||
einsum_1_n_fp16_nd; // swizzleDir[1] transA[0] transB[0] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0]
|
||||
// DataFormatB[0]
|
||||
PpMatmulEinSum<0, false, true, half, half, DataFormat::ND>
|
||||
einsum_0_t_fp16_nd; // swizzleDir[0] transA[0] transB[1] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0]
|
||||
// DataFormatB[0]
|
||||
PpMatmulEinSum<1, false, true, half, half, DataFormat::ND>
|
||||
einsum_1_t_fp16_nd; // swizzleDir[1] transA[0] transB[1] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0]
|
||||
// DataFormatB[0]
|
||||
PpMatmulEinSum<0, false, false, __bf16, __bf16, DataFormat::ND>
|
||||
einsum_0_n_bf16_nd; // swizzleDir[0] transA[0] transB[0] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0]
|
||||
// DataFormatB[0]
|
||||
PpMatmulEinSum<1, false, false, __bf16, __bf16, DataFormat::ND>
|
||||
einsum_1_n_bf16_nd; // swizzleDir[1] transA[0] transB[0] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0]
|
||||
// DataFormatB[0]
|
||||
PpMatmulEinSum<0, false, true, __bf16, __bf16, DataFormat::ND>
|
||||
einsum_0_t_bf16_nd; // swizzleDir[0] transA[0] transB[1] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0]
|
||||
// DataFormatB[0]
|
||||
PpMatmulEinSum<1, false, true, __bf16, __bf16, DataFormat::ND>
|
||||
einsum_1_t_bf16_nd; // swizzleDir[1] transA[0] transB[1] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0]
|
||||
// DataFormatB[0]
|
||||
|
||||
PpMatmulEinSum<0, false, false, half, half, DataFormat::NZ>
|
||||
einsum_0_n_fp16_nz; // swizzleDir[0] transA[0] transB[0] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0]
|
||||
// DataFormatB[1]
|
||||
PpMatmulEinSum<1, false, false, half, half, DataFormat::NZ>
|
||||
einsum_1_n_fp16_nz; // swizzleDir[1] transA[0] transB[0] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0]
|
||||
// DataFormatB[1]
|
||||
PpMatmulEinSum<0, false, true, half, half, DataFormat::NZ>
|
||||
einsum_0_t_fp16_nz; // swizzleDir[0] transA[0] transB[1] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0]
|
||||
// DataFormatB[1]
|
||||
PpMatmulEinSum<1, false, true, half, half, DataFormat::NZ>
|
||||
einsum_1_t_fp16_nz; // swizzleDir[1] transA[0] transB[1] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0]
|
||||
// DataFormatB[1]
|
||||
PpMatmulEinSum<0, false, false, __bf16, __bf16, DataFormat::NZ>
|
||||
einsum_0_n_bf16_nz; // swizzleDir[0] transA[0] transB[0] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0]
|
||||
// DataFormatB[1]
|
||||
PpMatmulEinSum<1, false, false, __bf16, __bf16, DataFormat::NZ>
|
||||
einsum_1_n_bf16_nz; // swizzleDir[1] transA[0] transB[0] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0]
|
||||
// DataFormatB[1]
|
||||
PpMatmulEinSum<0, false, true, __bf16, __bf16, DataFormat::NZ>
|
||||
einsum_0_t_bf16_nz; // swizzleDir[0] transA[0] transB[1] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0]
|
||||
// DataFormatB[1]
|
||||
PpMatmulEinSum<1, false, true, __bf16, __bf16, DataFormat::NZ>
|
||||
einsum_1_t_bf16_nz; // swizzleDir[1] transA[0] transB[1] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0]
|
||||
// DataFormatB[1]
|
||||
|
||||
SetPadding<uint64_t>((uint64_t)0);
|
||||
SetNdpara(1, 0, 0);
|
||||
SetAtomicnone();
|
||||
|
||||
// get tiling args
|
||||
auto tiling_data = reinterpret_cast<__gm__ pp_matmul::PpMatmulTilingData *>(gm_tiling_data);
|
||||
uint32_t masked_key = tiling_data->tilingKey >> 2;
|
||||
|
||||
switch (masked_key) {
|
||||
case 0b00000100100100:
|
||||
case 0b01000100100100:
|
||||
einsum_0_n_fp16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data);
|
||||
einsum_0_n_fp16_nd.Process();
|
||||
break;
|
||||
case 0b00100100100100:
|
||||
case 0b01100100100100:
|
||||
einsum_0_t_fp16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data);
|
||||
einsum_0_t_fp16_nd.Process();
|
||||
break;
|
||||
case 0b10000100100100:
|
||||
case 0b11000100100100:
|
||||
einsum_1_n_fp16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data);
|
||||
einsum_1_n_fp16_nd.Process();
|
||||
break;
|
||||
case 0b10100100100100:
|
||||
case 0b11100100100100:
|
||||
einsum_1_t_fp16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data);
|
||||
einsum_1_t_fp16_nd.Process();
|
||||
break;
|
||||
case 0b00001001001000:
|
||||
case 0b01001001001000:
|
||||
einsum_0_n_bf16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data);
|
||||
einsum_0_n_bf16_nd.Process();
|
||||
break;
|
||||
case 0b00101001001000:
|
||||
case 0b01101001001000:
|
||||
einsum_0_t_bf16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data);
|
||||
einsum_0_t_bf16_nd.Process();
|
||||
break;
|
||||
case 0b10001001001000:
|
||||
case 0b11001001001000:
|
||||
einsum_1_n_bf16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data);
|
||||
einsum_1_n_bf16_nd.Process();
|
||||
break;
|
||||
case 0b10101001001000:
|
||||
case 0b11101001001000:
|
||||
einsum_1_t_bf16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data);
|
||||
einsum_1_t_bf16_nd.Process();
|
||||
break;
|
||||
|
||||
case 0b00000100100101:
|
||||
case 0b01000100100101:
|
||||
einsum_0_n_fp16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data);
|
||||
einsum_0_n_fp16_nz.Process();
|
||||
break;
|
||||
case 0b00100100100101:
|
||||
case 0b01100100100101:
|
||||
einsum_0_t_fp16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data);
|
||||
einsum_0_t_fp16_nz.Process();
|
||||
break;
|
||||
case 0b10000100100101:
|
||||
case 0b11000100100101:
|
||||
einsum_1_n_fp16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data);
|
||||
einsum_1_n_fp16_nz.Process();
|
||||
break;
|
||||
case 0b10100100100101:
|
||||
case 0b11100100100101:
|
||||
einsum_1_t_fp16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data);
|
||||
einsum_1_t_fp16_nz.Process();
|
||||
break;
|
||||
case 0b00001001001001:
|
||||
case 0b01001001001001:
|
||||
einsum_0_n_bf16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data);
|
||||
einsum_0_n_bf16_nz.Process();
|
||||
break;
|
||||
case 0b00101001001001:
|
||||
case 0b01101001001001:
|
||||
einsum_0_t_bf16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data);
|
||||
einsum_0_t_bf16_nz.Process();
|
||||
break;
|
||||
case 0b10001001001001:
|
||||
case 0b11001001001001:
|
||||
einsum_1_n_bf16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data);
|
||||
einsum_1_n_bf16_nz.Process();
|
||||
break;
|
||||
case 0b10101001001001:
|
||||
case 0b11101001001001:
|
||||
einsum_1_t_bf16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data);
|
||||
einsum_1_t_bf16_nz.Process();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
namespace vllm_ascend {
|
||||
|
||||
extern void batch_matmul_transpose_impl(
|
||||
void* stream,
|
||||
void* gm_a,
|
||||
void* gm_b,
|
||||
void* gm_c,
|
||||
void* gm_tiling_data,
|
||||
const uint32_t block_dim)
|
||||
{
|
||||
batch_matmul_transpose<<<block_dim, nullptr, stream>>>(
|
||||
gm_a,
|
||||
gm_b,
|
||||
gm_c,
|
||||
gm_tiling_data);
|
||||
}
|
||||
|
||||
}
|
||||
15
csrc/kernels/math_utils.h
Normal file
15
csrc/kernels/math_utils.h
Normal file
@@ -0,0 +1,15 @@
|
||||
#ifndef KERNEL_MATH_UTILS_H
|
||||
#define KERNEL_MATH_UTILS_H
|
||||
#include <cstdint>
|
||||
|
||||
namespace device_utils {
|
||||
|
||||
template <typename T, T roundVal>
|
||||
__aicore__ __force_inline__ T RoundUp(const T &val)
|
||||
{
|
||||
return (val + roundVal - 1) / roundVal * roundVal;
|
||||
}
|
||||
|
||||
}; // namespace device_utils
|
||||
|
||||
#endif
|
||||
@@ -158,4 +158,13 @@ namespace vllm_ascend {
|
||||
void* tiling,
|
||||
const uint32_t block_dim
|
||||
);
|
||||
|
||||
extern void batch_matmul_transpose_impl(
|
||||
void* stream,
|
||||
void* gm_a,
|
||||
void* gm_b,
|
||||
void* gm_c,
|
||||
void* gm_tiling_data,
|
||||
const uint32_t block_dim
|
||||
);
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
#include "ops.h"
|
||||
#include "utils.h"
|
||||
#include "mla_preprocess/op_host/mla_preprocess.h"
|
||||
#include "batch_matmul_transpose/op_host/batch_matmul_transpose.h"
|
||||
|
||||
namespace vllm_ascend {
|
||||
|
||||
@@ -458,6 +459,39 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
|
||||
cmd.Run();
|
||||
return y_out;
|
||||
}
|
||||
|
||||
void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
|
||||
c10::optional<c10::string_view> format_mode,
|
||||
c10::optional<c10::string_view> quant_mode)
|
||||
{
|
||||
auto [tiling_tensor, block_dim] = bmm_trans::batch_matmul_transpose_tiling(
|
||||
tensor_a,
|
||||
tensor_b,
|
||||
tensor_c,
|
||||
format_mode,
|
||||
quant_mode
|
||||
);
|
||||
|
||||
void *gm_a = tensor_a.data_ptr();
|
||||
void *gm_b = tensor_b.data_ptr();
|
||||
void *gm_c = tensor_c.data_ptr();
|
||||
void *gm_tiling_data = tiling_tensor.data_ptr();
|
||||
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("batch_matmul_transpose");
|
||||
|
||||
cmd.SetCustomHandler([stream, gm_a, gm_b, gm_c, gm_tiling_data,
|
||||
block_dim]() -> int {
|
||||
batch_matmul_transpose_impl(stream, gm_a, gm_b, gm_c, gm_tiling_data,
|
||||
block_dim);
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
return;
|
||||
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
@@ -511,4 +545,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
" Tensor q_out1, Tensor kv_cache_out1)"
|
||||
);
|
||||
ops.impl("mla_preprocess", torch::kPrivateUse1, &vllm_ascend::mla_preprocess);
|
||||
//batch_matmul ops refer to sgl-kernel-npu
|
||||
ops.def(
|
||||
"batch_matmul_transpose(Tensor tensor_a, Tensor tensor_b, Tensor tensor_c, str? format_mode=None, str? quant_mode=None) -> ()");
|
||||
ops.impl("batch_matmul_transpose", torch::kPrivateUse1, &vllm_ascend::batch_matmul_transpose);
|
||||
}
|
||||
|
||||
@@ -115,6 +115,14 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preproces
|
||||
}
|
||||
|
||||
|
||||
void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
|
||||
c10::optional<c10::string_view> format_mode,
|
||||
c10::optional<c10::string_view> quant_mode)
|
||||
{
|
||||
return;
|
||||
|
||||
}
|
||||
|
||||
} // namespace meta
|
||||
} // namespace vllm_ascend
|
||||
|
||||
@@ -132,5 +140,7 @@ namespace {
|
||||
ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta);
|
||||
// MLA preprocess
|
||||
ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess);
|
||||
// batch_matmul_transpose
|
||||
ops.impl("batch_matmul_transpose", &vllm_ascend::meta::batch_matmul_transpose);
|
||||
}
|
||||
}
|
||||
|
||||
141
tests/e2e/singlecard/ops/test_batch_matmul_transpose.py
Normal file
141
tests/e2e/singlecard/ops/test_batch_matmul_transpose.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
torch.set_printoptions(threshold=float("inf"))
|
||||
|
||||
|
||||
class TestMatrixMultiplication(unittest.TestCase):
|
||||
|
||||
def compute_golden(self, a, b, res1, m, n):
|
||||
"""Compute reference result (golden)"""
|
||||
torch.bmm(a.transpose(0, 1),
|
||||
b,
|
||||
out=res1.view(-1, m, n).transpose(0, 1))
|
||||
|
||||
def assert_tensors_almost_equal(self, actual, expected, dtype):
|
||||
"""Check if two tensors are approximately equal (considering floating point errors)"""
|
||||
self.assertEqual(actual.shape, expected.shape, "Shape mismatch")
|
||||
|
||||
# Check for NaN
|
||||
self.assertFalse(
|
||||
torch.isnan(actual).any(), "Actual result contains NaN")
|
||||
self.assertFalse(
|
||||
torch.isnan(expected).any(), "Expected result contains NaN")
|
||||
|
||||
# Check for Inf
|
||||
self.assertFalse(
|
||||
torch.isinf(actual).any(), "Actual result contains Inf")
|
||||
self.assertFalse(
|
||||
torch.isinf(expected).any(), "Expected result contains Inf")
|
||||
|
||||
# Set different tolerances based on data type
|
||||
if dtype == torch.float16:
|
||||
rtol, atol = 1e-5, 1e-5
|
||||
else: # bfloat16
|
||||
rtol, atol = 1.5e-5, 1.5e-5
|
||||
|
||||
# Compare values
|
||||
diff = torch.abs(actual - expected)
|
||||
max_diff = diff.max().item()
|
||||
max_expected = torch.abs(expected).max().item()
|
||||
|
||||
# Check relative and absolute errors
|
||||
if max_expected > 0:
|
||||
relative_diff = max_diff / max_expected
|
||||
self.assertLessEqual(
|
||||
relative_diff,
|
||||
rtol,
|
||||
f"Relative error too large: {relative_diff} > {rtol}. Max difference: {max_diff}",
|
||||
)
|
||||
|
||||
self.assertLessEqual(max_diff, atol,
|
||||
f"Absolute error too large: {max_diff} > {atol}")
|
||||
|
||||
def test_boundary_conditions(self):
|
||||
"""Test boundary conditions"""
|
||||
test_cases = [
|
||||
# (b, m, k, n)
|
||||
(1, 1, 1, 1), # Minimum size
|
||||
(1, 10, 1, 1), # b=1
|
||||
(10, 1, 1, 10), # m=1
|
||||
(5, 5, 1, 5), # k=1
|
||||
(2, 2, 2, 1), # n=1
|
||||
(100, 1, 1, 100), # Flat case
|
||||
(1, 100, 100, 1), # Flat case
|
||||
(2, 3, 4, 5), # Random small size
|
||||
(10, 20, 30, 40), # Medium size
|
||||
(36, 128, 512, 128), # target case
|
||||
(8, 160, 512, 128),
|
||||
]
|
||||
|
||||
dtypes = [torch.float16, torch.bfloat16]
|
||||
|
||||
for dtype in dtypes:
|
||||
for b, m, k, n in test_cases:
|
||||
with self.subTest(dtype=dtype, shape=f"({b}, {m}, {k}, {n})"):
|
||||
a = torch.randn(b, m, k, dtype=dtype, device="npu")
|
||||
b_tensor = torch.randn(m, k, n, dtype=dtype, device="npu")
|
||||
res1 = torch.empty((b, m * n), dtype=dtype, device="npu")
|
||||
res2 = torch.empty((b, m, n), dtype=dtype, device="npu")
|
||||
|
||||
self.compute_golden(a, b_tensor, res1, m, n)
|
||||
torch.ops._C_ascend.batch_matmul_transpose(
|
||||
a, b_tensor, res2)
|
||||
|
||||
self.assert_tensors_almost_equal(res1.view(-1, m, n), res2,
|
||||
dtype)
|
||||
|
||||
def test_random_shapes(self):
|
||||
"""Test randomly generated shapes"""
|
||||
num_tests = 1
|
||||
dtypes = [torch.float16, torch.bfloat16]
|
||||
|
||||
for dtype in dtypes:
|
||||
for _ in range(num_tests):
|
||||
# Generate reasonable random sizes
|
||||
b = random.randint(1, 500)
|
||||
m = random.randint(1, 500)
|
||||
k = random.randint(1, 500)
|
||||
n = random.randint(1, 500)
|
||||
|
||||
with self.subTest(dtype=dtype,
|
||||
shape=f"Random ({b}, {m}, {k}, {n})"):
|
||||
a = torch.randn(b, m, k, dtype=dtype, device="npu")
|
||||
b_tensor = torch.randn(m, k, n, dtype=dtype, device="npu")
|
||||
res1 = torch.empty((b, m * n), dtype=dtype, device="npu")
|
||||
res2 = torch.empty((b, m, n), dtype=dtype, device="npu")
|
||||
|
||||
self.compute_golden(a, b_tensor, res1, m, n)
|
||||
torch.ops._C_ascend.batch_matmul_transpose(
|
||||
a, b_tensor, res2)
|
||||
self.assert_tensors_almost_equal(res1.view(-1, m, n), res2,
|
||||
dtype)
|
||||
|
||||
def test_zero_values(self):
|
||||
"""Test zero input values"""
|
||||
dtypes = [torch.float16, torch.bfloat16]
|
||||
b, m, k, n = 5, 4, 3, 2
|
||||
|
||||
for dtype in dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
a = torch.zeros(b, m, k, dtype=dtype, device="npu")
|
||||
b_tensor = torch.zeros(m, k, n, dtype=dtype, device="npu")
|
||||
res1 = torch.empty((b, m * n), dtype=dtype, device="npu")
|
||||
res2 = torch.empty((b, m, n), dtype=dtype, device="npu")
|
||||
|
||||
self.compute_golden(a, b_tensor, res1, m, n)
|
||||
torch.ops._C_ascend.batch_matmul_transpose(a, b_tensor, res2)
|
||||
|
||||
self.assert_tensors_almost_equal(res1.view(-1, m, n), res2,
|
||||
dtype)
|
||||
self.assertTrue(torch.all(res2 == 0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
@@ -19,7 +19,7 @@ locale = "en"
|
||||
extend-ignore-identifiers-re = [".*Unc.*", ".*_thw",
|
||||
".*UE8M0.*", ".*[UE4M3|ue4m3].*", ".*eles.*", ".*fo.*", ".*ba.*",
|
||||
".*ot.*", ".*[Tt]h[rR].*"]
|
||||
extend-ignore-words-re = ["CANN", "cann"]
|
||||
extend-ignore-words-re = ["CANN", "cann","ND"]
|
||||
extend-ignore-re = []
|
||||
|
||||
[default.extend-identifiers]
|
||||
|
||||
@@ -565,14 +565,15 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
|
||||
if x.dtype in [torch.float16, torch.bfloat16] \
|
||||
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose"):
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
x = torch_npu.npu_transpose_batchmatmul(x,
|
||||
self.W_UV,
|
||||
perm_x1=[1, 0, 2],
|
||||
perm_x2=[0, 1, 2],
|
||||
perm_y=[1, 0, 2])
|
||||
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
b, _, _ = x.shape
|
||||
res = torch.empty((b, self.num_heads, self.v_head_dim),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
|
||||
x = res.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
else:
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
|
||||
Reference in New Issue
Block a user