[qwen3 next ]add ascend c casual_conv1d_fn (#6661)

### What this PR does / why we need it?
add ascend c casual_conv1d_fn

- vLLM version: v0.15.0
- vLLM main:
13397841ab
---------
Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <63220130+ZT-AIA@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
ZT-AIA
2026-03-09 23:29:49 +08:00
committed by GitHub
parent 48b624e4cc
commit ee5347e824
26 changed files with 2504 additions and 14 deletions

View File

@@ -24,7 +24,8 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;"
CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;causal_conv1d;"
SOC_ARG="ascend910b"
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
# ASCEND910C (A3) series
@@ -63,6 +64,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
"add_rms_norm_bias"
"apply_top_k_top_p_custom"
"transpose_kv_cache_by_block"
"causal_conv1d"
"moe_grouped_matmul"
)
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")

View File

@@ -0,0 +1,50 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
add_ops_compile_options(
OP_NAME CausalConv1d
OPTIONS --cce-auto-sync=off
-Wno-deprecated-declarations
-Werror
)
target_sources(op_host_aclnn PRIVATE
causal_conv1d_def.cpp
)
# target_sources(opapi PRIVATE
# aclnn_causal_conv1d.cpp
# )
# if (NOT BUILD_OPEN_PROJECT)
# target_sources(aclnn_ops_train PRIVATE
# aclnn_causal_conv1d.cpp
# )
# target_sources(aclnn_ops_infer PRIVATE
# aclnn_causal_conv1d.cpp
# )
# endif ()
target_sources(optiling PRIVATE
causal_conv1d_tiling.cpp
tiling_util.cpp
)
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
target_sources(opsproto PRIVATE)
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_causal_conv1d.h")
install(FILES ${_GMM_Aclnn_header}
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
)

View File

@@ -0,0 +1,83 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file causal_conv1d_def.cpp
* \brief
*/
#include "register/op_def_registry.h"
namespace ops {
class CausalConv1d : public OpDef {
public:
explicit CausalConv1d(const char* name) : OpDef(name)
{
this->Input("x")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("weight")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("bias")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("convStates")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("queryStartLoc")
.ParamType(REQUIRED)
.DataTypeList({ge::DT_INT32})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("cacheIndices")
.ParamType(REQUIRED)
.DataTypeList({ge::DT_INT32})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("hasInitialState")
.ParamType(REQUIRED)
.DataTypeList({ge::DT_BOOL})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Output("y")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Attr("activationMode").AttrType(OPTIONAL).Int(0);
this->Attr("padSlotId").AttrType(OPTIONAL).Int(-1);
OpAICoreConfig aicoreConfig;
aicoreConfig.DynamicCompileStaticFlag(true)
.DynamicFormatFlag(false)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true)
.NeedCheckSupportFlag(false)
.PrecisionReduceFlag(true)
.ExtendCfgInfo("coreType.value", "AiCore");
this->AICore().AddConfig("ascend910b", aicoreConfig);
this->AICore().AddConfig("ascend910_93", aicoreConfig);
}
};
OP_ADD(CausalConv1d);
} // namespace ops

View File

