From 18b90b501d6aad1d9426dcdee1ccfbe8139dd47d Mon Sep 17 00:00:00 2001 From: Song Mingyang <43877003+MingYang119@users.noreply.github.com> Date: Wed, 3 Dec 2025 09:53:10 +0800 Subject: [PATCH] [kernel] add AscendC op: lightning_indexer and sparse_flash_attention (#4625) ### What this PR does / why we need it? Provide high-performance AscendC operators lightning_indexer and sparse_flash_attention to boost the execution performance of the DeepSeek v3.2 model. Meanwhile, adapt the two AscendC operators to vllm-ascend framework. ### Does this PR introduce _any_ user-facing change? No (only underlying operator optimizations, with no user-facing changes) ### How was this patch tested? - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: MingYang119 --- csrc/build_aclnn.sh | 4 +- csrc/lightning_indexer/op_host/CMakeLists.txt | 42 + .../op_host/lightning_indexer_def.cpp | 72 + .../op_host/lightning_indexer_proto.cpp | 96 + .../op_host/lightning_indexer_tiling.cpp | 694 +++++++ .../op_host/lightning_indexer_tiling.h | 215 ++ .../op_kernel/lightning_indexer.cpp | 58 + .../op_kernel/lightning_indexer_common.h | 135 ++ .../op_kernel/lightning_indexer_kernel.h | 623 ++++++ .../lightning_indexer_service_cube.h | 415 ++++ .../lightning_indexer_service_vector.h | 559 +++++ .../lightning_indexer_template_tiling_key.h | 66 + .../op_kernel/lightning_indexer_vector.h | 335 +++ .../op_host/CMakeLists.txt | 39 + .../op_host/sparse_flash_attention_def.cpp | 90 + .../op_host/sparse_flash_attention_proto.cpp | 48 + .../op_host/sparse_flash_attention_tiling.cpp | 1845 +++++++++++++++++ .../op_host/sparse_flash_attention_tiling.h | 583 ++++++ .../op_kernel/sparse_flash_attention.cpp | 53 + .../op_kernel/sparse_flash_attention_common.h | 192 ++ .../sparse_flash_attention_kernel_mla.h | 969 +++++++++ .../sparse_flash_attention_service_cube_mla.h | 1079 ++++++++++ ...parse_flash_attention_service_vector_mla.h | 1329 ++++++++++++ ...arse_flash_attention_template_tiling_key.h | 54 + csrc/torch_binding.cpp | 115 + csrc/torch_binding_meta.cpp | 62 + vllm_ascend/attention/sfa_v1.py | 4 +- vllm_ascend/worker/worker_v1.py | 15 - 28 files changed, 9772 insertions(+), 19 deletions(-) create mode 100644 csrc/lightning_indexer/op_host/CMakeLists.txt create mode 100644 csrc/lightning_indexer/op_host/lightning_indexer_def.cpp create mode 100644 csrc/lightning_indexer/op_host/lightning_indexer_proto.cpp create mode 100644 csrc/lightning_indexer/op_host/lightning_indexer_tiling.cpp create mode 100644 csrc/lightning_indexer/op_host/lightning_indexer_tiling.h create mode 100644 csrc/lightning_indexer/op_kernel/lightning_indexer.cpp create mode 100644 csrc/lightning_indexer/op_kernel/lightning_indexer_common.h create mode 100644 csrc/lightning_indexer/op_kernel/lightning_indexer_kernel.h create mode 100644 csrc/lightning_indexer/op_kernel/lightning_indexer_service_cube.h create mode 100644 csrc/lightning_indexer/op_kernel/lightning_indexer_service_vector.h create mode 100644 csrc/lightning_indexer/op_kernel/lightning_indexer_template_tiling_key.h create mode 100644 csrc/lightning_indexer/op_kernel/lightning_indexer_vector.h create mode 100644 csrc/sparse_flash_attention/op_host/CMakeLists.txt create mode 100644 csrc/sparse_flash_attention/op_host/sparse_flash_attention_def.cpp create mode 100644 csrc/sparse_flash_attention/op_host/sparse_flash_attention_proto.cpp create mode 100644 csrc/sparse_flash_attention/op_host/sparse_flash_attention_tiling.cpp create mode 100644 csrc/sparse_flash_attention/op_host/sparse_flash_attention_tiling.h create mode 100644 csrc/sparse_flash_attention/op_kernel/sparse_flash_attention.cpp create mode 100644 csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_common.h create mode 100644 csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_kernel_mla.h create mode 100644 csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_service_cube_mla.h create mode 100644 csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_service_vector_mla.h create mode 100644 csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_template_tiling_key.h diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index 8a282bef..f7896359 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -11,11 +11,11 @@ if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then exit 0 elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then # ASCEND910B (A2) series - CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list" + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series - CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list" + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention" SOC_ARG="ascend910_93" else # others diff --git a/csrc/lightning_indexer/op_host/CMakeLists.txt b/csrc/lightning_indexer/op_host/CMakeLists.txt new file mode 100644 index 00000000..7922ba8e --- /dev/null +++ b/csrc/lightning_indexer/op_host/CMakeLists.txt @@ -0,0 +1,42 @@ +# This program is free software, you can redistribute it and/or modify it. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_ops_compile_options( + OP_NAME LightningIndexer + OPTIONS --cce-auto-sync=off + -Wno-deprecated-declarations + -Werror + -mllvm -cce-aicore-hoist-movemask=false + --op_relocatable_kernel_binary=true +) + +set(lightning_indexer_depends transformer/attention/lightning_indexer PARENT_SCOPE) + +target_sources(op_host_aclnn PRIVATE + lightning_indexer_def.cpp +) + +target_sources(optiling PRIVATE + lightning_indexer_tiling.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(opmaster_ct PRIVATE + lightning_indexer_tiling.cpp + ) +endif () + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_sources(opsproto PRIVATE + lightning_indexer_proto.cpp +) + diff --git a/csrc/lightning_indexer/op_host/lightning_indexer_def.cpp b/csrc/lightning_indexer/op_host/lightning_indexer_def.cpp new file mode 100644 index 00000000..95f97a34 --- /dev/null +++ b/csrc/lightning_indexer/op_host/lightning_indexer_def.cpp @@ -0,0 +1,72 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_def.cpp + * \brief + */ +#include +#include "register/op_def_registry.h" + +namespace ops { +class LightningIndexer : public OpDef { +public: + explicit LightningIndexer(const char *name) : OpDef(name) + { + this->Input("query") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("key") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("weights") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("actual_seq_lengths_query") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("actual_seq_lengths_key") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("block_table") + .ParamType(OPTIONAL) + .DataTypeList({ge::DT_INT32}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Output("sparse_indices").ParamType(REQUIRED).DataTypeList({ge::DT_INT32}).FormatList({ge::FORMAT_ND}); + this->Attr("layout_query").AttrType(OPTIONAL).String("BSND"); + this->Attr("layout_key").AttrType(OPTIONAL).String("PA_BSND"); + this->Attr("sparse_count").AttrType(OPTIONAL).Int(2048); // 2048: Default value, filter the top 2048 + this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3); // 3: Default value, only calculate the lower triangular matrix + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true) + .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") + .ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false"); + this->AICore().AddConfig("ascend910b", aicore_config); + this->AICore().AddConfig("ascend910_93", aicore_config); + } +}; +OP_ADD(LightningIndexer); +} // namespace ops \ No newline at end of file diff --git a/csrc/lightning_indexer/op_host/lightning_indexer_proto.cpp b/csrc/lightning_indexer/op_host/lightning_indexer_proto.cpp new file mode 100644 index 00000000..cc1a793e --- /dev/null +++ b/csrc/lightning_indexer/op_host/lightning_indexer_proto.cpp @@ -0,0 +1,96 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_proto.cpp + * \brief + */ +#include +#include +#include "error/ops_error.h" + + +using namespace ge; + +namespace ops { +constexpr uint32_t QUERY_INDEX = 0; +constexpr uint32_t KEY_INDEX = 1; +constexpr uint32_t ACTUAL_SEQ_K_INDEX = 4; +constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 0; +constexpr uint32_t ATTR_KEY_LAYOUT_INDEX = 1; +constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 2; + +static ge::graphStatus InferShapeLightningIndexer(gert::InferShapeContext *context) +{ + OPS_ERR_IF(context == nullptr, OPS_LOG_E("LightningIndexer", "InferShapeContext is nullptr!"), + return ge::GRAPH_FAILED); + const gert::Shape *queryShape = context->GetInputShape(QUERY_INDEX); + OPS_LOG_E_IF_NULL(context, queryShape, return ge::GRAPH_FAILED); + const gert::Shape *keyShape = context->GetInputShape(KEY_INDEX); + OPS_LOG_E_IF_NULL(context, keyShape, return ge::GRAPH_FAILED); + gert::Shape *outShape = context->GetOutputShape(0); + + auto attrs = context->GetAttrs(); + OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED); + const char *inputLayoutQueryPtr = attrs->GetAttrPointer(ATTR_QUERY_LAYOUT_INDEX); + OPS_LOG_E_IF_NULL(context, inputLayoutQueryPtr, return ge::GRAPH_FAILED); + const char *inputLayoutKeyPtr = attrs->GetAttrPointer(ATTR_KEY_LAYOUT_INDEX); + OPS_LOG_E_IF_NULL(context, inputLayoutKeyPtr, return ge::GRAPH_FAILED); + const int64_t *seleced_count = attrs->GetInt(ATTR_SPARSE_COUNT_INDEX); + OPS_LOG_E_IF_NULL(context, seleced_count, return ge::GRAPH_FAILED); + std::string inputLayoutQueryPtrStr = std::string(inputLayoutQueryPtr); + std::string inputLayoutKeyPtrStr = std::string(inputLayoutKeyPtr); + OPS_ERR_IF( + inputLayoutQueryPtrStr != "TND" && inputLayoutQueryPtrStr != "BSND", + OPS_LOG_E(context, "The attr layout_query should be TND or BSND, but got %s.", inputLayoutQueryPtrStr.c_str()), + return ge::GRAPH_FAILED); + + outShape->SetDimNum(queryShape->GetDimNum()); + if (inputLayoutQueryPtrStr == "BSND") { + OPS_ERR_IF( + queryShape->GetDimNum() != 4, + OPS_LOG_E(context, "Layout BSND, queryDims (%zu) must be 4!", queryShape->GetDimNum()), + return ge::GRAPH_FAILED); + outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim B + outShape->SetDim(1, queryShape->GetDim(1)); // 1:Dim S + outShape->SetDim(2, keyShape->GetDim(2)); // 2:Dim N + outShape->SetDim(3, *seleced_count); // 3:Dim K + } else { + OPS_ERR_IF( + queryShape->GetDimNum() != 3, + OPS_LOG_E(context, "Layout TND, queryDims (%zu) must be 3!", queryShape->GetDimNum()), + return ge::GRAPH_FAILED); + outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim T + int32_t nDimIndex = (inputLayoutKeyPtrStr == "PA_BSND") ? 2 : 1; // 2:Key Dim N + outShape->SetDim(1, keyShape->GetDim(nDimIndex)); // 1:Dim N + outShape->SetDim(2, *seleced_count); // 2:Dim K + } + OPS_LOG_D(context->GetNodeName(), "LightningIndexer InferShape end."); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus InferDataTypeLightningIndexer(gert::InferDataTypeContext *context) +{ + OPS_ERR_IF(context == nullptr, OPS_LOG_E("LightningIndexer", "InferDataTypeContext is nullptr!"), + return ge::GRAPH_FAILED); + OPS_LOG_D(context->GetNodeName(), "Enter LightningIndexer InferDataType impl."); + // default set q's dtype as fia's output type + ge::DataType outputType = ge::DT_INT32; + // attention_out, outidx:0 + context->SetOutputDataType(0, outputType); + OPS_LOG_D(context->GetNodeName(), "LightningIndexer InferDataType end."); + return GRAPH_SUCCESS; +} + +IMPL_OP_INFERSHAPE(LightningIndexer) + .InferShape(InferShapeLightningIndexer) + .InferDataType(InferDataTypeLightningIndexer); +} // namespace ops diff --git a/csrc/lightning_indexer/op_host/lightning_indexer_tiling.cpp b/csrc/lightning_indexer/op_host/lightning_indexer_tiling.cpp new file mode 100644 index 00000000..ff7f7843 --- /dev/null +++ b/csrc/lightning_indexer/op_host/lightning_indexer_tiling.cpp @@ -0,0 +1,694 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_tiling.cpp + * \brief + */ + +#include "lightning_indexer_tiling.h" +#include "../op_kernel/lightning_indexer_template_tiling_key.h" + +using namespace ge; +using namespace AscendC; +using std::map; +using std::string; +namespace optiling { +ge::graphStatus LIInfoParser::CheckRequiredInOutExistence() const +{ + OPS_ERR_IF(opParamInfo_.query.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor query is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.query.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor query is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.key.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor k is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.key.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor k is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.weights.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor value is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.weights.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor value is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.attenOut.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor output is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.attenOut.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor output is nullptr"), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::CheckRequiredAttrExistence() const +{ + OPS_ERR_IF(opParamInfo_.layOut == nullptr, OPS_LOG_E(opName_, "attr layout_query is nullptr"), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(opParamInfo_.layOutKey == nullptr, OPS_LOG_E(opName_, "attr layout_key is nullptr"), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(opParamInfo_.sparseCount == nullptr, OPS_LOG_E(opName_, "attr sparse_count is nullptr"), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(opParamInfo_.sparseMode == nullptr, OPS_LOG_E(opName_, "attr sparse_mode is nullptr"), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::CheckRequiredParaExistence() const +{ + if (CheckRequiredInOutExistence() != ge::GRAPH_SUCCESS || CheckRequiredAttrExistence() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetOpName() +{ + if (context_->GetNodeName() == nullptr) { + OPS_LOG_E("LightningIndexer", "opName got from TilingContext is nullptr"); + return ge::GRAPH_FAILED; + } + opName_ = context_->GetNodeName(); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetNpuInfo() +{ + platformInfo_ = context_->GetPlatformInfo(); + OPS_ERR_IF(platformInfo_ == nullptr, OPS_LOG_E(opName_, "GetPlatformInfo is nullptr."), return ge::GRAPH_FAILED); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo_); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint32_t aicNum = ascendcPlatform.GetCoreNumAic(); + OPS_ERR_IF(aicNum == 0 || aivNum == 0, OPS_LOG_E(opName_, "num of core obtained is 0."), return GRAPH_FAILED); + + socVersion_ = ascendcPlatform.GetSocVersion(); + if ((socVersion_ != platform_ascendc::SocVersion::ASCEND910B) && + (socVersion_ != platform_ascendc::SocVersion::ASCEND910_93)) { + OPS_LOG_E(opName_, "SOC Version[%d] is not support.", (int32_t)socVersion_); + return GRAPH_FAILED; + } + OPS_ERR_IF(context_->GetWorkspaceSizes(1) == nullptr, OPS_LOG_E(opName_, "workSpaceSize got from ge is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(context_->GetRawTilingData() == nullptr, + OPS_LOG_E(context_->GetNodeName(), "RawTilingData got from GE context is nullptr."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +void LIInfoParser::GetOptionalInputParaInfo() +{ + opParamInfo_.actualSeqLengthsQ.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_Q_INDEX); + opParamInfo_.actualSeqLengthsQ.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_Q_INDEX); + opParamInfo_.actualSeqLengths.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_K_INDEX); + opParamInfo_.actualSeqLengths.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_K_INDEX); + opParamInfo_.blockTable.tensor = context_->GetOptionalInputTensor(BLOCK_TABLE_INDEX); + opParamInfo_.blockTable.desc = context_->GetOptionalInputDesc(BLOCK_TABLE_INDEX); +} + +void LIInfoParser::GetInputParaInfo() +{ + opParamInfo_.query.desc = context_->GetInputDesc(QUERY_INDEX); + opParamInfo_.query.shape = context_->GetInputShape(QUERY_INDEX); + opParamInfo_.key.desc = context_->GetInputDesc(KEY_INDEX); + opParamInfo_.key.shape = context_->GetInputShape(KEY_INDEX); + opParamInfo_.weights.desc = context_->GetInputDesc(WEIGTHS_INDEX); + opParamInfo_.weights.shape = context_->GetInputShape(WEIGTHS_INDEX); + GetOptionalInputParaInfo(); +} + +void LIInfoParser::GetOutputParaInfo() +{ + opParamInfo_.attenOut.desc = context_->GetOutputDesc(LIGHTNING_INDEXER); + opParamInfo_.attenOut.shape = context_->GetOutputShape(LIGHTNING_INDEXER); +} + +ge::graphStatus LIInfoParser::GetAndCheckAttrParaInfo() +{ + auto attrs = context_->GetAttrs(); + OPS_ERR_IF(attrs == nullptr, OPS_REPORT_VECTOR_INNER_ERR(context_->GetNodeName(), "attrs got from ge is nullptr"), + return ge::GRAPH_FAILED); + + OPS_LOG_I(context_->GetNodeName(), "GetAndCheckAttrParaInfo start"); + opParamInfo_.layOut = attrs->GetStr(ATTR_QUERY_LAYOUT_INDEX); + opParamInfo_.layOutKey = attrs->GetStr(ATTR_KEY_LAYOUT_INDEX); + opParamInfo_.sparseCount = attrs->GetAttrPointer(ATTR_SPARSE_COUNT_INDEX); + opParamInfo_.sparseMode = attrs->GetAttrPointer(ATTR_SPARSE_MODE_INDEX); + + if (opParamInfo_.layOut != nullptr) { + OPS_LOG_I(context_->GetNodeName(), "layout_query is:%s", opParamInfo_.layOut); + } + if (opParamInfo_.layOutKey != nullptr) { + OPS_LOG_I(context_->GetNodeName(), "layout_key is:%s", opParamInfo_.layOutKey); + } + if (opParamInfo_.sparseCount != nullptr) { + OPS_LOG_I(context_->GetNodeName(), "selscted count is:%d", *opParamInfo_.sparseCount); + } + if (opParamInfo_.sparseMode != nullptr) { + OPS_LOG_I(context_->GetNodeName(), "sparse mode is:%d", *opParamInfo_.sparseMode); + } + OPS_LOG_I(context_->GetNodeName(), "GetAndCheckAttrParaInfo end"); + + OPS_ERR_IF( + ((std::string(opParamInfo_.layOutKey) != "PA_BSND") + && (std::string(opParamInfo_.layOut) != std::string(opParamInfo_.layOutKey))), + OPS_LOG_E(opName_, "under non-PA conditions, layout_query and layout_key should be equal."), + return ge::GRAPH_FAILED); + OPS_ERR_IF( + ((std::string(opParamInfo_.layOutKey) != "PA_BSND") && (std::string(opParamInfo_.layOutKey) != "BSND") + && (std::string(opParamInfo_.layOutKey) != "TND")), + OPS_LOG_E(opName_, "input attr layout_key only supported PA_BSND, BSND or TND"), return ge::GRAPH_FAILED); + OPS_ERR_IF(((std::string(opParamInfo_.layOut) != "BSND") && (std::string(opParamInfo_.layOut) != "TND")), + OPS_LOG_E(opName_, "input attr layout_query only supported BSND or TND."), return ge::GRAPH_FAILED); + OPS_ERR_IF(!((*opParamInfo_.sparseCount > 0) && (*opParamInfo_.sparseCount <= SPARSE_LIMIT)), + OPS_LOG_E(opName_, "input attr sparse_count must > 0 and <= 2048."), return ge::GRAPH_FAILED); + OPS_ERR_IF(!((*opParamInfo_.sparseMode == 0) || (*opParamInfo_.sparseMode == SPARSE_MODE_LOWER)), + OPS_LOG_E(opName_, "input attr sparse_mode only supported 0 or 3."), return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetOpParaInfo() +{ + GetInputParaInfo(); + GetOutputParaInfo(); + if (ge::GRAPH_SUCCESS != GetAndCheckAttrParaInfo()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetAndCheckInOutDataType() +{ + inputQType_ = opParamInfo_.query.desc->GetDataType(); + inputKType_ = opParamInfo_.key.desc->GetDataType(); + weightsType_ = opParamInfo_.weights.desc->GetDataType(); + outputType_ = opParamInfo_.attenOut.desc->GetDataType(); + + bool inDTypeAllEqual = (inputQType_ == inputKType_) && (inputKType_ == weightsType_); + OPS_ERR_IF(!inDTypeAllEqual, + OPS_LOG_E(opName_, "The data types of the input query, key, and weights must be the same."), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(((inputQType_ != ge::DT_FLOAT16) && (inputQType_ != ge::DT_BF16)), + OPS_LOG_E(opName_, "The data types of the input query, key, and weights must be float16 or bfloat16."), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(outputType_ != ge::DT_INT32, + OPS_LOG_E(opName_, "The data types of the output sparse_indices must be int32."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetQueryKeyAndOutLayout() +{ + const map layoutMap = { + {"BSND", DataLayout::BSND}, + {"TND", DataLayout::TND}, + {"PA_BSND", DataLayout::BnBsND} + }; + + std::string layout(opParamInfo_.layOut); + auto it = layoutMap.find(layout); + if (it != layoutMap.end()) { + qLayout_ = it->second; + } + + std::string layoutKey(opParamInfo_.layOutKey); + auto itKey = layoutMap.find(layoutKey); + if (itKey != layoutMap.end()) { + kLayout_ = itKey->second; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetAndCheckOptionalInput() +{ + if (kLayout_ == DataLayout::BnBsND) { + OPS_ERR_IF(opParamInfo_.blockTable.tensor == nullptr, + OPS_LOG_E(opName_, "key layout only supported PA_BSND, input block_table must not be null"), + return ge::GRAPH_FAILED); + OPS_ERR_IF( + opParamInfo_.actualSeqLengths.tensor == nullptr, + OPS_LOG_E(opName_, "key layout only supported PA_BSND, input actual_seq_lengths_key must not be null"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.blockTable.desc->GetDataType() != ge::DT_INT32, + OPS_LOG_E(opName_, "input block_table data type only support int32"), return ge::GRAPH_FAILED); + } else if (kLayout_ == DataLayout::TND) { + OPS_ERR_IF(opParamInfo_.actualSeqLengths.tensor == nullptr, + OPS_LOG_E(opName_, "when layout_key is TND, input actual_seq_lengths_key must not be null"), + return ge::GRAPH_FAILED); + } + OPS_ERR_IF(opParamInfo_.actualSeqLengths.tensor != nullptr && + opParamInfo_.actualSeqLengths.desc->GetDataType() != ge::DT_INT32, + OPS_LOG_E(opName_, "input actual_seq_lengths_key data type only support int32"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.actualSeqLengths.tensor != nullptr && + opParamInfo_.actualSeqLengths.desc->GetDataType() != ge::DT_INT32, + OPS_LOG_E(opName_, "input actual_seq_lengths_key data type only support int32"), + return ge::GRAPH_FAILED); + if (qLayout_ == DataLayout::TND) { + OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor == nullptr, + OPS_LOG_E(opName_, "when layout_query is TND, input actual_seq_lengths_query must not be null"), + return ge::GRAPH_FAILED); + } + OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor != nullptr && + opParamInfo_.actualSeqLengthsQ.desc->GetDataType() != ge::DT_INT32, + OPS_LOG_E(opName_, "input actual_seq_lengths_query data type only support int32"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(kLayout_ != DataLayout::BnBsND && opParamInfo_.blockTable.tensor != nullptr, + OPS_LOG_E(opName_, "when key layout is not PA_BSND, input block_table must be null"), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::CheckShapeDim() +{ + OPS_ERR_IF((opParamInfo_.blockTable.tensor != nullptr) && + (opParamInfo_.blockTable.tensor->GetStorageShape().GetDimNum() != DIM_NUM_TWO), + OPS_LOG_E(opName_, "the dim num of block_table's shape should be 2"), return ge::GRAPH_FAILED); + + uint32_t kShapeDim = opParamInfo_.key.shape->GetStorageShape().GetDimNum(); + uint32_t qShapeDim = opParamInfo_.query.shape->GetStorageShape().GetDimNum(); + uint32_t weightsShapeDim = opParamInfo_.weights.shape->GetStorageShape().GetDimNum(); + uint32_t outShapeDim = opParamInfo_.attenOut.shape->GetStorageShape().GetDimNum(); + uint32_t qExpectShapeDim = DIM_NUM_FOUR; + uint32_t kExpectShapeDim = DIM_NUM_FOUR; + if (qLayout_ == DataLayout::TND) { + qExpectShapeDim = DIM_NUM_THREE; + } + if (kLayout_ == DataLayout::TND) { + kExpectShapeDim = DIM_NUM_THREE; + } + OPS_ERR_IF(kShapeDim != kExpectShapeDim, + OPS_LOG_E(opName_, "the dim num of key's shape should be %u, but now is %u", kExpectShapeDim, kShapeDim), + return ge::GRAPH_FAILED); + OPS_ERR_IF(qShapeDim != qExpectShapeDim, + OPS_LOG_E(opName_, "the dim num of query's shape should be %u, but now is %u", + qExpectShapeDim, qShapeDim), + return ge::GRAPH_FAILED); + OPS_ERR_IF(outShapeDim != qExpectShapeDim, + OPS_LOG_E(opName_, "the dim num of sparse_indices's shape should be %u, but now is %u", + qExpectShapeDim, outShapeDim), + return ge::GRAPH_FAILED); + OPS_ERR_IF(!(weightsShapeDim == qExpectShapeDim - 1), + OPS_LOG_E(opName_, "the dim num of weights's shape should be %u, but now is %u", qExpectShapeDim - 1, + weightsShapeDim), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetN1Size() +{ + if (qLayout_ == DataLayout::BSND) { + n1Size_ = static_cast(opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_TWO)); + } else { + // TND + n1Size_ = static_cast(opParamInfo_.query.shape->GetStorageShape().GetDim(1)); + } + OPS_LOG_I(context_->GetNodeName(), "n1Size is %d", n1Size_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor, + const std::string &actualSeqLenName) +{ + size = static_cast(tensor->GetShapeSize()); + if (size <= 0) { + OPS_LOG_E(opName_, "%s's shape size is %u, it should be greater than 0.", actualSeqLenName.c_str(), size); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetAndCheckN2Size() +{ + uint32_t n2Index = (kLayout_ == DataLayout::TND) ? DIM_IDX_ONE : DIM_IDX_TWO; + n2Size_ = static_cast(opParamInfo_.key.shape->GetStorageShape().GetDim(n2Index)); + OPS_LOG_I(context_->GetNodeName(), "n2Size_ is %d", n2Size_); + OPS_ERR_IF(n2Size_ != 1, OPS_LOG_E(opName_, "key shape[%u] is numhead, only support 1.", n2Index), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetGSize() +{ + if (n1Size_ % n2Size_ != 0) { + OPS_LOG_E(opName_, "input query's head_num %u can not be a multiple of key's head_num %u.", n1Size_, n2Size_); + return ge::GRAPH_FAILED; + } + gSize_ = n1Size_ / n2Size_; + OPS_ERR_IF(gSize_ != 64, OPS_LOG_E(opName_, "N1 is %u, N2 is %u, N1 divided by N2 must equal 64.", + n1Size_, n2Size_), return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetBatchSize() +{ + if ((qLayout_ == DataLayout::TND)) { + return GetActualSeqLenSize(bSize_, opParamInfo_.actualSeqLengthsQ.tensor, "input actual_seq_lengths_query"); + } else { // BSND + bSize_ = opParamInfo_.query.shape->GetStorageShape().GetDim(0); + return ge::GRAPH_SUCCESS; + } +} + +ge::graphStatus LIInfoParser::GetHeadDim() +{ + uint32_t dIndex = DIM_IDX_TWO; + switch (qLayout_) { + case DataLayout::TND: + // TND: [Total, N, D] -> D is the 2nd dimension + dIndex = DIM_IDX_TWO; + break; + case DataLayout::BSND: + // BSND: [Batch, SeqLen, N, D] -> D is the 3nd dimension + dIndex = DIM_IDX_THREE; + break; + default: + OPS_LOG_E(opName_, "unsupported layout for getting head dim."); + return ge::GRAPH_FAILED; + } + headDim_ = opParamInfo_.query.shape->GetStorageShape().GetDim(dIndex); + OPS_ERR_IF(headDim_ != HEAD_DIM_LIMIT, OPS_LOG_E(opName_, "input query's last dim head_dim only support 128."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetS1Size() +{ + if (qLayout_ == DataLayout::BSND) { + s1Size_ = opParamInfo_.query.shape->GetStorageShape().GetDim(1); + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetAndCheckBlockSize() +{ + blockSize_ = static_cast(opParamInfo_.key.shape->GetStorageShape().GetDim(1)); + OPS_LOG_I(context_->GetNodeName(), "blockSize_ is %d", blockSize_); + + OPS_ERR_IF(((blockSize_ % 16 != 0) || (blockSize_ == 0) || (blockSize_ > 1024)), + OPS_LOG_E(opName_, "input key's block_size must be a multiple of 16 and belong to (0, 1024]."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::CheckBlockCount() +{ + int32_t blockCount_ = static_cast(opParamInfo_.key.shape->GetStorageShape().GetDim(0)); + OPS_ERR_IF((blockCount_ == 0), + OPS_LOG_E(opName_, "input key's block_count cannot be 0."), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetS2SizeForPageAttention() +{ + if (GetAndCheckBlockSize() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + if (CheckBlockCount() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + maxBlockNumPerBatch_ = opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(1); + s2Size_ = maxBlockNumPerBatch_ * blockSize_; + OPS_LOG_I(context_->GetNodeName(), "maxBlockNumPerBatch_ is %d, blockSize_ is %d, s2Size_ is %d", + maxBlockNumPerBatch_, blockSize_, s2Size_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetS2Size() +{ + if (kLayout_ == DataLayout::BnBsND) { + return GetS2SizeForPageAttention(); + } else if (kLayout_ == DataLayout::TND) { + s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(0); + } else if (kLayout_ == DataLayout::BSND) { + s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(1); + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::ValidateInputShapesMatchQTnd() +{ + // -----------------------check BatchSize------------------- + if (kLayout_ == DataLayout::TND) { + OPS_ERR_IF( + (opParamInfo_.actualSeqLengths.tensor->GetShapeSize() != bSize_), + OPS_LOG_E(opName_, + "TND case input actual_seq_lengths_query, actual_seq_lengths_key are %u, %ld respectively, they must be same.", + bSize_, opParamInfo_.actualSeqLengths.tensor->GetShapeSize()), + return ge::GRAPH_FAILED); + } else { // kLayout_ PA_BSND + OPS_ERR_IF( + (opParamInfo_.actualSeqLengths.tensor->GetShapeSize() != bSize_) || + (opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_), + OPS_LOG_E( + opName_, + "TND case input actual_seq_lengths_query, actual_seq_lengths_key, block_table dim 0 are %u, %ld, %ld respectively, they must be same.", + bSize_, opParamInfo_.actualSeqLengths.tensor->GetShapeSize(), + opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0)), + return ge::GRAPH_FAILED); + } + // -----------------------check T------------------- + uint32_t qTsize = opParamInfo_.query.shape->GetStorageShape().GetDim(0); + OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != qTsize) || + (opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != qTsize), + OPS_LOG_E(opName_, "TND case input query, weights, sparse_indices dim 0 are %u, %ld, %ld respectively, they must be same.", + qTsize, opParamInfo_.weights.shape->GetStorageShape().GetDim(0), + opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::ValidateInputShapesMatchQBsnd() +{ + // -----------------------check BatchSize------------------- + if (kLayout_ == DataLayout::BnBsND) { + OPS_ERR_IF((opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_) || + (opParamInfo_.actualSeqLengths.tensor->GetShapeSize() != bSize_), + OPS_LOG_E(opName_, "BSND case input query, actual_seq_lengths_key, block_table dim 0 are %u, %ld, %ld respectively, they must be same.", + bSize_, opParamInfo_.actualSeqLengths.tensor->GetShapeSize(), + opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0)), + return ge::GRAPH_FAILED); + } else if (kLayout_ == DataLayout::BSND) { + OPS_ERR_IF(opParamInfo_.key.shape->GetStorageShape().GetDim(0) != bSize_, + OPS_LOG_E(opName_, "BSND case input query, key dim 0 are %u, %ld respectively, they must be same.", + bSize_, opParamInfo_.key.shape->GetStorageShape().GetDim(0)), + return ge::GRAPH_FAILED); + OPS_ERR_IF((opParamInfo_.actualSeqLengths.tensor != nullptr) && + (opParamInfo_.actualSeqLengths.tensor->GetShapeSize() != bSize_), + OPS_LOG_E(opName_, "BSND case input query, actual_seq_lengths_key dim 0 are %u, %ld respectively, they must be same.", + bSize_, opParamInfo_.actualSeqLengths.tensor->GetShapeSize()), + return ge::GRAPH_FAILED); + } + OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != bSize_) || + (opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != bSize_), + OPS_LOG_E(opName_, "BSND case input query, weight, sparse_indices dim 0 are %u, %ld, %ld respectively, they must be same.", + bSize_, opParamInfo_.weights.shape->GetStorageShape().GetDim(0), + opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)), + return ge::GRAPH_FAILED); + OPS_ERR_IF((opParamInfo_.actualSeqLengthsQ.tensor != nullptr) && + (opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize() != bSize_), + OPS_LOG_E(opName_, "BSND case input query, actual_seq_lengths_query dim 0 are %u, %ld respectively, they must be same.", + bSize_, opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize()), + return ge::GRAPH_FAILED); + // -----------------------check S1------------------- + OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(1) != s1Size_) || + (opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1) != s1Size_), + OPS_LOG_E(opName_, "BSND case input query, weight, sparse_indices dim 1 are %u, %ld, %ld, they must be same.", + s1Size_, opParamInfo_.weights.shape->GetStorageShape().GetDim(1), + opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1)), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::ValidateInputShapesMatch() +{ + uint32_t queryWeightsN1Dim = 1; + uint32_t outN2Dim = 1; + if (qLayout_ == DataLayout::TND) { + if (ValidateInputShapesMatchQTnd() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + } else { + if (ValidateInputShapesMatchQBsnd() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + queryWeightsN1Dim = DIM_IDX_TWO; + outN2Dim = DIM_IDX_TWO; + } + // -----------------------check N1------------------- + OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(queryWeightsN1Dim) != n1Size_), + OPS_LOG_E(opName_, "input query, weight shape dim N1 must be same."), return ge::GRAPH_FAILED); + // -----------------------check D------------------- + uint32_t keyDDim = kLayout_ == DataLayout::TND ? DIM_IDX_TWO : DIM_IDX_THREE; + OPS_ERR_IF((opParamInfo_.key.shape->GetStorageShape().GetDim(keyDDim) != headDim_), + OPS_LOG_E(opName_, "input query, key shape last dim must be same."), return ge::GRAPH_FAILED); + // -----------------------check N2------------------- + OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim) != n2Size_), + OPS_LOG_E(opName_, "input query and output sparse_indices shape n2 dim must be same."), + return ge::GRAPH_FAILED); + // -----------------------check sparse_count------------------- + OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim + 1) != *opParamInfo_.sparseCount), + OPS_LOG_E(opName_, "output sparse_indices shape last dim must be same as attr sparse_count."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +void LIInfoParser::GenerateInfo(LITilingInfo &liInfo) +{ + liInfo.opName = opName_; + liInfo.platformInfo = platformInfo_; + liInfo.opParamInfo = opParamInfo_; + liInfo.socVersion = socVersion_; + + liInfo.bSize = bSize_; + liInfo.n1Size = n1Size_; + liInfo.n2Size = n2Size_; + liInfo.s1Size = s1Size_; + liInfo.s2Size = s2Size_; + liInfo.gSize = gSize_; + + liInfo.inputQType = inputQType_; + liInfo.inputKType = inputKType_; + liInfo.outputType = outputType_; + + liInfo.blockSize = blockSize_; + liInfo.maxBlockNumPerBatch = maxBlockNumPerBatch_; + + std::string layOutKeyStr(opParamInfo_.layOutKey); + liInfo.pageAttentionFlag = layOutKeyStr == "PA_BSND" ? true : false; + liInfo.sparseMode = *opParamInfo_.sparseMode; + liInfo.sparseCount = *opParamInfo_.sparseCount; + + liInfo.inputQLayout = qLayout_; + liInfo.inputKLayout = kLayout_; +} + +ge::graphStatus LIInfoParser::ParseAndCheck(LITilingInfo &liInfo) +{ + if (ge::GRAPH_SUCCESS != GetOpName() || ge::GRAPH_SUCCESS != GetNpuInfo() || ge::GRAPH_SUCCESS != GetOpParaInfo() || + ge::GRAPH_SUCCESS != CheckRequiredParaExistence()) { + return ge::GRAPH_FAILED; + } + + if (ge::GRAPH_SUCCESS != GetAndCheckInOutDataType() || ge::GRAPH_SUCCESS != GetQueryKeyAndOutLayout() || + ge::GRAPH_SUCCESS != GetAndCheckOptionalInput()) { + return ge::GRAPH_FAILED; + } + + if (ge::GRAPH_SUCCESS != CheckShapeDim() || ge::GRAPH_SUCCESS != GetN1Size() || + ge::GRAPH_SUCCESS != GetAndCheckN2Size() || ge::GRAPH_SUCCESS != GetGSize()) { + return ge::GRAPH_FAILED; + } + + if (ge::GRAPH_SUCCESS != GetBatchSize() || ge::GRAPH_SUCCESS != GetS1Size() || ge::GRAPH_SUCCESS != GetHeadDim() || + ge::GRAPH_SUCCESS != GetS2Size()) { + return ge::GRAPH_FAILED; + } + if (ge::GRAPH_SUCCESS != ValidateInputShapesMatch()) { + return ge::GRAPH_FAILED; + } + + GenerateInfo(liInfo); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus TilingPrepareForLightningIndexer(gert::TilingParseContext * /* context */) +{ + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LightningIndexerTiling::DoTiling(LITilingInfo *tilingInfo) +{ + // -------------set blockdim----------------- + auto ascendcPlatform = platform_ascendc::PlatformAscendC(tilingInfo->platformInfo); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint32_t aicNum = ascendcPlatform.GetCoreNumAic(); + uint32_t blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, aicNum, aivNum); + context_->SetBlockDim(blockDim); + + // -------------set workspacesize----------------- + constexpr uint32_t MM1_RES_ELEM_SIZE = 4; + constexpr uint32_t DOUBLE_BUFFER = 2; + constexpr uint32_t M_BASE_SIZE = 512; + constexpr uint32_t S2_BASE_SIZE = 512; + constexpr uint32_t V1_RES_ELEM_SIZE = 4; + constexpr uint32_t V1_RES_ELEM_TYPE = 2; + constexpr uint32_t V1_DECODE_PARAM_ELEM_SIZE = 8; + constexpr uint32_t V1_DECODE_PARAM_NUM = 16; + constexpr uint32_t V1_DECODE_DATA_NUM = 2; + constexpr uint32_t S1_BASE_SIZE = 8; + constexpr uint32_t TOPK_MAX_SIZE = 2048; + uint32_t workspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize(); + uint32_t mm1ResSize = M_BASE_SIZE * S2_BASE_SIZE; + workspaceSize += mm1ResSize * MM1_RES_ELEM_SIZE * DOUBLE_BUFFER * aicNum; + workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_RES_ELEM_TYPE * TOPK_MAX_SIZE * V1_RES_ELEM_SIZE * aicNum; + workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_DECODE_PARAM_NUM * V1_DECODE_PARAM_ELEM_SIZE * aicNum; + size_t *workSpaces = context_->GetWorkspaceSizes(1); + workSpaces[0] = workspaceSize; + + // -------------set tilingdata----------------- + tilingData_.set_bSize(tilingInfo->bSize); + tilingData_.set_s2Size(tilingInfo->s2Size); + tilingData_.set_s1Size(tilingInfo->s1Size); + tilingData_.set_sparseCount(tilingInfo->sparseCount); + tilingData_.set_gSize(tilingInfo->gSize); + tilingData_.set_blockSize(tilingInfo->blockSize); + tilingData_.set_maxBlockNumPerBatch(tilingInfo->maxBlockNumPerBatch); + tilingData_.set_sparseMode(tilingInfo->sparseMode); + tilingData_.set_usedCoreNum(blockDim); + tilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity()); + context_->GetRawTilingData()->SetDataSize(tilingData_.GetDataSize()); + + // -------------set tilingkey----------------- + // DT_Q, DT_KV, DT_OUT, PAGE_ATTENTION, FLASH_DECODE, LAYOUT_T, KV_LAYOUT_T + uint32_t inputQType = static_cast(tilingInfo->inputQType); + uint32_t inputKType = static_cast(tilingInfo->inputKType); + uint32_t outputType = static_cast(tilingInfo->outputType); + uint32_t pageAttentionFlag = static_cast(tilingInfo->pageAttentionFlag); + uint32_t inputQLayout = static_cast(tilingInfo->inputQLayout); + uint32_t inputKLayout = static_cast(tilingInfo->inputKLayout); + uint32_t tilingKey = + GET_TPL_TILING_KEY(inputQType, inputKType, outputType, pageAttentionFlag, inputQLayout, inputKLayout); + context_->SetTilingKey(tilingKey); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus TilingForLightningIndexer(gert::TilingContext *context) +{ + OPS_ERR_IF(context == nullptr, OPS_REPORT_VECTOR_INNER_ERR("LightningIndexer", "Tiling context is null."), + return ge::GRAPH_FAILED); + LITilingInfo liInfo; + LIInfoParser LIInfoParser(context); + if (LIInfoParser.ParseAndCheck(liInfo) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + LightningIndexerTiling liTiling(context); + return liTiling.DoTiling(&liInfo); +} + +IMPL_OP_OPTILING(LightningIndexer) + .Tiling(TilingForLightningIndexer) + .TilingParse(TilingPrepareForLightningIndexer); + +} // namespace optiling diff --git a/csrc/lightning_indexer/op_host/lightning_indexer_tiling.h b/csrc/lightning_indexer/op_host/lightning_indexer_tiling.h new file mode 100644 index 00000000..fb7ce43d --- /dev/null +++ b/csrc/lightning_indexer/op_host/lightning_indexer_tiling.h @@ -0,0 +1,215 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_tiling.h + * \brief + */ + +#ifndef LIGHTNING_INDEXER_TILING_H_ +#define LIGHTNING_INDEXER_TILING_H_ + +#include "exe_graph/runtime/tiling_context.h" +#include "tiling/platform/platform_ascendc.h" +#include "register/op_def_registry.h" +#include "register/tilingdata_base.h" +#include "tiling/tiling_api.h" +#include "error/ops_error.h" +#include "platform/platform_info.h" + +namespace optiling { + +struct TilingRequiredParaInfo { + const gert::CompileTimeTensorDesc *desc; + const gert::StorageShape *shape; +}; + +struct TilingOptionalParaInfo { + const gert::CompileTimeTensorDesc *desc; + const gert::Tensor *tensor; +}; + +enum class DataLayout : uint32_t { + BSND = 0, + TND = 1, + BnBsND = 2 +}; + +// Inputs Index +constexpr uint32_t QUERY_INDEX = 0; +constexpr uint32_t KEY_INDEX = 1; +constexpr uint32_t WEIGTHS_INDEX = 2; +constexpr uint32_t ACTUAL_SEQ_Q_INDEX = 3; +constexpr uint32_t ACTUAL_SEQ_K_INDEX = 4; +constexpr uint32_t BLOCK_TABLE_INDEX = 5; +constexpr uint32_t LIGHTNING_INDEXER = 0; +// Attributes Index +constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 0; +constexpr uint32_t ATTR_KEY_LAYOUT_INDEX = 1; +constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 2; +constexpr uint32_t ATTR_SPARSE_MODE_INDEX = 3; +// Dim Index +constexpr uint32_t DIM_IDX_ONE = 1; +constexpr uint32_t DIM_IDX_TWO = 2; +constexpr uint32_t DIM_IDX_THREE = 3; +// Dim Num +constexpr uint32_t DIM_NUM_TWO = 2; +constexpr uint32_t DIM_NUM_THREE = 3; +constexpr uint32_t DIM_NUM_FOUR = 4; +// Input Parameter Limit Constant +constexpr uint32_t HEAD_DIM_LIMIT = 128; +constexpr uint32_t SPARSE_LIMIT = 2048; +constexpr uint32_t SPARSE_MODE_LOWER = 3; + +BEGIN_TILING_DATA_DEF(LITilingData) +TILING_DATA_FIELD_DEF(uint32_t, bSize) +TILING_DATA_FIELD_DEF(uint32_t, n2Size) +TILING_DATA_FIELD_DEF(uint32_t, gSize) +TILING_DATA_FIELD_DEF(uint32_t, s1Size) +TILING_DATA_FIELD_DEF(uint32_t, s2Size) +TILING_DATA_FIELD_DEF(uint32_t, sparseCount) +TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum) +TILING_DATA_FIELD_DEF(uint32_t, blockSize) +TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch) +TILING_DATA_FIELD_DEF(uint32_t, sparseMode) +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(LightningIndexer, LITilingData) + +struct LICompileInfo {}; + +struct LiParaInfo { + TilingRequiredParaInfo query = {nullptr, nullptr}; + TilingRequiredParaInfo key = {nullptr, nullptr}; + TilingRequiredParaInfo weights = {nullptr, nullptr}; + TilingOptionalParaInfo actualSeqLengthsQ = {nullptr, nullptr}; + TilingOptionalParaInfo actualSeqLengths = {nullptr, nullptr}; + TilingOptionalParaInfo blockTable = {nullptr, nullptr}; + TilingRequiredParaInfo attenOut = {nullptr, nullptr}; + + const char *layOut = nullptr; + const char *layOutKey = nullptr; + const int32_t *blockSize = nullptr; + const int32_t *sparseMode = nullptr; + const int32_t *sparseCount = nullptr; +}; + +class LITilingInfo { +public: + const char *opName = nullptr; + fe::PlatFormInfos *platformInfo = nullptr; + LiParaInfo opParamInfo; + // Base Param + platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910B; + uint32_t bSize = 0; + uint32_t n1Size = 0; + uint32_t n2Size = 0; + uint32_t s1Size = 0; + int64_t s2Size = 0; + uint32_t qkHeadDim = 0; + uint32_t gSize = 0; + // PageAttention + bool pageAttentionFlag = false; + int32_t blockSize = 0; + uint32_t maxBlockNumPerBatch = 0; + // Mask + int32_t sparseMode = 0; + // Others Flag + uint32_t sparseCount = 0; + // DType + ge::DataType inputQType = ge::DT_FLOAT16; + ge::DataType inputKType = ge::DT_FLOAT16; + ge::DataType outputType = ge::DT_INT32; + // Layout + DataLayout inputQLayout = DataLayout::BSND; + DataLayout inputKLayout = DataLayout::BnBsND; +}; + +class LIInfoParser { +public: + explicit LIInfoParser(gert::TilingContext *context) : context_(context) + { + } + ~LIInfoParser() = default; + + ge::graphStatus CheckRequiredInOutExistence() const; + ge::graphStatus CheckRequiredAttrExistence() const; + ge::graphStatus CheckRequiredParaExistence() const; + ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor, + const std::string &actualSeqLenName); + ge::graphStatus GetOpName(); + ge::graphStatus GetNpuInfo(); + void GetOptionalInputParaInfo(); + void GetInputParaInfo(); + void GetOutputParaInfo(); + ge::graphStatus GetAndCheckAttrParaInfo(); + ge::graphStatus GetOpParaInfo(); + ge::graphStatus ValidateInputShapesMatchQBsnd(); + ge::graphStatus ValidateInputShapesMatchQTnd(); + ge::graphStatus ValidateInputShapesMatch(); + ge::graphStatus GetAndCheckInOutDataType(); + ge::graphStatus GetBatchSize(); + ge::graphStatus GetHeadDim(); + ge::graphStatus GetS1Size(); + ge::graphStatus GetAndCheckOptionalInput(); + ge::graphStatus CheckShapeDim(); + ge::graphStatus GetAndCheckBlockSize(); + ge::graphStatus CheckBlockCount(); + ge::graphStatus GetS2SizeForPageAttention(); + ge::graphStatus GetS2Size(); + ge::graphStatus GetQueryKeyAndOutLayout(); + ge::graphStatus GetN1Size(); + ge::graphStatus GetAndCheckN2Size(); + ge::graphStatus GetGSize(); + ge::graphStatus GetAttenMaskInfo(); + ge::graphStatus GetActualSeqInfo(); + void GenerateInfo(LITilingInfo &liInfo); + ge::graphStatus ParseAndCheck(LITilingInfo &liInfo); + +public: + gert::TilingContext *context_ = nullptr; + const char *opName_; + fe::PlatFormInfos *platformInfo_; + LiParaInfo opParamInfo_; + + // BaseParams + uint32_t bSize_ = 0; + uint32_t n1Size_ = 0; + uint32_t n2Size_ = 0; + uint32_t gSize_ = 0; + uint32_t s1Size_ = 0; + int64_t s2Size_ = 0; + uint32_t headDim_ = 0; + // Layout + DataLayout qLayout_ = DataLayout::BSND; + DataLayout kLayout_ = DataLayout::BnBsND; + // PageAttention + uint32_t maxBlockNumPerBatch_ = 0; + int32_t blockSize_ = 0; + platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B; + ge::DataType inputQType_ = ge::DT_FLOAT16; + ge::DataType inputKType_ = ge::DT_FLOAT16; + ge::DataType weightsType_ = ge::DT_FLOAT16; + ge::DataType blockTableType_ = ge::DT_FLOAT16; + ge::DataType inputKRopeType_ = ge::DT_FLOAT16; + ge::DataType outputType_ = ge::DT_FLOAT16; +}; + +class LightningIndexerTiling { +public: + explicit LightningIndexerTiling(gert::TilingContext *context) : context_(context){}; + ge::graphStatus DoTiling(LITilingInfo *tilingInfo); + +private: + gert::TilingContext *context_ = nullptr; + LITilingData tilingData_; +}; + +} // namespace optiling +#endif // LIGHTNING_INDEXER_TILING_H_ \ No newline at end of file diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer.cpp b/csrc/lightning_indexer/op_kernel/lightning_indexer.cpp new file mode 100644 index 00000000..fefa72e6 --- /dev/null +++ b/csrc/lightning_indexer/op_kernel/lightning_indexer.cpp @@ -0,0 +1,58 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer.cpp + * \brief + */ + +#include "kernel_operator.h" +#include "lib/matmul_intf.h" +#include "lightning_indexer_template_tiling_key.h" +#include "lightning_indexer_kernel.h" + +using namespace LIKernel; + +#define INVOKE_LI_NO_KFC_OP_IMPL(templateClass, ...) \ + do { \ + templateClass> op; \ + LI_COPY_TILING_DATA(LITilingData, tiling); \ + op.Init(query, key, weights, actualSeqLengthsQ, actualSeqLengths, blocktable, sparseIndices, user, \ + tiling_data, &tPipe); \ + op.Process(); \ + } while (0) + +#define LI_COPY_TILING_DATA(tilingDataStruct, tiling) \ + GET_TILING_DATA_WITH_STRUCT(tilingDataStruct, tiling_data_in, tiling); \ + const tilingDataStruct *__restrict tiling_data = &tiling_data_in; + + +template +__global__ __aicore__ void lightning_indexer(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights, + __gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths, + __gm__ uint8_t *blocktable, __gm__ uint8_t *sparseIndices, + __gm__ uint8_t *workspace, __gm__ uint8_t *tiling) +{ +#if (__CCE_AICORE__ == 310) || (defined __DAV_310R6__) || (__CCE_AICORE__ == 200) + +#else + TPipe tPipe; + __gm__ uint8_t *user = GetUserWorkspace(workspace); + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); + + if constexpr (DT_Q == LI_TPL_FP16 && DT_K == LI_TPL_FP16 && DT_OUT == LI_TPL_INT32) { + INVOKE_LI_NO_KFC_OP_IMPL(LIPreload, half, half, int32_t, PAGE_ATTENTION, + LI_LAYOUT(LAYOUT_T), LI_LAYOUT(K_LAYOUT_T)); + } else { + INVOKE_LI_NO_KFC_OP_IMPL(LIPreload, bfloat16_t, bfloat16_t, int32_t, PAGE_ATTENTION, + LI_LAYOUT(LAYOUT_T), LI_LAYOUT(K_LAYOUT_T)); + } +#endif +} diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_common.h b/csrc/lightning_indexer/op_kernel/lightning_indexer_common.h new file mode 100644 index 00000000..4c693140 --- /dev/null +++ b/csrc/lightning_indexer/op_kernel/lightning_indexer_common.h @@ -0,0 +1,135 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_common.h + * \brief + */ +#ifndef LIGHTNING_INDEXER_COMMON_H +#define LIGHTNING_INDEXER_COMMON_H + +namespace LICommon { +enum class LI_LAYOUT { + BSND = 0, + TND = 1, + PA_BSND = 2 +}; + +template +struct LIType { + using queryType = Q_T; + using keyType = K_T; + using outputType = OUT_T; + static constexpr bool pageAttention = PAGE_ATTENTION; + static constexpr LI_LAYOUT layout = LAYOUT_T; + static constexpr LI_LAYOUT keyLayout = K_LAYOUT_T; +}; + +struct RunInfo { + uint32_t loop; + uint32_t bN2Idx; + uint32_t bIdx; + uint32_t n2Idx = 0; + uint32_t gS1Idx; + uint32_t s2Idx; + + uint32_t actS1Size = 1; + uint32_t actS2Size = 1; + uint32_t actMBaseSize; + uint32_t actualSingleProcessSInnerSize; + uint32_t actualSingleProcessSInnerSizeAlign; + + uint64_t tensorQueryOffset; + uint64_t tensorKeyOffset; + uint64_t tensorWeightsOffset; + uint64_t indiceOutOffset; + + bool isFirstS2InnerLoop; + bool isLastS2InnerLoop; + bool isAllLoopEnd = false; +}; + +struct ConstInfo { + static constexpr uint32_t FIA_SYNC_MODE2 = 2; + static constexpr uint32_t BUFFER_SIZE_BYTE_32B = 32; + static constexpr uint32_t BUFFER_SIZE_BYTE_64B = 64; + static constexpr uint32_t BUFFER_SIZE_BYTE_256B = 256; + static constexpr uint32_t BUFFER_SIZE_BYTE_512B = 512; + static constexpr uint32_t BUFFER_SIZE_BYTE_1K = 1024; + static constexpr uint32_t BUFFER_SIZE_BYTE_2K = 2048; + static constexpr uint32_t BUFFER_SIZE_BYTE_4K = 4096; + static constexpr uint32_t BUFFER_SIZE_BYTE_8K = 8192; + static constexpr uint32_t BUFFER_SIZE_BYTE_16K = 16384; + static constexpr uint32_t BUFFER_SIZE_BYTE_32K = 32768; + static constexpr int INVALID_IDX = -1; + + uint32_t syncC1V1 = 0U; + uint32_t syncV1C1 = 0U; + + uint32_t mBaseSize = 1ULL; + uint32_t s1BaseSize = 1ULL; + uint32_t s2BaseSize = 1ULL; + + uint64_t batchSize = 0ULL; + uint64_t gSize = 0ULL; + uint64_t qHeadNum = 0ULL; + uint64_t kHeadNum; + uint64_t headDim; + uint64_t sparseCount; + uint64_t kSeqSize = 0ULL; + uint64_t qSeqSize = 1ULL; + uint32_t kCacheBlockSize = 0; + uint32_t maxBlockNumPerBatch = 0; + LI_LAYOUT outputLayout; + bool attenMaskFlag = false; + + uint32_t actualLenQDims = 0U; + uint32_t actualLenDims = 0U; + bool isAccumSeqS1 = false; + bool isAccumSeqS2 = false; +}; + +struct SplitCoreInfo { + uint32_t s2Start = 0U; + uint32_t s2End = 0U; + uint32_t bN2Start = 0U; + uint32_t bN2End = 0U; + uint32_t gS1Start = 0U; + uint32_t gS1End = 0U; + bool isLD = false; +}; + +template +__aicore__ inline T Align(T num, T rnd) +{ + return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd) * (rnd))); +} + +template +__aicore__ inline T1 Min(T1 a, T2 b) +{ + return (a > b) ? (b) : (a); +} + +template +__aicore__ inline T1 Max(T1 a, T2 b) +{ + return (a > b) ? (a) : (b); +} + +template +__aicore__ inline T CeilDiv(T num, T rnd) +{ + return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd))); +} +} // namespace LICommon + +#endif // LIGHTNING_INDEXER_COMMON_H \ No newline at end of file diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_kernel.h b/csrc/lightning_indexer/op_kernel/lightning_indexer_kernel.h new file mode 100644 index 00000000..14ef6978 --- /dev/null +++ b/csrc/lightning_indexer/op_kernel/lightning_indexer_kernel.h @@ -0,0 +1,623 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_kernel.h + * \brief + */ + +#ifndef LIGHTNING_INDEXER_KERNEL_H +#define LIGHTNING_INDEXER_KERNEL_H + +#include "kernel_operator.h" +#include "kernel_operator_list_tensor_intf.h" +#include "kernel_tiling/kernel_tiling.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" +#include "lightning_indexer_common.h" +#include "lightning_indexer_service_vector.h" +#include "lightning_indexer_service_cube.h" + +namespace LIKernel { +using namespace LICommon; +using namespace LIServiceVec; +using namespace matmul; +using AscendC::CacheMode; +using AscendC::CrossCoreSetFlag; +using AscendC::CrossCoreWaitFlag; + +struct TempLoopInfo { + uint32_t bN2Idx = 0; + uint32_t bIdx = 0U; + uint32_t n2Idx = 0U; + uint32_t gS1Idx = 0U; + uint32_t gS1LoopEnd = 0U; + uint32_t s2LoopEnd = 0U; + uint32_t actS1Size = 1ULL; + uint32_t actS2Size = 0ULL; + bool curActSeqLenIsZero = false; + bool needDealActS1LessThanS1 = false; + uint32_t actMBaseSize = 0U; + uint32_t mBasicSizeTail = 0U; + uint32_t s2BasicSizeTail = 0U; +}; + +template +class LIPreload { +public: + __aicore__ inline LIPreload(){}; + __aicore__ inline void Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights, + __gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths, + __gm__ uint8_t *blockTable, __gm__ uint8_t *sparseIndices, __gm__ uint8_t *workspace, + const LITilingData *__restrict tiling, TPipe *tPipe); + __aicore__ inline void Process(); + + using Q_T = typename LIT::queryType; + using K_T = typename LIT::keyType; + using OUT_T = typename LIT::outputType; + static constexpr bool PAGE_ATTENTION = LIT::pageAttention; + static constexpr LI_LAYOUT LAYOUT_T = LIT::layout; + static constexpr LI_LAYOUT K_LAYOUT_T = LIT::keyLayout; + + using MM1_OUT_T = float; + + LIMatmul matmulService; + LIVector vectorService; + + static constexpr uint32_t SYNC_C1_V1_FLAG = 4; + static constexpr uint32_t SYNC_V1_C1_FLAG = 5; + + static constexpr uint32_t M_BASE_SIZE = 512; + static constexpr uint32_t S2_BASE_SIZE = 512; + static constexpr uint32_t HEAD_DIM = 128; + static constexpr uint32_t K_HEAD_NUM = 1; + static constexpr uint32_t GM_ALIGN_BYTES = 512; + + static constexpr int64_t LD_PREFETCH_LEN = 2; + // for workspace double + static constexpr uint32_t WS_DOBULE = 2; + +protected: + TPipe *pipe = nullptr; + + // offset + uint64_t queryCoreOffset = 0ULL; + uint64_t keyCoreOffset = 0ULL; + uint64_t weightsCoreOffset = 0ULL; + uint64_t indiceOutCoreOffset = 0ULL; + + GlobalTensor queryGm; + GlobalTensor keyGm; + GlobalTensor weightsGm; + + GlobalTensor indiceOutGm; + GlobalTensor blockTableGm; + + GlobalTensor actualSeqLengthsGmQ; + GlobalTensor actualSeqLengthsGm; + // workspace + GlobalTensor mm1ResGm; + GlobalTensor vec1ResGm; + GlobalTensor vec1ParamGm; + + // aic、aiv kernel info + uint32_t tmpBlockIdx = 0U; + uint32_t aiCoreIdx = 0U; + uint32_t usedCoreNum = 0U; + + LICommon::ConstInfo constInfo{}; + TempLoopInfo tempLoopInfo{}; + LICommon::SplitCoreInfo splitCoreInfo{}; + + // ================================Init functions================================== + __aicore__ inline void InitTilingData(const LITilingData *__restrict tilingData); + __aicore__ inline void InitBuffers(); + __aicore__ inline void InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths); + // ================================Split Core================================ + __aicore__ inline void SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, LICommon::SplitCoreInfo &info); + __aicore__ inline uint32_t GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size, uint32_t actS2Size); + __aicore__ inline uint32_t GetTotalBaseBlockNum(); + // ================================Process functions================================ + __aicore__ inline void ProcessMain(); + __aicore__ inline void ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx, LICommon::RunInfo &runInfo); + __aicore__ inline void ProcessDecode(); + __aicore__ inline void ProcessInvalid(); + // ================================Params Calc===================================== + __aicore__ inline void CalcGS1LoopParams(uint32_t bN2Idx); + __aicore__ inline void GetBN2Idx(uint32_t bN2Idx); + __aicore__ inline uint32_t GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq, + GlobalTensor &actualSeqLengthsGm, uint32_t defaultSeqLen); + __aicore__ inline void GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size); + __aicore__ inline void CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx); + __aicore__ inline void CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LICommon::RunInfo &runInfo); + __aicore__ inline void DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start); +}; + +template +__aicore__ inline void LIPreload::InitTilingData(const LITilingData *__restrict tilingData) +{ + usedCoreNum = tilingData->usedCoreNum; + constInfo.batchSize = tilingData->bSize; + constInfo.qHeadNum = constInfo.gSize = tilingData->gSize; + constInfo.kSeqSize = tilingData->s2Size; + constInfo.qSeqSize = tilingData->s1Size; + constInfo.attenMaskFlag = (tilingData->sparseMode == 3); + constInfo.kCacheBlockSize = tilingData->blockSize; + constInfo.maxBlockNumPerBatch = tilingData->maxBlockNumPerBatch; + constInfo.sparseCount = tilingData->sparseCount; + constInfo.outputLayout = LAYOUT_T; + if (LAYOUT_T == LI_LAYOUT::TND) { + constInfo.isAccumSeqS1 = true; + } + if (K_LAYOUT_T == LI_LAYOUT::TND) { + constInfo.isAccumSeqS2 = true; + } + + constInfo.kHeadNum = K_HEAD_NUM; + constInfo.headDim = HEAD_DIM; + + constInfo.mBaseSize = M_BASE_SIZE; + constInfo.s2BaseSize = S2_BASE_SIZE; + constInfo.s1BaseSize = (constInfo.mBaseSize + constInfo.gSize - 1) / constInfo.gSize; +} + +template +__aicore__ inline void LIPreload::InitBuffers() +{ + if ASCEND_IS_AIV { + vectorService.InitBuffers(pipe); + } else { + matmulService.InitBuffers(pipe); + } +} + +template +__aicore__ inline void LIPreload::InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, + __gm__ uint8_t *actualSeqLengths) +{ + if (actualSeqLengthsQ == nullptr) { + constInfo.actualLenQDims = 0; + } else { + constInfo.actualLenQDims = constInfo.batchSize; + actualSeqLengthsGmQ.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengthsQ, constInfo.actualLenQDims); + } + if (actualSeqLengths == nullptr) { + constInfo.actualLenDims = 0; + } else { + constInfo.actualLenDims = constInfo.batchSize; + actualSeqLengthsGm.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengths, constInfo.actualLenDims); + } +} + +template +__aicore__ inline uint32_t LIPreload::GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq, + GlobalTensor &actualSeqLengthsGm, + uint32_t defaultSeqLen) +{ + if (actualLenDims == 0) { + return defaultSeqLen; + } else if (isAccumSeq && bIdx > 0) { + return actualSeqLengthsGm.GetValue(bIdx) - actualSeqLengthsGm.GetValue(bIdx - 1); + } else { + return actualSeqLengthsGm.GetValue(bIdx); + } +} + +template +__aicore__ inline void LIPreload::GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size) +{ + actS1Size = GetActualSeqLen(bIdx, constInfo.actualLenQDims, constInfo.isAccumSeqS1, actualSeqLengthsGmQ, + constInfo.qSeqSize); + actS2Size = + GetActualSeqLen(bIdx, constInfo.actualLenDims, constInfo.isAccumSeqS2, actualSeqLengthsGm, constInfo.kSeqSize); +} + +template +__aicore__ inline uint32_t LIPreload::GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size, + uint32_t actS2Size) +{ + if (actS2Size == 0) { + return 0; + } + uint32_t s1Offset = constInfo.s1BaseSize * s1gIdx; + int32_t validS2LenBase = static_cast(actS2Size) - static_cast(actS1Size); + int32_t validS2Len = s1Offset + validS2LenBase + constInfo.s1BaseSize; + validS2Len = Min(validS2Len, static_cast(actS2Size)); + validS2Len = Max(validS2Len, 1); + return (validS2Len + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize; +} + +template +__aicore__ inline uint32_t LIPreload::GetTotalBaseBlockNum() +{ + uint32_t totalBlockNum = 0; + uint32_t actS1Size, actS2Size; + uint32_t s1GBaseNum, s2BaseNum; + for (uint32_t bIdx = 0; bIdx < constInfo.batchSize; bIdx++) { + GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size); + s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize); + if (!constInfo.attenMaskFlag) { + s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize); + totalBlockNum += s1GBaseNum * s2BaseNum * constInfo.kHeadNum; + continue; + } + for (uint32_t s1gIdx = 0; s1gIdx < s1GBaseNum; s1gIdx++) { + s2BaseNum = GetS2BaseBlockNumOnMask(s1gIdx, actS1Size, actS2Size); + totalBlockNum += s2BaseNum * constInfo.kHeadNum; + } + } + return totalBlockNum; +} + +template +__aicore__ void inline LIPreload::SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, LICommon::SplitCoreInfo &info) +{ + uint32_t totalBlockNum = GetTotalBaseBlockNum(); + uint32_t minBlockPerCore = totalBlockNum / coreNum; + uint32_t deal1MoreBlockCoreNum = totalBlockNum % coreNum; + uint32_t coreIdx = 0; + uint32_t lastGS1RemainBlockCnt = 0; + uint32_t coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore; + coreNum = minBlockPerCore == 0 ? deal1MoreBlockCoreNum : coreNum; + + bool findLastCoreEnd = true; + uint32_t actS1Size, actS2Size; + uint32_t s1GBaseNum, s2BaseNum; + for (uint32_t bN2Idx = 0; bN2Idx < constInfo.batchSize * constInfo.kHeadNum; bN2Idx++) { + uint32_t bIdx = bN2Idx / constInfo.kHeadNum; + if (bN2Idx % constInfo.kHeadNum == 0) { + GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size); + s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize); + s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize); + } + if constexpr (LAYOUT_T == LI_LAYOUT::BSND) { + if (findLastCoreEnd && (s1GBaseNum == 0U || s2BaseNum == 0U)) { + info.bN2Start = bN2Idx; + info.gS1Start = 0; + info.s2Start = 0; + findLastCoreEnd = false; + } + } + for (uint32_t gS1Idx = 0; gS1Idx < s1GBaseNum; gS1Idx++) { + if (constInfo.attenMaskFlag) { + s2BaseNum = GetS2BaseBlockNumOnMask(gS1Idx, actS1Size, actS2Size); + } + if (findLastCoreEnd && s2BaseNum == 0U) { + info.bN2Start = bN2Idx; + info.gS1Start = gS1Idx; + info.s2Start = 0; + findLastCoreEnd = false; + } + for (uint32_t s2Idx = 0; s2Idx < s2BaseNum;) { + if (findLastCoreEnd) { + info.bN2Start = bN2Idx; + info.gS1Start = gS1Idx; + info.s2Start = s2Idx; + findLastCoreEnd = false; + } + uint32_t s2RemainBaseNum = s2BaseNum - s2Idx; + if (lastGS1RemainBlockCnt + s2RemainBaseNum >= coreDealBlockCnt) { + info.bN2End = bN2Idx; + info.gS1End = gS1Idx; + info.s2End = s2Idx + coreDealBlockCnt - lastGS1RemainBlockCnt - 1; + + if (coreIdx == curCoreIdx) { + if (s2Idx == 0 && info.s2End + 1 < s2BaseNum) { + info.isLD = true; + } + if (coreIdx == coreNum - 1 && info.bN2End != constInfo.batchSize -1) { + info.bN2End = constInfo.batchSize -1; + info.gS1End = 0; + info.s2End = 0; + } + return; + } + coreIdx++; + findLastCoreEnd = true; + s2Idx = info.s2End + 1; + lastGS1RemainBlockCnt = 0; + coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore; + } else { + lastGS1RemainBlockCnt += s2RemainBaseNum; + break; + } + } + } + } +} + +template +__aicore__ inline void LIPreload::DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start) +{ + if ASCEND_IS_AIV { + if (constInfo.outputLayout == LI_LAYOUT::TND) { + uint32_t tSize = actualSeqLengthsGmQ.GetValue(constInfo.batchSize - 1); + uint32_t tBase = bIdx == 0 ? 0 : actualSeqLengthsGmQ.GetValue(bIdx - 1); + uint32_t s1Count = tempLoopInfo.actS1Size; + + for (uint32_t s1Idx = s1Start; s1Idx < s1Count; s1Idx++) { + uint64_t indiceOutOffset = + (tBase + s1Idx) * constInfo.kHeadNum * constInfo.sparseCount + + n2Idx * constInfo.sparseCount; + vectorService.CleanInvalidOutput(indiceOutOffset); + } + } else if (constInfo.outputLayout == LI_LAYOUT::BSND) { + for (uint32_t s1Idx = s1Start; s1Idx < constInfo.qSeqSize; s1Idx++) { + // B,S1,N2,K + uint64_t indiceOutOffset = bIdx * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount + + s1Idx * constInfo.kHeadNum * constInfo.sparseCount + + n2Idx * constInfo.sparseCount; + vectorService.CleanInvalidOutput(indiceOutOffset); + } + } + } +} + +template +__aicore__ inline void LIPreload::Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights, + __gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths, + __gm__ uint8_t *blockTable, __gm__ uint8_t *sparseIndices, + __gm__ uint8_t *workspace, const LITilingData *__restrict tiling, + TPipe *tPipe) +{ + if ASCEND_IS_AIV { + tmpBlockIdx = GetBlockIdx(); // vec:0-47 + aiCoreIdx = tmpBlockIdx / 2; + } else { + tmpBlockIdx = GetBlockIdx(); // cube:0-23 + aiCoreIdx = tmpBlockIdx; + } + + InitTilingData(tiling); + InitActualSeqLen(actualSeqLengthsQ, actualSeqLengths); + + SplitCore(aiCoreIdx, usedCoreNum, splitCoreInfo); + + pipe = tPipe; + uint64_t offset = 0; + uint64_t singleCoreMm1ResSize = WS_DOBULE * constInfo.mBaseSize * constInfo.s2BaseSize * sizeof(MM1_OUT_T); + mm1ResGm.SetGlobalBuffer((__gm__ MM1_OUT_T *)(workspace + offset + aiCoreIdx * singleCoreMm1ResSize)); + offset += GetBlockNum() * singleCoreMm1ResSize; + + vec1ResGm.SetGlobalBuffer((__gm__ float *)(workspace + offset)); + offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * WS_DOBULE * BASE_TOPK * sizeof(float); + + vec1ParamGm.SetGlobalBuffer((__gm__ int64_t *)(workspace + offset)); + offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * LD_PARAM_NUM * sizeof(int64_t); + + if ASCEND_IS_AIV { + vectorService.InitParams(constInfo, tiling); + indiceOutGm.SetGlobalBuffer((__gm__ int32_t *)sparseIndices); + weightsGm.SetGlobalBuffer((__gm__ K_T *)weights); + vectorService.InitVec1GlobalTensor(mm1ResGm, vec1ResGm, vec1ParamGm, weightsGm, indiceOutGm); + } else { + matmulService.InitParams(constInfo); + queryGm.SetGlobalBuffer((__gm__ Q_T *)query); + if constexpr (PAGE_ATTENTION) { + blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable); + } + keyGm.SetGlobalBuffer((__gm__ K_T *)key); + matmulService.InitMm1GlobalTensor(blockTableGm, keyGm, queryGm, mm1ResGm); + } + InitBuffers(); +} + +template +__aicore__ inline void LIPreload::GetBN2Idx(uint32_t bN2Idx) +{ + tempLoopInfo.bN2Idx = bN2Idx; + tempLoopInfo.bIdx = bN2Idx / constInfo.kHeadNum; + tempLoopInfo.n2Idx = bN2Idx % constInfo.kHeadNum; +} + +template +__aicore__ inline void LIPreload::CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx) +{ + tempLoopInfo.gS1Idx = gS1LoopIdx; + tempLoopInfo.actMBaseSize = constInfo.mBaseSize; + uint32_t remainedGS1Size = tempLoopInfo.actS1Size * constInfo.gSize - tempLoopInfo.gS1Idx * constInfo.mBaseSize; + if (remainedGS1Size <= constInfo.mBaseSize && remainedGS1Size > 0) { + tempLoopInfo.actMBaseSize = tempLoopInfo.mBasicSizeTail; + } + + bool isEnd = (bN2LoopIdx == splitCoreInfo.bN2End) && (gS1LoopIdx == splitCoreInfo.gS1End); + uint32_t s2BlockNum; + if (constInfo.attenMaskFlag) { + s2BlockNum = GetS2BaseBlockNumOnMask(gS1LoopIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size); + } else { + s2BlockNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize; + } + tempLoopInfo.s2LoopEnd = isEnd ? splitCoreInfo.s2End : s2BlockNum - 1; +} + +template +__aicore__ inline void LIPreload::CalcGS1LoopParams(uint32_t bN2LoopIdx) +{ + GetBN2Idx(bN2LoopIdx); + GetS1S2ActualSeqLen(tempLoopInfo.bIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size); + if ((tempLoopInfo.actS2Size == 0) || (tempLoopInfo.actS1Size == 0)) { + tempLoopInfo.curActSeqLenIsZero = true; + return; + } + tempLoopInfo.curActSeqLenIsZero = false; + tempLoopInfo.s2BasicSizeTail = tempLoopInfo.actS2Size % constInfo.s2BaseSize; + tempLoopInfo.s2BasicSizeTail = + (tempLoopInfo.s2BasicSizeTail == 0) ? constInfo.s2BaseSize : tempLoopInfo.s2BasicSizeTail; + tempLoopInfo.mBasicSizeTail = (tempLoopInfo.actS1Size * constInfo.gSize) % constInfo.mBaseSize; + tempLoopInfo.mBasicSizeTail = + (tempLoopInfo.mBasicSizeTail == 0) ? constInfo.mBaseSize : tempLoopInfo.mBasicSizeTail; + + uint32_t gS1SplitNum = (tempLoopInfo.actS1Size * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize; + tempLoopInfo.gS1LoopEnd = (bN2LoopIdx == splitCoreInfo.bN2End) ? splitCoreInfo.gS1End : gS1SplitNum - 1; + if constexpr (LAYOUT_T == LI_LAYOUT::BSND) { + if (tempLoopInfo.gS1LoopEnd == gS1SplitNum - 1 && constInfo.qSeqSize > tempLoopInfo.actS1Size) { + tempLoopInfo.needDealActS1LessThanS1 = true; + } + } +} + +template +__aicore__ inline void LIPreload::CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LICommon::RunInfo &runInfo) +{ + runInfo.loop = loop; + runInfo.bIdx = tempLoopInfo.bIdx; + runInfo.gS1Idx = tempLoopInfo.gS1Idx; + runInfo.s2Idx = s2LoopIdx; + runInfo.bN2Idx = tempLoopInfo.bN2Idx; + + runInfo.actS1Size = tempLoopInfo.actS1Size; + runInfo.actS2Size = tempLoopInfo.actS2Size; + runInfo.actMBaseSize = tempLoopInfo.actMBaseSize; + runInfo.actualSingleProcessSInnerSize = constInfo.s2BaseSize; + uint32_t s2SplitNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize; + if (runInfo.s2Idx == s2SplitNum - 1) { + runInfo.actualSingleProcessSInnerSize = tempLoopInfo.s2BasicSizeTail; + } + runInfo.actualSingleProcessSInnerSizeAlign = + LICommon::Align((uint32_t)runInfo.actualSingleProcessSInnerSize, LICommon::ConstInfo::BUFFER_SIZE_BYTE_32B); + + runInfo.isFirstS2InnerLoop = s2LoopIdx == splitCoreInfo.s2Start; + runInfo.isLastS2InnerLoop = s2LoopIdx == tempLoopInfo.s2LoopEnd; + runInfo.isAllLoopEnd = (runInfo.bN2Idx == splitCoreInfo.bN2End) && (runInfo.gS1Idx == splitCoreInfo.gS1End) && + (runInfo.s2Idx == splitCoreInfo.s2End); + + if (runInfo.isFirstS2InnerLoop) { + uint64_t actualSeqQPrefixSum; + uint64_t actualSeqKPrefixSum; + if constexpr (LAYOUT_T == LI_LAYOUT::TND) { + actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGmQ.GetValue(runInfo.bIdx - 1); + actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGm.GetValue(runInfo.bIdx - 1); + } else { // BSND + actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.qSeqSize; + actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.kSeqSize; + } + uint64_t tndBIdxOffset = actualSeqQPrefixSum * constInfo.qHeadNum * constInfo.headDim; + uint64_t tndKeyBIdxOffset = actualSeqKPrefixSum * constInfo.kHeadNum * constInfo.headDim; + // B,S1,N1(N2,G),D + queryCoreOffset = tndBIdxOffset + runInfo.gS1Idx * constInfo.mBaseSize * constInfo.headDim; + keyCoreOffset = tndKeyBIdxOffset + runInfo.n2Idx * constInfo.headDim; + // B,S1,N1(N2,G)/T,N1(N2,G) + weightsCoreOffset = actualSeqQPrefixSum * constInfo.qHeadNum + runInfo.n2Idx * constInfo.gSize; + // B,S1,N2,k/T,N2,k + indiceOutCoreOffset = actualSeqQPrefixSum * constInfo.kHeadNum * constInfo.sparseCount + + runInfo.n2Idx * constInfo.sparseCount; + } + runInfo.tensorQueryOffset = queryCoreOffset; + runInfo.tensorKeyOffset = keyCoreOffset + runInfo.s2Idx * constInfo.s2BaseSize * constInfo.kHeadNum + * constInfo.headDim; + runInfo.tensorWeightsOffset = weightsCoreOffset; + runInfo.indiceOutOffset = indiceOutCoreOffset; +} + +template +__aicore__ inline void LIPreload::Process() +{ + if (usedCoreNum == 0) { + ProcessInvalid(); + return; + } + ProcessMain(); + ProcessDecode(); +} + +template +__aicore__ inline void LIPreload::ProcessInvalid() +{ + if ASCEND_IS_AIV { + uint32_t aivCoreNum = GetBlockNum() * 2; // 2 means c:v = 1:2 + uint64_t totalOutputSize = + constInfo.batchSize * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount; + uint64_t singleCoreSize = + LICommon::Align((totalOutputSize + aivCoreNum - 1) / aivCoreNum, GM_ALIGN_BYTES / sizeof(OUT_T)); + uint64_t baseSize = tmpBlockIdx * singleCoreSize; + if (baseSize < totalOutputSize) { + uint64_t dealSize = + (baseSize + singleCoreSize > totalOutputSize) ? singleCoreSize : totalOutputSize - baseSize; + GlobalTensor output = indiceOutGm[baseSize]; + AscendC::InitGlobalMemory(output, dealSize, constInfo.INVALID_IDX); + } + } +} + +template +__aicore__ inline void LIPreload::ProcessMain() +{ + if (aiCoreIdx >= usedCoreNum) { + return; + } + + if ASCEND_IS_AIV { + vectorService.AllocEventID(); + CrossCoreSetFlag(constInfo.syncV1C1); + CrossCoreSetFlag(constInfo.syncV1C1); + } else { + matmulService.AllocEventID(); + } + + LICommon::RunInfo runInfo; + uint32_t gloop = 0; + for (uint32_t bN2LoopIdx = splitCoreInfo.bN2Start; bN2LoopIdx <= splitCoreInfo.bN2End; bN2LoopIdx++) { + CalcGS1LoopParams(bN2LoopIdx); + if (tempLoopInfo.curActSeqLenIsZero) { + DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, 0U); + continue; + } + for (uint32_t gS1LoopIdx = splitCoreInfo.gS1Start; gS1LoopIdx <= tempLoopInfo.gS1LoopEnd; gS1LoopIdx++) { + CalcS2LoopParams(bN2LoopIdx, gS1LoopIdx); + for (int s2LoopIdx = splitCoreInfo.s2Start; s2LoopIdx <= tempLoopInfo.s2LoopEnd; s2LoopIdx++) { + ProcessBaseBlock(gloop, s2LoopIdx, runInfo); + ++gloop; + } + splitCoreInfo.s2Start = 0; + } + if (tempLoopInfo.needDealActS1LessThanS1) { + DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, tempLoopInfo.actS1Size); + } + splitCoreInfo.gS1Start = 0; + } + + if ASCEND_IS_AIV { + vectorService.FreeEventID(); + } else { + matmulService.FreeEventID(); + CrossCoreWaitFlag(constInfo.syncV1C1); + CrossCoreWaitFlag(constInfo.syncV1C1); + } +} + +template +__aicore__ inline void LIPreload::ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx, LICommon::RunInfo &runInfo) +{ + CalcRunInfo(loop, s2LoopIdx, runInfo); + if ASCEND_IS_AIC { + CrossCoreWaitFlag(constInfo.syncV1C1); + matmulService.ComputeMm1(runInfo); + CrossCoreSetFlag(constInfo.syncC1V1); + } else { + CrossCoreWaitFlag(constInfo.syncC1V1); + vectorService.ProcessVec(runInfo); + CrossCoreSetFlag(constInfo.syncV1C1); + } +} + +template +__aicore__ inline void LIPreload::ProcessDecode() +{ + if ASCEND_IS_AIV { + vectorService.InitLDBuffers(pipe); + ICachePreLoad(LD_PREFETCH_LEN); + SyncAll(); + if (splitCoreInfo.isLD) { + vectorService.ProcessLD(); + } + } +} +} // namespace LIKernel +#endif // LIGHTNING_INDEXER_KERNEL_H \ No newline at end of file diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_service_cube.h b/csrc/lightning_indexer/op_kernel/lightning_indexer_service_cube.h new file mode 100644 index 00000000..aa188876 --- /dev/null +++ b/csrc/lightning_indexer/op_kernel/lightning_indexer_service_cube.h @@ -0,0 +1,415 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_service_cube.h + * \brief use 5 buffer for matmul l1, better pipeline + */ +#ifndef LIGHTNING_INDEXER_SERVICE_CUBE_H +#define LIGHTNING_INDEXER_SERVICE_CUBE_H + +#include "kernel_operator.h" +#include "kernel_operator_list_tensor_intf.h" +#include "kernel_tiling/kernel_tiling.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" +#include "lightning_indexer_common.h" + +namespace LIKernel { +using namespace LICommon; +template +class LIMatmul { +public: + using Q_T = typename LIT::queryType; + using K_T = typename LIT::keyType; + + __aicore__ inline LIMatmul(){}; + __aicore__ inline void InitBuffers(TPipe *pipe); + __aicore__ inline void InitMm1GlobalTensor(const GlobalTensor &blkTableGm, const GlobalTensor &keyGm, + const GlobalTensor &queryGm, const GlobalTensor &mm1ResGm); + __aicore__ inline void InitParams(const ConstInfo &constInfo); + __aicore__ inline void AllocEventID(); + __aicore__ inline void FreeEventID(); + __aicore__ inline void ComputeMm1(const LICommon::RunInfo &runInfo); + + static constexpr IsResetLoad3dConfig LOAD3DV2_CONFIG = {true, true}; // isSetFMatrix isSetPadding; + static constexpr uint64_t KEY_BUF_NUM = 3; + static constexpr uint64_t QUERY_BUF_NUM = 2; + static constexpr uint64_t L0_BUF_NUM = 2; + + static constexpr uint32_t KEY_MTE1_MTE2_EVENT = EVENT_ID2; + static constexpr uint32_t QUERY_MTE1_MTE2_EVENT = EVENT_ID5; // KEY_MTE1_MTE2_EVENT + KEY_BUF_NUM; + static constexpr uint32_t M_MTE1_EVENT = EVENT_ID3; + + static constexpr uint32_t MTE2_MTE1_EVENT = EVENT_ID2; + static constexpr uint32_t MTE1_M_EVENT = EVENT_ID2; + + static constexpr uint64_t M_BASIC_BLOCK = 256; + static constexpr uint64_t D_BASIC_BLOCK = 128; + static constexpr uint64_t S2_BASIC_BLOCK = 256; + + static constexpr uint64_t M_BASIC_BLOCK_L0 = 128; + static constexpr uint64_t D_BASIC_BLOCK_L0 = 128; + static constexpr uint64_t S2_BASIC_BLOCK_L0 = 128; + + static constexpr uint64_t QUERY_BUFFER_OFFSET = M_BASIC_BLOCK * D_BASIC_BLOCK; + static constexpr uint64_t KEY_BUFFER_OFFSET = S2_BASIC_BLOCK * D_BASIC_BLOCK; + static constexpr uint64_t L0AB_BUFFER_OFFSET = M_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0; + static constexpr uint64_t L0C_BUFFER_OFFSET = M_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0; + +protected: + __aicore__ inline void Fixp(uint64_t s1gGmOffset, uint64_t s2GmOffset, uint64_t s1gL0RealSize, + uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo); + __aicore__ inline void ComuteL0c(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo); + __aicore__ inline void LoadKeyToL0b(uint64_t s2L0Offset, uint64_t s2L1RealSize, uint64_t s2L0RealSize, + const LICommon::RunInfo &runInfo); + __aicore__ inline void LoadQueryToL0a(uint64_t s1gL1Offset, uint64_t s1gL0Offset, uint64_t s1gL1RealSize, + uint64_t s1gL0RealSize, const LICommon::RunInfo &runInfo); + __aicore__ inline void QueryNd2Nz(uint64_t s1gL1RealSize, uint64_t s1gL1Offset, const LICommon::RunInfo &runInfo); + __aicore__ inline void KeyNd2Nz(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LICommon::RunInfo &runInfo); + __aicore__ inline void KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LICommon::RunInfo &runInfo); + GlobalTensor blkTableGm_; + GlobalTensor keyGm_; + GlobalTensor queryGm_; + GlobalTensor mm1ResGm_; + + TBuf bufQL1_; + LocalTensor queryL1_; + TBuf bufKeyL1_; + LocalTensor keyL1_; + + TBuf bufQL0_; + LocalTensor queryL0_; + TBuf bufKeyL0_; + LocalTensor keyL0_; + + TBuf bufL0C_; + LocalTensor cL0_; + + uint64_t keyL1BufIdx_ = 0; + uint64_t queryL1Mte2BufIdx_ = 0; + uint64_t queryL1Mte1BufIdx_ = 0; + uint64_t l0BufIdx_ = 0; + + ConstInfo constInfo_; + +private: + static constexpr bool PAGE_ATTENTION = LIT::pageAttention; +}; + +template +__aicore__ inline void LIMatmul::InitParams(const ConstInfo &constInfo) +{ + constInfo_ = constInfo; +} + +template +__aicore__ inline void LIMatmul::InitBuffers(TPipe *pipe) +{ + pipe->InitBuffer(bufQL1_, QUERY_BUF_NUM * M_BASIC_BLOCK * D_BASIC_BLOCK * sizeof(Q_T)); + queryL1_ = bufQL1_.Get(); + pipe->InitBuffer(bufKeyL1_, KEY_BUF_NUM * S2_BASIC_BLOCK * D_BASIC_BLOCK * sizeof(K_T)); + keyL1_ = bufKeyL1_.Get(); + + pipe->InitBuffer(bufQL0_, L0_BUF_NUM * M_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 * sizeof(Q_T)); + queryL0_ = bufQL0_.Get(); + pipe->InitBuffer(bufKeyL0_, L0_BUF_NUM * D_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0 * sizeof(K_T)); + keyL0_ = bufKeyL0_.Get(); + + pipe->InitBuffer(bufL0C_, L0_BUF_NUM * M_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0 * sizeof(float)); + cL0_ = bufL0C_.Get(); +} + +template +__aicore__ inline void +LIMatmul::InitMm1GlobalTensor(const GlobalTensor &blkTableGm, const GlobalTensor &keyGm, + const GlobalTensor &queryGm, const GlobalTensor &mm1ResGm) +{ + blkTableGm_ = blkTableGm; + keyGm_ = keyGm; + queryGm_ = queryGm; + mm1ResGm_ = mm1ResGm; +} + +template +__aicore__ inline void LIMatmul::ComputeMm1(const LICommon::RunInfo &runInfo) +{ + uint64_t s2GmBaseOffset = runInfo.s2Idx * constInfo_.s2BaseSize; + uint64_t s1gProcessSize = runInfo.actMBaseSize; + uint64_t s2ProcessSize = runInfo.actualSingleProcessSInnerSize; + for (uint64_t s2GmOffset = 0; s2GmOffset < s2ProcessSize; s2GmOffset += S2_BASIC_BLOCK) { + WaitFlag(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % KEY_BUF_NUM); + uint64_t s2L1RealSize = + s2GmOffset + S2_BASIC_BLOCK > s2ProcessSize ? s2ProcessSize - s2GmOffset : S2_BASIC_BLOCK; + if (PAGE_ATTENTION) { + KeyNd2NzForPA(s2L1RealSize, s2GmBaseOffset + s2GmOffset, runInfo); + }else { + KeyNd2Nz(s2L1RealSize, s2GmOffset, runInfo); + } + + SetFlag(MTE2_MTE1_EVENT); + WaitFlag(MTE2_MTE1_EVENT); + for (uint64_t s1gGmOffset = 0; s1gGmOffset < s1gProcessSize; s1gGmOffset += M_BASIC_BLOCK) { + uint64_t s1gL1RealSize = + s1gGmOffset + M_BASIC_BLOCK > s1gProcessSize ? s1gProcessSize - s1gGmOffset : M_BASIC_BLOCK; + if (runInfo.isFirstS2InnerLoop && s2GmOffset == 0) { + queryL1Mte2BufIdx_++; + queryL1Mte1BufIdx_ = queryL1Mte2BufIdx_; + WaitFlag(QUERY_MTE1_MTE2_EVENT + queryL1Mte2BufIdx_ % QUERY_BUF_NUM); + QueryNd2Nz(s1gL1RealSize, s1gGmOffset, runInfo); + SetFlag(MTE2_MTE1_EVENT); + WaitFlag(MTE2_MTE1_EVENT); + } else { + queryL1Mte1BufIdx_ = + queryL1Mte2BufIdx_ - (CeilDiv(s1gProcessSize, M_BASIC_BLOCK) - 1 - (s1gGmOffset > 0)); + } + for (uint64_t s2L1Offset = 0; s2L1Offset < s2L1RealSize; s2L1Offset += S2_BASIC_BLOCK_L0) { + uint64_t s2L0RealSize = + s2L1Offset + S2_BASIC_BLOCK_L0 > s2L1RealSize ? s2L1RealSize - s2L1Offset : S2_BASIC_BLOCK_L0; + for (uint64_t s1gL1Offset = 0; s1gL1Offset < s1gL1RealSize; s1gL1Offset += M_BASIC_BLOCK_L0) { + WaitFlag(M_MTE1_EVENT + l0BufIdx_ % L0_BUF_NUM); + uint64_t s1gL0RealSize = + s1gL1Offset + M_BASIC_BLOCK_L0 > s1gL1RealSize ? s1gL1RealSize - s1gL1Offset : M_BASIC_BLOCK_L0; + LoadQueryToL0a(s1gGmOffset, s1gL1Offset, s1gL1RealSize, s1gL0RealSize, runInfo); + LoadKeyToL0b(s2L1Offset, s2L1RealSize, s2L0RealSize, runInfo); + + SetFlag(MTE1_M_EVENT); + WaitFlag(MTE1_M_EVENT); + + ComuteL0c(s1gL0RealSize, s2L0RealSize, runInfo); + + SetFlag(M_MTE1_EVENT + l0BufIdx_ % L0_BUF_NUM); + + Fixp(s1gGmOffset + s1gL1Offset, s2GmOffset + s2L1Offset, s1gL0RealSize, s2L0RealSize, runInfo); + l0BufIdx_++; + } + } + if (s2GmOffset + S2_BASIC_BLOCK >= s2ProcessSize && runInfo.isLastS2InnerLoop) { + SetFlag(QUERY_MTE1_MTE2_EVENT + queryL1Mte1BufIdx_ % QUERY_BUF_NUM); + } + } + + SetFlag(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % KEY_BUF_NUM); + keyL1BufIdx_++; + } +} + +template +__aicore__ inline void LIMatmul::KeyNd2Nz(uint64_t s2L1RealSize, uint64_t s2GmOffset, + const LICommon::RunInfo &runInfo) +{ + uint64_t s2L1Offset = 0; + while (s2L1Offset < s2L1RealSize) { + uint64_t keyGmOffset = runInfo.tensorKeyOffset + (s2GmOffset + s2L1Offset) * constInfo_.headDim; + uint64_t s2Mte2Size = (s2L1RealSize <= S2_BASIC_BLOCK_L0 || s2L1Offset >= S2_BASIC_BLOCK_L0) ? + s2L1RealSize - s2L1Offset : + S2_BASIC_BLOCK_L0 - s2L1Offset; + + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = s2Mte2Size; // 行数 + nd2nzPara.dValue = constInfo_.headDim; + nd2nzPara.srcDValue = constInfo_.headDim; + nd2nzPara.dstNzC0Stride = s2L1Offset >= S2_BASIC_BLOCK_L0 ? + CeilAlign(s2L1RealSize - S2_BASIC_BLOCK_L0, (uint64_t)BLOCK_CUBE) : + (s2L1RealSize > S2_BASIC_BLOCK_L0 ? + S2_BASIC_BLOCK_L0 : + CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE)); + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + DataCopy(keyL1_[(keyL1BufIdx_ % KEY_BUF_NUM) * KEY_BUFFER_OFFSET + + (s2L1Offset >= S2_BASIC_BLOCK_L0 ? + S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 + (s2L1Offset - S2_BASIC_BLOCK_L0) * BLOCK_CUBE : + s2L1Offset * BLOCK_CUBE)], + keyGm_[keyGmOffset], nd2nzPara); + + s2L1Offset += s2Mte2Size; + } +} + +// blkNum, blkSize, N2, D +template +__aicore__ inline void LIMatmul::KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset, + const LICommon::RunInfo &runInfo) +{ + uint64_t s2L1Offset = 0; + while (s2L1Offset < s2L1RealSize) { + uint64_t s2BlkId = (s2L1Offset + s2GmOffset) / constInfo_.kCacheBlockSize; + uint64_t s2BlkOffset = (s2L1Offset + s2GmOffset) % constInfo_.kCacheBlockSize; + uint64_t keyGmOffset = blkTableGm_.GetValue(runInfo.bIdx * constInfo_.maxBlockNumPerBatch + s2BlkId) * + constInfo_.kCacheBlockSize * constInfo_.kHeadNum * constInfo_.headDim + + s2BlkOffset * constInfo_.headDim; + uint64_t s2Mte2Size = (s2L1RealSize <= S2_BASIC_BLOCK_L0 || s2L1Offset >= S2_BASIC_BLOCK_L0) ? + s2L1RealSize - s2L1Offset : + S2_BASIC_BLOCK_L0 - s2L1Offset; + s2Mte2Size = s2BlkOffset + s2Mte2Size >= constInfo_.kCacheBlockSize ? constInfo_.kCacheBlockSize - s2BlkOffset : + s2Mte2Size; + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = s2Mte2Size; + nd2nzPara.dValue = constInfo_.headDim; + nd2nzPara.srcDValue = constInfo_.headDim; + nd2nzPara.dstNzC0Stride = s2L1Offset >= S2_BASIC_BLOCK_L0 ? + CeilAlign(s2L1RealSize - S2_BASIC_BLOCK_L0, (uint64_t)BLOCK_CUBE) : + (s2L1RealSize > S2_BASIC_BLOCK_L0 ? + S2_BASIC_BLOCK_L0 : + CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE)); + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + DataCopy(keyL1_[(keyL1BufIdx_ % KEY_BUF_NUM) * KEY_BUFFER_OFFSET + + (s2L1Offset >= S2_BASIC_BLOCK_L0 ? + S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 + (s2L1Offset - S2_BASIC_BLOCK_L0) * BLOCK_CUBE : + s2L1Offset * BLOCK_CUBE)], + keyGm_[keyGmOffset], nd2nzPara); + + s2L1Offset += s2Mte2Size; + } +} + +// batch, s1, n2, g, d +template +__aicore__ inline void LIMatmul::QueryNd2Nz(uint64_t s1gL1RealSize, uint64_t s1gGmOffset, + const LICommon::RunInfo &runInfo) +{ + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = s1gL1RealSize; + nd2nzPara.dValue = constInfo_.headDim; + nd2nzPara.srcDValue = constInfo_.headDim; + nd2nzPara.dstNzC0Stride = CeilAlign(s1gL1RealSize, (uint64_t)BLOCK_CUBE); + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + DataCopy(queryL1_[(queryL1Mte2BufIdx_ % QUERY_BUF_NUM) * QUERY_BUFFER_OFFSET], + queryGm_[runInfo.tensorQueryOffset + s1gGmOffset * constInfo_.headDim], nd2nzPara); +} + +template +__aicore__ inline void LIMatmul::LoadQueryToL0a(uint64_t s1gGmOffset, uint64_t s1gL1Offset, uint64_t s1gL1RealSize, + uint64_t s1gL0RealSize, const LICommon::RunInfo &runInfo) +{ + LoadData3DParamsV2 loadData3DParams; + // SetFmatrixParams + loadData3DParams.l1H = CeilDiv(s1gL1RealSize, BLOCK_CUBE); // Hin=M1=8 + loadData3DParams.l1W = BLOCK_CUBE; // Win=M0 + loadData3DParams.channelSize = constInfo_.headDim; // Cin=K + + loadData3DParams.padList[0] = 0; + loadData3DParams.padList[1] = 0; + loadData3DParams.padList[2] = 0; + loadData3DParams.padList[3] = 255; + + // SetLoadToA0Params + loadData3DParams.mExtension = CeilAlign(s1gL0RealSize, BLOCK_CUBE); + loadData3DParams.kExtension = constInfo_.headDim; + loadData3DParams.mStartPt = s1gL1Offset; + loadData3DParams.kStartPt = 0; + loadData3DParams.strideW = 1; + loadData3DParams.strideH = 1; + loadData3DParams.filterW = 1; + loadData3DParams.filterSizeW = (1 >> 8) & 255; + loadData3DParams.filterH = 1; + loadData3DParams.filterSizeH = (1 >> 8) & 255; + loadData3DParams.dilationFilterW = 1; + loadData3DParams.dilationFilterH = 1; + loadData3DParams.enTranspose = 0; + loadData3DParams.fMatrixCtrl = 0; + + LoadData(queryL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET], + queryL1_[(queryL1Mte1BufIdx_ % QUERY_BUF_NUM) * QUERY_BUFFER_OFFSET], + loadData3DParams); +} + +template +__aicore__ inline void LIMatmul::LoadKeyToL0b(uint64_t s2L1Offset, uint64_t s2L1RealSize, uint64_t s2L0RealSize, + const LICommon::RunInfo &runInfo) +{ + uint64_t keyL1Offset = s2L1Offset >= S2_BASIC_BLOCK_L0 ? S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 : 0; + LoadData2DParams loadData2DParams; + loadData2DParams.startIndex = 0; + loadData2DParams.repeatTimes = CeilDiv(s2L0RealSize, BLOCK_CUBE) * CeilDiv(constInfo_.headDim, BLOCK_CUBE); + loadData2DParams.srcStride = 1; + loadData2DParams.dstGap = 0; + loadData2DParams.ifTranspose = false; + LoadData(keyL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET], + keyL1_[(keyL1BufIdx_ % KEY_BUF_NUM) * KEY_BUFFER_OFFSET + keyL1Offset], loadData2DParams); +} + +template +__aicore__ inline void LIMatmul::ComuteL0c(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, + const LICommon::RunInfo &runInfo) +{ + MmadParams mmadParams; + mmadParams.m = CeilAlign(s1gL0RealSize, BLOCK_CUBE); + mmadParams.n = s2L0RealSize; + mmadParams.k = constInfo_.headDim; + mmadParams.cmatrixInitVal = true; + mmadParams.cmatrixSource = false; + mmadParams.unitFlag = 0b11; + Mmad(cL0_[(l0BufIdx_ % L0_BUF_NUM) * L0C_BUFFER_OFFSET], queryL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET], + keyL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET], mmadParams); + if ((mmadParams.m / 16) * (mmadParams.n / 16) < 10) { + PipeBarrier(); + } +} + +template +__aicore__ inline void LIMatmul::Fixp(uint64_t s1gGmOffset, uint64_t s2GmOffset, uint64_t s1gL0RealSize, + uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo) +{ + AscendC::DataCopyCO12DstParams intriParams; + intriParams.mSize = CeilAlign(s1gL0RealSize, BLOCK_CUBE); + intriParams.nSize = s2L0RealSize; + intriParams.dstStride = runInfo.actualSingleProcessSInnerSizeAlign; + intriParams.srcStride = CeilAlign(s1gL0RealSize, BLOCK_CUBE); + // set mode according to dtype + intriParams.quantPre = QuantMode_t::NoQuant; + intriParams.nz2ndEn = true; + intriParams.unitFlag = 0b11; // 3 unitflag + intriParams.reluPre = 1; + AscendC::SetFixpipeNz2ndFlag(1, 1, 1); + AscendC::DataCopy(mm1ResGm_[(runInfo.loop % 2) * constInfo_.mBaseSize * constInfo_.s2BaseSize + + s1gGmOffset * intriParams.dstStride + s2GmOffset], + cL0_[(l0BufIdx_ % L0_BUF_NUM) * L0C_BUFFER_OFFSET], intriParams); +} + +template +__aicore__ inline void LIMatmul::AllocEventID() +{ + SetMMLayoutTransform(true); + SetFlag(KEY_MTE1_MTE2_EVENT + 0); + SetFlag(KEY_MTE1_MTE2_EVENT + 1); + SetFlag(KEY_MTE1_MTE2_EVENT + 2); + + SetFlag(QUERY_MTE1_MTE2_EVENT + 0); + SetFlag(QUERY_MTE1_MTE2_EVENT + 1); + + SetFlag(M_MTE1_EVENT + 0); + SetFlag(M_MTE1_EVENT + 1); +} + +template +__aicore__ inline void LIMatmul::FreeEventID() +{ + SetMMLayoutTransform(false); + WaitFlag(KEY_MTE1_MTE2_EVENT + 0); + WaitFlag(KEY_MTE1_MTE2_EVENT + 1); + WaitFlag(KEY_MTE1_MTE2_EVENT + 2); + + WaitFlag(QUERY_MTE1_MTE2_EVENT + 0); + WaitFlag(QUERY_MTE1_MTE2_EVENT + 1); + + WaitFlag(M_MTE1_EVENT + 0); + WaitFlag(M_MTE1_EVENT + 1); +} +} // namespace LIKernel +#endif \ No newline at end of file diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_service_vector.h b/csrc/lightning_indexer/op_kernel/lightning_indexer_service_vector.h new file mode 100644 index 00000000..1ed25b4c --- /dev/null +++ b/csrc/lightning_indexer/op_kernel/lightning_indexer_service_vector.h @@ -0,0 +1,559 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_service_vector.h + * \brief + */ +#ifndef LIGHTNING_INDEXER_SERVICE_VECTOR_H +#define LIGHTNING_INDEXER_SERVICE_VECTOR_H + +#include "kernel_operator.h" +#include "kernel_operator_list_tensor_intf.h" +#include "kernel_tiling/kernel_tiling.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" +#include "lightning_indexer_common.h" +#include "lightning_indexer_vector.h" + +namespace LIKernel { +using namespace LICommon; +using namespace LIServiceVec; +constexpr uint32_t BASE_TOPK = 2048; +constexpr uint32_t LD_PARAM_NUM = 16; + +template +class LIVector { +public: + using K_T = typename LIT::keyType; + static constexpr LI_LAYOUT LAYOUT_T = LIT::layout; + + using MM1_OUT_T = float; + + __aicore__ inline LIVector(){}; + __aicore__ inline void ProcessVec(const LICommon::RunInfo &info); + __aicore__ inline void ProcessLD(); + __aicore__ inline void InitBuffers(TPipe *pipe); + __aicore__ inline void InitParams(const struct LICommon::ConstInfo &constInfo, + const LITilingData *__restrict tilingData); + __aicore__ inline void InitVec1GlobalTensor(GlobalTensor mm1ResGm, GlobalTensor vec1ResGm, + GlobalTensor vec1ParamGm, GlobalTensor weightsGm, + GlobalTensor indiceOutGm); + __aicore__ inline void CleanInvalidOutput(int64_t invalidS1offset); + __aicore__ inline void AllocEventID(); + __aicore__ inline void FreeEventID(); + __aicore__ inline void InitLDBuffers(TPipe *pipe); + +protected: + GlobalTensor mm1ResGm; + GlobalTensor vec1ResGm; + GlobalTensor vec1ParamGm; + GlobalTensor weightsGm; + GlobalTensor indiceOutGm; + +private: + // queue + TQue inQueue_; + TQue outQueue_; + + // tmp buff for vector + TBuf sortOutBuf_; + TBuf indexBuf_; + TBuf reduceOutBuf_; + TBuf brcBuf_; + TBuf paramBuf_; + + // tmp buff for LD + TBuf<> ldToBeMrgBuf_; + TBuf<> ldTmpBuf_; + TBuf<> ldOutValueBuf_; + TBuf<> ldOutIdxBuf_; + + LocalTensor globalTopkIndice_; + LocalTensor globalTopkUb_; + LocalTensor SortedBasicBlock_; + + int32_t blockId_ = -1; + // para for vector + int32_t groupInner_ = 0; + int32_t globalTopkNum_ = 0; + int64_t blockS2StartIdx_ = 0; + int32_t gSize_ = 0; + int32_t kHeadNum_ = 0; + int32_t s1BaseSize_ = 0; + int32_t s2BaseSize_ = 0; + + // para for LD + uint32_t mrgListNum_ = 4; + uint32_t paramNum_ = 16; + + constexpr static uint32_t REDUCE_BANK_CONFLICT_OFFSETS = 256; + constexpr static uint32_t REDUCE_BANK_CONFLICT_NUM = REDUCE_BANK_CONFLICT_OFFSETS / sizeof(float); + + struct LICommon::ConstInfo constInfo_; +}; + +template +__aicore__ inline void LIVector::InitBuffers(TPipe *pipe) +{ + uint32_t outNeedBufSize = (BASE_TOPK * 2) * 2 * sizeof(float); + uint32_t reduceCacheSize = REDUCE_BANK_CONFLICT_OFFSETS + groupInner_ * s2BaseSize_ * sizeof(float); + outNeedBufSize = reduceCacheSize > outNeedBufSize ? reduceCacheSize : outNeedBufSize; + + pipe->InitBuffer(inQueue_, 2, + groupInner_ * s2BaseSize_ * sizeof(float) + s2BaseSize_ * sizeof(float)); // 69KB mm_out_ub + pipe->InitBuffer(outQueue_, 1, outNeedBufSize); // 32KB extract + pipe->InitBuffer(sortOutBuf_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK * 2 * sizeof(float)); // 64KB + pipe->InitBuffer(indexBuf_, s2BaseSize_ * sizeof(int32_t)); // 2KB + pipe->InitBuffer(reduceOutBuf_, s2BaseSize_ * 2 * sizeof(float)); // 4KB + pipe->InitBuffer(brcBuf_, groupInner_ * 8 * sizeof(float)); + pipe->InitBuffer(paramBuf_, LD_PARAM_NUM * sizeof(int64_t)); + + // + globalTopkIndice_ = indexBuf_.Get(); + globalTopkUb_ = sortOutBuf_.Get(); + SortedBasicBlock_ = globalTopkUb_[BASE_TOPK * 2 * 2]; + globalTopkNum_ = 0; + + ArithProgression(globalTopkIndice_, 0, 1, s2BaseSize_); + InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK * 2); + LocalTensor tmpfBuff = outQueue_.AllocTensor(); + Duplicate(tmpfBuff.template ReinterpretCast(), -1, 2 * (s1BaseSize_ / 2) * paramNum_ * 2); + SetWaitFlag(HardEvent::V_MTE3); + int64_t wsInfoOffset = (blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ + + (blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_; + DataCopyPad(vec1ParamGm[wsInfoOffset], tmpfBuff.template ReinterpretCast(), + {1, static_cast((s1BaseSize_ / 2) * 2 * paramNum_ * sizeof(int64_t)), 0, 0}); + SetWaitFlag(HardEvent::MTE3_V); + outQueue_.FreeTensor(tmpfBuff); +} + +template +__aicore__ inline void LIVector::InitLDBuffers(TPipe *pipe) +{ + pipe->Reset(); + pipe->InitBuffer(ldToBeMrgBuf_, 2 * BASE_TOPK * mrgListNum_ * sizeof(float)); // 2:value + index + pipe->InitBuffer(ldTmpBuf_, 2 * BASE_TOPK * mrgListNum_ * sizeof(float)); // 2:value + index + pipe->InitBuffer(ldOutValueBuf_, BASE_TOPK * sizeof(float)); + pipe->InitBuffer(ldOutIdxBuf_, BASE_TOPK * sizeof(int32_t)); +} + +template +__aicore__ inline void LIVector::InitParams(const struct LICommon::ConstInfo &constInfo, + const LITilingData *__restrict tilingData) +{ + this->constInfo_ = constInfo; + blockS2StartIdx_ = 0; + gSize_ = constInfo.gSize; + // define N2 para + kHeadNum_ = constInfo.kHeadNum; + // define MMBase para + s1BaseSize_ = constInfo.s1BaseSize; + s2BaseSize_ = constInfo.s2BaseSize; + + groupInner_ = 16; + blockId_ = GetBlockIdx(); +} + +template +__aicore__ inline void +LIVector::InitVec1GlobalTensor(GlobalTensor mm1ResGm, GlobalTensor vec1ResGm, + GlobalTensor vec1ParamGm, GlobalTensor weightsGm, + GlobalTensor indiceOutGm) +{ + this->mm1ResGm = mm1ResGm; + this->vec1ResGm = vec1ResGm; + this->vec1ParamGm = vec1ParamGm; + this->weightsGm = weightsGm; + this->indiceOutGm = indiceOutGm; +} + +template +__aicore__ inline void LIVector::AllocEventID() +{ +} + +template +__aicore__ inline void LIVector::FreeEventID() +{ +} + +template +__aicore__ inline void LIVector::CleanInvalidOutput(int64_t invalidS1offset) +{ + // init -1 and copy to output + LocalTensor valueULocal = outQueue_.AllocTensor(); + LocalTensor idxULocal1 = valueULocal.template ReinterpretCast(); + Duplicate(idxULocal1, constInfo_.INVALID_IDX, constInfo_.sparseCount); + outQueue_.EnQue(valueULocal); + valueULocal = outQueue_.DeQue(); + LIServiceVec::CopyOut(indiceOutGm[invalidS1offset], idxULocal1, constInfo_.sparseCount); + outQueue_.FreeTensor(valueULocal); +} + +template +__aicore__ inline void LIVector::ProcessVec(const LICommon::RunInfo &info) +{ + int32_t cuBaseS1Idx = info.gS1Idx * s1BaseSize_; + int32_t cuBaseS2Idx = info.s2Idx * s2BaseSize_; + + int64_t mmGmOffset = (info.loop % 2) * ((s1BaseSize_ * gSize_) * s2BaseSize_); + int64_t weightGmOffset = info.tensorWeightsOffset + cuBaseS1Idx * kHeadNum_ * gSize_; + + PipeBarrier(); + int32_t cuS1BeginIdxPerAiv = cuBaseS1Idx; + int32_t cuS1ProcNum = + cuS1BeginIdxPerAiv + s1BaseSize_ > info.actS1Size ? info.actS1Size % s1BaseSize_ : s1BaseSize_; + int32_t cuS1ProcNumPerAiv = blockId_ % 2 == 0 ? CeilDiv(cuS1ProcNum, 2) : (cuS1ProcNum / 2); + cuS1BeginIdxPerAiv += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2); + + weightGmOffset += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2) * kHeadNum_ * gSize_; + mmGmOffset += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2) * gSize_ * info.actualSingleProcessSInnerSizeAlign; + + // cut G + int32_t outerG = CeilDiv(gSize_, groupInner_); + + if (info.loop != 0 && info.s2Idx == 0) { + // globalTopkUb_ value,index=-inf,-1 + InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK * 2); + blockS2StartIdx_ = 0; + } else if (info.loop == 0) { + blockS2StartIdx_ = info.s2Idx; + } + int32_t cuRealAcSeq = info.actS2Size; + if (constInfo_.attenMaskFlag) { + cuRealAcSeq = info.actS2Size - (info.actS1Size - cuS1BeginIdxPerAiv); + } + LocalTensor reduceOutBuff = reduceOutBuf_.Get(); + LocalTensor brcBuf = brcBuf_.Get(); + uint32_t ldS1Offset = (blockId_ % 2 == 0) ? s1BaseSize_ / 2 - cuS1ProcNumPerAiv : 0; + for (int innerS1Idx = 0; innerS1Idx < cuS1ProcNumPerAiv; innerS1Idx++) { + if (constInfo_.attenMaskFlag) { + cuRealAcSeq += 1; + } + int32_t cuS2Len = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq ? cuRealAcSeq - cuBaseS2Idx : s2BaseSize_; + int32_t cuS1Idx = cuS1BeginIdxPerAiv + innerS1Idx; + if (cuRealAcSeq > 0 && cuS2Len > 0) { + int32_t cuS2LenVecAlign = CeilDiv(cuS2Len, s2BaseSize_) * s2BaseSize_; + int32_t mmUbStride = (cuS2LenVecAlign - info.actualSingleProcessSInnerSizeAlign) / B32_BLOCK_ALIGN_NUM; + LocalTensor reduceOutInner = reduceOutBuff[s2BaseSize_]; + PipeBarrier(); + LocalTensor reduceCacheBuf = outQueue_.AllocTensor(); + for (int outerGidx = 0; outerGidx < outerG; outerGidx++) { + int32_t procGnum = outerGidx != outerG - 1 ? groupInner_ : gSize_ - outerGidx * groupInner_; + LocalTensor mmInUb = inQueue_.AllocTensor(); + LocalTensor weightsInUb = mmInUb[procGnum * s2BaseSize_]; + LocalTensor weightsInTUb = weightsInUb.template ReinterpretCast(); + if constexpr (!IsSameType::value) { + weightsInTUb = weightsInTUb[groupInner_]; + } + LIServiceVec::CopyIn(mmInUb, weightsInTUb, mm1ResGm, weightsGm, + mmGmOffset + innerS1Idx * gSize_ * info.actualSingleProcessSInnerSizeAlign + + outerGidx * groupInner_ * info.actualSingleProcessSInnerSizeAlign, + weightGmOffset + innerS1Idx * gSize_ + outerGidx * groupInner_, procGnum, + info.actualSingleProcessSInnerSizeAlign, mmUbStride); + + inQueue_.EnQue(mmInUb); + mmInUb = inQueue_.DeQue(); + weightsInUb = mmInUb[procGnum * s2BaseSize_]; + LIServiceVec::DoScale(reduceCacheBuf[REDUCE_BANK_CONFLICT_NUM], mmInUb, weightsInUb, weightsInTUb, + brcBuf, procGnum, s2BaseSize_, outerGidx); + // confused reduceOp in DoScale + // neednot use LIServiceVec::doReduce(mmInUb, reduceOutInner, procGnum, (s2BaseSize_+8)); + inQueue_.FreeTensor(mmInUb); + } + + int32_t gRedCnt = groupInner_ > gSize_ ? gSize_ : groupInner_; + bool isS2End = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq; + LIServiceVec::DoReduce(reduceCacheBuf[REDUCE_BANK_CONFLICT_NUM], reduceOutInner, gRedCnt, s2BaseSize_); + outQueue_.FreeTensor(reduceCacheBuf); + + LocalTensor sortScoreUb = reduceOutBuff; + LocalTensor sortIndiceUb = reduceOutBuff[cuS2LenVecAlign]; + PipeBarrier(); + Duplicate(sortScoreUb.template ReinterpretCast(), LIServiceVec::NEG_INF, cuS2LenVecAlign); + PipeBarrier(); + Adds(sortScoreUb, reduceOutInner, 0.0f, cuS2Len); + PipeBarrier(); + LocalTensor sortIndiceUbInt = sortIndiceUb.template ReinterpretCast(); + if (cuS2LenVecAlign != cuS2Len) { + Duplicate(sortIndiceUbInt, -1, cuS2LenVecAlign); + } + PipeBarrier(); + Adds(sortIndiceUbInt, globalTopkIndice_, static_cast(cuBaseS2Idx), cuS2Len); + PipeBarrier(); + + LocalTensor tmpSortBuf = outQueue_.AllocTensor(); + if (info.actS1Size > 4) { + LIServiceVec::SortAll(reduceOutBuff, tmpSortBuf, + cuS2LenVecAlign); // cuS2LenVecAlign <= s2BaseSize_, fill -inf + PipeBarrier(); + LIServiceVec::MergeSort(globalTopkUb_[innerS1Idx * BASE_TOPK * 2], BASE_TOPK, reduceOutBuff, + cuS2LenVecAlign, tmpSortBuf); + } else { + int64_t globalTopkUbCacheIdx = (info.s2Idx - blockS2StartIdx_) % 4; + Sort( + SortedBasicBlock_[innerS1Idx * BASE_TOPK * 2 + globalTopkUbCacheIdx * s2BaseSize_ * 2], + reduceOutBuff, sortIndiceUbInt.template ReinterpretCast(), tmpSortBuf, + cuS2LenVecAlign / 32); + if (globalTopkUbCacheIdx == 3 || isS2End || info.isAllLoopEnd) { + LocalTensor tt = SortedBasicBlock_[innerS1Idx * BASE_TOPK * 2]; + if (info.s2Idx - blockS2StartIdx_ < 4) { + MrgBasicBlock(globalTopkUb_[innerS1Idx * BASE_TOPK * 2], tt, + static_cast(globalTopkUbCacheIdx + 1), s2BaseSize_); + } else { + if (globalTopkUbCacheIdx > 0) { + MrgBasicBlock(tmpSortBuf, tt, static_cast(globalTopkUbCacheIdx + 1), s2BaseSize_); + PipeBarrier(); + DataCopy(SortedBasicBlock_[innerS1Idx * BASE_TOPK * 2], tmpSortBuf, + (globalTopkUbCacheIdx + 1) * s2BaseSize_ * 2); + } + PipeBarrier(); + SparseTopK(globalTopkUb_[innerS1Idx * BASE_TOPK * 2], + SortedBasicBlock_[innerS1Idx * BASE_TOPK * 2], tmpSortBuf, BASE_TOPK, + s2BaseSize_ * (globalTopkUbCacheIdx + 1)); + } + } + } + + PipeBarrier(); + outQueue_.FreeTensor(tmpSortBuf); + + bool needCopyOutGm = blockS2StartIdx_ == 0 && isS2End; + bool needCopyWsGm = info.isAllLoopEnd || isS2End; + + if (needCopyOutGm) { + LocalTensor valueULocal = outQueue_.AllocTensor(); + LocalTensor idxULocal = valueULocal.template ReinterpretCast()[BASE_TOPK]; + ExtractIndex(idxULocal, globalTopkUb_[innerS1Idx * BASE_TOPK * 2].template ReinterpretCast(), + BASE_TOPK); + PipeBarrier(); + InitSortOutBuf(globalTopkUb_[innerS1Idx * BASE_TOPK * 2], BASE_TOPK * 2); + outQueue_.EnQue(valueULocal); + valueULocal = outQueue_.DeQue(); + LocalTensor idxULocal1 = valueULocal.template ReinterpretCast()[BASE_TOPK]; + LIServiceVec::CopyOut(indiceOutGm[info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount], + idxULocal1, constInfo_.sparseCount); + outQueue_.FreeTensor(valueULocal); + } else if (needCopyWsGm) { + // vec1Res Gm = [aic, s1BaseSize_, 2, 2, topkOut_] float32 + // vec1Param Gm = [aic, s1BaseSize_, 2, 16] int64 + // 16 = [needFd, s2AcSeq, s2Start, s2End, isS2End, bn2idx, s1Idx, S1ProcNum, ......] + + int64_t wsOffset = (blockId_ / 2) * s1BaseSize_ * 2 * 2 * BASE_TOPK + + (blockId_ % 2) * (s1BaseSize_ / 2) * 2 * 2 * BASE_TOPK + + (ldS1Offset + innerS1Idx) * 2 * 2 * BASE_TOPK; + int64_t wsInfoOffset = (blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ + + (blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_ + + (ldS1Offset + innerS1Idx) * 2 * paramNum_; + + LocalTensor tmpiBuff = paramBuf_.Get(); + SetWaitFlag(HardEvent::MTE3_S); + tmpiBuff.SetValue(0, static_cast(1)); + tmpiBuff.SetValue(1, static_cast(cuRealAcSeq)); + tmpiBuff.SetValue(2, static_cast(blockS2StartIdx_)); + tmpiBuff.SetValue(3, static_cast(cuBaseS2Idx + cuS2Len)); + tmpiBuff.SetValue(4, static_cast(isS2End)); + tmpiBuff.SetValue(5, static_cast(info.bN2Idx)); + tmpiBuff.SetValue(6, static_cast(cuS1Idx)); + tmpiBuff.SetValue(7, static_cast(cuS1ProcNum)); + tmpiBuff.SetValue(8, static_cast(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount)); + bool isTailReduce = blockS2StartIdx_ == 0; + if (isTailReduce) { + wsInfoOffset += paramNum_; + wsOffset += 2 * BASE_TOPK; + } + SetWaitFlag(HardEvent::S_MTE3); + LIServiceVec::CopyOut(vec1ParamGm[wsInfoOffset], tmpiBuff, 16); + SetWaitFlag(HardEvent::V_MTE3); + LIServiceVec::CopyOut(vec1ResGm[wsOffset], globalTopkUb_[innerS1Idx * BASE_TOPK * 2], 2 * BASE_TOPK); + SetWaitFlag(HardEvent::MTE3_V); + } + } else if (cuRealAcSeq <= 0) { + CleanInvalidOutput(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount); + } + } + + if (LAYOUT_T == LI_LAYOUT::BSND) { + bool isS1LoopEnd = (cuBaseS1Idx + s1BaseSize_) >= info.actS1Size; + int32_t invalidS1Num = constInfo_.qSeqSize - info.actS1Size; + if (invalidS1Num > 0 && isS1LoopEnd && blockS2StartIdx_ == 0) { + int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num, 2) : (invalidS1Num / 2); + int32_t s1OffsetPerAiv = info.actS1Size + (blockId_ % 2) * CeilDiv(invalidS1Num, 2); + for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) { + CleanInvalidOutput(info.indiceOutOffset + (s1OffsetPerAiv + innerS1Idx) * constInfo_.sparseCount); + } + } + + int32_t invalidS1Num2 = info.actS1Size - info.actS2Size; + if (invalidS1Num2 > 0 && isS1LoopEnd && blockS2StartIdx_ == 0 && constInfo_.attenMaskFlag) { + int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num2, 2) : (invalidS1Num2 / 2); + int32_t s1OffsetPerAiv = (blockId_ % 2) * CeilDiv(invalidS1Num2, 2); + for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) { + CleanInvalidOutput((info.bN2Idx * constInfo_.qSeqSize + s1OffsetPerAiv + innerS1Idx) * + constInfo_.sparseCount); + } + } + } + + if (info.isLastS2InnerLoop) { + blockS2StartIdx_ = 0; + } +} + +template +__aicore__ inline void LIVector::ProcessLD() +{ + int32_t curCubeId = blockId_ / 2; + int32_t tmpCubeId = curCubeId; + + int64_t s2ActSeq; + int64_t s2Start; + int64_t s2End; + int64_t isS2End; + int64_t bn2Idx; + int64_t s1Idx; + uint32_t acc_list_num = 0; + int64_t bIdx = 0; + int64_t needFd; + int64_t wsOffset; + int64_t wsInfoOffset = 0; + int64_t nextneedFd; + int64_t valueOffset = 0; + int64_t outOffset = 0; + + LocalTensor curValueIdxUb = ldToBeMrgBuf_.Get(); + LocalTensor tmpUb = ldTmpBuf_.Get(); + + uint32_t s1LdStartIdx = 0; + uint32_t s1ProcNum = 0; + uint64_t paramGmCoreOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_; + for (uint32_t innerS1Idx = 0; innerS1Idx < s1BaseSize_; innerS1Idx++) { + needFd = vec1ParamGm.GetValue(paramGmCoreOffset + innerS1Idx * 2 * paramNum_ + paramNum_); + if (needFd == 1) { + s1LdStartIdx = (s1ProcNum == 0) ? innerS1Idx : s1LdStartIdx; + s1ProcNum++; + } + } + + if (s1ProcNum == 0) { + return; + } + + uint32_t s1VecNum = CeilDiv(s1ProcNum, 2); + if (blockId_ % 2 == 1) { + s1LdStartIdx = s1LdStartIdx + s1VecNum; + s1VecNum = s1ProcNum - s1VecNum; + } + for (uint32_t innerS1Idx = s1LdStartIdx; innerS1Idx < s1LdStartIdx + s1VecNum; innerS1Idx++) { + tmpCubeId = curCubeId; + acc_list_num = 0; + valueOffset = 0; + + wsOffset = tmpCubeId * s1BaseSize_ * 2 * 2 * BASE_TOPK + + innerS1Idx * 2 * 2 * BASE_TOPK + 2 * BASE_TOPK; + SetWaitFlag(HardEvent::V_MTE2); + SetWaitFlag(HardEvent::S_MTE2); + DataCopyPad(curValueIdxUb, vec1ResGm[wsOffset], + {1, static_cast(2 * BASE_TOPK * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0}); + acc_list_num++; + valueOffset += 2 * BASE_TOPK; + + tmpCubeId++; + wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_; + needFd = vec1ParamGm.GetValue(wsInfoOffset); + isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4); + s1Idx = vec1ParamGm.GetValue(wsInfoOffset + 6); + outOffset = vec1ParamGm.GetValue(wsInfoOffset + 8); + + while (needFd == 1) { + wsOffset = tmpCubeId * s1BaseSize_ * 2 * 2 * BASE_TOPK + + innerS1Idx * 2 * 2 * BASE_TOPK; + SetWaitFlag(HardEvent::V_MTE2); + SetWaitFlag(HardEvent::S_MTE2); + DataCopyPad(curValueIdxUb[valueOffset], vec1ResGm[wsOffset], + {1, static_cast(2 * BASE_TOPK * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0}); + valueOffset += 2 * BASE_TOPK; + acc_list_num++; + + if (acc_list_num == mrgListNum_) { + AscendC::MrgSort4Info params; + params.elementLengths[0] = BASE_TOPK; + params.elementLengths[1] = BASE_TOPK; + params.elementLengths[2] = BASE_TOPK; + params.elementLengths[3] = BASE_TOPK; + params.ifExhaustedSuspension = true; + params.validBit = 0b1111; + params.repeatTimes = 1; + + AscendC::MrgSortSrcList srcList; + srcList.src1 = curValueIdxUb[0]; + srcList.src2 = curValueIdxUb[2 * BASE_TOPK]; + srcList.src3 = curValueIdxUb[4 * BASE_TOPK]; + srcList.src4 = curValueIdxUb[6 * BASE_TOPK]; + SetWaitFlag(HardEvent::MTE2_V); + MrgSort(tmpUb, srcList, params); + PipeBarrier(); + DataCopy(curValueIdxUb, tmpUb, 2 * BASE_TOPK); + PipeBarrier(); + acc_list_num = 1; + valueOffset = 2 * BASE_TOPK; + } + + if (isS2End == 1) { + break; + } + + tmpCubeId++; + wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_; + needFd = vec1ParamGm.GetValue(wsInfoOffset); + isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4); + } + + if (acc_list_num != 1) { + AscendC::MrgSort4Info params; + params.elementLengths[0] = BASE_TOPK; + params.elementLengths[1] = BASE_TOPK; + params.elementLengths[2] = BASE_TOPK; + params.elementLengths[3] = BASE_TOPK; + params.ifExhaustedSuspension = true; + if (acc_list_num == 2) { + params.validBit = 0b0011; + } else if (acc_list_num == 3) { + params.validBit = 0b0111; + } + params.repeatTimes = 1; + + AscendC::MrgSortSrcList srcList; + srcList.src1 = curValueIdxUb[0]; + srcList.src2 = curValueIdxUb[2 * BASE_TOPK]; + srcList.src3 = curValueIdxUb[4 * BASE_TOPK]; + srcList.src4 = curValueIdxUb[6 * BASE_TOPK]; + SetWaitFlag(HardEvent::MTE2_V); + MrgSort(tmpUb, srcList, params); + PipeBarrier(); + DataCopy(curValueIdxUb, tmpUb, 2 * BASE_TOPK); + PipeBarrier(); + } + + LocalTensor outValueUb = ldOutValueBuf_.Get(); + LocalTensor outIdxUb = ldOutIdxBuf_.Get(); + + Extract(outValueUb, outIdxUb, curValueIdxUb, (BASE_TOPK / 32)); + LocalTensor idxULocal1 = outIdxUb.template ReinterpretCast(); + SetWaitFlag(HardEvent::V_MTE3); + SetWaitFlag(HardEvent::S_MTE3); + DataCopyPad(indiceOutGm[outOffset], idxULocal1, + {1, static_cast(constInfo_.sparseCount * sizeof(int32_t)), 0, 0}); + SetWaitFlag(HardEvent::MTE3_V); + } +} +} // namespace LIKernel +#endif \ No newline at end of file diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_template_tiling_key.h b/csrc/lightning_indexer/op_kernel/lightning_indexer_template_tiling_key.h new file mode 100644 index 00000000..a4ce580a --- /dev/null +++ b/csrc/lightning_indexer/op_kernel/lightning_indexer_template_tiling_key.h @@ -0,0 +1,66 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_template_tiling_key.h + * \brief + */ + +#ifndef TEMPLATE_TILING_KEY_LI_H_ +#define TEMPLATE_TILING_KEY_LI_H_ + +#include "ascendc/host_api/tiling/template_argument.h" + +#define LI_TPL_FP16 1 +#define LI_TPL_INT32 3 +#define LI_TPL_BF16 27 + +#define LI_LAYOUT_BSND 0 +#define LI_LAYOUT_TND 1 +#define LI_LAYOUT_PA_BSND 2 + +#define ASCENDC_TPL_4_BW 4 + +ASCENDC_TPL_ARGS_DECL(LightningIndexer, + ASCENDC_TPL_DTYPE_DECL(DT_Q, LI_TPL_FP16, LI_TPL_BF16), + ASCENDC_TPL_DTYPE_DECL(DT_K, LI_TPL_FP16, LI_TPL_BF16), + ASCENDC_TPL_DTYPE_DECL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_DECL(PAGE_ATTENTION, 0, 1), + ASCENDC_TPL_UINT_DECL(LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, + LI_LAYOUT_TND), + ASCENDC_TPL_UINT_DECL(K_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, + LI_LAYOUT_PA_BSND, LI_LAYOUT_BSND, LI_LAYOUT_TND), ); + +ASCENDC_TPL_SEL( + ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_FP16), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_FP16), + ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), + ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 1), + ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_PA_BSND), ), + + ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_BF16), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_BF16), + ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), + ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 1), + ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_PA_BSND), ), + + ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_FP16), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_FP16), + ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), + ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 0), + ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, + LI_LAYOUT_BSND, LI_LAYOUT_TND), ), + + ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_BF16), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_BF16), + ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), + ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 0), + ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND), ), ); + +#endif \ No newline at end of file diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_vector.h b/csrc/lightning_indexer/op_kernel/lightning_indexer_vector.h new file mode 100644 index 00000000..96290127 --- /dev/null +++ b/csrc/lightning_indexer/op_kernel/lightning_indexer_vector.h @@ -0,0 +1,335 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_vector.h + * \brief + */ +#ifndef LIGHTNING_INDEXER_VECTOR_H +#define LIGHTNING_INDEXER_VECTOR_H + +#include "lightning_indexer_vector.h" +#include "kernel_operator.h" + +namespace LIServiceVec { +using namespace AscendC; + +constexpr int32_t NEG_INF = 0xFF800000; +constexpr int32_t INVALID_INDEX = -1; +constexpr uint8_t VEC_REPEAT_MAX = 255; +constexpr uint8_t B32_VEC_ELM_NUM = 64; +constexpr uint8_t B32_BLOCK_ALIGN_NUM = 8; +constexpr uint8_t B32_VEC_REPEAT_STRIDE = 8; +constexpr uint64_t VEC_REPEAT_BYTES = 256; +constexpr int32_t CONST_TWO = 2; +constexpr int64_t VALUE_AND_INDEX_NUM = 2; +constexpr int64_t BLOCK_BYTES = 32; +constexpr int64_t MRG_QUE_0 = 0; +constexpr int64_t MRG_QUE_1 = 1; +constexpr int64_t MRG_QUE_2 = 2; +constexpr int64_t MRG_QUE_3 = 3; +constexpr int64_t MRG_BLOCK_2 = 2; +constexpr int64_t MRG_BLOCK_3 = 3; +constexpr int64_t MRG_BLOCK_4 = 4; + +template +__aicore__ inline void CopyIn(LocalTensor &mmOutUb, LocalTensor &weightsUb, GlobalTensor &mMoutGm, + GlobalTensor &weightScaleGm, int64_t MMout_gmoffset, int64_t weights_gmoffset, + int64_t groupInner, int64_t s2Inner, int64_t mmUbStride) +{ + AscendC::DataCopyPadExtParams padParams{false, 0, 0, 0}; + AscendC::DataCopyExtParams dataCopymMoutParams; + dataCopymMoutParams.blockCount = groupInner; + dataCopymMoutParams.blockLen = s2Inner * sizeof(float); + dataCopymMoutParams.srcStride = 0; + dataCopymMoutParams.dstStride = mmUbStride; + dataCopymMoutParams.rsv = 0; + AscendC::DataCopyPad(mmOutUb, mMoutGm[MMout_gmoffset], dataCopymMoutParams, padParams); + + AscendC::DataCopyPadExtParams padTParams{false, 0, 0, 0}; + AscendC::DataCopyExtParams dataCopyweightParams; + dataCopyweightParams.blockCount = 1; + dataCopyweightParams.blockLen = groupInner * sizeof(T); + dataCopyweightParams.srcStride = 0; + dataCopyweightParams.dstStride = 0; + dataCopyweightParams.rsv = 0; + AscendC::DataCopyPad(weightsUb, weightScaleGm[weights_gmoffset], dataCopyweightParams, padTParams); +} + + +template +__aicore__ inline void CopyOut(const GlobalTensor &dstGm, const LocalTensor &srcUb, int64_t copyCount) +{ + AscendC::DataCopyParams dataCopyOutyParams; + dataCopyOutyParams.blockCount = 1; + dataCopyOutyParams.blockLen = copyCount * sizeof(T); + dataCopyOutyParams.srcStride = 0; + dataCopyOutyParams.dstStride = 0; + AscendC::DataCopyPad(dstGm, srcUb, dataCopyOutyParams); +} + + +template +__aicore__ inline void DoScale(const LocalTensor &reduceCacheBuf, LocalTensor &mmOutUb, + LocalTensor &weightsUb, LocalTensor &weightsTUb, LocalTensor &tmpBuff, + int64_t groupInner, int64_t s2Inner, int32_t outerGidx) +{ + // cast bfloat16_t to float + if constexpr (!IsSameType::value) { + AscendC::Cast(weightsUb, weightsTUb, RoundMode::CAST_NONE, groupInner); + AscendC::PipeBarrier(); + } + + // weight broadcast: [groupInner, 1] -> [groupInner, 8] + AscendC::Brcb(tmpBuff, weightsUb, LICommon::CeilDiv(groupInner, static_cast(B32_BLOCK_ALIGN_NUM)), + {1, B32_VEC_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + + // do scale: [groupInner, 8] * [groupInner, s2Inner] + uint64_t countPerRepeat = VEC_REPEAT_BYTES / sizeof(float); + uint64_t repeatTimes = s2Inner / countPerRepeat; + for (int32_t i = 0; i < groupInner; i++) { + if (outerGidx == 0) { + AscendC::Mul(reduceCacheBuf[i * s2Inner], mmOutUb[i * s2Inner], tmpBuff[i * B32_BLOCK_ALIGN_NUM], + countPerRepeat, repeatTimes, {1, 1, 0, B32_VEC_REPEAT_STRIDE, B32_VEC_REPEAT_STRIDE, 0}); + } else { + AscendC::Mul(mmOutUb[i * s2Inner], mmOutUb[i * s2Inner], tmpBuff[i * B32_BLOCK_ALIGN_NUM], countPerRepeat, + repeatTimes, {1, 1, 0, B32_VEC_REPEAT_STRIDE, B32_VEC_REPEAT_STRIDE, 0}); + } + } + + if (outerGidx != 0) { + AscendC::PipeBarrier(); + AscendC::Add(reduceCacheBuf, mmOutUb, reduceCacheBuf, groupInner * s2Inner); + } + AscendC::PipeBarrier(); +} + + +__aicore__ inline uint64_t FindNearestPower2(uint64_t value) +{ + if (value <= CONST_TWO) { + return value; + } else { + const uint64_t pow = 63 - clz(value); + return (1 << pow); + } +} + + +__aicore__ inline void DoReduce(const LocalTensor &srcTensor, LocalTensor &dstTensor, int32_t rNum, + int32_t aNum) +{ + if (rNum == 1) { + AscendC::Adds(dstTensor, srcTensor, 0, aNum); + AscendC::PipeBarrier(); + return; + } + + uint32_t dichotomizeAddPow = FindNearestPower2(rNum); + uint32_t dichotomizeAddDiffSize = rNum - dichotomizeAddPow; + if (dichotomizeAddDiffSize != 0) { + AscendC::Add(srcTensor, srcTensor, srcTensor[dichotomizeAddPow * aNum], dichotomizeAddDiffSize * aNum); + AscendC::PipeBarrier(); + } + int32_t nowRows = dichotomizeAddPow; + while (nowRows > CONST_TWO) { + nowRows = nowRows / CONST_TWO; + AscendC::Add(srcTensor, srcTensor, srcTensor[nowRows * aNum], nowRows * aNum); + AscendC::PipeBarrier(); + } + AscendC::Add(dstTensor, srcTensor, srcTensor[aNum], aNum); + AscendC::PipeBarrier(); +} + +__aicore__ inline void InitSortOutBuf(const LocalTensor &src, int64_t eleNum) +{ + uint64_t mask1[2] = {0x5555555555555555, 0}; + uint64_t mask0[2] = {0xaaaaaaaaaaaaaaaa, 0}; + int64_t repeatNum = eleNum / B32_VEC_ELM_NUM; + int64_t forLoop = repeatNum / VEC_REPEAT_MAX; + int64_t forRemain = repeatNum % VEC_REPEAT_MAX; + for (int i = 0; i < forLoop; i++) { + AscendC::Duplicate(src.template ReinterpretCast(), NEG_INF, mask1, VEC_REPEAT_MAX, 1, + B32_VEC_REPEAT_STRIDE); + AscendC::Duplicate(src.template ReinterpretCast(), INVALID_INDEX, mask0, VEC_REPEAT_MAX, 1, + B32_VEC_REPEAT_STRIDE); + } + if (forRemain > 0) { + AscendC::Duplicate(src.template ReinterpretCast()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], NEG_INF, + mask1, forRemain, 1, B32_VEC_REPEAT_STRIDE); + AscendC::Duplicate(src.template ReinterpretCast()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], + INVALID_INDEX, mask0, forRemain, 1, B32_VEC_REPEAT_STRIDE); + } + AscendC::PipeBarrier(); +} + +__aicore__ inline void SortAll(LocalTensor &src, LocalTensor &tmp, int64_t logitsNum) +{ + int64_t sort32Repeats = logitsNum / BLOCK_BYTES; + AscendC::Sort32(tmp, src, src[logitsNum].ReinterpretCast(), sort32Repeats); + AscendC::PipeBarrier(); + + int64_t mrgGroups = sort32Repeats; + int64_t mrgElements = BLOCK_BYTES; + int64_t i = 0; + AscendC::LocalTensor srcTensor; + AscendC::LocalTensor dstTensor; + while (true) { + if (i % CONST_TWO == 0) { + srcTensor = tmp; + dstTensor = src; + } else { + srcTensor = src; + dstTensor = tmp; + } + AscendC::MrgSort4Info params; + params.elementLengths[0] = mrgElements; + params.elementLengths[MRG_QUE_1] = mrgElements; + params.elementLengths[MRG_QUE_2] = mrgElements; + params.elementLengths[MRG_QUE_3] = mrgElements; + params.ifExhaustedSuspension = false; + params.validBit = 0b1111; + + AscendC::MrgSortSrcList srcList; + srcList.src1 = srcTensor[0]; + srcList.src2 = srcTensor[MRG_QUE_1 * VALUE_AND_INDEX_NUM * mrgElements]; + srcList.src3 = srcTensor[MRG_QUE_2 * VALUE_AND_INDEX_NUM * mrgElements]; + srcList.src4 = srcTensor[MRG_QUE_3 * VALUE_AND_INDEX_NUM * mrgElements]; + if (mrgGroups <= MRG_BLOCK_4) { + params.repeatTimes = 1; + if (mrgGroups == 1) { + break; + } else if (mrgGroups == MRG_BLOCK_2) { + params.validBit = 0b0011; + } else if (mrgGroups == MRG_BLOCK_3) { + params.validBit = 0b0111; + } else if (mrgGroups == MRG_BLOCK_4) { + params.validBit = 0b1111; + } + AscendC::MrgSort(dstTensor, srcList, params); + i += 1; + break; + } else { + params.repeatTimes = mrgGroups / MRG_BLOCK_4; + AscendC::MrgSort(dstTensor, srcList, params); + i += 1; + mrgElements = mrgElements * MRG_BLOCK_4; + mrgGroups = mrgGroups / MRG_BLOCK_4; + } + AscendC::PipeBarrier(); + } + if (i % CONST_TWO == 0) { + AscendC::DataCopy(src, tmp, logitsNum * VALUE_AND_INDEX_NUM); + AscendC::PipeBarrier(); + } +} + +__aicore__ inline void SortAll(LocalTensor &dst, LocalTensor &srcValue, LocalTensor &srcIndex, + LocalTensor &tmpTensor, int64_t logitsNum) +{ + int64_t sort32Repeats = logitsNum / BLOCK_BYTES; + AscendC::Sort(dst, srcValue, srcIndex, tmpTensor, sort32Repeats); + AscendC::PipeBarrier(); +} + +__aicore__ inline void MergeSort(const LocalTensor &mrgDst, int32_t mrgDstNum, LocalTensor &mrgSrc, + int32_t mrgSrcNum, LocalTensor &tmpTensor) +{ + AscendC::MrgSort4Info params; + params.elementLengths[0] = mrgDstNum; + params.elementLengths[1] = mrgSrcNum; + params.ifExhaustedSuspension = false; + params.validBit = 0b0011; + params.repeatTimes = 1; + + AscendC::MrgSortSrcList srcList; + srcList.src1 = mrgDst; + srcList.src2 = mrgSrc; + + AscendC::MrgSort(tmpTensor, srcList, params); + AscendC::PipeBarrier(); + AscendC::DataCopy(mrgDst, tmpTensor, mrgDstNum * VALUE_AND_INDEX_NUM); + AscendC::PipeBarrier(); +} + +__aicore__ inline void MrgBasicBlock(const LocalTensor &dst, const LocalTensor &src, int64_t blockNum, + int64_t basicBlockSize) +{ + AscendC::MrgSort4Info params; + params.elementLengths[MRG_QUE_0] = basicBlockSize; + params.elementLengths[MRG_QUE_1] = basicBlockSize; + params.elementLengths[MRG_QUE_2] = basicBlockSize; + params.elementLengths[MRG_QUE_3] = basicBlockSize; + params.ifExhaustedSuspension = false; + if (blockNum == MRG_BLOCK_2) { + params.validBit = 0b0011; + } else if (blockNum == MRG_BLOCK_3) { + params.validBit = 0b0111; + } else if (blockNum == MRG_BLOCK_4) { + params.validBit = 0b1111; + } else { + AscendC::DataCopy(dst, src, basicBlockSize * VALUE_AND_INDEX_NUM); + return; + } + AscendC::MrgSortSrcList srcList; + srcList.src1 = src[0]; + srcList.src2 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_1]; + srcList.src3 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_2]; + srcList.src4 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_3]; + AscendC::MrgSort(dst, srcList, params); +} + +template +__aicore__ inline void SparseTopK(const LocalTensor &dst, const LocalTensor &needsMerging, + const LocalTensor &tmp, int64_t topk, int64_t mergSize) +{ + if (!needMrg) { + AscendC::DataCopy(dst, needsMerging, mergSize * VALUE_AND_INDEX_NUM); + return; + } + AscendC::MrgSort4Info params; + params.elementLengths[0] = topk; + params.elementLengths[1] = mergSize; + params.ifExhaustedSuspension = (topk == mergSize); + params.validBit = 0b0011; + AscendC::MrgSortSrcList srcList; + srcList.src1 = dst; + srcList.src2 = needsMerging; + AscendC::MrgSort(tmp, srcList, params); + AscendC::DataCopy(dst, tmp, topk * VALUE_AND_INDEX_NUM); +} + + +__aicore__ inline void ExtractIndex(const LocalTensor &idxULocal, const LocalTensor &sortLocal, + int64_t extractNum) +{ + AscendC::GatherMaskParams gatherMaskParams; + gatherMaskParams.repeatTimes = Ceil(extractNum * sizeof(float) * VALUE_AND_INDEX_NUM, VEC_REPEAT_BYTES); + gatherMaskParams.src0BlockStride = 1; + gatherMaskParams.src0RepeatStride = B32_VEC_REPEAT_STRIDE; + gatherMaskParams.src1RepeatStride = 0; + uint64_t rsvdCnt = 0; + uint8_t src1Pattern = 2; + AscendC::GatherMask(idxULocal, sortLocal, src1Pattern, false, static_cast(0), gatherMaskParams, rsvdCnt); + AscendC::PipeBarrier(); +} + + +template +__aicore__ inline void SetWaitFlag(HardEvent evt) +{ + event_t eventId = static_cast(GetTPipePtr()->FetchEventID(evt)); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); +} + +} // namespace LIServiceVec +#endif // LIGHTNING_INDEXER_VECTOR_H \ No newline at end of file diff --git a/csrc/sparse_flash_attention/op_host/CMakeLists.txt b/csrc/sparse_flash_attention/op_host/CMakeLists.txt new file mode 100644 index 00000000..ad24f34e --- /dev/null +++ b/csrc/sparse_flash_attention/op_host/CMakeLists.txt @@ -0,0 +1,39 @@ +# This program is free software, you can redistribute it and/or modify it. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_ops_compile_options( + OP_NAME SparseFlashAttention + OPTIONS --cce-auto-sync=off + -Wno-deprecated-declarations + -Werror + -fpermissive +) + +set(sparse_flash_attention_depends transformer/attention/sparse_flash_attention PARENT_SCOPE) +target_sources(op_host_aclnn PRIVATE + sparse_flash_attention_def.cpp +) + +target_sources(optiling PRIVATE + sparse_flash_attention_tiling.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(opmaster_ct PRIVATE + sparse_flash_attention_tiling.cpp + ) +endif () + +target_sources(opsproto PRIVATE + sparse_flash_attention_proto.cpp +) + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) diff --git a/csrc/sparse_flash_attention/op_host/sparse_flash_attention_def.cpp b/csrc/sparse_flash_attention/op_host/sparse_flash_attention_def.cpp new file mode 100644 index 00000000..dbf6879e --- /dev/null +++ b/csrc/sparse_flash_attention/op_host/sparse_flash_attention_def.cpp @@ -0,0 +1,90 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_def.cpp + * \brief + */ + +#include "register/op_def_registry.h" + +namespace ops { +class SparseFlashAttention : public OpDef { +public: + explicit SparseFlashAttention(const char *name) : OpDef(name) + { + this->Input("query") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("key") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("value") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("sparse_indices") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("block_table") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("actual_seq_lengths_query") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("actual_seq_lengths_kv") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("query_rope") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("key_rope") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Output("attention_out") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("scale_value").AttrType(REQUIRED).Float(1.0); + this->Attr("sparse_block_size").AttrType(REQUIRED).Int(1); + this->Attr("layout_query").AttrType(OPTIONAL).String("BSND"); + this->Attr("layout_kv").AttrType(OPTIONAL).String("BSND"); + this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3); + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true) + .ExtendCfgInfo("aclnnSupport.value", "support_aclnn"); + this->AICore().AddConfig("ascend910b", aicore_config); + this->AICore().AddConfig("ascend910_93", aicore_config); + } +}; +OP_ADD(SparseFlashAttention); +} // namespace ops diff --git a/csrc/sparse_flash_attention/op_host/sparse_flash_attention_proto.cpp b/csrc/sparse_flash_attention/op_host/sparse_flash_attention_proto.cpp new file mode 100644 index 00000000..07d2091b --- /dev/null +++ b/csrc/sparse_flash_attention/op_host/sparse_flash_attention_proto.cpp @@ -0,0 +1,48 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_proto.cpp + * \brief + */ + +#include +#include +#include "error/ops_error.h" + +using namespace ge; + +namespace ops { +constexpr size_t QUERY_INPUT_INDEX = 0; + +ge::graphStatus InferShapeSparseFlashAttention(gert::InferShapeContext *context) +{ + OPS_ERR_IF(context == nullptr, OPS_LOG_E("SparseFlashAttention", "InferShapeContext is nullptr"), + return ge::GRAPH_FAILED); + const gert::Shape *queryShape = context->GetInputShape(QUERY_INPUT_INDEX); + OPS_LOG_E_IF_NULL(context, queryShape, return ge::GRAPH_FAILED) + gert::Shape *attentionOutShape = context->GetOutputShape(0); + OPS_LOG_E_IF_NULL(context, attentionOutShape, return ge::GRAPH_FAILED) + *attentionOutShape = *queryShape; + return GRAPH_SUCCESS; +} + +ge::graphStatus InferDataTypeSparseFlashAttention(gert::InferDataTypeContext *context) +{ + OPS_ERR_IF(context == nullptr, OPS_LOG_E("SparseFlashAttention", "InferShapeContext is nullptr"), + return ge::GRAPH_FAILED); + const auto inputDataType = context->GetInputDataType(QUERY_INPUT_INDEX); + context->SetOutputDataType(0, inputDataType); + return ge::GRAPH_SUCCESS; +} + +IMPL_OP(SparseFlashAttention).InferShape(InferShapeSparseFlashAttention).InferDataType(InferDataTypeSparseFlashAttention); +} // namespace ops + diff --git a/csrc/sparse_flash_attention/op_host/sparse_flash_attention_tiling.cpp b/csrc/sparse_flash_attention/op_host/sparse_flash_attention_tiling.cpp new file mode 100644 index 00000000..11a0ddee --- /dev/null +++ b/csrc/sparse_flash_attention/op_host/sparse_flash_attention_tiling.cpp @@ -0,0 +1,1845 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_tiling.cpp + * \brief + */ + +#include +#include +#include +#include +#include +#include "error/ops_error.h" +#include "register/op_def_registry.h" +#include "../op_kernel/sparse_flash_attention_template_tiling_key.h" +#include "sparse_flash_attention_tiling.h" + +using std::map; +using std::string; +using std::pair; + +using namespace ge; +using namespace AscendC; +namespace optiling { + +constexpr uint32_t PRE_LOAD_NUM = 2; +constexpr uint32_t BLOCK_TABLE_ELEM_BYTE = 4; +constexpr int32_t SPARSE_MODE_BAND = 4; + +static const std::string QUERY_NAME = "query"; +static const std::string KEY_NAME = "key"; +static const std::string VALUE_NAME = "value"; +static const std::string BLOCK_TABLE_NAME = "block_table"; +static const std::string SPARSE_INDICES_NAME = "sparse_indices"; +static const std::string QUERY_ROPE_NAME = "query_rope"; +static const std::string KEY_ROPE_NAME = "key_rope"; +static const std::string ATTEN_OUT_NAME = "attention_out"; + +const std::map> DTYPE_SUPPORT_MAP = { + {QUERY_NAME, {ge::DT_FLOAT16, ge::DT_BF16}}, + {KEY_NAME, {ge::DT_FLOAT16, ge::DT_BF16}}, + {VALUE_NAME, {ge::DT_FLOAT16, ge::DT_BF16}}, + {QUERY_ROPE_NAME, {ge::DT_FLOAT16, ge::DT_BF16}}, + {KEY_ROPE_NAME, {ge::DT_FLOAT16, ge::DT_BF16}}, + {ATTEN_OUT_NAME, {ge::DT_FLOAT16, ge::DT_BF16}}, + {SPARSE_INDICES_NAME, {ge::DT_INT32}} +}; + +const std::map> LAYOUT_SUPPORT_MAP = { + {QUERY_NAME, {SFALayout::BSND, SFALayout::TND}}, + {KEY_NAME, {SFALayout::BSND, SFALayout::TND, SFALayout::PA_BSND}}, + {VALUE_NAME, {SFALayout::BSND, SFALayout::TND, SFALayout::PA_BSND}}, + {ATTEN_OUT_NAME, {SFALayout::BSND, SFALayout::TND}}, +}; + +const std::map DATATYPE_TO_STRING_MAP = { + {ge::DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. + {ge::DT_FLOAT, "DT_FLOAT"}, // float type + {ge::DT_FLOAT16, "DT_FLOAT16"}, // fp16 type + {ge::DT_INT8, "DT_INT8"}, // int8 type + {ge::DT_INT16, "DT_INT16"}, // int16 type + {ge::DT_UINT16, "DT_UINT16"}, // uint16 type + {ge::DT_UINT8, "DT_UINT8"}, // uint8 type + {ge::DT_INT32, "DT_INT32"}, // uint32 type + {ge::DT_INT64, "DT_INT64"}, // int64 type + {ge::DT_UINT32, "DT_UINT32"}, // unsigned int32 + {ge::DT_UINT64, "DT_UINT64"}, // unsigned int64 + {ge::DT_BOOL, "DT_BOOL"}, // bool type + {ge::DT_DOUBLE, "DT_DOUBLE"}, // double type + {ge::DT_DUAL, "DT_DUAL"}, // dual output type + {ge::DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type + {ge::DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type + {ge::DT_COMPLEX32, "DT_COMPLEX32"}, // complex32 type + {ge::DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type + {ge::DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type + {ge::DT_QINT8, "DT_QINT8"}, // qint8 type + {ge::DT_QINT16, "DT_QINT16"}, // qint16 type + {ge::DT_QINT32, "DT_QINT32"}, // qint32 type + {ge::DT_QUINT8, "DT_QUINT8"}, // quint8 type + {ge::DT_QUINT16, "DT_QUINT16"}, // quint16 type + {ge::DT_RESOURCE, "DT_RESOURCE"}, // resource type + {ge::DT_STRING_REF, "DT_STRING_REF"}, // string ref type + {ge::DT_STRING, "DT_STRING"}, // string type + {ge::DT_VARIANT, "DT_VARIANT"}, // dt_variant type + {ge::DT_BF16, "DT_BFLOAT16"}, // dt_bfloat16 type + {ge::DT_INT4, "DT_INT4"}, // dt_variant type + {ge::DT_UINT1, "DT_UINT1"}, // dt_variant type + {ge::DT_INT2, "DT_INT2"}, // dt_variant type + {ge::DT_UINT2, "DT_UINT2"} // dt_variant type +}; + +struct SparseFlashAttentionCompileInfo { + int64_t core_num; +}; + +static const std::map> SFA_LAYOUT_AXIS_MAP = { + {SFALayout::BSND, {SFAAxis::B, SFAAxis::S, SFAAxis::N, SFAAxis::D}}, + {SFALayout::TND, {SFAAxis::T, SFAAxis::N, SFAAxis::D}}, + {SFALayout::PA_BSND, {SFAAxis::Bn, SFAAxis::Bs, SFAAxis::N, SFAAxis::D}}, +}; + +static const std::map SFA_LAYOUT_DIM_MAP = { + {SFALayout::BSND, DIM_NUM_FOUR}, + {SFALayout::TND, DIM_NUM_THREE}, + {SFALayout::PA_BSND, DIM_NUM_FOUR}, +}; + +static std::string GetShapeStr(gert::Shape shape) +{ + std::ostringstream oss; + oss << "["; + if (shape.GetDimNum() > 0) { + for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) { + oss << shape.GetDim(i) << ", "; + } + oss << shape.GetDim(shape.GetDimNum() - 1); + } + oss << "]"; + return oss.str(); +} + +static std::string SFADataTypeToSerialString(ge::DataType type) +{ + const auto it = DATATYPE_TO_STRING_MAP.find(type); + if (it != DATATYPE_TO_STRING_MAP.end()) { + return it->second; + } else { + OPS_LOG_E("SparseFlashAttention", "datatype %d not support", type); + return "UNDEFINED"; + } +} + +string SFATensorDesc2String(const gert::StorageShape *shape, const gert::CompileTimeTensorDesc *tensor) +{ + if (shape == nullptr || tensor == nullptr) { + return "nil "; + } + + std::ostringstream oss; + oss << "(dtype: " << ge::TypeUtils::DataTypeToAscendString(tensor->GetDataType()).GetString() << "),"; + oss << "(shape:" << SFAShape2String(shape->GetStorageShape()) << "),"; + oss << "(ori_shape:" << SFAShape2String(shape->GetOriginShape()) << "),"; + oss << "(format: " + << ge::TypeUtils::FormatToAscendString( + static_cast(ge::GetPrimaryFormat(tensor->GetStorageFormat()))) + .GetString() + << "),"; + oss << "(ori_format: " << ge::TypeUtils::FormatToAscendString(tensor->GetOriginFormat()).GetString() << ") "; + + return oss.str(); +} + +string SFADebugTilingContext(const gert::TilingContext *context) +{ + std::ostringstream oss; + for (size_t i = 0; i < context->GetComputeNodeInfo()->GetInputsNum(); ++i) { + oss << "input" << i << ": "; + oss << SFATensorDesc2String(context->GetInputShape(i), context->GetInputDesc(i)); + } + + for (size_t i = 0; i < context->GetComputeNodeInfo()->GetOutputsNum(); ++i) { + oss << "output" << i << ": "; + oss << SFATensorDesc2String(context->GetOutputShape(i), context->GetOutputDesc(i)); + } + return oss.str(); +} + +std::string SFALayoutToSerialString(SFALayout layout) +{ + switch (layout) { + case SFALayout::BSND: return "BSND"; + case SFALayout::TND: return "TND"; + case SFALayout::PA_BSND: return "PA_BSND"; + default: return "UNKNOWN"; + } +} + +ge::graphStatus SFAMlaTiling::SetBlockDim(uint32_t blockDim) +{ + context_->SetBlockDim(blockDim); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAMlaTiling::SetTilingKey(uint64_t tilingKey) +{ + context_->SetTilingKey(tilingKey); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAMlaTiling::SetWorkspaceSize(uint64_t workspaceSize) +{ + OPS_ERR_IF(context_->GetWorkspaceSizes(1) == nullptr, + OPS_REPORT_VECTOR_INNER_ERR(context_->GetNodeName(), "workSpaceSize got from ge is nullptr"), + return ge::GRAPH_FAILED); + size_t *workSpaces = context_->GetWorkspaceSizes(1); + workSpaces[0] = workspaceSize; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAMlaTiling::SetTilingData(TilingDef &tilingData) +{ + OPS_ERR_IF(context_->GetRawTilingData() == nullptr, + OPS_REPORT_VECTOR_INNER_ERR(context_->GetNodeName(), "RawTilingData got from GE context is nullptr."), + return ge::GRAPH_FAILED); + + tilingData.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity()); + context_->GetRawTilingData()->SetDataSize(tilingData.GetDataSize()); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAMlaTiling::GetPlatformInfo() +{ + OPS_ERR_IF(sfaInfo_->platformInfo == nullptr, + OPS_REPORT_VECTOR_INNER_ERR(sfaInfo_->opName, "GetPlatformInfo is nullptr."), return ge::GRAPH_FAILED); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(sfaInfo_->platformInfo); + libapiSize_ = ascendcPlatform.GetLibApiWorkSpaceSize(); + aivNum_ = ascendcPlatform.GetCoreNumAiv(); + aicNum_ = ascendcPlatform.GetCoreNumAic(); + + OPS_ERR_IF(aicNum_ == 0 || aivNum_ == 0, + OPS_REPORT_VECTOR_INNER_ERR(sfaInfo_->opName, "num of core obtained is 0."), return GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +void SFAMlaTiling::GenTilingKey() +{ + uint32_t inputQType = static_cast(sfaInfo_->inputQType); + uint32_t inputKvType = static_cast(sfaInfo_->inputKvType); + uint32_t outputType = static_cast(sfaInfo_->outputType); + uint32_t layoutQuery = static_cast(sfaInfo_->qLayout); + uint32_t layoutKV = static_cast(sfaInfo_->kvLayout); + + tilingKey_ = GET_TPL_TILING_KEY(0U, layoutQuery, layoutKV, perfMode_ == SFAPerfMode::V_TEMPLATE_MODE); + + OPS_LOG_I(sfaInfo_->opName, "SFA tilingKey_: %lu.", tilingKey_); +} + +void SFAMlaTiling::ZeroTensorProcess() +{ + if (sfaInfo_->s2Size == 0) { + sfaInfo_->s2Size = 1024; + } +} + +void SFAMlaTiling::InitParams() +{ + if (sfaInfo_->s2Size != 0 && sfaInfo_->sparseBlockSize <= 4) { + perfMode_ = SFAPerfMode::V_TEMPLATE_MODE; + } else { + perfMode_ = SFAPerfMode::C_TEMPLATE_MODE; + } + + coreNum_ = aicNum_; + + headDimAlign_ = Align(sfaInfo_->qkHeadDim, BYTE_BLOCK); + ZeroTensorProcess(); +} + +void SFAMlaTiling::CalcUbBmm() +{ + uint32_t cubeMSize = sfaInfo_->gSize * sfaInfo_->s1Size; + uint32_t maxMSize = mBaseSize_; + if (cubeMSize > maxMSize) { + cubeMSize = maxMSize; + } + mmResUbSize_ = sInnerSizeAlign_ * Align(cubeMSize, 16U); + bmm2ResUbSize_ = headDimAlign_ * Align(cubeMSize, 16U); + + qPreSizeMla_ = sfaInfo_->gSize * (headDimAlign_ + 64U) * sfaInfo_->s1Size; +} + +void SFAMlaTiling::CheckUbSpace() +{ + CalcUbBmm(); +} + +void SFAMlaTiling::CalcInnerSize(uint32_t s2Size) +{ + sInnerSize_ = 512; + if (splitKVFlag_ && sfaInfo_->qLayout != SFALayout::TND) { + if (s2Size == 256) { + sInnerSize_ = 128; + } else if (s2Size > 256 && s2Size <= sInnerSize_) { + sInnerSize_ = (sInnerSize_ + 1) / 2; + } + } + + sInnerLoopTimes_ = (s2Size + sInnerSize_ - 1) / sInnerSize_; + sInnerSizeTail_ = s2Size - (sInnerLoopTimes_ - 1) * sInnerSize_; + if (sInnerSize_ > s2Size) { + sInnerSize_ = s2Size; + } + sInnerSizeAlign_ = Align(sInnerSize_, BYTE_BLOCK); + + CheckUbSpace(); +} + +void SFAMlaTiling::SplitBalanced() +{ + CalcInnerSize(sfaInfo_->s2Size); + + InnerSplitParams innerSplitParams; + innerSplitParams.s1GBaseSize = sfaInfo_->gSize; + innerSplitParams.s2BaseSize = sInnerSize_; + tilingData_.innerSplitParams.set_mBaseSize(innerSplitParams.s1GBaseSize); + tilingData_.innerSplitParams.set_s2BaseSize(innerSplitParams.s2BaseSize); + + usedCoreNum_ = aicNum_; +} + +void SFAMlaTiling::Split() +{ + SplitBalanced(); +} + +void SFAMlaTiling::FillTilingBaseParamsMla() +{ + tilingData_.baseParams.set_batchSize(sfaInfo_->bSize); + tilingData_.baseParams.set_seqSize(sfaInfo_->s2Size); + tilingData_.baseParams.set_qSeqSize(sfaInfo_->s1Size); + tilingData_.baseParams.set_blockSize(sfaInfo_->blockSize); + tilingData_.baseParams.set_maxBlockNumPerBatch(sfaInfo_->maxBlockNumPerBatch); + tilingData_.baseParams.set_scaleValue(sfaInfo_->scaleValue); + tilingData_.baseParams.set_nNumOfQInOneGroup(sfaInfo_->n1Size / sfaInfo_->n2Size); + tilingData_.baseParams.set_actualLenDimsQ(sfaInfo_->actualLenDimsQ); + tilingData_.baseParams.set_actualLenDimsKV(sfaInfo_->actualLenDimsKV); + tilingData_.baseParams.set_outputLayout(static_cast(sfaInfo_->outLayout)); + tilingData_.baseParams.set_sparseMode(sfaInfo_->sparseMode); + tilingData_.baseParams.set_sparseBlockSize(sfaInfo_->sparseBlockSize); + tilingData_.baseParams.set_sparseBlockCount(sfaInfo_->sparseBlockCount); +} + +// for flash decode +void SFAMlaTiling::FillTilingSplitKVMla() +{ + tilingData_.splitKVParams.set_s2(kvSplitPart_); + + tilingData_.splitKVParams.set_accumOutSize(aicNum_ * 2 * sfaInfo_->n2Size * mBaseSize_ * headDimAlign_); + tilingData_.splitKVParams.set_logSumExpSize(2 * aicNum_ * 2 * sfaInfo_->n2Size * mBaseSize_ * + (BYTE_BLOCK / BLOCK_TABLE_ELEM_BYTE)); + + if (!splitKVFlag_) { + tilingData_.splitKVParams.set_s2(0); + } +} + +void SFAMlaTiling::FillTilingSingleCoreParamsMla() +{ + tilingData_.singleCoreParams.set_usedCoreNum(usedCoreNum_); +} + +void SFAMlaTiling::FillTilingSingleCoreTensorSizeMla() +{ + tilingData_.singleCoreTensorSize.set_mmResUbSize(mmResUbSize_); + tilingData_.singleCoreTensorSize.set_bmm2ResUbSize(bmm2ResUbSize_); +} + +void SFAMlaTiling::FillTiling() +{ + FillTilingBaseParamsMla(); + FillTilingSplitKVMla(); + FillTilingSingleCoreParamsMla(); + FillTilingSingleCoreTensorSizeMla(); +} + +uint32_t SFAMlaTiling::CalcBalanceFDParamNums(const uint32_t actCoreNum) +{ + return actCoreNum * 2 * sfaInfo_->n2Size * mBaseSize_; +} + +void SFAMlaTiling::NormalCalcFDWorkSpace(const uint32_t actCoreNum) +{ + if (splitKVFlag_) { + uint32_t accumOutSize = 0; + uint32_t logSumExpSize = 0; + uint32_t FDParamNums = CalcBalanceFDParamNums(actCoreNum); + accumOutSize = FDParamNums * headDimAlign_; + logSumExpSize = 2 * FDParamNums * (BYTE_BLOCK / sfaInfo_->blockTypeSize); + workspaceSize_ += (accumOutSize + logSumExpSize) * sfaInfo_->blockTypeSize; + if (sfaInfo_->socVersion == platform_ascendc::SocVersion::ASCEND310P) { + workspaceSize_ += static_cast(actCoreNum) * 32; + } + } +} + +void SFAMlaTiling::CalcFDWorkSpace(const uint32_t actCoreNum) +{ + NormalCalcFDWorkSpace(actCoreNum); +} + +void SFAMlaTiling::GetWorkspaceSize() +{ + uint32_t mmResElemSize = 4; + uint32_t vec1ResElemSize = 2; + uint32_t bmm2ResElemSize = 4; + uint32_t qPreProcResElemSize = 0; + uint32_t nUpdateElemSize = 4; + uint32_t softmaxSumElemSize = 4; + float kvDtypeRatio = 1.0; + + workspaceSize_ = libapiSize_; + uint32_t preLoadNum = 1; + uint32_t actCoreNum = coreNum_; + preLoadNum = PRE_LOAD_NUM; + + workspaceSize_ += preLoadNum * (mmResUbSize_ * actCoreNum * mmResElemSize); + workspaceSize_ += preLoadNum * static_cast(static_cast(mmResUbSize_ * actCoreNum * vec1ResElemSize) * kvDtypeRatio); + workspaceSize_ += preLoadNum * bmm2ResUbSize_ * actCoreNum * bmm2ResElemSize; + workspaceSize_ += preLoadNum * static_cast(static_cast(qPreSizeMla_ * actCoreNum * qPreProcResElemSize) * kvDtypeRatio); + workspaceSize_ += preLoadNum * mBaseSize_ * actCoreNum * nUpdateElemSize; + workspaceSize_ += preLoadNum * mBaseSize_ * actCoreNum * softmaxSumElemSize; + workspaceSize_ += 4 * 512 * (512 + 64) * 2 * actCoreNum; + workspaceSize_ += 4 * 128 * 4 * (2 * actCoreNum); + + CalcFDWorkSpace(actCoreNum); +} + +void SFAMlaTiling::CalcBlockDim() +{ + auto ascendcPlatform = platform_ascendc::PlatformAscendC(sfaInfo_->platformInfo); + auto aicNum = usedCoreNum_; + auto aivNum = 2 * usedCoreNum_; + + blockDim_ = ascendcPlatform.CalcTschBlockDim(aivNum, aicNum, aivNum); + OPS_LOG_I(sfaInfo_->opName, "SFA block dim: %u aiv Num: %u aic Num: %u.", blockDim_, aivNum, aicNum); +} + +ge::graphStatus SFAMlaTiling::DoOpTiling(SFATilingInfo *sfaInfo) +{ + sfaInfo_ = sfaInfo; + if (GetPlatformInfo() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + InitParams(); + Split(); + FillTiling(); + CalcBlockDim(); + GetWorkspaceSize(); + GenTilingKey(); + + if ((SetBlockDim(blockDim_) != ge::GRAPH_SUCCESS) || + (SetTilingKey(tilingKey_) != ge::GRAPH_SUCCESS) || + (SetWorkspaceSize(workspaceSize_) != ge::GRAPH_SUCCESS) || + (SetTilingData(tilingData_) != ge::GRAPH_SUCCESS)) { + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus TilingSparseFlashAttention(gert::TilingContext *context) +{ + SFATilingInfo sfaInfo; + SFAInfoParser sfaInfoParser(context); + if (sfaInfoParser.Parse(sfaInfo) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + SFATilingCheck tilingChecker(sfaInfo); + if (tilingChecker.Process() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + SFAMlaTiling tiling(context); + return tiling.DoOpTiling(&sfaInfo); +} + +ge::graphStatus TilingPrepareForSparseFlashAttention(gert::TilingParseContext *context) +{ + (void)context; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::GetExpectedShape(gert::Shape &shapeExpected, + const SFATilingShapeCompareParam ¶m, const SFALayout &layout) const +{ + if (layout == SFALayout::BSND) { + shapeExpected = gert::Shape({param.B, param.S, param.N, param.D}); + } else if (layout == SFALayout::TND) { + shapeExpected = gert::Shape({param.T, param.N, param.D}); + } else if (layout == SFALayout::PA_BSND) { + shapeExpected = gert::Shape({param.Bn, param.Bs, param.N, param.D}); + } else { + OPS_LOG_E(opName_, "layout %s is unsupported", SFALayoutToSerialString(layout).c_str()); + return ge::GRAPH_FAILED; + } + if (shapeExpected.GetDim(0) == 0) { + OPS_LOG_E(opName_, "expected shape is %s, the first dim should not be 0.", GetShapeStr(shapeExpected).c_str()); + return ge::GRAPH_PARAM_INVALID; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CompareShape(SFATilingShapeCompareParam ¶m, + const gert::Shape &shape, const SFALayout &layout, const std::string &name) const +{ + gert::Shape shapeExpected; + if (GetExpectedShape(shapeExpected, param, layout) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + if (shape.GetDimNum() != shapeExpected.GetDimNum()) { + OPS_LOG_E(opName_, + "%s dimension is %zu, expected dimension is %zu.", + name.c_str(), shape.GetDimNum(), shapeExpected.GetDimNum()); + return ge::GRAPH_FAILED; + } + + for (size_t i = 0; i < shape.GetDimNum(); i++) { + if (shape.GetDim(i) != shapeExpected.GetDim(i)) { + OPS_LOG_E(opName_, "%s layout is %s, shape is %s, expected shape is %s.", + name.c_str(), SFALayoutToSerialString(layout).c_str(), + GetShapeStr(shape).c_str(), GetShapeStr(shapeExpected).c_str()); + return ge::GRAPH_FAILED; + } + } + + return ge::GRAPH_SUCCESS; +} + +void SFATilingCheck::LogErrorDtypeSupport(const std::vector &expectDtypeList, + const ge::DataType &actualDtype, const std::string &name) const +{ + std::ostringstream oss; + for (size_t i = 0; i < expectDtypeList.size(); ++i) { + oss << SFADataTypeToSerialString(expectDtypeList[i]); + if (i < expectDtypeList.size() - 1) { + oss << ", "; + } + } + OPS_LOG_E(opName_, "Tensor %s only supports dtype %s, but got %s", + name.c_str(), oss.str().c_str(), SFADataTypeToSerialString(actualDtype).c_str()); +} + +ge::graphStatus SFATilingCheck::CheckDtypeSupport(const gert::CompileTimeTensorDesc *desc, + const std::string &name) const +{ + if (desc != nullptr) { + const auto& it = DTYPE_SUPPORT_MAP.find(name); + OPS_ERR_IF(it == DTYPE_SUPPORT_MAP.end(), + OPS_LOG_E(opName_, "%s datatype support list should be specify in DTYPE_SUPPORT_MAP", name.c_str()), + return ge::GRAPH_FAILED); + auto &expectDtypeList = it->second; + OPS_ERR_IF(std::find( + expectDtypeList.begin(), expectDtypeList.end(), desc->GetDataType()) == expectDtypeList.end(), + LogErrorDtypeSupport(expectDtypeList, desc->GetDataType(), name), + return ge::GRAPH_FAILED); + } + return ge::GRAPH_SUCCESS; +} + +template +void SFATilingCheck::LogErrorNumberSupport(const std::vector &expectNumberList, + const T &actualValue, const std::string &name, const std::string subName) const +{ + std::ostringstream oss; + for (size_t i = 0; i < expectNumberList.size(); ++i) { + oss << std::to_string(expectNumberList[i]); + if (i < expectNumberList.size() - 1) { + oss << ", "; + } + } + + OPS_LOG_E(opName_, "%s %s only supports %s, but got %s", + name.c_str(), subName.c_str(), oss.str().c_str(), std::to_string(actualValue).c_str()); +} + +template +void SFATilingCheck::LogErrorDimNumSupport(const std::vector &expectNumberList, + const T &actualValue, const std::string &name) const +{ + LogErrorNumberSupport(expectNumberList, actualValue, name, "dimension"); +} + +ge::graphStatus SFATilingCheck::CheckDimNumInLayoutSupport(const SFALayout &layout, + const gert::StorageShape *shape, const std::string &name) const +{ + const auto& dimIt = SFA_LAYOUT_DIM_MAP.find(layout); + OPS_ERR_IF(shape->GetStorageShape().GetDimNum() != dimIt->second, + OPS_LOG_E(opName_, "When layout is %s, %s dimension should be %zu, but it's %zu", + SFALayoutToSerialString(layout).c_str(), name.c_str(), dimIt->second, + shape->GetStorageShape().GetDimNum()), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckDimNumSupport(const gert::StorageShape *shape, + const std::vector &expectDimNumList, const std::string &name) const +{ + if (shape == nullptr) { + return ge::GRAPH_SUCCESS; + } + + if (std::find(expectDimNumList.begin(), expectDimNumList.end(), + shape->GetStorageShape().GetDimNum()) == expectDimNumList.end()) { + LogErrorDimNumSupport(expectDimNumList, shape->GetStorageShape().GetDimNum(), name); + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + + +void SFATilingCheck::LogErrorLayoutSupport(const std::vector &expectLayoutList, + const SFALayout &actualLayout, const std::string &name) const +{ + std::ostringstream oss; + for (size_t i = 0; i < expectLayoutList.size(); ++i) { + oss << SFALayoutToSerialString(expectLayoutList[i]); + if (i < expectLayoutList.size() - 1) { + oss << ", "; + } + } + OPS_LOG_E(opName_, "Tensor %s only supports layout %s, but got %s", + name.c_str(), oss.str().c_str(), SFALayoutToSerialString(actualLayout).c_str()); +} + +ge::graphStatus SFATilingCheck::CheckLayoutSupport(const SFALayout &actualLayout, const std::string &name) const +{ + const auto& it = LAYOUT_SUPPORT_MAP.find(name); + OPS_ERR_IF(it == LAYOUT_SUPPORT_MAP.end(), + OPS_LOG_E(opName_, "%s layout support list should be specify in LAYOUT_SUPPORT_MAP", name.c_str()), + return ge::GRAPH_FAILED); + auto &expectLayoutList = it->second; + OPS_ERR_IF(std::find( + expectLayoutList.begin(), expectLayoutList.end(), actualLayout) == expectLayoutList.end(), + LogErrorLayoutSupport(expectLayoutList, actualLayout, name), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckSingleParaQuery() const +{ + const std::vector queryDimNumList = {DIM_NUM_THREE, DIM_NUM_FOUR}; + if (ge::GRAPH_SUCCESS != CheckDtypeSupport(opParamInfo_.query.desc, QUERY_NAME) || + ge::GRAPH_SUCCESS != CheckLayoutSupport(qLayout_, QUERY_NAME) || + ge::GRAPH_SUCCESS != CheckDimNumSupport(opParamInfo_.query.shape, queryDimNumList, QUERY_NAME) || + ge::GRAPH_SUCCESS != CheckDimNumInLayoutSupport(qLayout_, opParamInfo_.query.shape, QUERY_NAME)) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckSingleParaKey() const +{ + const std::vector keyDimNumList = {DIM_NUM_FOUR, DIM_NUM_THREE}; + if (ge::GRAPH_SUCCESS != CheckDtypeSupport(opParamInfo_.key.desc, KEY_NAME) || + ge::GRAPH_SUCCESS != CheckLayoutSupport(kvLayout_, KEY_NAME) || + ge::GRAPH_SUCCESS != CheckDimNumSupport(opParamInfo_.key.shape, keyDimNumList, KEY_NAME) || + ge::GRAPH_SUCCESS != CheckDimNumInLayoutSupport(kvLayout_, opParamInfo_.key.shape, KEY_NAME)) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckSingleParaNumHeads() const +{ + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckSingleParaKvHeadNums() const +{ + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckSingleParaSparseMode() const +{ + OPS_ERR_IF((*opParamInfo_.sparseMode != 3 && *opParamInfo_.sparseMode != 0), + OPS_LOG_E(opName_, "sparseMode must == 0/3, but got: %ld.", *opParamInfo_.sparseMode), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckSingleParaSparseBlockSize() const +{ + OPS_ERR_IF((*opParamInfo_.sparseBlockSize <= 0), + OPS_LOG_E(opName_, "sparseBlockSize should be greater than 0, but got: %ld.", *opParamInfo_.sparseBlockSize), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckSingleParaSparseIndices() const +{ + if (ge::GRAPH_SUCCESS != CheckDtypeSupport(opParamInfo_.sparseIndices.desc, SPARSE_INDICES_NAME)) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckSinglePara() const +{ + if (ge::GRAPH_SUCCESS != CheckSingleParaQuery() || + ge::GRAPH_SUCCESS != CheckSingleParaKey() || + ge::GRAPH_SUCCESS != CheckSingleParaSparseIndices() || + ge::GRAPH_SUCCESS != CheckSingleParaNumHeads() || + ge::GRAPH_SUCCESS != CheckSingleParaKvHeadNums() || + ge::GRAPH_SUCCESS != CheckSingleParaSparseMode() || + ge::GRAPH_SUCCESS != CheckSingleParaSparseBlockSize()) { + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckRopeExistence() +{ + OPS_ERR_IF((opParamInfo_.queryRope.tensor != nullptr && opParamInfo_.keyRope.tensor == nullptr), + OPS_LOG_E(opName_, "KeyRope is null, but queryRope exists, they should be both null or exist."), + return ge::GRAPH_FAILED); + OPS_ERR_IF((opParamInfo_.queryRope.tensor == nullptr && opParamInfo_.keyRope.tensor != nullptr), + OPS_LOG_E(opName_, "QueryRope is null, but keyRope exists, they should be both null or exist."), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.keyRope.desc == nullptr || opParamInfo_.queryRope.desc == nullptr, + OPS_LOG_E(opName_, "In Mla situation, desc of keyRope and queryRope should not be null"), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckExists(const void *pointer, const std::string &name) const +{ + OPS_ERR_IF(pointer == nullptr, + OPS_LOG_E(opName_, "%s should not be null", name.c_str()), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckNotExists(const void *pointer, const std::string &name) const +{ + OPS_ERR_IF(pointer != nullptr, + OPS_LOG_E(opName_, "%s should be null", name.c_str()), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckExistsByMap(const std::map ¶mMap) const +{ + for (const auto& kv : paramMap) { + if (CheckExists(kv.second, kv.first) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckNotExistsByMap(const std::map ¶mMap) const +{ + for (const auto& kv : paramMap) { + if (CheckNotExists(kv.second, kv.first) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckExistenceByMap(std::map &existMap, + std::map ¬ExistMap) const +{ + if (CheckExistsByMap(existMap) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + if (CheckNotExistsByMap(notExistMap) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +template +ge::graphStatus SFATilingCheck::CheckAttrValueByMap(std::map> &attrMap) const +{ + for (auto const &kv : attrMap) { + const std::string &name = kv.first; + const std::pair &pointerValuePair = kv.second; + if (pointerValuePair.first == nullptr) { + OPS_LOG_E(opName_, "Attr %s should not be nullptr", name.c_str()); + return ge::GRAPH_FAILED; + } + + if (*(pointerValuePair.first) != pointerValuePair.second) { + std::ostringstream ossExpect; + ossExpect << std::to_string(pointerValuePair.second); + std::ostringstream ossActual; + ossActual << std::to_string(*(pointerValuePair.first)); + OPS_LOG_E(opName_, + "%s value should be %s, but got %s", + name.c_str(), + ossExpect.str().c_str(), + ossActual.str().c_str()); + return ge::GRAPH_FAILED; + } + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckParaExistenceMlaNoquant() const +{ + if (kvStorageMode_ != KvStorageMode::PAGE_ATTENTION) { + return ge::GRAPH_SUCCESS; + } + std::map mlaNoquantParamExistMap = { + {"actualSeqLengths", opParamInfo_.actualSeqLengths.tensor}, + {"blockTable", opParamInfo_.blockTable.tensor}, + }; + std::map mlaNoquantParamNotExistMap = {}; + if (CheckExistenceByMap(mlaNoquantParamExistMap, mlaNoquantParamNotExistMap) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckParaExistenceMla() const +{ + return CheckParaExistenceMlaNoquant(); +} + +ge::graphStatus SFATilingCheck::CheckParaExistence() +{ + if (ge::GRAPH_SUCCESS != CheckRopeExistence()) { + return ge::GRAPH_FAILED; + } + + return CheckParaExistenceMla(); +} + +ge::graphStatus SFATilingCheck::GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor, + const SFALayout &layoutQuery, const std::string &name) +{ + if (tensor == nullptr) { + OPS_LOG_E(opName_, "when layout of query is %s, %s must be provided.", + SFALayoutToSerialString(layoutQuery).c_str(), name.c_str()); + return ge::GRAPH_FAILED; + } + int64_t shapeSize = tensor->GetShapeSize(); + if (shapeSize <= 0) { + OPS_LOG_E(opName_, "the shape size of %s is %ld, it should be greater than 0.", + name.c_str(), shapeSize); + return ge::GRAPH_FAILED; + } + size = static_cast(shapeSize); + return ge::GRAPH_SUCCESS; +} + +void SFATilingCheck::SetSFAShapeCompare() +{ + queryShapeCmp_ = opParamInfo_.query.shape->GetStorageShape(); + topkShapeCmp_ = opParamInfo_.sparseIndices.shape->GetStorageShape(); + keyShapeCmp_ = opParamInfo_.key.shape->GetStorageShape(); + valueShapeCmp_ = opParamInfo_.value.shape->GetStorageShape(); + attenOutShapeCmp_ = opParamInfo_.attenOut.shape->GetStorageShape(); + queryRopeShapeCmp_ = opParamInfo_.queryRope.tensor->GetStorageShape(); + keyRopeShapeCmp_ = opParamInfo_.keyRope.tensor->GetStorageShape(); +} + +ge::graphStatus SFATilingCheck::CheckBlockTable() const +{ + if (kvStorageMode_ != KvStorageMode::PAGE_ATTENTION) { + OPS_ERR_IF(opParamInfo_.blockTable.tensor != nullptr, + OPS_LOG_E(opName_, "when the layout_kv is %s, %s should be null", + SFALayoutToSerialString(kvLayout_).c_str(), BLOCK_TABLE_NAME.c_str()), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; + } + + uint32_t blockTableBatch = opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0); + OPS_ERR_IF(blockTableBatch != bSize_, + OPS_LOG_E(opName_, "%s's first dimension(%u) should be equal to batch size(%u)", + BLOCK_TABLE_NAME.c_str(), blockTableBatch, bSize_), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckDTypeConsistency(const ge::DataType &actualDtype, + const ge::DataType &expectDtype, const std::string &name) const +{ + if (actualDtype != expectDtype) { + OPS_LOG_E(opName_, "%s dtype should be %s, but it's %s.", name.c_str(), + SFADataTypeToSerialString(expectDtype).c_str(), + SFADataTypeToSerialString(actualDtype).c_str()); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckQRopeShape() +{ + SFATilingShapeCompareParam shapeParams; + shapeParams.B = bSize_; + shapeParams.N = n1Size_; + shapeParams.S = s1Size_; + shapeParams.D = ropeHeadDim_; + shapeParams.T = qTSize_; + return CompareShape(shapeParams, queryRopeShapeCmp_, qLayout_, QUERY_ROPE_NAME); +} + +ge::graphStatus SFATilingCheck::CheckTopkShape() +{ + SFATilingShapeCompareParam shapeParams; + shapeParams.B = bSize_; + shapeParams.N = n2Size_; + shapeParams.S = s1Size_; + shapeParams.D = sparseBlockCount_; + shapeParams.T = qTSize_; + return CompareShape(shapeParams, topkShapeCmp_, topkLayout_, SPARSE_INDICES_NAME); +} + +ge::graphStatus SFATilingCheck::CheckAttenOutShape() +{ + SFATilingShapeCompareParam shapeParams; + shapeParams.B = bSize_; + shapeParams.N = n1Size_; + shapeParams.S = s1Size_; + shapeParams.D = vHeadDim_; + shapeParams.T = qTSize_; + if (CompareShape(shapeParams, attenOutShapeCmp_, outLayout_, ATTEN_OUT_NAME) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckAttenOut() +{ + if (ge::GRAPH_SUCCESS != CheckDTypeConsistency(opParamInfo_.attenOut.desc->GetDataType(), + inputQType_, ATTEN_OUT_NAME) || + ge::GRAPH_SUCCESS != CheckAttenOutShape()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckQRope() +{ + if (ge::GRAPH_SUCCESS != CheckDTypeConsistency(opParamInfo_.queryRope.desc->GetDataType(), + inputQType_, QUERY_ROPE_NAME) || + ge::GRAPH_SUCCESS != CheckQRopeShape()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckTopK() +{ + if (ge::GRAPH_SUCCESS != CheckTopkShape()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckVAndKRopeShapeForBatchContinuous() +{ + SFATilingShapeCompareParam shapeParams; + shapeParams.B = bSize_; + shapeParams.N = n2Size_; + shapeParams.S = s2Size_; + shapeParams.T = kvTSize_; + shapeParams.D = qkHeadDim_; + if (CompareShape(shapeParams, keyShapeCmp_, kvLayout_, KEY_NAME) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + shapeParams.D = vHeadDim_; + if (CompareShape(shapeParams, valueShapeCmp_, kvLayout_, VALUE_NAME) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + shapeParams.D = ropeHeadDim_; + if (CompareShape(shapeParams, keyRopeShapeCmp_, kvLayout_, KEY_ROPE_NAME) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +uint32_t SFATilingCheck::GetTypeSize(ge::DataType dtype) const +{ + uint32_t typeSize = NUM_BYTES_FLOAT16; + switch (dtype) { + case ge::DT_FLOAT16: + typeSize = NUM_BYTES_FLOAT16; + break; + case ge::DT_BF16: + typeSize = NUM_BYTES_BF16; + break; + default: + typeSize = NUM_BYTES_FLOAT16; + } + return typeSize; +} + +ge::graphStatus SFATilingCheck::CheckVAndKRopeShapeForPageAttention() +{ + int64_t blockNum = keyShapeCmp_.GetDim(0); + OPS_ERR_IF(blockNum <= 0, + OPS_LOG_E(opName_, "The first dim(%ld) of key should be greater than 0", blockNum), + return ge::GRAPH_FAILED); + SFATilingShapeCompareParam shapeParams; + shapeParams.Bn = blockNum; + shapeParams.N = n2Size_; + shapeParams.Bs = blockSize_; + shapeParams.D = vHeadDim_; + shapeParams.T = kvTSize_; + if (CompareShape(shapeParams, valueShapeCmp_, kvLayout_, VALUE_NAME) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + shapeParams.D = ropeHeadDim_; + if (CompareShape(shapeParams, keyRopeShapeCmp_, kvLayout_, KEY_ROPE_NAME) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckVAndKRopeShape() +{ + if (kvStorageMode_ == KvStorageMode::BATCH_CONTINUOUS) { + return CheckVAndKRopeShapeForBatchContinuous(); + } + + if (kvStorageMode_ == KvStorageMode::PAGE_ATTENTION) { + return CheckVAndKRopeShapeForPageAttention(); + } + + OPS_LOG_E(opName_, "storage mode of key and value is %u, it is incorrect.", static_cast(kvStorageMode_)); + return ge::GRAPH_FAILED; +} + +ge::graphStatus SFATilingCheck::CheckVAndKRope() +{ + if (ge::GRAPH_SUCCESS != CheckDTypeConsistency(opParamInfo_.value.desc->GetDataType(), + inputKvType_, VALUE_NAME) || + ge::GRAPH_SUCCESS != CheckDTypeConsistency(opParamInfo_.keyRope.desc->GetDataType(), + inputKvType_, KEY_ROPE_NAME) || ge::GRAPH_SUCCESS != CheckVAndKRopeShape()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckActualSeqLensQ() +{ + if (ge::GRAPH_SUCCESS != CheckActualSeqLensQDType() || + ge::GRAPH_SUCCESS != CheckActualSeqLensQShape()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckActualSeqLensQDType() +{ + if (opParamInfo_.actualSeqLengthsQ.tensor == nullptr) { + return ge::GRAPH_SUCCESS; + } + if (opParamInfo_.actualSeqLengthsQ.desc == nullptr) { + OPS_LOG_E(opName_, "actualSeqLengthsQ is not empty," + "but actualSeqLengthsQ's dtype is nullptr."); + return ge::GRAPH_FAILED; + } + if (opParamInfo_.actualSeqLengthsQ.desc->GetDataType() != ge::DT_INT32) { + OPS_LOG_E(opName_, "actualSeqLengthsQ's dtype is %s, it should be DT_INT32.", + SFADataTypeToSerialString(opParamInfo_.actualSeqLengthsQ.desc->GetDataType()).c_str()); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckActualSeqLensQShape() +{ + if (opParamInfo_.actualSeqLengthsQ.tensor == nullptr) { + return ge::GRAPH_SUCCESS; + } + uint32_t shapeSize = 0; + if (GetActualSeqLenSize(shapeSize, opParamInfo_.actualSeqLengthsQ.tensor, qLayout_, "actualSeqLengthsQ") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + if (shapeSize != bSize_) { + OPS_LOG_E(opName_, "actualSeqLengthsQ shape size is %u, it should be equal to batch size[%u]", + shapeSize, bSize_); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckActualSeqLens() +{ + if (std::string(opParamInfo_.layoutKV) == "TND" && opParamInfo_.actualSeqLengths.tensor == nullptr) { + OPS_LOG_E(opName_, + "when the layout of key and value is TND, " + "the actualSeqLengths of key and value shoule not be empty."); + return ge::GRAPH_PARAM_INVALID; + } + if (ge::GRAPH_SUCCESS != CheckActualSeqLensDType() || + ge::GRAPH_SUCCESS != CheckActualSeqLensShape()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckActualSeqLensDType() +{ + if (opParamInfo_.actualSeqLengths.tensor == nullptr) { + return ge::GRAPH_SUCCESS; + } + if (opParamInfo_.actualSeqLengths.desc == nullptr) { + OPS_LOG_E(opName_, "actualSeqLengths is not empty," + "but actualSeqLengths's dtype is nullptr."); + return ge::GRAPH_FAILED; + } + if (opParamInfo_.actualSeqLengths.desc->GetDataType() != ge::DT_INT32) { + OPS_LOG_E(opName_, "actualSeqLengths's dtype is %s, it should be DT_INT32.", + SFADataTypeToSerialString(opParamInfo_.actualSeqLengths.desc->GetDataType()).c_str()); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckActualSeqLensShape() +{ + if (opParamInfo_.actualSeqLengths.tensor == nullptr) { + return ge::GRAPH_SUCCESS; + } + uint32_t shapeSize = 0; + if(GetActualSeqLenSize(shapeSize, opParamInfo_.actualSeqLengths.tensor, kvLayout_, "actualSeqLengths") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + if (shapeSize != bSize_) { + OPS_LOG_E(opName_, "actualSeqLengths shape size is %u, it should be equal to batch size[%u].", + shapeSize, bSize_); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckMultiParaConsistency() +{ + SetSFAShapeCompare(); + if (ge::GRAPH_SUCCESS != CheckVAndKRope() || + ge::GRAPH_SUCCESS != CheckQRope() || + ge::GRAPH_SUCCESS != CheckTopK() || + ge::GRAPH_SUCCESS != CheckAttenOut() || + ge::GRAPH_SUCCESS != CheckActualSeqLensQ() || + ge::GRAPH_SUCCESS != CheckActualSeqLens() || + ge::GRAPH_SUCCESS != CheckBlockTable()) { + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckFeatureMlaNoQuantShape() const +{ + OPS_ERR_IF(bSize_ <= 0, + OPS_LOG_E(opName_, "batch_size should be greater than 0, but got %u", bSize_), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(qTSize_ <= 0 && (qLayout_ == SFALayout::TND), + OPS_LOG_E(opName_, "T_size of query should be greater than 0, but got %u", qTSize_), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(n1Size_ <= 0, + OPS_LOG_E(opName_, "q_head_num should be greater than 0, but got %u", n1Size_), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(n2Size_ != 1, + OPS_LOG_E(opName_, "kv_head_num should be 1, but got %u", n2Size_), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(n1Size_ % n2Size_ != 0, + OPS_LOG_E(opName_, "q_head_num(%u) must be divisible by kv_head_num(%u)", n1Size_, n2Size_), + return ge::GRAPH_FAILED); + + std::vector gSizeSupportList = {1, 2, 4, 8, 16, 32, 64, 128}; + OPS_ERR_IF(std::find(gSizeSupportList.begin(), gSizeSupportList.end(), gSize_) == gSizeSupportList.end(), + OPS_LOG_E(opName_, "group num should be in 1, 2, 4, 8, 16, 32, 64, 128, but got %u", gSize_), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(qkHeadDim_ != 512, + OPS_LOG_E(opName_, "qk_head_dim only support 512, but got %u", qkHeadDim_), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(qkHeadDim_ != vHeadDim_, + OPS_LOG_E(opName_, "qk_head_dim[%u] should be equal to v_head_dim[%u]", qkHeadDim_, vHeadDim_), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(ropeHeadDim_ != 64, + OPS_LOG_E(opName_, "rope_head_dim should be 64, but got %u", ropeHeadDim_), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckFeatureMlaNoQuantLayout() const +{ + const std::vector layoutSupportList = { + "BSND", + "TND" + }; + std::string layoutQuery = opParamInfo_.layoutQuery; + OPS_ERR_IF(std::find(layoutSupportList.begin(), layoutSupportList.end(), layoutQuery) == layoutSupportList.end(), + OPS_LOG_E(opName_, "layoutQuery only supports BSND/TND, but got %s", layoutQuery.c_str()), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckFeatureMlaNoQuantDtype() const +{ + OPS_ERR_IF(inputQType_ != ge::DT_BF16 && inputQType_ != ge::DT_FLOAT16, + OPS_LOG_E(opName_, "query dtype only support %s and %s, but got %s", + SFADataTypeToSerialString(ge::DT_BF16).c_str(), SFADataTypeToSerialString(ge::DT_FLOAT16).c_str(), + SFADataTypeToSerialString(inputQType_).c_str()), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckFeatureMlaNoquantPa() const +{ + if (kvStorageMode_ != KvStorageMode::PAGE_ATTENTION) { + return ge::GRAPH_SUCCESS; + } + + OPS_ERR_IF(blockSize_ <= 0 || blockSize_ > static_cast(MAX_BLOCK_SIZE), + OPS_LOG_E(opName_, "when page attention is enabled, block_size(%d) should be in range (0, %u].", + blockSize_, MAX_BLOCK_SIZE), return ge::GRAPH_FAILED); + + OPS_ERR_IF(blockSize_ % 16 > 0, + OPS_LOG_E(opName_, "when page attention is enabled, block_size(%d) should be 16-aligned.", + blockSize_), return ge::GRAPH_FAILED); + + OPS_ERR_IF(blockSize_ % sparseBlockSize_ > 0, + OPS_LOG_E(opName_, "when page attention is enabled, block_size(%d) must be divided by sparse_block_size(%d), but now the remainder is %d.", + blockSize_, sparseBlockSize_, blockSize_ % sparseBlockSize_), return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckFeatureMlaNoquant() const +{ + if (ge::GRAPH_SUCCESS != CheckFeatureMlaNoQuantShape() || + ge::GRAPH_SUCCESS != CheckFeatureMlaNoQuantLayout() || + ge::GRAPH_SUCCESS != CheckFeatureMlaNoQuantDtype() || + ge::GRAPH_SUCCESS != CheckFeatureMlaNoquantPa()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckFeatureMla() const +{ + return CheckFeatureMlaNoquant(); +} + +ge::graphStatus SFATilingCheck::CheckFeature() const +{ + return CheckFeatureMla(); +} + +void SFATilingCheck::Init() +{ + opName_ = sfaInfo_.opName; + platformInfo_ = sfaInfo_.platformInfo; + opParamInfo_ = sfaInfo_.opParamInfo; + socVersion_ = sfaInfo_.socVersion; + + bSize_ = sfaInfo_.bSize; + n1Size_ = sfaInfo_.n1Size; + n2Size_ = sfaInfo_.n2Size; + s1Size_ = sfaInfo_.s1Size; + s2Size_ = sfaInfo_.s2Size; + gSize_ = sfaInfo_.gSize; + qkHeadDim_ = sfaInfo_.qkHeadDim; + vHeadDim_ = sfaInfo_.vHeadDim; + ropeHeadDim_ = sfaInfo_.ropeHeadDim; + maxBlockNumPerBatch_ = sfaInfo_.maxBlockNumPerBatch; + qTSize_ = sfaInfo_.qTSize; + kvTSize_ = sfaInfo_.kvTSize; + blockSize_ = sfaInfo_.blockSize; + sparseBlockCount_ = sfaInfo_.sparseBlockCount; + sparseBlockSize_ = sfaInfo_.sparseBlockSize; + + inputQType_ = sfaInfo_.inputQType; + inputKvType_ = sfaInfo_.inputKvType; + inputQRopeType_ = sfaInfo_.inputQRopeType; + inputKRopeType_ = sfaInfo_.inputKRopeType; + outputType_ = sfaInfo_.outputType; + + qLayout_ = sfaInfo_.qLayout; + topkLayout_ = sfaInfo_.topkLayout; + kvLayout_ = sfaInfo_.kvLayout; + outLayout_ = sfaInfo_.outLayout; + + kvStorageMode_ = sfaInfo_.kvStorageMode; + l2CacheSize_ = sfaInfo_.l2CacheSize; +} + +ge::graphStatus SFATilingCheck::Process() +{ + Init(); + if (CheckSinglePara() != ge::GRAPH_SUCCESS || + CheckParaExistence() != ge::GRAPH_SUCCESS || + CheckFeature() != ge::GRAPH_SUCCESS || + CheckMultiParaConsistency() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +bool SFAInfoParser::HasAxis(const SFAAxis &axis, const SFALayout &layout, const gert::Shape &shape) const +{ + const auto& layoutIt = SFA_LAYOUT_AXIS_MAP.find(layout); + if (layoutIt == SFA_LAYOUT_AXIS_MAP.end()) { + return false; + } + + const std::vector& axes = layoutIt->second; + const auto& axisIt = std::find(axes.begin(), axes.end(), axis); + if (axisIt == axes.end()) { + return false; + } + const auto& dimIt = SFA_LAYOUT_DIM_MAP.find(layout); + if (dimIt == SFA_LAYOUT_DIM_MAP.end() || dimIt->second != shape.GetDimNum()) { + return false; + } + return true; +} + +size_t SFAInfoParser::GetAxisIdx(const SFAAxis &axis, const SFALayout &layout) const +{ + const std::vector& axes = SFA_LAYOUT_AXIS_MAP.find(layout)->second; + const auto& axisIt = std::find(axes.begin(), axes.end(), axis); + return std::distance(axes.begin(), axisIt); +} + +uint32_t SFAInfoParser::GetAxisNum(const gert::Shape &shape, const SFAAxis &axis,const SFALayout &layout) const +{ + return HasAxis(axis, layout, shape) ? shape.GetDim(GetAxisIdx(axis, layout)) : invalidDimValue_; +} + +ge::graphStatus SFAInfoParser::CheckRequiredInOutExistence() const +{ + OPS_ERR_IF(opParamInfo_.query.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor query is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.query.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor query is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.key.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor k is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.key.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor k is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.value.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor value is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.value.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor value is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.sparseIndices.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor sparseIndices is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.sparseIndices.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor sparseIndices is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.attenOut.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor output is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.attenOut.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor output is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.queryRope.tensor == nullptr, OPS_LOG_E(opName_, "Shape of queryRope is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.queryRope.desc == nullptr, OPS_LOG_E(opName_, "Desc of queryRope is nullptr"), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::CheckRequiredAttrExistence() const +{ + OPS_ERR_IF(opParamInfo_.layoutQuery == nullptr, OPS_LOG_E(opName_, "attr layoutQuery is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.layoutKV == nullptr, OPS_LOG_E(opName_, "attr layoutKV is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.sparseBlockSize == nullptr, OPS_LOG_E(opName_, "attr sparseBlockSize is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.scaleValue == nullptr, OPS_LOG_E(opName_, "attr scaleValue is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.sparseMode == nullptr, OPS_LOG_E(opName_, "attr sparseMode is nullptr"), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::CheckRequiredParaExistence() const +{ + if (CheckRequiredInOutExistence() != ge::GRAPH_SUCCESS || + CheckRequiredAttrExistence() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor, + SFALayout &layout, const std::string &name) +{ + if ((tensor == nullptr)) { + OPS_LOG_E(opName_, "when layout of query is %s, %s must be provided.", + SFALayoutToSerialString(layout).c_str(), name.c_str()); + return ge::GRAPH_FAILED; + } + int64_t shapeSize = tensor->GetShapeSize(); + if (shapeSize <= 0) { + OPS_LOG_E(opName_, "the shape size of %s is %ld, it should be greater than 0.", + name.c_str(), shapeSize); + return ge::GRAPH_FAILED; + } + size = static_cast(shapeSize); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetActualSeqLenQSize(uint32_t &size) +{ + return GetActualSeqLenSize(size, opParamInfo_.actualSeqLengthsQ.tensor, qLayout_, "actualSeqLengthsQ"); +} + +ge::graphStatus SFAInfoParser::GetOpName() +{ + if (context_->GetNodeName() == nullptr) { + OPS_LOG_E("SparseFlashAttention", "opName got from TilingContext is nullptr"); + return ge::GRAPH_FAILED; + } + opName_ = context_->GetNodeName(); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetNpuInfo() +{ + platformInfo_ = context_->GetPlatformInfo(); + OPS_ERR_IF(platformInfo_ == nullptr, + OPS_REPORT_VECTOR_INNER_ERR(opName_, "GetPlatformInfo is nullptr."), return ge::GRAPH_FAILED); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo_); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint32_t aicNum = ascendcPlatform.GetCoreNumAic(); + OPS_ERR_IF(aicNum == 0 || aivNum == 0, + OPS_REPORT_VECTOR_INNER_ERR(opName_, "num of core obtained is 0."), return GRAPH_FAILED); + + socVersion_ = ascendcPlatform.GetSocVersion(); + if (socVersion_ != platform_ascendc::SocVersion::ASCEND910B) { + OPS_REPORT_VECTOR_INNER_ERR(opName_, "SOC Version[%d] is not support.", (int32_t)socVersion_); + return GRAPH_FAILED; + } + + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L2, l2CacheSize_); + + return ge::GRAPH_SUCCESS; +} + +void SFAInfoParser::GetOptionalInputParaInfo() +{ + opParamInfo_.blockTable.tensor = context_->GetOptionalInputTensor(BLOCK_TABLE_INPUT_INDEX); + opParamInfo_.actualSeqLengthsQ.tensor = context_->GetOptionalInputTensor(ACT_SEQ_LEN_Q_INPUT_INDEX); + opParamInfo_.actualSeqLengthsQ.desc = context_->GetOptionalInputDesc(ACT_SEQ_LEN_Q_INPUT_INDEX); + opParamInfo_.actualSeqLengths.tensor = context_->GetOptionalInputTensor(ACT_SEQ_LEN_KV_INPUT_INDEX); + opParamInfo_.actualSeqLengths.desc = context_->GetOptionalInputDesc(ACT_SEQ_LEN_KV_INPUT_INDEX); + opParamInfo_.queryRope.tensor = context_->GetOptionalInputTensor(QUERY_ROPE_INPUT_INDEX); + opParamInfo_.queryRope.desc = context_->GetOptionalInputDesc(QUERY_ROPE_INPUT_INDEX); + opParamInfo_.keyRope.tensor = context_->GetOptionalInputTensor(KEY_ROPE_INPUT_INDEX); + opParamInfo_.keyRope.desc = context_->GetOptionalInputDesc(KEY_ROPE_INPUT_INDEX); +} + +void SFAInfoParser::GetInputParaInfo() +{ + opParamInfo_.query.desc = context_->GetInputDesc(QUERY_INPUT_INDEX); + opParamInfo_.query.shape = context_->GetInputShape(QUERY_INPUT_INDEX); + opParamInfo_.key.desc = context_->GetInputDesc(KEY_INPUT_INDEX); + opParamInfo_.key.shape = context_->GetInputShape(KEY_INPUT_INDEX); + opParamInfo_.value.desc = context_->GetInputDesc(VALUE_INPUT_INDEX); + opParamInfo_.value.shape = context_->GetInputShape(VALUE_INPUT_INDEX); + opParamInfo_.sparseIndices.desc = context_->GetInputDesc(SPARSE_INDICES_INPUT_INDEX); + opParamInfo_.sparseIndices.shape = context_->GetInputShape(SPARSE_INDICES_INPUT_INDEX); + GetOptionalInputParaInfo(); +} + +void SFAInfoParser::GetOutputParaInfo() +{ + opParamInfo_.attenOut.desc = context_->GetOutputDesc(OUTPUT_INDEX); + opParamInfo_.attenOut.shape = context_->GetOutputShape(OUTPUT_INDEX); +} + +ge::graphStatus SFAInfoParser::GetAttrParaInfo() +{ + auto attrs = context_->GetAttrs(); + OPS_ERR_IF(attrs == nullptr, OPS_REPORT_VECTOR_INNER_ERR(context_->GetNodeName(), "attrs got from ge is nullptr"), + return ge::GRAPH_FAILED); + + opParamInfo_.layoutQuery = attrs->GetStr(LAYOUT_QUERY_ATTR_INDEX); + opParamInfo_.layoutKV = attrs->GetStr(LAYOUT_KV_ATTR_INDEX); + opParamInfo_.sparseBlockSize = attrs->GetAttrPointer(SPARSE_BLOCK_SIZE_ATTR_INDEX); + opParamInfo_.scaleValue = attrs->GetAttrPointer(SCALE_VALUE_ATTR_INDEX); + opParamInfo_.sparseMode = attrs->GetAttrPointer(SPARSE_MODE_ATTR_INDEX); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetOpParaInfo() +{ + GetInputParaInfo(); + GetOutputParaInfo(); + if (ge::GRAPH_SUCCESS != GetAttrParaInfo()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetInOutDataType() +{ + inputQType_ = opParamInfo_.query.desc->GetDataType(); + inputKvType_ = opParamInfo_.key.desc->GetDataType(); + outputType_ = opParamInfo_.attenOut.desc->GetDataType(); + if (opParamInfo_.queryRope.desc != nullptr) { + inputQRopeType_ = opParamInfo_.queryRope.desc->GetDataType(); + } + if (opParamInfo_.keyRope.desc != nullptr) { + inputKRopeType_ = opParamInfo_.keyRope.desc->GetDataType(); + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetBatchSize() +{ + if (qLayout_ == SFALayout::TND) { + return GetActualSeqLenQSize(bSize_); + } else { // BSND + bSize_ = GetAxisNum(queryShape_, SFAAxis::B, qLayout_); + return ge::GRAPH_SUCCESS; + } +} + +ge::graphStatus SFAInfoParser::GetQTSize() +{ + qTSize_ = (qLayout_ == SFALayout::TND) ? GetAxisNum(queryShape_, SFAAxis::T, qLayout_) : 0; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetKVTSize() +{ + kvTSize_ = (kvLayout_ == SFALayout::TND) ? GetAxisNum(keyShape_, SFAAxis::T, kvLayout_) : 0; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetQkHeadDim() +{ + qkHeadDim_ = GetAxisNum(queryShape_, SFAAxis::D, qLayout_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetS1Size() +{ + if (qLayout_ == SFALayout::TND) { + s1Size_ = GetAxisNum(queryShape_, SFAAxis::T, qLayout_); + return ge::GRAPH_SUCCESS; + } else { // BSND + s1Size_ = GetAxisNum(queryShape_, SFAAxis::S, qLayout_); + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetKvStorageMode() +{ + if (kvLayout_ == SFALayout::PA_BSND) { + kvStorageMode_ = KvStorageMode::PAGE_ATTENTION; + } else { + kvStorageMode_ = KvStorageMode::BATCH_CONTINUOUS; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetKvLayout() +{ + const map layoutKVMap = { + {"BSND", SFALayout::BSND}, + {"PA_BSND", SFALayout::PA_BSND}, + {"TND", SFALayout::TND} + }; + + std::string layout(opParamInfo_.layoutKV); + auto it = layoutKVMap.find(layout); + if (it != layoutKVMap.end()) { + kvLayout_ = it->second; + } else { + OPS_LOG_E(opName_, "layoutKV is %s, it is unsupported.", layout.c_str()); + return ge::GRAPH_FAILED; + } + if (kvLayout_ != SFALayout::PA_BSND && qLayout_ != kvLayout_) { + OPS_LOG_E(opName_, "When layoutKV is not PA_BSND, layoutKV must be the same as layoutQ."); + return ge::GRAPH_FAILED; + } + uint32_t keyDimNum = opParamInfo_.key.shape->GetStorageShape().GetDimNum(); + if (kvLayout_ == SFALayout::PA_BSND && keyDimNum != 4U) { + OPS_LOG_E(opName_, "When layoutKV is PA_BSND, kvDimNum must be 4, but now is %d.", keyDimNum); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetS2SizeForBatchContinuous() +{ + if (kvLayout_ == SFALayout::BSND) { // BSND + s2Size_ = GetAxisNum(keyShape_, SFAAxis::S, kvLayout_); + } else if (kvLayout_ == SFALayout::TND) { + s2Size_ = GetAxisNum(keyShape_, SFAAxis::T, kvLayout_); + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetMaxBlockNumPerBatch() +{ + if (opParamInfo_.blockTable.tensor == nullptr) { + OPS_LOG_E(opName_, "the layout_kv is %s, blockTable must be provided.", SFALayoutToSerialString(kvLayout_).c_str()); + return ge::GRAPH_FAILED; + } + uint32_t dimNum = opParamInfo_.blockTable.tensor->GetStorageShape().GetDimNum(); + if (dimNum != DIM_NUM_TWO) { + OPS_LOG_E(opName_, "the dim num of block_table is %u, it should be %u.", dimNum, DIM_NUM_TWO); + return ge::GRAPH_FAILED; + } + if (opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(1) <= 0) { + OPS_LOG_E(opName_, "%s's second dimension(%ld) should be greater than 0", + BLOCK_TABLE_NAME.c_str(), opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(1)); + return ge::GRAPH_FAILED; + } + maxBlockNumPerBatch_ = opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(1); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetBlockSize() +{ + blockSize_ = GetAxisNum(keyShape_, SFAAxis::Bs, kvLayout_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetSparseBlockCount() +{ + sparseBlockCount_ = GetAxisNum(sparseIndicesShape_, SFAAxis::K, qLayout_); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetS2SizeForPageAttention() +{ + if (GetMaxBlockNumPerBatch() != ge::GRAPH_SUCCESS || GetBlockSize() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + s2Size_ = maxBlockNumPerBatch_ * blockSize_; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetS2Size() +{ + if (kvStorageMode_ == KvStorageMode::BATCH_CONTINUOUS) { + return GetS2SizeForBatchContinuous(); + } + return GetS2SizeForPageAttention(); +} + +ge::graphStatus SFAInfoParser::GetValueHeadDim() +{ + vHeadDim_ = GetAxisNum(valueShape_, SFAAxis::D, kvLayout_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetRopeHeadDim() +{ + ropeHeadDim_ = GetAxisNum(queryRopeShape_, SFAAxis::D, qLayout_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetQueryAndOutLayout() +{ + const map> layoutMap = { + {"BSND", {SFALayout::BSND, SFALayout::BSND}}, + {"TND", {SFALayout::TND, SFALayout::TND }}, + }; + + std::string layout(opParamInfo_.layoutQuery); + auto it = layoutMap.find(layout); + if (it != layoutMap.end()) { + qLayout_ = it->second.first; + outLayout_ = it->second.second; + } else { + OPS_LOG_E(opName_, "layoutQuery is %s, it is unsupported.", layout.c_str()); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetTopkLayout() +{ + topkLayout_ = qLayout_; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetN1Size() +{ + n1Size_ = GetAxisNum(queryShape_, SFAAxis::N, qLayout_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetN2Size() +{ + n2Size_ = GetAxisNum(keyShape_, SFAAxis::N, kvLayout_); + return ge::GRAPH_SUCCESS; +} + +void SFAInfoParser::SetSFAShape() +{ + queryShape_ = opParamInfo_.query.shape->GetStorageShape(); + keyShape_ = opParamInfo_.key.shape->GetStorageShape(); + valueShape_ = opParamInfo_.value.shape->GetStorageShape(); + sparseIndicesShape_ = opParamInfo_.sparseIndices.shape->GetStorageShape(); + queryRopeShape_ = opParamInfo_.queryRope.tensor->GetStorageShape(); +} + +ge::graphStatus SFAInfoParser::GetGSize() +{ + if (n2Size_ != 0) { + gSize_ = n1Size_ / n2Size_; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetActualseqInfo() +{ + maxActualseq_ = static_cast(s2Size_); + if (opParamInfo_.actualSeqLengths.tensor != nullptr) { + actualLenDimsKV_ = opParamInfo_.actualSeqLengths.tensor->GetShapeSize(); + } + if (opParamInfo_.actualSeqLengthsQ.tensor != nullptr) { + actualLenDimsQ_ = opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize(); + } + return ge::GRAPH_SUCCESS; +} + +void SFAInfoParser::GenerateInfo(SFATilingInfo &sfaInfo) +{ + sfaInfo.opName = opName_; + sfaInfo.platformInfo = platformInfo_; + sfaInfo.opParamInfo = opParamInfo_; + sfaInfo.socVersion = socVersion_; + + sfaInfo.bSize = bSize_; + sfaInfo.n1Size = n1Size_; + sfaInfo.n2Size = n2Size_; + sfaInfo.s1Size = s1Size_; + sfaInfo.s2Size = s2Size_; + sfaInfo.gSize = gSize_; + sfaInfo.qkHeadDim = qkHeadDim_; + sfaInfo.vHeadDim = vHeadDim_; + sfaInfo.ropeHeadDim = ropeHeadDim_; + sfaInfo.qTSize = qTSize_; + sfaInfo.kvTSize = kvTSize_; + sfaInfo.sparseBlockSize = *opParamInfo_.sparseBlockSize; + sfaInfo.sparseBlockCount = sparseBlockCount_; + + sfaInfo.inputQType = inputQType_; + sfaInfo.inputKvType = inputKvType_; + sfaInfo.inputQRopeType = inputQRopeType_; + sfaInfo.inputKRopeType = inputKRopeType_; + sfaInfo.outputType = outputType_; + + sfaInfo.kvStorageMode = kvStorageMode_; + sfaInfo.l2CacheSize = l2CacheSize_; + + sfaInfo.totalBlockNum = opParamInfo_.key.shape->GetStorageShape().GetDim(0); + sfaInfo.scaleValue = *opParamInfo_.scaleValue; + sfaInfo.pageAttentionFlag = (kvStorageMode_ == KvStorageMode::PAGE_ATTENTION); + sfaInfo.blockSize = blockSize_; + sfaInfo.blockTypeSize = sizeof(float); + sfaInfo.maxBlockNumPerBatch = maxBlockNumPerBatch_; + + sfaInfo.actualLenDimsQ = actualLenDimsQ_; + sfaInfo.actualLenDimsKV = actualLenDimsKV_; + sfaInfo.maxActualseq = maxActualseq_; + sfaInfo.actualSeqLenFlag = (opParamInfo_.actualSeqLengths.tensor != nullptr); + sfaInfo.isSameSeqAllKVTensor = isSameSeqAllKVTensor_; + sfaInfo.isSameActualseq = isSameActualseq_; + + sfaInfo.sparseMode = *opParamInfo_.sparseMode; + + sfaInfo.qLayout = qLayout_; + sfaInfo.topkLayout = topkLayout_; + sfaInfo.kvLayout = kvLayout_; + sfaInfo.outLayout = outLayout_; +} + +ge::graphStatus SFAInfoParser::Parse(SFATilingInfo &sfaInfo) +{ + if (context_ == nullptr) { + OPS_LOG_E("SparseFlashAttention", "tiling context is nullptr!"); + return ge::GRAPH_FAILED; + } + OPS_LOG_FULL(DLOG_INFO, "SparseFlashAttention", "TilingContext: %s", SFADebugTilingContext(context_).c_str()); + if (ge::GRAPH_SUCCESS != GetOpName() || + ge::GRAPH_SUCCESS != GetNpuInfo() || + ge::GRAPH_SUCCESS != GetOpParaInfo() || + ge::GRAPH_SUCCESS != CheckRequiredParaExistence()) { + return ge::GRAPH_FAILED; + } + + if (ge::GRAPH_SUCCESS != GetInOutDataType() || + ge::GRAPH_SUCCESS != GetQueryAndOutLayout() || + ge::GRAPH_SUCCESS != GetTopkLayout() || + ge::GRAPH_SUCCESS != GetKvLayout() || + ge::GRAPH_SUCCESS != GetKvStorageMode()) { + return ge::GRAPH_FAILED; + } + + SetSFAShape(); + if ( + ge::GRAPH_SUCCESS != GetN1Size() || + ge::GRAPH_SUCCESS != GetN2Size() || + ge::GRAPH_SUCCESS != GetGSize() || + ge::GRAPH_SUCCESS != GetBatchSize() || + ge::GRAPH_SUCCESS != GetQTSize() || + ge::GRAPH_SUCCESS != GetKVTSize() || + ge::GRAPH_SUCCESS != GetS1Size() || + ge::GRAPH_SUCCESS != GetQkHeadDim() || + ge::GRAPH_SUCCESS != GetS2Size() || + ge::GRAPH_SUCCESS != GetValueHeadDim() || + ge::GRAPH_SUCCESS != GetRopeHeadDim() || + ge::GRAPH_SUCCESS != GetSparseBlockCount()) { + return ge::GRAPH_FAILED; + } + + if (ge::GRAPH_SUCCESS != GetActualseqInfo()) { + return ge::GRAPH_FAILED; + } + + GenerateInfo(sfaInfo); + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(SparseFlashAttention) + .Tiling(TilingSparseFlashAttention) + .TilingParse(TilingPrepareForSparseFlashAttention); +} // namespace optiling diff --git a/csrc/sparse_flash_attention/op_host/sparse_flash_attention_tiling.h b/csrc/sparse_flash_attention/op_host/sparse_flash_attention_tiling.h new file mode 100644 index 00000000..34e7b9c0 --- /dev/null +++ b/csrc/sparse_flash_attention/op_host/sparse_flash_attention_tiling.h @@ -0,0 +1,583 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_tiling.h + * \brief + */ +#ifndef SPARSE_FLASH_ATTENTION_TILING_H +#define SPARSE_FLASH_ATTENTION_TILING_H + +#include +#include +#include +#include +#include "register/tilingdata_base.h" +#include "exe_graph/runtime/tiling_context.h" + +namespace optiling { +// Inputs Index +constexpr uint32_t QUERY_INPUT_INDEX = 0; +constexpr uint32_t KEY_INPUT_INDEX = 1; +constexpr uint32_t VALUE_INPUT_INDEX = 2; +constexpr uint32_t SPARSE_INDICES_INPUT_INDEX = 3; +constexpr uint32_t BLOCK_TABLE_INPUT_INDEX = 4; +constexpr uint32_t ACT_SEQ_LEN_Q_INPUT_INDEX = 5; +constexpr uint32_t ACT_SEQ_LEN_KV_INPUT_INDEX = 6; +constexpr uint32_t QUERY_ROPE_INPUT_INDEX = 7; +constexpr uint32_t KEY_ROPE_INPUT_INDEX = 8; +// Outputs Index +constexpr uint32_t OUTPUT_INDEX = 0; +// Attributes Index +constexpr uint32_t SCALE_VALUE_ATTR_INDEX = 0; +constexpr uint32_t SPARSE_BLOCK_SIZE_ATTR_INDEX = 1; +constexpr uint32_t LAYOUT_QUERY_ATTR_INDEX = 2; +constexpr uint32_t LAYOUT_KV_ATTR_INDEX = 3; +constexpr uint32_t SPARSE_MODE_ATTR_INDEX = 4; +// Dim Num +constexpr size_t DIM_NUM_TWO = 2; +constexpr size_t DIM_NUM_THREE = 3; +constexpr size_t DIM_NUM_FOUR = 4; +// Constant +constexpr uint32_t MAX_BLOCK_SIZE = 1024; +constexpr uint32_t COPYND2NZ_SRC_STRIDE_LIMITATION = 65535; +constexpr uint32_t NUM_BYTES_FLOAT = 4; +constexpr uint32_t NUM_BYTES_FLOAT16 = 2; +constexpr uint32_t NUM_BYTES_BF16 = 2; +constexpr uint32_t BYTE_BLOCK = 32; +const uint32_t SFA_MAX_AIC_CORE_NUM = 26; + +enum class SFALayout : uint32_t { + BSND = 0, + TND = 1, + PA_BSND = 2 +}; + +struct SFATilingShapeCompareParam { + int64_t B = 1; + int64_t S = 1; + int64_t N = 1; + int64_t D = 1; + int64_t T = 1; + // PA + int64_t Bs = 1; + int64_t Bn = 1; +}; + +enum class KvStorageMode : uint32_t { + BATCH_CONTINUOUS = 0, + PAGE_ATTENTION = 1 +}; + +enum class SFAPerfMode : uint32_t { + C_TEMPLATE_MODE = 0, + V_TEMPLATE_MODE +}; + +enum class SFAAxis : uint32_t { + B = 0, + S = 1, + N = 2, + D = 3, + K = 3, + T = 5, + Bn = 6, // block number + Bs = 7, // block size +}; + +struct SFARequiredParaInfo { + const gert::CompileTimeTensorDesc *desc; + const gert::StorageShape *shape; +}; + +struct SFAOptionalParaInfo { + const gert::CompileTimeTensorDesc *desc; + const gert::Tensor *tensor; +}; + +struct SFAParaInfo { + SFARequiredParaInfo query = {nullptr, nullptr}; + SFARequiredParaInfo key = {nullptr, nullptr}; + SFARequiredParaInfo value = {nullptr, nullptr}; + SFARequiredParaInfo sparseIndices = {nullptr, nullptr}; + SFAOptionalParaInfo blockTable = {nullptr, nullptr}; + SFAOptionalParaInfo actualSeqLengthsQ = {nullptr, nullptr}; + SFAOptionalParaInfo actualSeqLengths = {nullptr, nullptr}; + SFAOptionalParaInfo queryRope = {nullptr, nullptr}; + SFAOptionalParaInfo keyRope = {nullptr, nullptr}; + SFARequiredParaInfo attenOut = {nullptr, nullptr}; + + const char *layoutQuery = nullptr; + const char *layoutKV = nullptr; + const int64_t *sparseBlockSize = nullptr; + const float *scaleValue = nullptr; + const int64_t *sparseMode = nullptr; +}; + +struct InnerSplitParams { + uint32_t s1GBaseSize = 1; + uint32_t s2BaseSize = 1; +}; + +BEGIN_TILING_DATA_DEF(SparseFlashAttentionBaseParamsMla) +TILING_DATA_FIELD_DEF(uint32_t, batchSize) +TILING_DATA_FIELD_DEF(uint32_t, seqSize) +TILING_DATA_FIELD_DEF(uint32_t, qSeqSize) +TILING_DATA_FIELD_DEF(int64_t, blockSize) +TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch) +TILING_DATA_FIELD_DEF(float, scaleValue) +TILING_DATA_FIELD_DEF(uint32_t, nNumOfQInOneGroup) +TILING_DATA_FIELD_DEF(uint32_t, actualLenDimsQ) +TILING_DATA_FIELD_DEF(uint32_t, actualLenDimsKV) +TILING_DATA_FIELD_DEF(uint32_t, outputLayout) +TILING_DATA_FIELD_DEF(uint32_t, sparseMode) +TILING_DATA_FIELD_DEF(int64_t, sparseBlockSize) +TILING_DATA_FIELD_DEF(uint32_t, sparseBlockCount) +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(SparseFlashAttentionBaseParamsMlaOp, SparseFlashAttentionBaseParamsMla) + +BEGIN_TILING_DATA_DEF(SparseFlashAttentionSingleCoreParamsMla) +TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum); +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(SparseFlashAttentionSingleCoreParamsMlaOp, SparseFlashAttentionSingleCoreParamsMla) + +BEGIN_TILING_DATA_DEF(SparseFlashAttentionSingleCoreTensorSizeMla) +TILING_DATA_FIELD_DEF(uint32_t, mmResUbSize); +TILING_DATA_FIELD_DEF(uint32_t, bmm2ResUbSize); +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(SparseFlashAttentionSingleCoreTensorSizeMlaOp, SparseFlashAttentionSingleCoreTensorSizeMla) + +BEGIN_TILING_DATA_DEF(SparseFlashAttentionSplitKVParamsMla) +TILING_DATA_FIELD_DEF(uint32_t, s2) +TILING_DATA_FIELD_DEF(uint32_t, accumOutSize) // FD workspace +TILING_DATA_FIELD_DEF(uint32_t, logSumExpSize) // FD workspace +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(SparseFlashAttentionSplitKVParamsMlaOp, SparseFlashAttentionSplitKVParamsMla) + +BEGIN_TILING_DATA_DEF(SparseFlashAttentionInnerSplitParams) +TILING_DATA_FIELD_DEF(uint32_t, mBaseSize) +TILING_DATA_FIELD_DEF(uint32_t, s2BaseSize) +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(SparseFlashAttentionInnerSplitParamsOp, SparseFlashAttentionInnerSplitParams) + +BEGIN_TILING_DATA_DEF(SparseFlashAttentionTilingDataMla) +TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionBaseParamsMla, baseParams); +TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionSplitKVParamsMla, splitKVParams); +TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionSingleCoreParamsMla, singleCoreParams); +TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionSingleCoreTensorSizeMla, singleCoreTensorSize); +TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionInnerSplitParams, innerSplitParams); +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(SparseFlashAttention, SparseFlashAttentionTilingDataMla) + +template inline T Align(T num, T rnd) +{ + return (((rnd) == 0) ? 0 : (((num) + (rnd) - 1) / (rnd) * (rnd))); +} + +template +std::string SFAShape2String(const T &shape) +{ + std::ostringstream oss; + oss << "["; + if (shape.GetDimNum() > 0) { + for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) { + oss << shape.GetDim(i) << ", "; + } + oss << shape.GetDim(shape.GetDimNum() - 1); + } + oss << "]"; + return oss.str(); +} + +static std::string GetShapeStr(gert::Shape shape); +static std::string SFADataTypeToSerialString(ge::DataType type); +std::string SFATensorDesc2String(const gert::StorageShape *shape, const gert::CompileTimeTensorDesc *tensor); +std::string SFADebugTilingContext(const gert::TilingContext *context); +std::string SFALayoutToSerialString(SFALayout layout); + +struct SFATilingInfo { + const char *opName = nullptr; + fe::PlatFormInfos *platformInfo = nullptr; + SFAParaInfo opParamInfo; + + // Base Param + platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910B; + uint32_t bSize = 0; + uint32_t n1Size = 0; + uint32_t n2Size = 0; + uint32_t s1Size = 0; + int64_t s2Size = 0; + uint32_t qkHeadDim = 0; + uint32_t vHeadDim = 0; + uint32_t gSize = 0; + uint32_t ropeHeadDim = 0; + uint32_t qTSize = 0; + uint32_t kvTSize = 0; + float scaleValue = 0; + uint32_t innerPrecise = 0; + uint32_t l2CacheOffFlag = 0; + int64_t sparseBlockSize = 0; + int64_t sparseBlockCount = 0; + + bool pageAttentionFlag = false; + int64_t blockSize = 0; + uint32_t blockTypeSize = 0; + uint32_t maxBlockNumPerBatch = 0; + uint32_t totalBlockNum = 0; + + uint32_t actualLenDimsQ = 0; + uint32_t maxActualseq = 0; + + bool actualSeqLenFlag = false; + bool isSameSeqAllKVTensor = true; + bool isSameActualseq = true; + uint32_t actualLenDimsKV = 0; + std::vector kvListSeqLens {}; + + uint32_t sparseMode = 0; + + ge::DataType inputQType = ge::DT_FLOAT16; + ge::DataType inputKvType = ge::DT_FLOAT16; + ge::DataType outputType = ge::DT_FLOAT16; + + KvStorageMode kvStorageMode = KvStorageMode::BATCH_CONTINUOUS; + + SFALayout qLayout = SFALayout::BSND; + SFALayout topkLayout = SFALayout::BSND; + SFALayout outLayout = SFALayout::BSND; + SFALayout kvLayout = SFALayout::BSND; + + ge::DataType inputQRopeType = ge::DT_FLOAT16; + ge::DataType inputKRopeType = ge::DT_FLOAT16; + + uint64_t l2CacheSize = 0; +}; + +class SFAMlaTiling { +public: + explicit SFAMlaTiling(gert::TilingContext *context) : context_(context) {} + ge::graphStatus DoOpTiling(SFATilingInfo *sfaInfo); + +private: + ge::graphStatus SetBlockDim(uint32_t blockDim); + ge::graphStatus SetTilingKey(uint64_t tilingKey); + ge::graphStatus SetWorkspaceSize(uint64_t workspaceSize); + ge::graphStatus SetTilingData(TilingDef &tilingData); + gert::TilingContext *context_ = nullptr; + ge::graphStatus GetPlatformInfo(); + void GenTilingKey(); + bool DealSameSeqEachBatch(); + + void ZeroTensorProcess(); + void InitParams(); + + void Split(); + bool IsBalanceSplitCore(); + + void SplitBalanced(); + void CalcInnerSize(uint32_t s2Size); + + bool IsFlashDecode(uint32_t coreNum); + + void FillTilingBaseParamsMla(); + void FillTilingSplitKVMla(); + + void FillTilingSingleCoreParamsMla(); + void FillTilingSingleCoreTensorSizeMla(); + void FillTiling(); + + void CalcUbBmm(); + void CheckUbSpace(); + void NormalCalcFDWorkSpace(const uint32_t actCoreNum); + void CalcFDWorkSpace(const uint32_t actCoreNum); + void GetWorkspaceSize(); + + uint32_t CalcBalanceFDParamNums(const uint32_t actCoreNum); + + void CalcBlockDim(); + + bool balanceModeFlag_ = false; + bool splitKVFlag_ = false; + + uint32_t coreNum_ = 0; + SFAPerfMode perfMode_ = SFAPerfMode::V_TEMPLATE_MODE; + uint32_t kvSplitPart_ = 1; + size_t mmResUbSize_ = 0; + size_t bmm2ResUbSize_ = 0; + size_t qPreSizeMla_= 0; + uint32_t sInnerLoopTimes_ = 0; + uint32_t sInnerSize_ = 0; + uint32_t sInnerSizeTail_ = 0; + uint32_t sInnerSizeAlign_ = 0; + uint32_t kvSplit_ = 0; + uint32_t usedCoreNum_ = 0; + uint32_t formerCoreNum_ = 0; + uint32_t blockSplitBn2Range_ = 0; + uint32_t tailSplitedBatchRange_ = 0; + + uint32_t aicNum_ = 0; + uint32_t aivNum_ = 0; + size_t libapiSize_ = 0; + + SparseFlashAttentionTilingDataMla tilingData_; + uint32_t blockDim_{0}; + uint64_t workspaceSize_{0}; + uint64_t tilingKey_{0}; + + uint32_t headDimAlign_ = 0; + uint32_t mBaseSize_ = 128; + uint32_t mFdBaseSize_ = 8; + + SFATilingInfo *sfaInfo_ = nullptr; +}; + +class SFATilingCheck { +public: + explicit SFATilingCheck(const SFATilingInfo &sfaInfo) : sfaInfo_(sfaInfo) {}; + ~SFATilingCheck() = default; + virtual ge::graphStatus Process(); +private: + void Init(); + void LogErrorDtypeSupport(const std::vector &expectDtypeList, + const ge::DataType &actualDtype, const std::string &name) const; + ge::graphStatus CheckDtypeSupport(const gert::CompileTimeTensorDesc *desc, + const std::string &name) const; + template void LogErrorNumberSupport(const std::vector &expectNumberList, + const T &actualValue, const std::string &name, const std::string subName) const; + template void LogErrorDimNumSupport(const std::vector &expectNumberList, + const T &actualValue, const std::string &name) const; + ge::graphStatus CheckDimNumSupport(const gert::StorageShape *shape, + const std::vector &expectDimNumList, const std::string &name) const; + ge::graphStatus CheckDimNumInLayoutSupport(const SFALayout &layout, + const gert::StorageShape *shape, const std::string &name) const; + void LogErrorLayoutSupport(const std::vector &expectLayoutList, + const SFALayout &actualLayout, const std::string &name) const; + ge::graphStatus GetExpectedShape(gert::Shape &shapeExpected, + const SFATilingShapeCompareParam ¶m, const SFALayout &layout) const; + ge::graphStatus CompareShape(SFATilingShapeCompareParam ¶m, + const gert::Shape &shape, const SFALayout &layout, const std::string &name) const; + ge::graphStatus CheckLayoutSupport(const SFALayout &actualLayout, const std::string &name) const; + ge::graphStatus CheckSingleParaQuery() const; + ge::graphStatus CheckSingleParaKey() const; + ge::graphStatus CheckSingleParaValue() const; + ge::graphStatus CheckSingleParaQueryRope() const; + ge::graphStatus CheckSingleParaKeyRope() const; + ge::graphStatus CheckSingleParaAttenOut() const; + ge::graphStatus CheckSingleParaNumHeads() const; + ge::graphStatus CheckSingleParaKvHeadNums() const; + ge::graphStatus CheckSingleParaLayout() const; + ge::graphStatus CheckSingleParaSparseMode() const; + ge::graphStatus CheckSingleParaSparseBlockSize() const; + ge::graphStatus CheckSingleParaSparseIndices() const; + ge::graphStatus CheckSinglePara() const; + ge::graphStatus CheckMultiParaConsistency() const; + ge::graphStatus CheckRopeExistence(); + ge::graphStatus CheckExists(const void *pointer, const std::string &name) const; + ge::graphStatus CheckNotExists(const void *pointer, const std::string &name) const; + ge::graphStatus CheckExistsByMap(const std::map ¶mMap) const; + ge::graphStatus CheckNotExistsByMap(const std::map ¶mMap) const; + ge::graphStatus CheckExistenceByMap(std::map &existMap, + std::map ¬ExistMap) const; + template ge::graphStatus CheckAttrValueByMap( + std::map> &attrMap) const; + ge::graphStatus CheckParaExistenceMlaNoquant() const; + ge::graphStatus CheckParaExistenceGqaNoquant() const; + ge::graphStatus CheckParaExistenceMla() const; + ge::graphStatus CheckParaExistence(); + ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor, + const SFALayout &layout, const std::string &name); + void SetSFAShapeCompare(); + ge::graphStatus CheckQRope(); + ge::graphStatus CheckQRopeShape(); + ge::graphStatus CheckVAndKRopeShapeForBatchContinuous(); + uint32_t GetTypeSize(ge::DataType dtype) const; + ge::graphStatus CheckVAndKRopeShapeForPageAttention(); + ge::graphStatus CheckVAndKRopeShape(); + ge::graphStatus CheckVAndKRope(); + ge::graphStatus CheckTopK(); + ge::graphStatus CheckTopkShape(); + ge::graphStatus CheckBlockTable() const; + ge::graphStatus CheckDTypeConsistency(const ge::DataType &actualDtype, + const ge::DataType &expectDtype, const std::string &name) const; + + ge::graphStatus CheckAttenOut(); + ge::graphStatus CheckAttenOutShape(); + ge::graphStatus CheckActualSeqLensQ(); + ge::graphStatus CheckActualSeqLensQShape(); + ge::graphStatus CheckActualSeqLensQDType(); + ge::graphStatus CheckActualSeqLens(); + ge::graphStatus CheckActualSeqLensDType(); + ge::graphStatus CheckActualSeqLensShape(); + ge::graphStatus CheckMultiParaConsistency(); + + ge::graphStatus CheckFeatureMlaNoQuantShape() const; + ge::graphStatus CheckFeatureMlaNoQuantLayout() const; + ge::graphStatus CheckFeatureMlaNoQuantDtype() const; + ge::graphStatus CheckFeatureMlaNoquantPa() const; + ge::graphStatus CheckFeatureMlaNoquant() const; + ge::graphStatus CheckFeatureMla() const; + ge::graphStatus CheckFeature() const; + +private: + const char *opName_; + fe::PlatFormInfos *platformInfo_; + SFAParaInfo opParamInfo_; + const SFATilingInfo &sfaInfo_; + + uint32_t bSize_ = 0; + uint32_t n1Size_ = 0; + uint32_t n2Size_ = 0; + uint32_t gSize_ = 0; + uint32_t s1Size_ = 0; + int64_t s2Size_ = 0; + uint32_t qkHeadDim_ = 0; + uint32_t vHeadDim_ = 0; + uint32_t ropeHeadDim_ = 0; + uint32_t qTSize_ = 0; + uint32_t kvTSize_ = 0; + KvStorageMode kvStorageMode_ = KvStorageMode::BATCH_CONTINUOUS; + uint32_t sparseBlockCount_ = 0; + int64_t sparseBlockSize_ = 0; + + SFALayout qLayout_ = SFALayout::BSND; + SFALayout topkLayout_ = SFALayout::BSND; + SFALayout outLayout_ = SFALayout::BSND; + SFALayout kvLayout_ = SFALayout::BSND; + + uint32_t maxBlockNumPerBatch_ = 0; + int64_t blockSize_ = 0; + + uint32_t aicNum_ = 0; + uint32_t aivNum_ = 0; + platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B; + uint64_t l2CacheSize_ = 0; + + ge::DataType inputQType_ = ge::DT_FLOAT16; + ge::DataType inputKvType_ = ge::DT_FLOAT16; + ge::DataType outputType_ = ge::DT_FLOAT16; + ge::DataType inputQRopeType_ = ge::DT_FLOAT16; + ge::DataType inputKRopeType_ = ge::DT_FLOAT16; + + gert::Shape queryShapeCmp_{}; + gert::Shape keyShapeCmp_{}; + gert::Shape valueShapeCmp_{}; + gert::Shape topkShapeCmp_{}; + gert::Shape queryRopeShapeCmp_{}; + gert::Shape keyRopeShapeCmp_{}; + gert::Shape attenOutShapeCmp_{}; +}; + +class SFAInfoParser { +public: + explicit SFAInfoParser(const gert::TilingContext *context) : context_(context) {} + ~SFAInfoParser() = default; + + ge::graphStatus CheckRequiredInOutExistence() const; + ge::graphStatus CheckRequiredAttrExistence() const; + ge::graphStatus CheckRequiredParaExistence() const; + + ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor, + SFALayout &layout, const std::string &name); + ge::graphStatus GetActualSeqLenQSize(uint32_t &size); + ge::graphStatus GetOpName(); + ge::graphStatus GetNpuInfo(); + void GetOptionalInputParaInfo(); + void GetInputParaInfo(); + void GetOutputParaInfo(); + ge::graphStatus GetAttrParaInfo(); + ge::graphStatus GetKvCache(); + ge::graphStatus GetOpParaInfo(); + + ge::graphStatus GetInOutDataType(); + ge::graphStatus GetBatchSize(); + ge::graphStatus GetQTSize(); + ge::graphStatus GetKVTSize(); + ge::graphStatus GetQkHeadDim(); + ge::graphStatus GetS1Size(); + ge::graphStatus GetKvStorageMode(); + ge::graphStatus GetKvLayout(); + void SetSFAShape(); + ge::graphStatus GetS2SizeForBatchContinuous(); + ge::graphStatus GetMaxBlockNumPerBatch(); + ge::graphStatus GetBlockSize(); + ge::graphStatus GetS2SizeForPageAttention(); + ge::graphStatus GetS2Size(); + ge::graphStatus GetValueHeadDim(); + ge::graphStatus GetRopeHeadDim(); + ge::graphStatus GetQueryAndOutLayout(); + ge::graphStatus GetTopkLayout(); + ge::graphStatus GetN1Size(); + ge::graphStatus GetN2Size(); + ge::graphStatus GetGSize(); + ge::graphStatus GetSparseBlockCount(); + ge::graphStatus GetActualseqInfo(); + void GenerateInfo(SFATilingInfo &sfaInfo); + ge::graphStatus Parse(SFATilingInfo &sfaInfo); + +public: + bool HasAxis(const SFAAxis &axis, const SFALayout &layout, const gert::Shape &shape) const; + size_t GetAxisIdx(const SFAAxis &axis, const SFALayout &layout) const; + uint32_t GetAxisNum(const gert::Shape &shape, const SFAAxis &axis,const SFALayout &layout) const; + + const gert::TilingContext *context_ = nullptr; + + const char *opName_; + fe::PlatFormInfos *platformInfo_; + SFAParaInfo opParamInfo_; + static constexpr int64_t invalidDimValue_ = std::numeric_limits::min(); + + uint32_t bSize_ = 0; + uint32_t n1Size_ = 0; + uint32_t n2Size_ = 0; + uint32_t gSize_ = 0; + uint32_t s1Size_ = 0; + int64_t s2Size_ = 0; + uint32_t qkHeadDim_ = 0; + uint32_t vHeadDim_ = 0; + uint32_t ropeHeadDim_ = 0; + uint32_t qTSize_ = 0; + uint32_t kvTSize_ = 0; + KvStorageMode kvStorageMode_ = KvStorageMode::BATCH_CONTINUOUS; + uint32_t sparseBlockCount_ = 0; + + SFALayout qLayout_ = SFALayout::BSND; + SFALayout topkLayout_ = SFALayout::BSND; + SFALayout outLayout_ = SFALayout::BSND; + SFALayout kvLayout_ = SFALayout::BSND; + + uint32_t maxBlockNumPerBatch_ = 0; + uint32_t blockSize_ = 0; + + platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B; + + ge::DataType inputQType_ = ge::DT_FLOAT16; + ge::DataType inputKvType_ = ge::DT_FLOAT16; + ge::DataType outputType_ = ge::DT_FLOAT16; + ge::DataType inputQRopeType_ = ge::DT_FLOAT16; + ge::DataType inputKRopeType_ = ge::DT_FLOAT16; + + uint64_t l2CacheSize_ = 0; + + bool isSameSeqAllKVTensor_ = true; + bool isSameActualseq_ = true; + uint32_t maxActualseq_ = 0; + + uint32_t actualLenDimsQ_ = 0; + uint32_t actualLenDimsKV_ = 0; + + gert::Shape queryShape_{}; + gert::Shape keyShape_{}; + gert::Shape valueShape_{}; + gert::Shape sparseIndicesShape_{}; + gert::Shape queryRopeShape_{}; + gert::Shape keyRopeShape_{}; +}; +} // namespace optiling +#endif // SPARSE_FLASH_ATTENTION_TILING_H diff --git a/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention.cpp b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention.cpp new file mode 100644 index 00000000..a71306b5 --- /dev/null +++ b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention.cpp @@ -0,0 +1,53 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + /*! + * \file sparse_flash_attention.cpp + * \brief + */ + +#include "kernel_operator.h" +#include "sparse_flash_attention_template_tiling_key.h" +#include "sparse_flash_attention_kernel_mla.h" + +using namespace AscendC; + +#define SFA_OP_IMPL(templateClass, tilingdataClass, ...) \ + do { \ + templateClass> op; \ + GET_TILING_DATA_WITH_STRUCT(tilingdataClass, tiling_data_in, tiling); \ + const tilingdataClass *__restrict tiling_data = &tiling_data_in; \ + op.Init(query, key, value, sparseIndices, actualSeqLengthsQuery, actualSeqLengthsKV, \ + blocktable, queryRope, keyRope, attentionOut, user, tiling_data, tiling, &tPipe); \ + op.Process(); \ + } while (0) + +template + __global__ __aicore__ void +sparse_flash_attention(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *value, + __gm__ uint8_t *sparseIndices, __gm__ uint8_t *blocktable, + __gm__ uint8_t *actualSeqLengthsQuery, __gm__ uint8_t *actualSeqLengthsKV, + __gm__ uint8_t* queryRope, __gm__ uint8_t* keyRope, + __gm__ uint8_t *attentionOut, __gm__ uint8_t *workspace, __gm__ uint8_t *tiling) +{ + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); + + TPipe tPipe; + __gm__ uint8_t *user = GetUserWorkspace(workspace); + + if constexpr (ORIG_DTYPE_QUERY == DT_FLOAT16 && ORIG_DTYPE_KEY == DT_FLOAT16 && + ORIG_DTYPE_ATTENTION_OUT == DT_FLOAT16) { + SFA_OP_IMPL(SparseFlashAttentionMla, SparseFlashAttentionTilingDataMla, half, half, half, + FLASH_DECODE, static_cast(LAYOUT_T), static_cast(KV_LAYOUT_T), TEMPLATE_MODE); + } else { // bf16 + SFA_OP_IMPL(SparseFlashAttentionMla, SparseFlashAttentionTilingDataMla, bfloat16_t, bfloat16_t, bfloat16_t, + FLASH_DECODE, static_cast(LAYOUT_T), static_cast(KV_LAYOUT_T), TEMPLATE_MODE); + } +} \ No newline at end of file diff --git a/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_common.h b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_common.h new file mode 100644 index 00000000..58530c60 --- /dev/null +++ b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_common.h @@ -0,0 +1,192 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_common.h + * \brief + */ + +#ifndef SPARSE_FLASH_ATTENTION_COMMON_H +#define SPARSE_FLASH_ATTENTION_COMMON_H + +#include "kernel_operator.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" + +using namespace AscendC; +constexpr SoftmaxConfig SFA_SOFTMAX_FLASHV2_CFG_WITHOUT_BRC = {false, 0, 0, SoftmaxMode::SOFTMAX_OUTPUT_WITHOUT_BRC}; + +enum class SFA_LAYOUT +{ + BSND = 0, + TND = 1, + PA_BSND = 2, +}; + +template +struct SFAType { + using queryType = Q_T; + using kvType = KV_T; + using outputType = OUT_T; + static constexpr bool flashDecode = FLASH_DECODE; + static constexpr SFA_LAYOUT layout = LAYOUT_T; + static constexpr SFA_LAYOUT kvLayout = KV_LAYOUT_T; + static constexpr int templateMode = TEMPLATE_MODE; + static constexpr bool pageAttention = (KV_LAYOUT_T == SFA_LAYOUT::PA_BSND); +}; + +// ================================Util functions================================== +template __aicore__ inline T SFAAlign(T num, T rnd) +{ + return (((rnd) == 0) ? 0 : (((num) + (rnd) - 1) / (rnd) * (rnd))); +} + +template __aicore__ inline T1 Min(T1 a, T2 b) +{ + return (a > b) ? (b) : (a); +} + +template __aicore__ inline size_t BlockAlign(size_t s) +{ + if constexpr (IsSameType::value) { + return (s + 63) / 64 * 64; + } + size_t n = (32 / sizeof(T)); + return (s + n - 1) / n * n; +} + +struct RunInfo { + uint32_t loop; + uint32_t bIdx; + uint32_t gIdx; + uint32_t s1Idx; + uint32_t s2Idx; + uint32_t bn2IdxInCurCore; + uint32_t curSInnerLoopTimes; + uint64_t tndBIdxOffsetForQ; + uint64_t tndBIdxOffsetForKV; + uint64_t tensorAOffset; + uint64_t tensorBOffset; + uint64_t tensorARopeOffset; + uint64_t tensorBRopeOffset; + uint64_t attenOutOffset; + uint64_t attenMaskOffset; + uint64_t topKBaseOffset; + uint32_t actualSingleProcessSInnerSize; + uint32_t actualSingleProcessSInnerSizeAlign; + bool isFirstSInnerLoop; + bool isChangeBatch; + uint32_t s2BatchOffset; + uint32_t gSize; + uint32_t s1Size; + uint32_t s2Size; + uint32_t mSize; + uint32_t mSizeV; + uint32_t mSizeVStart; + uint32_t tndIsS2SplitCore; + uint32_t tndCoreStartKVSplitPos; + bool isBmm2Output; + bool isValid = false; + + static constexpr uint32_t n2Idx = 0; + uint64_t actS1Size = 1; + uint64_t curActualSeqLenOri = 0ULL; + + uint32_t gS1Idx; + uint64_t actS2Size = 1; + uint32_t actMBaseSize; + bool isLastS2Loop; + int32_t nextTokensPerBatch = 0; + int64_t threshold; + uint32_t curTopKIdx = 0; + uint64_t curOffsetInSparseBlock = 0; +}; + +struct ConstInfo { + static constexpr uint32_t SFA_SYNC_MODE2 = 2; + static constexpr uint32_t BUFFER_SIZE_BYTE_32B = 32; + static constexpr uint32_t BUFFER_SIZE_BYTE_64B = 64; + static constexpr uint32_t BUFFER_SIZE_BYTE_256B = 256; + static constexpr uint32_t BUFFER_SIZE_BYTE_512B = 512; + static constexpr uint32_t BUFFER_SIZE_BYTE_1K = 1024; + static constexpr uint32_t BUFFER_SIZE_BYTE_2K = 2048; + static constexpr uint32_t BUFFER_SIZE_BYTE_4K = 4096; + static constexpr uint32_t BUFFER_SIZE_BYTE_8K = 8192; + static constexpr uint32_t BUFFER_SIZE_BYTE_16K = 16384; + static constexpr uint32_t BUFFER_SIZE_BYTE_32K = 32768; + static constexpr float FLOAT_ZERO = 0; + static constexpr float FLOAT_MAX = 3.402823466e+38F; + + uint32_t preLoadNum = 0U; + uint32_t nBufferMBaseSize = 0U; + uint32_t syncV1NupdateC2 = 0U; + uint32_t syncV0C1 = 0U; + uint32_t syncC1V1 = 0U; + uint32_t syncV1C2 = 0U; + uint32_t syncC2V2 = 0U; + uint32_t syncC2V1 = 0U; + + uint32_t mmResUbSize = 0U; + uint32_t vec1ResUbSize = 0U; + uint32_t bmm2ResUbSize = 0U; + uint64_t batchSize = 0ULL; + uint64_t gSize = 0ULL; + uint64_t qHeadNum = 0ULL; + uint64_t kvHeadNum; + uint64_t headDim; + uint64_t headDimRope; + uint64_t kvSeqSize = 0ULL; + uint64_t qSeqSize = 1ULL; + int64_t kvCacheBlockSize = 0; + uint32_t maxBlockNumPerBatch = 0; + uint32_t splitKVNum = 0U; + SFA_LAYOUT outputLayout; + uint32_t sparseMode = 0; + bool needInit = false; + + // FlashDecoding + uint32_t actualCombineLoopSize = 0U; + uint64_t combineLseOffset = 0ULL; + uint64_t combineAccumOutOffset = 0ULL; + + uint32_t actualLenDimsQ = 0U; + uint32_t actualLenDimsKV = 0U; + + // TND + uint32_t s2Start = 0U; + uint32_t s2End = 0U; + + uint32_t bN2Start = 0U; + uint32_t bN2End = 0U; + uint32_t gS1Start = 0U; + uint32_t gS1End = 0U; + + uint32_t tndFDCoreArrLen = 0U; + uint32_t coreStartKVSplitPos = 0U; + + uint32_t mBaseSize = 1ULL; + uint32_t s2BaseSize = 1ULL; + + // sparse attr + int64_t sparseBlockSize = 0; + uint32_t sparseBlockCount = 0; +}; + +struct MSplitInfo { + uint32_t nBufferIdx = 0U; + uint32_t nBufferStartM = 0U; + uint32_t nBufferDealM = 0U; + uint32_t vecStartM = 0U; + uint32_t vecDealM = 0U; +}; + +#endif // SPARSE_FLASH_ATTENTION_COMMON_H \ No newline at end of file diff --git a/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_kernel_mla.h b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_kernel_mla.h new file mode 100644 index 00000000..1aec5de6 --- /dev/null +++ b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_kernel_mla.h @@ -0,0 +1,969 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_kernel_mla.h + * \brief + */ + +#ifndef SPARSE_FLASH_ATTENTION_KERNEL_MLA_H +#define SPARSE_FLASH_ATTENTION_KERNEL_MLA_H + +#include "kernel_operator.h" +#include "kernel_operator_list_tensor_intf.h" +#include "kernel_tiling/kernel_tiling.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" +#include "sparse_flash_attention_common.h" +#include "sparse_flash_attention_service_cube_mla.h" +#include "sparse_flash_attention_service_vector_mla.h" + +using namespace matmul; +using AscendC::CacheMode; +using AscendC::CrossCoreSetFlag; +using AscendC::CrossCoreWaitFlag; + +struct TempLoopInfo { + uint32_t bn2IdxInCurCore = 0; + uint32_t bIdx = 0U; + uint32_t n2Idx = 0U; + uint64_t s2BasicSizeTail = 0U; + uint32_t s2LoopTimes = 0U; + uint64_t curActualSeqLen = 0ULL; + uint64_t curActualSeqLenOri = 0ULL; + bool curActSeqLenIsZero = false; + int32_t nextTokensPerBatch = 0; + + uint64_t actS1Size = 1ULL; + uint32_t tndCoreStartKVSplitPos; + bool tndIsS2SplitCore; + + uint32_t gS1Idx = 0U; + uint64_t mBasicSizeTail = 0U; +}; + +template class SparseFlashAttentionMla { +public: + using T = float; + using Q_T = typename SFAT::queryType; + using KV_T = typename SFAT::kvType; + using OUT_T = typename SFAT::outputType; + using Q_ROPE_T = Q_T; + using K_ROPE_T = KV_T; + using UPDATE_T = T; + using MM1_OUT_T = T; + using MM2_OUT_T = T; + + __aicore__ inline SparseFlashAttentionMla(){}; + __aicore__ inline void Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *value, + __gm__ uint8_t *sparseIndices, __gm__ uint8_t *actualSeqLengthsQ, + __gm__ uint8_t *actualSeqLengths, __gm__ uint8_t *blockTable, + __gm__ uint8_t *queryRope, __gm__ uint8_t *keyRope, + __gm__ uint8_t *attentionOut, __gm__ uint8_t *workspace, + const SparseFlashAttentionTilingDataMla *__restrict tiling, + __gm__ uint8_t *gmTiling, TPipe *tPipe); + + __aicore__ inline void Process(); + +private: + static constexpr bool PAGE_ATTENTION = SFAT::pageAttention; + static constexpr int TEMPLATE_MODE = SFAT::templateMode; + static constexpr bool FLASH_DECODE = SFAT::flashDecode; + static constexpr SFA_LAYOUT LAYOUT_T = SFAT::layout; + static constexpr SFA_LAYOUT KV_LAYOUT_T = SFAT::kvLayout; + + static constexpr uint32_t PRELOAD_NUM = 2; + static constexpr uint32_t N_BUFFER_M_BASIC_SIZE = 256; + static constexpr uint32_t SFA_PRELOAD_TASK_CACHE_SIZE = 3; + + static constexpr uint32_t SYNC_V0_C1_FLAG = 6; + static constexpr uint32_t SYNC_C1_V1_FLAG = 7; + static constexpr uint32_t SYNC_V1_C2_FLAG = 8; + static constexpr uint32_t SYNC_C2_V2_FLAG = 9; + static constexpr uint32_t SYNC_C2_V1_FLAG = 4; + static constexpr uint32_t SYNC_V1_NUPDATE_C2_FLAG = 5; + + static constexpr uint64_t SYNC_MM2RES_BUF1_FLAG = 10; + static constexpr uint64_t SYNC_MM2RES_BUF2_FLAG = 11; + static constexpr uint64_t SYNC_FDOUTPUT_BUF_FLAG = 12; + + static constexpr uint32_t BLOCK_ELEMENT_NUM = SFAVectorService::BYTE_BLOCK / sizeof(T); + + static constexpr uint64_t kvHeadNum = 1ULL; + static constexpr uint64_t headDim = 512ULL; + static constexpr uint64_t headDimAlign = 512ULL; + static constexpr uint64_t headDimRope = 64ULL; + static constexpr uint32_t msdIterNum = 2U; + + static constexpr uint32_t dbWorkspaceRatio = PRELOAD_NUM; + + const SparseFlashAttentionTilingDataMla *__restrict tilingData = nullptr; + + TPipe *pipe = nullptr; + + uint64_t mSizeVStart = 0ULL; + int64_t threshold = 0; + uint64_t topKBaseOffset = 0ULL; + uint64_t s2BatchBaseOffset = 0; + uint64_t tensorACoreOffset = 0ULL; + uint64_t tensorBCoreOffset = 0ULL; + uint64_t tensorARopeCoreOffset = 0ULL; + uint64_t tensorBRopeCoreOffset = 0ULL; + uint64_t tensorBOffset = 0ULL; + uint64_t attenOutOffset = 0ULL; + + uint32_t tmpBlockIdx = 0U; + uint32_t aiCoreIdx = 0U; + uint32_t usedCoreNum = 0U; + + __gm__ uint8_t *keyPtr = nullptr; + __gm__ uint8_t *valuePtr = nullptr; + + ConstInfo constInfo{}; + TempLoopInfo tempLoopInfo{}; + + SFAMatmulService matmulService; + SFAVectorService vectorService; + + GlobalTensor queryGm; + GlobalTensor keyGm; + GlobalTensor valueGm; + GlobalTensor qRopeGm; + GlobalTensor kRopeGm; + + GlobalTensor attentionOutGm; + GlobalTensor blockTableGm; + GlobalTensor topKGm; + + GlobalTensor actualSeqLengthsQGm; + GlobalTensor actualSeqLengthsKVGm; + + // workspace + GlobalTensor mm1ResGm; + GlobalTensor vec1ResGm; + GlobalTensor mm2ResGm; + GlobalTensor kvMergeGm_; + GlobalTensor kvValidSizeGm_; + + GlobalTensor mm2ResInt32Gm; + GlobalTensor vec2ResGm; + + GlobalTensor accumOutGm; + GlobalTensor lseSumFdGm; + GlobalTensor lseMaxFdGm; + + // ================================Init functions=================================== + __aicore__ inline void InitTilingData(); + __aicore__ inline void InitCalcParamsEach(); + __aicore__ inline void InitBuffers(); + __aicore__ inline void InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths); + __aicore__ inline void InitOutputSingleCore(); + // ================================Process functions================================ + __aicore__ inline void ProcessBalance(); + __aicore__ inline void PreloadPipeline(uint32_t loop, uint64_t s2Start, uint64_t s2LoopIdx, + RunInfo extraInfo[SFA_PRELOAD_TASK_CACHE_SIZE], uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock); + // ================================Offset Calc===================================== + __aicore__ inline void GetActualSeqLen(uint32_t bIdx, uint32_t s1Idx = 0); + __aicore__ inline void GetSparseActualSeqLen(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx); + __aicore__ inline void CalcSinnerTopKBegin(RunInfo &info, uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock); + __aicore__ inline void UpdateInnerLoopCond(); + __aicore__ inline void DealActSeqLenIsZero(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx); + __aicore__ inline void CalcParams(uint32_t loop, uint64_t s2Start, uint32_t s2LoopIdx, RunInfo &info); + __aicore__ inline void GetAxisStartIdx(uint32_t bN2EndPrev, uint32_t gS1EndPrev, uint32_t s2EndPrev); + __aicore__ inline uint64_t GetBalanceActualSeqLengths(GlobalTensor &actualSeqLengths, uint32_t bIdx); + __aicore__ inline uint32_t GetActualSeqLenKV(uint32_t bIdx); + __aicore__ inline void GetBN2Idx(uint32_t bN2Idx, uint32_t &bIdx, uint32_t &n2Idx); + __aicore__ inline void UpdateInner(uint32_t &s2End, uint32_t &curS2End, uint32_t s1Idx, bool isEnd); + __aicore__ inline void GetPreNextTokensLeftUp(); + // ================================Mm1============================================== + __aicore__ inline void ComputeMm1(const RunInfo &info); + // ================================Mm2============================================== + __aicore__ inline void ComputeMm2(const RunInfo &info); + __aicore__ inline void Bmm2DataCopyOut(uint64_t attenOutOffset, LocalTensor &attenOutUb, uint32_t startRow, + uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount); + __aicore__ inline void InitAllZeroOutput(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx); +}; + +template __aicore__ inline void SparseFlashAttentionMla::InitTilingData() +{ + usedCoreNum = tilingData->singleCoreParams.usedCoreNum; + constInfo.splitKVNum = tilingData->splitKVParams.s2; + constInfo.mmResUbSize = tilingData->singleCoreTensorSize.mmResUbSize; + constInfo.bmm2ResUbSize = tilingData->singleCoreTensorSize.bmm2ResUbSize; + constInfo.vec1ResUbSize = constInfo.mmResUbSize * msdIterNum; + + constInfo.batchSize = tilingData->baseParams.batchSize; + constInfo.qHeadNum = constInfo.gSize = tilingData->baseParams.nNumOfQInOneGroup; + constInfo.kvSeqSize = tilingData->baseParams.seqSize; + constInfo.qSeqSize = tilingData->baseParams.qSeqSize; + constInfo.maxBlockNumPerBatch = tilingData->baseParams.maxBlockNumPerBatch; + constInfo.kvCacheBlockSize = tilingData->baseParams.blockSize; + constInfo.outputLayout = static_cast(tilingData->baseParams.outputLayout); + constInfo.mBaseSize = tilingData->innerSplitParams.mBaseSize; + constInfo.s2BaseSize = tilingData->innerSplitParams.s2BaseSize; + constInfo.kvHeadNum = kvHeadNum; + constInfo.headDim = headDim; + constInfo.headDimRope = headDimRope; + constInfo.sparseBlockSize = tilingData->baseParams.sparseBlockSize; + constInfo.sparseBlockCount = tilingData->baseParams.sparseBlockCount; + constInfo.sparseMode = tilingData->baseParams.sparseMode; + + constInfo.preLoadNum = PRELOAD_NUM; + constInfo.nBufferMBaseSize = N_BUFFER_M_BASIC_SIZE; + constInfo.syncV0C1 = SYNC_V0_C1_FLAG; + constInfo.syncC1V1 = SYNC_C1_V1_FLAG; + constInfo.syncV1C2 = SYNC_V1_C2_FLAG; + constInfo.syncC2V2 = SYNC_C2_V2_FLAG; + constInfo.syncC2V1 = SYNC_C2_V1_FLAG; + constInfo.syncV1NupdateC2 = SYNC_V1_NUPDATE_C2_FLAG; +} + +template __aicore__ inline void SparseFlashAttentionMla::InitBuffers() +{ + if ASCEND_IS_AIV { + vectorService.InitBuffers(pipe); + } else { + matmulService.InitBuffers(pipe); + } +} + +template +__aicore__ inline void +SparseFlashAttentionMla::InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, + __gm__ uint8_t *actualSeqLengths) +{ + constInfo.actualLenDimsQ = tilingData->baseParams.actualLenDimsQ; + constInfo.actualLenDimsKV = tilingData->baseParams.actualLenDimsKV; + if (constInfo.actualLenDimsKV != 0) { + actualSeqLengthsKVGm.SetGlobalBuffer((__gm__ int32_t *)actualSeqLengths, constInfo.actualLenDimsKV); + } + if (constInfo.actualLenDimsQ != 0) { + actualSeqLengthsQGm.SetGlobalBuffer((__gm__ int32_t *)actualSeqLengthsQ, constInfo.actualLenDimsQ); + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::InitAllZeroOutput(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx) +{ + if (constInfo.outputLayout == SFA_LAYOUT::TND) { + uint32_t tBase = bIdx == 0 ? 0 : actualSeqLengthsQGm.GetValue(bIdx - 1); + uint32_t s1Count = tempLoopInfo.actS1Size; + + uint64_t attenOutOffset = (tBase + s1Idx) * kvHeadNum * constInfo.gSize * headDim + + n2Idx * constInfo.gSize * headDim; + matmul::InitOutput(attentionOutGm[attenOutOffset], constInfo.gSize * headDim, 0); + } else if (constInfo.outputLayout == SFA_LAYOUT::BSND) { + uint64_t attenOutOffset = bIdx * constInfo.qSeqSize * kvHeadNum * constInfo.gSize * headDim + + s1Idx * kvHeadNum * constInfo.gSize * headDim + + n2Idx * constInfo.gSize * headDim; + matmul::InitOutput(attentionOutGm[attenOutOffset], constInfo.gSize * headDim, 0); + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::InitOutputSingleCore() +{ + uint32_t coreNum = GetBlockNum(); + if (coreNum != 0) { + uint64_t totalOutputSize = constInfo.batchSize * constInfo.qHeadNum * constInfo.qSeqSize * constInfo.headDim; + uint64_t singleCoreSize = (totalOutputSize + (2 * coreNum) - 1) / (2 * coreNum); // 2 means c:v = 1:2 + uint64_t tailSize = totalOutputSize - tmpBlockIdx * singleCoreSize; + uint64_t singleInitOutputSize = tailSize < singleCoreSize ? tailSize : singleCoreSize; + if (singleInitOutputSize > 0) { + matmul::InitOutput(attentionOutGm[tmpBlockIdx * singleCoreSize], singleInitOutputSize, 0); + } + SyncAll(); + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::GetActualSeqLen(uint32_t bIdx, uint32_t s1Idx) +{ + tempLoopInfo.curActualSeqLenOri = GetActualSeqLenKV(bIdx); + tempLoopInfo.actS1Size = GetBalanceActualSeqLengths(actualSeqLengthsQGm, bIdx); +} + +template +__aicore__ inline void SparseFlashAttentionMla::GetSparseActualSeqLen(uint32_t bIdx, uint32_t s1Idx, + uint32_t n2Idx) +{ + if (tempLoopInfo.nextTokensPerBatch < 0 && s1Idx < (-tempLoopInfo.nextTokensPerBatch)) { + tempLoopInfo.curActualSeqLen = 0; + return; + } + int64_t threshold = tempLoopInfo.curActualSeqLenOri; + if (constInfo.sparseMode == 3) { + threshold = static_cast(tempLoopInfo.nextTokensPerBatch) + s1Idx + 1; + } + + tempLoopInfo.curActualSeqLen = (constInfo.sparseBlockCount * constInfo.sparseBlockSize > threshold) ? + threshold : + constInfo.sparseBlockCount * constInfo.sparseBlockSize; +} + +template +__aicore__ inline uint32_t SparseFlashAttentionMla::GetActualSeqLenKV(uint32_t bIdx) +{ + if constexpr (KV_LAYOUT_T == SFA_LAYOUT::TND) { + if (bIdx > 0) { + return actualSeqLengthsKVGm.GetValue(bIdx) - actualSeqLengthsKVGm.GetValue(bIdx - 1); + } else if (bIdx == 0) { + return actualSeqLengthsKVGm.GetValue(0); + } else { + return 0; + } + } else { + if (constInfo.actualLenDimsKV == 0) { + return constInfo.kvSeqSize; + } else if (constInfo.actualLenDimsKV == 1) { + return actualSeqLengthsKVGm.GetValue(0); + } else { + return actualSeqLengthsKVGm.GetValue(bIdx); + } + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::DealActSeqLenIsZero(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx) +{ + if ASCEND_IS_AIV { + InitAllZeroOutput(bIdx, s1Idx, n2Idx); + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::GetPreNextTokensLeftUp() +{ + if (constInfo.sparseMode == 3) { + tempLoopInfo.nextTokensPerBatch = + static_cast(tempLoopInfo.curActualSeqLenOri) - static_cast(tempLoopInfo.actS1Size); + } +} + +template __aicore__ inline void SparseFlashAttentionMla::UpdateInnerLoopCond() +{ + if ((tempLoopInfo.curActualSeqLen == 0) || (tempLoopInfo.actS1Size == 0)) { + tempLoopInfo.curActSeqLenIsZero = true; + return; + } + tempLoopInfo.curActSeqLenIsZero = false; + tempLoopInfo.mBasicSizeTail = (tempLoopInfo.actS1Size * constInfo.gSize) % constInfo.mBaseSize; + tempLoopInfo.mBasicSizeTail = + (tempLoopInfo.mBasicSizeTail == 0) ? constInfo.mBaseSize : tempLoopInfo.mBasicSizeTail; + tempLoopInfo.s2LoopTimes = 0; +} + +template +__aicore__ inline void SparseFlashAttentionMla::UpdateInner(uint32_t &s2End, uint32_t &curS2End, + uint32_t s1Idx, bool isEnd) +{ + uint32_t s1BaseSize = 1; + int64_t s1Offset = s1BaseSize * s1Idx; + int64_t s2LastToken = Min(s1Offset + tempLoopInfo.nextTokensPerBatch + s1BaseSize,tempLoopInfo.curActualSeqLenOri); + s2LastToken = Min(constInfo.sparseBlockSize * constInfo.sparseBlockCount, s2LastToken); + curS2End = (s2LastToken + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize; + tempLoopInfo.s2LoopTimes = isEnd ? constInfo.s2End + 1 : curS2End; +} + +template +__aicore__ inline void SparseFlashAttentionMla::Init(__gm__ uint8_t *query, + __gm__ uint8_t *key, __gm__ uint8_t *value, + __gm__ uint8_t *sparseIndices, __gm__ uint8_t *actualSeqLengthsQ, + __gm__ uint8_t *actualSeqLengths, __gm__ uint8_t *blockTable, + __gm__ uint8_t *queryRope, __gm__ uint8_t *keyRope, + __gm__ uint8_t *attentionOut, __gm__ uint8_t *workspace, + const SparseFlashAttentionTilingDataMla *__restrict tiling, + __gm__ uint8_t *gmTiling, TPipe *tPipe) +{ + if ASCEND_IS_AIV { + tmpBlockIdx = GetBlockIdx(); // vec:0-47 + aiCoreIdx = tmpBlockIdx / 2; + } else { + tmpBlockIdx = GetBlockIdx(); // cube:0-23 + aiCoreIdx = tmpBlockIdx; + } + + // init tiling data + tilingData = tiling; + + InitTilingData(); + InitActualSeqLen(actualSeqLengthsQ, actualSeqLengths); + + InitCalcParamsEach(); + pipe = tPipe; + keyPtr = key; + valuePtr = value; + + // init global buffer + queryGm.SetGlobalBuffer((__gm__ Q_T *)query); + keyGm.SetGlobalBuffer((__gm__ KV_T *)keyPtr); + valueGm.SetGlobalBuffer((__gm__ KV_T *)valuePtr); + qRopeGm.SetGlobalBuffer((__gm__ Q_ROPE_T *)queryRope); + kRopeGm.SetGlobalBuffer((__gm__ K_ROPE_T *)keyRope); + + attentionOutGm.SetGlobalBuffer((__gm__ OUT_T *)attentionOut); + + if ASCEND_IS_AIV { + if (constInfo.needInit && LAYOUT_T != SFA_LAYOUT::TND) { + InitOutputSingleCore(); + } + } + + if constexpr (PAGE_ATTENTION) { + blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable); + } + topKGm.SetGlobalBuffer((__gm__ int32_t *)sparseIndices); + + uint64_t offset = 0; + mm1ResGm.SetGlobalBuffer( + (__gm__ MM1_OUT_T *)(workspace + offset + + aiCoreIdx * dbWorkspaceRatio * constInfo.mmResUbSize * sizeof(MM1_OUT_T))); + offset += GetBlockNum() * dbWorkspaceRatio * constInfo.mmResUbSize * sizeof(MM1_OUT_T); + + vec1ResGm.SetGlobalBuffer( + (__gm__ KV_T *)(workspace + offset + aiCoreIdx * dbWorkspaceRatio * constInfo.mmResUbSize * sizeof(KV_T))); + offset += GetBlockNum() * dbWorkspaceRatio * constInfo.mmResUbSize * sizeof(KV_T); + + mm2ResGm.SetGlobalBuffer( + (__gm__ MM2_OUT_T *)(workspace + offset + + aiCoreIdx * dbWorkspaceRatio * constInfo.bmm2ResUbSize * sizeof(MM2_OUT_T))); + offset += GetBlockNum() * dbWorkspaceRatio * constInfo.bmm2ResUbSize * sizeof(MM2_OUT_T); + mm2ResInt32Gm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(mm2ResGm.GetPhyAddr(0))); + + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + // s2 d+rope bufNum + kvMergeGm_.SetGlobalBuffer((__gm__ KV_T *)(workspace + offset + aiCoreIdx * 512 * 576 * 4 * sizeof(KV_T))); + offset += GetBlockNum() * 512 * 576 * 4 * sizeof(KV_T); + + kvValidSizeGm_.SetGlobalBuffer( + (__gm__ int32_t *)(workspace + offset + (aiCoreIdx * 2) * 128 * 4 * sizeof(int32_t))); + } + + if constexpr (FLASH_DECODE) { + accumOutGm.SetGlobalBuffer((__gm__ float *)(workspace + offset)); + offset = offset + tilingData->splitKVParams.accumOutSize * sizeof(float); + lseSumFdGm.SetGlobalBuffer((__gm__ float *)(workspace + offset)); + lseMaxFdGm.SetGlobalBuffer((__gm__ float *)(workspace + offset) + tilingData->splitKVParams.logSumExpSize / 2); + offset = offset + tilingData->splitKVParams.logSumExpSize * sizeof(float); + } + + if ASCEND_IS_AIV { + vectorService.InitParams(constInfo, tilingData); + vectorService.InitMm2ResInt32GmGlobalTensor(mm2ResInt32Gm); + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + vectorService.InitVec0GlobalTensor(kvValidSizeGm_, kvMergeGm_, kRopeGm, keyGm, blockTableGm); + } + vectorService.InitVec1GlobalTensor(mm1ResGm, vec1ResGm, actualSeqLengthsQGm, + actualSeqLengthsKVGm, lseMaxFdGm, lseSumFdGm, topKGm); + vectorService.InitVec2GlobalTensor(accumOutGm, vec2ResGm, mm2ResGm, attentionOutGm); + } + + if ASCEND_IS_AIC { + matmulService.InitParams(constInfo); + matmulService.InitMm1GlobalTensor(queryGm, qRopeGm, keyGm, kRopeGm, mm1ResGm); + matmulService.InitMm2GlobalTensor(vec1ResGm, valueGm, mm2ResGm, attentionOutGm); + matmulService.InitPageAttentionInfo(kvMergeGm_, blockTableGm, topKGm, + constInfo.kvCacheBlockSize, constInfo.maxBlockNumPerBatch); + } + if (pipe != nullptr) { + InitBuffers(); + } +} + +template __aicore__ inline void SparseFlashAttentionMla::InitCalcParamsEach() +{ + uint32_t totalBaseNum = 0; + uint32_t s1GBaseSize = constInfo.gSize; + uint32_t actBatchS2 = 1; + uint32_t coreNum = GetBlockNum(); + uint32_t currCoreIdx = aiCoreIdx; + uint32_t actBatchS1 = 1; + for (uint32_t bIdx = 0; bIdx < constInfo.batchSize; bIdx++) { + uint32_t actBatchS1 = GetBalanceActualSeqLengths(actualSeqLengthsQGm, bIdx); + if (actBatchS1 < constInfo.qSeqSize) { + constInfo.needInit = true; + } + totalBaseNum += actBatchS1*actBatchS2 ; + } + uint32_t avgBaseNum = 1; + if (totalBaseNum > coreNum) { + avgBaseNum = (totalBaseNum + coreNum - 1) / coreNum; + }else { + usedCoreNum = totalBaseNum; + } + if(aiCoreIdx>=usedCoreNum){ + return; + } + uint32_t accumBaseNum = 0; + uint32_t targetBaseNum = 0; + uint32_t lastValidBIdx = 0; + uint32_t lastValidactBatchS1=0; + bool setStart=false; + targetBaseNum = (currCoreIdx + 1) * avgBaseNum; + uint32_t targetStartBaseNum = targetBaseNum-avgBaseNum; + for (uint32_t bN2Idx = 0; bN2Idx < constInfo.batchSize * constInfo.kvHeadNum; bN2Idx++) { + uint32_t bIdx = bN2Idx / constInfo.kvHeadNum; + actBatchS1 = GetBalanceActualSeqLengths(actualSeqLengthsQGm, bIdx); + for (uint32_t s1GIdx = 0; s1GIdx < actBatchS1; s1GIdx++) { + accumBaseNum += 1; + if(!setStart && accumBaseNum >= targetStartBaseNum){ + constInfo.bN2Start = bN2Idx; + constInfo.gS1Start = s1GIdx; + setStart=true; + } + if (accumBaseNum >= targetBaseNum) { + constInfo.bN2End = bN2Idx; + constInfo.gS1End = s1GIdx; + constInfo.s2End = 0; + constInfo.coreStartKVSplitPos = 0; + if (aiCoreIdx != 0) { + GetAxisStartIdx(constInfo.bN2Start, constInfo.gS1Start, 0); + } + return; + } + } + if ((actBatchS1 > 0) && (actBatchS2 > 0)) { + lastValidBIdx = bIdx; + lastValidactBatchS1 = actBatchS1; + } + } + if (!setStart){ + constInfo.bN2Start = lastValidBIdx; + constInfo.gS1Start = lastValidactBatchS1-1; + } + if (accumBaseNum < targetBaseNum) { + constInfo.bN2End = lastValidBIdx; + constInfo.gS1End = lastValidactBatchS1-1; + constInfo.s2End = 0; + constInfo.coreStartKVSplitPos = 0; + if (aiCoreIdx != 0) { + GetAxisStartIdx(constInfo.bN2Start, constInfo.gS1Start, 0); + } + return; + } +} + +template +__aicore__ inline void +SparseFlashAttentionMla::Bmm2DataCopyOut(uint64_t attenOutOffset, LocalTensor &attenOutUb, + uint32_t startRow, uint32_t dealRowCount, + uint32_t columnCount, uint32_t actualColumnCount) +{ + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = dealRowCount; + dataCopyParams.blockLen = actualColumnCount * sizeof(OUT_T); + dataCopyParams.srcStride = (columnCount - actualColumnCount) / (SFAVectorService::BYTE_BLOCK / sizeof(OUT_T)); + dataCopyParams.dstStride = 0; + DataCopyPad(attentionOutGm[attenOutOffset + (mSizeVStart + startRow) * actualColumnCount], attenOutUb, + dataCopyParams); +} + + +template +__aicore__ inline void SparseFlashAttentionMla::CalcParams(uint32_t loop, uint64_t s2Start, + uint32_t s2LoopIdx, RunInfo &info) +{ + info.loop = loop; + info.bIdx = tempLoopInfo.bIdx; + info.gS1Idx = tempLoopInfo.gS1Idx; + info.s2Idx = s2LoopIdx; + info.curSInnerLoopTimes = tempLoopInfo.s2LoopTimes; + + info.tndIsS2SplitCore = tempLoopInfo.tndIsS2SplitCore; + info.tndCoreStartKVSplitPos = tempLoopInfo.tndCoreStartKVSplitPos; + info.isBmm2Output = false; + + info.actS1Size = tempLoopInfo.actS1Size; + + + info.actMBaseSize = constInfo.mBaseSize; + uint32_t remainedGS1Size = tempLoopInfo.actS1Size * constInfo.gSize - tempLoopInfo.gS1Idx; + if (remainedGS1Size <= constInfo.mBaseSize && remainedGS1Size > 0) { + info.actMBaseSize = tempLoopInfo.mBasicSizeTail; + } + + info.isValid = s2LoopIdx < tempLoopInfo.s2LoopTimes; + + if ASCEND_IS_AIV { + info.mSize = info.actMBaseSize; + info.mSizeV = (info.mSize <= 16) ? info.mSize : (((info.mSize + 15) / 16 + 1) / 2 * 16); + info.mSizeVStart = 0; + if (tmpBlockIdx % 2 == 1) { + info.mSizeVStart = info.mSizeV; + info.mSizeV = info.mSize - info.mSizeV; + } + } + + info.isChangeBatch = false; + + info.isFirstSInnerLoop = s2LoopIdx == s2Start; + if (info.isFirstSInnerLoop) { + tempLoopInfo.bn2IdxInCurCore++; + } + info.isLastS2Loop = s2LoopIdx == tempLoopInfo.s2LoopTimes - 1; + info.bn2IdxInCurCore = tempLoopInfo.bn2IdxInCurCore - 1; + uint64_t actualSeqQPrefixSum; + if constexpr (LAYOUT_T == SFA_LAYOUT::TND) { + actualSeqQPrefixSum = (info.bIdx <= 0) ? 0 : actualSeqLengthsQGm.GetValue(info.bIdx - 1); + } else { + actualSeqQPrefixSum = (info.bIdx <= 0) ? 0 : info.bIdx * constInfo.qSeqSize; + } + info.tndBIdxOffsetForQ = actualSeqQPrefixSum * constInfo.qHeadNum * headDim; + + uint64_t actualSeqKVPrefixSum; + if constexpr (KV_LAYOUT_T == SFA_LAYOUT::TND) { + actualSeqKVPrefixSum = (info.bIdx <= 0) ? 0 : actualSeqLengthsKVGm.GetValue(info.bIdx - 1); + } else { + actualSeqKVPrefixSum = (info.bIdx <= 0) ? 0 : info.bIdx * constInfo.kvSeqSize; + } + info.tndBIdxOffsetForKV = actualSeqKVPrefixSum * constInfo.kvHeadNum * headDim; + + if (info.isFirstSInnerLoop) { + uint64_t tndBIdxRopeOffsetForQ = actualSeqQPrefixSum * constInfo.qHeadNum * headDimRope; + tensorACoreOffset = info.tndBIdxOffsetForQ + info.gS1Idx * headDim; + tensorARopeCoreOffset = tndBIdxRopeOffsetForQ + info.gS1Idx * headDimRope; + + uint64_t tndBIdxRopeOffsetForK = actualSeqKVPrefixSum * constInfo.kvHeadNum * headDimRope; + tensorBCoreOffset = info.tndBIdxOffsetForKV + info.n2Idx * headDim; + tensorBRopeCoreOffset = tndBIdxRopeOffsetForK + info.n2Idx * headDimRope; + if (constInfo.sparseMode == 3) { + threshold = static_cast(tempLoopInfo.nextTokensPerBatch) + info.gS1Idx / constInfo.gSize + 1; + } else { + threshold = tempLoopInfo.curActualSeqLenOri; + } + if constexpr(LAYOUT_T == SFA_LAYOUT::BSND) { // B,S1,N2 K + topKBaseOffset = info.bIdx * constInfo.qSeqSize * constInfo.kvHeadNum * constInfo.sparseBlockCount + + info.gS1Idx / constInfo.gSize * constInfo.kvHeadNum * constInfo.sparseBlockCount + + info.n2Idx * constInfo.sparseBlockCount; + } else if (LAYOUT_T == SFA_LAYOUT::TND) { // T N2 K + topKBaseOffset = info.tndBIdxOffsetForQ / constInfo.gSize / constInfo.headDim * constInfo.kvHeadNum * + constInfo.sparseBlockCount + info.n2Idx * constInfo.sparseBlockCount + + info.gS1Idx / constInfo.gSize * constInfo.kvHeadNum * constInfo.sparseBlockCount; + } else { // B N2 S1 K + topKBaseOffset = info.bIdx * constInfo.kvHeadNum * constInfo.qSeqSize * constInfo.sparseBlockCount + + info.n2Idx * constInfo.qSeqSize * constInfo.sparseBlockCount + + info.gS1Idx / constInfo.gSize * constInfo.sparseBlockCount; + } + } + info.topKBaseOffset = topKBaseOffset; + info.threshold = threshold; + info.tensorAOffset = tensorACoreOffset; + info.tensorARopeOffset = tensorARopeCoreOffset; + info.tensorBOffset = tensorBCoreOffset; + info.tensorBRopeOffset = tensorBRopeCoreOffset; + info.attenOutOffset = tensorACoreOffset; + + uint64_t sInnerOffsetDataSize = info.s2Idx * constInfo.s2BaseSize; + info.s2BatchOffset = s2BatchBaseOffset + sInnerOffsetDataSize; + + info.curActualSeqLenOri = tempLoopInfo.curActualSeqLenOri; + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + if (tempLoopInfo.curActualSeqLen > sInnerOffsetDataSize) { + info.actualSingleProcessSInnerSize = tempLoopInfo.curActualSeqLen - sInnerOffsetDataSize; + info.actualSingleProcessSInnerSize = info.actualSingleProcessSInnerSize > constInfo.s2BaseSize ? + constInfo.s2BaseSize : info.actualSingleProcessSInnerSize; + info.actualSingleProcessSInnerSize = + SFAAlign((int64_t)info.actualSingleProcessSInnerSize, (int64_t)constInfo.sparseBlockSize); + } else { + info.actualSingleProcessSInnerSize = 0; + } + info.actualSingleProcessSInnerSizeAlign = + SFAAlign((uint32_t)info.actualSingleProcessSInnerSize, (uint32_t)SFAVectorService::BYTE_BLOCK); + } + +} + +template +__aicore__ inline void SparseFlashAttentionMla::ComputeMm1(const RunInfo &info) +{ + uint32_t nBufferLoopTimes = (info.actMBaseSize + constInfo.nBufferMBaseSize - 1) / constInfo.nBufferMBaseSize; + uint32_t nBufferTail = info.actMBaseSize - (nBufferLoopTimes - 1) * constInfo.nBufferMBaseSize; + for (uint32_t i = 0; i < nBufferLoopTimes; i++) { + MSplitInfo mSplitInfo; + mSplitInfo.nBufferStartM = i * constInfo.nBufferMBaseSize; + mSplitInfo.nBufferDealM = (i + 1 != nBufferLoopTimes) ? constInfo.nBufferMBaseSize : nBufferTail; + matmulService.ComputeMm1(info, mSplitInfo); + CrossCoreSetFlag(constInfo.syncC1V1); + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::ComputeMm2(const RunInfo &info) +{ + uint32_t nBufferLoopTimes = (info.actMBaseSize + constInfo.nBufferMBaseSize - 1) / constInfo.nBufferMBaseSize; + uint32_t nBufferTail = info.actMBaseSize - (nBufferLoopTimes - 1) * constInfo.nBufferMBaseSize; + for (uint32_t i = 0; i < nBufferLoopTimes; i++) { + MSplitInfo mSplitInfo; + mSplitInfo.nBufferStartM = i * constInfo.nBufferMBaseSize; + mSplitInfo.nBufferDealM = (i + 1 != nBufferLoopTimes) ? constInfo.nBufferMBaseSize : nBufferTail; + CrossCoreWaitFlag(constInfo.syncV1C2); + matmulService.ComputeMm2(info, mSplitInfo); + CrossCoreSetFlag(constInfo.syncC2V2); + CrossCoreSetFlag(constInfo.syncC2V1); + } +} + +template __aicore__ inline void SparseFlashAttentionMla::Process() +{ + if (aiCoreIdx < usedCoreNum) { + if ASCEND_IS_AIV { + vectorService.AllocEventID(); + vectorService.InitSoftmaxDefaultBuffer(); + } else { + matmulService.AllocEventID(); + } + ProcessBalance(); + + if ASCEND_IS_AIV { + vectorService.FreeEventID(); + } else { + matmulService.FreeEventID(); + } + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::GetBN2Idx(uint32_t bN2Idx, uint32_t &bIdx, + uint32_t &n2Idx) +{ + bIdx = bN2Idx / kvHeadNum; + n2Idx = bN2Idx % kvHeadNum; +} + +template __aicore__ inline void SparseFlashAttentionMla::ProcessBalance() +{ + RunInfo extraInfo[SFA_PRELOAD_TASK_CACHE_SIZE]; + uint32_t gloop = 0; + int gS1LoopEnd; + bool globalLoopStart = true; + if ASCEND_IS_AIC { + CrossCoreSetFlag(constInfo.syncC2V1); + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + CrossCoreSetFlag(3); + CrossCoreSetFlag(3); + CrossCoreSetFlag(3); + CrossCoreSetFlag(3); + } + } + for (uint32_t bN2LoopIdx = constInfo.bN2Start; bN2LoopIdx <= constInfo.bN2End; bN2LoopIdx++) { + GetBN2Idx(bN2LoopIdx, tempLoopInfo.bIdx, tempLoopInfo.n2Idx); + GetActualSeqLen(tempLoopInfo.bIdx); + GetPreNextTokensLeftUp(); + if (tempLoopInfo.actS1Size == 0) { + continue; + } + int gS1SplitNum = (tempLoopInfo.actS1Size * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize; + gS1LoopEnd = (bN2LoopIdx == constInfo.bN2End) ? constInfo.gS1End : gS1SplitNum - 1; + for (uint32_t gS1LoopIdx = constInfo.gS1Start; gS1LoopIdx <= gS1LoopEnd; gS1LoopIdx++) { + tempLoopInfo.gS1Idx = gS1LoopIdx * constInfo.mBaseSize; + GetSparseActualSeqLen(tempLoopInfo.bIdx, gS1LoopIdx, tempLoopInfo.n2Idx); + UpdateInnerLoopCond(); + + if (tempLoopInfo.curActSeqLenIsZero) { + DealActSeqLenIsZero(tempLoopInfo.bIdx, gS1LoopIdx, tempLoopInfo.n2Idx); + } + int s2SplitNum = + (tempLoopInfo.curActualSeqLen + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize; + bool isEnd = (bN2LoopIdx == constInfo.bN2End) && (gS1LoopIdx == constInfo.gS1End); + tempLoopInfo.s2LoopTimes = s2SplitNum; + tempLoopInfo.tndIsS2SplitCore = + ((constInfo.s2Start == 0) && (tempLoopInfo.s2LoopTimes == s2SplitNum)) ? false : true; + tempLoopInfo.tndCoreStartKVSplitPos = globalLoopStart ? constInfo.coreStartKVSplitPos : 0; + uint32_t extraLoop = isEnd ? 2 : 0; + + uint32_t curTopKIdx = 0; + uint64_t curOffsetInSparseBlock = 0; + for (int s2LoopIdx = constInfo.s2Start; s2LoopIdx < (tempLoopInfo.s2LoopTimes + extraLoop); s2LoopIdx++) { + PreloadPipeline(gloop, constInfo.s2Start, s2LoopIdx, extraInfo, curTopKIdx, curOffsetInSparseBlock); + ++gloop; + } + globalLoopStart = false; + constInfo.s2Start = 0; + } + constInfo.gS1Start = 0; + } + if ASCEND_IS_AIV { + CrossCoreWaitFlag(constInfo.syncC2V1); + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + CrossCoreWaitFlag(3); + CrossCoreWaitFlag(3); + CrossCoreWaitFlag(3); + CrossCoreWaitFlag(3); + } + } +} + +template +__aicore__ inline void +SparseFlashAttentionMla::PreloadPipeline(uint32_t loop, uint64_t s2Start, uint64_t s2LoopIdx, + RunInfo extraInfo[SFA_PRELOAD_TASK_CACHE_SIZE], uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock) +{ + RunInfo &extraInfo0 = extraInfo[loop % SFA_PRELOAD_TASK_CACHE_SIZE]; + RunInfo &extraInfo2 = extraInfo[(loop + 2) % SFA_PRELOAD_TASK_CACHE_SIZE]; + RunInfo &extraInfo1 = extraInfo[(loop + 1) % SFA_PRELOAD_TASK_CACHE_SIZE]; + + CalcParams(loop, s2Start, s2LoopIdx, extraInfo0); + CalcSinnerTopKBegin(extraInfo0, curTopKIdx, curOffsetInSparseBlock); + + if (extraInfo0.isValid) { + if ASCEND_IS_AIC { + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + CrossCoreWaitFlag(constInfo.syncV0C1); + } + ComputeMm1(extraInfo0); + } else { + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + CrossCoreWaitFlag(3); + vectorService.MergeKv(extraInfo0); + CrossCoreSetFlag(constInfo.syncV0C1); + } + } + } + if (extraInfo2.isValid) { + if ASCEND_IS_AIV { + vectorService.ProcessVec1L(extraInfo2); + } + if ASCEND_IS_AIC { + ComputeMm2(extraInfo2); + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + CrossCoreSetFlag(3); + } + } + } + if (extraInfo1.isValid) { + if ASCEND_IS_AIV { + vectorService.ProcessVec2L(extraInfo1); + } + extraInfo1.isValid = false; + } +} + +template +__aicore__ inline uint64_t +SparseFlashAttentionMla::GetBalanceActualSeqLengths(GlobalTensor &actualSeqLengths, + uint32_t bIdx) +{ + if constexpr (LAYOUT_T == SFA_LAYOUT::TND) { + if (bIdx > 0) { + return actualSeqLengths.GetValue(bIdx) - actualSeqLengths.GetValue(bIdx - 1); + } else if (bIdx == 0) { + return actualSeqLengths.GetValue(0); + } else { + return 0; + } + } else { + if (constInfo.actualLenDimsQ == 0) { + return constInfo.qSeqSize; + } else if (constInfo.actualLenDimsQ == 1) { + return actualSeqLengths.GetValue(0); + } else { + return actualSeqLengths.GetValue(bIdx); + } + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::GetAxisStartIdx(uint32_t bN2EndPrev, + uint32_t s1GEndPrev, + uint32_t s2EndPrev) +{ + uint32_t bEndPrev = bN2EndPrev / kvHeadNum; + uint32_t actualSeqQPrev = GetBalanceActualSeqLengths(actualSeqLengthsQGm, bEndPrev); + uint32_t s1GPrevBaseNum = (actualSeqQPrev * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize; + constInfo.bN2Start = bN2EndPrev; + constInfo.gS1Start = s1GEndPrev; + + constInfo.s2Start = 0; + if (s1GEndPrev >= s1GPrevBaseNum - 1) { + constInfo.gS1Start = 0; + constInfo.bN2Start++; + } else { + constInfo.gS1Start++; + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::CalcSinnerTopKBegin(RunInfo &info, uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock) + +{ + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + return; + } + + uint64_t thresholdSparseCount = (info.threshold + constInfo.sparseBlockSize - 1) / constInfo.sparseBlockSize; + uint64_t validCount = (constInfo.sparseBlockCount > thresholdSparseCount) ? thresholdSparseCount : constInfo.sparseBlockCount; + + int32_t sparseIndices = topKGm.GetValue(info.topKBaseOffset + curTopKIdx); + if (sparseIndices == -1 || curTopKIdx == validCount) { + info.actualSingleProcessSInnerSize = 0; + info.actualSingleProcessSInnerSizeAlign = 0; + tempLoopInfo.s2BasicSizeTail = 0; + if (curTopKIdx == 0) { + DealActSeqLenIsZero(info.bIdx, info.gS1Idx / constInfo.gSize, tempLoopInfo.n2Idx); + } + return; + } + + uint32_t sparseLen = 0; + uint64_t blockBegin = sparseIndices * constInfo.sparseBlockSize; + uint64_t blockEnd = (blockBegin + constInfo.sparseBlockSize > info.threshold) ? info.threshold : blockBegin + constInfo.sparseBlockSize; + int32_t blockLen = blockEnd - blockBegin; + sparseLen += (blockLen > static_cast(curOffsetInSparseBlock)) ? blockLen - curOffsetInSparseBlock : 0; + + bool firstVaildFlag = false; + if (curTopKIdx > 0) { + info.curTopKIdx = curTopKIdx; + info.curOffsetInSparseBlock = curOffsetInSparseBlock; + } else if (curTopKIdx == 0 && sparseLen > 0) { + info.curTopKIdx = curTopKIdx; + info.curOffsetInSparseBlock = 0; + firstVaildFlag = true; + } + + for (uint64_t topkIdx = curTopKIdx + 1; topkIdx < validCount; topkIdx++) { + int32_t sparseIndices = topKGm.GetValue(info.topKBaseOffset + topkIdx); + if (sparseIndices == -1) { + curTopKIdx = topkIdx; + curOffsetInSparseBlock = 0; + break; + } + uint64_t blockBegin = sparseIndices * constInfo.sparseBlockSize; + if (blockBegin >= info.threshold) { + continue; + } + if (firstVaildFlag == false && curTopKIdx == 0) { + info.curTopKIdx = topkIdx; + info.curOffsetInSparseBlock = 0; + firstVaildFlag = true; + } + uint64_t blockEnd = (blockBegin + constInfo.sparseBlockSize > info.threshold) ? info.threshold : blockBegin + constInfo.sparseBlockSize; + uint64_t blockLen = blockEnd - blockBegin; + sparseLen += blockLen; + if (sparseLen >= constInfo.s2BaseSize) { + curTopKIdx = topkIdx; + curOffsetInSparseBlock = blockLen - (sparseLen - constInfo.s2BaseSize); + sparseLen = constInfo.s2BaseSize; + break; + } + + if (topkIdx == validCount - 1) { + curTopKIdx = validCount; + curOffsetInSparseBlock = 0; + } + } + + info.actualSingleProcessSInnerSize = sparseLen; + info.actualSingleProcessSInnerSizeAlign = SFAAlign((uint32_t)info.actualSingleProcessSInnerSize, (uint32_t)SFAVectorService::BYTE_BLOCK); + tempLoopInfo.s2BasicSizeTail = (sparseLen == constInfo.s2BaseSize) ? 0 : sparseLen; + if (curTopKIdx == 0 && sparseLen == 0) { + DealActSeqLenIsZero(info.bIdx, info.gS1Idx / constInfo.gSize, tempLoopInfo.n2Idx); + } +} +#endif // SPARSE_FLASH_ATTENTION_KERNEL_MLA_H \ No newline at end of file diff --git a/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_service_cube_mla.h b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_service_cube_mla.h new file mode 100644 index 00000000..60bb606c --- /dev/null +++ b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_service_cube_mla.h @@ -0,0 +1,1079 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_service_cube_mla.h + * \brief use 7 buffer for matmul l1, better pipeline + */ +#ifndef SPARSE_FLASH_ATTENTION_SERVICE_CUBE_MLA_H +#define SPARSE_FLASH_ATTENTION_SERVICE_CUBE_MLA_H + +#include "kernel_operator.h" +#include "kernel_operator_list_tensor_intf.h" +#include "kernel_tiling/kernel_tiling.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" +#include "sparse_flash_attention_common.h" + +struct PAShape { + uint32_t blockSize; + uint32_t headNum; + uint32_t headDim; + uint32_t maxblockNumPerBatch; + uint32_t actHeadDim; + uint32_t copyRowNum; + uint32_t copyRowNumAlign; +}; + +struct Position { + uint32_t bIdx; + uint32_t n2Idx; + uint32_t s2Idx; + uint32_t dIdx; +}; + +template +__aicore__ inline void DataCopyGmNDToL1(LocalTensor &l1Tensor, GlobalTensor &gmTensor, + uint32_t rowAct, + uint32_t rowAlign, + uint32_t col, // D + uint32_t colStride) // D or N*D +{ + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = rowAct; + nd2nzPara.dValue = col; + nd2nzPara.srcDValue = colStride; + nd2nzPara.dstNzC0Stride = rowAlign; + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + DataCopy(l1Tensor, gmTensor, nd2nzPara); +} + +template +__aicore__ inline void DataCopyPA(LocalTensor &dstTensor, //l1 + GlobalTensor &srcTensor, //gm + GlobalTensor &blockTableGm, + const PAShape &shape, // blockSize, headNum, headDim + const Position &startPos) // bacthIdx nIdx curSeqIdx +{ + uint32_t copyFinishRowCnt = 0; + uint64_t blockTableBaseOffset = startPos.bIdx * shape.maxblockNumPerBatch; + uint32_t curS2Idx = startPos.s2Idx; + uint32_t blockElementCnt = 32 / sizeof(T); + while (copyFinishRowCnt < shape.copyRowNum) { + uint64_t blockIdOffset = curS2Idx / shape.blockSize; + uint64_t reaminRowCnt = curS2Idx % shape.blockSize; + uint64_t idInBlockTable = blockTableGm.GetValue(blockTableBaseOffset + blockIdOffset); + uint32_t copyRowCnt = shape.blockSize - reaminRowCnt; + if (copyFinishRowCnt + copyRowCnt > shape.copyRowNum) { + copyRowCnt = shape.copyRowNum - copyFinishRowCnt; + } + uint64_t offset = idInBlockTable * shape.blockSize * shape.headNum * shape.headDim; + + uint64_t dStride = shape.headDim; + if constexpr (SRC_LAYOUT == SFA_LAYOUT::BSND || SRC_LAYOUT == SFA_LAYOUT::TND) { + offset += (uint64_t)(startPos.n2Idx * shape.headDim) + + reaminRowCnt * shape.headDim * shape.headNum + startPos.dIdx; + dStride = shape.headDim * shape.headNum; + } else { + offset += (uint64_t)(startPos.n2Idx * shape.headDim * shape.blockSize) + + reaminRowCnt * shape.headDim + startPos.dIdx; + } + + uint32_t dValue = shape.actHeadDim; + uint32_t srcDValue = dStride; + LocalTensor tmpDstTensor = dstTensor[copyFinishRowCnt * blockElementCnt]; + GlobalTensor tmpSrcTensor = srcTensor[offset]; + + DataCopyGmNDToL1(tmpDstTensor, tmpSrcTensor, copyRowCnt, shape.copyRowNumAlign, dValue, srcDValue); + copyFinishRowCnt += copyRowCnt; + curS2Idx += copyRowCnt; + } +} + +template class SFAMatmulService { +public: + using T = float; + using Q_T = typename SFAT::queryType; + using KV_T = typename SFAT::kvType; + using OUT_T = typename SFAT::outputType; + using MM_OUT_T = T; + + __aicore__ inline SFAMatmulService(){}; + __aicore__ inline void InitParams(const ConstInfo &constInfo); + __aicore__ inline void InitMm1GlobalTensor(GlobalTensor queryGm, GlobalTensor qRopeGm, + GlobalTensor keyGm, GlobalTensor kRopeGm, + GlobalTensor mm1ResGm); + __aicore__ inline void InitMm2GlobalTensor(GlobalTensor vec1ResGm, GlobalTensor valueGm, + GlobalTensor mm2ResGm, GlobalTensor attentionOutGm); + __aicore__ inline void InitPageAttentionInfo(const GlobalTensor& kvMergeGm, + GlobalTensor blockTableGm, GlobalTensor topKGm, + uint32_t blockSize, uint32_t maxBlockNumPerBatch); + __aicore__ inline void InitBuffers(TPipe *pipe); + __aicore__ inline void UpdateKey(GlobalTensor keyGm); + __aicore__ inline void UpdateValue(GlobalTensor valueGm); + + __aicore__ inline void AllocEventID(); + __aicore__ inline void FreeEventID(); + __aicore__ inline void CalcTopKBlockInfo(const RunInfo &info, uint32_t &curTopKIdx, + uint64_t &curOffsetInSparseBlock, uint32_t curSeqIdx, + uint32_t ©RowCnt, int64_t &idInTopK); + __aicore__ inline void ComputeMm1(const RunInfo &info, const MSplitInfo mSplitInfo); + __aicore__ inline void ComputeMm2(const RunInfo &info, const MSplitInfo mSplitInfo); + +private: + static constexpr bool PAGE_ATTENTION = SFAT::pageAttention; + static constexpr int TEMPLATE_MODE = SFAT::templateMode; + static constexpr bool FLASH_DECODE = SFAT::flashDecode; + static constexpr SFA_LAYOUT LAYOUT_T = SFAT::layout; + static constexpr SFA_LAYOUT KV_LAYOUT_T = SFAT::kvLayout; + + static constexpr uint32_t M_SPLIT_SIZE = 128; + static constexpr uint32_t N_SPLIT_SIZE = 128; + static constexpr uint32_t N_WORKSPACE_SIZE = 512; + + static constexpr uint32_t L1_BLOCK_SIZE = (64 * (512 + 64) * sizeof(Q_T)); + static constexpr uint32_t L1_BLOCK_OFFSET = 64 * (512 + 64); + + static constexpr uint32_t L0A_PP_SIZE = (32 * 1024); + static constexpr uint32_t L0B_PP_SIZE = (32 * 1024); + static constexpr uint32_t L0C_PP_SIZE = (64 * 1024); + + static constexpr uint32_t L1_EVENT0 = EVENT_ID2; + static constexpr uint32_t L1_EVENT1 = EVENT_ID3; + static constexpr uint32_t L1_EVENT2 = EVENT_ID4; + static constexpr uint32_t L1_EVENT3 = EVENT_ID5; + static constexpr uint32_t L1_EVENT4 = EVENT_ID6; + static constexpr uint32_t L1_EVENT5 = EVENT_ID7; + static constexpr uint32_t L1_EVENT6 = EVENT_ID1; + + // m <> mte1 EventID + static constexpr uint32_t L0AB_EVENT0 = EVENT_ID3; + static constexpr uint32_t L0AB_EVENT1 = EVENT_ID4; + + static constexpr IsResetLoad3dConfig LOAD3DV2_CONFIG = {true, true}; + static constexpr uint32_t mte21QPIds[4] = {L1_EVENT0, L1_EVENT1, L1_EVENT2, L1_EVENT3}; + static constexpr uint32_t mte21KVIds[3] = {L1_EVENT4, L1_EVENT5, L1_EVENT6}; + + uint32_t kvCacheBlockSize = 0; + uint32_t maxBlockNumPerBatch = 0; + ConstInfo constInfo{}; + + uint32_t qpL1BufIter = 0; + uint32_t kvL1BufIter = -1; + uint32_t abL0BufIter = 0; + uint32_t cL0BufIter = 0; + + // mm1 + GlobalTensor queryGm; + GlobalTensor qRopeGm; + GlobalTensor keyGm; + GlobalTensor kRopeGm; + GlobalTensor mm1ResGm; + GlobalTensor kvMergeGm_; + + // mm2 + GlobalTensor vec1ResGm; + GlobalTensor valueGm; + GlobalTensor mm2ResGm; + GlobalTensor attentionOutGm; + + // block_table + GlobalTensor blockTableGm; + GlobalTensor topKGm; + + TBuf bufQPL1; + TBuf bufKVL1; + TBuf tmpBufL0A; + TBuf tmpBufL0B; + TBuf tmpBufL0C; + + LocalTensor l1QPTensor; + LocalTensor l1KVTensor; + LocalTensor aL0TensorPingPong; + LocalTensor bL0TensorPingPong; + LocalTensor cL0TensorPingPong; + + // L0AB m <> mte1 EventID + __aicore__ inline uint32_t Mte1MmABEventId(uint32_t idx) + { + return (L0AB_EVENT0 + idx); + } + + __aicore__ inline uint32_t GetQPL1RealIdx(uint32_t mIdx, uint32_t k1Idx) + { + uint32_t idxMap[] = {0, 2}; + return idxMap[mIdx % 2] + k1Idx; + } + + __aicore__ inline void CopyGmToL1(LocalTensor &l1Tensor, GlobalTensor &gmSrcTensor, uint32_t srcN, + uint32_t srcD, uint32_t srcDstride); + __aicore__ inline void CopyInMm1AToL1(LocalTensor &aL1Tensor, const RunInfo &info, uint32_t mSeqIdx, + uint32_t mSizeAct, uint32_t headSize, uint32_t headOffset); + __aicore__ inline void CopyInMm1ARopeToL1(LocalTensor &aL1Tensor, const RunInfo &info, uint32_t mSeqIdx, + uint32_t mSizeAct); + __aicore__ inline void CopyInMm1BToL1(LocalTensor &bL1Tensor, const uint64_t keyGmBaseOffset, + uint32_t copyTotalRowCntAlign, uint32_t copyStartRowCnt, + uint32_t nActCopyRowCount, uint32_t headSize); + __aicore__ inline void CopyInMm1BRopeToL1(LocalTensor &bL1Tensor, const uint64_t keyGmBaseOffset, + uint32_t copyTotalRowCntAlign, uint32_t copyStartRowCnt, + uint32_t nActCopyRowCount, uint32_t headSize); + __aicore__ inline void CopyInMm2AToL1(LocalTensor &aL1Tensor, const RunInfo &info, uint32_t mSeqIdx, + uint32_t subMSizeAct, uint32_t nSize, uint32_t nOffset); + __aicore__ inline void CopyInMm2BToL1(LocalTensor &bL1Tensor, const uint64_t valueGmBaseOffset, + uint32_t copyTotalRowCntAlign, uint32_t copyStartRowCnt, + uint32_t nActCopyRowCount, uint32_t copyStartColumnCount, + uint32_t copyColumnCount); + __aicore__ inline void LoadDataMm1A(LocalTensor &aL0Tensor, LocalTensor &aL1Tensor, uint32_t idx, + uint32_t kSplitSize, uint32_t mSize, uint32_t kSize); + __aicore__ inline void LoadDataMm1B(LocalTensor &bL0Tensor, LocalTensor &bL1Tensor, uint32_t idx, + uint32_t kSplitSize, uint32_t kSize, uint32_t nSize); +}; + +template __aicore__ inline void SFAMatmulService::InitParams(const ConstInfo &constInfo) +{ + this->constInfo = constInfo; +} + +template +__aicore__ inline void +SFAMatmulService::InitMm1GlobalTensor(GlobalTensor queryGm, GlobalTensor qRopeGm, + GlobalTensor keyGm, GlobalTensor kRopeGm, + GlobalTensor mm1ResGm) +{ + // mm1 + this->queryGm = queryGm; + this->qRopeGm = qRopeGm; + this->keyGm = keyGm; + this->kRopeGm = kRopeGm; + this->mm1ResGm = mm1ResGm; +} + +template +__aicore__ inline void +SFAMatmulService::InitMm2GlobalTensor(GlobalTensor vec1ResGm, GlobalTensor valueGm, + GlobalTensor mm2ResGm, GlobalTensor attentionOutGm) +{ + // mm2 + this->vec1ResGm = vec1ResGm; + this->valueGm = valueGm; + this->mm2ResGm = mm2ResGm; + this->attentionOutGm = attentionOutGm; +} + +template +__aicore__ inline void +SFAMatmulService::InitPageAttentionInfo(const GlobalTensor& kvMergeGm, GlobalTensor blockTableGm, + GlobalTensor topKGm, uint32_t blockSize, uint32_t maxBlockNumPerBatch) +{ + this->blockTableGm = blockTableGm; + this->topKGm = topKGm; + this->kvCacheBlockSize = blockSize; + this->maxBlockNumPerBatch = maxBlockNumPerBatch; + this->kvMergeGm_ = kvMergeGm; +} + +template __aicore__ inline void SFAMatmulService::InitBuffers(TPipe *pipe) +{ + pipe->InitBuffer(bufQPL1, L1_BLOCK_SIZE * 4); // (64K + 8K) * 4 + l1QPTensor = bufQPL1.Get(); + pipe->InitBuffer(bufKVL1, L1_BLOCK_SIZE * 3); // (64K + 8K) * 3 + l1KVTensor = bufKVL1.Get(); + + // L0A + pipe->InitBuffer(tmpBufL0A, L0A_PP_SIZE * 2); // 64K + aL0TensorPingPong = tmpBufL0A.Get(); + // L0B + pipe->InitBuffer(tmpBufL0B, L0B_PP_SIZE * 2); // 64K + bL0TensorPingPong = tmpBufL0B.Get(); + // L0C + pipe->InitBuffer(tmpBufL0C, L0C_PP_SIZE * 2); // 128K + cL0TensorPingPong = tmpBufL0C.Get(); +} + +template __aicore__ inline void SFAMatmulService::UpdateKey(GlobalTensor keyGm) +{ + this->keyGm = keyGm; +} + +template __aicore__ inline void SFAMatmulService::UpdateValue(GlobalTensor valueGm) +{ + this->valueGm = valueGm; +} + +template __aicore__ inline void SFAMatmulService::AllocEventID() +{ + SetFlag(L1_EVENT0); + SetFlag(L1_EVENT1); + SetFlag(L1_EVENT2); + SetFlag(L1_EVENT3); + SetFlag(L1_EVENT4); + SetFlag(L1_EVENT5); + SetFlag(L1_EVENT6); + SetFlag(L0AB_EVENT0); + SetFlag(L0AB_EVENT1); +} + +template __aicore__ inline void SFAMatmulService::FreeEventID() +{ + WaitFlag(L1_EVENT0); + WaitFlag(L1_EVENT1); + WaitFlag(L1_EVENT2); + WaitFlag(L1_EVENT3); + WaitFlag(L1_EVENT4); + WaitFlag(L1_EVENT5); + WaitFlag(L1_EVENT6); + WaitFlag(L0AB_EVENT0); + WaitFlag(L0AB_EVENT1); +} + +template +__aicore__ inline void SFAMatmulService::CopyGmToL1(LocalTensor &l1Tensor, + GlobalTensor &gmSrcTensor, uint32_t srcN, + uint32_t srcD, uint32_t srcDstride) +{ + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = srcN; + nd2nzPara.dValue = srcD; + nd2nzPara.srcDValue = srcDstride; + nd2nzPara.dstNzC0Stride = (srcN + 15) / 16 * 16; + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + DataCopy(l1Tensor, gmSrcTensor, nd2nzPara); +} + +template +__aicore__ inline void SFAMatmulService::CopyInMm1AToL1(LocalTensor &l1Tensor, const RunInfo &info, + uint32_t mSeqIdx, uint32_t mSizeAct, + uint32_t headSize, uint32_t headOffset) +{ + auto srcGm = queryGm[info.tensorAOffset + mSeqIdx * constInfo.headDim + headOffset]; + CopyGmToL1(l1Tensor, srcGm, mSizeAct, headSize, constInfo.headDim); +} + +template +__aicore__ inline void SFAMatmulService::CopyInMm1ARopeToL1(LocalTensor &l1Tensor, + const RunInfo &info, uint32_t mSeqIdx, + uint32_t mSizeAct) +{ + auto srcGm = qRopeGm[info.tensorARopeOffset + mSeqIdx * constInfo.headDimRope]; + CopyGmToL1(l1Tensor, srcGm, mSizeAct, constInfo.headDimRope, constInfo.headDimRope); +} + +template +__aicore__ inline void +SFAMatmulService::CopyInMm1BToL1(LocalTensor &bL1Tensor, const uint64_t keyGmBaseOffset, + uint32_t copyTotalRowCntAlign, uint32_t copyStartRowCnt, + uint32_t nActCopyRowCount, uint32_t headSize) +{ + uint64_t dStride = constInfo.headDim; + if constexpr (KV_LAYOUT_T == SFA_LAYOUT::BSND || KV_LAYOUT_T == SFA_LAYOUT::TND) { + dStride = constInfo.headDim * constInfo.kvHeadNum; + } + + uint32_t blockElementCnt = 32 / sizeof(KV_T); + + Nd2NzParams mm1Nd2NzParamsForB; + mm1Nd2NzParamsForB.ndNum = 1; + mm1Nd2NzParamsForB.nValue = nActCopyRowCount; + mm1Nd2NzParamsForB.dValue = headSize; + mm1Nd2NzParamsForB.srcDValue = dStride; + mm1Nd2NzParamsForB.dstNzC0Stride = copyTotalRowCntAlign; + mm1Nd2NzParamsForB.dstNzNStride = 1; + mm1Nd2NzParamsForB.srcNdMatrixStride = 0; + mm1Nd2NzParamsForB.dstNzMatrixStride = 0; + DataCopy(bL1Tensor[copyStartRowCnt * blockElementCnt], keyGm[keyGmBaseOffset], mm1Nd2NzParamsForB); +} + +template +__aicore__ inline void +SFAMatmulService::CopyInMm1BRopeToL1(LocalTensor &bL1Tensor, const uint64_t kRopeGmBaseOffset, + uint32_t copyTotalRowCntAlign, uint32_t copyStartRowCnt, + uint32_t nActCopyRowCount, uint32_t headSize) +{ + uint64_t dStride = constInfo.headDimRope; + if constexpr (KV_LAYOUT_T == SFA_LAYOUT::BSND || KV_LAYOUT_T == SFA_LAYOUT::TND) { + dStride = constInfo.headDimRope * constInfo.kvHeadNum; + } + + uint32_t blockElementCnt = 32 / sizeof(KV_T); + + Nd2NzParams mm1Nd2NzParamsForB; + mm1Nd2NzParamsForB.ndNum = 1; + mm1Nd2NzParamsForB.nValue = nActCopyRowCount; + mm1Nd2NzParamsForB.dValue = headSize; + mm1Nd2NzParamsForB.srcDValue = dStride; + mm1Nd2NzParamsForB.dstNzC0Stride = copyTotalRowCntAlign; + mm1Nd2NzParamsForB.dstNzNStride = 1; + mm1Nd2NzParamsForB.srcNdMatrixStride = 0; + mm1Nd2NzParamsForB.dstNzMatrixStride = 0; + DataCopy(bL1Tensor[copyStartRowCnt * blockElementCnt], kRopeGm[kRopeGmBaseOffset], mm1Nd2NzParamsForB); +} + +template +__aicore__ inline void SFAMatmulService::LoadDataMm1A(LocalTensor &aL0Tensor, + LocalTensor &aL1Tensor, uint32_t idx, + uint32_t kSplitSize, uint32_t mSize, uint32_t kSize) +{ + LocalTensor srcTensor = aL1Tensor[mSize * kSplitSize * idx]; + LoadData3DParamsV2 loadData3DParams; + // SetFmatrixParams + loadData3DParams.l1H = mSize / 16; // Hin=M1=8 + loadData3DParams.l1W = 16; // Win=M0 + loadData3DParams.padList[0] = 0; + loadData3DParams.padList[1] = 0; + loadData3DParams.padList[2] = 0; + loadData3DParams.padList[3] = 255; + + // SetLoadToA0Params + loadData3DParams.mExtension = mSize; // M + loadData3DParams.kExtension = kSize; // K + loadData3DParams.mStartPt = 0; + loadData3DParams.kStartPt = 0; + loadData3DParams.strideW = 1; + loadData3DParams.strideH = 1; + loadData3DParams.filterW = 1; + loadData3DParams.filterSizeW = (1 >> 8) & 255; + loadData3DParams.filterH = 1; + loadData3DParams.filterSizeH = (1 >> 8) & 255; + loadData3DParams.dilationFilterW = 1; + loadData3DParams.dilationFilterH = 1; + loadData3DParams.enTranspose = 0; + loadData3DParams.fMatrixCtrl = 0; + loadData3DParams.channelSize = kSize; // Cin=K + LoadData(aL0Tensor, srcTensor, loadData3DParams); +} + +template +__aicore__ inline void SFAMatmulService::LoadDataMm1B(LocalTensor &l0Tensor, + LocalTensor &l1Tensor, uint32_t idx, + uint32_t kSplitSize, uint32_t kSize, uint32_t nSize) +{ + LocalTensor srcTensor = l1Tensor[nSize * kSplitSize * idx]; + + LoadData2DParams loadData2DParams; + loadData2DParams.startIndex = 0; + loadData2DParams.repeatTimes = (nSize + 15) / 16 * kSize / (32 / sizeof(KV_T)); + loadData2DParams.srcStride = 1; + loadData2DParams.dstGap = 0; + loadData2DParams.ifTranspose = false; + LoadData(l0Tensor, srcTensor, loadData2DParams); +} + +template +__aicore__ inline void SFAMatmulService::CopyInMm2AToL1(LocalTensor &aL1Tensor, const RunInfo &info, + uint32_t mSeqIdx, uint32_t subMSizeAct, + uint32_t nSize, uint32_t nOffset) +{ + auto srcGm = vec1ResGm[(info.loop % constInfo.preLoadNum) * constInfo.mmResUbSize + + mSeqIdx * info.actualSingleProcessSInnerSizeAlign + nOffset]; + CopyGmToL1(aL1Tensor, srcGm, subMSizeAct, nSize, info.actualSingleProcessSInnerSizeAlign); +} + +template +__aicore__ inline void SFAMatmulService::CopyInMm2BToL1( + LocalTensor &bL1Tensor, const uint64_t valueGmBaseOffset, uint32_t copyTotalRowCntAlign, + uint32_t copyStartRowCnt, uint32_t nActCopyRowCount, uint32_t copyStartColumnCount, uint32_t copyColumnCount) +{ + uint64_t step = constInfo.headDim; + if constexpr (KV_LAYOUT_T == SFA_LAYOUT::BSND || KV_LAYOUT_T == SFA_LAYOUT::TND) { + step = constInfo.headDim * constInfo.kvHeadNum; + } + + uint32_t blockElementCnt = 32 / sizeof(KV_T); + + Nd2NzParams mm1Nd2NzParamsForB; + mm1Nd2NzParamsForB.ndNum = 1; + mm1Nd2NzParamsForB.nValue = nActCopyRowCount; + mm1Nd2NzParamsForB.dValue = copyColumnCount; + mm1Nd2NzParamsForB.srcDValue = step; + mm1Nd2NzParamsForB.dstNzC0Stride = copyTotalRowCntAlign; + mm1Nd2NzParamsForB.dstNzNStride = 1; + mm1Nd2NzParamsForB.srcNdMatrixStride = 0; + mm1Nd2NzParamsForB.dstNzMatrixStride = 0; + DataCopy(bL1Tensor[copyStartRowCnt * blockElementCnt], valueGm[valueGmBaseOffset + copyStartColumnCount], + mm1Nd2NzParamsForB); +} + +template +__aicore__ inline void SFAMatmulService::CalcTopKBlockInfo( + const RunInfo &info, uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock, uint32_t curSeqIdx, uint32_t ©RowCnt, int64_t &idInTopK) +{ + uint64_t blockBegin = idInTopK * constInfo.sparseBlockSize; + uint64_t blockEnd = (blockBegin + constInfo.sparseBlockSize > info.threshold) ? + info.threshold : blockBegin + constInfo.sparseBlockSize; + uint64_t blockLen = blockEnd - blockBegin; + if (curOffsetInSparseBlock + copyRowCnt < blockLen) { + curOffsetInSparseBlock += copyRowCnt; + copyRowCnt = blockLen - curOffsetInSparseBlock; + } else { + for (uint64_t topkidx = curTopKIdx + 1; topkidx < constInfo.sparseBlockCount; topkidx++) { + int64_t sparseIndices = topKGm.GetValue(info.topKBaseOffset + topkidx); + if (sparseIndices == -1) { + break; + } + + uint64_t blockBegin = sparseIndices * constInfo.sparseBlockSize; + if (blockBegin >= info.threshold) { + continue; + } + uint64_t blockEnd = (blockBegin + constInfo.sparseBlockSize > info.threshold) ? + info.threshold : blockBegin + constInfo.sparseBlockSize; + uint64_t blockLen = blockEnd - blockBegin; + curTopKIdx = topkidx; + idInTopK = sparseIndices; + curOffsetInSparseBlock = 0; + copyRowCnt = blockLen; + break; + } + } +} + +template +__aicore__ inline void SFAMatmulService::ComputeMm1(const RunInfo &info, const MSplitInfo mSplitInfo) +{ + uint32_t mSize = mSplitInfo.nBufferDealM; + uint32_t mL1Size = M_SPLIT_SIZE; + uint32_t mL1SizeAlign = SFAAlign(M_SPLIT_SIZE, 16U); + uint32_t mL1Loops = (mSize + M_SPLIT_SIZE - 1) / M_SPLIT_SIZE; + + uint32_t nSize = info.actualSingleProcessSInnerSize; + uint32_t nL1Size = N_SPLIT_SIZE; + uint32_t nL1SizeAlign = SFAAlign(N_SPLIT_SIZE, 16U); + uint32_t nL1Loops = (nSize + N_SPLIT_SIZE - 1) / N_SPLIT_SIZE; + + uint32_t kSize = 576; + uint32_t kL1Size = 288; + uint32_t kL1Loops = 2; + + uint32_t kL0Size = 96; + uint32_t kL0Loops = (kL1Size + kL0Size - 1) / kL0Size; + + LocalTensor bL1Tensor; + LocalTensor kRopeTensor; + LocalTensor kTensor; + uint32_t ka = 0, kb = 0; + + uint32_t curTopKIdx = info.curTopKIdx; + uint64_t curOffsetInSparseBlock = info.curOffsetInSparseBlock; + uint32_t copyRowCnt = 0; + int64_t idInTopK = topKGm.GetValue(info.topKBaseOffset + curTopKIdx); + + uint32_t curTopKIdxTmp = 0; + uint64_t curOffsetInSparseBlockTmp = 0; + uint32_t copyRowCntTmp = 0; + int64_t idInTopKTmp = 0; + + for (uint32_t nL1 = 0; nL1 < nL1Loops; nL1++) { + if (nL1 == (nL1Loops - 1)) { + nL1Size = nSize - (nL1Loops - 1) * N_SPLIT_SIZE; + nL1SizeAlign = SFAAlign(nL1Size, 16U); + } + curTopKIdxTmp = curTopKIdx; + curOffsetInSparseBlockTmp = curOffsetInSparseBlock; + copyRowCntTmp = copyRowCnt; + idInTopKTmp = idInTopK; + + for (uint32_t kL1 = 0; kL1 < kL1Loops; kL1++) { + kvL1BufIter++; + uint32_t kb = kvL1BufIter % 3; + WaitFlag(mte21KVIds[kb]); + bL1Tensor = l1KVTensor[kb * L1_BLOCK_OFFSET]; + uint32_t curSeqIdx = info.s2BatchOffset + nL1 * N_SPLIT_SIZE; + uint32_t copyFinishRowCnt = 0; + curTopKIdx = curTopKIdxTmp; + curOffsetInSparseBlock = curOffsetInSparseBlockTmp; + copyRowCnt = copyRowCntTmp; + idInTopK = idInTopKTmp; + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + if (kL1 == 0) { + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = nL1Size; + nd2nzPara.dValue = constInfo.headDim >> 1; + nd2nzPara.srcDValue = constInfo.headDim; + nd2nzPara.dstNzC0Stride = nL1SizeAlign; + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + DataCopy(bL1Tensor, + kvMergeGm_[info.loop % 4 * N_WORKSPACE_SIZE * kSize + + nL1 * N_SPLIT_SIZE * constInfo.headDim], + nd2nzPara); + nd2nzPara.dValue = constInfo.headDimRope >> 1; + nd2nzPara.srcDValue = constInfo.headDimRope; + DataCopy( + bL1Tensor[nL1SizeAlign * (constInfo.headDim >> 1)], + kvMergeGm_[info.loop % 4 * N_WORKSPACE_SIZE * kSize + N_WORKSPACE_SIZE * constInfo.headDim + + nL1 * N_SPLIT_SIZE * constInfo.headDimRope], + nd2nzPara); + } else { + LocalTensor kTmpTensor = bL1Tensor[(constInfo.headDimRope >> 1) * nL1SizeAlign]; + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = nL1Size; + nd2nzPara.dValue = constInfo.headDim >> 1; + nd2nzPara.srcDValue = constInfo.headDim; + nd2nzPara.dstNzC0Stride = nL1SizeAlign; + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + DataCopy(kTmpTensor, + kvMergeGm_[info.loop % 4 * N_WORKSPACE_SIZE * kSize + (constInfo.headDim >> 1) + + nL1 * N_SPLIT_SIZE * constInfo.headDim], + nd2nzPara); + nd2nzPara.dValue = constInfo.headDimRope >> 1; + nd2nzPara.srcDValue = constInfo.headDimRope; + DataCopy( + bL1Tensor, + kvMergeGm_[info.loop % 4 * N_WORKSPACE_SIZE * kSize + N_WORKSPACE_SIZE * constInfo.headDim + + (constInfo.headDimRope >> 1) + nL1 * N_SPLIT_SIZE * constInfo.headDimRope], + nd2nzPara); + } + } else { + while (copyFinishRowCnt < nL1Size) { + CalcTopKBlockInfo(info, curTopKIdx, curOffsetInSparseBlock, curSeqIdx, copyRowCnt, idInTopK); + if (copyFinishRowCnt + copyRowCnt > nL1Size) { + copyRowCnt = nL1Size - copyFinishRowCnt; + } + + if constexpr (PAGE_ATTENTION) { + Position startPos; + startPos.bIdx = info.bIdx; + startPos.n2Idx = info.n2Idx; + startPos.s2Idx = idInTopK * constInfo.sparseBlockSize + curOffsetInSparseBlock; + startPos.dIdx = kL1 * 256; + Position ropeStartPos = startPos; + ropeStartPos.dIdx = kL1 * 32; + PAShape shape; + shape.blockSize = kvCacheBlockSize; + shape.headNum = constInfo.kvHeadNum; + shape.headDim = constInfo.headDim; + shape.actHeadDim = 256; + shape.maxblockNumPerBatch = maxBlockNumPerBatch; + shape.copyRowNum = copyRowCnt; + shape.copyRowNumAlign = nL1SizeAlign; + PAShape ropeShape = shape; + ropeShape.headDim = constInfo.headDimRope; + ropeShape.actHeadDim = 32; + if (kL1 == 0) { + kTensor = bL1Tensor[copyFinishRowCnt * 16]; + DataCopyPA(kTensor, keyGm, blockTableGm, shape, startPos); + kRopeTensor = bL1Tensor[(nL1SizeAlign * (BlockAlign(constInfo.headDim) >> 1)) + + copyFinishRowCnt * 16]; + DataCopyPA(kRopeTensor, kRopeGm, blockTableGm, ropeShape, + ropeStartPos); + } else { + kRopeTensor = bL1Tensor[copyFinishRowCnt * 16]; + DataCopyPA(kRopeTensor, kRopeGm, blockTableGm, ropeShape, + ropeStartPos); + LocalTensor kTmpTensor = bL1Tensor[32 * nL1SizeAlign + copyFinishRowCnt * 16]; + DataCopyPA(kTmpTensor, keyGm, blockTableGm, shape, startPos); + } + } else { + uint64_t keyOffset = info.tensorBOffset; + uint64_t kRopeOffset = info.tensorBRopeOffset; + if constexpr (KV_LAYOUT_T == SFA_LAYOUT::BSND || KV_LAYOUT_T == SFA_LAYOUT::TND) { + keyOffset += (idInTopK * constInfo.sparseBlockSize + curOffsetInSparseBlock) * + constInfo.kvHeadNum * constInfo.headDim; + kRopeOffset += (idInTopK * constInfo.sparseBlockSize + curOffsetInSparseBlock) * + constInfo.kvHeadNum * constInfo.headDimRope; + } else { + keyOffset += (idInTopK * constInfo.sparseBlockSize + curOffsetInSparseBlock) * + constInfo.headDim; + kRopeOffset += (idInTopK * constInfo.sparseBlockSize + curOffsetInSparseBlock) * + constInfo.headDimRope; + } + + if (kL1 == 0) { + CopyInMm1BToL1(bL1Tensor, keyOffset, nL1SizeAlign, copyFinishRowCnt, copyRowCnt, 256); + kRopeTensor = bL1Tensor[nL1SizeAlign * (BlockAlign(constInfo.headDim) >> 1)]; + CopyInMm1BRopeToL1(kRopeTensor, kRopeOffset, nL1SizeAlign, copyFinishRowCnt, copyRowCnt, + 32); + } else { + kRopeTensor = bL1Tensor; + CopyInMm1BRopeToL1(kRopeTensor, kRopeOffset + 32, nL1SizeAlign, copyFinishRowCnt, + copyRowCnt, 32); + LocalTensor kTmpTensor = bL1Tensor[nL1SizeAlign * 32]; + CopyInMm1BToL1(kTmpTensor, keyOffset + 256, nL1SizeAlign, copyFinishRowCnt, copyRowCnt, + 256); + } + } + + copyFinishRowCnt += copyRowCnt; + curSeqIdx += copyRowCnt; + } + } + + SetFlag(mte21KVIds[kb]); + WaitFlag(mte21KVIds[kb]); + mL1Size = M_SPLIT_SIZE; + mL1SizeAlign = SFAAlign(M_SPLIT_SIZE, 16U); + for (uint32_t mL1 = 0; mL1 < mL1Loops; mL1++) { + uint32_t aL1PaddingSize = 0; + if (mL1 == (mL1Loops - 1)) { + mL1Size = mSize - (mL1Loops - 1) * M_SPLIT_SIZE; + mL1SizeAlign = SFAAlign(mL1Size, 16U); + aL1PaddingSize = (M_SPLIT_SIZE - mL1SizeAlign) * 288; + } + + uint32_t mIdx = qpL1BufIter + mL1; + ka = GetQPL1RealIdx(mIdx, kL1); + LocalTensor aL1Tensor = + l1QPTensor[ka * L1_BLOCK_OFFSET + (1 - kL1) * aL1PaddingSize]; + if (nL1 == 0) { + if (kL1 == 0) { + WaitFlag(mte21QPIds[ka]); + WaitFlag(mte21QPIds[ka + 1]); + CopyInMm1AToL1(aL1Tensor, info, mSplitInfo.nBufferStartM + mL1 * M_SPLIT_SIZE, mL1Size, 256, 0); + LocalTensor qRopeTensor = + aL1Tensor[mL1SizeAlign * + 256]; + CopyInMm1ARopeToL1(qRopeTensor, info, mSplitInfo.nBufferStartM + mL1 * M_SPLIT_SIZE, mL1Size); + } else { + LocalTensor qTmpTensor = aL1Tensor[mL1SizeAlign * 32]; + CopyInMm1AToL1(qTmpTensor, info, mSplitInfo.nBufferStartM + mL1 * M_SPLIT_SIZE, mL1Size, 256, + 256); + } + SetFlag(mte21QPIds[ka]); + WaitFlag(mte21QPIds[ka]); + } + + LocalTensor cL0Tensor = + cL0TensorPingPong[(cL0BufIter % 2) * + (L0C_PP_SIZE / sizeof(MM_OUT_T))]; + for (uint32_t kL0 = 0; kL0 < kL0Loops; kL0++) { + WaitFlag(Mte1MmABEventId(abL0BufIter % 2)); + LocalTensor aL0Tensor = aL0TensorPingPong[(abL0BufIter % 2) * (L0A_PP_SIZE / sizeof(KV_T))]; + LoadDataMm1A(aL0Tensor, aL1Tensor, kL0, kL0Size, mL1SizeAlign, kL0Size); + LocalTensor bL0Tensor = bL0TensorPingPong[(abL0BufIter % 2) * (L0B_PP_SIZE / sizeof(KV_T))]; + LoadDataMm1B(bL0Tensor, bL1Tensor, kL0, kL0Size, kL0Size, nL1SizeAlign); + SetFlag(Mte1MmABEventId(abL0BufIter % 2)); + WaitFlag(Mte1MmABEventId(abL0BufIter % 2)); + + MmadParams mmadParams; + mmadParams.m = mL1SizeAlign; + mmadParams.n = nL1SizeAlign; + mmadParams.k = kL0Size; + mmadParams.cmatrixInitVal = (kL1 == 0 && kL0 == 0); + mmadParams.cmatrixSource = false; + mmadParams.unitFlag = + (kL1 == 1 && kL0 == (kL0Loops - 1)) ? 0b11 : 0b10; + Mmad(cL0Tensor, aL0Tensor, bL0Tensor, mmadParams); + + if ((mmadParams.m / 16) * (mmadParams.n / 16) < 10) { + PipeBarrier(); + } + SetFlag(Mte1MmABEventId(abL0BufIter % 2)); + abL0BufIter++; + } + + if (nL1 == (nL1Loops - 1)) { + SetFlag(mte21QPIds[ka]); + } + + if (kL1 == 1) { + FixpipeParamsV220 fixParams; + fixParams.nSize = nL1SizeAlign; + fixParams.mSize = mL1SizeAlign; + fixParams.srcStride = mL1SizeAlign; + fixParams.dstStride = info.actualSingleProcessSInnerSizeAlign; + fixParams.unitFlag = 0b11; + fixParams.ndNum = 1; + + Fixpipe(mm1ResGm[(info.loop % (constInfo.preLoadNum)) * constInfo.mmResUbSize + nL1 * N_SPLIT_SIZE + + (mSplitInfo.nBufferStartM + mL1 * M_SPLIT_SIZE) * + info.actualSingleProcessSInnerSizeAlign], + cL0Tensor, fixParams); + } + if (mL1Loops == 2) { + cL0BufIter++; + } + } + SetFlag(mte21KVIds[kb]); + } + if (mL1Loops == 1) { + cL0BufIter++; + } + } + qpL1BufIter += mL1Loops; +} + +template +__aicore__ inline void SFAMatmulService::ComputeMm2(const RunInfo &info, const MSplitInfo mSplitInfo) +{ + uint32_t mSize = mSplitInfo.nBufferDealM; + uint32_t mSizeAlign = (mSize + 16 - 1) / 16; + uint32_t mL1Loops = (mSize + M_SPLIT_SIZE - 1) / M_SPLIT_SIZE; + uint32_t mL1SizeAlign = M_SPLIT_SIZE; + uint32_t mL1Size = M_SPLIT_SIZE; + + uint32_t nSize = BlockAlign(constInfo.headDim); + uint32_t nL1Loops = (nSize + N_SPLIT_SIZE - 1) / N_SPLIT_SIZE; + uint32_t nL1SizeAlign = N_SPLIT_SIZE; + uint32_t nL1Size = N_SPLIT_SIZE; + + uint32_t kSize = info.actualSingleProcessSInnerSize; + uint32_t kL1Size = 256; + uint32_t kL1SizeAlign = SFAAlign(kL1Size, 16U); + uint32_t kL1Loops = (kSize + kL1Size - 1) / kL1Size; + uint32_t kL0Size = 128; + uint32_t kL0Loops = (kL1Size + kL0Size - 1) / kL0Size; + uint32_t kL0SizeAlign = kL0Size; + LocalTensor bL1Tensor; + LocalTensor subvTensor; + + uint32_t ka = 0, kb = 0; + uint32_t mBaseIdx = qpL1BufIter; + for (uint32_t nL1 = 0; nL1 < nL1Loops; nL1++) { + if (nL1 == (nL1Loops - 1)) { + nL1Size = nSize - (nL1Loops - 1) * N_SPLIT_SIZE; + nL1SizeAlign = SFAAlign(nL1Size, 16U); + } + + kL1Size = 256; + kL1SizeAlign = SFAAlign(kL1Size, 16U); + + uint32_t curTopKIdx = info.curTopKIdx; + uint64_t curOffsetInSparseBlock = info.curOffsetInSparseBlock; + uint32_t copyRowCnt = 0; + int64_t idInTopK = topKGm.GetValue(info.topKBaseOffset + curTopKIdx); + + for (uint32_t k1 = 0; k1 < kL1Loops; k1++) { + if (k1 == (kL1Loops - 1)) { + kL1Size = kSize - (kL1Loops - 1) * 256; + kL1SizeAlign = SFAAlign(kL1Size, 16U); + } + kvL1BufIter++; + uint32_t kb = kvL1BufIter % 3; + WaitFlag(mte21KVIds[kb]); + bL1Tensor = l1KVTensor[kb * L1_BLOCK_OFFSET]; + uint32_t kOffset = k1 * kL0Loops; + kL0Size = 128; + kL0Loops = (kL1Size + kL0Size - 1) / kL0Size; + kL0SizeAlign = kL0Size; + for (uint32_t kL1 = kOffset; kL1 < kL0Loops + kOffset; kL1++) { + if (kL1 == kOffset + kL0Loops - 1) { + kL0Size = kL1Size - (kL0Loops - 1) * kL0Size; + kL0SizeAlign = SFAAlign(kL0Size, 16U); + } + + uint32_t curSeqIdx = info.s2BatchOffset + (kL1 - kOffset) * 128 + k1 * 256; + uint32_t copyFinishRowCnt = 0; + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = kL0Size; + nd2nzPara.dValue = N_SPLIT_SIZE; // constInfo.headDim; + nd2nzPara.srcDValue = constInfo.headDim; + nd2nzPara.dstNzC0Stride = kL0SizeAlign; + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + DataCopy(bL1Tensor[(kL1 - kOffset) * 128 * N_SPLIT_SIZE], + kvMergeGm_[info.loop % 4 * N_WORKSPACE_SIZE * 576 + kL1 * 128 * constInfo.headDim + + nL1 * N_SPLIT_SIZE], + nd2nzPara); + } else { + while (copyFinishRowCnt < kL0Size) { + CalcTopKBlockInfo(info, curTopKIdx, curOffsetInSparseBlock, curSeqIdx, copyRowCnt, idInTopK); + + if (copyFinishRowCnt + copyRowCnt > kL0Size) { + copyRowCnt = kL0Size - copyFinishRowCnt; + } + + if constexpr (PAGE_ATTENTION) { + Position startPos; + startPos.bIdx = info.bIdx; + startPos.n2Idx = info.n2Idx; + startPos.s2Idx = idInTopK * constInfo.sparseBlockSize + curOffsetInSparseBlock; + startPos.dIdx = + nL1 * N_SPLIT_SIZE; + PAShape shape; + shape.blockSize = kvCacheBlockSize; + shape.headNum = constInfo.kvHeadNum; + shape.headDim = constInfo.headDim; + shape.actHeadDim = nL1Size; + shape.maxblockNumPerBatch = maxBlockNumPerBatch; + shape.copyRowNum = copyRowCnt; + shape.copyRowNumAlign = kL0SizeAlign; + subvTensor = bL1Tensor[(kL1 - kOffset) * 128 * N_SPLIT_SIZE + copyFinishRowCnt * 16]; + DataCopyPA(subvTensor, valueGm, blockTableGm, shape, startPos); + } else { + uint64_t valueOffset = info.tensorBOffset; + if constexpr (KV_LAYOUT_T == SFA_LAYOUT::BSND || KV_LAYOUT_T == SFA_LAYOUT::TND) { + valueOffset += (idInTopK * constInfo.sparseBlockSize + curOffsetInSparseBlock) * + constInfo.kvHeadNum * constInfo.headDim; + } else { + valueOffset += (idInTopK * constInfo.sparseBlockSize + curOffsetInSparseBlock) * + constInfo.headDim; + } + + subvTensor = bL1Tensor[(kL1 - kOffset) * 128 * N_SPLIT_SIZE]; + CopyInMm2BToL1(subvTensor, valueOffset, kL0SizeAlign, copyFinishRowCnt, copyRowCnt, + nL1 * N_SPLIT_SIZE, nL1Size); + } + copyFinishRowCnt += copyRowCnt; + curSeqIdx += copyRowCnt; + } + } + } + SetFlag(mte21KVIds[kb]); + WaitFlag(mte21KVIds[kb]); + mL1SizeAlign = M_SPLIT_SIZE; + mL1Size = M_SPLIT_SIZE; + for (uint32_t mL1 = 0; mL1 < mL1Loops; mL1++) { + if (mL1 == (mL1Loops - 1)) { + mL1Size = mSize - (mL1Loops - 1) * M_SPLIT_SIZE; + mL1SizeAlign = SFAAlign(mL1Size, 16U); + } + + uint32_t mIdx = mBaseIdx + mL1; + ka = GetQPL1RealIdx(mIdx, k1); + LocalTensor aL1Tensor = l1QPTensor[ka * L1_BLOCK_OFFSET]; + if (nL1 == 0) { + WaitFlag(mte21QPIds[ka]); + CopyInMm2AToL1(aL1Tensor, info, mSplitInfo.nBufferStartM + mL1 * M_SPLIT_SIZE, mL1Size, kL1Size, + 256 * k1); + SetFlag(mte21QPIds[ka]); + WaitFlag(mte21QPIds[ka]); + } + + LocalTensor cL0Tensor = + cL0TensorPingPong[(cL0BufIter % 2) * + (L0C_PP_SIZE / sizeof(MM_OUT_T))]; + uint32_t baseK = 128; + uint32_t baseN = 128; + kL0Size = 128; + kL0SizeAlign = kL0Size; + for (uint32_t kL0 = 0; kL0 < kL0Loops; kL0++) { + if (kL0 + 1 == kL0Loops) { + kL0Size = kL1Size - (kL0Loops - 1) * kL0Size; + kL0SizeAlign = SFAAlign(kL0Size, 16U); + } + WaitFlag(Mte1MmABEventId(abL0BufIter % 2)); + LocalTensor bL0Tensor = bL0TensorPingPong[(abL0BufIter % 2) * (L0B_PP_SIZE / sizeof(KV_T))]; + LoadData3DParamsV2 loadData3DParamsForB; + loadData3DParamsForB.l1H = kL0SizeAlign / 16; + loadData3DParamsForB.l1W = 16; + loadData3DParamsForB.padList[0] = 0; + loadData3DParamsForB.padList[1] = 0; + loadData3DParamsForB.padList[2] = 0; + loadData3DParamsForB.padList[3] = 255; + + loadData3DParamsForB.mExtension = kL0SizeAlign; + loadData3DParamsForB.kExtension = nL1SizeAlign; + loadData3DParamsForB.mStartPt = 0; + loadData3DParamsForB.kStartPt = 0; + loadData3DParamsForB.strideW = 1; + loadData3DParamsForB.strideH = 1; + loadData3DParamsForB.filterW = 1; + loadData3DParamsForB.filterSizeW = false; + loadData3DParamsForB.filterH = 1; + loadData3DParamsForB.filterSizeH = false; + loadData3DParamsForB.dilationFilterW = 1; + loadData3DParamsForB.dilationFilterH = 1; + loadData3DParamsForB.enTranspose = 1; + loadData3DParamsForB.fMatrixCtrl = 0; + loadData3DParamsForB.channelSize = nL1SizeAlign; + LoadData(bL0Tensor, bL1Tensor[kL0 * baseK * baseN], loadData3DParamsForB); + + LocalTensor aL0Tensor = aL0TensorPingPong[(abL0BufIter % 2) * (L0A_PP_SIZE / sizeof(KV_T))]; + LoadData3DParamsV2 loadData3DParamsForA; + loadData3DParamsForA.l1H = mL1SizeAlign / 16; + loadData3DParamsForA.l1W = 16; + loadData3DParamsForA.padList[0] = 0; + loadData3DParamsForA.padList[1] = 0; + loadData3DParamsForA.padList[2] = 0; + loadData3DParamsForA.padList[3] = 255; + + loadData3DParamsForA.mExtension = mL1SizeAlign; + loadData3DParamsForA.kExtension = kL0SizeAlign; + loadData3DParamsForA.mStartPt = 0; + loadData3DParamsForA.kStartPt = 0; + loadData3DParamsForA.strideW = 1; + loadData3DParamsForA.strideH = 1; + loadData3DParamsForA.filterW = 1; + loadData3DParamsForA.filterSizeW = false; + loadData3DParamsForA.filterH = 1; + loadData3DParamsForA.filterSizeH = false; + loadData3DParamsForA.dilationFilterW = 1; + loadData3DParamsForA.dilationFilterH = 1; + loadData3DParamsForA.enTranspose = 0; + loadData3DParamsForA.fMatrixCtrl = 0; + loadData3DParamsForA.channelSize = kL0SizeAlign; + LoadData(aL0Tensor, aL1Tensor[kL0 * baseK * mL1SizeAlign], + loadData3DParamsForA); + SetFlag(Mte1MmABEventId(abL0BufIter % 2)); + WaitFlag(Mte1MmABEventId(abL0BufIter % 2)); + + MmadParams mmadParams; + mmadParams.m = mL1SizeAlign; + mmadParams.n = nL1SizeAlign; + mmadParams.k = kL0Size; + mmadParams.cmatrixInitVal = (kL0 == 0 && k1 == 0); + mmadParams.cmatrixSource = false; + mmadParams.unitFlag = ((k1 == (kL1Loops - 1)) && (kL0 == (kL0Loops - 1))) ? 0b11 : 0b10; + + Mmad(cL0Tensor, aL0Tensor, bL0Tensor, mmadParams); + if ((mmadParams.m / 16) * (mmadParams.n / 16) < 10) { + PipeBarrier(); + } + SetFlag(Mte1MmABEventId(abL0BufIter % 2)); + abL0BufIter++; + } + + if (nL1 == (nL1Loops - 1)) { + SetFlag(mte21QPIds[ka]); + } + + if (k1 == (kL1Loops - 1)) { + if (nL1 == 0 && mL1 == 0) { + CrossCoreWaitFlag(constInfo.syncV1NupdateC2); + } + + if (!info.isFirstSInnerLoop) { + SetAtomicAdd(); + } + // ND + FixpipeParamsV220 fixParams; + fixParams.nSize = nL1SizeAlign; + fixParams.mSize = mL1SizeAlign; + fixParams.srcStride = mL1SizeAlign; + fixParams.dstStride = nSize; + fixParams.ndNum = 1; + fixParams.unitFlag = 0b11; + + uint64_t mm2Offset = (mSplitInfo.nBufferStartM + mL1 * M_SPLIT_SIZE) * nSize + nL1 * N_SPLIT_SIZE; + Fixpipe(mm2ResGm[(info.bn2IdxInCurCore % (constInfo.preLoadNum)) * + constInfo.bmm2ResUbSize + mm2Offset], cL0Tensor, fixParams); + if (!info.isFirstSInnerLoop) { + SetAtomicNone(); + } + } + + if (mL1Loops == 2) { + cL0BufIter++; + } + } + SetFlag(mte21KVIds[kb]); + } + if (mL1Loops == 1) { + cL0BufIter++; + } + } + qpL1BufIter += mL1Loops; +} + +#endif // SPARSE_FLASH_ATTENTION_SERVICE_CUBE_MLA_H \ No newline at end of file diff --git a/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_service_vector_mla.h b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_service_vector_mla.h new file mode 100644 index 00000000..79b49e0a --- /dev/null +++ b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_service_vector_mla.h @@ -0,0 +1,1329 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_service_vector_mla.h + * \brief + */ +#ifndef SPARSE_FLASH_ATTENTION_SERVICE_VECTOR_MLA_H +#define SPARSE_FLASH_ATTENTION_SERVICE_VECTOR_MLA_H + +#include "kernel_operator.h" +#include "kernel_operator_list_tensor_intf.h" +#include "kernel_tiling/kernel_tiling.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" +#include "sparse_flash_attention_common.h" + +using AscendC::CrossCoreSetFlag; +using AscendC::CrossCoreWaitFlag; + +template class SFAVectorService { +public: + using T = float; + using KV_T = typename SFAT::kvType; + using OUT_T = typename SFAT::outputType; + using UPDATE_T = T; + using MM1_OUT_T = float; + using MM2_OUT_T = float; + + __aicore__ inline SFAVectorService(){}; + __aicore__ inline void ProcessVec1L(const RunInfo &info); + __aicore__ inline void ProcessVec2L(const RunInfo &info); + __aicore__ inline void InitBuffers(TPipe *pipe); + __aicore__ inline void InitParams(const struct ConstInfo &constInfo, + const SparseFlashAttentionTilingDataMla *__restrict tilingData); + __aicore__ inline void InitMm2ResInt32GmGlobalTensor(GlobalTensor mm2ResInt32Gm); + __aicore__ inline void InitVec0GlobalTensor(const GlobalTensor &kvValidSizeGm, + const GlobalTensor &kvMergeGm, + const GlobalTensor &keyRopeGm, const GlobalTensor &keyGm, + const GlobalTensor &blkTableGm); + __aicore__ inline void InitVec1GlobalTensor(GlobalTensor mm1ResGm, GlobalTensor vec1ResGm, + GlobalTensor actualSeqLengthsQGm, + GlobalTensor actualSeqLengthsKVGm, GlobalTensor lseMaxFdGm, + GlobalTensor lseSumFdGm, GlobalTensor topKGm); + __aicore__ inline void InitVec2GlobalTensor(GlobalTensor accumOutGm, GlobalTensor vec2ResGm, + GlobalTensor mm2ResGm, GlobalTensor attentionOutGm); + __aicore__ inline void AllocEventID(); + __aicore__ inline void FreeEventID(); + __aicore__ inline void InitSoftmaxDefaultBuffer(); + // ================================Base Vector========================================== + __aicore__ inline void RowDivs(LocalTensor dstUb, LocalTensor src0Ub, LocalTensor src1Ub, + uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount); + __aicore__ inline void RowMuls(LocalTensor dstUb, LocalTensor src0Ub, LocalTensor src1Ub, + uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount); + // ================================Vector0========================================== + __aicore__ inline void MergeKv(const RunInfo &runInfo); + __aicore__ inline int64_t GetKeyGmOffset(int64_t realS2Idx, const RunInfo &runInfo, int64_t s2IdLimit); + __aicore__ inline int64_t GetKeyRopeGmOffset(int64_t realS2Idx, const RunInfo &runInfo, int64_t s2IdLimit); + __aicore__ inline void GetRealS2Idx(int64_t s2GmOffset, int64_t &realS2Idx, int64_t topkGmBaseOffset, + const RunInfo &runInfo); + __aicore__ inline void CopyInKv(int64_t &mte2Size, int64_t mte3Size, int64_t mergeMte3Idx, int64_t realS2Idx1, + int64_t realS2Idx2, const RunInfo &runInfo); + __aicore__ inline void CopyOutMrgeResult(int64_t mte2Size, int64_t mte3Size, int64_t s2StartGmOffset, + int64_t mergeMte3Idx, const RunInfo &runInfo); + __aicore__ inline void SetInfInBlk(const LocalTensor &mmResUb, uint32_t dealRowCount, uint32_t columnCount, + uint64_t startId, uint64_t endId); + __aicore__ inline void SetMidInf(const LocalTensor &mmResUb, uint32_t dealRowCount, uint32_t columnCount, + uint64_t startId, uint64_t endId); + __aicore__ inline void CopyInSingleKv(int64_t &mte2Size, int64_t mte3Size, int64_t mergeMte3Idx, int64_t realS2Idx, + int64_t keyBNBOffset,int64_t s2IdLimit, const RunInfo &runInfo); + // ================================Vector1========================================== + __aicore__ inline void ProcessVec1SingleBuf(const RunInfo &info, const MSplitInfo &mSplitInfo); + __aicore__ inline void DealBmm1ResBaseBlock(const RunInfo &info, const MSplitInfo &mSplitInfo, uint32_t startRow, + uint32_t dealRowCount, uint32_t columnCount, uint32_t loopId); + __aicore__ inline void SoftmaxFlashV2Compute(const RunInfo &info, const MSplitInfo &mSplitInfo, + LocalTensor &mmResUb, LocalTensor &softmaxTmpUb, + uint32_t startRow, uint32_t dealRowCount, uint32_t columnCount, + uint32_t actualColumnCount); + __aicore__ inline void AmlaVecCompute(const RunInfo &info, const MSplitInfo &mSplitInfo, LocalTensor &mmResUb, + LocalTensor &softmaxTmpUb, uint32_t startRow, uint32_t dealRowCount, + uint32_t columnCount, uint32_t actualColumnCount); + __aicore__ inline void ElewiseCompute(const RunInfo &info, const LocalTensor &mmResUb, uint32_t dealRowCount, + uint32_t columnCount); + __aicore__ inline void ProcessAmlaNupdate(const RunInfo &info, const MSplitInfo &mSplitInfo); + __aicore__ inline void ComputeLogSumExpAndCopyToGm(const RunInfo &info, const MSplitInfo &mSplitInfo, + LocalTensor &softmaxSumUb, LocalTensor &softmaxMaxUb); + // ================================Vecotr2========================================== + __aicore__ inline void ProcessVec2SingleBuf(const RunInfo &info, const MSplitInfo &mSplitInfo); + __aicore__ inline void DealBmm2ResBaseBlock(const RunInfo &info, const MSplitInfo &mSplitInfo, uint32_t startRow, + uint32_t dealRowCount, uint32_t columnCount, + uint32_t actualColumnCount); + __aicore__ inline void ProcessVec2Inner(const RunInfo &info, const MSplitInfo &mSplitInfo, uint32_t mStartRow, + uint32_t mDealSize); + __aicore__ inline void Bmm2DataCopyOutTrans(const RunInfo &info, LocalTensor &attenOutUb, uint32_t wsMStart, + uint32_t dealRowCount, uint32_t columnCount, + uint32_t actualColumnCount); + __aicore__ inline void Bmm2ResCopyOut(const RunInfo &info, LocalTensor &bmm2ResUb, uint32_t wsMStart, + uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount); + __aicore__ inline void Bmm2CastAndCopyOut(const RunInfo &info, LocalTensor &bmm2ResUb, uint32_t wsMStart, + uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount); + __aicore__ inline void Bmm2FDDataCopyOut(const RunInfo &info, LocalTensor &bmm2ResUb, uint32_t wsMStart, + uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount); + __aicore__ inline uint64_t CalcAccumOffset(uint32_t bN2Idx, uint32_t gS1Idx); + __aicore__ inline void GetConfusionTransposeTiling(int64_t numR, int64_t numC, const uint32_t stackBufferSize, + const uint32_t typeSize, ConfusionTransposeTiling &tiling); + + static constexpr uint64_t BYTE_BLOCK = 32UL; + static constexpr uint32_t REPEAT_BLOCK_BYTE = 256U; + static constexpr uint32_t FP32_BLOCK_ELEMENT_NUM = BYTE_BLOCK / sizeof(float); + static constexpr uint32_t FP32_REPEAT_ELEMENT_NUM = REPEAT_BLOCK_BYTE / sizeof(float); + static constexpr uint32_t REPEATE_STRIDE_UP_BOUND = 256; + +private: + static constexpr bool PAGE_ATTENTION = SFAT::pageAttention; + static constexpr int TEMPLATE_MODE = SFAT::templateMode; + static constexpr bool FLASH_DECODE = SFAT::flashDecode; + static constexpr SFA_LAYOUT LAYOUT_T = SFAT::layout; + static constexpr SFA_LAYOUT KV_LAYOUT_T = SFAT::kvLayout; + + static constexpr uint64_t MERGE_CACHE_GM_BUF_NUM = 4; + static constexpr uint64_t SYNC_INPUT_BUF1_FLAG = 2; + static constexpr uint64_t SYNC_INPUT_BUF1_PONG_FLAG = 3; + static constexpr uint64_t SYNC_INPUT_BUF2_FLAG = 4; + static constexpr uint64_t SYNC_INPUT_BUF2_PONG_FLAG = 5; + static constexpr uint64_t SYNC_OUTPUT_BUF1_FLAG = 4; + static constexpr uint64_t SYNC_OUTPUT_BUF2_FLAG = 5; + static constexpr uint32_t INPUT1_BUFFER_OFFSET = ConstInfo::BUFFER_SIZE_BYTE_32K; + static constexpr uint32_t SOFTMAX_TMP_BUFFER_OFFSET = ConstInfo::BUFFER_SIZE_BYTE_1K; + static constexpr uint32_t BASE_BLOCK_MAX_ELEMENT_NUM = ConstInfo::BUFFER_SIZE_BYTE_32K / sizeof(T); // 32768/4=8096 + static constexpr uint32_t BLOCK_ELEMENT_NUM = BYTE_BLOCK / sizeof(T); // 32/4=8 + static constexpr T FLOAT_E_SCALAR = 8388608; + static constexpr T LN2 = 0.6931471805599453094172; + static constexpr T RECIP_OF_LN2 = 1 / LN2; + static constexpr T SOFTMAX_MIN_NUM = -2e38; + + const SparseFlashAttentionTilingDataMla *__restrict tilingData; + + uint32_t pingpongFlag = 0U; + ConstInfo constInfo = {}; + + GlobalTensor mm2ResInt32Gm; + GlobalTensor mm1ResGm; + GlobalTensor vec1ResGm; + GlobalTensor lseSumFdGm; + GlobalTensor lseMaxFdGm; + + GlobalTensor actualSeqLengthsQGm; + GlobalTensor actualSeqLengthsKVGm; + GlobalTensor vec2ResGm; + GlobalTensor mm2ResGm; + GlobalTensor accumOutGm; + GlobalTensor attentionOutGm; + GlobalTensor blkTableGm_; + + GlobalTensor kvMergeGm_; + GlobalTensor keyRopeGm_; + GlobalTensor keyGm_; + GlobalTensor topkGm_; + GlobalTensor kvValidSizeGm_; + + // ================================Local Buffer==================================== + TBuf<> inputBuff1; // 32K + TBuf<> inputBuff2; // 16K + TBuf<> outputBuff1; // 32K + TBuf<> outputBuff2; // 4K + + TBuf<> tmpBuff1; // 32K + TBuf<> v0ValidSizeBuff; // 8K + + TBuf<> nValueBuff; + TBuf<> cofValueBuff; + TBuf<> aMlaSumBuff; + TBuf<> softmaxMaxBuff; // PRE_LOAD_NUM * 2K + TBuf<> softmaxExpBuff; // PRE_LOAD_NUM * 2K + TBuf<> softmaxSumBuff; // PRE_LOAD_NUM * 2K + TBuf<> softmaxMaxDefaultBuff; // 2K + TBuf<> softmaxSumDefaultBuff; // 2K + + LocalTensor softmaxMaxDefaultUb; + LocalTensor softmaxSumDefaultUb; + + LocalTensor nValueUb; + LocalTensor cofValueUb; + LocalTensor aMlaSumUb; + LocalTensor softmaxMaxUb; + LocalTensor softmaxSumUb; + LocalTensor softmaxExpUb; + LocalTensor kvMergUb_; + LocalTensor ropeMergUb_; + LocalTensor v0ValidSizeUb_; +}; + +template __aicore__ inline void SFAVectorService::InitBuffers(TPipe *pipe) +{ + pipe->InitBuffer(inputBuff1, ConstInfo::BUFFER_SIZE_BYTE_32K * 2); + pipe->InitBuffer(inputBuff2, ConstInfo::BUFFER_SIZE_BYTE_8K * 2); + pipe->InitBuffer(outputBuff1, ConstInfo::BUFFER_SIZE_BYTE_32K); + pipe->InitBuffer(outputBuff2, ConstInfo::BUFFER_SIZE_BYTE_4K); + + pipe->InitBuffer(tmpBuff1, ConstInfo::BUFFER_SIZE_BYTE_32K); + pipe->InitBuffer(v0ValidSizeBuff, ConstInfo::BUFFER_SIZE_BYTE_8K); + + // M_MAX = 512/2vector = 256, 256 * sizeof(T) * N_Buffer + pipe->InitBuffer(nValueBuff, ConstInfo::BUFFER_SIZE_BYTE_1K * constInfo.preLoadNum); + pipe->InitBuffer(cofValueBuff, ConstInfo::BUFFER_SIZE_BYTE_1K * constInfo.preLoadNum); + pipe->InitBuffer(aMlaSumBuff, ConstInfo::BUFFER_SIZE_BYTE_1K * constInfo.preLoadNum); + + pipe->InitBuffer(softmaxMaxBuff, ConstInfo::BUFFER_SIZE_BYTE_1K * constInfo.preLoadNum); + pipe->InitBuffer(softmaxExpBuff, ConstInfo::BUFFER_SIZE_BYTE_1K * constInfo.preLoadNum); + pipe->InitBuffer(softmaxSumBuff, ConstInfo::BUFFER_SIZE_BYTE_1K * constInfo.preLoadNum); + + pipe->InitBuffer(softmaxMaxDefaultBuff, ConstInfo::BUFFER_SIZE_BYTE_1K); + pipe->InitBuffer(softmaxSumDefaultBuff, ConstInfo::BUFFER_SIZE_BYTE_1K); + + nValueUb = nValueBuff.Get(); + cofValueUb = cofValueBuff.Get(); + aMlaSumUb = aMlaSumBuff.Get(); + + softmaxMaxUb = softmaxMaxBuff.Get(); + softmaxSumUb = softmaxSumBuff.Get(); + softmaxExpUb = softmaxExpBuff.Get(); + + softmaxMaxDefaultUb = softmaxMaxDefaultBuff.Get(); + softmaxSumDefaultUb = softmaxSumDefaultBuff.Get(); + + kvMergUb_ = inputBuff1.Get(); + ropeMergUb_ = inputBuff2.Get(); + + v0ValidSizeUb_ = v0ValidSizeBuff.Get(); +} + +template +__aicore__ inline void +SFAVectorService::InitParams(const struct ConstInfo &constInfo, + const SparseFlashAttentionTilingDataMla *__restrict tilingData) +{ + this->constInfo = constInfo; + this->tilingData = tilingData; +} + +template +__aicore__ inline void +SFAVectorService::InitMm2ResInt32GmGlobalTensor(GlobalTensor mm2ResInt32Gm) +{ + this->mm2ResInt32Gm = mm2ResInt32Gm; +} + +template +__aicore__ inline void SFAVectorService::InitVec0GlobalTensor( + const GlobalTensor &kvValidSizeGm, const GlobalTensor &kvMergeGm, + const GlobalTensor &keyRopeGm, const GlobalTensor &keyGm, const GlobalTensor &blkTableGm) +{ + this->kvMergeGm_ = kvMergeGm; + this->keyRopeGm_ = keyRopeGm; + this->keyGm_ = keyGm; + this->blkTableGm_ = blkTableGm; + this->kvValidSizeGm_ = kvValidSizeGm; +} + +template +__aicore__ inline void SFAVectorService::InitVec1GlobalTensor( + GlobalTensor mm1ResGm, GlobalTensor vec1ResGm, + GlobalTensor actualSeqLengthsQGm, GlobalTensor actualSeqLengthsKVGm, GlobalTensor lseMaxFdGm, + GlobalTensor lseSumFdGm, GlobalTensor topKGm) +{ + this->mm1ResGm = mm1ResGm; + this->vec1ResGm = vec1ResGm; + this->actualSeqLengthsQGm = actualSeqLengthsQGm; + this->actualSeqLengthsKVGm = actualSeqLengthsKVGm; + this->lseMaxFdGm = lseMaxFdGm; + this->lseSumFdGm = lseSumFdGm; + this->topkGm_ = topKGm; +} + +template +__aicore__ inline void SFAVectorService::InitVec2GlobalTensor(GlobalTensor accumOutGm, + GlobalTensor vec2ResGm, + GlobalTensor mm2ResGm, + GlobalTensor attentionOutGm) +{ + this->accumOutGm = accumOutGm; + this->vec2ResGm = vec2ResGm; + this->mm2ResGm = mm2ResGm; + this->attentionOutGm = attentionOutGm; +} + +template __aicore__ inline void SFAVectorService::AllocEventID() +{ + SetFlag(SYNC_INPUT_BUF1_FLAG); + SetFlag(SYNC_INPUT_BUF1_PONG_FLAG); + SetFlag(SYNC_INPUT_BUF2_FLAG); + SetFlag(SYNC_INPUT_BUF2_PONG_FLAG); + SetFlag(SYNC_OUTPUT_BUF1_FLAG); + SetFlag(SYNC_OUTPUT_BUF2_FLAG); +} + +template __aicore__ inline void SFAVectorService::FreeEventID() +{ + WaitFlag(SYNC_INPUT_BUF1_FLAG); + WaitFlag(SYNC_INPUT_BUF1_PONG_FLAG); + WaitFlag(SYNC_INPUT_BUF2_FLAG); + WaitFlag(SYNC_INPUT_BUF2_PONG_FLAG); + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); +} + +template __aicore__ inline void SFAVectorService::InitSoftmaxDefaultBuffer() +{ + Duplicate(softmaxMaxDefaultUb, SOFTMAX_MIN_NUM, SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)); + Duplicate(softmaxSumDefaultUb, ConstInfo::FLOAT_ZERO, SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)); +} + +template +__aicore__ inline void SFAVectorService::ComputeLogSumExpAndCopyToGm(const RunInfo &info, + const MSplitInfo &mSplitInfo, + LocalTensor &softmaxSumUb, + LocalTensor &softmaxMaxUb) +{ + if (mSplitInfo.vecDealM == 0) { + return; + } + uint64_t baseOffset = mSplitInfo.nBufferStartM / 2; + size_t size = mSplitInfo.vecDealM * FP32_BLOCK_ELEMENT_NUM; + uint64_t accumTmpOutNum = CalcAccumOffset(info.bIdx, info.gS1Idx); + uint64_t offset = (accumTmpOutNum * constInfo.kvHeadNum * constInfo.mBaseSize + + info.tndCoreStartKVSplitPos * constInfo.kvHeadNum * constInfo.mBaseSize + + mSplitInfo.nBufferStartM + mSplitInfo.vecStartM) * + FP32_BLOCK_ELEMENT_NUM; + if (info.actualSingleProcessSInnerSize != 0) { + LocalTensor tmp = outputBuff2.Get(); + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); + Brcb(tmp, softmaxSumUb[baseOffset], (mSplitInfo.vecDealM + 7) / 8, {1, 8}); + SetFlag(SYNC_OUTPUT_BUF2_FLAG); + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); + DataCopy(lseSumFdGm[offset], tmp, size); + SetFlag(SYNC_OUTPUT_BUF2_FLAG); + + tmp = outputBuff2.Get(); + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); + Brcb(tmp, softmaxMaxUb[baseOffset], (mSplitInfo.vecDealM + 7) / 8, {1, 8}); + SetFlag(SYNC_OUTPUT_BUF2_FLAG); + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); + DataCopy(lseMaxFdGm[offset], tmp, size); + SetFlag(SYNC_OUTPUT_BUF2_FLAG); + } else { + matmul::InitOutput(lseSumFdGm[offset], size, ConstInfo::FLOAT_ZERO); + matmul::InitOutput(lseMaxFdGm[offset], size, SOFTMAX_MIN_NUM); + } +} + +template +__aicore__ inline void SFAVectorService::ElewiseCompute(const RunInfo &info, + const LocalTensor &mmResUb, + uint32_t dealRowCount, uint32_t columnCount) +{ + Muls(mmResUb, mmResUb, static_cast(tilingData->baseParams.scaleValue), dealRowCount * columnCount); + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + uint64_t s2ValidSizeFirstPart = v0ValidSizeUb_.GetValue(128 + info.loop % MERGE_CACHE_GM_BUF_NUM); + uint64_t s2ValidSizeSecondPart = v0ValidSizeUb_.GetValue(256 + info.loop % MERGE_CACHE_GM_BUF_NUM); + + int64_t s2ProcessSize = info.actualSingleProcessSInnerSize; + int64_t s2Pair = CeilDiv(s2ProcessSize, 2L * constInfo.sparseBlockSize); + int64_t s2Mid = CeilDiv(s2Pair, 2L) * 2 * constInfo.sparseBlockSize; + if (s2Mid > s2ProcessSize) { + s2Mid = s2ProcessSize; + } + if (unlikely(s2ValidSizeFirstPart < s2Mid)) { + int64_t s2StartCeilAlign = CeilAlign(s2ValidSizeFirstPart, 8); + int64_t s2MidFloorAlign = s2Mid / 8 * 8; + SetInfInBlk(mmResUb, dealRowCount, columnCount, s2ValidSizeFirstPart, + s2StartCeilAlign >= s2Mid ? s2Mid : s2StartCeilAlign); + SetMidInf(mmResUb, dealRowCount, columnCount, s2StartCeilAlign, s2MidFloorAlign); + SetInfInBlk(mmResUb, dealRowCount, columnCount, + s2StartCeilAlign <= s2MidFloorAlign ? s2MidFloorAlign : s2StartCeilAlign, s2Mid); + } + if (unlikely(s2ValidSizeSecondPart < s2ProcessSize - s2Mid)) { + int64_t s2StartCeilAlign = CeilAlign(s2Mid + s2ValidSizeSecondPart, 8); + int64_t s2EndFloorAlign = s2ProcessSize / 8 * 8; + SetInfInBlk(mmResUb, dealRowCount, columnCount, s2Mid + s2ValidSizeSecondPart, + s2StartCeilAlign >= s2ProcessSize ? s2ProcessSize : s2StartCeilAlign); + SetMidInf(mmResUb, dealRowCount, columnCount, s2StartCeilAlign, s2EndFloorAlign); + SetInfInBlk(mmResUb, dealRowCount, columnCount, + s2StartCeilAlign <= s2EndFloorAlign ? s2EndFloorAlign : s2StartCeilAlign, s2ProcessSize); + } + } +} + +template +__aicore__ inline void SFAVectorService::SetInfInBlk(const LocalTensor &mmResUb, + uint32_t dealRowCount, uint32_t columnCount, + uint64_t startId, uint64_t endId) +{ + if (startId >= endId) { + return; + } + + uint64_t startFloorAlignSize = startId / BLOCK_ELEMENT_NUM * BLOCK_ELEMENT_NUM; + uint64_t notComputePreMaskOneBlk = (1 << (startId - startFloorAlignSize)) - 1; + uint64_t notComputePostMaskOneBlk = ~((1 << (endId - startFloorAlignSize)) - 1); + uint64_t notComputeMaskOneBlk = notComputePreMaskOneBlk ^ notComputePostMaskOneBlk; + + uint64_t maskOneBlk = ~notComputeMaskOneBlk; + uint64_t mask[1] = {maskOneBlk}; + for (int i = 1; i < 8; i++) { + mask[0] = mask[0] | (maskOneBlk << (i * 8)); + } + for (uint64_t rowId = 0; rowId < dealRowCount; rowId += 8) { + Duplicate(mmResUb[rowId * columnCount + startFloorAlignSize], SOFTMAX_MIN_NUM, mask, + 1, CeilDiv(columnCount, 8), 0); + } +} + +template +__aicore__ inline void SFAVectorService::SetMidInf(const LocalTensor &mmResUb, + uint32_t dealRowCount, uint32_t columnCount, + uint64_t startId, uint64_t endId) +{ + if (startId >= endId) { + return; + } + for (uint64_t rowId = 0; rowId < dealRowCount; rowId++) { + Duplicate(mmResUb[rowId * columnCount + startId], SOFTMAX_MIN_NUM, endId - startId); + } +} + +template +__aicore__ inline void SFAVectorService::SoftmaxFlashV2Compute( + const RunInfo &info, const MSplitInfo &mSplitInfo, LocalTensor &mmResUb, LocalTensor &softmaxTmpUb, + uint32_t startRow, uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount) +{ + LocalTensor inSumTensor; + LocalTensor inMaxTensor; + uint32_t baseOffset = mSplitInfo.nBufferStartM / 2 + startRow; + uint32_t outIdx = info.loop % (constInfo.preLoadNum); + uint32_t softmaxOutOffset = outIdx * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T) + baseOffset; + if (info.isFirstSInnerLoop) { + inMaxTensor = softmaxMaxDefaultUb; + inSumTensor = softmaxSumDefaultUb; + } else { + uint32_t inIdx = (info.loop - 1) % (constInfo.preLoadNum); + inMaxTensor = softmaxMaxUb[inIdx * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T) + baseOffset]; + inSumTensor = softmaxSumUb[inIdx * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T) + baseOffset]; + } + if (actualColumnCount !=0) { + SoftMaxShapeInfo srcShape{dealRowCount, columnCount, dealRowCount, actualColumnCount}; + SoftMaxTiling newTiling = + SoftMaxFlashV2TilingFunc(srcShape, sizeof(T), sizeof(T), softmaxTmpUb.GetSize(), true, false); + SoftmaxFlashV2( + mmResUb, softmaxSumUb[softmaxOutOffset], softmaxMaxUb[softmaxOutOffset], mmResUb, + softmaxExpUb[softmaxOutOffset], inSumTensor, inMaxTensor, softmaxTmpUb, newTiling, srcShape); + } else { + uint32_t dealRowCountAlign = SFAAlign(dealRowCount, FP32_BLOCK_ELEMENT_NUM); + DataCopy(softmaxSumUb[softmaxOutOffset], inSumTensor, dealRowCountAlign); + pipe_barrier(PIPE_V); + DataCopy(softmaxMaxUb[softmaxOutOffset], inMaxTensor, dealRowCountAlign); + } +} + +template +__aicore__ inline void SFAVectorService::AmlaVecCompute( + const RunInfo &info, const MSplitInfo &mSplitInfo, LocalTensor &mmResUb, LocalTensor &softmaxTmpUb, + uint32_t startRow, uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount) +{ + uint32_t baseOffset = mSplitInfo.nBufferStartM / 2 + startRow; + uint32_t calCount = dealRowCount; + uint32_t outIdx = info.loop % (constInfo.preLoadNum); + uint32_t softmaxOutOffset = outIdx * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T) + baseOffset; + // compute n(i) + LocalTensor nTmp = softmaxTmpUb.template ReinterpretCast(); + LocalTensor nUpdateTmp = nTmp[SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)]; + Muls(nTmp, softmaxMaxUb[softmaxOutOffset], ((T)(-1.0)) * RECIP_OF_LN2, calCount); + + pipe_barrier(PIPE_V); + Cast(nTmp, nTmp, RoundMode::CAST_ROUND, calCount); + pipe_barrier(PIPE_V); + + uint32_t prOutIdx = (info.loop - 1) % (constInfo.preLoadNum); + uint32_t PreSoftmaxOutOffset = prOutIdx * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T) + baseOffset; + // n(i) - n(i-1) + if (info.isFirstSInnerLoop) { + Duplicate(nUpdateTmp, ConstInfo::FLOAT_ZERO, calCount); // n1=n0 + } else { + Sub(nUpdateTmp, nTmp, nValueUb[PreSoftmaxOutOffset], calCount); + } + pipe_barrier(PIPE_V); + // update n(i), DataCopy not support when calCount is not align 32B, so use Adds + Adds(nValueUb[softmaxOutOffset], nTmp, ConstInfo::FLOAT_ZERO, calCount); + pipe_barrier(PIPE_V); + + // update softmax res + LocalTensor nUpdateTmp2 = nTmp[2 * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)]; + LocalTensor nTmp_KvT = nTmp[3 * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)].template ReinterpretCast(); + LocalTensor tmpCofUb = nTmp[4 * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)]; + LocalTensor epsUb = nTmp[5 * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)]; + Muls(nUpdateTmp2, softmaxMaxUb[softmaxOutOffset], RECIP_OF_LN2, calCount); + pipe_barrier(PIPE_V); + Add(nTmp, nUpdateTmp2, nTmp, calCount); + pipe_barrier(PIPE_V); + Muls(nTmp, nTmp, LN2, calCount); + pipe_barrier(PIPE_V); + Exp(nTmp, nTmp, calCount); + pipe_barrier(PIPE_V); + Cast(nTmp_KvT, nTmp, RoundMode::CAST_ROUND, calCount); // fp32->fp16/bf16 + pipe_barrier(PIPE_V); + Cast(nUpdateTmp2, nTmp_KvT, RoundMode::CAST_NONE, calCount); // fp16/bf16->fp32 + pipe_barrier(PIPE_V); + if (info.s2Idx + 1 == info.curSInnerLoopTimes) { + Mul(aMlaSumUb[softmaxOutOffset], softmaxSumUb[softmaxOutOffset], nUpdateTmp2, calCount); + } + if (actualColumnCount == 0) { + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); + SetFlag(SYNC_OUTPUT_BUF2_FLAG); + return; + } + LocalTensor nTmp3 = nTmp[6 * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)]; + Brcb(nTmp3, nUpdateTmp2, (dealRowCount + 7) / 8, {1, 8}); + pipe_barrier(PIPE_V); + RowMuls(mmResUb, mmResUb, nTmp3, dealRowCount, columnCount, actualColumnCount); + + Div(tmpCofUb, nTmp, nUpdateTmp2, calCount); // cof(i)=tmpS32/tmpS16 + if (info.isFirstSInnerLoop) { + Duplicate(cofValueUb[softmaxOutOffset], (T)1.0, calCount); // cof_0=1 + pipe_barrier(PIPE_V); + Div(epsUb, cofValueUb[softmaxOutOffset], tmpCofUb, calCount); // 1 / cof(i) + } else { + pipe_barrier(PIPE_V); + Div(epsUb, cofValueUb[PreSoftmaxOutOffset], tmpCofUb, calCount); // cof(i - 1) / cof(i) + } + pipe_barrier(PIPE_V); + + Adds(cofValueUb[softmaxOutOffset], tmpCofUb, ConstInfo::FLOAT_ZERO, calCount); // store cof(i) + Adds(epsUb, epsUb, (T)(-1.0), calCount); // cof(i - 1) / cof(i) - 1 + pipe_barrier(PIPE_V); + Muls(epsUb, epsUb, (T)1.5, calCount); // (cof(i - 1) - cof(i)) / cof(i) * 1.5 + + Maxs(nUpdateTmp, nUpdateTmp, (T)(-30.0), calCount); // N = max(n(i) - n(i-1), -30) + pipe_barrier(PIPE_V); + Adds(epsUb, epsUb, (T)(0.000001), calCount); + pipe_barrier(PIPE_V); + Add(nUpdateTmp, nUpdateTmp, epsUb, calCount); + pipe_barrier(PIPE_V); + Muls(nUpdateTmp, nUpdateTmp, FLOAT_E_SCALAR, calCount); // N = N * pow(2, 23) + pipe_barrier(PIPE_V); + + // nUpdate int32 out + LocalTensor tmQue = outputBuff2.Get(); + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); + LocalTensor nInt32Out = tmQue[startRow]; + + Cast(nInt32Out, nUpdateTmp, RoundMode::CAST_ROUND, dealRowCount); + pipe_barrier(PIPE_V); + + SetFlag(SYNC_OUTPUT_BUF2_FLAG); +} + +template +__aicore__ inline void SFAVectorService::DealBmm1ResBaseBlock( + const RunInfo &info, const MSplitInfo &mSplitInfo, uint32_t startRow, uint32_t dealRowCount, + uint32_t columnCount, uint32_t loopId) +{ + uint32_t computeSize = dealRowCount * columnCount; + uint64_t inOutGmOffset = (info.loop % constInfo.preLoadNum) * constInfo.mmResUbSize + + (mSplitInfo.nBufferStartM + mSplitInfo.vecStartM + startRow) * columnCount; + LocalTensor mmResUb = inputBuff1.Get(); + mmResUb = mmResUb[pingpongFlag * INPUT1_BUFFER_OFFSET / sizeof(MM1_OUT_T)]; + WaitFlag(SYNC_INPUT_BUF1_FLAG + pingpongFlag); + + DataCopy(mmResUb, mm1ResGm[inOutGmOffset], computeSize); + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + if (loopId == 0) { + WaitFlag(0); + } + } + SetFlag(SYNC_INPUT_BUF1_FLAG); + WaitFlag(SYNC_INPUT_BUF1_FLAG); + + ElewiseCompute(info, mmResUb, dealRowCount, columnCount); + + pipe_barrier(PIPE_V); + LocalTensor tmpAFloorUb = tmpBuff1.Get(); + LocalTensor softmaxTmpUb = tmpAFloorUb.template ReinterpretCast(); + + SoftmaxFlashV2Compute(info, mSplitInfo, mmResUb, softmaxTmpUb, startRow, dealRowCount, columnCount, + info.actualSingleProcessSInnerSize); + + pipe_barrier(PIPE_V); + AmlaVecCompute(info, mSplitInfo, mmResUb, softmaxTmpUb, startRow, dealRowCount, columnCount, + info.actualSingleProcessSInnerSize); + + pipe_barrier(PIPE_V); + LocalTensor tmpMMResCastTensor = outputBuff1.Get(); + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + + Cast(tmpMMResCastTensor, mmResUb, AscendC::RoundMode::CAST_ROUND, computeSize); + SetFlag(SYNC_INPUT_BUF1_FLAG + pingpongFlag); + + SetFlag(SYNC_OUTPUT_BUF1_FLAG); + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + DataCopy(vec1ResGm[inOutGmOffset], tmpMMResCastTensor, computeSize); + SetFlag(SYNC_OUTPUT_BUF1_FLAG); +} + +template +__aicore__ inline void SFAVectorService::ProcessAmlaNupdate(const RunInfo &info, const MSplitInfo &mSplitInfo) +{ + if (mSplitInfo.vecDealM == 0) { + return; + } + if (info.isFirstSInnerLoop) { + return; + } + + LocalTensor nUpdateTensor = outputBuff2.Get(); // shape:1/2*s1*g + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); + SetFlag(SYNC_OUTPUT_BUF2_FLAG); + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); + + constexpr uint32_t dGroupSize = 128U; + constexpr uint32_t mSplitSize = 64U; + constexpr uint32_t ONE_BLOCK_SIZE = 32U; // 32B + + uint32_t subMSize = SFAAlign(mSplitInfo.vecDealM, 16U); + uint16_t elementPerBlock = ONE_BLOCK_SIZE / sizeof(int32_t); + uint32_t loopCount = (subMSize + mSplitSize - 1) / mSplitSize; + uint32_t tailSplitSize = subMSize - (loopCount - 1) * mSplitSize; + + for (uint32_t loop = 0, processMSize = mSplitSize; loop < loopCount; loop++) { + if (loop == (loopCount - 1)) { + processMSize = tailSplitSize; + } + LocalTensor tmpQue = outputBuff1.Get(); + + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + for (uint32_t i = 0; i < dGroupSize / elementPerBlock; i++) { + Brcb(tmpQue[i * elementPerBlock], + nUpdateTensor[loop * mSplitSize], + static_cast((processMSize + elementPerBlock - 1) / elementPerBlock), + {static_cast(dGroupSize / elementPerBlock), + static_cast(dGroupSize)}); + } + + SetFlag(SYNC_OUTPUT_BUF1_FLAG); + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + + uint64_t baseoffset = (info.bn2IdxInCurCore % constInfo.preLoadNum) * constInfo.bmm2ResUbSize + + (mSplitInfo.nBufferStartM + mSplitInfo.vecStartM + loop * mSplitSize) * constInfo.headDim; + + SetAtomicAdd(); + DataCopyParams dataCopyParams; + dataCopyParams.blockCount = static_cast(processMSize); + dataCopyParams.blockLen = dGroupSize * sizeof(int32_t) / ONE_BLOCK_SIZE; + dataCopyParams.srcStride = 0; + dataCopyParams.dstStride = static_cast((constInfo.headDim - dGroupSize) * + sizeof(int32_t) / ONE_BLOCK_SIZE); + for (uint32_t i = 0; i < constInfo.headDim / dGroupSize; i++) { + DataCopy(mm2ResInt32Gm[baseoffset + i * dGroupSize] ,tmpQue, dataCopyParams); + } + SetAtomicNone(); + SetFlag(SYNC_OUTPUT_BUF1_FLAG); + } + SetFlag(SYNC_OUTPUT_BUF2_FLAG); +} + +template +__aicore__ inline void SFAVectorService::ProcessVec1SingleBuf(const RunInfo &info, + const MSplitInfo &mSplitInfo) +{ + if (mSplitInfo.vecDealM == 0) { + return; + } + uint32_t mSplitSize = info.actualSingleProcessSInnerSize == 0 ? + 16 : BASE_BLOCK_MAX_ELEMENT_NUM / info.actualSingleProcessSInnerSizeAlign; + mSplitSize = mSplitSize / 8 * 8; + + if (mSplitSize > mSplitInfo.vecDealM) { + mSplitSize = mSplitInfo.vecDealM; + } + uint32_t loopCount = (mSplitInfo.vecDealM + mSplitSize - 1) / mSplitSize; + uint32_t tailSplitSize = mSplitInfo.vecDealM - (loopCount - 1) * mSplitSize; + + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = 1; + dataCopyParams.blockLen = 256 * sizeof(int32_t); + dataCopyParams.srcStride = 0; + dataCopyParams.dstStride = 0; + DataCopyPadExtParams padParams; + DataCopyPad(v0ValidSizeUb_[128], kvValidSizeGm_[info.loop % MERGE_CACHE_GM_BUF_NUM * (128 * 2)], + dataCopyParams, padParams); + SetFlag(0); + if (unlikely(loopCount == 0)) { + WaitFlag(0); + } + } + for (uint32_t i = 0, dealSize = mSplitSize; i < loopCount; i++) { + if (i == (loopCount - 1)) { + dealSize = tailSplitSize; + } + DealBmm1ResBaseBlock(info, mSplitInfo, i * mSplitSize, dealSize, info.actualSingleProcessSInnerSizeAlign, i); + pingpongFlag ^= 1; + } +} + +template +__aicore__ inline void SFAVectorService::GetRealS2Idx(int64_t s2GmOffset, int64_t &realS2Idx, + int64_t topkGmBaseOffset, const RunInfo &runInfo) +{ + int64_t topkGmIdx = (s2GmOffset + runInfo.s2Idx * constInfo.s2BaseSize) / constInfo.sparseBlockSize; + if (unlikely(topkGmIdx >= constInfo.sparseBlockCount)) { + realS2Idx = -1; + return; + } + realS2Idx = topkGm_.GetValue(topkGmBaseOffset + topkGmIdx) * static_cast(constInfo.sparseBlockSize) + + static_cast((s2GmOffset + runInfo.s2Idx * constInfo.s2BaseSize) % constInfo.sparseBlockSize); +} + +template +__aicore__ inline int64_t SFAVectorService::GetKeyGmOffset(int64_t realS2Idx, + const RunInfo &runInfo, int64_t s2IdLimit) +{ + if (realS2Idx < 0 || realS2Idx >= s2IdLimit) { + return -1; + } + int64_t realKeyGmOffset = 0; + if constexpr (PAGE_ATTENTION) { + int64_t blkTableIdx = realS2Idx / constInfo.kvCacheBlockSize; + int64_t blkTableOffset = realS2Idx % constInfo.kvCacheBlockSize; + realKeyGmOffset = blkTableGm_.GetValue(runInfo.bIdx * constInfo.maxBlockNumPerBatch + blkTableIdx) * + static_cast(constInfo.kvCacheBlockSize) * + static_cast(constInfo.kvHeadNum) + + blkTableOffset; + } else { + realKeyGmOffset = (runInfo.tensorBOffset + + realS2Idx * constInfo.kvHeadNum * constInfo.headDim) / + constInfo.headDim; + } + return realKeyGmOffset; +} + +template +__aicore__ inline int64_t SFAVectorService::GetKeyRopeGmOffset(int64_t realS2Idx, + const RunInfo &runInfo, int64_t s2IdLimit) +{ + if (realS2Idx < 0 || realS2Idx >= s2IdLimit) { + return -1; + } + int64_t realKeyRopeGmOffset = 0; + realKeyRopeGmOffset = (runInfo.tensorBRopeOffset + + realS2Idx * constInfo.kvHeadNum * constInfo.headDimRope) / + constInfo.headDimRope; + return realKeyRopeGmOffset; +} + +template +__aicore__ inline void +SFAVectorService::CopyInSingleKv(int64_t &mte2Size, int64_t mte3Size, int64_t mergeMte3Idx, int64_t realS2Idx, + int64_t keyBNBOffset,int64_t s2IdLimit, const RunInfo &runInfo) +{ + if (keyBNBOffset < 0) { + return; + } + int64_t validS2Count = + (realS2Idx + constInfo.sparseBlockSize > s2IdLimit ? s2IdLimit - realS2Idx : constInfo.sparseBlockSize); + DataCopyExtParams intriParams; + intriParams.blockLen = validS2Count * constInfo.headDim * sizeof(KV_T); + intriParams.blockCount = 1; + intriParams.dstStride = 0; + intriParams.srcStride = 0; + DataCopyPadExtParams padParams; + DataCopyPad(kvMergUb_[mergeMte3Idx % 2 * 32 * 512 + (mte2Size - mte3Size) * constInfo.headDim], + keyGm_[keyBNBOffset * constInfo.headDim], intriParams, padParams); + intriParams.blockLen = validS2Count * constInfo.headDimRope * sizeof(KV_T); + + DataCopyPad(ropeMergUb_[mergeMte3Idx % 2 * 32 * 64 + (mte2Size - mte3Size) * constInfo.headDimRope], + keyRopeGm_[keyBNBOffset * constInfo.headDimRope], intriParams, padParams); + mte2Size += validS2Count; +} + +template +__aicore__ inline void SFAVectorService::CopyInKv(int64_t &mte2Size, int64_t mte3Size, int64_t mergeMte3Idx, + int64_t realS2Idx1, int64_t realS2Idx2, const RunInfo &runInfo) +{ + int64_t s2IdLimit = runInfo.curActualSeqLenOri; + if (constInfo.sparseMode == 3) { + s2IdLimit = runInfo.curActualSeqLenOri - runInfo.actS1Size + runInfo.gS1Idx / constInfo.gSize + 1; + } + + int64_t keyOffset1 = GetKeyGmOffset(realS2Idx1, runInfo, s2IdLimit); + int64_t keyOffset2 = GetKeyGmOffset(realS2Idx2, runInfo, s2IdLimit); + if (unlikely(keyOffset1 < 0 && keyOffset2 < 0)) { + return; + } + + int64_t keySrcStride = 0; + int64_t keyRopeSrcStride = 0; + if constexpr (PAGE_ATTENTION) { + int64_t blkTableSrcStride = + ((keyOffset1 > keyOffset2 ? (keyOffset1 - keyOffset2) : + (keyOffset2 - keyOffset1)) - constInfo.sparseBlockSize); + keySrcStride = blkTableSrcStride * constInfo.headDim * sizeof(KV_T); + keyRopeSrcStride = blkTableSrcStride * constInfo.headDimRope * sizeof(KV_T); + } else { + int64_t keyRopeOffset1 = GetKeyRopeGmOffset(realS2Idx1, runInfo, s2IdLimit); + int64_t keyRopeOffset2 = GetKeyRopeGmOffset(realS2Idx2, runInfo, s2IdLimit); + keySrcStride = ((keyOffset1 > keyOffset2 ? (keyOffset1 - keyOffset2) : + (keyOffset2 - keyOffset1)) - constInfo.sparseBlockSize) * constInfo.headDim * sizeof(KV_T); + keyRopeSrcStride = ((keyRopeOffset1 > keyRopeOffset2 ? (keyRopeOffset1 - keyRopeOffset2) : + (keyRopeOffset2 - keyRopeOffset1)) - constInfo.sparseBlockSize) * + constInfo.headDimRope * sizeof(KV_T); + } + + if (unlikely(keySrcStride >= INT32_MAX || keySrcStride < 0 || + (!PAGE_ATTENTION && (keyRopeSrcStride >= INT32_MAX || keyRopeSrcStride < 0)) || + realS2Idx1 + constInfo.sparseBlockSize >= s2IdLimit || + realS2Idx2 + constInfo.sparseBlockSize >= s2IdLimit)) { + CopyInSingleKv(mte2Size, mte3Size, mergeMte3Idx, realS2Idx1, keyOffset1, s2IdLimit, runInfo); + CopyInSingleKv(mte2Size, mte3Size, mergeMte3Idx, realS2Idx2, keyOffset2, s2IdLimit, runInfo); + } else { + DataCopyExtParams intriParams; + intriParams.blockLen = constInfo.sparseBlockSize * constInfo.headDim * sizeof(KV_T); + intriParams.blockCount = (keyOffset1 >= 0) + (keyOffset2 >= 0); + intriParams.dstStride = 0; + intriParams.srcStride = keySrcStride; + DataCopyPadExtParams padParams; + + int64_t startGmOffset = keyOffset1 > -1 ? keyOffset1 : keyOffset2; + if (keyOffset2 > -1 && keyOffset2 < keyOffset1) { + startGmOffset = keyOffset2; + } + DataCopyPad(kvMergUb_[mergeMte3Idx % 2 * 32 * 512 + (mte2Size - mte3Size) * constInfo.headDim], + keyGm_[startGmOffset * constInfo.headDim], intriParams, padParams); + + intriParams.blockLen = constInfo.sparseBlockSize * constInfo.headDimRope * sizeof(KV_T); + intriParams.dstStride = 0; + intriParams.srcStride = keyRopeSrcStride; + DataCopyPad(ropeMergUb_[mergeMte3Idx % 2 * 32 * 64 + (mte2Size - mte3Size) * constInfo.headDimRope], + keyRopeGm_[startGmOffset * constInfo.headDimRope], intriParams, padParams); + mte2Size += ((keyOffset1 > -1) + (keyOffset2 > -1)) * constInfo.sparseBlockSize; + } +} + +template +__aicore__ inline void SFAVectorService::CopyOutMrgeResult(int64_t mte2Size, int64_t mte3Size, + int64_t s2GmStartOffset, int64_t mergeMte3Idx, + const RunInfo &runInfo) +{ + if (mte2Size <= mte3Size) { + return; + } + SetFlag(0); + WaitFlag(0); + + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = mte2Size - mte3Size; + dataCopyParams.blockLen = constInfo.headDim * sizeof(KV_T); + dataCopyParams.srcStride = 0; + dataCopyParams.dstStride = 0; + + DataCopyPad(kvMergeGm_[runInfo.loop % 4 * 512 * 576 + (s2GmStartOffset + mte3Size)*constInfo.headDim], + kvMergUb_[mergeMte3Idx % 2 * 32 * 512], dataCopyParams); + + dataCopyParams.blockLen = constInfo.headDimRope * sizeof(KV_T); + DataCopyPad(kvMergeGm_[runInfo.loop % 4 * 512 * 576 + 512 * 512 + (s2GmStartOffset + mte3Size) * + constInfo.headDimRope], ropeMergUb_[mergeMte3Idx % 2 * 32 * 64], dataCopyParams); +} + +// b s1 k +template +__aicore__ inline void SFAVectorService::MergeKv(const RunInfo &runInfo) +{ + int64_t s2ProcessSize = runInfo.actualSingleProcessSInnerSize; + int64_t s2Pair = CeilDiv(s2ProcessSize, 2L * constInfo.sparseBlockSize); + int64_t topkGmBaseOffset = 0; + + if constexpr (LAYOUT_T == SFA_LAYOUT::TND) { + uint64_t actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsQGm.GetValue(runInfo.bIdx - 1); + topkGmBaseOffset += (actualSeqQPrefixSum + runInfo.gS1Idx / constInfo.gSize) * constInfo.kvHeadNum * + constInfo.sparseBlockCount + runInfo.n2Idx * constInfo.sparseBlockCount; + } else { + topkGmBaseOffset += runInfo.bIdx * constInfo.qSeqSize * constInfo.sparseBlockCount + + runInfo.gS1Idx / constInfo.gSize * constInfo.sparseBlockCount; + } + int64_t mergeMte3Idx = 0; + int64_t mte2Size = 0; + int64_t mte3Size = 0; + int64_t s2IdxArray0 = -1; + int64_t s2IdxArray1 = -1; + bool needWaitMte3ToMte2 = true; + SetFlag(0); + SetFlag(1); + int64_t s2GmStartOffset = GetSubBlockIdx() == 0 ? 0 : CeilDiv(s2Pair, 2L) * 2 * constInfo.sparseBlockSize; + int64_t s2GmLimit = GetSubBlockIdx() == 0 ? CeilDiv(s2Pair, 2L) * 2 * constInfo.sparseBlockSize: s2ProcessSize; + if (s2GmLimit > s2ProcessSize) { + s2GmLimit = s2ProcessSize; + } + for (int64_t s2GmOffsetArray = s2GmStartOffset; s2GmOffsetArray < s2GmLimit; s2GmOffsetArray += 2 * constInfo.sparseBlockSize) { + if (needWaitMte3ToMte2) { + WaitFlag(mergeMte3Idx % 2); + needWaitMte3ToMte2 = false; + } + GetRealS2Idx(s2GmOffsetArray, s2IdxArray0, topkGmBaseOffset, runInfo); + if (unlikely(s2IdxArray0 < 0)) { + CopyOutMrgeResult(mte2Size, mte3Size, s2GmStartOffset, mergeMte3Idx, runInfo); + SetFlag(mergeMte3Idx % 2); + mergeMte3Idx++; + break; + } + GetRealS2Idx(s2GmOffsetArray + constInfo.sparseBlockSize, s2IdxArray1, topkGmBaseOffset, runInfo); + CopyInKv(mte2Size, mte3Size, mergeMte3Idx, s2IdxArray0, s2IdxArray1, runInfo); + if ((mte2Size - mte3Size + 2 * constInfo.sparseBlockSize > 32) || + s2GmOffsetArray + 2 * constInfo.sparseBlockSize >= s2GmLimit) { + CopyOutMrgeResult(mte2Size, mte3Size, s2GmStartOffset, mergeMte3Idx, runInfo); + mte3Size = mte2Size; + SetFlag(mergeMte3Idx % 2); + mergeMte3Idx++; + needWaitMte3ToMte2 = true; + } + } + + if (unlikely(s2GmStartOffset + mte2Size < s2GmLimit)) { + SetFlag(0); + WaitFlag(0); + WaitFlag(mergeMte3Idx & 1); + Duplicate(kvMergUb_, static_cast(0.0), constInfo.headDim); + SetFlag(0); + WaitFlag(0); + + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = 1; + dataCopyParams.blockLen = constInfo.headDim * sizeof(KV_T); + dataCopyParams.srcStride = 0; + dataCopyParams.dstStride = 0; + for (int64_t s2GmOffset = s2GmStartOffset + mte2Size; s2GmOffset < s2GmLimit; s2GmOffset++) { + DataCopyPad(kvMergeGm_[runInfo.loop % MERGE_CACHE_GM_BUF_NUM * 512 * 576 + s2GmOffset * constInfo.headDim], + kvMergUb_, dataCopyParams); + } + dataCopyParams.blockLen = constInfo.headDimRope * sizeof(KV_T); + for (int64_t s2GmOffset = s2GmStartOffset + mte2Size; s2GmOffset < s2GmLimit; s2GmOffset++) { + DataCopyPad(kvMergeGm_[runInfo.loop % MERGE_CACHE_GM_BUF_NUM * 512 * 576 + 512 * constInfo.headDim + + s2GmOffset * constInfo.headDimRope], + kvMergUb_, dataCopyParams); + } + SetFlag(mergeMte3Idx & 1); + mergeMte3Idx++; + } + WaitFlag(0); + WaitFlag(1); + v0ValidSizeUb_.SetValue(runInfo.loop % MERGE_CACHE_GM_BUF_NUM, mte2Size); + SetFlag(1); + WaitFlag(1); + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = 1; + dataCopyParams.blockLen = 128 * sizeof(int32_t); + dataCopyParams.srcStride = 0; + dataCopyParams.dstStride = 0; + DataCopyPad(kvValidSizeGm_[runInfo.loop % MERGE_CACHE_GM_BUF_NUM * (128 * 2) + GetSubBlockIdx() * 128], + v0ValidSizeUb_, dataCopyParams); + return; +} + +template +__aicore__ inline void SFAVectorService::ProcessVec1L(const RunInfo &info) +{ + uint32_t nBufferLoopTimes = (info.actMBaseSize + constInfo.nBufferMBaseSize - 1) / constInfo.nBufferMBaseSize; + uint32_t nBufferTail = info.actMBaseSize - (nBufferLoopTimes - 1) * constInfo.nBufferMBaseSize; + for (uint32_t i = 0; i < nBufferLoopTimes; i++) { + MSplitInfo mSplitInfo; + mSplitInfo.nBufferIdx = i; + mSplitInfo.nBufferStartM = i * constInfo.nBufferMBaseSize; + mSplitInfo.nBufferDealM = (i + 1 != nBufferLoopTimes) ? constInfo.nBufferMBaseSize : nBufferTail; + + mSplitInfo.vecDealM = (mSplitInfo.nBufferDealM <= 16) ? mSplitInfo.nBufferDealM : + (((mSplitInfo.nBufferDealM + 15) / 16 + 1) / 2 * 16); + mSplitInfo.vecStartM = 0; + if (GetBlockIdx() % 2 == 1) { + mSplitInfo.vecStartM = mSplitInfo.vecDealM; + mSplitInfo.vecDealM = mSplitInfo.nBufferDealM - mSplitInfo.vecDealM; + } + + CrossCoreWaitFlag(constInfo.syncC1V1); + // vec1 compute + ProcessVec1SingleBuf(info, mSplitInfo); + CrossCoreSetFlag(constInfo.syncV1C2); + CrossCoreWaitFlag(constInfo.syncC2V1); + // add nUpdate to mm2ResGm + if (info.actualSingleProcessSInnerSize != 0) { + ProcessAmlaNupdate(info, mSplitInfo); + CrossCoreSetFlag(constInfo.syncV1NupdateC2); + } + // move lse for flash decode + if (info.s2Idx == info.curSInnerLoopTimes - 1) { + if (info.tndIsS2SplitCore) { + if constexpr (FLASH_DECODE) { + uint32_t outIdx = info.loop % (constInfo.preLoadNum); + auto sumTensor = softmaxSumUb[outIdx * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)]; + auto maxTensor = softmaxMaxUb[outIdx * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)]; + ComputeLogSumExpAndCopyToGm(info, mSplitInfo, sumTensor, maxTensor); + } + } + } + } +} + +template +__aicore__ inline uint64_t SFAVectorService::CalcAccumOffset(uint32_t bN2Idx, uint32_t gS1Idx) +{ + return 0; +} + +template +__aicore__ inline void SFAVectorService::ProcessVec2SingleBuf(const RunInfo &info, + const MSplitInfo &mSplitInfo) +{ + if (info.s2Idx + 1 != info.curSInnerLoopTimes) { + return; + } + if (mSplitInfo.vecDealM == 0) { + return; + } + + ProcessVec2Inner(info, mSplitInfo, 0, mSplitInfo.vecDealM); +} + +template __aicore__ inline void SFAVectorService::ProcessVec2L(const RunInfo &info) +{ + uint32_t nBufferLoopTimes = (info.actMBaseSize + constInfo.nBufferMBaseSize - 1) / constInfo.nBufferMBaseSize; + uint32_t nBufferTail = info.actMBaseSize - (nBufferLoopTimes - 1) * constInfo.nBufferMBaseSize; + for (uint32_t i = 0; i < nBufferLoopTimes; i++) { + MSplitInfo mSplitInfo; + mSplitInfo.nBufferIdx = i; + mSplitInfo.nBufferStartM = i * constInfo.nBufferMBaseSize; + mSplitInfo.nBufferDealM = (i + 1 != nBufferLoopTimes) ? constInfo.nBufferMBaseSize : nBufferTail; + + mSplitInfo.vecDealM = (mSplitInfo.nBufferDealM <= 16) ? mSplitInfo.nBufferDealM : + (((mSplitInfo.nBufferDealM + 15) / 16 + 1) / 2 * 16); + mSplitInfo.vecStartM = 0; + if (GetBlockIdx() % 2 == 1) { + mSplitInfo.vecStartM = mSplitInfo.vecDealM; + mSplitInfo.vecDealM = mSplitInfo.nBufferDealM - mSplitInfo.vecDealM; + } + CrossCoreWaitFlag(constInfo.syncC2V2); + ProcessVec2SingleBuf(info, mSplitInfo); + } +} + +template +__aicore__ inline void SFAVectorService::ProcessVec2Inner(const RunInfo &info, + const MSplitInfo &mSplitInfo, + uint32_t mStartRow, uint32_t mDealSize) +{ + uint32_t mSplitSize = BASE_BLOCK_MAX_ELEMENT_NUM / constInfo.headDim; + if (mSplitSize > mDealSize) { + mSplitSize = mDealSize; + } + + uint32_t loopCount = (mDealSize + mSplitSize - 1) / mSplitSize; + uint32_t tailSplitSize = mDealSize - (loopCount - 1) * mSplitSize; + for (uint32_t i = 0, dealSize = mSplitSize; i < loopCount; i++) { + if (i == (loopCount - 1)) { + dealSize = tailSplitSize; + } + DealBmm2ResBaseBlock(info, mSplitInfo, i * mSplitSize + mStartRow, dealSize, + constInfo.headDim, constInfo.headDim); + pingpongFlag ^= 1; + } +} + + +template +__aicore__ inline void SFAVectorService::GetConfusionTransposeTiling( + int64_t numR, int64_t numC, const uint32_t stackBufferSize, const uint32_t typeSize, + ConfusionTransposeTiling &tiling) +{ + (void)stackBufferSize; + uint32_t blockSize = ONE_BLK_SIZE / typeSize; + uint32_t height = numC; + uint32_t width = numR; + uint32_t highBlock = height / BLOCK_CUBE; + uint32_t stride = height * blockSize * typeSize / ONE_BLK_SIZE; + uint32_t repeat = width / blockSize; + + tiling.param0 = blockSize; + tiling.param1 = height; + tiling.param2 = width; + tiling.param3 = highBlock; + tiling.param4 = stride; + tiling.param5 = repeat; +} + +template +__aicore__ inline void +SFAVectorService::Bmm2FDDataCopyOut(const RunInfo &info, LocalTensor &bmm2ResUb, + uint32_t wsMStart, uint32_t dealRowCount, uint32_t columnCount, + uint32_t actualColumnCount) +{ + LocalTensor tmp = outputBuff1.Get(); + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + DataCopy(tmp, bmm2ResUb, columnCount * dealRowCount); + SetFlag(SYNC_OUTPUT_BUF1_FLAG); + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + uint64_t accumTmpOutNum = CalcAccumOffset(info.bIdx, info.gS1Idx); + uint64_t offset = accumTmpOutNum * constInfo.kvHeadNum * constInfo.mBaseSize * constInfo.headDim + + info.tndCoreStartKVSplitPos * constInfo.kvHeadNum * constInfo.mBaseSize * constInfo.headDim + + wsMStart * actualColumnCount; + GlobalTensor dst = accumOutGm[offset]; + if (info.actualSingleProcessSInnerSize== 0) { + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = dealRowCount; + dataCopyParams.blockLen = actualColumnCount * sizeof(T); + dataCopyParams.srcStride = (columnCount - actualColumnCount) / (BYTE_BLOCK / sizeof(T)); + dataCopyParams.dstStride = 0; + DataCopyPad(dst, tmp, dataCopyParams); + } else { + matmul::InitOutput(dst, dealRowCount * actualColumnCount, ConstInfo::FLOAT_ZERO); + } + SetFlag(SYNC_OUTPUT_BUF1_FLAG); +} + +template +__aicore__ inline void +SFAVectorService::Bmm2DataCopyOutTrans(const RunInfo &info, LocalTensor &attenOutUb, + uint32_t wsMStart, uint32_t dealRowCount, + uint32_t columnCount, uint32_t actualColumnCount) +{ + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = dealRowCount; + dataCopyParams.blockLen = actualColumnCount * sizeof(OUT_T); + dataCopyParams.srcStride = (columnCount - actualColumnCount) / (BYTE_BLOCK / sizeof(OUT_T)); + dataCopyParams.dstStride = 0; + DataCopyPad(attentionOutGm[info.attenOutOffset + wsMStart * actualColumnCount], attenOutUb, dataCopyParams); + return; +} + +template +__aicore__ inline void +SFAVectorService::Bmm2CastAndCopyOut(const RunInfo &info, LocalTensor &bmm2ResUb, + uint32_t wsMStart, uint32_t dealRowCount, uint32_t columnCount, + uint32_t actualColumnCount) +{ + LocalTensor tmpBmm2ResCastTensor = outputBuff1.Get(); + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + if constexpr (IsSameType::value) { + Cast(tmpBmm2ResCastTensor, bmm2ResUb, AscendC::RoundMode::CAST_RINT, dealRowCount * columnCount); + } else { + Cast(tmpBmm2ResCastTensor, bmm2ResUb, AscendC::RoundMode::CAST_ROUND, dealRowCount * columnCount); + } + + SetFlag(SYNC_OUTPUT_BUF1_FLAG); + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + Bmm2DataCopyOutTrans(info, tmpBmm2ResCastTensor, wsMStart, dealRowCount, columnCount, actualColumnCount); + SetFlag(SYNC_OUTPUT_BUF1_FLAG); +} + +template +__aicore__ inline void +SFAVectorService::Bmm2ResCopyOut(const RunInfo &info, LocalTensor &bmm2ResUb, uint32_t wsMStart, + uint32_t dealRowCount, uint32_t columnCount, + uint32_t actualColumnCount) +{ + if constexpr (FLASH_DECODE) { + if (info.tndIsS2SplitCore) { + Bmm2FDDataCopyOut(info, bmm2ResUb, wsMStart, dealRowCount, columnCount, actualColumnCount); + } else { + Bmm2CastAndCopyOut(info, bmm2ResUb, wsMStart, dealRowCount, columnCount, actualColumnCount); + } + } else { + Bmm2CastAndCopyOut(info, bmm2ResUb, wsMStart, dealRowCount, columnCount, actualColumnCount); + } +} + +template +__aicore__ inline void +SFAVectorService::DealBmm2ResBaseBlock(const RunInfo &info, const MSplitInfo &mSplitInfo, + uint32_t startRow, uint32_t dealRowCount, + uint32_t columnCount, uint32_t actualColumnCount) +{ + uint32_t vec2ComputeSize = dealRowCount * columnCount; + uint32_t mStart = mSplitInfo.nBufferStartM + mSplitInfo.vecStartM + startRow; + uint64_t srcGmOffset = (info.bn2IdxInCurCore % constInfo.preLoadNum) * constInfo.bmm2ResUbSize + + mStart * columnCount; + LocalTensor tmpBmm2ResUb = inputBuff1.Get(); + tmpBmm2ResUb = tmpBmm2ResUb[pingpongFlag * INPUT1_BUFFER_OFFSET / sizeof(MM2_OUT_T)]; + WaitFlag(SYNC_INPUT_BUF1_FLAG + pingpongFlag); + DataCopy(tmpBmm2ResUb, mm2ResGm[srcGmOffset], vec2ComputeSize); + + SetFlag(SYNC_INPUT_BUF1_FLAG); + WaitFlag(SYNC_INPUT_BUF1_FLAG); + + LocalTensor bmm2ResUb = tmpBuff1.Get(); + bmm2ResUb.SetSize(vec2ComputeSize); + LocalTensor absBmm2ResUb = bmm2ResUb.template ReinterpretCast(); + Abs(absBmm2ResUb, tmpBmm2ResUb, vec2ComputeSize); + pipe_barrier(PIPE_V); + LocalTensor cmpMaskUb = absBmm2ResUb.template ReinterpretCast(); + CompareScalar(cmpMaskUb, absBmm2ResUb, (T)1e10, CMPMODE::LE, vec2ComputeSize); + pipe_barrier(PIPE_V); + Select(tmpBmm2ResUb, cmpMaskUb, tmpBmm2ResUb, ConstInfo::FLOAT_ZERO, + SELMODE::VSEL_TENSOR_SCALAR_MODE, vec2ComputeSize); + pipe_barrier(PIPE_V); + uint32_t baseOffset = mSplitInfo.nBufferStartM / 2 + startRow; + uint32_t idx = info.loop % (constInfo.preLoadNum); + LocalTensor tmpSumUb = v0ValidSizeBuff.Get()[384]; + Brcb(tmpSumUb, aMlaSumUb[idx * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T) + baseOffset], (dealRowCount + 7) / 8, {1, 8}); + pipe_barrier(PIPE_V); + RowDivs(bmm2ResUb, tmpBmm2ResUb, tmpSumUb, dealRowCount, columnCount, actualColumnCount); + pipe_barrier(PIPE_V); + SetFlag(SYNC_INPUT_BUF1_FLAG + pingpongFlag); + Bmm2ResCopyOut(info, bmm2ResUb, mStart, dealRowCount, columnCount, actualColumnCount); +} + +template +__aicore__ inline void +SFAVectorService::RowDivs(LocalTensor dstUb, LocalTensor src0Ub, LocalTensor src1Ub, + uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount) +{ + uint32_t dtypeMask = FP32_REPEAT_ELEMENT_NUM; + uint32_t dLoop = actualColumnCount / dtypeMask; + uint32_t dRemain = actualColumnCount % dtypeMask; + + BinaryRepeatParams repeatParamsDiv; + repeatParamsDiv.src0BlkStride = 1; + repeatParamsDiv.src1BlkStride = 0; + repeatParamsDiv.dstBlkStride = 1; + repeatParamsDiv.src0RepStride = columnCount / FP32_BLOCK_ELEMENT_NUM; + repeatParamsDiv.src1RepStride = 1; + repeatParamsDiv.dstRepStride = columnCount / FP32_BLOCK_ELEMENT_NUM; + uint32_t columnRepeatCount = dLoop; + if (columnRepeatCount <= dealRowCount) { + uint32_t offset = 0; + for (uint32_t i = 0; i < dLoop; i++) { + Div(dstUb[offset], src0Ub[offset], src1Ub, dtypeMask, dealRowCount, repeatParamsDiv); + offset += dtypeMask; + } + } else { + BinaryRepeatParams columnRepeatParams; + columnRepeatParams.src0BlkStride = 1; + columnRepeatParams.src1BlkStride = 0; + columnRepeatParams.dstBlkStride = 1; + columnRepeatParams.src0RepStride = 8; + columnRepeatParams.src1RepStride = 0; + columnRepeatParams.dstRepStride = 8; + uint32_t offset = 0; + for (uint32_t i = 0; i < dealRowCount; i++) { + Div(dstUb[offset], src0Ub[offset], src1Ub[i * FP32_BLOCK_ELEMENT_NUM], dtypeMask, columnRepeatCount, + columnRepeatParams); + offset += columnCount; + } + } + if (dRemain > 0) { + Div(dstUb[dLoop * dtypeMask], src0Ub[dLoop * dtypeMask], src1Ub, dRemain, dealRowCount, repeatParamsDiv); + } +} + +template +__aicore__ inline void +SFAVectorService::RowMuls(LocalTensor dstUb, LocalTensor src0Ub, LocalTensor src1Ub, + uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount) +{ + uint32_t repeatElementNum = FP32_REPEAT_ELEMENT_NUM; + uint32_t blockElementNum = FP32_BLOCK_ELEMENT_NUM; + + if constexpr (std::is_same::value) { + repeatElementNum = FP32_REPEAT_ELEMENT_NUM * 2; // 256/4 * 2=128 + blockElementNum = FP32_BLOCK_ELEMENT_NUM * 2; // 32/4 * 2 = 16 + } + + uint32_t dLoop = actualColumnCount / repeatElementNum; + uint32_t dRemain = actualColumnCount % repeatElementNum; + if (columnCount < REPEATE_STRIDE_UP_BOUND * blockElementNum) { + BinaryRepeatParams repeatParams; + repeatParams.src0BlkStride = 1; + repeatParams.src1BlkStride = 0; + repeatParams.dstBlkStride = 1; + repeatParams.src0RepStride = columnCount / blockElementNum; + repeatParams.src1RepStride = 1; + repeatParams.dstRepStride = columnCount / blockElementNum; + + if (dLoop <= dealRowCount) { + uint32_t offset = 0; + for (uint32_t i = 0; i < dLoop; i++) { + Mul(dstUb[offset], src0Ub[offset], src1Ub, repeatElementNum, dealRowCount, repeatParams); + offset += repeatElementNum; + } + } else { + BinaryRepeatParams columnRepeatParams; + columnRepeatParams.src0BlkStride = 1; + columnRepeatParams.src1BlkStride = 0; + columnRepeatParams.dstBlkStride = 1; + columnRepeatParams.src0RepStride = 8; + columnRepeatParams.src1RepStride = 0; + columnRepeatParams.dstRepStride = 8; + for (uint32_t i = 0; i < dealRowCount; i++) { + Mul(dstUb[i * columnCount], src0Ub[i * columnCount], src1Ub[i * blockElementNum], repeatElementNum, + dLoop, columnRepeatParams); + } + } + + if (dRemain > 0) { + Mul(dstUb[dLoop * repeatElementNum], src0Ub[dLoop * repeatElementNum], src1Ub, dRemain, dealRowCount, + repeatParams); + } + } else { + BinaryRepeatParams repeatParams; + repeatParams.src0RepStride = 8; + repeatParams.src0BlkStride = 1; + repeatParams.src1RepStride = 0; + repeatParams.src1BlkStride = 0; + repeatParams.dstRepStride = 8; + repeatParams.dstBlkStride = 1; + for (uint32_t i = 0; i < dealRowCount; i++) { + Mul(dstUb[i * columnCount], src0Ub[i * columnCount], src1Ub[i * blockElementNum], repeatElementNum, dLoop, + repeatParams); + if (dRemain > 0) { + Mul(dstUb[i * columnCount + dLoop * repeatElementNum], + src0Ub[i * columnCount + dLoop * repeatElementNum], src1Ub[i * blockElementNum], dRemain, 1, + repeatParams); + } + } + } +} + +#endif // SPARSE_FLASH_ATTENTION_SERVICE_VECTOR_MLA_H diff --git a/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_template_tiling_key.h b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_template_tiling_key.h new file mode 100644 index 00000000..ced883ca --- /dev/null +++ b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_template_tiling_key.h @@ -0,0 +1,54 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_template_tiling_key.h + * \brief + */ + +#ifndef SPARSE_FLASH_ATTENTION_TEMPLATE_TILING_KEY_H +#define SPARSE_FLASH_ATTENTION_TEMPLATE_TILING_KEY_H + +#include "ascendc/host_api/tiling/template_argument.h" + +#define SFA_LAYOUT_BSND 0 +#define SFA_LAYOUT_TND 1 +#define SFA_LAYOUT_PA_BSND 2 + +#define ASCENDC_TPL_4_BW 4 + +#define C_TEMPLATE 0 +#define V_TEMPLATE 1 + +ASCENDC_TPL_ARGS_DECL(SparseFlashAttention, +ASCENDC_TPL_BOOL_DECL(FLASH_DECODE, 0, 1), +ASCENDC_TPL_UINT_DECL(LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_TND), +ASCENDC_TPL_UINT_DECL(KV_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_TND, + SFA_LAYOUT_PA_BSND), +ASCENDC_TPL_UINT_DECL(TEMPLATE_MODE, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, C_TEMPLATE, V_TEMPLATE), +); + +ASCENDC_TPL_SEL( + ASCENDC_TPL_ARGS_SEL( + ASCENDC_TPL_BOOL_SEL(FLASH_DECODE, 0), + ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(KV_LAYOUT_T, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_PA_BSND, SFA_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(TEMPLATE_MODE, ASCENDC_TPL_UI_LIST, C_TEMPLATE), + ), + + ASCENDC_TPL_ARGS_SEL( + ASCENDC_TPL_BOOL_SEL(FLASH_DECODE, 0), + ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(KV_LAYOUT_T, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_PA_BSND, SFA_LAYOUT_BSND, SFA_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(TEMPLATE_MODE, ASCENDC_TPL_UI_LIST, V_TEMPLATE), // V模板不支持非PA + ), +); + +#endif // TEMPLATE_TILING_KEY \ No newline at end of file diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 9ef0cfbb..68cefc15 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -620,6 +620,103 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor } +at::Tensor npu_lightning_indexer( + const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights, + const c10::optional &actual_seq_lengths_query, + const c10::optional &actual_seq_lengths_key, + const c10::optional &block_table, c10::string_view layout_query, + c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode) +{ + // npu tensor max size + constexpr int32_t SIZE = 8; + constexpr int32_t DIM_0 = 0; + constexpr int32_t DIM_1 = 1; + constexpr int32_t DIM_2 = 2; + constexpr int32_t DIM_3 = 3; + + TORCH_CHECK(query.numel() > 0, "Query is empty."); + TORCH_CHECK(key.numel() > 0, "Key is empty."); + TORCH_CHECK(weights.numel() > 0, "Weights is empty."); + for (size_t i = 0; i < query.sizes().size(); i++) { + TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater " + "than 0, but shape[", i, "] is ", query.size(i)); + } + TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count); + + at::SmallVector output_size; + std::string query_layout_str = std::string(layout_query); + std::string key_layout_str = std::string(layout_key); + if (query_layout_str == "BSND") { + output_size = {query.size(DIM_0), query.size(DIM_1), key.size(DIM_2), sparse_count}; + } else { + int n_dim_index = 0; + n_dim_index = (key_layout_str == "TND") ? DIM_1 : DIM_2; + output_size = {query.size(DIM_0), key.size(n_dim_index), sparse_count}; + } + at::Tensor lightning_indexer_output = at::empty(output_size, query.options().dtype(at::kInt)); + // convert str + char *query_layout_ptr = const_cast(query_layout_str.c_str()); + char *key_layout_ptr = const_cast(key_layout_str.c_str()); + EXEC_NPU_CMD( + aclnnLightningIndexer, + query, + key, + weights, + actual_seq_lengths_query, + actual_seq_lengths_key, + block_table, + query_layout_ptr, + key_layout_ptr, + sparse_count, + sparse_mode, + lightning_indexer_output); + return lightning_indexer_output; +} + +at::Tensor npu_sparse_flash_attention( + const at::Tensor &query, const at::Tensor &key, const at::Tensor &value, + const at::Tensor &sparse_indices, double scale_value, int64_t sparse_block_size, + const c10::optional &block_table, + const c10::optional &actual_seq_lengths_query, + const c10::optional &actual_seq_lengths_kv, + const c10::optional &query_rope, + const c10::optional &key_rope, c10::string_view layout_query, + c10::string_view layout_kv, + int64_t sparse_mode) +{ + std::string layout_query_str = std::string(layout_query); + std::string layout_kv_str = std::string(layout_kv); + + for (size_t i = 0; i < query.sizes().size(); i++) { + TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater " + "than 0, but shape[", i, "] is ", query.size(i)); + } + // construct the output tensor + at::Tensor output = at::empty(query.sizes(), query.options().dtype(query.dtype())); + // convert str + char *layout_query_ptr = const_cast(layout_query_str.c_str()); + char *layout_kv_ptr = const_cast(layout_kv_str.c_str()); + + EXEC_NPU_CMD( + aclnnSparseFlashAttention, + query, + key, + value, + sparse_indices, + block_table, + actual_seq_lengths_query, + actual_seq_lengths_kv, + query_rope, + key_rope, + scale_value, + sparse_block_size, + layout_query_ptr, + layout_kv_ptr, + sparse_mode, + output); + return output; +} + } // namespace vllm_ascend TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) @@ -695,4 +792,22 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) " (Tensor output, Tensor output_scale, Tensor output_offset)" ); ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant_weight_nz_tensor_list); + + ops.def( + "npu_lightning_indexer(Tensor query, Tensor key, Tensor weights, *," + " Tensor? actual_seq_lengths_query=None, Tensor? actual_seq_lengths_key=None," + " Tensor? block_table=None, str layout_query='BSND', str layout_key='BSND'," + " int sparse_count=2048, int sparse_mode=3) -> Tensor" + ); + ops.impl("npu_lightning_indexer", torch::kPrivateUse1, &vllm_ascend::npu_lightning_indexer); + + ops.def( + "npu_sparse_flash_attention(Tensor query, Tensor key, Tensor value," + " Tensor sparse_indices, float scale_value, int sparse_block_size, *," + " Tensor? block_table=None, Tensor? actual_seq_lengths_query=None," + " Tensor? actual_seq_lengths_kv=None, Tensor? query_rope=None," + " Tensor? key_rope=None, str layout_query='BSND', str layout_kv='BSND'," + " int sparse_mode=3) -> Tensor" + ); + ops.impl("npu_sparse_flash_attention", torch::kPrivateUse1, &vllm_ascend::npu_sparse_flash_attention); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index c5811998..b84779e2 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -159,6 +159,64 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor } +at::Tensor npu_lightning_indexer_meta( + const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights, + const c10::optional &actual_seq_lengths_query, + const c10::optional &actual_seq_lengths_key, + const c10::optional &block_table, c10::string_view layout_query, + c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode) +{ + // npu tensor max size + constexpr int32_t SIZE = 8; + constexpr int32_t DIM_0 = 0; + constexpr int32_t DIM_1 = 1; + constexpr int32_t DIM_2 = 2; + constexpr int32_t DIM_3 = 3; + + TORCH_CHECK(query.numel() > 0, "Query is empty."); + TORCH_CHECK(key.numel() > 0, "Key is empty."); + TORCH_CHECK(weights.numel() > 0, "Weights is empty."); + for (size_t i = 0; i < query.sizes().size(); i++) { + TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater " + "than 0, but shape[", i, "] is ", query.size(i)); + } + TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count); + + std::string query_layout_str = std::string(layout_query); + std::string key_layout_str = std::string(layout_key); + at::SmallVector output_size; + if (query_layout_str == "BSND") { + output_size = {query.size(DIM_0), query.size(DIM_1), key.size(DIM_2), sparse_count}; + } else { + int n_dim_index = 0; + n_dim_index = (key_layout_str == "TND") ? DIM_1 : DIM_2; + output_size = {query.size(DIM_0), key.size(n_dim_index), sparse_count}; + } + // construct the output tensor + at::Tensor lightning_indexer_output = at::empty(output_size, query.options().dtype(at::kInt)); + return lightning_indexer_output; +} + +at::Tensor npu_sparse_flash_attention_meta( + const at::Tensor &query, const at::Tensor &key, const at::Tensor &value, + const at::Tensor &sparse_indices, double scale_value, int64_t sparse_block_size, + const c10::optional &block_table, + const c10::optional &actual_seq_lengths_query, + const c10::optional &actual_seq_lengths_kv, + const c10::optional &query_rope, + const c10::optional &key_rope, c10::string_view layout_query, + c10::string_view layout_kv, + int64_t sparse_mode) +{ + std::string layout_query_str = std::string(layout_query); + for (size_t i = 0; i < query.sizes().size(); i++) { + TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater " + "than 0, but shape[", i, "] is ", query.size(i)); + } + at::Tensor output = at::empty(query.sizes(), query.options().dtype(query.dtype())); + return output; +} + } // namespace meta } // namespace vllm_ascend @@ -182,5 +240,9 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", &vllm_ascend::meta::grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta); // batch_matmul_transpose ops.impl("batch_matmul_transpose", &vllm_ascend::meta::batch_matmul_transpose); + // Lightning indexer + ops.impl("npu_lightning_indexer", &vllm_ascend::meta::npu_lightning_indexer_meta); + // Sparse flash attention + ops.impl("npu_sparse_flash_attention", &vllm_ascend::meta::npu_sparse_flash_attention_meta); } } diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index bec90b32..24306a01 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -459,7 +459,7 @@ class AscendSFAImpl(MLAAttentionImpl): kv_cache=kv_cache, attn_metadata=attn_metadata, need_gather_q_kv=need_gather_q_kv) - attn_output = torch.ops.custom.npu_sparse_flash_attention( + attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( query=ql_nope, key=k_nope, value=k_nope, @@ -554,7 +554,7 @@ class AscendSFAImpl(MLAAttentionImpl): seq_lens = attn_metadata.seq_lens cum_query_lens = attn_metadata.cum_query_lens - topk_indices = torch.ops.custom.npu_lightning_indexer( + topk_indices = torch.ops._C_ascend.npu_lightning_indexer( query=q, key=kv_cache[2], weights=weights, diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 9e064d25..849e2654 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -93,21 +93,6 @@ class NPUWorker(WorkerBase): # init ascend config and soc version init_ascend_config(vllm_config) check_ascend_device_type() - use_sparse = False - if vllm_config.model_config is not None: - use_sparse = hasattr(vllm_config.model_config.hf_config, - "index_topk") - if use_sparse: - # Direct import instead of using try_register_lib to ensure proper error handling when - # custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments) - # yapf: disable - import custom_ops # type: ignore # noqa - - # yapf: enable - logger.info( - "custom_ops module loaded successfully. Custom operators like " - "torch.ops.custom.npu_sparse_flash_attention are now available." - ) super().__init__(vllm_config=vllm_config, local_rank=local_rank,