[Model] GLM5 adaptation (#6642)

### What this PR does / why we need it?
GLM5 adaptation
1. use torch_npu.npu_lightning_indexer for GLM5
2. forbid eagle proposer when fullgraph mode is enabled because of bugs
3. add quatization config for GLM5
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM main:
978a37c823

---------

Signed-off-by: yydyzr <liuyuncong1@huawei.com>
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
yydyzr
2026-02-11 22:22:22 +08:00
committed by GitHub
parent 140fcaffc3
commit ff3a50d011
17 changed files with 77 additions and 34 deletions

View File

@@ -0,0 +1,42 @@
# This program is free software, you can redistribute it and/or modify it.
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
add_ops_compile_options(
OP_NAME LightningIndexerVllm
OPTIONS --cce-auto-sync=off
-Wno-deprecated-declarations
-Werror
-mllvm -cce-aicore-hoist-movemask=false
--op_relocatable_kernel_binary=true
)
set(lightning_indexer_vllm_depends transformer/attention/lightning_indexer_vllm PARENT_SCOPE)
target_sources(op_host_aclnn PRIVATE
lightning_indexer_vllm_def.cpp
)
target_sources(optiling PRIVATE
lightning_indexer_vllm_tiling.cpp
)
if (NOT BUILD_OPEN_PROJECT)
target_sources(opmaster_ct PRIVATE
lightning_indexer_vllm_tiling.cpp
)
endif ()
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
target_sources(opsproto PRIVATE
lightning_indexer_vllm_proto.cpp
)

View File

@@ -0,0 +1,72 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_def.cpp
* \brief
*/
#include <cstdint>
#include "register/op_def_registry.h"
namespace ops {
class LightningIndexerVllm : public OpDef {
public:
explicit LightningIndexerVllm(const char *name) : OpDef(name)
{
this->Input("query")
.ParamType(REQUIRED)
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("key")
.ParamType(REQUIRED)
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("weights")
.ParamType(REQUIRED)
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("actual_seq_lengths_query")
.ParamType(OPTIONAL)
.DataType({ge::DT_INT32, ge::DT_INT32})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("actual_seq_lengths_key")
.ParamType(OPTIONAL)
.DataType({ge::DT_INT32, ge::DT_INT32})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("block_table")
.ParamType(OPTIONAL)
.DataTypeList({ge::DT_INT32})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Output("sparse_indices").ParamType(REQUIRED).DataTypeList({ge::DT_INT32}).FormatList({ge::FORMAT_ND});
this->Attr("layout_query").AttrType(OPTIONAL).String("BSND");
this->Attr("layout_key").AttrType(OPTIONAL).String("PA_BSND");
this->Attr("sparse_count").AttrType(OPTIONAL).Int(2048); // 2048: Default value, filter the top 2048
this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3); // 3: Default value, only calculate the lower triangular matrix
OpAICoreConfig aicore_config;
aicore_config.DynamicCompileStaticFlag(true)
.DynamicFormatFlag(true)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true)
.NeedCheckSupportFlag(false)
.PrecisionReduceFlag(true)
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
.ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false");
this->AICore().AddConfig("ascend910b", aicore_config);
this->AICore().AddConfig("ascend910_93", aicore_config);
}
};
OP_ADD(LightningIndexerVllm);
} // namespace ops

View File

