[Kernel] add custom op GmmSwigluQuantWeightNzTensorList (#3804)
### What this PR does / why we need it? This PR introduces support for adding custom CANN `aclnn` ops to `vllm-ascend`, allowing users to define and use their own custom operators. Key changes include: - Building and installing custom ops into the `vllm-ascend`-specified directory - Binding the `aclnn` op interface to the `torch.ops._C_ascend` module - Enabling invocation of these ops within `vllm-ascend` This PR includes a sample custom op: `aclnnGroupedMatmulSwigluQuantWeightNzTensorList`, which is adapted from the CANN operator [`aclnnGroupedMatmulSwigluQuantWeightNZ`](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/API/aolapi/context/aclnnGroupedMatmulSwigluQuantWeightNZ.md). Its input parameters `weight` and `weight_scale` now accept `list[torch.Tensor]` (i.e., `at::TensorList`). ### Does this PR introduce _any_ user-facing change? No. - vLLM version: v0.11.2 --------- Signed-off-by: QianChenxi <chenxi.qian.cq@outlook.com>
This commit is contained in:
14
csrc/utils/inc/aclnn_util.h
Normal file
14
csrc/utils/inc/aclnn_util.h
Normal file
@@ -0,0 +1,14 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef OP_API_INC_ACLNN_UTIL_H
|
||||
#define OP_API_INC_ACLNN_UTIL_H
|
||||
|
||||
#define ACLNN_API __attribute__((visibility("default")))
|
||||
#endif // OP_API_INC_ACLNN_UTIL_H
|
||||
25
csrc/utils/inc/error/ops_error.h
Normal file
25
csrc/utils/inc/error/ops_error.h
Normal file
@@ -0,0 +1,25 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file ops_error.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "log/ops_log.h"
|
||||
|
||||
/* 基础报错 */
|
||||
#define OPS_REPORT_VECTOR_INNER_ERR(OPS_DESC, ...) OPS_INNER_ERR_STUB("E89999", OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_REPORT_CUBE_INNER_ERR(OPS_DESC, ...) OPS_INNER_ERR_STUB("E69999", OPS_DESC, __VA_ARGS__)
|
||||
|
||||
/* 条件报错 */
|
||||
#define OPS_ERR_IF(COND, LOG_FUNC, EXPR) OPS_LOG_STUB_IF(COND, LOG_FUNC, EXPR)
|
||||
497
csrc/utils/inc/fallback.h
Normal file
497
csrc/utils/inc/fallback.h
Normal file
@@ -0,0 +1,497 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file fallback.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef ACLNNFALLBACK_OPAPI_H_
|
||||
#define ACLNNFALLBACK_OPAPI_H_
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "aclnn/aclnn_base.h"
|
||||
#include "fallback_comm.h"
|
||||
#include "error/ops_error.h"
|
||||
#include "runtime/base.h"
|
||||
|
||||
namespace fallback {
|
||||
using namespace std;
|
||||
using namespace gert;
|
||||
using namespace ge;
|
||||
using namespace std;
|
||||
|
||||
namespace std_utils {
|
||||
template <std::size_t... Is>
|
||||
struct index_sequence {};
|
||||
|
||||
template <std::size_t N, std::size_t... Is>
|
||||
struct make_index_sequence_helper : make_index_sequence_helper<N - 1, N - 1, Is...> {};
|
||||
|
||||
template <std::size_t... Is>
|
||||
struct make_index_sequence_helper<0, Is...> {
|
||||
using type = index_sequence<Is...>;
|
||||
};
|
||||
|
||||
template <std::size_t N>
|
||||
using make_index_sequence = typename make_index_sequence_helper<N>::type;
|
||||
}
|
||||
|
||||
using aclOpExecutor = struct aclOpExecutor;
|
||||
using aclTensor = struct aclTensor;
|
||||
using aclScalar = struct aclScalar;
|
||||
using aclIntArray = struct aclIntArray;
|
||||
using aclFloatArray = struct aclFloatArray;
|
||||
using aclBoolArray = struct aclBoolArray;
|
||||
using aclTensorList = struct aclTensorList;
|
||||
|
||||
using _aclCreateTensor = aclTensor* (*)(const int64_t* view_dims, uint64_t view_dims_num, aclDataType data_type,
|
||||
const int64_t* stride, int64_t offset, aclFormat format,
|
||||
const int64_t* storage_dims, uint64_t storage_dims_num, void* tensor_data);
|
||||
|
||||
using _aclCreateScalar = aclScalar* (*)(void* value, aclDataType data_type);
|
||||
using _aclCreateIntArray = aclIntArray* (*)(const int64_t* value, uint64_t size);
|
||||
using _aclCreateFloatArray = aclFloatArray* (*)(const float* value, uint64_t size);
|
||||
using _aclCreateBoolArray = aclBoolArray* (*)(const bool* value, uint64_t size);
|
||||
using _aclCreateTensorList = aclTensorList* (*)(const aclTensor* const* value, uint64_t size);
|
||||
|
||||
using _aclDestroyTensor = int (*)(const aclTensor* tensor);
|
||||
using _aclDestroyScalar = int (*)(const aclScalar* scalar);
|
||||
using _aclDestroyIntArray = int (*)(const aclIntArray* array);
|
||||
using _aclDestroyFloatArray = int (*)(const aclFloatArray* array);
|
||||
using _aclDestroyBoolArray = int (*)(const aclBoolArray* array);
|
||||
using _aclDestroyTensorList = int (*)(const aclTensorList* array);
|
||||
|
||||
#define GET_OP_API_FUNC(apiName) reinterpret_cast<_##apiName>(GetOpApiFuncAddr(#apiName))
|
||||
|
||||
inline const char* GetOpApiLibName(void) {
|
||||
return "libopapi.so";
|
||||
}
|
||||
|
||||
inline const char* GetCustOpApiLibName(void) {
|
||||
return "libcust_opapi.so";
|
||||
}
|
||||
|
||||
inline void* GetOpApiFuncAddrInLib(void* handler, const char* libName, const char* apiName) {
|
||||
auto funcAddr = dlsym(handler, apiName);
|
||||
if (funcAddr == nullptr) {
|
||||
OPS_LOG_W("aclnnfallback", "dlsym %s from %s failed, error:%s.", apiName, libName, dlerror());
|
||||
}
|
||||
return funcAddr;
|
||||
}
|
||||
|
||||
inline void* GetOpApiLibHandler(const char* libName) {
|
||||
auto handler = dlopen(libName, RTLD_LAZY);
|
||||
if (handler == nullptr) {
|
||||
OPS_LOG_W("aclnnfallback", "dlopen %s failed, error:%s.", libName, dlerror());
|
||||
}
|
||||
return handler;
|
||||
}
|
||||
|
||||
inline void* GetAclnnArrdByApiName(const char *apiName) {
|
||||
vector<std:: string> libs = {"libaclnn_ops_infer.so", "libaclnn_ops_train.so", "libaclnn_math.so",
|
||||
"libaclnn_rand.so", "libaclnn_sparse.so", "libaclnn_fft.so"};
|
||||
for (const auto &libName : libs) {
|
||||
static auto libHandler = GetOpApiLibHandler(libName.c_str());
|
||||
if (libHandler != nullptr) {
|
||||
auto funcAddr = GetOpApiFuncAddrInLib(libHandler, libName.c_str(), apiName);
|
||||
if (funcAddr != nullptr) {
|
||||
return funcAddr;
|
||||
}
|
||||
}
|
||||
}
|
||||
OPS_LOG_E("aclnnfallback", "api %s can't find in any aclnn lib.", apiName);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline void* GetOpApiFuncAddr(const char* apiName) {
|
||||
static auto custOpApiHandler = GetOpApiLibHandler(GetCustOpApiLibName());
|
||||
if (custOpApiHandler != nullptr) {
|
||||
auto funcAddr = GetOpApiFuncAddrInLib(custOpApiHandler, GetCustOpApiLibName(), apiName);
|
||||
if (funcAddr != nullptr) {
|
||||
return funcAddr;
|
||||
}
|
||||
}
|
||||
|
||||
static auto opApiHandler = GetOpApiLibHandler(GetOpApiLibName());
|
||||
if (opApiHandler != nullptr) {
|
||||
auto funcAddr = GetOpApiFuncAddrInLib(opApiHandler, GetOpApiLibName(), apiName);
|
||||
if (funcAddr != nullptr) {
|
||||
return funcAddr;
|
||||
}
|
||||
}
|
||||
OPS_LOG_D("aclnnfallback", "opapi lib is not exist,will use aclnn lib.");
|
||||
return GetAclnnArrdByApiName(apiName);
|
||||
}
|
||||
|
||||
inline aclTensor* ConvertType(aclTensor* ge_tensor) {
|
||||
return ge_tensor;
|
||||
}
|
||||
|
||||
inline aclIntArray* ConvertType(const std::vector<int64_t> &arr) {
|
||||
if (arr.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
static const auto aclCreateIntArray = GET_OP_API_FUNC(aclCreateIntArray);
|
||||
auto array = aclCreateIntArray(arr.data(), arr.size());
|
||||
return array;
|
||||
}
|
||||
|
||||
inline aclDataType GetConvertType(const gert::Tensor* ge_tensor) {
|
||||
// convert data type
|
||||
auto dataType_ge = ge_tensor->GetDataType();
|
||||
auto dataType = aclDataType::ACL_FLOAT16;
|
||||
if (dataType_ge == DT_FLOAT) {
|
||||
dataType = aclDataType::ACL_FLOAT;
|
||||
} else if (dataType_ge == DT_BF16) {
|
||||
dataType = aclDataType::ACL_BF16;
|
||||
} else if (dataType_ge == DT_BOOL) {
|
||||
dataType = aclDataType::ACL_BOOL;
|
||||
} else if (dataType_ge == DT_INT64) {
|
||||
dataType = aclDataType::ACL_INT64;
|
||||
} else if (dataType_ge == DT_INT32) {
|
||||
dataType = aclDataType::ACL_INT32;
|
||||
} else if (dataType_ge == DT_UINT64) {
|
||||
dataType = aclDataType::ACL_UINT64;
|
||||
} else if (dataType_ge == DT_UINT32) {
|
||||
dataType = aclDataType::ACL_UINT32;
|
||||
} else if (dataType_ge == DT_INT8) {
|
||||
dataType = aclDataType::ACL_INT8;
|
||||
} else if (dataType_ge == DT_UINT8) {
|
||||
dataType = aclDataType::ACL_UINT8;
|
||||
} else if (dataType_ge == DT_INT4) {
|
||||
dataType = aclDataType::ACL_INT4;
|
||||
} else {
|
||||
dataType = aclDataType::ACL_FLOAT16;
|
||||
}
|
||||
|
||||
return dataType;
|
||||
}
|
||||
|
||||
inline aclTensor* ConvertType(const gert::Tensor* ge_tensor) {
|
||||
if (ge_tensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
|
||||
OPS_ERR_IF(aclCreateTensor == nullptr, OPS_LOG_E("aclnnfallback", "aclCreateTensor nullptr"), return nullptr);
|
||||
|
||||
void* device_addr = nullptr;
|
||||
auto tensor_place = ge_tensor->GetPlacement();
|
||||
device_addr = const_cast<void*>(ge_tensor->GetAddr());
|
||||
|
||||
auto dataType = GetConvertType(ge_tensor);
|
||||
|
||||
OPS_LOG_D("aclnnfallback", "aclCreateTensor: tensor type is %d", dataType);
|
||||
|
||||
// convert shape
|
||||
auto gert_shape = ge_tensor->GetStorageShape();
|
||||
std::vector<int64_t> shape;
|
||||
for (size_t i = 0; i < gert_shape.GetDimNum(); ++i) {
|
||||
shape.push_back(gert_shape.GetDim(i));
|
||||
}
|
||||
|
||||
// 计算连续tensor的strides
|
||||
std::vector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t i = shape.size() - 2; i >= 0; i--) {
|
||||
strides[i] = shape[i + 1] * strides[i + 1];
|
||||
}
|
||||
|
||||
aclTensor* out = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(),
|
||||
0, aclFormat::ACL_FORMAT_ND,
|
||||
shape.data(), shape.size(), device_addr);
|
||||
|
||||
OPS_ERR_IF(out == nullptr,
|
||||
OPS_LOG_E("aclnnfallback", "out nullptr"), return nullptr);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
inline aclTensorList* ConvertType(std::vector<const gert::Tensor*>& ge_tenserList) {
|
||||
OPS_ERR_IF(ge_tenserList.size() == 0,
|
||||
OPS_LOG_E("aclnnfallback", "ge_tenserList size 0"), return nullptr);
|
||||
|
||||
static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList);
|
||||
OPS_ERR_IF(aclCreateTensorList == nullptr,
|
||||
OPS_LOG_E("aclnnfallback", "ge_tenserList size 0"), return nullptr);
|
||||
|
||||
std::vector<aclTensor*> tmp;
|
||||
for (size_t i = 0; i < ge_tenserList.size(); i++) {
|
||||
auto t_acl = ConvertType(ge_tenserList[i]);
|
||||
tmp.push_back(t_acl);
|
||||
}
|
||||
|
||||
aclTensorList* tensorList = aclCreateTensorList(tmp.data(), tmp.size());
|
||||
return tensorList;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline aclScalar* ConvertScalarType(T value) {
|
||||
static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar);
|
||||
OPS_ERR_IF(aclCreateScalar == nullptr,
|
||||
OPS_LOG_E("aclnnfallback", "aclCreateScalar nullptr"), return nullptr);
|
||||
if (typeid(value) == typeid(float)) {
|
||||
return aclCreateScalar(&value, aclDataType::ACL_FLOAT);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T ConvertType(T value) {
|
||||
return value;
|
||||
}
|
||||
|
||||
inline aclTensor* ConvertMmType(const gert::Tensor* ge_tensor, bool transpose, bool enable_NZ=false) {
|
||||
if (ge_tensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto gert_shape = ge_tensor->GetStorageShape();
|
||||
if (gert_shape.GetDimNum() <= 1) {
|
||||
return ConvertType(ge_tensor);
|
||||
}
|
||||
|
||||
static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
|
||||
OPS_ERR_IF(aclCreateTensor == nullptr, OPS_LOG_E("aclnnfallback", "aclCreateTensor nullptr"), return nullptr);
|
||||
|
||||
void* device_addr = const_cast<void*>(ge_tensor->GetAddr());
|
||||
// convert data type
|
||||
auto dataType_ge = ge_tensor->GetDataType();
|
||||
auto dataType = ToAclDataType(dataType_ge);
|
||||
// convert shape
|
||||
std::vector<int64_t> shape;
|
||||
for (size_t i = 0; i < gert_shape.GetDimNum(); ++i) {
|
||||
shape.push_back(gert_shape.GetDim(i));
|
||||
}
|
||||
// 计算连续tensor的strides
|
||||
std::vector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t i = shape.size() - 2; i >= 0; i--) {
|
||||
strides[i] = shape[i + 1] * strides[i + 1];
|
||||
}
|
||||
|
||||
auto viewShape = shape;
|
||||
// 对于transpose后的tensor对后两维度进行strides, viewShape转换
|
||||
if (transpose) {
|
||||
// dimM 为倒数第二维, dimN 为倒数第一维度
|
||||
auto dimM = shape.size() - 2;
|
||||
auto dimN = shape.size() - 1;
|
||||
auto swap = strides[dimN];
|
||||
strides[dimN] = strides[dimM];
|
||||
strides[dimM] = swap;
|
||||
// 修改viewShape
|
||||
viewShape[dimN] = shape[dimM];
|
||||
viewShape[dimM] = shape[dimN];
|
||||
}
|
||||
auto acl_format = aclFormat::ACL_FORMAT_ND;
|
||||
if (enable_NZ && GetPrimaryFormat(ge_tensor->GetStorageFormat()) == ge::Format::FORMAT_FRACTAL_NZ) {
|
||||
acl_format = aclFormat::ACL_FORMAT_FRACTAL_NZ;
|
||||
}
|
||||
aclTensor* out = aclCreateTensor(viewShape.data(), shape.size(), dataType, strides.data(),
|
||||
0, acl_format, shape.data(), shape.size(), device_addr);
|
||||
OPS_ERR_IF(out == nullptr, OPS_LOG_E("aclnnfallback", "out nullptr"), return nullptr);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
inline void Release(aclTensor* p) {
|
||||
static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor);
|
||||
OPS_ERR_IF(aclDestroyTensor == nullptr,
|
||||
OPS_LOG_E("aclnnfallback", "aclDestroyTensor is null"), return);
|
||||
aclDestroyTensor(p);
|
||||
}
|
||||
|
||||
inline void Release(aclScalar* p) {
|
||||
static const auto aclDestroyScalar = GET_OP_API_FUNC(aclDestroyScalar);
|
||||
OPS_ERR_IF(aclDestroyScalar == nullptr,
|
||||
OPS_LOG_E("aclnnfallback", "aclDestroyScalar is null"), return);
|
||||
aclDestroyScalar(p);
|
||||
}
|
||||
|
||||
inline void Release(aclIntArray* p) {
|
||||
static const auto aclDestroyIntArray = GET_OP_API_FUNC(aclDestroyIntArray);
|
||||
OPS_ERR_IF(aclDestroyIntArray == nullptr,
|
||||
OPS_LOG_E("aclnnfallback", "aclDestroyIntArray is null"), return);
|
||||
aclDestroyIntArray(p);
|
||||
}
|
||||
|
||||
inline void Release(aclBoolArray* p) {
|
||||
static const auto aclDestroyBoolArray = GET_OP_API_FUNC(aclDestroyBoolArray);
|
||||
OPS_ERR_IF(aclDestroyBoolArray == nullptr,
|
||||
OPS_LOG_E("aclnnfallback", "aclDestroyBoolArray is null"), return);
|
||||
aclDestroyBoolArray(p);
|
||||
}
|
||||
|
||||
inline void Release(aclTensorList* p) {
|
||||
static const auto aclDestroyTensorList = GET_OP_API_FUNC(aclDestroyTensorList);
|
||||
OPS_ERR_IF(aclDestroyTensorList == nullptr,
|
||||
OPS_LOG_E("aclnnfallback", "aclDestroyTensorList is null"), return);
|
||||
aclDestroyTensorList(p);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Release(T value) {
|
||||
(void)value;
|
||||
}
|
||||
|
||||
template <typename Tuple, size_t... I>
|
||||
void CallRelease(Tuple t, std_utils::index_sequence<I...>) {
|
||||
(void)std::initializer_list<int>{(Release(std::get<I>(t)), 0)...};
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
void ReleaseConvertTypes(Tuple& t) {
|
||||
static constexpr auto size = std::tuple_size<Tuple>::value;
|
||||
CallRelease(t, std_utils::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
auto ConvertTypes(Ts&... args) -> decltype(std::make_tuple(ConvertType(args)...)) {
|
||||
auto tp = std::make_tuple(ConvertType(args)...);
|
||||
return tp;
|
||||
}
|
||||
|
||||
template <typename Function, typename Tuple, size_t... I>
|
||||
auto call(Function f, Tuple t, std_utils::index_sequence<I...>) -> int {
|
||||
return f(std::get<I>(t)...);
|
||||
}
|
||||
|
||||
template <typename Function, typename Tuple>
|
||||
auto call(Function f, Tuple t) -> int {
|
||||
static constexpr auto size = std::tuple_size<Tuple>::value;
|
||||
return call(f, t, std_utils::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
template <typename Tuple, size_t... I>
|
||||
auto ConvertToOpApiFunc(const Tuple& params, void* opApiAddr, std_utils::index_sequence<I...>)
|
||||
-> int (*)(typename std::decay<decltype(std::get<I>(params))>::type...) {
|
||||
using OpApiFunc = int (*)(typename std::decay<decltype(std::get<I>(params))>::type...);
|
||||
auto func = reinterpret_cast<OpApiFunc>(opApiAddr);
|
||||
return func;
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
auto ConvertToOpApiFunc(const Tuple& params, void* opApiAddr)
|
||||
-> typename std::enable_if<std::tuple_size<Tuple>::value != 0,
|
||||
decltype(ConvertToOpApiFunc(params, opApiAddr, std_utils::make_index_sequence<std::tuple_size<Tuple>::value>{}))>::type {
|
||||
static constexpr auto size = std::tuple_size<Tuple>::value;
|
||||
return ConvertToOpApiFunc(params, opApiAddr, std_utils::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
class ConvertedParams {
|
||||
public:
|
||||
ConvertedParams(Tuple&& convertedParams) : convertedParams_(std::move(convertedParams)){};
|
||||
ConvertedParams(ConvertedParams&& other) : convertedParams_(std::move(other.convertedParams_)) {
|
||||
other.validParams_ = false;
|
||||
};
|
||||
ConvertedParams& operator=(ConvertedParams&& other) {
|
||||
if (this == &other) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
convertedParams_ = std::move(other.convertedParams_);
|
||||
validParams_ = true;
|
||||
other.validParams_ = false;
|
||||
return *this;
|
||||
}
|
||||
|
||||
ConvertedParams() = delete;
|
||||
ConvertedParams(const ConvertedParams& other) = delete;
|
||||
ConvertedParams& operator=(const ConvertedParams& other) = delete;
|
||||
|
||||
~ConvertedParams() {
|
||||
if (validParams_) {
|
||||
ReleaseConvertTypes(convertedParams_);
|
||||
}
|
||||
}
|
||||
|
||||
const Tuple& GetConvertedParams() const {
|
||||
return convertedParams_;
|
||||
}
|
||||
|
||||
private:
|
||||
Tuple convertedParams_;
|
||||
bool validParams_{true};
|
||||
};
|
||||
|
||||
using InitHugeMemThreadLocal = int (*)(void*, bool);
|
||||
using UnInitHugeMemThreadLocal = void (*)(void*, bool);
|
||||
using ReleaseHugeMem = void (*)(void*, bool);
|
||||
using PTAGetExecCache = aclOpExecutor* (*)(uint64_t, uint64_t*);
|
||||
using InitPTACacheThreadLocal = void (*)();
|
||||
using SetPTAHashKey = void (*)(uint64_t);
|
||||
using CanUsePTACache = bool (*)(const char*);
|
||||
|
||||
using ResetCacheThreadLocal = void (*)();
|
||||
|
||||
#define EXEC_OPAPI_CMD(aclnn_api, ...) \
|
||||
({ \
|
||||
static auto ret = GRAPH_SUCCESS; \
|
||||
do { \
|
||||
static const auto ResetCacheThreadLocalAddr = GetOpApiFuncAddr("ResetCacheThreadLocal"); \
|
||||
static const auto getWorkspaceSizeFuncAddr = GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \
|
||||
static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \
|
||||
if (getWorkspaceSizeFuncAddr == nullptr || opApiFuncAddr == nullptr || ResetCacheThreadLocalAddr == nullptr) { \
|
||||
OPS_LOG_E("aclnnfallback", "%s or %s not in %s or %s or ResetCacheThreadLocal not found.", \
|
||||
#aclnn_api "GetWorkspaceSize", #aclnn_api, GetOpApiLibName(), GetOpApiLibName()); \
|
||||
ret = GRAPH_FAILED; \
|
||||
break; \
|
||||
} \
|
||||
auto ResetCacheThreadLocalFunc = reinterpret_cast<ResetCacheThreadLocal>(ResetCacheThreadLocalAddr); \
|
||||
ResetCacheThreadLocalFunc(); \
|
||||
uint64_t workspace_size = 0; \
|
||||
uint64_t* workspace_size_addr = &workspace_size; \
|
||||
aclOpExecutor* executor = nullptr; \
|
||||
aclOpExecutor** executor_addr = &executor; \
|
||||
auto converted_params = ConvertTypes(__VA_ARGS__, workspace_size_addr, executor_addr); \
|
||||
static auto getWorkspaceSizeFunc = ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \
|
||||
auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \
|
||||
if (workspace_status != 0) { \
|
||||
OPS_LOG_E("aclnnfallback", "call %s failed:", #aclnn_api); \
|
||||
ret = GRAPH_FAILED; \
|
||||
break; \
|
||||
} \
|
||||
void* workspace_addr = nullptr; \
|
||||
if (workspace_size > 0) { \
|
||||
workspace_addr = host_api_ctx->MallocWorkspace(workspace_size); \
|
||||
if (workspace_addr == nullptr) { \
|
||||
OPS_LOG_E("aclnnfallback", "call %s allocate workspace failed", #aclnn_api); \
|
||||
ret = GRAPH_FAILED; \
|
||||
break; \
|
||||
} \
|
||||
} \
|
||||
auto acl_stream = host_api_ctx->GetStream(); \
|
||||
auto acl_call = [converted_params, workspace_addr, workspace_size, host_api_ctx, acl_stream, \
|
||||
executor]() -> int { \
|
||||
using OpApiFunc = int (*)(void*, uint64_t, aclOpExecutor*, const aclrtStream); \
|
||||
OpApiFunc opApiFunc = reinterpret_cast<OpApiFunc>(opApiFuncAddr); \
|
||||
auto api_ret = opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \
|
||||
ReleaseConvertTypes(converted_params); \
|
||||
host_api_ctx->FreeWorkspace(); \
|
||||
if (api_ret != 0) { \
|
||||
OPS_LOG_E("aclnnfallback", "call %s allocate workspace failed api_ret: %d", #aclnn_api, api_ret); \
|
||||
return GRAPH_FAILED; \
|
||||
} \
|
||||
return api_ret; \
|
||||
}; \
|
||||
\
|
||||
ret = acl_call(); \
|
||||
} while (false); \
|
||||
(ret); \
|
||||
})
|
||||
|
||||
} // namespace fallback
|
||||
|
||||
#endif // ACLNNFALLBACK_OPAPI_H_
|
||||
38
csrc/utils/inc/fallback_comm.h
Normal file
38
csrc/utils/inc/fallback_comm.h
Normal file
@@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file fallback_comm.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef INC_EXTERNAL_GRAPH_FALLBACK_COMMON_H_
|
||||
#define INC_EXTERNAL_GRAPH_FALLBACK_COMMON_H_
|
||||
|
||||
#include "aclnn/aclnn_base.h"
|
||||
#include "exe_graph/runtime/op_execute_context.h"
|
||||
#include "exe_graph/runtime/tensor.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "runtime/base.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
namespace fallback {
|
||||
|
||||
aclDataType ToAclDataType(ge::DataType dtype);
|
||||
} // namespace fallback
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // INC_EXTERNAL_GRAPH_FALLBACK_COMMON_H_
|
||||
121
csrc/utils/inc/kernel/dropmask.h
Normal file
121
csrc/utils/inc/kernel/dropmask.h
Normal file
@@ -0,0 +1,121 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file dropmask.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef DROPMASK_H
|
||||
#define DROPMASK_H
|
||||
|
||||
#include "util.h"
|
||||
|
||||
using AscendC::DROPOUT_MODE_BIT_MISALIGN;
|
||||
using AscendC::DropOutShapeInfo;
|
||||
using AscendC::DropOut;
|
||||
|
||||
struct DropMaskInfo {
|
||||
// for compute dropout mask offset
|
||||
// 参数按B N G S1 S2全部切分设置进行偏移计算,没有切分的轴对应的参数设置为合适的0或者原始值
|
||||
int64_t n2G; // n2 * g
|
||||
int64_t gSize; // g
|
||||
int64_t s1Size; // s1
|
||||
int64_t s2Size; // s2
|
||||
int64_t gOutIdx; // g out index
|
||||
int64_t bSSOffset; // boidx * s1 * s2 ===bSSOffset
|
||||
int64_t n2OutIdx; // n out index
|
||||
int64_t s1OutIdx; // s1 out index ===s1oIdx
|
||||
int64_t s1InnerIdx; // s1 inner index, 配比 ===loopIdx
|
||||
int64_t s1BaseSize; // S1基本块大小
|
||||
int64_t splitS1BaseSize; // s1 split size ===vec1S1BaseSize
|
||||
int64_t s2StartIdx; // s2 start index
|
||||
int64_t s2Idx; // s2 index =====s2LoopCount
|
||||
int64_t s2BaseNratioSize; // s2的配比长度: s2BaseSize(S2基本块大小) * nRatio
|
||||
|
||||
// for copy in dropout mask
|
||||
uint32_t s1CopySize;
|
||||
uint32_t s2CopySize;
|
||||
int64_t s2TotalSize;
|
||||
|
||||
// for compute dropout mask
|
||||
uint32_t firstAxis;
|
||||
uint32_t lstAxis;
|
||||
uint32_t maskLstAxis;
|
||||
int64_t vecCoreOffset = 0;
|
||||
float keepProb;
|
||||
|
||||
bool boolMode;
|
||||
};
|
||||
|
||||
template <bool hasDrop>
|
||||
__aicore__ inline int64_t ComputeDropOffset(DropMaskInfo &dropMaskInfo)
|
||||
{
|
||||
if constexpr (hasDrop == true) {
|
||||
// boidx * n2 * g* s1 * s2
|
||||
int64_t bOffset = dropMaskInfo.bSSOffset * dropMaskInfo.n2G;
|
||||
// n2oIdx * g * s1 *s2
|
||||
int64_t n2Offset = dropMaskInfo.n2OutIdx * dropMaskInfo.gSize * dropMaskInfo.s1Size * dropMaskInfo.s2Size;
|
||||
// goIdx * s1 * s2
|
||||
int64_t gOffset = dropMaskInfo.gOutIdx * dropMaskInfo.s1Size * dropMaskInfo.s2Size;
|
||||
// s1oIdx * s1BaseSize * s2Size + s1innerindex * vec1S1BaseSize * s2Size
|
||||
int64_t s1Offset = (dropMaskInfo.s1OutIdx * dropMaskInfo.s1BaseSize + dropMaskInfo.vecCoreOffset +
|
||||
dropMaskInfo.s1InnerIdx * dropMaskInfo.splitS1BaseSize) * dropMaskInfo.s2Size;
|
||||
// s2StartIdx + s2index * s2BaseNratioSize
|
||||
int64_t s2Offset = dropMaskInfo.s2StartIdx + dropMaskInfo.s2Idx * dropMaskInfo.s2BaseNratioSize;
|
||||
return bOffset + n2Offset + gOffset + s1Offset + s2Offset;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool hasDrop>
|
||||
__aicore__ inline void CopyInDropMask(LocalTensor<uint8_t>&dstTensor, GlobalTensor<uint8_t>& srcBoolTensor,
|
||||
GlobalTensor<uint8_t>& srcByteTensor, DropMaskInfo &dropMaskInfo, int64_t alignedSize = blockBytes)
|
||||
{
|
||||
if constexpr (hasDrop == true) {
|
||||
int64_t dropMaskOffset = ComputeDropOffset<hasDrop>(dropMaskInfo);
|
||||
if (unlikely(dropMaskInfo.boolMode)) {
|
||||
BoolCopyIn(dstTensor, srcBoolTensor, dropMaskOffset,
|
||||
dropMaskInfo.s1CopySize, dropMaskInfo.s2CopySize, dropMaskInfo.s2TotalSize, alignedSize);
|
||||
} else {
|
||||
Bit2Int8CopyIn(dstTensor, srcByteTensor, dropMaskOffset, 1,
|
||||
dropMaskInfo.s1CopySize, dropMaskInfo.s2CopySize, dropMaskInfo.s2TotalSize, alignedSize);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasDrop>
|
||||
__aicore__ inline void ComputeDropMask(LocalTensor<T>& dstTensor, LocalTensor<T>& srcTensor,
|
||||
LocalTensor<uint8_t>& dropoutBuffer, LocalTensor<uint8_t>& tmpDropBuffer, DropMaskInfo &dropMaskInfo)
|
||||
{
|
||||
if constexpr (hasDrop == true) {
|
||||
DropOutShapeInfo dropOutShapeInfo;
|
||||
dropOutShapeInfo.firstAxis = dropMaskInfo.firstAxis;
|
||||
dropOutShapeInfo.srcLastAxis = dropMaskInfo.lstAxis;
|
||||
|
||||
if (unlikely(dropMaskInfo.boolMode)) {
|
||||
dropOutShapeInfo.maskLastAxis = CeilDiv(dropMaskInfo.maskLstAxis, blockBytes) * blockBytes;
|
||||
DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, dropMaskInfo.keepProb, dropOutShapeInfo);
|
||||
} else {
|
||||
dropOutShapeInfo.maskLastAxis = CeilDiv(dropMaskInfo.maskLstAxis / byteBitRatio, blockBytes) * blockBytes;
|
||||
if (likely(dropMaskInfo.lstAxis / byteBitRatio % blockBytes == 0)) {
|
||||
DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, dropMaskInfo.keepProb, dropOutShapeInfo);
|
||||
} else {
|
||||
DropOut<T, false, DROPOUT_MODE_BIT_MISALIGN>(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer,
|
||||
dropMaskInfo.keepProb, dropOutShapeInfo);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // DROPMASK_H
|
||||
483
csrc/utils/inc/kernel/pse.h
Normal file
483
csrc/utils/inc/kernel/pse.h
Normal file
@@ -0,0 +1,483 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file pse.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef FLASH_ATTENTION_SCORE_PSE_H
|
||||
#define FLASH_ATTENTION_SCORE_PSE_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "util.h"
|
||||
|
||||
constexpr static int64_t pseS1S2 = 0;
|
||||
constexpr static int64_t pse1S2 = 1;
|
||||
constexpr static int64_t pseSlopeBn = 2;
|
||||
constexpr static int64_t pseSlopeN = 3;
|
||||
|
||||
constexpr static uint8_t pseEncodeALibiS2Full = 0x11;
|
||||
|
||||
enum class PseTypeEnum {
|
||||
PSE_OUTER_MUL_ADD_TYPE = 0, // default
|
||||
PSE_OUTER_ADD_MUL_TYPE,
|
||||
PSE_INNER_MUL_ADD_TYPE,
|
||||
PSE_INNER_MUL_ADD_SQRT_TYPE,
|
||||
PSE_INVALID_TYPE
|
||||
};
|
||||
|
||||
struct PseInfo {
|
||||
int64_t blockCount;
|
||||
int64_t bSSOffset; // boidx * s1 * s2
|
||||
int64_t boIdx;
|
||||
int64_t gSize;
|
||||
int64_t goIdx;
|
||||
int64_t loopIdx;
|
||||
int64_t n2G;
|
||||
int64_t n2oIdx;
|
||||
int64_t pseBSize;
|
||||
int64_t pseS1Size; // for alibi
|
||||
int64_t pseS2ComputeSize; // for alibi, do not need assignment
|
||||
int64_t pseS2Size; // for alibi
|
||||
uint32_t pseShapeType;
|
||||
int64_t readS2Size; // for alibi, do not need assignment
|
||||
int64_t s1BaseSize;
|
||||
int64_t s1Size;
|
||||
int64_t s1oIdx;
|
||||
int64_t s2AlignedSize;
|
||||
int64_t s2BaseNratioSize;
|
||||
int64_t s2LoopCount;
|
||||
int64_t s2RealSize;
|
||||
int64_t s2Size;
|
||||
int64_t s2SizeAcc; // accumulated sum of s2 size
|
||||
int64_t s2StartIdx;
|
||||
int64_t vec1S1BaseSize;
|
||||
int64_t vec1S1RealSize;
|
||||
uint32_t pseEncodeType; // for distinguish alibi
|
||||
uint32_t pseType; // 0: outer, mul-add 1:outer, add-mul 2:inner, mul-add 3:inner, mul-add-sqrt
|
||||
int64_t pseAlibiBaseS1;
|
||||
int64_t pseAlibiBaseS2;
|
||||
int64_t qStartIdx;
|
||||
int64_t kvStartIdx;
|
||||
int64_t vecCoreOffset = 0;
|
||||
bool needCast;
|
||||
bool align8 = false;
|
||||
bool pseEndogenous = false;
|
||||
};
|
||||
|
||||
template <typename INPUT_T, bool hasPse>
|
||||
__aicore__ inline void DataCopyInCommon(LocalTensor<INPUT_T> &dstTensor, GlobalTensor<INPUT_T> &srcTensor, int64_t offset,
|
||||
int64_t s1Size, int64_t s2Size, int64_t actualS2Len, int32_t dtypeSize,
|
||||
int32_t alignedS2Size)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
uint32_t shapeArray[] = {static_cast<uint32_t>(s1Size), static_cast<uint32_t>(alignedS2Size)};
|
||||
dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
|
||||
dstTensor.SetSize(s1Size * alignedS2Size);
|
||||
DataCopyParams dataCopyParams;
|
||||
dataCopyParams.blockCount = s1Size;
|
||||
dataCopyParams.blockLen = CeilDiv(s2Size * dtypeSize, blockBytes); // 单位32B
|
||||
dataCopyParams.dstStride = alignedS2Size * dtypeSize / blockBytes - dataCopyParams.blockLen; // gap
|
||||
if (actualS2Len * dtypeSize % blockBytes == 0) {
|
||||
dataCopyParams.srcStride =
|
||||
(actualS2Len * dtypeSize - dataCopyParams.blockLen * blockBytes) / blockBytes; // srcGap
|
||||
DataCopy(dstTensor, srcTensor[offset], dataCopyParams);
|
||||
} else {
|
||||
dataCopyParams.blockLen = s2Size * dtypeSize; // 单位Byte
|
||||
dataCopyParams.srcStride = (actualS2Len * dtypeSize - dataCopyParams.blockLen);
|
||||
dataCopyParams.dstStride = (alignedS2Size - s2Size) * dtypeSize / blockBytes;
|
||||
DataCopyPadParams dataCopyPadParams;
|
||||
dataCopyPadParams.isPad = false;
|
||||
DataCopyPad(dstTensor, srcTensor[offset], dataCopyParams, dataCopyPadParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename INPUT_T, bool hasPse>
|
||||
__aicore__ inline void DataCopyIn(LocalTensor<INPUT_T> &dstTensor, GlobalTensor<INPUT_T> &srcTensor, int64_t offset,
|
||||
int64_t s1Size, int64_t s2Size, int64_t actualS2Len, int64_t alignedSize = 16)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
int32_t dtypeSize = sizeof(INPUT_T);
|
||||
int32_t alignedS2Size = CeilDiv(s2Size, alignedSize) * alignedSize;
|
||||
DataCopyInCommon<INPUT_T, hasPse>(dstTensor, srcTensor, offset, s1Size, s2Size,
|
||||
actualS2Len, dtypeSize, alignedS2Size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename INPUT_T, bool hasPse>
|
||||
__aicore__ inline void DataCopyInAlign8(LocalTensor<INPUT_T> &dstTensor, GlobalTensor<INPUT_T> &srcTensor, int64_t offset,
|
||||
int64_t s1Size, int64_t s2Size, int64_t actualS2Len)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
int32_t dtypeSize = sizeof(INPUT_T);
|
||||
if (dtypeSize == 0){
|
||||
return;
|
||||
}
|
||||
int32_t alignedS2Size = CeilDiv(s2Size, 32 / dtypeSize) * (32 / dtypeSize);
|
||||
DataCopyInCommon<INPUT_T, hasPse>(dstTensor, srcTensor, offset, s1Size, s2Size,
|
||||
actualS2Len, dtypeSize, alignedS2Size);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
dst = BroadcastAdd(src0, src1)
|
||||
src0 shape: (s1, s2)
|
||||
src1 shape: (1, s2)
|
||||
dst shape: (s1, s2)
|
||||
*/
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void BroadcastAdd(const LocalTensor<T> &src0Tensor, const LocalTensor<T> &src1Tensor,
|
||||
int64_t src0Offset, int32_t src1Size, int32_t repeatTimes)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
/* Total data number of single step should be smaller than 256bytes.
|
||||
* If larger, we need to do add multiple times. */
|
||||
int32_t innerLoop = src1Size / repeatMaxSize; // s2轴整块计算次数
|
||||
int32_t innerRemain = src1Size % repeatMaxSize; // s2轴尾块计算量
|
||||
BinaryRepeatParams binaryRepeatParams;
|
||||
binaryRepeatParams.src0BlkStride = 1;
|
||||
binaryRepeatParams.src0RepStride = src1Size / blockSize;
|
||||
binaryRepeatParams.src1BlkStride = 1;
|
||||
binaryRepeatParams.src1RepStride = 0;
|
||||
binaryRepeatParams.dstRepStride = binaryRepeatParams.src0RepStride;
|
||||
binaryRepeatParams.blockNumber = binaryRepeatParams.src0RepStride;
|
||||
|
||||
for (int32_t j = 0; j < innerLoop; j++) {
|
||||
auto innerOffset = j * repeatMaxSize;
|
||||
auto ubOffset = src0Offset + innerOffset;
|
||||
Add(src0Tensor[ubOffset], src0Tensor[ubOffset], src1Tensor[innerOffset], repeatMaxSize, repeatTimes,
|
||||
binaryRepeatParams);
|
||||
}
|
||||
if (innerRemain > 0) {
|
||||
auto innerOffset = innerLoop * repeatMaxSize;
|
||||
auto ubOffset = src0Offset + innerOffset;
|
||||
Add(src0Tensor[ubOffset], src0Tensor[ubOffset], src1Tensor[innerOffset], innerRemain, repeatTimes,
|
||||
binaryRepeatParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void PseBroadcastAdd(int32_t s1Size, int32_t s2Size, int32_t computeSize, const LocalTensor<T> &pseUb,
|
||||
const LocalTensor<T> &dstTensor, uint32_t pseShapeType)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
if (pseShapeType == pseS1S2 || pseShapeType == pseSlopeBn || pseShapeType == pseSlopeN) {
|
||||
Add(dstTensor, dstTensor, pseUb, computeSize);
|
||||
} else {
|
||||
/* Total repeated times should be <= repeatMaxTimes. If larger,
|
||||
* we need to do multiple inner loops. */
|
||||
int32_t s1OuterLoop = s1Size / repeatMaxTimes;
|
||||
int32_t s1OuterRemain = s1Size % repeatMaxTimes;
|
||||
for (int32_t s1OuterIdx = 0; s1OuterIdx < s1OuterLoop; s1OuterIdx++) {
|
||||
int32_t s1OuterOffset = s1OuterIdx * repeatMaxTimes * s2Size;
|
||||
BroadcastAdd<T, hasPse>(dstTensor, pseUb, s1OuterOffset, s2Size, repeatMaxTimes);
|
||||
}
|
||||
if (s1OuterRemain > 0) {
|
||||
int32_t s1OuterOffset = s1OuterLoop * repeatMaxTimes * s2Size;
|
||||
BroadcastAdd<T, hasPse>(dstTensor, pseUb, s1OuterOffset, s2Size, s1OuterRemain);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <bool hasPse> __aicore__ inline int64_t PseComputeOffset(PseInfo &pseInfo)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
int64_t bOffset = 0;
|
||||
int64_t n2Offset = 0;
|
||||
int64_t s1Offset = 0;
|
||||
int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize;
|
||||
int64_t gOffset = 0;
|
||||
if (pseInfo.pseShapeType == pseS1S2) {
|
||||
// b, n2, g, s1, s2
|
||||
bOffset = pseInfo.bSSOffset * pseInfo.n2G;
|
||||
n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.s1Size * pseInfo.s2Size;
|
||||
gOffset = pseInfo.goIdx * pseInfo.s1Size * pseInfo.s2Size;
|
||||
s1Offset = (pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
|
||||
pseInfo.loopIdx * pseInfo.vec1S1BaseSize) * pseInfo.s2Size;
|
||||
} else if (pseInfo.pseShapeType == pse1S2) {
|
||||
// b, n2, g, 1, s2
|
||||
bOffset = pseInfo.s2SizeAcc * pseInfo.n2G;
|
||||
n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.s2Size;
|
||||
gOffset = pseInfo.goIdx * pseInfo.s2Size;
|
||||
}
|
||||
if (pseInfo.pseBSize == 1) {
|
||||
bOffset = 0;
|
||||
}
|
||||
return bOffset + n2Offset + gOffset + s1Offset + s2Offset;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <LayOutTypeEnum layOutType, bool hasPse> __aicore__ inline int64_t PseAlibiComputeOffset(PseInfo &pseInfo)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
int64_t bOffset = (pseInfo.boIdx % pseInfo.pseBSize) * pseInfo.n2G * pseInfo.pseS2Size * pseInfo.pseS1Size;
|
||||
int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.pseS2Size * pseInfo.pseS1Size;
|
||||
int64_t gOffset = pseInfo.goIdx * pseInfo.pseS2Size * pseInfo.pseS1Size;
|
||||
int64_t row = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
|
||||
pseInfo.loopIdx * pseInfo.vec1S1BaseSize;
|
||||
int64_t column = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize;
|
||||
int64_t m = 0;
|
||||
int64_t k = 0;
|
||||
if constexpr (layOutType != LayOutTypeEnum::LAYOUT_TND) {
|
||||
int64_t threshold = pseInfo.s1Size - pseInfo.pseS1Size;
|
||||
if (row >= threshold) {
|
||||
m = row - threshold;
|
||||
k = column;
|
||||
} else {
|
||||
m = row % pseInfo.pseS1Size;
|
||||
k = pseInfo.pseS2Size - (row - column) - (pseInfo.pseS1Size - m);
|
||||
}
|
||||
} else {
|
||||
int64_t threshold = pseInfo.pseS2Size - pseInfo.pseS1Size;
|
||||
int64_t posVal = row - column - threshold;
|
||||
if (threshold >= 0) {
|
||||
if (posVal >= 0) {
|
||||
m = posVal;
|
||||
k = 0;
|
||||
} else {
|
||||
m = 0;
|
||||
k = -posVal;
|
||||
}
|
||||
} else {
|
||||
m = posVal;
|
||||
k = 0;
|
||||
}
|
||||
}
|
||||
int64_t s1Offset = m * pseInfo.pseS2Size;
|
||||
int64_t s2Offset = k;
|
||||
pseInfo.readS2Size = Min(pseInfo.s2AlignedSize, pseInfo.pseS2Size - k);
|
||||
pseInfo.pseS2ComputeSize = Align(pseInfo.readS2Size);
|
||||
|
||||
return bOffset + n2Offset + gOffset + s1Offset + s2Offset;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool hasPse> __aicore__ inline bool NeedPseAlibiCompute(PseInfo &pseInfo)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
// Alibi编码只计算下三角
|
||||
if (pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
|
||||
(pseInfo.loopIdx + 1) * pseInfo.vec1S1BaseSize <=
|
||||
pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename INPUT_T, typename T, LayOutTypeEnum layOutType, bool hasPse>
|
||||
__aicore__ inline void PseAlibiCopyIn(LocalTensor<T> &dstTensor, LocalTensor<INPUT_T> &tmpTensor,
|
||||
GlobalTensor<INPUT_T> &srcTensor, PseInfo &pseInfo, int64_t alignedSize = 16)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
if (!NeedPseAlibiCompute<hasPse>(pseInfo)) {
|
||||
return;
|
||||
}
|
||||
int64_t offset = PseAlibiComputeOffset<layOutType, hasPse>(pseInfo);
|
||||
if constexpr (IsSameType<INPUT_T, T>::value) {
|
||||
if (!pseInfo.align8){
|
||||
DataCopyIn<INPUT_T, hasPse>(dstTensor, srcTensor, offset, pseInfo.vec1S1RealSize, pseInfo.readS2Size,
|
||||
pseInfo.pseS2Size, alignedSize);
|
||||
} else {
|
||||
DataCopyInAlign8<INPUT_T, hasPse>(dstTensor, srcTensor, offset, pseInfo.vec1S1RealSize,
|
||||
pseInfo.readS2Size, pseInfo.pseS2Size);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
DataCopyIn<INPUT_T, hasPse>(tmpTensor, srcTensor, offset, pseInfo.vec1S1RealSize, pseInfo.readS2Size,
|
||||
pseInfo.pseS2Size, alignedSize);
|
||||
if (pseInfo.needCast) {
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
Cast(dstTensor, tmpTensor, RoundMode::CAST_NONE, pseInfo.vec1S1RealSize * pseInfo.pseS2ComputeSize);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void PseSlopeCopyIn(LocalTensor<T> &dstTensor, LocalTensor<half> &helpTensor,
|
||||
__gm__ uint8_t *pseSlope, GlobalTensor<half> &alibiGm, PseInfo &pseInfo,
|
||||
int64_t alignedSize = 16) {
|
||||
if constexpr (hasPse == true) {
|
||||
int64_t bOffset = 0;
|
||||
int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize;
|
||||
int64_t gOffset = pseInfo.goIdx;
|
||||
|
||||
if (pseInfo.pseShapeType == pseSlopeBn) {
|
||||
bOffset = pseInfo.boIdx * pseInfo.n2G;
|
||||
}
|
||||
int64_t offset = bOffset + n2Offset + gOffset;
|
||||
|
||||
DataCopyIn<half, hasPse>(helpTensor, alibiGm, 0, pseInfo.vec1S1RealSize,
|
||||
pseInfo.s2RealSize, pseInfo.pseAlibiBaseS2, alignedSize);
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
|
||||
if (pseInfo.needCast) {
|
||||
int64_t computeSize = pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize;
|
||||
Cast(dstTensor, helpTensor, RoundMode::CAST_NONE, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
int64_t s1Offset = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
|
||||
pseInfo.loopIdx * pseInfo.vec1S1BaseSize;
|
||||
int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize;
|
||||
|
||||
float posShift = float(s2Offset + pseInfo.kvStartIdx - s1Offset - pseInfo.qStartIdx);
|
||||
|
||||
Adds(dstTensor, dstTensor, posShift, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
Abs(dstTensor, dstTensor, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
float slopes = ((__gm__ T *)pseSlope)[offset] * -1;
|
||||
if (pseInfo.pseType == (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) {
|
||||
Sqrt(dstTensor, dstTensor, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
Muls(dstTensor, dstTensor, slopes, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void PseSlopeCast(LocalTensor<T> &dstTensor, LocalTensor<half> &helpTensor,
|
||||
__gm__ uint8_t *pseSlope, PseInfo &pseInfo) {
|
||||
if constexpr (hasPse == true) {
|
||||
int64_t bOffset = 0;
|
||||
int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize;
|
||||
int64_t gOffset = pseInfo.goIdx;
|
||||
|
||||
if (pseInfo.pseShapeType == pseSlopeBn) {
|
||||
bOffset = pseInfo.boIdx * pseInfo.n2G;
|
||||
}
|
||||
int64_t offset = bOffset + n2Offset + gOffset;
|
||||
int64_t computeSize = pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize;
|
||||
Cast(dstTensor, helpTensor, RoundMode::CAST_NONE, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
int64_t s1Offset = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
|
||||
pseInfo.loopIdx * pseInfo.vec1S1BaseSize;
|
||||
int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize;
|
||||
|
||||
float posShift = float(s2Offset + pseInfo.kvStartIdx - s1Offset - pseInfo.qStartIdx);
|
||||
|
||||
Adds(dstTensor, dstTensor, posShift, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
Abs(dstTensor, dstTensor, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
float slopes = ((__gm__ T *)pseSlope)[offset] * -1;
|
||||
if (pseInfo.pseType == (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) {
|
||||
Sqrt(dstTensor, dstTensor, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
Muls(dstTensor, dstTensor, slopes, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename INPUT_T, typename T, LayOutTypeEnum layOutType, bool hasPse>
|
||||
__aicore__ inline void PseCopyIn(LocalTensor<T> &dstTensor, LocalTensor<INPUT_T> &tmpTensor,
|
||||
GlobalTensor<INPUT_T> &srcTensor, PseInfo &pseInfo, int64_t alignedSize = 16)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
if (pseInfo.pseEncodeType == pseEncodeALibiS2Full) {
|
||||
return PseAlibiCopyIn<INPUT_T, T, layOutType, hasPse>(dstTensor, tmpTensor, srcTensor, pseInfo, alignedSize);
|
||||
}
|
||||
int64_t offset = PseComputeOffset<hasPse>(pseInfo);
|
||||
int64_t s1Size = pseInfo.pseShapeType == pse1S2 ? (pseInfo.blockCount == 0 ? 1 : pseInfo.blockCount) :
|
||||
pseInfo.vec1S1RealSize;
|
||||
|
||||
if constexpr (IsSameType<INPUT_T, T>::value) {
|
||||
if (!pseInfo.align8){
|
||||
DataCopyIn<INPUT_T, hasPse>(dstTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize,
|
||||
pseInfo.s2Size, alignedSize);
|
||||
} else {
|
||||
DataCopyInAlign8<INPUT_T, hasPse>(dstTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize, pseInfo.s2Size);
|
||||
}
|
||||
return;
|
||||
}
|
||||
DataCopyIn<INPUT_T, hasPse>(tmpTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize, pseInfo.s2Size,
|
||||
alignedSize);
|
||||
if (pseInfo.needCast) {
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
Cast(dstTensor, tmpTensor, RoundMode::CAST_NONE, s1Size * pseInfo.s2AlignedSize);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void PseAlibiCompute(LocalTensor<T> &dstTensor, LocalTensor<T> &pseTensor, PseInfo &pseInfo)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
if (!NeedPseAlibiCompute<hasPse>(pseInfo)) {
|
||||
return;
|
||||
}
|
||||
Add(dstTensor, dstTensor, pseTensor, pseInfo.vec1S1RealSize * pseInfo.pseS2ComputeSize);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void PseCompute(LocalTensor<T> &dstTensor, LocalTensor<T> &pseTensor, PseInfo &pseInfo)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
if (pseInfo.pseEncodeType == pseEncodeALibiS2Full) {
|
||||
return PseAlibiCompute<T, hasPse>(dstTensor, pseTensor, pseInfo);
|
||||
}
|
||||
int64_t computeSize = (pseInfo.pseShapeType == pseS1S2 || pseInfo.pseShapeType == pseSlopeBn ||
|
||||
pseInfo.pseShapeType == pseSlopeN)
|
||||
? pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize
|
||||
: pseInfo.s2AlignedSize;
|
||||
PseBroadcastAdd<T, hasPse>(pseInfo.vec1S1RealSize, pseInfo.s2AlignedSize, computeSize, pseTensor,
|
||||
dstTensor, pseInfo.pseShapeType);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool hasPse>
|
||||
__aicore__ inline void PseInnerAlibiCreate(GlobalTensor<half> &dstTensor, LocalTensor<half> &helpTensor, PseInfo &pseInfo) {
|
||||
if constexpr (hasPse == true) {
|
||||
if (pseInfo.pseType != (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_TYPE && pseInfo.pseType != (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) {
|
||||
return;
|
||||
}
|
||||
event_t eventIdMte3ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
|
||||
event_t eventIdMte3ToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_S));
|
||||
event_t eventIdVToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
|
||||
float tmpValue = -1.0;
|
||||
|
||||
for (int64_t i = 0; i < pseInfo.pseAlibiBaseS1; i++) {
|
||||
CreateVecIndex(helpTensor, (half)(i * tmpValue), pseInfo.pseAlibiBaseS2);
|
||||
SetFlag<HardEvent::V_MTE3>(eventIdVToMte3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventIdVToMte3);
|
||||
DataCopy(dstTensor[i * pseInfo.pseAlibiBaseS2], helpTensor, pseInfo.pseAlibiBaseS2);
|
||||
SetFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
|
||||
WaitFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
|
||||
SetFlag<HardEvent::MTE3_S>(eventIdMte3ToS);
|
||||
WaitFlag<HardEvent::MTE3_S>(eventIdMte3ToS);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
144
csrc/utils/inc/kernel/util.h
Normal file
144
csrc/utils/inc/kernel/util.h
Normal file
@@ -0,0 +1,144 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file util.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef FLASH_ATTENTION_UTIL_H
|
||||
#define FLASH_ATTENTION_UTIL_H
|
||||
|
||||
constexpr int32_t blockBytes = 32;
|
||||
constexpr int32_t byteBitRatio = 8;
|
||||
constexpr int64_t prefixAttenMaskDownHeight = 1024;
|
||||
constexpr static int32_t blockSize = blockBytes / 4; // 4 means sizeof(T)
|
||||
constexpr static int32_t repeatMaxBytes = 256;
|
||||
constexpr static int32_t repeatMaxTimes = 255;
|
||||
constexpr static int32_t repeatMaxSize = repeatMaxBytes / 4; // 4 means sizeof(T)
|
||||
|
||||
using AscendC::LocalTensor;
|
||||
using AscendC::GlobalTensor;
|
||||
using AscendC::DataFormat;
|
||||
using AscendC::ShapeInfo;
|
||||
using AscendC::DataCopyParams;
|
||||
using AscendC::DataCopyPadParams;
|
||||
using AscendC::BinaryRepeatParams;
|
||||
using AscendC::IsSameType;
|
||||
using AscendC::HardEvent;
|
||||
using AscendC::SetFlag;
|
||||
using AscendC::WaitFlag;
|
||||
|
||||
enum class LayOutTypeEnum { None = 0, LAYOUT_BSH = 1, LAYOUT_SBH = 2, LAYOUT_BNSD = 3, LAYOUT_TND = 4, LAYOUT_NTD_TND = 5};
|
||||
|
||||
namespace math {
|
||||
template <typename T> __aicore__ inline T Ceil(T a, T b)
|
||||
{
|
||||
if (b == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T> __aicore__ inline T Align(T a, T b)
|
||||
{
|
||||
if (b == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (a + b - 1) / b * b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__aicore__ inline T1 CeilDiv(T1 a, T2 b)
|
||||
{
|
||||
if (b == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__aicore__ inline T1 Max(T1 a, T2 b)
|
||||
{
|
||||
return (a > b) ? (a) : (b);
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__aicore__ inline T1 Min(T1 a, T2 b)
|
||||
{
|
||||
return (a > b) ? (b) : (a);
|
||||
}
|
||||
|
||||
__aicore__ inline void BoolCopyIn(LocalTensor<uint8_t> &dstTensor, GlobalTensor<uint8_t> &srcTensor,
|
||||
int64_t srcOffset, uint32_t s1Size, uint32_t s2Size, int64_t totalS2Size, int64_t alignedSize = blockBytes)
|
||||
{
|
||||
uint32_t alignedS2Size = CeilDiv(s2Size, alignedSize) * alignedSize;
|
||||
uint32_t shapeArray[] = {s1Size, alignedS2Size};
|
||||
dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
|
||||
dstTensor.SetSize(s1Size * alignedS2Size);
|
||||
DataCopyParams dataCopyParams;
|
||||
dataCopyParams.blockCount = s1Size;
|
||||
dataCopyParams.dstStride = 0;
|
||||
if (totalS2Size == blockBytes && alignedSize == 64) { // totalS2Size < 64 && totalS2Size % blockBytes == 0
|
||||
dataCopyParams.dstStride = 1;
|
||||
alignedSize = blockBytes;
|
||||
alignedS2Size = CeilDiv(s2Size, blockBytes) * blockBytes;
|
||||
}
|
||||
if (totalS2Size % alignedSize == 0) {
|
||||
dataCopyParams.blockLen = alignedS2Size / blockBytes;
|
||||
dataCopyParams.srcStride = (totalS2Size - alignedS2Size) / blockBytes;
|
||||
DataCopy(dstTensor, srcTensor[srcOffset], dataCopyParams);
|
||||
} else {
|
||||
dataCopyParams.blockLen = s2Size;
|
||||
dataCopyParams.srcStride = totalS2Size - s2Size;
|
||||
DataCopyPadParams dataCopyPadParams;
|
||||
dataCopyPadParams.isPad = true;
|
||||
dataCopyPadParams.rightPadding = Min(alignedS2Size - s2Size, blockBytes);
|
||||
dataCopyPadParams.paddingValue = 1;
|
||||
DataCopyPad(dstTensor, srcTensor[srcOffset], dataCopyParams, dataCopyPadParams);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void Bit2Int8CopyIn(LocalTensor<uint8_t> &dstTensor, GlobalTensor<uint8_t> &srcTensor,
|
||||
int64_t srcOffset, uint32_t batchSize, uint32_t s1BaseSize, uint32_t s2BaseSize, int64_t s2TotalSize,
|
||||
int64_t alignedSize = blockBytes)
|
||||
{
|
||||
uint32_t alignedS2Size = CeilDiv(s2BaseSize / byteBitRatio, alignedSize) * alignedSize;
|
||||
uint32_t shapeArray[] = {batchSize * s1BaseSize, alignedS2Size};
|
||||
dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
|
||||
dstTensor.SetSize(batchSize * s1BaseSize * alignedS2Size);
|
||||
DataCopyParams dataCopyParams;
|
||||
dataCopyParams.blockCount = batchSize * s1BaseSize;
|
||||
dataCopyParams.blockLen = CeilDiv(s2BaseSize / byteBitRatio, blockBytes);
|
||||
dataCopyParams.dstStride = 0;
|
||||
if (s2TotalSize / byteBitRatio % alignedSize == 0 && s2BaseSize / byteBitRatio % alignedSize == 0) {
|
||||
dataCopyParams.srcStride =
|
||||
(s2TotalSize / byteBitRatio - dataCopyParams.blockLen * blockBytes) / blockBytes;
|
||||
DataCopy(dstTensor, srcTensor[srcOffset / byteBitRatio], dataCopyParams);
|
||||
} else {
|
||||
dataCopyParams.blockLen = CeilDiv(s2BaseSize , byteBitRatio);
|
||||
dataCopyParams.srcStride = (s2TotalSize - s2BaseSize) / byteBitRatio;
|
||||
DataCopyPadParams dataCopyPadParams;
|
||||
dataCopyPadParams.isPad = true;
|
||||
dataCopyPadParams.rightPadding = 0;
|
||||
dataCopyPadParams.paddingValue = 0;
|
||||
DataCopyPad(dstTensor, srcTensor[srcOffset / byteBitRatio], dataCopyParams, dataCopyPadParams);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline int32_t Align(int32_t shape)
|
||||
{
|
||||
int32_t alignFactor = 16;
|
||||
int32_t alignedSize = CeilDiv<int32_t, int32_t>(shape, alignFactor) * alignFactor;
|
||||
return alignedSize;
|
||||
}
|
||||
|
||||
#endif // FLASH_ATTENTION_UTIL_H
|
||||
190
csrc/utils/inc/log/inner/dfx_base.h
Normal file
190
csrc/utils/inc/log/inner/dfx_base.h
Normal file
@@ -0,0 +1,190 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file dfx_base.h
|
||||
* \brief 外部模块不应直接引用本头文件
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <cstdint>
|
||||
#include <sstream>
|
||||
#include <unistd.h>
|
||||
#include <sys/syscall.h>
|
||||
#include <securec.h>
|
||||
#include <base/alog_pub.h>
|
||||
#include <base/err_msg.h>
|
||||
#include <exe_graph/runtime/tiling_context.h>
|
||||
#include <exe_graph/runtime/tiling_parse_context.h>
|
||||
#include <exe_graph/runtime/infer_shape_context.h>
|
||||
#include <exe_graph/runtime/infer_datatype_context.h>
|
||||
|
||||
namespace ops {
|
||||
namespace utils {
|
||||
|
||||
class LogBase {
|
||||
public:
|
||||
static constexpr const int MAX_LOG_LEN = 16000;
|
||||
static constexpr const int MSG_HDR_LEN = 200;
|
||||
|
||||
static inline uint64_t GetTid()
|
||||
{
|
||||
return static_cast<uint64_t>(syscall(__NR_gettid));
|
||||
}
|
||||
|
||||
static inline const char *GetStr(const std::string &str)
|
||||
{
|
||||
return str.c_str();
|
||||
}
|
||||
|
||||
static inline const char *GetStr(const char *str)
|
||||
{
|
||||
return str;
|
||||
}
|
||||
|
||||
static inline const std::string &GetOpInfo(const std::string &str)
|
||||
{
|
||||
return str;
|
||||
}
|
||||
|
||||
static inline const char *GetOpInfo(const char *str)
|
||||
{
|
||||
return str;
|
||||
}
|
||||
|
||||
static inline std::string GetOpInfo(const gert::TilingContext *context)
|
||||
{
|
||||
return GetOpInfoFromContext(context);
|
||||
}
|
||||
|
||||
static inline std::string GetOpInfo(const gert::TilingParseContext *context)
|
||||
{
|
||||
return GetOpInfoFromContext(context);
|
||||
}
|
||||
|
||||
static inline std::string GetOpInfo(const gert::InferShapeContext *context)
|
||||
{
|
||||
return GetOpInfoFromContext(context);
|
||||
}
|
||||
|
||||
static inline std::string GetOpInfo(const gert::InferDataTypeContext *context)
|
||||
{
|
||||
return GetOpInfoFromContext(context);
|
||||
}
|
||||
|
||||
private:
|
||||
template <class T> static inline std::string GetOpInfoFromContext(T context)
|
||||
{
|
||||
if (context == nullptr) {
|
||||
return "nil:nil";
|
||||
}
|
||||
std::string opInfo = context->GetNodeType() != nullptr ? context->GetNodeType() : "nil";
|
||||
opInfo += ":";
|
||||
opInfo += context->GetNodeName() != nullptr ? context->GetNodeName() : "nil";
|
||||
return opInfo;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace utils
|
||||
|
||||
template <typename T>
|
||||
std::string Shape2String(const T& shape) {
|
||||
std::ostringstream oss;
|
||||
oss << "[";
|
||||
if (shape.GetDimNum() > 0) {
|
||||
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
|
||||
oss << shape.GetDim(i) << ", ";
|
||||
}
|
||||
oss << shape.GetDim(shape.GetDimNum() - 1);
|
||||
}
|
||||
oss << "]";
|
||||
return oss.str();
|
||||
}
|
||||
} // namespace ops
|
||||
|
||||
// 使用本宏前需预定义标识子模块名称的 OPS_UTILS_LOG_SUB_MOD_NAME
|
||||
// 如: #define OPS_UTILS_LOG_SUB_MOD_NAME "OP_TILING" 或通过 CMake 传递预定义宏
|
||||
#define OPS_LOG_STUB(MOD_ID, LOG_LEVEL, OPS_DESC, FMT, ...) \
|
||||
do { \
|
||||
if (AlogCheckDebugLevel(static_cast<int>(MOD_ID), (LOG_LEVEL)) == 1) { \
|
||||
AlogRecord(static_cast<int>(MOD_ID), DLOG_TYPE_DEBUG, (LOG_LEVEL), \
|
||||
"[%s:%d][%s]%s[%s][%lu] OpName:[%s] " #FMT, \
|
||||
__FILE__, __LINE__, (OPS_UTILS_LOG_SUB_MOD_NAME), \
|
||||
(OPS_UTILS_LOG_PACKAGE_TYPE), __FUNCTION__, ops::utils::LogBase::GetTid(), \
|
||||
ops::utils::LogBase::GetStr(ops::utils::LogBase::GetOpInfo(OPS_DESC)), ##__VA_ARGS__); \
|
||||
} \
|
||||
}while (0)
|
||||
|
||||
#define OPS_LOG_STUB_IF(COND, LOG_FUNC, EXPR) \
|
||||
static_assert(std::is_same<bool, std::decay<decltype(COND)>::type>::value, "condition should be bool"); \
|
||||
do { \
|
||||
if (__builtin_expect((COND), 0)) { \
|
||||
LOG_FUNC; \
|
||||
EXPR; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define OPS_INNER_ERR_STUB(ERR_CODE_STR, OPS_DESC, FMT, ...) \
|
||||
do { \
|
||||
OPS_LOG_STUB(OP, DLOG_ERROR, OPS_DESC, FMT, ##__VA_ARGS__); \
|
||||
REPORT_INNER_ERR_MSG(ERR_CODE_STR, FMT, ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
#define OPS_CALL_ERR_STUB(ERR_CODE_STR, OPS_DESC, FMT, ...) \
|
||||
do { \
|
||||
OPS_LOG_STUB(OP, DLOG_ERROR, OPS_DESC, FMT, ##__VA_ARGS__); \
|
||||
REPORT_INNER_ERR_MSG(ERR_CODE_STR, FMT, ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
#define OPS_LOG_STUB_D(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_DEBUG, OPS_DESC, FMT, ##__VA_ARGS__)
|
||||
#define OPS_LOG_STUB_I(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_INFO, OPS_DESC, FMT, ##__VA_ARGS__)
|
||||
#define OPS_LOG_STUB_W(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_WARN, OPS_DESC, FMT, ##__VA_ARGS__)
|
||||
#define OPS_LOG_STUB_E(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_ERROR, OPS_DESC, FMT, ##__VA_ARGS__)
|
||||
#define OPS_LOG_STUB_EVENT(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_EVENT, OPS_DESC, FMT, ##__VA_ARGS__)
|
||||
|
||||
#define OPS_LOG_STUB_FULL(LEVEL, OPS_DESC, FMT, ...) \
|
||||
do { \
|
||||
if (0 == AlogCheckDebugLevel(OP, (LEVEL))) { \
|
||||
break; \
|
||||
} \
|
||||
char msgbufxyz[ops::utils::LogBase::MAX_LOG_LEN]; \
|
||||
size_t msgmaxlen = (MSG_LENGTH - ops::utils::LogBase::MSG_HDR_LEN); \
|
||||
int rettmp = snprintf_s(msgbufxyz, sizeof(msgbufxyz), sizeof(msgbufxyz) - 1, FMT, ##__VA_ARGS__); \
|
||||
if (rettmp == -1) { \
|
||||
msgbufxyz[sizeof(msgbufxyz) - 1] = '\0'; \
|
||||
} \
|
||||
size_t msglength = std::strlen(msgbufxyz); \
|
||||
if (msglength < msgmaxlen) { \
|
||||
OPS_LOG_STUB(OP, (LEVEL), (OPS_DESC), "%s", msgbufxyz); \
|
||||
break; \
|
||||
} \
|
||||
char *msgchunkbegin = msgbufxyz; \
|
||||
char *msgchunkend = nullptr; \
|
||||
while (msgchunkbegin < msgbufxyz + msglength) { \
|
||||
if (msgchunkbegin[0] == '\n') { \
|
||||
OPS_LOG_STUB(OP, (LEVEL), (OPS_DESC), ""); \
|
||||
msgchunkbegin += 1; \
|
||||
continue; \
|
||||
} \
|
||||
msgchunkend = std::strchr(msgchunkbegin, '\n'); \
|
||||
if (msgchunkend == nullptr) { \
|
||||
msgchunkend = msgchunkbegin + std::strlen(msgchunkbegin); \
|
||||
} \
|
||||
while (msgchunkend > msgchunkbegin) { \
|
||||
std::string msgchunk(msgchunkbegin, \
|
||||
std::min(msgmaxlen, static_cast<size_t>(msgchunkend - msgchunkbegin))); \
|
||||
OPS_LOG_STUB(OP, (LEVEL), (OPS_DESC), "%s", msgchunk.c_str()); \
|
||||
msgchunkbegin += msgchunk.size(); \
|
||||
} \
|
||||
msgchunkbegin += 1; \
|
||||
} \
|
||||
} while (0)
|
||||
59
csrc/utils/inc/log/ops_log.h
Normal file
59
csrc/utils/inc/log/ops_log.h
Normal file
@@ -0,0 +1,59 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file ops_log.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "log/inner/dfx_base.h"
|
||||
|
||||
/* 基础日志 */
|
||||
#define OPS_LOG_D(OPS_DESC, ...) OPS_LOG_STUB_D(OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_LOG_I(OPS_DESC, ...) OPS_LOG_STUB_I(OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_LOG_W(OPS_DESC, ...) OPS_LOG_STUB_W(OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_LOG_E(OPS_DESC, ...) OPS_INNER_ERR_STUB("EZ9999", OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_LOG_E_WITHOUT_REPORT(OPS_DESC, ...) OPS_LOG_STUB_E(OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_LOG_EVENT(OPS_DESC, ...) OPS_LOG_STUB_EVENT(OPS_DESC, __VA_ARGS__)
|
||||
|
||||
/* 全量日志
|
||||
* 输出超长日志, 若日志超长, 则会被分为多行输出 */
|
||||
#define OPS_LOG_FULL(LEVEL, OPS_DESC, ...) OPS_LOG_STUB_FULL(LEVEL, OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_LOG_D_FULL(OPS_DESC, ...) OPS_LOG_STUB_FULL(DLOG_DEBUG, OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_LOG_I_FULL(OPS_DESC, ...) OPS_LOG_STUB_FULL(DLOG_INFO, OPS_DESC, __VA_ARGS__)
|
||||
#define OPS_LOG_W_FULL(OPS_DESC, ...) OPS_LOG_STUB_FULL(DLOG_WARN, OPS_DESC, __VA_ARGS__)
|
||||
|
||||
/* 条件日志 */
|
||||
#define OPS_LOG_D_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_D(OP_DESC, __VA_ARGS__), EXPR)
|
||||
#define OPS_LOG_I_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_I(OP_DESC, __VA_ARGS__), EXPR)
|
||||
#define OPS_LOG_W_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_W(OP_DESC, __VA_ARGS__), EXPR)
|
||||
#define OPS_LOG_E_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_E(OP_DESC, __VA_ARGS__), EXPR)
|
||||
#define OPS_LOG_EVENT_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_EVENT(OP_DESC, __VA_ARGS__), EXPR)
|
||||
|
||||
#define OPS_LOG_E_IF_NULL(OPS_DESC, PTR, EXPR) \
|
||||
if (__builtin_expect((PTR) == nullptr, 0)) { \
|
||||
OPS_LOG_STUB_E(OPS_DESC, "%s is nullptr!", #PTR); \
|
||||
OPS_CALL_ERR_STUB("EZ9999", OPS_DESC, "%s is nullptr!", #PTR); \
|
||||
EXPR; \
|
||||
}
|
||||
|
||||
#define OPS_CHECK(COND, LOG_FUNC, EXPR) \
|
||||
if (COND) { \
|
||||
LOG_FUNC; \
|
||||
EXPR; \
|
||||
}
|
||||
|
||||
#define OP_CHECK(COND, LOG_FUNC, EXPR) \
|
||||
if (COND) { \
|
||||
LOG_FUNC; \
|
||||
EXPR; \
|
||||
}
|
||||
47
csrc/utils/inc/tiling/data_copy_transpose_tiling.h
Normal file
47
csrc/utils/inc/tiling/data_copy_transpose_tiling.h
Normal file
@@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file data_copy_transpose_tiling.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <graph/tensor.h>
|
||||
#include "data_copy_transpose_tiling_def.h"
|
||||
|
||||
namespace optiling {
|
||||
|
||||
inline void GetDataCopyTransposeTiling(const ge::Shape &dstShape, const ge::Shape &srcShape, const uint32_t typeSize,
|
||||
optiling::CopyTransposeTiling &tiling)
|
||||
{
|
||||
std::vector<int64_t> dstShapeInfo = dstShape.GetDims();
|
||||
std::vector<int64_t> srcShapeInfo = srcShape.GetDims();
|
||||
|
||||
tiling.set_dstShapeB(dstShapeInfo[0]);
|
||||
tiling.set_dstShapeN(dstShapeInfo[1]);
|
||||
tiling.set_dstShapeS(dstShapeInfo[2]);
|
||||
tiling.set_dstShapeH(dstShapeInfo[3]);
|
||||
tiling.set_dstShapeHN(tiling.get_dstShapeH() / tiling.get_dstShapeN());
|
||||
|
||||
tiling.set_srcShapeB(srcShapeInfo[0]);
|
||||
tiling.set_srcShapeN(srcShapeInfo[1]);
|
||||
tiling.set_srcShapeS(srcShapeInfo[2]);
|
||||
tiling.set_srcShapeHN(srcShapeInfo[3]);
|
||||
tiling.set_originalShapeNLen(tiling.get_srcShapeHN() * typeSize);
|
||||
tiling.set_shapeSHValue(tiling.get_dstShapeS() * tiling.get_dstShapeH());
|
||||
tiling.set_shapeNsValue(tiling.get_dstShapeN() * tiling.get_dstShapeS());
|
||||
tiling.set_shapeNsnValue(tiling.get_dstShapeN() * tiling.get_srcShapeS() * tiling.get_srcShapeN());
|
||||
tiling.set_shapeBHValue(tiling.get_dstShapeB() * tiling.get_dstShapeH());
|
||||
}
|
||||
|
||||
} // namespace optiling
|
||||
43
csrc/utils/inc/tiling/data_copy_transpose_tiling_def.h
Normal file
43
csrc/utils/inc/tiling/data_copy_transpose_tiling_def.h
Normal file
@@ -0,0 +1,43 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file data_copy_transpose_tiling_def.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <register/tilingdata_base.h>
|
||||
|
||||
namespace optiling {
|
||||
|
||||
BEGIN_TILING_DATA_DEF(CopyTransposeTiling)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeB);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeN);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeS);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeHN);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeH);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeB);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeN);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeS);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeHN);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, originalShapeNLen);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shapeSHValue);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shapeNsValue);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shapeNsnValue);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, invalidParamCopyTransposeTiling);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shapeBHValue);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, paramsAlign);
|
||||
END_TILING_DATA_DEF;
|
||||
REGISTER_TILING_DATA_CLASS(CopyTransposeTilingOp, CopyTransposeTiling)
|
||||
|
||||
} // namespace optiling
|
||||
225
csrc/utils/inc/tiling/tiling_base.h
Normal file
225
csrc/utils/inc/tiling/tiling_base.h
Normal file
@@ -0,0 +1,225 @@
|
||||
/**
|
||||
* Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_base.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <exe_graph/runtime/tiling_context.h>
|
||||
#include <graph/utils/type_utils.h>
|
||||
#include <tiling/platform/platform_ascendc.h>
|
||||
#include "log/ops_log.h"
|
||||
|
||||
#ifdef ASCENDC_OP_TEST
|
||||
#define ASCENDC_EXTERN_C extern "C"
|
||||
#else
|
||||
#define ASCENDC_EXTERN_C
|
||||
#endif
|
||||
|
||||
namespace optiling {
|
||||
|
||||
struct AiCoreParams {
|
||||
uint64_t ubSize;
|
||||
uint64_t blockDim;
|
||||
uint64_t aicNum;
|
||||
uint64_t l1Size;
|
||||
uint64_t l0aSize;
|
||||
uint64_t l0bSize;
|
||||
uint64_t l0cSize;
|
||||
};
|
||||
|
||||
struct FlashAttentionScoreGradCompileInfo {
|
||||
uint32_t aivNum;
|
||||
uint32_t aicNum;
|
||||
uint64_t ubSize;
|
||||
uint64_t l1Size;
|
||||
uint64_t l0aSize;
|
||||
uint64_t l0bSize;
|
||||
uint64_t l0cSize;
|
||||
uint64_t l2CacheSize;
|
||||
int64_t coreNum;
|
||||
};
|
||||
|
||||
class TilingBaseClass {
|
||||
public:
|
||||
TilingBaseClass() = default;
|
||||
|
||||
explicit TilingBaseClass(gert::TilingContext *context) : context_(context)
|
||||
{
|
||||
}
|
||||
|
||||
virtual ~TilingBaseClass() = default;
|
||||
|
||||
// Tiling执行框架
|
||||
// 1、GRAPH_SUCCESS: 成功,并且不需要继续执行后续Tiling类的实现
|
||||
// 2、GRAPH_FAILED: 失败,中止整个Tiling流程
|
||||
// 3、GRAPH_PARAM_INVALID: 本类不支持,需要继续往下执行其他Tiling类的实现
|
||||
ge::graphStatus DoTiling()
|
||||
{
|
||||
auto ret = GetShapeAttrsInfo();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
ret = GetPlatformInfo();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
if (!IsCapable()) {
|
||||
return ge::GRAPH_PARAM_INVALID;
|
||||
}
|
||||
ret = DoOpTiling();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
ret = DoLibApiTiling();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
ret = GetWorkspaceSize();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
ret = PostTiling();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
context_->SetTilingKey(GetTilingKey());
|
||||
DumpTilingInfo();
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// 更新 context
|
||||
virtual void Reset(gert::TilingContext *context)
|
||||
{
|
||||
context_ = context;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual bool IsCapable() = 0;
|
||||
// 1、获取平台信息比如CoreNum、UB/L1/L0C资源大小
|
||||
virtual ge::graphStatus GetPlatformInfo() = 0;
|
||||
// 2、获取INPUT/OUTPUT/ATTR信息
|
||||
virtual ge::graphStatus GetShapeAttrsInfo() = 0;
|
||||
// 3、计算数据切分TilingData
|
||||
virtual ge::graphStatus DoOpTiling() = 0;
|
||||
// 4、计算高阶API的TilingData
|
||||
virtual ge::graphStatus DoLibApiTiling() = 0;
|
||||
// 5、计算TilingKey
|
||||
[[nodiscard]] virtual uint64_t GetTilingKey() const = 0;
|
||||
// 6、计算Workspace 大小
|
||||
virtual ge::graphStatus GetWorkspaceSize() = 0;
|
||||
// 7、保存Tiling数据
|
||||
virtual ge::graphStatus PostTiling() = 0;
|
||||
// 8、Dump Tiling数据
|
||||
virtual void DumpTilingInfo()
|
||||
{
|
||||
int32_t enable = AlogCheckDebugLevel(static_cast<int32_t>(OP), DLOG_DEBUG);
|
||||
if (enable != 1) {
|
||||
return;
|
||||
}
|
||||
auto buf = (uint32_t *)context_->GetRawTilingData()->GetData();
|
||||
auto bufLen = context_->GetRawTilingData()->GetDataSize();
|
||||
std::ostringstream oss;
|
||||
oss << "Start to dump tiling info. tilingkey:" << GetTilingKey() << ", tiling data size:" << bufLen
|
||||
<< ", content:";
|
||||
for (size_t i = 0; i < bufLen / sizeof(uint32_t); i++) {
|
||||
oss << *(buf + i) << ",";
|
||||
if (oss.str().length() > 640) { // Split according to 640 to avoid truncation
|
||||
OPS_LOG_D(context_, "%s", oss.str().c_str());
|
||||
oss.str("");
|
||||
}
|
||||
}
|
||||
OPS_LOG_D(context_, "%s", oss.str().c_str());
|
||||
}
|
||||
|
||||
static uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum)
|
||||
{
|
||||
uint32_t ration;
|
||||
if (aicCoreNum == 0 || aivCoreNum == 0 || aicCoreNum > aivCoreNum) {
|
||||
return sliceNum;
|
||||
}
|
||||
ration = aivCoreNum / aicCoreNum;
|
||||
return (sliceNum + (ration - 1)) / ration;
|
||||
}
|
||||
|
||||
template <typename T> [[nodiscard]] std::string GetShapeDebugStr(const T &shape) const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "[";
|
||||
if (shape.GetDimNum() > 0) {
|
||||
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
|
||||
oss << shape.GetDim(i) << ", ";
|
||||
}
|
||||
oss << shape.GetDim(shape.GetDimNum() - 1);
|
||||
}
|
||||
oss << "]";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string GetTensorDebugStr(const gert::StorageShape *shape,
|
||||
const gert::CompileTimeTensorDesc *tensor)
|
||||
{
|
||||
if (shape == nullptr || tensor == nullptr) {
|
||||
return "nil ";
|
||||
}
|
||||
std::ostringstream oss;
|
||||
oss << "(dtype: " << ge::TypeUtils::DataTypeToSerialString(tensor->GetDataType()) << "),";
|
||||
oss << "(shape:" << GetShapeDebugStr(shape->GetStorageShape()) << "),";
|
||||
oss << "(ori_shape:" << GetShapeDebugStr(shape->GetOriginShape()) << "),";
|
||||
oss << "(format: "
|
||||
<< ge::TypeUtils::FormatToSerialString(
|
||||
static_cast<ge::Format>(ge::GetPrimaryFormat(tensor->GetStorageFormat())))
|
||||
<< "),";
|
||||
oss << "(ori_format: " << ge::TypeUtils::FormatToSerialString(tensor->GetOriginFormat()) << ") ";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string GetTilingContextDebugStr()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetInputsNum(); ++i) {
|
||||
oss << "input" << i << ": ";
|
||||
oss << GetTensorDebugStr(context_->GetInputShape(i), context_->GetInputDesc(i));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetOutputsNum(); ++i) {
|
||||
oss << "output" << i << ": ";
|
||||
oss << GetTensorDebugStr(context_->GetOutputShape(i), context_->GetOutputDesc(i));
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string GetTilingDataDebugStr() const
|
||||
{
|
||||
auto rawTilingData = context_->GetRawTilingData();
|
||||
auto rawTilingDataSize = rawTilingData->GetDataSize();
|
||||
auto data = reinterpret_cast<const int32_t *>(rawTilingData->GetData());
|
||||
size_t len = rawTilingDataSize / sizeof(int32_t);
|
||||
std::ostringstream oss;
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
oss << data[i] << ", ";
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
gert::TilingContext *context_ = nullptr;
|
||||
std::unique_ptr<platform_ascendc::PlatformAscendC> ascendcPlatform_{nullptr};
|
||||
uint32_t blockDim_{0};
|
||||
uint64_t workspaceSize_{0};
|
||||
uint64_t tilingKey_{0};
|
||||
AiCoreParams aicoreParams_{0, 0, 0, 0, 0, 0, 0};
|
||||
};
|
||||
|
||||
} // namespace optiling
|
||||
162
csrc/utils/inc/tiling/tiling_templates_registry.h
Normal file
162
csrc/utils/inc/tiling/tiling_templates_registry.h
Normal file
@@ -0,0 +1,162 @@
|
||||
/**
|
||||
* Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_templates_registry.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <exe_graph/runtime/tiling_context.h>
|
||||
#include "tiling/tiling_base.h"
|
||||
#include "log/ops_log.h"
|
||||
#include "error/ops_error.h"
|
||||
|
||||
namespace optiling {
|
||||
|
||||
template <typename T> std::unique_ptr<TilingBaseClass> TILING_CLASS(gert::TilingContext *context)
|
||||
{
|
||||
return std::unique_ptr<T>(new (std::nothrow) T(context));
|
||||
}
|
||||
|
||||
using TilingClassCase = std::unique_ptr<TilingBaseClass> (*)(gert::TilingContext *);
|
||||
|
||||
class TilingCases {
|
||||
public:
|
||||
explicit TilingCases(std::string op_type) : op_type_(std::move(op_type))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename T> void AddTiling(int32_t priority)
|
||||
{
|
||||
OPS_ERR_IF(cases_.find(priority) != cases_.end(),
|
||||
OPS_REPORT_VECTOR_INNER_ERR(op_type_, "There are duplicate registrations."), return);
|
||||
cases_[priority] = TILING_CLASS<T>;
|
||||
OPS_ERR_IF(
|
||||
cases_[priority] == nullptr,
|
||||
OPS_REPORT_VECTOR_INNER_ERR(op_type_, "Register op tiling func failed, please check the class name."),
|
||||
return);
|
||||
}
|
||||
|
||||
const std::map<int32_t, TilingClassCase> &GetTilingCases()
|
||||
{
|
||||
return cases_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<int32_t, TilingClassCase> cases_;
|
||||
const std::string op_type_;
|
||||
};
|
||||
|
||||
class TilingRegistry {
|
||||
public:
|
||||
TilingRegistry() = default;
|
||||
|
||||
#ifdef ASCENDC_OP_TEST
|
||||
static TilingRegistry &GetInstance();
|
||||
#else
|
||||
static TilingRegistry &GetInstance()
|
||||
{
|
||||
static TilingRegistry registry_impl_;
|
||||
return registry_impl_;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::shared_ptr<TilingCases> RegisterOp(const std::string &op_type)
|
||||
{
|
||||
if (registry_map_.find(op_type) == registry_map_.end()) {
|
||||
registry_map_[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
||||
}
|
||||
OPS_ERR_IF(registry_map_[op_type] == nullptr,
|
||||
OPS_REPORT_VECTOR_INNER_ERR(op_type, "Register tiling func failed, please check the class name."),
|
||||
return nullptr);
|
||||
return registry_map_[op_type];
|
||||
}
|
||||
|
||||
ge::graphStatus DoTilingImpl(gert::TilingContext *context)
|
||||
{
|
||||
const char *op_type = context->GetNodeType();
|
||||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
|
||||
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
|
||||
auto tilingTemplate = it->second(context);
|
||||
if (tilingTemplate != nullptr) {
|
||||
ge::graphStatus status = tilingTemplate->DoTiling();
|
||||
if (status != ge::GRAPH_PARAM_INVALID) {
|
||||
OPS_LOG_D(context, "Do general op tiling success priority=%d", it->first);
|
||||
return status;
|
||||
}
|
||||
OPS_LOG_D(context, "Ignore general op tiling priority=%d", it->first);
|
||||
}
|
||||
}
|
||||
OPS_REPORT_VECTOR_INNER_ERR(op_type, "Do op tiling failed, no valid template is found.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
ge::graphStatus DoTilingImpl(gert::TilingContext *context, const std::vector<int32_t> &priorities)
|
||||
{
|
||||
const char *op_type = context->GetNodeType();
|
||||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
|
||||
for (auto priorityId : priorities) {
|
||||
auto templateFunc = tilingTemplateRegistryMap[priorityId](context);
|
||||
if (templateFunc != nullptr) {
|
||||
ge::graphStatus status = templateFunc->DoTiling();
|
||||
if (status == ge::GRAPH_SUCCESS) {
|
||||
OPS_LOG_D(context, "Do general op tiling success priority=%d", priorityId);
|
||||
return status;
|
||||
}
|
||||
OPS_LOG_D(context, "Ignore general op tiling priority=%d", priorityId);
|
||||
}
|
||||
}
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
const std::map<int32_t, TilingClassCase> &GetTilingTemplates(const std::string &op_type)
|
||||
{
|
||||
OPS_ERR_IF(registry_map_.find(op_type) == registry_map_.end(),
|
||||
OPS_REPORT_VECTOR_INNER_ERR(op_type, "Get op tiling func failed, please check the op name."),
|
||||
return empty_tiling_case_);
|
||||
return registry_map_[op_type]->GetTilingCases();
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<std::string, std::shared_ptr<TilingCases>> registry_map_;
|
||||
const std::map<int32_t, TilingClassCase> empty_tiling_case_ {};
|
||||
};
|
||||
|
||||
class Register {
|
||||
public:
|
||||
explicit Register(std::string op_type) : op_type_(std::move(op_type))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename T> Register &tiling(int32_t priority)
|
||||
{
|
||||
auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_);
|
||||
OPS_ERR_IF(tilingCases == nullptr,
|
||||
OPS_REPORT_VECTOR_INNER_ERR(op_type_, "Register op tiling failed, please the op name."),
|
||||
return *this);
|
||||
tilingCases->AddTiling<T>(priority);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string op_type_;
|
||||
};
|
||||
|
||||
// op_type: 算子名称, class_name: 注册的 tiling 类,
|
||||
// priority: tiling 类的优先级, 越小表示优先级越高, 即被选中的概率越大
|
||||
#define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \
|
||||
static Register VAR_UNUSED##op_type_##class_name##priority_register = Register(op_type).tiling<class_name>(priority)
|
||||
|
||||
} // namespace optiling
|
||||
136
csrc/utils/inc/tiling/tiling_type.h
Normal file
136
csrc/utils/inc/tiling/tiling_type.h
Normal file
@@ -0,0 +1,136 @@
|
||||
/**
|
||||
* Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_type.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace optiling {
|
||||
|
||||
enum class AxisEnum {
|
||||
B = 0,
|
||||
N2 = 1,
|
||||
G = 2,
|
||||
S1 = 3,
|
||||
S2 = 4,
|
||||
D = 5,
|
||||
NONE = 9,
|
||||
};
|
||||
|
||||
enum class DtypeEnum {
|
||||
FLOAT16 = 0,
|
||||
FLOAT32 = 1,
|
||||
BFLOAT16 = 2,
|
||||
FLOAT16_PRECISION = 3,
|
||||
};
|
||||
|
||||
enum class PerformanceOrientedEnum {
|
||||
BIG_BUFFER = 1,
|
||||
BIG_DOUBLE_BUFFER = 2,
|
||||
};
|
||||
|
||||
enum class MatmulConfig {
|
||||
NULL_CONFIG = 0,
|
||||
NORMAL_CONFIG = 1,
|
||||
MDL_CONFIG = 2
|
||||
};
|
||||
|
||||
enum class PseConfig {
|
||||
NO_PSE = 0,
|
||||
EXIST_PSE = 1
|
||||
};
|
||||
|
||||
enum class AttenMaskConfig {
|
||||
NO_ATTEN_MASK = 0,
|
||||
EXIST_ATTEN_MASK = 1
|
||||
};
|
||||
|
||||
enum class DropOutConfig {
|
||||
NO_DROP_OUT = 0,
|
||||
EXIST_DROP_OUT = 1
|
||||
};
|
||||
|
||||
enum class CubeFormatEnum {
|
||||
ND = 0,
|
||||
NZ = 1
|
||||
};
|
||||
enum class LayoutEnum {
|
||||
BSND = 0,
|
||||
SBND = 1,
|
||||
BNSD = 2,
|
||||
TND = 3
|
||||
};
|
||||
|
||||
enum class CubeInputSourceEnum {
|
||||
GM = 0,
|
||||
L1 = 1
|
||||
};
|
||||
|
||||
enum class OptionEnum {
|
||||
DISABLE = 0,
|
||||
ENABLE = 1
|
||||
};
|
||||
|
||||
enum class SparseEnum {
|
||||
ALL = 0,
|
||||
NONE = 1,
|
||||
ANY = 2,
|
||||
CAUSAL = 3,
|
||||
BAND = 4,
|
||||
PREFIX = 5,
|
||||
BAND_COMPRESS = 6,
|
||||
RIGHT_DOWN_CAUSAL = 7,
|
||||
RIGHT_DOWN_CAUSAL_BAND = 8,
|
||||
BAND_LEFT_UP_CAUSAL = 9
|
||||
};
|
||||
|
||||
constexpr uint64_t RecursiveSum()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
|
||||
{
|
||||
return static_cast<uint64_t>(templateId) + 10 * RecursiveSum(templateIds...);
|
||||
}
|
||||
|
||||
// TilingKey 的生成规则:
|
||||
// FlashAttentionScore/FlashAttentionScoreGrad 十进制位组装tiling key,包含以下关键参数,从低位到高位依次是:Ub0, Ub1,
|
||||
// Block, DataType, Format, Sparse, 特化模板 Ub0、Ub1:
|
||||
// 表示Ub核内切分的轴,使用枚举AxisEnum表示,因为我们允许最多切分两根轴,所以存在UB0和UB1,如果没有UB核内切分,
|
||||
// 那么填AXIS_NONE。UB0和UB1各占一个十进制位;
|
||||
// Block: 表示UB用来分核的轴,使用枚举AxisEnum表示,占一个十进制位;
|
||||
// DataType: 表示当前tiling key支持的输入输出的数据类型,使用枚举SupportedDtype来表示,占一个十进制位
|
||||
// Format: 表示当前tiling key支持的Format, 使用枚举InputLayout表示,占一个十进制位
|
||||
// Sparse: 表示当前tiling key是否支持Sparse,使用枚举SparseCapability表示,占一个十进制位
|
||||
// 其余特化场景,定义自己的位域和值
|
||||
// usage: get tilingKey from inputed types
|
||||
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
|
||||
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
|
||||
|
||||
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
|
||||
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
|
||||
{
|
||||
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
|
||||
}
|
||||
|
||||
// usage: get tilingKey from inputed types
|
||||
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
|
||||
|
||||
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
|
||||
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
|
||||
SparseEnum::sparse))
|
||||
|
||||
} // namespace optiling
|
||||
Reference in New Issue
Block a user