[qwen3 next ]add ascend c casual_conv1d_fn (#6661)
### What this PR does / why we need it?
add ascend c casual_conv1d_fn
- vLLM version: v0.15.0
- vLLM main:
13397841ab
---------
Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <63220130+ZT-AIA@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -24,7 +24,8 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
|
|||||||
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
|
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
|
||||||
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
|
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
|
||||||
|
|
||||||
CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;"
|
|
||||||
|
CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;causal_conv1d;"
|
||||||
SOC_ARG="ascend910b"
|
SOC_ARG="ascend910b"
|
||||||
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||||
# ASCEND910C (A3) series
|
# ASCEND910C (A3) series
|
||||||
@@ -63,6 +64,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
|||||||
"add_rms_norm_bias"
|
"add_rms_norm_bias"
|
||||||
"apply_top_k_top_p_custom"
|
"apply_top_k_top_p_custom"
|
||||||
"transpose_kv_cache_by_block"
|
"transpose_kv_cache_by_block"
|
||||||
|
"causal_conv1d"
|
||||||
"moe_grouped_matmul"
|
"moe_grouped_matmul"
|
||||||
)
|
)
|
||||||
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
|
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
|
||||||
|
|||||||
50
csrc/causal_conv1d/op_host/CMakeLists.txt
Normal file
50
csrc/causal_conv1d/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
# This file is a part of the CANN Open Software.
|
||||||
|
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||||
|
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||||
|
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
# See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
# ======================================================================================================================
|
||||||
|
|
||||||
|
add_ops_compile_options(
|
||||||
|
OP_NAME CausalConv1d
|
||||||
|
OPTIONS --cce-auto-sync=off
|
||||||
|
-Wno-deprecated-declarations
|
||||||
|
-Werror
|
||||||
|
)
|
||||||
|
|
||||||
|
target_sources(op_host_aclnn PRIVATE
|
||||||
|
causal_conv1d_def.cpp
|
||||||
|
)
|
||||||
|
|
||||||
|
# target_sources(opapi PRIVATE
|
||||||
|
# aclnn_causal_conv1d.cpp
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if (NOT BUILD_OPEN_PROJECT)
|
||||||
|
# target_sources(aclnn_ops_train PRIVATE
|
||||||
|
# aclnn_causal_conv1d.cpp
|
||||||
|
# )
|
||||||
|
|
||||||
|
# target_sources(aclnn_ops_infer PRIVATE
|
||||||
|
# aclnn_causal_conv1d.cpp
|
||||||
|
# )
|
||||||
|
# endif ()
|
||||||
|
|
||||||
|
target_sources(optiling PRIVATE
|
||||||
|
causal_conv1d_tiling.cpp
|
||||||
|
tiling_util.cpp
|
||||||
|
)
|
||||||
|
|
||||||
|
target_include_directories(optiling PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
)
|
||||||
|
|
||||||
|
target_sources(opsproto PRIVATE)
|
||||||
|
|
||||||
|
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_causal_conv1d.h")
|
||||||
|
|
||||||
|
install(FILES ${_GMM_Aclnn_header}
|
||||||
|
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||||
|
)
|
||||||
83
csrc/causal_conv1d/op_host/causal_conv1d_def.cpp
Normal file
83
csrc/causal_conv1d/op_host/causal_conv1d_def.cpp
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
/**
|
||||||
|
* This program is free software, you can redistribute it and/or modify it.
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This file is a part of the CANN Open Software.
|
||||||
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||||
|
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file causal_conv1d_def.cpp
|
||||||
|
* \brief
|
||||||
|
*/
|
||||||
|
#include "register/op_def_registry.h"
|
||||||
|
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
class CausalConv1d : public OpDef {
|
||||||
|
public:
|
||||||
|
explicit CausalConv1d(const char* name) : OpDef(name)
|
||||||
|
{
|
||||||
|
this->Input("x")
|
||||||
|
.ParamType(REQUIRED)
|
||||||
|
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||||
|
.FormatList({ge::FORMAT_ND})
|
||||||
|
.AutoContiguous();
|
||||||
|
this->Input("weight")
|
||||||
|
.ParamType(REQUIRED)
|
||||||
|
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||||
|
.FormatList({ge::FORMAT_ND})
|
||||||
|
.AutoContiguous();
|
||||||
|
this->Input("bias")
|
||||||
|
.ParamType(OPTIONAL)
|
||||||
|
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||||
|
.FormatList({ge::FORMAT_ND})
|
||||||
|
.AutoContiguous();
|
||||||
|
this->Input("convStates")
|
||||||
|
.ParamType(REQUIRED)
|
||||||
|
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||||
|
.FormatList({ge::FORMAT_ND})
|
||||||
|
.AutoContiguous();
|
||||||
|
this->Input("queryStartLoc")
|
||||||
|
.ParamType(REQUIRED)
|
||||||
|
.DataTypeList({ge::DT_INT32})
|
||||||
|
.FormatList({ge::FORMAT_ND})
|
||||||
|
.AutoContiguous();
|
||||||
|
this->Input("cacheIndices")
|
||||||
|
.ParamType(REQUIRED)
|
||||||
|
.DataTypeList({ge::DT_INT32})
|
||||||
|
.FormatList({ge::FORMAT_ND})
|
||||||
|
.AutoContiguous();
|
||||||
|
this->Input("hasInitialState")
|
||||||
|
.ParamType(REQUIRED)
|
||||||
|
.DataTypeList({ge::DT_BOOL})
|
||||||
|
.FormatList({ge::FORMAT_ND})
|
||||||
|
.AutoContiguous();
|
||||||
|
|
||||||
|
this->Output("y")
|
||||||
|
.ParamType(REQUIRED)
|
||||||
|
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||||
|
.FormatList({ge::FORMAT_ND})
|
||||||
|
.AutoContiguous();
|
||||||
|
|
||||||
|
this->Attr("activationMode").AttrType(OPTIONAL).Int(0);
|
||||||
|
this->Attr("padSlotId").AttrType(OPTIONAL).Int(-1);
|
||||||
|
|
||||||
|
OpAICoreConfig aicoreConfig;
|
||||||
|
aicoreConfig.DynamicCompileStaticFlag(true)
|
||||||
|
.DynamicFormatFlag(false)
|
||||||
|
.DynamicRankSupportFlag(true)
|
||||||
|
.DynamicShapeSupportFlag(true)
|
||||||
|
.NeedCheckSupportFlag(false)
|
||||||
|
.PrecisionReduceFlag(true)
|
||||||
|
.ExtendCfgInfo("coreType.value", "AiCore");
|
||||||
|
this->AICore().AddConfig("ascend910b", aicoreConfig);
|
||||||
|
this->AICore().AddConfig("ascend910_93", aicoreConfig);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
OP_ADD(CausalConv1d);
|
||||||
|
|
||||||
|
} // namespace ops
|
||||||
49
csrc/causal_conv1d/op_host/causal_conv1d_infershape.cpp
Normal file
49
csrc/causal_conv1d/op_host/causal_conv1d_infershape.cpp
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
/**
|
||||||
|
* This program is free software, you can redistribute it and/or modify it.
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This file is a part of the CANN Open Software.
|
||||||
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||||
|
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file causal_conv1d_infershape.cpp
|
||||||
|
* \brief
|
||||||
|
*/
|
||||||
|
#include "register/op_impl_registry.h"
|
||||||
|
#include "error_log.h"
|
||||||
|
|
||||||
|
using namespace ge;
|
||||||
|
|
||||||
|
namespace ops {
|
||||||
|
static constexpr int64_t IDX_0 = 0;
|
||||||
|
|
||||||
|
static ge::graphStatus InferShapeCausalConv1d(gert::InferShapeContext* context)
|
||||||
|
{
|
||||||
|
// OPS_LOG_D(context->GetNodeName(), "Begin to do InferShapeCausalConv1d");
|
||||||
|
|
||||||
|
// get input shapes
|
||||||
|
const gert::Shape* xShape = context->GetInputShape(IDX_0);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, xShape);
|
||||||
|
|
||||||
|
// get output shapes
|
||||||
|
gert::Shape* yShape = context->GetOutputShape(IDX_0);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, yShape);
|
||||||
|
|
||||||
|
// 填充输出shape大小
|
||||||
|
auto xShapeSize = xShape->GetDimNum();
|
||||||
|
yShape->SetDimNum(xShapeSize);
|
||||||
|
for (size_t i = 0; i < xShapeSize; i++) {
|
||||||
|
int64_t dim = xShape->GetDim(i);
|
||||||
|
yShape->SetDim(i, dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
// OPS_LOG_D(context->GetNodeName(), "End to do InferShapeCausalConv1d");
|
||||||
|
return GRAPH_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
IMPL_OP_INFERSHAPE(CausalConv1d).InferShape(InferShapeCausalConv1d);
|
||||||
|
} // namespace ops
|
||||||
365
csrc/causal_conv1d/op_host/causal_conv1d_tiling.cpp
Normal file
365
csrc/causal_conv1d/op_host/causal_conv1d_tiling.cpp
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
/**
|
||||||
|
* This program is free software, you can redistribute it and/or modify it.
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This file is a part of the CANN Open Software.
|
||||||
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||||
|
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file causal_conv1d_tiling.cpp
|
||||||
|
* \brief
|
||||||
|
*/
|
||||||
|
|
||||||
|
// #include "error_log.h"
|
||||||
|
#include "log/ops_log.h"
|
||||||
|
#include "../tiling_base/tiling_templates_registry.h"
|
||||||
|
#include "../tiling_base/tiling_util.h"
|
||||||
|
#include "math_util.h"
|
||||||
|
#include "causal_conv1d_tiling.h"
|
||||||
|
#include "../op_kernel/causal_conv1d_tiling_key.h"
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
namespace optiling {
|
||||||
|
|
||||||
|
using namespace Ops::Transformer::OpTiling;
|
||||||
|
|
||||||
|
constexpr uint32_t X_INDEX = 0;
|
||||||
|
constexpr uint32_t WEIGHT_INDEX = 1;
|
||||||
|
constexpr uint32_t BIAS_INDEX = 2;
|
||||||
|
constexpr uint32_t CONV_STATES_INDEX = 3;
|
||||||
|
constexpr uint32_t QUERY_START_LOC_INDEX = 4;
|
||||||
|
constexpr uint32_t CACHE_INDICES_INDEX = 5;
|
||||||
|
constexpr uint32_t HAS_INITIAL_STATE_INDEX = 6;
|
||||||
|
|
||||||
|
constexpr int32_t ATTR_ACTIVATION_MODE_INDEX = 0;
|
||||||
|
constexpr int32_t ATTR_PAD_SLOT_ID_INDEX = 1;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
struct DimTileChoice {
|
||||||
|
int64_t dimTileSize = 0;
|
||||||
|
int64_t blocksPerSeq = 0;
|
||||||
|
int64_t gridSize = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
static inline DimTileChoice ChooseDimTileSize(gert::TilingContext* context, int64_t batch, int64_t dim, uint32_t coreNum)
|
||||||
|
{
|
||||||
|
|
||||||
|
const int64_t candidates[] = {4096, 2048, 1024, 512,384};
|
||||||
|
DimTileChoice bestOver;
|
||||||
|
int64_t bestOverGap = std::numeric_limits<int64_t>::max();
|
||||||
|
DimTileChoice bestUnder;
|
||||||
|
|
||||||
|
for (int64_t dimTileSize : candidates) {
|
||||||
|
if (dim % dimTileSize != 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const int64_t blocksPerSeq = dim / dimTileSize;
|
||||||
|
const int64_t gridSize = batch * blocksPerSeq;
|
||||||
|
if (gridSize <= 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gridSize >= static_cast<int64_t>(coreNum)) {
|
||||||
|
const int64_t gap = gridSize - static_cast<int64_t>(coreNum);
|
||||||
|
if (gap < bestOverGap) {
|
||||||
|
bestOver.dimTileSize = dimTileSize;
|
||||||
|
bestOver.blocksPerSeq = blocksPerSeq;
|
||||||
|
bestOver.gridSize = gridSize;
|
||||||
|
bestOverGap = gap;
|
||||||
|
}
|
||||||
|
} else if (gridSize > bestUnder.gridSize ||
|
||||||
|
(gridSize == bestUnder.gridSize && dimTileSize < bestUnder.dimTileSize)) {
|
||||||
|
bestUnder.dimTileSize = dimTileSize;
|
||||||
|
bestUnder.blocksPerSeq = blocksPerSeq;
|
||||||
|
bestUnder.gridSize = gridSize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DimTileChoice result = (bestOver.dimTileSize != 0) ? bestOver : bestUnder;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ge::graphStatus GetPlatformInfo(gert::TilingContext* context, uint64_t& ubSize, uint32_t& coreNum)
|
||||||
|
{
|
||||||
|
auto compileInfoPtr = context->GetCompileInfo<CausalConv1dCompileInfo>();
|
||||||
|
if (compileInfoPtr != nullptr && compileInfoPtr->coreNum != 0 && compileInfoPtr->ubSize != 0) {
|
||||||
|
ubSize = compileInfoPtr->ubSize;
|
||||||
|
coreNum = compileInfoPtr->coreNum;
|
||||||
|
return ge::GRAPH_SUCCESS;
|
||||||
|
}
|
||||||
|
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
|
||||||
|
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||||||
|
coreNum = ascendcPlatform.GetCoreNumAiv();
|
||||||
|
if(coreNum == 0) {
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
|
||||||
|
if(ubSize == 0) {
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
return ge::GRAPH_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ge::graphStatus GetWorkspaceSize(gert::TilingContext* context)
|
||||||
|
{
|
||||||
|
size_t* currentWorkspace = context->GetWorkspaceSizes(1);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, currentWorkspace);
|
||||||
|
currentWorkspace[0] = 0;
|
||||||
|
return ge::GRAPH_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ge::graphStatus GetAttrsInfo(gert::TilingContext* context, int64_t& activationMode, int64_t& padSlotId)
|
||||||
|
{
|
||||||
|
auto attrs = context->GetAttrs();
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
|
||||||
|
|
||||||
|
const int64_t* activationModePtr = attrs->GetAttrPointer<int64_t>(ATTR_ACTIVATION_MODE_INDEX);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, activationModePtr);
|
||||||
|
activationMode = *activationModePtr;
|
||||||
|
if(activationMode != 0 && activationMode != 1){
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
const int64_t* padSlotIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_PAD_SLOT_ID_INDEX);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, padSlotIdPtr);
|
||||||
|
padSlotId = *padSlotIdPtr;
|
||||||
|
|
||||||
|
return ge::GRAPH_SUCCESS;
|
||||||
|
}
|
||||||
|
static ge::graphStatus GetShapeDtypeInfo(gert::TilingContext* context, CausalConv1dTilingData& tiling)
|
||||||
|
{
|
||||||
|
auto xShapePtr = context->GetInputShape(X_INDEX);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, xShapePtr);
|
||||||
|
auto xShape = EnsureNotScalar(xShapePtr->GetStorageShape());
|
||||||
|
|
||||||
|
int64_t dim = 0;
|
||||||
|
int64_t cuSeqlen = 0;
|
||||||
|
int64_t seqLen = 0;
|
||||||
|
int64_t batch = 0;
|
||||||
|
int64_t inputMode = 0;
|
||||||
|
|
||||||
|
if (xShape.GetDimNum() == 2) {
|
||||||
|
inputMode = 0;
|
||||||
|
cuSeqlen = xShape.GetDim(0);
|
||||||
|
dim = xShape.GetDim(1);
|
||||||
|
seqLen = 0;
|
||||||
|
if(dim <= 0 || cuSeqlen < 0){
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (xShape.GetDimNum() == 3) {
|
||||||
|
inputMode = 1;
|
||||||
|
batch = xShape.GetDim(0);
|
||||||
|
seqLen = xShape.GetDim(1);
|
||||||
|
dim = xShape.GetDim(2);
|
||||||
|
cuSeqlen = batch * seqLen;
|
||||||
|
if(batch <= 0 || dim <= 0 || seqLen <= 0){
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto wShapePtr = context->GetInputShape(WEIGHT_INDEX);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, wShapePtr);
|
||||||
|
auto wShape = EnsureNotScalar(wShapePtr->GetStorageShape());
|
||||||
|
if(wShape.GetDimNum() != 2){
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
const int64_t width = wShape.GetDim(0);
|
||||||
|
const int64_t wDim = wShape.GetDim(1);
|
||||||
|
if(wDim != dim){
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
if(width != 4){
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sShapePtr = context->GetInputShape(CONV_STATES_INDEX);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, sShapePtr);
|
||||||
|
auto sShape = EnsureNotScalar(sShapePtr->GetStorageShape());
|
||||||
|
if(sShape.GetDimNum() != 3){
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
const int64_t numCacheLines = sShape.GetDim(0);
|
||||||
|
const int64_t stateLen = sShape.GetDim(1);
|
||||||
|
const int64_t sDim = sShape.GetDim(2);
|
||||||
|
if(numCacheLines <= 0){
|
||||||
|
return ge::GRAPH_FAILED;}
|
||||||
|
if(sDim != dim){
|
||||||
|
return ge::GRAPH_FAILED;}
|
||||||
|
if(stateLen < (width - 1)){
|
||||||
|
return ge::GRAPH_FAILED;}
|
||||||
|
|
||||||
|
auto qslShapePtr = context->GetInputShape(QUERY_START_LOC_INDEX);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, qslShapePtr);
|
||||||
|
auto qslShape = EnsureNotScalar(qslShapePtr->GetStorageShape());
|
||||||
|
if(qslShape.GetDimNum() != 1){
|
||||||
|
return ge::GRAPH_FAILED;}
|
||||||
|
const int64_t qslSize = qslShape.GetDim(0);
|
||||||
|
if(qslSize < 1){
|
||||||
|
return ge::GRAPH_FAILED;}
|
||||||
|
|
||||||
|
if (inputMode == 0) {
|
||||||
|
batch = qslSize - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inputMode == 1) {
|
||||||
|
if(qslSize != batch + 1){
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ciShapePtr = context->GetInputShape(CACHE_INDICES_INDEX);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, ciShapePtr);
|
||||||
|
auto ciShape = EnsureNotScalar(ciShapePtr->GetStorageShape());
|
||||||
|
if(ciShape.GetDimNum() != 1){return ge::GRAPH_FAILED;}
|
||||||
|
if(ciShape.GetDim(0) != batch){return ge::GRAPH_FAILED;}
|
||||||
|
|
||||||
|
auto hisShapePtr = context->GetInputShape(HAS_INITIAL_STATE_INDEX);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, hisShapePtr);
|
||||||
|
auto hisShape = EnsureNotScalar(hisShapePtr->GetStorageShape());
|
||||||
|
if(hisShape.GetDimNum() != 1){
|
||||||
|
return ge::GRAPH_FAILED;}
|
||||||
|
if(hisShape.GetDim(0) != batch){
|
||||||
|
return ge::GRAPH_FAILED;}
|
||||||
|
|
||||||
|
tiling.set_hasBias(0);
|
||||||
|
auto biasShapePtr = context->GetOptionalInputShape(BIAS_INDEX);
|
||||||
|
if (biasShapePtr != nullptr && biasShapePtr->GetStorageShape().GetDimNum() != 0) {
|
||||||
|
auto biasShape = EnsureNotScalar(biasShapePtr->GetStorageShape());
|
||||||
|
if(biasShape.GetDimNum() != 1){
|
||||||
|
return ge::GRAPH_FAILED;}
|
||||||
|
if(biasShape.GetDim(0) != dim){
|
||||||
|
return ge::GRAPH_FAILED;}
|
||||||
|
tiling.set_hasBias(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::set<ge::DataType> supportedXDtype = {ge::DT_BF16, ge::DT_FLOAT16};
|
||||||
|
auto xDesc = context->GetInputDesc(X_INDEX);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, xDesc);
|
||||||
|
const ge::DataType xDtype = xDesc->GetDataType();
|
||||||
|
if(supportedXDtype.count(xDtype) == 0){
|
||||||
|
return ge::GRAPH_FAILED;}
|
||||||
|
|
||||||
|
auto wDesc = context->GetInputDesc(WEIGHT_INDEX);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, wDesc);
|
||||||
|
if(wDesc->GetDataType() != xDtype){
|
||||||
|
return ge::GRAPH_FAILED;}
|
||||||
|
|
||||||
|
if (tiling.get_hasBias() == 1) {
|
||||||
|
auto biasDesc = context->GetOptionalInputDesc(BIAS_INDEX);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, biasDesc);
|
||||||
|
if(biasDesc->GetDataType() != xDtype){
|
||||||
|
return ge::GRAPH_FAILED;}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sDesc = context->GetInputDesc(CONV_STATES_INDEX);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, sDesc);
|
||||||
|
if(sDesc->GetDataType() != xDtype){
|
||||||
|
return ge::GRAPH_FAILED;}
|
||||||
|
|
||||||
|
auto qslDesc = context->GetInputDesc(QUERY_START_LOC_INDEX);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, qslDesc);
|
||||||
|
if(qslDesc->GetDataType() != ge::DT_INT32){
|
||||||
|
return ge::GRAPH_FAILED;}
|
||||||
|
|
||||||
|
auto ciDesc = context->GetInputDesc(CACHE_INDICES_INDEX);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, ciDesc);
|
||||||
|
if(ciDesc->GetDataType() != ge::DT_INT32){
|
||||||
|
return ge::GRAPH_FAILED;}
|
||||||
|
|
||||||
|
auto hisDesc = context->GetInputDesc(HAS_INITIAL_STATE_INDEX);
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, hisDesc);
|
||||||
|
if(hisDesc->GetDataType() != ge::DT_BOOL){
|
||||||
|
return ge::GRAPH_FAILED;}
|
||||||
|
|
||||||
|
tiling.set_dim(dim);
|
||||||
|
tiling.set_cuSeqlen(cuSeqlen);
|
||||||
|
tiling.set_seqLen(seqLen);
|
||||||
|
tiling.set_inputMode(inputMode);
|
||||||
|
tiling.set_width(width);
|
||||||
|
tiling.set_stateLen(stateLen);
|
||||||
|
tiling.set_numCacheLines(numCacheLines);
|
||||||
|
tiling.set_batch(batch);
|
||||||
|
return ge::GRAPH_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ge::graphStatus CausalConv1dTilingFunc(gert::TilingContext* context)
|
||||||
|
{
|
||||||
|
uint64_t ubSize;
|
||||||
|
uint32_t coreNum;
|
||||||
|
if( GetPlatformInfo(context, ubSize, coreNum) != ge::GRAPH_SUCCESS){
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(GetWorkspaceSize(context) != ge::GRAPH_SUCCESS){
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
CausalConv1dTilingData tilingData;
|
||||||
|
|
||||||
|
int64_t activationMode = 0;
|
||||||
|
int64_t padSlotId = -1;
|
||||||
|
if(GetAttrsInfo(context, activationMode, padSlotId) != ge::GRAPH_SUCCESS){
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
tilingData.set_activationMode(activationMode);
|
||||||
|
tilingData.set_padSlotId(padSlotId);
|
||||||
|
|
||||||
|
if( GetShapeDtypeInfo(context, tilingData) != ge::GRAPH_SUCCESS){
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t dim = tilingData.get_dim();
|
||||||
|
const int64_t batch = tilingData.get_batch();
|
||||||
|
if(dim <= 0 || batch <= 0){
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
const DimTileChoice choice = ChooseDimTileSize(context, batch, dim, coreNum);
|
||||||
|
const uint32_t blockDim = (choice.gridSize < static_cast<int64_t>(coreNum))
|
||||||
|
? static_cast<uint32_t>(choice.gridSize)
|
||||||
|
: coreNum;
|
||||||
|
context->SetBlockDim(blockDim);
|
||||||
|
tilingData.set_dimTileSize(choice.dimTileSize);
|
||||||
|
tilingData.set_blocksPerSeq(choice.blocksPerSeq);
|
||||||
|
|
||||||
|
const uint64_t tilingKey = GET_TPL_TILING_KEY(CAUSAL_CONV1D_TPL_SCH_MODE_DEFAULT);
|
||||||
|
context->SetTilingKey(tilingKey);
|
||||||
|
|
||||||
|
tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
|
||||||
|
context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
|
||||||
|
return ge::GRAPH_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
static ge::graphStatus TilingParseForCausalConv1d(gert::TilingParseContext* context)
|
||||||
|
{
|
||||||
|
auto platformInfoPtr = context->GetPlatformInfo();
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
|
||||||
|
auto compileInfoPtr = context->GetCompiledInfo<CausalConv1dCompileInfo>();
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfoPtr);
|
||||||
|
|
||||||
|
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||||||
|
compileInfoPtr->coreNum = static_cast<uint32_t>(ascendcPlatform.GetCoreNumAiv());
|
||||||
|
if(compileInfoPtr->coreNum == 0){
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize);
|
||||||
|
if(compileInfoPtr->ubSize == 0){
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
return ge::GRAPH_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
IMPL_OP_OPTILING(CausalConv1d)
|
||||||
|
.Tiling(CausalConv1dTilingFunc)
|
||||||
|
.TilingParse<CausalConv1dCompileInfo>(TilingParseForCausalConv1d);
|
||||||
|
} // namespace optiling
|
||||||
60
csrc/causal_conv1d/op_host/causal_conv1d_tiling.h
Normal file
60
csrc/causal_conv1d/op_host/causal_conv1d_tiling.h
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
/**
|
||||||
|
* This program is free software, you can redistribute it and/or modify it.
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This file is a part of the CANN Open Software.
|
||||||
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||||
|
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file causal_conv1d_tiling_data.h
|
||||||
|
* \brief
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H
|
||||||
|
#define ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
// #include "register/tilingdata_base.h"
|
||||||
|
// #include "tiling/tiling_api.h"
|
||||||
|
#include "register/tilingdata_base.h"
|
||||||
|
#include "error_log.h"
|
||||||
|
#include "register/op_impl_registry.h"
|
||||||
|
#include "tiling/platform/platform_ascendc.h"
|
||||||
|
#include "platform/platform_infos_def.h"
|
||||||
|
namespace optiling {
|
||||||
|
|
||||||
|
BEGIN_TILING_DATA_DEF(CausalConv1dTilingData)
|
||||||
|
TILING_DATA_FIELD_DEF(int64_t, dim);
|
||||||
|
TILING_DATA_FIELD_DEF(int64_t, cuSeqlen);
|
||||||
|
TILING_DATA_FIELD_DEF(int64_t, seqLen);
|
||||||
|
TILING_DATA_FIELD_DEF(int64_t, inputMode);
|
||||||
|
|
||||||
|
TILING_DATA_FIELD_DEF(int64_t, width);
|
||||||
|
|
||||||
|
TILING_DATA_FIELD_DEF(int64_t, stateLen);
|
||||||
|
TILING_DATA_FIELD_DEF(int64_t, numCacheLines);
|
||||||
|
|
||||||
|
TILING_DATA_FIELD_DEF(int64_t, batch);
|
||||||
|
|
||||||
|
TILING_DATA_FIELD_DEF(int64_t, activationMode);
|
||||||
|
TILING_DATA_FIELD_DEF(int64_t, padSlotId);
|
||||||
|
|
||||||
|
TILING_DATA_FIELD_DEF(int64_t, hasBias);
|
||||||
|
|
||||||
|
TILING_DATA_FIELD_DEF(int64_t, dimTileSize);
|
||||||
|
TILING_DATA_FIELD_DEF(int64_t, blocksPerSeq);
|
||||||
|
END_TILING_DATA_DEF;
|
||||||
|
struct CausalConv1dCompileInfo {
|
||||||
|
uint64_t ubSize = 0;
|
||||||
|
uint32_t coreNum = 0;
|
||||||
|
};
|
||||||
|
REGISTER_TILING_DATA_CLASS(CausalConv1d, CausalConv1dTilingData)
|
||||||
|
|
||||||
|
} // namespace optiling
|
||||||
|
|
||||||
|
#endif // ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H
|
||||||
71
csrc/causal_conv1d/op_host/error_log.h
Normal file
71
csrc/causal_conv1d/op_host/error_log.h
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||||
|
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include "toolchain/slog.h"
|
||||||
|
|
||||||
|
#define OP_LOGI(opname, ...)
|
||||||
|
#define OP_LOGW(opname, ...) \
|
||||||
|
do { \
|
||||||
|
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
|
||||||
|
printf("\n"); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
|
||||||
|
do { \
|
||||||
|
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
|
||||||
|
printf("\n"); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define OP_LOGE(opname, ...) \
|
||||||
|
do { \
|
||||||
|
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
|
||||||
|
printf("\n"); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define OP_LOGD(opname, ...)
|
||||||
|
|
||||||
|
namespace optiling {
|
||||||
|
|
||||||
|
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
|
||||||
|
do { \
|
||||||
|
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
|
||||||
|
#define OP_CHECK_IF(cond, log_func, expr) \
|
||||||
|
do { \
|
||||||
|
if (cond) { \
|
||||||
|
log_func; \
|
||||||
|
expr; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
|
||||||
|
do { \
|
||||||
|
if ((ptr) == nullptr) { \
|
||||||
|
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
|
||||||
|
return ge::GRAPH_FAILED; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
} // namespace optiling
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T CeilAlign(T a, T b)
|
||||||
|
{
|
||||||
|
return (a + b - 1) / b * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T CeilDiv(T a, T b)
|
||||||
|
{
|
||||||
|
if (b == 0) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
return (a + b - 1) / b;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||||
61
csrc/causal_conv1d/op_host/math_util.h
Normal file
61
csrc/causal_conv1d/op_host/math_util.h
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||||
|
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||||
|
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file math_util.h
|
||||||
|
* \brief
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef TILING_MATMUL_MATH_UTIL_H
|
||||||
|
#define TILING_MATMUL_MATH_UTIL_H
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
namespace matmul_tiling {
|
||||||
|
class MathUtil {
|
||||||
|
public:
|
||||||
|
static bool IsEqual(float leftValue, float rightValue);
|
||||||
|
template<typename T>
|
||||||
|
static auto CeilDivision(T num1, T num2) -> T
|
||||||
|
{
|
||||||
|
if (num2 == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return static_cast<T>((static_cast<int64_t>(num1) + static_cast<int64_t>(num2) - 1) /
|
||||||
|
static_cast<int64_t>(num2));
|
||||||
|
}
|
||||||
|
template<typename T>
|
||||||
|
static auto Align(T num1, T num2) -> T
|
||||||
|
{
|
||||||
|
return CeilDivision(num1, num2) * num2;
|
||||||
|
}
|
||||||
|
static int32_t AlignDown(int32_t num1, int32_t num2);
|
||||||
|
static bool CheckMulOverflow(int32_t a, int32_t b, int32_t &c);
|
||||||
|
static int32_t MapShape(int32_t shape, bool roundUpFlag = true);
|
||||||
|
static void AddFactor(std::vector<int32_t> &dimsFactors, int32_t dim);
|
||||||
|
static void GetFactorCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
|
||||||
|
const int32_t factorEnd);
|
||||||
|
static void GetFactorLayerCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
|
||||||
|
const int32_t factorEnd);
|
||||||
|
static bool CheckFactorNumSatisfy(const int32_t dim);
|
||||||
|
static int32_t FindBestSingleCore(const int32_t oriShape, const int32_t mappedShape, const int32_t coreNum,
|
||||||
|
bool isKDim);
|
||||||
|
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t minFactor, int32_t maxFactor);
|
||||||
|
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
|
||||||
|
static void GetBlockFactors(std::vector<int32_t> &factorList, const int32_t oriShape, const int32_t mpShape,
|
||||||
|
const int32_t coreNum, const int32_t maxNum);
|
||||||
|
static int32_t GetNonFactorMap(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
|
||||||
|
static std::vector<std::pair<int, int>> GetFactorPairs(int32_t num);
|
||||||
|
static std::pair<int32_t, int32_t> DivideIntoMainAndTail(int32_t num, int32_t divisor);
|
||||||
|
};
|
||||||
|
} // namespace matmul_tiling
|
||||||
|
#endif // _MATH_UTIL_H_
|
||||||
31
csrc/causal_conv1d/op_host/tiling_util.cpp
Normal file
31
csrc/causal_conv1d/op_host/tiling_util.cpp
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||||
|
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||||
|
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file tiling_util.cpp
|
||||||
|
* \brief
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "../tiling_base/tiling_util.h"
|
||||||
|
namespace Ops {
|
||||||
|
namespace Transformer {
|
||||||
|
namespace OpTiling {
|
||||||
|
static const gert::Shape g_vec_1_shape = {1};
|
||||||
|
|
||||||
|
const gert::Shape &EnsureNotScalar(const gert::Shape &inShape)
|
||||||
|
{
|
||||||
|
if (inShape.IsScalar()) {
|
||||||
|
return g_vec_1_shape;
|
||||||
|
}
|
||||||
|
return inShape;
|
||||||
|
}
|
||||||
|
} // namespace OpTiling
|
||||||
|
} // namespace Transformer
|
||||||
|
} // namespace Ops
|
||||||
58
csrc/causal_conv1d/op_kernel/causal_conv1d.cpp
Normal file
58
csrc/causal_conv1d/op_kernel/causal_conv1d.cpp
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
/**
|
||||||
|
* This program is free software, you can redistribute it and/or modify it.
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This file is a part of the CANN Open Software.
|
||||||
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||||
|
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file causal_conv1d.cpp
|
||||||
|
* \brief
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "causal_conv1d.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__aicore__ inline void RunCausalConv1d(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates,
|
||||||
|
GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR hasInitialState,
|
||||||
|
GM_ADDR y, const NsCausalConv1d::CausalConv1dTilingData* tilingData)
|
||||||
|
{
|
||||||
|
NsCausalConv1d::CausalConv1d<T> op;
|
||||||
|
op.Init(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, tilingData);
|
||||||
|
op.Process();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <uint32_t schMode>
|
||||||
|
__global__ __aicore__ void causal_conv1d(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates,
|
||||||
|
GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR hasInitialState,
|
||||||
|
GM_ADDR y, GM_ADDR workspace, GM_ADDR tiling)
|
||||||
|
{
|
||||||
|
REGISTER_TILING_DEFAULT( NsCausalConv1d::CausalConv1dTilingData);
|
||||||
|
// GET_TILING_DATA_WITH_STRUCT( NsCausalConv1d::CausalConv1dTilingData, tilingData, tiling);
|
||||||
|
GET_TILING_DATA(tilingData, tiling);
|
||||||
|
#if defined(ORIG_DTYPE_X)
|
||||||
|
#if (ORIG_DTYPE_X == DT_FLOAT16)
|
||||||
|
RunCausalConv1d<half>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
|
||||||
|
#elif (ORIG_DTYPE_X == DT_BF16)
|
||||||
|
RunCausalConv1d<bfloat16_t>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
|
||||||
|
#elif (ORIG_DTYPE_X == DT_FLOAT)
|
||||||
|
RunCausalConv1d<float>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
#if (DTYPE_X == DT_FLOAT16)
|
||||||
|
RunCausalConv1d<half>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
|
||||||
|
#elif (DTYPE_X == DT_BF16)
|
||||||
|
RunCausalConv1d<bfloat16_t>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
|
||||||
|
#elif (DTYPE_X == DT_FLOAT)
|
||||||
|
RunCausalConv1d<float>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
}
|
||||||
436
csrc/causal_conv1d/op_kernel/causal_conv1d.h
Normal file
436
csrc/causal_conv1d/op_kernel/causal_conv1d.h
Normal file
@@ -0,0 +1,436 @@
|
|||||||
|
/**
|
||||||
|
* This program is free software, you can redistribute it and/or modify it.
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This file is a part of the CANN Open Software.
|
||||||
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||||
|
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file causal_conv1d.h
|
||||||
|
* \brief CausalConv1D (prefill/extend) AscendC kernel implementation.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef CAUSAL_CONV1D_H
|
||||||
|
#define CAUSAL_CONV1D_H
|
||||||
|
|
||||||
|
#include "kernel_operator.h"
|
||||||
|
// #include "kernel_tiling/kernel_tiling.h"
|
||||||
|
#include "causal_conv1d_tiling_key.h"
|
||||||
|
#include "causal_conv1d_common.h"
|
||||||
|
|
||||||
|
// #define ENABLE_CAUSAL_CONV1D_DEBUG
|
||||||
|
|
||||||
|
// #ifdef ENABLE_CAUSAL_CONV1D_DEBUG
|
||||||
|
// #define CCONV_PRINTF(fmt, ...) printf(fmt, ##__VA_ARGS__)
|
||||||
|
// #else
|
||||||
|
// #define CCONV_PRINTF(fmt, ...)
|
||||||
|
// #endif
|
||||||
|
|
||||||
|
// #define CCONV_PRINT_IF(cond, fmt, ...) \
|
||||||
|
// do { \
|
||||||
|
// if (cond) { \
|
||||||
|
// CCONV_PRINTF(fmt, ##__VA_ARGS__); \
|
||||||
|
// } \
|
||||||
|
// } while (0)
|
||||||
|
|
||||||
|
// #ifdef ENABLE_CAUSAL_CONV1D_DEBUG
|
||||||
|
|
||||||
|
// #define CCONV_DUMP_TENSOR_IF(cond, tensor, size) \
|
||||||
|
// do { \
|
||||||
|
// if (cond) { \
|
||||||
|
// DumpTensor(tensor, __LINE__, size); \
|
||||||
|
// } \
|
||||||
|
// } while (0)
|
||||||
|
// #else
|
||||||
|
constexpr int32_t CCONV_DBG_SEQ = -1;
|
||||||
|
constexpr int32_t CCONV_DBG_C0 = -1;
|
||||||
|
constexpr int32_t CCONV_DBG_MAX_TOKENS = 0;
|
||||||
|
constexpr int32_t CCONV_DBG_VERBOSE_TOKENS = 0;
|
||||||
|
constexpr int32_t CCONV_DBG_DUMP_SIZE = 0;
|
||||||
|
constexpr bool CCONV_DBG_PRINT_SYNC = false;
|
||||||
|
constexpr bool CCONV_DBG_DUMP_WEIGHTS = false;
|
||||||
|
constexpr bool CCONV_DBG_DUMP_BIAS = false;
|
||||||
|
constexpr bool CCONV_DBG_DUMP_INIT_RING = false;
|
||||||
|
constexpr bool CCONV_DBG_DUMP_RUNSEQ = false;
|
||||||
|
constexpr bool CCONV_DBG_DUMP_PREFETCH = false;
|
||||||
|
constexpr bool CCONV_DBG_DUMP_STATE = false;
|
||||||
|
|
||||||
|
// #define CCONV_DUMP_TENSOR_IF(cond, tensor, size) \
|
||||||
|
// do { \
|
||||||
|
// } while (0)
|
||||||
|
// #endif
|
||||||
|
using namespace AscendC;
|
||||||
|
namespace NsCausalConv1d {
|
||||||
|
using namespace NsCausalConv1dCommon;
|
||||||
|
|
||||||
|
#ifndef CAUSAL_CONV1D_TILING_DATA_H_
|
||||||
|
#define CAUSAL_CONV1D_TILING_DATA_H_
|
||||||
|
|
||||||
|
struct CausalConv1dTilingData {
|
||||||
|
int64_t dim;
|
||||||
|
int64_t cuSeqlen;
|
||||||
|
int64_t seqLen;
|
||||||
|
int64_t inputMode;
|
||||||
|
|
||||||
|
int64_t width;
|
||||||
|
|
||||||
|
int64_t stateLen;
|
||||||
|
int64_t numCacheLines;
|
||||||
|
|
||||||
|
int64_t batch;
|
||||||
|
|
||||||
|
// attrs
|
||||||
|
int64_t activationMode; // 0: none, 1: silu/swish
|
||||||
|
int64_t padSlotId; // default -1
|
||||||
|
|
||||||
|
// optional inputs
|
||||||
|
int64_t hasBias; // 0/1
|
||||||
|
|
||||||
|
// Channel-wise tiling
|
||||||
|
int64_t dimTileSize;
|
||||||
|
int64_t blocksPerSeq;
|
||||||
|
};
|
||||||
|
#endif // CAUSAL_CONV1D_TILING_DATA_H_
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class CausalConv1d
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
__aicore__ inline CausalConv1d() = default;
|
||||||
|
|
||||||
|
__aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates, GM_ADDR queryStartLoc,
|
||||||
|
GM_ADDR cacheIndices, GM_ADDR hasInitialState, GM_ADDR y
|
||||||
|
,
|
||||||
|
const CausalConv1dTilingData* tilingData);
|
||||||
|
__aicore__ inline void Process();
|
||||||
|
|
||||||
|
private:
|
||||||
|
__aicore__ inline void LoadWeightAndBias(int32_t c0, int32_t dimTileSize, bool dbg);
|
||||||
|
__aicore__ inline void InitRing(int32_t cacheIdx, bool hasInit, int32_t start, int32_t len,
|
||||||
|
int32_t c0, int32_t dimTileSize, int32_t dim, bool dbg);
|
||||||
|
__aicore__ inline void RunSeq(int32_t start, int32_t len, int32_t c0, int32_t dimTileSize, int32_t dim, bool dbg);
|
||||||
|
__aicore__ inline void WriteBackState(int32_t cacheIdx, int32_t len, int32_t c0,
|
||||||
|
int32_t dimTileSize, int32_t dim, bool dbg);
|
||||||
|
__aicore__ inline void AllocEvents();
|
||||||
|
__aicore__ inline void ReleaseEvents();
|
||||||
|
|
||||||
|
private:
|
||||||
|
TPipe pipe;
|
||||||
|
TBuf<QuePosition::VECIN> inBuf;
|
||||||
|
TBuf<QuePosition::VECOUT> outBuf;
|
||||||
|
TBuf<QuePosition::VECCALC> calcBuf;
|
||||||
|
|
||||||
|
TEventID tempVToMte2Event_;
|
||||||
|
TEventID tempMte2ToVEvent_;
|
||||||
|
TEventID inputMte2ToVEvent_;
|
||||||
|
TEventID outMte3ToVEvent_[2];
|
||||||
|
TEventID outVToMte3Event_[2];
|
||||||
|
|
||||||
|
GlobalTensor<T> xGm;
|
||||||
|
GlobalTensor<T> weightGm;
|
||||||
|
GlobalTensor<T> biasGm;
|
||||||
|
GlobalTensor<T> convStatesGm;
|
||||||
|
GlobalTensor<int32_t> queryStartLocGm;
|
||||||
|
GlobalTensor<int32_t> cacheIndicesGm;
|
||||||
|
GlobalTensor<bool> hasInitialStateGm;
|
||||||
|
GlobalTensor<T> yGm;
|
||||||
|
|
||||||
|
const CausalConv1dTilingData* tilingData_ {nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__aicore__ inline void CausalConv1d<T>::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates,
|
||||||
|
GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR hasInitialState,
|
||||||
|
GM_ADDR y
|
||||||
|
, const CausalConv1dTilingData* tilingData)
|
||||||
|
{
|
||||||
|
// REGISTER_TILING_DEFAULT(CausalConv1dTilingData);
|
||||||
|
// auto tiling = (__gm__ CausalConv1dTilingData*)tilingGM;
|
||||||
|
// GET_TILING_DATA(tilingData, tilingGM);
|
||||||
|
tilingData_ = tilingData;
|
||||||
|
|
||||||
|
xGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(x));
|
||||||
|
weightGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(weight));
|
||||||
|
if (tilingData_->hasBias != 0) {
|
||||||
|
biasGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(bias));
|
||||||
|
}
|
||||||
|
convStatesGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(convStates));
|
||||||
|
queryStartLocGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(queryStartLoc));
|
||||||
|
cacheIndicesGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(cacheIndices));
|
||||||
|
hasInitialStateGm.SetGlobalBuffer(reinterpret_cast<__gm__ bool*>(hasInitialState));
|
||||||
|
yGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(y));
|
||||||
|
|
||||||
|
pipe.InitBuffer(inBuf, RING_SLOTS * MAX_BLOCK_DIM * sizeof(T));
|
||||||
|
pipe.InitBuffer(outBuf, 2 * MAX_BLOCK_DIM * sizeof(T));
|
||||||
|
pipe.InitBuffer(calcBuf, (MAX_WIDTH + 3) * MAX_BLOCK_DIM * sizeof(float));
|
||||||
|
|
||||||
|
AllocEvents();
|
||||||
|
|
||||||
|
// CCONV_PRINT_IF(GetBlockIdx() == 0U, "[Init] dim=%d, dimTileSize=%d, blocksPerSeq=%d, batch=%d\n",
|
||||||
|
// tilingData_->dim, tilingData_->dimTileSize, tilingData_->blocksPerSeq, tilingData_->batch);
|
||||||
|
// CCONV_PRINT_IF(GetBlockIdx() == 0U, "[Init] hasBias=%d, activationMode=%d, stateLen=%d, inputMode=%d\n",
|
||||||
|
// tilingData_->hasBias, tilingData_->activationMode, tilingData_->stateLen, tilingData_->inputMode);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__aicore__ inline void CausalConv1d<T>::AllocEvents()
|
||||||
|
{
|
||||||
|
tempVToMte2Event_ = GetTPipePtr()->AllocEventID<HardEvent::V_MTE2>();
|
||||||
|
tempMte2ToVEvent_ = GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>();
|
||||||
|
inputMte2ToVEvent_ = GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>();
|
||||||
|
outMte3ToVEvent_[0] = GetTPipePtr()->AllocEventID<HardEvent::MTE3_V>();
|
||||||
|
outMte3ToVEvent_[1] = GetTPipePtr()->AllocEventID<HardEvent::MTE3_V>();
|
||||||
|
outVToMte3Event_[0] = GetTPipePtr()->AllocEventID<HardEvent::V_MTE3>();
|
||||||
|
outVToMte3Event_[1] = GetTPipePtr()->AllocEventID<HardEvent::V_MTE3>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__aicore__ inline void CausalConv1d<T>::ReleaseEvents()
|
||||||
|
{
|
||||||
|
GetTPipePtr()->ReleaseEventID<HardEvent::V_MTE2>(tempVToMte2Event_);
|
||||||
|
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(tempMte2ToVEvent_);
|
||||||
|
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(inputMte2ToVEvent_);
|
||||||
|
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_V>(outMte3ToVEvent_[0]);
|
||||||
|
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_V>(outMte3ToVEvent_[1]);
|
||||||
|
GetTPipePtr()->ReleaseEventID<HardEvent::V_MTE3>(outVToMte3Event_[0]);
|
||||||
|
GetTPipePtr()->ReleaseEventID<HardEvent::V_MTE3>(outVToMte3Event_[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__aicore__ inline void CausalConv1d<T>::LoadWeightAndBias(int32_t c0, int32_t dimTileSize, bool dbg)
|
||||||
|
{
|
||||||
|
const int32_t dim = tilingData_->dim;
|
||||||
|
const bool dbgSync = dbg && CCONV_DBG_PRINT_SYNC;
|
||||||
|
(void)dbgSync;
|
||||||
|
LocalTensor<float> calc = calcBuf.Get<float>();
|
||||||
|
LocalTensor<float> weightF = calc;
|
||||||
|
LocalTensor<float> biasF = weightF[MAX_WIDTH * MAX_BLOCK_DIM];
|
||||||
|
LocalTensor<T> tempT = outBuf.Get<T>();
|
||||||
|
|
||||||
|
// CCONV_PRINT_IF(dbg, "[LoadWeightAndBias] c0=%d, dimTileSize=%d\n", c0, dimTileSize);
|
||||||
|
|
||||||
|
for (int32_t j = 0; j < MAX_WIDTH; ++j) {
|
||||||
|
const int64_t weightOffset = static_cast<int64_t>(j) * dim + c0;
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
DataCopy(tempT, weightGm[weightOffset], dimTileSize);
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
Cast(weightF[j * MAX_BLOCK_DIM], tempT, RoundMode::CAST_NONE, dimTileSize);
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
// if (dbg && CCONV_DBG_DUMP_WEIGHTS) {
|
||||||
|
// CCONV_PRINTF("[Dump][weightF] j=%d\n", j);
|
||||||
|
// CCONV_DUMP_TENSOR_IF(true, weightF[j * MAX_BLOCK_DIM], CCONV_DBG_DUMP_SIZE);
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tilingData_->hasBias != 0) {
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
DataCopy(tempT, biasGm[c0], dimTileSize);
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
Cast(biasF, tempT, RoundMode::CAST_NONE, dimTileSize);
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
// if (dbg && CCONV_DBG_DUMP_BIAS) {
|
||||||
|
// CCONV_PRINTF("[Dump][biasF]\n");
|
||||||
|
// CCONV_DUMP_TENSOR_IF(true, biasF, CCONV_DBG_DUMP_SIZE);
|
||||||
|
// }
|
||||||
|
} else {
|
||||||
|
Duplicate(biasF, 0.0f, dimTileSize);
|
||||||
|
// CCONV_PRINT_IF(dbg, "[LoadWeightAndBias] bias=0 (no bias)\n");
|
||||||
|
}
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__aicore__ inline void CausalConv1d<T>::InitRing(int32_t cacheIdx, bool hasInit, int32_t start, int32_t len,
|
||||||
|
int32_t c0, int32_t dimTileSize, int32_t dim, bool dbg)
|
||||||
|
{
|
||||||
|
const int32_t stateLen = tilingData_->stateLen;
|
||||||
|
LocalTensor<T> ring = inBuf.Get<T>();
|
||||||
|
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
if (hasInit) {
|
||||||
|
for (int32_t i = 0; i < (MAX_WIDTH - 1); ++i) {
|
||||||
|
const int64_t stateOffset = static_cast<int64_t>(cacheIdx) * stateLen * dim +
|
||||||
|
static_cast<int64_t>(i) * dim + c0;
|
||||||
|
DataCopy(ring[i * MAX_BLOCK_DIM], convStatesGm[stateOffset], dimTileSize);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int32_t i = 0; i < (MAX_WIDTH - 1); ++i) {
|
||||||
|
Duplicate(ring[i * MAX_BLOCK_DIM], static_cast<T>(0), dimTileSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
|
||||||
|
if (len > 0) {
|
||||||
|
const int64_t xOffset = static_cast<int64_t>(start) * dim + c0;
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
DataCopy(ring[SlotCurr(0) * MAX_BLOCK_DIM], xGm[xOffset], dimTileSize);
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__aicore__ inline void CausalConv1d<T>::RunSeq(int32_t start, int32_t len, int32_t c0, int32_t dimTileSize,
|
||||||
|
int32_t dim, bool dbg)
|
||||||
|
{
|
||||||
|
LocalTensor<float> calc = calcBuf.Get<float>();
|
||||||
|
LocalTensor<float> weightF = calc;
|
||||||
|
LocalTensor<float> biasF = weightF[MAX_WIDTH * MAX_BLOCK_DIM];
|
||||||
|
LocalTensor<float> accF = biasF[MAX_BLOCK_DIM];
|
||||||
|
LocalTensor<float> tmpF = accF[MAX_BLOCK_DIM];
|
||||||
|
LocalTensor<T> ring = inBuf.Get<T>();
|
||||||
|
LocalTensor<T> outT = outBuf.Get<T>();
|
||||||
|
const bool dbgSync = dbg && CCONV_DBG_PRINT_SYNC;
|
||||||
|
(void)dbgSync;
|
||||||
|
const bool hasActivation = (tilingData_->activationMode != 0);
|
||||||
|
const int32_t dbgMaxTokens = CCONV_DBG_MAX_TOKENS;
|
||||||
|
const int32_t dbgVerboseTokens = CCONV_DBG_VERBOSE_TOKENS;
|
||||||
|
|
||||||
|
for (int32_t t = 0; t < len; ++t) {
|
||||||
|
const bool dbgTok = dbg && (t < dbgMaxTokens);
|
||||||
|
const bool dbgVerbose = dbg && CCONV_DBG_DUMP_RUNSEQ && (t < dbgVerboseTokens);
|
||||||
|
const bool dbgStep = dbgVerbose && (t == 0);
|
||||||
|
const int32_t slotCurr = SlotCurr(t);
|
||||||
|
const int32_t slotH1 = SlotHist(t, 1);
|
||||||
|
const int32_t slotH2 = SlotHist(t, 2);
|
||||||
|
const int32_t slotH3 = SlotHist(t, 3);
|
||||||
|
const int32_t slotPref = (t + 1 < len) ? SlotPrefetch(t) : -1;
|
||||||
|
const int32_t outSlot = t & 1;
|
||||||
|
|
||||||
|
if (t + 1 < len) {
|
||||||
|
const int64_t xOffset = static_cast<int64_t>(start + t + 1) * dim + c0;
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
DataCopy(ring[slotPref * MAX_BLOCK_DIM], xGm[xOffset], dimTileSize);
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
DataCopy(accF, biasF, dimTileSize);
|
||||||
|
|
||||||
|
|
||||||
|
for (int32_t j = 0; j < MAX_WIDTH; ++j) {
|
||||||
|
const int32_t tap = (MAX_WIDTH - 1) - j;
|
||||||
|
const int32_t slot = (tap == 0) ? slotCurr : SlotHist(t, tap);
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
Cast(tmpF, ring[slot * MAX_BLOCK_DIM], RoundMode::CAST_NONE, dimTileSize);
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
MulAddDst(accF, tmpF, weightF[j * MAX_BLOCK_DIM], dimTileSize);
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasActivation) {
|
||||||
|
Silu(tmpF, accF, dimTileSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
if constexpr (IsSameType<T, float>::value) {
|
||||||
|
if (hasActivation) {
|
||||||
|
DataCopy(outT[outSlot * MAX_BLOCK_DIM], tmpF, dimTileSize);
|
||||||
|
} else {
|
||||||
|
DataCopy(outT[outSlot * MAX_BLOCK_DIM], accF, dimTileSize);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (hasActivation) {
|
||||||
|
Cast(outT[outSlot * MAX_BLOCK_DIM], tmpF, RoundMode::CAST_RINT, dimTileSize);
|
||||||
|
} else {
|
||||||
|
Cast(outT[outSlot * MAX_BLOCK_DIM], accF, RoundMode::CAST_RINT, dimTileSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
|
||||||
|
const int64_t outOffset = static_cast<int64_t>(start + t) * dim + c0;
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
DataCopy(yGm[outOffset], outT[outSlot * MAX_BLOCK_DIM], dimTileSize);
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__aicore__ inline void CausalConv1d<T>::WriteBackState(int32_t cacheIdx, int32_t len, int32_t c0,
|
||||||
|
int32_t dimTileSize, int32_t dim, bool dbg)
|
||||||
|
{
|
||||||
|
const int32_t stateLen = tilingData_->stateLen;
|
||||||
|
if (len <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int32_t lastT = len - 1;
|
||||||
|
LocalTensor<T> ring = inBuf.Get<T>();
|
||||||
|
|
||||||
|
for (int32_t pos = 0; pos < (MAX_WIDTH - 1); ++pos) {
|
||||||
|
const int32_t tap = (MAX_WIDTH - 2) - pos;
|
||||||
|
const int32_t slot = (tap == 0) ? SlotCurr(lastT) : SlotHist(lastT, tap);
|
||||||
|
const int64_t stateOffset = static_cast<int64_t>(cacheIdx) * stateLen * dim +
|
||||||
|
static_cast<int64_t>(pos) * dim + c0;
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
DataCopy(convStatesGm[stateOffset], ring[slot * MAX_BLOCK_DIM], dimTileSize);
|
||||||
|
PipeBarrier<PIPE_ALL>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__aicore__ inline void CausalConv1d<T>::Process()
|
||||||
|
{
|
||||||
|
const int32_t dim = tilingData_->dim;
|
||||||
|
const int32_t batch = tilingData_->batch;
|
||||||
|
const int32_t inputMode = tilingData_->inputMode;
|
||||||
|
const int32_t seqLen = tilingData_->seqLen;
|
||||||
|
const int32_t dimTileSize = static_cast<int32_t>(tilingData_->dimTileSize);
|
||||||
|
const int32_t blocksPerSeq = static_cast<int32_t>(tilingData_->blocksPerSeq);
|
||||||
|
|
||||||
|
const uint32_t blockIdx = GetBlockIdx();
|
||||||
|
const uint32_t blockNum = GetBlockNum();
|
||||||
|
|
||||||
|
if (dimTileSize <= 0 || blocksPerSeq <= 0 || dimTileSize > MAX_BLOCK_DIM || blocksPerSeq * dimTileSize != dim) {
|
||||||
|
ReleaseEvents();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t gridSize = static_cast<int64_t>(batch) * blocksPerSeq;
|
||||||
|
for (int64_t task = static_cast<int64_t>(blockIdx); task < gridSize; task += static_cast<int64_t>(blockNum)) {
|
||||||
|
const int32_t seq = static_cast<int32_t>(task / blocksPerSeq);
|
||||||
|
const int32_t dimBlockId = static_cast<int32_t>(task % blocksPerSeq);
|
||||||
|
const int32_t c0 = dimBlockId * dimTileSize;
|
||||||
|
const bool dbg = (seq == CCONV_DBG_SEQ) && (c0 == CCONV_DBG_C0);
|
||||||
|
|
||||||
|
LoadWeightAndBias(c0, dimTileSize, dbg);
|
||||||
|
|
||||||
|
int32_t start = 0;
|
||||||
|
int32_t len = 0;
|
||||||
|
if (inputMode == 0) {
|
||||||
|
const int32_t startVal = queryStartLocGm.GetValue(seq);
|
||||||
|
const int32_t endVal = queryStartLocGm.GetValue(seq + 1);
|
||||||
|
start = startVal;
|
||||||
|
len = endVal - startVal;
|
||||||
|
} else {
|
||||||
|
start = seq * seqLen;
|
||||||
|
len = seqLen;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (len <= 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int32_t cacheIdx = cacheIndicesGm.GetValue(seq);
|
||||||
|
if (cacheIdx == tilingData_->padSlotId) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool hasInit = hasInitialStateGm.GetValue(seq);
|
||||||
|
|
||||||
|
InitRing(cacheIdx, hasInit, start, len, c0, dimTileSize, dim, dbg);
|
||||||
|
RunSeq(start, len, c0, dimTileSize, dim, dbg);
|
||||||
|
WriteBackState(cacheIdx, len, c0, dimTileSize, dim, dbg);
|
||||||
|
}
|
||||||
|
|
||||||
|
ReleaseEvents();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace NsCausalConv1d
|
||||||
|
#endif // CAUSAL_CONV1D_H
|
||||||
45
csrc/causal_conv1d/op_kernel/causal_conv1d_common.h
Normal file
45
csrc/causal_conv1d/op_kernel/causal_conv1d_common.h
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
/**
|
||||||
|
* This program is free software, you can redistribute it and/or modify it.
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This file is a part of the CANN Open Software.
|
||||||
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||||
|
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file causal_conv1d_common.h
|
||||||
|
* \brief Common utilities and constants for CausalConv1D prefill kernel.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef CAUSAL_CONV1D_COMMON_H
|
||||||
|
#define CAUSAL_CONV1D_COMMON_H
|
||||||
|
|
||||||
|
#include "kernel_operator.h"
|
||||||
|
|
||||||
|
namespace NsCausalConv1dCommon {
|
||||||
|
|
||||||
|
constexpr int32_t MAX_WIDTH = 4;
|
||||||
|
constexpr int32_t MAX_BLOCK_DIM = 4096;
|
||||||
|
constexpr int32_t RING_SLOTS = 5;
|
||||||
|
|
||||||
|
__aicore__ inline int32_t SlotCurr(int32_t t)
|
||||||
|
{
|
||||||
|
return (t + 3) % RING_SLOTS;
|
||||||
|
}
|
||||||
|
|
||||||
|
__aicore__ inline int32_t SlotHist(int32_t t, int32_t i)
|
||||||
|
{
|
||||||
|
return (t + 3 - i) % RING_SLOTS;
|
||||||
|
}
|
||||||
|
|
||||||
|
__aicore__ inline int32_t SlotPrefetch(int32_t t)
|
||||||
|
{
|
||||||
|
return (t + 4) % RING_SLOTS;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace NsCausalConv1dCommon
|
||||||
|
|
||||||
|
#endif // CAUSAL_CONV1D_COMMON_H
|
||||||
34
csrc/causal_conv1d/op_kernel/causal_conv1d_tiling_key.h
Normal file
34
csrc/causal_conv1d/op_kernel/causal_conv1d_tiling_key.h
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
/**
|
||||||
|
* This program is free software, you can redistribute it and/or modify it.
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This file is a part of the CANN Open Software.
|
||||||
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||||
|
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file causal_conv1d_tiling_key.h
|
||||||
|
* \brief causal_conv1d tiling key declare
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef __CAUSAL_CONV1D_TILING_KEY_H__
|
||||||
|
#define __CAUSAL_CONV1D_TILING_KEY_H__
|
||||||
|
|
||||||
|
#include "ascendc/host_api/tiling/template_argument.h"
|
||||||
|
|
||||||
|
#define CAUSAL_CONV1D_TPL_SCH_MODE_DEFAULT 0
|
||||||
|
|
||||||
|
ASCENDC_TPL_ARGS_DECL(CausalConv1d,
|
||||||
|
ASCENDC_TPL_UINT_DECL(
|
||||||
|
schMode, 1, ASCENDC_TPL_UI_LIST, CAUSAL_CONV1D_TPL_SCH_MODE_DEFAULT)
|
||||||
|
);
|
||||||
|
|
||||||
|
ASCENDC_TPL_SEL(
|
||||||
|
ASCENDC_TPL_ARGS_SEL(
|
||||||
|
ASCENDC_TPL_UINT_SEL(
|
||||||
|
schMode, ASCENDC_TPL_UI_LIST, CAUSAL_CONV1D_TPL_SCH_MODE_DEFAULT)));
|
||||||
|
|
||||||
|
#endif // __CAUSAL_CONV1D_TILING_KEY_H__
|
||||||
51
csrc/causal_conv1d/tiling_base/data_copy_transpose_tiling.h
Normal file
51
csrc/causal_conv1d/tiling_base/data_copy_transpose_tiling.h
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
/**
|
||||||
|
* This program is free software, you can redistribute it and/or modify.
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This file is a part of the CANN Open Software.
|
||||||
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file data_copy_transpose_tiling.h
|
||||||
|
* \brief
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <graph/tensor.h>
|
||||||
|
#include "data_copy_transpose_tiling_def.h"
|
||||||
|
|
||||||
|
namespace optiling {
|
||||||
|
|
||||||
|
inline void GetDataCopyTransposeTiling(const ge::Shape &dstShape, const ge::Shape &srcShape, const uint32_t typeSize,
|
||||||
|
optiling::CopyTransposeTiling &tiling)
|
||||||
|
{
|
||||||
|
constexpr int64_t B_INDEX = 0;
|
||||||
|
constexpr int64_t N_INDEX = 1;
|
||||||
|
constexpr int64_t S_INDEX = 2;
|
||||||
|
constexpr int64_t H_INDEX = 3;
|
||||||
|
std::vector<int64_t> dstShapeInfo = dstShape.GetDims();
|
||||||
|
std::vector<int64_t> srcShapeInfo = srcShape.GetDims();
|
||||||
|
|
||||||
|
tiling.set_dstShapeB(dstShapeInfo[B_INDEX]);
|
||||||
|
tiling.set_dstShapeN(dstShapeInfo[N_INDEX]);
|
||||||
|
tiling.set_dstShapeS(dstShapeInfo[S_INDEX]);
|
||||||
|
tiling.set_dstShapeH(dstShapeInfo[H_INDEX]);
|
||||||
|
tiling.set_dstShapeHN(tiling.get_dstShapeH() / tiling.get_dstShapeN());
|
||||||
|
|
||||||
|
tiling.set_srcShapeB(srcShapeInfo[B_INDEX]);
|
||||||
|
tiling.set_srcShapeN(srcShapeInfo[N_INDEX]);
|
||||||
|
tiling.set_srcShapeS(srcShapeInfo[S_INDEX]);
|
||||||
|
tiling.set_srcShapeHN(srcShapeInfo[H_INDEX]);
|
||||||
|
tiling.set_originalShapeNLen(tiling.get_srcShapeHN() * typeSize);
|
||||||
|
tiling.set_shapeSHValue(tiling.get_dstShapeS() * tiling.get_dstShapeH());
|
||||||
|
tiling.set_shapeNsValue(tiling.get_dstShapeN() * tiling.get_dstShapeS());
|
||||||
|
tiling.set_shapeNsnValue(tiling.get_dstShapeN() * tiling.get_srcShapeS() * tiling.get_srcShapeN());
|
||||||
|
tiling.set_shapeBHValue(tiling.get_dstShapeB() * tiling.get_dstShapeH());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace optiling
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
/**
|
||||||
|
* This program is free software, you can redistribute it and/or modify.
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This file is a part of the CANN Open Software.
|
||||||
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file data_copy_transpose_tiling_def.h
|
||||||
|
* \brief
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <register/tilingdata_base.h>
|
||||||
|
|
||||||
|
namespace optiling {
|
||||||
|
|
||||||
|
BEGIN_TILING_DATA_DEF(CopyTransposeTiling)
|
||||||
|
TILING_DATA_FIELD_DEF(uint32_t, dstShapeB);
|
||||||
|
TILING_DATA_FIELD_DEF(uint32_t, dstShapeN);
|
||||||
|
TILING_DATA_FIELD_DEF(uint32_t, dstShapeS);
|
||||||
|
TILING_DATA_FIELD_DEF(uint32_t, dstShapeHN);
|
||||||
|
TILING_DATA_FIELD_DEF(uint32_t, dstShapeH);
|
||||||
|
TILING_DATA_FIELD_DEF(uint32_t, srcShapeB);
|
||||||
|
TILING_DATA_FIELD_DEF(uint32_t, srcShapeN);
|
||||||
|
TILING_DATA_FIELD_DEF(uint32_t, srcShapeS);
|
||||||
|
TILING_DATA_FIELD_DEF(uint32_t, srcShapeHN);
|
||||||
|
TILING_DATA_FIELD_DEF(uint32_t, originalShapeNLen);
|
||||||
|
TILING_DATA_FIELD_DEF(uint32_t, shapeSHValue);
|
||||||
|
TILING_DATA_FIELD_DEF(uint32_t, shapeNsValue);
|
||||||
|
TILING_DATA_FIELD_DEF(uint32_t, shapeNsnValue);
|
||||||
|
TILING_DATA_FIELD_DEF(uint32_t, invalidParamCopyTransposeTiling);
|
||||||
|
TILING_DATA_FIELD_DEF(uint32_t, shapeBHValue);
|
||||||
|
TILING_DATA_FIELD_DEF(uint32_t, paramsAlign);
|
||||||
|
END_TILING_DATA_DEF;
|
||||||
|
REGISTER_TILING_DATA_CLASS(CopyTransposeTilingOp, CopyTransposeTiling)
|
||||||
|
|
||||||
|
} // namespace optiling
|
||||||
56
csrc/causal_conv1d/tiling_base/error_log.h
Normal file
56
csrc/causal_conv1d/tiling_base/error_log.h
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||||
|
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include "toolchain/slog.h"
|
||||||
|
|
||||||
|
#define OP_LOGI(opname, ...)
|
||||||
|
#define OP_LOGW(opname, ...) \
|
||||||
|
do { \
|
||||||
|
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
|
||||||
|
printf("\n"); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
|
||||||
|
do { \
|
||||||
|
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
|
||||||
|
printf("\n"); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define OP_LOGE(opname, ...) \
|
||||||
|
do { \
|
||||||
|
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
|
||||||
|
printf("\n"); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define OP_LOGD(opname, ...)
|
||||||
|
|
||||||
|
namespace optiling {
|
||||||
|
|
||||||
|
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
|
||||||
|
do { \
|
||||||
|
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
// Modify OP_TILING_CHECK macro to ensure proper handling of expressions
|
||||||
|
#define OP_CHECK_IF(cond, log_func, expr) \
|
||||||
|
do { \
|
||||||
|
if (cond) { \
|
||||||
|
log_func; \
|
||||||
|
expr; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
|
||||||
|
do { \
|
||||||
|
if ((ptr) == nullptr) { \
|
||||||
|
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
|
||||||
|
return ge::GRAPH_FAILED; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
} // namespace optiling
|
||||||
|
|
||||||
|
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||||
256
csrc/causal_conv1d/tiling_base/tiling_base.h
Normal file
256
csrc/causal_conv1d/tiling_base/tiling_base.h
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
/**
|
||||||
|
* This program is free software, you can redistribute it and/or modify.
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This file is a part of the CANN Open Software.
|
||||||
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file tiling_base.h
|
||||||
|
* \brief
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
#include <exe_graph/runtime/tiling_context.h>
|
||||||
|
#include <graph/utils/type_utils.h>
|
||||||
|
#include "tiling/platform/platform_ascendc.h"
|
||||||
|
#include "error_log.h"
|
||||||
|
|
||||||
|
#ifdef ASCENDC_OP_TEST
|
||||||
|
#define ASCENDC_EXTERN_C extern "C"
|
||||||
|
#else
|
||||||
|
#define ASCENDC_EXTERN_C
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace Ops {
|
||||||
|
namespace Transformer {
|
||||||
|
namespace OpTiling {
|
||||||
|
|
||||||
|
struct AiCoreParams {
|
||||||
|
uint64_t ubSize = 0;
|
||||||
|
uint64_t blockDim = 0;
|
||||||
|
uint64_t aicNum = 0;
|
||||||
|
uint64_t l1Size = 0;
|
||||||
|
uint64_t l0aSize = 0;
|
||||||
|
uint64_t l0bSize = 0;
|
||||||
|
uint64_t l0cSize = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CompileInfoCommon {
|
||||||
|
uint32_t aivNum;
|
||||||
|
uint32_t aicNum;
|
||||||
|
uint64_t ubSize;
|
||||||
|
uint64_t l1Size;
|
||||||
|
uint64_t l0aSize;
|
||||||
|
uint64_t l0bSize;
|
||||||
|
uint64_t l0cSize;
|
||||||
|
uint64_t l2CacheSize;
|
||||||
|
int64_t coreNum;
|
||||||
|
int32_t socVersion;
|
||||||
|
uint32_t rsvd;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FlashAttentionScoreGradCompileInfo {
|
||||||
|
uint32_t aivNum;
|
||||||
|
uint32_t aicNum;
|
||||||
|
uint64_t ubSize;
|
||||||
|
uint64_t l1Size;
|
||||||
|
uint64_t l0aSize;
|
||||||
|
uint64_t l0bSize;
|
||||||
|
uint64_t l0cSize;
|
||||||
|
uint64_t l2CacheSize;
|
||||||
|
int64_t coreNum;
|
||||||
|
platform_ascendc::SocVersion socVersion;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FACompileInfoCommon {
|
||||||
|
uint32_t aivNum;
|
||||||
|
uint32_t aicNum;
|
||||||
|
uint64_t ubSize;
|
||||||
|
uint64_t l1Size;
|
||||||
|
uint64_t l0aSize;
|
||||||
|
uint64_t l0bSize;
|
||||||
|
uint64_t l0cSize;
|
||||||
|
uint64_t l2CacheSize;
|
||||||
|
int64_t coreNum;
|
||||||
|
int32_t socVersion;
|
||||||
|
uint32_t rsvd;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TilingBaseClass {
|
||||||
|
public:
|
||||||
|
explicit TilingBaseClass(gert::TilingContext* context) : context_(context)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual ~TilingBaseClass() = default;
|
||||||
|
|
||||||
|
// Tiling execution framework
|
||||||
|
// 1. GRAPH_SUCCESS: Success, and no need to continue executing subsequent Tiling class implementations
|
||||||
|
// 2. GRAPH_FAILED: Failure, abort the entire Tiling process
|
||||||
|
// 3. GRAPH_PARAM_INVALID: This class does not support, need to continue executing other Tiling class implementations
|
||||||
|
ge::graphStatus DoTiling()
|
||||||
|
{
|
||||||
|
auto ret = GetShapeAttrsInfo();
|
||||||
|
if (ret != ge::GRAPH_SUCCESS) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
ret = GetPlatformInfo();
|
||||||
|
if (ret != ge::GRAPH_SUCCESS) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
if (!IsCapable()) {
|
||||||
|
return ge::GRAPH_PARAM_INVALID;
|
||||||
|
}
|
||||||
|
ret = DoOpTiling();
|
||||||
|
if (ret != ge::GRAPH_SUCCESS) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
ret = DoLibApiTiling();
|
||||||
|
if (ret != ge::GRAPH_SUCCESS) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
ret = GetWorkspaceSize();
|
||||||
|
if (ret != ge::GRAPH_SUCCESS) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
ret = PostTiling();
|
||||||
|
if (ret != ge::GRAPH_SUCCESS) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
context_->SetTilingKey(GetTilingKey());
|
||||||
|
DumpTilingInfo();
|
||||||
|
return ge::GRAPH_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update context
|
||||||
|
virtual void Reset(gert::TilingContext* context)
|
||||||
|
{
|
||||||
|
context_ = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual bool IsCapable() = 0;
|
||||||
|
// 1. Get platform information such as CoreNum, UB/L1/L0C resource sizes
|
||||||
|
virtual ge::graphStatus GetPlatformInfo() = 0;
|
||||||
|
// 2. Get INPUT/OUTPUT/ATTR information
|
||||||
|
virtual ge::graphStatus GetShapeAttrsInfo() = 0;
|
||||||
|
// 3. Calculate data splitting TilingData
|
||||||
|
virtual ge::graphStatus DoOpTiling() = 0;
|
||||||
|
// 4. Calculate high-level API TilingData
|
||||||
|
virtual ge::graphStatus DoLibApiTiling() = 0;
|
||||||
|
// 5. Calculate TilingKey
|
||||||
|
[[nodiscard]] virtual uint64_t GetTilingKey() const = 0;
|
||||||
|
// 6. Calculate Workspace size
|
||||||
|
virtual ge::graphStatus GetWorkspaceSize() = 0;
|
||||||
|
// 7. Save Tiling data
|
||||||
|
virtual ge::graphStatus PostTiling() = 0;
|
||||||
|
// 8. Dump Tiling data
|
||||||
|
virtual void DumpTilingInfo()
|
||||||
|
{
|
||||||
|
int32_t enable = CheckLogLevel(static_cast<int32_t>(OP), DLOG_DEBUG);
|
||||||
|
if (enable != 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto buf = (uint32_t*)context_->GetRawTilingData()->GetData();
|
||||||
|
auto bufLen = context_->GetRawTilingData()->GetDataSize();
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "Start to dump tiling info. tilingkey:" << context_->GetTilingKey() << ", tiling data size:" << bufLen
|
||||||
|
<< ", content:";
|
||||||
|
for (size_t i = 0; i < bufLen / sizeof(uint32_t); i++) {
|
||||||
|
oss << *(buf + i) << ",";
|
||||||
|
if (oss.str().length() > 640) { // Split according to 640 to avoid truncation
|
||||||
|
OP_LOGD(context_, "%s", oss.str().c_str());
|
||||||
|
oss.str("");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
OP_LOGD(context_, "%s", oss.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
static uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum)
|
||||||
|
{
|
||||||
|
uint32_t ration;
|
||||||
|
if (aicCoreNum == 0 || aivCoreNum == 0 || aicCoreNum > aivCoreNum) {
|
||||||
|
return sliceNum;
|
||||||
|
}
|
||||||
|
ration = aivCoreNum / aicCoreNum;
|
||||||
|
return (sliceNum + (ration - 1)) / ration;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] std::string GetShapeDebugStr(const T& shape) const
|
||||||
|
{
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "[";
|
||||||
|
if (shape.GetDimNum() > 0) {
|
||||||
|
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
|
||||||
|
oss << shape.GetDim(i) << ", ";
|
||||||
|
}
|
||||||
|
oss << shape.GetDim(shape.GetDimNum() - 1);
|
||||||
|
}
|
||||||
|
oss << "]";
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::string GetTensorDebugStr(
|
||||||
|
const gert::StorageShape* shape, const gert::CompileTimeTensorDesc* tensor)
|
||||||
|
{
|
||||||
|
if (shape == nullptr || tensor == nullptr) {
|
||||||
|
return "nil ";
|
||||||
|
}
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "(dtype: " << ge::TypeUtils::DataTypeToSerialString(tensor->GetDataType()) << "),";
|
||||||
|
oss << "(shape:" << GetShapeDebugStr(shape->GetStorageShape()) << "),";
|
||||||
|
oss << "(ori_shape:" << GetShapeDebugStr(shape->GetOriginShape()) << "),";
|
||||||
|
oss << "(format: "
|
||||||
|
<< ge::TypeUtils::FormatToSerialString(
|
||||||
|
static_cast<ge::Format>(ge::GetPrimaryFormat(tensor->GetStorageFormat())))
|
||||||
|
<< "),";
|
||||||
|
oss << "(ori_format: " << ge::TypeUtils::FormatToSerialString(tensor->GetOriginFormat()) << ") ";
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::string GetTilingContextDebugStr()
|
||||||
|
{
|
||||||
|
std::ostringstream oss;
|
||||||
|
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetInputsNum(); ++i) {
|
||||||
|
oss << "input" << i << ": ";
|
||||||
|
oss << GetTensorDebugStr(context_->GetInputShape(i), context_->GetInputDesc(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetOutputsNum(); ++i) {
|
||||||
|
oss << "output" << i << ": ";
|
||||||
|
oss << GetTensorDebugStr(context_->GetOutputShape(i), context_->GetOutputDesc(i));
|
||||||
|
}
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::string GetTilingDataDebugStr() const
|
||||||
|
{
|
||||||
|
auto rawTilingData = context_->GetRawTilingData();
|
||||||
|
auto rawTilingDataSize = rawTilingData->GetDataSize();
|
||||||
|
auto data = reinterpret_cast<const int32_t*>(rawTilingData->GetData());
|
||||||
|
size_t len = rawTilingDataSize / sizeof(int32_t);
|
||||||
|
std::ostringstream oss;
|
||||||
|
for (size_t i = 0; i < len; i++) {
|
||||||
|
oss << data[i] << ", ";
|
||||||
|
}
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
gert::TilingContext* context_ = nullptr;
|
||||||
|
std::unique_ptr<platform_ascendc::PlatformAscendC> ascendcPlatform_{nullptr};
|
||||||
|
uint32_t blockDim_{0};
|
||||||
|
uint64_t workspaceSize_{0};
|
||||||
|
uint64_t tilingKey_{0};
|
||||||
|
AiCoreParams aicoreParams_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace OpTiling
|
||||||
|
} // namespace Transformer
|
||||||
|
} // namespace Ops
|
||||||
63
csrc/causal_conv1d/tiling_base/tiling_key.h
Normal file
63
csrc/causal_conv1d/tiling_base/tiling_key.h
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
/**
|
||||||
|
* This program is free software, you can redistribute it and/or modify.
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This file is a part of the CANN Open Software.
|
||||||
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file tiling_key.h
|
||||||
|
* \brief
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
namespace Ops {
|
||||||
|
namespace Transformer {
|
||||||
|
namespace OpTiling {
|
||||||
|
constexpr uint64_t RecursiveSum()
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr uint64_t kBase = 10; // Base-10 carry base
|
||||||
|
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
|
||||||
|
{
|
||||||
|
return static_cast<uint64_t>(templateId) + kBase * RecursiveSum(templateIds...);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TilingKey generation rules:
|
||||||
|
// FlashAttentionScore/FlashAttentionScoreGrad assembles tiling key using decimal digits, containing the following key parameters from low to high: Ub0, Ub1,
|
||||||
|
// Block, DataType, Format, Sparse. Specialized template Ub0, Ub1:
|
||||||
|
// Represents the axis for UB intra-core splitting, using AxisEnum. Since we allow at most two axes to be split, UB0 and UB1 exist. If there is no UB intra-core splitting,
|
||||||
|
// fill with AXIS_NONE. UB0 and UB1 each occupy one decimal digit;
|
||||||
|
// Block: Represents the axis used by UB for multi-core splitting, using AxisEnum, occupies one decimal digit;
|
||||||
|
// DataType: Represents the input/output data types supported by the current tiling key, using SupportedDtype enum, occupies one decimal digit
|
||||||
|
// Format: Represents the Format supported by the current tiling key, using InputLayout enum, occupies one decimal digit
|
||||||
|
// Sparse: Represents whether the current tiling key supports Sparse, using SparseCapability enum, occupies one decimal digit
|
||||||
|
// For other specialized scenarios, define your own bit fields and values
|
||||||
|
// usage: get tilingKey from inputted types
|
||||||
|
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
|
||||||
|
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
|
||||||
|
|
||||||
|
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
|
||||||
|
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
|
||||||
|
{
|
||||||
|
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
|
||||||
|
}
|
||||||
|
|
||||||
|
// usage: get tilingKey from inputted types
|
||||||
|
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
|
||||||
|
|
||||||
|
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
|
||||||
|
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
|
||||||
|
SparseEnum::sparse))
|
||||||
|
|
||||||
|
} // namespace Optiling
|
||||||
|
} // namespace Transformer
|
||||||
|
} // namespace Ops
|
||||||
351
csrc/causal_conv1d/tiling_base/tiling_templates_registry.h
Normal file
351
csrc/causal_conv1d/tiling_base/tiling_templates_registry.h
Normal file
@@ -0,0 +1,351 @@
|
|||||||
|
/**
|
||||||
|
* This program is free software, you can redistribute it and/or modify.
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This file is a part of the CANN Open Software.
|
||||||
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file tiling_templates_registry.h
|
||||||
|
* \brief
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "exe_graph/runtime/tiling_context.h"
|
||||||
|
#include "tiling_base.h"
|
||||||
|
#include "error_log.h"
|
||||||
|
|
||||||
|
namespace Ops {
|
||||||
|
namespace Transformer {
|
||||||
|
namespace OpTiling {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::unique_ptr<TilingBaseClass> TILING_CLASS(gert::TilingContext* context)
|
||||||
|
{
|
||||||
|
return std::unique_ptr<T>(new (std::nothrow) T(context));
|
||||||
|
}
|
||||||
|
|
||||||
|
using TilingClassCase = std::unique_ptr<TilingBaseClass> (*)(gert::TilingContext*);
|
||||||
|
|
||||||
|
class TilingCases {
|
||||||
|
public:
|
||||||
|
explicit TilingCases(std::string op_type) : op_type_(std::move(op_type))
|
||||||
|
{}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void AddTiling(int32_t priority)
|
||||||
|
{
|
||||||
|
OP_CHECK_IF(
|
||||||
|
cases_.find(priority) != cases_.end(), OP_LOGE(op_type_, "There are duplicate registrations."), return);
|
||||||
|
cases_[priority] = TILING_CLASS<T>;
|
||||||
|
OP_CHECK_IF(
|
||||||
|
cases_[priority] == nullptr,
|
||||||
|
OP_LOGE(op_type_, "Register op tiling func failed, please check the class name."), return);
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::map<int32_t, TilingClassCase>& GetTilingCases()
|
||||||
|
{
|
||||||
|
return cases_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::map<int32_t, TilingClassCase> cases_;
|
||||||
|
const std::string op_type_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// --------------------------------Interfacce with soc version --------------------------------
|
||||||
|
class TilingRegistryNew {
|
||||||
|
public:
|
||||||
|
TilingRegistryNew() = default;
|
||||||
|
|
||||||
|
#ifdef ASCENDC_OP_TEST
|
||||||
|
static TilingRegistryNew& GetInstance();
|
||||||
|
#else
|
||||||
|
static TilingRegistryNew& GetInstance()
|
||||||
|
{
|
||||||
|
static TilingRegistryNew registry_impl_;
|
||||||
|
return registry_impl_;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::shared_ptr<TilingCases> RegisterOp(const std::string& op_type, int32_t soc_version)
|
||||||
|
{
|
||||||
|
auto soc_iter = registry_map_.find(soc_version);
|
||||||
|
if (soc_iter == registry_map_.end()) {
|
||||||
|
std::map<std::string, std::shared_ptr<TilingCases>> op_type_map;
|
||||||
|
op_type_map[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
||||||
|
registry_map_[soc_version] = op_type_map;
|
||||||
|
} else {
|
||||||
|
if (soc_iter->second.find(op_type) == soc_iter->second.end()) {
|
||||||
|
soc_iter->second[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OP_CHECK_IF(
|
||||||
|
registry_map_[soc_version][op_type] == nullptr,
|
||||||
|
OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr);
|
||||||
|
return registry_map_[soc_version][op_type];
|
||||||
|
}
|
||||||
|
|
||||||
|
ge::graphStatus DoTilingImpl(gert::TilingContext* context)
|
||||||
|
{
|
||||||
|
int32_t soc_version = (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION;
|
||||||
|
const char* op_type = context->GetNodeType();
|
||||||
|
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
|
||||||
|
if (platformInfoPtr == nullptr) {
|
||||||
|
auto compileInfoPtr = static_cast<const CompileInfoCommon*>(context->GetCompileInfo());
|
||||||
|
OP_CHECK_IF(
|
||||||
|
compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED);
|
||||||
|
soc_version = compileInfoPtr->socVersion;
|
||||||
|
OP_LOGD(context, "soc version in compileInfo is %d", soc_version);
|
||||||
|
} else {
|
||||||
|
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||||||
|
soc_version = static_cast<int32_t>(ascendcPlatform.GetSocVersion());
|
||||||
|
OP_LOGD(context, "soc version is %d", soc_version);
|
||||||
|
if (soc_version == (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION) {
|
||||||
|
OP_LOGE(op_type, "Do op tiling failed, cannot find soc version.");
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version);
|
||||||
|
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
|
||||||
|
auto tilingTemplate = it->second(context);
|
||||||
|
if (tilingTemplate != nullptr) {
|
||||||
|
ge::graphStatus status = tilingTemplate->DoTiling();
|
||||||
|
if (status != ge::GRAPH_PARAM_INVALID) {
|
||||||
|
OP_LOGD(context, "Do general op tiling success priority=%d", it->first);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
OP_LOGD(context, "Ignore general op tiling priority=%d", it->first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector<int32_t>& priorities)
|
||||||
|
{
|
||||||
|
int32_t soc_version;
|
||||||
|
const char* op_type = context->GetNodeType();
|
||||||
|
auto platformInfoPtr = context->GetPlatformInfo();
|
||||||
|
if (platformInfoPtr == nullptr) {
|
||||||
|
auto compileInfoPtr = reinterpret_cast<const CompileInfoCommon*>(context->GetCompileInfo());
|
||||||
|
OP_CHECK_IF(
|
||||||
|
compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED);
|
||||||
|
soc_version = compileInfoPtr->socVersion;
|
||||||
|
OP_LOGD(context, "soc version in compileInfo is %d", soc_version);
|
||||||
|
} else {
|
||||||
|
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||||||
|
soc_version = static_cast<int32_t>(ascendcPlatform.GetSocVersion());
|
||||||
|
OP_LOGD(context, "soc version is %d", soc_version);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version);
|
||||||
|
for (auto priority_id : priorities) {
|
||||||
|
auto tilingCaseIter = tilingTemplateRegistryMap.find(priority_id);
|
||||||
|
if (tilingCaseIter != tilingTemplateRegistryMap.end()) {
|
||||||
|
auto templateFunc = tilingCaseIter->second(context);
|
||||||
|
if (templateFunc != nullptr) {
|
||||||
|
ge::graphStatus status = templateFunc->DoTiling();
|
||||||
|
if (status == ge::GRAPH_SUCCESS) {
|
||||||
|
OP_LOGD(context, "Do general op tiling success priority=%d", priority_id);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
OP_LOGD(context, "Ignore general op tiling priority=%d", priority_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::map<int32_t, TilingClassCase>& GetTilingTemplates(const std::string& op_type, int32_t soc_version)
|
||||||
|
{
|
||||||
|
auto soc_iter = registry_map_.find(soc_version);
|
||||||
|
OP_CHECK_IF(
|
||||||
|
soc_iter == registry_map_.end(),
|
||||||
|
OP_LOGE(op_type, "Get op tiling func failed, please check the soc version %d", soc_version),
|
||||||
|
return empty_tiling_case_);
|
||||||
|
auto op_iter = soc_iter->second.find(op_type);
|
||||||
|
OP_CHECK_IF(
|
||||||
|
op_iter == soc_iter->second.end(), OP_LOGE(op_type, "Get op tiling func failed, please check the op name."),
|
||||||
|
return empty_tiling_case_);
|
||||||
|
return op_iter->second->GetTilingCases();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::map<int32_t, std::map<std::string, std::shared_ptr<TilingCases>>> registry_map_; // key is socversion
|
||||||
|
const std::map<int32_t, TilingClassCase> empty_tiling_case_{};
|
||||||
|
};
|
||||||
|
|
||||||
|
class RegisterNew {
|
||||||
|
public:
|
||||||
|
explicit RegisterNew(std::string op_type) : op_type_(std::move(op_type))
|
||||||
|
{}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
RegisterNew& tiling(int32_t priority, int32_t soc_version)
|
||||||
|
{
|
||||||
|
auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version);
|
||||||
|
OP_CHECK_IF(
|
||||||
|
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this);
|
||||||
|
tilingCases->AddTiling<T>(priority);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
RegisterNew& tiling(int32_t priority, const std::vector<int32_t>& soc_versions)
|
||||||
|
{
|
||||||
|
for (int32_t soc_version : soc_versions) {
|
||||||
|
auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version);
|
||||||
|
OP_CHECK_IF(
|
||||||
|
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."),
|
||||||
|
return *this);
|
||||||
|
tilingCases->AddTiling<T>(priority);
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::string op_type_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// --------------------------------Interfacce without soc version --------------------------------
|
||||||
|
class TilingRegistry {
|
||||||
|
public:
|
||||||
|
TilingRegistry() = default;
|
||||||
|
|
||||||
|
#ifdef ASCENDC_OP_TEST
|
||||||
|
static TilingRegistry& GetInstance();
|
||||||
|
#else
|
||||||
|
static TilingRegistry& GetInstance()
|
||||||
|
{
|
||||||
|
static TilingRegistry registry_impl_;
|
||||||
|
return registry_impl_;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::shared_ptr<TilingCases> RegisterOp(const std::string& op_type)
|
||||||
|
{
|
||||||
|
if (registry_map_.find(op_type) == registry_map_.end()) {
|
||||||
|
registry_map_[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
||||||
|
}
|
||||||
|
OP_CHECK_IF(
|
||||||
|
registry_map_[op_type] == nullptr,
|
||||||
|
OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr);
|
||||||
|
return registry_map_[op_type];
|
||||||
|
}
|
||||||
|
|
||||||
|
ge::graphStatus DoTilingImpl(gert::TilingContext* context)
|
||||||
|
{
|
||||||
|
const char* op_type = context->GetNodeType();
|
||||||
|
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
|
||||||
|
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
|
||||||
|
auto tilingTemplate = it->second(context);
|
||||||
|
if (tilingTemplate != nullptr) {
|
||||||
|
ge::graphStatus status = tilingTemplate->DoTiling();
|
||||||
|
if (status != ge::GRAPH_PARAM_INVALID) {
|
||||||
|
OP_LOGD(context, "Do general op tiling success priority=%d", it->first);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
OP_LOGD(context, "Ignore general op tiling priority=%d", it->first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector<int32_t>& priorities)
|
||||||
|
{
|
||||||
|
const char* op_type = context->GetNodeType();
|
||||||
|
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
|
||||||
|
for (auto priorityId : priorities) {
|
||||||
|
auto templateFunc = tilingTemplateRegistryMap[priorityId](context);
|
||||||
|
if (templateFunc != nullptr) {
|
||||||
|
ge::graphStatus status = templateFunc->DoTiling();
|
||||||
|
if (status == ge::GRAPH_SUCCESS) {
|
||||||
|
OP_LOGD(context, "Do general op tiling success priority=%d", priorityId);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
if (status != ge::GRAPH_PARAM_INVALID) {
|
||||||
|
OP_LOGD(context, "Do op tiling failed");
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
OP_LOGD(context, "Ignore general op tiling priority=%d", priorityId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
|
||||||
|
return ge::GRAPH_FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::map<int32_t, TilingClassCase>& GetTilingTemplates(const std::string& op_type)
|
||||||
|
{
|
||||||
|
OP_CHECK_IF(
|
||||||
|
registry_map_.find(op_type) == registry_map_.end(),
|
||||||
|
OP_LOGE(op_type, "Get op tiling func failed, please check the op name."), return empty_tiling_case_);
|
||||||
|
return registry_map_[op_type]->GetTilingCases();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::map<std::string, std::shared_ptr<TilingCases>> registry_map_;
|
||||||
|
const std::map<int32_t, TilingClassCase> empty_tiling_case_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Register {
|
||||||
|
public:
|
||||||
|
explicit Register(std::string op_type) : op_type_(std::move(op_type))
|
||||||
|
{}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Register& tiling(int32_t priority)
|
||||||
|
{
|
||||||
|
auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_);
|
||||||
|
OP_CHECK_IF(
|
||||||
|
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this);
|
||||||
|
tilingCases->AddTiling<T>(priority);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::string op_type_;
|
||||||
|
};
|
||||||
|
} // namespace OpTiling
|
||||||
|
} // namespace Transformer
|
||||||
|
} // namespace Ops
|
||||||
|
|
||||||
|
// op_type: operator name, class_name: registered tiling class, soc_version: chip version number
|
||||||
|
// priority: priority of tiling class, smaller value means higher priority, i.e., this tiling class will be selected first
|
||||||
|
#define REGISTER_TILING_TEMPLATE_WITH_SOCVERSION(op_type, class_name, soc_versions, priority) \
|
||||||
|
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
||||||
|
static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \
|
||||||
|
Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling<class_name>(priority, soc_versions)
|
||||||
|
|
||||||
|
// op_type: operator name, class_name: registered tiling class
|
||||||
|
// priority: priority of tiling class, smaller value means higher priority, i.e., higher probability of being selected
|
||||||
|
#define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \
|
||||||
|
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
||||||
|
static Ops::Transformer::OpTiling::Register VAR_UNUSED##op_type_##class_name##priority_register = \
|
||||||
|
Ops::Transformer::OpTiling::Register(op_type).tiling<class_name>(priority)
|
||||||
|
|
||||||
|
// op_type: operator name, class_name: registered tiling class
|
||||||
|
// soc_version: SOC version, used to distinguish different SOCs
|
||||||
|
// priority: priority of tiling class, smaller value means higher priority, i.e., this tiling class will be selected first
|
||||||
|
#define REGISTER_TILING_TEMPLATE_NEW(op_type, class_name, soc_version, priority) \
|
||||||
|
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
||||||
|
static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \
|
||||||
|
Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling<class_name>(priority, soc_version)
|
||||||
|
|
||||||
|
// op_type: operator name, class_name: registered tiling class
|
||||||
|
// priority: priority of tiling class, smaller value means higher priority, i.e., higher probability of being selected
|
||||||
|
// Replaces REGISTER_TILING_TEMPLATE, if op_type is a string constant, remove the quotes
|
||||||
|
#define REGISTER_OPS_TILING_TEMPLATE(op_type, class_name, priority) \
|
||||||
|
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
||||||
|
static Ops::Transformer::OpTiling::Register \
|
||||||
|
__attribute__((unused)) tiling_##op_type##_##class_name##_##priority##_register = \
|
||||||
|
Ops::Transformer::OpTiling::Register(#op_type).tiling<class_name>(priority)
|
||||||
139
csrc/causal_conv1d/tiling_base/tiling_type.h
Normal file
139
csrc/causal_conv1d/tiling_base/tiling_type.h
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
/**
|
||||||
|
* This program is free software, you can redistribute it and/or modify.
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This file is a part of the CANN Open Software.
|
||||||
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file tiling_type.h
|
||||||
|
* \brief
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
namespace optiling {
|
||||||
|
|
||||||
|
enum class AxisEnum {
|
||||||
|
B = 0,
|
||||||
|
N2 = 1,
|
||||||
|
G = 2,
|
||||||
|
S1 = 3,
|
||||||
|
S2 = 4,
|
||||||
|
D = 5,
|
||||||
|
NONE = 9,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class DtypeEnum {
|
||||||
|
FLOAT16 = 0,
|
||||||
|
FLOAT32 = 1,
|
||||||
|
BFLOAT16 = 2,
|
||||||
|
FLOAT16_PRECISION = 3,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class PerformanceOrientedEnum {
|
||||||
|
BIG_BUFFER = 1,
|
||||||
|
BIG_DOUBLE_BUFFER = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class MatmulConfig {
|
||||||
|
NULL_CONFIG = 0,
|
||||||
|
NORMAL_CONFIG = 1,
|
||||||
|
MDL_CONFIG = 2
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class PseConfig {
|
||||||
|
NO_PSE = 0,
|
||||||
|
EXIST_PSE = 1
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class AttenMaskConfig {
|
||||||
|
NO_ATTEN_MASK = 0,
|
||||||
|
EXIST_ATTEN_MASK = 1
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class DropOutConfig {
|
||||||
|
NO_DROP_OUT = 0,
|
||||||
|
EXIST_DROP_OUT = 1
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class CubeFormatEnum {
|
||||||
|
ND = 0,
|
||||||
|
NZ = 1
|
||||||
|
};
|
||||||
|
enum class LayoutEnum {
|
||||||
|
BSND = 0,
|
||||||
|
SBND = 1,
|
||||||
|
BNSD = 2,
|
||||||
|
TND = 3,
|
||||||
|
NTD_TND = 4
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class CubeInputSourceEnum {
|
||||||
|
GM = 0,
|
||||||
|
L1 = 1
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class OptionEnum {
|
||||||
|
DISABLE = 0,
|
||||||
|
ENABLE = 1
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class SparseEnum {
|
||||||
|
ALL = 0,
|
||||||
|
NONE = 1,
|
||||||
|
ANY = 2,
|
||||||
|
CAUSAL = 3,
|
||||||
|
BAND = 4,
|
||||||
|
PREFIX = 5,
|
||||||
|
BAND_COMPRESS = 6,
|
||||||
|
RIGHT_DOWN_CAUSAL = 7,
|
||||||
|
RIGHT_DOWN_CAUSAL_BAND = 8,
|
||||||
|
BAND_LEFT_UP_CAUSAL = 9
|
||||||
|
};
|
||||||
|
|
||||||
|
constexpr uint64_t RecursiveSum()
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int64_t base10Multiplier = 10;
|
||||||
|
|
||||||
|
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
|
||||||
|
{
|
||||||
|
return static_cast<uint64_t>(templateId) + base10Multiplier * RecursiveSum(templateIds...);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TilingKey generation rules:
|
||||||
|
// FlashAttentionScore/FlashAttentionScoreGrad assembles tiling key using decimal digits, containing the following key parameters from low to high: Ub0, Ub1,
|
||||||
|
// Block, DataType, Format, Sparse. Specialized template Ub0, Ub1:
|
||||||
|
// Represents the axis for UB intra-core splitting, using AxisEnum. Since we allow at most two axes to be split, UB0 and UB1 exist. If there is no UB intra-core splitting,
|
||||||
|
// fill with AXIS_NONE. UB0 and UB1 each occupy one decimal digit;
|
||||||
|
// Block: Represents the axis used by UB for multi-core splitting, using AxisEnum, occupies one decimal digit;
|
||||||
|
// DataType: Represents the input/output data types supported by the current tiling key, using SupportedDtype enum, occupies one decimal digit
|
||||||
|
// Format: Represents the Format supported by the current tiling key, using InputLayout enum, occupies one decimal digit
|
||||||
|
// Sparse: Represents whether the current tiling key supports Sparse, using SparseCapability enum, occupies one decimal digit
|
||||||
|
// For other specialized scenarios, define your own bit fields and values
|
||||||
|
// usage: get tilingKey from inputted types
|
||||||
|
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
|
||||||
|
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
|
||||||
|
|
||||||
|
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
|
||||||
|
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
|
||||||
|
{
|
||||||
|
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
|
||||||
|
}
|
||||||
|
|
||||||
|
// usage: get tilingKey from inputted types
|
||||||
|
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
|
||||||
|
|
||||||
|
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
|
||||||
|
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
|
||||||
|
SparseEnum::sparse))
|
||||||
|
|
||||||
|
} // namespace optiling
|
||||||
30
csrc/causal_conv1d/tiling_base/tiling_util.h
Normal file
30
csrc/causal_conv1d/tiling_base/tiling_util.h
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
/**
|
||||||
|
* This program is free software, you can redistribute it and/or modify.
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This file is a part of the CANN Open Software.
|
||||||
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file tiling_util.h
|
||||||
|
* \brief
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "register/op_impl_registry.h"
|
||||||
|
|
||||||
|
namespace Ops {
|
||||||
|
namespace Transformer {
|
||||||
|
namespace OpTiling {
|
||||||
|
bool IsRegbaseSocVersion(const gert::TilingParseContext* context);
|
||||||
|
|
||||||
|
bool IsRegbaseSocVersion(const gert::TilingContext* context);
|
||||||
|
|
||||||
|
const gert::Shape& EnsureNotScalar(const gert::Shape& inShape);
|
||||||
|
} // namespace OpTiling
|
||||||
|
} // namespace Transformer
|
||||||
|
} // namespace Ops
|
||||||
@@ -597,6 +597,44 @@ void transpose_kv_cache_by_block(
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
at::Tensor causal_conv1d_fn(
|
||||||
|
const at::Tensor& mixed_qkv_non_spec_T,
|
||||||
|
const at::Tensor& conv_weights,
|
||||||
|
const c10::optional<at::Tensor>& bias_opt,
|
||||||
|
c10::string_view activation,
|
||||||
|
const at::Tensor& conv_state,
|
||||||
|
const at::Tensor& has_initial_state,
|
||||||
|
const at::Tensor& non_spec_state_indices_tensor,
|
||||||
|
const at::Tensor& non_spec_query_start_loc,
|
||||||
|
int64_t pad_slot_id)
|
||||||
|
{
|
||||||
|
at::Tensor x=mixed_qkv_non_spec_T; //不需要转置
|
||||||
|
at::Tensor weight=conv_weights;//不需要转置
|
||||||
|
c10::optional<at::Tensor> biasOptional =bias_opt;
|
||||||
|
at::Tensor convStates= conv_state;
|
||||||
|
at::Tensor queryStartLoc=non_spec_query_start_loc;
|
||||||
|
at::Tensor cacheIndices=non_spec_state_indices_tensor;
|
||||||
|
at::Tensor hasInitialState=has_initial_state;
|
||||||
|
int64_t activationMode=(activation.empty()?0:1);
|
||||||
|
int64_t padSlotId=pad_slot_id;
|
||||||
|
|
||||||
|
at::Tensor output = at::empty(mixed_qkv_non_spec_T.sizes(), mixed_qkv_non_spec_T.options());
|
||||||
|
EXEC_NPU_CMD(aclnnCausalConv1d,
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
biasOptional,
|
||||||
|
convStates,
|
||||||
|
queryStartLoc,
|
||||||
|
cacheIndices,
|
||||||
|
hasInitialState,
|
||||||
|
activationMode,
|
||||||
|
padSlotId,
|
||||||
|
output
|
||||||
|
);
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
// It is expected that further improvements will be made after it is incorporated into CANN on June 30th.
|
// It is expected that further improvements will be made after it is incorporated into CANN on June 30th.
|
||||||
std::vector<at::Tensor> moe_grouped_matmul(
|
std::vector<at::Tensor> moe_grouped_matmul(
|
||||||
at::Tensor x,
|
at::Tensor x,
|
||||||
@@ -811,6 +849,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
|||||||
"transpose_kv_cache_by_block(Tensor[] kCache, Tensor[] vCache, Tensor blockIDs, int blockSize, int headNum, int headDim, int splitNum, int layerNum) -> ()"
|
"transpose_kv_cache_by_block(Tensor[] kCache, Tensor[] vCache, Tensor blockIDs, int blockSize, int headNum, int headDim, int splitNum, int layerNum) -> ()"
|
||||||
);
|
);
|
||||||
ops.impl("transpose_kv_cache_by_block", torch::kPrivateUse1, &vllm_ascend::transpose_kv_cache_by_block);
|
ops.impl("transpose_kv_cache_by_block", torch::kPrivateUse1, &vllm_ascend::transpose_kv_cache_by_block);
|
||||||
|
// causal_conv1d_fn
|
||||||
|
ops.def(
|
||||||
|
"causal_conv1d_fn(Tensor mixed_qkv_non_spec_T, "
|
||||||
|
" Tensor conv_weights, "
|
||||||
|
" Tensor? bias_opt, "
|
||||||
|
" str activation, "
|
||||||
|
" Tensor conv_state, "
|
||||||
|
" Tensor has_initial_state, "
|
||||||
|
" Tensor non_spec_state_indices_tensor, "
|
||||||
|
" Tensor non_spec_query_start_loc, "
|
||||||
|
" int pad_slot_id) -> (Tensor output)");
|
||||||
|
ops.impl("causal_conv1d_fn", torch::kPrivateUse1, &vllm_ascend::causal_conv1d_fn);
|
||||||
ops.def(
|
ops.def(
|
||||||
"moe_grouped_matmul("
|
"moe_grouped_matmul("
|
||||||
"Tensor x,"
|
"Tensor x,"
|
||||||
|
|||||||
@@ -458,6 +458,22 @@ void transpose_kv_cache_by_block_meta(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
at::Tensor causal_conv1d_fn_meta(
|
||||||
|
const at::Tensor& mixed_qkv_non_spec_T,
|
||||||
|
const at::Tensor& conv_weights,
|
||||||
|
const c10::optional<at::Tensor>& bias_opt,
|
||||||
|
c10::string_view activation,
|
||||||
|
const at::Tensor& conv_state,
|
||||||
|
const at::Tensor& has_initial_state,
|
||||||
|
const at::Tensor& non_spec_state_indices_tensor,
|
||||||
|
const at::Tensor& non_spec_query_start_loc,
|
||||||
|
int64_t pad_slot_id)
|
||||||
|
{
|
||||||
|
|
||||||
|
at::Tensor output = at::empty_symint(mixed_qkv_non_spec_T.sym_sizes(), mixed_qkv_non_spec_T.options());
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<at::Tensor> moe_grouped_matmul_meta(
|
std::vector<at::Tensor> moe_grouped_matmul_meta(
|
||||||
at::Tensor x,
|
at::Tensor x,
|
||||||
at::Tensor weight,
|
at::Tensor weight,
|
||||||
@@ -527,6 +543,8 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
|||||||
ops.impl("npu_add_rms_norm_bias", &vllm_ascend::meta::npu_add_rms_norm_bias_meta);
|
ops.impl("npu_add_rms_norm_bias", &vllm_ascend::meta::npu_add_rms_norm_bias_meta);
|
||||||
// transpose_kv_cache_by_block
|
// transpose_kv_cache_by_block
|
||||||
ops.impl("transpose_kv_cache_by_block", &vllm_ascend::meta::transpose_kv_cache_by_block_meta);
|
ops.impl("transpose_kv_cache_by_block", &vllm_ascend::meta::transpose_kv_cache_by_block_meta);
|
||||||
|
// causal_conv1d_fn
|
||||||
|
ops.impl("causal_conv1d_fn", &vllm_ascend::meta::causal_conv1d_fn_meta);
|
||||||
// moe_grouped_matmul
|
// moe_grouped_matmul
|
||||||
ops.impl("moe_grouped_matmul", &vllm_ascend::meta::moe_grouped_matmul_meta);
|
ops.impl("moe_grouped_matmul", &vllm_ascend::meta::moe_grouped_matmul_meta);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -92,10 +92,15 @@ def test_qwen3_next_mtp_acceptance_tp4(model_name):
|
|||||||
@pytest.mark.parametrize("model_name", MODELS)
|
@pytest.mark.parametrize("model_name", MODELS)
|
||||||
@pytest.mark.parametrize("num_speculative_tokens", [1])
|
@pytest.mark.parametrize("num_speculative_tokens", [1])
|
||||||
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
|
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
|
||||||
|
@pytest.mark.skip("Skip this CI.")
|
||||||
def test_qwen3_next_mtp_correctness_tp4(model_name: str,
|
def test_qwen3_next_mtp_correctness_tp4(model_name: str,
|
||||||
num_speculative_tokens: int,
|
num_speculative_tokens: int,
|
||||||
disable_padded_drafter_batch: bool):
|
disable_padded_drafter_batch: bool):
|
||||||
example_prompts = [
|
example_prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
"The president of the United States is",
|
"The president of the United States is",
|
||||||
"The capital of France is",
|
"The capital of France is",
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from vllm_ascend.ops.triton.mamba.causal_conv1d import (PAD_SLOT_ID,
|
|||||||
causal_conv1d_fn)
|
causal_conv1d_fn)
|
||||||
from vllm_ascend.ops.triton.mamba.causal_conv1d import \
|
from vllm_ascend.ops.triton.mamba.causal_conv1d import \
|
||||||
causal_conv1d_update_npu as causal_conv1d_update
|
causal_conv1d_update_npu as causal_conv1d_update
|
||||||
|
from vllm_ascend.utils import enable_custom_op
|
||||||
|
|
||||||
def validate_cmp(y_cal, y_ref, dtype, device='npu'):
|
def validate_cmp(y_cal, y_ref, dtype, device='npu'):
|
||||||
y_cal = y_cal.to(device)
|
y_cal = y_cal.to(device)
|
||||||
@@ -157,6 +157,90 @@ def causal_conv1d_fn_pytorch(
|
|||||||
return out_ref_tensor
|
return out_ref_tensor
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('has_initial_state', [False, True])
|
||||||
|
@pytest.mark.parametrize('itype', [torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize('silu_activation', [True])
|
||||||
|
@pytest.mark.parametrize('has_bias', [True])
|
||||||
|
@pytest.mark.parametrize('seq_len', [[128, 1024, 2048, 4096]])
|
||||||
|
@pytest.mark.parametrize('extra_state_len', [0, 2])
|
||||||
|
@pytest.mark.parametrize('width', [4])
|
||||||
|
@pytest.mark.parametrize('dim', [2048])
|
||||||
|
def test_ascend_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias,
|
||||||
|
silu_activation, itype, has_initial_state):
|
||||||
|
|
||||||
|
torch.random.manual_seed(0)
|
||||||
|
enable_custom_op()
|
||||||
|
device = "npu"
|
||||||
|
cu_seqlen, num_seq = sum(seq_len), len(seq_len)
|
||||||
|
state_len = width - 1 + extra_state_len
|
||||||
|
|
||||||
|
x = torch.randn(cu_seqlen, dim, device=device, dtype=itype).transpose(0, 1)
|
||||||
|
weight = torch.randn(dim, width, device=device, dtype=itype)#
|
||||||
|
query_start_loc = torch.cumsum(torch.tensor([0] + seq_len,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32),
|
||||||
|
dim=0).to(dtype=torch.int32)
|
||||||
|
cache_indices = torch.arange(num_seq, device=device, dtype=torch.int32)
|
||||||
|
has_initial_state_tensor = torch.tensor([has_initial_state] * num_seq,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.bool)
|
||||||
|
activation = None if not silu_activation else "silu"
|
||||||
|
|
||||||
|
if has_initial_state:
|
||||||
|
conv_states = torch.randn((num_seq, state_len, dim),
|
||||||
|
device=device,
|
||||||
|
dtype=itype).transpose(-1, -2)
|
||||||
|
conv_states_ref = torch.randn(
|
||||||
|
(num_seq, state_len, dim), device=device,
|
||||||
|
dtype=itype).transpose(-1, -2).copy_(conv_states)
|
||||||
|
else:
|
||||||
|
conv_states = torch.zeros((num_seq, state_len, dim),
|
||||||
|
device=device,
|
||||||
|
dtype=itype).transpose(-1, -2)
|
||||||
|
conv_states_ref = torch.zeros((num_seq, state_len, dim),
|
||||||
|
device=device,
|
||||||
|
dtype=itype).transpose(-1, -2)
|
||||||
|
|
||||||
|
if has_bias:
|
||||||
|
bias = torch.randn(dim, device=device, dtype=itype)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
|
||||||
|
out_ref = causal_conv1d_fn_pytorch(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
bias=bias,
|
||||||
|
activation=activation,
|
||||||
|
conv_states=conv_states_ref,
|
||||||
|
has_initial_state=has_initial_state_tensor,
|
||||||
|
cache_indices=cache_indices,
|
||||||
|
query_start_loc=query_start_loc)
|
||||||
|
# out = causal_conv1d_fn(x,
|
||||||
|
# weight,
|
||||||
|
# bias=bias,
|
||||||
|
# activation=activation,
|
||||||
|
# conv_states=conv_states,
|
||||||
|
# has_initial_state=has_initial_state_tensor,
|
||||||
|
# cache_indices=cache_indices,
|
||||||
|
# query_start_loc=query_start_loc)
|
||||||
|
x_origin=x.transpose(-1, -2)
|
||||||
|
weight_origin=weight.transpose(-1, -2)
|
||||||
|
conv_states_origin=conv_states.transpose(-1, -2)
|
||||||
|
out = torch.ops._C_ascend.causal_conv1d_fn(
|
||||||
|
x_origin,
|
||||||
|
weight_origin,
|
||||||
|
bias,
|
||||||
|
activation=activation,
|
||||||
|
conv_state=conv_states_origin,
|
||||||
|
has_initial_state=has_initial_state_tensor,
|
||||||
|
non_spec_state_indices_tensor=cache_indices,
|
||||||
|
non_spec_query_start_loc=query_start_loc,
|
||||||
|
pad_slot_id=PAD_SLOT_ID,
|
||||||
|
).transpose(-1, -2)
|
||||||
|
validate_cmp(out, out_ref, itype)
|
||||||
|
validate_cmp(conv_states, conv_states_ref, itype)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('has_initial_state', [False, True])
|
@pytest.mark.parametrize('has_initial_state', [False, True])
|
||||||
@pytest.mark.parametrize('itype', [torch.bfloat16])
|
@pytest.mark.parametrize('itype', [torch.bfloat16])
|
||||||
@pytest.mark.parametrize('silu_activation', [True])
|
@pytest.mark.parametrize('silu_activation', [True])
|
||||||
|
|||||||
@@ -22,11 +22,12 @@ from einops import rearrange
|
|||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule
|
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule
|
||||||
from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd
|
from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update
|
||||||
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
|
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
|
||||||
from vllm.triton_utils import triton
|
from vllm.triton_utils import triton
|
||||||
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
||||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||||
|
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||||
|
|
||||||
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
|
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
|
||||||
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
|
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
|
||||||
@@ -163,20 +164,18 @@ class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet):
|
|||||||
# 1.2: Process the remaining part
|
# 1.2: Process the remaining part
|
||||||
if attn_metadata.num_prefills > 0:
|
if attn_metadata.num_prefills > 0:
|
||||||
if mixed_qkv_non_spec is not None:
|
if mixed_qkv_non_spec is not None:
|
||||||
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
|
conv_weights_T = conv_weights.transpose(0, 1)
|
||||||
# - "cache_indices" updates the conv_state cache in positions
|
mixed_qkv_non_spec = torch.ops._C_ascend.causal_conv1d_fn(
|
||||||
# pointed to by "state_indices_tensor"
|
mixed_qkv_non_spec,
|
||||||
mixed_qkv_non_spec = causal_conv1d_fn(
|
conv_weights_T,
|
||||||
mixed_qkv_non_spec_T,
|
|
||||||
conv_weights,
|
|
||||||
self.conv1d.bias,
|
self.conv1d.bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
conv_states=conv_state,
|
conv_state=self_kv_cache[0],
|
||||||
has_initial_state=has_initial_state,
|
has_initial_state=has_initial_state,
|
||||||
cache_indices=non_spec_state_indices_tensor,
|
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
|
||||||
query_start_loc=non_spec_query_start_loc,
|
non_spec_query_start_loc=non_spec_query_start_loc,
|
||||||
metadata=attn_metadata,
|
pad_slot_id=PAD_SLOT_ID,
|
||||||
).transpose(0, 1)
|
)
|
||||||
elif attn_metadata.num_decodes > 0:
|
elif attn_metadata.num_decodes > 0:
|
||||||
mixed_qkv_non_spec = causal_conv1d_update(
|
mixed_qkv_non_spec = causal_conv1d_update(
|
||||||
mixed_qkv_non_spec,
|
mixed_qkv_non_spec,
|
||||||
|
|||||||
Reference in New Issue
Block a user