[qwen3 next ]add ascend c casual_conv1d_fn (#6661)

### What this PR does / why we need it?
add ascend c casual_conv1d_fn

- vLLM version: v0.15.0
- vLLM main:
13397841ab
---------
Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <63220130+ZT-AIA@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
ZT-AIA
2026-03-09 23:29:49 +08:00
committed by GitHub
parent 48b624e4cc
commit ee5347e824
26 changed files with 2504 additions and 14 deletions

View File

@@ -0,0 +1,50 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
add_ops_compile_options(
OP_NAME CausalConv1d
OPTIONS --cce-auto-sync=off
-Wno-deprecated-declarations
-Werror
)
target_sources(op_host_aclnn PRIVATE
causal_conv1d_def.cpp
)
# target_sources(opapi PRIVATE
# aclnn_causal_conv1d.cpp
# )
# if (NOT BUILD_OPEN_PROJECT)
# target_sources(aclnn_ops_train PRIVATE
# aclnn_causal_conv1d.cpp
# )
# target_sources(aclnn_ops_infer PRIVATE
# aclnn_causal_conv1d.cpp
# )
# endif ()
target_sources(optiling PRIVATE
causal_conv1d_tiling.cpp
tiling_util.cpp
)
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
target_sources(opsproto PRIVATE)
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_causal_conv1d.h")
install(FILES ${_GMM_Aclnn_header}
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
)

View File

@@ -0,0 +1,83 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file causal_conv1d_def.cpp
* \brief
*/
#include "register/op_def_registry.h"
namespace ops {
class CausalConv1d : public OpDef {
public:
explicit CausalConv1d(const char* name) : OpDef(name)
{
this->Input("x")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("weight")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("bias")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("convStates")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("queryStartLoc")
.ParamType(REQUIRED)
.DataTypeList({ge::DT_INT32})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("cacheIndices")
.ParamType(REQUIRED)
.DataTypeList({ge::DT_INT32})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("hasInitialState")
.ParamType(REQUIRED)
.DataTypeList({ge::DT_BOOL})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Output("y")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Attr("activationMode").AttrType(OPTIONAL).Int(0);
this->Attr("padSlotId").AttrType(OPTIONAL).Int(-1);
OpAICoreConfig aicoreConfig;
aicoreConfig.DynamicCompileStaticFlag(true)
.DynamicFormatFlag(false)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true)
.NeedCheckSupportFlag(false)
.PrecisionReduceFlag(true)
.ExtendCfgInfo("coreType.value", "AiCore");
this->AICore().AddConfig("ascend910b", aicoreConfig);
this->AICore().AddConfig("ascend910_93", aicoreConfig);
}
};
OP_ADD(CausalConv1d);
} // namespace ops

View File

