[qwen3 next ]add ascend c casual_conv1d_fn (#6661)
### What this PR does / why we need it?
add ascend c casual_conv1d_fn
- vLLM version: v0.15.0
- vLLM main:
13397841ab
---------
Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <63220130+ZT-AIA@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
50
csrc/causal_conv1d/op_host/CMakeLists.txt
Normal file
50
csrc/causal_conv1d/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME CausalConv1d
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnn PRIVATE
|
||||
causal_conv1d_def.cpp
|
||||
)
|
||||
|
||||
# target_sources(opapi PRIVATE
|
||||
# aclnn_causal_conv1d.cpp
|
||||
# )
|
||||
|
||||
# if (NOT BUILD_OPEN_PROJECT)
|
||||
# target_sources(aclnn_ops_train PRIVATE
|
||||
# aclnn_causal_conv1d.cpp
|
||||
# )
|
||||
|
||||
# target_sources(aclnn_ops_infer PRIVATE
|
||||
# aclnn_causal_conv1d.cpp
|
||||
# )
|
||||
# endif ()
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
causal_conv1d_tiling.cpp
|
||||
tiling_util.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE)
|
||||
|
||||
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_causal_conv1d.h")
|
||||
|
||||
install(FILES ${_GMM_Aclnn_header}
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||
)
|
||||
83
csrc/causal_conv1d/op_host/causal_conv1d_def.cpp
Normal file
83
csrc/causal_conv1d/op_host/causal_conv1d_def.cpp
Normal file
@@ -0,0 +1,83 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file causal_conv1d_def.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
|
||||
class CausalConv1d : public OpDef {
|
||||
public:
|
||||
explicit CausalConv1d(const char* name) : OpDef(name)
|
||||
{
|
||||
this->Input("x")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.FormatList({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("weight")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.FormatList({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("bias")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.FormatList({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("convStates")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.FormatList({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("queryStartLoc")
|
||||
.ParamType(REQUIRED)
|
||||
.DataTypeList({ge::DT_INT32})
|
||||
.FormatList({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("cacheIndices")
|
||||
.ParamType(REQUIRED)
|
||||
.DataTypeList({ge::DT_INT32})
|
||||
.FormatList({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("hasInitialState")
|
||||
.ParamType(REQUIRED)
|
||||
.DataTypeList({ge::DT_BOOL})
|
||||
.FormatList({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
|
||||
this->Output("y")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.FormatList({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
|
||||
this->Attr("activationMode").AttrType(OPTIONAL).Int(0);
|
||||
this->Attr("padSlotId").AttrType(OPTIONAL).Int(-1);
|
||||
|
||||
OpAICoreConfig aicoreConfig;
|
||||
aicoreConfig.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(false)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.NeedCheckSupportFlag(false)
|
||||
.PrecisionReduceFlag(true)
|
||||
.ExtendCfgInfo("coreType.value", "AiCore");
|
||||
this->AICore().AddConfig("ascend910b", aicoreConfig);
|
||||
this->AICore().AddConfig("ascend910_93", aicoreConfig);
|
||||
}
|
||||
};
|
||||
OP_ADD(CausalConv1d);
|
||||
|
||||
} // namespace ops
|
||||
49
csrc/causal_conv1d/op_host/causal_conv1d_infershape.cpp
Normal file
49
csrc/causal_conv1d/op_host/causal_conv1d_infershape.cpp
Normal file
@@ -0,0 +1,49 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file causal_conv1d_infershape.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "error_log.h"
|
||||
|
||||
using namespace ge;
|
||||
|
||||
namespace ops {
|
||||
static constexpr int64_t IDX_0 = 0;
|
||||
|
||||
static ge::graphStatus InferShapeCausalConv1d(gert::InferShapeContext* context)
|
||||
{
|
||||
// OPS_LOG_D(context->GetNodeName(), "Begin to do InferShapeCausalConv1d");
|
||||
|
||||
// get input shapes
|
||||
const gert::Shape* xShape = context->GetInputShape(IDX_0);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, xShape);
|
||||
|
||||
// get output shapes
|
||||
gert::Shape* yShape = context->GetOutputShape(IDX_0);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, yShape);
|
||||
|
||||
// 填充输出shape大小
|
||||
auto xShapeSize = xShape->GetDimNum();
|
||||
yShape->SetDimNum(xShapeSize);
|
||||
for (size_t i = 0; i < xShapeSize; i++) {
|
||||
int64_t dim = xShape->GetDim(i);
|
||||
yShape->SetDim(i, dim);
|
||||
}
|
||||
|
||||
// OPS_LOG_D(context->GetNodeName(), "End to do InferShapeCausalConv1d");
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_INFERSHAPE(CausalConv1d).InferShape(InferShapeCausalConv1d);
|
||||
} // namespace ops
|
||||
365
csrc/causal_conv1d/op_host/causal_conv1d_tiling.cpp
Normal file
365
csrc/causal_conv1d/op_host/causal_conv1d_tiling.cpp
Normal file
@@ -0,0 +1,365 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file causal_conv1d_tiling.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
// #include "error_log.h"
|
||||
#include "log/ops_log.h"
|
||||
#include "../tiling_base/tiling_templates_registry.h"
|
||||
#include "../tiling_base/tiling_util.h"
|
||||
#include "math_util.h"
|
||||
#include "causal_conv1d_tiling.h"
|
||||
#include "../op_kernel/causal_conv1d_tiling_key.h"
|
||||
|
||||
#include <set>
|
||||
#include <limits>
|
||||
|
||||
namespace optiling {
|
||||
|
||||
using namespace Ops::Transformer::OpTiling;
|
||||
|
||||
constexpr uint32_t X_INDEX = 0;
|
||||
constexpr uint32_t WEIGHT_INDEX = 1;
|
||||
constexpr uint32_t BIAS_INDEX = 2;
|
||||
constexpr uint32_t CONV_STATES_INDEX = 3;
|
||||
constexpr uint32_t QUERY_START_LOC_INDEX = 4;
|
||||
constexpr uint32_t CACHE_INDICES_INDEX = 5;
|
||||
constexpr uint32_t HAS_INITIAL_STATE_INDEX = 6;
|
||||
|
||||
constexpr int32_t ATTR_ACTIVATION_MODE_INDEX = 0;
|
||||
constexpr int32_t ATTR_PAD_SLOT_ID_INDEX = 1;
|
||||
|
||||
|
||||
|
||||
struct DimTileChoice {
|
||||
int64_t dimTileSize = 0;
|
||||
int64_t blocksPerSeq = 0;
|
||||
int64_t gridSize = 0;
|
||||
};
|
||||
|
||||
static inline DimTileChoice ChooseDimTileSize(gert::TilingContext* context, int64_t batch, int64_t dim, uint32_t coreNum)
|
||||
{
|
||||
|
||||
const int64_t candidates[] = {4096, 2048, 1024, 512,384};
|
||||
DimTileChoice bestOver;
|
||||
int64_t bestOverGap = std::numeric_limits<int64_t>::max();
|
||||
DimTileChoice bestUnder;
|
||||
|
||||
for (int64_t dimTileSize : candidates) {
|
||||
if (dim % dimTileSize != 0) {
|
||||
continue;
|
||||
}
|
||||
const int64_t blocksPerSeq = dim / dimTileSize;
|
||||
const int64_t gridSize = batch * blocksPerSeq;
|
||||
if (gridSize <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (gridSize >= static_cast<int64_t>(coreNum)) {
|
||||
const int64_t gap = gridSize - static_cast<int64_t>(coreNum);
|
||||
if (gap < bestOverGap) {
|
||||
bestOver.dimTileSize = dimTileSize;
|
||||
bestOver.blocksPerSeq = blocksPerSeq;
|
||||
bestOver.gridSize = gridSize;
|
||||
bestOverGap = gap;
|
||||
}
|
||||
} else if (gridSize > bestUnder.gridSize ||
|
||||
(gridSize == bestUnder.gridSize && dimTileSize < bestUnder.dimTileSize)) {
|
||||
bestUnder.dimTileSize = dimTileSize;
|
||||
bestUnder.blocksPerSeq = blocksPerSeq;
|
||||
bestUnder.gridSize = gridSize;
|
||||
}
|
||||
}
|
||||
DimTileChoice result = (bestOver.dimTileSize != 0) ? bestOver : bestUnder;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static ge::graphStatus GetPlatformInfo(gert::TilingContext* context, uint64_t& ubSize, uint32_t& coreNum)
|
||||
{
|
||||
auto compileInfoPtr = context->GetCompileInfo<CausalConv1dCompileInfo>();
|
||||
if (compileInfoPtr != nullptr && compileInfoPtr->coreNum != 0 && compileInfoPtr->ubSize != 0) {
|
||||
ubSize = compileInfoPtr->ubSize;
|
||||
coreNum = compileInfoPtr->coreNum;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||||
coreNum = ascendcPlatform.GetCoreNumAiv();
|
||||
if(coreNum == 0) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
|
||||
if(ubSize == 0) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus GetWorkspaceSize(gert::TilingContext* context)
|
||||
{
|
||||
size_t* currentWorkspace = context->GetWorkspaceSizes(1);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, currentWorkspace);
|
||||
currentWorkspace[0] = 0;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus GetAttrsInfo(gert::TilingContext* context, int64_t& activationMode, int64_t& padSlotId)
|
||||
{
|
||||
auto attrs = context->GetAttrs();
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
|
||||
|
||||
const int64_t* activationModePtr = attrs->GetAttrPointer<int64_t>(ATTR_ACTIVATION_MODE_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, activationModePtr);
|
||||
activationMode = *activationModePtr;
|
||||
if(activationMode != 0 && activationMode != 1){
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
const int64_t* padSlotIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_PAD_SLOT_ID_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, padSlotIdPtr);
|
||||
padSlotId = *padSlotIdPtr;
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
static ge::graphStatus GetShapeDtypeInfo(gert::TilingContext* context, CausalConv1dTilingData& tiling)
|
||||
{
|
||||
auto xShapePtr = context->GetInputShape(X_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, xShapePtr);
|
||||
auto xShape = EnsureNotScalar(xShapePtr->GetStorageShape());
|
||||
|
||||
int64_t dim = 0;
|
||||
int64_t cuSeqlen = 0;
|
||||
int64_t seqLen = 0;
|
||||
int64_t batch = 0;
|
||||
int64_t inputMode = 0;
|
||||
|
||||
if (xShape.GetDimNum() == 2) {
|
||||
inputMode = 0;
|
||||
cuSeqlen = xShape.GetDim(0);
|
||||
dim = xShape.GetDim(1);
|
||||
seqLen = 0;
|
||||
if(dim <= 0 || cuSeqlen < 0){
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
} else if (xShape.GetDimNum() == 3) {
|
||||
inputMode = 1;
|
||||
batch = xShape.GetDim(0);
|
||||
seqLen = xShape.GetDim(1);
|
||||
dim = xShape.GetDim(2);
|
||||
cuSeqlen = batch * seqLen;
|
||||
if(batch <= 0 || dim <= 0 || seqLen <= 0){
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
} else {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
auto wShapePtr = context->GetInputShape(WEIGHT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, wShapePtr);
|
||||
auto wShape = EnsureNotScalar(wShapePtr->GetStorageShape());
|
||||
if(wShape.GetDimNum() != 2){
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
const int64_t width = wShape.GetDim(0);
|
||||
const int64_t wDim = wShape.GetDim(1);
|
||||
if(wDim != dim){
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
if(width != 4){
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
auto sShapePtr = context->GetInputShape(CONV_STATES_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, sShapePtr);
|
||||
auto sShape = EnsureNotScalar(sShapePtr->GetStorageShape());
|
||||
if(sShape.GetDimNum() != 3){
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
const int64_t numCacheLines = sShape.GetDim(0);
|
||||
const int64_t stateLen = sShape.GetDim(1);
|
||||
const int64_t sDim = sShape.GetDim(2);
|
||||
if(numCacheLines <= 0){
|
||||
return ge::GRAPH_FAILED;}
|
||||
if(sDim != dim){
|
||||
return ge::GRAPH_FAILED;}
|
||||
if(stateLen < (width - 1)){
|
||||
return ge::GRAPH_FAILED;}
|
||||
|
||||
auto qslShapePtr = context->GetInputShape(QUERY_START_LOC_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, qslShapePtr);
|
||||
auto qslShape = EnsureNotScalar(qslShapePtr->GetStorageShape());
|
||||
if(qslShape.GetDimNum() != 1){
|
||||
return ge::GRAPH_FAILED;}
|
||||
const int64_t qslSize = qslShape.GetDim(0);
|
||||
if(qslSize < 1){
|
||||
return ge::GRAPH_FAILED;}
|
||||
|
||||
if (inputMode == 0) {
|
||||
batch = qslSize - 1;
|
||||
}
|
||||
|
||||
if (inputMode == 1) {
|
||||
if(qslSize != batch + 1){
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
auto ciShapePtr = context->GetInputShape(CACHE_INDICES_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, ciShapePtr);
|
||||
auto ciShape = EnsureNotScalar(ciShapePtr->GetStorageShape());
|
||||
if(ciShape.GetDimNum() != 1){return ge::GRAPH_FAILED;}
|
||||
if(ciShape.GetDim(0) != batch){return ge::GRAPH_FAILED;}
|
||||
|
||||
auto hisShapePtr = context->GetInputShape(HAS_INITIAL_STATE_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, hisShapePtr);
|
||||
auto hisShape = EnsureNotScalar(hisShapePtr->GetStorageShape());
|
||||
if(hisShape.GetDimNum() != 1){
|
||||
return ge::GRAPH_FAILED;}
|
||||
if(hisShape.GetDim(0) != batch){
|
||||
return ge::GRAPH_FAILED;}
|
||||
|
||||
tiling.set_hasBias(0);
|
||||
auto biasShapePtr = context->GetOptionalInputShape(BIAS_INDEX);
|
||||
if (biasShapePtr != nullptr && biasShapePtr->GetStorageShape().GetDimNum() != 0) {
|
||||
auto biasShape = EnsureNotScalar(biasShapePtr->GetStorageShape());
|
||||
if(biasShape.GetDimNum() != 1){
|
||||
return ge::GRAPH_FAILED;}
|
||||
if(biasShape.GetDim(0) != dim){
|
||||
return ge::GRAPH_FAILED;}
|
||||
tiling.set_hasBias(1);
|
||||
}
|
||||
|
||||
const std::set<ge::DataType> supportedXDtype = {ge::DT_BF16, ge::DT_FLOAT16};
|
||||
auto xDesc = context->GetInputDesc(X_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, xDesc);
|
||||
const ge::DataType xDtype = xDesc->GetDataType();
|
||||
if(supportedXDtype.count(xDtype) == 0){
|
||||
return ge::GRAPH_FAILED;}
|
||||
|
||||
auto wDesc = context->GetInputDesc(WEIGHT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, wDesc);
|
||||
if(wDesc->GetDataType() != xDtype){
|
||||
return ge::GRAPH_FAILED;}
|
||||
|
||||
if (tiling.get_hasBias() == 1) {
|
||||
auto biasDesc = context->GetOptionalInputDesc(BIAS_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, biasDesc);
|
||||
if(biasDesc->GetDataType() != xDtype){
|
||||
return ge::GRAPH_FAILED;}
|
||||
}
|
||||
|
||||
auto sDesc = context->GetInputDesc(CONV_STATES_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, sDesc);
|
||||
if(sDesc->GetDataType() != xDtype){
|
||||
return ge::GRAPH_FAILED;}
|
||||
|
||||
auto qslDesc = context->GetInputDesc(QUERY_START_LOC_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, qslDesc);
|
||||
if(qslDesc->GetDataType() != ge::DT_INT32){
|
||||
return ge::GRAPH_FAILED;}
|
||||
|
||||
auto ciDesc = context->GetInputDesc(CACHE_INDICES_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, ciDesc);
|
||||
if(ciDesc->GetDataType() != ge::DT_INT32){
|
||||
return ge::GRAPH_FAILED;}
|
||||
|
||||
auto hisDesc = context->GetInputDesc(HAS_INITIAL_STATE_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, hisDesc);
|
||||
if(hisDesc->GetDataType() != ge::DT_BOOL){
|
||||
return ge::GRAPH_FAILED;}
|
||||
|
||||
tiling.set_dim(dim);
|
||||
tiling.set_cuSeqlen(cuSeqlen);
|
||||
tiling.set_seqLen(seqLen);
|
||||
tiling.set_inputMode(inputMode);
|
||||
tiling.set_width(width);
|
||||
tiling.set_stateLen(stateLen);
|
||||
tiling.set_numCacheLines(numCacheLines);
|
||||
tiling.set_batch(batch);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus CausalConv1dTilingFunc(gert::TilingContext* context)
|
||||
{
|
||||
uint64_t ubSize;
|
||||
uint32_t coreNum;
|
||||
if( GetPlatformInfo(context, ubSize, coreNum) != ge::GRAPH_SUCCESS){
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if(GetWorkspaceSize(context) != ge::GRAPH_SUCCESS){
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
CausalConv1dTilingData tilingData;
|
||||
|
||||
int64_t activationMode = 0;
|
||||
int64_t padSlotId = -1;
|
||||
if(GetAttrsInfo(context, activationMode, padSlotId) != ge::GRAPH_SUCCESS){
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
tilingData.set_activationMode(activationMode);
|
||||
tilingData.set_padSlotId(padSlotId);
|
||||
|
||||
if( GetShapeDtypeInfo(context, tilingData) != ge::GRAPH_SUCCESS){
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
const int64_t dim = tilingData.get_dim();
|
||||
const int64_t batch = tilingData.get_batch();
|
||||
if(dim <= 0 || batch <= 0){
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
const DimTileChoice choice = ChooseDimTileSize(context, batch, dim, coreNum);
|
||||
const uint32_t blockDim = (choice.gridSize < static_cast<int64_t>(coreNum))
|
||||
? static_cast<uint32_t>(choice.gridSize)
|
||||
: coreNum;
|
||||
context->SetBlockDim(blockDim);
|
||||
tilingData.set_dimTileSize(choice.dimTileSize);
|
||||
tilingData.set_blocksPerSeq(choice.blocksPerSeq);
|
||||
|
||||
const uint64_t tilingKey = GET_TPL_TILING_KEY(CAUSAL_CONV1D_TPL_SCH_MODE_DEFAULT);
|
||||
context->SetTilingKey(tilingKey);
|
||||
|
||||
tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
|
||||
context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
|
||||
static ge::graphStatus TilingParseForCausalConv1d(gert::TilingParseContext* context)
|
||||
{
|
||||
auto platformInfoPtr = context->GetPlatformInfo();
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
|
||||
auto compileInfoPtr = context->GetCompiledInfo<CausalConv1dCompileInfo>();
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfoPtr);
|
||||
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||||
compileInfoPtr->coreNum = static_cast<uint32_t>(ascendcPlatform.GetCoreNumAiv());
|
||||
if(compileInfoPtr->coreNum == 0){
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize);
|
||||
if(compileInfoPtr->ubSize == 0){
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(CausalConv1d)
|
||||
.Tiling(CausalConv1dTilingFunc)
|
||||
.TilingParse<CausalConv1dCompileInfo>(TilingParseForCausalConv1d);
|
||||
} // namespace optiling
|
||||
60
csrc/causal_conv1d/op_host/causal_conv1d_tiling.h
Normal file
60
csrc/causal_conv1d/op_host/causal_conv1d_tiling.h
Normal file
@@ -0,0 +1,60 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file causal_conv1d_tiling_data.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H
|
||||
#define ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
// #include "register/tilingdata_base.h"
|
||||
// #include "tiling/tiling_api.h"
|
||||
#include "register/tilingdata_base.h"
|
||||
#include "error_log.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "platform/platform_infos_def.h"
|
||||
namespace optiling {
|
||||
|
||||
BEGIN_TILING_DATA_DEF(CausalConv1dTilingData)
|
||||
TILING_DATA_FIELD_DEF(int64_t, dim);
|
||||
TILING_DATA_FIELD_DEF(int64_t, cuSeqlen);
|
||||
TILING_DATA_FIELD_DEF(int64_t, seqLen);
|
||||
TILING_DATA_FIELD_DEF(int64_t, inputMode);
|
||||
|
||||
TILING_DATA_FIELD_DEF(int64_t, width);
|
||||
|
||||
TILING_DATA_FIELD_DEF(int64_t, stateLen);
|
||||
TILING_DATA_FIELD_DEF(int64_t, numCacheLines);
|
||||
|
||||
TILING_DATA_FIELD_DEF(int64_t, batch);
|
||||
|
||||
TILING_DATA_FIELD_DEF(int64_t, activationMode);
|
||||
TILING_DATA_FIELD_DEF(int64_t, padSlotId);
|
||||
|
||||
TILING_DATA_FIELD_DEF(int64_t, hasBias);
|
||||
|
||||
TILING_DATA_FIELD_DEF(int64_t, dimTileSize);
|
||||
TILING_DATA_FIELD_DEF(int64_t, blocksPerSeq);
|
||||
END_TILING_DATA_DEF;
|
||||
struct CausalConv1dCompileInfo {
|
||||
uint64_t ubSize = 0;
|
||||
uint32_t coreNum = 0;
|
||||
};
|
||||
REGISTER_TILING_DATA_CLASS(CausalConv1d, CausalConv1dTilingData)
|
||||
|
||||
} // namespace optiling
|
||||
|
||||
#endif // ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H
|
||||
71
csrc/causal_conv1d/op_host/error_log.h
Normal file
71
csrc/causal_conv1d/op_host/error_log.h
Normal file
@@ -0,0 +1,71 @@
|
||||
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
|
||||
#include <string>
|
||||
#include "toolchain/slog.h"
|
||||
|
||||
#define OP_LOGI(opname, ...)
|
||||
#define OP_LOGW(opname, ...) \
|
||||
do { \
|
||||
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
|
||||
do { \
|
||||
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGE(opname, ...) \
|
||||
do { \
|
||||
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGD(opname, ...)
|
||||
|
||||
namespace optiling {
|
||||
|
||||
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
|
||||
do { \
|
||||
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
|
||||
#define OP_CHECK_IF(cond, log_func, expr) \
|
||||
do { \
|
||||
if (cond) { \
|
||||
log_func; \
|
||||
expr; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
|
||||
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
|
||||
do { \
|
||||
if ((ptr) == nullptr) { \
|
||||
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
|
||||
return ge::GRAPH_FAILED; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
} // namespace optiling
|
||||
|
||||
template <typename T>
|
||||
T CeilAlign(T a, T b)
|
||||
{
|
||||
return (a + b - 1) / b * b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T CeilDiv(T a, T b)
|
||||
{
|
||||
if (b == 0) {
|
||||
return a;
|
||||
}
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
61
csrc/causal_conv1d/op_host/math_util.h
Normal file
61
csrc/causal_conv1d/op_host/math_util.h
Normal file
@@ -0,0 +1,61 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file math_util.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef TILING_MATMUL_MATH_UTIL_H
|
||||
#define TILING_MATMUL_MATH_UTIL_H
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
namespace matmul_tiling {
|
||||
class MathUtil {
|
||||
public:
|
||||
static bool IsEqual(float leftValue, float rightValue);
|
||||
template<typename T>
|
||||
static auto CeilDivision(T num1, T num2) -> T
|
||||
{
|
||||
if (num2 == 0) {
|
||||
return 0;
|
||||
}
|
||||
return static_cast<T>((static_cast<int64_t>(num1) + static_cast<int64_t>(num2) - 1) /
|
||||
static_cast<int64_t>(num2));
|
||||
}
|
||||
template<typename T>
|
||||
static auto Align(T num1, T num2) -> T
|
||||
{
|
||||
return CeilDivision(num1, num2) * num2;
|
||||
}
|
||||
static int32_t AlignDown(int32_t num1, int32_t num2);
|
||||
static bool CheckMulOverflow(int32_t a, int32_t b, int32_t &c);
|
||||
static int32_t MapShape(int32_t shape, bool roundUpFlag = true);
|
||||
static void AddFactor(std::vector<int32_t> &dimsFactors, int32_t dim);
|
||||
static void GetFactorCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
|
||||
const int32_t factorEnd);
|
||||
static void GetFactorLayerCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
|
||||
const int32_t factorEnd);
|
||||
static bool CheckFactorNumSatisfy(const int32_t dim);
|
||||
static int32_t FindBestSingleCore(const int32_t oriShape, const int32_t mappedShape, const int32_t coreNum,
|
||||
bool isKDim);
|
||||
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t minFactor, int32_t maxFactor);
|
||||
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
|
||||
static void GetBlockFactors(std::vector<int32_t> &factorList, const int32_t oriShape, const int32_t mpShape,
|
||||
const int32_t coreNum, const int32_t maxNum);
|
||||
static int32_t GetNonFactorMap(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
|
||||
static std::vector<std::pair<int, int>> GetFactorPairs(int32_t num);
|
||||
static std::pair<int32_t, int32_t> DivideIntoMainAndTail(int32_t num, int32_t divisor);
|
||||
};
|
||||
} // namespace matmul_tiling
|
||||
#endif // _MATH_UTIL_H_
|
||||
31
csrc/causal_conv1d/op_host/tiling_util.cpp
Normal file
31
csrc/causal_conv1d/op_host/tiling_util.cpp
Normal file
@@ -0,0 +1,31 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_util.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "../tiling_base/tiling_util.h"
|
||||
namespace Ops {
|
||||
namespace Transformer {
|
||||
namespace OpTiling {
|
||||
static const gert::Shape g_vec_1_shape = {1};
|
||||
|
||||
const gert::Shape &EnsureNotScalar(const gert::Shape &inShape)
|
||||
{
|
||||
if (inShape.IsScalar()) {
|
||||
return g_vec_1_shape;
|
||||
}
|
||||
return inShape;
|
||||
}
|
||||
} // namespace OpTiling
|
||||
} // namespace Transformer
|
||||
} // namespace Ops
|
||||
Reference in New Issue
Block a user