@@ -0,0 +1,49 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file causal_conv1d_infershape.cpp
* \brief
*/
#include "register/op_impl_registry.h"
#include "error_log.h"
using namespace ge;
namespace ops {
static constexpr int64_t IDX_0 = 0;
static ge::graphStatus InferShapeCausalConv1d(gert::InferShapeContext* context)
{
// OPS_LOG_D(context->GetNodeName(), "Begin to do InferShapeCausalConv1d");
// get input shapes
const gert::Shape* xShape = context->GetInputShape(IDX_0);
OP_CHECK_NULL_WITH_CONTEXT(context, xShape);
// get output shapes
gert::Shape* yShape = context->GetOutputShape(IDX_0);
OP_CHECK_NULL_WITH_CONTEXT(context, yShape);
// 填充输出shape大小
auto xShapeSize = xShape->GetDimNum();
yShape->SetDimNum(xShapeSize);
for (size_t i = 0; i < xShapeSize; i++) {
int64_t dim = xShape->GetDim(i);
yShape->SetDim(i, dim);
}
// OPS_LOG_D(context->GetNodeName(), "End to do InferShapeCausalConv1d");
return GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(CausalConv1d).InferShape(InferShapeCausalConv1d);
} // namespace ops

View File

@@ -0,0 +1,365 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file causal_conv1d_tiling.cpp
* \brief
*/
// #include "error_log.h"
#include "log/ops_log.h"
#include "../tiling_base/tiling_templates_registry.h"
#include "../tiling_base/tiling_util.h"
#include "math_util.h"
#include "causal_conv1d_tiling.h"
#include "../op_kernel/causal_conv1d_tiling_key.h"
#include <set>
#include <limits>
namespace optiling {
using namespace Ops::Transformer::OpTiling;
constexpr uint32_t X_INDEX = 0;
constexpr uint32_t WEIGHT_INDEX = 1;
constexpr uint32_t BIAS_INDEX = 2;
constexpr uint32_t CONV_STATES_INDEX = 3;
constexpr uint32_t QUERY_START_LOC_INDEX = 4;
constexpr uint32_t CACHE_INDICES_INDEX = 5;
constexpr uint32_t HAS_INITIAL_STATE_INDEX = 6;
constexpr int32_t ATTR_ACTIVATION_MODE_INDEX = 0;
constexpr int32_t ATTR_PAD_SLOT_ID_INDEX = 1;
struct DimTileChoice {
int64_t dimTileSize = 0;
int64_t blocksPerSeq = 0;
int64_t gridSize = 0;
};
static inline DimTileChoice ChooseDimTileSize(gert::TilingContext* context, int64_t batch, int64_t dim, uint32_t coreNum)
{
const int64_t candidates[] = {4096, 2048, 1024, 512,384};
DimTileChoice bestOver;
int64_t bestOverGap = std::numeric_limits<int64_t>::max();
DimTileChoice bestUnder;
for (int64_t dimTileSize : candidates) {
if (dim % dimTileSize != 0) {
continue;
}
const int64_t blocksPerSeq = dim / dimTileSize;
const int64_t gridSize = batch * blocksPerSeq;
if (gridSize <= 0) {
continue;
}
if (gridSize >= static_cast<int64_t>(coreNum)) {
const int64_t gap = gridSize - static_cast<int64_t>(coreNum);
if (gap < bestOverGap) {
bestOver.dimTileSize = dimTileSize;
bestOver.blocksPerSeq = blocksPerSeq;
bestOver.gridSize = gridSize;
bestOverGap = gap;
}
} else if (gridSize > bestUnder.gridSize ||
(gridSize == bestUnder.gridSize && dimTileSize < bestUnder.dimTileSize)) {
bestUnder.dimTileSize = dimTileSize;
bestUnder.blocksPerSeq = blocksPerSeq;
bestUnder.gridSize = gridSize;
}
}
DimTileChoice result = (bestOver.dimTileSize != 0) ? bestOver : bestUnder;
return result;
}
static ge::graphStatus GetPlatformInfo(gert::TilingContext* context, uint64_t& ubSize, uint32_t& coreNum)
{
auto compileInfoPtr = context->GetCompileInfo<CausalConv1dCompileInfo>();
if (compileInfoPtr != nullptr && compileInfoPtr->coreNum != 0 && compileInfoPtr->ubSize != 0) {
ubSize = compileInfoPtr->ubSize;
coreNum = compileInfoPtr->coreNum;
return ge::GRAPH_SUCCESS;
}
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
coreNum = ascendcPlatform.GetCoreNumAiv();
if(coreNum == 0) {
return ge::GRAPH_FAILED;
}
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
if(ubSize == 0) {
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetWorkspaceSize(gert::TilingContext* context)
{
size_t* currentWorkspace = context->GetWorkspaceSizes(1);
OP_CHECK_NULL_WITH_CONTEXT(context, currentWorkspace);
currentWorkspace[0] = 0;
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetAttrsInfo(gert::TilingContext* context, int64_t& activationMode, int64_t& padSlotId)
{
auto attrs = context->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
const int64_t* activationModePtr = attrs->GetAttrPointer<int64_t>(ATTR_ACTIVATION_MODE_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, activationModePtr);
activationMode = *activationModePtr;
if(activationMode != 0 && activationMode != 1){
return ge::GRAPH_FAILED;
}
const int64_t* padSlotIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_PAD_SLOT_ID_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, padSlotIdPtr);
padSlotId = *padSlotIdPtr;
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetShapeDtypeInfo(gert::TilingContext* context, CausalConv1dTilingData& tiling)
{
auto xShapePtr = context->GetInputShape(X_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, xShapePtr);
auto xShape = EnsureNotScalar(xShapePtr->GetStorageShape());
int64_t dim = 0;
int64_t cuSeqlen = 0;
int64_t seqLen = 0;
int64_t batch = 0;
int64_t inputMode = 0;
if (xShape.GetDimNum() == 2) {
inputMode = 0;
cuSeqlen = xShape.GetDim(0);
dim = xShape.GetDim(1);
seqLen = 0;
if(dim <= 0 || cuSeqlen < 0){
return ge::GRAPH_FAILED;
}
} else if (xShape.GetDimNum() == 3) {
inputMode = 1;
batch = xShape.GetDim(0);
seqLen = xShape.GetDim(1);
dim = xShape.GetDim(2);
cuSeqlen = batch * seqLen;
if(batch <= 0 || dim <= 0 || seqLen <= 0){
return ge::GRAPH_FAILED;
}
} else {
return ge::GRAPH_FAILED;
}
auto wShapePtr = context->GetInputShape(WEIGHT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, wShapePtr);
auto wShape = EnsureNotScalar(wShapePtr->GetStorageShape());
if(wShape.GetDimNum() != 2){
return ge::GRAPH_FAILED;
}
const int64_t width = wShape.GetDim(0);
const int64_t wDim = wShape.GetDim(1);
if(wDim != dim){
return ge::GRAPH_FAILED;
}
if(width != 4){
return ge::GRAPH_FAILED;
}
auto sShapePtr = context->GetInputShape(CONV_STATES_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, sShapePtr);
auto sShape = EnsureNotScalar(sShapePtr->GetStorageShape());
if(sShape.GetDimNum() != 3){
return ge::GRAPH_FAILED;
}
const int64_t numCacheLines = sShape.GetDim(0);
const int64_t stateLen = sShape.GetDim(1);
const int64_t sDim = sShape.GetDim(2);
if(numCacheLines <= 0){
return ge::GRAPH_FAILED;}
if(sDim != dim){
return ge::GRAPH_FAILED;}
if(stateLen < (width - 1)){
return ge::GRAPH_FAILED;}
auto qslShapePtr = context->GetInputShape(QUERY_START_LOC_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, qslShapePtr);
auto qslShape = EnsureNotScalar(qslShapePtr->GetStorageShape());
if(qslShape.GetDimNum() != 1){
return ge::GRAPH_FAILED;}
const int64_t qslSize = qslShape.GetDim(0);
if(qslSize < 1){
return ge::GRAPH_FAILED;}
if (inputMode == 0) {
batch = qslSize - 1;
}
if (inputMode == 1) {
if(qslSize != batch + 1){
return ge::GRAPH_FAILED;
}
}
auto ciShapePtr = context->GetInputShape(CACHE_INDICES_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, ciShapePtr);
auto ciShape = EnsureNotScalar(ciShapePtr->GetStorageShape());
if(ciShape.GetDimNum() != 1){return ge::GRAPH_FAILED;}
if(ciShape.GetDim(0) != batch){return ge::GRAPH_FAILED;}
auto hisShapePtr = context->GetInputShape(HAS_INITIAL_STATE_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, hisShapePtr);
auto hisShape = EnsureNotScalar(hisShapePtr->GetStorageShape());
if(hisShape.GetDimNum() != 1){
return ge::GRAPH_FAILED;}
if(hisShape.GetDim(0) != batch){
return ge::GRAPH_FAILED;}
tiling.set_hasBias(0);
auto biasShapePtr = context->GetOptionalInputShape(BIAS_INDEX);
if (biasShapePtr != nullptr && biasShapePtr->GetStorageShape().GetDimNum() != 0) {
auto biasShape = EnsureNotScalar(biasShapePtr->GetStorageShape());
if(biasShape.GetDimNum() != 1){
return ge::GRAPH_FAILED;}
if(biasShape.GetDim(0) != dim){
return ge::GRAPH_FAILED;}
tiling.set_hasBias(1);
}
const std::set<ge::DataType> supportedXDtype = {ge::DT_BF16, ge::DT_FLOAT16};
auto xDesc = context->GetInputDesc(X_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, xDesc);
const ge::DataType xDtype = xDesc->GetDataType();
if(supportedXDtype.count(xDtype) == 0){
return ge::GRAPH_FAILED;}
auto wDesc = context->GetInputDesc(WEIGHT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, wDesc);
if(wDesc->GetDataType() != xDtype){
return ge::GRAPH_FAILED;}
if (tiling.get_hasBias() == 1) {
auto biasDesc = context->GetOptionalInputDesc(BIAS_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, biasDesc);
if(biasDesc->GetDataType() != xDtype){
return ge::GRAPH_FAILED;}
}
auto sDesc = context->GetInputDesc(CONV_STATES_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, sDesc);
if(sDesc->GetDataType() != xDtype){
return ge::GRAPH_FAILED;}
auto qslDesc = context->GetInputDesc(QUERY_START_LOC_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, qslDesc);
if(qslDesc->GetDataType() != ge::DT_INT32){
return ge::GRAPH_FAILED;}
auto ciDesc = context->GetInputDesc(CACHE_INDICES_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, ciDesc);
if(ciDesc->GetDataType() != ge::DT_INT32){
return ge::GRAPH_FAILED;}
auto hisDesc = context->GetInputDesc(HAS_INITIAL_STATE_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, hisDesc);
if(hisDesc->GetDataType() != ge::DT_BOOL){
return ge::GRAPH_FAILED;}
tiling.set_dim(dim);
tiling.set_cuSeqlen(cuSeqlen);
tiling.set_seqLen(seqLen);
tiling.set_inputMode(inputMode);
tiling.set_width(width);
tiling.set_stateLen(stateLen);
tiling.set_numCacheLines(numCacheLines);
tiling.set_batch(batch);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CausalConv1dTilingFunc(gert::TilingContext* context)
{
uint64_t ubSize;
uint32_t coreNum;
if( GetPlatformInfo(context, ubSize, coreNum) != ge::GRAPH_SUCCESS){
return ge::GRAPH_FAILED;
}
if(GetWorkspaceSize(context) != ge::GRAPH_SUCCESS){
return ge::GRAPH_FAILED;
}
CausalConv1dTilingData tilingData;
int64_t activationMode = 0;
int64_t padSlotId = -1;
if(GetAttrsInfo(context, activationMode, padSlotId) != ge::GRAPH_SUCCESS){
return ge::GRAPH_FAILED;
}
tilingData.set_activationMode(activationMode);
tilingData.set_padSlotId(padSlotId);
if( GetShapeDtypeInfo(context, tilingData) != ge::GRAPH_SUCCESS){
return ge::GRAPH_FAILED;
}
const int64_t dim = tilingData.get_dim();
const int64_t batch = tilingData.get_batch();
if(dim <= 0 || batch <= 0){
return ge::GRAPH_FAILED;
}
const DimTileChoice choice = ChooseDimTileSize(context, batch, dim, coreNum);
const uint32_t blockDim = (choice.gridSize < static_cast<int64_t>(coreNum))
? static_cast<uint32_t>(choice.gridSize)
: coreNum;
context->SetBlockDim(blockDim);
tilingData.set_dimTileSize(choice.dimTileSize);
tilingData.set_blocksPerSeq(choice.blocksPerSeq);
const uint64_t tilingKey = GET_TPL_TILING_KEY(CAUSAL_CONV1D_TPL_SCH_MODE_DEFAULT);
context->SetTilingKey(tilingKey);
tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingParseForCausalConv1d(gert::TilingParseContext* context)
{
auto platformInfoPtr = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
auto compileInfoPtr = context->GetCompiledInfo<CausalConv1dCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfoPtr);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
compileInfoPtr->coreNum = static_cast<uint32_t>(ascendcPlatform.GetCoreNumAiv());
if(compileInfoPtr->coreNum == 0){
return ge::GRAPH_FAILED;
}
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize);
if(compileInfoPtr->ubSize == 0){
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(CausalConv1d)
.Tiling(CausalConv1dTilingFunc)
.TilingParse<CausalConv1dCompileInfo>(TilingParseForCausalConv1d);
} // namespace optiling

View File

@@ -0,0 +1,60 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file causal_conv1d_tiling_data.h
* \brief
*/
#ifndef ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H
#define ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H
#include <cstdint>
// #include "register/tilingdata_base.h"
// #include "tiling/tiling_api.h"
#include "register/tilingdata_base.h"
#include "error_log.h"
#include "register/op_impl_registry.h"
#include "tiling/platform/platform_ascendc.h"
#include "platform/platform_infos_def.h"
namespace optiling {
BEGIN_TILING_DATA_DEF(CausalConv1dTilingData)
TILING_DATA_FIELD_DEF(int64_t, dim);
TILING_DATA_FIELD_DEF(int64_t, cuSeqlen);
TILING_DATA_FIELD_DEF(int64_t, seqLen);
TILING_DATA_FIELD_DEF(int64_t, inputMode);
TILING_DATA_FIELD_DEF(int64_t, width);
TILING_DATA_FIELD_DEF(int64_t, stateLen);
TILING_DATA_FIELD_DEF(int64_t, numCacheLines);
TILING_DATA_FIELD_DEF(int64_t, batch);
TILING_DATA_FIELD_DEF(int64_t, activationMode);
TILING_DATA_FIELD_DEF(int64_t, padSlotId);
TILING_DATA_FIELD_DEF(int64_t, hasBias);
TILING_DATA_FIELD_DEF(int64_t, dimTileSize);
TILING_DATA_FIELD_DEF(int64_t, blocksPerSeq);
END_TILING_DATA_DEF;
struct CausalConv1dCompileInfo {
uint64_t ubSize = 0;
uint32_t coreNum = 0;
};
REGISTER_TILING_DATA_CLASS(CausalConv1d, CausalConv1dTilingData)
} // namespace optiling
#endif // ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H

View File

@@ -0,0 +1,71 @@
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#include <string>
#include "toolchain/slog.h"
#define OP_LOGI(opname, ...)
#define OP_LOGW(opname, ...) \
do { \
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
do { \
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE(opname, ...) \
do { \
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGD(opname, ...)
namespace optiling {
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
do { \
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
} while (0)
#define OP_CHECK_IF(cond, log_func, expr) \
do { \
if (cond) { \
log_func; \
expr; \
} \
} while (0)
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
do { \
if ((ptr) == nullptr) { \
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
return ge::GRAPH_FAILED; \
} \
} while (0)
} // namespace optiling
template <typename T>
T CeilAlign(T a, T b)
{
return (a + b - 1) / b * b;
}
template <typename T>
T CeilDiv(T a, T b)
{
if (b == 0) {
return a;
}
return (a + b - 1) / b;
}
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_

View File

@@ -0,0 +1,61 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file math_util.h
* \brief
*/
#ifndef TILING_MATMUL_MATH_UTIL_H
#define TILING_MATMUL_MATH_UTIL_H
#include <array>
#include <cstdint>
#include <vector>
#include <utility>
namespace matmul_tiling {
class MathUtil {
public:
static bool IsEqual(float leftValue, float rightValue);
template<typename T>
static auto CeilDivision(T num1, T num2) -> T
{
if (num2 == 0) {
return 0;
}
return static_cast<T>((static_cast<int64_t>(num1) + static_cast<int64_t>(num2) - 1) /
static_cast<int64_t>(num2));
}
template<typename T>
static auto Align(T num1, T num2) -> T
{
return CeilDivision(num1, num2) * num2;
}
static int32_t AlignDown(int32_t num1, int32_t num2);
static bool CheckMulOverflow(int32_t a, int32_t b, int32_t &c);
static int32_t MapShape(int32_t shape, bool roundUpFlag = true);
static void AddFactor(std::vector<int32_t> &dimsFactors, int32_t dim);
static void GetFactorCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
const int32_t factorEnd);
static void GetFactorLayerCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
const int32_t factorEnd);
static bool CheckFactorNumSatisfy(const int32_t dim);
static int32_t FindBestSingleCore(const int32_t oriShape, const int32_t mappedShape, const int32_t coreNum,
bool isKDim);
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t minFactor, int32_t maxFactor);
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
static void GetBlockFactors(std::vector<int32_t> &factorList, const int32_t oriShape, const int32_t mpShape,
const int32_t coreNum, const int32_t maxNum);
static int32_t GetNonFactorMap(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
static std::vector<std::pair<int, int>> GetFactorPairs(int32_t num);
static std::pair<int32_t, int32_t> DivideIntoMainAndTail(int32_t num, int32_t divisor);
};
} // namespace matmul_tiling
#endif // _MATH_UTIL_H_

View File

@@ -0,0 +1,31 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_util.cpp
* \brief
*/
#include "../tiling_base/tiling_util.h"
namespace Ops {
namespace Transformer {
namespace OpTiling {
static const gert::Shape g_vec_1_shape = {1};
const gert::Shape &EnsureNotScalar(const gert::Shape &inShape)
{
if (inShape.IsScalar()) {
return g_vec_1_shape;
}
return inShape;
}
} // namespace OpTiling
} // namespace Transformer
} // namespace Ops

View File

@@ -0,0 +1,58 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file causal_conv1d.cpp
* \brief
*/
#include "causal_conv1d.h"
namespace {
template <typename T>
__aicore__ inline void RunCausalConv1d(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates,
GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR hasInitialState,
GM_ADDR y, const NsCausalConv1d::CausalConv1dTilingData* tilingData)
{
NsCausalConv1d::CausalConv1d<T> op;
op.Init(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, tilingData);
op.Process();
}
} // namespace
template <uint32_t schMode>
__global__ __aicore__ void causal_conv1d(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates,
GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR hasInitialState,
GM_ADDR y, GM_ADDR workspace, GM_ADDR tiling)
{
REGISTER_TILING_DEFAULT( NsCausalConv1d::CausalConv1dTilingData);
// GET_TILING_DATA_WITH_STRUCT( NsCausalConv1d::CausalConv1dTilingData, tilingData, tiling);
GET_TILING_DATA(tilingData, tiling);
#if defined(ORIG_DTYPE_X)
#if (ORIG_DTYPE_X == DT_FLOAT16)
RunCausalConv1d<half>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
#elif (ORIG_DTYPE_X == DT_BF16)
RunCausalConv1d<bfloat16_t>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
#elif (ORIG_DTYPE_X == DT_FLOAT)
RunCausalConv1d<float>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
#endif
#else
#if (DTYPE_X == DT_FLOAT16)
RunCausalConv1d<half>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
#elif (DTYPE_X == DT_BF16)
RunCausalConv1d<bfloat16_t>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
#elif (DTYPE_X == DT_FLOAT)
RunCausalConv1d<float>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
#endif
#endif
}

View File

@@ -0,0 +1,436 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file causal_conv1d.h
* \brief CausalConv1D (prefill/extend) AscendC kernel implementation.
*/
#ifndef CAUSAL_CONV1D_H
#define CAUSAL_CONV1D_H
#include "kernel_operator.h"
// #include "kernel_tiling/kernel_tiling.h"
#include "causal_conv1d_tiling_key.h"
#include "causal_conv1d_common.h"
// #define ENABLE_CAUSAL_CONV1D_DEBUG
// #ifdef ENABLE_CAUSAL_CONV1D_DEBUG
// #define CCONV_PRINTF(fmt, ...) printf(fmt, ##__VA_ARGS__)
// #else
// #define CCONV_PRINTF(fmt, ...)
// #endif
// #define CCONV_PRINT_IF(cond, fmt, ...) \
// do { \
// if (cond) { \
// CCONV_PRINTF(fmt, ##__VA_ARGS__); \
// } \
// } while (0)
// #ifdef ENABLE_CAUSAL_CONV1D_DEBUG
// #define CCONV_DUMP_TENSOR_IF(cond, tensor, size) \
// do { \
// if (cond) { \
// DumpTensor(tensor, __LINE__, size); \
// } \
// } while (0)
// #else
constexpr int32_t CCONV_DBG_SEQ = -1;
constexpr int32_t CCONV_DBG_C0 = -1;
constexpr int32_t CCONV_DBG_MAX_TOKENS = 0;
constexpr int32_t CCONV_DBG_VERBOSE_TOKENS = 0;
constexpr int32_t CCONV_DBG_DUMP_SIZE = 0;
constexpr bool CCONV_DBG_PRINT_SYNC = false;
constexpr bool CCONV_DBG_DUMP_WEIGHTS = false;
constexpr bool CCONV_DBG_DUMP_BIAS = false;
constexpr bool CCONV_DBG_DUMP_INIT_RING = false;
constexpr bool CCONV_DBG_DUMP_RUNSEQ = false;
constexpr bool CCONV_DBG_DUMP_PREFETCH = false;
constexpr bool CCONV_DBG_DUMP_STATE = false;
// #define CCONV_DUMP_TENSOR_IF(cond, tensor, size) \
// do { \
// } while (0)
// #endif
using namespace AscendC;
namespace NsCausalConv1d {
using namespace NsCausalConv1dCommon;
#ifndef CAUSAL_CONV1D_TILING_DATA_H_
#define CAUSAL_CONV1D_TILING_DATA_H_
struct CausalConv1dTilingData {
int64_t dim;
int64_t cuSeqlen;
int64_t seqLen;
int64_t inputMode;
int64_t width;
int64_t stateLen;
int64_t numCacheLines;
int64_t batch;
// attrs
int64_t activationMode; // 0: none, 1: silu/swish
int64_t padSlotId; // default -1
// optional inputs
int64_t hasBias; // 0/1
// Channel-wise tiling
int64_t dimTileSize;
int64_t blocksPerSeq;
};
#endif // CAUSAL_CONV1D_TILING_DATA_H_
template <typename T>
class CausalConv1d
{
public:
__aicore__ inline CausalConv1d() = default;
__aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates, GM_ADDR queryStartLoc,
GM_ADDR cacheIndices, GM_ADDR hasInitialState, GM_ADDR y
,
const CausalConv1dTilingData* tilingData);
__aicore__ inline void Process();
private:
__aicore__ inline void LoadWeightAndBias(int32_t c0, int32_t dimTileSize, bool dbg);
__aicore__ inline void InitRing(int32_t cacheIdx, bool hasInit, int32_t start, int32_t len,
int32_t c0, int32_t dimTileSize, int32_t dim, bool dbg);
__aicore__ inline void RunSeq(int32_t start, int32_t len, int32_t c0, int32_t dimTileSize, int32_t dim, bool dbg);
__aicore__ inline void WriteBackState(int32_t cacheIdx, int32_t len, int32_t c0,
int32_t dimTileSize, int32_t dim, bool dbg);
__aicore__ inline void AllocEvents();
__aicore__ inline void ReleaseEvents();
private:
TPipe pipe;
TBuf<QuePosition::VECIN> inBuf;
TBuf<QuePosition::VECOUT> outBuf;
TBuf<QuePosition::VECCALC> calcBuf;
TEventID tempVToMte2Event_;
TEventID tempMte2ToVEvent_;
TEventID inputMte2ToVEvent_;
TEventID outMte3ToVEvent_[2];
TEventID outVToMte3Event_[2];
GlobalTensor<T> xGm;
GlobalTensor<T> weightGm;
GlobalTensor<T> biasGm;
GlobalTensor<T> convStatesGm;
GlobalTensor<int32_t> queryStartLocGm;
GlobalTensor<int32_t> cacheIndicesGm;
GlobalTensor<bool> hasInitialStateGm;
GlobalTensor<T> yGm;
const CausalConv1dTilingData* tilingData_ {nullptr};
};
template <typename T>
__aicore__ inline void CausalConv1d<T>::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates,
GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR hasInitialState,
GM_ADDR y
, const CausalConv1dTilingData* tilingData)
{
// REGISTER_TILING_DEFAULT(CausalConv1dTilingData);
// auto tiling = (__gm__ CausalConv1dTilingData*)tilingGM;
// GET_TILING_DATA(tilingData, tilingGM);
tilingData_ = tilingData;
xGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(x));
weightGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(weight));
if (tilingData_->hasBias != 0) {
biasGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(bias));
}
convStatesGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(convStates));
queryStartLocGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(queryStartLoc));
cacheIndicesGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(cacheIndices));
hasInitialStateGm.SetGlobalBuffer(reinterpret_cast<__gm__ bool*>(hasInitialState));
yGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(y));
pipe.InitBuffer(inBuf, RING_SLOTS * MAX_BLOCK_DIM * sizeof(T));
pipe.InitBuffer(outBuf, 2 * MAX_BLOCK_DIM * sizeof(T));
pipe.InitBuffer(calcBuf, (MAX_WIDTH + 3) * MAX_BLOCK_DIM * sizeof(float));
AllocEvents();
// CCONV_PRINT_IF(GetBlockIdx() == 0U, "[Init] dim=%d, dimTileSize=%d, blocksPerSeq=%d, batch=%d\n",
// tilingData_->dim, tilingData_->dimTileSize, tilingData_->blocksPerSeq, tilingData_->batch);
// CCONV_PRINT_IF(GetBlockIdx() == 0U, "[Init] hasBias=%d, activationMode=%d, stateLen=%d, inputMode=%d\n",
// tilingData_->hasBias, tilingData_->activationMode, tilingData_->stateLen, tilingData_->inputMode);
}
template <typename T>
__aicore__ inline void CausalConv1d<T>::AllocEvents()
{
tempVToMte2Event_ = GetTPipePtr()->AllocEventID<HardEvent::V_MTE2>();
tempMte2ToVEvent_ = GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>();
inputMte2ToVEvent_ = GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>();
outMte3ToVEvent_[0] = GetTPipePtr()->AllocEventID<HardEvent::MTE3_V>();
outMte3ToVEvent_[1] = GetTPipePtr()->AllocEventID<HardEvent::MTE3_V>();
outVToMte3Event_[0] = GetTPipePtr()->AllocEventID<HardEvent::V_MTE3>();
outVToMte3Event_[1] = GetTPipePtr()->AllocEventID<HardEvent::V_MTE3>();
}
template <typename T>
__aicore__ inline void CausalConv1d<T>::ReleaseEvents()
{
GetTPipePtr()->ReleaseEventID<HardEvent::V_MTE2>(tempVToMte2Event_);
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(tempMte2ToVEvent_);
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(inputMte2ToVEvent_);
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_V>(outMte3ToVEvent_[0]);
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_V>(outMte3ToVEvent_[1]);
GetTPipePtr()->ReleaseEventID<HardEvent::V_MTE3>(outVToMte3Event_[0]);
GetTPipePtr()->ReleaseEventID<HardEvent::V_MTE3>(outVToMte3Event_[1]);
}
template <typename T>
__aicore__ inline void CausalConv1d<T>::LoadWeightAndBias(int32_t c0, int32_t dimTileSize, bool dbg)
{
const int32_t dim = tilingData_->dim;
const bool dbgSync = dbg && CCONV_DBG_PRINT_SYNC;
(void)dbgSync;
LocalTensor<float> calc = calcBuf.Get<float>();
LocalTensor<float> weightF = calc;
LocalTensor<float> biasF = weightF[MAX_WIDTH * MAX_BLOCK_DIM];
LocalTensor<T> tempT = outBuf.Get<T>();
// CCONV_PRINT_IF(dbg, "[LoadWeightAndBias] c0=%d, dimTileSize=%d\n", c0, dimTileSize);
for (int32_t j = 0; j < MAX_WIDTH; ++j) {
const int64_t weightOffset = static_cast<int64_t>(j) * dim + c0;
PipeBarrier<PIPE_ALL>();
DataCopy(tempT, weightGm[weightOffset], dimTileSize);
PipeBarrier<PIPE_ALL>();
Cast(weightF[j * MAX_BLOCK_DIM], tempT, RoundMode::CAST_NONE, dimTileSize);
PipeBarrier<PIPE_ALL>();
// if (dbg && CCONV_DBG_DUMP_WEIGHTS) {
// CCONV_PRINTF("[Dump][weightF] j=%d\n", j);
// CCONV_DUMP_TENSOR_IF(true, weightF[j * MAX_BLOCK_DIM], CCONV_DBG_DUMP_SIZE);
// }
}
if (tilingData_->hasBias != 0) {
PipeBarrier<PIPE_ALL>();
DataCopy(tempT, biasGm[c0], dimTileSize);
PipeBarrier<PIPE_ALL>();
Cast(biasF, tempT, RoundMode::CAST_NONE, dimTileSize);
PipeBarrier<PIPE_ALL>();
// if (dbg && CCONV_DBG_DUMP_BIAS) {
// CCONV_PRINTF("[Dump][biasF]\n");
// CCONV_DUMP_TENSOR_IF(true, biasF, CCONV_DBG_DUMP_SIZE);
// }
} else {
Duplicate(biasF, 0.0f, dimTileSize);
// CCONV_PRINT_IF(dbg, "[LoadWeightAndBias] bias=0 (no bias)\n");
}
PipeBarrier<PIPE_ALL>();
}
template <typename T>
__aicore__ inline void CausalConv1d<T>::InitRing(int32_t cacheIdx, bool hasInit, int32_t start, int32_t len,
int32_t c0, int32_t dimTileSize, int32_t dim, bool dbg)
{
const int32_t stateLen = tilingData_->stateLen;
LocalTensor<T> ring = inBuf.Get<T>();
PipeBarrier<PIPE_ALL>();
if (hasInit) {
for (int32_t i = 0; i < (MAX_WIDTH - 1); ++i) {
const int64_t stateOffset = static_cast<int64_t>(cacheIdx) * stateLen * dim +
static_cast<int64_t>(i) * dim + c0;
DataCopy(ring[i * MAX_BLOCK_DIM], convStatesGm[stateOffset], dimTileSize);
}
} else {
for (int32_t i = 0; i < (MAX_WIDTH - 1); ++i) {
Duplicate(ring[i * MAX_BLOCK_DIM], static_cast<T>(0), dimTileSize);
}
}
PipeBarrier<PIPE_ALL>();
if (len > 0) {
const int64_t xOffset = static_cast<int64_t>(start) * dim + c0;
PipeBarrier<PIPE_ALL>();
DataCopy(ring[SlotCurr(0) * MAX_BLOCK_DIM], xGm[xOffset], dimTileSize);
PipeBarrier<PIPE_ALL>();
}
}
template <typename T>
__aicore__ inline void CausalConv1d<T>::RunSeq(int32_t start, int32_t len, int32_t c0, int32_t dimTileSize,
int32_t dim, bool dbg)
{
LocalTensor<float> calc = calcBuf.Get<float>();
LocalTensor<float> weightF = calc;
LocalTensor<float> biasF = weightF[MAX_WIDTH * MAX_BLOCK_DIM];
LocalTensor<float> accF = biasF[MAX_BLOCK_DIM];
LocalTensor<float> tmpF = accF[MAX_BLOCK_DIM];
LocalTensor<T> ring = inBuf.Get<T>();
LocalTensor<T> outT = outBuf.Get<T>();
const bool dbgSync = dbg && CCONV_DBG_PRINT_SYNC;
(void)dbgSync;
const bool hasActivation = (tilingData_->activationMode != 0);
const int32_t dbgMaxTokens = CCONV_DBG_MAX_TOKENS;
const int32_t dbgVerboseTokens = CCONV_DBG_VERBOSE_TOKENS;
for (int32_t t = 0; t < len; ++t) {
const bool dbgTok = dbg && (t < dbgMaxTokens);
const bool dbgVerbose = dbg && CCONV_DBG_DUMP_RUNSEQ && (t < dbgVerboseTokens);
const bool dbgStep = dbgVerbose && (t == 0);
const int32_t slotCurr = SlotCurr(t);
const int32_t slotH1 = SlotHist(t, 1);
const int32_t slotH2 = SlotHist(t, 2);
const int32_t slotH3 = SlotHist(t, 3);
const int32_t slotPref = (t + 1 < len) ? SlotPrefetch(t) : -1;
const int32_t outSlot = t & 1;
if (t + 1 < len) {
const int64_t xOffset = static_cast<int64_t>(start + t + 1) * dim + c0;
PipeBarrier<PIPE_ALL>();
DataCopy(ring[slotPref * MAX_BLOCK_DIM], xGm[xOffset], dimTileSize);
PipeBarrier<PIPE_ALL>();
}
DataCopy(accF, biasF, dimTileSize);
for (int32_t j = 0; j < MAX_WIDTH; ++j) {
const int32_t tap = (MAX_WIDTH - 1) - j;
const int32_t slot = (tap == 0) ? slotCurr : SlotHist(t, tap);
PipeBarrier<PIPE_ALL>();
Cast(tmpF, ring[slot * MAX_BLOCK_DIM], RoundMode::CAST_NONE, dimTileSize);
PipeBarrier<PIPE_ALL>();
PipeBarrier<PIPE_ALL>();
MulAddDst(accF, tmpF, weightF[j * MAX_BLOCK_DIM], dimTileSize);
PipeBarrier<PIPE_ALL>();
}
if (hasActivation) {
Silu(tmpF, accF, dimTileSize);
}
PipeBarrier<PIPE_ALL>();
if constexpr (IsSameType<T, float>::value) {
if (hasActivation) {
DataCopy(outT[outSlot * MAX_BLOCK_DIM], tmpF, dimTileSize);
} else {
DataCopy(outT[outSlot * MAX_BLOCK_DIM], accF, dimTileSize);
}
} else {
if (hasActivation) {
Cast(outT[outSlot * MAX_BLOCK_DIM], tmpF, RoundMode::CAST_RINT, dimTileSize);
} else {
Cast(outT[outSlot * MAX_BLOCK_DIM], accF, RoundMode::CAST_RINT, dimTileSize);
}
}
PipeBarrier<PIPE_ALL>();
const int64_t outOffset = static_cast<int64_t>(start + t) * dim + c0;
PipeBarrier<PIPE_ALL>();
DataCopy(yGm[outOffset], outT[outSlot * MAX_BLOCK_DIM], dimTileSize);
PipeBarrier<PIPE_ALL>();
}
}
template <typename T>
__aicore__ inline void CausalConv1d<T>::WriteBackState(int32_t cacheIdx, int32_t len, int32_t c0,
int32_t dimTileSize, int32_t dim, bool dbg)
{
const int32_t stateLen = tilingData_->stateLen;
if (len <= 0) {
return;
}
const int32_t lastT = len - 1;
LocalTensor<T> ring = inBuf.Get<T>();
for (int32_t pos = 0; pos < (MAX_WIDTH - 1); ++pos) {
const int32_t tap = (MAX_WIDTH - 2) - pos;
const int32_t slot = (tap == 0) ? SlotCurr(lastT) : SlotHist(lastT, tap);
const int64_t stateOffset = static_cast<int64_t>(cacheIdx) * stateLen * dim +
static_cast<int64_t>(pos) * dim + c0;
PipeBarrier<PIPE_ALL>();
DataCopy(convStatesGm[stateOffset], ring[slot * MAX_BLOCK_DIM], dimTileSize);
PipeBarrier<PIPE_ALL>();
}
}
template <typename T>
__aicore__ inline void CausalConv1d<T>::Process()
{
const int32_t dim = tilingData_->dim;
const int32_t batch = tilingData_->batch;
const int32_t inputMode = tilingData_->inputMode;
const int32_t seqLen = tilingData_->seqLen;
const int32_t dimTileSize = static_cast<int32_t>(tilingData_->dimTileSize);
const int32_t blocksPerSeq = static_cast<int32_t>(tilingData_->blocksPerSeq);
const uint32_t blockIdx = GetBlockIdx();
const uint32_t blockNum = GetBlockNum();
if (dimTileSize <= 0 || blocksPerSeq <= 0 || dimTileSize > MAX_BLOCK_DIM || blocksPerSeq * dimTileSize != dim) {
ReleaseEvents();
return;
}
const int64_t gridSize = static_cast<int64_t>(batch) * blocksPerSeq;
for (int64_t task = static_cast<int64_t>(blockIdx); task < gridSize; task += static_cast<int64_t>(blockNum)) {
const int32_t seq = static_cast<int32_t>(task / blocksPerSeq);
const int32_t dimBlockId = static_cast<int32_t>(task % blocksPerSeq);
const int32_t c0 = dimBlockId * dimTileSize;
const bool dbg = (seq == CCONV_DBG_SEQ) && (c0 == CCONV_DBG_C0);
LoadWeightAndBias(c0, dimTileSize, dbg);
int32_t start = 0;
int32_t len = 0;
if (inputMode == 0) {
const int32_t startVal = queryStartLocGm.GetValue(seq);
const int32_t endVal = queryStartLocGm.GetValue(seq + 1);
start = startVal;
len = endVal - startVal;
} else {
start = seq * seqLen;
len = seqLen;
}
if (len <= 0) {
continue;
}
const int32_t cacheIdx = cacheIndicesGm.GetValue(seq);
if (cacheIdx == tilingData_->padSlotId) {
continue;
}
const bool hasInit = hasInitialStateGm.GetValue(seq);
InitRing(cacheIdx, hasInit, start, len, c0, dimTileSize, dim, dbg);
RunSeq(start, len, c0, dimTileSize, dim, dbg);
WriteBackState(cacheIdx, len, c0, dimTileSize, dim, dbg);
}
ReleaseEvents();
}
} // namespace NsCausalConv1d
#endif // CAUSAL_CONV1D_H