@@ -0,0 +1,49 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file causal_conv1d_infershape.cpp
* \brief
*/
#include "register/op_impl_registry.h"
#include "error_log.h"
using namespace ge;
namespace ops {
static constexpr int64_t IDX_0 = 0;
static ge::graphStatus InferShapeCausalConv1d(gert::InferShapeContext* context)
{
// OPS_LOG_D(context->GetNodeName(), "Begin to do InferShapeCausalConv1d");
// get input shapes
const gert::Shape* xShape = context->GetInputShape(IDX_0);
OP_CHECK_NULL_WITH_CONTEXT(context, xShape);
// get output shapes
gert::Shape* yShape = context->GetOutputShape(IDX_0);
OP_CHECK_NULL_WITH_CONTEXT(context, yShape);
// 填充输出shape大小
auto xShapeSize = xShape->GetDimNum();
yShape->SetDimNum(xShapeSize);
for (size_t i = 0; i < xShapeSize; i++) {
int64_t dim = xShape->GetDim(i);
yShape->SetDim(i, dim);
}
// OPS_LOG_D(context->GetNodeName(), "End to do InferShapeCausalConv1d");
return GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(CausalConv1d).InferShape(InferShapeCausalConv1d);
} // namespace ops

View File

@@ -0,0 +1,365 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file causal_conv1d_tiling.cpp
* \brief
*/
// #include "error_log.h"
#include "log/ops_log.h"
#include "../tiling_base/tiling_templates_registry.h"
#include "../tiling_base/tiling_util.h"
#include "math_util.h"
#include "causal_conv1d_tiling.h"
#include "../op_kernel/causal_conv1d_tiling_key.h"
#include <set>
#include <limits>
namespace optiling {
using namespace Ops::Transformer::OpTiling;
constexpr uint32_t X_INDEX = 0;
constexpr uint32_t WEIGHT_INDEX = 1;
constexpr uint32_t BIAS_INDEX = 2;
constexpr uint32_t CONV_STATES_INDEX = 3;
constexpr uint32_t QUERY_START_LOC_INDEX = 4;
constexpr uint32_t CACHE_INDICES_INDEX = 5;
constexpr uint32_t HAS_INITIAL_STATE_INDEX = 6;
constexpr int32_t ATTR_ACTIVATION_MODE_INDEX = 0;
constexpr int32_t ATTR_PAD_SLOT_ID_INDEX = 1;
struct DimTileChoice {
int64_t dimTileSize = 0;
int64_t blocksPerSeq = 0;
int64_t gridSize = 0;
};
static inline DimTileChoice ChooseDimTileSize(gert::TilingContext* context, int64_t batch, int64_t dim, uint32_t coreNum)
{
const int64_t candidates[] = {4096, 2048, 1024, 512,384};
DimTileChoice bestOver;
int64_t bestOverGap = std::numeric_limits<int64_t>::max();
DimTileChoice bestUnder;
for (int64_t dimTileSize : candidates) {
if (dim % dimTileSize != 0) {
continue;
}
const int64_t blocksPerSeq = dim / dimTileSize;
const int64_t gridSize = batch * blocksPerSeq;
if (gridSize <= 0) {
continue;
}
if (gridSize >= static_cast<int64_t>(coreNum)) {
const int64_t gap = gridSize - static_cast<int64_t>(coreNum);
if (gap < bestOverGap) {
bestOver.dimTileSize = dimTileSize;
bestOver.blocksPerSeq = blocksPerSeq;
bestOver.gridSize = gridSize;
bestOverGap = gap;
}
} else if (gridSize > bestUnder.gridSize ||
(gridSize == bestUnder.gridSize && dimTileSize < bestUnder.dimTileSize)) {
bestUnder.dimTileSize = dimTileSize;
bestUnder.blocksPerSeq = blocksPerSeq;
bestUnder.gridSize = gridSize;
}
}
DimTileChoice result = (bestOver.dimTileSize != 0) ? bestOver : bestUnder;
return result;
}
static ge::graphStatus GetPlatformInfo(gert::TilingContext* context, uint64_t& ubSize, uint32_t& coreNum)
{
auto compileInfoPtr = context->GetCompileInfo<CausalConv1dCompileInfo>();
if (compileInfoPtr != nullptr && compileInfoPtr->coreNum != 0 && compileInfoPtr->ubSize != 0) {
ubSize = compileInfoPtr->ubSize;
coreNum = compileInfoPtr->coreNum;
return ge::GRAPH_SUCCESS;
}
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
coreNum = ascendcPlatform.GetCoreNumAiv();
if(coreNum == 0) {
return ge::GRAPH_FAILED;
}
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
if(ubSize == 0) {
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetWorkspaceSize(gert::TilingContext* context)
{
size_t* currentWorkspace = context->GetWorkspaceSizes(1);
OP_CHECK_NULL_WITH_CONTEXT(context, currentWorkspace);
currentWorkspace[0] = 0;
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetAttrsInfo(gert::TilingContext* context, int64_t& activationMode, int64_t& padSlotId)
{
auto attrs = context->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
const int64_t* activationModePtr = attrs->GetAttrPointer<int64_t>(ATTR_ACTIVATION_MODE_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, activationModePtr);
activationMode = *activationModePtr;
if(activationMode != 0 && activationMode != 1){
return ge::GRAPH_FAILED;
}
const int64_t* padSlotIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_PAD_SLOT_ID_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, padSlotIdPtr);
padSlotId = *padSlotIdPtr;
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetShapeDtypeInfo(gert::TilingContext* context, CausalConv1dTilingData& tiling)
{
auto xShapePtr = context->GetInputShape(X_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, xShapePtr);
auto xShape = EnsureNotScalar(xShapePtr->GetStorageShape());
int64_t dim = 0;
int64_t cuSeqlen = 0;
int64_t seqLen = 0;
int64_t batch = 0;
int64_t inputMode = 0;
if (xShape.GetDimNum() == 2) {
inputMode = 0;
cuSeqlen = xShape.GetDim(0);
dim = xShape.GetDim(1);
seqLen = 0;
if(dim <= 0 || cuSeqlen < 0){
return ge::GRAPH_FAILED;
}
} else if (xShape.GetDimNum() == 3) {
inputMode = 1;
batch = xShape.GetDim(0);
seqLen = xShape.GetDim(1);
dim = xShape.GetDim(2);
cuSeqlen = batch * seqLen;
if(batch <= 0 || dim <= 0 || seqLen <= 0){
return ge::GRAPH_FAILED;
}
} else {
return ge::GRAPH_FAILED;
}
auto wShapePtr = context->GetInputShape(WEIGHT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, wShapePtr);
auto wShape = EnsureNotScalar(wShapePtr->GetStorageShape());
if(wShape.GetDimNum() != 2){
return ge::GRAPH_FAILED;
}
const int64_t width = wShape.GetDim(0);
const int64_t wDim = wShape.GetDim(1);
if(wDim != dim){
return ge::GRAPH_FAILED;
}
if(width != 4){
return ge::GRAPH_FAILED;
}
auto sShapePtr = context->GetInputShape(CONV_STATES_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, sShapePtr);
auto sShape = EnsureNotScalar(sShapePtr->GetStorageShape());
if(sShape.GetDimNum() != 3){
return ge::GRAPH_FAILED;
}
const int64_t numCacheLines = sShape.GetDim(0);
const int64_t stateLen = sShape.GetDim(1);
const int64_t sDim = sShape.GetDim(2);
if(numCacheLines <= 0){
return ge::GRAPH_FAILED;}
if(sDim != dim){
return ge::GRAPH_FAILED;}
if(stateLen < (width - 1)){
return ge::GRAPH_FAILED;}
auto qslShapePtr = context->GetInputShape(QUERY_START_LOC_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, qslShapePtr);
auto qslShape = EnsureNotScalar(qslShapePtr->GetStorageShape());
if(qslShape.GetDimNum() != 1){
return ge::GRAPH_FAILED;}
const int64_t qslSize = qslShape.GetDim(0);
if(qslSize < 1){
return ge::GRAPH_FAILED;}
if (inputMode == 0) {
batch = qslSize - 1;
}
if (inputMode == 1) {
if(qslSize != batch + 1){
return ge::GRAPH_FAILED;
}
}
auto ciShapePtr = context->GetInputShape(CACHE_INDICES_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, ciShapePtr);
auto ciShape = EnsureNotScalar(ciShapePtr->GetStorageShape());
if(ciShape.GetDimNum() != 1){return ge::GRAPH_FAILED;}
if(ciShape.GetDim(0) != batch){return ge::GRAPH_FAILED;}
auto hisShapePtr = context->GetInputShape(HAS_INITIAL_STATE_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, hisShapePtr);
auto hisShape = EnsureNotScalar(hisShapePtr->GetStorageShape());
if(hisShape.GetDimNum() != 1){
return ge::GRAPH_FAILED;}
if(hisShape.GetDim(0) != batch){
return ge::GRAPH_FAILED;}
tiling.set_hasBias(0);
auto biasShapePtr = context->GetOptionalInputShape(BIAS_INDEX);
if (biasShapePtr != nullptr && biasShapePtr->GetStorageShape().GetDimNum() != 0) {
auto biasShape = EnsureNotScalar(biasShapePtr->GetStorageShape());
if(biasShape.GetDimNum() != 1){
return ge::GRAPH_FAILED;}
if(biasShape.GetDim(0) != dim){
return ge::GRAPH_FAILED;}
tiling.set_hasBias(1);
}
const std::set<ge::DataType> supportedXDtype = {ge::DT_BF16, ge::DT_FLOAT16};
auto xDesc = context->GetInputDesc(X_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, xDesc);
const ge::DataType xDtype = xDesc->GetDataType();
if(supportedXDtype.count(xDtype) == 0){
return ge::GRAPH_FAILED;}
auto wDesc = context->GetInputDesc(WEIGHT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, wDesc);
if(wDesc->GetDataType() != xDtype){
return ge::GRAPH_FAILED;}
if (tiling.get_hasBias() == 1) {
auto biasDesc = context->GetOptionalInputDesc(BIAS_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, biasDesc);
if(biasDesc->GetDataType() != xDtype){
return ge::GRAPH_FAILED;}
}
auto sDesc = context->GetInputDesc(CONV_STATES_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, sDesc);
if(sDesc->GetDataType() != xDtype){
return ge::GRAPH_FAILED;}
auto qslDesc = context->GetInputDesc(QUERY_START_LOC_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, qslDesc);
if(qslDesc->GetDataType() != ge::DT_INT32){
return ge::GRAPH_FAILED;}
auto ciDesc = context->GetInputDesc(CACHE_INDICES_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, ciDesc);
if(ciDesc->GetDataType() != ge::DT_INT32){
return ge::GRAPH_FAILED;}
auto hisDesc = context->GetInputDesc(HAS_INITIAL_STATE_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, hisDesc);
if(hisDesc->GetDataType() != ge::DT_BOOL){
return ge::GRAPH_FAILED;}
tiling.set_dim(dim);
tiling.set_cuSeqlen(cuSeqlen);
tiling.set_seqLen(seqLen);
tiling.set_inputMode(inputMode);
tiling.set_width(width);
tiling.set_stateLen(stateLen);
tiling.set_numCacheLines(numCacheLines);
tiling.set_batch(batch);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CausalConv1dTilingFunc(gert::TilingContext* context)
{
uint64_t ubSize;
uint32_t coreNum;
if( GetPlatformInfo(context, ubSize, coreNum) != ge::GRAPH_SUCCESS){
return ge::GRAPH_FAILED;
}
if(GetWorkspaceSize(context) != ge::GRAPH_SUCCESS){
return ge::GRAPH_FAILED;
}
CausalConv1dTilingData tilingData;
int64_t activationMode = 0;
int64_t padSlotId = -1;
if(GetAttrsInfo(context, activationMode, padSlotId) != ge::GRAPH_SUCCESS){
return ge::GRAPH_FAILED;
}
tilingData.set_activationMode(activationMode);
tilingData.set_padSlotId(padSlotId);
if( GetShapeDtypeInfo(context, tilingData) != ge::GRAPH_SUCCESS){
return ge::GRAPH_FAILED;
}
const int64_t dim = tilingData.get_dim();
const int64_t batch = tilingData.get_batch();
if(dim <= 0 || batch <= 0){
return ge::GRAPH_FAILED;
}
const DimTileChoice choice = ChooseDimTileSize(context, batch, dim, coreNum);
const uint32_t blockDim = (choice.gridSize < static_cast<int64_t>(coreNum))
? static_cast<uint32_t>(choice.gridSize)
: coreNum;
context->SetBlockDim(blockDim);
tilingData.set_dimTileSize(choice.dimTileSize);
tilingData.set_blocksPerSeq(choice.blocksPerSeq);
const uint64_t tilingKey = GET_TPL_TILING_KEY(CAUSAL_CONV1D_TPL_SCH_MODE_DEFAULT);
context->SetTilingKey(tilingKey);
tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingParseForCausalConv1d(gert::TilingParseContext* context)
{
auto platformInfoPtr = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
auto compileInfoPtr = context->GetCompiledInfo<CausalConv1dCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfoPtr);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
compileInfoPtr->coreNum = static_cast<uint32_t>(ascendcPlatform.GetCoreNumAiv());
if(compileInfoPtr->coreNum == 0){
return ge::GRAPH_FAILED;
}
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize);
if(compileInfoPtr->ubSize == 0){
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(CausalConv1d)
.Tiling(CausalConv1dTilingFunc)
.TilingParse<CausalConv1dCompileInfo>(TilingParseForCausalConv1d);
} // namespace optiling

View File

@@ -0,0 +1,60 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file causal_conv1d_tiling_data.h
* \brief
*/
#ifndef ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H
#define ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H
#include <cstdint>
// #include "register/tilingdata_base.h"
// #include "tiling/tiling_api.h"
#include "register/tilingdata_base.h"
#include "error_log.h"
#include "register/op_impl_registry.h"
#include "tiling/platform/platform_ascendc.h"
#include "platform/platform_infos_def.h"
namespace optiling {
BEGIN_TILING_DATA_DEF(CausalConv1dTilingData)
TILING_DATA_FIELD_DEF(int64_t, dim);
TILING_DATA_FIELD_DEF(int64_t, cuSeqlen);
TILING_DATA_FIELD_DEF(int64_t, seqLen);
TILING_DATA_FIELD_DEF(int64_t, inputMode);
TILING_DATA_FIELD_DEF(int64_t, width);
TILING_DATA_FIELD_DEF(int64_t, stateLen);
TILING_DATA_FIELD_DEF(int64_t, numCacheLines);
TILING_DATA_FIELD_DEF(int64_t, batch);
TILING_DATA_FIELD_DEF(int64_t, activationMode);
TILING_DATA_FIELD_DEF(int64_t, padSlotId);
TILING_DATA_FIELD_DEF(int64_t, hasBias);
TILING_DATA_FIELD_DEF(int64_t, dimTileSize);
TILING_DATA_FIELD_DEF(int64_t, blocksPerSeq);
END_TILING_DATA_DEF;
struct CausalConv1dCompileInfo {
uint64_t ubSize = 0;
uint32_t coreNum = 0;
};
REGISTER_TILING_DATA_CLASS(CausalConv1d, CausalConv1dTilingData)
} // namespace optiling
#endif // ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H

View File

@@ -0,0 +1,71 @@
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#include <string>
#include "toolchain/slog.h"
#define OP_LOGI(opname, ...)
#define OP_LOGW(opname, ...) \
do { \
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
do { \
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE(opname, ...) \
do { \
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGD(opname, ...)
namespace optiling {
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
do { \
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
} while (0)
#define OP_CHECK_IF(cond, log_func, expr) \
do { \
if (cond) { \
log_func; \
expr; \
} \
} while (0)
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
do { \
if ((ptr) == nullptr) { \
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
return ge::GRAPH_FAILED; \
} \
} while (0)
} // namespace optiling
template <typename T>
T CeilAlign(T a, T b)
{
return (a + b - 1) / b * b;
}
template <typename T>
T CeilDiv(T a, T b)
{
if (b == 0) {
return a;
}
return (a + b - 1) / b;
}
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_

View File

@@ -0,0 +1,61 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file math_util.h
* \brief
*/
#ifndef TILING_MATMUL_MATH_UTIL_H
#define TILING_MATMUL_MATH_UTIL_H
#include <array>
#include <cstdint>
#include <vector>
#include <utility>
namespace matmul_tiling {
class MathUtil {
public:
static bool IsEqual(float leftValue, float rightValue);
template<typename T>
static auto CeilDivision(T num1, T num2) -> T
{
if (num2 == 0) {
return 0;
}
return static_cast<T>((static_cast<int64_t>(num1) + static_cast<int64_t>(num2) - 1) /
static_cast<int64_t>(num2));
}
template<typename T>
static auto Align(T num1, T num2) -> T
{
return CeilDivision(num1, num2) * num2;
}
static int32_t AlignDown(int32_t num1, int32_t num2);
static bool CheckMulOverflow(int32_t a, int32_t b, int32_t &c);
static int32_t MapShape(int32_t shape, bool roundUpFlag = true);
static void AddFactor(std::vector<int32_t> &dimsFactors, int32_t dim);
static void GetFactorCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
const int32_t factorEnd);
static void GetFactorLayerCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
const int32_t factorEnd);
static bool CheckFactorNumSatisfy(const int32_t dim);
static int32_t FindBestSingleCore(const int32_t oriShape, const int32_t mappedShape, const int32_t coreNum,
bool isKDim);
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t minFactor, int32_t maxFactor);
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
static void GetBlockFactors(std::vector<int32_t> &factorList, const int32_t oriShape, const int32_t mpShape,
const int32_t coreNum, const int32_t maxNum);
static int32_t GetNonFactorMap(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
static std::vector<std::pair<int, int>> GetFactorPairs(int32_t num);
static std::pair<int32_t, int32_t> DivideIntoMainAndTail(int32_t num, int32_t divisor);
};
} // namespace matmul_tiling
#endif // _MATH_UTIL_H_

View File

@@ -0,0 +1,31 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_util.cpp
* \brief
*/
#include "../tiling_base/tiling_util.h"
namespace Ops {
namespace Transformer {
namespace OpTiling {
static const gert::Shape g_vec_1_shape = {1};
const gert::Shape &EnsureNotScalar(const gert::Shape &inShape)
{
if (inShape.IsScalar()) {
return g_vec_1_shape;
}
return inShape;
}
} // namespace OpTiling
} // namespace Transformer
} // namespace Ops