### What this PR does / why we need it?
add ascend c casual_conv1d_fn
- vLLM version: v0.15.0
- vLLM main:
13397841ab
---------
Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <63220130+ZT-AIA@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
352 lines
15 KiB
C++
352 lines
15 KiB
C++
/**
|
||
* This program is free software, you can redistribute it and/or modify.
|
||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||
* This file is a part of the CANN Open Software.
|
||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||
* See LICENSE in the root of the software repository for the full text of the License.
|
||
*/
|
||
|
||
/*!
|
||
* \file tiling_templates_registry.h
|
||
* \brief
|
||
*/
|
||
|
||
#pragma once
|
||
|
||
#include <map>
|
||
#include <string>
|
||
#include <memory>
|
||
#include "exe_graph/runtime/tiling_context.h"
|
||
#include "tiling_base.h"
|
||
#include "error_log.h"
|
||
|
||
namespace Ops {
|
||
namespace Transformer {
|
||
namespace OpTiling {
|
||
|
||
template <typename T>
|
||
std::unique_ptr<TilingBaseClass> TILING_CLASS(gert::TilingContext* context)
|
||
{
|
||
return std::unique_ptr<T>(new (std::nothrow) T(context));
|
||
}
|
||
|
||
using TilingClassCase = std::unique_ptr<TilingBaseClass> (*)(gert::TilingContext*);
|
||
|
||
class TilingCases {
|
||
public:
|
||
explicit TilingCases(std::string op_type) : op_type_(std::move(op_type))
|
||
{}
|
||
|
||
template <typename T>
|
||
void AddTiling(int32_t priority)
|
||
{
|
||
OP_CHECK_IF(
|
||
cases_.find(priority) != cases_.end(), OP_LOGE(op_type_, "There are duplicate registrations."), return);
|
||
cases_[priority] = TILING_CLASS<T>;
|
||
OP_CHECK_IF(
|
||
cases_[priority] == nullptr,
|
||
OP_LOGE(op_type_, "Register op tiling func failed, please check the class name."), return);
|
||
}
|
||
|
||
const std::map<int32_t, TilingClassCase>& GetTilingCases()
|
||
{
|
||
return cases_;
|
||
}
|
||
|
||
private:
|
||
std::map<int32_t, TilingClassCase> cases_;
|
||
const std::string op_type_;
|
||
};
|
||
|
||
// --------------------------------Interfacce with soc version --------------------------------
|
||
class TilingRegistryNew {
|
||
public:
|
||
TilingRegistryNew() = default;
|
||
|
||
#ifdef ASCENDC_OP_TEST
|
||
static TilingRegistryNew& GetInstance();
|
||
#else
|
||
static TilingRegistryNew& GetInstance()
|
||
{
|
||
static TilingRegistryNew registry_impl_;
|
||
return registry_impl_;
|
||
}
|
||
#endif
|
||
|
||
std::shared_ptr<TilingCases> RegisterOp(const std::string& op_type, int32_t soc_version)
|
||
{
|
||
auto soc_iter = registry_map_.find(soc_version);
|
||
if (soc_iter == registry_map_.end()) {
|
||
std::map<std::string, std::shared_ptr<TilingCases>> op_type_map;
|
||
op_type_map[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
||
registry_map_[soc_version] = op_type_map;
|
||
} else {
|
||
if (soc_iter->second.find(op_type) == soc_iter->second.end()) {
|
||
soc_iter->second[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
||
}
|
||
}
|
||
|
||
OP_CHECK_IF(
|
||
registry_map_[soc_version][op_type] == nullptr,
|
||
OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr);
|
||
return registry_map_[soc_version][op_type];
|
||
}
|
||
|
||
ge::graphStatus DoTilingImpl(gert::TilingContext* context)
|
||
{
|
||
int32_t soc_version = (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION;
|
||
const char* op_type = context->GetNodeType();
|
||
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
|
||
if (platformInfoPtr == nullptr) {
|
||
auto compileInfoPtr = static_cast<const CompileInfoCommon*>(context->GetCompileInfo());
|
||
OP_CHECK_IF(
|
||
compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED);
|
||
soc_version = compileInfoPtr->socVersion;
|
||
OP_LOGD(context, "soc version in compileInfo is %d", soc_version);
|
||
} else {
|
||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||
soc_version = static_cast<int32_t>(ascendcPlatform.GetSocVersion());
|
||
OP_LOGD(context, "soc version is %d", soc_version);
|
||
if (soc_version == (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION) {
|
||
OP_LOGE(op_type, "Do op tiling failed, cannot find soc version.");
|
||
return ge::GRAPH_FAILED;
|
||
}
|
||
}
|
||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version);
|
||
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
|
||
auto tilingTemplate = it->second(context);
|
||
if (tilingTemplate != nullptr) {
|
||
ge::graphStatus status = tilingTemplate->DoTiling();
|
||
if (status != ge::GRAPH_PARAM_INVALID) {
|
||
OP_LOGD(context, "Do general op tiling success priority=%d", it->first);
|
||
return status;
|
||
}
|
||
OP_LOGD(context, "Ignore general op tiling priority=%d", it->first);
|
||
}
|
||
}
|
||
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
|
||
return ge::GRAPH_FAILED;
|
||
}
|
||
|
||
ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector<int32_t>& priorities)
|
||
{
|
||
int32_t soc_version;
|
||
const char* op_type = context->GetNodeType();
|
||
auto platformInfoPtr = context->GetPlatformInfo();
|
||
if (platformInfoPtr == nullptr) {
|
||
auto compileInfoPtr = reinterpret_cast<const CompileInfoCommon*>(context->GetCompileInfo());
|
||
OP_CHECK_IF(
|
||
compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED);
|
||
soc_version = compileInfoPtr->socVersion;
|
||
OP_LOGD(context, "soc version in compileInfo is %d", soc_version);
|
||
} else {
|
||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||
soc_version = static_cast<int32_t>(ascendcPlatform.GetSocVersion());
|
||
OP_LOGD(context, "soc version is %d", soc_version);
|
||
}
|
||
|
||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version);
|
||
for (auto priority_id : priorities) {
|
||
auto tilingCaseIter = tilingTemplateRegistryMap.find(priority_id);
|
||
if (tilingCaseIter != tilingTemplateRegistryMap.end()) {
|
||
auto templateFunc = tilingCaseIter->second(context);
|
||
if (templateFunc != nullptr) {
|
||
ge::graphStatus status = templateFunc->DoTiling();
|
||
if (status == ge::GRAPH_SUCCESS) {
|
||
OP_LOGD(context, "Do general op tiling success priority=%d", priority_id);
|
||
return status;
|
||
}
|
||
OP_LOGD(context, "Ignore general op tiling priority=%d", priority_id);
|
||
}
|
||
}
|
||
}
|
||
return ge::GRAPH_FAILED;
|
||
}
|
||
|
||
const std::map<int32_t, TilingClassCase>& GetTilingTemplates(const std::string& op_type, int32_t soc_version)
|
||
{
|
||
auto soc_iter = registry_map_.find(soc_version);
|
||
OP_CHECK_IF(
|
||
soc_iter == registry_map_.end(),
|
||
OP_LOGE(op_type, "Get op tiling func failed, please check the soc version %d", soc_version),
|
||
return empty_tiling_case_);
|
||
auto op_iter = soc_iter->second.find(op_type);
|
||
OP_CHECK_IF(
|
||
op_iter == soc_iter->second.end(), OP_LOGE(op_type, "Get op tiling func failed, please check the op name."),
|
||
return empty_tiling_case_);
|
||
return op_iter->second->GetTilingCases();
|
||
}
|
||
|
||
private:
|
||
std::map<int32_t, std::map<std::string, std::shared_ptr<TilingCases>>> registry_map_; // key is socversion
|
||
const std::map<int32_t, TilingClassCase> empty_tiling_case_{};
|
||
};
|
||
|
||
class RegisterNew {
|
||
public:
|
||
explicit RegisterNew(std::string op_type) : op_type_(std::move(op_type))
|
||
{}
|
||
|
||
template <typename T>
|
||
RegisterNew& tiling(int32_t priority, int32_t soc_version)
|
||
{
|
||
auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version);
|
||
OP_CHECK_IF(
|
||
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this);
|
||
tilingCases->AddTiling<T>(priority);
|
||
return *this;
|
||
}
|
||
|
||
template <typename T>
|
||
RegisterNew& tiling(int32_t priority, const std::vector<int32_t>& soc_versions)
|
||
{
|
||
for (int32_t soc_version : soc_versions) {
|
||
auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version);
|
||
OP_CHECK_IF(
|
||
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."),
|
||
return *this);
|
||
tilingCases->AddTiling<T>(priority);
|
||
}
|
||
return *this;
|
||
}
|
||
|
||
private:
|
||
const std::string op_type_;
|
||
};
|
||
|
||
// --------------------------------Interfacce without soc version --------------------------------
|
||
class TilingRegistry {
|
||
public:
|
||
TilingRegistry() = default;
|
||
|
||
#ifdef ASCENDC_OP_TEST
|
||
static TilingRegistry& GetInstance();
|
||
#else
|
||
static TilingRegistry& GetInstance()
|
||
{
|
||
static TilingRegistry registry_impl_;
|
||
return registry_impl_;
|
||
}
|
||
#endif
|
||
|
||
std::shared_ptr<TilingCases> RegisterOp(const std::string& op_type)
|
||
{
|
||
if (registry_map_.find(op_type) == registry_map_.end()) {
|
||
registry_map_[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
||
}
|
||
OP_CHECK_IF(
|
||
registry_map_[op_type] == nullptr,
|
||
OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr);
|
||
return registry_map_[op_type];
|
||
}
|
||
|
||
ge::graphStatus DoTilingImpl(gert::TilingContext* context)
|
||
{
|
||
const char* op_type = context->GetNodeType();
|
||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
|
||
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
|
||
auto tilingTemplate = it->second(context);
|
||
if (tilingTemplate != nullptr) {
|
||
ge::graphStatus status = tilingTemplate->DoTiling();
|
||
if (status != ge::GRAPH_PARAM_INVALID) {
|
||
OP_LOGD(context, "Do general op tiling success priority=%d", it->first);
|
||
return status;
|
||
}
|
||
OP_LOGD(context, "Ignore general op tiling priority=%d", it->first);
|
||
}
|
||
}
|
||
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
|
||
return ge::GRAPH_FAILED;
|
||
}
|
||
|
||
ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector<int32_t>& priorities)
|
||
{
|
||
const char* op_type = context->GetNodeType();
|
||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
|
||
for (auto priorityId : priorities) {
|
||
auto templateFunc = tilingTemplateRegistryMap[priorityId](context);
|
||
if (templateFunc != nullptr) {
|
||
ge::graphStatus status = templateFunc->DoTiling();
|
||
if (status == ge::GRAPH_SUCCESS) {
|
||
OP_LOGD(context, "Do general op tiling success priority=%d", priorityId);
|
||
return status;
|
||
}
|
||
if (status != ge::GRAPH_PARAM_INVALID) {
|
||
OP_LOGD(context, "Do op tiling failed");
|
||
return status;
|
||
}
|
||
OP_LOGD(context, "Ignore general op tiling priority=%d", priorityId);
|
||
}
|
||
}
|
||
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
|
||
return ge::GRAPH_FAILED;
|
||
}
|
||
|
||
const std::map<int32_t, TilingClassCase>& GetTilingTemplates(const std::string& op_type)
|
||
{
|
||
OP_CHECK_IF(
|
||
registry_map_.find(op_type) == registry_map_.end(),
|
||
OP_LOGE(op_type, "Get op tiling func failed, please check the op name."), return empty_tiling_case_);
|
||
return registry_map_[op_type]->GetTilingCases();
|
||
}
|
||
|
||
private:
|
||
std::map<std::string, std::shared_ptr<TilingCases>> registry_map_;
|
||
const std::map<int32_t, TilingClassCase> empty_tiling_case_;
|
||
};
|
||
|
||
class Register {
|
||
public:
|
||
explicit Register(std::string op_type) : op_type_(std::move(op_type))
|
||
{}
|
||
|
||
template <typename T>
|
||
Register& tiling(int32_t priority)
|
||
{
|
||
auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_);
|
||
OP_CHECK_IF(
|
||
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this);
|
||
tilingCases->AddTiling<T>(priority);
|
||
return *this;
|
||
}
|
||
|
||
private:
|
||
const std::string op_type_;
|
||
};
|
||
} // namespace OpTiling
|
||
} // namespace Transformer
|
||
} // namespace Ops
|
||
|
||
// op_type: operator name, class_name: registered tiling class, soc_version: chip version number
|
||
// priority: priority of tiling class, smaller value means higher priority, i.e., this tiling class will be selected first
|
||
#define REGISTER_TILING_TEMPLATE_WITH_SOCVERSION(op_type, class_name, soc_versions, priority) \
|
||
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
||
static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \
|
||
Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling<class_name>(priority, soc_versions)
|
||
|
||
// op_type: operator name, class_name: registered tiling class
|
||
// priority: priority of tiling class, smaller value means higher priority, i.e., higher probability of being selected
|
||
#define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \
|
||
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
||
static Ops::Transformer::OpTiling::Register VAR_UNUSED##op_type_##class_name##priority_register = \
|
||
Ops::Transformer::OpTiling::Register(op_type).tiling<class_name>(priority)
|
||
|
||
// op_type: operator name, class_name: registered tiling class
|
||
// soc_version: SOC version, used to distinguish different SOCs
|
||
// priority: priority of tiling class, smaller value means higher priority, i.e., this tiling class will be selected first
|
||
#define REGISTER_TILING_TEMPLATE_NEW(op_type, class_name, soc_version, priority) \
|
||
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
||
static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \
|
||
Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling<class_name>(priority, soc_version)
|
||
|
||
// op_type: operator name, class_name: registered tiling class
|
||
// priority: priority of tiling class, smaller value means higher priority, i.e., higher probability of being selected
|
||
// Replaces REGISTER_TILING_TEMPLATE, if op_type is a string constant, remove the quotes
|
||
#define REGISTER_OPS_TILING_TEMPLATE(op_type, class_name, priority) \
|
||
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
||
static Ops::Transformer::OpTiling::Register \
|
||
__attribute__((unused)) tiling_##op_type##_##class_name##_##priority##_register = \
|
||
Ops::Transformer::OpTiling::Register(#op_type).tiling<class_name>(priority)
|