View File

@@ -0,0 +1,45 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file causal_conv1d_common.h
* \brief Common utilities and constants for CausalConv1D prefill kernel.
*/
#ifndef CAUSAL_CONV1D_COMMON_H
#define CAUSAL_CONV1D_COMMON_H
#include "kernel_operator.h"
namespace NsCausalConv1dCommon {
constexpr int32_t MAX_WIDTH = 4;
constexpr int32_t MAX_BLOCK_DIM = 4096;
constexpr int32_t RING_SLOTS = 5;
__aicore__ inline int32_t SlotCurr(int32_t t)
{
return (t + 3) % RING_SLOTS;
}
__aicore__ inline int32_t SlotHist(int32_t t, int32_t i)
{
return (t + 3 - i) % RING_SLOTS;
}
__aicore__ inline int32_t SlotPrefetch(int32_t t)
{
return (t + 4) % RING_SLOTS;
}
} // namespace NsCausalConv1dCommon
#endif // CAUSAL_CONV1D_COMMON_H

View File

@@ -0,0 +1,34 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file causal_conv1d_tiling_key.h
* \brief causal_conv1d tiling key declare
*/
#ifndef __CAUSAL_CONV1D_TILING_KEY_H__
#define __CAUSAL_CONV1D_TILING_KEY_H__
#include "ascendc/host_api/tiling/template_argument.h"
#define CAUSAL_CONV1D_TPL_SCH_MODE_DEFAULT 0
ASCENDC_TPL_ARGS_DECL(CausalConv1d,
ASCENDC_TPL_UINT_DECL(
schMode, 1, ASCENDC_TPL_UI_LIST, CAUSAL_CONV1D_TPL_SCH_MODE_DEFAULT)
);
ASCENDC_TPL_SEL(
ASCENDC_TPL_ARGS_SEL(
ASCENDC_TPL_UINT_SEL(
schMode, ASCENDC_TPL_UI_LIST, CAUSAL_CONV1D_TPL_SCH_MODE_DEFAULT)));
#endif // __CAUSAL_CONV1D_TILING_KEY_H__

