moe_gating_top_k (#5271)
1. What this PR does / why we need it?
This PR supports the moe_gating_top_k operator, which enables
post-positioned renormalization (renorm) on the basis of softmax.
2. Does this PR introduce any user-facing change?
No user-facing changes are required.
3. How was this patch tested?
This patch was tested with the test_npu_moe_gating_top_k test case.
vLLM version: release/v0.13.0
vLLM main:
ad32e3e19c
---------
Signed-off-by: ZCG12345 <2097562023@qq.com>
Signed-off-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,51 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file data_copy_transpose_tiling.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <graph/tensor.h>
|
||||
#include "data_copy_transpose_tiling_def.h"
|
||||
|
||||
namespace optiling {
|
||||
|
||||
inline void GetDataCopyTransposeTiling(const ge::Shape &dstShape, const ge::Shape &srcShape, const uint32_t typeSize,
|
||||
optiling::CopyTransposeTiling &tiling)
|
||||
{
|
||||
constexpr int64_t B_INDEX = 0;
|
||||
constexpr int64_t N_INDEX = 1;
|
||||
constexpr int64_t S_INDEX = 2;
|
||||
constexpr int64_t H_INDEX = 3;
|
||||
std::vector<int64_t> dstShapeInfo = dstShape.GetDims();
|
||||
std::vector<int64_t> srcShapeInfo = srcShape.GetDims();
|
||||
|
||||
tiling.set_dstShapeB(dstShapeInfo[B_INDEX]);
|
||||
tiling.set_dstShapeN(dstShapeInfo[N_INDEX]);
|
||||
tiling.set_dstShapeS(dstShapeInfo[S_INDEX]);
|
||||
tiling.set_dstShapeH(dstShapeInfo[H_INDEX]);
|
||||
tiling.set_dstShapeHN(tiling.get_dstShapeH() / tiling.get_dstShapeN());
|
||||
|
||||
tiling.set_srcShapeB(srcShapeInfo[B_INDEX]);
|
||||
tiling.set_srcShapeN(srcShapeInfo[N_INDEX]);
|
||||
tiling.set_srcShapeS(srcShapeInfo[S_INDEX]);
|
||||
tiling.set_srcShapeHN(srcShapeInfo[H_INDEX]);
|
||||
tiling.set_originalShapeNLen(tiling.get_srcShapeHN() * typeSize);
|
||||
tiling.set_shapeSHValue(tiling.get_dstShapeS() * tiling.get_dstShapeH());
|
||||
tiling.set_shapeNsValue(tiling.get_dstShapeN() * tiling.get_dstShapeS());
|
||||
tiling.set_shapeNsnValue(tiling.get_dstShapeN() * tiling.get_srcShapeS() * tiling.get_srcShapeN());
|
||||
tiling.set_shapeBHValue(tiling.get_dstShapeB() * tiling.get_dstShapeH());
|
||||
}
|
||||
|
||||
} // namespace optiling
|
||||
@@ -0,0 +1,43 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file data_copy_transpose_tiling_def.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <register/tilingdata_base.h>
|
||||
|
||||
namespace optiling {
|
||||
|
||||
BEGIN_TILING_DATA_DEF(CopyTransposeTiling)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeB);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeN);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeS);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeHN);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeH);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeB);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeN);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeS);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeHN);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, originalShapeNLen);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shapeSHValue);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shapeNsValue);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shapeNsnValue);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, invalidParamCopyTransposeTiling);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shapeBHValue);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, paramsAlign);
|
||||
END_TILING_DATA_DEF;
|
||||
REGISTER_TILING_DATA_CLASS(CopyTransposeTilingOp, CopyTransposeTiling)
|
||||
|
||||
} // namespace optiling
|
||||
56
csrc/moe_gating_top_k/tiling_base/error_log.h
Normal file
56
csrc/moe_gating_top_k/tiling_base/error_log.h
Normal file
@@ -0,0 +1,56 @@
|
||||
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
|
||||
#include <string>
|
||||
#include "toolchain/slog.h"
|
||||
|
||||
#define OP_LOGI(opname, ...)
|
||||
#define OP_LOGW(opname, ...) \
|
||||
do { \
|
||||
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
|
||||
do { \
|
||||
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGE(opname, ...) \
|
||||
do { \
|
||||
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGD(opname, ...)
|
||||
|
||||
namespace optiling {
|
||||
|
||||
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
|
||||
do { \
|
||||
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
// 修改 OP_TILING_CHECK 宏,确保正确处理表达式
|
||||
#define OP_CHECK_IF(cond, log_func, expr) \
|
||||
do { \
|
||||
if (cond) { \
|
||||
log_func; \
|
||||
expr; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
|
||||
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
|
||||
do { \
|
||||
if ((ptr) == nullptr) { \
|
||||
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
|
||||
return ge::GRAPH_FAILED; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
} // namespace optiling
|
||||
|
||||
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
256
csrc/moe_gating_top_k/tiling_base/tiling_base.h
Normal file
256
csrc/moe_gating_top_k/tiling_base/tiling_base.h
Normal file
@@ -0,0 +1,256 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_base.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <exe_graph/runtime/tiling_context.h>
|
||||
#include <graph/utils/type_utils.h>
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "error_log.h"
|
||||
|
||||
#ifdef ASCENDC_OP_TEST
|
||||
#define ASCENDC_EXTERN_C extern "C"
|
||||
#else
|
||||
#define ASCENDC_EXTERN_C
|
||||
#endif
|
||||
|
||||
namespace Ops {
|
||||
namespace Transformer {
|
||||
namespace OpTiling {
|
||||
|
||||
struct AiCoreParams {
|
||||
uint64_t ubSize = 0;
|
||||
uint64_t blockDim = 0;
|
||||
uint64_t aicNum = 0;
|
||||
uint64_t l1Size = 0;
|
||||
uint64_t l0aSize = 0;
|
||||
uint64_t l0bSize = 0;
|
||||
uint64_t l0cSize = 0;
|
||||
};
|
||||
|
||||
struct CompileInfoCommon {
|
||||
uint32_t aivNum;
|
||||
uint32_t aicNum;
|
||||
uint64_t ubSize;
|
||||
uint64_t l1Size;
|
||||
uint64_t l0aSize;
|
||||
uint64_t l0bSize;
|
||||
uint64_t l0cSize;
|
||||
uint64_t l2CacheSize;
|
||||
int64_t coreNum;
|
||||
int32_t socVersion;
|
||||
uint32_t rsvd;
|
||||
};
|
||||
|
||||
struct FlashAttentionScoreGradCompileInfo {
|
||||
uint32_t aivNum;
|
||||
uint32_t aicNum;
|
||||
uint64_t ubSize;
|
||||
uint64_t l1Size;
|
||||
uint64_t l0aSize;
|
||||
uint64_t l0bSize;
|
||||
uint64_t l0cSize;
|
||||
uint64_t l2CacheSize;
|
||||
int64_t coreNum;
|
||||
platform_ascendc::SocVersion socVersion;
|
||||
};
|
||||
|
||||
struct FACompileInfoCommon {
|
||||
uint32_t aivNum;
|
||||
uint32_t aicNum;
|
||||
uint64_t ubSize;
|
||||
uint64_t l1Size;
|
||||
uint64_t l0aSize;
|
||||
uint64_t l0bSize;
|
||||
uint64_t l0cSize;
|
||||
uint64_t l2CacheSize;
|
||||
int64_t coreNum;
|
||||
int32_t socVersion;
|
||||
uint32_t rsvd;
|
||||
};
|
||||
|
||||
class TilingBaseClass {
|
||||
public:
|
||||
explicit TilingBaseClass(gert::TilingContext* context) : context_(context)
|
||||
{}
|
||||
|
||||
virtual ~TilingBaseClass() = default;
|
||||
|
||||
// Tiling执行框架
|
||||
// 1、GRAPH_SUCCESS: 成功,并且不需要继续执行后续Tiling类的实现
|
||||
// 2、GRAPH_FAILED: 失败,中止整个Tiling流程
|
||||
// 3、GRAPH_PARAM_INVALID: 本类不支持,需要继续往下执行其他Tiling类的实现
|
||||
ge::graphStatus DoTiling()
|
||||
{
|
||||
auto ret = GetShapeAttrsInfo();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
ret = GetPlatformInfo();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
if (!IsCapable()) {
|
||||
return ge::GRAPH_PARAM_INVALID;
|
||||
}
|
||||
ret = DoOpTiling();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
ret = DoLibApiTiling();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
ret = GetWorkspaceSize();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
ret = PostTiling();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
context_->SetTilingKey(GetTilingKey());
|
||||
DumpTilingInfo();
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// 更新 context
|
||||
virtual void Reset(gert::TilingContext* context)
|
||||
{
|
||||
context_ = context;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual bool IsCapable() = 0;
|
||||
// 1、获取平台信息比如CoreNum、UB/L1/L0C资源大小
|
||||
virtual ge::graphStatus GetPlatformInfo() = 0;
|
||||
// 2、获取INPUT/OUTPUT/ATTR信息
|
||||
virtual ge::graphStatus GetShapeAttrsInfo() = 0;
|
||||
// 3、计算数据切分TilingData
|
||||
virtual ge::graphStatus DoOpTiling() = 0;
|
||||
// 4、计算高阶API的TilingData
|
||||
virtual ge::graphStatus DoLibApiTiling() = 0;
|
||||
// 5、计算TilingKey
|
||||
[[nodiscard]] virtual uint64_t GetTilingKey() const = 0;
|
||||
// 6、计算Workspace 大小
|
||||
virtual ge::graphStatus GetWorkspaceSize() = 0;
|
||||
// 7、保存Tiling数据
|
||||
virtual ge::graphStatus PostTiling() = 0;
|
||||
// 8、Dump Tiling数据
|
||||
virtual void DumpTilingInfo()
|
||||
{
|
||||
int32_t enable = CheckLogLevel(static_cast<int32_t>(OP), DLOG_DEBUG);
|
||||
if (enable != 1) {
|
||||
return;
|
||||
}
|
||||
auto buf = (uint32_t*)context_->GetRawTilingData()->GetData();
|
||||
auto bufLen = context_->GetRawTilingData()->GetDataSize();
|
||||
std::ostringstream oss;
|
||||
oss << "Start to dump tiling info. tilingkey:" << context_->GetTilingKey() << ", tiling data size:" << bufLen
|
||||
<< ", content:";
|
||||
for (size_t i = 0; i < bufLen / sizeof(uint32_t); i++) {
|
||||
oss << *(buf + i) << ",";
|
||||
if (oss.str().length() > 640) { // Split according to 640 to avoid truncation
|
||||
OP_LOGD(context_, "%s", oss.str().c_str());
|
||||
oss.str("");
|
||||
}
|
||||
}
|
||||
OP_LOGD(context_, "%s", oss.str().c_str());
|
||||
}
|
||||
|
||||
static uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum)
|
||||
{
|
||||
uint32_t ration;
|
||||
if (aicCoreNum == 0 || aivCoreNum == 0 || aicCoreNum > aivCoreNum) {
|
||||
return sliceNum;
|
||||
}
|
||||
ration = aivCoreNum / aicCoreNum;
|
||||
return (sliceNum + (ration - 1)) / ration;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] std::string GetShapeDebugStr(const T& shape) const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "[";
|
||||
if (shape.GetDimNum() > 0) {
|
||||
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
|
||||
oss << shape.GetDim(i) << ", ";
|
||||
}
|
||||
oss << shape.GetDim(shape.GetDimNum() - 1);
|
||||
}
|
||||
oss << "]";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string GetTensorDebugStr(
|
||||
const gert::StorageShape* shape, const gert::CompileTimeTensorDesc* tensor)
|
||||
{
|
||||
if (shape == nullptr || tensor == nullptr) {
|
||||
return "nil ";
|
||||
}
|
||||
std::ostringstream oss;
|
||||
oss << "(dtype: " << ge::TypeUtils::DataTypeToSerialString(tensor->GetDataType()) << "),";
|
||||
oss << "(shape:" << GetShapeDebugStr(shape->GetStorageShape()) << "),";
|
||||
oss << "(ori_shape:" << GetShapeDebugStr(shape->GetOriginShape()) << "),";
|
||||
oss << "(format: "
|
||||
<< ge::TypeUtils::FormatToSerialString(
|
||||
static_cast<ge::Format>(ge::GetPrimaryFormat(tensor->GetStorageFormat())))
|
||||
<< "),";
|
||||
oss << "(ori_format: " << ge::TypeUtils::FormatToSerialString(tensor->GetOriginFormat()) << ") ";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string GetTilingContextDebugStr()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetInputsNum(); ++i) {
|
||||
oss << "input" << i << ": ";
|
||||
oss << GetTensorDebugStr(context_->GetInputShape(i), context_->GetInputDesc(i));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetOutputsNum(); ++i) {
|
||||
oss << "output" << i << ": ";
|
||||
oss << GetTensorDebugStr(context_->GetOutputShape(i), context_->GetOutputDesc(i));
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string GetTilingDataDebugStr() const
|
||||
{
|
||||
auto rawTilingData = context_->GetRawTilingData();
|
||||
auto rawTilingDataSize = rawTilingData->GetDataSize();
|
||||
auto data = reinterpret_cast<const int32_t*>(rawTilingData->GetData());
|
||||
size_t len = rawTilingDataSize / sizeof(int32_t);
|
||||
std::ostringstream oss;
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
oss << data[i] << ", ";
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
gert::TilingContext* context_ = nullptr;
|
||||
std::unique_ptr<platform_ascendc::PlatformAscendC> ascendcPlatform_{nullptr};
|
||||
uint32_t blockDim_{0};
|
||||
uint64_t workspaceSize_{0};
|
||||
uint64_t tilingKey_{0};
|
||||
AiCoreParams aicoreParams_;
|
||||
};
|
||||
|
||||
} // namespace OpTiling
|
||||
} // namespace Transformer
|
||||
} // namespace Ops
|
||||
63
csrc/moe_gating_top_k/tiling_base/tiling_key.h
Normal file
63
csrc/moe_gating_top_k/tiling_base/tiling_key.h
Normal file
@@ -0,0 +1,63 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_key.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace Ops {
|
||||
namespace Transformer {
|
||||
namespace OpTiling {
|
||||
constexpr uint64_t RecursiveSum()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
constexpr uint64_t kBase = 10; // 10进制进位基数
|
||||
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
|
||||
{
|
||||
return static_cast<uint64_t>(templateId) + kBase * RecursiveSum(templateIds...);
|
||||
}
|
||||
|
||||
// TilingKey 的生成规则:
|
||||
// FlashAttentionScore/FlashAttentionScoreGrad 十进制位组装tiling key,包含以下关键参数,从低位到高位依次是:Ub0, Ub1,
|
||||
// Block, DataType, Format, Sparse, 特化模板 Ub0、Ub1:
|
||||
// 表示Ub核内切分的轴,使用枚举AxisEnum表示,因为我们允许最多切分两根轴,所以存在UB0和UB1,如果没有UB核内切分,
|
||||
// 那么填AXIS_NONE。UB0和UB1各占一个十进制位;
|
||||
// Block: 表示UB用来分核的轴,使用枚举AxisEnum表示,占一个十进制位;
|
||||
// DataType: 表示当前tiling key支持的输入输出的数据类型,使用枚举SupportedDtype来表示,占一个十进制位
|
||||
// Format: 表示当前tiling key支持的Format, 使用枚举InputLayout表示,占一个十进制位
|
||||
// Sparse: 表示当前tiling key是否支持Sparse,使用枚举SparseCapability表示,占一个十进制位
|
||||
// 其余特化场景,定义自己的位域和值
|
||||
// usage: get tilingKey from inputed types
|
||||
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
|
||||
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
|
||||
|
||||
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
|
||||
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
|
||||
{
|
||||
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
|
||||
}
|
||||
|
||||
// usage: get tilingKey from inputed types
|
||||
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
|
||||
|
||||
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
|
||||
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
|
||||
SparseEnum::sparse))
|
||||
|
||||
} // namespace Optiling
|
||||
} // namespace Transformer
|
||||
} // namespace Ops
|
||||
351
csrc/moe_gating_top_k/tiling_base/tiling_templates_registry.h
Normal file
351
csrc/moe_gating_top_k/tiling_base/tiling_templates_registry.h
Normal file
@@ -0,0 +1,351 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_templates_registry.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "exe_graph/runtime/tiling_context.h"
|
||||
#include "tiling_base.h"
|
||||
#include "error_log.h"
|
||||
|
||||
namespace Ops {
|
||||
namespace Transformer {
|
||||
namespace OpTiling {
|
||||
|
||||
template <typename T>
|
||||
std::unique_ptr<TilingBaseClass> TILING_CLASS(gert::TilingContext* context)
|
||||
{
|
||||
return std::unique_ptr<T>(new (std::nothrow) T(context));
|
||||
}
|
||||
|
||||
using TilingClassCase = std::unique_ptr<TilingBaseClass> (*)(gert::TilingContext*);
|
||||
|
||||
class TilingCases {
|
||||
public:
|
||||
explicit TilingCases(std::string op_type) : op_type_(std::move(op_type))
|
||||
{}
|
||||
|
||||
template <typename T>
|
||||
void AddTiling(int32_t priority)
|
||||
{
|
||||
OP_CHECK_IF(
|
||||
cases_.find(priority) != cases_.end(), OP_LOGE(op_type_, "There are duplicate registrations."), return);
|
||||
cases_[priority] = TILING_CLASS<T>;
|
||||
OP_CHECK_IF(
|
||||
cases_[priority] == nullptr,
|
||||
OP_LOGE(op_type_, "Register op tiling func failed, please check the class name."), return);
|
||||
}
|
||||
|
||||
const std::map<int32_t, TilingClassCase>& GetTilingCases()
|
||||
{
|
||||
return cases_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<int32_t, TilingClassCase> cases_;
|
||||
const std::string op_type_;
|
||||
};
|
||||
|
||||
// --------------------------------Interfacce with soc version --------------------------------
|
||||
class TilingRegistryNew {
|
||||
public:
|
||||
TilingRegistryNew() = default;
|
||||
|
||||
#ifdef ASCENDC_OP_TEST
|
||||
static TilingRegistryNew& GetInstance();
|
||||
#else
|
||||
static TilingRegistryNew& GetInstance()
|
||||
{
|
||||
static TilingRegistryNew registry_impl_;
|
||||
return registry_impl_;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::shared_ptr<TilingCases> RegisterOp(const std::string& op_type, int32_t soc_version)
|
||||
{
|
||||
auto soc_iter = registry_map_.find(soc_version);
|
||||
if (soc_iter == registry_map_.end()) {
|
||||
std::map<std::string, std::shared_ptr<TilingCases>> op_type_map;
|
||||
op_type_map[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
||||
registry_map_[soc_version] = op_type_map;
|
||||
} else {
|
||||
if (soc_iter->second.find(op_type) == soc_iter->second.end()) {
|
||||
soc_iter->second[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
||||
}
|
||||
}
|
||||
|
||||
OP_CHECK_IF(
|
||||
registry_map_[soc_version][op_type] == nullptr,
|
||||
OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr);
|
||||
return registry_map_[soc_version][op_type];
|
||||
}
|
||||
|
||||
ge::graphStatus DoTilingImpl(gert::TilingContext* context)
|
||||
{
|
||||
int32_t soc_version = (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION;
|
||||
const char* op_type = context->GetNodeType();
|
||||
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
|
||||
if (platformInfoPtr == nullptr) {
|
||||
auto compileInfoPtr = static_cast<const CompileInfoCommon*>(context->GetCompileInfo());
|
||||
OP_CHECK_IF(
|
||||
compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED);
|
||||
soc_version = compileInfoPtr->socVersion;
|
||||
OP_LOGD(context, "soc version in compileInfo is %d", soc_version);
|
||||
} else {
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||||
soc_version = static_cast<int32_t>(ascendcPlatform.GetSocVersion());
|
||||
OP_LOGD(context, "soc version is %d", soc_version);
|
||||
if (soc_version == (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION) {
|
||||
OP_LOGE(op_type, "Do op tiling failed, cannot find soc version.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version);
|
||||
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
|
||||
auto tilingTemplate = it->second(context);
|
||||
if (tilingTemplate != nullptr) {
|
||||
ge::graphStatus status = tilingTemplate->DoTiling();
|
||||
if (status != ge::GRAPH_PARAM_INVALID) {
|
||||
OP_LOGD(context, "Do general op tiling success priority=%d", it->first);
|
||||
return status;
|
||||
}
|
||||
OP_LOGD(context, "Ignore general op tiling priority=%d", it->first);
|
||||
}
|
||||
}
|
||||
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector<int32_t>& priorities)
|
||||
{
|
||||
int32_t soc_version;
|
||||
const char* op_type = context->GetNodeType();
|
||||
auto platformInfoPtr = context->GetPlatformInfo();
|
||||
if (platformInfoPtr == nullptr) {
|
||||
auto compileInfoPtr = reinterpret_cast<const CompileInfoCommon*>(context->GetCompileInfo());
|
||||
OP_CHECK_IF(
|
||||
compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED);
|
||||
soc_version = compileInfoPtr->socVersion;
|
||||
OP_LOGD(context, "soc version in compileInfo is %d", soc_version);
|
||||
} else {
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||||
soc_version = static_cast<int32_t>(ascendcPlatform.GetSocVersion());
|
||||
OP_LOGD(context, "soc version is %d", soc_version);
|
||||
}
|
||||
|
||||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version);
|
||||
for (auto priority_id : priorities) {
|
||||
auto tilingCaseIter = tilingTemplateRegistryMap.find(priority_id);
|
||||
if (tilingCaseIter != tilingTemplateRegistryMap.end()) {
|
||||
auto templateFunc = tilingCaseIter->second(context);
|
||||
if (templateFunc != nullptr) {
|
||||
ge::graphStatus status = templateFunc->DoTiling();
|
||||
if (status == ge::GRAPH_SUCCESS) {
|
||||
OP_LOGD(context, "Do general op tiling success priority=%d", priority_id);
|
||||
return status;
|
||||
}
|
||||
OP_LOGD(context, "Ignore general op tiling priority=%d", priority_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
const std::map<int32_t, TilingClassCase>& GetTilingTemplates(const std::string& op_type, int32_t soc_version)
|
||||
{
|
||||
auto soc_iter = registry_map_.find(soc_version);
|
||||
OP_CHECK_IF(
|
||||
soc_iter == registry_map_.end(),
|
||||
OP_LOGE(op_type, "Get op tiling func failed, please check the soc version %d", soc_version),
|
||||
return empty_tiling_case_);
|
||||
auto op_iter = soc_iter->second.find(op_type);
|
||||
OP_CHECK_IF(
|
||||
op_iter == soc_iter->second.end(), OP_LOGE(op_type, "Get op tiling func failed, please check the op name."),
|
||||
return empty_tiling_case_);
|
||||
return op_iter->second->GetTilingCases();
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<int32_t, std::map<std::string, std::shared_ptr<TilingCases>>> registry_map_; // key is socversion
|
||||
const std::map<int32_t, TilingClassCase> empty_tiling_case_{};
|
||||
};
|
||||
|
||||
class RegisterNew {
|
||||
public:
|
||||
explicit RegisterNew(std::string op_type) : op_type_(std::move(op_type))
|
||||
{}
|
||||
|
||||
template <typename T>
|
||||
RegisterNew& tiling(int32_t priority, int32_t soc_version)
|
||||
{
|
||||
auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version);
|
||||
OP_CHECK_IF(
|
||||
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this);
|
||||
tilingCases->AddTiling<T>(priority);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
RegisterNew& tiling(int32_t priority, const std::vector<int32_t>& soc_versions)
|
||||
{
|
||||
for (int32_t soc_version : soc_versions) {
|
||||
auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version);
|
||||
OP_CHECK_IF(
|
||||
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."),
|
||||
return *this);
|
||||
tilingCases->AddTiling<T>(priority);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string op_type_;
|
||||
};
|
||||
|
||||
// --------------------------------Interfacce without soc version --------------------------------
|
||||
class TilingRegistry {
|
||||
public:
|
||||
TilingRegistry() = default;
|
||||
|
||||
#ifdef ASCENDC_OP_TEST
|
||||
static TilingRegistry& GetInstance();
|
||||
#else
|
||||
static TilingRegistry& GetInstance()
|
||||
{
|
||||
static TilingRegistry registry_impl_;
|
||||
return registry_impl_;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::shared_ptr<TilingCases> RegisterOp(const std::string& op_type)
|
||||
{
|
||||
if (registry_map_.find(op_type) == registry_map_.end()) {
|
||||
registry_map_[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
||||
}
|
||||
OP_CHECK_IF(
|
||||
registry_map_[op_type] == nullptr,
|
||||
OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr);
|
||||
return registry_map_[op_type];
|
||||
}
|
||||
|
||||
ge::graphStatus DoTilingImpl(gert::TilingContext* context)
|
||||
{
|
||||
const char* op_type = context->GetNodeType();
|
||||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
|
||||
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
|
||||
auto tilingTemplate = it->second(context);
|
||||
if (tilingTemplate != nullptr) {
|
||||
ge::graphStatus status = tilingTemplate->DoTiling();
|
||||
if (status != ge::GRAPH_PARAM_INVALID) {
|
||||
OP_LOGD(context, "Do general op tiling success priority=%d", it->first);
|
||||
return status;
|
||||
}
|
||||
OP_LOGD(context, "Ignore general op tiling priority=%d", it->first);
|
||||
}
|
||||
}
|
||||
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector<int32_t>& priorities)
|
||||
{
|
||||
const char* op_type = context->GetNodeType();
|
||||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
|
||||
for (auto priorityId : priorities) {
|
||||
auto templateFunc = tilingTemplateRegistryMap[priorityId](context);
|
||||
if (templateFunc != nullptr) {
|
||||
ge::graphStatus status = templateFunc->DoTiling();
|
||||
if (status == ge::GRAPH_SUCCESS) {
|
||||
OP_LOGD(context, "Do general op tiling success priority=%d", priorityId);
|
||||
return status;
|
||||
}
|
||||
if (status != ge::GRAPH_PARAM_INVALID) {
|
||||
OP_LOGD(context, "Do op tiling failed");
|
||||
return status;
|
||||
}
|
||||
OP_LOGD(context, "Ignore general op tiling priority=%d", priorityId);
|
||||
}
|
||||
}
|
||||
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
const std::map<int32_t, TilingClassCase>& GetTilingTemplates(const std::string& op_type)
|
||||
{
|
||||
OP_CHECK_IF(
|
||||
registry_map_.find(op_type) == registry_map_.end(),
|
||||
OP_LOGE(op_type, "Get op tiling func failed, please check the op name."), return empty_tiling_case_);
|
||||
return registry_map_[op_type]->GetTilingCases();
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<std::string, std::shared_ptr<TilingCases>> registry_map_;
|
||||
const std::map<int32_t, TilingClassCase> empty_tiling_case_;
|
||||
};
|
||||
|
||||
class Register {
|
||||
public:
|
||||
explicit Register(std::string op_type) : op_type_(std::move(op_type))
|
||||
{}
|
||||
|
||||
template <typename T>
|
||||
Register& tiling(int32_t priority)
|
||||
{
|
||||
auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_);
|
||||
OP_CHECK_IF(
|
||||
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this);
|
||||
tilingCases->AddTiling<T>(priority);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string op_type_;
|
||||
};
|
||||
} // namespace OpTiling
|
||||
} // namespace Transformer
|
||||
} // namespace Ops
|
||||
|
||||
// op_type: 算子名称, class_name: 注册的 tiling 类, soc_version:芯片版本号
|
||||
// priority: tiling 类的优先级, 越小表示优先级越高, 即会优先选择这个tiling类
|
||||
#define REGISTER_TILING_TEMPLATE_WITH_SOCVERSION(op_type, class_name, soc_versions, priority) \
|
||||
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
||||
static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \
|
||||
Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling<class_name>(priority, soc_versions)
|
||||
|
||||
// op_type: 算子名称, class_name: 注册的 tiling 类,
|
||||
// priority: tiling 类的优先级, 越小表示优先级越高, 即被选中的概率越大
|
||||
#define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \
|
||||
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
||||
static Ops::Transformer::OpTiling::Register VAR_UNUSED##op_type_##class_name##priority_register = \
|
||||
Ops::Transformer::OpTiling::Register(op_type).tiling<class_name>(priority)
|
||||
|
||||
// op_type: 算子名称, class_name: 注册的 tiling 类,
|
||||
// soc_version: soc版本,用于区分不同的soc
|
||||
// priority: tiling 类的优先级, 越小表示优先级越高, 即会优先选择这个tiling类
|
||||
#define REGISTER_TILING_TEMPLATE_NEW(op_type, class_name, soc_version, priority) \
|
||||
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
||||
static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \
|
||||
Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling<class_name>(priority, soc_version)
|
||||
|
||||
// op_type: 算子名称, class_name: 注册的 tiling 类,
|
||||
// priority: tiling 类的优先级, 越小表示优先级越高, 即被选中的概率越大
|
||||
// 取代 REGISTER_TILING_TEMPLATE , 传入的op_type如果是字符串常量,需要去掉引号
|
||||
#define REGISTER_OPS_TILING_TEMPLATE(op_type, class_name, priority) \
|
||||
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
||||
static Ops::Transformer::OpTiling::Register \
|
||||
__attribute__((unused)) tiling_##op_type##_##class_name##_##priority##_register = \
|
||||
Ops::Transformer::OpTiling::Register(#op_type).tiling<class_name>(priority)
|
||||
139
csrc/moe_gating_top_k/tiling_base/tiling_type.h
Normal file
139
csrc/moe_gating_top_k/tiling_base/tiling_type.h
Normal file
@@ -0,0 +1,139 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_type.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace optiling {
|
||||
|
||||
enum class AxisEnum {
|
||||
B = 0,
|
||||
N2 = 1,
|
||||
G = 2,
|
||||
S1 = 3,
|
||||
S2 = 4,
|
||||
D = 5,
|
||||
NONE = 9,
|
||||
};
|
||||
|
||||
enum class DtypeEnum {
|
||||
FLOAT16 = 0,
|
||||
FLOAT32 = 1,
|
||||
BFLOAT16 = 2,
|
||||
FLOAT16_PRECISION = 3,
|
||||
};
|
||||
|
||||
enum class PerformanceOrientedEnum {
|
||||
BIG_BUFFER = 1,
|
||||
BIG_DOUBLE_BUFFER = 2,
|
||||
};
|
||||
|
||||
enum class MatmulConfig {
|
||||
NULL_CONFIG = 0,
|
||||
NORMAL_CONFIG = 1,
|
||||
MDL_CONFIG = 2
|
||||
};
|
||||
|
||||
enum class PseConfig {
|
||||
NO_PSE = 0,
|
||||
EXIST_PSE = 1
|
||||
};
|
||||
|
||||
enum class AttenMaskConfig {
|
||||
NO_ATTEN_MASK = 0,
|
||||
EXIST_ATTEN_MASK = 1
|
||||
};
|
||||
|
||||
enum class DropOutConfig {
|
||||
NO_DROP_OUT = 0,
|
||||
EXIST_DROP_OUT = 1
|
||||
};
|
||||
|
||||
enum class CubeFormatEnum {
|
||||
ND = 0,
|
||||
NZ = 1
|
||||
};
|
||||
enum class LayoutEnum {
|
||||
BSND = 0,
|
||||
SBND = 1,
|
||||
BNSD = 2,
|
||||
TND = 3,
|
||||
NTD_TND = 4
|
||||
};
|
||||
|
||||
enum class CubeInputSourceEnum {
|
||||
GM = 0,
|
||||
L1 = 1
|
||||
};
|
||||
|
||||
enum class OptionEnum {
|
||||
DISABLE = 0,
|
||||
ENABLE = 1
|
||||
};
|
||||
|
||||
enum class SparseEnum {
|
||||
ALL = 0,
|
||||
NONE = 1,
|
||||
ANY = 2,
|
||||
CAUSAL = 3,
|
||||
BAND = 4,
|
||||
PREFIX = 5,
|
||||
BAND_COMPRESS = 6,
|
||||
RIGHT_DOWN_CAUSAL = 7,
|
||||
RIGHT_DOWN_CAUSAL_BAND = 8,
|
||||
BAND_LEFT_UP_CAUSAL = 9
|
||||
};
|
||||
|
||||
constexpr uint64_t RecursiveSum()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
constexpr int64_t base10Multiplier = 10;
|
||||
|
||||
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
|
||||
{
|
||||
return static_cast<uint64_t>(templateId) + base10Multiplier * RecursiveSum(templateIds...);
|
||||
}
|
||||
|
||||
// TilingKey 的生成规则:
|
||||
// FlashAttentionScore/FlashAttentionScoreGrad 十进制位组装tiling key,包含以下关键参数,从低位到高位依次是:Ub0, Ub1,
|
||||
// Block, DataType, Format, Sparse, 特化模板 Ub0、Ub1:
|
||||
// 表示Ub核内切分的轴,使用枚举AxisEnum表示,因为我们允许最多切分两根轴,所以存在UB0和UB1,如果没有UB核内切分,
|
||||
// 那么填AXIS_NONE。UB0和UB1各占一个十进制位;
|
||||
// Block: 表示UB用来分核的轴,使用枚举AxisEnum表示,占一个十进制位;
|
||||
// DataType: 表示当前tiling key支持的输入输出的数据类型,使用枚举SupportedDtype来表示,占一个十进制位
|
||||
// Format: 表示当前tiling key支持的Format, 使用枚举InputLayout表示,占一个十进制位
|
||||
// Sparse: 表示当前tiling key是否支持Sparse,使用枚举SparseCapability表示,占一个十进制位
|
||||
// 其余特化场景,定义自己的位域和值
|
||||
// usage: get tilingKey from inputed types
|
||||
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
|
||||
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
|
||||
|
||||
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
|
||||
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
|
||||
{
|
||||
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
|
||||
}
|
||||
|
||||
// usage: get tilingKey from inputed types
|
||||
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
|
||||
|
||||
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
|
||||
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
|
||||
SparseEnum::sparse))
|
||||
|
||||
} // namespace optiling
|
||||
30
csrc/moe_gating_top_k/tiling_base/tiling_util.h
Normal file
30
csrc/moe_gating_top_k/tiling_base/tiling_util.h
Normal file
@@ -0,0 +1,30 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_util.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "register/op_impl_registry.h"
|
||||
|
||||
namespace Ops {
|
||||
namespace Transformer {
|
||||
namespace OpTiling {
|
||||
bool IsRegbaseSocVersion(const gert::TilingParseContext* context);
|
||||
|
||||
bool IsRegbaseSocVersion(const gert::TilingContext* context);
|
||||
|
||||
const gert::Shape& EnsureNotScalar(const gert::Shape& inShape);
|
||||
} // namespace OpTiling
|
||||
} // namespace Transformer
|
||||
} // namespace Ops
|
||||
Reference in New Issue
Block a user