[Perf][1/N] w8a8c8 support in dsv3.2/glm5 (#7029)
### What this PR does / why we need it?
This PR supports W8A8C8 in dsv3.2/glm5 with lightning_indexer_quant ops
in pd-mix stage mainly.
Because the code for the current PD-disaggregated scenario is still
under refactoring and cleanup, this PR prioritizes ensuring the C8
functionality in the pd-mix scenario.
The next steps are planned in two parts:
① Once the optimized scatter operator is updated, we will replace the
original operator to improve the performance of storing k_scale.
② Once the code logic for the PD-disaggregated scenario becomes stable,
we will carry out more comprehensive validation and make appropriate
adaptations.
③ Because enabling C8 currently introduces several new operators whose
performance still needs improvement, performance may regress in some
scenarios. Therefore, only after all the operators are fully ready can
we ensure that this feature does not cause any performance degradation.
At that point, we will enable this feature by default and remove the
switch in `additional_config`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
41
csrc/lightning_indexer_quant/op_host/CMakeLists.txt
Normal file
41
csrc/lightning_indexer_quant/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,41 @@
|
||||
# This program is free software, you can redistribute it and/or modify it.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME LightningIndexerQuant
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
-mllvm -cce-aicore-hoist-movemask=false
|
||||
--op_relocatable_kernel_binary=true
|
||||
)
|
||||
|
||||
set(lightning_indexer_quant_depends transformer/attention/lightning_indexer_quant PARENT_SCOPE)
|
||||
|
||||
target_sources(op_host_aclnn PRIVATE
|
||||
lightning_indexer_quant_def.cpp
|
||||
)
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
lightning_indexer_quant_tiling.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(opmaster_ct PRIVATE
|
||||
lightning_indexer_quant_tiling.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/op_host
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
lightning_indexer_quant_proto.cpp
|
||||
)
|
||||
@@ -0,0 +1,85 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_def.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include <cstdint>
|
||||
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class LightningIndexerQuant : public OpDef {
|
||||
public:
|
||||
explicit LightningIndexerQuant(const char *name) : OpDef(name)
|
||||
{
|
||||
this->Input("query")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT8})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("key")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT8})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("weights")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("query_dequant_scale")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("key_dequant_scale")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("actual_seq_lengths_query")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("actual_seq_lengths_key")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("block_table")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Output("sparse_indices").ParamType(REQUIRED).DataType({ge::DT_INT32}).Format({ge::FORMAT_ND});
|
||||
this->Attr("query_quant_mode").AttrType(REQUIRED).Int(0); // 0: 默认值,per-token-head
|
||||
this->Attr("key_quant_mode").AttrType(REQUIRED).Int(0); // 0: 默认值,per-token-head
|
||||
this->Attr("layout_query").AttrType(OPTIONAL).String("BSND");
|
||||
this->Attr("layout_key").AttrType(OPTIONAL).String("PA_BSND");
|
||||
this->Attr("sparse_count").AttrType(OPTIONAL).Int(2048); // 2048: 默认值,筛选前2048
|
||||
this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3); // 3: 默认值,只计算下三角
|
||||
OpAICoreConfig aicore_config;
|
||||
aicore_config.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(true)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.NeedCheckSupportFlag(false)
|
||||
.PrecisionReduceFlag(true)
|
||||
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
|
||||
.ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false");
|
||||
this->AICore().AddConfig("ascend910b", aicore_config);
|
||||
this->AICore().AddConfig("ascend910_93", aicore_config);
|
||||
}
|
||||
};
|
||||
OP_ADD(LightningIndexerQuant);
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,91 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_proto.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include <graph/utils/type_utils.h>
|
||||
#include <register/op_impl_registry.h>
|
||||
|
||||
#include "error/ops_error.h"
|
||||
|
||||
using namespace ge;
|
||||
|
||||
namespace ops {
|
||||
constexpr uint32_t QUERY_INDEX = 0;
|
||||
constexpr uint32_t KEY_INDEX = 1;
|
||||
constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 2;
|
||||
constexpr uint32_t ATTR_KV_LAYOUT_INDEX = 3;
|
||||
constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 4;
|
||||
|
||||
static ge::graphStatus InferShapeLightningIndexerQuant(gert::InferShapeContext *context)
|
||||
{
|
||||
if (context == nullptr) {
|
||||
OPS_LOG_E("LightningIndexerQuant", "context is nullptr!");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
const gert::Shape *queryShape = context->GetInputShape(QUERY_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, queryShape, return ge::GRAPH_FAILED);
|
||||
const gert::Shape *keyShape = context->GetInputShape(KEY_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, keyShape, return ge::GRAPH_FAILED);
|
||||
gert::Shape *outShape = context->GetOutputShape(0);
|
||||
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED);
|
||||
const char *inputLayoutQueryPtr = attrs->GetAttrPointer<char>(ATTR_QUERY_LAYOUT_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, inputLayoutQueryPtr, return ge::GRAPH_FAILED);
|
||||
const char *inputLayoutKeyPtr = attrs->GetAttrPointer<char>(ATTR_KV_LAYOUT_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, inputLayoutKeyPtr, return ge::GRAPH_FAILED);
|
||||
const int64_t *sparse_count = attrs->GetInt(ATTR_SPARSE_COUNT_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, sparse_count, return ge::GRAPH_FAILED);
|
||||
|
||||
std::string inputLayoutQueryPtrStr = std::string(inputLayoutQueryPtr);
|
||||
std::string inputLayoutKeyPtrStr = std::string(inputLayoutKeyPtr);
|
||||
if (inputLayoutQueryPtrStr != "TND" && inputLayoutQueryPtrStr != "BSND") {
|
||||
OPS_LOG_E(context, "The input layout query should be TND or BSND, but got %s.", inputLayoutQueryPtrStr.c_str());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
outShape->SetDimNum(queryShape->GetDimNum());
|
||||
int64_t keyHeadNum = (inputLayoutKeyPtrStr == "TND") ? keyShape->GetDim(1) : keyShape->GetDim(2);
|
||||
if (inputLayoutQueryPtrStr == "BSND") {
|
||||
outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim B
|
||||
outShape->SetDim(1, queryShape->GetDim(1)); // 1:Dim S
|
||||
outShape->SetDim(2, keyHeadNum); // 2:Dim N
|
||||
outShape->SetDim(3, *sparse_count); // 3:Dim K
|
||||
} else {
|
||||
outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim T
|
||||
outShape->SetDim(1, keyHeadNum); // 1:output shape's N Dim, 2: key shape's N Dim
|
||||
outShape->SetDim(2, *sparse_count); // 2:Dim K
|
||||
}
|
||||
|
||||
OPS_LOG_D(context->GetNodeName(), "LightningIndexerQuant InferShape end.");
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus InferDataTypeLightningIndexerQuant(gert::InferDataTypeContext *context)
|
||||
{
|
||||
if (context == nullptr) {
|
||||
OPS_LOG_E("LightningIndexerQuant", "InferDataTypeContext context is nullptr!");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
OPS_LOG_D(context->GetNodeName(), "Enter LightningIndexerQuant InferDataType impl.");
|
||||
// default index data type is int32
|
||||
ge::DataType outputType = ge::DT_INT32;
|
||||
context->SetOutputDataType(0, outputType);
|
||||
OPS_LOG_D(context->GetNodeName(), "LightningIndexerQuant InferDataType end.");
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_INFERSHAPE(LightningIndexerQuant)
|
||||
.InferShape(InferShapeLightningIndexerQuant)
|
||||
.InferDataType(InferDataTypeLightningIndexerQuant);
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,828 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_tiling.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "lightning_indexer_quant_tiling.h"
|
||||
|
||||
#include "../op_kernel/lightning_indexer_quant_template_tiling_key.h"
|
||||
|
||||
using namespace ge;
|
||||
using namespace AscendC;
|
||||
using std::map;
|
||||
using std::string;
|
||||
namespace optiling {
|
||||
// --------------------------LIQInfoParser类成员函数定义-------------------------------------
|
||||
ge::graphStatus LIQInfoParser::CheckRequiredInOutExistence() const
|
||||
{
|
||||
OPS_ERR_IF(opParamInfo_.query.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor query is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.query.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor query is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.key.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor key is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.key.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor key is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.weights.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor weights is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.weights.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor weights is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.query_dequant_scale.shape == nullptr,
|
||||
OPS_LOG_E(opName_, "Shape of tensor query_dequant_scale is nullptr"), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.query_dequant_scale.desc == nullptr,
|
||||
OPS_LOG_E(opName_, "Desc of tensor query_dequant_scale is nullptr"), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.key_dequant_scale.shape == nullptr,
|
||||
OPS_LOG_E(opName_, "Shape of tensor key_dequant_scale is nullptr"), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.key_dequant_scale.desc == nullptr,
|
||||
OPS_LOG_E(opName_, "Desc of tensor key_dequant_scale is nullptr"), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.attenOut.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor output is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.attenOut.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor output is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::CheckRequiredAttrExistence() const
|
||||
{
|
||||
OPS_ERR_IF(opParamInfo_.layOutQuery == nullptr, OPS_LOG_E(opName_, "attr layout_query is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(opParamInfo_.layOutKey == nullptr, OPS_LOG_E(opName_, "attr layout_key is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(opParamInfo_.sparseCount == nullptr, OPS_LOG_E(opName_, "attr sparse_count is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(opParamInfo_.sparseMode == nullptr, OPS_LOG_E(opName_, "attr sparse_mode is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.queryQuantMode == nullptr, OPS_LOG_E(opName_, "query_quant_mode is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.keyQuantMode == nullptr, OPS_LOG_E(opName_, "key_quant_mode is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::CheckRequiredParaExistence() const
|
||||
{
|
||||
if (CheckRequiredInOutExistence() != ge::GRAPH_SUCCESS || CheckRequiredAttrExistence() != ge::GRAPH_SUCCESS) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetOpName()
|
||||
{
|
||||
if (context_->GetNodeName() == nullptr) {
|
||||
OPS_LOG_E("LightningIndexerQuant", "opName got from TilingContext is nullptr");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
opName_ = context_->GetNodeName();
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetNpuInfo()
|
||||
{
|
||||
platformInfo_ = context_->GetPlatformInfo();
|
||||
OPS_ERR_IF(platformInfo_ == nullptr, OPS_LOG_E(opName_, "GetPlatformInfo is nullptr."), return ge::GRAPH_FAILED);
|
||||
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo_);
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
|
||||
OPS_ERR_IF(aicNum == 0 || aivNum == 0, OPS_LOG_E(opName_, "num of core obtained is 0."), return GRAPH_FAILED);
|
||||
|
||||
socVersion_ = ascendcPlatform.GetSocVersion();
|
||||
if ((socVersion_ != platform_ascendc::SocVersion::ASCEND910B) &&
|
||||
(socVersion_ != platform_ascendc::SocVersion::ASCEND910_93)) {
|
||||
OPS_LOG_E(opName_, "SOC Version[%d] is not support.", (int32_t)socVersion_);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
OPS_ERR_IF(context_->GetWorkspaceSizes(1) == nullptr, OPS_LOG_E(opName_, "workSpaceSize got from ge is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(context_->GetRawTilingData() == nullptr,
|
||||
OPS_LOG_E(context_->GetNodeName(), "RawTilingData got from GE context is nullptr."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
void LIQInfoParser::GetOptionalInputParaInfo()
|
||||
{
|
||||
opParamInfo_.actualSeqLengthsQ.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_Q_INDEX);
|
||||
opParamInfo_.actualSeqLengthsQ.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_Q_INDEX);
|
||||
opParamInfo_.actualSeqLengthsK.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_K_INDEX);
|
||||
opParamInfo_.actualSeqLengthsK.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_K_INDEX);
|
||||
opParamInfo_.blockTable.tensor = context_->GetOptionalInputTensor(BLOCK_TABLE_INDEX);
|
||||
opParamInfo_.blockTable.desc = context_->GetOptionalInputDesc(BLOCK_TABLE_INDEX);
|
||||
}
|
||||
|
||||
void LIQInfoParser::GetInputParaInfo()
|
||||
{
|
||||
opParamInfo_.query.desc = context_->GetInputDesc(QUERY_INDEX);
|
||||
opParamInfo_.query.shape = context_->GetInputShape(QUERY_INDEX);
|
||||
opParamInfo_.key.desc = context_->GetInputDesc(KEY_INDEX);
|
||||
opParamInfo_.key.shape = context_->GetInputShape(KEY_INDEX);
|
||||
opParamInfo_.weights.desc = context_->GetInputDesc(WEIGTHS_INDEX);
|
||||
opParamInfo_.weights.shape = context_->GetInputShape(WEIGTHS_INDEX);
|
||||
opParamInfo_.query_dequant_scale.desc = context_->GetInputDesc(QUERY_DEQUANT_SCALE_INDEX);
|
||||
opParamInfo_.query_dequant_scale.shape = context_->GetInputShape(QUERY_DEQUANT_SCALE_INDEX);
|
||||
opParamInfo_.key_dequant_scale.desc = context_->GetInputDesc(KEY_DEQUANT_SCALE_INDEX);
|
||||
opParamInfo_.key_dequant_scale.shape = context_->GetInputShape(KEY_DEQUANT_SCALE_INDEX);
|
||||
GetOptionalInputParaInfo();
|
||||
}
|
||||
|
||||
void LIQInfoParser::GetOutputParaInfo()
|
||||
{
|
||||
opParamInfo_.attenOut.desc = context_->GetOutputDesc(LIGHTNING_INDEXER_QUANT);
|
||||
opParamInfo_.attenOut.shape = context_->GetOutputShape(LIGHTNING_INDEXER_QUANT);
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetAttrParaInfo()
|
||||
{
|
||||
auto attrs = context_->GetAttrs();
|
||||
OPS_ERR_IF(attrs == nullptr, OPS_REPORT_VECTOR_INNER_ERR(context_->GetNodeName(), "attrs got from ge is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_LOG_I(context_->GetNodeName(), "GetAttrParaInfo start");
|
||||
opParamInfo_.layOutQuery = attrs->GetStr(ATTR_QUERY_LAYOUT_INDEX);
|
||||
opParamInfo_.layOutKey = attrs->GetStr(ATTR_KEY_LAYOUT_INDEX);
|
||||
|
||||
opParamInfo_.queryQuantMode = attrs->GetAttrPointer<int32_t>(ATTR_QUERY_QUANT_MODE_INDEX);
|
||||
opParamInfo_.keyQuantMode = attrs->GetAttrPointer<int32_t>(ATTR_KEY_QUANT_MODE_INDEX);
|
||||
opParamInfo_.layOutQuery = attrs->GetStr(ATTR_QUERY_LAYOUT_INDEX);
|
||||
opParamInfo_.layOutKey = attrs->GetStr(ATTR_KEY_LAYOUT_INDEX);
|
||||
opParamInfo_.sparseCount = attrs->GetAttrPointer<int32_t>(ATTR_SPARSE_COUNT_INDEX);
|
||||
opParamInfo_.sparseMode = attrs->GetAttrPointer<int32_t>(ATTR_SPARSE_MODE_INDEX);
|
||||
|
||||
if (opParamInfo_.layOutQuery != nullptr) {
|
||||
OPS_LOG_I(context_->GetNodeName(), "layout_query is:%s", opParamInfo_.layOutQuery);
|
||||
}
|
||||
if (opParamInfo_.layOutKey != nullptr) {
|
||||
OPS_LOG_I(context_->GetNodeName(), "layout_key is:%s", opParamInfo_.layOutKey);
|
||||
}
|
||||
if (opParamInfo_.sparseCount != nullptr) {
|
||||
OPS_LOG_I(context_->GetNodeName(), "selscted count is:%d", *opParamInfo_.sparseCount);
|
||||
}
|
||||
if (opParamInfo_.sparseMode != nullptr) {
|
||||
OPS_LOG_I(context_->GetNodeName(), "sparse mode is:%d", *opParamInfo_.sparseMode);
|
||||
}
|
||||
if (opParamInfo_.queryQuantMode != nullptr) {
|
||||
OPS_LOG_I(context_->GetNodeName(), "query_quant_mode mode is:%d", *opParamInfo_.queryQuantMode);
|
||||
}
|
||||
if (opParamInfo_.keyQuantMode != nullptr) {
|
||||
OPS_LOG_I(context_->GetNodeName(), "key_quant_mode mode is:%d", *opParamInfo_.keyQuantMode);
|
||||
}
|
||||
OPS_LOG_I(context_->GetNodeName(), "GetAttrParaInfo end");
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::CheckAttrParaInfo()
|
||||
{
|
||||
std::string layout_key(opParamInfo_.layOutKey);
|
||||
std::string layout_query(opParamInfo_.layOutQuery);
|
||||
OPS_ERR_IF(
|
||||
((std::string(opParamInfo_.layOutKey) == "BNSD") || (std::string(opParamInfo_.layOutKey) == "PA_BBND")),
|
||||
OPS_LOG_E(opName_, "input attr layout_key only supported PA_BSND, PA_BBND, BSND or TND"
|
||||
"but now layout_key is %s.", layout_key.c_str()),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(((std::string(opParamInfo_.layOutQuery) != "BSND") && (std::string(opParamInfo_.layOutQuery) != "TND")),
|
||||
OPS_LOG_E(opName_, "input attr layout_query only supported BSND or TND."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(
|
||||
((std::string(opParamInfo_.layOutKey) != "PA_BSND") &&
|
||||
(std::string(opParamInfo_.layOutQuery)) != (std::string(opParamInfo_.layOutKey))),
|
||||
OPS_LOG_E(opName_, "outside of PA, input attr layout_query and input attr layout_key must be the same, but now layout_key is %s, layout_query is %s.",
|
||||
layout_key.c_str(), layout_query.c_str()), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(!((*opParamInfo_.sparseCount > 0) && (*opParamInfo_.sparseCount <= SPARSE_LIMIT)),
|
||||
OPS_LOG_E(opName_, "input attr sparse_count must > 0 and <= 2048."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(!((*opParamInfo_.sparseMode == 0) || (*opParamInfo_.sparseMode == SPARSE_MODE_LOWER)),
|
||||
OPS_LOG_E(opName_, "input attr sparse_mode only supported 0 or 3, but now is %u.",
|
||||
*opParamInfo_.sparseMode), return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(*opParamInfo_.queryQuantMode != 0, OPS_LOG_E(opName_, "input attr query_quant_mode only supported 0."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(*opParamInfo_.keyQuantMode != 0, OPS_LOG_E(opName_, "input attr key_quant_mode only supported 0."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetOpParaInfo()
|
||||
{
|
||||
GetInputParaInfo();
|
||||
GetOutputParaInfo();
|
||||
if (ge::GRAPH_SUCCESS != GetAttrParaInfo()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
if (ge::GRAPH_SUCCESS != CheckAttrParaInfo()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetAndCheckInOutDataType()
|
||||
{
|
||||
inputQType_ = opParamInfo_.query.desc->GetDataType();
|
||||
inputKType_ = opParamInfo_.key.desc->GetDataType();
|
||||
weightsType_ = opParamInfo_.weights.desc->GetDataType();
|
||||
inputQueryScaleType_ = opParamInfo_.query_dequant_scale.desc->GetDataType();
|
||||
inputKeyScaleType_ = opParamInfo_.key_dequant_scale.desc->GetDataType();
|
||||
outputType_ = opParamInfo_.attenOut.desc->GetDataType();
|
||||
|
||||
OPS_ERR_IF(!(inputQType_ == inputKType_),
|
||||
OPS_LOG_E(opName_, "The data types of the input query and key must be the same, but now is %s, %s respectively.",
|
||||
inputQType_, inputKType_),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(
|
||||
!(inputQueryScaleType_ == inputKeyScaleType_),
|
||||
OPS_LOG_E(opName_, "The data types of the input query_dequant_scale and key_dequant_scale must be the same."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(inputQType_ != ge::DT_INT8,
|
||||
OPS_LOG_E(opName_, "The data types of the input query and key must be int8."), return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(weightsType_ != ge::DT_FLOAT16,
|
||||
OPS_LOG_E(opName_, "The data types of the input weights must be float16."), return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(
|
||||
inputQueryScaleType_ != ge::DT_FLOAT16,
|
||||
OPS_LOG_E(opName_, "The data types of the input query_dequant_scale and key_dequant_scale must be float16."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(outputType_ != ge::DT_INT32,
|
||||
OPS_LOG_E(opName_, "The data types of the output sparse_indices must be int32."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetQueryKeyAndOutLayout()
|
||||
{
|
||||
// 获取query,key的Layout基准值
|
||||
const map<string, DataLayout> layoutQueryMap = {{"BSND", DataLayout::BSND}, {"TND", DataLayout::TND}};
|
||||
|
||||
std::string layout_query(opParamInfo_.layOutQuery);
|
||||
auto QLayout_ = layoutQueryMap.find(layout_query);
|
||||
if (QLayout_ != layoutQueryMap.end()) {
|
||||
qLayout_ = QLayout_->second;
|
||||
}
|
||||
|
||||
const map<string, DataLayout> layoutKeyMap = {
|
||||
{"BSND", DataLayout::BSND}, {"TND", DataLayout::TND},
|
||||
{"PA_BSND", DataLayout::PA_BSND}, {"PA_BBND", DataLayout::PA_BSND}};
|
||||
std::string layout_key(opParamInfo_.layOutKey);
|
||||
auto KLayout = layoutKeyMap.find(layout_key);
|
||||
if (KLayout != layoutKeyMap.end()) {
|
||||
kLayout_ = KLayout->second;
|
||||
}
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetAndCheckOptionalInput()
|
||||
{
|
||||
if (kLayout_ == DataLayout::PA_BSND) {
|
||||
OPS_ERR_IF(opParamInfo_.blockTable.tensor == nullptr,
|
||||
OPS_LOG_E(opName_, "key layout only supported PA_BSND, input block_table must not be null"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(
|
||||
opParamInfo_.actualSeqLengthsK.tensor == nullptr,
|
||||
OPS_LOG_E(opName_, "key layout only supported PA_BSND, input actual_seq_lengths_key must not be null"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.blockTable.desc->GetDataType() != ge::DT_INT32,
|
||||
OPS_LOG_E(opName_, "input block_table data type only support int32"), return ge::GRAPH_FAILED);
|
||||
} else {
|
||||
OPS_ERR_IF(opParamInfo_.blockTable.tensor != nullptr,
|
||||
OPS_LOG_E(opName_, "key layout is not PA_BSND, input block_table must be null"),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
|
||||
if (kLayout_ == DataLayout::TND) {
|
||||
OPS_ERR_IF(opParamInfo_.actualSeqLengthsK.tensor == nullptr,
|
||||
OPS_LOG_E(opName_, "when layout_key is TND, input actual_seq_lengths_key must not be null"),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
OPS_ERR_IF(opParamInfo_.actualSeqLengthsK.tensor != nullptr &&
|
||||
opParamInfo_.actualSeqLengthsK.desc->GetDataType() != ge::DT_INT32,
|
||||
OPS_LOG_E(opName_, "input actual_seq_lengths_key data type only support int32"),
|
||||
return ge::GRAPH_FAILED);
|
||||
if (qLayout_ == DataLayout::TND) {
|
||||
OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor == nullptr,
|
||||
OPS_LOG_E(opName_, "when layout_query is TND, input actual_seq_lengths_query must not be null"),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor != nullptr &&
|
||||
opParamInfo_.actualSeqLengthsQ.desc->GetDataType() != ge::DT_INT32,
|
||||
OPS_LOG_E(opName_, "input actual_seq_lengths_query data type only support int32"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::CheckShapeDim()
|
||||
{
|
||||
OPS_ERR_IF((opParamInfo_.blockTable.tensor != nullptr) &&
|
||||
(opParamInfo_.blockTable.tensor->GetStorageShape().GetDimNum() != DIM_NUM_TWO),
|
||||
OPS_LOG_E(opName_, "the dim num of block_table's shape should be 2, but now is %u",
|
||||
opParamInfo_.blockTable.tensor->GetStorageShape().GetDimNum()), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(
|
||||
(kLayout_ == DataLayout::PA_BSND) && (opParamInfo_.key.shape->GetStorageShape().GetDimNum() != DIM_NUM_FOUR),
|
||||
OPS_LOG_E(opName_, "the dim num of key's shape should be 4, but now is %u",
|
||||
opParamInfo_.key.shape->GetStorageShape().GetDimNum()), return ge::GRAPH_FAILED);
|
||||
|
||||
uint32_t qShapeDim = opParamInfo_.query.shape->GetStorageShape().GetDimNum();
|
||||
uint32_t weightsShapeDim = opParamInfo_.weights.shape->GetStorageShape().GetDimNum();
|
||||
uint32_t outShapeDim = opParamInfo_.attenOut.shape->GetStorageShape().GetDimNum();
|
||||
uint32_t expectShapeDim = DIM_NUM_FOUR;
|
||||
if (qLayout_ == DataLayout::TND) {
|
||||
expectShapeDim = DIM_NUM_THREE;
|
||||
}
|
||||
OPS_ERR_IF(
|
||||
qShapeDim != expectShapeDim,
|
||||
OPS_LOG_E(opName_, "the dim num of query's shape should be %u, but now is %u", expectShapeDim, qShapeDim),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(outShapeDim != expectShapeDim,
|
||||
OPS_LOG_E(opName_, "the dim num of sparse_indices's shape should be %u, but now is %u", expectShapeDim,
|
||||
outShapeDim),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(!(weightsShapeDim == expectShapeDim - 1),
|
||||
OPS_LOG_E(opName_, "the dim num of weights's shape should be %u, but now is %u", expectShapeDim - 1,
|
||||
weightsShapeDim),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetN1Size()
|
||||
{
|
||||
if (qLayout_ == DataLayout::BSND) {
|
||||
n1Size_ = static_cast<uint32_t>(opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_TWO));
|
||||
} else {
|
||||
// TND
|
||||
n1Size_ = static_cast<uint32_t>(opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ONE));
|
||||
}
|
||||
OPS_LOG_I(context_->GetNodeName(), "n1Size is %d", n1Size_);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor,
|
||||
const std::string &actualSeqLenName)
|
||||
{
|
||||
size = static_cast<uint32_t>(tensor->GetShapeSize());
|
||||
if (size <= 0) {
|
||||
OPS_LOG_E(opName_, "%s's shape size is %u, it should be greater than 0.", actualSeqLenName.c_str(), size);
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetAndCheckN2Size()
|
||||
{
|
||||
// PA_BSND
|
||||
if (kLayout_ == DataLayout::TND) {
|
||||
n2Size_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_ONE));
|
||||
} else {
|
||||
n2Size_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_TWO));
|
||||
}
|
||||
OPS_LOG_I(context_->GetNodeName(), "N2 is %d", n2Size_);
|
||||
OPS_ERR_IF(n2Size_ != 1, OPS_LOG_E(opName_, "key shape[2] is numhead, only support 1."), return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetGSize()
|
||||
{
|
||||
if (n1Size_ % n2Size_ != 0) {
|
||||
OPS_LOG_E(opName_, "input query's head_num %u can not be a multiple of key's head_num %u.", n1Size_, n2Size_);
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
gSize_ = n1Size_ / n2Size_;
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetBatchSize()
|
||||
{
|
||||
// 获取B基准值
|
||||
// 1、非TND/NTD时, 以query的batch_size维度为基准;
|
||||
// 2、TND/NTD时, actual_seq_lens_q必须传入, 以actual_seq_lens_q数组的长度为B轴大小
|
||||
if (qLayout_ == DataLayout::TND) {
|
||||
return GetActualSeqLenSize(bSize_, opParamInfo_.actualSeqLengthsQ.tensor, "input actual_seq_lengths_query");
|
||||
} else { // BSND
|
||||
bSize_ = opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ZERO);
|
||||
OPS_LOG_I(context_->GetNodeName(), "b: %d, s: %d, n: %d,d :%d",
|
||||
opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ZERO),
|
||||
opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ONE),
|
||||
opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_TWO),
|
||||
opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_THREE));
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetHeadDim()
|
||||
{
|
||||
// 以query的D维度为基准
|
||||
uint32_t dIndex = DIM_IDX_TWO;
|
||||
// 根据layout确定D维度在shape中的位置
|
||||
switch (qLayout_) {
|
||||
case DataLayout::TND:
|
||||
// TND格式: [Total, N, D] -> D是第2维(索引2)
|
||||
dIndex = DIM_IDX_TWO;
|
||||
break;
|
||||
case DataLayout::BSND:
|
||||
// BSND格式: [Batch, SeqLen, N, D] -> D是第3维(索引3)
|
||||
dIndex = DIM_IDX_THREE;
|
||||
break;
|
||||
default:
|
||||
OPS_LOG_E(opName_, "unsupported layout for getting head dim.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
headDim_ = opParamInfo_.query.shape->GetStorageShape().GetDim(dIndex);
|
||||
OPS_ERR_IF(headDim_ != HEAD_DIM_LIMIT, OPS_LOG_E(opName_, "input query's last dim head_dim only support 128, but now is %u.", headDim_),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetS1Size()
|
||||
{
|
||||
if (qLayout_ == DataLayout::BSND) {
|
||||
s1Size_ = opParamInfo_.query.shape->GetStorageShape().GetDim(1);
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetAndCheckBlockSize()
|
||||
{
|
||||
blockSize_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(1));
|
||||
OPS_LOG_I(context_->GetNodeName(), "blockSize_ is %d", blockSize_);
|
||||
|
||||
OPS_ERR_IF(
|
||||
((blockSize_ % BLOCK_SIZE_FACTOR != 0) || (blockSize_ == 0) || (blockSize_ > BLOCK_SIZE_LIMIT)),
|
||||
OPS_LOG_E(opName_, "input key's block_size must be a multiple of 16 and belong to (0, 1024], but now is %u.",
|
||||
blockSize_),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetS2SizeForPageAttention()
|
||||
{
|
||||
if (GetAndCheckBlockSize() != ge::GRAPH_SUCCESS) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int32_t blockCount_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(0));
|
||||
OPS_ERR_IF((blockCount_ == 0), OPS_LOG_E(opName_, "input key's block_count cannot be 0."), return ge::GRAPH_FAILED);
|
||||
|
||||
maxBlockNumPerBatch_ = opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(1);
|
||||
s2Size_ = maxBlockNumPerBatch_ * blockSize_;
|
||||
OPS_LOG_I(context_->GetNodeName(), "maxBlockNumPerBatch_ is %d, blockSize_ is %d, s2Size_ is %d",
|
||||
maxBlockNumPerBatch_, blockSize_, s2Size_);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetS2SizeForBatchContinuous()
|
||||
{
|
||||
std::string layout_key(opParamInfo_.layOutKey);
|
||||
if (kLayout_ == DataLayout::BSND) {
|
||||
s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_ONE);
|
||||
} else if (kLayout_ == DataLayout::TND) {
|
||||
s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_ZERO);
|
||||
}
|
||||
OPS_ERR_IF((kLayout_ != DataLayout::BSND) && (kLayout_ != DataLayout::TND),
|
||||
OPS_LOG_E(opName_, "the layout of key is %s, it is unsupported.", layout_key.c_str()),
|
||||
return ge::GRAPH_FAILED);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetS2Size()
|
||||
{
|
||||
// 获取S2基准值
|
||||
// 1、BATCH_CONTINUOUS时, 从key的S轴获取
|
||||
// 3、PAGE_ATTENTION时, S2 = block_table.dim1 * block_size
|
||||
if (kLayout_ == DataLayout::PA_BSND) {
|
||||
return GetS2SizeForPageAttention();
|
||||
}
|
||||
return GetS2SizeForBatchContinuous();
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::ValidateInputShapesMatch()
|
||||
{
|
||||
/*
|
||||
TND:
|
||||
query [T,N1,D],
|
||||
key [BlockNum,BlockSize,N2,D],
|
||||
weight [T,N1],
|
||||
block_table [BatchSize, BatchMaxBlockNum],
|
||||
act_seq_k [BatchSize]
|
||||
act_seq_q [BatchSize],
|
||||
out [T,N2,topk]
|
||||
----------------------
|
||||
BSND:
|
||||
query [BatchSize,S1,N1,D],
|
||||
key [BlockNum,BlockSize,N2,D],
|
||||
weight [BatchSize,S1,N1],
|
||||
block_table [BatchSize, BatchMaxBlockNum],
|
||||
act_seq_k [BatchSize]
|
||||
act_seq_q [BatchSize] 可选
|
||||
out [BatchSize,S1,N2,topk]
|
||||
*/
|
||||
uint32_t queryWeightsN1Dim = 1;
|
||||
uint32_t outN2Dim = 1;
|
||||
|
||||
if (qLayout_ == DataLayout::TND) {
|
||||
// -----------------------check BatchSize-------------------
|
||||
// bSize_ 来源于act_seq_q
|
||||
OPS_ERR_IF((kLayout_ == DataLayout::PA_BSND) &&
|
||||
((opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_) ||
|
||||
(opParamInfo_.blockTable.tensor != nullptr &&
|
||||
opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_)),
|
||||
OPS_LOG_E(
|
||||
opName_,
|
||||
"TND case input actual_seq_lengths_query, actual_seq_lengths_key, block_table dim 0 are %u, %u, %u respectively, they must be same.",
|
||||
bSize_, opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize(),
|
||||
opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF((kLayout_ != DataLayout::PA_BSND) &&
|
||||
(opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_),
|
||||
OPS_LOG_E(
|
||||
opName_,
|
||||
"TND case input actual_seq_lengths_query, actual_seq_lengths_key, are %u, %u respectively, they must be same.",
|
||||
bSize_, opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize()),
|
||||
return ge::GRAPH_FAILED);
|
||||
// -----------------------check T-------------------
|
||||
uint32_t qTsize = opParamInfo_.query.shape->GetStorageShape().GetDim(0);
|
||||
OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != qTsize) ||
|
||||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != qTsize),
|
||||
OPS_LOG_E(opName_,
|
||||
"TND case input query, weights, sparse_indices dim 0 are %u, %u, %u respectively, they must be same.",
|
||||
qTsize, opParamInfo_.weights.shape->GetStorageShape().GetDim(0),
|
||||
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
} else {
|
||||
// -----------------------check BatchSize-------------------
|
||||
// bSize_ 来源于query
|
||||
OPS_ERR_IF((kLayout_ == DataLayout::PA_BSND) &&
|
||||
((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != bSize_) ||
|
||||
(opParamInfo_.blockTable.tensor != nullptr &&
|
||||
opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_) ||
|
||||
(opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_) ||
|
||||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != bSize_)),
|
||||
OPS_LOG_E(opName_,
|
||||
"BSND case input query, weight, actual_seq_lengths_key, block_table, sparse_indices dim 0 are %u, %u, %u, %u, %u respectively, they must be same.",
|
||||
bSize_, opParamInfo_.weights.shape->GetStorageShape().GetDim(0),
|
||||
opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize(),
|
||||
opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0),
|
||||
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF((kLayout_ != DataLayout::PA_BSND) &&
|
||||
((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != bSize_) ||
|
||||
(opParamInfo_.actualSeqLengthsK.tensor != nullptr &&
|
||||
opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_) ||
|
||||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != bSize_)),
|
||||
OPS_LOG_E(opName_,
|
||||
"BSND case input query, weight, actual_seq_lengths_key, sparse_indices dim 0 are %u, %u, %u, %u respectively, they must be same.",
|
||||
bSize_, opParamInfo_.weights.shape->GetStorageShape().GetDim(0),
|
||||
opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize(),
|
||||
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(
|
||||
(opParamInfo_.actualSeqLengthsQ.tensor != nullptr) &&
|
||||
(opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize() != bSize_),
|
||||
OPS_LOG_E(
|
||||
opName_,
|
||||
"BSND case input query, actual_seq_lengths_query dim 0 are %u, %u respectively, they must be same",
|
||||
bSize_, opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize()),
|
||||
return ge::GRAPH_FAILED);
|
||||
// -----------------------check S1-------------------
|
||||
OPS_ERR_IF(
|
||||
(opParamInfo_.weights.shape->GetStorageShape().GetDim(1) != s1Size_) ||
|
||||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1) != s1Size_),
|
||||
OPS_LOG_E(opName_, "BSND case input query, weight, sparse_indices dim 1 are %u, %u, %u, they must be same.",
|
||||
s1Size_, opParamInfo_.weights.shape->GetStorageShape().GetDim(1),
|
||||
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1)),
|
||||
return ge::GRAPH_FAILED);
|
||||
queryWeightsN1Dim = DIM_IDX_TWO;
|
||||
outN2Dim = DIM_IDX_TWO;
|
||||
}
|
||||
// -----------------------check N1-------------------
|
||||
OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(queryWeightsN1Dim) != n1Size_),
|
||||
OPS_LOG_E(opName_, "input query, weight shape dim N1 must be same, but now are %u, %u respectively.",
|
||||
opParamInfo_.weights.shape->GetStorageShape().GetDim(queryWeightsN1Dim), n1Size_),
|
||||
return ge::GRAPH_FAILED);
|
||||
// -----------------------check D-------------------
|
||||
OPS_ERR_IF(
|
||||
((kLayout_ != DataLayout::TND && opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_THREE) != headDim_)
|
||||
|| (kLayout_ == DataLayout::TND && opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_TWO) != headDim_)),
|
||||
OPS_LOG_E(opName_, "input query, key shape last dim must be same, now are %u, %u respectively.",
|
||||
headDim_, opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_THREE)),
|
||||
return ge::GRAPH_FAILED);
|
||||
// -----------------------check N2-------------------
|
||||
OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim) != n2Size_),
|
||||
OPS_LOG_E(opName_, "input query and output sparse_indices shape n2 dim must be same."),
|
||||
return ge::GRAPH_FAILED);
|
||||
// -----------------------check sparse_count-------------------
|
||||
OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim + 1) != *opParamInfo_.sparseCount),
|
||||
OPS_LOG_E(opName_, "output sparse_indices shape last dim must be same as attr sparse_count."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::CheckScaleShape()
|
||||
{
|
||||
uint32_t qShapeDim = opParamInfo_.query.shape->GetStorageShape().GetDimNum();
|
||||
uint32_t kShapeDim = opParamInfo_.key.shape->GetStorageShape().GetDimNum();
|
||||
uint32_t qDequantScaleShapeDim = opParamInfo_.query_dequant_scale.shape->GetStorageShape().GetDimNum();
|
||||
uint32_t kDequantScaleShapeDim = opParamInfo_.key_dequant_scale.shape->GetStorageShape().GetDimNum();
|
||||
OPS_ERR_IF(qDequantScaleShapeDim != (qShapeDim - 1),
|
||||
OPS_LOG_E(opName_, "the dim num of query_dequant_scale's shape should be %u, but now is %u",
|
||||
qShapeDim - 1, qDequantScaleShapeDim),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(kDequantScaleShapeDim != (kShapeDim - 1),
|
||||
OPS_LOG_E(opName_, "the dim num of key_dequant_scale's shape should be %u, but now is %u", kShapeDim - 1,
|
||||
kDequantScaleShapeDim),
|
||||
return ge::GRAPH_FAILED);
|
||||
// check q scale
|
||||
for (uint32_t i = 0; i < (qShapeDim - 1); i++) {
|
||||
uint32_t dimValueQueryScale = opParamInfo_.query_dequant_scale.shape->GetStorageShape().GetDim(i);
|
||||
uint32_t dimValueQuery = opParamInfo_.query.shape->GetStorageShape().GetDim(i);
|
||||
OPS_ERR_IF(dimValueQueryScale != dimValueQuery,
|
||||
OPS_LOG_E(opName_, "query_dequant_scale's shape[%u] %u and query's shape[%u] %u is not same", i,
|
||||
dimValueQueryScale, i, dimValueQuery),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
// check k scale
|
||||
for (uint32_t i = 0; i < (kShapeDim - 1); i++) {
|
||||
uint32_t dimValueKeyScale = opParamInfo_.key_dequant_scale.shape->GetStorageShape().GetDim(i);
|
||||
uint32_t dimValueKey = opParamInfo_.key.shape->GetStorageShape().GetDim(i);
|
||||
OPS_ERR_IF(dimValueKeyScale != dimValueKey,
|
||||
OPS_LOG_E(opName_, "key_dequant_scale's shape[%u] %u and key's shape[%u] %u is not same", i,
|
||||
dimValueKeyScale, i, dimValueKey),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
void LIQInfoParser::GenerateInfo(LIQTilingInfo &liqInfo)
|
||||
{
|
||||
liqInfo.opName = opName_;
|
||||
liqInfo.platformInfo = platformInfo_;
|
||||
liqInfo.opParamInfo = opParamInfo_;
|
||||
liqInfo.socVersion = socVersion_;
|
||||
|
||||
liqInfo.bSize = bSize_;
|
||||
liqInfo.n1Size = n1Size_;
|
||||
liqInfo.n2Size = n2Size_;
|
||||
liqInfo.s1Size = s1Size_;
|
||||
liqInfo.s2Size = s2Size_;
|
||||
liqInfo.gSize = gSize_;
|
||||
|
||||
liqInfo.inputQType = inputQType_;
|
||||
liqInfo.inputKType = inputKType_;
|
||||
liqInfo.outputType = outputType_;
|
||||
|
||||
liqInfo.blockSize = blockSize_;
|
||||
liqInfo.maxBlockNumPerBatch = maxBlockNumPerBatch_;
|
||||
|
||||
liqInfo.pageAttentionFlag = (kLayout_ == DataLayout::PA_BSND);
|
||||
liqInfo.sparseMode = *opParamInfo_.sparseMode;
|
||||
liqInfo.sparseCount = *opParamInfo_.sparseCount;
|
||||
|
||||
liqInfo.inputQLayout = qLayout_;
|
||||
liqInfo.inputKLayout = kLayout_;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::ParseAndCheck(LIQTilingInfo &liqInfo)
|
||||
{
|
||||
if (ge::GRAPH_SUCCESS != GetOpName() || ge::GRAPH_SUCCESS != GetNpuInfo() || ge::GRAPH_SUCCESS != GetOpParaInfo() ||
|
||||
ge::GRAPH_SUCCESS != CheckRequiredParaExistence()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (ge::GRAPH_SUCCESS != GetAndCheckInOutDataType() || ge::GRAPH_SUCCESS != GetQueryKeyAndOutLayout() ||
|
||||
ge::GRAPH_SUCCESS != GetAndCheckOptionalInput()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (ge::GRAPH_SUCCESS != CheckShapeDim() || ge::GRAPH_SUCCESS != GetN1Size() ||
|
||||
ge::GRAPH_SUCCESS != GetAndCheckN2Size() || ge::GRAPH_SUCCESS != GetGSize()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (ge::GRAPH_SUCCESS != GetBatchSize() || ge::GRAPH_SUCCESS != GetS1Size() || ge::GRAPH_SUCCESS != GetHeadDim() ||
|
||||
ge::GRAPH_SUCCESS != GetS2Size()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
if (ge::GRAPH_SUCCESS != ValidateInputShapesMatch() || ge::GRAPH_SUCCESS != CheckScaleShape()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GenerateInfo(liqInfo);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// --------------------------TilingPrepare函数定义-------------------------------------
|
||||
static ge::graphStatus TilingPrepareForLightningIndexerQuant(gert::TilingParseContext * /* context */)
|
||||
{
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// --------------------------LightningIndexerQuantTiling类成员函数定义-----------------------
|
||||
ge::graphStatus LightningIndexerQuantTiling::DoTiling(LIQTilingInfo *tilingInfo)
|
||||
{
|
||||
// -------------set blockdim-----------------
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(tilingInfo->platformInfo);
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
|
||||
uint32_t blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, aicNum, aivNum);
|
||||
context_->SetBlockDim(blockDim);
|
||||
|
||||
// -------------set workspacesize-----------------
|
||||
constexpr uint32_t MM1_RES_ELEM_SIZE = 4; // 4: fp32
|
||||
constexpr uint32_t DOUBLE_BUFFER = 2; // 双Buffer
|
||||
constexpr uint32_t M_BASE_SIZE = 512; // m轴基本块大小
|
||||
constexpr uint32_t S2_BASE_SIZE = 512; // S2轴基本块大小
|
||||
constexpr uint32_t V1_RES_ELEM_SIZE = 4; // 4: int32
|
||||
constexpr uint32_t V1_RES_ELEM_TYPE = 2; // 保留Index和Value 2种数据
|
||||
constexpr uint32_t V1_DECODE_PARAM_ELEM_SIZE = 8; // 8: int64
|
||||
constexpr uint32_t V1_DECODE_PARAM_NUM = 16; // Decode参数个数
|
||||
constexpr uint32_t V1_DECODE_DATA_NUM = 2; // Decode每个核需要存储头和尾部两块数据
|
||||
constexpr uint32_t S1_BASE_SIZE = 8; // S1轴基本块的大小
|
||||
constexpr uint32_t TOPK_MAX_SIZE = 2048; // TopK选取个数
|
||||
uint32_t workspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize();
|
||||
// 主流程需Workspace大小
|
||||
uint32_t mm1ResSize = M_BASE_SIZE * S2_BASE_SIZE;
|
||||
workspaceSize += mm1ResSize * MM1_RES_ELEM_SIZE * DOUBLE_BUFFER * aicNum;
|
||||
// Decode流程(LD)需要Workspace大小
|
||||
// 临时存储Decode中间结果大小: 2(头/尾)*8(s1Base)*2(idx/value)*2048(K)*sizeof(int32)*24=6M
|
||||
workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_RES_ELEM_TYPE * TOPK_MAX_SIZE * V1_RES_ELEM_SIZE * aicNum;
|
||||
// 临时存储Decode中间参数信息大小: 2(头/尾)*8(s1Base)*16(paramNum)*sizeof(int64_t)*24=48k
|
||||
workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_DECODE_PARAM_NUM * V1_DECODE_PARAM_ELEM_SIZE * aicNum;
|
||||
size_t *workSpaces = context_->GetWorkspaceSizes(1);
|
||||
workSpaces[0] = workspaceSize;
|
||||
|
||||
// -------------set tilingdata-----------------
|
||||
tilingData_.set_bSize(tilingInfo->bSize);
|
||||
tilingData_.set_s2Size(tilingInfo->s2Size);
|
||||
tilingData_.set_s1Size(tilingInfo->s1Size);
|
||||
tilingData_.set_sparseCount(tilingInfo->sparseCount);
|
||||
tilingData_.set_gSize(tilingInfo->gSize);
|
||||
tilingData_.set_blockSize(tilingInfo->blockSize);
|
||||
tilingData_.set_maxBlockNumPerBatch(tilingInfo->maxBlockNumPerBatch);
|
||||
tilingData_.set_sparseMode(tilingInfo->sparseMode);
|
||||
tilingData_.set_usedCoreNum(blockDim);
|
||||
tilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity());
|
||||
context_->GetRawTilingData()->SetDataSize(tilingData_.GetDataSize());
|
||||
|
||||
// -------------set tilingkey-----------------
|
||||
// DT_Q, DT_KV, DT_OUT, PAGE_ATTENTION, FLASH_DECODE, LAYOUT_T, KV_LAYOUT_T
|
||||
uint32_t inputQType = static_cast<uint32_t>(tilingInfo->inputQType);
|
||||
uint32_t inputKType = static_cast<uint32_t>(tilingInfo->inputKType);
|
||||
uint32_t outputType = static_cast<uint32_t>(tilingInfo->outputType);
|
||||
uint32_t pageAttentionFlag = static_cast<uint32_t>(tilingInfo->pageAttentionFlag);
|
||||
uint32_t inputQLayout = static_cast<uint32_t>(tilingInfo->inputQLayout);
|
||||
uint32_t inputKLayout = static_cast<uint32_t>(tilingInfo->inputKLayout);
|
||||
uint32_t tilingKey =
|
||||
GET_TPL_TILING_KEY(inputQType, inputKType, outputType, pageAttentionFlag, inputQLayout, inputKLayout);
|
||||
context_->SetTilingKey(tilingKey);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// --------------------------Tiling函数定义---------------------------
|
||||
ge::graphStatus TilingForLightningIndexerQuant(gert::TilingContext *context)
|
||||
{
|
||||
OPS_ERR_IF(context == nullptr, OPS_REPORT_VECTOR_INNER_ERR("LightningIndexerQuant", "Tiling context is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
LIQTilingInfo liqInfo;
|
||||
LIQInfoParser LIQInfoParser(context);
|
||||
if (LIQInfoParser.ParseAndCheck(liqInfo) != ge::GRAPH_SUCCESS) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
LightningIndexerQuantTiling liqTiling(context);
|
||||
return liqTiling.DoTiling(&liqInfo);
|
||||
}
|
||||
|
||||
// --------------------------Tiling及函数TilingPrepare函数注册--------
|
||||
IMPL_OP_OPTILING(LightningIndexerQuant)
|
||||
.Tiling(TilingForLightningIndexerQuant)
|
||||
.TilingParse<LIQCompileInfo>(TilingPrepareForLightningIndexerQuant);
|
||||
|
||||
} // namespace optiling
|
||||
@@ -0,0 +1,234 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_tiling.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef LIGHTNING_INDEXER_QUANT_TILING_H_
|
||||
#define LIGHTNING_INDEXER_QUANT_TILING_H_
|
||||
|
||||
#include "error/ops_error.h"
|
||||
#include "exe_graph/runtime/tiling_context.h"
|
||||
#include "platform/platform_info.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "register/tilingdata_base.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "tiling/tiling_api.h"
|
||||
|
||||
namespace optiling {
|
||||
// ------------------公共定义--------------------------
|
||||
struct TilingRequiredParaInfo {
|
||||
const gert::CompileTimeTensorDesc *desc;
|
||||
const gert::StorageShape *shape;
|
||||
};
|
||||
|
||||
struct TilingOptionalParaInfo {
|
||||
const gert::CompileTimeTensorDesc *desc;
|
||||
const gert::Tensor *tensor;
|
||||
};
|
||||
|
||||
enum class DataLayout : uint32_t {
|
||||
BSND = 0,
|
||||
TND = 1,
|
||||
PA_BSND = 2
|
||||
};
|
||||
|
||||
// ------------------算子原型索引常量定义----------------
|
||||
// Inputs Index
|
||||
constexpr uint32_t QUERY_INDEX = 0;
|
||||
constexpr uint32_t KEY_INDEX = 1;
|
||||
constexpr uint32_t WEIGTHS_INDEX = 2;
|
||||
constexpr uint32_t QUERY_DEQUANT_SCALE_INDEX = 3;
|
||||
constexpr uint32_t KEY_DEQUANT_SCALE_INDEX = 4;
|
||||
constexpr uint32_t ACTUAL_SEQ_Q_INDEX = 5;
|
||||
constexpr uint32_t ACTUAL_SEQ_K_INDEX = 6;
|
||||
constexpr uint32_t BLOCK_TABLE_INDEX = 7;
|
||||
constexpr uint32_t LIGHTNING_INDEXER_QUANT = 0;
|
||||
// Attributes Index
|
||||
constexpr uint32_t ATTR_QUERY_QUANT_MODE_INDEX = 0;
|
||||
constexpr uint32_t ATTR_KEY_QUANT_MODE_INDEX = 1;
|
||||
constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 2;
|
||||
constexpr uint32_t ATTR_KEY_LAYOUT_INDEX = 3;
|
||||
constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 4;
|
||||
constexpr uint32_t ATTR_SPARSE_MODE_INDEX = 5;
|
||||
// Dim Index
|
||||
constexpr uint32_t DIM_IDX_ZERO = 0;
|
||||
constexpr uint32_t DIM_IDX_ONE = 1;
|
||||
constexpr uint32_t DIM_IDX_TWO = 2;
|
||||
constexpr uint32_t DIM_IDX_THREE = 3;
|
||||
// Dim Num
|
||||
constexpr uint32_t DIM_NUM_TWO = 2;
|
||||
constexpr uint32_t DIM_NUM_THREE = 3;
|
||||
constexpr uint32_t DIM_NUM_FOUR = 4;
|
||||
// 入参限制常量
|
||||
constexpr uint32_t HEAD_DIM_LIMIT = 128;
|
||||
constexpr uint32_t SPARSE_LIMIT = 2048;
|
||||
constexpr uint32_t G_SIZE_LIMIT = 64;
|
||||
constexpr uint32_t BLOCK_SIZE_LIMIT = 1024;
|
||||
constexpr uint32_t BLOCK_SIZE_FACTOR = 16;
|
||||
constexpr uint32_t SPARSE_MODE_LOWER = 3;
|
||||
|
||||
// -----------算子TilingData定义---------------
|
||||
BEGIN_TILING_DATA_DEF(LIQTilingData)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, bSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, n2Size)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, gSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, s1Size)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, s2Size)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, sparseCount)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, blockSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, sparseMode)
|
||||
END_TILING_DATA_DEF
|
||||
REGISTER_TILING_DATA_CLASS(LightningIndexerQuant, LIQTilingData)
|
||||
|
||||
// -----------算子CompileInfo定义-------------------
|
||||
struct LIQCompileInfo {};
|
||||
|
||||
// -----------算子Tiling入参结构体定义---------------
|
||||
struct LIQParaInfo {
|
||||
TilingRequiredParaInfo query = {nullptr, nullptr};
|
||||
TilingRequiredParaInfo key = {nullptr, nullptr};
|
||||
TilingRequiredParaInfo weights = {nullptr, nullptr};
|
||||
TilingRequiredParaInfo query_dequant_scale = {nullptr, nullptr};
|
||||
TilingRequiredParaInfo key_dequant_scale = {nullptr, nullptr};
|
||||
TilingOptionalParaInfo actualSeqLengthsQ = {nullptr, nullptr};
|
||||
TilingOptionalParaInfo actualSeqLengthsK = {nullptr, nullptr};
|
||||
TilingOptionalParaInfo blockTable = {nullptr, nullptr};
|
||||
TilingRequiredParaInfo attenOut = {nullptr, nullptr};
|
||||
|
||||
const int32_t *queryQuantMode = nullptr;
|
||||
const int32_t *keyQuantMode = nullptr;
|
||||
const char *layOutQuery = nullptr;
|
||||
const char *layOutKey = nullptr;
|
||||
const int32_t *blockSize = nullptr;
|
||||
const int32_t *sparseMode = nullptr;
|
||||
const int32_t *sparseCount = nullptr;
|
||||
};
|
||||
|
||||
// -----------算子Tiling入参信息类---------------
|
||||
class LIQTilingInfo {
|
||||
public:
|
||||
const char *opName = nullptr;
|
||||
fe::PlatFormInfos *platformInfo = nullptr;
|
||||
LIQParaInfo opParamInfo;
|
||||
// Base Param
|
||||
platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910B;
|
||||
uint32_t bSize = 0;
|
||||
uint32_t n1Size = 0;
|
||||
uint32_t n2Size = 0;
|
||||
uint32_t s1Size = 0;
|
||||
int64_t s2Size = 0;
|
||||
uint32_t qkHeadDim = 0;
|
||||
uint32_t gSize = 0;
|
||||
// PageAttention
|
||||
bool pageAttentionFlag = false;
|
||||
int32_t blockSize = 0;
|
||||
uint32_t maxBlockNumPerBatch = 0;
|
||||
// Mask
|
||||
int32_t sparseMode = 0;
|
||||
// Others Flag
|
||||
uint32_t sparseCount = 0;
|
||||
// DType
|
||||
ge::DataType inputQType = ge::DT_FLOAT16;
|
||||
ge::DataType inputKType = ge::DT_FLOAT16;
|
||||
ge::DataType outputType = ge::DT_INT32;
|
||||
// Layout
|
||||
DataLayout inputQLayout = DataLayout::BSND;
|
||||
DataLayout inputKLayout = DataLayout::PA_BSND;
|
||||
};
|
||||
|
||||
// -----------算子Tiling入参信息解析及Check类---------------
|
||||
class LIQInfoParser {
|
||||
public:
|
||||
explicit LIQInfoParser(gert::TilingContext *context) : context_(context) {}
|
||||
~LIQInfoParser() = default;
|
||||
|
||||
ge::graphStatus CheckRequiredInOutExistence() const;
|
||||
ge::graphStatus CheckRequiredAttrExistence() const;
|
||||
ge::graphStatus CheckRequiredParaExistence() const;
|
||||
ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor,
|
||||
const std::string &actualSeqLenName);
|
||||
ge::graphStatus GetOpName();
|
||||
ge::graphStatus GetNpuInfo();
|
||||
void GetOptionalInputParaInfo();
|
||||
void GetInputParaInfo();
|
||||
void GetOutputParaInfo();
|
||||
ge::graphStatus GetAttrParaInfo();
|
||||
ge::graphStatus CheckAttrParaInfo();
|
||||
ge::graphStatus GetOpParaInfo();
|
||||
ge::graphStatus ValidateInputShapesMatch();
|
||||
ge::graphStatus CheckScaleShape();
|
||||
ge::graphStatus GetAndCheckInOutDataType();
|
||||
ge::graphStatus GetBatchSize();
|
||||
ge::graphStatus GetHeadDim();
|
||||
ge::graphStatus GetS1Size();
|
||||
ge::graphStatus GetAndCheckOptionalInput();
|
||||
ge::graphStatus CheckShapeDim();
|
||||
ge::graphStatus GetAndCheckBlockSize();
|
||||
ge::graphStatus GetS2SizeForPageAttention();
|
||||
ge::graphStatus GetS2SizeForBatchContinuous();
|
||||
ge::graphStatus GetS2Size();
|
||||
ge::graphStatus GetQueryKeyAndOutLayout();
|
||||
ge::graphStatus GetN1Size();
|
||||
ge::graphStatus GetAndCheckN2Size();
|
||||
ge::graphStatus GetGSize();
|
||||
ge::graphStatus GetAttenMaskInfo();
|
||||
ge::graphStatus GetActualSeqInfo();
|
||||
void GenerateInfo(LIQTilingInfo &liqInfo);
|
||||
ge::graphStatus ParseAndCheck(LIQTilingInfo &liqInfo);
|
||||
|
||||
public:
|
||||
gert::TilingContext *context_ = nullptr;
|
||||
const char *opName_;
|
||||
fe::PlatFormInfos *platformInfo_;
|
||||
LIQParaInfo opParamInfo_;
|
||||
|
||||
// BaseParams
|
||||
uint32_t bSize_ = 0;
|
||||
uint32_t n1Size_ = 0;
|
||||
uint32_t n2Size_ = 0;
|
||||
uint32_t gSize_ = 0;
|
||||
uint32_t s1Size_ = 0;
|
||||
int64_t s2Size_ = 0;
|
||||
uint32_t headDim_ = 0;
|
||||
// Layout
|
||||
DataLayout qLayout_ = DataLayout::BSND;
|
||||
DataLayout kLayout_ = DataLayout::PA_BSND;
|
||||
// PageAttention
|
||||
uint32_t maxBlockNumPerBatch_ = 0;
|
||||
int32_t blockSize_ = 0;
|
||||
platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B;
|
||||
ge::DataType inputQType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputKType_ = ge::DT_FLOAT16;
|
||||
ge::DataType weightsType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputQueryScaleType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputKeyScaleType_ = ge::DT_FLOAT16;
|
||||
ge::DataType blockTableType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputKRopeType_ = ge::DT_FLOAT16;
|
||||
ge::DataType outputType_ = ge::DT_FLOAT16;
|
||||
};
|
||||
|
||||
// ---------------算子Tiling类---------------
|
||||
class LightningIndexerQuantTiling {
|
||||
public:
|
||||
explicit LightningIndexerQuantTiling(gert::TilingContext *context) : context_(context) {};
|
||||
ge::graphStatus DoTiling(LIQTilingInfo *tilingInfo);
|
||||
|
||||
private:
|
||||
gert::TilingContext *context_ = nullptr;
|
||||
LIQTilingData tilingData_;
|
||||
};
|
||||
|
||||
} // namespace optiling
|
||||
#endif // LIGHTNING_INDEXER_QUANT_TILING_H_
|
||||
Reference in New Issue
Block a user