View File

@@ -0,0 +1,51 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file data_copy_transpose_tiling.h
* \brief
*/
#pragma once
#include <vector>
#include <graph/tensor.h>
#include "data_copy_transpose_tiling_def.h"
namespace optiling {
inline void GetDataCopyTransposeTiling(const ge::Shape &dstShape, const ge::Shape &srcShape, const uint32_t typeSize,
optiling::CopyTransposeTiling &tiling)
{
constexpr int64_t B_INDEX = 0;
constexpr int64_t N_INDEX = 1;
constexpr int64_t S_INDEX = 2;
constexpr int64_t H_INDEX = 3;
std::vector<int64_t> dstShapeInfo = dstShape.GetDims();
std::vector<int64_t> srcShapeInfo = srcShape.GetDims();
tiling.set_dstShapeB(dstShapeInfo[B_INDEX]);
tiling.set_dstShapeN(dstShapeInfo[N_INDEX]);
tiling.set_dstShapeS(dstShapeInfo[S_INDEX]);
tiling.set_dstShapeH(dstShapeInfo[H_INDEX]);
tiling.set_dstShapeHN(tiling.get_dstShapeH() / tiling.get_dstShapeN());
tiling.set_srcShapeB(srcShapeInfo[B_INDEX]);
tiling.set_srcShapeN(srcShapeInfo[N_INDEX]);
tiling.set_srcShapeS(srcShapeInfo[S_INDEX]);
tiling.set_srcShapeHN(srcShapeInfo[H_INDEX]);
tiling.set_originalShapeNLen(tiling.get_srcShapeHN() * typeSize);
tiling.set_shapeSHValue(tiling.get_dstShapeS() * tiling.get_dstShapeH());
tiling.set_shapeNsValue(tiling.get_dstShapeN() * tiling.get_dstShapeS());
tiling.set_shapeNsnValue(tiling.get_dstShapeN() * tiling.get_srcShapeS() * tiling.get_srcShapeN());
tiling.set_shapeBHValue(tiling.get_dstShapeB() * tiling.get_dstShapeH());
}
} // namespace optiling

