diff --git a/.github/workflows/scripts/config.yaml b/.github/workflows/scripts/config.yaml index c649f8ac..7cd9396d 100644 --- a/.github/workflows/scripts/config.yaml +++ b/.github/workflows/scripts/config.yaml @@ -68,6 +68,8 @@ e2e-2card-light: estimated_time: 220 - name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8_pruning_mtp_tp2_ep estimated_time: 90 + - name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep + estimated_time: 90 - name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_gpt_oss_distributed_tp2 estimated_time: 180 @@ -118,6 +120,8 @@ e2e-multicard-2-cards: estimated_time: 71 - name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8_pruning_mtp_tp2_ep estimated_time: 111 + - name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep + estimated_time: 111 - name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_qwen3_w4a4_distributed_tp2 estimated_time: 180 - name: tests/e2e/multicard/2-cards/test_pipeline_parallel.py diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index 5b11cbe2..631ea4e3 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -25,7 +25,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH} - CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;" + CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;lightning_indexer_quant;" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series @@ -67,6 +67,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then "copy_and_expand_eagle_inputs" "causal_conv1d" "moe_grouped_matmul" + "lightning_indexer_quant" ) CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}") SOC_ARG="ascend910_93" diff --git a/csrc/lightning_indexer_quant/lightning_indexer_quant_torch_adpt.h b/csrc/lightning_indexer_quant/lightning_indexer_quant_torch_adpt.h new file mode 100644 index 00000000..fbc8284c --- /dev/null +++ b/csrc/lightning_indexer_quant/lightning_indexer_quant_torch_adpt.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LIGHTING_INDEXER_QUANT_VLLM_TORCH_ADPT_H +#define LIGHTING_INDEXER_QUANT_VLLM_TORCH_ADPT_H +namespace vllm_ascend { + +at::Tensor npu_lightning_indexer_quant( + const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights, + const at::Tensor &query_dequant_scale, const at::Tensor &key_dequant_scale, + const c10::optional &actual_seq_lengths_query, + const c10::optional &actual_seq_lengths_key, + const c10::optional &block_table, int64_t query_quant_mode, int64_t key_quant_mode, + c10::string_view layout_query, c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode) +{ + std::string query_layout_str = std::string(layout_query); + std::string key_layout_str = std::string(layout_key); + + const int SIZE = 8; + const int DIM_0 = 0; + const int DIM_1 = 1; + const int DIM_2 = 2; + const int DIM_3 = 3; + + at::SmallVector output_size; + for (size_t i = 0; i < query.sizes().size(); i++) { + TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater " + "than 0, but shape[", i, "] is ", query.size(i)); + } + for (size_t i = 0; i < key.sizes().size(); i++) { + TORCH_CHECK(key.size(i) > 0, "All values within key's shape should be greater " + "than 0, but shape[", i, "] is ", key.size(i)); + } + TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count); + int64_t keyHeadNum = (key_layout_str == "TND")? key.size(DIM_1) : key.size(DIM_2); + if (query_layout_str == "BSND") { + output_size = {query.size(DIM_0), query.size(DIM_1), keyHeadNum, sparse_count}; + } else { + output_size = {query.size(DIM_0), keyHeadNum, sparse_count}; + } + at::Tensor lightning_indexer_quant_output = at::empty(output_size, query.options().dtype(at::kInt)); + + // convert str + char *query_layout_ptr = const_cast(query_layout_str.c_str()); + char *key_layout_ptr = const_cast(key_layout_str.c_str()); + + EXEC_NPU_CMD(aclnnLightningIndexerQuant, + query, + key, + weights, + query_dequant_scale, + key_dequant_scale, + actual_seq_lengths_query, + actual_seq_lengths_key, + block_table, + query_quant_mode, + key_quant_mode, + query_layout_ptr, + key_layout_ptr, + sparse_count, + sparse_mode, + lightning_indexer_quant_output + ); + + return lightning_indexer_quant_output; + +} +} +#endif \ No newline at end of file diff --git a/csrc/lightning_indexer_quant/op_host/CMakeLists.txt b/csrc/lightning_indexer_quant/op_host/CMakeLists.txt new file mode 100644 index 00000000..252cfd85 --- /dev/null +++ b/csrc/lightning_indexer_quant/op_host/CMakeLists.txt @@ -0,0 +1,41 @@ +# This program is free software, you can redistribute it and/or modify it. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_ops_compile_options( + OP_NAME LightningIndexerQuant + OPTIONS --cce-auto-sync=off + -Wno-deprecated-declarations + -Werror + -mllvm -cce-aicore-hoist-movemask=false + --op_relocatable_kernel_binary=true +) + +set(lightning_indexer_quant_depends transformer/attention/lightning_indexer_quant PARENT_SCOPE) + +target_sources(op_host_aclnn PRIVATE + lightning_indexer_quant_def.cpp +) + +target_sources(optiling PRIVATE + lightning_indexer_quant_tiling.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(opmaster_ct PRIVATE + lightning_indexer_quant_tiling.cpp + ) +endif () + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/op_host +) + +target_sources(opsproto PRIVATE + lightning_indexer_quant_proto.cpp +) diff --git a/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_def.cpp b/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_def.cpp new file mode 100644 index 00000000..7049c2d0 --- /dev/null +++ b/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_def.cpp @@ -0,0 +1,85 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_quant_def.cpp + * \brief + */ +#include + +#include "register/op_def_registry.h" + +namespace ops { +class LightningIndexerQuant : public OpDef { +public: + explicit LightningIndexerQuant(const char *name) : OpDef(name) + { + this->Input("query") + .ParamType(REQUIRED) + .DataType({ge::DT_INT8}) + .Format({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("key") + .ParamType(REQUIRED) + .DataType({ge::DT_INT8}) + .Format({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("weights") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("query_dequant_scale") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("key_dequant_scale") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("actual_seq_lengths_query") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("actual_seq_lengths_key") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("block_table") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .AutoContiguous(); + this->Output("sparse_indices").ParamType(REQUIRED).DataType({ge::DT_INT32}).Format({ge::FORMAT_ND}); + this->Attr("query_quant_mode").AttrType(REQUIRED).Int(0); // 0: 默认值,per-token-head + this->Attr("key_quant_mode").AttrType(REQUIRED).Int(0); // 0: 默认值,per-token-head + this->Attr("layout_query").AttrType(OPTIONAL).String("BSND"); + this->Attr("layout_key").AttrType(OPTIONAL).String("PA_BSND"); + this->Attr("sparse_count").AttrType(OPTIONAL).Int(2048); // 2048: 默认值,筛选前2048 + this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3); // 3: 默认值,只计算下三角 + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true) + .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") + .ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false"); + this->AICore().AddConfig("ascend910b", aicore_config); + this->AICore().AddConfig("ascend910_93", aicore_config); + } +}; +OP_ADD(LightningIndexerQuant); +} // namespace ops \ No newline at end of file diff --git a/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_proto.cpp b/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_proto.cpp new file mode 100644 index 00000000..fb4539d6 --- /dev/null +++ b/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_proto.cpp @@ -0,0 +1,91 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_quant_proto.cpp + * \brief + */ +#include +#include + +#include "error/ops_error.h" + +using namespace ge; + +namespace ops { +constexpr uint32_t QUERY_INDEX = 0; +constexpr uint32_t KEY_INDEX = 1; +constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 2; +constexpr uint32_t ATTR_KV_LAYOUT_INDEX = 3; +constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 4; + +static ge::graphStatus InferShapeLightningIndexerQuant(gert::InferShapeContext *context) +{ + if (context == nullptr) { + OPS_LOG_E("LightningIndexerQuant", "context is nullptr!"); + return ge::GRAPH_FAILED; + } + const gert::Shape *queryShape = context->GetInputShape(QUERY_INDEX); + OPS_LOG_E_IF_NULL(context, queryShape, return ge::GRAPH_FAILED); + const gert::Shape *keyShape = context->GetInputShape(KEY_INDEX); + OPS_LOG_E_IF_NULL(context, keyShape, return ge::GRAPH_FAILED); + gert::Shape *outShape = context->GetOutputShape(0); + + auto attrs = context->GetAttrs(); + OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED); + const char *inputLayoutQueryPtr = attrs->GetAttrPointer(ATTR_QUERY_LAYOUT_INDEX); + OPS_LOG_E_IF_NULL(context, inputLayoutQueryPtr, return ge::GRAPH_FAILED); + const char *inputLayoutKeyPtr = attrs->GetAttrPointer(ATTR_KV_LAYOUT_INDEX); + OPS_LOG_E_IF_NULL(context, inputLayoutKeyPtr, return ge::GRAPH_FAILED); + const int64_t *sparse_count = attrs->GetInt(ATTR_SPARSE_COUNT_INDEX); + OPS_LOG_E_IF_NULL(context, sparse_count, return ge::GRAPH_FAILED); + + std::string inputLayoutQueryPtrStr = std::string(inputLayoutQueryPtr); + std::string inputLayoutKeyPtrStr = std::string(inputLayoutKeyPtr); + if (inputLayoutQueryPtrStr != "TND" && inputLayoutQueryPtrStr != "BSND") { + OPS_LOG_E(context, "The input layout query should be TND or BSND, but got %s.", inputLayoutQueryPtrStr.c_str()); + return GRAPH_FAILED; + } + + outShape->SetDimNum(queryShape->GetDimNum()); + int64_t keyHeadNum = (inputLayoutKeyPtrStr == "TND") ? keyShape->GetDim(1) : keyShape->GetDim(2); + if (inputLayoutQueryPtrStr == "BSND") { + outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim B + outShape->SetDim(1, queryShape->GetDim(1)); // 1:Dim S + outShape->SetDim(2, keyHeadNum); // 2:Dim N + outShape->SetDim(3, *sparse_count); // 3:Dim K + } else { + outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim T + outShape->SetDim(1, keyHeadNum); // 1:output shape's N Dim, 2: key shape's N Dim + outShape->SetDim(2, *sparse_count); // 2:Dim K + } + + OPS_LOG_D(context->GetNodeName(), "LightningIndexerQuant InferShape end."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus InferDataTypeLightningIndexerQuant(gert::InferDataTypeContext *context) +{ + if (context == nullptr) { + OPS_LOG_E("LightningIndexerQuant", "InferDataTypeContext context is nullptr!"); + return ge::GRAPH_FAILED; + } + OPS_LOG_D(context->GetNodeName(), "Enter LightningIndexerQuant InferDataType impl."); + // default index data type is int32 + ge::DataType outputType = ge::DT_INT32; + context->SetOutputDataType(0, outputType); + OPS_LOG_D(context->GetNodeName(), "LightningIndexerQuant InferDataType end."); + return GRAPH_SUCCESS; +} + +IMPL_OP_INFERSHAPE(LightningIndexerQuant) + .InferShape(InferShapeLightningIndexerQuant) + .InferDataType(InferDataTypeLightningIndexerQuant); +} // namespace ops diff --git a/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_tiling.cpp b/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_tiling.cpp new file mode 100644 index 00000000..042f0271 --- /dev/null +++ b/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_tiling.cpp @@ -0,0 +1,828 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_quant_tiling.cpp + * \brief + */ + +#include "lightning_indexer_quant_tiling.h" + +#include "../op_kernel/lightning_indexer_quant_template_tiling_key.h" + +using namespace ge; +using namespace AscendC; +using std::map; +using std::string; +namespace optiling { +// --------------------------LIQInfoParser类成员函数定义------------------------------------- +ge::graphStatus LIQInfoParser::CheckRequiredInOutExistence() const +{ + OPS_ERR_IF(opParamInfo_.query.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor query is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.query.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor query is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.key.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor key is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.key.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor key is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.weights.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor weights is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.weights.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor weights is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.query_dequant_scale.shape == nullptr, + OPS_LOG_E(opName_, "Shape of tensor query_dequant_scale is nullptr"), return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.query_dequant_scale.desc == nullptr, + OPS_LOG_E(opName_, "Desc of tensor query_dequant_scale is nullptr"), return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.key_dequant_scale.shape == nullptr, + OPS_LOG_E(opName_, "Shape of tensor key_dequant_scale is nullptr"), return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.key_dequant_scale.desc == nullptr, + OPS_LOG_E(opName_, "Desc of tensor key_dequant_scale is nullptr"), return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.attenOut.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor output is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.attenOut.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor output is nullptr"), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::CheckRequiredAttrExistence() const +{ + OPS_ERR_IF(opParamInfo_.layOutQuery == nullptr, OPS_LOG_E(opName_, "attr layout_query is nullptr"), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(opParamInfo_.layOutKey == nullptr, OPS_LOG_E(opName_, "attr layout_key is nullptr"), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(opParamInfo_.sparseCount == nullptr, OPS_LOG_E(opName_, "attr sparse_count is nullptr"), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(opParamInfo_.sparseMode == nullptr, OPS_LOG_E(opName_, "attr sparse_mode is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.queryQuantMode == nullptr, OPS_LOG_E(opName_, "query_quant_mode is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.keyQuantMode == nullptr, OPS_LOG_E(opName_, "key_quant_mode is nullptr"), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::CheckRequiredParaExistence() const +{ + if (CheckRequiredInOutExistence() != ge::GRAPH_SUCCESS || CheckRequiredAttrExistence() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::GetOpName() +{ + if (context_->GetNodeName() == nullptr) { + OPS_LOG_E("LightningIndexerQuant", "opName got from TilingContext is nullptr"); + return ge::GRAPH_FAILED; + } + opName_ = context_->GetNodeName(); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::GetNpuInfo() +{ + platformInfo_ = context_->GetPlatformInfo(); + OPS_ERR_IF(platformInfo_ == nullptr, OPS_LOG_E(opName_, "GetPlatformInfo is nullptr."), return ge::GRAPH_FAILED); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo_); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint32_t aicNum = ascendcPlatform.GetCoreNumAic(); + OPS_ERR_IF(aicNum == 0 || aivNum == 0, OPS_LOG_E(opName_, "num of core obtained is 0."), return GRAPH_FAILED); + + socVersion_ = ascendcPlatform.GetSocVersion(); + if ((socVersion_ != platform_ascendc::SocVersion::ASCEND910B) && + (socVersion_ != platform_ascendc::SocVersion::ASCEND910_93)) { + OPS_LOG_E(opName_, "SOC Version[%d] is not support.", (int32_t)socVersion_); + return GRAPH_FAILED; + } + OPS_ERR_IF(context_->GetWorkspaceSizes(1) == nullptr, OPS_LOG_E(opName_, "workSpaceSize got from ge is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(context_->GetRawTilingData() == nullptr, + OPS_LOG_E(context_->GetNodeName(), "RawTilingData got from GE context is nullptr."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +void LIQInfoParser::GetOptionalInputParaInfo() +{ + opParamInfo_.actualSeqLengthsQ.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_Q_INDEX); + opParamInfo_.actualSeqLengthsQ.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_Q_INDEX); + opParamInfo_.actualSeqLengthsK.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_K_INDEX); + opParamInfo_.actualSeqLengthsK.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_K_INDEX); + opParamInfo_.blockTable.tensor = context_->GetOptionalInputTensor(BLOCK_TABLE_INDEX); + opParamInfo_.blockTable.desc = context_->GetOptionalInputDesc(BLOCK_TABLE_INDEX); +} + +void LIQInfoParser::GetInputParaInfo() +{ + opParamInfo_.query.desc = context_->GetInputDesc(QUERY_INDEX); + opParamInfo_.query.shape = context_->GetInputShape(QUERY_INDEX); + opParamInfo_.key.desc = context_->GetInputDesc(KEY_INDEX); + opParamInfo_.key.shape = context_->GetInputShape(KEY_INDEX); + opParamInfo_.weights.desc = context_->GetInputDesc(WEIGTHS_INDEX); + opParamInfo_.weights.shape = context_->GetInputShape(WEIGTHS_INDEX); + opParamInfo_.query_dequant_scale.desc = context_->GetInputDesc(QUERY_DEQUANT_SCALE_INDEX); + opParamInfo_.query_dequant_scale.shape = context_->GetInputShape(QUERY_DEQUANT_SCALE_INDEX); + opParamInfo_.key_dequant_scale.desc = context_->GetInputDesc(KEY_DEQUANT_SCALE_INDEX); + opParamInfo_.key_dequant_scale.shape = context_->GetInputShape(KEY_DEQUANT_SCALE_INDEX); + GetOptionalInputParaInfo(); +} + +void LIQInfoParser::GetOutputParaInfo() +{ + opParamInfo_.attenOut.desc = context_->GetOutputDesc(LIGHTNING_INDEXER_QUANT); + opParamInfo_.attenOut.shape = context_->GetOutputShape(LIGHTNING_INDEXER_QUANT); +} + +ge::graphStatus LIQInfoParser::GetAttrParaInfo() +{ + auto attrs = context_->GetAttrs(); + OPS_ERR_IF(attrs == nullptr, OPS_REPORT_VECTOR_INNER_ERR(context_->GetNodeName(), "attrs got from ge is nullptr"), + return ge::GRAPH_FAILED); + + OPS_LOG_I(context_->GetNodeName(), "GetAttrParaInfo start"); + opParamInfo_.layOutQuery = attrs->GetStr(ATTR_QUERY_LAYOUT_INDEX); + opParamInfo_.layOutKey = attrs->GetStr(ATTR_KEY_LAYOUT_INDEX); + + opParamInfo_.queryQuantMode = attrs->GetAttrPointer(ATTR_QUERY_QUANT_MODE_INDEX); + opParamInfo_.keyQuantMode = attrs->GetAttrPointer(ATTR_KEY_QUANT_MODE_INDEX); + opParamInfo_.layOutQuery = attrs->GetStr(ATTR_QUERY_LAYOUT_INDEX); + opParamInfo_.layOutKey = attrs->GetStr(ATTR_KEY_LAYOUT_INDEX); + opParamInfo_.sparseCount = attrs->GetAttrPointer(ATTR_SPARSE_COUNT_INDEX); + opParamInfo_.sparseMode = attrs->GetAttrPointer(ATTR_SPARSE_MODE_INDEX); + + if (opParamInfo_.layOutQuery != nullptr) { + OPS_LOG_I(context_->GetNodeName(), "layout_query is:%s", opParamInfo_.layOutQuery); + } + if (opParamInfo_.layOutKey != nullptr) { + OPS_LOG_I(context_->GetNodeName(), "layout_key is:%s", opParamInfo_.layOutKey); + } + if (opParamInfo_.sparseCount != nullptr) { + OPS_LOG_I(context_->GetNodeName(), "selscted count is:%d", *opParamInfo_.sparseCount); + } + if (opParamInfo_.sparseMode != nullptr) { + OPS_LOG_I(context_->GetNodeName(), "sparse mode is:%d", *opParamInfo_.sparseMode); + } + if (opParamInfo_.queryQuantMode != nullptr) { + OPS_LOG_I(context_->GetNodeName(), "query_quant_mode mode is:%d", *opParamInfo_.queryQuantMode); + } + if (opParamInfo_.keyQuantMode != nullptr) { + OPS_LOG_I(context_->GetNodeName(), "key_quant_mode mode is:%d", *opParamInfo_.keyQuantMode); + } + OPS_LOG_I(context_->GetNodeName(), "GetAttrParaInfo end"); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::CheckAttrParaInfo() +{ + std::string layout_key(opParamInfo_.layOutKey); + std::string layout_query(opParamInfo_.layOutQuery); + OPS_ERR_IF( + ((std::string(opParamInfo_.layOutKey) == "BNSD") || (std::string(opParamInfo_.layOutKey) == "PA_BBND")), + OPS_LOG_E(opName_, "input attr layout_key only supported PA_BSND, PA_BBND, BSND or TND" + "but now layout_key is %s.", layout_key.c_str()), + return ge::GRAPH_FAILED); + OPS_ERR_IF(((std::string(opParamInfo_.layOutQuery) != "BSND") && (std::string(opParamInfo_.layOutQuery) != "TND")), + OPS_LOG_E(opName_, "input attr layout_query only supported BSND or TND."), return ge::GRAPH_FAILED); + OPS_ERR_IF( + ((std::string(opParamInfo_.layOutKey) != "PA_BSND") && + (std::string(opParamInfo_.layOutQuery)) != (std::string(opParamInfo_.layOutKey))), + OPS_LOG_E(opName_, "outside of PA, input attr layout_query and input attr layout_key must be the same, but now layout_key is %s, layout_query is %s.", + layout_key.c_str(), layout_query.c_str()), return ge::GRAPH_FAILED); + OPS_ERR_IF(!((*opParamInfo_.sparseCount > 0) && (*opParamInfo_.sparseCount <= SPARSE_LIMIT)), + OPS_LOG_E(opName_, "input attr sparse_count must > 0 and <= 2048."), return ge::GRAPH_FAILED); + OPS_ERR_IF(!((*opParamInfo_.sparseMode == 0) || (*opParamInfo_.sparseMode == SPARSE_MODE_LOWER)), + OPS_LOG_E(opName_, "input attr sparse_mode only supported 0 or 3, but now is %u.", + *opParamInfo_.sparseMode), return ge::GRAPH_FAILED); + + OPS_ERR_IF(*opParamInfo_.queryQuantMode != 0, OPS_LOG_E(opName_, "input attr query_quant_mode only supported 0."), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(*opParamInfo_.keyQuantMode != 0, OPS_LOG_E(opName_, "input attr key_quant_mode only supported 0."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::GetOpParaInfo() +{ + GetInputParaInfo(); + GetOutputParaInfo(); + if (ge::GRAPH_SUCCESS != GetAttrParaInfo()) { + return ge::GRAPH_FAILED; + } + if (ge::GRAPH_SUCCESS != CheckAttrParaInfo()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::GetAndCheckInOutDataType() +{ + inputQType_ = opParamInfo_.query.desc->GetDataType(); + inputKType_ = opParamInfo_.key.desc->GetDataType(); + weightsType_ = opParamInfo_.weights.desc->GetDataType(); + inputQueryScaleType_ = opParamInfo_.query_dequant_scale.desc->GetDataType(); + inputKeyScaleType_ = opParamInfo_.key_dequant_scale.desc->GetDataType(); + outputType_ = opParamInfo_.attenOut.desc->GetDataType(); + + OPS_ERR_IF(!(inputQType_ == inputKType_), + OPS_LOG_E(opName_, "The data types of the input query and key must be the same, but now is %s, %s respectively.", + inputQType_, inputKType_), + return ge::GRAPH_FAILED); + + OPS_ERR_IF( + !(inputQueryScaleType_ == inputKeyScaleType_), + OPS_LOG_E(opName_, "The data types of the input query_dequant_scale and key_dequant_scale must be the same."), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(inputQType_ != ge::DT_INT8, + OPS_LOG_E(opName_, "The data types of the input query and key must be int8."), return ge::GRAPH_FAILED); + + OPS_ERR_IF(weightsType_ != ge::DT_FLOAT16, + OPS_LOG_E(opName_, "The data types of the input weights must be float16."), return ge::GRAPH_FAILED); + + OPS_ERR_IF( + inputQueryScaleType_ != ge::DT_FLOAT16, + OPS_LOG_E(opName_, "The data types of the input query_dequant_scale and key_dequant_scale must be float16."), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(outputType_ != ge::DT_INT32, + OPS_LOG_E(opName_, "The data types of the output sparse_indices must be int32."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::GetQueryKeyAndOutLayout() +{ + // 获取query,key的Layout基准值 + const map layoutQueryMap = {{"BSND", DataLayout::BSND}, {"TND", DataLayout::TND}}; + + std::string layout_query(opParamInfo_.layOutQuery); + auto QLayout_ = layoutQueryMap.find(layout_query); + if (QLayout_ != layoutQueryMap.end()) { + qLayout_ = QLayout_->second; + } + + const map layoutKeyMap = { + {"BSND", DataLayout::BSND}, {"TND", DataLayout::TND}, + {"PA_BSND", DataLayout::PA_BSND}, {"PA_BBND", DataLayout::PA_BSND}}; + std::string layout_key(opParamInfo_.layOutKey); + auto KLayout = layoutKeyMap.find(layout_key); + if (KLayout != layoutKeyMap.end()) { + kLayout_ = KLayout->second; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::GetAndCheckOptionalInput() +{ + if (kLayout_ == DataLayout::PA_BSND) { + OPS_ERR_IF(opParamInfo_.blockTable.tensor == nullptr, + OPS_LOG_E(opName_, "key layout only supported PA_BSND, input block_table must not be null"), + return ge::GRAPH_FAILED); + OPS_ERR_IF( + opParamInfo_.actualSeqLengthsK.tensor == nullptr, + OPS_LOG_E(opName_, "key layout only supported PA_BSND, input actual_seq_lengths_key must not be null"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.blockTable.desc->GetDataType() != ge::DT_INT32, + OPS_LOG_E(opName_, "input block_table data type only support int32"), return ge::GRAPH_FAILED); + } else { + OPS_ERR_IF(opParamInfo_.blockTable.tensor != nullptr, + OPS_LOG_E(opName_, "key layout is not PA_BSND, input block_table must be null"), + return ge::GRAPH_FAILED); + } + + if (kLayout_ == DataLayout::TND) { + OPS_ERR_IF(opParamInfo_.actualSeqLengthsK.tensor == nullptr, + OPS_LOG_E(opName_, "when layout_key is TND, input actual_seq_lengths_key must not be null"), + return ge::GRAPH_FAILED); + } + OPS_ERR_IF(opParamInfo_.actualSeqLengthsK.tensor != nullptr && + opParamInfo_.actualSeqLengthsK.desc->GetDataType() != ge::DT_INT32, + OPS_LOG_E(opName_, "input actual_seq_lengths_key data type only support int32"), + return ge::GRAPH_FAILED); + if (qLayout_ == DataLayout::TND) { + OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor == nullptr, + OPS_LOG_E(opName_, "when layout_query is TND, input actual_seq_lengths_query must not be null"), + return ge::GRAPH_FAILED); + } + OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor != nullptr && + opParamInfo_.actualSeqLengthsQ.desc->GetDataType() != ge::DT_INT32, + OPS_LOG_E(opName_, "input actual_seq_lengths_query data type only support int32"), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::CheckShapeDim() +{ + OPS_ERR_IF((opParamInfo_.blockTable.tensor != nullptr) && + (opParamInfo_.blockTable.tensor->GetStorageShape().GetDimNum() != DIM_NUM_TWO), + OPS_LOG_E(opName_, "the dim num of block_table's shape should be 2, but now is %u", + opParamInfo_.blockTable.tensor->GetStorageShape().GetDimNum()), return ge::GRAPH_FAILED); + OPS_ERR_IF( + (kLayout_ == DataLayout::PA_BSND) && (opParamInfo_.key.shape->GetStorageShape().GetDimNum() != DIM_NUM_FOUR), + OPS_LOG_E(opName_, "the dim num of key's shape should be 4, but now is %u", + opParamInfo_.key.shape->GetStorageShape().GetDimNum()), return ge::GRAPH_FAILED); + + uint32_t qShapeDim = opParamInfo_.query.shape->GetStorageShape().GetDimNum(); + uint32_t weightsShapeDim = opParamInfo_.weights.shape->GetStorageShape().GetDimNum(); + uint32_t outShapeDim = opParamInfo_.attenOut.shape->GetStorageShape().GetDimNum(); + uint32_t expectShapeDim = DIM_NUM_FOUR; + if (qLayout_ == DataLayout::TND) { + expectShapeDim = DIM_NUM_THREE; + } + OPS_ERR_IF( + qShapeDim != expectShapeDim, + OPS_LOG_E(opName_, "the dim num of query's shape should be %u, but now is %u", expectShapeDim, qShapeDim), + return ge::GRAPH_FAILED); + OPS_ERR_IF(outShapeDim != expectShapeDim, + OPS_LOG_E(opName_, "the dim num of sparse_indices's shape should be %u, but now is %u", expectShapeDim, + outShapeDim), + return ge::GRAPH_FAILED); + OPS_ERR_IF(!(weightsShapeDim == expectShapeDim - 1), + OPS_LOG_E(opName_, "the dim num of weights's shape should be %u, but now is %u", expectShapeDim - 1, + weightsShapeDim), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::GetN1Size() +{ + if (qLayout_ == DataLayout::BSND) { + n1Size_ = static_cast(opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_TWO)); + } else { + // TND + n1Size_ = static_cast(opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ONE)); + } + OPS_LOG_I(context_->GetNodeName(), "n1Size is %d", n1Size_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor, + const std::string &actualSeqLenName) +{ + size = static_cast(tensor->GetShapeSize()); + if (size <= 0) { + OPS_LOG_E(opName_, "%s's shape size is %u, it should be greater than 0.", actualSeqLenName.c_str(), size); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::GetAndCheckN2Size() +{ + // PA_BSND + if (kLayout_ == DataLayout::TND) { + n2Size_ = static_cast(opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_ONE)); + } else { + n2Size_ = static_cast(opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_TWO)); + } + OPS_LOG_I(context_->GetNodeName(), "N2 is %d", n2Size_); + OPS_ERR_IF(n2Size_ != 1, OPS_LOG_E(opName_, "key shape[2] is numhead, only support 1."), return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::GetGSize() +{ + if (n1Size_ % n2Size_ != 0) { + OPS_LOG_E(opName_, "input query's head_num %u can not be a multiple of key's head_num %u.", n1Size_, n2Size_); + return ge::GRAPH_FAILED; + } + gSize_ = n1Size_ / n2Size_; + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::GetBatchSize() +{ + // 获取B基准值 + // 1、非TND/NTD时, 以query的batch_size维度为基准; + // 2、TND/NTD时, actual_seq_lens_q必须传入, 以actual_seq_lens_q数组的长度为B轴大小 + if (qLayout_ == DataLayout::TND) { + return GetActualSeqLenSize(bSize_, opParamInfo_.actualSeqLengthsQ.tensor, "input actual_seq_lengths_query"); + } else { // BSND + bSize_ = opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ZERO); + OPS_LOG_I(context_->GetNodeName(), "b: %d, s: %d, n: %d,d :%d", + opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ZERO), + opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ONE), + opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_TWO), + opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_THREE)); + return ge::GRAPH_SUCCESS; + } +} + +ge::graphStatus LIQInfoParser::GetHeadDim() +{ + // 以query的D维度为基准 + uint32_t dIndex = DIM_IDX_TWO; + // 根据layout确定D维度在shape中的位置 + switch (qLayout_) { + case DataLayout::TND: + // TND格式: [Total, N, D] -> D是第2维(索引2) + dIndex = DIM_IDX_TWO; + break; + case DataLayout::BSND: + // BSND格式: [Batch, SeqLen, N, D] -> D是第3维(索引3) + dIndex = DIM_IDX_THREE; + break; + default: + OPS_LOG_E(opName_, "unsupported layout for getting head dim."); + return ge::GRAPH_FAILED; + } + headDim_ = opParamInfo_.query.shape->GetStorageShape().GetDim(dIndex); + OPS_ERR_IF(headDim_ != HEAD_DIM_LIMIT, OPS_LOG_E(opName_, "input query's last dim head_dim only support 128, but now is %u.", headDim_), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::GetS1Size() +{ + if (qLayout_ == DataLayout::BSND) { + s1Size_ = opParamInfo_.query.shape->GetStorageShape().GetDim(1); + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::GetAndCheckBlockSize() +{ + blockSize_ = static_cast(opParamInfo_.key.shape->GetStorageShape().GetDim(1)); + OPS_LOG_I(context_->GetNodeName(), "blockSize_ is %d", blockSize_); + + OPS_ERR_IF( + ((blockSize_ % BLOCK_SIZE_FACTOR != 0) || (blockSize_ == 0) || (blockSize_ > BLOCK_SIZE_LIMIT)), + OPS_LOG_E(opName_, "input key's block_size must be a multiple of 16 and belong to (0, 1024], but now is %u.", + blockSize_), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::GetS2SizeForPageAttention() +{ + if (GetAndCheckBlockSize() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + int32_t blockCount_ = static_cast(opParamInfo_.key.shape->GetStorageShape().GetDim(0)); + OPS_ERR_IF((blockCount_ == 0), OPS_LOG_E(opName_, "input key's block_count cannot be 0."), return ge::GRAPH_FAILED); + + maxBlockNumPerBatch_ = opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(1); + s2Size_ = maxBlockNumPerBatch_ * blockSize_; + OPS_LOG_I(context_->GetNodeName(), "maxBlockNumPerBatch_ is %d, blockSize_ is %d, s2Size_ is %d", + maxBlockNumPerBatch_, blockSize_, s2Size_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::GetS2SizeForBatchContinuous() +{ + std::string layout_key(opParamInfo_.layOutKey); + if (kLayout_ == DataLayout::BSND) { + s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_ONE); + } else if (kLayout_ == DataLayout::TND) { + s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_ZERO); + } + OPS_ERR_IF((kLayout_ != DataLayout::BSND) && (kLayout_ != DataLayout::TND), + OPS_LOG_E(opName_, "the layout of key is %s, it is unsupported.", layout_key.c_str()), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::GetS2Size() +{ + // 获取S2基准值 + // 1、BATCH_CONTINUOUS时, 从key的S轴获取 + // 3、PAGE_ATTENTION时, S2 = block_table.dim1 * block_size + if (kLayout_ == DataLayout::PA_BSND) { + return GetS2SizeForPageAttention(); + } + return GetS2SizeForBatchContinuous(); +} + +ge::graphStatus LIQInfoParser::ValidateInputShapesMatch() +{ + /* + TND: + query [T,N1,D], + key [BlockNum,BlockSize,N2,D], + weight [T,N1], + block_table [BatchSize, BatchMaxBlockNum], + act_seq_k [BatchSize] + act_seq_q [BatchSize], + out [T,N2,topk] + ---------------------- + BSND: + query [BatchSize,S1,N1,D], + key [BlockNum,BlockSize,N2,D], + weight [BatchSize,S1,N1], + block_table [BatchSize, BatchMaxBlockNum], + act_seq_k [BatchSize] + act_seq_q [BatchSize] 可选 + out [BatchSize,S1,N2,topk] + */ + uint32_t queryWeightsN1Dim = 1; + uint32_t outN2Dim = 1; + + if (qLayout_ == DataLayout::TND) { + // -----------------------check BatchSize------------------- + // bSize_ 来源于act_seq_q + OPS_ERR_IF((kLayout_ == DataLayout::PA_BSND) && + ((opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_) || + (opParamInfo_.blockTable.tensor != nullptr && + opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_)), + OPS_LOG_E( + opName_, + "TND case input actual_seq_lengths_query, actual_seq_lengths_key, block_table dim 0 are %u, %u, %u respectively, they must be same.", + bSize_, opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize(), + opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0)), + return ge::GRAPH_FAILED); + OPS_ERR_IF((kLayout_ != DataLayout::PA_BSND) && + (opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_), + OPS_LOG_E( + opName_, + "TND case input actual_seq_lengths_query, actual_seq_lengths_key, are %u, %u respectively, they must be same.", + bSize_, opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize()), + return ge::GRAPH_FAILED); + // -----------------------check T------------------- + uint32_t qTsize = opParamInfo_.query.shape->GetStorageShape().GetDim(0); + OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != qTsize) || + (opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != qTsize), + OPS_LOG_E(opName_, + "TND case input query, weights, sparse_indices dim 0 are %u, %u, %u respectively, they must be same.", + qTsize, opParamInfo_.weights.shape->GetStorageShape().GetDim(0), + opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)), + return ge::GRAPH_FAILED); + } else { + // -----------------------check BatchSize------------------- + // bSize_ 来源于query + OPS_ERR_IF((kLayout_ == DataLayout::PA_BSND) && + ((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != bSize_) || + (opParamInfo_.blockTable.tensor != nullptr && + opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_) || + (opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_) || + (opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != bSize_)), + OPS_LOG_E(opName_, + "BSND case input query, weight, actual_seq_lengths_key, block_table, sparse_indices dim 0 are %u, %u, %u, %u, %u respectively, they must be same.", + bSize_, opParamInfo_.weights.shape->GetStorageShape().GetDim(0), + opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize(), + opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0), + opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)), + return ge::GRAPH_FAILED); + OPS_ERR_IF((kLayout_ != DataLayout::PA_BSND) && + ((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != bSize_) || + (opParamInfo_.actualSeqLengthsK.tensor != nullptr && + opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_) || + (opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != bSize_)), + OPS_LOG_E(opName_, + "BSND case input query, weight, actual_seq_lengths_key, sparse_indices dim 0 are %u, %u, %u, %u respectively, they must be same.", + bSize_, opParamInfo_.weights.shape->GetStorageShape().GetDim(0), + opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize(), + opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)), + return ge::GRAPH_FAILED); + OPS_ERR_IF( + (opParamInfo_.actualSeqLengthsQ.tensor != nullptr) && + (opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize() != bSize_), + OPS_LOG_E( + opName_, + "BSND case input query, actual_seq_lengths_query dim 0 are %u, %u respectively, they must be same", + bSize_, opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize()), + return ge::GRAPH_FAILED); + // -----------------------check S1------------------- + OPS_ERR_IF( + (opParamInfo_.weights.shape->GetStorageShape().GetDim(1) != s1Size_) || + (opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1) != s1Size_), + OPS_LOG_E(opName_, "BSND case input query, weight, sparse_indices dim 1 are %u, %u, %u, they must be same.", + s1Size_, opParamInfo_.weights.shape->GetStorageShape().GetDim(1), + opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1)), + return ge::GRAPH_FAILED); + queryWeightsN1Dim = DIM_IDX_TWO; + outN2Dim = DIM_IDX_TWO; + } + // -----------------------check N1------------------- + OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(queryWeightsN1Dim) != n1Size_), + OPS_LOG_E(opName_, "input query, weight shape dim N1 must be same, but now are %u, %u respectively.", + opParamInfo_.weights.shape->GetStorageShape().GetDim(queryWeightsN1Dim), n1Size_), + return ge::GRAPH_FAILED); + // -----------------------check D------------------- + OPS_ERR_IF( + ((kLayout_ != DataLayout::TND && opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_THREE) != headDim_) + || (kLayout_ == DataLayout::TND && opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_TWO) != headDim_)), + OPS_LOG_E(opName_, "input query, key shape last dim must be same, now are %u, %u respectively.", + headDim_, opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_THREE)), + return ge::GRAPH_FAILED); + // -----------------------check N2------------------- + OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim) != n2Size_), + OPS_LOG_E(opName_, "input query and output sparse_indices shape n2 dim must be same."), + return ge::GRAPH_FAILED); + // -----------------------check sparse_count------------------- + OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim + 1) != *opParamInfo_.sparseCount), + OPS_LOG_E(opName_, "output sparse_indices shape last dim must be same as attr sparse_count."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIQInfoParser::CheckScaleShape() +{ + uint32_t qShapeDim = opParamInfo_.query.shape->GetStorageShape().GetDimNum(); + uint32_t kShapeDim = opParamInfo_.key.shape->GetStorageShape().GetDimNum(); + uint32_t qDequantScaleShapeDim = opParamInfo_.query_dequant_scale.shape->GetStorageShape().GetDimNum(); + uint32_t kDequantScaleShapeDim = opParamInfo_.key_dequant_scale.shape->GetStorageShape().GetDimNum(); + OPS_ERR_IF(qDequantScaleShapeDim != (qShapeDim - 1), + OPS_LOG_E(opName_, "the dim num of query_dequant_scale's shape should be %u, but now is %u", + qShapeDim - 1, qDequantScaleShapeDim), + return ge::GRAPH_FAILED); + OPS_ERR_IF(kDequantScaleShapeDim != (kShapeDim - 1), + OPS_LOG_E(opName_, "the dim num of key_dequant_scale's shape should be %u, but now is %u", kShapeDim - 1, + kDequantScaleShapeDim), + return ge::GRAPH_FAILED); + // check q scale + for (uint32_t i = 0; i < (qShapeDim - 1); i++) { + uint32_t dimValueQueryScale = opParamInfo_.query_dequant_scale.shape->GetStorageShape().GetDim(i); + uint32_t dimValueQuery = opParamInfo_.query.shape->GetStorageShape().GetDim(i); + OPS_ERR_IF(dimValueQueryScale != dimValueQuery, + OPS_LOG_E(opName_, "query_dequant_scale's shape[%u] %u and query's shape[%u] %u is not same", i, + dimValueQueryScale, i, dimValueQuery), + return ge::GRAPH_FAILED); + } + // check k scale + for (uint32_t i = 0; i < (kShapeDim - 1); i++) { + uint32_t dimValueKeyScale = opParamInfo_.key_dequant_scale.shape->GetStorageShape().GetDim(i); + uint32_t dimValueKey = opParamInfo_.key.shape->GetStorageShape().GetDim(i); + OPS_ERR_IF(dimValueKeyScale != dimValueKey, + OPS_LOG_E(opName_, "key_dequant_scale's shape[%u] %u and key's shape[%u] %u is not same", i, + dimValueKeyScale, i, dimValueKey), + return ge::GRAPH_FAILED); + } + + return ge::GRAPH_SUCCESS; +} + +void LIQInfoParser::GenerateInfo(LIQTilingInfo &liqInfo) +{ + liqInfo.opName = opName_; + liqInfo.platformInfo = platformInfo_; + liqInfo.opParamInfo = opParamInfo_; + liqInfo.socVersion = socVersion_; + + liqInfo.bSize = bSize_; + liqInfo.n1Size = n1Size_; + liqInfo.n2Size = n2Size_; + liqInfo.s1Size = s1Size_; + liqInfo.s2Size = s2Size_; + liqInfo.gSize = gSize_; + + liqInfo.inputQType = inputQType_; + liqInfo.inputKType = inputKType_; + liqInfo.outputType = outputType_; + + liqInfo.blockSize = blockSize_; + liqInfo.maxBlockNumPerBatch = maxBlockNumPerBatch_; + + liqInfo.pageAttentionFlag = (kLayout_ == DataLayout::PA_BSND); + liqInfo.sparseMode = *opParamInfo_.sparseMode; + liqInfo.sparseCount = *opParamInfo_.sparseCount; + + liqInfo.inputQLayout = qLayout_; + liqInfo.inputKLayout = kLayout_; +} + +ge::graphStatus LIQInfoParser::ParseAndCheck(LIQTilingInfo &liqInfo) +{ + if (ge::GRAPH_SUCCESS != GetOpName() || ge::GRAPH_SUCCESS != GetNpuInfo() || ge::GRAPH_SUCCESS != GetOpParaInfo() || + ge::GRAPH_SUCCESS != CheckRequiredParaExistence()) { + return ge::GRAPH_FAILED; + } + + if (ge::GRAPH_SUCCESS != GetAndCheckInOutDataType() || ge::GRAPH_SUCCESS != GetQueryKeyAndOutLayout() || + ge::GRAPH_SUCCESS != GetAndCheckOptionalInput()) { + return ge::GRAPH_FAILED; + } + + if (ge::GRAPH_SUCCESS != CheckShapeDim() || ge::GRAPH_SUCCESS != GetN1Size() || + ge::GRAPH_SUCCESS != GetAndCheckN2Size() || ge::GRAPH_SUCCESS != GetGSize()) { + return ge::GRAPH_FAILED; + } + + if (ge::GRAPH_SUCCESS != GetBatchSize() || ge::GRAPH_SUCCESS != GetS1Size() || ge::GRAPH_SUCCESS != GetHeadDim() || + ge::GRAPH_SUCCESS != GetS2Size()) { + return ge::GRAPH_FAILED; + } + if (ge::GRAPH_SUCCESS != ValidateInputShapesMatch() || ge::GRAPH_SUCCESS != CheckScaleShape()) { + return ge::GRAPH_FAILED; + } + + GenerateInfo(liqInfo); + + return ge::GRAPH_SUCCESS; +} + +// --------------------------TilingPrepare函数定义------------------------------------- +static ge::graphStatus TilingPrepareForLightningIndexerQuant(gert::TilingParseContext * /* context */) +{ + return ge::GRAPH_SUCCESS; +} + +// --------------------------LightningIndexerQuantTiling类成员函数定义----------------------- +ge::graphStatus LightningIndexerQuantTiling::DoTiling(LIQTilingInfo *tilingInfo) +{ + // -------------set blockdim----------------- + auto ascendcPlatform = platform_ascendc::PlatformAscendC(tilingInfo->platformInfo); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint32_t aicNum = ascendcPlatform.GetCoreNumAic(); + uint32_t blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, aicNum, aivNum); + context_->SetBlockDim(blockDim); + + // -------------set workspacesize----------------- + constexpr uint32_t MM1_RES_ELEM_SIZE = 4; // 4: fp32 + constexpr uint32_t DOUBLE_BUFFER = 2; // 双Buffer + constexpr uint32_t M_BASE_SIZE = 512; // m轴基本块大小 + constexpr uint32_t S2_BASE_SIZE = 512; // S2轴基本块大小 + constexpr uint32_t V1_RES_ELEM_SIZE = 4; // 4: int32 + constexpr uint32_t V1_RES_ELEM_TYPE = 2; // 保留Index和Value 2种数据 + constexpr uint32_t V1_DECODE_PARAM_ELEM_SIZE = 8; // 8: int64 + constexpr uint32_t V1_DECODE_PARAM_NUM = 16; // Decode参数个数 + constexpr uint32_t V1_DECODE_DATA_NUM = 2; // Decode每个核需要存储头和尾部两块数据 + constexpr uint32_t S1_BASE_SIZE = 8; // S1轴基本块的大小 + constexpr uint32_t TOPK_MAX_SIZE = 2048; // TopK选取个数 + uint32_t workspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize(); + // 主流程需Workspace大小 + uint32_t mm1ResSize = M_BASE_SIZE * S2_BASE_SIZE; + workspaceSize += mm1ResSize * MM1_RES_ELEM_SIZE * DOUBLE_BUFFER * aicNum; + // Decode流程(LD)需要Workspace大小 + // 临时存储Decode中间结果大小: 2(头/尾)*8(s1Base)*2(idx/value)*2048(K)*sizeof(int32)*24=6M + workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_RES_ELEM_TYPE * TOPK_MAX_SIZE * V1_RES_ELEM_SIZE * aicNum; + // 临时存储Decode中间参数信息大小: 2(头/尾)*8(s1Base)*16(paramNum)*sizeof(int64_t)*24=48k + workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_DECODE_PARAM_NUM * V1_DECODE_PARAM_ELEM_SIZE * aicNum; + size_t *workSpaces = context_->GetWorkspaceSizes(1); + workSpaces[0] = workspaceSize; + + // -------------set tilingdata----------------- + tilingData_.set_bSize(tilingInfo->bSize); + tilingData_.set_s2Size(tilingInfo->s2Size); + tilingData_.set_s1Size(tilingInfo->s1Size); + tilingData_.set_sparseCount(tilingInfo->sparseCount); + tilingData_.set_gSize(tilingInfo->gSize); + tilingData_.set_blockSize(tilingInfo->blockSize); + tilingData_.set_maxBlockNumPerBatch(tilingInfo->maxBlockNumPerBatch); + tilingData_.set_sparseMode(tilingInfo->sparseMode); + tilingData_.set_usedCoreNum(blockDim); + tilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity()); + context_->GetRawTilingData()->SetDataSize(tilingData_.GetDataSize()); + + // -------------set tilingkey----------------- + // DT_Q, DT_KV, DT_OUT, PAGE_ATTENTION, FLASH_DECODE, LAYOUT_T, KV_LAYOUT_T + uint32_t inputQType = static_cast(tilingInfo->inputQType); + uint32_t inputKType = static_cast(tilingInfo->inputKType); + uint32_t outputType = static_cast(tilingInfo->outputType); + uint32_t pageAttentionFlag = static_cast(tilingInfo->pageAttentionFlag); + uint32_t inputQLayout = static_cast(tilingInfo->inputQLayout); + uint32_t inputKLayout = static_cast(tilingInfo->inputKLayout); + uint32_t tilingKey = + GET_TPL_TILING_KEY(inputQType, inputKType, outputType, pageAttentionFlag, inputQLayout, inputKLayout); + context_->SetTilingKey(tilingKey); + + return ge::GRAPH_SUCCESS; +} + +// --------------------------Tiling函数定义--------------------------- +ge::graphStatus TilingForLightningIndexerQuant(gert::TilingContext *context) +{ + OPS_ERR_IF(context == nullptr, OPS_REPORT_VECTOR_INNER_ERR("LightningIndexerQuant", "Tiling context is null."), + return ge::GRAPH_FAILED); + LIQTilingInfo liqInfo; + LIQInfoParser LIQInfoParser(context); + if (LIQInfoParser.ParseAndCheck(liqInfo) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + LightningIndexerQuantTiling liqTiling(context); + return liqTiling.DoTiling(&liqInfo); +} + +// --------------------------Tiling及函数TilingPrepare函数注册-------- +IMPL_OP_OPTILING(LightningIndexerQuant) + .Tiling(TilingForLightningIndexerQuant) + .TilingParse(TilingPrepareForLightningIndexerQuant); + +} // namespace optiling diff --git a/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_tiling.h b/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_tiling.h new file mode 100644 index 00000000..d51a51e7 --- /dev/null +++ b/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_tiling.h @@ -0,0 +1,234 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_quant_tiling.h + * \brief + */ + +#ifndef LIGHTNING_INDEXER_QUANT_TILING_H_ +#define LIGHTNING_INDEXER_QUANT_TILING_H_ + +#include "error/ops_error.h" +#include "exe_graph/runtime/tiling_context.h" +#include "platform/platform_info.h" +#include "register/op_def_registry.h" +#include "register/tilingdata_base.h" +#include "tiling/platform/platform_ascendc.h" +#include "tiling/tiling_api.h" + +namespace optiling { +// ------------------公共定义-------------------------- +struct TilingRequiredParaInfo { + const gert::CompileTimeTensorDesc *desc; + const gert::StorageShape *shape; +}; + +struct TilingOptionalParaInfo { + const gert::CompileTimeTensorDesc *desc; + const gert::Tensor *tensor; +}; + +enum class DataLayout : uint32_t { + BSND = 0, + TND = 1, + PA_BSND = 2 +}; + +// ------------------算子原型索引常量定义---------------- +// Inputs Index +constexpr uint32_t QUERY_INDEX = 0; +constexpr uint32_t KEY_INDEX = 1; +constexpr uint32_t WEIGTHS_INDEX = 2; +constexpr uint32_t QUERY_DEQUANT_SCALE_INDEX = 3; +constexpr uint32_t KEY_DEQUANT_SCALE_INDEX = 4; +constexpr uint32_t ACTUAL_SEQ_Q_INDEX = 5; +constexpr uint32_t ACTUAL_SEQ_K_INDEX = 6; +constexpr uint32_t BLOCK_TABLE_INDEX = 7; +constexpr uint32_t LIGHTNING_INDEXER_QUANT = 0; +// Attributes Index +constexpr uint32_t ATTR_QUERY_QUANT_MODE_INDEX = 0; +constexpr uint32_t ATTR_KEY_QUANT_MODE_INDEX = 1; +constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 2; +constexpr uint32_t ATTR_KEY_LAYOUT_INDEX = 3; +constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 4; +constexpr uint32_t ATTR_SPARSE_MODE_INDEX = 5; +// Dim Index +constexpr uint32_t DIM_IDX_ZERO = 0; +constexpr uint32_t DIM_IDX_ONE = 1; +constexpr uint32_t DIM_IDX_TWO = 2; +constexpr uint32_t DIM_IDX_THREE = 3; +// Dim Num +constexpr uint32_t DIM_NUM_TWO = 2; +constexpr uint32_t DIM_NUM_THREE = 3; +constexpr uint32_t DIM_NUM_FOUR = 4; +// 入参限制常量 +constexpr uint32_t HEAD_DIM_LIMIT = 128; +constexpr uint32_t SPARSE_LIMIT = 2048; +constexpr uint32_t G_SIZE_LIMIT = 64; +constexpr uint32_t BLOCK_SIZE_LIMIT = 1024; +constexpr uint32_t BLOCK_SIZE_FACTOR = 16; +constexpr uint32_t SPARSE_MODE_LOWER = 3; + +// -----------算子TilingData定义--------------- +BEGIN_TILING_DATA_DEF(LIQTilingData) +TILING_DATA_FIELD_DEF(uint32_t, bSize) +TILING_DATA_FIELD_DEF(uint32_t, n2Size) +TILING_DATA_FIELD_DEF(uint32_t, gSize) +TILING_DATA_FIELD_DEF(uint32_t, s1Size) +TILING_DATA_FIELD_DEF(uint32_t, s2Size) +TILING_DATA_FIELD_DEF(uint32_t, sparseCount) +TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum) +TILING_DATA_FIELD_DEF(uint32_t, blockSize) +TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch) +TILING_DATA_FIELD_DEF(uint32_t, sparseMode) +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(LightningIndexerQuant, LIQTilingData) + +// -----------算子CompileInfo定义------------------- +struct LIQCompileInfo {}; + +// -----------算子Tiling入参结构体定义--------------- +struct LIQParaInfo { + TilingRequiredParaInfo query = {nullptr, nullptr}; + TilingRequiredParaInfo key = {nullptr, nullptr}; + TilingRequiredParaInfo weights = {nullptr, nullptr}; + TilingRequiredParaInfo query_dequant_scale = {nullptr, nullptr}; + TilingRequiredParaInfo key_dequant_scale = {nullptr, nullptr}; + TilingOptionalParaInfo actualSeqLengthsQ = {nullptr, nullptr}; + TilingOptionalParaInfo actualSeqLengthsK = {nullptr, nullptr}; + TilingOptionalParaInfo blockTable = {nullptr, nullptr}; + TilingRequiredParaInfo attenOut = {nullptr, nullptr}; + + const int32_t *queryQuantMode = nullptr; + const int32_t *keyQuantMode = nullptr; + const char *layOutQuery = nullptr; + const char *layOutKey = nullptr; + const int32_t *blockSize = nullptr; + const int32_t *sparseMode = nullptr; + const int32_t *sparseCount = nullptr; +}; + +// -----------算子Tiling入参信息类--------------- +class LIQTilingInfo { +public: + const char *opName = nullptr; + fe::PlatFormInfos *platformInfo = nullptr; + LIQParaInfo opParamInfo; + // Base Param + platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910B; + uint32_t bSize = 0; + uint32_t n1Size = 0; + uint32_t n2Size = 0; + uint32_t s1Size = 0; + int64_t s2Size = 0; + uint32_t qkHeadDim = 0; + uint32_t gSize = 0; + // PageAttention + bool pageAttentionFlag = false; + int32_t blockSize = 0; + uint32_t maxBlockNumPerBatch = 0; + // Mask + int32_t sparseMode = 0; + // Others Flag + uint32_t sparseCount = 0; + // DType + ge::DataType inputQType = ge::DT_FLOAT16; + ge::DataType inputKType = ge::DT_FLOAT16; + ge::DataType outputType = ge::DT_INT32; + // Layout + DataLayout inputQLayout = DataLayout::BSND; + DataLayout inputKLayout = DataLayout::PA_BSND; +}; + +// -----------算子Tiling入参信息解析及Check类--------------- +class LIQInfoParser { +public: + explicit LIQInfoParser(gert::TilingContext *context) : context_(context) {} + ~LIQInfoParser() = default; + + ge::graphStatus CheckRequiredInOutExistence() const; + ge::graphStatus CheckRequiredAttrExistence() const; + ge::graphStatus CheckRequiredParaExistence() const; + ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor, + const std::string &actualSeqLenName); + ge::graphStatus GetOpName(); + ge::graphStatus GetNpuInfo(); + void GetOptionalInputParaInfo(); + void GetInputParaInfo(); + void GetOutputParaInfo(); + ge::graphStatus GetAttrParaInfo(); + ge::graphStatus CheckAttrParaInfo(); + ge::graphStatus GetOpParaInfo(); + ge::graphStatus ValidateInputShapesMatch(); + ge::graphStatus CheckScaleShape(); + ge::graphStatus GetAndCheckInOutDataType(); + ge::graphStatus GetBatchSize(); + ge::graphStatus GetHeadDim(); + ge::graphStatus GetS1Size(); + ge::graphStatus GetAndCheckOptionalInput(); + ge::graphStatus CheckShapeDim(); + ge::graphStatus GetAndCheckBlockSize(); + ge::graphStatus GetS2SizeForPageAttention(); + ge::graphStatus GetS2SizeForBatchContinuous(); + ge::graphStatus GetS2Size(); + ge::graphStatus GetQueryKeyAndOutLayout(); + ge::graphStatus GetN1Size(); + ge::graphStatus GetAndCheckN2Size(); + ge::graphStatus GetGSize(); + ge::graphStatus GetAttenMaskInfo(); + ge::graphStatus GetActualSeqInfo(); + void GenerateInfo(LIQTilingInfo &liqInfo); + ge::graphStatus ParseAndCheck(LIQTilingInfo &liqInfo); + +public: + gert::TilingContext *context_ = nullptr; + const char *opName_; + fe::PlatFormInfos *platformInfo_; + LIQParaInfo opParamInfo_; + + // BaseParams + uint32_t bSize_ = 0; + uint32_t n1Size_ = 0; + uint32_t n2Size_ = 0; + uint32_t gSize_ = 0; + uint32_t s1Size_ = 0; + int64_t s2Size_ = 0; + uint32_t headDim_ = 0; + // Layout + DataLayout qLayout_ = DataLayout::BSND; + DataLayout kLayout_ = DataLayout::PA_BSND; + // PageAttention + uint32_t maxBlockNumPerBatch_ = 0; + int32_t blockSize_ = 0; + platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B; + ge::DataType inputQType_ = ge::DT_FLOAT16; + ge::DataType inputKType_ = ge::DT_FLOAT16; + ge::DataType weightsType_ = ge::DT_FLOAT16; + ge::DataType inputQueryScaleType_ = ge::DT_FLOAT16; + ge::DataType inputKeyScaleType_ = ge::DT_FLOAT16; + ge::DataType blockTableType_ = ge::DT_FLOAT16; + ge::DataType inputKRopeType_ = ge::DT_FLOAT16; + ge::DataType outputType_ = ge::DT_FLOAT16; +}; + +// ---------------算子Tiling类--------------- +class LightningIndexerQuantTiling { +public: + explicit LightningIndexerQuantTiling(gert::TilingContext *context) : context_(context) {}; + ge::graphStatus DoTiling(LIQTilingInfo *tilingInfo); + +private: + gert::TilingContext *context_ = nullptr; + LIQTilingData tilingData_; +}; + +} // namespace optiling +#endif // LIGHTNING_INDEXER_QUANT_TILING_H_ \ No newline at end of file diff --git a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant.cpp b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant.cpp new file mode 100644 index 00000000..a9513daf --- /dev/null +++ b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant.cpp @@ -0,0 +1,50 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_quant.cpp + * \brief + */ + +#include "kernel_operator.h" +#include "lib/matmul_intf.h" +#include "lightning_indexer_quant_kernel.h" +#include "lightning_indexer_quant_template_tiling_key.h" + +using namespace LIQKernel; + +#define INVOKE_LI_NO_KFC_OP_IMPL(templateClass, ...) \ + do { \ + templateClass> op; \ + GET_TILING_DATA_WITH_STRUCT(LIQTilingData, tiling_data_in, tiling); \ + const LIQTilingData *__restrict tiling_data = &tiling_data_in; \ + op.Init(query, key, weights, queryScale, keyScale, actualSeqLengthsQ, actualSeqLengthsK, blocktable, \ + sparseIndices, user, tiling_data, &tPipe); \ + op.Process(); \ + } while (0) + +template +__global__ __aicore__ void lightning_indexer_quant(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights, + __gm__ uint8_t *queryScale, __gm__ uint8_t *keyScale, + __gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengthsK, + __gm__ uint8_t *blocktable, __gm__ uint8_t *sparseIndices, + __gm__ uint8_t *workspace, __gm__ uint8_t *tiling) +{ +#if (__CCE_AICORE__ == 310) || (defined __DAV_310R6__) || (__CCE_AICORE__ == 200) + +#else + TPipe tPipe; + __gm__ uint8_t *user = GetUserWorkspace(workspace); + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); + + INVOKE_LI_NO_KFC_OP_IMPL(LIQPreload, int8_t, int8_t, int32_t, + PAGE_ATTENTION, LI_LAYOUT(Q_LAYOUT_T), LI_LAYOUT(K_LAYOUT_T)); +#endif +} diff --git a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_common.h b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_common.h new file mode 100644 index 00000000..a0f0eb7f --- /dev/null +++ b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_common.h @@ -0,0 +1,146 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_quant_common.h + * \brief + */ +#ifndef LIGHTNING_INDEXER_QUANT_COMMON_H +#define LIGHTNING_INDEXER_QUANT_COMMON_H + +namespace LIQCommon { + +// 与tiling的layout保持一致 +enum class LI_LAYOUT : uint32_t { + BSND = 0, + TND = 1, + PA_BSND = 2 +}; + +template +struct LIQType { + using queryType = Q_T; + using keyType = K_T; + using outputType = OUT_T; + static constexpr bool pageAttention = PAGE_ATTENTION; + static constexpr LI_LAYOUT layout = Q_LAYOUT_T; + static constexpr LI_LAYOUT keyLayout = K_LAYOUT_T; +}; + +struct RunInfo { + uint32_t loop; + uint32_t bN2Idx; + uint32_t bIdx; + uint32_t n2Idx = 0; + uint32_t gS1Idx; + uint32_t s2Idx; + + uint32_t actS1Size = 1; + uint32_t actS2Size = 1; + uint32_t actMBaseSize; + uint32_t actualSingleProcessSInnerSize; + uint32_t actualSingleProcessSInnerSizeAlign; + + uint64_t tensorQueryOffset; + uint64_t tensorKeyOffset; + uint64_t tensorKeyScaleOffset; + uint64_t tensorWeightsOffset; + uint64_t indiceOutOffset; + + bool isFirstS2InnerLoop; + bool isLastS2InnerLoop; + bool isAllLoopEnd = false; + bool isValid = false; +}; + +struct ConstInfo { + // CUBE与VEC核间同步的模式 + static constexpr uint32_t FIA_SYNC_MODE2 = 2; + // BUFFER的字节数 + static constexpr uint32_t BUFFER_SIZE_BYTE_32B = 32; + static constexpr uint32_t BUFFER_SIZE_BYTE_64B = 64; + static constexpr uint32_t BUFFER_SIZE_BYTE_256B = 256; + static constexpr uint32_t BUFFER_SIZE_BYTE_512B = 512; + static constexpr uint32_t BUFFER_SIZE_BYTE_1K = 1024; + static constexpr uint32_t BUFFER_SIZE_BYTE_2K = 2048; + static constexpr uint32_t BUFFER_SIZE_BYTE_4K = 4096; + static constexpr uint32_t BUFFER_SIZE_BYTE_8K = 8192; + static constexpr uint32_t BUFFER_SIZE_BYTE_16K = 16384; + static constexpr uint32_t BUFFER_SIZE_BYTE_32K = 32768; + // 无效索引 + static constexpr int INVALID_IDX = -1; + + // CUBE和VEC的核间同步EventID + uint32_t syncC1V1 = 0U; + uint32_t syncC1V0 = 2U; + uint32_t syncV1C1 = 0U; + uint32_t syncV0C1 = 1U; + + // 基本块大小 + uint32_t mBaseSize = 1ULL; + uint32_t s1BaseSize = 1ULL; + uint32_t s2BaseSize = 1ULL; + + uint64_t batchSize = 0ULL; + uint64_t gSize = 0ULL; + uint64_t qHeadNum = 0ULL; + uint64_t kHeadNum; + uint64_t headDim; + uint64_t sparseCount; // topK选取大小 + uint64_t kSeqSize = 0ULL; // kv最大S长度 + uint64_t qSeqSize = 1ULL; // q最大S长度 + uint32_t kCacheBlockSize = 0; // PA场景的block size + uint32_t maxBlockNumPerBatch = 0; // PA场景的最大单batch block number + LI_LAYOUT outputLayout; // 输出的格式 + bool attenMaskFlag = false; + + uint32_t actualLenQDims = 0U; // query的actualSeqLength 的维度 + uint32_t actualLenDims = 0U; // KV 的actualSeqLength 的维度 + bool isAccumSeqS1 = false; // 是否累加模式 + bool isAccumSeqS2 = false; // 是否累加模式 +}; + +struct SplitCoreInfo { + uint32_t s2Start = 0U; // S2的起始位置 + uint32_t s2End = 0U; // S2循环index上限 + uint32_t bN2Start = 0U; + uint32_t bN2End = 0U; + uint32_t gS1Start = 0U; + uint32_t gS1End = 0U; + bool isLD = false; // 当前核是否需要进行Decode归约任务 +}; + +template +__aicore__ inline T Align(T num, T rnd) +{ + return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd) * (rnd))); +} + +template +__aicore__ inline T1 Min(T1 a, T2 b) +{ + return (a > b) ? (b) : (a); +} + +template +__aicore__ inline T1 Max(T1 a, T2 b) +{ + return (a > b) ? (a) : (b); +} + +template +__aicore__ inline T CeilDiv(T num, T rnd) +{ + return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd))); +} +} // namespace LIQCommon + +#endif // LIGHTNING_INDEXER_QUANT_COMMON_H \ No newline at end of file diff --git a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_kernel.h b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_kernel.h new file mode 100644 index 00000000..723255d3 --- /dev/null +++ b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_kernel.h @@ -0,0 +1,714 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_quant_kernel.h + * \brief + */ + +#ifndef LIGHTNING_INDEXER_QUANT_KERNEL_H +#define LIGHTNING_INDEXER_QUANT_KERNEL_H + +#include "kernel_operator.h" +#include "kernel_operator_list_tensor_intf.h" +#include "kernel_tiling/kernel_tiling.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" +#include "lightning_indexer_quant_common.h" +#include "lightning_indexer_quant_service_vector.h" +#include "lightning_indexer_quant_service_cube.h" + +namespace LIQKernel { +using namespace LIQCommon; +using namespace LIQServiceVec; +using namespace matmul; +using AscendC::CacheMode; +using AscendC::CrossCoreSetFlag; +using AscendC::CrossCoreWaitFlag; + +// 由于S2循环前,RunInfo还没有赋值,使用TempLoopInfo临时存放B、N、S1轴相关的信息;同时减少重复计算 +struct TempLoopInfo { + uint32_t bN2Idx = 0; + uint32_t bIdx = 0U; + uint32_t n2Idx = 0U; + uint32_t gS1Idx = 0U; + uint32_t gS1LoopEnd = 0U; // gS1方向循环的结束Idx + uint32_t s2LoopEnd = 0U; // S2方向循环的结束Idx + uint32_t actS1Size = 1ULL; // 当前Batch循环处理的S1轴的实际大小 + uint32_t actS2Size = 0ULL; + bool curActSeqLenIsZero = false; + bool needDealActS1LessThanS1 = false; // S1的实际长度小于shape的S1长度时,是否需要清理输出 + uint32_t actMBaseSize = 0U; // m轴(gS1)方向实际大小 + uint32_t mBasicSizeTail = 0U; // gS1方向循环的尾基本块大小 + uint32_t s2BasicSizeTail = 0U; // S2方向循环的尾基本块大小 +}; + +template +class LIQPreload { +public: + __aicore__ inline LIQPreload(){}; + __aicore__ inline void Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights, + __gm__ uint8_t *queryScale, __gm__ uint8_t *keyScale, __gm__ uint8_t *actualSeqLengthsQ, + __gm__ uint8_t *actualSeqLengthsK, __gm__ uint8_t *blockTable, + __gm__ uint8_t *sparseIndices, __gm__ uint8_t *workspace, + const LIQTilingData *__restrict tiling, TPipe *tPipe); + __aicore__ inline void Process(); + + // =================================类型定义区================================= + using Q_T = typename LIQT::queryType; + using K_T = typename LIQT::keyType; + using OUT_T = typename LIQT::outputType; + static constexpr bool PAGE_ATTENTION = LIQT::pageAttention; + static constexpr LI_LAYOUT Q_LAYOUT_T = LIQT::layout; + static constexpr LI_LAYOUT K_LAYOUT_T = LIQT::keyLayout; + + using MM1_OUT_T = float; + + LIQMatmul matmulService; + LIQVector vectorService; + + // =================================常量区================================= + static constexpr uint32_t SYNC_C1_V1_FLAG = 4; + static constexpr uint32_t SYNC_V1_C1_FLAG = 5; + + static constexpr uint32_t M_BASE_SIZE = 256; + static constexpr uint32_t S2_BASE_SIZE = 2048; + static constexpr uint32_t HEAD_DIM = 128; + static constexpr uint32_t K_HEAD_NUM = 1; + static constexpr uint32_t GM_ALIGN_BYTES = 512; + static constexpr uint32_t LI_QUANT_PRELOAD_TASK_CACHE_SIZE = 2; + + static constexpr int64_t LD_PREFETCH_LEN = 2; + // for workspace double + static constexpr uint32_t WS_DOBULE = 2; + +protected: + TPipe *pipe = nullptr; + + // offset + uint64_t queryCoreOffset = 0ULL; + uint64_t keyCoreOffset = 0ULL; + uint64_t keyScaleCoreOffset = 0ULL; + uint64_t weightsCoreOffset = 0ULL; + uint64_t indiceOutCoreOffset = 0ULL; + + // ================================Global Buffer区================================= + GlobalTensor queryGm; + GlobalTensor keyGm; + GlobalTensor weightsGm; + + GlobalTensor indiceOutGm; + GlobalTensor blockTableGm; + + GlobalTensor actualSeqLengthsGmQ; + GlobalTensor actualSeqLengthsGm; + + // ================================类成员变量==================================== + // aic、aiv核信息 + uint32_t tmpBlockIdx = 0U; + uint32_t aiCoreIdx = 0U; + uint32_t usedCoreNum = 0U; + + LIQCommon::ConstInfo constInfo{}; + TempLoopInfo tempLoopInfo{}; + LIQCommon::SplitCoreInfo splitCoreInfo{}; + + // ================================Init functions================================== + __aicore__ inline void InitTilingData(const LIQTilingData *__restrict tilingData); + __aicore__ inline void InitBuffers(); + __aicore__ inline void InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengthsK); + // ================================Split Core================================ + __aicore__ inline void SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, LIQCommon::SplitCoreInfo &info); + __aicore__ inline uint32_t GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size, uint32_t actS2Size); + __aicore__ inline uint32_t GetTotalBaseBlockNum(); + // ================================Process functions================================ + __aicore__ inline void ProcessMain(); + __aicore__ inline void ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx, + LIQCommon::RunInfo runInfo[LI_QUANT_PRELOAD_TASK_CACHE_SIZE]); + __aicore__ inline void ProcessDecode(); + __aicore__ inline void ProcessInvalid(); + // ================================Params Calc===================================== + __aicore__ inline void CalcGS1LoopParams(uint32_t bN2Idx); + __aicore__ inline void GetBN2Idx(uint32_t bN2Idx); + __aicore__ inline uint32_t GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq, + GlobalTensor &actualSeqLengthsGm, uint32_t defaultSeqLen); + __aicore__ inline void GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size); + __aicore__ inline void CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx); + __aicore__ inline void CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LIQCommon::RunInfo &runInfo); + __aicore__ inline void DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start); +}; + +template +__aicore__ inline void LIQPreload::InitTilingData(const LIQTilingData *__restrict tilingData) +{ + usedCoreNum = tilingData->usedCoreNum; + constInfo.batchSize = tilingData->bSize; + constInfo.qHeadNum = constInfo.gSize = tilingData->gSize; + constInfo.kSeqSize = tilingData->s2Size; + constInfo.qSeqSize = tilingData->s1Size; + constInfo.attenMaskFlag = (tilingData->sparseMode == 3); + constInfo.kCacheBlockSize = tilingData->blockSize; + constInfo.maxBlockNumPerBatch = tilingData->maxBlockNumPerBatch; + constInfo.sparseCount = tilingData->sparseCount; + constInfo.outputLayout = Q_LAYOUT_T; // 输出和输入形状一致 + if (Q_LAYOUT_T == LI_LAYOUT::TND) { + constInfo.isAccumSeqS1 = true; + } + if (K_LAYOUT_T == LI_LAYOUT::TND) { + constInfo.isAccumSeqS2 = true; + } + + constInfo.kHeadNum = K_HEAD_NUM; + constInfo.headDim = HEAD_DIM; + + constInfo.mBaseSize = M_BASE_SIZE; + constInfo.s2BaseSize = S2_BASE_SIZE; + constInfo.s1BaseSize = (constInfo.mBaseSize + constInfo.gSize - 1) / constInfo.gSize; +} + +template +__aicore__ inline void LIQPreload::InitBuffers() +{ + if ASCEND_IS_AIV { + vectorService.InitBuffers(pipe); + } else { + matmulService.InitBuffers(pipe); + } +} + +template +__aicore__ inline void LIQPreload::InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, + __gm__ uint8_t *actualSeqLengthsK) +{ + if (actualSeqLengthsQ == nullptr) { + constInfo.actualLenQDims = 0; + } else { + constInfo.actualLenQDims = constInfo.batchSize; + actualSeqLengthsGmQ.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengthsQ, constInfo.actualLenQDims); + } + if (actualSeqLengthsK == nullptr) { + constInfo.actualLenDims = 0; + } else { + constInfo.actualLenDims = constInfo.batchSize; + actualSeqLengthsGm.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengthsK, constInfo.actualLenDims); + } +} + +template +__aicore__ inline uint32_t LIQPreload::GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq, + GlobalTensor &actualSeqLengthsGm, + uint32_t defaultSeqLen) +{ + if (actualLenDims == 0) { + return defaultSeqLen; + } else if (isAccumSeq && bIdx > 0) { + return actualSeqLengthsGm.GetValue(bIdx) - actualSeqLengthsGm.GetValue(bIdx - 1); + } else { + return actualSeqLengthsGm.GetValue(bIdx); + } +} + +template +__aicore__ inline void LIQPreload::GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size) +{ + actS1Size = GetActualSeqLen(bIdx, constInfo.actualLenQDims, constInfo.isAccumSeqS1, actualSeqLengthsGmQ, + constInfo.qSeqSize); + actS2Size = + GetActualSeqLen(bIdx, constInfo.actualLenDims, constInfo.isAccumSeqS2, actualSeqLengthsGm, constInfo.kSeqSize); +} + +template +__aicore__ inline uint32_t LIQPreload::GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size, + uint32_t actS2Size) +{ + if (actS2Size == 0) { + return 0; + } + uint32_t s1Offset = constInfo.s1BaseSize * s1gIdx; + int32_t validS2LenBase = static_cast(actS2Size) - static_cast(actS1Size); + int32_t validS2Len = s1Offset + validS2LenBase + constInfo.s1BaseSize; + validS2Len = Min(validS2Len, static_cast(actS2Size)); + validS2Len = Max(validS2Len, 1); + return (validS2Len + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize; +} + +template +__aicore__ inline uint32_t LIQPreload::GetTotalBaseBlockNum() +{ + uint32_t totalBlockNum = 0; + uint32_t actS1Size, actS2Size; + uint32_t s1GBaseNum, s2BaseNum; + for (uint32_t bIdx = 0; bIdx < constInfo.batchSize; bIdx++) { + GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size); + s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize); + if (!constInfo.attenMaskFlag) { + s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize); + totalBlockNum += s1GBaseNum * s2BaseNum * constInfo.kHeadNum; + continue; + } + for (uint32_t s1gIdx = 0; s1gIdx < s1GBaseNum; s1gIdx++) { + s2BaseNum = GetS2BaseBlockNumOnMask(s1gIdx, actS1Size, actS2Size); + totalBlockNum += s2BaseNum * constInfo.kHeadNum; + } + } + return totalBlockNum; +} + +// 多核版本,双闭区间。基本原则:计算每个核最少处理的块数, 剩余的部分前面的核每个核多处理一块 +template +__aicore__ void inline LIQPreload::SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, + LIQCommon::SplitCoreInfo &info) +{ + uint32_t totalBlockNum = GetTotalBaseBlockNum(); + uint32_t minBlockPerCore = totalBlockNum / coreNum; + uint32_t deal1MoreBlockCoreNum = totalBlockNum % coreNum; + uint32_t coreIdx = 0; + uint32_t lastGS1RemainBlockCnt = 0; + uint32_t coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore; + coreNum = minBlockPerCore == 0 ? deal1MoreBlockCoreNum : coreNum; + + bool findLastCoreEnd = true; + uint32_t actS1Size, actS2Size; + uint32_t s1GBaseNum, s2BaseNum; + for (uint32_t bN2Idx = 0; bN2Idx < constInfo.batchSize * constInfo.kHeadNum; bN2Idx++) { + uint32_t bIdx = bN2Idx / constInfo.kHeadNum; + if (bN2Idx % constInfo.kHeadNum == 0) { + GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size); + s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize); + s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize); + } + if constexpr (Q_LAYOUT_T == LI_LAYOUT::BSND) { + if (findLastCoreEnd && (s1GBaseNum == 0U || s2BaseNum == 0U)) { + info.bN2Start = bN2Idx; + info.gS1Start = 0; + info.s2Start = 0; + findLastCoreEnd = false; + } + } + for (uint32_t gS1Idx = 0; gS1Idx < s1GBaseNum; gS1Idx++) { + if (constInfo.attenMaskFlag) { + s2BaseNum = GetS2BaseBlockNumOnMask(gS1Idx, actS1Size, actS2Size); + } + if (findLastCoreEnd && s2BaseNum == 0U) { + info.bN2Start = bN2Idx; + info.gS1Start = gS1Idx; + info.s2Start = 0; + findLastCoreEnd = false; + } + for (uint32_t s2Idx = 0; s2Idx < s2BaseNum;) { + if (findLastCoreEnd) { + info.bN2Start = bN2Idx; + info.gS1Start = gS1Idx; + info.s2Start = s2Idx; + findLastCoreEnd = false; + } + uint32_t s2RemainBaseNum = s2BaseNum - s2Idx; + if (lastGS1RemainBlockCnt + s2RemainBaseNum >= coreDealBlockCnt) { + info.bN2End = bN2Idx; + info.gS1End = gS1Idx; + info.s2End = s2Idx + coreDealBlockCnt - lastGS1RemainBlockCnt - 1; + + if (coreIdx == curCoreIdx) { + // S2被切N核,那么只有第一个核需要处理LD,其他核不用 + if (s2Idx == 0 && info.s2End + 1 < s2BaseNum) { + info.isLD = true; + } + // 最后一个核处理的不是最后一个Batch,表明后面的Batch为空块(S2=0), 调整终点坐标以便清理输出 + if (coreIdx == coreNum - 1 && info.bN2End != constInfo.batchSize - 1) { + info.bN2End = constInfo.batchSize - 1; + info.gS1End = 0; + info.s2End = 0; + } + return; + } + coreIdx++; + findLastCoreEnd = true; + s2Idx = info.s2End + 1; + lastGS1RemainBlockCnt = 0; + coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore; + } else { + lastGS1RemainBlockCnt += s2RemainBaseNum; + break; + } + } + } + } +} + +template +__aicore__ inline void LIQPreload::DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start) +{ + if ASCEND_IS_AIV { + if (constInfo.outputLayout == LI_LAYOUT::TND) { + uint32_t tSize = actualSeqLengthsGmQ.GetValue(constInfo.batchSize - 1); + uint32_t tBase = bIdx == 0 ? 0 : actualSeqLengthsGmQ.GetValue(bIdx - 1); + uint32_t s1Count = tempLoopInfo.actS1Size; + + for (uint32_t s1Idx = s1Start; s1Idx < s1Count; s1Idx++) { + uint64_t indiceOutOffset = + (tBase + s1Idx) * constInfo.kHeadNum * constInfo.sparseCount + // T轴、s1轴偏移 + n2Idx * constInfo.sparseCount; // N2轴偏移 + vectorService.CleanInvalidOutput(indiceOutOffset); + } + } else if (constInfo.outputLayout == LI_LAYOUT::BSND) { + for (uint32_t s1Idx = s1Start; s1Idx < constInfo.qSeqSize; s1Idx++) { + // B,S1,N2,K + uint64_t indiceOutOffset = bIdx * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount + + s1Idx * constInfo.kHeadNum * constInfo.sparseCount + // B轴、S1轴偏移 + n2Idx * constInfo.sparseCount; // N2轴偏移 + vectorService.CleanInvalidOutput(indiceOutOffset); + } + } + } +} + +template +__aicore__ inline void LIQPreload::Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights, + __gm__ uint8_t *queryScale, __gm__ uint8_t *keyScale, + __gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengthsK, + __gm__ uint8_t *blockTable, __gm__ uint8_t *sparseIndices, + __gm__ uint8_t *workspace, const LIQTilingData *__restrict tiling, + TPipe *tPipe) +{ + if ASCEND_IS_AIV { + tmpBlockIdx = GetBlockIdx(); // vec:0-47 + aiCoreIdx = tmpBlockIdx / 2; + } else { + tmpBlockIdx = GetBlockIdx(); // cube:0-23 + aiCoreIdx = tmpBlockIdx; + } + + InitTilingData(tiling); + InitActualSeqLen(actualSeqLengthsQ, actualSeqLengthsK); + + // 计算分核 + SplitCore(aiCoreIdx, usedCoreNum, splitCoreInfo); + + pipe = tPipe; + // workspace 内存排布 + // |mm1ResGm(存S)|vec1ResGm(存LD中间结果)|vec1ParamGm(存LD参数) + // |Core0_mm1ResDB0-Core0_mm1ResDB1-Core1_mm1ResDB0....Core23_mm1ResDB0-Core23_mm1ResDB1|Core0_vec1Res... + uint64_t offset = 0; + + // mm1开DoubleBuffer + GlobalTensor mm1ResGm; // 存放S + uint64_t singleCoreMm1ResSize = WS_DOBULE * constInfo.s1BaseSize * constInfo.s2BaseSize * sizeof(MM1_OUT_T); + mm1ResGm.SetGlobalBuffer((__gm__ MM1_OUT_T *)(workspace + aiCoreIdx * singleCoreMm1ResSize)); + offset += GetBlockNum() * singleCoreMm1ResSize; + + // ld流程需要ws大小: [aicnum, 2, CeilDiv(constInfo.mBaseSize, constInfo.gSize), topkOut_*2] + // (aic, 8, 2, 2, 2048) + // (aic, s1_cube, 头尾, idx/value, K) + GlobalTensor vec1ResGm; // 存放TopK计算中间结果 + vec1ResGm.SetGlobalBuffer((__gm__ float *)(workspace + offset)); + offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * WS_DOBULE * BASE_TOPK * sizeof(float); + + // (aic, 8, 2, 16) + // (aic, s1_cube, 头尾,16ele) + GlobalTensor vec1ParamGm; // 存放LD参数信息 + vec1ParamGm.SetGlobalBuffer((__gm__ int64_t *)(workspace + offset)); + offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * LD_PARAM_NUM * sizeof(int64_t); + + GlobalTensor weightWorkspaceGm; // v1阶段处理w*scale后的结果 + uint64_t weightMemSize = BLOCK_CUBE * constInfo.mBaseSize * WS_DOBULE * sizeof(half); + weightWorkspaceGm.SetGlobalBuffer((__gm__ half *)(workspace + offset + aiCoreIdx * weightMemSize)); + offset += GetBlockNum() * weightMemSize; + + GlobalTensor qScaleGm; + GlobalTensor kScaleGm; + if ASCEND_IS_AIV { + vectorService.InitParams(constInfo, tiling); + indiceOutGm.SetGlobalBuffer((__gm__ int32_t *)sparseIndices); + weightsGm.SetGlobalBuffer((__gm__ half *)weights); + qScaleGm.SetGlobalBuffer((__gm__ half *)queryScale); + kScaleGm.SetGlobalBuffer((__gm__ half *)keyScale); + blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable); + vectorService.InitVecInputTensor(weightsGm, qScaleGm, kScaleGm, indiceOutGm, blockTableGm); + vectorService.InitVecWorkspaceTensor(weightWorkspaceGm, mm1ResGm, vec1ResGm, vec1ParamGm); + } else { + matmulService.InitParams(constInfo); + queryGm.SetGlobalBuffer((__gm__ Q_T *)query); + if constexpr (PAGE_ATTENTION) { + blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable); + } + keyGm.SetGlobalBuffer((__gm__ K_T *)key); + matmulService.InitMm1GlobalTensor(blockTableGm, keyGm, queryGm, mm1ResGm, weightWorkspaceGm); + } + InitBuffers(); +} + +template +__aicore__ inline void LIQPreload::GetBN2Idx(uint32_t bN2Idx) +{ + tempLoopInfo.bN2Idx = bN2Idx; + tempLoopInfo.bIdx = bN2Idx / constInfo.kHeadNum; + tempLoopInfo.n2Idx = bN2Idx % constInfo.kHeadNum; +} + +template +__aicore__ inline void LIQPreload::CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx) +{ + tempLoopInfo.gS1Idx = gS1LoopIdx; + tempLoopInfo.actMBaseSize = constInfo.mBaseSize; + uint32_t remainedGS1Size = tempLoopInfo.actS1Size * constInfo.gSize - tempLoopInfo.gS1Idx * constInfo.mBaseSize; + if (remainedGS1Size <= constInfo.mBaseSize && remainedGS1Size > 0) { + tempLoopInfo.actMBaseSize = tempLoopInfo.mBasicSizeTail; + } + + bool isEnd = (bN2LoopIdx == splitCoreInfo.bN2End) && (gS1LoopIdx == splitCoreInfo.gS1End); + uint32_t s2BlockNum; + if (constInfo.attenMaskFlag) { + s2BlockNum = GetS2BaseBlockNumOnMask(gS1LoopIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size); + } else { + s2BlockNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize; + } + tempLoopInfo.s2LoopEnd = isEnd ? splitCoreInfo.s2End : s2BlockNum - 1; +} + +template +__aicore__ inline void LIQPreload::CalcGS1LoopParams(uint32_t bN2LoopIdx) +{ + GetBN2Idx(bN2LoopIdx); + GetS1S2ActualSeqLen(tempLoopInfo.bIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size); + if ((tempLoopInfo.actS2Size == 0) || (tempLoopInfo.actS1Size == 0)) { + tempLoopInfo.curActSeqLenIsZero = true; + return; + } + tempLoopInfo.curActSeqLenIsZero = false; + tempLoopInfo.s2BasicSizeTail = tempLoopInfo.actS2Size % constInfo.s2BaseSize; + tempLoopInfo.s2BasicSizeTail = + (tempLoopInfo.s2BasicSizeTail == 0) ? constInfo.s2BaseSize : tempLoopInfo.s2BasicSizeTail; + tempLoopInfo.mBasicSizeTail = (tempLoopInfo.actS1Size * constInfo.gSize) % constInfo.mBaseSize; + tempLoopInfo.mBasicSizeTail = + (tempLoopInfo.mBasicSizeTail == 0) ? constInfo.mBaseSize : tempLoopInfo.mBasicSizeTail; + + uint32_t gS1SplitNum = (tempLoopInfo.actS1Size * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize; + tempLoopInfo.gS1LoopEnd = (bN2LoopIdx == splitCoreInfo.bN2End) ? splitCoreInfo.gS1End : gS1SplitNum - 1; + if constexpr (Q_LAYOUT_T == LI_LAYOUT::BSND) { + if (tempLoopInfo.gS1LoopEnd == gS1SplitNum - 1 && constInfo.qSeqSize > tempLoopInfo.actS1Size) { + tempLoopInfo.needDealActS1LessThanS1 = true; + } + } +} + +template +__aicore__ inline void LIQPreload::CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LIQCommon::RunInfo &runInfo) +{ + runInfo.loop = loop; + runInfo.bIdx = tempLoopInfo.bIdx; + runInfo.gS1Idx = tempLoopInfo.gS1Idx; + runInfo.s2Idx = s2LoopIdx; + runInfo.bN2Idx = tempLoopInfo.bN2Idx; + runInfo.isValid = s2LoopIdx <= tempLoopInfo.s2LoopEnd; + + if (!runInfo.isValid) { + return; // 需要验证, v1 时候需要runInfo + } + + runInfo.actS1Size = tempLoopInfo.actS1Size; + runInfo.actS2Size = tempLoopInfo.actS2Size; + // 计算实际基本块size + runInfo.actMBaseSize = tempLoopInfo.actMBaseSize; + runInfo.actualSingleProcessSInnerSize = constInfo.s2BaseSize; + uint32_t s2SplitNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize; + if (runInfo.s2Idx == s2SplitNum - 1) { + runInfo.actualSingleProcessSInnerSize = tempLoopInfo.s2BasicSizeTail; + } + runInfo.actualSingleProcessSInnerSizeAlign = + LIQCommon::Align((uint32_t)runInfo.actualSingleProcessSInnerSize, LIQCommon::ConstInfo::BUFFER_SIZE_BYTE_32B); + + runInfo.isFirstS2InnerLoop = s2LoopIdx == splitCoreInfo.s2Start; + runInfo.isLastS2InnerLoop = s2LoopIdx == tempLoopInfo.s2LoopEnd; + runInfo.isAllLoopEnd = (runInfo.bN2Idx == splitCoreInfo.bN2End) && (runInfo.gS1Idx == splitCoreInfo.gS1End) && + (runInfo.s2Idx == splitCoreInfo.s2End); + + if (runInfo.isFirstS2InnerLoop) { + uint64_t actualSeqQPrefixSum; + if constexpr (Q_LAYOUT_T == LI_LAYOUT::TND) { + actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGmQ.GetValue(runInfo.bIdx - 1); + } else { // BSND + actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.qSeqSize; + } + uint64_t tndBIdxOffset = actualSeqQPrefixSum * constInfo.qHeadNum * constInfo.headDim; + // B,S1,N1(N2,G),D + queryCoreOffset = tndBIdxOffset + runInfo.gS1Idx * constInfo.mBaseSize * constInfo.headDim; + // B,S1,N1(N2,G)/T,N1(N2,G) + weightsCoreOffset = actualSeqQPrefixSum * constInfo.qHeadNum + runInfo.n2Idx * constInfo.gSize; + // B,S1,N2,k/T,N2,k + indiceOutCoreOffset = + actualSeqQPrefixSum * constInfo.kHeadNum * constInfo.sparseCount + runInfo.n2Idx * constInfo.sparseCount; + } + uint64_t actualSeqKPrefixSum; + if constexpr (K_LAYOUT_T == LI_LAYOUT::TND) { // T N2 D + actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGm.GetValue(runInfo.bIdx - 1); + } else { + actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.kSeqSize; + } + uint64_t tndBIdxOffsetForK = actualSeqKPrefixSum * constInfo.kHeadNum * constInfo.headDim; + keyCoreOffset = tndBIdxOffsetForK + runInfo.s2Idx * constInfo.s2BaseSize * constInfo.kHeadNum * constInfo.headDim; + keyScaleCoreOffset = (actualSeqKPrefixSum + runInfo.s2Idx * constInfo.s2BaseSize) * constInfo.kHeadNum; + runInfo.tensorQueryOffset = queryCoreOffset; + runInfo.tensorKeyOffset = keyCoreOffset; + runInfo.tensorKeyScaleOffset = keyScaleCoreOffset; + runInfo.tensorWeightsOffset = weightsCoreOffset; + runInfo.indiceOutOffset = indiceOutCoreOffset; +} + +template +__aicore__ inline void LIQPreload::Process() +{ + if (usedCoreNum == 0) { + // 没有计算任务,直接清理输出 + ProcessInvalid(); + return; + } + + ProcessMain(); + + ProcessDecode(); +} + +template +__aicore__ inline void LIQPreload::ProcessInvalid() +{ + if ASCEND_IS_AIV { + uint32_t aivCoreNum = GetBlockNum() * 2; // 2 means c:v = 1:2 + uint64_t totalOutputSize = + constInfo.batchSize * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount; + uint64_t singleCoreSize = + LIQCommon::Align((totalOutputSize + aivCoreNum - 1) / aivCoreNum, GM_ALIGN_BYTES / sizeof(OUT_T)); + uint64_t baseSize = tmpBlockIdx * singleCoreSize; + if (baseSize < totalOutputSize) { + uint64_t dealSize = + (baseSize + singleCoreSize <= totalOutputSize) ? singleCoreSize : totalOutputSize - baseSize; + GlobalTensor output = indiceOutGm[baseSize]; + AscendC::InitGlobalMemory(output, dealSize, constInfo.INVALID_IDX); + } + } +} + +template +__aicore__ inline void LIQPreload::ProcessMain() +{ + if (aiCoreIdx >= usedCoreNum) { + // 无任务核直接返回 + return; + } + + if ASCEND_IS_AIV { + vectorService.AllocEventID(); + CrossCoreSetFlag(constInfo.syncV1C1); + CrossCoreSetFlag(constInfo.syncV1C1); + } else { + matmulService.AllocEventID(); + CrossCoreSetFlag(constInfo.syncC1V0); + CrossCoreSetFlag(constInfo.syncC1V0); + } + + LIQCommon::RunInfo runInfo[LI_QUANT_PRELOAD_TASK_CACHE_SIZE]; + + uint32_t gloop = 0; + for (uint32_t bN2LoopIdx = splitCoreInfo.bN2Start; bN2LoopIdx <= splitCoreInfo.bN2End; bN2LoopIdx++) { + CalcGS1LoopParams(bN2LoopIdx); + if (tempLoopInfo.curActSeqLenIsZero) { + DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, 0U); + + if ASCEND_IS_AIV { + if (bN2LoopIdx == splitCoreInfo.bN2End && gloop > 0) { + CrossCoreWaitFlag(constInfo.syncC1V1); + vectorService.ProcessVec1(runInfo[1 - gloop % LI_QUANT_PRELOAD_TASK_CACHE_SIZE]); + CrossCoreSetFlag( + constInfo.syncV1C1); // 反向同步 1 + } + } + continue; + } + for (uint32_t gS1LoopIdx = splitCoreInfo.gS1Start; gS1LoopIdx <= tempLoopInfo.gS1LoopEnd; gS1LoopIdx++) { + CalcS2LoopParams(bN2LoopIdx, gS1LoopIdx); + bool isEnd = (bN2LoopIdx == splitCoreInfo.bN2End) && (gS1LoopIdx == splitCoreInfo.gS1End); + uint32_t extraLoop = isEnd ? LI_QUANT_PRELOAD_TASK_CACHE_SIZE - 1 : 0; + for (int s2LoopIdx = splitCoreInfo.s2Start; s2LoopIdx <= (tempLoopInfo.s2LoopEnd + extraLoop); + s2LoopIdx++) { + ProcessBaseBlock(gloop, s2LoopIdx, runInfo); + ++gloop; + } + splitCoreInfo.s2Start = 0; + } + if (tempLoopInfo.needDealActS1LessThanS1) { + DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, tempLoopInfo.actS1Size); + } + splitCoreInfo.gS1Start = 0; + } + + if ASCEND_IS_AIV { + vectorService.FreeEventID(); + CrossCoreWaitFlag(constInfo.syncC1V0); + CrossCoreWaitFlag(constInfo.syncC1V0); + } else { + matmulService.FreeEventID(); + CrossCoreWaitFlag(constInfo.syncV1C1); + CrossCoreWaitFlag(constInfo.syncV1C1); + } +} + +template +__aicore__ inline void LIQPreload::ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx, + LIQCommon::RunInfo runInfo[LI_QUANT_PRELOAD_TASK_CACHE_SIZE]) +{ + int32_t curTaskId = loop % LI_QUANT_PRELOAD_TASK_CACHE_SIZE; + LIQCommon::RunInfo &curRunInfo = runInfo[curTaskId]; + LIQCommon::RunInfo &lastRunInfo = runInfo[1 - curTaskId]; + + CalcRunInfo(loop, s2LoopIdx, curRunInfo); + + if (curRunInfo.isValid) { + if ASCEND_IS_AIC { + if (curRunInfo.isFirstS2InnerLoop) { + CrossCoreWaitFlag(constInfo.syncV0C1); + } + CrossCoreWaitFlag(constInfo.syncV1C1); // 反向同步 1 + matmulService.ComputeMm1(curRunInfo); + CrossCoreSetFlag(constInfo.syncC1V1); + if (curRunInfo.isLastS2InnerLoop) { + CrossCoreSetFlag(constInfo.syncC1V0); // 反向同步 0 + } + } else { + if (curRunInfo.isFirstS2InnerLoop) { + CrossCoreWaitFlag(constInfo.syncC1V0); // 反向同步 0 + vectorService.ProcessVec0(curRunInfo); + CrossCoreSetFlag(constInfo.syncV0C1); + } + } + } + + if (lastRunInfo.isValid) { + if ASCEND_IS_AIV { + CrossCoreWaitFlag(constInfo.syncC1V1); + vectorService.ProcessVec1(lastRunInfo); + CrossCoreSetFlag(constInfo.syncV1C1); // 反向同步 1 + } + lastRunInfo.isValid = false; + } +} + +template +__aicore__ inline void LIQPreload::ProcessDecode() +{ + if ASCEND_IS_AIV { + vectorService.InitLDBuffers(pipe); + ICachePreLoad(LD_PREFETCH_LEN); + SyncAll(); + if (splitCoreInfo.isLD) { + vectorService.ProcessLD(); + } + } +} +} // namespace LIQKernel +#endif // LIGHTNING_INDEXER_QUANT_KERNEL_H \ No newline at end of file diff --git a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_service_cube.h b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_service_cube.h new file mode 100644 index 00000000..2f58a9e1 --- /dev/null +++ b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_service_cube.h @@ -0,0 +1,613 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_quant_service_cube.h + * \brief use 5 buffer for matmul l1, better pipeline + */ +#ifndef LIGHTNING_INDEXER_QUANT_SERVICE_CUBE_H +#define LIGHTNING_INDEXER_QUANT_SERVICE_CUBE_H + +#include "kernel_operator.h" +#include "kernel_operator_list_tensor_intf.h" +#include "kernel_tiling/kernel_tiling.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" +#include "lightning_indexer_quant_common.h" + +namespace LIQKernel { +using namespace LIQCommon; +struct MmInfo { + int64_t s2L0LoopId; + int64_t s1gL0LoopId; + int64_t s2L0RealSize; + int64_t s2GmOffset; +}; + +template +class LIQMatmul { +public: + using Q_T = typename LIQT::queryType; + using K_T = typename LIQT::keyType; + + __aicore__ inline LIQMatmul(){}; + __aicore__ inline void InitBuffers(TPipe *pipe); + __aicore__ inline void InitMm1GlobalTensor(const GlobalTensor &blkTableGm, const GlobalTensor &keyGm, + const GlobalTensor &queryGm, const GlobalTensor &mm1ResGm, + const GlobalTensor &weightWorkspaceGm); + __aicore__ inline void InitParams(const ConstInfo &constInfo); + __aicore__ inline void AllocEventID(); + __aicore__ inline void FreeEventID(); + __aicore__ inline void ComputeMm1(const LIQCommon::RunInfo &runInfo); + + static constexpr IsResetLoad3dConfig LOAD3DV2_CONFIG = {true, true}; // isSetFMatrix isSetPadding; + static constexpr uint64_t DOUBLE_BUF_NUM = 2; + static constexpr uint64_t L0AB_BUF_NUM = 4; + + static constexpr uint32_t KEY_MTE1_MTE2_EVENT = EVENT_ID2; + static constexpr uint32_t QW_MTE1_MTE2_EVENT = EVENT_ID5; // KEY_MTE1_MTE2_EVENT + DOUBLE_BUF_NUM; + static constexpr uint32_t M_MTE1_EVENT = EVENT_ID3; + static constexpr uint32_t M_FIX_EVENT = EVENT_ID0; + static constexpr uint32_t FIX_M_EVENT = EVENT_ID2; + static constexpr uint32_t FIX_MTE1_EVENT = EVENT_ID4; + + static constexpr uint64_t S8_BLOCK_CUBE = 32; + + static constexpr uint32_t MTE2_MTE1_EVENT = EVENT_ID2; + static constexpr uint32_t MTE1_M_EVENT = EVENT_ID2; + + static constexpr uint64_t D_BASIC_BLOCK = 128; + static constexpr uint64_t S1G_BASIC_BLOCK_L1 = 256; + + static constexpr uint64_t S1G_BASIC_BLOCK_L0 = 128; + static constexpr uint64_t S2_BASIC_BLOCK_L0 = 128; + + static constexpr uint64_t QUERY_BUFFER_OFFSET = S1G_BASIC_BLOCK_L1 * D_BASIC_BLOCK; + static constexpr uint64_t SL1_BUFFER_OFFSET = S1G_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0; + static constexpr uint64_t KEY_BUFFER_OFFSET = S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK; + static constexpr uint64_t WEIGHT_BUFFER_OFFSET = S1G_BASIC_BLOCK_L1 * BLOCK_CUBE; + static constexpr uint64_t L0AB_BUFFER_OFFSET_S8_16K = 16 * 1024; + static constexpr uint64_t L0AB_BUFFER_OFFSET_FP16_16K = 16 * 512; + static constexpr uint64_t L0C_BUFFER_OFFSET = 64 * 256; + +private: + __aicore__ inline void WeightDmaCopy(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo); + __aicore__ inline void LoadKeyToL0b(uint64_t s2L0RealSize); + __aicore__ inline void LoadQueryToL0a(uint64_t s1gL1Offset, uint64_t s1gL1RealSize, uint64_t s1gL0RealSize); + __aicore__ inline void QueryNd2Nz(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo); + __aicore__ inline void KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LIQCommon::RunInfo &runInfo); + __aicore__ inline void KeyNd2Nz(uint64_t s2L1RealSize, const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo); + __aicore__ inline void FixpSToL1(uint64_t s1gL0RealSize, uint64_t s2L0RealSize); + __aicore__ inline void LoadSToL0b(uint64_t s1gL1RealSize, uint64_t s2L0RealSize, uint64_t sL1BufIdx, + int64_t mStartPt); + __aicore__ inline void LoadWeightToL0a(uint64_t s1gL1Offset); + __aicore__ inline void ComputeWs(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, int64_t s1gOffset); + __aicore__ inline void FixpResToGm(uint64_t s1L0RealCount, uint64_t s2L0RealSize, uint64_t s1GmOffset, + uint64_t s2GmOffset, const LIQCommon::RunInfo &runInfo); + __aicore__ inline void ComputeQk(uint64_t s1gL0RealSize, uint64_t s2L0RealSize); + __aicore__ inline void ProcessWs(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t sL1BufIdx, + const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo); + __aicore__ inline void ProcessQk(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t s1L0LoopCnt, + const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo); + __aicore__ inline void CalcMmInfo(MmInfo &mmInfo, uint64_t loopIdx, uint64_t s1L0LoopCnt, const MmInfo &lastMmInfo, + const LIQCommon::RunInfo &runInfo); + static constexpr LI_LAYOUT Q_LAYOUT_T = LIQT::layout; + static constexpr LI_LAYOUT K_LAYOUT_T = LIQT::keyLayout; + GlobalTensor blkTableGm_; + GlobalTensor keyGm_; + GlobalTensor queryGm_; + GlobalTensor weightGm_; + GlobalTensor mm1ResGm_; + + TBuf bufQL1_; + LocalTensor queryL1_; + TBuf bufKeyL1_; + LocalTensor keyL1_; + TBuf bufWeightL1_; + LocalTensor weightL1_; + TBuf bufSL1_; + LocalTensor sL1_; + + TBuf bufL0A_; + LocalTensor l0a_; + TBuf bufL0B_; + LocalTensor l0b_; + + TBuf bufL0C_; + LocalTensor cL0_; + + uint64_t keyL1BufIdx_ = 0; + uint64_t qwL1Mte2BufIdx_ = 0; + uint64_t sL1BufIdx_ = 0; + uint64_t l0BufIdx_ = 0; + uint64_t l0cBufIdx_ = 0; + + ConstInfo constInfo_; +}; + +template +__aicore__ inline void LIQMatmul::InitParams(const ConstInfo &constInfo) +{ + constInfo_ = constInfo; +} + +template +__aicore__ inline void LIQMatmul::InitBuffers(TPipe *pipe) +{ + pipe->InitBuffer(bufQL1_, DOUBLE_BUF_NUM * S1G_BASIC_BLOCK_L1 * D_BASIC_BLOCK * sizeof(Q_T)); + queryL1_ = bufQL1_.Get(); + pipe->InitBuffer(bufKeyL1_, DOUBLE_BUF_NUM * S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK * sizeof(K_T)); + keyL1_ = bufKeyL1_.Get(); + + pipe->InitBuffer(bufWeightL1_, DOUBLE_BUF_NUM * S1G_BASIC_BLOCK_L1 * BLOCK_CUBE * sizeof(half)); + weightL1_ = bufWeightL1_.Get(); + pipe->InitBuffer(bufSL1_, DOUBLE_BUF_NUM * S2_BASIC_BLOCK_L0 * S1G_BASIC_BLOCK_L0 * sizeof(half)); + sL1_ = bufSL1_.Get(); + + pipe->InitBuffer(bufL0A_, 64 * 1024); + l0a_ = bufL0A_.Get(); + pipe->InitBuffer(bufL0B_, 64 * 1024); + l0b_ = bufL0B_.Get(); + + pipe->InitBuffer(bufL0C_, 128 * 1024); + cL0_ = bufL0C_.Get(); +} + +template +__aicore__ inline void LIQMatmul::InitMm1GlobalTensor(const GlobalTensor &blkTableGm, + const GlobalTensor &keyGm, + const GlobalTensor &queryGm, + const GlobalTensor &mm1ResGm, + const GlobalTensor &weightWorkspaceGm) +{ + blkTableGm_ = blkTableGm; + keyGm_ = keyGm; + queryGm_ = queryGm; + mm1ResGm_ = mm1ResGm; + weightGm_ = weightWorkspaceGm; +} + +template +__aicore__ inline void LIQMatmul::ProcessWs(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t sL1BufIdx, + const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo) +{ + WaitFlag(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM); + for (int64_t s1gOffset = 0; s1gOffset < s1gL0RealSize; s1gOffset += constInfo_.gSize) { + WaitFlag(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM); + LoadSToL0b(s1gL0RealSize, mmInfo.s2L0RealSize, sL1BufIdx, s1gOffset); + LoadWeightToL0a(s1gOffset + s1gL1Offset); + + ComputeWs(s1gL0RealSize, mmInfo.s2L0RealSize, s1gOffset); + + SetFlag(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM); + l0BufIdx_++; + } + + FixpResToGm(s1gL0RealSize / constInfo_.gSize, mmInfo.s2L0RealSize, s1gL1Offset / constInfo_.gSize, + mmInfo.s2L0LoopId * S2_BASIC_BLOCK_L0, runInfo); + SetFlag(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM); + l0cBufIdx_++; +} + +template +__aicore__ inline void LIQMatmul::ProcessQk(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t s1L0LoopCnt, + const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo) +{ + if (mmInfo.s1gL0LoopId == 0) { + WaitFlag(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % DOUBLE_BUF_NUM); + if constexpr (K_LAYOUT_T == LI_LAYOUT::PA_BSND) { + KeyNd2NzForPA(mmInfo.s2L0RealSize, runInfo.s2Idx * constInfo_.s2BaseSize + mmInfo.s2GmOffset, runInfo); + } else { + KeyNd2Nz(mmInfo.s2L0RealSize, mmInfo, runInfo); + } + + SetFlag(MTE2_MTE1_EVENT); + WaitFlag(MTE2_MTE1_EVENT); + } + + WaitFlag(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM); + LoadQueryToL0a(s1gL1Offset, runInfo.actMBaseSize, s1gL0RealSize); + LoadKeyToL0b(mmInfo.s2L0RealSize); + + if (mmInfo.s1gL0LoopId + 1 >= s1L0LoopCnt) { + SetFlag(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % DOUBLE_BUF_NUM); + keyL1BufIdx_++; + } + + WaitFlag(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM); + ComputeQk(s1gL0RealSize, mmInfo.s2L0RealSize); + SetFlag(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM); + + FixpSToL1(s1gL0RealSize, mmInfo.s2L0RealSize); + SetFlag(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM); + l0BufIdx_++; + l0cBufIdx_++; +} + +template +__aicore__ inline void LIQMatmul::CalcMmInfo(MmInfo &mmInfo, uint64_t loopIdx, uint64_t s1L0LoopCnt, + const MmInfo &lastMmInfo, const LIQCommon::RunInfo &runInfo) +{ + mmInfo.s2L0LoopId = loopIdx / s1L0LoopCnt; + mmInfo.s1gL0LoopId = loopIdx % s1L0LoopCnt; + + if (mmInfo.s1gL0LoopId == 0) { + mmInfo.s2GmOffset = mmInfo.s2L0LoopId * S2_BASIC_BLOCK_L0; + mmInfo.s2L0RealSize = mmInfo.s2GmOffset + S2_BASIC_BLOCK_L0 > runInfo.actualSingleProcessSInnerSize + ? runInfo.actualSingleProcessSInnerSize - mmInfo.s2GmOffset + : S2_BASIC_BLOCK_L0; + } else { + mmInfo.s2L0RealSize = lastMmInfo.s2L0RealSize; + } +} + +template +__aicore__ inline void LIQMatmul::ComputeMm1(const LIQCommon::RunInfo &runInfo) +{ + if (runInfo.isFirstS2InnerLoop) { + WaitFlag(QW_MTE1_MTE2_EVENT + qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM); + QueryNd2Nz(runInfo.actMBaseSize, runInfo); // 256 * 128 // L1BasicBlock + WeightDmaCopy(runInfo.actMBaseSize, runInfo); + } + int64_t loopIdx = 0; + int64_t s2L0LoopCnt = CeilDiv(runInfo.actualSingleProcessSInnerSize, S2_BASIC_BLOCK_L0); // 2048取128 + int64_t s1L0LoopCnt = CeilDiv(runInfo.actMBaseSize, S1G_BASIC_BLOCK_L0); // 256取128 + int64_t s1gL1Offset[2] = {0, static_cast(S1G_BASIC_BLOCK_L0)}; + int64_t s1gL0RealSize[2] = {s1L0LoopCnt > 1 ? static_cast(S1G_BASIC_BLOCK_L0) : runInfo.actMBaseSize, + runInfo.actMBaseSize - s1gL1Offset[1]}; + MmInfo mmInfo[2]; + CalcMmInfo(mmInfo[loopIdx & 1], loopIdx, s1L0LoopCnt, mmInfo[(loopIdx + 1) & 1], runInfo); + + ProcessQk(s1gL0RealSize[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt], + s1gL1Offset[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt], s1L0LoopCnt, mmInfo[loopIdx & 1], + runInfo); + + SetFlag(FIX_MTE1_EVENT + sL1BufIdx_ % DOUBLE_BUF_NUM); + sL1BufIdx_++; + loopIdx++; + + while (loopIdx < s2L0LoopCnt * s1L0LoopCnt) { + CalcMmInfo(mmInfo[loopIdx & 1], loopIdx, s1L0LoopCnt, mmInfo[(loopIdx + 1) & 1], runInfo); + + ProcessQk(s1gL0RealSize[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt], + s1gL1Offset[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt], s1L0LoopCnt, mmInfo[loopIdx & 1], + runInfo); + + SetFlag(FIX_MTE1_EVENT + sL1BufIdx_ % DOUBLE_BUF_NUM); + sL1BufIdx_++; + + WaitFlag(FIX_MTE1_EVENT + sL1BufIdx_ % DOUBLE_BUF_NUM); + + ProcessWs(s1gL0RealSize[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt], + s1gL1Offset[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt], sL1BufIdx_, + mmInfo[(loopIdx + 1) & 1], runInfo); + loopIdx++; + } + + WaitFlag(FIX_MTE1_EVENT + (sL1BufIdx_ + 1) % DOUBLE_BUF_NUM); + + ProcessWs(s1gL0RealSize[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt], + s1gL1Offset[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt], sL1BufIdx_ - 1, + mmInfo[(loopIdx + 1) & 1], runInfo); + + if (runInfo.isLastS2InnerLoop) { + SetFlag(QW_MTE1_MTE2_EVENT + qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM); + qwL1Mte2BufIdx_++; + } +} + +// blkNum, blkSize, N2, D +template +__aicore__ inline void LIQMatmul::KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset, + const LIQCommon::RunInfo &runInfo) +{ + uint64_t s2L1Offset = 0; + while (s2L1Offset < s2L1RealSize) { + uint64_t s2BlkId = (s2L1Offset + s2GmOffset) / constInfo_.kCacheBlockSize; + uint64_t s2BlkOffset = (s2L1Offset + s2GmOffset) % constInfo_.kCacheBlockSize; + uint64_t keyGmOffset = blkTableGm_.GetValue(runInfo.bIdx * constInfo_.maxBlockNumPerBatch + s2BlkId) * + constInfo_.kCacheBlockSize * constInfo_.kHeadNum * constInfo_.headDim + + s2BlkOffset * constInfo_.headDim; + uint64_t s2Mte2Size = s2L1RealSize - s2L1Offset; + s2Mte2Size = s2BlkOffset + s2Mte2Size >= constInfo_.kCacheBlockSize ? constInfo_.kCacheBlockSize - s2BlkOffset + : s2Mte2Size; + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = s2Mte2Size; // 行数 + nd2nzPara.dValue = constInfo_.headDim; + nd2nzPara.srcDValue = constInfo_.headDim; + nd2nzPara.dstNzC0Stride = CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE); // 对齐到16 单位block + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + DataCopy(keyL1_[(keyL1BufIdx_ % DOUBLE_BUF_NUM) * KEY_BUFFER_OFFSET + s2L1Offset * S8_BLOCK_CUBE], + keyGm_[keyGmOffset], nd2nzPara); + + s2L1Offset += s2Mte2Size; + } +} + +template +__aicore__ inline void LIQMatmul::KeyNd2Nz(uint64_t s2L1RealSize, const MmInfo &mmInfo, + const LIQCommon::RunInfo &runInfo) +{ + uint64_t dStride = constInfo_.headDim; + if constexpr (K_LAYOUT_T == LI_LAYOUT::BSND || K_LAYOUT_T == LI_LAYOUT::TND) { + dStride = constInfo_.headDim * constInfo_.kHeadNum; // constInfo_.kHeadNum + } + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = s2L1RealSize; // 行数 + nd2nzPara.dValue = constInfo_.headDim; + nd2nzPara.srcDValue = dStride; + nd2nzPara.dstNzC0Stride = CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE); // 对齐到16 单位block + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + // 默认一块buf最多放两份 + DataCopy(keyL1_[(keyL1BufIdx_ % DOUBLE_BUF_NUM) * KEY_BUFFER_OFFSET], + keyGm_[runInfo.tensorKeyOffset + mmInfo.s2GmOffset * constInfo_.headDim], nd2nzPara); +} + +// batch, s1, g, 1 +template +__aicore__ inline void LIQMatmul::WeightDmaCopy(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo) +{ + DataCopyParams copyInParams; + copyInParams.blockCount = 1; + copyInParams.blockLen = s1gL1RealSize; + copyInParams.srcStride = 0; + copyInParams.dstStride = 0; + DataCopy(weightL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * WEIGHT_BUFFER_OFFSET], + weightGm_[runInfo.loop % DOUBLE_BUF_NUM * BLOCK_CUBE * constInfo_.mBaseSize], copyInParams); +} + +// batch, s1, n2, g, d +template +__aicore__ inline void LIQMatmul::QueryNd2Nz(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo) +{ + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = s1gL1RealSize; // 行数 + nd2nzPara.dValue = constInfo_.headDim; + nd2nzPara.srcDValue = constInfo_.headDim; + nd2nzPara.dstNzC0Stride = CeilAlign(s1gL1RealSize, (uint64_t)BLOCK_CUBE); // 对齐到16 单位block + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + // 默认一块buf最多放两份 + DataCopy(queryL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * QUERY_BUFFER_OFFSET], queryGm_[runInfo.tensorQueryOffset], + nd2nzPara); +} + +// s1g, d +template +__aicore__ inline void LIQMatmul::LoadQueryToL0a(uint64_t s1gL1Offset, uint64_t s1gL1RealSize, + uint64_t s1gL0RealSize) +{ + LoadData3DParamsV2 loadData3DParams; + // SetFmatrixParams + loadData3DParams.l1H = CeilDiv(s1gL1RealSize, BLOCK_CUBE); // Hin=M1=8 + loadData3DParams.l1W = BLOCK_CUBE; // Win=M0 + loadData3DParams.channelSize = constInfo_.headDim; // Cin=K + + loadData3DParams.padList[0] = 0; + loadData3DParams.padList[1] = 0; + loadData3DParams.padList[2] = 0; + loadData3DParams.padList[3] = 255; // 尾部数据不影响滑窗的结果 + + // SetLoadToA0Params + loadData3DParams.mExtension = CeilAlign(s1gL0RealSize, BLOCK_CUBE); // M height维度目的 + loadData3DParams.kExtension = constInfo_.headDim; // K width维度目的 + loadData3DParams.mStartPt = s1gL1Offset; + loadData3DParams.kStartPt = 0; + loadData3DParams.strideW = 1; + loadData3DParams.strideH = 1; + loadData3DParams.filterW = 1; + loadData3DParams.filterSizeW = (1 >> 8) & 255; + loadData3DParams.filterH = 1; + loadData3DParams.filterSizeH = (1 >> 8) & 255; + loadData3DParams.dilationFilterW = 1; + loadData3DParams.dilationFilterH = 1; + loadData3DParams.enTranspose = 0; + loadData3DParams.fMatrixCtrl = 0; + + LoadData(l0a_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K], + queryL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * QUERY_BUFFER_OFFSET], + loadData3DParams); +} + +// s1, g, s2 --> 2 * 64* 128 +template +__aicore__ inline void LIQMatmul::LoadSToL0b(uint64_t s1gL1RealSize, uint64_t s2L0RealSize, uint64_t sL1BufIdx, + int64_t mStartPt) +{ + LoadData3DParamsV2 loadData3DParams; + // SetFmatrixParams + loadData3DParams.l1H = S1G_BASIC_BLOCK_L0 / BLOCK_CUBE; // Hin=M1=8 + loadData3DParams.l1W = BLOCK_CUBE; // Win=M0 + loadData3DParams.channelSize = CeilAlign(s2L0RealSize, BLOCK_CUBE); // Cin=K + + loadData3DParams.padList[0] = 0; + loadData3DParams.padList[1] = 0; + loadData3DParams.padList[2] = 0; + loadData3DParams.padList[3] = 255; // 尾部数据不影响滑窗的结果 + + // SetLoadToA0Params + loadData3DParams.mExtension = constInfo_.gSize; // M height维度目的 + loadData3DParams.kExtension = CeilAlign(s2L0RealSize, BLOCK_CUBE); // K width维度目的 + loadData3DParams.kStartPt = 0; + loadData3DParams.strideW = 1; + loadData3DParams.strideH = 1; + loadData3DParams.filterW = 1; + loadData3DParams.filterSizeW = (1 >> 8) & 255; + loadData3DParams.filterH = 1; + loadData3DParams.filterSizeH = (1 >> 8) & 255; + loadData3DParams.dilationFilterW = 1; + loadData3DParams.dilationFilterH = 1; + loadData3DParams.enTranspose = 1; + loadData3DParams.fMatrixCtrl = 0; + + loadData3DParams.mStartPt = mStartPt; + LoadData( + l0b_.template ReinterpretCast()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K], + sL1_[(sL1BufIdx % DOUBLE_BUF_NUM) * SL1_BUFFER_OFFSET], loadData3DParams); +} + +// s1,g,1(16), 2,64,16 +template +__aicore__ inline void LIQMatmul::LoadWeightToL0a(uint64_t s1gL1Offset) +{ + LoadData2DParams loadData2DParams; + loadData2DParams.startIndex = 0; + loadData2DParams.repeatTimes = CeilDiv(constInfo_.gSize, BLOCK_CUBE); + loadData2DParams.srcStride = 1; + loadData2DParams.dstGap = 0; + loadData2DParams.ifTranspose = true; + LoadData(l0a_.template ReinterpretCast()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K], + weightL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * WEIGHT_BUFFER_OFFSET + s1gL1Offset* BLOCK_CUBE], + loadData2DParams); +} + +// s2, d -> 128,128 +template +__aicore__ inline void LIQMatmul::LoadKeyToL0b(uint64_t s2L0RealSize) +{ + LoadData2DParams loadData2DParams; + loadData2DParams.startIndex = 0; + loadData2DParams.repeatTimes = CeilDiv(s2L0RealSize, BLOCK_CUBE) * CeilDiv(constInfo_.headDim, S8_BLOCK_CUBE); + loadData2DParams.srcStride = 1; + loadData2DParams.dstGap = 0; + loadData2DParams.ifTranspose = false; + LoadData(l0b_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K], + keyL1_[(keyL1BufIdx_ % DOUBLE_BUF_NUM) * KEY_BUFFER_OFFSET], loadData2DParams); +} + +// A: s1,g,1(16) B: s1,g,s2 C: s1, 1(16), s2 +template +__aicore__ inline void LIQMatmul::ComputeWs(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, int64_t s1gOffset) +{ + SetFlag(MTE1_M_EVENT); + WaitFlag(MTE1_M_EVENT); + MmadParams mmadParams; + mmadParams.m = BLOCK_CUBE; + mmadParams.n = s2L0RealSize; + mmadParams.k = constInfo_.gSize; + mmadParams.cmatrixInitVal = true; + mmadParams.cmatrixSource = false; + Mmad(cL0_.template ReinterpretCast()[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET + + s1gOffset * S2_BASIC_BLOCK_L0], + l0a_.template ReinterpretCast()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K], + l0b_.template ReinterpretCast()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K], + mmadParams); +} + +template +__aicore__ inline void LIQMatmul::ComputeQk(uint64_t s1gL0RealSize, uint64_t s2L0RealSize) +{ + SetFlag(MTE1_M_EVENT); + WaitFlag(MTE1_M_EVENT); + + MmadParams mmadParams; + mmadParams.m = CeilAlign(s1gL0RealSize, BLOCK_CUBE); + mmadParams.n = s2L0RealSize; + mmadParams.k = constInfo_.headDim; + mmadParams.cmatrixInitVal = true; + mmadParams.cmatrixSource = false; + Mmad(cL0_[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET], + l0a_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K], + l0b_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K], mmadParams); + if ((mmadParams.m / 16) * (mmadParams.n / 16) < 10) { + PipeBarrier(); + } +} + +template +__aicore__ inline void LIQMatmul::FixpSToL1(uint64_t s1gL0RealSize, uint64_t s2L0RealSize) +{ + SetFlag(M_FIX_EVENT); + WaitFlag(M_FIX_EVENT); + DataCopyCO12DstParams params; + params.mSize = CeilAlign(s1gL0RealSize, BLOCK_CUBE); + params.nSize = CeilAlign(s2L0RealSize, BLOCK_CUBE); + params.dstStride = S1G_BASIC_BLOCK_L0; + params.srcStride = params.mSize; + params.quantPre = QuantMode_t::DEQF16; + params.reluPre = 1; + params.channelSplit = 0; + params.nz2ndEn = 0; + SetFixpipePreQuantFlag(0x3a800000); + DataCopy(sL1_[(sL1BufIdx_ % DOUBLE_BUF_NUM) * SL1_BUFFER_OFFSET], + cL0_[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET], params); +} + +template +__aicore__ inline void LIQMatmul::FixpResToGm(uint64_t s1L0RealCount, uint64_t s2L0RealSize, uint64_t s1GmOffset, + uint64_t s2GmOffset, const LIQCommon::RunInfo &runInfo) +{ + SetFlag(M_FIX_EVENT); + WaitFlag(M_FIX_EVENT); + + AscendC::DataCopyCO12DstParams intriParams; + intriParams.mSize = 1; + intriParams.nSize = s2L0RealSize; + intriParams.dstStride = constInfo_.s2BaseSize; + intriParams.srcStride = 16; + // set mode according to dtype + intriParams.quantPre = QuantMode_t::NoQuant; + intriParams.nz2ndEn = true; + intriParams.reluPre = 0; + AscendC::SetFixpipeNz2ndFlag(s1L0RealCount, CeilDiv(constInfo_.gSize, BLOCK_CUBE) * S2_BASIC_BLOCK_L0 / BLOCK_CUBE, + 2048); + AscendC::DataCopy(mm1ResGm_[(runInfo.loop % 2) * constInfo_.mBaseSize / constInfo_.gSize * constInfo_.s2BaseSize + + s1GmOffset * intriParams.dstStride + s2GmOffset], + cL0_.template ReinterpretCast()[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET], + intriParams); +} + +template +__aicore__ inline void LIQMatmul::AllocEventID() +{ + SetFlag(KEY_MTE1_MTE2_EVENT + 0); + SetFlag(KEY_MTE1_MTE2_EVENT + 1); + SetFlag(KEY_MTE1_MTE2_EVENT + 2); + + SetFlag(QW_MTE1_MTE2_EVENT + 0); + SetFlag(QW_MTE1_MTE2_EVENT + 1); + + SetFlag(M_MTE1_EVENT + 0); + SetFlag(M_MTE1_EVENT + 1); + SetFlag(M_MTE1_EVENT + 2); + SetFlag(M_MTE1_EVENT + 3); + + SetFlag(FIX_M_EVENT + 0); + SetFlag(FIX_M_EVENT + 1); +} + +template +__aicore__ inline void LIQMatmul::FreeEventID() +{ + WaitFlag(KEY_MTE1_MTE2_EVENT + 0); + WaitFlag(KEY_MTE1_MTE2_EVENT + 1); + WaitFlag(KEY_MTE1_MTE2_EVENT + 2); + + WaitFlag(QW_MTE1_MTE2_EVENT + 0); + WaitFlag(QW_MTE1_MTE2_EVENT + 1); + + WaitFlag(M_MTE1_EVENT + 0); + WaitFlag(M_MTE1_EVENT + 1); + WaitFlag(M_MTE1_EVENT + 2); + WaitFlag(M_MTE1_EVENT + 3); + + WaitFlag(FIX_M_EVENT + 0); + WaitFlag(FIX_M_EVENT + 1); +} +} // namespace LIQKernel +#endif \ No newline at end of file diff --git a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_service_vector.h b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_service_vector.h new file mode 100644 index 00000000..2588998c --- /dev/null +++ b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_service_vector.h @@ -0,0 +1,665 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_quant_service_vector.h + * \brief + */ +#ifndef LIGHTNING_INDEXER_QUANT_SERVICE_VECTOR_H +#define LIGHTNING_INDEXER_QUANT_SERVICE_VECTOR_H + +#include "kernel_operator.h" +#include "kernel_operator_list_tensor_intf.h" +#include "kernel_tiling/kernel_tiling.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" +#include "lightning_indexer_quant_common.h" +#include "lightning_indexer_quant_vector.h" + +namespace LIQKernel { +using namespace LIQCommon; +using namespace LIQServiceVec; +constexpr uint32_t BASE_TOPK = 2048; +constexpr uint32_t BASE_TOPK_VALUE_IDX_SIZE = 4096; +constexpr uint32_t LD_PARAM_NUM = 16; + +template +class LIQVector { +public: + // =================================类型定义区================================= + static constexpr LI_LAYOUT Q_LAYOUT_T = LIQT::layout; + static constexpr LI_LAYOUT K_LAYOUT_T = LIQT::keyLayout; + static constexpr bool PAGE_ATTENTION = LIQT::pageAttention; + // MM输出数据类型, 当前只支持float + using MM1_OUT_T = float; + + __aicore__ inline LIQVector(){}; + __aicore__ inline void ProcessVec0(const LIQCommon::RunInfo &info); + __aicore__ inline void ProcessVec1(const LIQCommon::RunInfo &info); + __aicore__ inline void ProcessLD(); + __aicore__ inline void InitBuffers(TPipe *pipe); + __aicore__ inline void InitParams(const struct LIQCommon::ConstInfo &constInfo, + const LIQTilingData *__restrict tilingData); + __aicore__ inline void InitVecWorkspaceTensor(GlobalTensor vec0OutGm, GlobalTensor mm1ResGm, + GlobalTensor vec1ResGm, GlobalTensor vec1ParamGm); + __aicore__ inline void InitVecInputTensor(GlobalTensor weightsGm, GlobalTensor qScaleGm, + GlobalTensor kScaleGm, GlobalTensor indiceOutGm, + GlobalTensor blockTableGm); + __aicore__ inline void CleanInvalidOutput(int64_t invalidS1offset); + __aicore__ inline void AllocEventID(); + __aicore__ inline void FreeEventID(); + __aicore__ inline void InitLDBuffers(TPipe *pipe); + +protected: + GlobalTensor mm1ResGm; + GlobalTensor vec1ResGm; + GlobalTensor vec1ParamGm; + GlobalTensor weightsGm; + GlobalTensor qScaleGm; + GlobalTensor kScaleGm; + GlobalTensor vec0OutGm; + GlobalTensor indiceOutGm; + GlobalTensor blockTableGm; + // =================================常量区================================= + +private: + __aicore__ inline void GetKeyScale(const LIQCommon::RunInfo &runInfo, const LocalTensor &resUb, + int64_t batchId, int64_t startS2, int64_t getLen); + // ================================Local Buffer区==================================== + // queue + TQue inQueue_; + TQue outQueue_; + + // tmp buff for vector + TBuf sortOutBuf_; + TBuf indexBuf_; + TBuf paramBuf_; + TBuf tmpBuf_; + + // tmp buff for LD + TBuf<> ldToBeMrgBuf_; + TBuf<> ldTmpBuf_; + TBuf<> ldOutValueBuf_; + TBuf<> ldOutIdxBuf_; + + LocalTensor globalTopkIndice_; + LocalTensor globalTopkUb_; + + int32_t blockId_ = -1; + // para for vector + int32_t groupInner_ = 0; + int32_t globalTopkNum_ = 0; + int64_t blockS2StartIdx_ = 0; + int32_t gSize_ = 0; + int32_t kSeqSize_ = 0; + int32_t kHeadNum_ = 0; + int32_t qHeadNum_ = 0; + int32_t s1BaseSize_ = 0; + int32_t s2BaseSize_ = 0; + int32_t kCacheBlockSize_ = 0; + int32_t maxBlockNumPerBatch_ = 0; + + // para for LD + uint32_t mrgListNum_ = 4; + uint32_t paramNum_ = 16; + + struct LIQCommon::ConstInfo constInfo_; +}; + +template +__aicore__ inline void LIQVector::GetKeyScale(const LIQCommon::RunInfo &runInfo, const LocalTensor &resUb, + int64_t batchId, int64_t startS2, int64_t getLen) +{ + // startS2一定能整除kCacheBlockSize_ + AscendC::DataCopyPadExtParams padParams{false, 0, 0, 0}; + AscendC::DataCopyExtParams copyInParams; + if constexpr (PAGE_ATTENTION) { + int32_t startBlockTableIdx = startS2 / kCacheBlockSize_; + int32_t startBlockTableOffset = startS2 % kCacheBlockSize_; + int32_t blockTableBatchOffset = batchId * maxBlockNumPerBatch_; + copyInParams.blockCount = 1; + copyInParams.srcStride = 0; + copyInParams.dstStride = 0; + copyInParams.rsv = 0; + int32_t resUbBaseOffset = 0; + if (startBlockTableOffset > 0) { + int32_t firstPartLen = + kCacheBlockSize_ - startBlockTableOffset > getLen ? getLen : kCacheBlockSize_ - startBlockTableOffset; + copyInParams.blockLen = firstPartLen * sizeof(half); + int32_t blockId = blockTableGm.GetValue(blockTableBatchOffset + startBlockTableIdx); + SetWaitFlag(HardEvent::S_MTE2); + AscendC::DataCopyPad(resUb, kScaleGm[blockId * kCacheBlockSize_ + startBlockTableOffset], + copyInParams, padParams); + startBlockTableIdx++; + getLen = getLen - firstPartLen; + resUbBaseOffset = firstPartLen; + } + int32_t getLoopNum = CeilDiv(getLen, kCacheBlockSize_); + copyInParams.blockLen = kCacheBlockSize_ * sizeof(half); + for (int32_t i = 0; i < getLoopNum; i++) { + if (i == getLoopNum - 1) { + copyInParams.blockLen = (getLen - i * kCacheBlockSize_) * sizeof(half); + } + int32_t blockId = blockTableGm.GetValue(blockTableBatchOffset + startBlockTableIdx + i); + SetWaitFlag(HardEvent::S_MTE2); + AscendC::DataCopyPad(resUb[resUbBaseOffset + i * kCacheBlockSize_], kScaleGm[blockId * kCacheBlockSize_], + copyInParams, padParams); + } + } else { + copyInParams.blockCount = 1; + copyInParams.blockLen = getLen * sizeof(half); + copyInParams.srcStride = 0; + copyInParams.dstStride = 0; + copyInParams.rsv = 0; + AscendC::DataCopyPad(resUb, kScaleGm[runInfo.tensorKeyScaleOffset], copyInParams, padParams); + } +} + +template +__aicore__ inline void LIQVector::InitBuffers(TPipe *pipe) +{ + pipe->InitBuffer(paramBuf_, LD_PARAM_NUM * sizeof(int64_t)); // 1 KB + pipe->InitBuffer(inQueue_, 2, s2BaseSize_ * sizeof(float) * 2); // 32KB + pipe->InitBuffer(outQueue_, 1, BASE_TOPK * sizeof(float)); // 8 KB + pipe->InitBuffer(indexBuf_, s2BaseSize_ * sizeof(int32_t)); // 8 KB + pipe->InitBuffer(tmpBuf_, 64 * 1024); // 64KB + pipe->InitBuffer(sortOutBuf_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK_VALUE_IDX_SIZE * sizeof(float)); // 32KB + + globalTopkIndice_ = indexBuf_.Get(); + globalTopkUb_ = sortOutBuf_.Get(); + globalTopkNum_ = 0; + + // 基本块执行前初始化UB和GM + // step1. 初始化一个有序索引 0 - s2BaseSize_ + ArithProgression(globalTopkIndice_, 0, 1, s2BaseSize_); + // step2. globalTopkUb_ [CeilDiv(s1BaseSize_, 2), BASE_TOPK, 2] -inf,-1 + InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK_VALUE_IDX_SIZE); + + // step3. 初始化vec1ParamGm,是否进行LD的标志位设为-1(needFd=-1) + // vec1ResIn32Gm = [aic, 2, s1BaseSize_, 16] int32 + // ws清零 [needFd, s2AcSeq, s2Start, s2End, isS2End, bn2idx, s1Idx, ......] + LocalTensor tmpfBuff = outQueue_.AllocTensor(); + Duplicate(tmpfBuff.template ReinterpretCast(), -1, 2 * (s1BaseSize_ / 2) * paramNum_ * 2); + SetWaitFlag(HardEvent::V_MTE3); + int64_t wsInfoOffset = (blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ + // 2个AIV共同地址偏移 + (blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_; // 每个AIV的地址偏移,S1方向 + DataCopyPad(vec1ParamGm[wsInfoOffset], tmpfBuff.template ReinterpretCast(), + {1, static_cast((s1BaseSize_ / 2) * 2 * paramNum_ * sizeof(int64_t)), 0, 0}); + SetWaitFlag(HardEvent::MTE3_V); + outQueue_.FreeTensor(tmpfBuff); +} + +template +__aicore__ inline void LIQVector::InitLDBuffers(TPipe *pipe) +{ + pipe->Reset(); + pipe->InitBuffer(ldToBeMrgBuf_, BASE_TOPK_VALUE_IDX_SIZE * mrgListNum_ * sizeof(float)); + pipe->InitBuffer(ldTmpBuf_, BASE_TOPK_VALUE_IDX_SIZE * mrgListNum_ * sizeof(float)); + pipe->InitBuffer(ldOutValueBuf_, BASE_TOPK * sizeof(float)); + pipe->InitBuffer(ldOutIdxBuf_, BASE_TOPK * sizeof(int32_t)); +} + +template +__aicore__ inline void LIQVector::InitParams(const struct LIQCommon::ConstInfo &constInfo, + const LIQTilingData *__restrict tilingData) +{ + this->constInfo_ = constInfo; + blockS2StartIdx_ = 0; + gSize_ = constInfo.gSize; + kSeqSize_ = constInfo.kSeqSize; + // define N2 para + kHeadNum_ = constInfo.kHeadNum; + qHeadNum_ = constInfo.qHeadNum; + // define MMBase para + s1BaseSize_ = constInfo.s1BaseSize; // 4 + s2BaseSize_ = constInfo.s2BaseSize; // 2048 + kCacheBlockSize_ = constInfo.kCacheBlockSize; + maxBlockNumPerBatch_ = constInfo.maxBlockNumPerBatch; + blockId_ = GetBlockIdx(); +} + +template +__aicore__ inline void LIQVector::InitVecInputTensor(GlobalTensor weightsGm, GlobalTensor qScaleGm, + GlobalTensor kScaleGm, + GlobalTensor indiceOutGm, + GlobalTensor blockTableGm) +{ + this->weightsGm = weightsGm; + this->qScaleGm = qScaleGm; + this->kScaleGm = kScaleGm; + this->indiceOutGm = indiceOutGm; + this->blockTableGm = blockTableGm; +} + +template +__aicore__ inline void LIQVector::InitVecWorkspaceTensor(GlobalTensor vec0OutGm, + GlobalTensor mm1ResGm, + GlobalTensor vec1ResGm, + GlobalTensor vec1ParamGm) +{ + this->mm1ResGm = mm1ResGm; + this->vec1ResGm = vec1ResGm; + this->vec0OutGm = vec0OutGm; + this->vec1ParamGm = vec1ParamGm; +} + +template +__aicore__ inline void LIQVector::AllocEventID() +{ +} + +template +__aicore__ inline void LIQVector::FreeEventID() +{ +} + +template +__aicore__ inline void LIQVector::CleanInvalidOutput(int64_t invalidS1offset) +{ + // init -1 and copy to output + LocalTensor valueULocal = outQueue_.AllocTensor(); + LocalTensor idxULocal1 = valueULocal.template ReinterpretCast(); + Duplicate(idxULocal1, constInfo_.INVALID_IDX, constInfo_.sparseCount); + outQueue_.EnQue(valueULocal); + valueULocal = outQueue_.DeQue(); + LIQServiceVec::CopyOut(indiceOutGm[invalidS1offset], idxULocal1, constInfo_.sparseCount); + outQueue_.FreeTensor(valueULocal); +} + +template +__aicore__ inline void LIQVector::ProcessVec0(const LIQCommon::RunInfo &info) +{ + // 只需要一个v核做 + if (blockId_ % 2 != 0) { + return; + } + int32_t cuBaseS1Idx = info.gS1Idx * s1BaseSize_; + // 计算输出w基地址偏移 偶数循环 -> 0 + aic_offset 奇数循环 -> 4*64 + aic_offset + int64_t vec0OutGmOffset = (info.loop % 2) * ((s1BaseSize_ * gSize_ * BLOCK_CUBE)); + // 计算输入weight的地址偏移,qScale的地址偏移与weight相同 + int64_t weightGmOffset = info.tensorWeightsOffset + cuBaseS1Idx * qHeadNum_; + // 当前需要计算的S1行数,处理尾块场景 + int32_t cuS1ProcNum = cuBaseS1Idx + s1BaseSize_ > info.actS1Size ? info.actS1Size % s1BaseSize_ : s1BaseSize_; + int32_t cuProcEleNum = cuS1ProcNum * gSize_; + + LocalTensor inWeightsUb = inQueue_.AllocTensor(); + LocalTensor inQScaleUb = inWeightsUb[cuProcEleNum]; + AscendC::DataCopyPadExtParams padParams{false, 0, 0, 0}; + AscendC::DataCopyExtParams copyInParams; + copyInParams.blockCount = 1; + copyInParams.blockLen = cuProcEleNum * sizeof(half); + copyInParams.srcStride = 0; + copyInParams.dstStride = 0; + copyInParams.rsv = 0; + AscendC::DataCopyPad(inWeightsUb, weightsGm[weightGmOffset], copyInParams, padParams); + AscendC::DataCopyPad(inQScaleUb, qScaleGm[weightGmOffset], copyInParams, padParams); + + inQueue_.EnQue(inWeightsUb); + inWeightsUb = inQueue_.DeQue(); + AscendC::Mul(inWeightsUb, inWeightsUb, inQScaleUb, cuProcEleNum); + PipeBarrier(); + LocalTensor resUb = outQueue_.AllocTensor(); + AscendC::Brcb(resUb, inWeightsUb, static_cast(cuProcEleNum / 8), {1, 8}); + inQueue_.FreeTensor(inWeightsUb); + + outQueue_.EnQue(resUb); + resUb = outQueue_.DeQue(); + AscendC::DataCopyParams copyOutParams; + copyOutParams.blockCount = 1; + copyOutParams.blockLen = cuProcEleNum * BLOCK_CUBE * sizeof(half); + copyOutParams.srcStride = 0; + copyOutParams.dstStride = 0; + AscendC::DataCopyPad(vec0OutGm[vec0OutGmOffset], resUb, copyOutParams); + outQueue_.FreeTensor(resUb); +} + +template +__aicore__ inline void LIQVector::ProcessVec1(const LIQCommon::RunInfo &info) +{ + int32_t cuBaseS1Idx = info.gS1Idx * s1BaseSize_; + int32_t cuBaseS2Idx = info.s2Idx * s2BaseSize_; + + // 计算基本块基地址偏移 偶数循环 -> 0 + aic_offset 奇数循环 -> 4*2048 + aic_offset + int64_t mmGmOffset = (info.loop % 2) * (s1BaseSize_ * s2BaseSize_); + + // cuS1BeginIdxPerAiv: 每个AIV的S1起始偏移 + int32_t cuS1BeginIdxPerAiv = cuBaseS1Idx; + int32_t cuS1ProcNum = + cuS1BeginIdxPerAiv + s1BaseSize_ > info.actS1Size ? info.actS1Size % s1BaseSize_ : s1BaseSize_; + // cuS1ProcNumPerAiv: 每个AIv的S1计算量 + int32_t cuS1ProcNumPerAiv = blockId_ % 2 == 0 ? CeilDiv(cuS1ProcNum, 2) : (cuS1ProcNum / 2); + cuS1BeginIdxPerAiv += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2); + // 基本块基地址偏移奇数核加一个S1地址偏移 + mmGmOffset += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2) * s2BaseSize_; + // 非首个基本块, M(S1)轴发生切换需要初始化 + if (info.loop != 0 && info.s2Idx == 0) { + // globalTopkUb_ value,index=-inf,-1 + InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK_VALUE_IDX_SIZE); + blockS2StartIdx_ = 0; + } else if (info.loop == 0) { + blockS2StartIdx_ = info.s2Idx; + } + // cuRealAcSeq: 当前基本块S1对应的AcSeq + int32_t cuRealAcSeq = info.actS2Size; + if (constInfo_.attenMaskFlag) { + // attenMask true场景 + cuRealAcSeq = info.actS2Size - (info.actS1Size - cuS1BeginIdxPerAiv); + } + + // LD输出S1方向偏移,保证2个Vector输出的内容连续 + uint32_t ldS1Offset = (blockId_ % 2 == 0) ? s1BaseSize_ / 2 - cuS1ProcNumPerAiv : 0; + for (int innerS1Idx = 0; innerS1Idx < cuS1ProcNumPerAiv; innerS1Idx++) { + if (constInfo_.attenMaskFlag) { + cuRealAcSeq += 1; + } + int32_t cuS2Len = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq ? cuRealAcSeq - cuBaseS2Idx : s2BaseSize_; + int32_t cuS1Idx = cuS1BeginIdxPerAiv + innerS1Idx; + if (cuRealAcSeq > 0 && cuS2Len > 0) { + int32_t cuS2LenVecAlign = CeilDiv(cuS2Len, s2BaseSize_) * s2BaseSize_; + LocalTensor mmInUb = inQueue_.AllocTensor(); + LocalTensor kScaleUb = mmInUb[cuS2LenVecAlign]; + LocalTensor kScaleTUb = kScaleUb.template ReinterpretCast()[cuS2LenVecAlign]; + AscendC::DataCopyPadExtParams padParams{false, 0, 0, 0}; + AscendC::DataCopyPadExtParams padTParams{false, 0, 0, 0}; + AscendC::DataCopyExtParams copyInParams; + copyInParams.blockCount = 1; + copyInParams.blockLen = cuS2Len * sizeof(float); + copyInParams.srcStride = 0; + copyInParams.dstStride = 0; + copyInParams.rsv = 0; + AscendC::DataCopyPad(mmInUb, mm1ResGm[mmGmOffset + innerS1Idx * s2BaseSize_], copyInParams, padParams); + GetKeyScale(info, kScaleTUb, info.bIdx, cuBaseS2Idx, cuS2Len); + inQueue_.EnQue(mmInUb); + mmInUb = inQueue_.DeQue(); + AscendC::Cast(kScaleUb, kScaleTUb, RoundMode::CAST_NONE, cuS2Len); + PipeBarrier(); + AscendC::Mul(mmInUb, mmInUb, kScaleUb, cuS2Len); + PipeBarrier(); + LocalTensor sortBuff = tmpBuf_.Get(); + LocalTensor sortScoreUb = sortBuff; + LocalTensor sortIndiceUb = sortBuff[cuS2LenVecAlign]; + PipeBarrier(); + Duplicate(sortScoreUb.template ReinterpretCast(), LIQServiceVec::NEG_INF, cuS2LenVecAlign); + PipeBarrier(); + Adds(sortScoreUb, mmInUb, 0.0f, cuS2Len); + PipeBarrier(); + inQueue_.FreeTensor(mmInUb); + LocalTensor sortIndiceUbInt = sortIndiceUb.template ReinterpretCast(); + // 无效数据索引填充为-1 + if (cuS2LenVecAlign != cuS2Len) { + Duplicate(sortIndiceUbInt, -1, cuS2LenVecAlign); + PipeBarrier(); + } + Adds(sortIndiceUbInt, globalTopkIndice_, static_cast(cuBaseS2Idx), cuS2Len); + PipeBarrier(); + LocalTensor tmpSortBuf = sortBuff[2 * cuS2LenVecAlign]; + LIQServiceVec::SortAll(sortBuff, tmpSortBuf, cuS2LenVecAlign); + PipeBarrier(); + LIQServiceVec::MergeSort(globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE], BASE_TOPK, sortBuff, + cuS2LenVecAlign, tmpSortBuf); + PipeBarrier(); + bool isS2End = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq; + bool needCopyOutGm = blockS2StartIdx_ == 0 && isS2End; + // 中间结果保存 + bool needCopyWsGm = info.isAllLoopEnd || isS2End; + if (needCopyOutGm) { + LocalTensor idxULocal = outQueue_.AllocTensor(); + ExtractIndex(idxULocal, + globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE].template ReinterpretCast(), + BASE_TOPK); + PipeBarrier(); + InitSortOutBuf(globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE], BASE_TOPK_VALUE_IDX_SIZE); + outQueue_.EnQue(idxULocal); + idxULocal = outQueue_.DeQue(); + LIQServiceVec::CopyOut(indiceOutGm[info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount], + idxULocal.template ReinterpretCast(), constInfo_.sparseCount); + outQueue_.FreeTensor(idxULocal); + } else if (needCopyWsGm) { + // vec1Res Gm = [aic, s1BaseSize_, 2, 2, topkOut_] float32 + // vec1Param Gm = [aic, s1BaseSize_, 2, 16] int64 + // 16 = [needFd, s2AcSeq, s2Start, s2End, isS2End, bn2idx, s1Idx, S1ProcNum, ......] + + int64_t wsOffset = + (blockId_ / 2) * s1BaseSize_ * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 2个AIV共同地址偏移 + (blockId_ % 2) * (s1BaseSize_ / 2) * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 每个AIV的地址偏移,S1方向 + (ldS1Offset + innerS1Idx) * 2 * BASE_TOPK_VALUE_IDX_SIZE; + int64_t wsInfoOffset = + (blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ + // 2个AIV共同地址偏移 + (blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_ + // 每个AIV的地址偏移,S1方向 + (ldS1Offset + innerS1Idx) * 2 * paramNum_; + + LocalTensor tmpiBuff = paramBuf_.Get(); + SetWaitFlag(HardEvent::MTE3_S); + tmpiBuff.SetValue(0, static_cast(1)); + tmpiBuff.SetValue(1, static_cast(cuRealAcSeq)); + tmpiBuff.SetValue(2, static_cast(blockS2StartIdx_)); + tmpiBuff.SetValue(3, static_cast(cuBaseS2Idx + cuS2Len)); + tmpiBuff.SetValue(4, static_cast(isS2End)); + tmpiBuff.SetValue(5, static_cast(info.bN2Idx)); + tmpiBuff.SetValue(6, static_cast(cuS1Idx)); + tmpiBuff.SetValue(7, static_cast(cuS1ProcNum)); + tmpiBuff.SetValue(8, static_cast(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount)); + // 写入头尾判断 + // [head, tail] + // head: 与前面规约,与前后规约 + // tail: 与后面规约 + bool isTailReduce = blockS2StartIdx_ == 0; // 一定是isLastTile + // WS偏移规则 blockS2StartIdx_ != 0 + // 跟前面块做规约 写到0偏移 不用做计算 blockS2StartIdx_ == 0 and !isS2End + // 跟后面块做规约 写到1偏移 需要 + s1BaseSize_, BASE_TOPK*2 + if (isTailReduce) { // S2不是最后结束的数据就需要往后做规约,放入第二块ws + wsInfoOffset += paramNum_; + wsOffset += BASE_TOPK_VALUE_IDX_SIZE; + } + SetWaitFlag(HardEvent::S_MTE3); + LIQServiceVec::CopyOut(vec1ParamGm[wsInfoOffset], tmpiBuff, 16); + SetWaitFlag(HardEvent::V_MTE3); + LIQServiceVec::CopyOut(vec1ResGm[wsOffset], globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE], + BASE_TOPK_VALUE_IDX_SIZE); + SetWaitFlag(HardEvent::MTE3_V); + } + } else if (cuRealAcSeq <= 0) { + CleanInvalidOutput(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount); + } + } + + // BNSD场景无效S1 输出-1 + if (Q_LAYOUT_T == LI_LAYOUT::BSND) { + // 最后一个S1的基本块, 需要 >= info.actS1Size + bool isS1LoopEnd = (cuBaseS1Idx + s1BaseSize_) >= info.actS1Size; + int32_t invalidS1Num = constInfo_.qSeqSize - info.actS1Size; + // blockS2StartIdx_ == 0 控制S2从开始的核去做冗余清理 + if (invalidS1Num > 0 && isS1LoopEnd && blockS2StartIdx_ == 0) { + int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num, 2) : (invalidS1Num / 2); + int32_t s1OffsetPerAiv = info.actS1Size + (blockId_ % 2) * CeilDiv(invalidS1Num, 2); + for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) { + CleanInvalidOutput(info.indiceOutOffset + (s1OffsetPerAiv + innerS1Idx) * constInfo_.sparseCount); + } + } + + int32_t invalidS1Num2 = info.actS1Size - info.actS2Size; + if (invalidS1Num2 > 0 && isS1LoopEnd && blockS2StartIdx_ == 0 && constInfo_.attenMaskFlag) { + int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num2, 2) : (invalidS1Num2 / 2); + int32_t s1OffsetPerAiv = (blockId_ % 2) * CeilDiv(invalidS1Num2, 2); + for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) { + CleanInvalidOutput((info.bN2Idx * constInfo_.qSeqSize + s1OffsetPerAiv + innerS1Idx) * + constInfo_.sparseCount); + } + } + } + + if (info.isLastS2InnerLoop) { + // S2最后一个Loop后, 下一个基本块初始从0开始 + blockS2StartIdx_ = 0; + } +} + +template +__aicore__ inline void LIQVector::ProcessLD() +{ + int32_t curCubeId = blockId_ / 2; + int32_t tmpCubeId = curCubeId; + + int64_t s2ActSeq; + int64_t s2Start; + int64_t s2End; + int64_t isS2End; + int64_t bn2Idx; + int64_t s1Idx; + uint32_t acc_list_num = 0; + int64_t bIdx = 0; + int64_t needFd; + int64_t wsOffset; + int64_t wsInfoOffset = 0; + int64_t nextneedFd; + int64_t valueOffset = 0; + int64_t outOffset = 0; + + LocalTensor curValueIdxUb = ldToBeMrgBuf_.Get(); + LocalTensor tmpUb = ldTmpBuf_.Get(); + + // S2开头信息 + // 开始必然没有头规约,因此从尾规约开始处理,while循环读取下一个核的头规约 + // 存满4个list或者遇到S2结尾,则做merge,直到做完S2 + // 每个核都忽略自己的头规约,因为必然由前面的核做完 + uint32_t s1LdStartIdx = 0; + uint32_t s1ProcNum = 0; + uint64_t paramGmCoreOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_; + for (uint32_t innerS1Idx = 0; innerS1Idx < s1BaseSize_; innerS1Idx++) { + needFd = vec1ParamGm.GetValue(paramGmCoreOffset + innerS1Idx * 2 * paramNum_ + paramNum_); + if (needFd == 1) { + s1LdStartIdx = (s1ProcNum == 0) ? innerS1Idx : s1LdStartIdx; + s1ProcNum++; + } + } + + if (s1ProcNum == 0) { + return; + } + + // S1逐行计算 + uint32_t s1VecNum = CeilDiv(s1ProcNum, 2); + if (blockId_ % 2 == 1) { + s1LdStartIdx = s1LdStartIdx + s1VecNum; + s1VecNum = s1ProcNum - s1VecNum; + } + for (uint32_t innerS1Idx = s1LdStartIdx; innerS1Idx < s1LdStartIdx + s1VecNum; innerS1Idx++) { + // 重置偏移 + tmpCubeId = curCubeId; + acc_list_num = 0; + valueOffset = 0; + + // 搬入数据 + wsOffset = tmpCubeId * s1BaseSize_ * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 2个AIV共同地址偏移 + innerS1Idx * 2 * BASE_TOPK_VALUE_IDX_SIZE + BASE_TOPK_VALUE_IDX_SIZE; + SetWaitFlag(HardEvent::V_MTE2); + SetWaitFlag(HardEvent::S_MTE2); + DataCopyPad(curValueIdxUb, vec1ResGm[wsOffset], + {1, static_cast(BASE_TOPK_VALUE_IDX_SIZE * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0}); + acc_list_num++; + valueOffset += BASE_TOPK_VALUE_IDX_SIZE; + + // 获取下一个核规约信息 + tmpCubeId++; + wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_; + needFd = vec1ParamGm.GetValue(wsInfoOffset); + isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4); + s1Idx = vec1ParamGm.GetValue(wsInfoOffset + 6); + outOffset = vec1ParamGm.GetValue(wsInfoOffset + 8); + + while (needFd == 1) { + // 搬入头规约数据 + wsOffset = tmpCubeId * s1BaseSize_ * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 2个AIV共同地址偏移 + innerS1Idx * 2 * BASE_TOPK_VALUE_IDX_SIZE; + SetWaitFlag(HardEvent::V_MTE2); + SetWaitFlag(HardEvent::S_MTE2); + DataCopyPad(curValueIdxUb[valueOffset], vec1ResGm[wsOffset], + {1, static_cast(BASE_TOPK_VALUE_IDX_SIZE * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0}); + valueOffset += BASE_TOPK_VALUE_IDX_SIZE; + acc_list_num++; + + // 每满4个list,聚合 前2K为mrg结果 + if (acc_list_num == mrgListNum_) { + // MrgSort 四条2048的队列,Mrg成一条 + AscendC::MrgSort4Info params; + params.elementLengths[0] = BASE_TOPK; + params.elementLengths[1] = BASE_TOPK; + params.elementLengths[2] = BASE_TOPK; + params.elementLengths[3] = BASE_TOPK; + params.ifExhaustedSuspension = true; + params.validBit = 0b1111; + params.repeatTimes = 1; + + AscendC::MrgSortSrcList srcList; + srcList.src1 = curValueIdxUb[0]; + srcList.src2 = curValueIdxUb[BASE_TOPK_VALUE_IDX_SIZE]; + srcList.src3 = curValueIdxUb[2 * BASE_TOPK_VALUE_IDX_SIZE]; + srcList.src4 = curValueIdxUb[3 * BASE_TOPK_VALUE_IDX_SIZE]; + SetWaitFlag(HardEvent::MTE2_V); + MrgSort(tmpUb, srcList, params); + PipeBarrier(); + DataCopy(curValueIdxUb, tmpUb, BASE_TOPK_VALUE_IDX_SIZE); + PipeBarrier(); + acc_list_num = 1; + valueOffset = BASE_TOPK_VALUE_IDX_SIZE; + } + + // reduce到S2末尾,则跳出 + if (isS2End == 1) { + break; + } + + tmpCubeId++; + wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_; + needFd = vec1ParamGm.GetValue(wsInfoOffset); + isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4); + } + + // mrg不足4个list的数据 + if (acc_list_num != 1) { + AscendC::MrgSort4Info params; + params.elementLengths[0] = BASE_TOPK; + params.elementLengths[1] = BASE_TOPK; + params.elementLengths[2] = BASE_TOPK; + params.elementLengths[3] = BASE_TOPK; + params.ifExhaustedSuspension = true; + if (acc_list_num == 2) { + params.validBit = 0b0011; + } else if (acc_list_num == 3) { + params.validBit = 0b0111; + } + params.repeatTimes = 1; + + AscendC::MrgSortSrcList srcList; + srcList.src1 = curValueIdxUb[0]; + srcList.src2 = curValueIdxUb[BASE_TOPK_VALUE_IDX_SIZE]; + srcList.src3 = curValueIdxUb[2 * BASE_TOPK_VALUE_IDX_SIZE]; + srcList.src4 = curValueIdxUb[3 * BASE_TOPK_VALUE_IDX_SIZE]; + SetWaitFlag(HardEvent::MTE2_V); + MrgSort(tmpUb, srcList, params); + PipeBarrier(); + DataCopy(curValueIdxUb, tmpUb, BASE_TOPK_VALUE_IDX_SIZE); + PipeBarrier(); + } + + // 搬出 + LocalTensor outValueUb = ldOutValueBuf_.Get(); + LocalTensor outIdxUb = ldOutIdxBuf_.Get(); + Extract(outValueUb, outIdxUb, curValueIdxUb, (BASE_TOPK / 32)); + LocalTensor idxULocal1 = outIdxUb.template ReinterpretCast(); + SetWaitFlag(HardEvent::V_MTE3); + SetWaitFlag(HardEvent::S_MTE3); + DataCopyPad(indiceOutGm[outOffset], idxULocal1, + {1, static_cast(constInfo_.sparseCount * sizeof(int32_t)), 0, 0}); + SetWaitFlag(HardEvent::MTE3_V); + } +} +} // namespace LIQKernel +#endif \ No newline at end of file diff --git a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_template_tiling_key.h b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_template_tiling_key.h new file mode 100644 index 00000000..165e6215 --- /dev/null +++ b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_template_tiling_key.h @@ -0,0 +1,53 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_quant_template_tiling_key.h + * \brief + */ + +#ifndef TEMPLATE_TILING_KEY_LI_H_ +#define TEMPLATE_TILING_KEY_LI_H_ + +#include "ascendc/host_api/tiling/template_argument.h" + +#define LI_TPL_FP16 1 +#define LI_TPL_IN8 2 +#define LI_TPL_INT32 3 +#define LI_TPL_BF16 27 + +#define LIQ_LAYOUT_BSND 0 +#define LIQ_LAYOUT_TND 1 +#define LIQ_LAYOUT_PA_BSND 2 + +#define ASCENDC_TPL_4_BW 4 + +// 模板参数支持的范围定义 +ASCENDC_TPL_ARGS_DECL(LightningIndexerQuant, // 算子OpType + ASCENDC_TPL_DTYPE_DECL(DT_Q, LI_TPL_IN8), ASCENDC_TPL_DTYPE_DECL(DT_K, LI_TPL_IN8), + ASCENDC_TPL_DTYPE_DECL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_DECL(PAGE_ATTENTION, 1, 0), + ASCENDC_TPL_UINT_DECL(Q_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, + LIQ_LAYOUT_TND), + ASCENDC_TPL_UINT_DECL(K_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, + LIQ_LAYOUT_PA_BSND, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND), ); + +// 支持的模板参数组合 +// 用于调用GET_TPL_TILING_KEY获取TilingKey时,接口内部校验TilingKey是否合法 +ASCENDC_TPL_SEL( + ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_IN8), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_IN8), + ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 1), + ASCENDC_TPL_UINT_SEL(Q_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_PA_BSND), ), + ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_IN8), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_IN8), + ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 0), + ASCENDC_TPL_UINT_SEL(Q_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND), ), ); + +#endif \ No newline at end of file diff --git a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_vector.h b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_vector.h new file mode 100644 index 00000000..d6a2e277 --- /dev/null +++ b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_vector.h @@ -0,0 +1,193 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_quant_vector.h + * \brief + */ +#ifndef LIGHTNING_INDEXER_QUANT_VECTOR_H +#define LIGHTNING_INDEXER_QUANT_VECTOR_H + +#include "kernel_operator.h" +#include "lightning_indexer_quant_vector.h" + +namespace LIQServiceVec { +using namespace AscendC; + +constexpr int32_t NEG_INF = 0xFF800000; +constexpr int32_t INVALID_INDEX = -1; +constexpr uint8_t VEC_REPEAT_MAX = 255; +constexpr uint8_t B32_VEC_ELM_NUM = 64; +constexpr uint8_t B32_BLOCK_ALIGN_NUM = 8; +constexpr uint8_t B32_VEC_REPEAT_STRIDE = 8; +constexpr uint64_t VEC_REPEAT_BYTES = 256; +constexpr int32_t CONST_TWO = 2; +constexpr int64_t VALUE_AND_INDEX_NUM = 2; +constexpr int64_t BLOCK_BYTES = 32; +constexpr int64_t MRG_QUE_0 = 0; +constexpr int64_t MRG_QUE_1 = 1; +constexpr int64_t MRG_QUE_2 = 2; +constexpr int64_t MRG_QUE_3 = 3; +constexpr int64_t MRG_BLOCK_2 = 2; +constexpr int64_t MRG_BLOCK_3 = 3; +constexpr int64_t MRG_BLOCK_4 = 4; + +template +__aicore__ inline void CopyOut(const GlobalTensor &dstGm, const LocalTensor &srcUb, int64_t copyCount) +{ + AscendC::DataCopyParams dataCopyOutyParams; + dataCopyOutyParams.blockCount = 1; + dataCopyOutyParams.blockLen = copyCount * sizeof(T); + dataCopyOutyParams.srcStride = 0; + dataCopyOutyParams.dstStride = 0; + AscendC::DataCopyPad(dstGm, srcUb, dataCopyOutyParams); +} + +/** + src: 传入的初始化空间 + eleNum: 需要初始化的元素个数需为64整数倍,元素将被初始化为交错排布的-inf,-1 + */ +__aicore__ inline void InitSortOutBuf(const LocalTensor &src, int64_t eleNum) +{ + uint64_t mask1[2] = {0x5555555555555555, 0}; + uint64_t mask0[2] = {0xaaaaaaaaaaaaaaaa, 0}; + int64_t repeatNum = eleNum / B32_VEC_ELM_NUM; + int64_t forLoop = repeatNum / VEC_REPEAT_MAX; + int64_t forRemain = repeatNum % VEC_REPEAT_MAX; + for (int i = 0; i < forLoop; i++) { + AscendC::Duplicate(src.template ReinterpretCast(), NEG_INF, mask1, VEC_REPEAT_MAX, 1, + B32_VEC_REPEAT_STRIDE); + AscendC::Duplicate(src.template ReinterpretCast(), INVALID_INDEX, mask0, VEC_REPEAT_MAX, 1, + B32_VEC_REPEAT_STRIDE); + } + if (forRemain > 0) { + AscendC::Duplicate(src.template ReinterpretCast()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], NEG_INF, + mask1, forRemain, 1, B32_VEC_REPEAT_STRIDE); + AscendC::Duplicate(src.template ReinterpretCast()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], + INVALID_INDEX, mask0, forRemain, 1, B32_VEC_REPEAT_STRIDE); + } + AscendC::PipeBarrier(); +} + +/** + src: logits和索引,前logitsNum为logits,后logitsNum为索引 + tmp: 计算使用到的临时空间,大小与src一致 + logitsNum: 排序的元素个数, 暂只支持[128,256,384,512,1024,2048] + */ +__aicore__ inline void SortAll(LocalTensor &src, LocalTensor &tmp, int64_t logitsNum) +{ + int64_t sort32Repeats = logitsNum / BLOCK_BYTES; + AscendC::Sort32(tmp, src, src[logitsNum].ReinterpretCast(), sort32Repeats); + AscendC::PipeBarrier(); + + int64_t mrgGroups = sort32Repeats; + int64_t mrgElements = BLOCK_BYTES; + int64_t i = 0; + AscendC::LocalTensor srcTensor; + AscendC::LocalTensor dstTensor; + while (true) { + if (i % CONST_TWO == 0) { + srcTensor = tmp; + dstTensor = src; + } else { + srcTensor = src; + dstTensor = tmp; + } + AscendC::MrgSort4Info params; + params.elementLengths[0] = mrgElements; + params.elementLengths[MRG_QUE_1] = mrgElements; + params.elementLengths[MRG_QUE_2] = mrgElements; + params.elementLengths[MRG_QUE_3] = mrgElements; + params.ifExhaustedSuspension = false; + params.validBit = 0b1111; + + AscendC::MrgSortSrcList srcList; + srcList.src1 = srcTensor[0]; + srcList.src2 = srcTensor[MRG_QUE_1 * VALUE_AND_INDEX_NUM * mrgElements]; + srcList.src3 = srcTensor[MRG_QUE_2 * VALUE_AND_INDEX_NUM * mrgElements]; + srcList.src4 = srcTensor[MRG_QUE_3 * VALUE_AND_INDEX_NUM * mrgElements]; + if (mrgGroups <= MRG_BLOCK_4) { + params.repeatTimes = 1; + if (mrgGroups == 1) { + break; + } else if (mrgGroups == MRG_BLOCK_2) { + params.validBit = 0b0011; + } else if (mrgGroups == MRG_BLOCK_3) { + params.validBit = 0b0111; + } else if (mrgGroups == MRG_BLOCK_4) { + params.validBit = 0b1111; + } + AscendC::MrgSort(dstTensor, srcList, params); + i += 1; + break; + } else { + params.repeatTimes = mrgGroups / MRG_BLOCK_4; + AscendC::MrgSort(dstTensor, srcList, params); + i += 1; + mrgElements = mrgElements * MRG_BLOCK_4; + mrgGroups = mrgGroups / MRG_BLOCK_4; + } + AscendC::PipeBarrier(); + } + if (i % CONST_TWO == 0) { + AscendC::DataCopy(src, tmp, logitsNum * VALUE_AND_INDEX_NUM); + AscendC::PipeBarrier(); + } +} + +/** + mrgDst: 合并进的Tensor + mrgSrc: 待合并的Tensor + tmpTensor:空间为mrgDst+mrgSrc + */ +__aicore__ inline void MergeSort(const LocalTensor &mrgDst, int32_t mrgDstNum, LocalTensor &mrgSrc, + int32_t mrgSrcNum, LocalTensor &tmpTensor) +{ + AscendC::MrgSort4Info params; + params.elementLengths[0] = mrgSrcNum; + params.elementLengths[1] = mrgDstNum; + params.ifExhaustedSuspension = false; + params.validBit = 0b0011; + params.repeatTimes = 1; + + AscendC::MrgSortSrcList srcList; + srcList.src1 = mrgSrc; + srcList.src2 = mrgDst; + + AscendC::MrgSort(tmpTensor, srcList, params); + AscendC::PipeBarrier(); + AscendC::DataCopy(mrgDst, tmpTensor, mrgDstNum * VALUE_AND_INDEX_NUM); + AscendC::PipeBarrier(); +} + +__aicore__ inline void ExtractIndex(const LocalTensor &idxULocal, const LocalTensor &sortLocal, + int64_t extractNum) +{ + AscendC::GatherMaskParams gatherMaskParams; + gatherMaskParams.repeatTimes = Ceil(extractNum * sizeof(float) * VALUE_AND_INDEX_NUM, VEC_REPEAT_BYTES); + gatherMaskParams.src0BlockStride = 1; + gatherMaskParams.src0RepeatStride = B32_VEC_REPEAT_STRIDE; + gatherMaskParams.src1RepeatStride = 0; + uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数 + uint8_t src1Pattern = 2; // 固定模式2,表示筛选出奇数索引的数 + AscendC::GatherMask(idxULocal, sortLocal, src1Pattern, false, static_cast(0), gatherMaskParams, rsvdCnt); + AscendC::PipeBarrier(); +} + +template +__aicore__ inline void SetWaitFlag(HardEvent evt) +{ + event_t eventId = static_cast(GetTPipePtr()->FetchEventID(evt)); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); +} + +} // namespace LIQServiceVec +#endif // LIGHTNING_INDEXER_QUANT_VECTOR_H \ No newline at end of file diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index bc2bb72c..b22cb7b0 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -42,6 +42,7 @@ #include "moe_gating_top_k/moe_gating_top_k_torch_adpt.h" #include "moe_init_routing_custom/moe_init_routing_custom_torch_adpt.h" #include "sparse_flash_attention/sparse_flash_attention_torch_adpt.h" +#include "lightning_indexer_quant/lightning_indexer_quant_torch_adpt.h" #include #include #include @@ -918,4 +919,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) "-> Tensor[]" ); ops.impl("moe_grouped_matmul", torch::kPrivateUse1,&vllm_ascend::moe_grouped_matmul); + + // This operator is planned to be integrated into PTA in the near future. + // Once that happens, the implementation in csrc will be removed. + ops.def( + "npu_lightning_indexer_quant(Tensor query, Tensor key, Tensor weights, Tensor query_dequant_scale, " + " Tensor key_dequant_scale, *, Tensor? actual_seq_lengths_query=None, " + " Tensor? actual_seq_lengths_key=None, Tensor? block_table=None, " + " int query_quant_mode=0, int key_quant_mode=0, " + " str layout_query='BSND', str layout_key='BSND'," + " int sparse_count=2048, int sparse_mode=3) -> Tensor" + ); + ops.impl("npu_lightning_indexer_quant", torch::kPrivateUse1, &vllm_ascend::npu_lightning_indexer_quant); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index f5980a01..1f62ce2c 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -529,6 +529,44 @@ std::vector moe_grouped_matmul_meta( return y; } +at::Tensor npu_lightning_indexer_quant_meta( + const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights, + const at::Tensor &query_dequant_scale, const at::Tensor &key_dequant_scale, + const c10::optional &actual_seq_lengths_query, + const c10::optional &actual_seq_lengths_key, + const c10::optional &block_table, int64_t query_quant_mode, int64_t key_quant_mode, + c10::string_view layout_query, c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode) +{ + std::string query_layout_str = std::string(layout_query); + std::string key_layout_str = std::string(layout_key); + + const int SIZE = 8; + const int DIM_0 = 0; + const int DIM_1 = 1; + const int DIM_2 = 2; + const int DIM_3 = 3; + + at::SmallVector output_size; + for (size_t i = 0; i < query.sizes().size(); i++) { + TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater " + "than 0, but shape[", i, "] is ", query.size(i)); + } + for (size_t i = 0; i < key.sizes().size(); i++) { + TORCH_CHECK(key.size(i) > 0, "All values within key's shape should be greater " + "than 0, but shape[", i, "] is ", key.size(i)); + } + TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count); + int64_t keyHeadNum = (key_layout_str == "TND")? key.size(DIM_1) : key.size(DIM_2); + if (query_layout_str == "BSND") { + output_size = {query.size(DIM_0), query.size(DIM_1), keyHeadNum, sparse_count}; + } else { + output_size = {query.size(DIM_0), keyHeadNum, sparse_count}; + } + at::Tensor lightning_indexer_quant_output = at::empty(output_size, query.options().dtype(at::kInt)); + + return lightning_indexer_quant_output; +} + } // namespace meta } // namespace vllm_ascend @@ -576,5 +614,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { ops.impl("causal_conv1d_fn", &vllm_ascend::meta::causal_conv1d_fn_meta); // moe_grouped_matmul ops.impl("moe_grouped_matmul", &vllm_ascend::meta::moe_grouped_matmul_meta); + // Lightning indexer quant + ops.impl("npu_lightning_indexer_quant", &vllm_ascend::meta::npu_lightning_indexer_quant_meta); } } diff --git a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py index fae23fa1..de971834 100644 --- a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py @@ -266,6 +266,33 @@ def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep(): vllm_model.generate_greedy(long_example_prompts, max_tokens) +@patch.dict(os.environ, {"HCCL_OP_EXPANSION_MODE": "AIV"}) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) +@patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"}) +@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"}) +def test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep(): + short_example_prompts = [ + "Hello ", + ] + # "max_position_embeddings": 163840, + long_example_prompts = ["Hello " * (163839 - 500) + "Hello"] + max_tokens = 500 + with VllmRunner( + "vllm-ascend/DeepSeek-V3.2-W8A8-Pruning", + tensor_parallel_size=2, + quantization="ascend", + enable_expert_parallel=True, + max_model_len=163840, + compilation_config={"cudagraph_capture_sizes": [2, 4, 6, 8, 10, 12], "cudagraph_mode": "FULL_DECODE_ONLY"}, + speculative_config={"num_speculative_tokens": 1, "method": "deepseek_mtp"}, + additional_config={"layer_sharding": ["q_b_proj", "o_proj"], "enable_sparse_c8": True}, + reasoning_parser="deepseek_v3", + tokenizer_mode="deepseek_v32", + ) as vllm_model: + vllm_model.generate_greedy(short_example_prompts, max_tokens) + vllm_model.generate_greedy(long_example_prompts, max_tokens) + + @pytest.mark.parametrize("model", QWEN_W4A4_MODELS) def test_qwen3_w4a4_distributed_tp2(model): example_prompts = [ diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 27095936..8dd63427 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -134,9 +134,12 @@ class AscendConfig: bool(additional_config.get("enable_async_exponential", False)) and not vllm_is_batch_invariant() ) + use_sparse = hasattr(vllm_config.model_config, "hf_text_config") and hasattr( + vllm_config.model_config.hf_text_config, "index_topk" + ) + self.enable_kv_nz = additional_config.get("enable_kv_nz", False) if self.enable_kv_nz: - use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk") if not vllm_config.model_config.is_deepseek_mla or use_sparse: raise RuntimeError("enable_kv_nz is only supported for mla currently.") if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer: @@ -144,6 +147,17 @@ class AscendConfig: "enable_kv_nz is only supported in pd scenario and can only be used in D node." ) + from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type + + # Disable Sparse C8 for A5 + # A5 has not been fully validated for this path and may carry hidden risks. + # TODO(rjg-lyh): Enable A5 support after sufficient validation. + self.enable_sparse_c8 = ( + additional_config.get("enable_sparse_c8", False) + and use_sparse + and get_ascend_device_type() != AscendDeviceType.A5 + ) + def _construct_weight_prefetch_config(self, additional_config): weight_prefetch_config = additional_config.get("weight_prefetch_config", {}) self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config) diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index f7edb5fb..70bfcb33 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, TypeVar +import scipy # type: ignore import torch import torch_npu import vllm.envs as envs_vllm @@ -355,6 +356,9 @@ class AscendSFAImpl(MLAAttentionImpl): # Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled. o_proj_full_pool: torch.Tensor | None = None + # qk_hadamard tensor shared when dsa c8 enabled + qk_hadamard: torch.Tensor | None = None + def __init__( self, num_heads: int, @@ -425,6 +429,12 @@ class AscendSFAImpl(MLAAttentionImpl): self.is_rope_neox_style = False self.use_torch_npu_lightning_indexer = True + # dsa c8 + self.use_sparse_c8_indexer = ascend_config.enable_sparse_c8 + if self.use_sparse_c8_indexer: + self.c8_k_cache_dtype = torch.int8 + self.c8_k_scale_cache_dtype = torch.float16 + # Effective in SFA when FlashComm is enabled. self.enable_dsa_cp = enable_dsa_cp() @@ -515,6 +525,11 @@ class AscendSFAImpl(MLAAttentionImpl): # if mlapo, W_UK_T can't trans nz self.W_UK_T = maybe_trans_nz(self.W_UK_T) + if self.use_sparse_c8_indexer and AscendSFAImpl.qk_hadamard is None: + AscendSFAImpl.qk_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / ( + 128**0.5 + ) + # Processing the input parameters for MLAPO by reordering and transposing # QKV(and part of Q) weight, applying RoPE-related dimension transformations, # and handling quantization parameters. @@ -874,7 +889,15 @@ class AscendSFAImpl(MLAAttentionImpl): k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128] - return k_li + if self.use_sparse_c8_indexer: + k_li = k_li @ AscendSFAImpl.qk_hadamard + k_li, k_li_scale = torch_npu.npu_dynamic_quant(k_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype) + k_li_scale = k_li_scale.to(self.c8_k_scale_cache_dtype) # [b*s,] + k_li_scale = k_li_scale.unsqueeze(-1) # [b*s,1] + else: + k_li_scale = None + + return k_li, k_li_scale def indexer_select_post_process( self, @@ -905,10 +928,35 @@ class AscendSFAImpl(MLAAttentionImpl): q_li_pe = q_li_pe.squeeze(2) q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128] + if self.use_sparse_c8_indexer: + q_li_shape_ori = q_li.shape + q_li = q_li @ AscendSFAImpl.qk_hadamard + q_li, q_li_scale = torch_npu.npu_dynamic_quant(q_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype) + q_li_scale = q_li_scale.to(self.c8_k_scale_cache_dtype) + # DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer. # So two branches are maintained temporarily. # TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed. - if self.use_torch_npu_lightning_indexer: + if self.use_sparse_c8_indexer: + assert len(kv_cache) == 4 + weights = weights.to(torch.float16) + topk_indices = torch.ops._C_ascend.npu_lightning_indexer_quant( + query=q_li.view(q_li_shape_ori), + key=kv_cache[2], + weights=weights, + query_dequant_scale=q_li_scale.view(q_li_shape_ori[:-1]), + key_dequant_scale=kv_cache[3].squeeze(2), # B S N D -> B S D + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + block_table=attn_metadata.block_table, + query_quant_mode=0, + key_quant_mode=0, + layout_query="TND", + layout_key="PA_BSND", + sparse_count=2048, + sparse_mode=3, + ) + elif self.use_torch_npu_lightning_indexer: topk_indices, _ = torch_npu.npu_lightning_indexer( query=q_li, key=kv_cache[2], @@ -1031,7 +1079,7 @@ class AscendSFAImpl(MLAAttentionImpl): assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized" q_c = self.q_a_layernorm(q_c) - k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin) + k_li, k_li_scale = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin) wait_for_kv_layer_from_connector(layer_name) @@ -1044,20 +1092,46 @@ class AscendSFAImpl(MLAAttentionImpl): if self.enable_dsa_cp: assert k_pe is not None assert k_nope is not None + assert k_li is not None async_op = self.enable_dsa_cp_with_layer_shard or full_gather_o_proj_enabled # support all_gather kv async for communication calculation overlap - fused_kv_no_split, kv_ag_handle = all_gather_async( - torch.cat( - [ - k_pe.view(-1, k_pe.shape[-1]), - k_nope.view(-1, k_nope.shape[-1]), - k_li.view(-1, k_li.shape[-1]), - ], - dim=1, - ), - get_tp_group(), - async_op=async_op, - ) + if not self.use_sparse_c8_indexer: + fused_kv_no_split, kv_ag_handle = all_gather_async( + torch.cat( + [ + k_pe.view(-1, k_pe.shape[-1]), + k_nope.view(-1, k_nope.shape[-1]), + k_li.view(-1, k_li.shape[-1]), + ], + dim=1, + ), + get_tp_group(), + async_op=async_op, + ) + else: + # due to different dtypes, we have to split commu pass + assert k_li_scale is not None + fused_kv_no_split, _ = all_gather_async( + torch.cat( + [ + k_pe.view(-1, k_pe.shape[-1]), + k_nope.view(-1, k_nope.shape[-1]), + ], + dim=1, + ), + get_tp_group(), + async_op=async_op, + ) + k_li, _ = all_gather_async( + k_li, + get_tp_group(), + async_op=async_op, + ) + k_li_scale, kv_ag_handle = all_gather_async( + k_li_scale, + get_tp_group(), + async_op=async_op, + ) ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) q_pe = self.rope_single(q_pe, cos, sin) @@ -1077,9 +1151,12 @@ class AscendSFAImpl(MLAAttentionImpl): if kv_cache is not None: assert fused_kv_no_split is not None - k_pe, k_nope, k_li = fused_kv_no_split.split( - [self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1 - ) + if not self.use_sparse_c8_indexer: + k_pe, k_nope, k_li = fused_kv_no_split.split( + [self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1 + ) + else: + k_pe, k_nope = fused_kv_no_split.split([self.qk_rope_head_dim, self.kv_lora_rank], dim=-1) k_nope = k_nope.view(k_nope.shape[0], 1, -1) k_pe = k_pe.view(k_pe.shape[0], 1, -1) DeviceOperator.reshape_and_cache( @@ -1098,6 +1175,13 @@ class AscendSFAImpl(MLAAttentionImpl): torch_npu.npu_scatter_nd_update_( kv_cache[2].view(-1, k_li.shape[-1]), slot_mapping.view(-1, 1), k_li.view(-1, k_li.shape[-1]) ) # b, s, n, d + if self.use_sparse_c8_indexer: + assert len(kv_cache) == 4 + torch_npu.npu_scatter_nd_update_( + kv_cache[3].view(-1, k_li_scale.shape[-1]), + slot_mapping.view(-1, 1), + k_li_scale.view(-1, k_li_scale.shape[-1]), + ) if self.is_kv_producer: attn_metadata.reshape_cache_event.record() diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 4a1b0a16..4eb74013 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -137,6 +137,28 @@ # Remove this patch if upstream provides an official NPU graph-capture # guidance / auto-configuration path for HCCL. # +# ** 8. File: platform/patch_kv_cache_interface.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.kv_cache_interface.MLAAttentionSpec` +# Why: +# The default `MLAAttentionSpec` is mainly built around `kv_lora_rank` +# and `qk_rope_head_dim`. On NPU, we also use this class to describe DSA +# models. Unlike the GPU path, where cache management is handled by an +# additional indexer module, extending this class directly simplifies the +# corresponding `model_runner` implementation on NPU. +# +# This patch also adds Sparse C8 support for DSA models on NPU. As part +# of that support, members such as `page_size_bytes` need to be adapted, +# so they are overridden here as well to preserve overall readability. +# How: +# This patch subclasses the original implementation, overrides selected +# methods, and adds DSA-specific attributes and helpers with default +# values where needed. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/25896 +# Future Plan: +# Remove this patch after the upcoming KV cache spec refactor. +# # * Worker Patch: # =============== # diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index 4b8fc9d2..ba7a8f3d 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -18,6 +18,7 @@ import os import vllm_ascend.patch.platform.patch_distributed # noqa import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa +import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa import vllm_ascend.patch.platform.patch_mamba_config # noqa import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa import vllm_ascend.patch.platform.patch_sched_yield # noqa diff --git a/vllm_ascend/patch/platform/patch_kv_cache_interface.py b/vllm_ascend/patch/platform/patch_kv_cache_interface.py new file mode 100644 index 00000000..3719a3c5 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_kv_cache_interface.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass + +import torch +import vllm.v1.kv_cache_interface +from typing_extensions import Self +from vllm.utils.torch_utils import get_dtype_size +from vllm.v1.kv_cache_interface import MLAAttentionSpec + + +@dataclass(frozen=True) +class AscendMLAAttentionSpec(MLAAttentionSpec): + """MLAAttentionSpec extended to support DSA models, with optional Sparse C8 support. + + When Sparse C8 is enabled, the KV cache tuple changes from + (kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: bfloat16) + to + (kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: int8, kv_cache[3]: float16). + + The semantic meaning of each KV cache entry is as follows: + 1. kv_cache[0] stores kv_lora. + 2. kv_cache[1] stores k_rope. + 3. kv_cache[2] stores the key tensor from the indexer module. + 4. kv_cache[3] stores the key scale tensor from the indexer module, + and exists only when Sparse C8 is enabled. + + The main changes are as follows: + 1. The key tensor from the indexer module stored in kv_cache[2] is + converted from bf16 to int8 to reduce memory usage. It is then + processed with int8 precision in Lightning_indexer computation + to improve computational efficiency. + 2. The quantization scale of the key tensor in the indexer module + must also be stored for the Lightning_indexer_quant operator, + and is therefore saved in kv_cache[3]. + """ + + sparse_head_dim: tuple[int, ...] | None = None + cache_sparse_c8: bool = False + c8_k_cache_dtype: torch.dtype = torch.int8 + c8_k_scale_cache_dtype: torch.dtype = torch.float16 + + @property + def page_size_bytes(self) -> int: + if self.cache_sparse_c8: + assert self.sparse_head_dim is not None + assert len(self.sparse_head_dim) == 3 + num_heads_per_page = self.block_size * self.num_kv_heads + # kv_cache[0]: bfloat16, kv_cache[1]: bfloat16 + kv_lora_rank, qk_rope_head_dim = self.sparse_head_dim[:2] + k_pe_nope_bytes = num_heads_per_page * (kv_lora_rank + qk_rope_head_dim) * get_dtype_size(self.dtype) + # kv_cache[2]: int8 + index_head_dim = self.sparse_head_dim[-1] + indexer_k_bytes = num_heads_per_page * index_head_dim * get_dtype_size(self.c8_k_cache_dtype) + # kv_cache[3]: float16 + # since the scale is stored per token, head_dim is set to 1. + index_scale_head_dim = 1 + indexer_k_scale_bytes = ( + num_heads_per_page * index_scale_head_dim * get_dtype_size(self.c8_k_scale_cache_dtype) + ) + return k_pe_nope_bytes + indexer_k_bytes + indexer_k_scale_bytes + + return self.block_size * self.num_kv_heads * self.head_size * get_dtype_size(self.dtype) + + @property + def sparse_kv_cache_ratio(self) -> tuple[float, float, float, float | None]: + """ + Compute the relative byte share of each KV cache entry. + + Returns: + A tuple containing the ratios for: + - kv_cache[0] + - kv_cache[1] + - kv_cache[2] + - kv_cache[3] (None if Sparse C8 is disabled) + """ + + assert self.sparse_head_dim is not None + + def get_sparse_head_dim_virtual() -> tuple[int, int, int, int]: + assert self.sparse_head_dim is not None + assert self.cache_sparse_c8 is True + + kv_lora_rank, qk_rope_head_dim, index_k_head_dim = self.sparse_head_dim + + factor = get_dtype_size(self.dtype) // get_dtype_size(self.c8_k_cache_dtype) + index_k_head_dim_virtual = index_k_head_dim // factor + + assert get_dtype_size(self.dtype) == get_dtype_size(self.c8_k_scale_cache_dtype) + index_k_scale_head_dim_virtual = 1 + + return ( + kv_lora_rank, + qk_rope_head_dim, + index_k_head_dim_virtual, + index_k_scale_head_dim_virtual, + ) + + if self.cache_sparse_c8: + virtual_dims = get_sparse_head_dim_virtual() + total_virtual_head_dim = sum(virtual_dims) + + return ( + total_virtual_head_dim / virtual_dims[0], # kv_cache[0] + total_virtual_head_dim / virtual_dims[1], # kv_cache[1] + total_virtual_head_dim / virtual_dims[2], # kv_cache[2] + total_virtual_head_dim / virtual_dims[3], # kv_cache[3] + ) + + return ( + self.head_size / self.sparse_head_dim[0], # kv_cache[0] + self.head_size / self.sparse_head_dim[1], # kv_cache[1] + self.head_size / self.sparse_head_dim[2], # kv_cache[2] + None, # kv_cache[3] does not exist + ) + + @classmethod + def merge(cls, specs: list[Self]) -> Self: + assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), ( + "All attention layers in the same KV cache group must be MLAAttentionSpec." + ) + cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs) + assert len(cache_dtype_str_set) == 1, ( + "All attention layers in the same KV cache group must use the same quantization method." + ) + return cls( + block_size=specs[0].block_size, + num_kv_heads=specs[0].num_kv_heads, + head_size=specs[0].head_size, + sparse_head_dim=specs[0].sparse_head_dim, + dtype=specs[0].dtype, + cache_dtype_str=cache_dtype_str_set.pop(), + cache_sparse_c8=specs[0].cache_sparse_c8, + ) + + +vllm.v1.kv_cache_interface.MLAAttentionSpec = AscendMLAAttentionSpec diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7631e2a9..139a99aa 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -84,6 +84,7 @@ from vllm.v1.worker.ubatch_utils import ( ) from vllm.v1.worker.utils import AttentionGroup +# yapf: enable from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, using_paged_attention @@ -96,8 +97,6 @@ from vllm_ascend.compilation.acl_graph import ( set_graph_params, update_full_graph_params, ) - -# yapf: enable from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor from vllm_ascend.eplb.core.eplb_device_transfer_loader import D2DExpertWeightLoader from vllm_ascend.eplb.core.eplb_worker import EplbProcess @@ -274,7 +273,21 @@ class NPUModelRunner(GPUModelRunner): self.is_multimodal_model = self.model_config.is_multimodal_model self.block_size = vllm_config.cache_config.block_size # Set up Attention - self.use_sparse = hasattr(self.vllm_config.model_config.hf_text_config, "index_topk") + self.use_sparse = hasattr(vllm_config.model_config, "hf_text_config") and hasattr( + vllm_config.model_config.hf_text_config, "index_topk" + ) + if self.use_sparse: + self.sparse_head_dim = ( + self.model_config.hf_text_config.kv_lora_rank, + self.model_config.hf_text_config.qk_rope_head_dim, + self.model_config.hf_text_config.index_head_dim, + ) + # dsa c8 + self.use_sparse_c8_indexer = self.ascend_config.enable_sparse_c8 + if self.use_sparse_c8_indexer: + self.c8_k_cache_dtype = torch.int8 + self.c8_k_scale_cache_dtype = torch.float16 + self.attn_backend = get_attn_backend( 0, self.dtype, @@ -2623,7 +2636,7 @@ class NPUModelRunner(GPUModelRunner): to their corresponding memory buffer for K cache and V cache. """ # init kv cache tensors - kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None] = {} + kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None | None] = {} # prefill disaggregation need the addr of cache tensor be aligned with 2M alignment = 2 * 1024 * 1024 layer_kv_cache_spec: dict[str, KVCacheSpec] = {} @@ -2670,19 +2683,18 @@ class NPUModelRunner(GPUModelRunner): + self.model_config.hf_text_config.kv_lora_rank ) - dsa_k_cache_factor = None - dsa_k_cache_size = None if not self.model_config.use_mla: # for non-mla model, use FullAttentionSpec - k_tensor_split_factor = 2 - v_tensor_split_factor = 2 + k_tensor_split_factor = 2.0 + v_tensor_split_factor = 2.0 elif self.use_sparse: # for deepseek v3.2, we split the kv cache according to the corresponding ratio - sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio()) - k_tensor_split_factor, v_tensor_split_factor, dsa_k_cache_factor = [ # type: ignore - sparse_sum_head_size / ratio for ratio in self._get_sparse_kv_cache_ratio() - ] - dsa_k_cache_size = int(kv_cache_tensor.size // dsa_k_cache_factor) + kv_cache_spec = layer_kv_cache_spec[layer_name] + sparse_kv_cache_ratio = kv_cache_spec.sparse_kv_cache_ratio + k_tensor_split_factor = sparse_kv_cache_ratio[0] + v_tensor_split_factor = sparse_kv_cache_ratio[1] + dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2] + dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3] else: # for other deepseek models, use MLAAttentionSpec k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank @@ -2690,35 +2702,56 @@ class NPUModelRunner(GPUModelRunner): k_tensor_size = int(kv_cache_tensor.size // k_tensor_split_factor) v_tensor_size = int(kv_cache_tensor.size // v_tensor_split_factor) + dsa_k_tensor_size = None + dsa_k_scale_tensor_size = None + #### for deepseek sparse attention + if self.use_sparse: + dsa_k_tensor_size = int(kv_cache_tensor.size // dsa_k_tensor_split_factor) + if self.use_sparse_c8_indexer: + dsa_k_scale_tensor_size = int(kv_cache_tensor.size // dsa_k_scale_tensor_split_factor) # for other attentions, e.g., self_attn, sliding window attn if self.vllm_config.kv_transfer_config is None: k_tensor = torch.zeros(k_tensor_size, dtype=torch.int8, device=self.device) v_tensor = torch.zeros(v_tensor_size, dtype=torch.int8, device=self.device) - #### k cache: for deepseek sparse attention - if dsa_k_cache_factor is not None: - dsa_k_cache_tensor = torch.zeros(dsa_k_cache_size, dtype=torch.int8, device=self.device) + #### for deepseek sparse attention + if dsa_k_tensor_size is not None: + dsa_k_tensor = torch.zeros(dsa_k_tensor_size, dtype=torch.int8, device=self.device) + if dsa_k_scale_tensor_size is not None: + dsa_k_scale_tensor = torch.zeros( + dsa_k_scale_tensor_size, dtype=torch.int8, device=self.device + ) else: k_tensor = torch.zeros(k_tensor_size + alignment, dtype=torch.int8, device=self.device) v_tensor = torch.zeros(v_tensor_size + alignment, dtype=torch.int8, device=self.device) k_tensor = self._align_memory(k_tensor, alignment)[:k_tensor_size] v_tensor = self._align_memory(v_tensor, alignment)[:v_tensor_size] - #### k cache: for deepseek sparse attention - if dsa_k_cache_factor is not None and dsa_k_cache_size is not None: - dsa_k_cache_tensor = torch.zeros( - dsa_k_cache_size + alignment, dtype=torch.int8, device=self.device + #### for deepseek sparse attention + if dsa_k_tensor_size is not None: + dsa_k_tensor = torch.zeros( + dsa_k_tensor_size + alignment, dtype=torch.int8, device=self.device ) - dsa_k_cache_tensor = self._align_memory(dsa_k_cache_tensor, alignment)[:dsa_k_cache_size] + dsa_k_tensor = self._align_memory(dsa_k_tensor, alignment)[:dsa_k_tensor_size] + if dsa_k_scale_tensor_size is not None: + dsa_k_scale_tensor = torch.zeros( + dsa_k_scale_tensor_size + alignment, dtype=torch.int8, device=self.device + ) + dsa_k_scale_tensor = self._align_memory( + dsa_k_scale_tensor, alignment + )[:dsa_k_scale_tensor_size] for layer_name_inner in kv_cache_tensor.shared_by: # shared the attn kvcache for all shared layers if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner: - kv_cache_raw_tensors[layer_name_inner] = ( - (k_tensor, v_tensor) - if not self.use_sparse - else (k_tensor, v_tensor, dsa_k_cache_tensor) - ) - + if self.use_sparse: + if self.use_sparse_c8_indexer: + kv_cache_raw_tensors[layer_name_inner] = ( + k_tensor, v_tensor, dsa_k_tensor, dsa_k_scale_tensor + ) + else: + kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor, dsa_k_tensor) + else: + kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor) layer_names = set() for group in kv_cache_config.kv_cache_groups: for layer_name in group.layer_names: @@ -2760,13 +2793,23 @@ class NPUModelRunner(GPUModelRunner): # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may # encounter OOM issue if isinstance(kv_cache_spec, AttentionSpec): - raw_dsa_k_tensor = None if self.use_sparse: - raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore - layer_name - ] - assert raw_dsa_k_tensor is not None - sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel() + if self.use_sparse_c8_indexer: + raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor, raw_dsa_k_scale_tensor = kv_cache_raw_tensors[ # type: ignore + layer_name] + assert raw_dsa_k_tensor is not None + assert raw_dsa_k_scale_tensor is not None + sum_page_size_bytes = ( + raw_k_tensor.numel() + + raw_v_tensor.numel() + + raw_dsa_k_tensor.numel() + + raw_dsa_k_scale_tensor.numel() + ) + else: + raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore + layer_name] + assert raw_dsa_k_tensor is not None + sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel() elif self.use_hybrid_blocks and self.hybrid_with_attn_and_mamba: # Currently, we ensure that the same kvcache format is used even if there # is no shared layer, such as the full attention mtp layer of qwen3.5, etc. @@ -2813,7 +2856,7 @@ class NPUModelRunner(GPUModelRunner): kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size ) - dtype = kv_cache_spec.dtype + if not self.model_config.use_mla: k_shape = kv_cache_shape[1:] v_shape = k_shape @@ -2832,19 +2875,37 @@ class NPUModelRunner(GPUModelRunner): num_kv_heads, self.model_config.hf_text_config.qk_rope_head_dim, ] - k_cache = raw_k_tensor.view(dtype).view(k_shape) - v_cache = raw_v_tensor.view(dtype).view(v_shape) + k_cache = raw_k_tensor.view(kv_cache_spec.dtype).view(k_shape) + v_cache = raw_v_tensor.view(kv_cache_spec.dtype).view(v_shape) - if self.use_sparse and raw_dsa_k_tensor is not None: - index_head_dim = self._get_sparse_kv_cache_ratio()[-1] + if self.use_sparse: dsa_k_cache_shape = ( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, - index_head_dim, + self.model_config.hf_text_config.index_head_dim, ) - dsa_k_cache = raw_dsa_k_tensor.view(dtype).view(dsa_k_cache_shape) - kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache) + if self.use_sparse_c8_indexer: + # dsa_k + dsa_k_cache = raw_dsa_k_tensor.view(self.c8_k_cache_dtype).view(dsa_k_cache_shape) + # dsa_k_scale + dsa_k_scale_cache_shape = ( + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + 1, + ) + assert raw_dsa_k_scale_tensor is not None + dsa_k_scale_cache = ( + raw_dsa_k_scale_tensor + .view(self.c8_k_scale_cache_dtype) + .view(dsa_k_scale_cache_shape) + ) + kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache, dsa_k_scale_cache) + else: + # dsa_k + dsa_k_cache = raw_dsa_k_tensor.view(kv_cache_spec.dtype).view(dsa_k_cache_shape) + kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache) else: kv_caches[layer_name] = (k_cache, v_cache) elif isinstance(kv_cache_spec, MambaSpec): @@ -3098,18 +3159,31 @@ class NPUModelRunner(GPUModelRunner): elif isinstance(attn_module, MLAAttention): if self.use_sparse: - # TODO(cmq): This is a hack way to fix deepseek kvcache when - # using DSA. Fix the spec in vLLM is the final way. - sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio()) - kv_cache_spec[layer_name] = MLAAttentionSpec( + # `MLAAttentionSpec` is temporarily patched to `AscendMLAAttentionSpec`. + # Re-importing it at runtime will therefore resolve to the patched class. + # Rename it here to make this behavior explicit. + from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec + # TODO(rjg-lyh): when kv_cache_spec's refactor is ready, + # implement it by creating a new kv_cache_spec class + kv_cache_spec[layer_name] = AscendMLAAttentionSpec( block_size=self.block_size, num_kv_heads=1, - head_size=sparse_sum_head_size, + head_size=sum(self.sparse_head_dim), + sparse_head_dim=self.sparse_head_dim, dtype=self.kv_cache_dtype, cache_dtype_str=self.vllm_config.cache_config.cache_dtype, + cache_sparse_c8=self.use_sparse_c8_indexer, ) elif spec := attn_module.get_kv_cache_spec(self.vllm_config): - kv_cache_spec[layer_name] = spec + assert isinstance(spec, MLAAttentionSpec) + from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec + kv_cache_spec[layer_name] = AscendMLAAttentionSpec( + block_size=spec.block_size, + num_kv_heads=spec.num_kv_heads, + head_size=spec.head_size, + dtype=spec.dtype, + cache_dtype_str=spec.cache_dtype_str, + ) elif isinstance(attn_module, MambaBase): mamba_layers[layer_name] = attn_module @@ -3129,16 +3203,6 @@ class NPUModelRunner(GPUModelRunner): return kv_cache_spec - def _get_sparse_kv_cache_ratio(self) -> list[int]: - # TODO:If C8 is supported, we need to consider the number of bytes occupied by different dtypes - # when calculating the ratio,for example: - # [kv_lora_rank * torch.int8.itemsize, qk_rope_head_dim * torch.bfloat16.itemsize, ...] - return [ - self.model_config.hf_text_config.kv_lora_rank, - self.model_config.hf_text_config.qk_rope_head_dim, - self.model_config.hf_text_config.index_head_dim, - ] - def _check_and_update_cudagraph_mode( self, attention_backends: list[set[type[AttentionBackend]]],