@@ -0,0 +1,96 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_proto.cpp
* \brief
*/
#include <graph/utils/type_utils.h>
#include <register/op_impl_registry.h>
#include "error/ops_error.h"
using namespace ge;
namespace ops {
constexpr uint32_t QUERY_INDEX = 0;
constexpr uint32_t KEY_INDEX = 1;
constexpr uint32_t ACTUAL_SEQ_K_INDEX = 4;
constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 0;
constexpr uint32_t ATTR_KEY_LAYOUT_INDEX = 1;
constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 2;
static ge::graphStatus InferShapeLightningIndexer(gert::InferShapeContext *context)
{
OPS_ERR_IF(context == nullptr, OPS_LOG_E("LightningIndexer", "InferShapeContext is nullptr!"),
return ge::GRAPH_FAILED);
const gert::Shape *queryShape = context->GetInputShape(QUERY_INDEX);
OPS_LOG_E_IF_NULL(context, queryShape, return ge::GRAPH_FAILED);
const gert::Shape *keyShape = context->GetInputShape(KEY_INDEX);
OPS_LOG_E_IF_NULL(context, keyShape, return ge::GRAPH_FAILED);
gert::Shape *outShape = context->GetOutputShape(0);
auto attrs = context->GetAttrs();
OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED);
const char *inputLayoutQueryPtr = attrs->GetAttrPointer<char>(ATTR_QUERY_LAYOUT_INDEX);
OPS_LOG_E_IF_NULL(context, inputLayoutQueryPtr, return ge::GRAPH_FAILED);
const char *inputLayoutKeyPtr = attrs->GetAttrPointer<char>(ATTR_KEY_LAYOUT_INDEX);
OPS_LOG_E_IF_NULL(context, inputLayoutKeyPtr, return ge::GRAPH_FAILED);
const int64_t *seleced_count = attrs->GetInt(ATTR_SPARSE_COUNT_INDEX);
OPS_LOG_E_IF_NULL(context, seleced_count, return ge::GRAPH_FAILED);
std::string inputLayoutQueryPtrStr = std::string(inputLayoutQueryPtr);
std::string inputLayoutKeyPtrStr = std::string(inputLayoutKeyPtr);
OPS_ERR_IF(
inputLayoutQueryPtrStr != "TND" && inputLayoutQueryPtrStr != "BSND",
OPS_LOG_E(context, "The attr layout_query should be TND or BSND, but got %s.", inputLayoutQueryPtrStr.c_str()),
return ge::GRAPH_FAILED);
outShape->SetDimNum(queryShape->GetDimNum());
if (inputLayoutQueryPtrStr == "BSND") {
OPS_ERR_IF(
queryShape->GetDimNum() != 4,
OPS_LOG_E(context, "Layout BSND, queryDims (%zu) must be 4!", queryShape->GetDimNum()),
return ge::GRAPH_FAILED);
outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim B
outShape->SetDim(1, queryShape->GetDim(1)); // 1:Dim S
outShape->SetDim(2, keyShape->GetDim(2)); // 2:Dim N
outShape->SetDim(3, *seleced_count); // 3:Dim K
} else {
OPS_ERR_IF(
queryShape->GetDimNum() != 3,
OPS_LOG_E(context, "Layout TND, queryDims (%zu) must be 3!", queryShape->GetDimNum()),
return ge::GRAPH_FAILED);
outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim T
int32_t nDimIndex = (inputLayoutKeyPtrStr == "PA_BSND") ? 2 : 1; // 2:Key Dim N
outShape->SetDim(1, keyShape->GetDim(nDimIndex)); // 1:Dim N
outShape->SetDim(2, *seleced_count); // 2:Dim K
}
OPS_LOG_D(context->GetNodeName(), "LightningIndexer InferShape end.");
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus InferDataTypeLightningIndexer(gert::InferDataTypeContext *context)
{
OPS_ERR_IF(context == nullptr, OPS_LOG_E("LightningIndexer", "InferDataTypeContext is nullptr!"),
return ge::GRAPH_FAILED);
OPS_LOG_D(context->GetNodeName(), "Enter LightningIndexer InferDataType impl.");
// default set q's dtype as fia's output type
ge::DataType outputType = ge::DT_INT32;
// attention_out, outidx:0
context->SetOutputDataType(0, outputType);
OPS_LOG_D(context->GetNodeName(), "LightningIndexer InferDataType end.");
return GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(LightningIndexerVllm)
.InferShape(InferShapeLightningIndexer)
.InferDataType(InferDataTypeLightningIndexer);
} // namespace ops

View File

@@ -0,0 +1,694 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_tiling.cpp
* \brief
*/
#include "lightning_indexer_vllm_tiling.h"
#include "../op_kernel/lightning_indexer_template_tiling_key.h"
using namespace ge;
using namespace AscendC;
using std::map;
using std::string;
namespace optiling {
ge::graphStatus LIInfoParser::CheckRequiredInOutExistence() const
{
OPS_ERR_IF(opParamInfo_.query.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor query is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.query.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor query is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.key.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor k is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.key.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor k is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.weights.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor value is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.weights.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor value is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.attenOut.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor output is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.attenOut.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor output is nullptr"),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::CheckRequiredAttrExistence() const
{
OPS_ERR_IF(opParamInfo_.layOut == nullptr, OPS_LOG_E(opName_, "attr layout_query is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.layOutKey == nullptr, OPS_LOG_E(opName_, "attr layout_key is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.sparseCount == nullptr, OPS_LOG_E(opName_, "attr sparse_count is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.sparseMode == nullptr, OPS_LOG_E(opName_, "attr sparse_mode is nullptr"),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::CheckRequiredParaExistence() const
{
if (CheckRequiredInOutExistence() != ge::GRAPH_SUCCESS || CheckRequiredAttrExistence() != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::GetOpName()
{
if (context_->GetNodeName() == nullptr) {
OPS_LOG_E("LightningIndexer", "opName got from TilingContext is nullptr");
return ge::GRAPH_FAILED;
}
opName_ = context_->GetNodeName();
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::GetNpuInfo()
{
platformInfo_ = context_->GetPlatformInfo();
OPS_ERR_IF(platformInfo_ == nullptr, OPS_LOG_E(opName_, "GetPlatformInfo is nullptr."), return ge::GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo_);
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
OPS_ERR_IF(aicNum == 0 || aivNum == 0, OPS_LOG_E(opName_, "num of core obtained is 0."), return GRAPH_FAILED);
socVersion_ = ascendcPlatform.GetSocVersion();
if ((socVersion_ != platform_ascendc::SocVersion::ASCEND910B) &&
(socVersion_ != platform_ascendc::SocVersion::ASCEND910_93)) {
OPS_LOG_E(opName_, "SOC Version[%d] is not support.", (int32_t)socVersion_);
return GRAPH_FAILED;
}
OPS_ERR_IF(context_->GetWorkspaceSizes(1) == nullptr, OPS_LOG_E(opName_, "workSpaceSize got from ge is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(context_->GetRawTilingData() == nullptr,
OPS_LOG_E(context_->GetNodeName(), "RawTilingData got from GE context is nullptr."),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
void LIInfoParser::GetOptionalInputParaInfo()
{
opParamInfo_.actualSeqLengthsQ.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_Q_INDEX);
opParamInfo_.actualSeqLengthsQ.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_Q_INDEX);
opParamInfo_.actualSeqLengths.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_K_INDEX);
opParamInfo_.actualSeqLengths.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_K_INDEX);
opParamInfo_.blockTable.tensor = context_->GetOptionalInputTensor(BLOCK_TABLE_INDEX);
opParamInfo_.blockTable.desc = context_->GetOptionalInputDesc(BLOCK_TABLE_INDEX);
}
void LIInfoParser::GetInputParaInfo()
{
opParamInfo_.query.desc = context_->GetInputDesc(QUERY_INDEX);
opParamInfo_.query.shape = context_->GetInputShape(QUERY_INDEX);
opParamInfo_.key.desc = context_->GetInputDesc(KEY_INDEX);
opParamInfo_.key.shape = context_->GetInputShape(KEY_INDEX);
opParamInfo_.weights.desc = context_->GetInputDesc(WEIGTHS_INDEX);
opParamInfo_.weights.shape = context_->GetInputShape(WEIGTHS_INDEX);
GetOptionalInputParaInfo();
}
void LIInfoParser::GetOutputParaInfo()
{
opParamInfo_.attenOut.desc = context_->GetOutputDesc(LIGHTNING_INDEXER);
opParamInfo_.attenOut.shape = context_->GetOutputShape(LIGHTNING_INDEXER);
}
ge::graphStatus LIInfoParser::GetAndCheckAttrParaInfo()
{
auto attrs = context_->GetAttrs();
OPS_ERR_IF(attrs == nullptr, OPS_REPORT_VECTOR_INNER_ERR(context_->GetNodeName(), "attrs got from ge is nullptr"),
return ge::GRAPH_FAILED);
OPS_LOG_I(context_->GetNodeName(), "GetAndCheckAttrParaInfo start");
opParamInfo_.layOut = attrs->GetStr(ATTR_QUERY_LAYOUT_INDEX);
opParamInfo_.layOutKey = attrs->GetStr(ATTR_KEY_LAYOUT_INDEX);
opParamInfo_.sparseCount = attrs->GetAttrPointer<int32_t>(ATTR_SPARSE_COUNT_INDEX);
opParamInfo_.sparseMode = attrs->GetAttrPointer<int32_t>(ATTR_SPARSE_MODE_INDEX);
if (opParamInfo_.layOut != nullptr) {
OPS_LOG_I(context_->GetNodeName(), "layout_query is:%s", opParamInfo_.layOut);
}
if (opParamInfo_.layOutKey != nullptr) {
OPS_LOG_I(context_->GetNodeName(), "layout_key is:%s", opParamInfo_.layOutKey);
}
if (opParamInfo_.sparseCount != nullptr) {
OPS_LOG_I(context_->GetNodeName(), "selscted count is:%d", *opParamInfo_.sparseCount);
}
if (opParamInfo_.sparseMode != nullptr) {
OPS_LOG_I(context_->GetNodeName(), "sparse mode is:%d", *opParamInfo_.sparseMode);
}
OPS_LOG_I(context_->GetNodeName(), "GetAndCheckAttrParaInfo end");
OPS_ERR_IF(
((std::string(opParamInfo_.layOutKey) != "PA_BSND")
&& (std::string(opParamInfo_.layOut) != std::string(opParamInfo_.layOutKey))),
OPS_LOG_E(opName_, "under non-PA conditions, layout_query and layout_key should be equal."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(
((std::string(opParamInfo_.layOutKey) != "PA_BSND") && (std::string(opParamInfo_.layOutKey) != "BSND")
&& (std::string(opParamInfo_.layOutKey) != "TND")),
OPS_LOG_E(opName_, "input attr layout_key only supported PA_BSND, BSND or TND"), return ge::GRAPH_FAILED);
OPS_ERR_IF(((std::string(opParamInfo_.layOut) != "BSND") && (std::string(opParamInfo_.layOut) != "TND")),
OPS_LOG_E(opName_, "input attr layout_query only supported BSND or TND."), return ge::GRAPH_FAILED);
OPS_ERR_IF(!((*opParamInfo_.sparseCount > 0) && (*opParamInfo_.sparseCount <= SPARSE_LIMIT)),
OPS_LOG_E(opName_, "input attr sparse_count must > 0 and <= 2048."), return ge::GRAPH_FAILED);
OPS_ERR_IF(!((*opParamInfo_.sparseMode == 0) || (*opParamInfo_.sparseMode == SPARSE_MODE_LOWER)),
OPS_LOG_E(opName_, "input attr sparse_mode only supported 0 or 3."), return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::GetOpParaInfo()
{
GetInputParaInfo();
GetOutputParaInfo();
if (ge::GRAPH_SUCCESS != GetAndCheckAttrParaInfo()) {
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::GetAndCheckInOutDataType()
{
inputQType_ = opParamInfo_.query.desc->GetDataType();
inputKType_ = opParamInfo_.key.desc->GetDataType();
weightsType_ = opParamInfo_.weights.desc->GetDataType();
outputType_ = opParamInfo_.attenOut.desc->GetDataType();
bool inDTypeAllEqual = (inputQType_ == inputKType_) && (inputKType_ == weightsType_);
OPS_ERR_IF(!inDTypeAllEqual,
OPS_LOG_E(opName_, "The data types of the input query, key, and weights must be the same."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(((inputQType_ != ge::DT_FLOAT16) && (inputQType_ != ge::DT_BF16)),
OPS_LOG_E(opName_, "The data types of the input query, key, and weights must be float16 or bfloat16."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(outputType_ != ge::DT_INT32,
OPS_LOG_E(opName_, "The data types of the output sparse_indices must be int32."),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::GetQueryKeyAndOutLayout()
{
const map<string, DataLayout> layoutMap = {
{"BSND", DataLayout::BSND},
{"TND", DataLayout::TND},
{"PA_BSND", DataLayout::BnBsND}
};
std::string layout(opParamInfo_.layOut);
auto it = layoutMap.find(layout);
if (it != layoutMap.end()) {
qLayout_ = it->second;
}
std::string layoutKey(opParamInfo_.layOutKey);
auto itKey = layoutMap.find(layoutKey);
if (itKey != layoutMap.end()) {
kLayout_ = itKey->second;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::GetAndCheckOptionalInput()
{
if (kLayout_ == DataLayout::BnBsND) {
OPS_ERR_IF(opParamInfo_.blockTable.tensor == nullptr,
OPS_LOG_E(opName_, "key layout only supported PA_BSND, input block_table must not be null"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(
opParamInfo_.actualSeqLengths.tensor == nullptr,
OPS_LOG_E(opName_, "key layout only supported PA_BSND, input actual_seq_lengths_key must not be null"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.blockTable.desc->GetDataType() != ge::DT_INT32,
OPS_LOG_E(opName_, "input block_table data type only support int32"), return ge::GRAPH_FAILED);
} else if (kLayout_ == DataLayout::TND) {
OPS_ERR_IF(opParamInfo_.actualSeqLengths.tensor == nullptr,
OPS_LOG_E(opName_, "when layout_key is TND, input actual_seq_lengths_key must not be null"),
return ge::GRAPH_FAILED);
}
OPS_ERR_IF(opParamInfo_.actualSeqLengths.tensor != nullptr &&
opParamInfo_.actualSeqLengths.desc->GetDataType() != ge::DT_INT32,
OPS_LOG_E(opName_, "input actual_seq_lengths_key data type only support int32"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.actualSeqLengths.tensor != nullptr &&
opParamInfo_.actualSeqLengths.desc->GetDataType() != ge::DT_INT32,
OPS_LOG_E(opName_, "input actual_seq_lengths_key data type only support int32"),
return ge::GRAPH_FAILED);
if (qLayout_ == DataLayout::TND) {
OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor == nullptr,
OPS_LOG_E(opName_, "when layout_query is TND, input actual_seq_lengths_query must not be null"),
return ge::GRAPH_FAILED);
}
OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor != nullptr &&
opParamInfo_.actualSeqLengthsQ.desc->GetDataType() != ge::DT_INT32,
OPS_LOG_E(opName_, "input actual_seq_lengths_query data type only support int32"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(kLayout_ != DataLayout::BnBsND && opParamInfo_.blockTable.tensor != nullptr,
OPS_LOG_E(opName_, "when key layout is not PA_BSND, input block_table must be null"),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::CheckShapeDim()
{
OPS_ERR_IF((opParamInfo_.blockTable.tensor != nullptr) &&
(opParamInfo_.blockTable.tensor->GetStorageShape().GetDimNum() != DIM_NUM_TWO),
OPS_LOG_E(opName_, "the dim num of block_table's shape should be 2"), return ge::GRAPH_FAILED);
uint32_t kShapeDim = opParamInfo_.key.shape->GetStorageShape().GetDimNum();
uint32_t qShapeDim = opParamInfo_.query.shape->GetStorageShape().GetDimNum();
uint32_t weightsShapeDim = opParamInfo_.weights.shape->GetStorageShape().GetDimNum();
uint32_t outShapeDim = opParamInfo_.attenOut.shape->GetStorageShape().GetDimNum();
uint32_t qExpectShapeDim = DIM_NUM_FOUR;
uint32_t kExpectShapeDim = DIM_NUM_FOUR;
if (qLayout_ == DataLayout::TND) {
qExpectShapeDim = DIM_NUM_THREE;
}
if (kLayout_ == DataLayout::TND) {
kExpectShapeDim = DIM_NUM_THREE;
}
OPS_ERR_IF(kShapeDim != kExpectShapeDim,
OPS_LOG_E(opName_, "the dim num of key's shape should be %u, but now is %u", kExpectShapeDim, kShapeDim),
return ge::GRAPH_FAILED);
OPS_ERR_IF(qShapeDim != qExpectShapeDim,
OPS_LOG_E(opName_, "the dim num of query's shape should be %u, but now is %u",
qExpectShapeDim, qShapeDim),
return ge::GRAPH_FAILED);
OPS_ERR_IF(outShapeDim != qExpectShapeDim,
OPS_LOG_E(opName_, "the dim num of sparse_indices's shape should be %u, but now is %u",
qExpectShapeDim, outShapeDim),
return ge::GRAPH_FAILED);
OPS_ERR_IF(!(weightsShapeDim == qExpectShapeDim - 1),
OPS_LOG_E(opName_, "the dim num of weights's shape should be %u, but now is %u", qExpectShapeDim - 1,
weightsShapeDim),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::GetN1Size()
{
if (qLayout_ == DataLayout::BSND) {
n1Size_ = static_cast<uint32_t>(opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_TWO));
} else {
// TND
n1Size_ = static_cast<uint32_t>(opParamInfo_.query.shape->GetStorageShape().GetDim(1));
}
OPS_LOG_I(context_->GetNodeName(), "n1Size is %d", n1Size_);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor,
const std::string &actualSeqLenName)
{
size = static_cast<uint32_t>(tensor->GetShapeSize());
if (size <= 0) {
OPS_LOG_E(opName_, "%s's shape size is %u, it should be greater than 0.", actualSeqLenName.c_str(), size);
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::GetAndCheckN2Size()
{
uint32_t n2Index = (kLayout_ == DataLayout::TND) ? DIM_IDX_ONE : DIM_IDX_TWO;
n2Size_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(n2Index));
OPS_LOG_I(context_->GetNodeName(), "n2Size_ is %d", n2Size_);
OPS_ERR_IF(n2Size_ != 1, OPS_LOG_E(opName_, "key shape[%u] is numhead, only support 1.", n2Index),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::GetGSize()
{
if (n1Size_ % n2Size_ != 0) {
OPS_LOG_E(opName_, "input query's head_num %u can not be a multiple of key's head_num %u.", n1Size_, n2Size_);
return ge::GRAPH_FAILED;
}
gSize_ = n1Size_ / n2Size_;
OPS_ERR_IF(gSize_ != 64, OPS_LOG_E(opName_, "N1 is %u, N2 is %u, N1 divided by N2 must equal 64.",
n1Size_, n2Size_), return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::GetBatchSize()
{
if ((qLayout_ == DataLayout::TND)) {
return GetActualSeqLenSize(bSize_, opParamInfo_.actualSeqLengthsQ.tensor, "input actual_seq_lengths_query");
} else { // BSND
bSize_ = opParamInfo_.query.shape->GetStorageShape().GetDim(0);
return ge::GRAPH_SUCCESS;
}
}
ge::graphStatus LIInfoParser::GetHeadDim()
{
uint32_t dIndex = DIM_IDX_TWO;
switch (qLayout_) {
case DataLayout::TND:
// TND: [Total, N, D] -> D is the 2nd dimension
dIndex = DIM_IDX_TWO;
break;
case DataLayout::BSND:
// BSND: [Batch, SeqLen, N, D] -> D is the 3rd dimension
dIndex = DIM_IDX_THREE;
break;
default:
OPS_LOG_E(opName_, "unsupported layout for getting head dim.");
return ge::GRAPH_FAILED;
}
headDim_ = opParamInfo_.query.shape->GetStorageShape().GetDim(dIndex);
OPS_ERR_IF(headDim_ != HEAD_DIM_LIMIT, OPS_LOG_E(opName_, "input query's last dim head_dim only support 128."),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::GetS1Size()
{
if (qLayout_ == DataLayout::BSND) {
s1Size_ = opParamInfo_.query.shape->GetStorageShape().GetDim(1);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::GetAndCheckBlockSize()
{
blockSize_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(1));
OPS_LOG_I(context_->GetNodeName(), "blockSize_ is %d", blockSize_);
OPS_ERR_IF(((blockSize_ % 16 != 0) || (blockSize_ == 0) || (blockSize_ > 1024)),
OPS_LOG_E(opName_, "input key's block_size must be a multiple of 16 and belong to (0, 1024]."),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::CheckBlockCount()
{
int32_t blockCount_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(0));
OPS_ERR_IF((blockCount_ == 0),
OPS_LOG_E(opName_, "input key's block_count cannot be 0."),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::GetS2SizeForPageAttention()
{
if (GetAndCheckBlockSize() != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
if (CheckBlockCount() != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
maxBlockNumPerBatch_ = opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(1);
s2Size_ = maxBlockNumPerBatch_ * blockSize_;
OPS_LOG_I(context_->GetNodeName(), "maxBlockNumPerBatch_ is %d, blockSize_ is %d, s2Size_ is %d",
maxBlockNumPerBatch_, blockSize_, s2Size_);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::GetS2Size()
{
if (kLayout_ == DataLayout::BnBsND) {
return GetS2SizeForPageAttention();
} else if (kLayout_ == DataLayout::TND) {
s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(0);
} else if (kLayout_ == DataLayout::BSND) {
s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(1);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::ValidateInputShapesMatchQTnd()
{
// -----------------------check BatchSize-------------------
if (kLayout_ == DataLayout::TND) {
OPS_ERR_IF(
(opParamInfo_.actualSeqLengths.tensor->GetShapeSize() != bSize_),
OPS_LOG_E(opName_,
"TND case input actual_seq_lengths_query, actual_seq_lengths_key are %u, %ld respectively, they must be same.",
bSize_, opParamInfo_.actualSeqLengths.tensor->GetShapeSize()),
return ge::GRAPH_FAILED);
} else { // kLayout_ PA_BSND
OPS_ERR_IF(
(opParamInfo_.actualSeqLengths.tensor->GetShapeSize() != bSize_) ||
(opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_),
OPS_LOG_E(
opName_,
"TND case input actual_seq_lengths_query, actual_seq_lengths_key, block_table dim 0 are %u, %ld, %ld respectively, they must be same.",
bSize_, opParamInfo_.actualSeqLengths.tensor->GetShapeSize(),
opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0)),
return ge::GRAPH_FAILED);
}
// -----------------------check T-------------------
uint32_t qTsize = opParamInfo_.query.shape->GetStorageShape().GetDim(0);
OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != qTsize) ||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != qTsize),
OPS_LOG_E(opName_, "TND case input query, weights, sparse_indices dim 0 are %u, %ld, %ld respectively, they must be same.",
qTsize, opParamInfo_.weights.shape->GetStorageShape().GetDim(0),
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::ValidateInputShapesMatchQBsnd()
{
// -----------------------check BatchSize-------------------
if (kLayout_ == DataLayout::BnBsND) {
OPS_ERR_IF((opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_) ||
(opParamInfo_.actualSeqLengths.tensor->GetShapeSize() != bSize_),
OPS_LOG_E(opName_, "BSND case input query, actual_seq_lengths_key, block_table dim 0 are %u, %ld, %ld respectively, they must be same.",
bSize_, opParamInfo_.actualSeqLengths.tensor->GetShapeSize(),
opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0)),
return ge::GRAPH_FAILED);
} else if (kLayout_ == DataLayout::BSND) {
OPS_ERR_IF(opParamInfo_.key.shape->GetStorageShape().GetDim(0) != bSize_,
OPS_LOG_E(opName_, "BSND case input query, key dim 0 are %u, %ld respectively, they must be same.",
bSize_, opParamInfo_.key.shape->GetStorageShape().GetDim(0)),
return ge::GRAPH_FAILED);
OPS_ERR_IF((opParamInfo_.actualSeqLengths.tensor != nullptr) &&
(opParamInfo_.actualSeqLengths.tensor->GetShapeSize() != bSize_),
OPS_LOG_E(opName_, "BSND case input query, actual_seq_lengths_key dim 0 are %u, %ld respectively, they must be same.",
bSize_, opParamInfo_.actualSeqLengths.tensor->GetShapeSize()),
return ge::GRAPH_FAILED);
}
OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != bSize_) ||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != bSize_),
OPS_LOG_E(opName_, "BSND case input query, weight, sparse_indices dim 0 are %u, %ld, %ld respectively, they must be same.",
bSize_, opParamInfo_.weights.shape->GetStorageShape().GetDim(0),
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)),
return ge::GRAPH_FAILED);
OPS_ERR_IF((opParamInfo_.actualSeqLengthsQ.tensor != nullptr) &&
(opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize() != bSize_),
OPS_LOG_E(opName_, "BSND case input query, actual_seq_lengths_query dim 0 are %u, %ld respectively, they must be same.",
bSize_, opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize()),
return ge::GRAPH_FAILED);
// -----------------------check S1-------------------
OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(1) != s1Size_) ||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1) != s1Size_),
OPS_LOG_E(opName_, "BSND case input query, weight, sparse_indices dim 1 are %u, %ld, %ld, they must be same.",
s1Size_, opParamInfo_.weights.shape->GetStorageShape().GetDim(1),
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1)),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIInfoParser::ValidateInputShapesMatch()
{
uint32_t queryWeightsN1Dim = 1;
uint32_t outN2Dim = 1;
if (qLayout_ == DataLayout::TND) {
if (ValidateInputShapesMatchQTnd() != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
} else {
if (ValidateInputShapesMatchQBsnd() != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
queryWeightsN1Dim = DIM_IDX_TWO;
outN2Dim = DIM_IDX_TWO;
}
// -----------------------check N1-------------------
OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(queryWeightsN1Dim) != n1Size_),
OPS_LOG_E(opName_, "input query, weight shape dim N1 must be same."), return ge::GRAPH_FAILED);
// -----------------------check D-------------------
uint32_t keyDDim = kLayout_ == DataLayout::TND ? DIM_IDX_TWO : DIM_IDX_THREE;
OPS_ERR_IF((opParamInfo_.key.shape->GetStorageShape().GetDim(keyDDim) != headDim_),
OPS_LOG_E(opName_, "input query, key shape last dim must be same."), return ge::GRAPH_FAILED);
// -----------------------check N2-------------------
OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim) != n2Size_),
OPS_LOG_E(opName_, "input query and output sparse_indices shape n2 dim must be same."),
return ge::GRAPH_FAILED);
// -----------------------check sparse_count-------------------
OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim + 1) != *opParamInfo_.sparseCount),
OPS_LOG_E(opName_, "output sparse_indices shape last dim must be same as attr sparse_count."),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
void LIInfoParser::GenerateInfo(LITilingInfo &liInfo)
{
liInfo.opName = opName_;
liInfo.platformInfo = platformInfo_;
liInfo.opParamInfo = opParamInfo_;
liInfo.socVersion = socVersion_;
liInfo.bSize = bSize_;
liInfo.n1Size = n1Size_;
liInfo.n2Size = n2Size_;
liInfo.s1Size = s1Size_;
liInfo.s2Size = s2Size_;
liInfo.gSize = gSize_;
liInfo.inputQType = inputQType_;
liInfo.inputKType = inputKType_;
liInfo.outputType = outputType_;
liInfo.blockSize = blockSize_;
liInfo.maxBlockNumPerBatch = maxBlockNumPerBatch_;
std::string layOutKeyStr(opParamInfo_.layOutKey);
liInfo.pageAttentionFlag = layOutKeyStr == "PA_BSND" ? true : false;
liInfo.sparseMode = *opParamInfo_.sparseMode;
liInfo.sparseCount = *opParamInfo_.sparseCount;
liInfo.inputQLayout = qLayout_;
liInfo.inputKLayout = kLayout_;
}
ge::graphStatus LIInfoParser::ParseAndCheck(LITilingInfo &liInfo)
{
if (ge::GRAPH_SUCCESS != GetOpName() || ge::GRAPH_SUCCESS != GetNpuInfo() || ge::GRAPH_SUCCESS != GetOpParaInfo() ||
ge::GRAPH_SUCCESS != CheckRequiredParaExistence()) {
return ge::GRAPH_FAILED;
}
if (ge::GRAPH_SUCCESS != GetAndCheckInOutDataType() || ge::GRAPH_SUCCESS != GetQueryKeyAndOutLayout() ||
ge::GRAPH_SUCCESS != GetAndCheckOptionalInput()) {
return ge::GRAPH_FAILED;
}
if (ge::GRAPH_SUCCESS != CheckShapeDim() || ge::GRAPH_SUCCESS != GetN1Size() ||
ge::GRAPH_SUCCESS != GetAndCheckN2Size() || ge::GRAPH_SUCCESS != GetGSize()) {
return ge::GRAPH_FAILED;
}
if (ge::GRAPH_SUCCESS != GetBatchSize() || ge::GRAPH_SUCCESS != GetS1Size() || ge::GRAPH_SUCCESS != GetHeadDim() ||
ge::GRAPH_SUCCESS != GetS2Size()) {
return ge::GRAPH_FAILED;
}
if (ge::GRAPH_SUCCESS != ValidateInputShapesMatch()) {
return ge::GRAPH_FAILED;
}
GenerateInfo(liInfo);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingPrepareForLightningIndexer(gert::TilingParseContext * /* context */)
{
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LightningIndexerTiling::DoTiling(LITilingInfo *tilingInfo)
{
// -------------set blockdim-----------------
auto ascendcPlatform = platform_ascendc::PlatformAscendC(tilingInfo->platformInfo);
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
uint32_t blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, aicNum, aivNum);
context_->SetBlockDim(blockDim);
// -------------set workspacesize-----------------
constexpr uint32_t MM1_RES_ELEM_SIZE = 4;
constexpr uint32_t DOUBLE_BUFFER = 2;
constexpr uint32_t M_BASE_SIZE = 512;
constexpr uint32_t S2_BASE_SIZE = 512;
constexpr uint32_t V1_RES_ELEM_SIZE = 4;
constexpr uint32_t V1_RES_ELEM_TYPE = 2;
constexpr uint32_t V1_DECODE_PARAM_ELEM_SIZE = 8;
constexpr uint32_t V1_DECODE_PARAM_NUM = 16;
constexpr uint32_t V1_DECODE_DATA_NUM = 2;
constexpr uint32_t S1_BASE_SIZE = 8;
constexpr uint32_t TOPK_MAX_SIZE = 2048;
uint32_t workspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize();
uint32_t mm1ResSize = M_BASE_SIZE * S2_BASE_SIZE;
workspaceSize += mm1ResSize * MM1_RES_ELEM_SIZE * DOUBLE_BUFFER * aicNum;
workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_RES_ELEM_TYPE * TOPK_MAX_SIZE * V1_RES_ELEM_SIZE * aicNum;
workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_DECODE_PARAM_NUM * V1_DECODE_PARAM_ELEM_SIZE * aicNum;
size_t *workSpaces = context_->GetWorkspaceSizes(1);
workSpaces[0] = workspaceSize;
// -------------set tilingdata-----------------
tilingData_.set_bSize(tilingInfo->bSize);
tilingData_.set_s2Size(tilingInfo->s2Size);
tilingData_.set_s1Size(tilingInfo->s1Size);
tilingData_.set_sparseCount(tilingInfo->sparseCount);
tilingData_.set_gSize(tilingInfo->gSize);
tilingData_.set_blockSize(tilingInfo->blockSize);
tilingData_.set_maxBlockNumPerBatch(tilingInfo->maxBlockNumPerBatch);
tilingData_.set_sparseMode(tilingInfo->sparseMode);
tilingData_.set_usedCoreNum(blockDim);
tilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity());
context_->GetRawTilingData()->SetDataSize(tilingData_.GetDataSize());
// -------------set tilingkey-----------------
// DT_Q, DT_KV, DT_OUT, PAGE_ATTENTION, FLASH_DECODE, LAYOUT_T, KV_LAYOUT_T
uint32_t inputQType = static_cast<uint32_t>(tilingInfo->inputQType);
uint32_t inputKType = static_cast<uint32_t>(tilingInfo->inputKType);
uint32_t outputType = static_cast<uint32_t>(tilingInfo->outputType);
uint32_t pageAttentionFlag = static_cast<uint32_t>(tilingInfo->pageAttentionFlag);
uint32_t inputQLayout = static_cast<uint32_t>(tilingInfo->inputQLayout);
uint32_t inputKLayout = static_cast<uint32_t>(tilingInfo->inputKLayout);
uint32_t tilingKey =
GET_TPL_TILING_KEY(inputQType, inputKType, outputType, pageAttentionFlag, inputQLayout, inputKLayout);
context_->SetTilingKey(tilingKey);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus TilingForLightningIndexer(gert::TilingContext *context)
{
OPS_ERR_IF(context == nullptr, OPS_REPORT_VECTOR_INNER_ERR("LightningIndexer", "Tiling context is null."),
return ge::GRAPH_FAILED);
LITilingInfo liInfo;
LIInfoParser LIInfoParser(context);
if (LIInfoParser.ParseAndCheck(liInfo) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
LightningIndexerTiling liTiling(context);
return liTiling.DoTiling(&liInfo);
}
IMPL_OP_OPTILING(LightningIndexerVllm)
.Tiling(TilingForLightningIndexer)
.TilingParse<LICompileInfo>(TilingPrepareForLightningIndexer);
} // namespace optiling

View File

@@ -0,0 +1,215 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_tiling.h
* \brief
*/
#ifndef LIGHTNING_INDEXER_TILING_H_
#define LIGHTNING_INDEXER_TILING_H_
#include "exe_graph/runtime/tiling_context.h"
#include "tiling/platform/platform_ascendc.h"
#include "register/op_def_registry.h"
#include "register/tilingdata_base.h"
#include "tiling/tiling_api.h"
#include "error/ops_error.h"
#include "platform/platform_info.h"
namespace optiling {
struct TilingRequiredParaInfo {
const gert::CompileTimeTensorDesc *desc;
const gert::StorageShape *shape;
};
struct TilingOptionalParaInfo {
const gert::CompileTimeTensorDesc *desc;
const gert::Tensor *tensor;
};
enum class DataLayout : uint32_t {
BSND = 0,
TND = 1,
BnBsND = 2
};
// Inputs Index
constexpr uint32_t QUERY_INDEX = 0;
constexpr uint32_t KEY_INDEX = 1;
constexpr uint32_t WEIGTHS_INDEX = 2;
constexpr uint32_t ACTUAL_SEQ_Q_INDEX = 3;
constexpr uint32_t ACTUAL_SEQ_K_INDEX = 4;
constexpr uint32_t BLOCK_TABLE_INDEX = 5;
constexpr uint32_t LIGHTNING_INDEXER = 0;
// Attributes Index
constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 0;
constexpr uint32_t ATTR_KEY_LAYOUT_INDEX = 1;
constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 2;
constexpr uint32_t ATTR_SPARSE_MODE_INDEX = 3;
// Dim Index
constexpr uint32_t DIM_IDX_ONE = 1;
constexpr uint32_t DIM_IDX_TWO = 2;
constexpr uint32_t DIM_IDX_THREE = 3;
// Dim Num
constexpr uint32_t DIM_NUM_TWO = 2;
constexpr uint32_t DIM_NUM_THREE = 3;
constexpr uint32_t DIM_NUM_FOUR = 4;
// Input Parameter Limit Constant
constexpr uint32_t HEAD_DIM_LIMIT = 128;
constexpr uint32_t SPARSE_LIMIT = 2048;
constexpr uint32_t SPARSE_MODE_LOWER = 3;
BEGIN_TILING_DATA_DEF(LITilingData)
TILING_DATA_FIELD_DEF(uint32_t, bSize)
TILING_DATA_FIELD_DEF(uint32_t, n2Size)
TILING_DATA_FIELD_DEF(uint32_t, gSize)
TILING_DATA_FIELD_DEF(uint32_t, s1Size)
TILING_DATA_FIELD_DEF(uint32_t, s2Size)
TILING_DATA_FIELD_DEF(uint32_t, sparseCount)
TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum)
TILING_DATA_FIELD_DEF(uint32_t, blockSize)
TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch)
TILING_DATA_FIELD_DEF(uint32_t, sparseMode)
END_TILING_DATA_DEF
REGISTER_TILING_DATA_CLASS(LightningIndexerVllm, LITilingData)
struct LICompileInfo {};
struct LiParaInfo {
TilingRequiredParaInfo query = {nullptr, nullptr};
TilingRequiredParaInfo key = {nullptr, nullptr};
TilingRequiredParaInfo weights = {nullptr, nullptr};
TilingOptionalParaInfo actualSeqLengthsQ = {nullptr, nullptr};
TilingOptionalParaInfo actualSeqLengths = {nullptr, nullptr};
TilingOptionalParaInfo blockTable = {nullptr, nullptr};
TilingRequiredParaInfo attenOut = {nullptr, nullptr};
const char *layOut = nullptr;
const char *layOutKey = nullptr;
const int32_t *blockSize = nullptr;
const int32_t *sparseMode = nullptr;
const int32_t *sparseCount = nullptr;
};
class LITilingInfo {
public:
const char *opName = nullptr;
fe::PlatFormInfos *platformInfo = nullptr;
LiParaInfo opParamInfo;
// Base Param
platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910B;
uint32_t bSize = 0;
uint32_t n1Size = 0;
uint32_t n2Size = 0;
uint32_t s1Size = 0;
int64_t s2Size = 0;
uint32_t qkHeadDim = 0;
uint32_t gSize = 0;
// PageAttention
bool pageAttentionFlag = false;
int32_t blockSize = 0;
uint32_t maxBlockNumPerBatch = 0;
// Mask
int32_t sparseMode = 0;
// Others Flag
uint32_t sparseCount = 0;
// DType
ge::DataType inputQType = ge::DT_FLOAT16;
ge::DataType inputKType = ge::DT_FLOAT16;
ge::DataType outputType = ge::DT_INT32;
// Layout
DataLayout inputQLayout = DataLayout::BSND;
DataLayout inputKLayout = DataLayout::BnBsND;
};
class LIInfoParser {
public:
explicit LIInfoParser(gert::TilingContext *context) : context_(context)
{
}
~LIInfoParser() = default;
ge::graphStatus CheckRequiredInOutExistence() const;
ge::graphStatus CheckRequiredAttrExistence() const;
ge::graphStatus CheckRequiredParaExistence() const;
ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor,
const std::string &actualSeqLenName);
ge::graphStatus GetOpName();
ge::graphStatus GetNpuInfo();
void GetOptionalInputParaInfo();
void GetInputParaInfo();
void GetOutputParaInfo();
ge::graphStatus GetAndCheckAttrParaInfo();
ge::graphStatus GetOpParaInfo();
ge::graphStatus ValidateInputShapesMatchQBsnd();
ge::graphStatus ValidateInputShapesMatchQTnd();
ge::graphStatus ValidateInputShapesMatch();
ge::graphStatus GetAndCheckInOutDataType();
ge::graphStatus GetBatchSize();
ge::graphStatus GetHeadDim();
ge::graphStatus GetS1Size();
ge::graphStatus GetAndCheckOptionalInput();
ge::graphStatus CheckShapeDim();
ge::graphStatus GetAndCheckBlockSize();
ge::graphStatus CheckBlockCount();
ge::graphStatus GetS2SizeForPageAttention();
ge::graphStatus GetS2Size();
ge::graphStatus GetQueryKeyAndOutLayout();
ge::graphStatus GetN1Size();
ge::graphStatus GetAndCheckN2Size();
ge::graphStatus GetGSize();
ge::graphStatus GetAttenMaskInfo();
ge::graphStatus GetActualSeqInfo();
void GenerateInfo(LITilingInfo &liInfo);
ge::graphStatus ParseAndCheck(LITilingInfo &liInfo);
public:
gert::TilingContext *context_ = nullptr;
const char *opName_;
fe::PlatFormInfos *platformInfo_;
LiParaInfo opParamInfo_;
// BaseParams
uint32_t bSize_ = 0;
uint32_t n1Size_ = 0;
uint32_t n2Size_ = 0;
uint32_t gSize_ = 0;
uint32_t s1Size_ = 0;
int64_t s2Size_ = 0;
uint32_t headDim_ = 0;
// Layout
DataLayout qLayout_ = DataLayout::BSND;
DataLayout kLayout_ = DataLayout::BnBsND;
// PageAttention
uint32_t maxBlockNumPerBatch_ = 0;
int32_t blockSize_ = 0;
platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B;
ge::DataType inputQType_ = ge::DT_FLOAT16;
ge::DataType inputKType_ = ge::DT_FLOAT16;
ge::DataType weightsType_ = ge::DT_FLOAT16;
ge::DataType blockTableType_ = ge::DT_FLOAT16;
ge::DataType inputKRopeType_ = ge::DT_FLOAT16;
ge::DataType outputType_ = ge::DT_FLOAT16;
};
class LightningIndexerTiling {
public:
explicit LightningIndexerTiling(gert::TilingContext *context) : context_(context){};
ge::graphStatus DoTiling(LITilingInfo *tilingInfo);
private:
gert::TilingContext *context_ = nullptr;
LITilingData tilingData_;
};
} // namespace optiling
#endif // LIGHTNING_INDEXER_TILING_H_