View File

@@ -0,0 +1,43 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file data_copy_transpose_tiling_def.h
* \brief
*/
#pragma once
#include <cstdint>
#include <register/tilingdata_base.h>
namespace optiling {
BEGIN_TILING_DATA_DEF(CopyTransposeTiling)
TILING_DATA_FIELD_DEF(uint32_t, dstShapeB);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeN);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeS);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeHN);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeH);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeB);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeN);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeS);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeHN);
TILING_DATA_FIELD_DEF(uint32_t, originalShapeNLen);
TILING_DATA_FIELD_DEF(uint32_t, shapeSHValue);
TILING_DATA_FIELD_DEF(uint32_t, shapeNsValue);
TILING_DATA_FIELD_DEF(uint32_t, shapeNsnValue);
TILING_DATA_FIELD_DEF(uint32_t, invalidParamCopyTransposeTiling);
TILING_DATA_FIELD_DEF(uint32_t, shapeBHValue);
TILING_DATA_FIELD_DEF(uint32_t, paramsAlign);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(CopyTransposeTilingOp, CopyTransposeTiling)
} // namespace optiling

View File

@@ -0,0 +1,56 @@
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#include <string>
#include "toolchain/slog.h"
#define OP_LOGI(opname, ...)
#define OP_LOGW(opname, ...) \
do { \
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
do { \
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE(opname, ...) \
do { \
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGD(opname, ...)
namespace optiling {
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
do { \
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
} while (0)
// Modify OP_TILING_CHECK macro to ensure proper handling of expressions
#define OP_CHECK_IF(cond, log_func, expr) \
do { \
if (cond) { \
log_func; \
expr; \
} \
} while (0)
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
do { \
if ((ptr) == nullptr) { \
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
return ge::GRAPH_FAILED; \
} \
} while (0)
} // namespace optiling
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_

View File

@@ -0,0 +1,256 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_base.h
* \brief
*/
#pragma once
#include <sstream>
#include <exe_graph/runtime/tiling_context.h>
#include <graph/utils/type_utils.h>
#include "tiling/platform/platform_ascendc.h"
#include "error_log.h"
#ifdef ASCENDC_OP_TEST
#define ASCENDC_EXTERN_C extern "C"
#else
#define ASCENDC_EXTERN_C
#endif
namespace Ops {
namespace Transformer {
namespace OpTiling {
struct AiCoreParams {
uint64_t ubSize = 0;
uint64_t blockDim = 0;
uint64_t aicNum = 0;
uint64_t l1Size = 0;
uint64_t l0aSize = 0;
uint64_t l0bSize = 0;
uint64_t l0cSize = 0;
};
struct CompileInfoCommon {
uint32_t aivNum;
uint32_t aicNum;
uint64_t ubSize;
uint64_t l1Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
uint64_t l2CacheSize;
int64_t coreNum;
int32_t socVersion;
uint32_t rsvd;
};
struct FlashAttentionScoreGradCompileInfo {
uint32_t aivNum;
uint32_t aicNum;
uint64_t ubSize;
uint64_t l1Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
uint64_t l2CacheSize;
int64_t coreNum;
platform_ascendc::SocVersion socVersion;
};
struct FACompileInfoCommon {
uint32_t aivNum;
uint32_t aicNum;
uint64_t ubSize;
uint64_t l1Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
uint64_t l2CacheSize;
int64_t coreNum;
int32_t socVersion;
uint32_t rsvd;
};
class TilingBaseClass {
public:
explicit TilingBaseClass(gert::TilingContext* context) : context_(context)
{}
virtual ~TilingBaseClass() = default;
// Tiling execution framework
// 1. GRAPH_SUCCESS: Success, and no need to continue executing subsequent Tiling class implementations
// 2. GRAPH_FAILED: Failure, abort the entire Tiling process
// 3. GRAPH_PARAM_INVALID: This class does not support, need to continue executing other Tiling class implementations
ge::graphStatus DoTiling()
{
auto ret = GetShapeAttrsInfo();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = GetPlatformInfo();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
if (!IsCapable()) {
return ge::GRAPH_PARAM_INVALID;
}
ret = DoOpTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = DoLibApiTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = GetWorkspaceSize();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = PostTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
context_->SetTilingKey(GetTilingKey());
DumpTilingInfo();
return ge::GRAPH_SUCCESS;
}
// Update context
virtual void Reset(gert::TilingContext* context)
{
context_ = context;
}
protected:
virtual bool IsCapable() = 0;
// 1. Get platform information such as CoreNum, UB/L1/L0C resource sizes
virtual ge::graphStatus GetPlatformInfo() = 0;
// 2. Get INPUT/OUTPUT/ATTR information
virtual ge::graphStatus GetShapeAttrsInfo() = 0;
// 3. Calculate data splitting TilingData
virtual ge::graphStatus DoOpTiling() = 0;
// 4. Calculate high-level API TilingData
virtual ge::graphStatus DoLibApiTiling() = 0;
// 5. Calculate TilingKey
[[nodiscard]] virtual uint64_t GetTilingKey() const = 0;
// 6. Calculate Workspace size
virtual ge::graphStatus GetWorkspaceSize() = 0;
// 7. Save Tiling data
virtual ge::graphStatus PostTiling() = 0;
// 8. Dump Tiling data
virtual void DumpTilingInfo()
{
int32_t enable = CheckLogLevel(static_cast<int32_t>(OP), DLOG_DEBUG);
if (enable != 1) {
return;
}
auto buf = (uint32_t*)context_->GetRawTilingData()->GetData();
auto bufLen = context_->GetRawTilingData()->GetDataSize();
std::ostringstream oss;
oss << "Start to dump tiling info. tilingkey:" << context_->GetTilingKey() << ", tiling data size:" << bufLen
<< ", content:";
for (size_t i = 0; i < bufLen / sizeof(uint32_t); i++) {
oss << *(buf + i) << ",";
if (oss.str().length() > 640) { // Split according to 640 to avoid truncation
OP_LOGD(context_, "%s", oss.str().c_str());
oss.str("");
}
}
OP_LOGD(context_, "%s", oss.str().c_str());
}
static uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum)
{
uint32_t ration;
if (aicCoreNum == 0 || aivCoreNum == 0 || aicCoreNum > aivCoreNum) {
return sliceNum;
}
ration = aivCoreNum / aicCoreNum;
return (sliceNum + (ration - 1)) / ration;
}
template <typename T>
[[nodiscard]] std::string GetShapeDebugStr(const T& shape) const
{
std::ostringstream oss;
oss << "[";
if (shape.GetDimNum() > 0) {
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
oss << shape.GetDim(i) << ", ";
}
oss << shape.GetDim(shape.GetDimNum() - 1);
}
oss << "]";
return oss.str();
}
[[nodiscard]] std::string GetTensorDebugStr(
const gert::StorageShape* shape, const gert::CompileTimeTensorDesc* tensor)
{
if (shape == nullptr || tensor == nullptr) {
return "nil ";
}
std::ostringstream oss;
oss << "(dtype: " << ge::TypeUtils::DataTypeToSerialString(tensor->GetDataType()) << "),";
oss << "(shape:" << GetShapeDebugStr(shape->GetStorageShape()) << "),";
oss << "(ori_shape:" << GetShapeDebugStr(shape->GetOriginShape()) << "),";
oss << "(format: "
<< ge::TypeUtils::FormatToSerialString(
static_cast<ge::Format>(ge::GetPrimaryFormat(tensor->GetStorageFormat())))
<< "),";
oss << "(ori_format: " << ge::TypeUtils::FormatToSerialString(tensor->GetOriginFormat()) << ") ";
return oss.str();
}
[[nodiscard]] std::string GetTilingContextDebugStr()
{
std::ostringstream oss;
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetInputsNum(); ++i) {
oss << "input" << i << ": ";
oss << GetTensorDebugStr(context_->GetInputShape(i), context_->GetInputDesc(i));
}
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetOutputsNum(); ++i) {
oss << "output" << i << ": ";
oss << GetTensorDebugStr(context_->GetOutputShape(i), context_->GetOutputDesc(i));
}
return oss.str();
}
[[nodiscard]] std::string GetTilingDataDebugStr() const
{
auto rawTilingData = context_->GetRawTilingData();
auto rawTilingDataSize = rawTilingData->GetDataSize();
auto data = reinterpret_cast<const int32_t*>(rawTilingData->GetData());
size_t len = rawTilingDataSize / sizeof(int32_t);
std::ostringstream oss;
for (size_t i = 0; i < len; i++) {
oss << data[i] << ", ";
}
return oss.str();
}
protected:
gert::TilingContext* context_ = nullptr;
std::unique_ptr<platform_ascendc::PlatformAscendC> ascendcPlatform_{nullptr};
uint32_t blockDim_{0};
uint64_t workspaceSize_{0};
uint64_t tilingKey_{0};
AiCoreParams aicoreParams_;
};
} // namespace OpTiling
} // namespace Transformer
} // namespace Ops

