[kernel] add AscendC op: lightning_indexer and sparse_flash_attention (#4625)
### What this PR does / why we need it? Provide high-performance AscendC operators lightning_indexer and sparse_flash_attention to boost the execution performance of the DeepSeek v3.2 model. Meanwhile, adapt the two AscendC operators to vllm-ascend framework. ### Does this PR introduce _any_ user-facing change? No (only underlying operator optimizations, with no user-facing changes) ### How was this patch tested? - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: MingYang119 <songmingyang@huawei.com>
This commit is contained in:
@@ -11,11 +11,11 @@ if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then
|
||||
exit 0
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
|
||||
# ASCEND910B (A2) series
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list"
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention"
|
||||
SOC_ARG="ascend910b"
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
# ASCEND910C (A3) series
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list"
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention"
|
||||
SOC_ARG="ascend910_93"
|
||||
else
|
||||
# others
|
||||
|
||||
42
csrc/lightning_indexer/op_host/CMakeLists.txt
Normal file
42
csrc/lightning_indexer/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,42 @@
|
||||
# This program is free software, you can redistribute it and/or modify it.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME LightningIndexer
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
-mllvm -cce-aicore-hoist-movemask=false
|
||||
--op_relocatable_kernel_binary=true
|
||||
)
|
||||
|
||||
set(lightning_indexer_depends transformer/attention/lightning_indexer PARENT_SCOPE)
|
||||
|
||||
target_sources(op_host_aclnn PRIVATE
|
||||
lightning_indexer_def.cpp
|
||||
)
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
lightning_indexer_tiling.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(opmaster_ct PRIVATE
|
||||
lightning_indexer_tiling.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
lightning_indexer_proto.cpp
|
||||
)
|
||||
|
||||
72
csrc/lightning_indexer/op_host/lightning_indexer_def.cpp
Normal file
72
csrc/lightning_indexer/op_host/lightning_indexer_def.cpp
Normal file
@@ -0,0 +1,72 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_def.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include <cstdint>
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class LightningIndexer : public OpDef {
|
||||
public:
|
||||
explicit LightningIndexer(const char *name) : OpDef(name)
|
||||
{
|
||||
this->Input("query")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.FormatList({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("key")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.FormatList({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("weights")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.FormatList({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("actual_seq_lengths_query")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
||||
.FormatList({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("actual_seq_lengths_key")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
||||
.FormatList({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("block_table")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataTypeList({ge::DT_INT32})
|
||||
.FormatList({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Output("sparse_indices").ParamType(REQUIRED).DataTypeList({ge::DT_INT32}).FormatList({ge::FORMAT_ND});
|
||||
this->Attr("layout_query").AttrType(OPTIONAL).String("BSND");
|
||||
this->Attr("layout_key").AttrType(OPTIONAL).String("PA_BSND");
|
||||
this->Attr("sparse_count").AttrType(OPTIONAL).Int(2048); // 2048: Default value, filter the top 2048
|
||||
this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3); // 3: Default value, only calculate the lower triangular matrix
|
||||
OpAICoreConfig aicore_config;
|
||||
aicore_config.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(true)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.NeedCheckSupportFlag(false)
|
||||
.PrecisionReduceFlag(true)
|
||||
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
|
||||
.ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false");
|
||||
this->AICore().AddConfig("ascend910b", aicore_config);
|
||||
this->AICore().AddConfig("ascend910_93", aicore_config);
|
||||
}
|
||||
};
|
||||
OP_ADD(LightningIndexer);
|
||||
} // namespace ops
|
||||
96
csrc/lightning_indexer/op_host/lightning_indexer_proto.cpp
Normal file
96
csrc/lightning_indexer/op_host/lightning_indexer_proto.cpp
Normal file
@@ -0,0 +1,96 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_proto.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include <graph/utils/type_utils.h>
|
||||
#include <register/op_impl_registry.h>
|
||||
#include "error/ops_error.h"
|
||||
|
||||
|
||||
using namespace ge;
|
||||
|
||||
namespace ops {
|
||||
constexpr uint32_t QUERY_INDEX = 0;
|
||||
constexpr uint32_t KEY_INDEX = 1;
|
||||
constexpr uint32_t ACTUAL_SEQ_K_INDEX = 4;
|
||||
constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 0;
|
||||
constexpr uint32_t ATTR_KEY_LAYOUT_INDEX = 1;
|
||||
constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 2;
|
||||
|
||||
static ge::graphStatus InferShapeLightningIndexer(gert::InferShapeContext *context)
|
||||
{
|
||||
OPS_ERR_IF(context == nullptr, OPS_LOG_E("LightningIndexer", "InferShapeContext is nullptr!"),
|
||||
return ge::GRAPH_FAILED);
|
||||
const gert::Shape *queryShape = context->GetInputShape(QUERY_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, queryShape, return ge::GRAPH_FAILED);
|
||||
const gert::Shape *keyShape = context->GetInputShape(KEY_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, keyShape, return ge::GRAPH_FAILED);
|
||||
gert::Shape *outShape = context->GetOutputShape(0);
|
||||
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED);
|
||||
const char *inputLayoutQueryPtr = attrs->GetAttrPointer<char>(ATTR_QUERY_LAYOUT_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, inputLayoutQueryPtr, return ge::GRAPH_FAILED);
|
||||
const char *inputLayoutKeyPtr = attrs->GetAttrPointer<char>(ATTR_KEY_LAYOUT_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, inputLayoutKeyPtr, return ge::GRAPH_FAILED);
|
||||
const int64_t *seleced_count = attrs->GetInt(ATTR_SPARSE_COUNT_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, seleced_count, return ge::GRAPH_FAILED);
|
||||
std::string inputLayoutQueryPtrStr = std::string(inputLayoutQueryPtr);
|
||||
std::string inputLayoutKeyPtrStr = std::string(inputLayoutKeyPtr);
|
||||
OPS_ERR_IF(
|
||||
inputLayoutQueryPtrStr != "TND" && inputLayoutQueryPtrStr != "BSND",
|
||||
OPS_LOG_E(context, "The attr layout_query should be TND or BSND, but got %s.", inputLayoutQueryPtrStr.c_str()),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
outShape->SetDimNum(queryShape->GetDimNum());
|
||||
if (inputLayoutQueryPtrStr == "BSND") {
|
||||
OPS_ERR_IF(
|
||||
queryShape->GetDimNum() != 4,
|
||||
OPS_LOG_E(context, "Layout BSND, queryDims (%zu) must be 4!", queryShape->GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim B
|
||||
outShape->SetDim(1, queryShape->GetDim(1)); // 1:Dim S
|
||||
outShape->SetDim(2, keyShape->GetDim(2)); // 2:Dim N
|
||||
outShape->SetDim(3, *seleced_count); // 3:Dim K
|
||||
} else {
|
||||
OPS_ERR_IF(
|
||||
queryShape->GetDimNum() != 3,
|
||||
OPS_LOG_E(context, "Layout TND, queryDims (%zu) must be 3!", queryShape->GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim T
|
||||
int32_t nDimIndex = (inputLayoutKeyPtrStr == "PA_BSND") ? 2 : 1; // 2:Key Dim N
|
||||
outShape->SetDim(1, keyShape->GetDim(nDimIndex)); // 1:Dim N
|
||||
outShape->SetDim(2, *seleced_count); // 2:Dim K
|
||||
}
|
||||
OPS_LOG_D(context->GetNodeName(), "LightningIndexer InferShape end.");
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus InferDataTypeLightningIndexer(gert::InferDataTypeContext *context)
|
||||
{
|
||||
OPS_ERR_IF(context == nullptr, OPS_LOG_E("LightningIndexer", "InferDataTypeContext is nullptr!"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_LOG_D(context->GetNodeName(), "Enter LightningIndexer InferDataType impl.");
|
||||
// default set q's dtype as fia's output type
|
||||
ge::DataType outputType = ge::DT_INT32;
|
||||
// attention_out, outidx:0
|
||||
context->SetOutputDataType(0, outputType);
|
||||
OPS_LOG_D(context->GetNodeName(), "LightningIndexer InferDataType end.");
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_INFERSHAPE(LightningIndexer)
|
||||
.InferShape(InferShapeLightningIndexer)
|
||||
.InferDataType(InferDataTypeLightningIndexer);
|
||||
} // namespace ops
|
||||
694
csrc/lightning_indexer/op_host/lightning_indexer_tiling.cpp
Normal file
694
csrc/lightning_indexer/op_host/lightning_indexer_tiling.cpp
Normal file
@@ -0,0 +1,694 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_tiling.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "lightning_indexer_tiling.h"
|
||||
#include "../op_kernel/lightning_indexer_template_tiling_key.h"
|
||||
|
||||
using namespace ge;
|
||||
using namespace AscendC;
|
||||
using std::map;
|
||||
using std::string;
|
||||
namespace optiling {
|
||||
ge::graphStatus LIInfoParser::CheckRequiredInOutExistence() const
|
||||
{
|
||||
OPS_ERR_IF(opParamInfo_.query.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor query is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.query.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor query is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.key.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor k is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.key.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor k is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.weights.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor value is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.weights.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor value is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.attenOut.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor output is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.attenOut.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor output is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::CheckRequiredAttrExistence() const
|
||||
{
|
||||
OPS_ERR_IF(opParamInfo_.layOut == nullptr, OPS_LOG_E(opName_, "attr layout_query is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(opParamInfo_.layOutKey == nullptr, OPS_LOG_E(opName_, "attr layout_key is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(opParamInfo_.sparseCount == nullptr, OPS_LOG_E(opName_, "attr sparse_count is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(opParamInfo_.sparseMode == nullptr, OPS_LOG_E(opName_, "attr sparse_mode is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::CheckRequiredParaExistence() const
|
||||
{
|
||||
if (CheckRequiredInOutExistence() != ge::GRAPH_SUCCESS || CheckRequiredAttrExistence() != ge::GRAPH_SUCCESS) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetOpName()
|
||||
{
|
||||
if (context_->GetNodeName() == nullptr) {
|
||||
OPS_LOG_E("LightningIndexer", "opName got from TilingContext is nullptr");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
opName_ = context_->GetNodeName();
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetNpuInfo()
|
||||
{
|
||||
platformInfo_ = context_->GetPlatformInfo();
|
||||
OPS_ERR_IF(platformInfo_ == nullptr, OPS_LOG_E(opName_, "GetPlatformInfo is nullptr."), return ge::GRAPH_FAILED);
|
||||
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo_);
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
|
||||
OPS_ERR_IF(aicNum == 0 || aivNum == 0, OPS_LOG_E(opName_, "num of core obtained is 0."), return GRAPH_FAILED);
|
||||
|
||||
socVersion_ = ascendcPlatform.GetSocVersion();
|
||||
if ((socVersion_ != platform_ascendc::SocVersion::ASCEND910B) &&
|
||||
(socVersion_ != platform_ascendc::SocVersion::ASCEND910_93)) {
|
||||
OPS_LOG_E(opName_, "SOC Version[%d] is not support.", (int32_t)socVersion_);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
OPS_ERR_IF(context_->GetWorkspaceSizes(1) == nullptr, OPS_LOG_E(opName_, "workSpaceSize got from ge is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(context_->GetRawTilingData() == nullptr,
|
||||
OPS_LOG_E(context_->GetNodeName(), "RawTilingData got from GE context is nullptr."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
void LIInfoParser::GetOptionalInputParaInfo()
|
||||
{
|
||||
opParamInfo_.actualSeqLengthsQ.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_Q_INDEX);
|
||||
opParamInfo_.actualSeqLengthsQ.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_Q_INDEX);
|
||||
opParamInfo_.actualSeqLengths.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_K_INDEX);
|
||||
opParamInfo_.actualSeqLengths.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_K_INDEX);
|
||||
opParamInfo_.blockTable.tensor = context_->GetOptionalInputTensor(BLOCK_TABLE_INDEX);
|
||||
opParamInfo_.blockTable.desc = context_->GetOptionalInputDesc(BLOCK_TABLE_INDEX);
|
||||
}
|
||||
|
||||
void LIInfoParser::GetInputParaInfo()
|
||||
{
|
||||
opParamInfo_.query.desc = context_->GetInputDesc(QUERY_INDEX);
|
||||
opParamInfo_.query.shape = context_->GetInputShape(QUERY_INDEX);
|
||||
opParamInfo_.key.desc = context_->GetInputDesc(KEY_INDEX);
|
||||
opParamInfo_.key.shape = context_->GetInputShape(KEY_INDEX);
|
||||
opParamInfo_.weights.desc = context_->GetInputDesc(WEIGTHS_INDEX);
|
||||
opParamInfo_.weights.shape = context_->GetInputShape(WEIGTHS_INDEX);
|
||||
GetOptionalInputParaInfo();
|
||||
}
|
||||
|
||||
void LIInfoParser::GetOutputParaInfo()
|
||||
{
|
||||
opParamInfo_.attenOut.desc = context_->GetOutputDesc(LIGHTNING_INDEXER);
|
||||
opParamInfo_.attenOut.shape = context_->GetOutputShape(LIGHTNING_INDEXER);
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetAndCheckAttrParaInfo()
|
||||
{
|
||||
auto attrs = context_->GetAttrs();
|
||||
OPS_ERR_IF(attrs == nullptr, OPS_REPORT_VECTOR_INNER_ERR(context_->GetNodeName(), "attrs got from ge is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_LOG_I(context_->GetNodeName(), "GetAndCheckAttrParaInfo start");
|
||||
opParamInfo_.layOut = attrs->GetStr(ATTR_QUERY_LAYOUT_INDEX);
|
||||
opParamInfo_.layOutKey = attrs->GetStr(ATTR_KEY_LAYOUT_INDEX);
|
||||
opParamInfo_.sparseCount = attrs->GetAttrPointer<int32_t>(ATTR_SPARSE_COUNT_INDEX);
|
||||
opParamInfo_.sparseMode = attrs->GetAttrPointer<int32_t>(ATTR_SPARSE_MODE_INDEX);
|
||||
|
||||
if (opParamInfo_.layOut != nullptr) {
|
||||
OPS_LOG_I(context_->GetNodeName(), "layout_query is:%s", opParamInfo_.layOut);
|
||||
}
|
||||
if (opParamInfo_.layOutKey != nullptr) {
|
||||
OPS_LOG_I(context_->GetNodeName(), "layout_key is:%s", opParamInfo_.layOutKey);
|
||||
}
|
||||
if (opParamInfo_.sparseCount != nullptr) {
|
||||
OPS_LOG_I(context_->GetNodeName(), "selscted count is:%d", *opParamInfo_.sparseCount);
|
||||
}
|
||||
if (opParamInfo_.sparseMode != nullptr) {
|
||||
OPS_LOG_I(context_->GetNodeName(), "sparse mode is:%d", *opParamInfo_.sparseMode);
|
||||
}
|
||||
OPS_LOG_I(context_->GetNodeName(), "GetAndCheckAttrParaInfo end");
|
||||
|
||||
OPS_ERR_IF(
|
||||
((std::string(opParamInfo_.layOutKey) != "PA_BSND")
|
||||
&& (std::string(opParamInfo_.layOut) != std::string(opParamInfo_.layOutKey))),
|
||||
OPS_LOG_E(opName_, "under non-PA conditions, layout_query and layout_key should be equal."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(
|
||||
((std::string(opParamInfo_.layOutKey) != "PA_BSND") && (std::string(opParamInfo_.layOutKey) != "BSND")
|
||||
&& (std::string(opParamInfo_.layOutKey) != "TND")),
|
||||
OPS_LOG_E(opName_, "input attr layout_key only supported PA_BSND, BSND or TND"), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(((std::string(opParamInfo_.layOut) != "BSND") && (std::string(opParamInfo_.layOut) != "TND")),
|
||||
OPS_LOG_E(opName_, "input attr layout_query only supported BSND or TND."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(!((*opParamInfo_.sparseCount > 0) && (*opParamInfo_.sparseCount <= SPARSE_LIMIT)),
|
||||
OPS_LOG_E(opName_, "input attr sparse_count must > 0 and <= 2048."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(!((*opParamInfo_.sparseMode == 0) || (*opParamInfo_.sparseMode == SPARSE_MODE_LOWER)),
|
||||
OPS_LOG_E(opName_, "input attr sparse_mode only supported 0 or 3."), return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetOpParaInfo()
|
||||
{
|
||||
GetInputParaInfo();
|
||||
GetOutputParaInfo();
|
||||
if (ge::GRAPH_SUCCESS != GetAndCheckAttrParaInfo()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetAndCheckInOutDataType()
|
||||
{
|
||||
inputQType_ = opParamInfo_.query.desc->GetDataType();
|
||||
inputKType_ = opParamInfo_.key.desc->GetDataType();
|
||||
weightsType_ = opParamInfo_.weights.desc->GetDataType();
|
||||
outputType_ = opParamInfo_.attenOut.desc->GetDataType();
|
||||
|
||||
bool inDTypeAllEqual = (inputQType_ == inputKType_) && (inputKType_ == weightsType_);
|
||||
OPS_ERR_IF(!inDTypeAllEqual,
|
||||
OPS_LOG_E(opName_, "The data types of the input query, key, and weights must be the same."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(((inputQType_ != ge::DT_FLOAT16) && (inputQType_ != ge::DT_BF16)),
|
||||
OPS_LOG_E(opName_, "The data types of the input query, key, and weights must be float16 or bfloat16."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(outputType_ != ge::DT_INT32,
|
||||
OPS_LOG_E(opName_, "The data types of the output sparse_indices must be int32."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetQueryKeyAndOutLayout()
|
||||
{
|
||||
const map<string, DataLayout> layoutMap = {
|
||||
{"BSND", DataLayout::BSND},
|
||||
{"TND", DataLayout::TND},
|
||||
{"PA_BSND", DataLayout::BnBsND}
|
||||
};
|
||||
|
||||
std::string layout(opParamInfo_.layOut);
|
||||
auto it = layoutMap.find(layout);
|
||||
if (it != layoutMap.end()) {
|
||||
qLayout_ = it->second;
|
||||
}
|
||||
|
||||
std::string layoutKey(opParamInfo_.layOutKey);
|
||||
auto itKey = layoutMap.find(layoutKey);
|
||||
if (itKey != layoutMap.end()) {
|
||||
kLayout_ = itKey->second;
|
||||
}
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetAndCheckOptionalInput()
|
||||
{
|
||||
if (kLayout_ == DataLayout::BnBsND) {
|
||||
OPS_ERR_IF(opParamInfo_.blockTable.tensor == nullptr,
|
||||
OPS_LOG_E(opName_, "key layout only supported PA_BSND, input block_table must not be null"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(
|
||||
opParamInfo_.actualSeqLengths.tensor == nullptr,
|
||||
OPS_LOG_E(opName_, "key layout only supported PA_BSND, input actual_seq_lengths_key must not be null"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.blockTable.desc->GetDataType() != ge::DT_INT32,
|
||||
OPS_LOG_E(opName_, "input block_table data type only support int32"), return ge::GRAPH_FAILED);
|
||||
} else if (kLayout_ == DataLayout::TND) {
|
||||
OPS_ERR_IF(opParamInfo_.actualSeqLengths.tensor == nullptr,
|
||||
OPS_LOG_E(opName_, "when layout_key is TND, input actual_seq_lengths_key must not be null"),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
OPS_ERR_IF(opParamInfo_.actualSeqLengths.tensor != nullptr &&
|
||||
opParamInfo_.actualSeqLengths.desc->GetDataType() != ge::DT_INT32,
|
||||
OPS_LOG_E(opName_, "input actual_seq_lengths_key data type only support int32"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.actualSeqLengths.tensor != nullptr &&
|
||||
opParamInfo_.actualSeqLengths.desc->GetDataType() != ge::DT_INT32,
|
||||
OPS_LOG_E(opName_, "input actual_seq_lengths_key data type only support int32"),
|
||||
return ge::GRAPH_FAILED);
|
||||
if (qLayout_ == DataLayout::TND) {
|
||||
OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor == nullptr,
|
||||
OPS_LOG_E(opName_, "when layout_query is TND, input actual_seq_lengths_query must not be null"),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor != nullptr &&
|
||||
opParamInfo_.actualSeqLengthsQ.desc->GetDataType() != ge::DT_INT32,
|
||||
OPS_LOG_E(opName_, "input actual_seq_lengths_query data type only support int32"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(kLayout_ != DataLayout::BnBsND && opParamInfo_.blockTable.tensor != nullptr,
|
||||
OPS_LOG_E(opName_, "when key layout is not PA_BSND, input block_table must be null"),
|
||||
return ge::GRAPH_FAILED);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::CheckShapeDim()
|
||||
{
|
||||
OPS_ERR_IF((opParamInfo_.blockTable.tensor != nullptr) &&
|
||||
(opParamInfo_.blockTable.tensor->GetStorageShape().GetDimNum() != DIM_NUM_TWO),
|
||||
OPS_LOG_E(opName_, "the dim num of block_table's shape should be 2"), return ge::GRAPH_FAILED);
|
||||
|
||||
uint32_t kShapeDim = opParamInfo_.key.shape->GetStorageShape().GetDimNum();
|
||||
uint32_t qShapeDim = opParamInfo_.query.shape->GetStorageShape().GetDimNum();
|
||||
uint32_t weightsShapeDim = opParamInfo_.weights.shape->GetStorageShape().GetDimNum();
|
||||
uint32_t outShapeDim = opParamInfo_.attenOut.shape->GetStorageShape().GetDimNum();
|
||||
uint32_t qExpectShapeDim = DIM_NUM_FOUR;
|
||||
uint32_t kExpectShapeDim = DIM_NUM_FOUR;
|
||||
if (qLayout_ == DataLayout::TND) {
|
||||
qExpectShapeDim = DIM_NUM_THREE;
|
||||
}
|
||||
if (kLayout_ == DataLayout::TND) {
|
||||
kExpectShapeDim = DIM_NUM_THREE;
|
||||
}
|
||||
OPS_ERR_IF(kShapeDim != kExpectShapeDim,
|
||||
OPS_LOG_E(opName_, "the dim num of key's shape should be %u, but now is %u", kExpectShapeDim, kShapeDim),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(qShapeDim != qExpectShapeDim,
|
||||
OPS_LOG_E(opName_, "the dim num of query's shape should be %u, but now is %u",
|
||||
qExpectShapeDim, qShapeDim),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(outShapeDim != qExpectShapeDim,
|
||||
OPS_LOG_E(opName_, "the dim num of sparse_indices's shape should be %u, but now is %u",
|
||||
qExpectShapeDim, outShapeDim),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(!(weightsShapeDim == qExpectShapeDim - 1),
|
||||
OPS_LOG_E(opName_, "the dim num of weights's shape should be %u, but now is %u", qExpectShapeDim - 1,
|
||||
weightsShapeDim),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetN1Size()
|
||||
{
|
||||
if (qLayout_ == DataLayout::BSND) {
|
||||
n1Size_ = static_cast<uint32_t>(opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_TWO));
|
||||
} else {
|
||||
// TND
|
||||
n1Size_ = static_cast<uint32_t>(opParamInfo_.query.shape->GetStorageShape().GetDim(1));
|
||||
}
|
||||
OPS_LOG_I(context_->GetNodeName(), "n1Size is %d", n1Size_);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor,
|
||||
const std::string &actualSeqLenName)
|
||||
{
|
||||
size = static_cast<uint32_t>(tensor->GetShapeSize());
|
||||
if (size <= 0) {
|
||||
OPS_LOG_E(opName_, "%s's shape size is %u, it should be greater than 0.", actualSeqLenName.c_str(), size);
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetAndCheckN2Size()
|
||||
{
|
||||
uint32_t n2Index = (kLayout_ == DataLayout::TND) ? DIM_IDX_ONE : DIM_IDX_TWO;
|
||||
n2Size_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(n2Index));
|
||||
OPS_LOG_I(context_->GetNodeName(), "n2Size_ is %d", n2Size_);
|
||||
OPS_ERR_IF(n2Size_ != 1, OPS_LOG_E(opName_, "key shape[%u] is numhead, only support 1.", n2Index),
|
||||
return ge::GRAPH_FAILED);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetGSize()
|
||||
{
|
||||
if (n1Size_ % n2Size_ != 0) {
|
||||
OPS_LOG_E(opName_, "input query's head_num %u can not be a multiple of key's head_num %u.", n1Size_, n2Size_);
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
gSize_ = n1Size_ / n2Size_;
|
||||
OPS_ERR_IF(gSize_ != 64, OPS_LOG_E(opName_, "N1 is %u, N2 is %u, N1 divided by N2 must equal 64.",
|
||||
n1Size_, n2Size_), return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetBatchSize()
|
||||
{
|
||||
if ((qLayout_ == DataLayout::TND)) {
|
||||
return GetActualSeqLenSize(bSize_, opParamInfo_.actualSeqLengthsQ.tensor, "input actual_seq_lengths_query");
|
||||
} else { // BSND
|
||||
bSize_ = opParamInfo_.query.shape->GetStorageShape().GetDim(0);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetHeadDim()
|
||||
{
|
||||
uint32_t dIndex = DIM_IDX_TWO;
|
||||
switch (qLayout_) {
|
||||
case DataLayout::TND:
|
||||
// TND: [Total, N, D] -> D is the 2nd dimension
|
||||
dIndex = DIM_IDX_TWO;
|
||||
break;
|
||||
case DataLayout::BSND:
|
||||
// BSND: [Batch, SeqLen, N, D] -> D is the 3nd dimension
|
||||
dIndex = DIM_IDX_THREE;
|
||||
break;
|
||||
default:
|
||||
OPS_LOG_E(opName_, "unsupported layout for getting head dim.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
headDim_ = opParamInfo_.query.shape->GetStorageShape().GetDim(dIndex);
|
||||
OPS_ERR_IF(headDim_ != HEAD_DIM_LIMIT, OPS_LOG_E(opName_, "input query's last dim head_dim only support 128."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetS1Size()
|
||||
{
|
||||
if (qLayout_ == DataLayout::BSND) {
|
||||
s1Size_ = opParamInfo_.query.shape->GetStorageShape().GetDim(1);
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetAndCheckBlockSize()
|
||||
{
|
||||
blockSize_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(1));
|
||||
OPS_LOG_I(context_->GetNodeName(), "blockSize_ is %d", blockSize_);
|
||||
|
||||
OPS_ERR_IF(((blockSize_ % 16 != 0) || (blockSize_ == 0) || (blockSize_ > 1024)),
|
||||
OPS_LOG_E(opName_, "input key's block_size must be a multiple of 16 and belong to (0, 1024]."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::CheckBlockCount()
|
||||
{
|
||||
int32_t blockCount_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(0));
|
||||
OPS_ERR_IF((blockCount_ == 0),
|
||||
OPS_LOG_E(opName_, "input key's block_count cannot be 0."),
|
||||
return ge::GRAPH_FAILED);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetS2SizeForPageAttention()
|
||||
{
|
||||
if (GetAndCheckBlockSize() != ge::GRAPH_SUCCESS) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
if (CheckBlockCount() != ge::GRAPH_SUCCESS) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
maxBlockNumPerBatch_ = opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(1);
|
||||
s2Size_ = maxBlockNumPerBatch_ * blockSize_;
|
||||
OPS_LOG_I(context_->GetNodeName(), "maxBlockNumPerBatch_ is %d, blockSize_ is %d, s2Size_ is %d",
|
||||
maxBlockNumPerBatch_, blockSize_, s2Size_);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::GetS2Size()
|
||||
{
|
||||
if (kLayout_ == DataLayout::BnBsND) {
|
||||
return GetS2SizeForPageAttention();
|
||||
} else if (kLayout_ == DataLayout::TND) {
|
||||
s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(0);
|
||||
} else if (kLayout_ == DataLayout::BSND) {
|
||||
s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(1);
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::ValidateInputShapesMatchQTnd()
|
||||
{
|
||||
// -----------------------check BatchSize-------------------
|
||||
if (kLayout_ == DataLayout::TND) {
|
||||
OPS_ERR_IF(
|
||||
(opParamInfo_.actualSeqLengths.tensor->GetShapeSize() != bSize_),
|
||||
OPS_LOG_E(opName_,
|
||||
"TND case input actual_seq_lengths_query, actual_seq_lengths_key are %u, %ld respectively, they must be same.",
|
||||
bSize_, opParamInfo_.actualSeqLengths.tensor->GetShapeSize()),
|
||||
return ge::GRAPH_FAILED);
|
||||
} else { // kLayout_ PA_BSND
|
||||
OPS_ERR_IF(
|
||||
(opParamInfo_.actualSeqLengths.tensor->GetShapeSize() != bSize_) ||
|
||||
(opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_),
|
||||
OPS_LOG_E(
|
||||
opName_,
|
||||
"TND case input actual_seq_lengths_query, actual_seq_lengths_key, block_table dim 0 are %u, %ld, %ld respectively, they must be same.",
|
||||
bSize_, opParamInfo_.actualSeqLengths.tensor->GetShapeSize(),
|
||||
opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
// -----------------------check T-------------------
|
||||
uint32_t qTsize = opParamInfo_.query.shape->GetStorageShape().GetDim(0);
|
||||
OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != qTsize) ||
|
||||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != qTsize),
|
||||
OPS_LOG_E(opName_, "TND case input query, weights, sparse_indices dim 0 are %u, %ld, %ld respectively, they must be same.",
|
||||
qTsize, opParamInfo_.weights.shape->GetStorageShape().GetDim(0),
|
||||
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::ValidateInputShapesMatchQBsnd()
|
||||
{
|
||||
// -----------------------check BatchSize-------------------
|
||||
if (kLayout_ == DataLayout::BnBsND) {
|
||||
OPS_ERR_IF((opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_) ||
|
||||
(opParamInfo_.actualSeqLengths.tensor->GetShapeSize() != bSize_),
|
||||
OPS_LOG_E(opName_, "BSND case input query, actual_seq_lengths_key, block_table dim 0 are %u, %ld, %ld respectively, they must be same.",
|
||||
bSize_, opParamInfo_.actualSeqLengths.tensor->GetShapeSize(),
|
||||
opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
} else if (kLayout_ == DataLayout::BSND) {
|
||||
OPS_ERR_IF(opParamInfo_.key.shape->GetStorageShape().GetDim(0) != bSize_,
|
||||
OPS_LOG_E(opName_, "BSND case input query, key dim 0 are %u, %ld respectively, they must be same.",
|
||||
bSize_, opParamInfo_.key.shape->GetStorageShape().GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF((opParamInfo_.actualSeqLengths.tensor != nullptr) &&
|
||||
(opParamInfo_.actualSeqLengths.tensor->GetShapeSize() != bSize_),
|
||||
OPS_LOG_E(opName_, "BSND case input query, actual_seq_lengths_key dim 0 are %u, %ld respectively, they must be same.",
|
||||
bSize_, opParamInfo_.actualSeqLengths.tensor->GetShapeSize()),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != bSize_) ||
|
||||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != bSize_),
|
||||
OPS_LOG_E(opName_, "BSND case input query, weight, sparse_indices dim 0 are %u, %ld, %ld respectively, they must be same.",
|
||||
bSize_, opParamInfo_.weights.shape->GetStorageShape().GetDim(0),
|
||||
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF((opParamInfo_.actualSeqLengthsQ.tensor != nullptr) &&
|
||||
(opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize() != bSize_),
|
||||
OPS_LOG_E(opName_, "BSND case input query, actual_seq_lengths_query dim 0 are %u, %ld respectively, they must be same.",
|
||||
bSize_, opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize()),
|
||||
return ge::GRAPH_FAILED);
|
||||
// -----------------------check S1-------------------
|
||||
OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(1) != s1Size_) ||
|
||||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1) != s1Size_),
|
||||
OPS_LOG_E(opName_, "BSND case input query, weight, sparse_indices dim 1 are %u, %ld, %ld, they must be same.",
|
||||
s1Size_, opParamInfo_.weights.shape->GetStorageShape().GetDim(1),
|
||||
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1)),
|
||||
return ge::GRAPH_FAILED);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::ValidateInputShapesMatch()
|
||||
{
|
||||
uint32_t queryWeightsN1Dim = 1;
|
||||
uint32_t outN2Dim = 1;
|
||||
if (qLayout_ == DataLayout::TND) {
|
||||
if (ValidateInputShapesMatchQTnd() != ge::GRAPH_SUCCESS) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
} else {
|
||||
if (ValidateInputShapesMatchQBsnd() != ge::GRAPH_SUCCESS) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
queryWeightsN1Dim = DIM_IDX_TWO;
|
||||
outN2Dim = DIM_IDX_TWO;
|
||||
}
|
||||
// -----------------------check N1-------------------
|
||||
OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(queryWeightsN1Dim) != n1Size_),
|
||||
OPS_LOG_E(opName_, "input query, weight shape dim N1 must be same."), return ge::GRAPH_FAILED);
|
||||
// -----------------------check D-------------------
|
||||
uint32_t keyDDim = kLayout_ == DataLayout::TND ? DIM_IDX_TWO : DIM_IDX_THREE;
|
||||
OPS_ERR_IF((opParamInfo_.key.shape->GetStorageShape().GetDim(keyDDim) != headDim_),
|
||||
OPS_LOG_E(opName_, "input query, key shape last dim must be same."), return ge::GRAPH_FAILED);
|
||||
// -----------------------check N2-------------------
|
||||
OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim) != n2Size_),
|
||||
OPS_LOG_E(opName_, "input query and output sparse_indices shape n2 dim must be same."),
|
||||
return ge::GRAPH_FAILED);
|
||||
// -----------------------check sparse_count-------------------
|
||||
OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim + 1) != *opParamInfo_.sparseCount),
|
||||
OPS_LOG_E(opName_, "output sparse_indices shape last dim must be same as attr sparse_count."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
void LIInfoParser::GenerateInfo(LITilingInfo &liInfo)
|
||||
{
|
||||
liInfo.opName = opName_;
|
||||
liInfo.platformInfo = platformInfo_;
|
||||
liInfo.opParamInfo = opParamInfo_;
|
||||
liInfo.socVersion = socVersion_;
|
||||
|
||||
liInfo.bSize = bSize_;
|
||||
liInfo.n1Size = n1Size_;
|
||||
liInfo.n2Size = n2Size_;
|
||||
liInfo.s1Size = s1Size_;
|
||||
liInfo.s2Size = s2Size_;
|
||||
liInfo.gSize = gSize_;
|
||||
|
||||
liInfo.inputQType = inputQType_;
|
||||
liInfo.inputKType = inputKType_;
|
||||
liInfo.outputType = outputType_;
|
||||
|
||||
liInfo.blockSize = blockSize_;
|
||||
liInfo.maxBlockNumPerBatch = maxBlockNumPerBatch_;
|
||||
|
||||
std::string layOutKeyStr(opParamInfo_.layOutKey);
|
||||
liInfo.pageAttentionFlag = layOutKeyStr == "PA_BSND" ? true : false;
|
||||
liInfo.sparseMode = *opParamInfo_.sparseMode;
|
||||
liInfo.sparseCount = *opParamInfo_.sparseCount;
|
||||
|
||||
liInfo.inputQLayout = qLayout_;
|
||||
liInfo.inputKLayout = kLayout_;
|
||||
}
|
||||
|
||||
ge::graphStatus LIInfoParser::ParseAndCheck(LITilingInfo &liInfo)
|
||||
{
|
||||
if (ge::GRAPH_SUCCESS != GetOpName() || ge::GRAPH_SUCCESS != GetNpuInfo() || ge::GRAPH_SUCCESS != GetOpParaInfo() ||
|
||||
ge::GRAPH_SUCCESS != CheckRequiredParaExistence()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (ge::GRAPH_SUCCESS != GetAndCheckInOutDataType() || ge::GRAPH_SUCCESS != GetQueryKeyAndOutLayout() ||
|
||||
ge::GRAPH_SUCCESS != GetAndCheckOptionalInput()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (ge::GRAPH_SUCCESS != CheckShapeDim() || ge::GRAPH_SUCCESS != GetN1Size() ||
|
||||
ge::GRAPH_SUCCESS != GetAndCheckN2Size() || ge::GRAPH_SUCCESS != GetGSize()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (ge::GRAPH_SUCCESS != GetBatchSize() || ge::GRAPH_SUCCESS != GetS1Size() || ge::GRAPH_SUCCESS != GetHeadDim() ||
|
||||
ge::GRAPH_SUCCESS != GetS2Size()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
if (ge::GRAPH_SUCCESS != ValidateInputShapesMatch()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GenerateInfo(liInfo);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus TilingPrepareForLightningIndexer(gert::TilingParseContext * /* context */)
|
||||
{
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LightningIndexerTiling::DoTiling(LITilingInfo *tilingInfo)
|
||||
{
|
||||
// -------------set blockdim-----------------
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(tilingInfo->platformInfo);
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
|
||||
uint32_t blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, aicNum, aivNum);
|
||||
context_->SetBlockDim(blockDim);
|
||||
|
||||
// -------------set workspacesize-----------------
|
||||
constexpr uint32_t MM1_RES_ELEM_SIZE = 4;
|
||||
constexpr uint32_t DOUBLE_BUFFER = 2;
|
||||
constexpr uint32_t M_BASE_SIZE = 512;
|
||||
constexpr uint32_t S2_BASE_SIZE = 512;
|
||||
constexpr uint32_t V1_RES_ELEM_SIZE = 4;
|
||||
constexpr uint32_t V1_RES_ELEM_TYPE = 2;
|
||||
constexpr uint32_t V1_DECODE_PARAM_ELEM_SIZE = 8;
|
||||
constexpr uint32_t V1_DECODE_PARAM_NUM = 16;
|
||||
constexpr uint32_t V1_DECODE_DATA_NUM = 2;
|
||||
constexpr uint32_t S1_BASE_SIZE = 8;
|
||||
constexpr uint32_t TOPK_MAX_SIZE = 2048;
|
||||
uint32_t workspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize();
|
||||
uint32_t mm1ResSize = M_BASE_SIZE * S2_BASE_SIZE;
|
||||
workspaceSize += mm1ResSize * MM1_RES_ELEM_SIZE * DOUBLE_BUFFER * aicNum;
|
||||
workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_RES_ELEM_TYPE * TOPK_MAX_SIZE * V1_RES_ELEM_SIZE * aicNum;
|
||||
workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_DECODE_PARAM_NUM * V1_DECODE_PARAM_ELEM_SIZE * aicNum;
|
||||
size_t *workSpaces = context_->GetWorkspaceSizes(1);
|
||||
workSpaces[0] = workspaceSize;
|
||||
|
||||
// -------------set tilingdata-----------------
|
||||
tilingData_.set_bSize(tilingInfo->bSize);
|
||||
tilingData_.set_s2Size(tilingInfo->s2Size);
|
||||
tilingData_.set_s1Size(tilingInfo->s1Size);
|
||||
tilingData_.set_sparseCount(tilingInfo->sparseCount);
|
||||
tilingData_.set_gSize(tilingInfo->gSize);
|
||||
tilingData_.set_blockSize(tilingInfo->blockSize);
|
||||
tilingData_.set_maxBlockNumPerBatch(tilingInfo->maxBlockNumPerBatch);
|
||||
tilingData_.set_sparseMode(tilingInfo->sparseMode);
|
||||
tilingData_.set_usedCoreNum(blockDim);
|
||||
tilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity());
|
||||
context_->GetRawTilingData()->SetDataSize(tilingData_.GetDataSize());
|
||||
|
||||
// -------------set tilingkey-----------------
|
||||
// DT_Q, DT_KV, DT_OUT, PAGE_ATTENTION, FLASH_DECODE, LAYOUT_T, KV_LAYOUT_T
|
||||
uint32_t inputQType = static_cast<uint32_t>(tilingInfo->inputQType);
|
||||
uint32_t inputKType = static_cast<uint32_t>(tilingInfo->inputKType);
|
||||
uint32_t outputType = static_cast<uint32_t>(tilingInfo->outputType);
|
||||
uint32_t pageAttentionFlag = static_cast<uint32_t>(tilingInfo->pageAttentionFlag);
|
||||
uint32_t inputQLayout = static_cast<uint32_t>(tilingInfo->inputQLayout);
|
||||
uint32_t inputKLayout = static_cast<uint32_t>(tilingInfo->inputKLayout);
|
||||
uint32_t tilingKey =
|
||||
GET_TPL_TILING_KEY(inputQType, inputKType, outputType, pageAttentionFlag, inputQLayout, inputKLayout);
|
||||
context_->SetTilingKey(tilingKey);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus TilingForLightningIndexer(gert::TilingContext *context)
|
||||
{
|
||||
OPS_ERR_IF(context == nullptr, OPS_REPORT_VECTOR_INNER_ERR("LightningIndexer", "Tiling context is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
LITilingInfo liInfo;
|
||||
LIInfoParser LIInfoParser(context);
|
||||
if (LIInfoParser.ParseAndCheck(liInfo) != ge::GRAPH_SUCCESS) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
LightningIndexerTiling liTiling(context);
|
||||
return liTiling.DoTiling(&liInfo);
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(LightningIndexer)
|
||||
.Tiling(TilingForLightningIndexer)
|
||||
.TilingParse<LICompileInfo>(TilingPrepareForLightningIndexer);
|
||||
|
||||
} // namespace optiling
|
||||
215
csrc/lightning_indexer/op_host/lightning_indexer_tiling.h
Normal file
215
csrc/lightning_indexer/op_host/lightning_indexer_tiling.h
Normal file
@@ -0,0 +1,215 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_tiling.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef LIGHTNING_INDEXER_TILING_H_
|
||||
#define LIGHTNING_INDEXER_TILING_H_
|
||||
|
||||
#include "exe_graph/runtime/tiling_context.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "register/tilingdata_base.h"
|
||||
#include "tiling/tiling_api.h"
|
||||
#include "error/ops_error.h"
|
||||
#include "platform/platform_info.h"
|
||||
|
||||
namespace optiling {
|
||||
|
||||
struct TilingRequiredParaInfo {
|
||||
const gert::CompileTimeTensorDesc *desc;
|
||||
const gert::StorageShape *shape;
|
||||
};
|
||||
|
||||
struct TilingOptionalParaInfo {
|
||||
const gert::CompileTimeTensorDesc *desc;
|
||||
const gert::Tensor *tensor;
|
||||
};
|
||||
|
||||
enum class DataLayout : uint32_t {
|
||||
BSND = 0,
|
||||
TND = 1,
|
||||
BnBsND = 2
|
||||
};
|
||||
|
||||
// Inputs Index
|
||||
constexpr uint32_t QUERY_INDEX = 0;
|
||||
constexpr uint32_t KEY_INDEX = 1;
|
||||
constexpr uint32_t WEIGTHS_INDEX = 2;
|
||||
constexpr uint32_t ACTUAL_SEQ_Q_INDEX = 3;
|
||||
constexpr uint32_t ACTUAL_SEQ_K_INDEX = 4;
|
||||
constexpr uint32_t BLOCK_TABLE_INDEX = 5;
|
||||
constexpr uint32_t LIGHTNING_INDEXER = 0;
|
||||
// Attributes Index
|
||||
constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 0;
|
||||
constexpr uint32_t ATTR_KEY_LAYOUT_INDEX = 1;
|
||||
constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 2;
|
||||
constexpr uint32_t ATTR_SPARSE_MODE_INDEX = 3;
|
||||
// Dim Index
|
||||
constexpr uint32_t DIM_IDX_ONE = 1;
|
||||
constexpr uint32_t DIM_IDX_TWO = 2;
|
||||
constexpr uint32_t DIM_IDX_THREE = 3;
|
||||
// Dim Num
|
||||
constexpr uint32_t DIM_NUM_TWO = 2;
|
||||
constexpr uint32_t DIM_NUM_THREE = 3;
|
||||
constexpr uint32_t DIM_NUM_FOUR = 4;
|
||||
// Input Parameter Limit Constant
|
||||
constexpr uint32_t HEAD_DIM_LIMIT = 128;
|
||||
constexpr uint32_t SPARSE_LIMIT = 2048;
|
||||
constexpr uint32_t SPARSE_MODE_LOWER = 3;
|
||||
|
||||
BEGIN_TILING_DATA_DEF(LITilingData)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, bSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, n2Size)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, gSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, s1Size)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, s2Size)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, sparseCount)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, blockSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, sparseMode)
|
||||
END_TILING_DATA_DEF
|
||||
REGISTER_TILING_DATA_CLASS(LightningIndexer, LITilingData)
|
||||
|
||||
struct LICompileInfo {};
|
||||
|
||||
struct LiParaInfo {
|
||||
TilingRequiredParaInfo query = {nullptr, nullptr};
|
||||
TilingRequiredParaInfo key = {nullptr, nullptr};
|
||||
TilingRequiredParaInfo weights = {nullptr, nullptr};
|
||||
TilingOptionalParaInfo actualSeqLengthsQ = {nullptr, nullptr};
|
||||
TilingOptionalParaInfo actualSeqLengths = {nullptr, nullptr};
|
||||
TilingOptionalParaInfo blockTable = {nullptr, nullptr};
|
||||
TilingRequiredParaInfo attenOut = {nullptr, nullptr};
|
||||
|
||||
const char *layOut = nullptr;
|
||||
const char *layOutKey = nullptr;
|
||||
const int32_t *blockSize = nullptr;
|
||||
const int32_t *sparseMode = nullptr;
|
||||
const int32_t *sparseCount = nullptr;
|
||||
};
|
||||
|
||||
class LITilingInfo {
|
||||
public:
|
||||
const char *opName = nullptr;
|
||||
fe::PlatFormInfos *platformInfo = nullptr;
|
||||
LiParaInfo opParamInfo;
|
||||
// Base Param
|
||||
platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910B;
|
||||
uint32_t bSize = 0;
|
||||
uint32_t n1Size = 0;
|
||||
uint32_t n2Size = 0;
|
||||
uint32_t s1Size = 0;
|
||||
int64_t s2Size = 0;
|
||||
uint32_t qkHeadDim = 0;
|
||||
uint32_t gSize = 0;
|
||||
// PageAttention
|
||||
bool pageAttentionFlag = false;
|
||||
int32_t blockSize = 0;
|
||||
uint32_t maxBlockNumPerBatch = 0;
|
||||
// Mask
|
||||
int32_t sparseMode = 0;
|
||||
// Others Flag
|
||||
uint32_t sparseCount = 0;
|
||||
// DType
|
||||
ge::DataType inputQType = ge::DT_FLOAT16;
|
||||
ge::DataType inputKType = ge::DT_FLOAT16;
|
||||
ge::DataType outputType = ge::DT_INT32;
|
||||
// Layout
|
||||
DataLayout inputQLayout = DataLayout::BSND;
|
||||
DataLayout inputKLayout = DataLayout::BnBsND;
|
||||
};
|
||||
|
||||
class LIInfoParser {
|
||||
public:
|
||||
explicit LIInfoParser(gert::TilingContext *context) : context_(context)
|
||||
{
|
||||
}
|
||||
~LIInfoParser() = default;
|
||||
|
||||
ge::graphStatus CheckRequiredInOutExistence() const;
|
||||
ge::graphStatus CheckRequiredAttrExistence() const;
|
||||
ge::graphStatus CheckRequiredParaExistence() const;
|
||||
ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor,
|
||||
const std::string &actualSeqLenName);
|
||||
ge::graphStatus GetOpName();
|
||||
ge::graphStatus GetNpuInfo();
|
||||
void GetOptionalInputParaInfo();
|
||||
void GetInputParaInfo();
|
||||
void GetOutputParaInfo();
|
||||
ge::graphStatus GetAndCheckAttrParaInfo();
|
||||
ge::graphStatus GetOpParaInfo();
|
||||
ge::graphStatus ValidateInputShapesMatchQBsnd();
|
||||
ge::graphStatus ValidateInputShapesMatchQTnd();
|
||||
ge::graphStatus ValidateInputShapesMatch();
|
||||
ge::graphStatus GetAndCheckInOutDataType();
|
||||
ge::graphStatus GetBatchSize();
|
||||
ge::graphStatus GetHeadDim();
|
||||
ge::graphStatus GetS1Size();
|
||||
ge::graphStatus GetAndCheckOptionalInput();
|
||||
ge::graphStatus CheckShapeDim();
|
||||
ge::graphStatus GetAndCheckBlockSize();
|
||||
ge::graphStatus CheckBlockCount();
|
||||
ge::graphStatus GetS2SizeForPageAttention();
|
||||
ge::graphStatus GetS2Size();
|
||||
ge::graphStatus GetQueryKeyAndOutLayout();
|
||||
ge::graphStatus GetN1Size();
|
||||
ge::graphStatus GetAndCheckN2Size();
|
||||
ge::graphStatus GetGSize();
|
||||
ge::graphStatus GetAttenMaskInfo();
|
||||
ge::graphStatus GetActualSeqInfo();
|
||||
void GenerateInfo(LITilingInfo &liInfo);
|
||||
ge::graphStatus ParseAndCheck(LITilingInfo &liInfo);
|
||||
|
||||
public:
|
||||
gert::TilingContext *context_ = nullptr;
|
||||
const char *opName_;
|
||||
fe::PlatFormInfos *platformInfo_;
|
||||
LiParaInfo opParamInfo_;
|
||||
|
||||
// BaseParams
|
||||
uint32_t bSize_ = 0;
|
||||
uint32_t n1Size_ = 0;
|
||||
uint32_t n2Size_ = 0;
|
||||
uint32_t gSize_ = 0;
|
||||
uint32_t s1Size_ = 0;
|
||||
int64_t s2Size_ = 0;
|
||||
uint32_t headDim_ = 0;
|
||||
// Layout
|
||||
DataLayout qLayout_ = DataLayout::BSND;
|
||||
DataLayout kLayout_ = DataLayout::BnBsND;
|
||||
// PageAttention
|
||||
uint32_t maxBlockNumPerBatch_ = 0;
|
||||
int32_t blockSize_ = 0;
|
||||
platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B;
|
||||
ge::DataType inputQType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputKType_ = ge::DT_FLOAT16;
|
||||
ge::DataType weightsType_ = ge::DT_FLOAT16;
|
||||
ge::DataType blockTableType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputKRopeType_ = ge::DT_FLOAT16;
|
||||
ge::DataType outputType_ = ge::DT_FLOAT16;
|
||||
};
|
||||
|
||||
class LightningIndexerTiling {
|
||||
public:
|
||||
explicit LightningIndexerTiling(gert::TilingContext *context) : context_(context){};
|
||||
ge::graphStatus DoTiling(LITilingInfo *tilingInfo);
|
||||
|
||||
private:
|
||||
gert::TilingContext *context_ = nullptr;
|
||||
LITilingData tilingData_;
|
||||
};
|
||||
|
||||
} // namespace optiling
|
||||
#endif // LIGHTNING_INDEXER_TILING_H_
|
||||
58
csrc/lightning_indexer/op_kernel/lightning_indexer.cpp
Normal file
58
csrc/lightning_indexer/op_kernel/lightning_indexer.cpp
Normal file
@@ -0,0 +1,58 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
#include "lightning_indexer_template_tiling_key.h"
|
||||
#include "lightning_indexer_kernel.h"
|
||||
|
||||
using namespace LIKernel;
|
||||
|
||||
#define INVOKE_LI_NO_KFC_OP_IMPL(templateClass, ...) \
|
||||
do { \
|
||||
templateClass<LIType<__VA_ARGS__>> op; \
|
||||
LI_COPY_TILING_DATA(LITilingData, tiling); \
|
||||
op.Init(query, key, weights, actualSeqLengthsQ, actualSeqLengths, blocktable, sparseIndices, user, \
|
||||
tiling_data, &tPipe); \
|
||||
op.Process(); \
|
||||
} while (0)
|
||||
|
||||
#define LI_COPY_TILING_DATA(tilingDataStruct, tiling) \
|
||||
GET_TILING_DATA_WITH_STRUCT(tilingDataStruct, tiling_data_in, tiling); \
|
||||
const tilingDataStruct *__restrict tiling_data = &tiling_data_in;
|
||||
|
||||
|
||||
template <int DT_Q, int DT_K, int DT_OUT, int PAGE_ATTENTION, int LAYOUT_T, int K_LAYOUT_T>
|
||||
__global__ __aicore__ void lightning_indexer(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
|
||||
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths,
|
||||
__gm__ uint8_t *blocktable, __gm__ uint8_t *sparseIndices,
|
||||
__gm__ uint8_t *workspace, __gm__ uint8_t *tiling)
|
||||
{
|
||||
#if (__CCE_AICORE__ == 310) || (defined __DAV_310R6__) || (__CCE_AICORE__ == 200)
|
||||
|
||||
#else
|
||||
TPipe tPipe;
|
||||
__gm__ uint8_t *user = GetUserWorkspace(workspace);
|
||||
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2);
|
||||
|
||||
if constexpr (DT_Q == LI_TPL_FP16 && DT_K == LI_TPL_FP16 && DT_OUT == LI_TPL_INT32) {
|
||||
INVOKE_LI_NO_KFC_OP_IMPL(LIPreload, half, half, int32_t, PAGE_ATTENTION,
|
||||
LI_LAYOUT(LAYOUT_T), LI_LAYOUT(K_LAYOUT_T));
|
||||
} else {
|
||||
INVOKE_LI_NO_KFC_OP_IMPL(LIPreload, bfloat16_t, bfloat16_t, int32_t, PAGE_ATTENTION,
|
||||
LI_LAYOUT(LAYOUT_T), LI_LAYOUT(K_LAYOUT_T));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
135
csrc/lightning_indexer/op_kernel/lightning_indexer_common.h
Normal file
135
csrc/lightning_indexer/op_kernel/lightning_indexer_common.h
Normal file
@@ -0,0 +1,135 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_common.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef LIGHTNING_INDEXER_COMMON_H
|
||||
#define LIGHTNING_INDEXER_COMMON_H
|
||||
|
||||
namespace LICommon {
|
||||
enum class LI_LAYOUT {
|
||||
BSND = 0,
|
||||
TND = 1,
|
||||
PA_BSND = 2
|
||||
};
|
||||
|
||||
template <typename Q_T, typename K_T, typename OUT_T, const bool PAGE_ATTENTION = false,
|
||||
LI_LAYOUT LAYOUT_T = LI_LAYOUT::BSND, LI_LAYOUT K_LAYOUT_T = LI_LAYOUT::PA_BSND, typename... Args>
|
||||
struct LIType {
|
||||
using queryType = Q_T;
|
||||
using keyType = K_T;
|
||||
using outputType = OUT_T;
|
||||
static constexpr bool pageAttention = PAGE_ATTENTION;
|
||||
static constexpr LI_LAYOUT layout = LAYOUT_T;
|
||||
static constexpr LI_LAYOUT keyLayout = K_LAYOUT_T;
|
||||
};
|
||||
|
||||
struct RunInfo {
|
||||
uint32_t loop;
|
||||
uint32_t bN2Idx;
|
||||
uint32_t bIdx;
|
||||
uint32_t n2Idx = 0;
|
||||
uint32_t gS1Idx;
|
||||
uint32_t s2Idx;
|
||||
|
||||
uint32_t actS1Size = 1;
|
||||
uint32_t actS2Size = 1;
|
||||
uint32_t actMBaseSize;
|
||||
uint32_t actualSingleProcessSInnerSize;
|
||||
uint32_t actualSingleProcessSInnerSizeAlign;
|
||||
|
||||
uint64_t tensorQueryOffset;
|
||||
uint64_t tensorKeyOffset;
|
||||
uint64_t tensorWeightsOffset;
|
||||
uint64_t indiceOutOffset;
|
||||
|
||||
bool isFirstS2InnerLoop;
|
||||
bool isLastS2InnerLoop;
|
||||
bool isAllLoopEnd = false;
|
||||
};
|
||||
|
||||
struct ConstInfo {
|
||||
static constexpr uint32_t FIA_SYNC_MODE2 = 2;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_32B = 32;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_64B = 64;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_256B = 256;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_512B = 512;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_1K = 1024;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_2K = 2048;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_4K = 4096;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_8K = 8192;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_16K = 16384;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_32K = 32768;
|
||||
static constexpr int INVALID_IDX = -1;
|
||||
|
||||
uint32_t syncC1V1 = 0U;
|
||||
uint32_t syncV1C1 = 0U;
|
||||
|
||||
uint32_t mBaseSize = 1ULL;
|
||||
uint32_t s1BaseSize = 1ULL;
|
||||
uint32_t s2BaseSize = 1ULL;
|
||||
|
||||
uint64_t batchSize = 0ULL;
|
||||
uint64_t gSize = 0ULL;
|
||||
uint64_t qHeadNum = 0ULL;
|
||||
uint64_t kHeadNum;
|
||||
uint64_t headDim;
|
||||
uint64_t sparseCount;
|
||||
uint64_t kSeqSize = 0ULL;
|
||||
uint64_t qSeqSize = 1ULL;
|
||||
uint32_t kCacheBlockSize = 0;
|
||||
uint32_t maxBlockNumPerBatch = 0;
|
||||
LI_LAYOUT outputLayout;
|
||||
bool attenMaskFlag = false;
|
||||
|
||||
uint32_t actualLenQDims = 0U;
|
||||
uint32_t actualLenDims = 0U;
|
||||
bool isAccumSeqS1 = false;
|
||||
bool isAccumSeqS2 = false;
|
||||
};
|
||||
|
||||
struct SplitCoreInfo {
|
||||
uint32_t s2Start = 0U;
|
||||
uint32_t s2End = 0U;
|
||||
uint32_t bN2Start = 0U;
|
||||
uint32_t bN2End = 0U;
|
||||
uint32_t gS1Start = 0U;
|
||||
uint32_t gS1End = 0U;
|
||||
bool isLD = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T Align(T num, T rnd)
|
||||
{
|
||||
return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd) * (rnd)));
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__aicore__ inline T1 Min(T1 a, T2 b)
|
||||
{
|
||||
return (a > b) ? (b) : (a);
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__aicore__ inline T1 Max(T1 a, T2 b)
|
||||
{
|
||||
return (a > b) ? (a) : (b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T CeilDiv(T num, T rnd)
|
||||
{
|
||||
return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd)));
|
||||
}
|
||||
} // namespace LICommon
|
||||
|
||||
#endif // LIGHTNING_INDEXER_COMMON_H
|
||||
623
csrc/lightning_indexer/op_kernel/lightning_indexer_kernel.h
Normal file
623
csrc/lightning_indexer/op_kernel/lightning_indexer_kernel.h
Normal file
@@ -0,0 +1,623 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_kernel.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef LIGHTNING_INDEXER_KERNEL_H
|
||||
#define LIGHTNING_INDEXER_KERNEL_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_operator_list_tensor_intf.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
#include "lib/matrix/matmul/tiling.h"
|
||||
#include "lightning_indexer_common.h"
|
||||
#include "lightning_indexer_service_vector.h"
|
||||
#include "lightning_indexer_service_cube.h"
|
||||
|
||||
namespace LIKernel {
|
||||
using namespace LICommon;
|
||||
using namespace LIServiceVec;
|
||||
using namespace matmul;
|
||||
using AscendC::CacheMode;
|
||||
using AscendC::CrossCoreSetFlag;
|
||||
using AscendC::CrossCoreWaitFlag;
|
||||
|
||||
struct TempLoopInfo {
|
||||
uint32_t bN2Idx = 0;
|
||||
uint32_t bIdx = 0U;
|
||||
uint32_t n2Idx = 0U;
|
||||
uint32_t gS1Idx = 0U;
|
||||
uint32_t gS1LoopEnd = 0U;
|
||||
uint32_t s2LoopEnd = 0U;
|
||||
uint32_t actS1Size = 1ULL;
|
||||
uint32_t actS2Size = 0ULL;
|
||||
bool curActSeqLenIsZero = false;
|
||||
bool needDealActS1LessThanS1 = false;
|
||||
uint32_t actMBaseSize = 0U;
|
||||
uint32_t mBasicSizeTail = 0U;
|
||||
uint32_t s2BasicSizeTail = 0U;
|
||||
};
|
||||
|
||||
template <typename LIT>
|
||||
class LIPreload {
|
||||
public:
|
||||
__aicore__ inline LIPreload(){};
|
||||
__aicore__ inline void Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
|
||||
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths,
|
||||
__gm__ uint8_t *blockTable, __gm__ uint8_t *sparseIndices, __gm__ uint8_t *workspace,
|
||||
const LITilingData *__restrict tiling, TPipe *tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
using Q_T = typename LIT::queryType;
|
||||
using K_T = typename LIT::keyType;
|
||||
using OUT_T = typename LIT::outputType;
|
||||
static constexpr bool PAGE_ATTENTION = LIT::pageAttention;
|
||||
static constexpr LI_LAYOUT LAYOUT_T = LIT::layout;
|
||||
static constexpr LI_LAYOUT K_LAYOUT_T = LIT::keyLayout;
|
||||
|
||||
using MM1_OUT_T = float;
|
||||
|
||||
LIMatmul<LIT> matmulService;
|
||||
LIVector<LIT> vectorService;
|
||||
|
||||
static constexpr uint32_t SYNC_C1_V1_FLAG = 4;
|
||||
static constexpr uint32_t SYNC_V1_C1_FLAG = 5;
|
||||
|
||||
static constexpr uint32_t M_BASE_SIZE = 512;
|
||||
static constexpr uint32_t S2_BASE_SIZE = 512;
|
||||
static constexpr uint32_t HEAD_DIM = 128;
|
||||
static constexpr uint32_t K_HEAD_NUM = 1;
|
||||
static constexpr uint32_t GM_ALIGN_BYTES = 512;
|
||||
|
||||
static constexpr int64_t LD_PREFETCH_LEN = 2;
|
||||
// for workspace double
|
||||
static constexpr uint32_t WS_DOBULE = 2;
|
||||
|
||||
protected:
|
||||
TPipe *pipe = nullptr;
|
||||
|
||||
// offset
|
||||
uint64_t queryCoreOffset = 0ULL;
|
||||
uint64_t keyCoreOffset = 0ULL;
|
||||
uint64_t weightsCoreOffset = 0ULL;
|
||||
uint64_t indiceOutCoreOffset = 0ULL;
|
||||
|
||||
GlobalTensor<Q_T> queryGm;
|
||||
GlobalTensor<K_T> keyGm;
|
||||
GlobalTensor<K_T> weightsGm;
|
||||
|
||||
GlobalTensor<int32_t> indiceOutGm;
|
||||
GlobalTensor<int32_t> blockTableGm;
|
||||
|
||||
GlobalTensor<uint32_t> actualSeqLengthsGmQ;
|
||||
GlobalTensor<uint32_t> actualSeqLengthsGm;
|
||||
// workspace
|
||||
GlobalTensor<MM1_OUT_T> mm1ResGm;
|
||||
GlobalTensor<float> vec1ResGm;
|
||||
GlobalTensor<int64_t> vec1ParamGm;
|
||||
|
||||
// aic、aiv kernel info
|
||||
uint32_t tmpBlockIdx = 0U;
|
||||
uint32_t aiCoreIdx = 0U;
|
||||
uint32_t usedCoreNum = 0U;
|
||||
|
||||
LICommon::ConstInfo constInfo{};
|
||||
TempLoopInfo tempLoopInfo{};
|
||||
LICommon::SplitCoreInfo splitCoreInfo{};
|
||||
|
||||
// ================================Init functions==================================
|
||||
__aicore__ inline void InitTilingData(const LITilingData *__restrict tilingData);
|
||||
__aicore__ inline void InitBuffers();
|
||||
__aicore__ inline void InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths);
|
||||
// ================================Split Core================================
|
||||
__aicore__ inline void SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, LICommon::SplitCoreInfo &info);
|
||||
__aicore__ inline uint32_t GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size, uint32_t actS2Size);
|
||||
__aicore__ inline uint32_t GetTotalBaseBlockNum();
|
||||
// ================================Process functions================================
|
||||
__aicore__ inline void ProcessMain();
|
||||
__aicore__ inline void ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx, LICommon::RunInfo &runInfo);
|
||||
__aicore__ inline void ProcessDecode();
|
||||
__aicore__ inline void ProcessInvalid();
|
||||
// ================================Params Calc=====================================
|
||||
__aicore__ inline void CalcGS1LoopParams(uint32_t bN2Idx);
|
||||
__aicore__ inline void GetBN2Idx(uint32_t bN2Idx);
|
||||
__aicore__ inline uint32_t GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq,
|
||||
GlobalTensor<uint32_t> &actualSeqLengthsGm, uint32_t defaultSeqLen);
|
||||
__aicore__ inline void GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size);
|
||||
__aicore__ inline void CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx);
|
||||
__aicore__ inline void CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LICommon::RunInfo &runInfo);
|
||||
__aicore__ inline void DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start);
|
||||
};
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIPreload<LIT>::InitTilingData(const LITilingData *__restrict tilingData)
|
||||
{
|
||||
usedCoreNum = tilingData->usedCoreNum;
|
||||
constInfo.batchSize = tilingData->bSize;
|
||||
constInfo.qHeadNum = constInfo.gSize = tilingData->gSize;
|
||||
constInfo.kSeqSize = tilingData->s2Size;
|
||||
constInfo.qSeqSize = tilingData->s1Size;
|
||||
constInfo.attenMaskFlag = (tilingData->sparseMode == 3);
|
||||
constInfo.kCacheBlockSize = tilingData->blockSize;
|
||||
constInfo.maxBlockNumPerBatch = tilingData->maxBlockNumPerBatch;
|
||||
constInfo.sparseCount = tilingData->sparseCount;
|
||||
constInfo.outputLayout = LAYOUT_T;
|
||||
if (LAYOUT_T == LI_LAYOUT::TND) {
|
||||
constInfo.isAccumSeqS1 = true;
|
||||
}
|
||||
if (K_LAYOUT_T == LI_LAYOUT::TND) {
|
||||
constInfo.isAccumSeqS2 = true;
|
||||
}
|
||||
|
||||
constInfo.kHeadNum = K_HEAD_NUM;
|
||||
constInfo.headDim = HEAD_DIM;
|
||||
|
||||
constInfo.mBaseSize = M_BASE_SIZE;
|
||||
constInfo.s2BaseSize = S2_BASE_SIZE;
|
||||
constInfo.s1BaseSize = (constInfo.mBaseSize + constInfo.gSize - 1) / constInfo.gSize;
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIPreload<LIT>::InitBuffers()
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.InitBuffers(pipe);
|
||||
} else {
|
||||
matmulService.InitBuffers(pipe);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIPreload<LIT>::InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ,
|
||||
__gm__ uint8_t *actualSeqLengths)
|
||||
{
|
||||
if (actualSeqLengthsQ == nullptr) {
|
||||
constInfo.actualLenQDims = 0;
|
||||
} else {
|
||||
constInfo.actualLenQDims = constInfo.batchSize;
|
||||
actualSeqLengthsGmQ.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengthsQ, constInfo.actualLenQDims);
|
||||
}
|
||||
if (actualSeqLengths == nullptr) {
|
||||
constInfo.actualLenDims = 0;
|
||||
} else {
|
||||
constInfo.actualLenDims = constInfo.batchSize;
|
||||
actualSeqLengthsGm.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengths, constInfo.actualLenDims);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline uint32_t LIPreload<LIT>::GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq,
|
||||
GlobalTensor<uint32_t> &actualSeqLengthsGm,
|
||||
uint32_t defaultSeqLen)
|
||||
{
|
||||
if (actualLenDims == 0) {
|
||||
return defaultSeqLen;
|
||||
} else if (isAccumSeq && bIdx > 0) {
|
||||
return actualSeqLengthsGm.GetValue(bIdx) - actualSeqLengthsGm.GetValue(bIdx - 1);
|
||||
} else {
|
||||
return actualSeqLengthsGm.GetValue(bIdx);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIPreload<LIT>::GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size)
|
||||
{
|
||||
actS1Size = GetActualSeqLen(bIdx, constInfo.actualLenQDims, constInfo.isAccumSeqS1, actualSeqLengthsGmQ,
|
||||
constInfo.qSeqSize);
|
||||
actS2Size =
|
||||
GetActualSeqLen(bIdx, constInfo.actualLenDims, constInfo.isAccumSeqS2, actualSeqLengthsGm, constInfo.kSeqSize);
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline uint32_t LIPreload<LIT>::GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size,
|
||||
uint32_t actS2Size)
|
||||
{
|
||||
if (actS2Size == 0) {
|
||||
return 0;
|
||||
}
|
||||
uint32_t s1Offset = constInfo.s1BaseSize * s1gIdx;
|
||||
int32_t validS2LenBase = static_cast<int32_t>(actS2Size) - static_cast<int32_t>(actS1Size);
|
||||
int32_t validS2Len = s1Offset + validS2LenBase + constInfo.s1BaseSize;
|
||||
validS2Len = Min(validS2Len, static_cast<int32_t>(actS2Size));
|
||||
validS2Len = Max(validS2Len, 1);
|
||||
return (validS2Len + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline uint32_t LIPreload<LIT>::GetTotalBaseBlockNum()
|
||||
{
|
||||
uint32_t totalBlockNum = 0;
|
||||
uint32_t actS1Size, actS2Size;
|
||||
uint32_t s1GBaseNum, s2BaseNum;
|
||||
for (uint32_t bIdx = 0; bIdx < constInfo.batchSize; bIdx++) {
|
||||
GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size);
|
||||
s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize);
|
||||
if (!constInfo.attenMaskFlag) {
|
||||
s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize);
|
||||
totalBlockNum += s1GBaseNum * s2BaseNum * constInfo.kHeadNum;
|
||||
continue;
|
||||
}
|
||||
for (uint32_t s1gIdx = 0; s1gIdx < s1GBaseNum; s1gIdx++) {
|
||||
s2BaseNum = GetS2BaseBlockNumOnMask(s1gIdx, actS1Size, actS2Size);
|
||||
totalBlockNum += s2BaseNum * constInfo.kHeadNum;
|
||||
}
|
||||
}
|
||||
return totalBlockNum;
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ void inline LIPreload<LIT>::SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, LICommon::SplitCoreInfo &info)
|
||||
{
|
||||
uint32_t totalBlockNum = GetTotalBaseBlockNum();
|
||||
uint32_t minBlockPerCore = totalBlockNum / coreNum;
|
||||
uint32_t deal1MoreBlockCoreNum = totalBlockNum % coreNum;
|
||||
uint32_t coreIdx = 0;
|
||||
uint32_t lastGS1RemainBlockCnt = 0;
|
||||
uint32_t coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore;
|
||||
coreNum = minBlockPerCore == 0 ? deal1MoreBlockCoreNum : coreNum;
|
||||
|
||||
bool findLastCoreEnd = true;
|
||||
uint32_t actS1Size, actS2Size;
|
||||
uint32_t s1GBaseNum, s2BaseNum;
|
||||
for (uint32_t bN2Idx = 0; bN2Idx < constInfo.batchSize * constInfo.kHeadNum; bN2Idx++) {
|
||||
uint32_t bIdx = bN2Idx / constInfo.kHeadNum;
|
||||
if (bN2Idx % constInfo.kHeadNum == 0) {
|
||||
GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size);
|
||||
s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize);
|
||||
s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize);
|
||||
}
|
||||
if constexpr (LAYOUT_T == LI_LAYOUT::BSND) {
|
||||
if (findLastCoreEnd && (s1GBaseNum == 0U || s2BaseNum == 0U)) {
|
||||
info.bN2Start = bN2Idx;
|
||||
info.gS1Start = 0;
|
||||
info.s2Start = 0;
|
||||
findLastCoreEnd = false;
|
||||
}
|
||||
}
|
||||
for (uint32_t gS1Idx = 0; gS1Idx < s1GBaseNum; gS1Idx++) {
|
||||
if (constInfo.attenMaskFlag) {
|
||||
s2BaseNum = GetS2BaseBlockNumOnMask(gS1Idx, actS1Size, actS2Size);
|
||||
}
|
||||
if (findLastCoreEnd && s2BaseNum == 0U) {
|
||||
info.bN2Start = bN2Idx;
|
||||
info.gS1Start = gS1Idx;
|
||||
info.s2Start = 0;
|
||||
findLastCoreEnd = false;
|
||||
}
|
||||
for (uint32_t s2Idx = 0; s2Idx < s2BaseNum;) {
|
||||
if (findLastCoreEnd) {
|
||||
info.bN2Start = bN2Idx;
|
||||
info.gS1Start = gS1Idx;
|
||||
info.s2Start = s2Idx;
|
||||
findLastCoreEnd = false;
|
||||
}
|
||||
uint32_t s2RemainBaseNum = s2BaseNum - s2Idx;
|
||||
if (lastGS1RemainBlockCnt + s2RemainBaseNum >= coreDealBlockCnt) {
|
||||
info.bN2End = bN2Idx;
|
||||
info.gS1End = gS1Idx;
|
||||
info.s2End = s2Idx + coreDealBlockCnt - lastGS1RemainBlockCnt - 1;
|
||||
|
||||
if (coreIdx == curCoreIdx) {
|
||||
if (s2Idx == 0 && info.s2End + 1 < s2BaseNum) {
|
||||
info.isLD = true;
|
||||
}
|
||||
if (coreIdx == coreNum - 1 && info.bN2End != constInfo.batchSize -1) {
|
||||
info.bN2End = constInfo.batchSize -1;
|
||||
info.gS1End = 0;
|
||||
info.s2End = 0;
|
||||
}
|
||||
return;
|
||||
}
|
||||
coreIdx++;
|
||||
findLastCoreEnd = true;
|
||||
s2Idx = info.s2End + 1;
|
||||
lastGS1RemainBlockCnt = 0;
|
||||
coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore;
|
||||
} else {
|
||||
lastGS1RemainBlockCnt += s2RemainBaseNum;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIPreload<LIT>::DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start)
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
if (constInfo.outputLayout == LI_LAYOUT::TND) {
|
||||
uint32_t tSize = actualSeqLengthsGmQ.GetValue(constInfo.batchSize - 1);
|
||||
uint32_t tBase = bIdx == 0 ? 0 : actualSeqLengthsGmQ.GetValue(bIdx - 1);
|
||||
uint32_t s1Count = tempLoopInfo.actS1Size;
|
||||
|
||||
for (uint32_t s1Idx = s1Start; s1Idx < s1Count; s1Idx++) {
|
||||
uint64_t indiceOutOffset =
|
||||
(tBase + s1Idx) * constInfo.kHeadNum * constInfo.sparseCount +
|
||||
n2Idx * constInfo.sparseCount;
|
||||
vectorService.CleanInvalidOutput(indiceOutOffset);
|
||||
}
|
||||
} else if (constInfo.outputLayout == LI_LAYOUT::BSND) {
|
||||
for (uint32_t s1Idx = s1Start; s1Idx < constInfo.qSeqSize; s1Idx++) {
|
||||
// B,S1,N2,K
|
||||
uint64_t indiceOutOffset = bIdx * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount +
|
||||
s1Idx * constInfo.kHeadNum * constInfo.sparseCount +
|
||||
n2Idx * constInfo.sparseCount;
|
||||
vectorService.CleanInvalidOutput(indiceOutOffset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIPreload<LIT>::Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
|
||||
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths,
|
||||
__gm__ uint8_t *blockTable, __gm__ uint8_t *sparseIndices,
|
||||
__gm__ uint8_t *workspace, const LITilingData *__restrict tiling,
|
||||
TPipe *tPipe)
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
tmpBlockIdx = GetBlockIdx(); // vec:0-47
|
||||
aiCoreIdx = tmpBlockIdx / 2;
|
||||
} else {
|
||||
tmpBlockIdx = GetBlockIdx(); // cube:0-23
|
||||
aiCoreIdx = tmpBlockIdx;
|
||||
}
|
||||
|
||||
InitTilingData(tiling);
|
||||
InitActualSeqLen(actualSeqLengthsQ, actualSeqLengths);
|
||||
|
||||
SplitCore(aiCoreIdx, usedCoreNum, splitCoreInfo);
|
||||
|
||||
pipe = tPipe;
|
||||
uint64_t offset = 0;
|
||||
uint64_t singleCoreMm1ResSize = WS_DOBULE * constInfo.mBaseSize * constInfo.s2BaseSize * sizeof(MM1_OUT_T);
|
||||
mm1ResGm.SetGlobalBuffer((__gm__ MM1_OUT_T *)(workspace + offset + aiCoreIdx * singleCoreMm1ResSize));
|
||||
offset += GetBlockNum() * singleCoreMm1ResSize;
|
||||
|
||||
vec1ResGm.SetGlobalBuffer((__gm__ float *)(workspace + offset));
|
||||
offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * WS_DOBULE * BASE_TOPK * sizeof(float);
|
||||
|
||||
vec1ParamGm.SetGlobalBuffer((__gm__ int64_t *)(workspace + offset));
|
||||
offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * LD_PARAM_NUM * sizeof(int64_t);
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.InitParams(constInfo, tiling);
|
||||
indiceOutGm.SetGlobalBuffer((__gm__ int32_t *)sparseIndices);
|
||||
weightsGm.SetGlobalBuffer((__gm__ K_T *)weights);
|
||||
vectorService.InitVec1GlobalTensor(mm1ResGm, vec1ResGm, vec1ParamGm, weightsGm, indiceOutGm);
|
||||
} else {
|
||||
matmulService.InitParams(constInfo);
|
||||
queryGm.SetGlobalBuffer((__gm__ Q_T *)query);
|
||||
if constexpr (PAGE_ATTENTION) {
|
||||
blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable);
|
||||
}
|
||||
keyGm.SetGlobalBuffer((__gm__ K_T *)key);
|
||||
matmulService.InitMm1GlobalTensor(blockTableGm, keyGm, queryGm, mm1ResGm);
|
||||
}
|
||||
InitBuffers();
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIPreload<LIT>::GetBN2Idx(uint32_t bN2Idx)
|
||||
{
|
||||
tempLoopInfo.bN2Idx = bN2Idx;
|
||||
tempLoopInfo.bIdx = bN2Idx / constInfo.kHeadNum;
|
||||
tempLoopInfo.n2Idx = bN2Idx % constInfo.kHeadNum;
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIPreload<LIT>::CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx)
|
||||
{
|
||||
tempLoopInfo.gS1Idx = gS1LoopIdx;
|
||||
tempLoopInfo.actMBaseSize = constInfo.mBaseSize;
|
||||
uint32_t remainedGS1Size = tempLoopInfo.actS1Size * constInfo.gSize - tempLoopInfo.gS1Idx * constInfo.mBaseSize;
|
||||
if (remainedGS1Size <= constInfo.mBaseSize && remainedGS1Size > 0) {
|
||||
tempLoopInfo.actMBaseSize = tempLoopInfo.mBasicSizeTail;
|
||||
}
|
||||
|
||||
bool isEnd = (bN2LoopIdx == splitCoreInfo.bN2End) && (gS1LoopIdx == splitCoreInfo.gS1End);
|
||||
uint32_t s2BlockNum;
|
||||
if (constInfo.attenMaskFlag) {
|
||||
s2BlockNum = GetS2BaseBlockNumOnMask(gS1LoopIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size);
|
||||
} else {
|
||||
s2BlockNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
|
||||
}
|
||||
tempLoopInfo.s2LoopEnd = isEnd ? splitCoreInfo.s2End : s2BlockNum - 1;
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIPreload<LIT>::CalcGS1LoopParams(uint32_t bN2LoopIdx)
|
||||
{
|
||||
GetBN2Idx(bN2LoopIdx);
|
||||
GetS1S2ActualSeqLen(tempLoopInfo.bIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size);
|
||||
if ((tempLoopInfo.actS2Size == 0) || (tempLoopInfo.actS1Size == 0)) {
|
||||
tempLoopInfo.curActSeqLenIsZero = true;
|
||||
return;
|
||||
}
|
||||
tempLoopInfo.curActSeqLenIsZero = false;
|
||||
tempLoopInfo.s2BasicSizeTail = tempLoopInfo.actS2Size % constInfo.s2BaseSize;
|
||||
tempLoopInfo.s2BasicSizeTail =
|
||||
(tempLoopInfo.s2BasicSizeTail == 0) ? constInfo.s2BaseSize : tempLoopInfo.s2BasicSizeTail;
|
||||
tempLoopInfo.mBasicSizeTail = (tempLoopInfo.actS1Size * constInfo.gSize) % constInfo.mBaseSize;
|
||||
tempLoopInfo.mBasicSizeTail =
|
||||
(tempLoopInfo.mBasicSizeTail == 0) ? constInfo.mBaseSize : tempLoopInfo.mBasicSizeTail;
|
||||
|
||||
uint32_t gS1SplitNum = (tempLoopInfo.actS1Size * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize;
|
||||
tempLoopInfo.gS1LoopEnd = (bN2LoopIdx == splitCoreInfo.bN2End) ? splitCoreInfo.gS1End : gS1SplitNum - 1;
|
||||
if constexpr (LAYOUT_T == LI_LAYOUT::BSND) {
|
||||
if (tempLoopInfo.gS1LoopEnd == gS1SplitNum - 1 && constInfo.qSeqSize > tempLoopInfo.actS1Size) {
|
||||
tempLoopInfo.needDealActS1LessThanS1 = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIPreload<LIT>::CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LICommon::RunInfo &runInfo)
|
||||
{
|
||||
runInfo.loop = loop;
|
||||
runInfo.bIdx = tempLoopInfo.bIdx;
|
||||
runInfo.gS1Idx = tempLoopInfo.gS1Idx;
|
||||
runInfo.s2Idx = s2LoopIdx;
|
||||
runInfo.bN2Idx = tempLoopInfo.bN2Idx;
|
||||
|
||||
runInfo.actS1Size = tempLoopInfo.actS1Size;
|
||||
runInfo.actS2Size = tempLoopInfo.actS2Size;
|
||||
runInfo.actMBaseSize = tempLoopInfo.actMBaseSize;
|
||||
runInfo.actualSingleProcessSInnerSize = constInfo.s2BaseSize;
|
||||
uint32_t s2SplitNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
|
||||
if (runInfo.s2Idx == s2SplitNum - 1) {
|
||||
runInfo.actualSingleProcessSInnerSize = tempLoopInfo.s2BasicSizeTail;
|
||||
}
|
||||
runInfo.actualSingleProcessSInnerSizeAlign =
|
||||
LICommon::Align((uint32_t)runInfo.actualSingleProcessSInnerSize, LICommon::ConstInfo::BUFFER_SIZE_BYTE_32B);
|
||||
|
||||
runInfo.isFirstS2InnerLoop = s2LoopIdx == splitCoreInfo.s2Start;
|
||||
runInfo.isLastS2InnerLoop = s2LoopIdx == tempLoopInfo.s2LoopEnd;
|
||||
runInfo.isAllLoopEnd = (runInfo.bN2Idx == splitCoreInfo.bN2End) && (runInfo.gS1Idx == splitCoreInfo.gS1End) &&
|
||||
(runInfo.s2Idx == splitCoreInfo.s2End);
|
||||
|
||||
if (runInfo.isFirstS2InnerLoop) {
|
||||
uint64_t actualSeqQPrefixSum;
|
||||
uint64_t actualSeqKPrefixSum;
|
||||
if constexpr (LAYOUT_T == LI_LAYOUT::TND) {
|
||||
actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGmQ.GetValue(runInfo.bIdx - 1);
|
||||
actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGm.GetValue(runInfo.bIdx - 1);
|
||||
} else { // BSND
|
||||
actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.qSeqSize;
|
||||
actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.kSeqSize;
|
||||
}
|
||||
uint64_t tndBIdxOffset = actualSeqQPrefixSum * constInfo.qHeadNum * constInfo.headDim;
|
||||
uint64_t tndKeyBIdxOffset = actualSeqKPrefixSum * constInfo.kHeadNum * constInfo.headDim;
|
||||
// B,S1,N1(N2,G),D
|
||||
queryCoreOffset = tndBIdxOffset + runInfo.gS1Idx * constInfo.mBaseSize * constInfo.headDim;
|
||||
keyCoreOffset = tndKeyBIdxOffset + runInfo.n2Idx * constInfo.headDim;
|
||||
// B,S1,N1(N2,G)/T,N1(N2,G)
|
||||
weightsCoreOffset = actualSeqQPrefixSum * constInfo.qHeadNum + runInfo.n2Idx * constInfo.gSize;
|
||||
// B,S1,N2,k/T,N2,k
|
||||
indiceOutCoreOffset = actualSeqQPrefixSum * constInfo.kHeadNum * constInfo.sparseCount +
|
||||
runInfo.n2Idx * constInfo.sparseCount;
|
||||
}
|
||||
runInfo.tensorQueryOffset = queryCoreOffset;
|
||||
runInfo.tensorKeyOffset = keyCoreOffset + runInfo.s2Idx * constInfo.s2BaseSize * constInfo.kHeadNum
|
||||
* constInfo.headDim;
|
||||
runInfo.tensorWeightsOffset = weightsCoreOffset;
|
||||
runInfo.indiceOutOffset = indiceOutCoreOffset;
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIPreload<LIT>::Process()
|
||||
{
|
||||
if (usedCoreNum == 0) {
|
||||
ProcessInvalid();
|
||||
return;
|
||||
}
|
||||
ProcessMain();
|
||||
ProcessDecode();
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIPreload<LIT>::ProcessInvalid()
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
uint32_t aivCoreNum = GetBlockNum() * 2; // 2 means c:v = 1:2
|
||||
uint64_t totalOutputSize =
|
||||
constInfo.batchSize * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount;
|
||||
uint64_t singleCoreSize =
|
||||
LICommon::Align((totalOutputSize + aivCoreNum - 1) / aivCoreNum, GM_ALIGN_BYTES / sizeof(OUT_T));
|
||||
uint64_t baseSize = tmpBlockIdx * singleCoreSize;
|
||||
if (baseSize < totalOutputSize) {
|
||||
uint64_t dealSize =
|
||||
(baseSize + singleCoreSize > totalOutputSize) ? singleCoreSize : totalOutputSize - baseSize;
|
||||
GlobalTensor<OUT_T> output = indiceOutGm[baseSize];
|
||||
AscendC::InitGlobalMemory(output, dealSize, constInfo.INVALID_IDX);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIPreload<LIT>::ProcessMain()
|
||||
{
|
||||
if (aiCoreIdx >= usedCoreNum) {
|
||||
return;
|
||||
}
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.AllocEventID();
|
||||
CrossCoreSetFlag<LICommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE2>(constInfo.syncV1C1);
|
||||
CrossCoreSetFlag<LICommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE2>(constInfo.syncV1C1);
|
||||
} else {
|
||||
matmulService.AllocEventID();
|
||||
}
|
||||
|
||||
LICommon::RunInfo runInfo;
|
||||
uint32_t gloop = 0;
|
||||
for (uint32_t bN2LoopIdx = splitCoreInfo.bN2Start; bN2LoopIdx <= splitCoreInfo.bN2End; bN2LoopIdx++) {
|
||||
CalcGS1LoopParams(bN2LoopIdx);
|
||||
if (tempLoopInfo.curActSeqLenIsZero) {
|
||||
DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, 0U);
|
||||
continue;
|
||||
}
|
||||
for (uint32_t gS1LoopIdx = splitCoreInfo.gS1Start; gS1LoopIdx <= tempLoopInfo.gS1LoopEnd; gS1LoopIdx++) {
|
||||
CalcS2LoopParams(bN2LoopIdx, gS1LoopIdx);
|
||||
for (int s2LoopIdx = splitCoreInfo.s2Start; s2LoopIdx <= tempLoopInfo.s2LoopEnd; s2LoopIdx++) {
|
||||
ProcessBaseBlock(gloop, s2LoopIdx, runInfo);
|
||||
++gloop;
|
||||
}
|
||||
splitCoreInfo.s2Start = 0;
|
||||
}
|
||||
if (tempLoopInfo.needDealActS1LessThanS1) {
|
||||
DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, tempLoopInfo.actS1Size);
|
||||
}
|
||||
splitCoreInfo.gS1Start = 0;
|
||||
}
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.FreeEventID();
|
||||
} else {
|
||||
matmulService.FreeEventID();
|
||||
CrossCoreWaitFlag(constInfo.syncV1C1);
|
||||
CrossCoreWaitFlag(constInfo.syncV1C1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIPreload<LIT>::ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx, LICommon::RunInfo &runInfo)
|
||||
{
|
||||
CalcRunInfo(loop, s2LoopIdx, runInfo);
|
||||
if ASCEND_IS_AIC {
|
||||
CrossCoreWaitFlag(constInfo.syncV1C1);
|
||||
matmulService.ComputeMm1(runInfo);
|
||||
CrossCoreSetFlag<LICommon::ConstInfo::FIA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V1);
|
||||
} else {
|
||||
CrossCoreWaitFlag(constInfo.syncC1V1);
|
||||
vectorService.ProcessVec(runInfo);
|
||||
CrossCoreSetFlag<LICommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE2>(constInfo.syncV1C1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIPreload<LIT>::ProcessDecode()
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.InitLDBuffers(pipe);
|
||||
ICachePreLoad(LD_PREFETCH_LEN);
|
||||
SyncAll();
|
||||
if (splitCoreInfo.isLD) {
|
||||
vectorService.ProcessLD();
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace LIKernel
|
||||
#endif // LIGHTNING_INDEXER_KERNEL_H
|
||||
@@ -0,0 +1,415 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_service_cube.h
|
||||
* \brief use 5 buffer for matmul l1, better pipeline
|
||||
*/
|
||||
#ifndef LIGHTNING_INDEXER_SERVICE_CUBE_H
|
||||
#define LIGHTNING_INDEXER_SERVICE_CUBE_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_operator_list_tensor_intf.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
#include "lib/matrix/matmul/tiling.h"
|
||||
#include "lightning_indexer_common.h"
|
||||
|
||||
namespace LIKernel {
|
||||
using namespace LICommon;
|
||||
template <typename LIT>
|
||||
class LIMatmul {
|
||||
public:
|
||||
using Q_T = typename LIT::queryType;
|
||||
using K_T = typename LIT::keyType;
|
||||
|
||||
__aicore__ inline LIMatmul(){};
|
||||
__aicore__ inline void InitBuffers(TPipe *pipe);
|
||||
__aicore__ inline void InitMm1GlobalTensor(const GlobalTensor<int32_t> &blkTableGm, const GlobalTensor<K_T> &keyGm,
|
||||
const GlobalTensor<Q_T> &queryGm, const GlobalTensor<float> &mm1ResGm);
|
||||
__aicore__ inline void InitParams(const ConstInfo &constInfo);
|
||||
__aicore__ inline void AllocEventID();
|
||||
__aicore__ inline void FreeEventID();
|
||||
__aicore__ inline void ComputeMm1(const LICommon::RunInfo &runInfo);
|
||||
|
||||
static constexpr IsResetLoad3dConfig LOAD3DV2_CONFIG = {true, true}; // isSetFMatrix isSetPadding;
|
||||
static constexpr uint64_t KEY_BUF_NUM = 3;
|
||||
static constexpr uint64_t QUERY_BUF_NUM = 2;
|
||||
static constexpr uint64_t L0_BUF_NUM = 2;
|
||||
|
||||
static constexpr uint32_t KEY_MTE1_MTE2_EVENT = EVENT_ID2;
|
||||
static constexpr uint32_t QUERY_MTE1_MTE2_EVENT = EVENT_ID5; // KEY_MTE1_MTE2_EVENT + KEY_BUF_NUM;
|
||||
static constexpr uint32_t M_MTE1_EVENT = EVENT_ID3;
|
||||
|
||||
static constexpr uint32_t MTE2_MTE1_EVENT = EVENT_ID2;
|
||||
static constexpr uint32_t MTE1_M_EVENT = EVENT_ID2;
|
||||
|
||||
static constexpr uint64_t M_BASIC_BLOCK = 256;
|
||||
static constexpr uint64_t D_BASIC_BLOCK = 128;
|
||||
static constexpr uint64_t S2_BASIC_BLOCK = 256;
|
||||
|
||||
static constexpr uint64_t M_BASIC_BLOCK_L0 = 128;
|
||||
static constexpr uint64_t D_BASIC_BLOCK_L0 = 128;
|
||||
static constexpr uint64_t S2_BASIC_BLOCK_L0 = 128;
|
||||
|
||||
static constexpr uint64_t QUERY_BUFFER_OFFSET = M_BASIC_BLOCK * D_BASIC_BLOCK;
|
||||
static constexpr uint64_t KEY_BUFFER_OFFSET = S2_BASIC_BLOCK * D_BASIC_BLOCK;
|
||||
static constexpr uint64_t L0AB_BUFFER_OFFSET = M_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0;
|
||||
static constexpr uint64_t L0C_BUFFER_OFFSET = M_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0;
|
||||
|
||||
protected:
|
||||
__aicore__ inline void Fixp(uint64_t s1gGmOffset, uint64_t s2GmOffset, uint64_t s1gL0RealSize,
|
||||
uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo);
|
||||
__aicore__ inline void ComuteL0c(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo);
|
||||
__aicore__ inline void LoadKeyToL0b(uint64_t s2L0Offset, uint64_t s2L1RealSize, uint64_t s2L0RealSize,
|
||||
const LICommon::RunInfo &runInfo);
|
||||
__aicore__ inline void LoadQueryToL0a(uint64_t s1gL1Offset, uint64_t s1gL0Offset, uint64_t s1gL1RealSize,
|
||||
uint64_t s1gL0RealSize, const LICommon::RunInfo &runInfo);
|
||||
__aicore__ inline void QueryNd2Nz(uint64_t s1gL1RealSize, uint64_t s1gL1Offset, const LICommon::RunInfo &runInfo);
|
||||
__aicore__ inline void KeyNd2Nz(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LICommon::RunInfo &runInfo);
|
||||
__aicore__ inline void KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LICommon::RunInfo &runInfo);
|
||||
GlobalTensor<int32_t> blkTableGm_;
|
||||
GlobalTensor<K_T> keyGm_;
|
||||
GlobalTensor<Q_T> queryGm_;
|
||||
GlobalTensor<float> mm1ResGm_;
|
||||
|
||||
TBuf<TPosition::A1> bufQL1_;
|
||||
LocalTensor<Q_T> queryL1_;
|
||||
TBuf<TPosition::B1> bufKeyL1_;
|
||||
LocalTensor<K_T> keyL1_;
|
||||
|
||||
TBuf<TPosition::A2> bufQL0_;
|
||||
LocalTensor<Q_T> queryL0_;
|
||||
TBuf<TPosition::B2> bufKeyL0_;
|
||||
LocalTensor<K_T> keyL0_;
|
||||
|
||||
TBuf<TPosition::CO1> bufL0C_;
|
||||
LocalTensor<float> cL0_;
|
||||
|
||||
uint64_t keyL1BufIdx_ = 0;
|
||||
uint64_t queryL1Mte2BufIdx_ = 0;
|
||||
uint64_t queryL1Mte1BufIdx_ = 0;
|
||||
uint64_t l0BufIdx_ = 0;
|
||||
|
||||
ConstInfo constInfo_;
|
||||
|
||||
private:
|
||||
static constexpr bool PAGE_ATTENTION = LIT::pageAttention;
|
||||
};
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIMatmul<LIT>::InitParams(const ConstInfo &constInfo)
|
||||
{
|
||||
constInfo_ = constInfo;
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIMatmul<LIT>::InitBuffers(TPipe *pipe)
|
||||
{
|
||||
pipe->InitBuffer(bufQL1_, QUERY_BUF_NUM * M_BASIC_BLOCK * D_BASIC_BLOCK * sizeof(Q_T));
|
||||
queryL1_ = bufQL1_.Get<Q_T>();
|
||||
pipe->InitBuffer(bufKeyL1_, KEY_BUF_NUM * S2_BASIC_BLOCK * D_BASIC_BLOCK * sizeof(K_T));
|
||||
keyL1_ = bufKeyL1_.Get<K_T>();
|
||||
|
||||
pipe->InitBuffer(bufQL0_, L0_BUF_NUM * M_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 * sizeof(Q_T));
|
||||
queryL0_ = bufQL0_.Get<Q_T>();
|
||||
pipe->InitBuffer(bufKeyL0_, L0_BUF_NUM * D_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0 * sizeof(K_T));
|
||||
keyL0_ = bufKeyL0_.Get<K_T>();
|
||||
|
||||
pipe->InitBuffer(bufL0C_, L0_BUF_NUM * M_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0 * sizeof(float));
|
||||
cL0_ = bufL0C_.Get<float>();
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void
|
||||
LIMatmul<LIT>::InitMm1GlobalTensor(const GlobalTensor<int32_t> &blkTableGm, const GlobalTensor<K_T> &keyGm,
|
||||
const GlobalTensor<Q_T> &queryGm, const GlobalTensor<float> &mm1ResGm)
|
||||
{
|
||||
blkTableGm_ = blkTableGm;
|
||||
keyGm_ = keyGm;
|
||||
queryGm_ = queryGm;
|
||||
mm1ResGm_ = mm1ResGm;
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIMatmul<LIT>::ComputeMm1(const LICommon::RunInfo &runInfo)
|
||||
{
|
||||
uint64_t s2GmBaseOffset = runInfo.s2Idx * constInfo_.s2BaseSize;
|
||||
uint64_t s1gProcessSize = runInfo.actMBaseSize;
|
||||
uint64_t s2ProcessSize = runInfo.actualSingleProcessSInnerSize;
|
||||
for (uint64_t s2GmOffset = 0; s2GmOffset < s2ProcessSize; s2GmOffset += S2_BASIC_BLOCK) {
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % KEY_BUF_NUM);
|
||||
uint64_t s2L1RealSize =
|
||||
s2GmOffset + S2_BASIC_BLOCK > s2ProcessSize ? s2ProcessSize - s2GmOffset : S2_BASIC_BLOCK;
|
||||
if (PAGE_ATTENTION) {
|
||||
KeyNd2NzForPA(s2L1RealSize, s2GmBaseOffset + s2GmOffset, runInfo);
|
||||
}else {
|
||||
KeyNd2Nz(s2L1RealSize, s2GmOffset, runInfo);
|
||||
}
|
||||
|
||||
SetFlag<HardEvent::MTE2_MTE1>(MTE2_MTE1_EVENT);
|
||||
WaitFlag<HardEvent::MTE2_MTE1>(MTE2_MTE1_EVENT);
|
||||
for (uint64_t s1gGmOffset = 0; s1gGmOffset < s1gProcessSize; s1gGmOffset += M_BASIC_BLOCK) {
|
||||
uint64_t s1gL1RealSize =
|
||||
s1gGmOffset + M_BASIC_BLOCK > s1gProcessSize ? s1gProcessSize - s1gGmOffset : M_BASIC_BLOCK;
|
||||
if (runInfo.isFirstS2InnerLoop && s2GmOffset == 0) {
|
||||
queryL1Mte2BufIdx_++;
|
||||
queryL1Mte1BufIdx_ = queryL1Mte2BufIdx_;
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(QUERY_MTE1_MTE2_EVENT + queryL1Mte2BufIdx_ % QUERY_BUF_NUM);
|
||||
QueryNd2Nz(s1gL1RealSize, s1gGmOffset, runInfo);
|
||||
SetFlag<HardEvent::MTE2_MTE1>(MTE2_MTE1_EVENT);
|
||||
WaitFlag<HardEvent::MTE2_MTE1>(MTE2_MTE1_EVENT);
|
||||
} else {
|
||||
queryL1Mte1BufIdx_ =
|
||||
queryL1Mte2BufIdx_ - (CeilDiv(s1gProcessSize, M_BASIC_BLOCK) - 1 - (s1gGmOffset > 0));
|
||||
}
|
||||
for (uint64_t s2L1Offset = 0; s2L1Offset < s2L1RealSize; s2L1Offset += S2_BASIC_BLOCK_L0) {
|
||||
uint64_t s2L0RealSize =
|
||||
s2L1Offset + S2_BASIC_BLOCK_L0 > s2L1RealSize ? s2L1RealSize - s2L1Offset : S2_BASIC_BLOCK_L0;
|
||||
for (uint64_t s1gL1Offset = 0; s1gL1Offset < s1gL1RealSize; s1gL1Offset += M_BASIC_BLOCK_L0) {
|
||||
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + l0BufIdx_ % L0_BUF_NUM);
|
||||
uint64_t s1gL0RealSize =
|
||||
s1gL1Offset + M_BASIC_BLOCK_L0 > s1gL1RealSize ? s1gL1RealSize - s1gL1Offset : M_BASIC_BLOCK_L0;
|
||||
LoadQueryToL0a(s1gGmOffset, s1gL1Offset, s1gL1RealSize, s1gL0RealSize, runInfo);
|
||||
LoadKeyToL0b(s2L1Offset, s2L1RealSize, s2L0RealSize, runInfo);
|
||||
|
||||
SetFlag<HardEvent::MTE1_M>(MTE1_M_EVENT);
|
||||
WaitFlag<HardEvent::MTE1_M>(MTE1_M_EVENT);
|
||||
|
||||
ComuteL0c(s1gL0RealSize, s2L0RealSize, runInfo);
|
||||
|
||||
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + l0BufIdx_ % L0_BUF_NUM);
|
||||
|
||||
Fixp(s1gGmOffset + s1gL1Offset, s2GmOffset + s2L1Offset, s1gL0RealSize, s2L0RealSize, runInfo);
|
||||
l0BufIdx_++;
|
||||
}
|
||||
}
|
||||
if (s2GmOffset + S2_BASIC_BLOCK >= s2ProcessSize && runInfo.isLastS2InnerLoop) {
|
||||
SetFlag<HardEvent::MTE1_MTE2>(QUERY_MTE1_MTE2_EVENT + queryL1Mte1BufIdx_ % QUERY_BUF_NUM);
|
||||
}
|
||||
}
|
||||
|
||||
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % KEY_BUF_NUM);
|
||||
keyL1BufIdx_++;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIMatmul<LIT>::KeyNd2Nz(uint64_t s2L1RealSize, uint64_t s2GmOffset,
|
||||
const LICommon::RunInfo &runInfo)
|
||||
{
|
||||
uint64_t s2L1Offset = 0;
|
||||
while (s2L1Offset < s2L1RealSize) {
|
||||
uint64_t keyGmOffset = runInfo.tensorKeyOffset + (s2GmOffset + s2L1Offset) * constInfo_.headDim;
|
||||
uint64_t s2Mte2Size = (s2L1RealSize <= S2_BASIC_BLOCK_L0 || s2L1Offset >= S2_BASIC_BLOCK_L0) ?
|
||||
s2L1RealSize - s2L1Offset :
|
||||
S2_BASIC_BLOCK_L0 - s2L1Offset;
|
||||
|
||||
Nd2NzParams nd2nzPara;
|
||||
nd2nzPara.ndNum = 1;
|
||||
nd2nzPara.nValue = s2Mte2Size; // 行数
|
||||
nd2nzPara.dValue = constInfo_.headDim;
|
||||
nd2nzPara.srcDValue = constInfo_.headDim;
|
||||
nd2nzPara.dstNzC0Stride = s2L1Offset >= S2_BASIC_BLOCK_L0 ?
|
||||
CeilAlign(s2L1RealSize - S2_BASIC_BLOCK_L0, (uint64_t)BLOCK_CUBE) :
|
||||
(s2L1RealSize > S2_BASIC_BLOCK_L0 ?
|
||||
S2_BASIC_BLOCK_L0 :
|
||||
CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE));
|
||||
nd2nzPara.dstNzNStride = 1;
|
||||
nd2nzPara.srcNdMatrixStride = 0;
|
||||
nd2nzPara.dstNzMatrixStride = 0;
|
||||
DataCopy(keyL1_[(keyL1BufIdx_ % KEY_BUF_NUM) * KEY_BUFFER_OFFSET +
|
||||
(s2L1Offset >= S2_BASIC_BLOCK_L0 ?
|
||||
S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 + (s2L1Offset - S2_BASIC_BLOCK_L0) * BLOCK_CUBE :
|
||||
s2L1Offset * BLOCK_CUBE)],
|
||||
keyGm_[keyGmOffset], nd2nzPara);
|
||||
|
||||
s2L1Offset += s2Mte2Size;
|
||||
}
|
||||
}
|
||||
|
||||
// blkNum, blkSize, N2, D
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIMatmul<LIT>::KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset,
|
||||
const LICommon::RunInfo &runInfo)
|
||||
{
|
||||
uint64_t s2L1Offset = 0;
|
||||
while (s2L1Offset < s2L1RealSize) {
|
||||
uint64_t s2BlkId = (s2L1Offset + s2GmOffset) / constInfo_.kCacheBlockSize;
|
||||
uint64_t s2BlkOffset = (s2L1Offset + s2GmOffset) % constInfo_.kCacheBlockSize;
|
||||
uint64_t keyGmOffset = blkTableGm_.GetValue(runInfo.bIdx * constInfo_.maxBlockNumPerBatch + s2BlkId) *
|
||||
constInfo_.kCacheBlockSize * constInfo_.kHeadNum * constInfo_.headDim +
|
||||
s2BlkOffset * constInfo_.headDim;
|
||||
uint64_t s2Mte2Size = (s2L1RealSize <= S2_BASIC_BLOCK_L0 || s2L1Offset >= S2_BASIC_BLOCK_L0) ?
|
||||
s2L1RealSize - s2L1Offset :
|
||||
S2_BASIC_BLOCK_L0 - s2L1Offset;
|
||||
s2Mte2Size = s2BlkOffset + s2Mte2Size >= constInfo_.kCacheBlockSize ? constInfo_.kCacheBlockSize - s2BlkOffset :
|
||||
s2Mte2Size;
|
||||
Nd2NzParams nd2nzPara;
|
||||
nd2nzPara.ndNum = 1;
|
||||
nd2nzPara.nValue = s2Mte2Size;
|
||||
nd2nzPara.dValue = constInfo_.headDim;
|
||||
nd2nzPara.srcDValue = constInfo_.headDim;
|
||||
nd2nzPara.dstNzC0Stride = s2L1Offset >= S2_BASIC_BLOCK_L0 ?
|
||||
CeilAlign(s2L1RealSize - S2_BASIC_BLOCK_L0, (uint64_t)BLOCK_CUBE) :
|
||||
(s2L1RealSize > S2_BASIC_BLOCK_L0 ?
|
||||
S2_BASIC_BLOCK_L0 :
|
||||
CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE));
|
||||
nd2nzPara.dstNzNStride = 1;
|
||||
nd2nzPara.srcNdMatrixStride = 0;
|
||||
nd2nzPara.dstNzMatrixStride = 0;
|
||||
DataCopy(keyL1_[(keyL1BufIdx_ % KEY_BUF_NUM) * KEY_BUFFER_OFFSET +
|
||||
(s2L1Offset >= S2_BASIC_BLOCK_L0 ?
|
||||
S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 + (s2L1Offset - S2_BASIC_BLOCK_L0) * BLOCK_CUBE :
|
||||
s2L1Offset * BLOCK_CUBE)],
|
||||
keyGm_[keyGmOffset], nd2nzPara);
|
||||
|
||||
s2L1Offset += s2Mte2Size;
|
||||
}
|
||||
}
|
||||
|
||||
// batch, s1, n2, g, d
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIMatmul<LIT>::QueryNd2Nz(uint64_t s1gL1RealSize, uint64_t s1gGmOffset,
|
||||
const LICommon::RunInfo &runInfo)
|
||||
{
|
||||
Nd2NzParams nd2nzPara;
|
||||
nd2nzPara.ndNum = 1;
|
||||
nd2nzPara.nValue = s1gL1RealSize;
|
||||
nd2nzPara.dValue = constInfo_.headDim;
|
||||
nd2nzPara.srcDValue = constInfo_.headDim;
|
||||
nd2nzPara.dstNzC0Stride = CeilAlign(s1gL1RealSize, (uint64_t)BLOCK_CUBE);
|
||||
nd2nzPara.dstNzNStride = 1;
|
||||
nd2nzPara.srcNdMatrixStride = 0;
|
||||
nd2nzPara.dstNzMatrixStride = 0;
|
||||
DataCopy(queryL1_[(queryL1Mte2BufIdx_ % QUERY_BUF_NUM) * QUERY_BUFFER_OFFSET],
|
||||
queryGm_[runInfo.tensorQueryOffset + s1gGmOffset * constInfo_.headDim], nd2nzPara);
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIMatmul<LIT>::LoadQueryToL0a(uint64_t s1gGmOffset, uint64_t s1gL1Offset, uint64_t s1gL1RealSize,
|
||||
uint64_t s1gL0RealSize, const LICommon::RunInfo &runInfo)
|
||||
{
|
||||
LoadData3DParamsV2<Q_T> loadData3DParams;
|
||||
// SetFmatrixParams
|
||||
loadData3DParams.l1H = CeilDiv(s1gL1RealSize, BLOCK_CUBE); // Hin=M1=8
|
||||
loadData3DParams.l1W = BLOCK_CUBE; // Win=M0
|
||||
loadData3DParams.channelSize = constInfo_.headDim; // Cin=K
|
||||
|
||||
loadData3DParams.padList[0] = 0;
|
||||
loadData3DParams.padList[1] = 0;
|
||||
loadData3DParams.padList[2] = 0;
|
||||
loadData3DParams.padList[3] = 255;
|
||||
|
||||
// SetLoadToA0Params
|
||||
loadData3DParams.mExtension = CeilAlign(s1gL0RealSize, BLOCK_CUBE);
|
||||
loadData3DParams.kExtension = constInfo_.headDim;
|
||||
loadData3DParams.mStartPt = s1gL1Offset;
|
||||
loadData3DParams.kStartPt = 0;
|
||||
loadData3DParams.strideW = 1;
|
||||
loadData3DParams.strideH = 1;
|
||||
loadData3DParams.filterW = 1;
|
||||
loadData3DParams.filterSizeW = (1 >> 8) & 255;
|
||||
loadData3DParams.filterH = 1;
|
||||
loadData3DParams.filterSizeH = (1 >> 8) & 255;
|
||||
loadData3DParams.dilationFilterW = 1;
|
||||
loadData3DParams.dilationFilterH = 1;
|
||||
loadData3DParams.enTranspose = 0;
|
||||
loadData3DParams.fMatrixCtrl = 0;
|
||||
|
||||
LoadData<Q_T, LOAD3DV2_CONFIG>(queryL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET],
|
||||
queryL1_[(queryL1Mte1BufIdx_ % QUERY_BUF_NUM) * QUERY_BUFFER_OFFSET],
|
||||
loadData3DParams);
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIMatmul<LIT>::LoadKeyToL0b(uint64_t s2L1Offset, uint64_t s2L1RealSize, uint64_t s2L0RealSize,
|
||||
const LICommon::RunInfo &runInfo)
|
||||
{
|
||||
uint64_t keyL1Offset = s2L1Offset >= S2_BASIC_BLOCK_L0 ? S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 : 0;
|
||||
LoadData2DParams loadData2DParams;
|
||||
loadData2DParams.startIndex = 0;
|
||||
loadData2DParams.repeatTimes = CeilDiv(s2L0RealSize, BLOCK_CUBE) * CeilDiv(constInfo_.headDim, BLOCK_CUBE);
|
||||
loadData2DParams.srcStride = 1;
|
||||
loadData2DParams.dstGap = 0;
|
||||
loadData2DParams.ifTranspose = false;
|
||||
LoadData(keyL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET],
|
||||
keyL1_[(keyL1BufIdx_ % KEY_BUF_NUM) * KEY_BUFFER_OFFSET + keyL1Offset], loadData2DParams);
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIMatmul<LIT>::ComuteL0c(uint64_t s1gL0RealSize, uint64_t s2L0RealSize,
|
||||
const LICommon::RunInfo &runInfo)
|
||||
{
|
||||
MmadParams mmadParams;
|
||||
mmadParams.m = CeilAlign(s1gL0RealSize, BLOCK_CUBE);
|
||||
mmadParams.n = s2L0RealSize;
|
||||
mmadParams.k = constInfo_.headDim;
|
||||
mmadParams.cmatrixInitVal = true;
|
||||
mmadParams.cmatrixSource = false;
|
||||
mmadParams.unitFlag = 0b11;
|
||||
Mmad(cL0_[(l0BufIdx_ % L0_BUF_NUM) * L0C_BUFFER_OFFSET], queryL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET],
|
||||
keyL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET], mmadParams);
|
||||
if ((mmadParams.m / 16) * (mmadParams.n / 16) < 10) {
|
||||
PipeBarrier<PIPE_M>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIMatmul<LIT>::Fixp(uint64_t s1gGmOffset, uint64_t s2GmOffset, uint64_t s1gL0RealSize,
|
||||
uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo)
|
||||
{
|
||||
AscendC::DataCopyCO12DstParams intriParams;
|
||||
intriParams.mSize = CeilAlign(s1gL0RealSize, BLOCK_CUBE);
|
||||
intriParams.nSize = s2L0RealSize;
|
||||
intriParams.dstStride = runInfo.actualSingleProcessSInnerSizeAlign;
|
||||
intriParams.srcStride = CeilAlign(s1gL0RealSize, BLOCK_CUBE);
|
||||
// set mode according to dtype
|
||||
intriParams.quantPre = QuantMode_t::NoQuant;
|
||||
intriParams.nz2ndEn = true;
|
||||
intriParams.unitFlag = 0b11; // 3 unitflag
|
||||
intriParams.reluPre = 1;
|
||||
AscendC::SetFixpipeNz2ndFlag(1, 1, 1);
|
||||
AscendC::DataCopy(mm1ResGm_[(runInfo.loop % 2) * constInfo_.mBaseSize * constInfo_.s2BaseSize +
|
||||
s1gGmOffset * intriParams.dstStride + s2GmOffset],
|
||||
cL0_[(l0BufIdx_ % L0_BUF_NUM) * L0C_BUFFER_OFFSET], intriParams);
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIMatmul<LIT>::AllocEventID()
|
||||
{
|
||||
SetMMLayoutTransform(true);
|
||||
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 0);
|
||||
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 1);
|
||||
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 2);
|
||||
|
||||
SetFlag<HardEvent::MTE1_MTE2>(QUERY_MTE1_MTE2_EVENT + 0);
|
||||
SetFlag<HardEvent::MTE1_MTE2>(QUERY_MTE1_MTE2_EVENT + 1);
|
||||
|
||||
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 0);
|
||||
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 1);
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIMatmul<LIT>::FreeEventID()
|
||||
{
|
||||
SetMMLayoutTransform(false);
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 0);
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 1);
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 2);
|
||||
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(QUERY_MTE1_MTE2_EVENT + 0);
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(QUERY_MTE1_MTE2_EVENT + 1);
|
||||
|
||||
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 0);
|
||||
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 1);
|
||||
}
|
||||
} // namespace LIKernel
|
||||
#endif
|
||||
@@ -0,0 +1,559 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_service_vector.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef LIGHTNING_INDEXER_SERVICE_VECTOR_H
|
||||
#define LIGHTNING_INDEXER_SERVICE_VECTOR_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_operator_list_tensor_intf.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
#include "lib/matrix/matmul/tiling.h"
|
||||
#include "lightning_indexer_common.h"
|
||||
#include "lightning_indexer_vector.h"
|
||||
|
||||
namespace LIKernel {
|
||||
using namespace LICommon;
|
||||
using namespace LIServiceVec;
|
||||
constexpr uint32_t BASE_TOPK = 2048;
|
||||
constexpr uint32_t LD_PARAM_NUM = 16;
|
||||
|
||||
template <typename LIT>
|
||||
class LIVector {
|
||||
public:
|
||||
using K_T = typename LIT::keyType;
|
||||
static constexpr LI_LAYOUT LAYOUT_T = LIT::layout;
|
||||
|
||||
using MM1_OUT_T = float;
|
||||
|
||||
__aicore__ inline LIVector(){};
|
||||
__aicore__ inline void ProcessVec(const LICommon::RunInfo &info);
|
||||
__aicore__ inline void ProcessLD();
|
||||
__aicore__ inline void InitBuffers(TPipe *pipe);
|
||||
__aicore__ inline void InitParams(const struct LICommon::ConstInfo &constInfo,
|
||||
const LITilingData *__restrict tilingData);
|
||||
__aicore__ inline void InitVec1GlobalTensor(GlobalTensor<MM1_OUT_T> mm1ResGm, GlobalTensor<float> vec1ResGm,
|
||||
GlobalTensor<int64_t> vec1ParamGm, GlobalTensor<K_T> weightsGm,
|
||||
GlobalTensor<int32_t> indiceOutGm);
|
||||
__aicore__ inline void CleanInvalidOutput(int64_t invalidS1offset);
|
||||
__aicore__ inline void AllocEventID();
|
||||
__aicore__ inline void FreeEventID();
|
||||
__aicore__ inline void InitLDBuffers(TPipe *pipe);
|
||||
|
||||
protected:
|
||||
GlobalTensor<MM1_OUT_T> mm1ResGm;
|
||||
GlobalTensor<float> vec1ResGm;
|
||||
GlobalTensor<int64_t> vec1ParamGm;
|
||||
GlobalTensor<K_T> weightsGm;
|
||||
GlobalTensor<int32_t> indiceOutGm;
|
||||
|
||||
private:
|
||||
// queue
|
||||
TQue<QuePosition::VECIN, 1> inQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> outQueue_;
|
||||
|
||||
// tmp buff for vector
|
||||
TBuf<TPosition::VECCALC> sortOutBuf_;
|
||||
TBuf<TPosition::VECCALC> indexBuf_;
|
||||
TBuf<TPosition::VECCALC> reduceOutBuf_;
|
||||
TBuf<TPosition::VECCALC> brcBuf_;
|
||||
TBuf<TPosition::VECCALC> paramBuf_;
|
||||
|
||||
// tmp buff for LD
|
||||
TBuf<> ldToBeMrgBuf_;
|
||||
TBuf<> ldTmpBuf_;
|
||||
TBuf<> ldOutValueBuf_;
|
||||
TBuf<> ldOutIdxBuf_;
|
||||
|
||||
LocalTensor<int32_t> globalTopkIndice_;
|
||||
LocalTensor<float> globalTopkUb_;
|
||||
LocalTensor<float> SortedBasicBlock_;
|
||||
|
||||
int32_t blockId_ = -1;
|
||||
// para for vector
|
||||
int32_t groupInner_ = 0;
|
||||
int32_t globalTopkNum_ = 0;
|
||||
int64_t blockS2StartIdx_ = 0;
|
||||
int32_t gSize_ = 0;
|
||||
int32_t kHeadNum_ = 0;
|
||||
int32_t s1BaseSize_ = 0;
|
||||
int32_t s2BaseSize_ = 0;
|
||||
|
||||
// para for LD
|
||||
uint32_t mrgListNum_ = 4;
|
||||
uint32_t paramNum_ = 16;
|
||||
|
||||
constexpr static uint32_t REDUCE_BANK_CONFLICT_OFFSETS = 256;
|
||||
constexpr static uint32_t REDUCE_BANK_CONFLICT_NUM = REDUCE_BANK_CONFLICT_OFFSETS / sizeof(float);
|
||||
|
||||
struct LICommon::ConstInfo constInfo_;
|
||||
};
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIVector<LIT>::InitBuffers(TPipe *pipe)
|
||||
{
|
||||
uint32_t outNeedBufSize = (BASE_TOPK * 2) * 2 * sizeof(float);
|
||||
uint32_t reduceCacheSize = REDUCE_BANK_CONFLICT_OFFSETS + groupInner_ * s2BaseSize_ * sizeof(float);
|
||||
outNeedBufSize = reduceCacheSize > outNeedBufSize ? reduceCacheSize : outNeedBufSize;
|
||||
|
||||
pipe->InitBuffer(inQueue_, 2,
|
||||
groupInner_ * s2BaseSize_ * sizeof(float) + s2BaseSize_ * sizeof(float)); // 69KB mm_out_ub
|
||||
pipe->InitBuffer(outQueue_, 1, outNeedBufSize); // 32KB extract
|
||||
pipe->InitBuffer(sortOutBuf_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK * 2 * sizeof(float)); // 64KB
|
||||
pipe->InitBuffer(indexBuf_, s2BaseSize_ * sizeof(int32_t)); // 2KB
|
||||
pipe->InitBuffer(reduceOutBuf_, s2BaseSize_ * 2 * sizeof(float)); // 4KB
|
||||
pipe->InitBuffer(brcBuf_, groupInner_ * 8 * sizeof(float));
|
||||
pipe->InitBuffer(paramBuf_, LD_PARAM_NUM * sizeof(int64_t));
|
||||
|
||||
//
|
||||
globalTopkIndice_ = indexBuf_.Get<int32_t>();
|
||||
globalTopkUb_ = sortOutBuf_.Get<float>();
|
||||
SortedBasicBlock_ = globalTopkUb_[BASE_TOPK * 2 * 2];
|
||||
globalTopkNum_ = 0;
|
||||
|
||||
ArithProgression<int32_t>(globalTopkIndice_, 0, 1, s2BaseSize_);
|
||||
InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK * 2);
|
||||
LocalTensor<float> tmpfBuff = outQueue_.AllocTensor<float>();
|
||||
Duplicate(tmpfBuff.template ReinterpretCast<int32_t>(), -1, 2 * (s1BaseSize_ / 2) * paramNum_ * 2);
|
||||
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
|
||||
int64_t wsInfoOffset = (blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ +
|
||||
(blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_;
|
||||
DataCopyPad(vec1ParamGm[wsInfoOffset], tmpfBuff.template ReinterpretCast<int64_t>(),
|
||||
{1, static_cast<uint16_t>((s1BaseSize_ / 2) * 2 * paramNum_ * sizeof(int64_t)), 0, 0});
|
||||
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
|
||||
outQueue_.FreeTensor(tmpfBuff);
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIVector<LIT>::InitLDBuffers(TPipe *pipe)
|
||||
{
|
||||
pipe->Reset();
|
||||
pipe->InitBuffer(ldToBeMrgBuf_, 2 * BASE_TOPK * mrgListNum_ * sizeof(float)); // 2:value + index
|
||||
pipe->InitBuffer(ldTmpBuf_, 2 * BASE_TOPK * mrgListNum_ * sizeof(float)); // 2:value + index
|
||||
pipe->InitBuffer(ldOutValueBuf_, BASE_TOPK * sizeof(float));
|
||||
pipe->InitBuffer(ldOutIdxBuf_, BASE_TOPK * sizeof(int32_t));
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIVector<LIT>::InitParams(const struct LICommon::ConstInfo &constInfo,
|
||||
const LITilingData *__restrict tilingData)
|
||||
{
|
||||
this->constInfo_ = constInfo;
|
||||
blockS2StartIdx_ = 0;
|
||||
gSize_ = constInfo.gSize;
|
||||
// define N2 para
|
||||
kHeadNum_ = constInfo.kHeadNum;
|
||||
// define MMBase para
|
||||
s1BaseSize_ = constInfo.s1BaseSize;
|
||||
s2BaseSize_ = constInfo.s2BaseSize;
|
||||
|
||||
groupInner_ = 16;
|
||||
blockId_ = GetBlockIdx();
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void
|
||||
LIVector<LIT>::InitVec1GlobalTensor(GlobalTensor<MM1_OUT_T> mm1ResGm, GlobalTensor<float> vec1ResGm,
|
||||
GlobalTensor<int64_t> vec1ParamGm, GlobalTensor<K_T> weightsGm,
|
||||
GlobalTensor<int32_t> indiceOutGm)
|
||||
{
|
||||
this->mm1ResGm = mm1ResGm;
|
||||
this->vec1ResGm = vec1ResGm;
|
||||
this->vec1ParamGm = vec1ParamGm;
|
||||
this->weightsGm = weightsGm;
|
||||
this->indiceOutGm = indiceOutGm;
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIVector<LIT>::AllocEventID()
|
||||
{
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIVector<LIT>::FreeEventID()
|
||||
{
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIVector<LIT>::CleanInvalidOutput(int64_t invalidS1offset)
|
||||
{
|
||||
// init -1 and copy to output
|
||||
LocalTensor<float> valueULocal = outQueue_.AllocTensor<float>();
|
||||
LocalTensor<int32_t> idxULocal1 = valueULocal.template ReinterpretCast<int32_t>();
|
||||
Duplicate(idxULocal1, constInfo_.INVALID_IDX, constInfo_.sparseCount);
|
||||
outQueue_.EnQue<float>(valueULocal);
|
||||
valueULocal = outQueue_.DeQue<float>();
|
||||
LIServiceVec::CopyOut(indiceOutGm[invalidS1offset], idxULocal1, constInfo_.sparseCount);
|
||||
outQueue_.FreeTensor(valueULocal);
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIVector<LIT>::ProcessVec(const LICommon::RunInfo &info)
|
||||
{
|
||||
int32_t cuBaseS1Idx = info.gS1Idx * s1BaseSize_;
|
||||
int32_t cuBaseS2Idx = info.s2Idx * s2BaseSize_;
|
||||
|
||||
int64_t mmGmOffset = (info.loop % 2) * ((s1BaseSize_ * gSize_) * s2BaseSize_);
|
||||
int64_t weightGmOffset = info.tensorWeightsOffset + cuBaseS1Idx * kHeadNum_ * gSize_;
|
||||
|
||||
PipeBarrier<PIPE_V>();
|
||||
int32_t cuS1BeginIdxPerAiv = cuBaseS1Idx;
|
||||
int32_t cuS1ProcNum =
|
||||
cuS1BeginIdxPerAiv + s1BaseSize_ > info.actS1Size ? info.actS1Size % s1BaseSize_ : s1BaseSize_;
|
||||
int32_t cuS1ProcNumPerAiv = blockId_ % 2 == 0 ? CeilDiv(cuS1ProcNum, 2) : (cuS1ProcNum / 2);
|
||||
cuS1BeginIdxPerAiv += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2);
|
||||
|
||||
weightGmOffset += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2) * kHeadNum_ * gSize_;
|
||||
mmGmOffset += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2) * gSize_ * info.actualSingleProcessSInnerSizeAlign;
|
||||
|
||||
// cut G
|
||||
int32_t outerG = CeilDiv(gSize_, groupInner_);
|
||||
|
||||
if (info.loop != 0 && info.s2Idx == 0) {
|
||||
// globalTopkUb_ value,index=-inf,-1
|
||||
InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK * 2);
|
||||
blockS2StartIdx_ = 0;
|
||||
} else if (info.loop == 0) {
|
||||
blockS2StartIdx_ = info.s2Idx;
|
||||
}
|
||||
int32_t cuRealAcSeq = info.actS2Size;
|
||||
if (constInfo_.attenMaskFlag) {
|
||||
cuRealAcSeq = info.actS2Size - (info.actS1Size - cuS1BeginIdxPerAiv);
|
||||
}
|
||||
LocalTensor<float> reduceOutBuff = reduceOutBuf_.Get<float>();
|
||||
LocalTensor<float> brcBuf = brcBuf_.Get<float>();
|
||||
uint32_t ldS1Offset = (blockId_ % 2 == 0) ? s1BaseSize_ / 2 - cuS1ProcNumPerAiv : 0;
|
||||
for (int innerS1Idx = 0; innerS1Idx < cuS1ProcNumPerAiv; innerS1Idx++) {
|
||||
if (constInfo_.attenMaskFlag) {
|
||||
cuRealAcSeq += 1;
|
||||
}
|
||||
int32_t cuS2Len = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq ? cuRealAcSeq - cuBaseS2Idx : s2BaseSize_;
|
||||
int32_t cuS1Idx = cuS1BeginIdxPerAiv + innerS1Idx;
|
||||
if (cuRealAcSeq > 0 && cuS2Len > 0) {
|
||||
int32_t cuS2LenVecAlign = CeilDiv(cuS2Len, s2BaseSize_) * s2BaseSize_;
|
||||
int32_t mmUbStride = (cuS2LenVecAlign - info.actualSingleProcessSInnerSizeAlign) / B32_BLOCK_ALIGN_NUM;
|
||||
LocalTensor<float> reduceOutInner = reduceOutBuff[s2BaseSize_];
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<float> reduceCacheBuf = outQueue_.AllocTensor<float>();
|
||||
for (int outerGidx = 0; outerGidx < outerG; outerGidx++) {
|
||||
int32_t procGnum = outerGidx != outerG - 1 ? groupInner_ : gSize_ - outerGidx * groupInner_;
|
||||
LocalTensor<float> mmInUb = inQueue_.AllocTensor<float>();
|
||||
LocalTensor<float> weightsInUb = mmInUb[procGnum * s2BaseSize_];
|
||||
LocalTensor<K_T> weightsInTUb = weightsInUb.template ReinterpretCast<K_T>();
|
||||
if constexpr (!IsSameType<K_T, float>::value) {
|
||||
weightsInTUb = weightsInTUb[groupInner_];
|
||||
}
|
||||
LIServiceVec::CopyIn(mmInUb, weightsInTUb, mm1ResGm, weightsGm,
|
||||
mmGmOffset + innerS1Idx * gSize_ * info.actualSingleProcessSInnerSizeAlign +
|
||||
outerGidx * groupInner_ * info.actualSingleProcessSInnerSizeAlign,
|
||||
weightGmOffset + innerS1Idx * gSize_ + outerGidx * groupInner_, procGnum,
|
||||
info.actualSingleProcessSInnerSizeAlign, mmUbStride);
|
||||
|
||||
inQueue_.EnQue<float>(mmInUb);
|
||||
mmInUb = inQueue_.DeQue<float>();
|
||||
weightsInUb = mmInUb[procGnum * s2BaseSize_];
|
||||
LIServiceVec::DoScale(reduceCacheBuf[REDUCE_BANK_CONFLICT_NUM], mmInUb, weightsInUb, weightsInTUb,
|
||||
brcBuf, procGnum, s2BaseSize_, outerGidx);
|
||||
// confused reduceOp in DoScale
|
||||
// neednot use LIServiceVec::doReduce(mmInUb, reduceOutInner, procGnum, (s2BaseSize_+8));
|
||||
inQueue_.FreeTensor(mmInUb);
|
||||
}
|
||||
|
||||
int32_t gRedCnt = groupInner_ > gSize_ ? gSize_ : groupInner_;
|
||||
bool isS2End = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq;
|
||||
LIServiceVec::DoReduce(reduceCacheBuf[REDUCE_BANK_CONFLICT_NUM], reduceOutInner, gRedCnt, s2BaseSize_);
|
||||
outQueue_.FreeTensor(reduceCacheBuf);
|
||||
|
||||
LocalTensor<float> sortScoreUb = reduceOutBuff;
|
||||
LocalTensor<float> sortIndiceUb = reduceOutBuff[cuS2LenVecAlign];
|
||||
PipeBarrier<PIPE_V>();
|
||||
Duplicate(sortScoreUb.template ReinterpretCast<int32_t>(), LIServiceVec::NEG_INF, cuS2LenVecAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(sortScoreUb, reduceOutInner, 0.0f, cuS2Len);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<int32_t> sortIndiceUbInt = sortIndiceUb.template ReinterpretCast<int32_t>();
|
||||
if (cuS2LenVecAlign != cuS2Len) {
|
||||
Duplicate(sortIndiceUbInt, -1, cuS2LenVecAlign);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(sortIndiceUbInt, globalTopkIndice_, static_cast<int32_t>(cuBaseS2Idx), cuS2Len);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
LocalTensor<float> tmpSortBuf = outQueue_.AllocTensor<float>();
|
||||
if (info.actS1Size > 4) {
|
||||
LIServiceVec::SortAll(reduceOutBuff, tmpSortBuf,
|
||||
cuS2LenVecAlign); // cuS2LenVecAlign <= s2BaseSize_, fill -inf
|
||||
PipeBarrier<PIPE_V>();
|
||||
LIServiceVec::MergeSort(globalTopkUb_[innerS1Idx * BASE_TOPK * 2], BASE_TOPK, reduceOutBuff,
|
||||
cuS2LenVecAlign, tmpSortBuf);
|
||||
} else {
|
||||
int64_t globalTopkUbCacheIdx = (info.s2Idx - blockS2StartIdx_) % 4;
|
||||
Sort<float, true>(
|
||||
SortedBasicBlock_[innerS1Idx * BASE_TOPK * 2 + globalTopkUbCacheIdx * s2BaseSize_ * 2],
|
||||
reduceOutBuff, sortIndiceUbInt.template ReinterpretCast<uint32_t>(), tmpSortBuf,
|
||||
cuS2LenVecAlign / 32);
|
||||
if (globalTopkUbCacheIdx == 3 || isS2End || info.isAllLoopEnd) {
|
||||
LocalTensor<float> tt = SortedBasicBlock_[innerS1Idx * BASE_TOPK * 2];
|
||||
if (info.s2Idx - blockS2StartIdx_ < 4) {
|
||||
MrgBasicBlock(globalTopkUb_[innerS1Idx * BASE_TOPK * 2], tt,
|
||||
static_cast<int64_t>(globalTopkUbCacheIdx + 1), s2BaseSize_);
|
||||
} else {
|
||||
if (globalTopkUbCacheIdx > 0) {
|
||||
MrgBasicBlock(tmpSortBuf, tt, static_cast<int64_t>(globalTopkUbCacheIdx + 1), s2BaseSize_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
DataCopy(SortedBasicBlock_[innerS1Idx * BASE_TOPK * 2], tmpSortBuf,
|
||||
(globalTopkUbCacheIdx + 1) * s2BaseSize_ * 2);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
SparseTopK(globalTopkUb_[innerS1Idx * BASE_TOPK * 2],
|
||||
SortedBasicBlock_[innerS1Idx * BASE_TOPK * 2], tmpSortBuf, BASE_TOPK,
|
||||
s2BaseSize_ * (globalTopkUbCacheIdx + 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PipeBarrier<PIPE_V>();
|
||||
outQueue_.FreeTensor(tmpSortBuf);
|
||||
|
||||
bool needCopyOutGm = blockS2StartIdx_ == 0 && isS2End;
|
||||
bool needCopyWsGm = info.isAllLoopEnd || isS2End;
|
||||
|
||||
if (needCopyOutGm) {
|
||||
LocalTensor<float> valueULocal = outQueue_.AllocTensor<float>();
|
||||
LocalTensor<uint32_t> idxULocal = valueULocal.template ReinterpretCast<uint32_t>()[BASE_TOPK];
|
||||
ExtractIndex(idxULocal, globalTopkUb_[innerS1Idx * BASE_TOPK * 2].template ReinterpretCast<uint32_t>(),
|
||||
BASE_TOPK);
|
||||
PipeBarrier<PIPE_V>();
|
||||
InitSortOutBuf(globalTopkUb_[innerS1Idx * BASE_TOPK * 2], BASE_TOPK * 2);
|
||||
outQueue_.EnQue<float>(valueULocal);
|
||||
valueULocal = outQueue_.DeQue<float>();
|
||||
LocalTensor<int32_t> idxULocal1 = valueULocal.template ReinterpretCast<int32_t>()[BASE_TOPK];
|
||||
LIServiceVec::CopyOut(indiceOutGm[info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount],
|
||||
idxULocal1, constInfo_.sparseCount);
|
||||
outQueue_.FreeTensor(valueULocal);
|
||||
} else if (needCopyWsGm) {
|
||||
// vec1Res Gm = [aic, s1BaseSize_, 2, 2, topkOut_] float32
|
||||
// vec1Param Gm = [aic, s1BaseSize_, 2, 16] int64
|
||||
// 16 = [needFd, s2AcSeq, s2Start, s2End, isS2End, bn2idx, s1Idx, S1ProcNum, ......]
|
||||
|
||||
int64_t wsOffset = (blockId_ / 2) * s1BaseSize_ * 2 * 2 * BASE_TOPK +
|
||||
(blockId_ % 2) * (s1BaseSize_ / 2) * 2 * 2 * BASE_TOPK +
|
||||
(ldS1Offset + innerS1Idx) * 2 * 2 * BASE_TOPK;
|
||||
int64_t wsInfoOffset = (blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ +
|
||||
(blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_ +
|
||||
(ldS1Offset + innerS1Idx) * 2 * paramNum_;
|
||||
|
||||
LocalTensor<int64_t> tmpiBuff = paramBuf_.Get<int64_t>();
|
||||
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
|
||||
tmpiBuff.SetValue(0, static_cast<int64_t>(1));
|
||||
tmpiBuff.SetValue(1, static_cast<int64_t>(cuRealAcSeq));
|
||||
tmpiBuff.SetValue(2, static_cast<int64_t>(blockS2StartIdx_));
|
||||
tmpiBuff.SetValue(3, static_cast<int64_t>(cuBaseS2Idx + cuS2Len));
|
||||
tmpiBuff.SetValue(4, static_cast<int64_t>(isS2End));
|
||||
tmpiBuff.SetValue(5, static_cast<int64_t>(info.bN2Idx));
|
||||
tmpiBuff.SetValue(6, static_cast<int64_t>(cuS1Idx));
|
||||
tmpiBuff.SetValue(7, static_cast<int64_t>(cuS1ProcNum));
|
||||
tmpiBuff.SetValue(8, static_cast<int64_t>(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount));
|
||||
bool isTailReduce = blockS2StartIdx_ == 0;
|
||||
if (isTailReduce) {
|
||||
wsInfoOffset += paramNum_;
|
||||
wsOffset += 2 * BASE_TOPK;
|
||||
}
|
||||
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
|
||||
LIServiceVec::CopyOut(vec1ParamGm[wsInfoOffset], tmpiBuff, 16);
|
||||
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
|
||||
LIServiceVec::CopyOut(vec1ResGm[wsOffset], globalTopkUb_[innerS1Idx * BASE_TOPK * 2], 2 * BASE_TOPK);
|
||||
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
|
||||
}
|
||||
} else if (cuRealAcSeq <= 0) {
|
||||
CleanInvalidOutput(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount);
|
||||
}
|
||||
}
|
||||
|
||||
if (LAYOUT_T == LI_LAYOUT::BSND) {
|
||||
bool isS1LoopEnd = (cuBaseS1Idx + s1BaseSize_) >= info.actS1Size;
|
||||
int32_t invalidS1Num = constInfo_.qSeqSize - info.actS1Size;
|
||||
if (invalidS1Num > 0 && isS1LoopEnd && blockS2StartIdx_ == 0) {
|
||||
int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num, 2) : (invalidS1Num / 2);
|
||||
int32_t s1OffsetPerAiv = info.actS1Size + (blockId_ % 2) * CeilDiv(invalidS1Num, 2);
|
||||
for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) {
|
||||
CleanInvalidOutput(info.indiceOutOffset + (s1OffsetPerAiv + innerS1Idx) * constInfo_.sparseCount);
|
||||
}
|
||||
}
|
||||
|
||||
int32_t invalidS1Num2 = info.actS1Size - info.actS2Size;
|
||||
if (invalidS1Num2 > 0 && isS1LoopEnd && blockS2StartIdx_ == 0 && constInfo_.attenMaskFlag) {
|
||||
int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num2, 2) : (invalidS1Num2 / 2);
|
||||
int32_t s1OffsetPerAiv = (blockId_ % 2) * CeilDiv(invalidS1Num2, 2);
|
||||
for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) {
|
||||
CleanInvalidOutput((info.bN2Idx * constInfo_.qSeqSize + s1OffsetPerAiv + innerS1Idx) *
|
||||
constInfo_.sparseCount);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (info.isLastS2InnerLoop) {
|
||||
blockS2StartIdx_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIT>
|
||||
__aicore__ inline void LIVector<LIT>::ProcessLD()
|
||||
{
|
||||
int32_t curCubeId = blockId_ / 2;
|
||||
int32_t tmpCubeId = curCubeId;
|
||||
|
||||
int64_t s2ActSeq;
|
||||
int64_t s2Start;
|
||||
int64_t s2End;
|
||||
int64_t isS2End;
|
||||
int64_t bn2Idx;
|
||||
int64_t s1Idx;
|
||||
uint32_t acc_list_num = 0;
|
||||
int64_t bIdx = 0;
|
||||
int64_t needFd;
|
||||
int64_t wsOffset;
|
||||
int64_t wsInfoOffset = 0;
|
||||
int64_t nextneedFd;
|
||||
int64_t valueOffset = 0;
|
||||
int64_t outOffset = 0;
|
||||
|
||||
LocalTensor<float> curValueIdxUb = ldToBeMrgBuf_.Get<float>();
|
||||
LocalTensor<float> tmpUb = ldTmpBuf_.Get<float>();
|
||||
|
||||
uint32_t s1LdStartIdx = 0;
|
||||
uint32_t s1ProcNum = 0;
|
||||
uint64_t paramGmCoreOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_;
|
||||
for (uint32_t innerS1Idx = 0; innerS1Idx < s1BaseSize_; innerS1Idx++) {
|
||||
needFd = vec1ParamGm.GetValue(paramGmCoreOffset + innerS1Idx * 2 * paramNum_ + paramNum_);
|
||||
if (needFd == 1) {
|
||||
s1LdStartIdx = (s1ProcNum == 0) ? innerS1Idx : s1LdStartIdx;
|
||||
s1ProcNum++;
|
||||
}
|
||||
}
|
||||
|
||||
if (s1ProcNum == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t s1VecNum = CeilDiv(s1ProcNum, 2);
|
||||
if (blockId_ % 2 == 1) {
|
||||
s1LdStartIdx = s1LdStartIdx + s1VecNum;
|
||||
s1VecNum = s1ProcNum - s1VecNum;
|
||||
}
|
||||
for (uint32_t innerS1Idx = s1LdStartIdx; innerS1Idx < s1LdStartIdx + s1VecNum; innerS1Idx++) {
|
||||
tmpCubeId = curCubeId;
|
||||
acc_list_num = 0;
|
||||
valueOffset = 0;
|
||||
|
||||
wsOffset = tmpCubeId * s1BaseSize_ * 2 * 2 * BASE_TOPK +
|
||||
innerS1Idx * 2 * 2 * BASE_TOPK + 2 * BASE_TOPK;
|
||||
SetWaitFlag<HardEvent::V_MTE2>(HardEvent::V_MTE2);
|
||||
SetWaitFlag<HardEvent::S_MTE2>(HardEvent::S_MTE2);
|
||||
DataCopyPad(curValueIdxUb, vec1ResGm[wsOffset],
|
||||
{1, static_cast<uint16_t>(2 * BASE_TOPK * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0});
|
||||
acc_list_num++;
|
||||
valueOffset += 2 * BASE_TOPK;
|
||||
|
||||
tmpCubeId++;
|
||||
wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_;
|
||||
needFd = vec1ParamGm.GetValue(wsInfoOffset);
|
||||
isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4);
|
||||
s1Idx = vec1ParamGm.GetValue(wsInfoOffset + 6);
|
||||
outOffset = vec1ParamGm.GetValue(wsInfoOffset + 8);
|
||||
|
||||
while (needFd == 1) {
|
||||
wsOffset = tmpCubeId * s1BaseSize_ * 2 * 2 * BASE_TOPK +
|
||||
innerS1Idx * 2 * 2 * BASE_TOPK;
|
||||
SetWaitFlag<HardEvent::V_MTE2>(HardEvent::V_MTE2);
|
||||
SetWaitFlag<HardEvent::S_MTE2>(HardEvent::S_MTE2);
|
||||
DataCopyPad(curValueIdxUb[valueOffset], vec1ResGm[wsOffset],
|
||||
{1, static_cast<uint16_t>(2 * BASE_TOPK * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0});
|
||||
valueOffset += 2 * BASE_TOPK;
|
||||
acc_list_num++;
|
||||
|
||||
if (acc_list_num == mrgListNum_) {
|
||||
AscendC::MrgSort4Info params;
|
||||
params.elementLengths[0] = BASE_TOPK;
|
||||
params.elementLengths[1] = BASE_TOPK;
|
||||
params.elementLengths[2] = BASE_TOPK;
|
||||
params.elementLengths[3] = BASE_TOPK;
|
||||
params.ifExhaustedSuspension = true;
|
||||
params.validBit = 0b1111;
|
||||
params.repeatTimes = 1;
|
||||
|
||||
AscendC::MrgSortSrcList<float> srcList;
|
||||
srcList.src1 = curValueIdxUb[0];
|
||||
srcList.src2 = curValueIdxUb[2 * BASE_TOPK];
|
||||
srcList.src3 = curValueIdxUb[4 * BASE_TOPK];
|
||||
srcList.src4 = curValueIdxUb[6 * BASE_TOPK];
|
||||
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
|
||||
MrgSort(tmpUb, srcList, params);
|
||||
PipeBarrier<PIPE_V>();
|
||||
DataCopy(curValueIdxUb, tmpUb, 2 * BASE_TOPK);
|
||||
PipeBarrier<PIPE_V>();
|
||||
acc_list_num = 1;
|
||||
valueOffset = 2 * BASE_TOPK;
|
||||
}
|
||||
|
||||
if (isS2End == 1) {
|
||||
break;
|
||||
}
|
||||
|
||||
tmpCubeId++;
|
||||
wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_;
|
||||
needFd = vec1ParamGm.GetValue(wsInfoOffset);
|
||||
isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4);
|
||||
}
|
||||
|
||||
if (acc_list_num != 1) {
|
||||
AscendC::MrgSort4Info params;
|
||||
params.elementLengths[0] = BASE_TOPK;
|
||||
params.elementLengths[1] = BASE_TOPK;
|
||||
params.elementLengths[2] = BASE_TOPK;
|
||||
params.elementLengths[3] = BASE_TOPK;
|
||||
params.ifExhaustedSuspension = true;
|
||||
if (acc_list_num == 2) {
|
||||
params.validBit = 0b0011;
|
||||
} else if (acc_list_num == 3) {
|
||||
params.validBit = 0b0111;
|
||||
}
|
||||
params.repeatTimes = 1;
|
||||
|
||||
AscendC::MrgSortSrcList<float> srcList;
|
||||
srcList.src1 = curValueIdxUb[0];
|
||||
srcList.src2 = curValueIdxUb[2 * BASE_TOPK];
|
||||
srcList.src3 = curValueIdxUb[4 * BASE_TOPK];
|
||||
srcList.src4 = curValueIdxUb[6 * BASE_TOPK];
|
||||
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
|
||||
MrgSort(tmpUb, srcList, params);
|
||||
PipeBarrier<PIPE_V>();
|
||||
DataCopy(curValueIdxUb, tmpUb, 2 * BASE_TOPK);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
LocalTensor<float> outValueUb = ldOutValueBuf_.Get<float>();
|
||||
LocalTensor<uint32_t> outIdxUb = ldOutIdxBuf_.Get<uint32_t>();
|
||||
|
||||
Extract(outValueUb, outIdxUb, curValueIdxUb, (BASE_TOPK / 32));
|
||||
LocalTensor<int32_t> idxULocal1 = outIdxUb.template ReinterpretCast<int32_t>();
|
||||
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
|
||||
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
|
||||
DataCopyPad(indiceOutGm[outOffset], idxULocal1,
|
||||
{1, static_cast<uint16_t>(constInfo_.sparseCount * sizeof(int32_t)), 0, 0});
|
||||
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
|
||||
}
|
||||
}
|
||||
} // namespace LIKernel
|
||||
#endif
|
||||
@@ -0,0 +1,66 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_template_tiling_key.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef TEMPLATE_TILING_KEY_LI_H_
|
||||
#define TEMPLATE_TILING_KEY_LI_H_
|
||||
|
||||
#include "ascendc/host_api/tiling/template_argument.h"
|
||||
|
||||
#define LI_TPL_FP16 1
|
||||
#define LI_TPL_INT32 3
|
||||
#define LI_TPL_BF16 27
|
||||
|
||||
#define LI_LAYOUT_BSND 0
|
||||
#define LI_LAYOUT_TND 1
|
||||
#define LI_LAYOUT_PA_BSND 2
|
||||
|
||||
#define ASCENDC_TPL_4_BW 4
|
||||
|
||||
ASCENDC_TPL_ARGS_DECL(LightningIndexer,
|
||||
ASCENDC_TPL_DTYPE_DECL(DT_Q, LI_TPL_FP16, LI_TPL_BF16),
|
||||
ASCENDC_TPL_DTYPE_DECL(DT_K, LI_TPL_FP16, LI_TPL_BF16),
|
||||
ASCENDC_TPL_DTYPE_DECL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_DECL(PAGE_ATTENTION, 0, 1),
|
||||
ASCENDC_TPL_UINT_DECL(LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND,
|
||||
LI_LAYOUT_TND),
|
||||
ASCENDC_TPL_UINT_DECL(K_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST,
|
||||
LI_LAYOUT_PA_BSND, LI_LAYOUT_BSND, LI_LAYOUT_TND), );
|
||||
|
||||
ASCENDC_TPL_SEL(
|
||||
ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_FP16), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_FP16),
|
||||
ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32),
|
||||
ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 1),
|
||||
ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND),
|
||||
ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_PA_BSND), ),
|
||||
|
||||
ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_BF16), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_BF16),
|
||||
ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32),
|
||||
ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 1),
|
||||
ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND),
|
||||
ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_PA_BSND), ),
|
||||
|
||||
ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_FP16), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_FP16),
|
||||
ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32),
|
||||
ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 0),
|
||||
ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND),
|
||||
ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST,
|
||||
LI_LAYOUT_BSND, LI_LAYOUT_TND), ),
|
||||
|
||||
ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_BF16), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_BF16),
|
||||
ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32),
|
||||
ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 0),
|
||||
ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND),
|
||||
ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND), ), );
|
||||
|
||||
#endif
|
||||
335
csrc/lightning_indexer/op_kernel/lightning_indexer_vector.h
Normal file
335
csrc/lightning_indexer/op_kernel/lightning_indexer_vector.h
Normal file
@@ -0,0 +1,335 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_vector.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef LIGHTNING_INDEXER_VECTOR_H
|
||||
#define LIGHTNING_INDEXER_VECTOR_H
|
||||
|
||||
#include "lightning_indexer_vector.h"
|
||||
#include "kernel_operator.h"
|
||||
|
||||
namespace LIServiceVec {
|
||||
using namespace AscendC;
|
||||
|
||||
constexpr int32_t NEG_INF = 0xFF800000;
|
||||
constexpr int32_t INVALID_INDEX = -1;
|
||||
constexpr uint8_t VEC_REPEAT_MAX = 255;
|
||||
constexpr uint8_t B32_VEC_ELM_NUM = 64;
|
||||
constexpr uint8_t B32_BLOCK_ALIGN_NUM = 8;
|
||||
constexpr uint8_t B32_VEC_REPEAT_STRIDE = 8;
|
||||
constexpr uint64_t VEC_REPEAT_BYTES = 256;
|
||||
constexpr int32_t CONST_TWO = 2;
|
||||
constexpr int64_t VALUE_AND_INDEX_NUM = 2;
|
||||
constexpr int64_t BLOCK_BYTES = 32;
|
||||
constexpr int64_t MRG_QUE_0 = 0;
|
||||
constexpr int64_t MRG_QUE_1 = 1;
|
||||
constexpr int64_t MRG_QUE_2 = 2;
|
||||
constexpr int64_t MRG_QUE_3 = 3;
|
||||
constexpr int64_t MRG_BLOCK_2 = 2;
|
||||
constexpr int64_t MRG_BLOCK_3 = 3;
|
||||
constexpr int64_t MRG_BLOCK_4 = 4;
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void CopyIn(LocalTensor<float> &mmOutUb, LocalTensor<T> &weightsUb, GlobalTensor<float> &mMoutGm,
|
||||
GlobalTensor<T> &weightScaleGm, int64_t MMout_gmoffset, int64_t weights_gmoffset,
|
||||
int64_t groupInner, int64_t s2Inner, int64_t mmUbStride)
|
||||
{
|
||||
AscendC::DataCopyPadExtParams<float> padParams{false, 0, 0, 0};
|
||||
AscendC::DataCopyExtParams dataCopymMoutParams;
|
||||
dataCopymMoutParams.blockCount = groupInner;
|
||||
dataCopymMoutParams.blockLen = s2Inner * sizeof(float);
|
||||
dataCopymMoutParams.srcStride = 0;
|
||||
dataCopymMoutParams.dstStride = mmUbStride;
|
||||
dataCopymMoutParams.rsv = 0;
|
||||
AscendC::DataCopyPad(mmOutUb, mMoutGm[MMout_gmoffset], dataCopymMoutParams, padParams);
|
||||
|
||||
AscendC::DataCopyPadExtParams<T> padTParams{false, 0, 0, 0};
|
||||
AscendC::DataCopyExtParams dataCopyweightParams;
|
||||
dataCopyweightParams.blockCount = 1;
|
||||
dataCopyweightParams.blockLen = groupInner * sizeof(T);
|
||||
dataCopyweightParams.srcStride = 0;
|
||||
dataCopyweightParams.dstStride = 0;
|
||||
dataCopyweightParams.rsv = 0;
|
||||
AscendC::DataCopyPad(weightsUb, weightScaleGm[weights_gmoffset], dataCopyweightParams, padTParams);
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void CopyOut(const GlobalTensor<T> &dstGm, const LocalTensor<T> &srcUb, int64_t copyCount)
|
||||
{
|
||||
AscendC::DataCopyParams dataCopyOutyParams;
|
||||
dataCopyOutyParams.blockCount = 1;
|
||||
dataCopyOutyParams.blockLen = copyCount * sizeof(T);
|
||||
dataCopyOutyParams.srcStride = 0;
|
||||
dataCopyOutyParams.dstStride = 0;
|
||||
AscendC::DataCopyPad(dstGm, srcUb, dataCopyOutyParams);
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void DoScale(const LocalTensor<float> &reduceCacheBuf, LocalTensor<float> &mmOutUb,
|
||||
LocalTensor<float> &weightsUb, LocalTensor<T> &weightsTUb, LocalTensor<float> &tmpBuff,
|
||||
int64_t groupInner, int64_t s2Inner, int32_t outerGidx)
|
||||
{
|
||||
// cast bfloat16_t to float
|
||||
if constexpr (!IsSameType<T, float>::value) {
|
||||
AscendC::Cast(weightsUb, weightsTUb, RoundMode::CAST_NONE, groupInner);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
// weight broadcast: [groupInner, 1] -> [groupInner, 8]
|
||||
AscendC::Brcb(tmpBuff, weightsUb, LICommon::CeilDiv(groupInner, static_cast<int64_t>(B32_BLOCK_ALIGN_NUM)),
|
||||
{1, B32_VEC_REPEAT_STRIDE});
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
// do scale: [groupInner, 8] * [groupInner, s2Inner]
|
||||
uint64_t countPerRepeat = VEC_REPEAT_BYTES / sizeof(float);
|
||||
uint64_t repeatTimes = s2Inner / countPerRepeat;
|
||||
for (int32_t i = 0; i < groupInner; i++) {
|
||||
if (outerGidx == 0) {
|
||||
AscendC::Mul(reduceCacheBuf[i * s2Inner], mmOutUb[i * s2Inner], tmpBuff[i * B32_BLOCK_ALIGN_NUM],
|
||||
countPerRepeat, repeatTimes, {1, 1, 0, B32_VEC_REPEAT_STRIDE, B32_VEC_REPEAT_STRIDE, 0});
|
||||
} else {
|
||||
AscendC::Mul(mmOutUb[i * s2Inner], mmOutUb[i * s2Inner], tmpBuff[i * B32_BLOCK_ALIGN_NUM], countPerRepeat,
|
||||
repeatTimes, {1, 1, 0, B32_VEC_REPEAT_STRIDE, B32_VEC_REPEAT_STRIDE, 0});
|
||||
}
|
||||
}
|
||||
|
||||
if (outerGidx != 0) {
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Add(reduceCacheBuf, mmOutUb, reduceCacheBuf, groupInner * s2Inner);
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
|
||||
__aicore__ inline uint64_t FindNearestPower2(uint64_t value)
|
||||
{
|
||||
if (value <= CONST_TWO) {
|
||||
return value;
|
||||
} else {
|
||||
const uint64_t pow = 63 - clz(value);
|
||||
return (1 << pow);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__aicore__ inline void DoReduce(const LocalTensor<float> &srcTensor, LocalTensor<float> &dstTensor, int32_t rNum,
|
||||
int32_t aNum)
|
||||
{
|
||||
if (rNum == 1) {
|
||||
AscendC::Adds<float>(dstTensor, srcTensor, 0, aNum);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t dichotomizeAddPow = FindNearestPower2(rNum);
|
||||
uint32_t dichotomizeAddDiffSize = rNum - dichotomizeAddPow;
|
||||
if (dichotomizeAddDiffSize != 0) {
|
||||
AscendC::Add(srcTensor, srcTensor, srcTensor[dichotomizeAddPow * aNum], dichotomizeAddDiffSize * aNum);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
int32_t nowRows = dichotomizeAddPow;
|
||||
while (nowRows > CONST_TWO) {
|
||||
nowRows = nowRows / CONST_TWO;
|
||||
AscendC::Add(srcTensor, srcTensor, srcTensor[nowRows * aNum], nowRows * aNum);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
AscendC::Add(dstTensor, srcTensor, srcTensor[aNum], aNum);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void InitSortOutBuf(const LocalTensor<float> &src, int64_t eleNum)
|
||||
{
|
||||
uint64_t mask1[2] = {0x5555555555555555, 0};
|
||||
uint64_t mask0[2] = {0xaaaaaaaaaaaaaaaa, 0};
|
||||
int64_t repeatNum = eleNum / B32_VEC_ELM_NUM;
|
||||
int64_t forLoop = repeatNum / VEC_REPEAT_MAX;
|
||||
int64_t forRemain = repeatNum % VEC_REPEAT_MAX;
|
||||
for (int i = 0; i < forLoop; i++) {
|
||||
AscendC::Duplicate(src.template ReinterpretCast<int32_t>(), NEG_INF, mask1, VEC_REPEAT_MAX, 1,
|
||||
B32_VEC_REPEAT_STRIDE);
|
||||
AscendC::Duplicate(src.template ReinterpretCast<int32_t>(), INVALID_INDEX, mask0, VEC_REPEAT_MAX, 1,
|
||||
B32_VEC_REPEAT_STRIDE);
|
||||
}
|
||||
if (forRemain > 0) {
|
||||
AscendC::Duplicate(src.template ReinterpretCast<int32_t>()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], NEG_INF,
|
||||
mask1, forRemain, 1, B32_VEC_REPEAT_STRIDE);
|
||||
AscendC::Duplicate(src.template ReinterpretCast<int32_t>()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM],
|
||||
INVALID_INDEX, mask0, forRemain, 1, B32_VEC_REPEAT_STRIDE);
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void SortAll(LocalTensor<float> &src, LocalTensor<float> &tmp, int64_t logitsNum)
|
||||
{
|
||||
int64_t sort32Repeats = logitsNum / BLOCK_BYTES;
|
||||
AscendC::Sort32(tmp, src, src[logitsNum].ReinterpretCast<uint32_t>(), sort32Repeats);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
int64_t mrgGroups = sort32Repeats;
|
||||
int64_t mrgElements = BLOCK_BYTES;
|
||||
int64_t i = 0;
|
||||
AscendC::LocalTensor<float> srcTensor;
|
||||
AscendC::LocalTensor<float> dstTensor;
|
||||
while (true) {
|
||||
if (i % CONST_TWO == 0) {
|
||||
srcTensor = tmp;
|
||||
dstTensor = src;
|
||||
} else {
|
||||
srcTensor = src;
|
||||
dstTensor = tmp;
|
||||
}
|
||||
AscendC::MrgSort4Info params;
|
||||
params.elementLengths[0] = mrgElements;
|
||||
params.elementLengths[MRG_QUE_1] = mrgElements;
|
||||
params.elementLengths[MRG_QUE_2] = mrgElements;
|
||||
params.elementLengths[MRG_QUE_3] = mrgElements;
|
||||
params.ifExhaustedSuspension = false;
|
||||
params.validBit = 0b1111;
|
||||
|
||||
AscendC::MrgSortSrcList<float> srcList;
|
||||
srcList.src1 = srcTensor[0];
|
||||
srcList.src2 = srcTensor[MRG_QUE_1 * VALUE_AND_INDEX_NUM * mrgElements];
|
||||
srcList.src3 = srcTensor[MRG_QUE_2 * VALUE_AND_INDEX_NUM * mrgElements];
|
||||
srcList.src4 = srcTensor[MRG_QUE_3 * VALUE_AND_INDEX_NUM * mrgElements];
|
||||
if (mrgGroups <= MRG_BLOCK_4) {
|
||||
params.repeatTimes = 1;
|
||||
if (mrgGroups == 1) {
|
||||
break;
|
||||
} else if (mrgGroups == MRG_BLOCK_2) {
|
||||
params.validBit = 0b0011;
|
||||
} else if (mrgGroups == MRG_BLOCK_3) {
|
||||
params.validBit = 0b0111;
|
||||
} else if (mrgGroups == MRG_BLOCK_4) {
|
||||
params.validBit = 0b1111;
|
||||
}
|
||||
AscendC::MrgSort<float>(dstTensor, srcList, params);
|
||||
i += 1;
|
||||
break;
|
||||
} else {
|
||||
params.repeatTimes = mrgGroups / MRG_BLOCK_4;
|
||||
AscendC::MrgSort<float>(dstTensor, srcList, params);
|
||||
i += 1;
|
||||
mrgElements = mrgElements * MRG_BLOCK_4;
|
||||
mrgGroups = mrgGroups / MRG_BLOCK_4;
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
if (i % CONST_TWO == 0) {
|
||||
AscendC::DataCopy(src, tmp, logitsNum * VALUE_AND_INDEX_NUM);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void SortAll(LocalTensor<float> &dst, LocalTensor<float> &srcValue, LocalTensor<uint32_t> &srcIndex,
|
||||
LocalTensor<float> &tmpTensor, int64_t logitsNum)
|
||||
{
|
||||
int64_t sort32Repeats = logitsNum / BLOCK_BYTES;
|
||||
AscendC::Sort<float, true>(dst, srcValue, srcIndex, tmpTensor, sort32Repeats);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void MergeSort(const LocalTensor<float> &mrgDst, int32_t mrgDstNum, LocalTensor<float> &mrgSrc,
|
||||
int32_t mrgSrcNum, LocalTensor<float> &tmpTensor)
|
||||
{
|
||||
AscendC::MrgSort4Info params;
|
||||
params.elementLengths[0] = mrgDstNum;
|
||||
params.elementLengths[1] = mrgSrcNum;
|
||||
params.ifExhaustedSuspension = false;
|
||||
params.validBit = 0b0011;
|
||||
params.repeatTimes = 1;
|
||||
|
||||
AscendC::MrgSortSrcList<float> srcList;
|
||||
srcList.src1 = mrgDst;
|
||||
srcList.src2 = mrgSrc;
|
||||
|
||||
AscendC::MrgSort<float>(tmpTensor, srcList, params);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::DataCopy(mrgDst, tmpTensor, mrgDstNum * VALUE_AND_INDEX_NUM);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void MrgBasicBlock(const LocalTensor<float> &dst, const LocalTensor<float> &src, int64_t blockNum,
|
||||
int64_t basicBlockSize)
|
||||
{
|
||||
AscendC::MrgSort4Info params;
|
||||
params.elementLengths[MRG_QUE_0] = basicBlockSize;
|
||||
params.elementLengths[MRG_QUE_1] = basicBlockSize;
|
||||
params.elementLengths[MRG_QUE_2] = basicBlockSize;
|
||||
params.elementLengths[MRG_QUE_3] = basicBlockSize;
|
||||
params.ifExhaustedSuspension = false;
|
||||
if (blockNum == MRG_BLOCK_2) {
|
||||
params.validBit = 0b0011;
|
||||
} else if (blockNum == MRG_BLOCK_3) {
|
||||
params.validBit = 0b0111;
|
||||
} else if (blockNum == MRG_BLOCK_4) {
|
||||
params.validBit = 0b1111;
|
||||
} else {
|
||||
AscendC::DataCopy(dst, src, basicBlockSize * VALUE_AND_INDEX_NUM);
|
||||
return;
|
||||
}
|
||||
AscendC::MrgSortSrcList<float> srcList;
|
||||
srcList.src1 = src[0];
|
||||
srcList.src2 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_1];
|
||||
srcList.src3 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_2];
|
||||
srcList.src4 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_3];
|
||||
AscendC::MrgSort<float>(dst, srcList, params);
|
||||
}
|
||||
|
||||
template <bool needMrg = true>
|
||||
__aicore__ inline void SparseTopK(const LocalTensor<float> &dst, const LocalTensor<float> &needsMerging,
|
||||
const LocalTensor<float> &tmp, int64_t topk, int64_t mergSize)
|
||||
{
|
||||
if (!needMrg) {
|
||||
AscendC::DataCopy(dst, needsMerging, mergSize * VALUE_AND_INDEX_NUM);
|
||||
return;
|
||||
}
|
||||
AscendC::MrgSort4Info params;
|
||||
params.elementLengths[0] = topk;
|
||||
params.elementLengths[1] = mergSize;
|
||||
params.ifExhaustedSuspension = (topk == mergSize);
|
||||
params.validBit = 0b0011;
|
||||
AscendC::MrgSortSrcList<float> srcList;
|
||||
srcList.src1 = dst;
|
||||
srcList.src2 = needsMerging;
|
||||
AscendC::MrgSort<float>(tmp, srcList, params);
|
||||
AscendC::DataCopy(dst, tmp, topk * VALUE_AND_INDEX_NUM);
|
||||
}
|
||||
|
||||
|
||||
__aicore__ inline void ExtractIndex(const LocalTensor<uint32_t> &idxULocal, const LocalTensor<uint32_t> &sortLocal,
|
||||
int64_t extractNum)
|
||||
{
|
||||
AscendC::GatherMaskParams gatherMaskParams;
|
||||
gatherMaskParams.repeatTimes = Ceil(extractNum * sizeof(float) * VALUE_AND_INDEX_NUM, VEC_REPEAT_BYTES);
|
||||
gatherMaskParams.src0BlockStride = 1;
|
||||
gatherMaskParams.src0RepeatStride = B32_VEC_REPEAT_STRIDE;
|
||||
gatherMaskParams.src1RepeatStride = 0;
|
||||
uint64_t rsvdCnt = 0;
|
||||
uint8_t src1Pattern = 2;
|
||||
AscendC::GatherMask(idxULocal, sortLocal, src1Pattern, false, static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
|
||||
template <HardEvent event>
|
||||
__aicore__ inline void SetWaitFlag(HardEvent evt)
|
||||
{
|
||||
event_t eventId = static_cast<event_t>(GetTPipePtr()->FetchEventID(evt));
|
||||
AscendC::SetFlag<event>(eventId);
|
||||
AscendC::WaitFlag<event>(eventId);
|
||||
}
|
||||
|
||||
} // namespace LIServiceVec
|
||||
#endif // LIGHTNING_INDEXER_VECTOR_H
|
||||
39
csrc/sparse_flash_attention/op_host/CMakeLists.txt
Normal file
39
csrc/sparse_flash_attention/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,39 @@
|
||||
# This program is free software, you can redistribute it and/or modify it.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME SparseFlashAttention
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
-fpermissive
|
||||
)
|
||||
|
||||
set(sparse_flash_attention_depends transformer/attention/sparse_flash_attention PARENT_SCOPE)
|
||||
target_sources(op_host_aclnn PRIVATE
|
||||
sparse_flash_attention_def.cpp
|
||||
)
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
sparse_flash_attention_tiling.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(opmaster_ct PRIVATE
|
||||
sparse_flash_attention_tiling.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
sparse_flash_attention_proto.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
@@ -0,0 +1,90 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file sparse_flash_attention_def.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class SparseFlashAttention : public OpDef {
|
||||
public:
|
||||
explicit SparseFlashAttention(const char *name) : OpDef(name)
|
||||
{
|
||||
this->Input("query")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("key")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("value")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("sparse_indices")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("block_table")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("actual_seq_lengths_query")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("actual_seq_lengths_kv")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("query_rope")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("key_rope")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Output("attention_out")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Attr("scale_value").AttrType(REQUIRED).Float(1.0);
|
||||
this->Attr("sparse_block_size").AttrType(REQUIRED).Int(1);
|
||||
this->Attr("layout_query").AttrType(OPTIONAL).String("BSND");
|
||||
this->Attr("layout_kv").AttrType(OPTIONAL).String("BSND");
|
||||
this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3);
|
||||
OpAICoreConfig aicore_config;
|
||||
aicore_config.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(true)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.NeedCheckSupportFlag(false)
|
||||
.PrecisionReduceFlag(true)
|
||||
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn");
|
||||
this->AICore().AddConfig("ascend910b", aicore_config);
|
||||
this->AICore().AddConfig("ascend910_93", aicore_config);
|
||||
}
|
||||
};
|
||||
OP_ADD(SparseFlashAttention);
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,48 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file sparse_flash_attention_proto.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include <graph/utils/type_utils.h>
|
||||
#include <register/op_impl_registry.h>
|
||||
#include "error/ops_error.h"
|
||||
|
||||
using namespace ge;
|
||||
|
||||
namespace ops {
|
||||
constexpr size_t QUERY_INPUT_INDEX = 0;
|
||||
|
||||
ge::graphStatus InferShapeSparseFlashAttention(gert::InferShapeContext *context)
|
||||
{
|
||||
OPS_ERR_IF(context == nullptr, OPS_LOG_E("SparseFlashAttention", "InferShapeContext is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
const gert::Shape *queryShape = context->GetInputShape(QUERY_INPUT_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, queryShape, return ge::GRAPH_FAILED)
|
||||
gert::Shape *attentionOutShape = context->GetOutputShape(0);
|
||||
OPS_LOG_E_IF_NULL(context, attentionOutShape, return ge::GRAPH_FAILED)
|
||||
*attentionOutShape = *queryShape;
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus InferDataTypeSparseFlashAttention(gert::InferDataTypeContext *context)
|
||||
{
|
||||
OPS_ERR_IF(context == nullptr, OPS_LOG_E("SparseFlashAttention", "InferShapeContext is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
const auto inputDataType = context->GetInputDataType(QUERY_INPUT_INDEX);
|
||||
context->SetOutputDataType(0, inputDataType);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP(SparseFlashAttention).InferShape(InferShapeSparseFlashAttention).InferDataType(InferDataTypeSparseFlashAttention);
|
||||
} // namespace ops
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,583 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file sparse_flash_attention_tiling.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef SPARSE_FLASH_ATTENTION_TILING_H
|
||||
#define SPARSE_FLASH_ATTENTION_TILING_H
|
||||
|
||||
#include <sstream>
|
||||
#include <graph/utils/type_utils.h>
|
||||
#include <exe_graph/runtime/tiling_context.h>
|
||||
#include <tiling/platform/platform_ascendc.h>
|
||||
#include "register/tilingdata_base.h"
|
||||
#include "exe_graph/runtime/tiling_context.h"
|
||||
|
||||
namespace optiling {
|
||||
// Inputs Index
|
||||
constexpr uint32_t QUERY_INPUT_INDEX = 0;
|
||||
constexpr uint32_t KEY_INPUT_INDEX = 1;
|
||||
constexpr uint32_t VALUE_INPUT_INDEX = 2;
|
||||
constexpr uint32_t SPARSE_INDICES_INPUT_INDEX = 3;
|
||||
constexpr uint32_t BLOCK_TABLE_INPUT_INDEX = 4;
|
||||
constexpr uint32_t ACT_SEQ_LEN_Q_INPUT_INDEX = 5;
|
||||
constexpr uint32_t ACT_SEQ_LEN_KV_INPUT_INDEX = 6;
|
||||
constexpr uint32_t QUERY_ROPE_INPUT_INDEX = 7;
|
||||
constexpr uint32_t KEY_ROPE_INPUT_INDEX = 8;
|
||||
// Outputs Index
|
||||
constexpr uint32_t OUTPUT_INDEX = 0;
|
||||
// Attributes Index
|
||||
constexpr uint32_t SCALE_VALUE_ATTR_INDEX = 0;
|
||||
constexpr uint32_t SPARSE_BLOCK_SIZE_ATTR_INDEX = 1;
|
||||
constexpr uint32_t LAYOUT_QUERY_ATTR_INDEX = 2;
|
||||
constexpr uint32_t LAYOUT_KV_ATTR_INDEX = 3;
|
||||
constexpr uint32_t SPARSE_MODE_ATTR_INDEX = 4;
|
||||
// Dim Num
|
||||
constexpr size_t DIM_NUM_TWO = 2;
|
||||
constexpr size_t DIM_NUM_THREE = 3;
|
||||
constexpr size_t DIM_NUM_FOUR = 4;
|
||||
// Constant
|
||||
constexpr uint32_t MAX_BLOCK_SIZE = 1024;
|
||||
constexpr uint32_t COPYND2NZ_SRC_STRIDE_LIMITATION = 65535;
|
||||
constexpr uint32_t NUM_BYTES_FLOAT = 4;
|
||||
constexpr uint32_t NUM_BYTES_FLOAT16 = 2;
|
||||
constexpr uint32_t NUM_BYTES_BF16 = 2;
|
||||
constexpr uint32_t BYTE_BLOCK = 32;
|
||||
const uint32_t SFA_MAX_AIC_CORE_NUM = 26;
|
||||
|
||||
enum class SFALayout : uint32_t {
|
||||
BSND = 0,
|
||||
TND = 1,
|
||||
PA_BSND = 2
|
||||
};
|
||||
|
||||
struct SFATilingShapeCompareParam {
|
||||
int64_t B = 1;
|
||||
int64_t S = 1;
|
||||
int64_t N = 1;
|
||||
int64_t D = 1;
|
||||
int64_t T = 1;
|
||||
// PA
|
||||
int64_t Bs = 1;
|
||||
int64_t Bn = 1;
|
||||
};
|
||||
|
||||
enum class KvStorageMode : uint32_t {
|
||||
BATCH_CONTINUOUS = 0,
|
||||
PAGE_ATTENTION = 1
|
||||
};
|
||||
|
||||
enum class SFAPerfMode : uint32_t {
|
||||
C_TEMPLATE_MODE = 0,
|
||||
V_TEMPLATE_MODE
|
||||
};
|
||||
|
||||
enum class SFAAxis : uint32_t {
|
||||
B = 0,
|
||||
S = 1,
|
||||
N = 2,
|
||||
D = 3,
|
||||
K = 3,
|
||||
T = 5,
|
||||
Bn = 6, // block number
|
||||
Bs = 7, // block size
|
||||
};
|
||||
|
||||
struct SFARequiredParaInfo {
|
||||
const gert::CompileTimeTensorDesc *desc;
|
||||
const gert::StorageShape *shape;
|
||||
};
|
||||
|
||||
struct SFAOptionalParaInfo {
|
||||
const gert::CompileTimeTensorDesc *desc;
|
||||
const gert::Tensor *tensor;
|
||||
};
|
||||
|
||||
struct SFAParaInfo {
|
||||
SFARequiredParaInfo query = {nullptr, nullptr};
|
||||
SFARequiredParaInfo key = {nullptr, nullptr};
|
||||
SFARequiredParaInfo value = {nullptr, nullptr};
|
||||
SFARequiredParaInfo sparseIndices = {nullptr, nullptr};
|
||||
SFAOptionalParaInfo blockTable = {nullptr, nullptr};
|
||||
SFAOptionalParaInfo actualSeqLengthsQ = {nullptr, nullptr};
|
||||
SFAOptionalParaInfo actualSeqLengths = {nullptr, nullptr};
|
||||
SFAOptionalParaInfo queryRope = {nullptr, nullptr};
|
||||
SFAOptionalParaInfo keyRope = {nullptr, nullptr};
|
||||
SFARequiredParaInfo attenOut = {nullptr, nullptr};
|
||||
|
||||
const char *layoutQuery = nullptr;
|
||||
const char *layoutKV = nullptr;
|
||||
const int64_t *sparseBlockSize = nullptr;
|
||||
const float *scaleValue = nullptr;
|
||||
const int64_t *sparseMode = nullptr;
|
||||
};
|
||||
|
||||
struct InnerSplitParams {
|
||||
uint32_t s1GBaseSize = 1;
|
||||
uint32_t s2BaseSize = 1;
|
||||
};
|
||||
|
||||
BEGIN_TILING_DATA_DEF(SparseFlashAttentionBaseParamsMla)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, batchSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, seqSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, qSeqSize)
|
||||
TILING_DATA_FIELD_DEF(int64_t, blockSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch)
|
||||
TILING_DATA_FIELD_DEF(float, scaleValue)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, nNumOfQInOneGroup)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, actualLenDimsQ)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, actualLenDimsKV)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, outputLayout)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, sparseMode)
|
||||
TILING_DATA_FIELD_DEF(int64_t, sparseBlockSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, sparseBlockCount)
|
||||
END_TILING_DATA_DEF
|
||||
REGISTER_TILING_DATA_CLASS(SparseFlashAttentionBaseParamsMlaOp, SparseFlashAttentionBaseParamsMla)
|
||||
|
||||
BEGIN_TILING_DATA_DEF(SparseFlashAttentionSingleCoreParamsMla)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum);
|
||||
END_TILING_DATA_DEF
|
||||
REGISTER_TILING_DATA_CLASS(SparseFlashAttentionSingleCoreParamsMlaOp, SparseFlashAttentionSingleCoreParamsMla)
|
||||
|
||||
BEGIN_TILING_DATA_DEF(SparseFlashAttentionSingleCoreTensorSizeMla)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, mmResUbSize);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, bmm2ResUbSize);
|
||||
END_TILING_DATA_DEF
|
||||
REGISTER_TILING_DATA_CLASS(SparseFlashAttentionSingleCoreTensorSizeMlaOp, SparseFlashAttentionSingleCoreTensorSizeMla)
|
||||
|
||||
BEGIN_TILING_DATA_DEF(SparseFlashAttentionSplitKVParamsMla)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, s2)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, accumOutSize) // FD workspace
|
||||
TILING_DATA_FIELD_DEF(uint32_t, logSumExpSize) // FD workspace
|
||||
END_TILING_DATA_DEF
|
||||
REGISTER_TILING_DATA_CLASS(SparseFlashAttentionSplitKVParamsMlaOp, SparseFlashAttentionSplitKVParamsMla)
|
||||
|
||||
BEGIN_TILING_DATA_DEF(SparseFlashAttentionInnerSplitParams)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, mBaseSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, s2BaseSize)
|
||||
END_TILING_DATA_DEF
|
||||
REGISTER_TILING_DATA_CLASS(SparseFlashAttentionInnerSplitParamsOp, SparseFlashAttentionInnerSplitParams)
|
||||
|
||||
BEGIN_TILING_DATA_DEF(SparseFlashAttentionTilingDataMla)
|
||||
TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionBaseParamsMla, baseParams);
|
||||
TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionSplitKVParamsMla, splitKVParams);
|
||||
TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionSingleCoreParamsMla, singleCoreParams);
|
||||
TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionSingleCoreTensorSizeMla, singleCoreTensorSize);
|
||||
TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionInnerSplitParams, innerSplitParams);
|
||||
END_TILING_DATA_DEF
|
||||
REGISTER_TILING_DATA_CLASS(SparseFlashAttention, SparseFlashAttentionTilingDataMla)
|
||||
|
||||
template <typename T> inline T Align(T num, T rnd)
|
||||
{
|
||||
return (((rnd) == 0) ? 0 : (((num) + (rnd) - 1) / (rnd) * (rnd)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string SFAShape2String(const T &shape)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "[";
|
||||
if (shape.GetDimNum() > 0) {
|
||||
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
|
||||
oss << shape.GetDim(i) << ", ";
|
||||
}
|
||||
oss << shape.GetDim(shape.GetDimNum() - 1);
|
||||
}
|
||||
oss << "]";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
static std::string GetShapeStr(gert::Shape shape);
|
||||
static std::string SFADataTypeToSerialString(ge::DataType type);
|
||||
std::string SFATensorDesc2String(const gert::StorageShape *shape, const gert::CompileTimeTensorDesc *tensor);
|
||||
std::string SFADebugTilingContext(const gert::TilingContext *context);
|
||||
std::string SFALayoutToSerialString(SFALayout layout);
|
||||
|
||||
struct SFATilingInfo {
|
||||
const char *opName = nullptr;
|
||||
fe::PlatFormInfos *platformInfo = nullptr;
|
||||
SFAParaInfo opParamInfo;
|
||||
|
||||
// Base Param
|
||||
platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910B;
|
||||
uint32_t bSize = 0;
|
||||
uint32_t n1Size = 0;
|
||||
uint32_t n2Size = 0;
|
||||
uint32_t s1Size = 0;
|
||||
int64_t s2Size = 0;
|
||||
uint32_t qkHeadDim = 0;
|
||||
uint32_t vHeadDim = 0;
|
||||
uint32_t gSize = 0;
|
||||
uint32_t ropeHeadDim = 0;
|
||||
uint32_t qTSize = 0;
|
||||
uint32_t kvTSize = 0;
|
||||
float scaleValue = 0;
|
||||
uint32_t innerPrecise = 0;
|
||||
uint32_t l2CacheOffFlag = 0;
|
||||
int64_t sparseBlockSize = 0;
|
||||
int64_t sparseBlockCount = 0;
|
||||
|
||||
bool pageAttentionFlag = false;
|
||||
int64_t blockSize = 0;
|
||||
uint32_t blockTypeSize = 0;
|
||||
uint32_t maxBlockNumPerBatch = 0;
|
||||
uint32_t totalBlockNum = 0;
|
||||
|
||||
uint32_t actualLenDimsQ = 0;
|
||||
uint32_t maxActualseq = 0;
|
||||
|
||||
bool actualSeqLenFlag = false;
|
||||
bool isSameSeqAllKVTensor = true;
|
||||
bool isSameActualseq = true;
|
||||
uint32_t actualLenDimsKV = 0;
|
||||
std::vector<int64_t> kvListSeqLens {};
|
||||
|
||||
uint32_t sparseMode = 0;
|
||||
|
||||
ge::DataType inputQType = ge::DT_FLOAT16;
|
||||
ge::DataType inputKvType = ge::DT_FLOAT16;
|
||||
ge::DataType outputType = ge::DT_FLOAT16;
|
||||
|
||||
KvStorageMode kvStorageMode = KvStorageMode::BATCH_CONTINUOUS;
|
||||
|
||||
SFALayout qLayout = SFALayout::BSND;
|
||||
SFALayout topkLayout = SFALayout::BSND;
|
||||
SFALayout outLayout = SFALayout::BSND;
|
||||
SFALayout kvLayout = SFALayout::BSND;
|
||||
|
||||
ge::DataType inputQRopeType = ge::DT_FLOAT16;
|
||||
ge::DataType inputKRopeType = ge::DT_FLOAT16;
|
||||
|
||||
uint64_t l2CacheSize = 0;
|
||||
};
|
||||
|
||||
class SFAMlaTiling {
|
||||
public:
|
||||
explicit SFAMlaTiling(gert::TilingContext *context) : context_(context) {}
|
||||
ge::graphStatus DoOpTiling(SFATilingInfo *sfaInfo);
|
||||
|
||||
private:
|
||||
ge::graphStatus SetBlockDim(uint32_t blockDim);
|
||||
ge::graphStatus SetTilingKey(uint64_t tilingKey);
|
||||
ge::graphStatus SetWorkspaceSize(uint64_t workspaceSize);
|
||||
ge::graphStatus SetTilingData(TilingDef &tilingData);
|
||||
gert::TilingContext *context_ = nullptr;
|
||||
ge::graphStatus GetPlatformInfo();
|
||||
void GenTilingKey();
|
||||
bool DealSameSeqEachBatch();
|
||||
|
||||
void ZeroTensorProcess();
|
||||
void InitParams();
|
||||
|
||||
void Split();
|
||||
bool IsBalanceSplitCore();
|
||||
|
||||
void SplitBalanced();
|
||||
void CalcInnerSize(uint32_t s2Size);
|
||||
|
||||
bool IsFlashDecode(uint32_t coreNum);
|
||||
|
||||
void FillTilingBaseParamsMla();
|
||||
void FillTilingSplitKVMla();
|
||||
|
||||
void FillTilingSingleCoreParamsMla();
|
||||
void FillTilingSingleCoreTensorSizeMla();
|
||||
void FillTiling();
|
||||
|
||||
void CalcUbBmm();
|
||||
void CheckUbSpace();
|
||||
void NormalCalcFDWorkSpace(const uint32_t actCoreNum);
|
||||
void CalcFDWorkSpace(const uint32_t actCoreNum);
|
||||
void GetWorkspaceSize();
|
||||
|
||||
uint32_t CalcBalanceFDParamNums(const uint32_t actCoreNum);
|
||||
|
||||
void CalcBlockDim();
|
||||
|
||||
bool balanceModeFlag_ = false;
|
||||
bool splitKVFlag_ = false;
|
||||
|
||||
uint32_t coreNum_ = 0;
|
||||
SFAPerfMode perfMode_ = SFAPerfMode::V_TEMPLATE_MODE;
|
||||
uint32_t kvSplitPart_ = 1;
|
||||
size_t mmResUbSize_ = 0;
|
||||
size_t bmm2ResUbSize_ = 0;
|
||||
size_t qPreSizeMla_= 0;
|
||||
uint32_t sInnerLoopTimes_ = 0;
|
||||
uint32_t sInnerSize_ = 0;
|
||||
uint32_t sInnerSizeTail_ = 0;
|
||||
uint32_t sInnerSizeAlign_ = 0;
|
||||
uint32_t kvSplit_ = 0;
|
||||
uint32_t usedCoreNum_ = 0;
|
||||
uint32_t formerCoreNum_ = 0;
|
||||
uint32_t blockSplitBn2Range_ = 0;
|
||||
uint32_t tailSplitedBatchRange_ = 0;
|
||||
|
||||
uint32_t aicNum_ = 0;
|
||||
uint32_t aivNum_ = 0;
|
||||
size_t libapiSize_ = 0;
|
||||
|
||||
SparseFlashAttentionTilingDataMla tilingData_;
|
||||
uint32_t blockDim_{0};
|
||||
uint64_t workspaceSize_{0};
|
||||
uint64_t tilingKey_{0};
|
||||
|
||||
uint32_t headDimAlign_ = 0;
|
||||
uint32_t mBaseSize_ = 128;
|
||||
uint32_t mFdBaseSize_ = 8;
|
||||
|
||||
SFATilingInfo *sfaInfo_ = nullptr;
|
||||
};
|
||||
|
||||
class SFATilingCheck {
|
||||
public:
|
||||
explicit SFATilingCheck(const SFATilingInfo &sfaInfo) : sfaInfo_(sfaInfo) {};
|
||||
~SFATilingCheck() = default;
|
||||
virtual ge::graphStatus Process();
|
||||
private:
|
||||
void Init();
|
||||
void LogErrorDtypeSupport(const std::vector<ge::DataType> &expectDtypeList,
|
||||
const ge::DataType &actualDtype, const std::string &name) const;
|
||||
ge::graphStatus CheckDtypeSupport(const gert::CompileTimeTensorDesc *desc,
|
||||
const std::string &name) const;
|
||||
template <typename T> void LogErrorNumberSupport(const std::vector<T> &expectNumberList,
|
||||
const T &actualValue, const std::string &name, const std::string subName) const;
|
||||
template <typename T> void LogErrorDimNumSupport(const std::vector<T> &expectNumberList,
|
||||
const T &actualValue, const std::string &name) const;
|
||||
ge::graphStatus CheckDimNumSupport(const gert::StorageShape *shape,
|
||||
const std::vector<size_t> &expectDimNumList, const std::string &name) const;
|
||||
ge::graphStatus CheckDimNumInLayoutSupport(const SFALayout &layout,
|
||||
const gert::StorageShape *shape, const std::string &name) const;
|
||||
void LogErrorLayoutSupport(const std::vector<SFALayout> &expectLayoutList,
|
||||
const SFALayout &actualLayout, const std::string &name) const;
|
||||
ge::graphStatus GetExpectedShape(gert::Shape &shapeExpected,
|
||||
const SFATilingShapeCompareParam ¶m, const SFALayout &layout) const;
|
||||
ge::graphStatus CompareShape(SFATilingShapeCompareParam ¶m,
|
||||
const gert::Shape &shape, const SFALayout &layout, const std::string &name) const;
|
||||
ge::graphStatus CheckLayoutSupport(const SFALayout &actualLayout, const std::string &name) const;
|
||||
ge::graphStatus CheckSingleParaQuery() const;
|
||||
ge::graphStatus CheckSingleParaKey() const;
|
||||
ge::graphStatus CheckSingleParaValue() const;
|
||||
ge::graphStatus CheckSingleParaQueryRope() const;
|
||||
ge::graphStatus CheckSingleParaKeyRope() const;
|
||||
ge::graphStatus CheckSingleParaAttenOut() const;
|
||||
ge::graphStatus CheckSingleParaNumHeads() const;
|
||||
ge::graphStatus CheckSingleParaKvHeadNums() const;
|
||||
ge::graphStatus CheckSingleParaLayout() const;
|
||||
ge::graphStatus CheckSingleParaSparseMode() const;
|
||||
ge::graphStatus CheckSingleParaSparseBlockSize() const;
|
||||
ge::graphStatus CheckSingleParaSparseIndices() const;
|
||||
ge::graphStatus CheckSinglePara() const;
|
||||
ge::graphStatus CheckMultiParaConsistency() const;
|
||||
ge::graphStatus CheckRopeExistence();
|
||||
ge::graphStatus CheckExists(const void *pointer, const std::string &name) const;
|
||||
ge::graphStatus CheckNotExists(const void *pointer, const std::string &name) const;
|
||||
ge::graphStatus CheckExistsByMap(const std::map<std::string, const void *> ¶mMap) const;
|
||||
ge::graphStatus CheckNotExistsByMap(const std::map<std::string, const void *> ¶mMap) const;
|
||||
ge::graphStatus CheckExistenceByMap(std::map<std::string, const void *> &existMap,
|
||||
std::map<std::string, const void *> ¬ExistMap) const;
|
||||
template <typename T> ge::graphStatus CheckAttrValueByMap(
|
||||
std::map<std::string, std::pair<const T *, T>> &attrMap) const;
|
||||
ge::graphStatus CheckParaExistenceMlaNoquant() const;
|
||||
ge::graphStatus CheckParaExistenceGqaNoquant() const;
|
||||
ge::graphStatus CheckParaExistenceMla() const;
|
||||
ge::graphStatus CheckParaExistence();
|
||||
ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor,
|
||||
const SFALayout &layout, const std::string &name);
|
||||
void SetSFAShapeCompare();
|
||||
ge::graphStatus CheckQRope();
|
||||
ge::graphStatus CheckQRopeShape();
|
||||
ge::graphStatus CheckVAndKRopeShapeForBatchContinuous();
|
||||
uint32_t GetTypeSize(ge::DataType dtype) const;
|
||||
ge::graphStatus CheckVAndKRopeShapeForPageAttention();
|
||||
ge::graphStatus CheckVAndKRopeShape();
|
||||
ge::graphStatus CheckVAndKRope();
|
||||
ge::graphStatus CheckTopK();
|
||||
ge::graphStatus CheckTopkShape();
|
||||
ge::graphStatus CheckBlockTable() const;
|
||||
ge::graphStatus CheckDTypeConsistency(const ge::DataType &actualDtype,
|
||||
const ge::DataType &expectDtype, const std::string &name) const;
|
||||
|
||||
ge::graphStatus CheckAttenOut();
|
||||
ge::graphStatus CheckAttenOutShape();
|
||||
ge::graphStatus CheckActualSeqLensQ();
|
||||
ge::graphStatus CheckActualSeqLensQShape();
|
||||
ge::graphStatus CheckActualSeqLensQDType();
|
||||
ge::graphStatus CheckActualSeqLens();
|
||||
ge::graphStatus CheckActualSeqLensDType();
|
||||
ge::graphStatus CheckActualSeqLensShape();
|
||||
ge::graphStatus CheckMultiParaConsistency();
|
||||
|
||||
ge::graphStatus CheckFeatureMlaNoQuantShape() const;
|
||||
ge::graphStatus CheckFeatureMlaNoQuantLayout() const;
|
||||
ge::graphStatus CheckFeatureMlaNoQuantDtype() const;
|
||||
ge::graphStatus CheckFeatureMlaNoquantPa() const;
|
||||
ge::graphStatus CheckFeatureMlaNoquant() const;
|
||||
ge::graphStatus CheckFeatureMla() const;
|
||||
ge::graphStatus CheckFeature() const;
|
||||
|
||||
private:
|
||||
const char *opName_;
|
||||
fe::PlatFormInfos *platformInfo_;
|
||||
SFAParaInfo opParamInfo_;
|
||||
const SFATilingInfo &sfaInfo_;
|
||||
|
||||
uint32_t bSize_ = 0;
|
||||
uint32_t n1Size_ = 0;
|
||||
uint32_t n2Size_ = 0;
|
||||
uint32_t gSize_ = 0;
|
||||
uint32_t s1Size_ = 0;
|
||||
int64_t s2Size_ = 0;
|
||||
uint32_t qkHeadDim_ = 0;
|
||||
uint32_t vHeadDim_ = 0;
|
||||
uint32_t ropeHeadDim_ = 0;
|
||||
uint32_t qTSize_ = 0;
|
||||
uint32_t kvTSize_ = 0;
|
||||
KvStorageMode kvStorageMode_ = KvStorageMode::BATCH_CONTINUOUS;
|
||||
uint32_t sparseBlockCount_ = 0;
|
||||
int64_t sparseBlockSize_ = 0;
|
||||
|
||||
SFALayout qLayout_ = SFALayout::BSND;
|
||||
SFALayout topkLayout_ = SFALayout::BSND;
|
||||
SFALayout outLayout_ = SFALayout::BSND;
|
||||
SFALayout kvLayout_ = SFALayout::BSND;
|
||||
|
||||
uint32_t maxBlockNumPerBatch_ = 0;
|
||||
int64_t blockSize_ = 0;
|
||||
|
||||
uint32_t aicNum_ = 0;
|
||||
uint32_t aivNum_ = 0;
|
||||
platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B;
|
||||
uint64_t l2CacheSize_ = 0;
|
||||
|
||||
ge::DataType inputQType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputKvType_ = ge::DT_FLOAT16;
|
||||
ge::DataType outputType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputQRopeType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputKRopeType_ = ge::DT_FLOAT16;
|
||||
|
||||
gert::Shape queryShapeCmp_{};
|
||||
gert::Shape keyShapeCmp_{};
|
||||
gert::Shape valueShapeCmp_{};
|
||||
gert::Shape topkShapeCmp_{};
|
||||
gert::Shape queryRopeShapeCmp_{};
|
||||
gert::Shape keyRopeShapeCmp_{};
|
||||
gert::Shape attenOutShapeCmp_{};
|
||||
};
|
||||
|
||||
class SFAInfoParser {
|
||||
public:
|
||||
explicit SFAInfoParser(const gert::TilingContext *context) : context_(context) {}
|
||||
~SFAInfoParser() = default;
|
||||
|
||||
ge::graphStatus CheckRequiredInOutExistence() const;
|
||||
ge::graphStatus CheckRequiredAttrExistence() const;
|
||||
ge::graphStatus CheckRequiredParaExistence() const;
|
||||
|
||||
ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor,
|
||||
SFALayout &layout, const std::string &name);
|
||||
ge::graphStatus GetActualSeqLenQSize(uint32_t &size);
|
||||
ge::graphStatus GetOpName();
|
||||
ge::graphStatus GetNpuInfo();
|
||||
void GetOptionalInputParaInfo();
|
||||
void GetInputParaInfo();
|
||||
void GetOutputParaInfo();
|
||||
ge::graphStatus GetAttrParaInfo();
|
||||
ge::graphStatus GetKvCache();
|
||||
ge::graphStatus GetOpParaInfo();
|
||||
|
||||
ge::graphStatus GetInOutDataType();
|
||||
ge::graphStatus GetBatchSize();
|
||||
ge::graphStatus GetQTSize();
|
||||
ge::graphStatus GetKVTSize();
|
||||
ge::graphStatus GetQkHeadDim();
|
||||
ge::graphStatus GetS1Size();
|
||||
ge::graphStatus GetKvStorageMode();
|
||||
ge::graphStatus GetKvLayout();
|
||||
void SetSFAShape();
|
||||
ge::graphStatus GetS2SizeForBatchContinuous();
|
||||
ge::graphStatus GetMaxBlockNumPerBatch();
|
||||
ge::graphStatus GetBlockSize();
|
||||
ge::graphStatus GetS2SizeForPageAttention();
|
||||
ge::graphStatus GetS2Size();
|
||||
ge::graphStatus GetValueHeadDim();
|
||||
ge::graphStatus GetRopeHeadDim();
|
||||
ge::graphStatus GetQueryAndOutLayout();
|
||||
ge::graphStatus GetTopkLayout();
|
||||
ge::graphStatus GetN1Size();
|
||||
ge::graphStatus GetN2Size();
|
||||
ge::graphStatus GetGSize();
|
||||
ge::graphStatus GetSparseBlockCount();
|
||||
ge::graphStatus GetActualseqInfo();
|
||||
void GenerateInfo(SFATilingInfo &sfaInfo);
|
||||
ge::graphStatus Parse(SFATilingInfo &sfaInfo);
|
||||
|
||||
public:
|
||||
bool HasAxis(const SFAAxis &axis, const SFALayout &layout, const gert::Shape &shape) const;
|
||||
size_t GetAxisIdx(const SFAAxis &axis, const SFALayout &layout) const;
|
||||
uint32_t GetAxisNum(const gert::Shape &shape, const SFAAxis &axis,const SFALayout &layout) const;
|
||||
|
||||
const gert::TilingContext *context_ = nullptr;
|
||||
|
||||
const char *opName_;
|
||||
fe::PlatFormInfos *platformInfo_;
|
||||
SFAParaInfo opParamInfo_;
|
||||
static constexpr int64_t invalidDimValue_ = std::numeric_limits<int64_t>::min();
|
||||
|
||||
uint32_t bSize_ = 0;
|
||||
uint32_t n1Size_ = 0;
|
||||
uint32_t n2Size_ = 0;
|
||||
uint32_t gSize_ = 0;
|
||||
uint32_t s1Size_ = 0;
|
||||
int64_t s2Size_ = 0;
|
||||
uint32_t qkHeadDim_ = 0;
|
||||
uint32_t vHeadDim_ = 0;
|
||||
uint32_t ropeHeadDim_ = 0;
|
||||
uint32_t qTSize_ = 0;
|
||||
uint32_t kvTSize_ = 0;
|
||||
KvStorageMode kvStorageMode_ = KvStorageMode::BATCH_CONTINUOUS;
|
||||
uint32_t sparseBlockCount_ = 0;
|
||||
|
||||
SFALayout qLayout_ = SFALayout::BSND;
|
||||
SFALayout topkLayout_ = SFALayout::BSND;
|
||||
SFALayout outLayout_ = SFALayout::BSND;
|
||||
SFALayout kvLayout_ = SFALayout::BSND;
|
||||
|
||||
uint32_t maxBlockNumPerBatch_ = 0;
|
||||
uint32_t blockSize_ = 0;
|
||||
|
||||
platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B;
|
||||
|
||||
ge::DataType inputQType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputKvType_ = ge::DT_FLOAT16;
|
||||
ge::DataType outputType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputQRopeType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputKRopeType_ = ge::DT_FLOAT16;
|
||||
|
||||
uint64_t l2CacheSize_ = 0;
|
||||
|
||||
bool isSameSeqAllKVTensor_ = true;
|
||||
bool isSameActualseq_ = true;
|
||||
uint32_t maxActualseq_ = 0;
|
||||
|
||||
uint32_t actualLenDimsQ_ = 0;
|
||||
uint32_t actualLenDimsKV_ = 0;
|
||||
|
||||
gert::Shape queryShape_{};
|
||||
gert::Shape keyShape_{};
|
||||
gert::Shape valueShape_{};
|
||||
gert::Shape sparseIndicesShape_{};
|
||||
gert::Shape queryRopeShape_{};
|
||||
gert::Shape keyRopeShape_{};
|
||||
};
|
||||
} // namespace optiling
|
||||
#endif // SPARSE_FLASH_ATTENTION_TILING_H
|
||||
@@ -0,0 +1,53 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file sparse_flash_attention.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "sparse_flash_attention_template_tiling_key.h"
|
||||
#include "sparse_flash_attention_kernel_mla.h"
|
||||
|
||||
using namespace AscendC;
|
||||
|
||||
#define SFA_OP_IMPL(templateClass, tilingdataClass, ...) \
|
||||
do { \
|
||||
templateClass<SFAType<__VA_ARGS__>> op; \
|
||||
GET_TILING_DATA_WITH_STRUCT(tilingdataClass, tiling_data_in, tiling); \
|
||||
const tilingdataClass *__restrict tiling_data = &tiling_data_in; \
|
||||
op.Init(query, key, value, sparseIndices, actualSeqLengthsQuery, actualSeqLengthsKV, \
|
||||
blocktable, queryRope, keyRope, attentionOut, user, tiling_data, tiling, &tPipe); \
|
||||
op.Process(); \
|
||||
} while (0)
|
||||
|
||||
template<int FLASH_DECODE, int LAYOUT_T, int KV_LAYOUT_T, int TEMPLATE_MODE>
|
||||
__global__ __aicore__ void
|
||||
sparse_flash_attention(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *value,
|
||||
__gm__ uint8_t *sparseIndices, __gm__ uint8_t *blocktable,
|
||||
__gm__ uint8_t *actualSeqLengthsQuery, __gm__ uint8_t *actualSeqLengthsKV,
|
||||
__gm__ uint8_t* queryRope, __gm__ uint8_t* keyRope,
|
||||
__gm__ uint8_t *attentionOut, __gm__ uint8_t *workspace, __gm__ uint8_t *tiling)
|
||||
{
|
||||
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2);
|
||||
|
||||
TPipe tPipe;
|
||||
__gm__ uint8_t *user = GetUserWorkspace(workspace);
|
||||
|
||||
if constexpr (ORIG_DTYPE_QUERY == DT_FLOAT16 && ORIG_DTYPE_KEY == DT_FLOAT16 &&
|
||||
ORIG_DTYPE_ATTENTION_OUT == DT_FLOAT16) {
|
||||
SFA_OP_IMPL(SparseFlashAttentionMla, SparseFlashAttentionTilingDataMla, half, half, half,
|
||||
FLASH_DECODE, static_cast<SFA_LAYOUT>(LAYOUT_T), static_cast<SFA_LAYOUT>(KV_LAYOUT_T), TEMPLATE_MODE);
|
||||
} else { // bf16
|
||||
SFA_OP_IMPL(SparseFlashAttentionMla, SparseFlashAttentionTilingDataMla, bfloat16_t, bfloat16_t, bfloat16_t,
|
||||
FLASH_DECODE, static_cast<SFA_LAYOUT>(LAYOUT_T), static_cast<SFA_LAYOUT>(KV_LAYOUT_T), TEMPLATE_MODE);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file sparse_flash_attention_common.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef SPARSE_FLASH_ATTENTION_COMMON_H
|
||||
#define SPARSE_FLASH_ATTENTION_COMMON_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
#include "lib/matrix/matmul/tiling.h"
|
||||
|
||||
using namespace AscendC;
|
||||
constexpr SoftmaxConfig SFA_SOFTMAX_FLASHV2_CFG_WITHOUT_BRC = {false, 0, 0, SoftmaxMode::SOFTMAX_OUTPUT_WITHOUT_BRC};
|
||||
|
||||
enum class SFA_LAYOUT
|
||||
{
|
||||
BSND = 0,
|
||||
TND = 1,
|
||||
PA_BSND = 2,
|
||||
};
|
||||
|
||||
template <typename Q_T, typename KV_T, typename OUT_T, const bool FLASH_DECODE = false,
|
||||
SFA_LAYOUT LAYOUT_T = SFA_LAYOUT::BSND, SFA_LAYOUT KV_LAYOUT_T = SFA_LAYOUT::BSND,
|
||||
const int TEMPLATE_MODE = C_TEMPLATE, typename... Args>
|
||||
struct SFAType {
|
||||
using queryType = Q_T;
|
||||
using kvType = KV_T;
|
||||
using outputType = OUT_T;
|
||||
static constexpr bool flashDecode = FLASH_DECODE;
|
||||
static constexpr SFA_LAYOUT layout = LAYOUT_T;
|
||||
static constexpr SFA_LAYOUT kvLayout = KV_LAYOUT_T;
|
||||
static constexpr int templateMode = TEMPLATE_MODE;
|
||||
static constexpr bool pageAttention = (KV_LAYOUT_T == SFA_LAYOUT::PA_BSND);
|
||||
};
|
||||
|
||||
// ================================Util functions==================================
|
||||
template <typename T> __aicore__ inline T SFAAlign(T num, T rnd)
|
||||
{
|
||||
return (((rnd) == 0) ? 0 : (((num) + (rnd) - 1) / (rnd) * (rnd)));
|
||||
}
|
||||
|
||||
template <typename T1, typename T2> __aicore__ inline T1 Min(T1 a, T2 b)
|
||||
{
|
||||
return (a > b) ? (b) : (a);
|
||||
}
|
||||
|
||||
template <typename T> __aicore__ inline size_t BlockAlign(size_t s)
|
||||
{
|
||||
if constexpr (IsSameType<T, int4b_t>::value) {
|
||||
return (s + 63) / 64 * 64;
|
||||
}
|
||||
size_t n = (32 / sizeof(T));
|
||||
return (s + n - 1) / n * n;
|
||||
}
|
||||
|
||||
struct RunInfo {
|
||||
uint32_t loop;
|
||||
uint32_t bIdx;
|
||||
uint32_t gIdx;
|
||||
uint32_t s1Idx;
|
||||
uint32_t s2Idx;
|
||||
uint32_t bn2IdxInCurCore;
|
||||
uint32_t curSInnerLoopTimes;
|
||||
uint64_t tndBIdxOffsetForQ;
|
||||
uint64_t tndBIdxOffsetForKV;
|
||||
uint64_t tensorAOffset;
|
||||
uint64_t tensorBOffset;
|
||||
uint64_t tensorARopeOffset;
|
||||
uint64_t tensorBRopeOffset;
|
||||
uint64_t attenOutOffset;
|
||||
uint64_t attenMaskOffset;
|
||||
uint64_t topKBaseOffset;
|
||||
uint32_t actualSingleProcessSInnerSize;
|
||||
uint32_t actualSingleProcessSInnerSizeAlign;
|
||||
bool isFirstSInnerLoop;
|
||||
bool isChangeBatch;
|
||||
uint32_t s2BatchOffset;
|
||||
uint32_t gSize;
|
||||
uint32_t s1Size;
|
||||
uint32_t s2Size;
|
||||
uint32_t mSize;
|
||||
uint32_t mSizeV;
|
||||
uint32_t mSizeVStart;
|
||||
uint32_t tndIsS2SplitCore;
|
||||
uint32_t tndCoreStartKVSplitPos;
|
||||
bool isBmm2Output;
|
||||
bool isValid = false;
|
||||
|
||||
static constexpr uint32_t n2Idx = 0;
|
||||
uint64_t actS1Size = 1;
|
||||
uint64_t curActualSeqLenOri = 0ULL;
|
||||
|
||||
uint32_t gS1Idx;
|
||||
uint64_t actS2Size = 1;
|
||||
uint32_t actMBaseSize;
|
||||
bool isLastS2Loop;
|
||||
int32_t nextTokensPerBatch = 0;
|
||||
int64_t threshold;
|
||||
uint32_t curTopKIdx = 0;
|
||||
uint64_t curOffsetInSparseBlock = 0;
|
||||
};
|
||||
|
||||
struct ConstInfo {
|
||||
static constexpr uint32_t SFA_SYNC_MODE2 = 2;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_32B = 32;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_64B = 64;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_256B = 256;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_512B = 512;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_1K = 1024;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_2K = 2048;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_4K = 4096;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_8K = 8192;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_16K = 16384;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_32K = 32768;
|
||||
static constexpr float FLOAT_ZERO = 0;
|
||||
static constexpr float FLOAT_MAX = 3.402823466e+38F;
|
||||
|
||||
uint32_t preLoadNum = 0U;
|
||||
uint32_t nBufferMBaseSize = 0U;
|
||||
uint32_t syncV1NupdateC2 = 0U;
|
||||
uint32_t syncV0C1 = 0U;
|
||||
uint32_t syncC1V1 = 0U;
|
||||
uint32_t syncV1C2 = 0U;
|
||||
uint32_t syncC2V2 = 0U;
|
||||
uint32_t syncC2V1 = 0U;
|
||||
|
||||
uint32_t mmResUbSize = 0U;
|
||||
uint32_t vec1ResUbSize = 0U;
|
||||
uint32_t bmm2ResUbSize = 0U;
|
||||
uint64_t batchSize = 0ULL;
|
||||
uint64_t gSize = 0ULL;
|
||||
uint64_t qHeadNum = 0ULL;
|
||||
uint64_t kvHeadNum;
|
||||
uint64_t headDim;
|
||||
uint64_t headDimRope;
|
||||
uint64_t kvSeqSize = 0ULL;
|
||||
uint64_t qSeqSize = 1ULL;
|
||||
int64_t kvCacheBlockSize = 0;
|
||||
uint32_t maxBlockNumPerBatch = 0;
|
||||
uint32_t splitKVNum = 0U;
|
||||
SFA_LAYOUT outputLayout;
|
||||
uint32_t sparseMode = 0;
|
||||
bool needInit = false;
|
||||
|
||||
// FlashDecoding
|
||||
uint32_t actualCombineLoopSize = 0U;
|
||||
uint64_t combineLseOffset = 0ULL;
|
||||
uint64_t combineAccumOutOffset = 0ULL;
|
||||
|
||||
uint32_t actualLenDimsQ = 0U;
|
||||
uint32_t actualLenDimsKV = 0U;
|
||||
|
||||
// TND
|
||||
uint32_t s2Start = 0U;
|
||||
uint32_t s2End = 0U;
|
||||
|
||||
uint32_t bN2Start = 0U;
|
||||
uint32_t bN2End = 0U;
|
||||
uint32_t gS1Start = 0U;
|
||||
uint32_t gS1End = 0U;
|
||||
|
||||
uint32_t tndFDCoreArrLen = 0U;
|
||||
uint32_t coreStartKVSplitPos = 0U;
|
||||
|
||||
uint32_t mBaseSize = 1ULL;
|
||||
uint32_t s2BaseSize = 1ULL;
|
||||
|
||||
// sparse attr
|
||||
int64_t sparseBlockSize = 0;
|
||||
uint32_t sparseBlockCount = 0;
|
||||
};
|
||||
|
||||
struct MSplitInfo {
|
||||
uint32_t nBufferIdx = 0U;
|
||||
uint32_t nBufferStartM = 0U;
|
||||
uint32_t nBufferDealM = 0U;
|
||||
uint32_t vecStartM = 0U;
|
||||
uint32_t vecDealM = 0U;
|
||||
};
|
||||
|
||||
#endif // SPARSE_FLASH_ATTENTION_COMMON_H
|
||||
@@ -0,0 +1,969 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file sparse_flash_attention_kernel_mla.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef SPARSE_FLASH_ATTENTION_KERNEL_MLA_H
|
||||
#define SPARSE_FLASH_ATTENTION_KERNEL_MLA_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_operator_list_tensor_intf.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
#include "lib/matrix/matmul/tiling.h"
|
||||
#include "sparse_flash_attention_common.h"
|
||||
#include "sparse_flash_attention_service_cube_mla.h"
|
||||
#include "sparse_flash_attention_service_vector_mla.h"
|
||||
|
||||
using namespace matmul;
|
||||
using AscendC::CacheMode;
|
||||
using AscendC::CrossCoreSetFlag;
|
||||
using AscendC::CrossCoreWaitFlag;
|
||||
|
||||
struct TempLoopInfo {
|
||||
uint32_t bn2IdxInCurCore = 0;
|
||||
uint32_t bIdx = 0U;
|
||||
uint32_t n2Idx = 0U;
|
||||
uint64_t s2BasicSizeTail = 0U;
|
||||
uint32_t s2LoopTimes = 0U;
|
||||
uint64_t curActualSeqLen = 0ULL;
|
||||
uint64_t curActualSeqLenOri = 0ULL;
|
||||
bool curActSeqLenIsZero = false;
|
||||
int32_t nextTokensPerBatch = 0;
|
||||
|
||||
uint64_t actS1Size = 1ULL;
|
||||
uint32_t tndCoreStartKVSplitPos;
|
||||
bool tndIsS2SplitCore;
|
||||
|
||||
uint32_t gS1Idx = 0U;
|
||||
uint64_t mBasicSizeTail = 0U;
|
||||
};
|
||||
|
||||
template <typename SFAT> class SparseFlashAttentionMla {
|
||||
public:
|
||||
using T = float;
|
||||
using Q_T = typename SFAT::queryType;
|
||||
using KV_T = typename SFAT::kvType;
|
||||
using OUT_T = typename SFAT::outputType;
|
||||
using Q_ROPE_T = Q_T;
|
||||
using K_ROPE_T = KV_T;
|
||||
using UPDATE_T = T;
|
||||
using MM1_OUT_T = T;
|
||||
using MM2_OUT_T = T;
|
||||
|
||||
__aicore__ inline SparseFlashAttentionMla(){};
|
||||
__aicore__ inline void Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *value,
|
||||
__gm__ uint8_t *sparseIndices, __gm__ uint8_t *actualSeqLengthsQ,
|
||||
__gm__ uint8_t *actualSeqLengths, __gm__ uint8_t *blockTable,
|
||||
__gm__ uint8_t *queryRope, __gm__ uint8_t *keyRope,
|
||||
__gm__ uint8_t *attentionOut, __gm__ uint8_t *workspace,
|
||||
const SparseFlashAttentionTilingDataMla *__restrict tiling,
|
||||
__gm__ uint8_t *gmTiling, TPipe *tPipe);
|
||||
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
static constexpr bool PAGE_ATTENTION = SFAT::pageAttention;
|
||||
static constexpr int TEMPLATE_MODE = SFAT::templateMode;
|
||||
static constexpr bool FLASH_DECODE = SFAT::flashDecode;
|
||||
static constexpr SFA_LAYOUT LAYOUT_T = SFAT::layout;
|
||||
static constexpr SFA_LAYOUT KV_LAYOUT_T = SFAT::kvLayout;
|
||||
|
||||
static constexpr uint32_t PRELOAD_NUM = 2;
|
||||
static constexpr uint32_t N_BUFFER_M_BASIC_SIZE = 256;
|
||||
static constexpr uint32_t SFA_PRELOAD_TASK_CACHE_SIZE = 3;
|
||||
|
||||
static constexpr uint32_t SYNC_V0_C1_FLAG = 6;
|
||||
static constexpr uint32_t SYNC_C1_V1_FLAG = 7;
|
||||
static constexpr uint32_t SYNC_V1_C2_FLAG = 8;
|
||||
static constexpr uint32_t SYNC_C2_V2_FLAG = 9;
|
||||
static constexpr uint32_t SYNC_C2_V1_FLAG = 4;
|
||||
static constexpr uint32_t SYNC_V1_NUPDATE_C2_FLAG = 5;
|
||||
|
||||
static constexpr uint64_t SYNC_MM2RES_BUF1_FLAG = 10;
|
||||
static constexpr uint64_t SYNC_MM2RES_BUF2_FLAG = 11;
|
||||
static constexpr uint64_t SYNC_FDOUTPUT_BUF_FLAG = 12;
|
||||
|
||||
static constexpr uint32_t BLOCK_ELEMENT_NUM = SFAVectorService<SFAT>::BYTE_BLOCK / sizeof(T);
|
||||
|
||||
static constexpr uint64_t kvHeadNum = 1ULL;
|
||||
static constexpr uint64_t headDim = 512ULL;
|
||||
static constexpr uint64_t headDimAlign = 512ULL;
|
||||
static constexpr uint64_t headDimRope = 64ULL;
|
||||
static constexpr uint32_t msdIterNum = 2U;
|
||||
|
||||
static constexpr uint32_t dbWorkspaceRatio = PRELOAD_NUM;
|
||||
|
||||
const SparseFlashAttentionTilingDataMla *__restrict tilingData = nullptr;
|
||||
|
||||
TPipe *pipe = nullptr;
|
||||
|
||||
uint64_t mSizeVStart = 0ULL;
|
||||
int64_t threshold = 0;
|
||||
uint64_t topKBaseOffset = 0ULL;
|
||||
uint64_t s2BatchBaseOffset = 0;
|
||||
uint64_t tensorACoreOffset = 0ULL;
|
||||
uint64_t tensorBCoreOffset = 0ULL;
|
||||
uint64_t tensorARopeCoreOffset = 0ULL;
|
||||
uint64_t tensorBRopeCoreOffset = 0ULL;
|
||||
uint64_t tensorBOffset = 0ULL;
|
||||
uint64_t attenOutOffset = 0ULL;
|
||||
|
||||
uint32_t tmpBlockIdx = 0U;
|
||||
uint32_t aiCoreIdx = 0U;
|
||||
uint32_t usedCoreNum = 0U;
|
||||
|
||||
__gm__ uint8_t *keyPtr = nullptr;
|
||||
__gm__ uint8_t *valuePtr = nullptr;
|
||||
|
||||
ConstInfo constInfo{};
|
||||
TempLoopInfo tempLoopInfo{};
|
||||
|
||||
SFAMatmulService<SFAT> matmulService;
|
||||
SFAVectorService<SFAT> vectorService;
|
||||
|
||||
GlobalTensor<Q_T> queryGm;
|
||||
GlobalTensor<KV_T> keyGm;
|
||||
GlobalTensor<KV_T> valueGm;
|
||||
GlobalTensor<Q_ROPE_T> qRopeGm;
|
||||
GlobalTensor<K_ROPE_T> kRopeGm;
|
||||
|
||||
GlobalTensor<OUT_T> attentionOutGm;
|
||||
GlobalTensor<int32_t> blockTableGm;
|
||||
GlobalTensor<int32_t> topKGm;
|
||||
|
||||
GlobalTensor<int32_t> actualSeqLengthsQGm;
|
||||
GlobalTensor<int32_t> actualSeqLengthsKVGm;
|
||||
|
||||
// workspace
|
||||
GlobalTensor<MM1_OUT_T> mm1ResGm;
|
||||
GlobalTensor<KV_T> vec1ResGm;
|
||||
GlobalTensor<MM2_OUT_T> mm2ResGm;
|
||||
GlobalTensor<KV_T> kvMergeGm_;
|
||||
GlobalTensor<int32_t> kvValidSizeGm_;
|
||||
|
||||
GlobalTensor<int32_t> mm2ResInt32Gm;
|
||||
GlobalTensor<UPDATE_T> vec2ResGm;
|
||||
|
||||
GlobalTensor<T> accumOutGm;
|
||||
GlobalTensor<T> lseSumFdGm;
|
||||
GlobalTensor<T> lseMaxFdGm;
|
||||
|
||||
// ================================Init functions===================================
|
||||
__aicore__ inline void InitTilingData();
|
||||
__aicore__ inline void InitCalcParamsEach();
|
||||
__aicore__ inline void InitBuffers();
|
||||
__aicore__ inline void InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths);
|
||||
__aicore__ inline void InitOutputSingleCore();
|
||||
// ================================Process functions================================
|
||||
__aicore__ inline void ProcessBalance();
|
||||
__aicore__ inline void PreloadPipeline(uint32_t loop, uint64_t s2Start, uint64_t s2LoopIdx,
|
||||
RunInfo extraInfo[SFA_PRELOAD_TASK_CACHE_SIZE], uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock);
|
||||
// ================================Offset Calc=====================================
|
||||
__aicore__ inline void GetActualSeqLen(uint32_t bIdx, uint32_t s1Idx = 0);
|
||||
__aicore__ inline void GetSparseActualSeqLen(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx);
|
||||
__aicore__ inline void CalcSinnerTopKBegin(RunInfo &info, uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock);
|
||||
__aicore__ inline void UpdateInnerLoopCond();
|
||||
__aicore__ inline void DealActSeqLenIsZero(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx);
|
||||
__aicore__ inline void CalcParams(uint32_t loop, uint64_t s2Start, uint32_t s2LoopIdx, RunInfo &info);
|
||||
__aicore__ inline void GetAxisStartIdx(uint32_t bN2EndPrev, uint32_t gS1EndPrev, uint32_t s2EndPrev);
|
||||
__aicore__ inline uint64_t GetBalanceActualSeqLengths(GlobalTensor<int32_t> &actualSeqLengths, uint32_t bIdx);
|
||||
__aicore__ inline uint32_t GetActualSeqLenKV(uint32_t bIdx);
|
||||
__aicore__ inline void GetBN2Idx(uint32_t bN2Idx, uint32_t &bIdx, uint32_t &n2Idx);
|
||||
__aicore__ inline void UpdateInner(uint32_t &s2End, uint32_t &curS2End, uint32_t s1Idx, bool isEnd);
|
||||
__aicore__ inline void GetPreNextTokensLeftUp();
|
||||
// ================================Mm1==============================================
|
||||
__aicore__ inline void ComputeMm1(const RunInfo &info);
|
||||
// ================================Mm2==============================================
|
||||
__aicore__ inline void ComputeMm2(const RunInfo &info);
|
||||
__aicore__ inline void Bmm2DataCopyOut(uint64_t attenOutOffset, LocalTensor<OUT_T> &attenOutUb, uint32_t startRow,
|
||||
uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount);
|
||||
__aicore__ inline void InitAllZeroOutput(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx);
|
||||
};
|
||||
|
||||
template <typename SFAT> __aicore__ inline void SparseFlashAttentionMla<SFAT>::InitTilingData()
|
||||
{
|
||||
usedCoreNum = tilingData->singleCoreParams.usedCoreNum;
|
||||
constInfo.splitKVNum = tilingData->splitKVParams.s2;
|
||||
constInfo.mmResUbSize = tilingData->singleCoreTensorSize.mmResUbSize;
|
||||
constInfo.bmm2ResUbSize = tilingData->singleCoreTensorSize.bmm2ResUbSize;
|
||||
constInfo.vec1ResUbSize = constInfo.mmResUbSize * msdIterNum;
|
||||
|
||||
constInfo.batchSize = tilingData->baseParams.batchSize;
|
||||
constInfo.qHeadNum = constInfo.gSize = tilingData->baseParams.nNumOfQInOneGroup;
|
||||
constInfo.kvSeqSize = tilingData->baseParams.seqSize;
|
||||
constInfo.qSeqSize = tilingData->baseParams.qSeqSize;
|
||||
constInfo.maxBlockNumPerBatch = tilingData->baseParams.maxBlockNumPerBatch;
|
||||
constInfo.kvCacheBlockSize = tilingData->baseParams.blockSize;
|
||||
constInfo.outputLayout = static_cast<SFA_LAYOUT>(tilingData->baseParams.outputLayout);
|
||||
constInfo.mBaseSize = tilingData->innerSplitParams.mBaseSize;
|
||||
constInfo.s2BaseSize = tilingData->innerSplitParams.s2BaseSize;
|
||||
constInfo.kvHeadNum = kvHeadNum;
|
||||
constInfo.headDim = headDim;
|
||||
constInfo.headDimRope = headDimRope;
|
||||
constInfo.sparseBlockSize = tilingData->baseParams.sparseBlockSize;
|
||||
constInfo.sparseBlockCount = tilingData->baseParams.sparseBlockCount;
|
||||
constInfo.sparseMode = tilingData->baseParams.sparseMode;
|
||||
|
||||
constInfo.preLoadNum = PRELOAD_NUM;
|
||||
constInfo.nBufferMBaseSize = N_BUFFER_M_BASIC_SIZE;
|
||||
constInfo.syncV0C1 = SYNC_V0_C1_FLAG;
|
||||
constInfo.syncC1V1 = SYNC_C1_V1_FLAG;
|
||||
constInfo.syncV1C2 = SYNC_V1_C2_FLAG;
|
||||
constInfo.syncC2V2 = SYNC_C2_V2_FLAG;
|
||||
constInfo.syncC2V1 = SYNC_C2_V1_FLAG;
|
||||
constInfo.syncV1NupdateC2 = SYNC_V1_NUPDATE_C2_FLAG;
|
||||
}
|
||||
|
||||
template <typename SFAT> __aicore__ inline void SparseFlashAttentionMla<SFAT>::InitBuffers()
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.InitBuffers(pipe);
|
||||
} else {
|
||||
matmulService.InitBuffers(pipe);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void
|
||||
SparseFlashAttentionMla<SFAT>::InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ,
|
||||
__gm__ uint8_t *actualSeqLengths)
|
||||
{
|
||||
constInfo.actualLenDimsQ = tilingData->baseParams.actualLenDimsQ;
|
||||
constInfo.actualLenDimsKV = tilingData->baseParams.actualLenDimsKV;
|
||||
if (constInfo.actualLenDimsKV != 0) {
|
||||
actualSeqLengthsKVGm.SetGlobalBuffer((__gm__ int32_t *)actualSeqLengths, constInfo.actualLenDimsKV);
|
||||
}
|
||||
if (constInfo.actualLenDimsQ != 0) {
|
||||
actualSeqLengthsQGm.SetGlobalBuffer((__gm__ int32_t *)actualSeqLengthsQ, constInfo.actualLenDimsQ);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void SparseFlashAttentionMla<SFAT>::InitAllZeroOutput(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx)
|
||||
{
|
||||
if (constInfo.outputLayout == SFA_LAYOUT::TND) {
|
||||
uint32_t tBase = bIdx == 0 ? 0 : actualSeqLengthsQGm.GetValue(bIdx - 1);
|
||||
uint32_t s1Count = tempLoopInfo.actS1Size;
|
||||
|
||||
uint64_t attenOutOffset = (tBase + s1Idx) * kvHeadNum * constInfo.gSize * headDim +
|
||||
n2Idx * constInfo.gSize * headDim;
|
||||
matmul::InitOutput<OUT_T>(attentionOutGm[attenOutOffset], constInfo.gSize * headDim, 0);
|
||||
} else if (constInfo.outputLayout == SFA_LAYOUT::BSND) {
|
||||
uint64_t attenOutOffset = bIdx * constInfo.qSeqSize * kvHeadNum * constInfo.gSize * headDim +
|
||||
s1Idx * kvHeadNum * constInfo.gSize * headDim +
|
||||
n2Idx * constInfo.gSize * headDim;
|
||||
matmul::InitOutput<OUT_T>(attentionOutGm[attenOutOffset], constInfo.gSize * headDim, 0);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void SparseFlashAttentionMla<SFAT>::InitOutputSingleCore()
|
||||
{
|
||||
uint32_t coreNum = GetBlockNum();
|
||||
if (coreNum != 0) {
|
||||
uint64_t totalOutputSize = constInfo.batchSize * constInfo.qHeadNum * constInfo.qSeqSize * constInfo.headDim;
|
||||
uint64_t singleCoreSize = (totalOutputSize + (2 * coreNum) - 1) / (2 * coreNum); // 2 means c:v = 1:2
|
||||
uint64_t tailSize = totalOutputSize - tmpBlockIdx * singleCoreSize;
|
||||
uint64_t singleInitOutputSize = tailSize < singleCoreSize ? tailSize : singleCoreSize;
|
||||
if (singleInitOutputSize > 0) {
|
||||
matmul::InitOutput<OUT_T>(attentionOutGm[tmpBlockIdx * singleCoreSize], singleInitOutputSize, 0);
|
||||
}
|
||||
SyncAll();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void SparseFlashAttentionMla<SFAT>::GetActualSeqLen(uint32_t bIdx, uint32_t s1Idx)
|
||||
{
|
||||
tempLoopInfo.curActualSeqLenOri = GetActualSeqLenKV(bIdx);
|
||||
tempLoopInfo.actS1Size = GetBalanceActualSeqLengths(actualSeqLengthsQGm, bIdx);
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void SparseFlashAttentionMla<SFAT>::GetSparseActualSeqLen(uint32_t bIdx, uint32_t s1Idx,
|
||||
uint32_t n2Idx)
|
||||
{
|
||||
if (tempLoopInfo.nextTokensPerBatch < 0 && s1Idx < (-tempLoopInfo.nextTokensPerBatch)) {
|
||||
tempLoopInfo.curActualSeqLen = 0;
|
||||
return;
|
||||
}
|
||||
int64_t threshold = tempLoopInfo.curActualSeqLenOri;
|
||||
if (constInfo.sparseMode == 3) {
|
||||
threshold = static_cast<int64_t>(tempLoopInfo.nextTokensPerBatch) + s1Idx + 1;
|
||||
}
|
||||
|
||||
tempLoopInfo.curActualSeqLen = (constInfo.sparseBlockCount * constInfo.sparseBlockSize > threshold) ?
|
||||
threshold :
|
||||
constInfo.sparseBlockCount * constInfo.sparseBlockSize;
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline uint32_t SparseFlashAttentionMla<SFAT>::GetActualSeqLenKV(uint32_t bIdx)
|
||||
{
|
||||
if constexpr (KV_LAYOUT_T == SFA_LAYOUT::TND) {
|
||||
if (bIdx > 0) {
|
||||
return actualSeqLengthsKVGm.GetValue(bIdx) - actualSeqLengthsKVGm.GetValue(bIdx - 1);
|
||||
} else if (bIdx == 0) {
|
||||
return actualSeqLengthsKVGm.GetValue(0);
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
if (constInfo.actualLenDimsKV == 0) {
|
||||
return constInfo.kvSeqSize;
|
||||
} else if (constInfo.actualLenDimsKV == 1) {
|
||||
return actualSeqLengthsKVGm.GetValue(0);
|
||||
} else {
|
||||
return actualSeqLengthsKVGm.GetValue(bIdx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void SparseFlashAttentionMla<SFAT>::DealActSeqLenIsZero(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx)
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
InitAllZeroOutput(bIdx, s1Idx, n2Idx);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void SparseFlashAttentionMla<SFAT>::GetPreNextTokensLeftUp()
|
||||
{
|
||||
if (constInfo.sparseMode == 3) {
|
||||
tempLoopInfo.nextTokensPerBatch =
|
||||
static_cast<int32_t>(tempLoopInfo.curActualSeqLenOri) - static_cast<int32_t>(tempLoopInfo.actS1Size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SFAT> __aicore__ inline void SparseFlashAttentionMla<SFAT>::UpdateInnerLoopCond()
|
||||
{
|
||||
if ((tempLoopInfo.curActualSeqLen == 0) || (tempLoopInfo.actS1Size == 0)) {
|
||||
tempLoopInfo.curActSeqLenIsZero = true;
|
||||
return;
|
||||
}
|
||||
tempLoopInfo.curActSeqLenIsZero = false;
|
||||
tempLoopInfo.mBasicSizeTail = (tempLoopInfo.actS1Size * constInfo.gSize) % constInfo.mBaseSize;
|
||||
tempLoopInfo.mBasicSizeTail =
|
||||
(tempLoopInfo.mBasicSizeTail == 0) ? constInfo.mBaseSize : tempLoopInfo.mBasicSizeTail;
|
||||
tempLoopInfo.s2LoopTimes = 0;
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void SparseFlashAttentionMla<SFAT>::UpdateInner(uint32_t &s2End, uint32_t &curS2End,
|
||||
uint32_t s1Idx, bool isEnd)
|
||||
{
|
||||
uint32_t s1BaseSize = 1;
|
||||
int64_t s1Offset = s1BaseSize * s1Idx;
|
||||
int64_t s2LastToken = Min(s1Offset + tempLoopInfo.nextTokensPerBatch + s1BaseSize,tempLoopInfo.curActualSeqLenOri);
|
||||
s2LastToken = Min(constInfo.sparseBlockSize * constInfo.sparseBlockCount, s2LastToken);
|
||||
curS2End = (s2LastToken + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
|
||||
tempLoopInfo.s2LoopTimes = isEnd ? constInfo.s2End + 1 : curS2End;
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void SparseFlashAttentionMla<SFAT>::Init(__gm__ uint8_t *query,
|
||||
__gm__ uint8_t *key, __gm__ uint8_t *value,
|
||||
__gm__ uint8_t *sparseIndices, __gm__ uint8_t *actualSeqLengthsQ,
|
||||
__gm__ uint8_t *actualSeqLengths, __gm__ uint8_t *blockTable,
|
||||
__gm__ uint8_t *queryRope, __gm__ uint8_t *keyRope,
|
||||
__gm__ uint8_t *attentionOut, __gm__ uint8_t *workspace,
|
||||
const SparseFlashAttentionTilingDataMla *__restrict tiling,
|
||||
__gm__ uint8_t *gmTiling, TPipe *tPipe)
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
tmpBlockIdx = GetBlockIdx(); // vec:0-47
|
||||
aiCoreIdx = tmpBlockIdx / 2;
|
||||
} else {
|
||||
tmpBlockIdx = GetBlockIdx(); // cube:0-23
|
||||
aiCoreIdx = tmpBlockIdx;
|
||||
}
|
||||
|
||||
// init tiling data
|
||||
tilingData = tiling;
|
||||
|
||||
InitTilingData();
|
||||
InitActualSeqLen(actualSeqLengthsQ, actualSeqLengths);
|
||||
|
||||
InitCalcParamsEach();
|
||||
pipe = tPipe;
|
||||
keyPtr = key;
|
||||
valuePtr = value;
|
||||
|
||||
// init global buffer
|
||||
queryGm.SetGlobalBuffer((__gm__ Q_T *)query);
|
||||
keyGm.SetGlobalBuffer((__gm__ KV_T *)keyPtr);
|
||||
valueGm.SetGlobalBuffer((__gm__ KV_T *)valuePtr);
|
||||
qRopeGm.SetGlobalBuffer((__gm__ Q_ROPE_T *)queryRope);
|
||||
kRopeGm.SetGlobalBuffer((__gm__ K_ROPE_T *)keyRope);
|
||||
|
||||
attentionOutGm.SetGlobalBuffer((__gm__ OUT_T *)attentionOut);
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
if (constInfo.needInit && LAYOUT_T != SFA_LAYOUT::TND) {
|
||||
InitOutputSingleCore();
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (PAGE_ATTENTION) {
|
||||
blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable);
|
||||
}
|
||||
topKGm.SetGlobalBuffer((__gm__ int32_t *)sparseIndices);
|
||||
|
||||
uint64_t offset = 0;
|
||||
mm1ResGm.SetGlobalBuffer(
|
||||
(__gm__ MM1_OUT_T *)(workspace + offset +
|
||||
aiCoreIdx * dbWorkspaceRatio * constInfo.mmResUbSize * sizeof(MM1_OUT_T)));
|
||||
offset += GetBlockNum() * dbWorkspaceRatio * constInfo.mmResUbSize * sizeof(MM1_OUT_T);
|
||||
|
||||
vec1ResGm.SetGlobalBuffer(
|
||||
(__gm__ KV_T *)(workspace + offset + aiCoreIdx * dbWorkspaceRatio * constInfo.mmResUbSize * sizeof(KV_T)));
|
||||
offset += GetBlockNum() * dbWorkspaceRatio * constInfo.mmResUbSize * sizeof(KV_T);
|
||||
|
||||
mm2ResGm.SetGlobalBuffer(
|
||||
(__gm__ MM2_OUT_T *)(workspace + offset +
|
||||
aiCoreIdx * dbWorkspaceRatio * constInfo.bmm2ResUbSize * sizeof(MM2_OUT_T)));
|
||||
offset += GetBlockNum() * dbWorkspaceRatio * constInfo.bmm2ResUbSize * sizeof(MM2_OUT_T);
|
||||
mm2ResInt32Gm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(mm2ResGm.GetPhyAddr(0)));
|
||||
|
||||
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
|
||||
// s2 d+rope bufNum
|
||||
kvMergeGm_.SetGlobalBuffer((__gm__ KV_T *)(workspace + offset + aiCoreIdx * 512 * 576 * 4 * sizeof(KV_T)));
|
||||
offset += GetBlockNum() * 512 * 576 * 4 * sizeof(KV_T);
|
||||
|
||||
kvValidSizeGm_.SetGlobalBuffer(
|
||||
(__gm__ int32_t *)(workspace + offset + (aiCoreIdx * 2) * 128 * 4 * sizeof(int32_t)));
|
||||
}
|
||||
|
||||
if constexpr (FLASH_DECODE) {
|
||||
accumOutGm.SetGlobalBuffer((__gm__ float *)(workspace + offset));
|
||||
offset = offset + tilingData->splitKVParams.accumOutSize * sizeof(float);
|
||||
lseSumFdGm.SetGlobalBuffer((__gm__ float *)(workspace + offset));
|
||||
lseMaxFdGm.SetGlobalBuffer((__gm__ float *)(workspace + offset) + tilingData->splitKVParams.logSumExpSize / 2);
|
||||
offset = offset + tilingData->splitKVParams.logSumExpSize * sizeof(float);
|
||||
}
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.InitParams(constInfo, tilingData);
|
||||
vectorService.InitMm2ResInt32GmGlobalTensor(mm2ResInt32Gm);
|
||||
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
|
||||
vectorService.InitVec0GlobalTensor(kvValidSizeGm_, kvMergeGm_, kRopeGm, keyGm, blockTableGm);
|
||||
}
|
||||
vectorService.InitVec1GlobalTensor(mm1ResGm, vec1ResGm, actualSeqLengthsQGm,
|
||||
actualSeqLengthsKVGm, lseMaxFdGm, lseSumFdGm, topKGm);
|
||||
vectorService.InitVec2GlobalTensor(accumOutGm, vec2ResGm, mm2ResGm, attentionOutGm);
|
||||
}
|
||||
|
||||
if ASCEND_IS_AIC {
|
||||
matmulService.InitParams(constInfo);
|
||||
matmulService.InitMm1GlobalTensor(queryGm, qRopeGm, keyGm, kRopeGm, mm1ResGm);
|
||||
matmulService.InitMm2GlobalTensor(vec1ResGm, valueGm, mm2ResGm, attentionOutGm);
|
||||
matmulService.InitPageAttentionInfo(kvMergeGm_, blockTableGm, topKGm,
|
||||
constInfo.kvCacheBlockSize, constInfo.maxBlockNumPerBatch);
|
||||
}
|
||||
if (pipe != nullptr) {
|
||||
InitBuffers();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SFAT> __aicore__ inline void SparseFlashAttentionMla<SFAT>::InitCalcParamsEach()
|
||||
{
|
||||
uint32_t totalBaseNum = 0;
|
||||
uint32_t s1GBaseSize = constInfo.gSize;
|
||||
uint32_t actBatchS2 = 1;
|
||||
uint32_t coreNum = GetBlockNum();
|
||||
uint32_t currCoreIdx = aiCoreIdx;
|
||||
uint32_t actBatchS1 = 1;
|
||||
for (uint32_t bIdx = 0; bIdx < constInfo.batchSize; bIdx++) {
|
||||
uint32_t actBatchS1 = GetBalanceActualSeqLengths(actualSeqLengthsQGm, bIdx);
|
||||
if (actBatchS1 < constInfo.qSeqSize) {
|
||||
constInfo.needInit = true;
|
||||
}
|
||||
totalBaseNum += actBatchS1*actBatchS2 ;
|
||||
}
|
||||
uint32_t avgBaseNum = 1;
|
||||
if (totalBaseNum > coreNum) {
|
||||
avgBaseNum = (totalBaseNum + coreNum - 1) / coreNum;
|
||||
}else {
|
||||
usedCoreNum = totalBaseNum;
|
||||
}
|
||||
if(aiCoreIdx>=usedCoreNum){
|
||||
return;
|
||||
}
|
||||
uint32_t accumBaseNum = 0;
|
||||
uint32_t targetBaseNum = 0;
|
||||
uint32_t lastValidBIdx = 0;
|
||||
uint32_t lastValidactBatchS1=0;
|
||||
bool setStart=false;
|
||||
targetBaseNum = (currCoreIdx + 1) * avgBaseNum;
|
||||
uint32_t targetStartBaseNum = targetBaseNum-avgBaseNum;
|
||||
for (uint32_t bN2Idx = 0; bN2Idx < constInfo.batchSize * constInfo.kvHeadNum; bN2Idx++) {
|
||||
uint32_t bIdx = bN2Idx / constInfo.kvHeadNum;
|
||||
actBatchS1 = GetBalanceActualSeqLengths(actualSeqLengthsQGm, bIdx);
|
||||
for (uint32_t s1GIdx = 0; s1GIdx < actBatchS1; s1GIdx++) {
|
||||
accumBaseNum += 1;
|
||||
if(!setStart && accumBaseNum >= targetStartBaseNum){
|
||||
constInfo.bN2Start = bN2Idx;
|
||||
constInfo.gS1Start = s1GIdx;
|
||||
setStart=true;
|
||||
}
|
||||
if (accumBaseNum >= targetBaseNum) {
|
||||
constInfo.bN2End = bN2Idx;
|
||||
constInfo.gS1End = s1GIdx;
|
||||
constInfo.s2End = 0;
|
||||
constInfo.coreStartKVSplitPos = 0;
|
||||
if (aiCoreIdx != 0) {
|
||||
GetAxisStartIdx(constInfo.bN2Start, constInfo.gS1Start, 0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
if ((actBatchS1 > 0) && (actBatchS2 > 0)) {
|
||||
lastValidBIdx = bIdx;
|
||||
lastValidactBatchS1 = actBatchS1;
|
||||
}
|
||||
}
|
||||
if (!setStart){
|
||||
constInfo.bN2Start = lastValidBIdx;
|
||||
constInfo.gS1Start = lastValidactBatchS1-1;
|
||||
}
|
||||
if (accumBaseNum < targetBaseNum) {
|
||||
constInfo.bN2End = lastValidBIdx;
|
||||
constInfo.gS1End = lastValidactBatchS1-1;
|
||||
constInfo.s2End = 0;
|
||||
constInfo.coreStartKVSplitPos = 0;
|
||||
if (aiCoreIdx != 0) {
|
||||
GetAxisStartIdx(constInfo.bN2Start, constInfo.gS1Start, 0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void
|
||||
SparseFlashAttentionMla<SFAT>::Bmm2DataCopyOut(uint64_t attenOutOffset, LocalTensor<OUT_T> &attenOutUb,
|
||||
uint32_t startRow, uint32_t dealRowCount,
|
||||
uint32_t columnCount, uint32_t actualColumnCount)
|
||||
{
|
||||
DataCopyExtParams dataCopyParams;
|
||||
dataCopyParams.blockCount = dealRowCount;
|
||||
dataCopyParams.blockLen = actualColumnCount * sizeof(OUT_T);
|
||||
dataCopyParams.srcStride = (columnCount - actualColumnCount) / (SFAVectorService<SFAT>::BYTE_BLOCK / sizeof(OUT_T));
|
||||
dataCopyParams.dstStride = 0;
|
||||
DataCopyPad(attentionOutGm[attenOutOffset + (mSizeVStart + startRow) * actualColumnCount], attenOutUb,
|
||||
dataCopyParams);
|
||||
}
|
||||
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void SparseFlashAttentionMla<SFAT>::CalcParams(uint32_t loop, uint64_t s2Start,
|
||||
uint32_t s2LoopIdx, RunInfo &info)
|
||||
{
|
||||
info.loop = loop;
|
||||
info.bIdx = tempLoopInfo.bIdx;
|
||||
info.gS1Idx = tempLoopInfo.gS1Idx;
|
||||
info.s2Idx = s2LoopIdx;
|
||||
info.curSInnerLoopTimes = tempLoopInfo.s2LoopTimes;
|
||||
|
||||
info.tndIsS2SplitCore = tempLoopInfo.tndIsS2SplitCore;
|
||||
info.tndCoreStartKVSplitPos = tempLoopInfo.tndCoreStartKVSplitPos;
|
||||
info.isBmm2Output = false;
|
||||
|
||||
info.actS1Size = tempLoopInfo.actS1Size;
|
||||
|
||||
|
||||
info.actMBaseSize = constInfo.mBaseSize;
|
||||
uint32_t remainedGS1Size = tempLoopInfo.actS1Size * constInfo.gSize - tempLoopInfo.gS1Idx;
|
||||
if (remainedGS1Size <= constInfo.mBaseSize && remainedGS1Size > 0) {
|
||||
info.actMBaseSize = tempLoopInfo.mBasicSizeTail;
|
||||
}
|
||||
|
||||
info.isValid = s2LoopIdx < tempLoopInfo.s2LoopTimes;
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
info.mSize = info.actMBaseSize;
|
||||
info.mSizeV = (info.mSize <= 16) ? info.mSize : (((info.mSize + 15) / 16 + 1) / 2 * 16);
|
||||
info.mSizeVStart = 0;
|
||||
if (tmpBlockIdx % 2 == 1) {
|
||||
info.mSizeVStart = info.mSizeV;
|
||||
info.mSizeV = info.mSize - info.mSizeV;
|
||||
}
|
||||
}
|
||||
|
||||
info.isChangeBatch = false;
|
||||
|
||||
info.isFirstSInnerLoop = s2LoopIdx == s2Start;
|
||||
if (info.isFirstSInnerLoop) {
|
||||
tempLoopInfo.bn2IdxInCurCore++;
|
||||
}
|
||||
info.isLastS2Loop = s2LoopIdx == tempLoopInfo.s2LoopTimes - 1;
|
||||
info.bn2IdxInCurCore = tempLoopInfo.bn2IdxInCurCore - 1;
|
||||
uint64_t actualSeqQPrefixSum;
|
||||
if constexpr (LAYOUT_T == SFA_LAYOUT::TND) {
|
||||
actualSeqQPrefixSum = (info.bIdx <= 0) ? 0 : actualSeqLengthsQGm.GetValue(info.bIdx - 1);
|
||||
} else {
|
||||
actualSeqQPrefixSum = (info.bIdx <= 0) ? 0 : info.bIdx * constInfo.qSeqSize;
|
||||
}
|
||||
info.tndBIdxOffsetForQ = actualSeqQPrefixSum * constInfo.qHeadNum * headDim;
|
||||
|
||||
uint64_t actualSeqKVPrefixSum;
|
||||
if constexpr (KV_LAYOUT_T == SFA_LAYOUT::TND) {
|
||||
actualSeqKVPrefixSum = (info.bIdx <= 0) ? 0 : actualSeqLengthsKVGm.GetValue(info.bIdx - 1);
|
||||
} else {
|
||||
actualSeqKVPrefixSum = (info.bIdx <= 0) ? 0 : info.bIdx * constInfo.kvSeqSize;
|
||||
}
|
||||
info.tndBIdxOffsetForKV = actualSeqKVPrefixSum * constInfo.kvHeadNum * headDim;
|
||||
|
||||
if (info.isFirstSInnerLoop) {
|
||||
uint64_t tndBIdxRopeOffsetForQ = actualSeqQPrefixSum * constInfo.qHeadNum * headDimRope;
|
||||
tensorACoreOffset = info.tndBIdxOffsetForQ + info.gS1Idx * headDim;
|
||||
tensorARopeCoreOffset = tndBIdxRopeOffsetForQ + info.gS1Idx * headDimRope;
|
||||
|
||||
uint64_t tndBIdxRopeOffsetForK = actualSeqKVPrefixSum * constInfo.kvHeadNum * headDimRope;
|
||||
tensorBCoreOffset = info.tndBIdxOffsetForKV + info.n2Idx * headDim;
|
||||
tensorBRopeCoreOffset = tndBIdxRopeOffsetForK + info.n2Idx * headDimRope;
|
||||
if (constInfo.sparseMode == 3) {
|
||||
threshold = static_cast<int64_t>(tempLoopInfo.nextTokensPerBatch) + info.gS1Idx / constInfo.gSize + 1;
|
||||
} else {
|
||||
threshold = tempLoopInfo.curActualSeqLenOri;
|
||||
}
|
||||
if constexpr(LAYOUT_T == SFA_LAYOUT::BSND) { // B,S1,N2 K
|
||||
topKBaseOffset = info.bIdx * constInfo.qSeqSize * constInfo.kvHeadNum * constInfo.sparseBlockCount +
|
||||
info.gS1Idx / constInfo.gSize * constInfo.kvHeadNum * constInfo.sparseBlockCount +
|
||||
info.n2Idx * constInfo.sparseBlockCount;
|
||||
} else if (LAYOUT_T == SFA_LAYOUT::TND) { // T N2 K
|
||||
topKBaseOffset = info.tndBIdxOffsetForQ / constInfo.gSize / constInfo.headDim * constInfo.kvHeadNum *
|
||||
constInfo.sparseBlockCount + info.n2Idx * constInfo.sparseBlockCount +
|
||||
info.gS1Idx / constInfo.gSize * constInfo.kvHeadNum * constInfo.sparseBlockCount;
|
||||
} else { // B N2 S1 K
|
||||
topKBaseOffset = info.bIdx * constInfo.kvHeadNum * constInfo.qSeqSize * constInfo.sparseBlockCount +
|
||||
info.n2Idx * constInfo.qSeqSize * constInfo.sparseBlockCount +
|
||||
info.gS1Idx / constInfo.gSize * constInfo.sparseBlockCount;
|
||||
}
|
||||
}
|
||||
info.topKBaseOffset = topKBaseOffset;
|
||||
info.threshold = threshold;
|
||||
info.tensorAOffset = tensorACoreOffset;
|
||||
info.tensorARopeOffset = tensorARopeCoreOffset;
|
||||
info.tensorBOffset = tensorBCoreOffset;
|
||||
info.tensorBRopeOffset = tensorBRopeCoreOffset;
|
||||
info.attenOutOffset = tensorACoreOffset;
|
||||
|
||||
uint64_t sInnerOffsetDataSize = info.s2Idx * constInfo.s2BaseSize;
|
||||
info.s2BatchOffset = s2BatchBaseOffset + sInnerOffsetDataSize;
|
||||
|
||||
info.curActualSeqLenOri = tempLoopInfo.curActualSeqLenOri;
|
||||
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
|
||||
if (tempLoopInfo.curActualSeqLen > sInnerOffsetDataSize) {
|
||||
info.actualSingleProcessSInnerSize = tempLoopInfo.curActualSeqLen - sInnerOffsetDataSize;
|
||||
info.actualSingleProcessSInnerSize = info.actualSingleProcessSInnerSize > constInfo.s2BaseSize ?
|
||||
constInfo.s2BaseSize : info.actualSingleProcessSInnerSize;
|
||||
info.actualSingleProcessSInnerSize =
|
||||
SFAAlign((int64_t)info.actualSingleProcessSInnerSize, (int64_t)constInfo.sparseBlockSize);
|
||||
} else {
|
||||
info.actualSingleProcessSInnerSize = 0;
|
||||
}
|
||||
info.actualSingleProcessSInnerSizeAlign =
|
||||
SFAAlign((uint32_t)info.actualSingleProcessSInnerSize, (uint32_t)SFAVectorService<SFAT>::BYTE_BLOCK);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void SparseFlashAttentionMla<SFAT>::ComputeMm1(const RunInfo &info)
|
||||
{
|
||||
uint32_t nBufferLoopTimes = (info.actMBaseSize + constInfo.nBufferMBaseSize - 1) / constInfo.nBufferMBaseSize;
|
||||
uint32_t nBufferTail = info.actMBaseSize - (nBufferLoopTimes - 1) * constInfo.nBufferMBaseSize;
|
||||
for (uint32_t i = 0; i < nBufferLoopTimes; i++) {
|
||||
MSplitInfo mSplitInfo;
|
||||
mSplitInfo.nBufferStartM = i * constInfo.nBufferMBaseSize;
|
||||
mSplitInfo.nBufferDealM = (i + 1 != nBufferLoopTimes) ? constInfo.nBufferMBaseSize : nBufferTail;
|
||||
matmulService.ComputeMm1(info, mSplitInfo);
|
||||
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void SparseFlashAttentionMla<SFAT>::ComputeMm2(const RunInfo &info)
|
||||
{
|
||||
uint32_t nBufferLoopTimes = (info.actMBaseSize + constInfo.nBufferMBaseSize - 1) / constInfo.nBufferMBaseSize;
|
||||
uint32_t nBufferTail = info.actMBaseSize - (nBufferLoopTimes - 1) * constInfo.nBufferMBaseSize;
|
||||
for (uint32_t i = 0; i < nBufferLoopTimes; i++) {
|
||||
MSplitInfo mSplitInfo;
|
||||
mSplitInfo.nBufferStartM = i * constInfo.nBufferMBaseSize;
|
||||
mSplitInfo.nBufferDealM = (i + 1 != nBufferLoopTimes) ? constInfo.nBufferMBaseSize : nBufferTail;
|
||||
CrossCoreWaitFlag(constInfo.syncV1C2);
|
||||
matmulService.ComputeMm2(info, mSplitInfo);
|
||||
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC2V2);
|
||||
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC2V1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SFAT> __aicore__ inline void SparseFlashAttentionMla<SFAT>::Process()
|
||||
{
|
||||
if (aiCoreIdx < usedCoreNum) {
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.AllocEventID();
|
||||
vectorService.InitSoftmaxDefaultBuffer();
|
||||
} else {
|
||||
matmulService.AllocEventID();
|
||||
}
|
||||
ProcessBalance();
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.FreeEventID();
|
||||
} else {
|
||||
matmulService.FreeEventID();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void SparseFlashAttentionMla<SFAT>::GetBN2Idx(uint32_t bN2Idx, uint32_t &bIdx,
|
||||
uint32_t &n2Idx)
|
||||
{
|
||||
bIdx = bN2Idx / kvHeadNum;
|
||||
n2Idx = bN2Idx % kvHeadNum;
|
||||
}
|
||||
|
||||
template <typename SFAT> __aicore__ inline void SparseFlashAttentionMla<SFAT>::ProcessBalance()
|
||||
{
|
||||
RunInfo extraInfo[SFA_PRELOAD_TASK_CACHE_SIZE];
|
||||
uint32_t gloop = 0;
|
||||
int gS1LoopEnd;
|
||||
bool globalLoopStart = true;
|
||||
if ASCEND_IS_AIC {
|
||||
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC2V1);
|
||||
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
|
||||
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_MTE2>(3);
|
||||
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_MTE2>(3);
|
||||
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_MTE2>(3);
|
||||
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_MTE2>(3);
|
||||
}
|
||||
}
|
||||
for (uint32_t bN2LoopIdx = constInfo.bN2Start; bN2LoopIdx <= constInfo.bN2End; bN2LoopIdx++) {
|
||||
GetBN2Idx(bN2LoopIdx, tempLoopInfo.bIdx, tempLoopInfo.n2Idx);
|
||||
GetActualSeqLen(tempLoopInfo.bIdx);
|
||||
GetPreNextTokensLeftUp();
|
||||
if (tempLoopInfo.actS1Size == 0) {
|
||||
continue;
|
||||
}
|
||||
int gS1SplitNum = (tempLoopInfo.actS1Size * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize;
|
||||
gS1LoopEnd = (bN2LoopIdx == constInfo.bN2End) ? constInfo.gS1End : gS1SplitNum - 1;
|
||||
for (uint32_t gS1LoopIdx = constInfo.gS1Start; gS1LoopIdx <= gS1LoopEnd; gS1LoopIdx++) {
|
||||
tempLoopInfo.gS1Idx = gS1LoopIdx * constInfo.mBaseSize;
|
||||
GetSparseActualSeqLen(tempLoopInfo.bIdx, gS1LoopIdx, tempLoopInfo.n2Idx);
|
||||
UpdateInnerLoopCond();
|
||||
|
||||
if (tempLoopInfo.curActSeqLenIsZero) {
|
||||
DealActSeqLenIsZero(tempLoopInfo.bIdx, gS1LoopIdx, tempLoopInfo.n2Idx);
|
||||
}
|
||||
int s2SplitNum =
|
||||
(tempLoopInfo.curActualSeqLen + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
|
||||
bool isEnd = (bN2LoopIdx == constInfo.bN2End) && (gS1LoopIdx == constInfo.gS1End);
|
||||
tempLoopInfo.s2LoopTimes = s2SplitNum;
|
||||
tempLoopInfo.tndIsS2SplitCore =
|
||||
((constInfo.s2Start == 0) && (tempLoopInfo.s2LoopTimes == s2SplitNum)) ? false : true;
|
||||
tempLoopInfo.tndCoreStartKVSplitPos = globalLoopStart ? constInfo.coreStartKVSplitPos : 0;
|
||||
uint32_t extraLoop = isEnd ? 2 : 0;
|
||||
|
||||
uint32_t curTopKIdx = 0;
|
||||
uint64_t curOffsetInSparseBlock = 0;
|
||||
for (int s2LoopIdx = constInfo.s2Start; s2LoopIdx < (tempLoopInfo.s2LoopTimes + extraLoop); s2LoopIdx++) {
|
||||
PreloadPipeline(gloop, constInfo.s2Start, s2LoopIdx, extraInfo, curTopKIdx, curOffsetInSparseBlock);
|
||||
++gloop;
|
||||
}
|
||||
globalLoopStart = false;
|
||||
constInfo.s2Start = 0;
|
||||
}
|
||||
constInfo.gS1Start = 0;
|
||||
}
|
||||
if ASCEND_IS_AIV {
|
||||
CrossCoreWaitFlag(constInfo.syncC2V1);
|
||||
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
|
||||
CrossCoreWaitFlag(3);
|
||||
CrossCoreWaitFlag(3);
|
||||
CrossCoreWaitFlag(3);
|
||||
CrossCoreWaitFlag(3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void
|
||||
SparseFlashAttentionMla<SFAT>::PreloadPipeline(uint32_t loop, uint64_t s2Start, uint64_t s2LoopIdx,
|
||||
RunInfo extraInfo[SFA_PRELOAD_TASK_CACHE_SIZE], uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock)
|
||||
{
|
||||
RunInfo &extraInfo0 = extraInfo[loop % SFA_PRELOAD_TASK_CACHE_SIZE];
|
||||
RunInfo &extraInfo2 = extraInfo[(loop + 2) % SFA_PRELOAD_TASK_CACHE_SIZE];
|
||||
RunInfo &extraInfo1 = extraInfo[(loop + 1) % SFA_PRELOAD_TASK_CACHE_SIZE];
|
||||
|
||||
CalcParams(loop, s2Start, s2LoopIdx, extraInfo0);
|
||||
CalcSinnerTopKBegin(extraInfo0, curTopKIdx, curOffsetInSparseBlock);
|
||||
|
||||
if (extraInfo0.isValid) {
|
||||
if ASCEND_IS_AIC {
|
||||
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
|
||||
CrossCoreWaitFlag(constInfo.syncV0C1);
|
||||
}
|
||||
ComputeMm1(extraInfo0);
|
||||
} else {
|
||||
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
|
||||
CrossCoreWaitFlag(3);
|
||||
vectorService.MergeKv(extraInfo0);
|
||||
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_MTE3>(constInfo.syncV0C1);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (extraInfo2.isValid) {
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.ProcessVec1L(extraInfo2);
|
||||
}
|
||||
if ASCEND_IS_AIC {
|
||||
ComputeMm2(extraInfo2);
|
||||
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
|
||||
CrossCoreSetFlag<ConstInfo::SFA_SYNC_MODE2, PIPE_MTE2>(3);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (extraInfo1.isValid) {
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.ProcessVec2L(extraInfo1);
|
||||
}
|
||||
extraInfo1.isValid = false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline uint64_t
|
||||
SparseFlashAttentionMla<SFAT>::GetBalanceActualSeqLengths(GlobalTensor<int32_t> &actualSeqLengths,
|
||||
uint32_t bIdx)
|
||||
{
|
||||
if constexpr (LAYOUT_T == SFA_LAYOUT::TND) {
|
||||
if (bIdx > 0) {
|
||||
return actualSeqLengths.GetValue(bIdx) - actualSeqLengths.GetValue(bIdx - 1);
|
||||
} else if (bIdx == 0) {
|
||||
return actualSeqLengths.GetValue(0);
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
if (constInfo.actualLenDimsQ == 0) {
|
||||
return constInfo.qSeqSize;
|
||||
} else if (constInfo.actualLenDimsQ == 1) {
|
||||
return actualSeqLengths.GetValue(0);
|
||||
} else {
|
||||
return actualSeqLengths.GetValue(bIdx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void SparseFlashAttentionMla<SFAT>::GetAxisStartIdx(uint32_t bN2EndPrev,
|
||||
uint32_t s1GEndPrev,
|
||||
uint32_t s2EndPrev)
|
||||
{
|
||||
uint32_t bEndPrev = bN2EndPrev / kvHeadNum;
|
||||
uint32_t actualSeqQPrev = GetBalanceActualSeqLengths(actualSeqLengthsQGm, bEndPrev);
|
||||
uint32_t s1GPrevBaseNum = (actualSeqQPrev * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize;
|
||||
constInfo.bN2Start = bN2EndPrev;
|
||||
constInfo.gS1Start = s1GEndPrev;
|
||||
|
||||
constInfo.s2Start = 0;
|
||||
if (s1GEndPrev >= s1GPrevBaseNum - 1) {
|
||||
constInfo.gS1Start = 0;
|
||||
constInfo.bN2Start++;
|
||||
} else {
|
||||
constInfo.gS1Start++;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SFAT>
|
||||
__aicore__ inline void SparseFlashAttentionMla<SFAT>::CalcSinnerTopKBegin(RunInfo &info, uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock)
|
||||
|
||||
{
|
||||
if constexpr (TEMPLATE_MODE == V_TEMPLATE) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t thresholdSparseCount = (info.threshold + constInfo.sparseBlockSize - 1) / constInfo.sparseBlockSize;
|
||||
uint64_t validCount = (constInfo.sparseBlockCount > thresholdSparseCount) ? thresholdSparseCount : constInfo.sparseBlockCount;
|
||||
|
||||
int32_t sparseIndices = topKGm.GetValue(info.topKBaseOffset + curTopKIdx);
|
||||
if (sparseIndices == -1 || curTopKIdx == validCount) {
|
||||
info.actualSingleProcessSInnerSize = 0;
|
||||
info.actualSingleProcessSInnerSizeAlign = 0;
|
||||
tempLoopInfo.s2BasicSizeTail = 0;
|
||||
if (curTopKIdx == 0) {
|
||||
DealActSeqLenIsZero(info.bIdx, info.gS1Idx / constInfo.gSize, tempLoopInfo.n2Idx);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t sparseLen = 0;
|
||||
uint64_t blockBegin = sparseIndices * constInfo.sparseBlockSize;
|
||||
uint64_t blockEnd = (blockBegin + constInfo.sparseBlockSize > info.threshold) ? info.threshold : blockBegin + constInfo.sparseBlockSize;
|
||||
int32_t blockLen = blockEnd - blockBegin;
|
||||
sparseLen += (blockLen > static_cast<int32_t>(curOffsetInSparseBlock)) ? blockLen - curOffsetInSparseBlock : 0;
|
||||
|
||||
bool firstVaildFlag = false;
|
||||
if (curTopKIdx > 0) {
|
||||
info.curTopKIdx = curTopKIdx;
|
||||
info.curOffsetInSparseBlock = curOffsetInSparseBlock;
|
||||
} else if (curTopKIdx == 0 && sparseLen > 0) {
|
||||
info.curTopKIdx = curTopKIdx;
|
||||
info.curOffsetInSparseBlock = 0;
|
||||
firstVaildFlag = true;
|
||||
}
|
||||
|
||||
for (uint64_t topkIdx = curTopKIdx + 1; topkIdx < validCount; topkIdx++) {
|
||||
int32_t sparseIndices = topKGm.GetValue(info.topKBaseOffset + topkIdx);
|
||||
if (sparseIndices == -1) {
|
||||
curTopKIdx = topkIdx;
|
||||
curOffsetInSparseBlock = 0;
|
||||
break;
|
||||
}
|
||||
uint64_t blockBegin = sparseIndices * constInfo.sparseBlockSize;
|
||||
if (blockBegin >= info.threshold) {
|
||||
continue;
|
||||
}
|
||||
if (firstVaildFlag == false && curTopKIdx == 0) {
|
||||
info.curTopKIdx = topkIdx;
|
||||
info.curOffsetInSparseBlock = 0;
|
||||
firstVaildFlag = true;
|
||||
}
|
||||
uint64_t blockEnd = (blockBegin + constInfo.sparseBlockSize > info.threshold) ? info.threshold : blockBegin + constInfo.sparseBlockSize;
|
||||
uint64_t blockLen = blockEnd - blockBegin;
|
||||
sparseLen += blockLen;
|
||||
if (sparseLen >= constInfo.s2BaseSize) {
|
||||
curTopKIdx = topkIdx;
|
||||
curOffsetInSparseBlock = blockLen - (sparseLen - constInfo.s2BaseSize);
|
||||
sparseLen = constInfo.s2BaseSize;
|
||||
break;
|
||||
}
|
||||
|
||||
if (topkIdx == validCount - 1) {
|
||||
curTopKIdx = validCount;
|
||||
curOffsetInSparseBlock = 0;
|
||||
}
|
||||
}
|
||||
|
||||
info.actualSingleProcessSInnerSize = sparseLen;
|
||||
info.actualSingleProcessSInnerSizeAlign = SFAAlign((uint32_t)info.actualSingleProcessSInnerSize, (uint32_t)SFAVectorService<SFAT>::BYTE_BLOCK);
|
||||
tempLoopInfo.s2BasicSizeTail = (sparseLen == constInfo.s2BaseSize) ? 0 : sparseLen;
|
||||
if (curTopKIdx == 0 && sparseLen == 0) {
|
||||
DealActSeqLenIsZero(info.bIdx, info.gS1Idx / constInfo.gSize, tempLoopInfo.n2Idx);
|
||||
}
|
||||
}
|
||||
#endif // SPARSE_FLASH_ATTENTION_KERNEL_MLA_H
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,54 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file sparse_flash_attention_template_tiling_key.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef SPARSE_FLASH_ATTENTION_TEMPLATE_TILING_KEY_H
|
||||
#define SPARSE_FLASH_ATTENTION_TEMPLATE_TILING_KEY_H
|
||||
|
||||
#include "ascendc/host_api/tiling/template_argument.h"
|
||||
|
||||
#define SFA_LAYOUT_BSND 0
|
||||
#define SFA_LAYOUT_TND 1
|
||||
#define SFA_LAYOUT_PA_BSND 2
|
||||
|
||||
#define ASCENDC_TPL_4_BW 4
|
||||
|
||||
#define C_TEMPLATE 0
|
||||
#define V_TEMPLATE 1
|
||||
|
||||
ASCENDC_TPL_ARGS_DECL(SparseFlashAttention,
|
||||
ASCENDC_TPL_BOOL_DECL(FLASH_DECODE, 0, 1),
|
||||
ASCENDC_TPL_UINT_DECL(LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_TND),
|
||||
ASCENDC_TPL_UINT_DECL(KV_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_TND,
|
||||
SFA_LAYOUT_PA_BSND),
|
||||
ASCENDC_TPL_UINT_DECL(TEMPLATE_MODE, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, C_TEMPLATE, V_TEMPLATE),
|
||||
);
|
||||
|
||||
ASCENDC_TPL_SEL(
|
||||
ASCENDC_TPL_ARGS_SEL(
|
||||
ASCENDC_TPL_BOOL_SEL(FLASH_DECODE, 0),
|
||||
ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_TND),
|
||||
ASCENDC_TPL_UINT_SEL(KV_LAYOUT_T, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_PA_BSND, SFA_LAYOUT_TND),
|
||||
ASCENDC_TPL_UINT_SEL(TEMPLATE_MODE, ASCENDC_TPL_UI_LIST, C_TEMPLATE),
|
||||
),
|
||||
|
||||
ASCENDC_TPL_ARGS_SEL(
|
||||
ASCENDC_TPL_BOOL_SEL(FLASH_DECODE, 0),
|
||||
ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_TND),
|
||||
ASCENDC_TPL_UINT_SEL(KV_LAYOUT_T, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_PA_BSND, SFA_LAYOUT_BSND, SFA_LAYOUT_TND),
|
||||
ASCENDC_TPL_UINT_SEL(TEMPLATE_MODE, ASCENDC_TPL_UI_LIST, V_TEMPLATE), // V模板不支持非PA
|
||||
),
|
||||
);
|
||||
|
||||
#endif // TEMPLATE_TILING_KEY
|
||||
@@ -620,6 +620,103 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor
|
||||
|
||||
}
|
||||
|
||||
at::Tensor npu_lightning_indexer(
|
||||
const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_query,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_key,
|
||||
const c10::optional<at::Tensor> &block_table, c10::string_view layout_query,
|
||||
c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode)
|
||||
{
|
||||
// npu tensor max size
|
||||
constexpr int32_t SIZE = 8;
|
||||
constexpr int32_t DIM_0 = 0;
|
||||
constexpr int32_t DIM_1 = 1;
|
||||
constexpr int32_t DIM_2 = 2;
|
||||
constexpr int32_t DIM_3 = 3;
|
||||
|
||||
TORCH_CHECK(query.numel() > 0, "Query is empty.");
|
||||
TORCH_CHECK(key.numel() > 0, "Key is empty.");
|
||||
TORCH_CHECK(weights.numel() > 0, "Weights is empty.");
|
||||
for (size_t i = 0; i < query.sizes().size(); i++) {
|
||||
TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater "
|
||||
"than 0, but shape[", i, "] is ", query.size(i));
|
||||
}
|
||||
TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count);
|
||||
|
||||
at::SmallVector<int64_t, SIZE> output_size;
|
||||
std::string query_layout_str = std::string(layout_query);
|
||||
std::string key_layout_str = std::string(layout_key);
|
||||
if (query_layout_str == "BSND") {
|
||||
output_size = {query.size(DIM_0), query.size(DIM_1), key.size(DIM_2), sparse_count};
|
||||
} else {
|
||||
int n_dim_index = 0;
|
||||
n_dim_index = (key_layout_str == "TND") ? DIM_1 : DIM_2;
|
||||
output_size = {query.size(DIM_0), key.size(n_dim_index), sparse_count};
|
||||
}
|
||||
at::Tensor lightning_indexer_output = at::empty(output_size, query.options().dtype(at::kInt));
|
||||
// convert str
|
||||
char *query_layout_ptr = const_cast<char *>(query_layout_str.c_str());
|
||||
char *key_layout_ptr = const_cast<char *>(key_layout_str.c_str());
|
||||
EXEC_NPU_CMD(
|
||||
aclnnLightningIndexer,
|
||||
query,
|
||||
key,
|
||||
weights,
|
||||
actual_seq_lengths_query,
|
||||
actual_seq_lengths_key,
|
||||
block_table,
|
||||
query_layout_ptr,
|
||||
key_layout_ptr,
|
||||
sparse_count,
|
||||
sparse_mode,
|
||||
lightning_indexer_output);
|
||||
return lightning_indexer_output;
|
||||
}
|
||||
|
||||
at::Tensor npu_sparse_flash_attention(
|
||||
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
|
||||
const at::Tensor &sparse_indices, double scale_value, int64_t sparse_block_size,
|
||||
const c10::optional<at::Tensor> &block_table,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_query,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_kv,
|
||||
const c10::optional<at::Tensor> &query_rope,
|
||||
const c10::optional<at::Tensor> &key_rope, c10::string_view layout_query,
|
||||
c10::string_view layout_kv,
|
||||
int64_t sparse_mode)
|
||||
{
|
||||
std::string layout_query_str = std::string(layout_query);
|
||||
std::string layout_kv_str = std::string(layout_kv);
|
||||
|
||||
for (size_t i = 0; i < query.sizes().size(); i++) {
|
||||
TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater "
|
||||
"than 0, but shape[", i, "] is ", query.size(i));
|
||||
}
|
||||
// construct the output tensor
|
||||
at::Tensor output = at::empty(query.sizes(), query.options().dtype(query.dtype()));
|
||||
// convert str
|
||||
char *layout_query_ptr = const_cast<char *>(layout_query_str.c_str());
|
||||
char *layout_kv_ptr = const_cast<char *>(layout_kv_str.c_str());
|
||||
|
||||
EXEC_NPU_CMD(
|
||||
aclnnSparseFlashAttention,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
sparse_indices,
|
||||
block_table,
|
||||
actual_seq_lengths_query,
|
||||
actual_seq_lengths_kv,
|
||||
query_rope,
|
||||
key_rope,
|
||||
scale_value,
|
||||
sparse_block_size,
|
||||
layout_query_ptr,
|
||||
layout_kv_ptr,
|
||||
sparse_mode,
|
||||
output);
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
@@ -695,4 +792,22 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
" (Tensor output, Tensor output_scale, Tensor output_offset)"
|
||||
);
|
||||
ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant_weight_nz_tensor_list);
|
||||
|
||||
ops.def(
|
||||
"npu_lightning_indexer(Tensor query, Tensor key, Tensor weights, *,"
|
||||
" Tensor? actual_seq_lengths_query=None, Tensor? actual_seq_lengths_key=None,"
|
||||
" Tensor? block_table=None, str layout_query='BSND', str layout_key='BSND',"
|
||||
" int sparse_count=2048, int sparse_mode=3) -> Tensor"
|
||||
);
|
||||
ops.impl("npu_lightning_indexer", torch::kPrivateUse1, &vllm_ascend::npu_lightning_indexer);
|
||||
|
||||
ops.def(
|
||||
"npu_sparse_flash_attention(Tensor query, Tensor key, Tensor value,"
|
||||
" Tensor sparse_indices, float scale_value, int sparse_block_size, *,"
|
||||
" Tensor? block_table=None, Tensor? actual_seq_lengths_query=None,"
|
||||
" Tensor? actual_seq_lengths_kv=None, Tensor? query_rope=None,"
|
||||
" Tensor? key_rope=None, str layout_query='BSND', str layout_kv='BSND',"
|
||||
" int sparse_mode=3) -> Tensor"
|
||||
);
|
||||
ops.impl("npu_sparse_flash_attention", torch::kPrivateUse1, &vllm_ascend::npu_sparse_flash_attention);
|
||||
}
|
||||
|
||||
@@ -159,6 +159,64 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor
|
||||
|
||||
}
|
||||
|
||||
at::Tensor npu_lightning_indexer_meta(
|
||||
const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_query,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_key,
|
||||
const c10::optional<at::Tensor> &block_table, c10::string_view layout_query,
|
||||
c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode)
|
||||
{
|
||||
// npu tensor max size
|
||||
constexpr int32_t SIZE = 8;
|
||||
constexpr int32_t DIM_0 = 0;
|
||||
constexpr int32_t DIM_1 = 1;
|
||||
constexpr int32_t DIM_2 = 2;
|
||||
constexpr int32_t DIM_3 = 3;
|
||||
|
||||
TORCH_CHECK(query.numel() > 0, "Query is empty.");
|
||||
TORCH_CHECK(key.numel() > 0, "Key is empty.");
|
||||
TORCH_CHECK(weights.numel() > 0, "Weights is empty.");
|
||||
for (size_t i = 0; i < query.sizes().size(); i++) {
|
||||
TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater "
|
||||
"than 0, but shape[", i, "] is ", query.size(i));
|
||||
}
|
||||
TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count);
|
||||
|
||||
std::string query_layout_str = std::string(layout_query);
|
||||
std::string key_layout_str = std::string(layout_key);
|
||||
at::SmallVector<int64_t, SIZE> output_size;
|
||||
if (query_layout_str == "BSND") {
|
||||
output_size = {query.size(DIM_0), query.size(DIM_1), key.size(DIM_2), sparse_count};
|
||||
} else {
|
||||
int n_dim_index = 0;
|
||||
n_dim_index = (key_layout_str == "TND") ? DIM_1 : DIM_2;
|
||||
output_size = {query.size(DIM_0), key.size(n_dim_index), sparse_count};
|
||||
}
|
||||
// construct the output tensor
|
||||
at::Tensor lightning_indexer_output = at::empty(output_size, query.options().dtype(at::kInt));
|
||||
return lightning_indexer_output;
|
||||
}
|
||||
|
||||
at::Tensor npu_sparse_flash_attention_meta(
|
||||
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
|
||||
const at::Tensor &sparse_indices, double scale_value, int64_t sparse_block_size,
|
||||
const c10::optional<at::Tensor> &block_table,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_query,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_kv,
|
||||
const c10::optional<at::Tensor> &query_rope,
|
||||
const c10::optional<at::Tensor> &key_rope, c10::string_view layout_query,
|
||||
c10::string_view layout_kv,
|
||||
int64_t sparse_mode)
|
||||
{
|
||||
std::string layout_query_str = std::string(layout_query);
|
||||
for (size_t i = 0; i < query.sizes().size(); i++) {
|
||||
TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater "
|
||||
"than 0, but shape[", i, "] is ", query.size(i));
|
||||
}
|
||||
at::Tensor output = at::empty(query.sizes(), query.options().dtype(query.dtype()));
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace meta
|
||||
} // namespace vllm_ascend
|
||||
|
||||
@@ -182,5 +240,9 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", &vllm_ascend::meta::grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta);
|
||||
// batch_matmul_transpose
|
||||
ops.impl("batch_matmul_transpose", &vllm_ascend::meta::batch_matmul_transpose);
|
||||
// Lightning indexer
|
||||
ops.impl("npu_lightning_indexer", &vllm_ascend::meta::npu_lightning_indexer_meta);
|
||||
// Sparse flash attention
|
||||
ops.impl("npu_sparse_flash_attention", &vllm_ascend::meta::npu_sparse_flash_attention_meta);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user