/**  * This program is free software, you can redistribute it and/or modify.  * Copyright (c) 2025 Huawei Technologies Co., Ltd. * This file is a part of the CANN Open Software. * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). * Please refer to the License for details. You may not use this file except in compliance with the License. * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. * See LICENSE in the root of the software repository for the full text of the License. */ /*! * \file tiling_templates_registry.h * \brief */ #pragma once #include #include #include #include "exe_graph/runtime/tiling_context.h" #include "tiling_base.h" #include "error_log.h" namespace Ops { namespace Transformer { namespace OpTiling { template std::unique_ptr TILING_CLASS(gert::TilingContext* context) { return std::unique_ptr(new (std::nothrow) T(context)); } using TilingClassCase = std::unique_ptr (*)(gert::TilingContext*); class TilingCases { public: explicit TilingCases(std::string op_type) : op_type_(std::move(op_type)) {} template void AddTiling(int32_t priority) { OP_CHECK_IF( cases_.find(priority) != cases_.end(), OP_LOGE(op_type_, "There are duplicate registrations."), return); cases_[priority] = TILING_CLASS; OP_CHECK_IF( cases_[priority] == nullptr, OP_LOGE(op_type_, "Register op tiling func failed, please check the class name."), return); } const std::map& GetTilingCases() { return cases_; } private: std::map cases_; const std::string op_type_; }; // --------------------------------Interfacce with soc version -------------------------------- class TilingRegistryNew { public: TilingRegistryNew() = default; #ifdef ASCENDC_OP_TEST static TilingRegistryNew& GetInstance(); #else static TilingRegistryNew& GetInstance() { static TilingRegistryNew registry_impl_; return registry_impl_; } #endif std::shared_ptr RegisterOp(const std::string& op_type, int32_t soc_version) { auto soc_iter = registry_map_.find(soc_version); if (soc_iter == registry_map_.end()) { std::map> op_type_map; op_type_map[op_type] = std::shared_ptr(new (std::nothrow) TilingCases(op_type)); registry_map_[soc_version] = op_type_map; } else { if (soc_iter->second.find(op_type) == soc_iter->second.end()) { soc_iter->second[op_type] = std::shared_ptr(new (std::nothrow) TilingCases(op_type)); } } OP_CHECK_IF( registry_map_[soc_version][op_type] == nullptr, OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr); return registry_map_[soc_version][op_type]; } ge::graphStatus DoTilingImpl(gert::TilingContext* context) { int32_t soc_version = (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION; const char* op_type = context->GetNodeType(); fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo(); if (platformInfoPtr == nullptr) { auto compileInfoPtr = static_cast(context->GetCompileInfo()); OP_CHECK_IF( compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED); soc_version = compileInfoPtr->socVersion; OP_LOGD(context, "soc version in compileInfo is %d", soc_version); } else { auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr); soc_version = static_cast(ascendcPlatform.GetSocVersion()); OP_LOGD(context, "soc version is %d", soc_version); if (soc_version == (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION) { OP_LOGE(op_type, "Do op tiling failed, cannot find soc version."); return ge::GRAPH_FAILED; } } auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version); for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) { auto tilingTemplate = it->second(context); if (tilingTemplate != nullptr) { ge::graphStatus status = tilingTemplate->DoTiling(); if (status != ge::GRAPH_PARAM_INVALID) { OP_LOGD(context, "Do general op tiling success priority=%d", it->first); return status; } OP_LOGD(context, "Ignore general op tiling priority=%d", it->first); } } OP_LOGE(op_type, "Do op tiling failed, no valid template is found."); return ge::GRAPH_FAILED; } ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector& priorities) { int32_t soc_version; const char* op_type = context->GetNodeType(); auto platformInfoPtr = context->GetPlatformInfo(); if (platformInfoPtr == nullptr) { auto compileInfoPtr = reinterpret_cast(context->GetCompileInfo()); OP_CHECK_IF( compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED); soc_version = compileInfoPtr->socVersion; OP_LOGD(context, "soc version in compileInfo is %d", soc_version); } else { auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr); soc_version = static_cast(ascendcPlatform.GetSocVersion()); OP_LOGD(context, "soc version is %d", soc_version); } auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version); for (auto priority_id : priorities) { auto tilingCaseIter = tilingTemplateRegistryMap.find(priority_id); if (tilingCaseIter != tilingTemplateRegistryMap.end()) { auto templateFunc = tilingCaseIter->second(context); if (templateFunc != nullptr) { ge::graphStatus status = templateFunc->DoTiling(); if (status == ge::GRAPH_SUCCESS) { OP_LOGD(context, "Do general op tiling success priority=%d", priority_id); return status; } OP_LOGD(context, "Ignore general op tiling priority=%d", priority_id); } } } return ge::GRAPH_FAILED; } const std::map& GetTilingTemplates(const std::string& op_type, int32_t soc_version) { auto soc_iter = registry_map_.find(soc_version); OP_CHECK_IF( soc_iter == registry_map_.end(), OP_LOGE(op_type, "Get op tiling func failed, please check the soc version %d", soc_version), return empty_tiling_case_); auto op_iter = soc_iter->second.find(op_type); OP_CHECK_IF( op_iter == soc_iter->second.end(), OP_LOGE(op_type, "Get op tiling func failed, please check the op name."), return empty_tiling_case_); return op_iter->second->GetTilingCases(); } private: std::map>> registry_map_; // key is socversion const std::map empty_tiling_case_{}; }; class RegisterNew { public: explicit RegisterNew(std::string op_type) : op_type_(std::move(op_type)) {} template RegisterNew& tiling(int32_t priority, int32_t soc_version) { auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version); OP_CHECK_IF( tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this); tilingCases->AddTiling(priority); return *this; } template RegisterNew& tiling(int32_t priority, const std::vector& soc_versions) { for (int32_t soc_version : soc_versions) { auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version); OP_CHECK_IF( tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this); tilingCases->AddTiling(priority); } return *this; } private: const std::string op_type_; }; // --------------------------------Interfacce without soc version -------------------------------- class TilingRegistry { public: TilingRegistry() = default; #ifdef ASCENDC_OP_TEST static TilingRegistry& GetInstance(); #else static TilingRegistry& GetInstance() { static TilingRegistry registry_impl_; return registry_impl_; } #endif std::shared_ptr RegisterOp(const std::string& op_type) { if (registry_map_.find(op_type) == registry_map_.end()) { registry_map_[op_type] = std::shared_ptr(new (std::nothrow) TilingCases(op_type)); } OP_CHECK_IF( registry_map_[op_type] == nullptr, OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr); return registry_map_[op_type]; } ge::graphStatus DoTilingImpl(gert::TilingContext* context) { const char* op_type = context->GetNodeType(); auto tilingTemplateRegistryMap = GetTilingTemplates(op_type); for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) { auto tilingTemplate = it->second(context); if (tilingTemplate != nullptr) { ge::graphStatus status = tilingTemplate->DoTiling(); if (status != ge::GRAPH_PARAM_INVALID) { OP_LOGD(context, "Do general op tiling success priority=%d", it->first); return status; } OP_LOGD(context, "Ignore general op tiling priority=%d", it->first); } } OP_LOGE(op_type, "Do op tiling failed, no valid template is found."); return ge::GRAPH_FAILED; } ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector& priorities) { const char* op_type = context->GetNodeType(); auto tilingTemplateRegistryMap = GetTilingTemplates(op_type); for (auto priorityId : priorities) { auto templateFunc = tilingTemplateRegistryMap[priorityId](context); if (templateFunc != nullptr) { ge::graphStatus status = templateFunc->DoTiling(); if (status == ge::GRAPH_SUCCESS) { OP_LOGD(context, "Do general op tiling success priority=%d", priorityId); return status; } if (status != ge::GRAPH_PARAM_INVALID) { OP_LOGD(context, "Do op tiling failed"); return status; } OP_LOGD(context, "Ignore general op tiling priority=%d", priorityId); } } OP_LOGE(op_type, "Do op tiling failed, no valid template is found."); return ge::GRAPH_FAILED; } const std::map& GetTilingTemplates(const std::string& op_type) { OP_CHECK_IF( registry_map_.find(op_type) == registry_map_.end(), OP_LOGE(op_type, "Get op tiling func failed, please check the op name."), return empty_tiling_case_); return registry_map_[op_type]->GetTilingCases(); } private: std::map> registry_map_; const std::map empty_tiling_case_; }; class Register { public: explicit Register(std::string op_type) : op_type_(std::move(op_type)) {} template Register& tiling(int32_t priority) { auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_); OP_CHECK_IF( tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this); tilingCases->AddTiling(priority); return *this; } private: const std::string op_type_; }; } // namespace OpTiling } // namespace Transformer } // namespace Ops // op_type: operator name, class_name: registered tiling class, soc_version: chip version number // priority: priority of tiling class, smaller value means higher priority, i.e., this tiling class will be selected first #define REGISTER_TILING_TEMPLATE_WITH_SOCVERSION(op_type, class_name, soc_versions, priority) \ [[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \ static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \ Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling(priority, soc_versions) // op_type: operator name, class_name: registered tiling class // priority: priority of tiling class, smaller value means higher priority, i.e., higher probability of being selected #define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \ [[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \ static Ops::Transformer::OpTiling::Register VAR_UNUSED##op_type_##class_name##priority_register = \ Ops::Transformer::OpTiling::Register(op_type).tiling(priority) // op_type: operator name, class_name: registered tiling class // soc_version: SOC version, used to distinguish different SOCs // priority: priority of tiling class, smaller value means higher priority, i.e., this tiling class will be selected first #define REGISTER_TILING_TEMPLATE_NEW(op_type, class_name, soc_version, priority) \ [[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \ static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \ Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling(priority, soc_version) // op_type: operator name, class_name: registered tiling class // priority: priority of tiling class, smaller value means higher priority, i.e., higher probability of being selected // Replaces REGISTER_TILING_TEMPLATE, if op_type is a string constant, remove the quotes #define REGISTER_OPS_TILING_TEMPLATE(op_type, class_name, priority) \ [[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \ static Ops::Transformer::OpTiling::Register \ __attribute__((unused)) tiling_##op_type##_##class_name##_##priority##_register = \ Ops::Transformer::OpTiling::Register(#op_type).tiling(priority)