View File

@@ -0,0 +1,63 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_key.h
* \brief
*/
#pragma once
#include <cstdint>
namespace Ops {
namespace Transformer {
namespace OpTiling {
constexpr uint64_t RecursiveSum()
{
return 0;
}
constexpr uint64_t kBase = 10; // Base-10 carry base
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
{
return static_cast<uint64_t>(templateId) + kBase * RecursiveSum(templateIds...);
}
// TilingKey generation rules:
// FlashAttentionScore/FlashAttentionScoreGrad assembles tiling key using decimal digits, containing the following key parameters from low to high: Ub0, Ub1,
// Block, DataType, Format, Sparse. Specialized template Ub0, Ub1:
// Represents the axis for UB intra-core splitting, using AxisEnum. Since we allow at most two axes to be split, UB0 and UB1 exist. If there is no UB intra-core splitting,
// fill with AXIS_NONE. UB0 and UB1 each occupy one decimal digit;
// Block: Represents the axis used by UB for multi-core splitting, using AxisEnum, occupies one decimal digit;
// DataType: Represents the input/output data types supported by the current tiling key, using SupportedDtype enum, occupies one decimal digit
// Format: Represents the Format supported by the current tiling key, using InputLayout enum, occupies one decimal digit
// Sparse: Represents whether the current tiling key supports Sparse, using SparseCapability enum, occupies one decimal digit
// For other specialized scenarios, define your own bit fields and values
// usage: get tilingKey from inputted types
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
{
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
}
// usage: get tilingKey from inputted types
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
SparseEnum::sparse))
} // namespace Optiling
} // namespace Transformer
} // namespace Ops

View File

@@ -0,0 +1,351 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_templates_registry.h
* \brief
*/
#pragma once
#include <map>
#include <string>
#include <memory>
#include "exe_graph/runtime/tiling_context.h"
#include "tiling_base.h"
#include "error_log.h"
namespace Ops {
namespace Transformer {
namespace OpTiling {
template <typename T>
std::unique_ptr<TilingBaseClass> TILING_CLASS(gert::TilingContext* context)
{
return std::unique_ptr<T>(new (std::nothrow) T(context));
}
using TilingClassCase = std::unique_ptr<TilingBaseClass> (*)(gert::TilingContext*);
class TilingCases {
public:
explicit TilingCases(std::string op_type) : op_type_(std::move(op_type))
{}
template <typename T>
void AddTiling(int32_t priority)
{
OP_CHECK_IF(
cases_.find(priority) != cases_.end(), OP_LOGE(op_type_, "There are duplicate registrations."), return);
cases_[priority] = TILING_CLASS<T>;
OP_CHECK_IF(
cases_[priority] == nullptr,
OP_LOGE(op_type_, "Register op tiling func failed, please check the class name."), return);
}
const std::map<int32_t, TilingClassCase>& GetTilingCases()
{
return cases_;
}
private:
std::map<int32_t, TilingClassCase> cases_;
const std::string op_type_;
};
// --------------------------------Interfacce with soc version --------------------------------
class TilingRegistryNew {
public:
TilingRegistryNew() = default;
#ifdef ASCENDC_OP_TEST
static TilingRegistryNew& GetInstance();
#else
static TilingRegistryNew& GetInstance()
{
static TilingRegistryNew registry_impl_;
return registry_impl_;
}
#endif
std::shared_ptr<TilingCases> RegisterOp(const std::string& op_type, int32_t soc_version)
{
auto soc_iter = registry_map_.find(soc_version);
if (soc_iter == registry_map_.end()) {
std::map<std::string, std::shared_ptr<TilingCases>> op_type_map;
op_type_map[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
registry_map_[soc_version] = op_type_map;
} else {
if (soc_iter->second.find(op_type) == soc_iter->second.end()) {
soc_iter->second[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
}
}
OP_CHECK_IF(
registry_map_[soc_version][op_type] == nullptr,
OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr);
return registry_map_[soc_version][op_type];
}
ge::graphStatus DoTilingImpl(gert::TilingContext* context)
{
int32_t soc_version = (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION;
const char* op_type = context->GetNodeType();
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
if (platformInfoPtr == nullptr) {
auto compileInfoPtr = static_cast<const CompileInfoCommon*>(context->GetCompileInfo());
OP_CHECK_IF(
compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED);
soc_version = compileInfoPtr->socVersion;
OP_LOGD(context, "soc version in compileInfo is %d", soc_version);
} else {
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
soc_version = static_cast<int32_t>(ascendcPlatform.GetSocVersion());
OP_LOGD(context, "soc version is %d", soc_version);
if (soc_version == (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION) {
OP_LOGE(op_type, "Do op tiling failed, cannot find soc version.");
return ge::GRAPH_FAILED;
}
}
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version);
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
auto tilingTemplate = it->second(context);
if (tilingTemplate != nullptr) {
ge::graphStatus status = tilingTemplate->DoTiling();
if (status != ge::GRAPH_PARAM_INVALID) {
OP_LOGD(context, "Do general op tiling success priority=%d", it->first);
return status;
}
OP_LOGD(context, "Ignore general op tiling priority=%d", it->first);
}
}
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
return ge::GRAPH_FAILED;
}
ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector<int32_t>& priorities)
{
int32_t soc_version;
const char* op_type = context->GetNodeType();
auto platformInfoPtr = context->GetPlatformInfo();
if (platformInfoPtr == nullptr) {
auto compileInfoPtr = reinterpret_cast<const CompileInfoCommon*>(context->GetCompileInfo());
OP_CHECK_IF(
compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED);
soc_version = compileInfoPtr->socVersion;
OP_LOGD(context, "soc version in compileInfo is %d", soc_version);
} else {
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
soc_version = static_cast<int32_t>(ascendcPlatform.GetSocVersion());
OP_LOGD(context, "soc version is %d", soc_version);
}
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version);
for (auto priority_id : priorities) {
auto tilingCaseIter = tilingTemplateRegistryMap.find(priority_id);
if (tilingCaseIter != tilingTemplateRegistryMap.end()) {
auto templateFunc = tilingCaseIter->second(context);
if (templateFunc != nullptr) {
ge::graphStatus status = templateFunc->DoTiling();
if (status == ge::GRAPH_SUCCESS) {
OP_LOGD(context, "Do general op tiling success priority=%d", priority_id);
return status;
}
OP_LOGD(context, "Ignore general op tiling priority=%d", priority_id);
}
}
}
return ge::GRAPH_FAILED;
}
const std::map<int32_t, TilingClassCase>& GetTilingTemplates(const std::string& op_type, int32_t soc_version)
{
auto soc_iter = registry_map_.find(soc_version);
OP_CHECK_IF(
soc_iter == registry_map_.end(),
OP_LOGE(op_type, "Get op tiling func failed, please check the soc version %d", soc_version),
return empty_tiling_case_);
auto op_iter = soc_iter->second.find(op_type);
OP_CHECK_IF(
op_iter == soc_iter->second.end(), OP_LOGE(op_type, "Get op tiling func failed, please check the op name."),
return empty_tiling_case_);
return op_iter->second->GetTilingCases();
}
private:
std::map<int32_t, std::map<std::string, std::shared_ptr<TilingCases>>> registry_map_; // key is socversion
const std::map<int32_t, TilingClassCase> empty_tiling_case_{};
};
class RegisterNew {
public:
explicit RegisterNew(std::string op_type) : op_type_(std::move(op_type))
{}
template <typename T>
RegisterNew& tiling(int32_t priority, int32_t soc_version)
{
auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version);
OP_CHECK_IF(
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this);
tilingCases->AddTiling<T>(priority);
return *this;
}
template <typename T>
RegisterNew& tiling(int32_t priority, const std::vector<int32_t>& soc_versions)
{
for (int32_t soc_version : soc_versions) {
auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version);
OP_CHECK_IF(
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."),
return *this);
tilingCases->AddTiling<T>(priority);
}
return *this;
}
private:
const std::string op_type_;
};
// --------------------------------Interfacce without soc version --------------------------------
class TilingRegistry {
public:
TilingRegistry() = default;
#ifdef ASCENDC_OP_TEST
static TilingRegistry& GetInstance();
#else
static TilingRegistry& GetInstance()
{
static TilingRegistry registry_impl_;
return registry_impl_;
}
#endif
std::shared_ptr<TilingCases> RegisterOp(const std::string& op_type)
{
if (registry_map_.find(op_type) == registry_map_.end()) {
registry_map_[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
}
OP_CHECK_IF(
registry_map_[op_type] == nullptr,
OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr);
return registry_map_[op_type];
}
ge::graphStatus DoTilingImpl(gert::TilingContext* context)
{
const char* op_type = context->GetNodeType();
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
auto tilingTemplate = it->second(context);
if (tilingTemplate != nullptr) {
ge::graphStatus status = tilingTemplate->DoTiling();
if (status != ge::GRAPH_PARAM_INVALID) {
OP_LOGD(context, "Do general op tiling success priority=%d", it->first);
return status;
}
OP_LOGD(context, "Ignore general op tiling priority=%d", it->first);
}
}
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
return ge::GRAPH_FAILED;
}
ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector<int32_t>& priorities)
{
const char* op_type = context->GetNodeType();
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
for (auto priorityId : priorities) {
auto templateFunc = tilingTemplateRegistryMap[priorityId](context);
if (templateFunc != nullptr) {
ge::graphStatus status = templateFunc->DoTiling();
if (status == ge::GRAPH_SUCCESS) {
OP_LOGD(context, "Do general op tiling success priority=%d", priorityId);
return status;
}
if (status != ge::GRAPH_PARAM_INVALID) {
OP_LOGD(context, "Do op tiling failed");
return status;
}
OP_LOGD(context, "Ignore general op tiling priority=%d", priorityId);
}
}
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
return ge::GRAPH_FAILED;
}
const std::map<int32_t, TilingClassCase>& GetTilingTemplates(const std::string& op_type)
{
OP_CHECK_IF(
registry_map_.find(op_type) == registry_map_.end(),
OP_LOGE(op_type, "Get op tiling func failed, please check the op name."), return empty_tiling_case_);
return registry_map_[op_type]->GetTilingCases();
}
private:
std::map<std::string, std::shared_ptr<TilingCases>> registry_map_;
const std::map<int32_t, TilingClassCase> empty_tiling_case_;
};
class Register {
public:
explicit Register(std::string op_type) : op_type_(std::move(op_type))
{}
template <typename T>
Register& tiling(int32_t priority)
{
auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_);
OP_CHECK_IF(
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this);
tilingCases->AddTiling<T>(priority);
return *this;
}
private:
const std::string op_type_;
};
} // namespace OpTiling
} // namespace Transformer
} // namespace Ops
// op_type: operator name, class_name: registered tiling class, soc_version: chip version number
// priority: priority of tiling class, smaller value means higher priority, i.e., this tiling class will be selected first
#define REGISTER_TILING_TEMPLATE_WITH_SOCVERSION(op_type, class_name, soc_versions, priority) \
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \
Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling<class_name>(priority, soc_versions)
// op_type: operator name, class_name: registered tiling class
// priority: priority of tiling class, smaller value means higher priority, i.e., higher probability of being selected
#define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
static Ops::Transformer::OpTiling::Register VAR_UNUSED##op_type_##class_name##priority_register = \
Ops::Transformer::OpTiling::Register(op_type).tiling<class_name>(priority)
// op_type: operator name, class_name: registered tiling class
// soc_version: SOC version, used to distinguish different SOCs
// priority: priority of tiling class, smaller value means higher priority, i.e., this tiling class will be selected first
#define REGISTER_TILING_TEMPLATE_NEW(op_type, class_name, soc_version, priority) \
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \
Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling<class_name>(priority, soc_version)
// op_type: operator name, class_name: registered tiling class
// priority: priority of tiling class, smaller value means higher priority, i.e., higher probability of being selected
// Replaces REGISTER_TILING_TEMPLATE, if op_type is a string constant, remove the quotes
#define REGISTER_OPS_TILING_TEMPLATE(op_type, class_name, priority) \
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
static Ops::Transformer::OpTiling::Register \
__attribute__((unused)) tiling_##op_type##_##class_name##_##priority##_register = \
Ops::Transformer::OpTiling::Register(#op_type).tiling<class_name>(priority)

View File

@@ -0,0 +1,139 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_type.h
* \brief
*/
#pragma once
#include <cstdint>
namespace optiling {
enum class AxisEnum {
B = 0,
N2 = 1,
G = 2,
S1 = 3,
S2 = 4,
D = 5,
NONE = 9,
};
enum class DtypeEnum {
FLOAT16 = 0,
FLOAT32 = 1,
BFLOAT16 = 2,
FLOAT16_PRECISION = 3,
};
enum class PerformanceOrientedEnum {
BIG_BUFFER = 1,
BIG_DOUBLE_BUFFER = 2,
};
enum class MatmulConfig {
NULL_CONFIG = 0,
NORMAL_CONFIG = 1,
MDL_CONFIG = 2
};
enum class PseConfig {
NO_PSE = 0,
EXIST_PSE = 1
};
enum class AttenMaskConfig {
NO_ATTEN_MASK = 0,
EXIST_ATTEN_MASK = 1
};
enum class DropOutConfig {
NO_DROP_OUT = 0,
EXIST_DROP_OUT = 1
};
enum class CubeFormatEnum {
ND = 0,
NZ = 1
};
enum class LayoutEnum {
BSND = 0,
SBND = 1,
BNSD = 2,
TND = 3,
NTD_TND = 4
};
enum class CubeInputSourceEnum {
GM = 0,
L1 = 1
};
enum class OptionEnum {
DISABLE = 0,
ENABLE = 1
};
enum class SparseEnum {
ALL = 0,
NONE = 1,
ANY = 2,
CAUSAL = 3,
BAND = 4,
PREFIX = 5,
BAND_COMPRESS = 6,
RIGHT_DOWN_CAUSAL = 7,
RIGHT_DOWN_CAUSAL_BAND = 8,
BAND_LEFT_UP_CAUSAL = 9
};
constexpr uint64_t RecursiveSum()
{
return 0;
}
constexpr int64_t base10Multiplier = 10;
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
{
return static_cast<uint64_t>(templateId) + base10Multiplier * RecursiveSum(templateIds...);
}
// TilingKey generation rules:
// FlashAttentionScore/FlashAttentionScoreGrad assembles tiling key using decimal digits, containing the following key parameters from low to high: Ub0, Ub1,
// Block, DataType, Format, Sparse. Specialized template Ub0, Ub1:
// Represents the axis for UB intra-core splitting, using AxisEnum. Since we allow at most two axes to be split, UB0 and UB1 exist. If there is no UB intra-core splitting,
// fill with AXIS_NONE. UB0 and UB1 each occupy one decimal digit;
// Block: Represents the axis used by UB for multi-core splitting, using AxisEnum, occupies one decimal digit;
// DataType: Represents the input/output data types supported by the current tiling key, using SupportedDtype enum, occupies one decimal digit
// Format: Represents the Format supported by the current tiling key, using InputLayout enum, occupies one decimal digit
// Sparse: Represents whether the current tiling key supports Sparse, using SparseCapability enum, occupies one decimal digit
// For other specialized scenarios, define your own bit fields and values
// usage: get tilingKey from inputted types
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
{
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
}
// usage: get tilingKey from inputted types
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
SparseEnum::sparse))
} // namespace optiling

View File

@@ -0,0 +1,30 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_util.h
* \brief
*/
#pragma once
#include "register/op_impl_registry.h"
namespace Ops {
namespace Transformer {
namespace OpTiling {
bool IsRegbaseSocVersion(const gert::TilingParseContext* context);
bool IsRegbaseSocVersion(const gert::TilingContext* context);
const gert::Shape& EnsureNotScalar(const gert::Shape& inShape);
} // namespace OpTiling
} // namespace Transformer
} // namespace Ops

View File

@@ -597,6 +597,44 @@ void transpose_kv_cache_by_block(
}
at::Tensor causal_conv1d_fn(
const at::Tensor& mixed_qkv_non_spec_T,
const at::Tensor& conv_weights,
const c10::optional<at::Tensor>& bias_opt,
c10::string_view activation,
const at::Tensor& conv_state,
const at::Tensor& has_initial_state,
const at::Tensor& non_spec_state_indices_tensor,
const at::Tensor& non_spec_query_start_loc,
int64_t pad_slot_id)
{
at::Tensor x=mixed_qkv_non_spec_T; //不需要转置
at::Tensor weight=conv_weights;//不需要转置
c10::optional<at::Tensor> biasOptional =bias_opt;
at::Tensor convStates= conv_state;
at::Tensor queryStartLoc=non_spec_query_start_loc;
at::Tensor cacheIndices=non_spec_state_indices_tensor;
at::Tensor hasInitialState=has_initial_state;
int64_t activationMode=(activation.empty()?0:1);
int64_t padSlotId=pad_slot_id;
at::Tensor output = at::empty(mixed_qkv_non_spec_T.sizes(), mixed_qkv_non_spec_T.options());
EXEC_NPU_CMD(aclnnCausalConv1d,
x,
weight,
biasOptional,
convStates,
queryStartLoc,
cacheIndices,
hasInitialState,
activationMode,
padSlotId,
output
);
return output;
}
// It is expected that further improvements will be made after it is incorporated into CANN on June 30th.
std::vector<at::Tensor> moe_grouped_matmul(
at::Tensor x,
@@ -811,6 +849,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
"transpose_kv_cache_by_block(Tensor[] kCache, Tensor[] vCache, Tensor blockIDs, int blockSize, int headNum, int headDim, int splitNum, int layerNum) -> ()"
);
ops.impl("transpose_kv_cache_by_block", torch::kPrivateUse1, &vllm_ascend::transpose_kv_cache_by_block);
// causal_conv1d_fn
ops.def(
"causal_conv1d_fn(Tensor mixed_qkv_non_spec_T, "
" Tensor conv_weights, "
" Tensor? bias_opt, "
" str activation, "
" Tensor conv_state, "
" Tensor has_initial_state, "
" Tensor non_spec_state_indices_tensor, "
" Tensor non_spec_query_start_loc, "
" int pad_slot_id) -> (Tensor output)");
ops.impl("causal_conv1d_fn", torch::kPrivateUse1, &vllm_ascend::causal_conv1d_fn);
ops.def(
"moe_grouped_matmul("
"Tensor x,"

View File

@@ -458,6 +458,22 @@ void transpose_kv_cache_by_block_meta(
return;
}
at::Tensor causal_conv1d_fn_meta(
const at::Tensor& mixed_qkv_non_spec_T,
const at::Tensor& conv_weights,
const c10::optional<at::Tensor>& bias_opt,
c10::string_view activation,
const at::Tensor& conv_state,
const at::Tensor& has_initial_state,
const at::Tensor& non_spec_state_indices_tensor,
const at::Tensor& non_spec_query_start_loc,
int64_t pad_slot_id)
{
at::Tensor output = at::empty_symint(mixed_qkv_non_spec_T.sym_sizes(), mixed_qkv_non_spec_T.options());
return output;
}
std::vector<at::Tensor> moe_grouped_matmul_meta(
at::Tensor x,
at::Tensor weight,
@@ -527,6 +543,8 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
ops.impl("npu_add_rms_norm_bias", &vllm_ascend::meta::npu_add_rms_norm_bias_meta);
// transpose_kv_cache_by_block
ops.impl("transpose_kv_cache_by_block", &vllm_ascend::meta::transpose_kv_cache_by_block_meta);
// causal_conv1d_fn
ops.impl("causal_conv1d_fn", &vllm_ascend::meta::causal_conv1d_fn_meta);
// moe_grouped_matmul
ops.impl("moe_grouped_matmul", &vllm_ascend::meta::moe_grouped_matmul_meta);
}

View File

@@ -92,10 +92,15 @@ def test_qwen3_next_mtp_acceptance_tp4(model_name):
@pytest.mark.parametrize("model_name", MODELS)
@pytest.mark.parametrize("num_speculative_tokens", [1])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
@pytest.mark.skip("Skip this CI.")
def test_qwen3_next_mtp_correctness_tp4(model_name: str,
num_speculative_tokens: int,
disable_padded_drafter_batch: bool):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"Hello, my name is",
"The president of the United States is",
"The capital of France is",

View File

@@ -8,7 +8,7 @@ from vllm_ascend.ops.triton.mamba.causal_conv1d import (PAD_SLOT_ID,
causal_conv1d_fn)
from vllm_ascend.ops.triton.mamba.causal_conv1d import \
causal_conv1d_update_npu as causal_conv1d_update
from vllm_ascend.utils import enable_custom_op
def validate_cmp(y_cal, y_ref, dtype, device='npu'):
y_cal = y_cal.to(device)
@@ -157,6 +157,90 @@ def causal_conv1d_fn_pytorch(
return out_ref_tensor
@pytest.mark.parametrize('has_initial_state', [False, True])
@pytest.mark.parametrize('itype', [torch.bfloat16])
@pytest.mark.parametrize('silu_activation', [True])
@pytest.mark.parametrize('has_bias', [True])
@pytest.mark.parametrize('seq_len', [[128, 1024, 2048, 4096]])
@pytest.mark.parametrize('extra_state_len', [0, 2])
@pytest.mark.parametrize('width', [4])
@pytest.mark.parametrize('dim', [2048])
def test_ascend_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias,
silu_activation, itype, has_initial_state):
torch.random.manual_seed(0)
enable_custom_op()
device = "npu"
cu_seqlen, num_seq = sum(seq_len), len(seq_len)
state_len = width - 1 + extra_state_len
x = torch.randn(cu_seqlen, dim, device=device, dtype=itype).transpose(0, 1)
weight = torch.randn(dim, width, device=device, dtype=itype)#
query_start_loc = torch.cumsum(torch.tensor([0] + seq_len,
device=device,
dtype=torch.int32),
dim=0).to(dtype=torch.int32)
cache_indices = torch.arange(num_seq, device=device, dtype=torch.int32)
has_initial_state_tensor = torch.tensor([has_initial_state] * num_seq,
device=device,
dtype=torch.bool)
activation = None if not silu_activation else "silu"
if has_initial_state:
conv_states = torch.randn((num_seq, state_len, dim),
device=device,
dtype=itype).transpose(-1, -2)
conv_states_ref = torch.randn(
(num_seq, state_len, dim), device=device,
dtype=itype).transpose(-1, -2).copy_(conv_states)
else:
conv_states = torch.zeros((num_seq, state_len, dim),
device=device,
dtype=itype).transpose(-1, -2)
conv_states_ref = torch.zeros((num_seq, state_len, dim),
device=device,
dtype=itype).transpose(-1, -2)
if has_bias:
bias = torch.randn(dim, device=device, dtype=itype)
else:
bias = None
out_ref = causal_conv1d_fn_pytorch(
x,
weight,
bias=bias,
activation=activation,
conv_states=conv_states_ref,
has_initial_state=has_initial_state_tensor,
cache_indices=cache_indices,
query_start_loc=query_start_loc)
# out = causal_conv1d_fn(x,
# weight,
# bias=bias,
# activation=activation,
# conv_states=conv_states,
# has_initial_state=has_initial_state_tensor,
# cache_indices=cache_indices,
# query_start_loc=query_start_loc)
x_origin=x.transpose(-1, -2)
weight_origin=weight.transpose(-1, -2)
conv_states_origin=conv_states.transpose(-1, -2)
out = torch.ops._C_ascend.causal_conv1d_fn(
x_origin,
weight_origin,
bias,
activation=activation,
conv_state=conv_states_origin,
has_initial_state=has_initial_state_tensor,
non_spec_state_indices_tensor=cache_indices,
non_spec_query_start_loc=query_start_loc,
pad_slot_id=PAD_SLOT_ID,
).transpose(-1, -2)
validate_cmp(out, out_ref, itype)
validate_cmp(conv_states, conv_states_ref, itype)
@pytest.mark.parametrize('has_initial_state', [False, True])
@pytest.mark.parametrize('itype', [torch.bfloat16])
@pytest.mark.parametrize('silu_activation', [True])

View File

@@ -22,11 +22,12 @@ from einops import rearrange
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule
from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
from vllm.triton_utils import triton
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
@@ -163,20 +164,18 @@ class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet):
# 1.2: Process the remaining part
if attn_metadata.num_prefills > 0:
if mixed_qkv_non_spec is not None:
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "state_indices_tensor"
mixed_qkv_non_spec = causal_conv1d_fn(
mixed_qkv_non_spec_T,
conv_weights,
conv_weights_T = conv_weights.transpose(0, 1)
mixed_qkv_non_spec = torch.ops._C_ascend.causal_conv1d_fn(
mixed_qkv_non_spec,
conv_weights_T,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
conv_state=self_kv_cache[0],
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata,
).transpose(0, 1)
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
non_spec_query_start_loc=non_spec_query_start_loc,
pad_slot_id=PAD_SLOT_ID,
)
elif attn_metadata.num_decodes > 0:
mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv_non_spec,