diff --git a/.github/workflows/scripts/config.yaml b/.github/workflows/scripts/config.yaml index 25c49465..e3e80e4a 100644 --- a/.github/workflows/scripts/config.yaml +++ b/.github/workflows/scripts/config.yaml @@ -70,8 +70,6 @@ e2e-2card-light: estimated_time: 220 - name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8_pruning_mtp_tp2_ep estimated_time: 90 - - name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep - estimated_time: 90 - name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_gpt_oss_distributed_tp2 estimated_time: 180 @@ -122,8 +120,6 @@ e2e-multicard-2-cards: estimated_time: 71 - name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8_pruning_mtp_tp2_ep estimated_time: 111 - - name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep - estimated_time: 111 - name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_qwen3_w4a4_distributed_tp2 estimated_time: 180 - name: tests/e2e/multicard/2-cards/test_pipeline_parallel.py diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index 631ea4e3..5b11cbe2 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -25,7 +25,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH} - CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;lightning_indexer_quant;" + CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series @@ -67,7 +67,6 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then "copy_and_expand_eagle_inputs" "causal_conv1d" "moe_grouped_matmul" - "lightning_indexer_quant" ) CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}") SOC_ARG="ascend910_93" diff --git a/csrc/lightning_indexer_quant/lightning_indexer_quant_torch_adpt.h b/csrc/lightning_indexer_quant/lightning_indexer_quant_torch_adpt.h deleted file mode 100644 index fbc8284c..00000000 --- a/csrc/lightning_indexer_quant/lightning_indexer_quant_torch_adpt.h +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef LIGHTING_INDEXER_QUANT_VLLM_TORCH_ADPT_H -#define LIGHTING_INDEXER_QUANT_VLLM_TORCH_ADPT_H -namespace vllm_ascend { - -at::Tensor npu_lightning_indexer_quant( - const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights, - const at::Tensor &query_dequant_scale, const at::Tensor &key_dequant_scale, - const c10::optional &actual_seq_lengths_query, - const c10::optional &actual_seq_lengths_key, - const c10::optional &block_table, int64_t query_quant_mode, int64_t key_quant_mode, - c10::string_view layout_query, c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode) -{ - std::string query_layout_str = std::string(layout_query); - std::string key_layout_str = std::string(layout_key); - - const int SIZE = 8; - const int DIM_0 = 0; - const int DIM_1 = 1; - const int DIM_2 = 2; - const int DIM_3 = 3; - - at::SmallVector output_size; - for (size_t i = 0; i < query.sizes().size(); i++) { - TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater " - "than 0, but shape[", i, "] is ", query.size(i)); - } - for (size_t i = 0; i < key.sizes().size(); i++) { - TORCH_CHECK(key.size(i) > 0, "All values within key's shape should be greater " - "than 0, but shape[", i, "] is ", key.size(i)); - } - TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count); - int64_t keyHeadNum = (key_layout_str == "TND")? key.size(DIM_1) : key.size(DIM_2); - if (query_layout_str == "BSND") { - output_size = {query.size(DIM_0), query.size(DIM_1), keyHeadNum, sparse_count}; - } else { - output_size = {query.size(DIM_0), keyHeadNum, sparse_count}; - } - at::Tensor lightning_indexer_quant_output = at::empty(output_size, query.options().dtype(at::kInt)); - - // convert str - char *query_layout_ptr = const_cast(query_layout_str.c_str()); - char *key_layout_ptr = const_cast(key_layout_str.c_str()); - - EXEC_NPU_CMD(aclnnLightningIndexerQuant, - query, - key, - weights, - query_dequant_scale, - key_dequant_scale, - actual_seq_lengths_query, - actual_seq_lengths_key, - block_table, - query_quant_mode, - key_quant_mode, - query_layout_ptr, - key_layout_ptr, - sparse_count, - sparse_mode, - lightning_indexer_quant_output - ); - - return lightning_indexer_quant_output; - -} -} -#endif \ No newline at end of file diff --git a/csrc/lightning_indexer_quant/op_host/CMakeLists.txt b/csrc/lightning_indexer_quant/op_host/CMakeLists.txt deleted file mode 100644 index 252cfd85..00000000 --- a/csrc/lightning_indexer_quant/op_host/CMakeLists.txt +++ /dev/null @@ -1,41 +0,0 @@ -# This program is free software, you can redistribute it and/or modify it. -# Copyright (c) 2025 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -add_ops_compile_options( - OP_NAME LightningIndexerQuant - OPTIONS --cce-auto-sync=off - -Wno-deprecated-declarations - -Werror - -mllvm -cce-aicore-hoist-movemask=false - --op_relocatable_kernel_binary=true -) - -set(lightning_indexer_quant_depends transformer/attention/lightning_indexer_quant PARENT_SCOPE) - -target_sources(op_host_aclnn PRIVATE - lightning_indexer_quant_def.cpp -) - -target_sources(optiling PRIVATE - lightning_indexer_quant_tiling.cpp -) - -if (NOT BUILD_OPEN_PROJECT) - target_sources(opmaster_ct PRIVATE - lightning_indexer_quant_tiling.cpp - ) -endif () - -target_include_directories(optiling PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/op_host -) - -target_sources(opsproto PRIVATE - lightning_indexer_quant_proto.cpp -) diff --git a/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_def.cpp b/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_def.cpp deleted file mode 100644 index 7049c2d0..00000000 --- a/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_def.cpp +++ /dev/null @@ -1,85 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify it. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file lightning_indexer_quant_def.cpp - * \brief - */ -#include - -#include "register/op_def_registry.h" - -namespace ops { -class LightningIndexerQuant : public OpDef { -public: - explicit LightningIndexerQuant(const char *name) : OpDef(name) - { - this->Input("query") - .ParamType(REQUIRED) - .DataType({ge::DT_INT8}) - .Format({ge::FORMAT_ND}) - .AutoContiguous(); - this->Input("key") - .ParamType(REQUIRED) - .DataType({ge::DT_INT8}) - .Format({ge::FORMAT_ND}) - .AutoContiguous(); - this->Input("weights") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND}) - .AutoContiguous(); - this->Input("query_dequant_scale") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND}) - .AutoContiguous(); - this->Input("key_dequant_scale") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND}) - .AutoContiguous(); - this->Input("actual_seq_lengths_query") - .ParamType(OPTIONAL) - .DataType({ge::DT_INT32}) - .Format({ge::FORMAT_ND}) - .AutoContiguous(); - this->Input("actual_seq_lengths_key") - .ParamType(OPTIONAL) - .DataType({ge::DT_INT32}) - .Format({ge::FORMAT_ND}) - .AutoContiguous(); - this->Input("block_table") - .ParamType(OPTIONAL) - .DataType({ge::DT_INT32}) - .Format({ge::FORMAT_ND}) - .AutoContiguous(); - this->Output("sparse_indices").ParamType(REQUIRED).DataType({ge::DT_INT32}).Format({ge::FORMAT_ND}); - this->Attr("query_quant_mode").AttrType(REQUIRED).Int(0); // 0: 默认值,per-token-head - this->Attr("key_quant_mode").AttrType(REQUIRED).Int(0); // 0: 默认值,per-token-head - this->Attr("layout_query").AttrType(OPTIONAL).String("BSND"); - this->Attr("layout_key").AttrType(OPTIONAL).String("PA_BSND"); - this->Attr("sparse_count").AttrType(OPTIONAL).Int(2048); // 2048: 默认值,筛选前2048 - this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3); // 3: 默认值,只计算下三角 - OpAICoreConfig aicore_config; - aicore_config.DynamicCompileStaticFlag(true) - .DynamicFormatFlag(true) - .DynamicRankSupportFlag(true) - .DynamicShapeSupportFlag(true) - .NeedCheckSupportFlag(false) - .PrecisionReduceFlag(true) - .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") - .ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false"); - this->AICore().AddConfig("ascend910b", aicore_config); - this->AICore().AddConfig("ascend910_93", aicore_config); - } -}; -OP_ADD(LightningIndexerQuant); -} // namespace ops \ No newline at end of file diff --git a/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_proto.cpp b/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_proto.cpp deleted file mode 100644 index fb4539d6..00000000 --- a/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_proto.cpp +++ /dev/null @@ -1,91 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify it. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file lightning_indexer_quant_proto.cpp - * \brief - */ -#include -#include - -#include "error/ops_error.h" - -using namespace ge; - -namespace ops { -constexpr uint32_t QUERY_INDEX = 0; -constexpr uint32_t KEY_INDEX = 1; -constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 2; -constexpr uint32_t ATTR_KV_LAYOUT_INDEX = 3; -constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 4; - -static ge::graphStatus InferShapeLightningIndexerQuant(gert::InferShapeContext *context) -{ - if (context == nullptr) { - OPS_LOG_E("LightningIndexerQuant", "context is nullptr!"); - return ge::GRAPH_FAILED; - } - const gert::Shape *queryShape = context->GetInputShape(QUERY_INDEX); - OPS_LOG_E_IF_NULL(context, queryShape, return ge::GRAPH_FAILED); - const gert::Shape *keyShape = context->GetInputShape(KEY_INDEX); - OPS_LOG_E_IF_NULL(context, keyShape, return ge::GRAPH_FAILED); - gert::Shape *outShape = context->GetOutputShape(0); - - auto attrs = context->GetAttrs(); - OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED); - const char *inputLayoutQueryPtr = attrs->GetAttrPointer(ATTR_QUERY_LAYOUT_INDEX); - OPS_LOG_E_IF_NULL(context, inputLayoutQueryPtr, return ge::GRAPH_FAILED); - const char *inputLayoutKeyPtr = attrs->GetAttrPointer(ATTR_KV_LAYOUT_INDEX); - OPS_LOG_E_IF_NULL(context, inputLayoutKeyPtr, return ge::GRAPH_FAILED); - const int64_t *sparse_count = attrs->GetInt(ATTR_SPARSE_COUNT_INDEX); - OPS_LOG_E_IF_NULL(context, sparse_count, return ge::GRAPH_FAILED); - - std::string inputLayoutQueryPtrStr = std::string(inputLayoutQueryPtr); - std::string inputLayoutKeyPtrStr = std::string(inputLayoutKeyPtr); - if (inputLayoutQueryPtrStr != "TND" && inputLayoutQueryPtrStr != "BSND") { - OPS_LOG_E(context, "The input layout query should be TND or BSND, but got %s.", inputLayoutQueryPtrStr.c_str()); - return GRAPH_FAILED; - } - - outShape->SetDimNum(queryShape->GetDimNum()); - int64_t keyHeadNum = (inputLayoutKeyPtrStr == "TND") ? keyShape->GetDim(1) : keyShape->GetDim(2); - if (inputLayoutQueryPtrStr == "BSND") { - outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim B - outShape->SetDim(1, queryShape->GetDim(1)); // 1:Dim S - outShape->SetDim(2, keyHeadNum); // 2:Dim N - outShape->SetDim(3, *sparse_count); // 3:Dim K - } else { - outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim T - outShape->SetDim(1, keyHeadNum); // 1:output shape's N Dim, 2: key shape's N Dim - outShape->SetDim(2, *sparse_count); // 2:Dim K - } - - OPS_LOG_D(context->GetNodeName(), "LightningIndexerQuant InferShape end."); - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus InferDataTypeLightningIndexerQuant(gert::InferDataTypeContext *context) -{ - if (context == nullptr) { - OPS_LOG_E("LightningIndexerQuant", "InferDataTypeContext context is nullptr!"); - return ge::GRAPH_FAILED; - } - OPS_LOG_D(context->GetNodeName(), "Enter LightningIndexerQuant InferDataType impl."); - // default index data type is int32 - ge::DataType outputType = ge::DT_INT32; - context->SetOutputDataType(0, outputType); - OPS_LOG_D(context->GetNodeName(), "LightningIndexerQuant InferDataType end."); - return GRAPH_SUCCESS; -} - -IMPL_OP_INFERSHAPE(LightningIndexerQuant) - .InferShape(InferShapeLightningIndexerQuant) - .InferDataType(InferDataTypeLightningIndexerQuant); -} // namespace ops diff --git a/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_tiling.cpp b/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_tiling.cpp deleted file mode 100644 index 042f0271..00000000 --- a/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_tiling.cpp +++ /dev/null @@ -1,828 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify it. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file lightning_indexer_quant_tiling.cpp - * \brief - */ - -#include "lightning_indexer_quant_tiling.h" - -#include "../op_kernel/lightning_indexer_quant_template_tiling_key.h" - -using namespace ge; -using namespace AscendC; -using std::map; -using std::string; -namespace optiling { -// --------------------------LIQInfoParser类成员函数定义------------------------------------- -ge::graphStatus LIQInfoParser::CheckRequiredInOutExistence() const -{ - OPS_ERR_IF(opParamInfo_.query.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor query is nullptr"), - return ge::GRAPH_FAILED); - OPS_ERR_IF(opParamInfo_.query.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor query is nullptr"), - return ge::GRAPH_FAILED); - OPS_ERR_IF(opParamInfo_.key.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor key is nullptr"), - return ge::GRAPH_FAILED); - OPS_ERR_IF(opParamInfo_.key.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor key is nullptr"), - return ge::GRAPH_FAILED); - OPS_ERR_IF(opParamInfo_.weights.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor weights is nullptr"), - return ge::GRAPH_FAILED); - OPS_ERR_IF(opParamInfo_.weights.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor weights is nullptr"), - return ge::GRAPH_FAILED); - OPS_ERR_IF(opParamInfo_.query_dequant_scale.shape == nullptr, - OPS_LOG_E(opName_, "Shape of tensor query_dequant_scale is nullptr"), return ge::GRAPH_FAILED); - OPS_ERR_IF(opParamInfo_.query_dequant_scale.desc == nullptr, - OPS_LOG_E(opName_, "Desc of tensor query_dequant_scale is nullptr"), return ge::GRAPH_FAILED); - OPS_ERR_IF(opParamInfo_.key_dequant_scale.shape == nullptr, - OPS_LOG_E(opName_, "Shape of tensor key_dequant_scale is nullptr"), return ge::GRAPH_FAILED); - OPS_ERR_IF(opParamInfo_.key_dequant_scale.desc == nullptr, - OPS_LOG_E(opName_, "Desc of tensor key_dequant_scale is nullptr"), return ge::GRAPH_FAILED); - OPS_ERR_IF(opParamInfo_.attenOut.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor output is nullptr"), - return ge::GRAPH_FAILED); - OPS_ERR_IF(opParamInfo_.attenOut.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor output is nullptr"), - return ge::GRAPH_FAILED); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::CheckRequiredAttrExistence() const -{ - OPS_ERR_IF(opParamInfo_.layOutQuery == nullptr, OPS_LOG_E(opName_, "attr layout_query is nullptr"), - return ge::GRAPH_FAILED); - - OPS_ERR_IF(opParamInfo_.layOutKey == nullptr, OPS_LOG_E(opName_, "attr layout_key is nullptr"), - return ge::GRAPH_FAILED); - - OPS_ERR_IF(opParamInfo_.sparseCount == nullptr, OPS_LOG_E(opName_, "attr sparse_count is nullptr"), - return ge::GRAPH_FAILED); - - OPS_ERR_IF(opParamInfo_.sparseMode == nullptr, OPS_LOG_E(opName_, "attr sparse_mode is nullptr"), - return ge::GRAPH_FAILED); - OPS_ERR_IF(opParamInfo_.queryQuantMode == nullptr, OPS_LOG_E(opName_, "query_quant_mode is nullptr"), - return ge::GRAPH_FAILED); - OPS_ERR_IF(opParamInfo_.keyQuantMode == nullptr, OPS_LOG_E(opName_, "key_quant_mode is nullptr"), - return ge::GRAPH_FAILED); - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::CheckRequiredParaExistence() const -{ - if (CheckRequiredInOutExistence() != ge::GRAPH_SUCCESS || CheckRequiredAttrExistence() != ge::GRAPH_SUCCESS) { - return ge::GRAPH_FAILED; - } - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::GetOpName() -{ - if (context_->GetNodeName() == nullptr) { - OPS_LOG_E("LightningIndexerQuant", "opName got from TilingContext is nullptr"); - return ge::GRAPH_FAILED; - } - opName_ = context_->GetNodeName(); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::GetNpuInfo() -{ - platformInfo_ = context_->GetPlatformInfo(); - OPS_ERR_IF(platformInfo_ == nullptr, OPS_LOG_E(opName_, "GetPlatformInfo is nullptr."), return ge::GRAPH_FAILED); - - auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo_); - uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); - uint32_t aicNum = ascendcPlatform.GetCoreNumAic(); - OPS_ERR_IF(aicNum == 0 || aivNum == 0, OPS_LOG_E(opName_, "num of core obtained is 0."), return GRAPH_FAILED); - - socVersion_ = ascendcPlatform.GetSocVersion(); - if ((socVersion_ != platform_ascendc::SocVersion::ASCEND910B) && - (socVersion_ != platform_ascendc::SocVersion::ASCEND910_93)) { - OPS_LOG_E(opName_, "SOC Version[%d] is not support.", (int32_t)socVersion_); - return GRAPH_FAILED; - } - OPS_ERR_IF(context_->GetWorkspaceSizes(1) == nullptr, OPS_LOG_E(opName_, "workSpaceSize got from ge is nullptr"), - return ge::GRAPH_FAILED); - OPS_ERR_IF(context_->GetRawTilingData() == nullptr, - OPS_LOG_E(context_->GetNodeName(), "RawTilingData got from GE context is nullptr."), - return ge::GRAPH_FAILED); - - return ge::GRAPH_SUCCESS; -} - -void LIQInfoParser::GetOptionalInputParaInfo() -{ - opParamInfo_.actualSeqLengthsQ.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_Q_INDEX); - opParamInfo_.actualSeqLengthsQ.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_Q_INDEX); - opParamInfo_.actualSeqLengthsK.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_K_INDEX); - opParamInfo_.actualSeqLengthsK.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_K_INDEX); - opParamInfo_.blockTable.tensor = context_->GetOptionalInputTensor(BLOCK_TABLE_INDEX); - opParamInfo_.blockTable.desc = context_->GetOptionalInputDesc(BLOCK_TABLE_INDEX); -} - -void LIQInfoParser::GetInputParaInfo() -{ - opParamInfo_.query.desc = context_->GetInputDesc(QUERY_INDEX); - opParamInfo_.query.shape = context_->GetInputShape(QUERY_INDEX); - opParamInfo_.key.desc = context_->GetInputDesc(KEY_INDEX); - opParamInfo_.key.shape = context_->GetInputShape(KEY_INDEX); - opParamInfo_.weights.desc = context_->GetInputDesc(WEIGTHS_INDEX); - opParamInfo_.weights.shape = context_->GetInputShape(WEIGTHS_INDEX); - opParamInfo_.query_dequant_scale.desc = context_->GetInputDesc(QUERY_DEQUANT_SCALE_INDEX); - opParamInfo_.query_dequant_scale.shape = context_->GetInputShape(QUERY_DEQUANT_SCALE_INDEX); - opParamInfo_.key_dequant_scale.desc = context_->GetInputDesc(KEY_DEQUANT_SCALE_INDEX); - opParamInfo_.key_dequant_scale.shape = context_->GetInputShape(KEY_DEQUANT_SCALE_INDEX); - GetOptionalInputParaInfo(); -} - -void LIQInfoParser::GetOutputParaInfo() -{ - opParamInfo_.attenOut.desc = context_->GetOutputDesc(LIGHTNING_INDEXER_QUANT); - opParamInfo_.attenOut.shape = context_->GetOutputShape(LIGHTNING_INDEXER_QUANT); -} - -ge::graphStatus LIQInfoParser::GetAttrParaInfo() -{ - auto attrs = context_->GetAttrs(); - OPS_ERR_IF(attrs == nullptr, OPS_REPORT_VECTOR_INNER_ERR(context_->GetNodeName(), "attrs got from ge is nullptr"), - return ge::GRAPH_FAILED); - - OPS_LOG_I(context_->GetNodeName(), "GetAttrParaInfo start"); - opParamInfo_.layOutQuery = attrs->GetStr(ATTR_QUERY_LAYOUT_INDEX); - opParamInfo_.layOutKey = attrs->GetStr(ATTR_KEY_LAYOUT_INDEX); - - opParamInfo_.queryQuantMode = attrs->GetAttrPointer(ATTR_QUERY_QUANT_MODE_INDEX); - opParamInfo_.keyQuantMode = attrs->GetAttrPointer(ATTR_KEY_QUANT_MODE_INDEX); - opParamInfo_.layOutQuery = attrs->GetStr(ATTR_QUERY_LAYOUT_INDEX); - opParamInfo_.layOutKey = attrs->GetStr(ATTR_KEY_LAYOUT_INDEX); - opParamInfo_.sparseCount = attrs->GetAttrPointer(ATTR_SPARSE_COUNT_INDEX); - opParamInfo_.sparseMode = attrs->GetAttrPointer(ATTR_SPARSE_MODE_INDEX); - - if (opParamInfo_.layOutQuery != nullptr) { - OPS_LOG_I(context_->GetNodeName(), "layout_query is:%s", opParamInfo_.layOutQuery); - } - if (opParamInfo_.layOutKey != nullptr) { - OPS_LOG_I(context_->GetNodeName(), "layout_key is:%s", opParamInfo_.layOutKey); - } - if (opParamInfo_.sparseCount != nullptr) { - OPS_LOG_I(context_->GetNodeName(), "selscted count is:%d", *opParamInfo_.sparseCount); - } - if (opParamInfo_.sparseMode != nullptr) { - OPS_LOG_I(context_->GetNodeName(), "sparse mode is:%d", *opParamInfo_.sparseMode); - } - if (opParamInfo_.queryQuantMode != nullptr) { - OPS_LOG_I(context_->GetNodeName(), "query_quant_mode mode is:%d", *opParamInfo_.queryQuantMode); - } - if (opParamInfo_.keyQuantMode != nullptr) { - OPS_LOG_I(context_->GetNodeName(), "key_quant_mode mode is:%d", *opParamInfo_.keyQuantMode); - } - OPS_LOG_I(context_->GetNodeName(), "GetAttrParaInfo end"); - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::CheckAttrParaInfo() -{ - std::string layout_key(opParamInfo_.layOutKey); - std::string layout_query(opParamInfo_.layOutQuery); - OPS_ERR_IF( - ((std::string(opParamInfo_.layOutKey) == "BNSD") || (std::string(opParamInfo_.layOutKey) == "PA_BBND")), - OPS_LOG_E(opName_, "input attr layout_key only supported PA_BSND, PA_BBND, BSND or TND" - "but now layout_key is %s.", layout_key.c_str()), - return ge::GRAPH_FAILED); - OPS_ERR_IF(((std::string(opParamInfo_.layOutQuery) != "BSND") && (std::string(opParamInfo_.layOutQuery) != "TND")), - OPS_LOG_E(opName_, "input attr layout_query only supported BSND or TND."), return ge::GRAPH_FAILED); - OPS_ERR_IF( - ((std::string(opParamInfo_.layOutKey) != "PA_BSND") && - (std::string(opParamInfo_.layOutQuery)) != (std::string(opParamInfo_.layOutKey))), - OPS_LOG_E(opName_, "outside of PA, input attr layout_query and input attr layout_key must be the same, but now layout_key is %s, layout_query is %s.", - layout_key.c_str(), layout_query.c_str()), return ge::GRAPH_FAILED); - OPS_ERR_IF(!((*opParamInfo_.sparseCount > 0) && (*opParamInfo_.sparseCount <= SPARSE_LIMIT)), - OPS_LOG_E(opName_, "input attr sparse_count must > 0 and <= 2048."), return ge::GRAPH_FAILED); - OPS_ERR_IF(!((*opParamInfo_.sparseMode == 0) || (*opParamInfo_.sparseMode == SPARSE_MODE_LOWER)), - OPS_LOG_E(opName_, "input attr sparse_mode only supported 0 or 3, but now is %u.", - *opParamInfo_.sparseMode), return ge::GRAPH_FAILED); - - OPS_ERR_IF(*opParamInfo_.queryQuantMode != 0, OPS_LOG_E(opName_, "input attr query_quant_mode only supported 0."), - return ge::GRAPH_FAILED); - - OPS_ERR_IF(*opParamInfo_.keyQuantMode != 0, OPS_LOG_E(opName_, "input attr key_quant_mode only supported 0."), - return ge::GRAPH_FAILED); - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::GetOpParaInfo() -{ - GetInputParaInfo(); - GetOutputParaInfo(); - if (ge::GRAPH_SUCCESS != GetAttrParaInfo()) { - return ge::GRAPH_FAILED; - } - if (ge::GRAPH_SUCCESS != CheckAttrParaInfo()) { - return ge::GRAPH_FAILED; - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::GetAndCheckInOutDataType() -{ - inputQType_ = opParamInfo_.query.desc->GetDataType(); - inputKType_ = opParamInfo_.key.desc->GetDataType(); - weightsType_ = opParamInfo_.weights.desc->GetDataType(); - inputQueryScaleType_ = opParamInfo_.query_dequant_scale.desc->GetDataType(); - inputKeyScaleType_ = opParamInfo_.key_dequant_scale.desc->GetDataType(); - outputType_ = opParamInfo_.attenOut.desc->GetDataType(); - - OPS_ERR_IF(!(inputQType_ == inputKType_), - OPS_LOG_E(opName_, "The data types of the input query and key must be the same, but now is %s, %s respectively.", - inputQType_, inputKType_), - return ge::GRAPH_FAILED); - - OPS_ERR_IF( - !(inputQueryScaleType_ == inputKeyScaleType_), - OPS_LOG_E(opName_, "The data types of the input query_dequant_scale and key_dequant_scale must be the same."), - return ge::GRAPH_FAILED); - - OPS_ERR_IF(inputQType_ != ge::DT_INT8, - OPS_LOG_E(opName_, "The data types of the input query and key must be int8."), return ge::GRAPH_FAILED); - - OPS_ERR_IF(weightsType_ != ge::DT_FLOAT16, - OPS_LOG_E(opName_, "The data types of the input weights must be float16."), return ge::GRAPH_FAILED); - - OPS_ERR_IF( - inputQueryScaleType_ != ge::DT_FLOAT16, - OPS_LOG_E(opName_, "The data types of the input query_dequant_scale and key_dequant_scale must be float16."), - return ge::GRAPH_FAILED); - - OPS_ERR_IF(outputType_ != ge::DT_INT32, - OPS_LOG_E(opName_, "The data types of the output sparse_indices must be int32."), - return ge::GRAPH_FAILED); - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::GetQueryKeyAndOutLayout() -{ - // 获取query,key的Layout基准值 - const map layoutQueryMap = {{"BSND", DataLayout::BSND}, {"TND", DataLayout::TND}}; - - std::string layout_query(opParamInfo_.layOutQuery); - auto QLayout_ = layoutQueryMap.find(layout_query); - if (QLayout_ != layoutQueryMap.end()) { - qLayout_ = QLayout_->second; - } - - const map layoutKeyMap = { - {"BSND", DataLayout::BSND}, {"TND", DataLayout::TND}, - {"PA_BSND", DataLayout::PA_BSND}, {"PA_BBND", DataLayout::PA_BSND}}; - std::string layout_key(opParamInfo_.layOutKey); - auto KLayout = layoutKeyMap.find(layout_key); - if (KLayout != layoutKeyMap.end()) { - kLayout_ = KLayout->second; - } - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::GetAndCheckOptionalInput() -{ - if (kLayout_ == DataLayout::PA_BSND) { - OPS_ERR_IF(opParamInfo_.blockTable.tensor == nullptr, - OPS_LOG_E(opName_, "key layout only supported PA_BSND, input block_table must not be null"), - return ge::GRAPH_FAILED); - OPS_ERR_IF( - opParamInfo_.actualSeqLengthsK.tensor == nullptr, - OPS_LOG_E(opName_, "key layout only supported PA_BSND, input actual_seq_lengths_key must not be null"), - return ge::GRAPH_FAILED); - OPS_ERR_IF(opParamInfo_.blockTable.desc->GetDataType() != ge::DT_INT32, - OPS_LOG_E(opName_, "input block_table data type only support int32"), return ge::GRAPH_FAILED); - } else { - OPS_ERR_IF(opParamInfo_.blockTable.tensor != nullptr, - OPS_LOG_E(opName_, "key layout is not PA_BSND, input block_table must be null"), - return ge::GRAPH_FAILED); - } - - if (kLayout_ == DataLayout::TND) { - OPS_ERR_IF(opParamInfo_.actualSeqLengthsK.tensor == nullptr, - OPS_LOG_E(opName_, "when layout_key is TND, input actual_seq_lengths_key must not be null"), - return ge::GRAPH_FAILED); - } - OPS_ERR_IF(opParamInfo_.actualSeqLengthsK.tensor != nullptr && - opParamInfo_.actualSeqLengthsK.desc->GetDataType() != ge::DT_INT32, - OPS_LOG_E(opName_, "input actual_seq_lengths_key data type only support int32"), - return ge::GRAPH_FAILED); - if (qLayout_ == DataLayout::TND) { - OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor == nullptr, - OPS_LOG_E(opName_, "when layout_query is TND, input actual_seq_lengths_query must not be null"), - return ge::GRAPH_FAILED); - } - OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor != nullptr && - opParamInfo_.actualSeqLengthsQ.desc->GetDataType() != ge::DT_INT32, - OPS_LOG_E(opName_, "input actual_seq_lengths_query data type only support int32"), - return ge::GRAPH_FAILED); - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::CheckShapeDim() -{ - OPS_ERR_IF((opParamInfo_.blockTable.tensor != nullptr) && - (opParamInfo_.blockTable.tensor->GetStorageShape().GetDimNum() != DIM_NUM_TWO), - OPS_LOG_E(opName_, "the dim num of block_table's shape should be 2, but now is %u", - opParamInfo_.blockTable.tensor->GetStorageShape().GetDimNum()), return ge::GRAPH_FAILED); - OPS_ERR_IF( - (kLayout_ == DataLayout::PA_BSND) && (opParamInfo_.key.shape->GetStorageShape().GetDimNum() != DIM_NUM_FOUR), - OPS_LOG_E(opName_, "the dim num of key's shape should be 4, but now is %u", - opParamInfo_.key.shape->GetStorageShape().GetDimNum()), return ge::GRAPH_FAILED); - - uint32_t qShapeDim = opParamInfo_.query.shape->GetStorageShape().GetDimNum(); - uint32_t weightsShapeDim = opParamInfo_.weights.shape->GetStorageShape().GetDimNum(); - uint32_t outShapeDim = opParamInfo_.attenOut.shape->GetStorageShape().GetDimNum(); - uint32_t expectShapeDim = DIM_NUM_FOUR; - if (qLayout_ == DataLayout::TND) { - expectShapeDim = DIM_NUM_THREE; - } - OPS_ERR_IF( - qShapeDim != expectShapeDim, - OPS_LOG_E(opName_, "the dim num of query's shape should be %u, but now is %u", expectShapeDim, qShapeDim), - return ge::GRAPH_FAILED); - OPS_ERR_IF(outShapeDim != expectShapeDim, - OPS_LOG_E(opName_, "the dim num of sparse_indices's shape should be %u, but now is %u", expectShapeDim, - outShapeDim), - return ge::GRAPH_FAILED); - OPS_ERR_IF(!(weightsShapeDim == expectShapeDim - 1), - OPS_LOG_E(opName_, "the dim num of weights's shape should be %u, but now is %u", expectShapeDim - 1, - weightsShapeDim), - return ge::GRAPH_FAILED); - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::GetN1Size() -{ - if (qLayout_ == DataLayout::BSND) { - n1Size_ = static_cast(opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_TWO)); - } else { - // TND - n1Size_ = static_cast(opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ONE)); - } - OPS_LOG_I(context_->GetNodeName(), "n1Size is %d", n1Size_); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor, - const std::string &actualSeqLenName) -{ - size = static_cast(tensor->GetShapeSize()); - if (size <= 0) { - OPS_LOG_E(opName_, "%s's shape size is %u, it should be greater than 0.", actualSeqLenName.c_str(), size); - return ge::GRAPH_FAILED; - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::GetAndCheckN2Size() -{ - // PA_BSND - if (kLayout_ == DataLayout::TND) { - n2Size_ = static_cast(opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_ONE)); - } else { - n2Size_ = static_cast(opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_TWO)); - } - OPS_LOG_I(context_->GetNodeName(), "N2 is %d", n2Size_); - OPS_ERR_IF(n2Size_ != 1, OPS_LOG_E(opName_, "key shape[2] is numhead, only support 1."), return ge::GRAPH_FAILED); - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::GetGSize() -{ - if (n1Size_ % n2Size_ != 0) { - OPS_LOG_E(opName_, "input query's head_num %u can not be a multiple of key's head_num %u.", n1Size_, n2Size_); - return ge::GRAPH_FAILED; - } - gSize_ = n1Size_ / n2Size_; - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::GetBatchSize() -{ - // 获取B基准值 - // 1、非TND/NTD时, 以query的batch_size维度为基准; - // 2、TND/NTD时, actual_seq_lens_q必须传入, 以actual_seq_lens_q数组的长度为B轴大小 - if (qLayout_ == DataLayout::TND) { - return GetActualSeqLenSize(bSize_, opParamInfo_.actualSeqLengthsQ.tensor, "input actual_seq_lengths_query"); - } else { // BSND - bSize_ = opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ZERO); - OPS_LOG_I(context_->GetNodeName(), "b: %d, s: %d, n: %d,d :%d", - opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ZERO), - opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ONE), - opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_TWO), - opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_THREE)); - return ge::GRAPH_SUCCESS; - } -} - -ge::graphStatus LIQInfoParser::GetHeadDim() -{ - // 以query的D维度为基准 - uint32_t dIndex = DIM_IDX_TWO; - // 根据layout确定D维度在shape中的位置 - switch (qLayout_) { - case DataLayout::TND: - // TND格式: [Total, N, D] -> D是第2维(索引2) - dIndex = DIM_IDX_TWO; - break; - case DataLayout::BSND: - // BSND格式: [Batch, SeqLen, N, D] -> D是第3维(索引3) - dIndex = DIM_IDX_THREE; - break; - default: - OPS_LOG_E(opName_, "unsupported layout for getting head dim."); - return ge::GRAPH_FAILED; - } - headDim_ = opParamInfo_.query.shape->GetStorageShape().GetDim(dIndex); - OPS_ERR_IF(headDim_ != HEAD_DIM_LIMIT, OPS_LOG_E(opName_, "input query's last dim head_dim only support 128, but now is %u.", headDim_), - return ge::GRAPH_FAILED); - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::GetS1Size() -{ - if (qLayout_ == DataLayout::BSND) { - s1Size_ = opParamInfo_.query.shape->GetStorageShape().GetDim(1); - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::GetAndCheckBlockSize() -{ - blockSize_ = static_cast(opParamInfo_.key.shape->GetStorageShape().GetDim(1)); - OPS_LOG_I(context_->GetNodeName(), "blockSize_ is %d", blockSize_); - - OPS_ERR_IF( - ((blockSize_ % BLOCK_SIZE_FACTOR != 0) || (blockSize_ == 0) || (blockSize_ > BLOCK_SIZE_LIMIT)), - OPS_LOG_E(opName_, "input key's block_size must be a multiple of 16 and belong to (0, 1024], but now is %u.", - blockSize_), - return ge::GRAPH_FAILED); - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::GetS2SizeForPageAttention() -{ - if (GetAndCheckBlockSize() != ge::GRAPH_SUCCESS) { - return ge::GRAPH_FAILED; - } - - int32_t blockCount_ = static_cast(opParamInfo_.key.shape->GetStorageShape().GetDim(0)); - OPS_ERR_IF((blockCount_ == 0), OPS_LOG_E(opName_, "input key's block_count cannot be 0."), return ge::GRAPH_FAILED); - - maxBlockNumPerBatch_ = opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(1); - s2Size_ = maxBlockNumPerBatch_ * blockSize_; - OPS_LOG_I(context_->GetNodeName(), "maxBlockNumPerBatch_ is %d, blockSize_ is %d, s2Size_ is %d", - maxBlockNumPerBatch_, blockSize_, s2Size_); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::GetS2SizeForBatchContinuous() -{ - std::string layout_key(opParamInfo_.layOutKey); - if (kLayout_ == DataLayout::BSND) { - s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_ONE); - } else if (kLayout_ == DataLayout::TND) { - s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_ZERO); - } - OPS_ERR_IF((kLayout_ != DataLayout::BSND) && (kLayout_ != DataLayout::TND), - OPS_LOG_E(opName_, "the layout of key is %s, it is unsupported.", layout_key.c_str()), - return ge::GRAPH_FAILED); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::GetS2Size() -{ - // 获取S2基准值 - // 1、BATCH_CONTINUOUS时, 从key的S轴获取 - // 3、PAGE_ATTENTION时, S2 = block_table.dim1 * block_size - if (kLayout_ == DataLayout::PA_BSND) { - return GetS2SizeForPageAttention(); - } - return GetS2SizeForBatchContinuous(); -} - -ge::graphStatus LIQInfoParser::ValidateInputShapesMatch() -{ - /* - TND: - query [T,N1,D], - key [BlockNum,BlockSize,N2,D], - weight [T,N1], - block_table [BatchSize, BatchMaxBlockNum], - act_seq_k [BatchSize] - act_seq_q [BatchSize], - out [T,N2,topk] - ---------------------- - BSND: - query [BatchSize,S1,N1,D], - key [BlockNum,BlockSize,N2,D], - weight [BatchSize,S1,N1], - block_table [BatchSize, BatchMaxBlockNum], - act_seq_k [BatchSize] - act_seq_q [BatchSize] 可选 - out [BatchSize,S1,N2,topk] - */ - uint32_t queryWeightsN1Dim = 1; - uint32_t outN2Dim = 1; - - if (qLayout_ == DataLayout::TND) { - // -----------------------check BatchSize------------------- - // bSize_ 来源于act_seq_q - OPS_ERR_IF((kLayout_ == DataLayout::PA_BSND) && - ((opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_) || - (opParamInfo_.blockTable.tensor != nullptr && - opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_)), - OPS_LOG_E( - opName_, - "TND case input actual_seq_lengths_query, actual_seq_lengths_key, block_table dim 0 are %u, %u, %u respectively, they must be same.", - bSize_, opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize(), - opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0)), - return ge::GRAPH_FAILED); - OPS_ERR_IF((kLayout_ != DataLayout::PA_BSND) && - (opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_), - OPS_LOG_E( - opName_, - "TND case input actual_seq_lengths_query, actual_seq_lengths_key, are %u, %u respectively, they must be same.", - bSize_, opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize()), - return ge::GRAPH_FAILED); - // -----------------------check T------------------- - uint32_t qTsize = opParamInfo_.query.shape->GetStorageShape().GetDim(0); - OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != qTsize) || - (opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != qTsize), - OPS_LOG_E(opName_, - "TND case input query, weights, sparse_indices dim 0 are %u, %u, %u respectively, they must be same.", - qTsize, opParamInfo_.weights.shape->GetStorageShape().GetDim(0), - opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)), - return ge::GRAPH_FAILED); - } else { - // -----------------------check BatchSize------------------- - // bSize_ 来源于query - OPS_ERR_IF((kLayout_ == DataLayout::PA_BSND) && - ((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != bSize_) || - (opParamInfo_.blockTable.tensor != nullptr && - opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_) || - (opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_) || - (opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != bSize_)), - OPS_LOG_E(opName_, - "BSND case input query, weight, actual_seq_lengths_key, block_table, sparse_indices dim 0 are %u, %u, %u, %u, %u respectively, they must be same.", - bSize_, opParamInfo_.weights.shape->GetStorageShape().GetDim(0), - opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize(), - opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0), - opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)), - return ge::GRAPH_FAILED); - OPS_ERR_IF((kLayout_ != DataLayout::PA_BSND) && - ((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != bSize_) || - (opParamInfo_.actualSeqLengthsK.tensor != nullptr && - opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_) || - (opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != bSize_)), - OPS_LOG_E(opName_, - "BSND case input query, weight, actual_seq_lengths_key, sparse_indices dim 0 are %u, %u, %u, %u respectively, they must be same.", - bSize_, opParamInfo_.weights.shape->GetStorageShape().GetDim(0), - opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize(), - opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)), - return ge::GRAPH_FAILED); - OPS_ERR_IF( - (opParamInfo_.actualSeqLengthsQ.tensor != nullptr) && - (opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize() != bSize_), - OPS_LOG_E( - opName_, - "BSND case input query, actual_seq_lengths_query dim 0 are %u, %u respectively, they must be same", - bSize_, opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize()), - return ge::GRAPH_FAILED); - // -----------------------check S1------------------- - OPS_ERR_IF( - (opParamInfo_.weights.shape->GetStorageShape().GetDim(1) != s1Size_) || - (opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1) != s1Size_), - OPS_LOG_E(opName_, "BSND case input query, weight, sparse_indices dim 1 are %u, %u, %u, they must be same.", - s1Size_, opParamInfo_.weights.shape->GetStorageShape().GetDim(1), - opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1)), - return ge::GRAPH_FAILED); - queryWeightsN1Dim = DIM_IDX_TWO; - outN2Dim = DIM_IDX_TWO; - } - // -----------------------check N1------------------- - OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(queryWeightsN1Dim) != n1Size_), - OPS_LOG_E(opName_, "input query, weight shape dim N1 must be same, but now are %u, %u respectively.", - opParamInfo_.weights.shape->GetStorageShape().GetDim(queryWeightsN1Dim), n1Size_), - return ge::GRAPH_FAILED); - // -----------------------check D------------------- - OPS_ERR_IF( - ((kLayout_ != DataLayout::TND && opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_THREE) != headDim_) - || (kLayout_ == DataLayout::TND && opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_TWO) != headDim_)), - OPS_LOG_E(opName_, "input query, key shape last dim must be same, now are %u, %u respectively.", - headDim_, opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_THREE)), - return ge::GRAPH_FAILED); - // -----------------------check N2------------------- - OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim) != n2Size_), - OPS_LOG_E(opName_, "input query and output sparse_indices shape n2 dim must be same."), - return ge::GRAPH_FAILED); - // -----------------------check sparse_count------------------- - OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim + 1) != *opParamInfo_.sparseCount), - OPS_LOG_E(opName_, "output sparse_indices shape last dim must be same as attr sparse_count."), - return ge::GRAPH_FAILED); - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LIQInfoParser::CheckScaleShape() -{ - uint32_t qShapeDim = opParamInfo_.query.shape->GetStorageShape().GetDimNum(); - uint32_t kShapeDim = opParamInfo_.key.shape->GetStorageShape().GetDimNum(); - uint32_t qDequantScaleShapeDim = opParamInfo_.query_dequant_scale.shape->GetStorageShape().GetDimNum(); - uint32_t kDequantScaleShapeDim = opParamInfo_.key_dequant_scale.shape->GetStorageShape().GetDimNum(); - OPS_ERR_IF(qDequantScaleShapeDim != (qShapeDim - 1), - OPS_LOG_E(opName_, "the dim num of query_dequant_scale's shape should be %u, but now is %u", - qShapeDim - 1, qDequantScaleShapeDim), - return ge::GRAPH_FAILED); - OPS_ERR_IF(kDequantScaleShapeDim != (kShapeDim - 1), - OPS_LOG_E(opName_, "the dim num of key_dequant_scale's shape should be %u, but now is %u", kShapeDim - 1, - kDequantScaleShapeDim), - return ge::GRAPH_FAILED); - // check q scale - for (uint32_t i = 0; i < (qShapeDim - 1); i++) { - uint32_t dimValueQueryScale = opParamInfo_.query_dequant_scale.shape->GetStorageShape().GetDim(i); - uint32_t dimValueQuery = opParamInfo_.query.shape->GetStorageShape().GetDim(i); - OPS_ERR_IF(dimValueQueryScale != dimValueQuery, - OPS_LOG_E(opName_, "query_dequant_scale's shape[%u] %u and query's shape[%u] %u is not same", i, - dimValueQueryScale, i, dimValueQuery), - return ge::GRAPH_FAILED); - } - // check k scale - for (uint32_t i = 0; i < (kShapeDim - 1); i++) { - uint32_t dimValueKeyScale = opParamInfo_.key_dequant_scale.shape->GetStorageShape().GetDim(i); - uint32_t dimValueKey = opParamInfo_.key.shape->GetStorageShape().GetDim(i); - OPS_ERR_IF(dimValueKeyScale != dimValueKey, - OPS_LOG_E(opName_, "key_dequant_scale's shape[%u] %u and key's shape[%u] %u is not same", i, - dimValueKeyScale, i, dimValueKey), - return ge::GRAPH_FAILED); - } - - return ge::GRAPH_SUCCESS; -} - -void LIQInfoParser::GenerateInfo(LIQTilingInfo &liqInfo) -{ - liqInfo.opName = opName_; - liqInfo.platformInfo = platformInfo_; - liqInfo.opParamInfo = opParamInfo_; - liqInfo.socVersion = socVersion_; - - liqInfo.bSize = bSize_; - liqInfo.n1Size = n1Size_; - liqInfo.n2Size = n2Size_; - liqInfo.s1Size = s1Size_; - liqInfo.s2Size = s2Size_; - liqInfo.gSize = gSize_; - - liqInfo.inputQType = inputQType_; - liqInfo.inputKType = inputKType_; - liqInfo.outputType = outputType_; - - liqInfo.blockSize = blockSize_; - liqInfo.maxBlockNumPerBatch = maxBlockNumPerBatch_; - - liqInfo.pageAttentionFlag = (kLayout_ == DataLayout::PA_BSND); - liqInfo.sparseMode = *opParamInfo_.sparseMode; - liqInfo.sparseCount = *opParamInfo_.sparseCount; - - liqInfo.inputQLayout = qLayout_; - liqInfo.inputKLayout = kLayout_; -} - -ge::graphStatus LIQInfoParser::ParseAndCheck(LIQTilingInfo &liqInfo) -{ - if (ge::GRAPH_SUCCESS != GetOpName() || ge::GRAPH_SUCCESS != GetNpuInfo() || ge::GRAPH_SUCCESS != GetOpParaInfo() || - ge::GRAPH_SUCCESS != CheckRequiredParaExistence()) { - return ge::GRAPH_FAILED; - } - - if (ge::GRAPH_SUCCESS != GetAndCheckInOutDataType() || ge::GRAPH_SUCCESS != GetQueryKeyAndOutLayout() || - ge::GRAPH_SUCCESS != GetAndCheckOptionalInput()) { - return ge::GRAPH_FAILED; - } - - if (ge::GRAPH_SUCCESS != CheckShapeDim() || ge::GRAPH_SUCCESS != GetN1Size() || - ge::GRAPH_SUCCESS != GetAndCheckN2Size() || ge::GRAPH_SUCCESS != GetGSize()) { - return ge::GRAPH_FAILED; - } - - if (ge::GRAPH_SUCCESS != GetBatchSize() || ge::GRAPH_SUCCESS != GetS1Size() || ge::GRAPH_SUCCESS != GetHeadDim() || - ge::GRAPH_SUCCESS != GetS2Size()) { - return ge::GRAPH_FAILED; - } - if (ge::GRAPH_SUCCESS != ValidateInputShapesMatch() || ge::GRAPH_SUCCESS != CheckScaleShape()) { - return ge::GRAPH_FAILED; - } - - GenerateInfo(liqInfo); - - return ge::GRAPH_SUCCESS; -} - -// --------------------------TilingPrepare函数定义------------------------------------- -static ge::graphStatus TilingPrepareForLightningIndexerQuant(gert::TilingParseContext * /* context */) -{ - return ge::GRAPH_SUCCESS; -} - -// --------------------------LightningIndexerQuantTiling类成员函数定义----------------------- -ge::graphStatus LightningIndexerQuantTiling::DoTiling(LIQTilingInfo *tilingInfo) -{ - // -------------set blockdim----------------- - auto ascendcPlatform = platform_ascendc::PlatformAscendC(tilingInfo->platformInfo); - uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); - uint32_t aicNum = ascendcPlatform.GetCoreNumAic(); - uint32_t blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, aicNum, aivNum); - context_->SetBlockDim(blockDim); - - // -------------set workspacesize----------------- - constexpr uint32_t MM1_RES_ELEM_SIZE = 4; // 4: fp32 - constexpr uint32_t DOUBLE_BUFFER = 2; // 双Buffer - constexpr uint32_t M_BASE_SIZE = 512; // m轴基本块大小 - constexpr uint32_t S2_BASE_SIZE = 512; // S2轴基本块大小 - constexpr uint32_t V1_RES_ELEM_SIZE = 4; // 4: int32 - constexpr uint32_t V1_RES_ELEM_TYPE = 2; // 保留Index和Value 2种数据 - constexpr uint32_t V1_DECODE_PARAM_ELEM_SIZE = 8; // 8: int64 - constexpr uint32_t V1_DECODE_PARAM_NUM = 16; // Decode参数个数 - constexpr uint32_t V1_DECODE_DATA_NUM = 2; // Decode每个核需要存储头和尾部两块数据 - constexpr uint32_t S1_BASE_SIZE = 8; // S1轴基本块的大小 - constexpr uint32_t TOPK_MAX_SIZE = 2048; // TopK选取个数 - uint32_t workspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize(); - // 主流程需Workspace大小 - uint32_t mm1ResSize = M_BASE_SIZE * S2_BASE_SIZE; - workspaceSize += mm1ResSize * MM1_RES_ELEM_SIZE * DOUBLE_BUFFER * aicNum; - // Decode流程(LD)需要Workspace大小 - // 临时存储Decode中间结果大小: 2(头/尾)*8(s1Base)*2(idx/value)*2048(K)*sizeof(int32)*24=6M - workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_RES_ELEM_TYPE * TOPK_MAX_SIZE * V1_RES_ELEM_SIZE * aicNum; - // 临时存储Decode中间参数信息大小: 2(头/尾)*8(s1Base)*16(paramNum)*sizeof(int64_t)*24=48k - workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_DECODE_PARAM_NUM * V1_DECODE_PARAM_ELEM_SIZE * aicNum; - size_t *workSpaces = context_->GetWorkspaceSizes(1); - workSpaces[0] = workspaceSize; - - // -------------set tilingdata----------------- - tilingData_.set_bSize(tilingInfo->bSize); - tilingData_.set_s2Size(tilingInfo->s2Size); - tilingData_.set_s1Size(tilingInfo->s1Size); - tilingData_.set_sparseCount(tilingInfo->sparseCount); - tilingData_.set_gSize(tilingInfo->gSize); - tilingData_.set_blockSize(tilingInfo->blockSize); - tilingData_.set_maxBlockNumPerBatch(tilingInfo->maxBlockNumPerBatch); - tilingData_.set_sparseMode(tilingInfo->sparseMode); - tilingData_.set_usedCoreNum(blockDim); - tilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity()); - context_->GetRawTilingData()->SetDataSize(tilingData_.GetDataSize()); - - // -------------set tilingkey----------------- - // DT_Q, DT_KV, DT_OUT, PAGE_ATTENTION, FLASH_DECODE, LAYOUT_T, KV_LAYOUT_T - uint32_t inputQType = static_cast(tilingInfo->inputQType); - uint32_t inputKType = static_cast(tilingInfo->inputKType); - uint32_t outputType = static_cast(tilingInfo->outputType); - uint32_t pageAttentionFlag = static_cast(tilingInfo->pageAttentionFlag); - uint32_t inputQLayout = static_cast(tilingInfo->inputQLayout); - uint32_t inputKLayout = static_cast(tilingInfo->inputKLayout); - uint32_t tilingKey = - GET_TPL_TILING_KEY(inputQType, inputKType, outputType, pageAttentionFlag, inputQLayout, inputKLayout); - context_->SetTilingKey(tilingKey); - - return ge::GRAPH_SUCCESS; -} - -// --------------------------Tiling函数定义--------------------------- -ge::graphStatus TilingForLightningIndexerQuant(gert::TilingContext *context) -{ - OPS_ERR_IF(context == nullptr, OPS_REPORT_VECTOR_INNER_ERR("LightningIndexerQuant", "Tiling context is null."), - return ge::GRAPH_FAILED); - LIQTilingInfo liqInfo; - LIQInfoParser LIQInfoParser(context); - if (LIQInfoParser.ParseAndCheck(liqInfo) != ge::GRAPH_SUCCESS) { - return ge::GRAPH_FAILED; - } - LightningIndexerQuantTiling liqTiling(context); - return liqTiling.DoTiling(&liqInfo); -} - -// --------------------------Tiling及函数TilingPrepare函数注册-------- -IMPL_OP_OPTILING(LightningIndexerQuant) - .Tiling(TilingForLightningIndexerQuant) - .TilingParse(TilingPrepareForLightningIndexerQuant); - -} // namespace optiling diff --git a/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_tiling.h b/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_tiling.h deleted file mode 100644 index d51a51e7..00000000 --- a/csrc/lightning_indexer_quant/op_host/lightning_indexer_quant_tiling.h +++ /dev/null @@ -1,234 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify it. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file lightning_indexer_quant_tiling.h - * \brief - */ - -#ifndef LIGHTNING_INDEXER_QUANT_TILING_H_ -#define LIGHTNING_INDEXER_QUANT_TILING_H_ - -#include "error/ops_error.h" -#include "exe_graph/runtime/tiling_context.h" -#include "platform/platform_info.h" -#include "register/op_def_registry.h" -#include "register/tilingdata_base.h" -#include "tiling/platform/platform_ascendc.h" -#include "tiling/tiling_api.h" - -namespace optiling { -// ------------------公共定义-------------------------- -struct TilingRequiredParaInfo { - const gert::CompileTimeTensorDesc *desc; - const gert::StorageShape *shape; -}; - -struct TilingOptionalParaInfo { - const gert::CompileTimeTensorDesc *desc; - const gert::Tensor *tensor; -}; - -enum class DataLayout : uint32_t { - BSND = 0, - TND = 1, - PA_BSND = 2 -}; - -// ------------------算子原型索引常量定义---------------- -// Inputs Index -constexpr uint32_t QUERY_INDEX = 0; -constexpr uint32_t KEY_INDEX = 1; -constexpr uint32_t WEIGTHS_INDEX = 2; -constexpr uint32_t QUERY_DEQUANT_SCALE_INDEX = 3; -constexpr uint32_t KEY_DEQUANT_SCALE_INDEX = 4; -constexpr uint32_t ACTUAL_SEQ_Q_INDEX = 5; -constexpr uint32_t ACTUAL_SEQ_K_INDEX = 6; -constexpr uint32_t BLOCK_TABLE_INDEX = 7; -constexpr uint32_t LIGHTNING_INDEXER_QUANT = 0; -// Attributes Index -constexpr uint32_t ATTR_QUERY_QUANT_MODE_INDEX = 0; -constexpr uint32_t ATTR_KEY_QUANT_MODE_INDEX = 1; -constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 2; -constexpr uint32_t ATTR_KEY_LAYOUT_INDEX = 3; -constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 4; -constexpr uint32_t ATTR_SPARSE_MODE_INDEX = 5; -// Dim Index -constexpr uint32_t DIM_IDX_ZERO = 0; -constexpr uint32_t DIM_IDX_ONE = 1; -constexpr uint32_t DIM_IDX_TWO = 2; -constexpr uint32_t DIM_IDX_THREE = 3; -// Dim Num -constexpr uint32_t DIM_NUM_TWO = 2; -constexpr uint32_t DIM_NUM_THREE = 3; -constexpr uint32_t DIM_NUM_FOUR = 4; -// 入参限制常量 -constexpr uint32_t HEAD_DIM_LIMIT = 128; -constexpr uint32_t SPARSE_LIMIT = 2048; -constexpr uint32_t G_SIZE_LIMIT = 64; -constexpr uint32_t BLOCK_SIZE_LIMIT = 1024; -constexpr uint32_t BLOCK_SIZE_FACTOR = 16; -constexpr uint32_t SPARSE_MODE_LOWER = 3; - -// -----------算子TilingData定义--------------- -BEGIN_TILING_DATA_DEF(LIQTilingData) -TILING_DATA_FIELD_DEF(uint32_t, bSize) -TILING_DATA_FIELD_DEF(uint32_t, n2Size) -TILING_DATA_FIELD_DEF(uint32_t, gSize) -TILING_DATA_FIELD_DEF(uint32_t, s1Size) -TILING_DATA_FIELD_DEF(uint32_t, s2Size) -TILING_DATA_FIELD_DEF(uint32_t, sparseCount) -TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum) -TILING_DATA_FIELD_DEF(uint32_t, blockSize) -TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch) -TILING_DATA_FIELD_DEF(uint32_t, sparseMode) -END_TILING_DATA_DEF -REGISTER_TILING_DATA_CLASS(LightningIndexerQuant, LIQTilingData) - -// -----------算子CompileInfo定义------------------- -struct LIQCompileInfo {}; - -// -----------算子Tiling入参结构体定义--------------- -struct LIQParaInfo { - TilingRequiredParaInfo query = {nullptr, nullptr}; - TilingRequiredParaInfo key = {nullptr, nullptr}; - TilingRequiredParaInfo weights = {nullptr, nullptr}; - TilingRequiredParaInfo query_dequant_scale = {nullptr, nullptr}; - TilingRequiredParaInfo key_dequant_scale = {nullptr, nullptr}; - TilingOptionalParaInfo actualSeqLengthsQ = {nullptr, nullptr}; - TilingOptionalParaInfo actualSeqLengthsK = {nullptr, nullptr}; - TilingOptionalParaInfo blockTable = {nullptr, nullptr}; - TilingRequiredParaInfo attenOut = {nullptr, nullptr}; - - const int32_t *queryQuantMode = nullptr; - const int32_t *keyQuantMode = nullptr; - const char *layOutQuery = nullptr; - const char *layOutKey = nullptr; - const int32_t *blockSize = nullptr; - const int32_t *sparseMode = nullptr; - const int32_t *sparseCount = nullptr; -}; - -// -----------算子Tiling入参信息类--------------- -class LIQTilingInfo { -public: - const char *opName = nullptr; - fe::PlatFormInfos *platformInfo = nullptr; - LIQParaInfo opParamInfo; - // Base Param - platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910B; - uint32_t bSize = 0; - uint32_t n1Size = 0; - uint32_t n2Size = 0; - uint32_t s1Size = 0; - int64_t s2Size = 0; - uint32_t qkHeadDim = 0; - uint32_t gSize = 0; - // PageAttention - bool pageAttentionFlag = false; - int32_t blockSize = 0; - uint32_t maxBlockNumPerBatch = 0; - // Mask - int32_t sparseMode = 0; - // Others Flag - uint32_t sparseCount = 0; - // DType - ge::DataType inputQType = ge::DT_FLOAT16; - ge::DataType inputKType = ge::DT_FLOAT16; - ge::DataType outputType = ge::DT_INT32; - // Layout - DataLayout inputQLayout = DataLayout::BSND; - DataLayout inputKLayout = DataLayout::PA_BSND; -}; - -// -----------算子Tiling入参信息解析及Check类--------------- -class LIQInfoParser { -public: - explicit LIQInfoParser(gert::TilingContext *context) : context_(context) {} - ~LIQInfoParser() = default; - - ge::graphStatus CheckRequiredInOutExistence() const; - ge::graphStatus CheckRequiredAttrExistence() const; - ge::graphStatus CheckRequiredParaExistence() const; - ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor, - const std::string &actualSeqLenName); - ge::graphStatus GetOpName(); - ge::graphStatus GetNpuInfo(); - void GetOptionalInputParaInfo(); - void GetInputParaInfo(); - void GetOutputParaInfo(); - ge::graphStatus GetAttrParaInfo(); - ge::graphStatus CheckAttrParaInfo(); - ge::graphStatus GetOpParaInfo(); - ge::graphStatus ValidateInputShapesMatch(); - ge::graphStatus CheckScaleShape(); - ge::graphStatus GetAndCheckInOutDataType(); - ge::graphStatus GetBatchSize(); - ge::graphStatus GetHeadDim(); - ge::graphStatus GetS1Size(); - ge::graphStatus GetAndCheckOptionalInput(); - ge::graphStatus CheckShapeDim(); - ge::graphStatus GetAndCheckBlockSize(); - ge::graphStatus GetS2SizeForPageAttention(); - ge::graphStatus GetS2SizeForBatchContinuous(); - ge::graphStatus GetS2Size(); - ge::graphStatus GetQueryKeyAndOutLayout(); - ge::graphStatus GetN1Size(); - ge::graphStatus GetAndCheckN2Size(); - ge::graphStatus GetGSize(); - ge::graphStatus GetAttenMaskInfo(); - ge::graphStatus GetActualSeqInfo(); - void GenerateInfo(LIQTilingInfo &liqInfo); - ge::graphStatus ParseAndCheck(LIQTilingInfo &liqInfo); - -public: - gert::TilingContext *context_ = nullptr; - const char *opName_; - fe::PlatFormInfos *platformInfo_; - LIQParaInfo opParamInfo_; - - // BaseParams - uint32_t bSize_ = 0; - uint32_t n1Size_ = 0; - uint32_t n2Size_ = 0; - uint32_t gSize_ = 0; - uint32_t s1Size_ = 0; - int64_t s2Size_ = 0; - uint32_t headDim_ = 0; - // Layout - DataLayout qLayout_ = DataLayout::BSND; - DataLayout kLayout_ = DataLayout::PA_BSND; - // PageAttention - uint32_t maxBlockNumPerBatch_ = 0; - int32_t blockSize_ = 0; - platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B; - ge::DataType inputQType_ = ge::DT_FLOAT16; - ge::DataType inputKType_ = ge::DT_FLOAT16; - ge::DataType weightsType_ = ge::DT_FLOAT16; - ge::DataType inputQueryScaleType_ = ge::DT_FLOAT16; - ge::DataType inputKeyScaleType_ = ge::DT_FLOAT16; - ge::DataType blockTableType_ = ge::DT_FLOAT16; - ge::DataType inputKRopeType_ = ge::DT_FLOAT16; - ge::DataType outputType_ = ge::DT_FLOAT16; -}; - -// ---------------算子Tiling类--------------- -class LightningIndexerQuantTiling { -public: - explicit LightningIndexerQuantTiling(gert::TilingContext *context) : context_(context) {}; - ge::graphStatus DoTiling(LIQTilingInfo *tilingInfo); - -private: - gert::TilingContext *context_ = nullptr; - LIQTilingData tilingData_; -}; - -} // namespace optiling -#endif // LIGHTNING_INDEXER_QUANT_TILING_H_ \ No newline at end of file diff --git a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant.cpp b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant.cpp deleted file mode 100644 index a9513daf..00000000 --- a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant.cpp +++ /dev/null @@ -1,50 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify it. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file lightning_indexer_quant.cpp - * \brief - */ - -#include "kernel_operator.h" -#include "lib/matmul_intf.h" -#include "lightning_indexer_quant_kernel.h" -#include "lightning_indexer_quant_template_tiling_key.h" - -using namespace LIQKernel; - -#define INVOKE_LI_NO_KFC_OP_IMPL(templateClass, ...) \ - do { \ - templateClass> op; \ - GET_TILING_DATA_WITH_STRUCT(LIQTilingData, tiling_data_in, tiling); \ - const LIQTilingData *__restrict tiling_data = &tiling_data_in; \ - op.Init(query, key, weights, queryScale, keyScale, actualSeqLengthsQ, actualSeqLengthsK, blocktable, \ - sparseIndices, user, tiling_data, &tPipe); \ - op.Process(); \ - } while (0) - -template -__global__ __aicore__ void lightning_indexer_quant(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights, - __gm__ uint8_t *queryScale, __gm__ uint8_t *keyScale, - __gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengthsK, - __gm__ uint8_t *blocktable, __gm__ uint8_t *sparseIndices, - __gm__ uint8_t *workspace, __gm__ uint8_t *tiling) -{ -#if (__CCE_AICORE__ == 310) || (defined __DAV_310R6__) || (__CCE_AICORE__ == 200) - -#else - TPipe tPipe; - __gm__ uint8_t *user = GetUserWorkspace(workspace); - KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); - - INVOKE_LI_NO_KFC_OP_IMPL(LIQPreload, int8_t, int8_t, int32_t, - PAGE_ATTENTION, LI_LAYOUT(Q_LAYOUT_T), LI_LAYOUT(K_LAYOUT_T)); -#endif -} diff --git a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_common.h b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_common.h deleted file mode 100644 index a0f0eb7f..00000000 --- a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_common.h +++ /dev/null @@ -1,146 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify it. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file lightning_indexer_quant_common.h - * \brief - */ -#ifndef LIGHTNING_INDEXER_QUANT_COMMON_H -#define LIGHTNING_INDEXER_QUANT_COMMON_H - -namespace LIQCommon { - -// 与tiling的layout保持一致 -enum class LI_LAYOUT : uint32_t { - BSND = 0, - TND = 1, - PA_BSND = 2 -}; - -template -struct LIQType { - using queryType = Q_T; - using keyType = K_T; - using outputType = OUT_T; - static constexpr bool pageAttention = PAGE_ATTENTION; - static constexpr LI_LAYOUT layout = Q_LAYOUT_T; - static constexpr LI_LAYOUT keyLayout = K_LAYOUT_T; -}; - -struct RunInfo { - uint32_t loop; - uint32_t bN2Idx; - uint32_t bIdx; - uint32_t n2Idx = 0; - uint32_t gS1Idx; - uint32_t s2Idx; - - uint32_t actS1Size = 1; - uint32_t actS2Size = 1; - uint32_t actMBaseSize; - uint32_t actualSingleProcessSInnerSize; - uint32_t actualSingleProcessSInnerSizeAlign; - - uint64_t tensorQueryOffset; - uint64_t tensorKeyOffset; - uint64_t tensorKeyScaleOffset; - uint64_t tensorWeightsOffset; - uint64_t indiceOutOffset; - - bool isFirstS2InnerLoop; - bool isLastS2InnerLoop; - bool isAllLoopEnd = false; - bool isValid = false; -}; - -struct ConstInfo { - // CUBE与VEC核间同步的模式 - static constexpr uint32_t FIA_SYNC_MODE2 = 2; - // BUFFER的字节数 - static constexpr uint32_t BUFFER_SIZE_BYTE_32B = 32; - static constexpr uint32_t BUFFER_SIZE_BYTE_64B = 64; - static constexpr uint32_t BUFFER_SIZE_BYTE_256B = 256; - static constexpr uint32_t BUFFER_SIZE_BYTE_512B = 512; - static constexpr uint32_t BUFFER_SIZE_BYTE_1K = 1024; - static constexpr uint32_t BUFFER_SIZE_BYTE_2K = 2048; - static constexpr uint32_t BUFFER_SIZE_BYTE_4K = 4096; - static constexpr uint32_t BUFFER_SIZE_BYTE_8K = 8192; - static constexpr uint32_t BUFFER_SIZE_BYTE_16K = 16384; - static constexpr uint32_t BUFFER_SIZE_BYTE_32K = 32768; - // 无效索引 - static constexpr int INVALID_IDX = -1; - - // CUBE和VEC的核间同步EventID - uint32_t syncC1V1 = 0U; - uint32_t syncC1V0 = 2U; - uint32_t syncV1C1 = 0U; - uint32_t syncV0C1 = 1U; - - // 基本块大小 - uint32_t mBaseSize = 1ULL; - uint32_t s1BaseSize = 1ULL; - uint32_t s2BaseSize = 1ULL; - - uint64_t batchSize = 0ULL; - uint64_t gSize = 0ULL; - uint64_t qHeadNum = 0ULL; - uint64_t kHeadNum; - uint64_t headDim; - uint64_t sparseCount; // topK选取大小 - uint64_t kSeqSize = 0ULL; // kv最大S长度 - uint64_t qSeqSize = 1ULL; // q最大S长度 - uint32_t kCacheBlockSize = 0; // PA场景的block size - uint32_t maxBlockNumPerBatch = 0; // PA场景的最大单batch block number - LI_LAYOUT outputLayout; // 输出的格式 - bool attenMaskFlag = false; - - uint32_t actualLenQDims = 0U; // query的actualSeqLength 的维度 - uint32_t actualLenDims = 0U; // KV 的actualSeqLength 的维度 - bool isAccumSeqS1 = false; // 是否累加模式 - bool isAccumSeqS2 = false; // 是否累加模式 -}; - -struct SplitCoreInfo { - uint32_t s2Start = 0U; // S2的起始位置 - uint32_t s2End = 0U; // S2循环index上限 - uint32_t bN2Start = 0U; - uint32_t bN2End = 0U; - uint32_t gS1Start = 0U; - uint32_t gS1End = 0U; - bool isLD = false; // 当前核是否需要进行Decode归约任务 -}; - -template -__aicore__ inline T Align(T num, T rnd) -{ - return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd) * (rnd))); -} - -template -__aicore__ inline T1 Min(T1 a, T2 b) -{ - return (a > b) ? (b) : (a); -} - -template -__aicore__ inline T1 Max(T1 a, T2 b) -{ - return (a > b) ? (a) : (b); -} - -template -__aicore__ inline T CeilDiv(T num, T rnd) -{ - return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd))); -} -} // namespace LIQCommon - -#endif // LIGHTNING_INDEXER_QUANT_COMMON_H \ No newline at end of file diff --git a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_kernel.h b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_kernel.h deleted file mode 100644 index 723255d3..00000000 --- a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_kernel.h +++ /dev/null @@ -1,714 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify it. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file lightning_indexer_quant_kernel.h - * \brief - */ - -#ifndef LIGHTNING_INDEXER_QUANT_KERNEL_H -#define LIGHTNING_INDEXER_QUANT_KERNEL_H - -#include "kernel_operator.h" -#include "kernel_operator_list_tensor_intf.h" -#include "kernel_tiling/kernel_tiling.h" -#include "lib/matmul_intf.h" -#include "lib/matrix/matmul/tiling.h" -#include "lightning_indexer_quant_common.h" -#include "lightning_indexer_quant_service_vector.h" -#include "lightning_indexer_quant_service_cube.h" - -namespace LIQKernel { -using namespace LIQCommon; -using namespace LIQServiceVec; -using namespace matmul; -using AscendC::CacheMode; -using AscendC::CrossCoreSetFlag; -using AscendC::CrossCoreWaitFlag; - -// 由于S2循环前,RunInfo还没有赋值,使用TempLoopInfo临时存放B、N、S1轴相关的信息;同时减少重复计算 -struct TempLoopInfo { - uint32_t bN2Idx = 0; - uint32_t bIdx = 0U; - uint32_t n2Idx = 0U; - uint32_t gS1Idx = 0U; - uint32_t gS1LoopEnd = 0U; // gS1方向循环的结束Idx - uint32_t s2LoopEnd = 0U; // S2方向循环的结束Idx - uint32_t actS1Size = 1ULL; // 当前Batch循环处理的S1轴的实际大小 - uint32_t actS2Size = 0ULL; - bool curActSeqLenIsZero = false; - bool needDealActS1LessThanS1 = false; // S1的实际长度小于shape的S1长度时,是否需要清理输出 - uint32_t actMBaseSize = 0U; // m轴(gS1)方向实际大小 - uint32_t mBasicSizeTail = 0U; // gS1方向循环的尾基本块大小 - uint32_t s2BasicSizeTail = 0U; // S2方向循环的尾基本块大小 -}; - -template -class LIQPreload { -public: - __aicore__ inline LIQPreload(){}; - __aicore__ inline void Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights, - __gm__ uint8_t *queryScale, __gm__ uint8_t *keyScale, __gm__ uint8_t *actualSeqLengthsQ, - __gm__ uint8_t *actualSeqLengthsK, __gm__ uint8_t *blockTable, - __gm__ uint8_t *sparseIndices, __gm__ uint8_t *workspace, - const LIQTilingData *__restrict tiling, TPipe *tPipe); - __aicore__ inline void Process(); - - // =================================类型定义区================================= - using Q_T = typename LIQT::queryType; - using K_T = typename LIQT::keyType; - using OUT_T = typename LIQT::outputType; - static constexpr bool PAGE_ATTENTION = LIQT::pageAttention; - static constexpr LI_LAYOUT Q_LAYOUT_T = LIQT::layout; - static constexpr LI_LAYOUT K_LAYOUT_T = LIQT::keyLayout; - - using MM1_OUT_T = float; - - LIQMatmul matmulService; - LIQVector vectorService; - - // =================================常量区================================= - static constexpr uint32_t SYNC_C1_V1_FLAG = 4; - static constexpr uint32_t SYNC_V1_C1_FLAG = 5; - - static constexpr uint32_t M_BASE_SIZE = 256; - static constexpr uint32_t S2_BASE_SIZE = 2048; - static constexpr uint32_t HEAD_DIM = 128; - static constexpr uint32_t K_HEAD_NUM = 1; - static constexpr uint32_t GM_ALIGN_BYTES = 512; - static constexpr uint32_t LI_QUANT_PRELOAD_TASK_CACHE_SIZE = 2; - - static constexpr int64_t LD_PREFETCH_LEN = 2; - // for workspace double - static constexpr uint32_t WS_DOBULE = 2; - -protected: - TPipe *pipe = nullptr; - - // offset - uint64_t queryCoreOffset = 0ULL; - uint64_t keyCoreOffset = 0ULL; - uint64_t keyScaleCoreOffset = 0ULL; - uint64_t weightsCoreOffset = 0ULL; - uint64_t indiceOutCoreOffset = 0ULL; - - // ================================Global Buffer区================================= - GlobalTensor queryGm; - GlobalTensor keyGm; - GlobalTensor weightsGm; - - GlobalTensor indiceOutGm; - GlobalTensor blockTableGm; - - GlobalTensor actualSeqLengthsGmQ; - GlobalTensor actualSeqLengthsGm; - - // ================================类成员变量==================================== - // aic、aiv核信息 - uint32_t tmpBlockIdx = 0U; - uint32_t aiCoreIdx = 0U; - uint32_t usedCoreNum = 0U; - - LIQCommon::ConstInfo constInfo{}; - TempLoopInfo tempLoopInfo{}; - LIQCommon::SplitCoreInfo splitCoreInfo{}; - - // ================================Init functions================================== - __aicore__ inline void InitTilingData(const LIQTilingData *__restrict tilingData); - __aicore__ inline void InitBuffers(); - __aicore__ inline void InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengthsK); - // ================================Split Core================================ - __aicore__ inline void SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, LIQCommon::SplitCoreInfo &info); - __aicore__ inline uint32_t GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size, uint32_t actS2Size); - __aicore__ inline uint32_t GetTotalBaseBlockNum(); - // ================================Process functions================================ - __aicore__ inline void ProcessMain(); - __aicore__ inline void ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx, - LIQCommon::RunInfo runInfo[LI_QUANT_PRELOAD_TASK_CACHE_SIZE]); - __aicore__ inline void ProcessDecode(); - __aicore__ inline void ProcessInvalid(); - // ================================Params Calc===================================== - __aicore__ inline void CalcGS1LoopParams(uint32_t bN2Idx); - __aicore__ inline void GetBN2Idx(uint32_t bN2Idx); - __aicore__ inline uint32_t GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq, - GlobalTensor &actualSeqLengthsGm, uint32_t defaultSeqLen); - __aicore__ inline void GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size); - __aicore__ inline void CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx); - __aicore__ inline void CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LIQCommon::RunInfo &runInfo); - __aicore__ inline void DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start); -}; - -template -__aicore__ inline void LIQPreload::InitTilingData(const LIQTilingData *__restrict tilingData) -{ - usedCoreNum = tilingData->usedCoreNum; - constInfo.batchSize = tilingData->bSize; - constInfo.qHeadNum = constInfo.gSize = tilingData->gSize; - constInfo.kSeqSize = tilingData->s2Size; - constInfo.qSeqSize = tilingData->s1Size; - constInfo.attenMaskFlag = (tilingData->sparseMode == 3); - constInfo.kCacheBlockSize = tilingData->blockSize; - constInfo.maxBlockNumPerBatch = tilingData->maxBlockNumPerBatch; - constInfo.sparseCount = tilingData->sparseCount; - constInfo.outputLayout = Q_LAYOUT_T; // 输出和输入形状一致 - if (Q_LAYOUT_T == LI_LAYOUT::TND) { - constInfo.isAccumSeqS1 = true; - } - if (K_LAYOUT_T == LI_LAYOUT::TND) { - constInfo.isAccumSeqS2 = true; - } - - constInfo.kHeadNum = K_HEAD_NUM; - constInfo.headDim = HEAD_DIM; - - constInfo.mBaseSize = M_BASE_SIZE; - constInfo.s2BaseSize = S2_BASE_SIZE; - constInfo.s1BaseSize = (constInfo.mBaseSize + constInfo.gSize - 1) / constInfo.gSize; -} - -template -__aicore__ inline void LIQPreload::InitBuffers() -{ - if ASCEND_IS_AIV { - vectorService.InitBuffers(pipe); - } else { - matmulService.InitBuffers(pipe); - } -} - -template -__aicore__ inline void LIQPreload::InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, - __gm__ uint8_t *actualSeqLengthsK) -{ - if (actualSeqLengthsQ == nullptr) { - constInfo.actualLenQDims = 0; - } else { - constInfo.actualLenQDims = constInfo.batchSize; - actualSeqLengthsGmQ.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengthsQ, constInfo.actualLenQDims); - } - if (actualSeqLengthsK == nullptr) { - constInfo.actualLenDims = 0; - } else { - constInfo.actualLenDims = constInfo.batchSize; - actualSeqLengthsGm.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengthsK, constInfo.actualLenDims); - } -} - -template -__aicore__ inline uint32_t LIQPreload::GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq, - GlobalTensor &actualSeqLengthsGm, - uint32_t defaultSeqLen) -{ - if (actualLenDims == 0) { - return defaultSeqLen; - } else if (isAccumSeq && bIdx > 0) { - return actualSeqLengthsGm.GetValue(bIdx) - actualSeqLengthsGm.GetValue(bIdx - 1); - } else { - return actualSeqLengthsGm.GetValue(bIdx); - } -} - -template -__aicore__ inline void LIQPreload::GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size) -{ - actS1Size = GetActualSeqLen(bIdx, constInfo.actualLenQDims, constInfo.isAccumSeqS1, actualSeqLengthsGmQ, - constInfo.qSeqSize); - actS2Size = - GetActualSeqLen(bIdx, constInfo.actualLenDims, constInfo.isAccumSeqS2, actualSeqLengthsGm, constInfo.kSeqSize); -} - -template -__aicore__ inline uint32_t LIQPreload::GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size, - uint32_t actS2Size) -{ - if (actS2Size == 0) { - return 0; - } - uint32_t s1Offset = constInfo.s1BaseSize * s1gIdx; - int32_t validS2LenBase = static_cast(actS2Size) - static_cast(actS1Size); - int32_t validS2Len = s1Offset + validS2LenBase + constInfo.s1BaseSize; - validS2Len = Min(validS2Len, static_cast(actS2Size)); - validS2Len = Max(validS2Len, 1); - return (validS2Len + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize; -} - -template -__aicore__ inline uint32_t LIQPreload::GetTotalBaseBlockNum() -{ - uint32_t totalBlockNum = 0; - uint32_t actS1Size, actS2Size; - uint32_t s1GBaseNum, s2BaseNum; - for (uint32_t bIdx = 0; bIdx < constInfo.batchSize; bIdx++) { - GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size); - s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize); - if (!constInfo.attenMaskFlag) { - s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize); - totalBlockNum += s1GBaseNum * s2BaseNum * constInfo.kHeadNum; - continue; - } - for (uint32_t s1gIdx = 0; s1gIdx < s1GBaseNum; s1gIdx++) { - s2BaseNum = GetS2BaseBlockNumOnMask(s1gIdx, actS1Size, actS2Size); - totalBlockNum += s2BaseNum * constInfo.kHeadNum; - } - } - return totalBlockNum; -} - -// 多核版本,双闭区间。基本原则:计算每个核最少处理的块数, 剩余的部分前面的核每个核多处理一块 -template -__aicore__ void inline LIQPreload::SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, - LIQCommon::SplitCoreInfo &info) -{ - uint32_t totalBlockNum = GetTotalBaseBlockNum(); - uint32_t minBlockPerCore = totalBlockNum / coreNum; - uint32_t deal1MoreBlockCoreNum = totalBlockNum % coreNum; - uint32_t coreIdx = 0; - uint32_t lastGS1RemainBlockCnt = 0; - uint32_t coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore; - coreNum = minBlockPerCore == 0 ? deal1MoreBlockCoreNum : coreNum; - - bool findLastCoreEnd = true; - uint32_t actS1Size, actS2Size; - uint32_t s1GBaseNum, s2BaseNum; - for (uint32_t bN2Idx = 0; bN2Idx < constInfo.batchSize * constInfo.kHeadNum; bN2Idx++) { - uint32_t bIdx = bN2Idx / constInfo.kHeadNum; - if (bN2Idx % constInfo.kHeadNum == 0) { - GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size); - s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize); - s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize); - } - if constexpr (Q_LAYOUT_T == LI_LAYOUT::BSND) { - if (findLastCoreEnd && (s1GBaseNum == 0U || s2BaseNum == 0U)) { - info.bN2Start = bN2Idx; - info.gS1Start = 0; - info.s2Start = 0; - findLastCoreEnd = false; - } - } - for (uint32_t gS1Idx = 0; gS1Idx < s1GBaseNum; gS1Idx++) { - if (constInfo.attenMaskFlag) { - s2BaseNum = GetS2BaseBlockNumOnMask(gS1Idx, actS1Size, actS2Size); - } - if (findLastCoreEnd && s2BaseNum == 0U) { - info.bN2Start = bN2Idx; - info.gS1Start = gS1Idx; - info.s2Start = 0; - findLastCoreEnd = false; - } - for (uint32_t s2Idx = 0; s2Idx < s2BaseNum;) { - if (findLastCoreEnd) { - info.bN2Start = bN2Idx; - info.gS1Start = gS1Idx; - info.s2Start = s2Idx; - findLastCoreEnd = false; - } - uint32_t s2RemainBaseNum = s2BaseNum - s2Idx; - if (lastGS1RemainBlockCnt + s2RemainBaseNum >= coreDealBlockCnt) { - info.bN2End = bN2Idx; - info.gS1End = gS1Idx; - info.s2End = s2Idx + coreDealBlockCnt - lastGS1RemainBlockCnt - 1; - - if (coreIdx == curCoreIdx) { - // S2被切N核,那么只有第一个核需要处理LD,其他核不用 - if (s2Idx == 0 && info.s2End + 1 < s2BaseNum) { - info.isLD = true; - } - // 最后一个核处理的不是最后一个Batch,表明后面的Batch为空块(S2=0), 调整终点坐标以便清理输出 - if (coreIdx == coreNum - 1 && info.bN2End != constInfo.batchSize - 1) { - info.bN2End = constInfo.batchSize - 1; - info.gS1End = 0; - info.s2End = 0; - } - return; - } - coreIdx++; - findLastCoreEnd = true; - s2Idx = info.s2End + 1; - lastGS1RemainBlockCnt = 0; - coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore; - } else { - lastGS1RemainBlockCnt += s2RemainBaseNum; - break; - } - } - } - } -} - -template -__aicore__ inline void LIQPreload::DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start) -{ - if ASCEND_IS_AIV { - if (constInfo.outputLayout == LI_LAYOUT::TND) { - uint32_t tSize = actualSeqLengthsGmQ.GetValue(constInfo.batchSize - 1); - uint32_t tBase = bIdx == 0 ? 0 : actualSeqLengthsGmQ.GetValue(bIdx - 1); - uint32_t s1Count = tempLoopInfo.actS1Size; - - for (uint32_t s1Idx = s1Start; s1Idx < s1Count; s1Idx++) { - uint64_t indiceOutOffset = - (tBase + s1Idx) * constInfo.kHeadNum * constInfo.sparseCount + // T轴、s1轴偏移 - n2Idx * constInfo.sparseCount; // N2轴偏移 - vectorService.CleanInvalidOutput(indiceOutOffset); - } - } else if (constInfo.outputLayout == LI_LAYOUT::BSND) { - for (uint32_t s1Idx = s1Start; s1Idx < constInfo.qSeqSize; s1Idx++) { - // B,S1,N2,K - uint64_t indiceOutOffset = bIdx * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount + - s1Idx * constInfo.kHeadNum * constInfo.sparseCount + // B轴、S1轴偏移 - n2Idx * constInfo.sparseCount; // N2轴偏移 - vectorService.CleanInvalidOutput(indiceOutOffset); - } - } - } -} - -template -__aicore__ inline void LIQPreload::Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights, - __gm__ uint8_t *queryScale, __gm__ uint8_t *keyScale, - __gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengthsK, - __gm__ uint8_t *blockTable, __gm__ uint8_t *sparseIndices, - __gm__ uint8_t *workspace, const LIQTilingData *__restrict tiling, - TPipe *tPipe) -{ - if ASCEND_IS_AIV { - tmpBlockIdx = GetBlockIdx(); // vec:0-47 - aiCoreIdx = tmpBlockIdx / 2; - } else { - tmpBlockIdx = GetBlockIdx(); // cube:0-23 - aiCoreIdx = tmpBlockIdx; - } - - InitTilingData(tiling); - InitActualSeqLen(actualSeqLengthsQ, actualSeqLengthsK); - - // 计算分核 - SplitCore(aiCoreIdx, usedCoreNum, splitCoreInfo); - - pipe = tPipe; - // workspace 内存排布 - // |mm1ResGm(存S)|vec1ResGm(存LD中间结果)|vec1ParamGm(存LD参数) - // |Core0_mm1ResDB0-Core0_mm1ResDB1-Core1_mm1ResDB0....Core23_mm1ResDB0-Core23_mm1ResDB1|Core0_vec1Res... - uint64_t offset = 0; - - // mm1开DoubleBuffer - GlobalTensor mm1ResGm; // 存放S - uint64_t singleCoreMm1ResSize = WS_DOBULE * constInfo.s1BaseSize * constInfo.s2BaseSize * sizeof(MM1_OUT_T); - mm1ResGm.SetGlobalBuffer((__gm__ MM1_OUT_T *)(workspace + aiCoreIdx * singleCoreMm1ResSize)); - offset += GetBlockNum() * singleCoreMm1ResSize; - - // ld流程需要ws大小: [aicnum, 2, CeilDiv(constInfo.mBaseSize, constInfo.gSize), topkOut_*2] - // (aic, 8, 2, 2, 2048) - // (aic, s1_cube, 头尾, idx/value, K) - GlobalTensor vec1ResGm; // 存放TopK计算中间结果 - vec1ResGm.SetGlobalBuffer((__gm__ float *)(workspace + offset)); - offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * WS_DOBULE * BASE_TOPK * sizeof(float); - - // (aic, 8, 2, 16) - // (aic, s1_cube, 头尾,16ele) - GlobalTensor vec1ParamGm; // 存放LD参数信息 - vec1ParamGm.SetGlobalBuffer((__gm__ int64_t *)(workspace + offset)); - offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * LD_PARAM_NUM * sizeof(int64_t); - - GlobalTensor weightWorkspaceGm; // v1阶段处理w*scale后的结果 - uint64_t weightMemSize = BLOCK_CUBE * constInfo.mBaseSize * WS_DOBULE * sizeof(half); - weightWorkspaceGm.SetGlobalBuffer((__gm__ half *)(workspace + offset + aiCoreIdx * weightMemSize)); - offset += GetBlockNum() * weightMemSize; - - GlobalTensor qScaleGm; - GlobalTensor kScaleGm; - if ASCEND_IS_AIV { - vectorService.InitParams(constInfo, tiling); - indiceOutGm.SetGlobalBuffer((__gm__ int32_t *)sparseIndices); - weightsGm.SetGlobalBuffer((__gm__ half *)weights); - qScaleGm.SetGlobalBuffer((__gm__ half *)queryScale); - kScaleGm.SetGlobalBuffer((__gm__ half *)keyScale); - blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable); - vectorService.InitVecInputTensor(weightsGm, qScaleGm, kScaleGm, indiceOutGm, blockTableGm); - vectorService.InitVecWorkspaceTensor(weightWorkspaceGm, mm1ResGm, vec1ResGm, vec1ParamGm); - } else { - matmulService.InitParams(constInfo); - queryGm.SetGlobalBuffer((__gm__ Q_T *)query); - if constexpr (PAGE_ATTENTION) { - blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable); - } - keyGm.SetGlobalBuffer((__gm__ K_T *)key); - matmulService.InitMm1GlobalTensor(blockTableGm, keyGm, queryGm, mm1ResGm, weightWorkspaceGm); - } - InitBuffers(); -} - -template -__aicore__ inline void LIQPreload::GetBN2Idx(uint32_t bN2Idx) -{ - tempLoopInfo.bN2Idx = bN2Idx; - tempLoopInfo.bIdx = bN2Idx / constInfo.kHeadNum; - tempLoopInfo.n2Idx = bN2Idx % constInfo.kHeadNum; -} - -template -__aicore__ inline void LIQPreload::CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx) -{ - tempLoopInfo.gS1Idx = gS1LoopIdx; - tempLoopInfo.actMBaseSize = constInfo.mBaseSize; - uint32_t remainedGS1Size = tempLoopInfo.actS1Size * constInfo.gSize - tempLoopInfo.gS1Idx * constInfo.mBaseSize; - if (remainedGS1Size <= constInfo.mBaseSize && remainedGS1Size > 0) { - tempLoopInfo.actMBaseSize = tempLoopInfo.mBasicSizeTail; - } - - bool isEnd = (bN2LoopIdx == splitCoreInfo.bN2End) && (gS1LoopIdx == splitCoreInfo.gS1End); - uint32_t s2BlockNum; - if (constInfo.attenMaskFlag) { - s2BlockNum = GetS2BaseBlockNumOnMask(gS1LoopIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size); - } else { - s2BlockNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize; - } - tempLoopInfo.s2LoopEnd = isEnd ? splitCoreInfo.s2End : s2BlockNum - 1; -} - -template -__aicore__ inline void LIQPreload::CalcGS1LoopParams(uint32_t bN2LoopIdx) -{ - GetBN2Idx(bN2LoopIdx); - GetS1S2ActualSeqLen(tempLoopInfo.bIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size); - if ((tempLoopInfo.actS2Size == 0) || (tempLoopInfo.actS1Size == 0)) { - tempLoopInfo.curActSeqLenIsZero = true; - return; - } - tempLoopInfo.curActSeqLenIsZero = false; - tempLoopInfo.s2BasicSizeTail = tempLoopInfo.actS2Size % constInfo.s2BaseSize; - tempLoopInfo.s2BasicSizeTail = - (tempLoopInfo.s2BasicSizeTail == 0) ? constInfo.s2BaseSize : tempLoopInfo.s2BasicSizeTail; - tempLoopInfo.mBasicSizeTail = (tempLoopInfo.actS1Size * constInfo.gSize) % constInfo.mBaseSize; - tempLoopInfo.mBasicSizeTail = - (tempLoopInfo.mBasicSizeTail == 0) ? constInfo.mBaseSize : tempLoopInfo.mBasicSizeTail; - - uint32_t gS1SplitNum = (tempLoopInfo.actS1Size * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize; - tempLoopInfo.gS1LoopEnd = (bN2LoopIdx == splitCoreInfo.bN2End) ? splitCoreInfo.gS1End : gS1SplitNum - 1; - if constexpr (Q_LAYOUT_T == LI_LAYOUT::BSND) { - if (tempLoopInfo.gS1LoopEnd == gS1SplitNum - 1 && constInfo.qSeqSize > tempLoopInfo.actS1Size) { - tempLoopInfo.needDealActS1LessThanS1 = true; - } - } -} - -template -__aicore__ inline void LIQPreload::CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LIQCommon::RunInfo &runInfo) -{ - runInfo.loop = loop; - runInfo.bIdx = tempLoopInfo.bIdx; - runInfo.gS1Idx = tempLoopInfo.gS1Idx; - runInfo.s2Idx = s2LoopIdx; - runInfo.bN2Idx = tempLoopInfo.bN2Idx; - runInfo.isValid = s2LoopIdx <= tempLoopInfo.s2LoopEnd; - - if (!runInfo.isValid) { - return; // 需要验证, v1 时候需要runInfo - } - - runInfo.actS1Size = tempLoopInfo.actS1Size; - runInfo.actS2Size = tempLoopInfo.actS2Size; - // 计算实际基本块size - runInfo.actMBaseSize = tempLoopInfo.actMBaseSize; - runInfo.actualSingleProcessSInnerSize = constInfo.s2BaseSize; - uint32_t s2SplitNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize; - if (runInfo.s2Idx == s2SplitNum - 1) { - runInfo.actualSingleProcessSInnerSize = tempLoopInfo.s2BasicSizeTail; - } - runInfo.actualSingleProcessSInnerSizeAlign = - LIQCommon::Align((uint32_t)runInfo.actualSingleProcessSInnerSize, LIQCommon::ConstInfo::BUFFER_SIZE_BYTE_32B); - - runInfo.isFirstS2InnerLoop = s2LoopIdx == splitCoreInfo.s2Start; - runInfo.isLastS2InnerLoop = s2LoopIdx == tempLoopInfo.s2LoopEnd; - runInfo.isAllLoopEnd = (runInfo.bN2Idx == splitCoreInfo.bN2End) && (runInfo.gS1Idx == splitCoreInfo.gS1End) && - (runInfo.s2Idx == splitCoreInfo.s2End); - - if (runInfo.isFirstS2InnerLoop) { - uint64_t actualSeqQPrefixSum; - if constexpr (Q_LAYOUT_T == LI_LAYOUT::TND) { - actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGmQ.GetValue(runInfo.bIdx - 1); - } else { // BSND - actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.qSeqSize; - } - uint64_t tndBIdxOffset = actualSeqQPrefixSum * constInfo.qHeadNum * constInfo.headDim; - // B,S1,N1(N2,G),D - queryCoreOffset = tndBIdxOffset + runInfo.gS1Idx * constInfo.mBaseSize * constInfo.headDim; - // B,S1,N1(N2,G)/T,N1(N2,G) - weightsCoreOffset = actualSeqQPrefixSum * constInfo.qHeadNum + runInfo.n2Idx * constInfo.gSize; - // B,S1,N2,k/T,N2,k - indiceOutCoreOffset = - actualSeqQPrefixSum * constInfo.kHeadNum * constInfo.sparseCount + runInfo.n2Idx * constInfo.sparseCount; - } - uint64_t actualSeqKPrefixSum; - if constexpr (K_LAYOUT_T == LI_LAYOUT::TND) { // T N2 D - actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGm.GetValue(runInfo.bIdx - 1); - } else { - actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.kSeqSize; - } - uint64_t tndBIdxOffsetForK = actualSeqKPrefixSum * constInfo.kHeadNum * constInfo.headDim; - keyCoreOffset = tndBIdxOffsetForK + runInfo.s2Idx * constInfo.s2BaseSize * constInfo.kHeadNum * constInfo.headDim; - keyScaleCoreOffset = (actualSeqKPrefixSum + runInfo.s2Idx * constInfo.s2BaseSize) * constInfo.kHeadNum; - runInfo.tensorQueryOffset = queryCoreOffset; - runInfo.tensorKeyOffset = keyCoreOffset; - runInfo.tensorKeyScaleOffset = keyScaleCoreOffset; - runInfo.tensorWeightsOffset = weightsCoreOffset; - runInfo.indiceOutOffset = indiceOutCoreOffset; -} - -template -__aicore__ inline void LIQPreload::Process() -{ - if (usedCoreNum == 0) { - // 没有计算任务,直接清理输出 - ProcessInvalid(); - return; - } - - ProcessMain(); - - ProcessDecode(); -} - -template -__aicore__ inline void LIQPreload::ProcessInvalid() -{ - if ASCEND_IS_AIV { - uint32_t aivCoreNum = GetBlockNum() * 2; // 2 means c:v = 1:2 - uint64_t totalOutputSize = - constInfo.batchSize * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount; - uint64_t singleCoreSize = - LIQCommon::Align((totalOutputSize + aivCoreNum - 1) / aivCoreNum, GM_ALIGN_BYTES / sizeof(OUT_T)); - uint64_t baseSize = tmpBlockIdx * singleCoreSize; - if (baseSize < totalOutputSize) { - uint64_t dealSize = - (baseSize + singleCoreSize <= totalOutputSize) ? singleCoreSize : totalOutputSize - baseSize; - GlobalTensor output = indiceOutGm[baseSize]; - AscendC::InitGlobalMemory(output, dealSize, constInfo.INVALID_IDX); - } - } -} - -template -__aicore__ inline void LIQPreload::ProcessMain() -{ - if (aiCoreIdx >= usedCoreNum) { - // 无任务核直接返回 - return; - } - - if ASCEND_IS_AIV { - vectorService.AllocEventID(); - CrossCoreSetFlag(constInfo.syncV1C1); - CrossCoreSetFlag(constInfo.syncV1C1); - } else { - matmulService.AllocEventID(); - CrossCoreSetFlag(constInfo.syncC1V0); - CrossCoreSetFlag(constInfo.syncC1V0); - } - - LIQCommon::RunInfo runInfo[LI_QUANT_PRELOAD_TASK_CACHE_SIZE]; - - uint32_t gloop = 0; - for (uint32_t bN2LoopIdx = splitCoreInfo.bN2Start; bN2LoopIdx <= splitCoreInfo.bN2End; bN2LoopIdx++) { - CalcGS1LoopParams(bN2LoopIdx); - if (tempLoopInfo.curActSeqLenIsZero) { - DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, 0U); - - if ASCEND_IS_AIV { - if (bN2LoopIdx == splitCoreInfo.bN2End && gloop > 0) { - CrossCoreWaitFlag(constInfo.syncC1V1); - vectorService.ProcessVec1(runInfo[1 - gloop % LI_QUANT_PRELOAD_TASK_CACHE_SIZE]); - CrossCoreSetFlag( - constInfo.syncV1C1); // 反向同步 1 - } - } - continue; - } - for (uint32_t gS1LoopIdx = splitCoreInfo.gS1Start; gS1LoopIdx <= tempLoopInfo.gS1LoopEnd; gS1LoopIdx++) { - CalcS2LoopParams(bN2LoopIdx, gS1LoopIdx); - bool isEnd = (bN2LoopIdx == splitCoreInfo.bN2End) && (gS1LoopIdx == splitCoreInfo.gS1End); - uint32_t extraLoop = isEnd ? LI_QUANT_PRELOAD_TASK_CACHE_SIZE - 1 : 0; - for (int s2LoopIdx = splitCoreInfo.s2Start; s2LoopIdx <= (tempLoopInfo.s2LoopEnd + extraLoop); - s2LoopIdx++) { - ProcessBaseBlock(gloop, s2LoopIdx, runInfo); - ++gloop; - } - splitCoreInfo.s2Start = 0; - } - if (tempLoopInfo.needDealActS1LessThanS1) { - DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, tempLoopInfo.actS1Size); - } - splitCoreInfo.gS1Start = 0; - } - - if ASCEND_IS_AIV { - vectorService.FreeEventID(); - CrossCoreWaitFlag(constInfo.syncC1V0); - CrossCoreWaitFlag(constInfo.syncC1V0); - } else { - matmulService.FreeEventID(); - CrossCoreWaitFlag(constInfo.syncV1C1); - CrossCoreWaitFlag(constInfo.syncV1C1); - } -} - -template -__aicore__ inline void LIQPreload::ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx, - LIQCommon::RunInfo runInfo[LI_QUANT_PRELOAD_TASK_CACHE_SIZE]) -{ - int32_t curTaskId = loop % LI_QUANT_PRELOAD_TASK_CACHE_SIZE; - LIQCommon::RunInfo &curRunInfo = runInfo[curTaskId]; - LIQCommon::RunInfo &lastRunInfo = runInfo[1 - curTaskId]; - - CalcRunInfo(loop, s2LoopIdx, curRunInfo); - - if (curRunInfo.isValid) { - if ASCEND_IS_AIC { - if (curRunInfo.isFirstS2InnerLoop) { - CrossCoreWaitFlag(constInfo.syncV0C1); - } - CrossCoreWaitFlag(constInfo.syncV1C1); // 反向同步 1 - matmulService.ComputeMm1(curRunInfo); - CrossCoreSetFlag(constInfo.syncC1V1); - if (curRunInfo.isLastS2InnerLoop) { - CrossCoreSetFlag(constInfo.syncC1V0); // 反向同步 0 - } - } else { - if (curRunInfo.isFirstS2InnerLoop) { - CrossCoreWaitFlag(constInfo.syncC1V0); // 反向同步 0 - vectorService.ProcessVec0(curRunInfo); - CrossCoreSetFlag(constInfo.syncV0C1); - } - } - } - - if (lastRunInfo.isValid) { - if ASCEND_IS_AIV { - CrossCoreWaitFlag(constInfo.syncC1V1); - vectorService.ProcessVec1(lastRunInfo); - CrossCoreSetFlag(constInfo.syncV1C1); // 反向同步 1 - } - lastRunInfo.isValid = false; - } -} - -template -__aicore__ inline void LIQPreload::ProcessDecode() -{ - if ASCEND_IS_AIV { - vectorService.InitLDBuffers(pipe); - ICachePreLoad(LD_PREFETCH_LEN); - SyncAll(); - if (splitCoreInfo.isLD) { - vectorService.ProcessLD(); - } - } -} -} // namespace LIQKernel -#endif // LIGHTNING_INDEXER_QUANT_KERNEL_H \ No newline at end of file diff --git a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_service_cube.h b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_service_cube.h deleted file mode 100644 index 2f58a9e1..00000000 --- a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_service_cube.h +++ /dev/null @@ -1,613 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify it. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file lightning_indexer_quant_service_cube.h - * \brief use 5 buffer for matmul l1, better pipeline - */ -#ifndef LIGHTNING_INDEXER_QUANT_SERVICE_CUBE_H -#define LIGHTNING_INDEXER_QUANT_SERVICE_CUBE_H - -#include "kernel_operator.h" -#include "kernel_operator_list_tensor_intf.h" -#include "kernel_tiling/kernel_tiling.h" -#include "lib/matmul_intf.h" -#include "lib/matrix/matmul/tiling.h" -#include "lightning_indexer_quant_common.h" - -namespace LIQKernel { -using namespace LIQCommon; -struct MmInfo { - int64_t s2L0LoopId; - int64_t s1gL0LoopId; - int64_t s2L0RealSize; - int64_t s2GmOffset; -}; - -template -class LIQMatmul { -public: - using Q_T = typename LIQT::queryType; - using K_T = typename LIQT::keyType; - - __aicore__ inline LIQMatmul(){}; - __aicore__ inline void InitBuffers(TPipe *pipe); - __aicore__ inline void InitMm1GlobalTensor(const GlobalTensor &blkTableGm, const GlobalTensor &keyGm, - const GlobalTensor &queryGm, const GlobalTensor &mm1ResGm, - const GlobalTensor &weightWorkspaceGm); - __aicore__ inline void InitParams(const ConstInfo &constInfo); - __aicore__ inline void AllocEventID(); - __aicore__ inline void FreeEventID(); - __aicore__ inline void ComputeMm1(const LIQCommon::RunInfo &runInfo); - - static constexpr IsResetLoad3dConfig LOAD3DV2_CONFIG = {true, true}; // isSetFMatrix isSetPadding; - static constexpr uint64_t DOUBLE_BUF_NUM = 2; - static constexpr uint64_t L0AB_BUF_NUM = 4; - - static constexpr uint32_t KEY_MTE1_MTE2_EVENT = EVENT_ID2; - static constexpr uint32_t QW_MTE1_MTE2_EVENT = EVENT_ID5; // KEY_MTE1_MTE2_EVENT + DOUBLE_BUF_NUM; - static constexpr uint32_t M_MTE1_EVENT = EVENT_ID3; - static constexpr uint32_t M_FIX_EVENT = EVENT_ID0; - static constexpr uint32_t FIX_M_EVENT = EVENT_ID2; - static constexpr uint32_t FIX_MTE1_EVENT = EVENT_ID4; - - static constexpr uint64_t S8_BLOCK_CUBE = 32; - - static constexpr uint32_t MTE2_MTE1_EVENT = EVENT_ID2; - static constexpr uint32_t MTE1_M_EVENT = EVENT_ID2; - - static constexpr uint64_t D_BASIC_BLOCK = 128; - static constexpr uint64_t S1G_BASIC_BLOCK_L1 = 256; - - static constexpr uint64_t S1G_BASIC_BLOCK_L0 = 128; - static constexpr uint64_t S2_BASIC_BLOCK_L0 = 128; - - static constexpr uint64_t QUERY_BUFFER_OFFSET = S1G_BASIC_BLOCK_L1 * D_BASIC_BLOCK; - static constexpr uint64_t SL1_BUFFER_OFFSET = S1G_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0; - static constexpr uint64_t KEY_BUFFER_OFFSET = S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK; - static constexpr uint64_t WEIGHT_BUFFER_OFFSET = S1G_BASIC_BLOCK_L1 * BLOCK_CUBE; - static constexpr uint64_t L0AB_BUFFER_OFFSET_S8_16K = 16 * 1024; - static constexpr uint64_t L0AB_BUFFER_OFFSET_FP16_16K = 16 * 512; - static constexpr uint64_t L0C_BUFFER_OFFSET = 64 * 256; - -private: - __aicore__ inline void WeightDmaCopy(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo); - __aicore__ inline void LoadKeyToL0b(uint64_t s2L0RealSize); - __aicore__ inline void LoadQueryToL0a(uint64_t s1gL1Offset, uint64_t s1gL1RealSize, uint64_t s1gL0RealSize); - __aicore__ inline void QueryNd2Nz(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo); - __aicore__ inline void KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LIQCommon::RunInfo &runInfo); - __aicore__ inline void KeyNd2Nz(uint64_t s2L1RealSize, const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo); - __aicore__ inline void FixpSToL1(uint64_t s1gL0RealSize, uint64_t s2L0RealSize); - __aicore__ inline void LoadSToL0b(uint64_t s1gL1RealSize, uint64_t s2L0RealSize, uint64_t sL1BufIdx, - int64_t mStartPt); - __aicore__ inline void LoadWeightToL0a(uint64_t s1gL1Offset); - __aicore__ inline void ComputeWs(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, int64_t s1gOffset); - __aicore__ inline void FixpResToGm(uint64_t s1L0RealCount, uint64_t s2L0RealSize, uint64_t s1GmOffset, - uint64_t s2GmOffset, const LIQCommon::RunInfo &runInfo); - __aicore__ inline void ComputeQk(uint64_t s1gL0RealSize, uint64_t s2L0RealSize); - __aicore__ inline void ProcessWs(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t sL1BufIdx, - const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo); - __aicore__ inline void ProcessQk(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t s1L0LoopCnt, - const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo); - __aicore__ inline void CalcMmInfo(MmInfo &mmInfo, uint64_t loopIdx, uint64_t s1L0LoopCnt, const MmInfo &lastMmInfo, - const LIQCommon::RunInfo &runInfo); - static constexpr LI_LAYOUT Q_LAYOUT_T = LIQT::layout; - static constexpr LI_LAYOUT K_LAYOUT_T = LIQT::keyLayout; - GlobalTensor blkTableGm_; - GlobalTensor keyGm_; - GlobalTensor queryGm_; - GlobalTensor weightGm_; - GlobalTensor mm1ResGm_; - - TBuf bufQL1_; - LocalTensor queryL1_; - TBuf bufKeyL1_; - LocalTensor keyL1_; - TBuf bufWeightL1_; - LocalTensor weightL1_; - TBuf bufSL1_; - LocalTensor sL1_; - - TBuf bufL0A_; - LocalTensor l0a_; - TBuf bufL0B_; - LocalTensor l0b_; - - TBuf bufL0C_; - LocalTensor cL0_; - - uint64_t keyL1BufIdx_ = 0; - uint64_t qwL1Mte2BufIdx_ = 0; - uint64_t sL1BufIdx_ = 0; - uint64_t l0BufIdx_ = 0; - uint64_t l0cBufIdx_ = 0; - - ConstInfo constInfo_; -}; - -template -__aicore__ inline void LIQMatmul::InitParams(const ConstInfo &constInfo) -{ - constInfo_ = constInfo; -} - -template -__aicore__ inline void LIQMatmul::InitBuffers(TPipe *pipe) -{ - pipe->InitBuffer(bufQL1_, DOUBLE_BUF_NUM * S1G_BASIC_BLOCK_L1 * D_BASIC_BLOCK * sizeof(Q_T)); - queryL1_ = bufQL1_.Get(); - pipe->InitBuffer(bufKeyL1_, DOUBLE_BUF_NUM * S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK * sizeof(K_T)); - keyL1_ = bufKeyL1_.Get(); - - pipe->InitBuffer(bufWeightL1_, DOUBLE_BUF_NUM * S1G_BASIC_BLOCK_L1 * BLOCK_CUBE * sizeof(half)); - weightL1_ = bufWeightL1_.Get(); - pipe->InitBuffer(bufSL1_, DOUBLE_BUF_NUM * S2_BASIC_BLOCK_L0 * S1G_BASIC_BLOCK_L0 * sizeof(half)); - sL1_ = bufSL1_.Get(); - - pipe->InitBuffer(bufL0A_, 64 * 1024); - l0a_ = bufL0A_.Get(); - pipe->InitBuffer(bufL0B_, 64 * 1024); - l0b_ = bufL0B_.Get(); - - pipe->InitBuffer(bufL0C_, 128 * 1024); - cL0_ = bufL0C_.Get(); -} - -template -__aicore__ inline void LIQMatmul::InitMm1GlobalTensor(const GlobalTensor &blkTableGm, - const GlobalTensor &keyGm, - const GlobalTensor &queryGm, - const GlobalTensor &mm1ResGm, - const GlobalTensor &weightWorkspaceGm) -{ - blkTableGm_ = blkTableGm; - keyGm_ = keyGm; - queryGm_ = queryGm; - mm1ResGm_ = mm1ResGm; - weightGm_ = weightWorkspaceGm; -} - -template -__aicore__ inline void LIQMatmul::ProcessWs(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t sL1BufIdx, - const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo) -{ - WaitFlag(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM); - for (int64_t s1gOffset = 0; s1gOffset < s1gL0RealSize; s1gOffset += constInfo_.gSize) { - WaitFlag(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM); - LoadSToL0b(s1gL0RealSize, mmInfo.s2L0RealSize, sL1BufIdx, s1gOffset); - LoadWeightToL0a(s1gOffset + s1gL1Offset); - - ComputeWs(s1gL0RealSize, mmInfo.s2L0RealSize, s1gOffset); - - SetFlag(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM); - l0BufIdx_++; - } - - FixpResToGm(s1gL0RealSize / constInfo_.gSize, mmInfo.s2L0RealSize, s1gL1Offset / constInfo_.gSize, - mmInfo.s2L0LoopId * S2_BASIC_BLOCK_L0, runInfo); - SetFlag(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM); - l0cBufIdx_++; -} - -template -__aicore__ inline void LIQMatmul::ProcessQk(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t s1L0LoopCnt, - const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo) -{ - if (mmInfo.s1gL0LoopId == 0) { - WaitFlag(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % DOUBLE_BUF_NUM); - if constexpr (K_LAYOUT_T == LI_LAYOUT::PA_BSND) { - KeyNd2NzForPA(mmInfo.s2L0RealSize, runInfo.s2Idx * constInfo_.s2BaseSize + mmInfo.s2GmOffset, runInfo); - } else { - KeyNd2Nz(mmInfo.s2L0RealSize, mmInfo, runInfo); - } - - SetFlag(MTE2_MTE1_EVENT); - WaitFlag(MTE2_MTE1_EVENT); - } - - WaitFlag(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM); - LoadQueryToL0a(s1gL1Offset, runInfo.actMBaseSize, s1gL0RealSize); - LoadKeyToL0b(mmInfo.s2L0RealSize); - - if (mmInfo.s1gL0LoopId + 1 >= s1L0LoopCnt) { - SetFlag(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % DOUBLE_BUF_NUM); - keyL1BufIdx_++; - } - - WaitFlag(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM); - ComputeQk(s1gL0RealSize, mmInfo.s2L0RealSize); - SetFlag(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM); - - FixpSToL1(s1gL0RealSize, mmInfo.s2L0RealSize); - SetFlag(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM); - l0BufIdx_++; - l0cBufIdx_++; -} - -template -__aicore__ inline void LIQMatmul::CalcMmInfo(MmInfo &mmInfo, uint64_t loopIdx, uint64_t s1L0LoopCnt, - const MmInfo &lastMmInfo, const LIQCommon::RunInfo &runInfo) -{ - mmInfo.s2L0LoopId = loopIdx / s1L0LoopCnt; - mmInfo.s1gL0LoopId = loopIdx % s1L0LoopCnt; - - if (mmInfo.s1gL0LoopId == 0) { - mmInfo.s2GmOffset = mmInfo.s2L0LoopId * S2_BASIC_BLOCK_L0; - mmInfo.s2L0RealSize = mmInfo.s2GmOffset + S2_BASIC_BLOCK_L0 > runInfo.actualSingleProcessSInnerSize - ? runInfo.actualSingleProcessSInnerSize - mmInfo.s2GmOffset - : S2_BASIC_BLOCK_L0; - } else { - mmInfo.s2L0RealSize = lastMmInfo.s2L0RealSize; - } -} - -template -__aicore__ inline void LIQMatmul::ComputeMm1(const LIQCommon::RunInfo &runInfo) -{ - if (runInfo.isFirstS2InnerLoop) { - WaitFlag(QW_MTE1_MTE2_EVENT + qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM); - QueryNd2Nz(runInfo.actMBaseSize, runInfo); // 256 * 128 // L1BasicBlock - WeightDmaCopy(runInfo.actMBaseSize, runInfo); - } - int64_t loopIdx = 0; - int64_t s2L0LoopCnt = CeilDiv(runInfo.actualSingleProcessSInnerSize, S2_BASIC_BLOCK_L0); // 2048取128 - int64_t s1L0LoopCnt = CeilDiv(runInfo.actMBaseSize, S1G_BASIC_BLOCK_L0); // 256取128 - int64_t s1gL1Offset[2] = {0, static_cast(S1G_BASIC_BLOCK_L0)}; - int64_t s1gL0RealSize[2] = {s1L0LoopCnt > 1 ? static_cast(S1G_BASIC_BLOCK_L0) : runInfo.actMBaseSize, - runInfo.actMBaseSize - s1gL1Offset[1]}; - MmInfo mmInfo[2]; - CalcMmInfo(mmInfo[loopIdx & 1], loopIdx, s1L0LoopCnt, mmInfo[(loopIdx + 1) & 1], runInfo); - - ProcessQk(s1gL0RealSize[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt], - s1gL1Offset[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt], s1L0LoopCnt, mmInfo[loopIdx & 1], - runInfo); - - SetFlag(FIX_MTE1_EVENT + sL1BufIdx_ % DOUBLE_BUF_NUM); - sL1BufIdx_++; - loopIdx++; - - while (loopIdx < s2L0LoopCnt * s1L0LoopCnt) { - CalcMmInfo(mmInfo[loopIdx & 1], loopIdx, s1L0LoopCnt, mmInfo[(loopIdx + 1) & 1], runInfo); - - ProcessQk(s1gL0RealSize[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt], - s1gL1Offset[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt], s1L0LoopCnt, mmInfo[loopIdx & 1], - runInfo); - - SetFlag(FIX_MTE1_EVENT + sL1BufIdx_ % DOUBLE_BUF_NUM); - sL1BufIdx_++; - - WaitFlag(FIX_MTE1_EVENT + sL1BufIdx_ % DOUBLE_BUF_NUM); - - ProcessWs(s1gL0RealSize[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt], - s1gL1Offset[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt], sL1BufIdx_, - mmInfo[(loopIdx + 1) & 1], runInfo); - loopIdx++; - } - - WaitFlag(FIX_MTE1_EVENT + (sL1BufIdx_ + 1) % DOUBLE_BUF_NUM); - - ProcessWs(s1gL0RealSize[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt], - s1gL1Offset[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt], sL1BufIdx_ - 1, - mmInfo[(loopIdx + 1) & 1], runInfo); - - if (runInfo.isLastS2InnerLoop) { - SetFlag(QW_MTE1_MTE2_EVENT + qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM); - qwL1Mte2BufIdx_++; - } -} - -// blkNum, blkSize, N2, D -template -__aicore__ inline void LIQMatmul::KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset, - const LIQCommon::RunInfo &runInfo) -{ - uint64_t s2L1Offset = 0; - while (s2L1Offset < s2L1RealSize) { - uint64_t s2BlkId = (s2L1Offset + s2GmOffset) / constInfo_.kCacheBlockSize; - uint64_t s2BlkOffset = (s2L1Offset + s2GmOffset) % constInfo_.kCacheBlockSize; - uint64_t keyGmOffset = blkTableGm_.GetValue(runInfo.bIdx * constInfo_.maxBlockNumPerBatch + s2BlkId) * - constInfo_.kCacheBlockSize * constInfo_.kHeadNum * constInfo_.headDim + - s2BlkOffset * constInfo_.headDim; - uint64_t s2Mte2Size = s2L1RealSize - s2L1Offset; - s2Mte2Size = s2BlkOffset + s2Mte2Size >= constInfo_.kCacheBlockSize ? constInfo_.kCacheBlockSize - s2BlkOffset - : s2Mte2Size; - Nd2NzParams nd2nzPara; - nd2nzPara.ndNum = 1; - nd2nzPara.nValue = s2Mte2Size; // 行数 - nd2nzPara.dValue = constInfo_.headDim; - nd2nzPara.srcDValue = constInfo_.headDim; - nd2nzPara.dstNzC0Stride = CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE); // 对齐到16 单位block - nd2nzPara.dstNzNStride = 1; - nd2nzPara.srcNdMatrixStride = 0; - nd2nzPara.dstNzMatrixStride = 0; - DataCopy(keyL1_[(keyL1BufIdx_ % DOUBLE_BUF_NUM) * KEY_BUFFER_OFFSET + s2L1Offset * S8_BLOCK_CUBE], - keyGm_[keyGmOffset], nd2nzPara); - - s2L1Offset += s2Mte2Size; - } -} - -template -__aicore__ inline void LIQMatmul::KeyNd2Nz(uint64_t s2L1RealSize, const MmInfo &mmInfo, - const LIQCommon::RunInfo &runInfo) -{ - uint64_t dStride = constInfo_.headDim; - if constexpr (K_LAYOUT_T == LI_LAYOUT::BSND || K_LAYOUT_T == LI_LAYOUT::TND) { - dStride = constInfo_.headDim * constInfo_.kHeadNum; // constInfo_.kHeadNum - } - Nd2NzParams nd2nzPara; - nd2nzPara.ndNum = 1; - nd2nzPara.nValue = s2L1RealSize; // 行数 - nd2nzPara.dValue = constInfo_.headDim; - nd2nzPara.srcDValue = dStride; - nd2nzPara.dstNzC0Stride = CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE); // 对齐到16 单位block - nd2nzPara.dstNzNStride = 1; - nd2nzPara.srcNdMatrixStride = 0; - nd2nzPara.dstNzMatrixStride = 0; - // 默认一块buf最多放两份 - DataCopy(keyL1_[(keyL1BufIdx_ % DOUBLE_BUF_NUM) * KEY_BUFFER_OFFSET], - keyGm_[runInfo.tensorKeyOffset + mmInfo.s2GmOffset * constInfo_.headDim], nd2nzPara); -} - -// batch, s1, g, 1 -template -__aicore__ inline void LIQMatmul::WeightDmaCopy(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo) -{ - DataCopyParams copyInParams; - copyInParams.blockCount = 1; - copyInParams.blockLen = s1gL1RealSize; - copyInParams.srcStride = 0; - copyInParams.dstStride = 0; - DataCopy(weightL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * WEIGHT_BUFFER_OFFSET], - weightGm_[runInfo.loop % DOUBLE_BUF_NUM * BLOCK_CUBE * constInfo_.mBaseSize], copyInParams); -} - -// batch, s1, n2, g, d -template -__aicore__ inline void LIQMatmul::QueryNd2Nz(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo) -{ - Nd2NzParams nd2nzPara; - nd2nzPara.ndNum = 1; - nd2nzPara.nValue = s1gL1RealSize; // 行数 - nd2nzPara.dValue = constInfo_.headDim; - nd2nzPara.srcDValue = constInfo_.headDim; - nd2nzPara.dstNzC0Stride = CeilAlign(s1gL1RealSize, (uint64_t)BLOCK_CUBE); // 对齐到16 单位block - nd2nzPara.dstNzNStride = 1; - nd2nzPara.srcNdMatrixStride = 0; - nd2nzPara.dstNzMatrixStride = 0; - // 默认一块buf最多放两份 - DataCopy(queryL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * QUERY_BUFFER_OFFSET], queryGm_[runInfo.tensorQueryOffset], - nd2nzPara); -} - -// s1g, d -template -__aicore__ inline void LIQMatmul::LoadQueryToL0a(uint64_t s1gL1Offset, uint64_t s1gL1RealSize, - uint64_t s1gL0RealSize) -{ - LoadData3DParamsV2 loadData3DParams; - // SetFmatrixParams - loadData3DParams.l1H = CeilDiv(s1gL1RealSize, BLOCK_CUBE); // Hin=M1=8 - loadData3DParams.l1W = BLOCK_CUBE; // Win=M0 - loadData3DParams.channelSize = constInfo_.headDim; // Cin=K - - loadData3DParams.padList[0] = 0; - loadData3DParams.padList[1] = 0; - loadData3DParams.padList[2] = 0; - loadData3DParams.padList[3] = 255; // 尾部数据不影响滑窗的结果 - - // SetLoadToA0Params - loadData3DParams.mExtension = CeilAlign(s1gL0RealSize, BLOCK_CUBE); // M height维度目的 - loadData3DParams.kExtension = constInfo_.headDim; // K width维度目的 - loadData3DParams.mStartPt = s1gL1Offset; - loadData3DParams.kStartPt = 0; - loadData3DParams.strideW = 1; - loadData3DParams.strideH = 1; - loadData3DParams.filterW = 1; - loadData3DParams.filterSizeW = (1 >> 8) & 255; - loadData3DParams.filterH = 1; - loadData3DParams.filterSizeH = (1 >> 8) & 255; - loadData3DParams.dilationFilterW = 1; - loadData3DParams.dilationFilterH = 1; - loadData3DParams.enTranspose = 0; - loadData3DParams.fMatrixCtrl = 0; - - LoadData(l0a_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K], - queryL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * QUERY_BUFFER_OFFSET], - loadData3DParams); -} - -// s1, g, s2 --> 2 * 64* 128 -template -__aicore__ inline void LIQMatmul::LoadSToL0b(uint64_t s1gL1RealSize, uint64_t s2L0RealSize, uint64_t sL1BufIdx, - int64_t mStartPt) -{ - LoadData3DParamsV2 loadData3DParams; - // SetFmatrixParams - loadData3DParams.l1H = S1G_BASIC_BLOCK_L0 / BLOCK_CUBE; // Hin=M1=8 - loadData3DParams.l1W = BLOCK_CUBE; // Win=M0 - loadData3DParams.channelSize = CeilAlign(s2L0RealSize, BLOCK_CUBE); // Cin=K - - loadData3DParams.padList[0] = 0; - loadData3DParams.padList[1] = 0; - loadData3DParams.padList[2] = 0; - loadData3DParams.padList[3] = 255; // 尾部数据不影响滑窗的结果 - - // SetLoadToA0Params - loadData3DParams.mExtension = constInfo_.gSize; // M height维度目的 - loadData3DParams.kExtension = CeilAlign(s2L0RealSize, BLOCK_CUBE); // K width维度目的 - loadData3DParams.kStartPt = 0; - loadData3DParams.strideW = 1; - loadData3DParams.strideH = 1; - loadData3DParams.filterW = 1; - loadData3DParams.filterSizeW = (1 >> 8) & 255; - loadData3DParams.filterH = 1; - loadData3DParams.filterSizeH = (1 >> 8) & 255; - loadData3DParams.dilationFilterW = 1; - loadData3DParams.dilationFilterH = 1; - loadData3DParams.enTranspose = 1; - loadData3DParams.fMatrixCtrl = 0; - - loadData3DParams.mStartPt = mStartPt; - LoadData( - l0b_.template ReinterpretCast()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K], - sL1_[(sL1BufIdx % DOUBLE_BUF_NUM) * SL1_BUFFER_OFFSET], loadData3DParams); -} - -// s1,g,1(16), 2,64,16 -template -__aicore__ inline void LIQMatmul::LoadWeightToL0a(uint64_t s1gL1Offset) -{ - LoadData2DParams loadData2DParams; - loadData2DParams.startIndex = 0; - loadData2DParams.repeatTimes = CeilDiv(constInfo_.gSize, BLOCK_CUBE); - loadData2DParams.srcStride = 1; - loadData2DParams.dstGap = 0; - loadData2DParams.ifTranspose = true; - LoadData(l0a_.template ReinterpretCast()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K], - weightL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * WEIGHT_BUFFER_OFFSET + s1gL1Offset* BLOCK_CUBE], - loadData2DParams); -} - -// s2, d -> 128,128 -template -__aicore__ inline void LIQMatmul::LoadKeyToL0b(uint64_t s2L0RealSize) -{ - LoadData2DParams loadData2DParams; - loadData2DParams.startIndex = 0; - loadData2DParams.repeatTimes = CeilDiv(s2L0RealSize, BLOCK_CUBE) * CeilDiv(constInfo_.headDim, S8_BLOCK_CUBE); - loadData2DParams.srcStride = 1; - loadData2DParams.dstGap = 0; - loadData2DParams.ifTranspose = false; - LoadData(l0b_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K], - keyL1_[(keyL1BufIdx_ % DOUBLE_BUF_NUM) * KEY_BUFFER_OFFSET], loadData2DParams); -} - -// A: s1,g,1(16) B: s1,g,s2 C: s1, 1(16), s2 -template -__aicore__ inline void LIQMatmul::ComputeWs(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, int64_t s1gOffset) -{ - SetFlag(MTE1_M_EVENT); - WaitFlag(MTE1_M_EVENT); - MmadParams mmadParams; - mmadParams.m = BLOCK_CUBE; - mmadParams.n = s2L0RealSize; - mmadParams.k = constInfo_.gSize; - mmadParams.cmatrixInitVal = true; - mmadParams.cmatrixSource = false; - Mmad(cL0_.template ReinterpretCast()[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET + - s1gOffset * S2_BASIC_BLOCK_L0], - l0a_.template ReinterpretCast()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K], - l0b_.template ReinterpretCast()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K], - mmadParams); -} - -template -__aicore__ inline void LIQMatmul::ComputeQk(uint64_t s1gL0RealSize, uint64_t s2L0RealSize) -{ - SetFlag(MTE1_M_EVENT); - WaitFlag(MTE1_M_EVENT); - - MmadParams mmadParams; - mmadParams.m = CeilAlign(s1gL0RealSize, BLOCK_CUBE); - mmadParams.n = s2L0RealSize; - mmadParams.k = constInfo_.headDim; - mmadParams.cmatrixInitVal = true; - mmadParams.cmatrixSource = false; - Mmad(cL0_[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET], - l0a_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K], - l0b_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K], mmadParams); - if ((mmadParams.m / 16) * (mmadParams.n / 16) < 10) { - PipeBarrier(); - } -} - -template -__aicore__ inline void LIQMatmul::FixpSToL1(uint64_t s1gL0RealSize, uint64_t s2L0RealSize) -{ - SetFlag(M_FIX_EVENT); - WaitFlag(M_FIX_EVENT); - DataCopyCO12DstParams params; - params.mSize = CeilAlign(s1gL0RealSize, BLOCK_CUBE); - params.nSize = CeilAlign(s2L0RealSize, BLOCK_CUBE); - params.dstStride = S1G_BASIC_BLOCK_L0; - params.srcStride = params.mSize; - params.quantPre = QuantMode_t::DEQF16; - params.reluPre = 1; - params.channelSplit = 0; - params.nz2ndEn = 0; - SetFixpipePreQuantFlag(0x3a800000); - DataCopy(sL1_[(sL1BufIdx_ % DOUBLE_BUF_NUM) * SL1_BUFFER_OFFSET], - cL0_[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET], params); -} - -template -__aicore__ inline void LIQMatmul::FixpResToGm(uint64_t s1L0RealCount, uint64_t s2L0RealSize, uint64_t s1GmOffset, - uint64_t s2GmOffset, const LIQCommon::RunInfo &runInfo) -{ - SetFlag(M_FIX_EVENT); - WaitFlag(M_FIX_EVENT); - - AscendC::DataCopyCO12DstParams intriParams; - intriParams.mSize = 1; - intriParams.nSize = s2L0RealSize; - intriParams.dstStride = constInfo_.s2BaseSize; - intriParams.srcStride = 16; - // set mode according to dtype - intriParams.quantPre = QuantMode_t::NoQuant; - intriParams.nz2ndEn = true; - intriParams.reluPre = 0; - AscendC::SetFixpipeNz2ndFlag(s1L0RealCount, CeilDiv(constInfo_.gSize, BLOCK_CUBE) * S2_BASIC_BLOCK_L0 / BLOCK_CUBE, - 2048); - AscendC::DataCopy(mm1ResGm_[(runInfo.loop % 2) * constInfo_.mBaseSize / constInfo_.gSize * constInfo_.s2BaseSize + - s1GmOffset * intriParams.dstStride + s2GmOffset], - cL0_.template ReinterpretCast()[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET], - intriParams); -} - -template -__aicore__ inline void LIQMatmul::AllocEventID() -{ - SetFlag(KEY_MTE1_MTE2_EVENT + 0); - SetFlag(KEY_MTE1_MTE2_EVENT + 1); - SetFlag(KEY_MTE1_MTE2_EVENT + 2); - - SetFlag(QW_MTE1_MTE2_EVENT + 0); - SetFlag(QW_MTE1_MTE2_EVENT + 1); - - SetFlag(M_MTE1_EVENT + 0); - SetFlag(M_MTE1_EVENT + 1); - SetFlag(M_MTE1_EVENT + 2); - SetFlag(M_MTE1_EVENT + 3); - - SetFlag(FIX_M_EVENT + 0); - SetFlag(FIX_M_EVENT + 1); -} - -template -__aicore__ inline void LIQMatmul::FreeEventID() -{ - WaitFlag(KEY_MTE1_MTE2_EVENT + 0); - WaitFlag(KEY_MTE1_MTE2_EVENT + 1); - WaitFlag(KEY_MTE1_MTE2_EVENT + 2); - - WaitFlag(QW_MTE1_MTE2_EVENT + 0); - WaitFlag(QW_MTE1_MTE2_EVENT + 1); - - WaitFlag(M_MTE1_EVENT + 0); - WaitFlag(M_MTE1_EVENT + 1); - WaitFlag(M_MTE1_EVENT + 2); - WaitFlag(M_MTE1_EVENT + 3); - - WaitFlag(FIX_M_EVENT + 0); - WaitFlag(FIX_M_EVENT + 1); -} -} // namespace LIQKernel -#endif \ No newline at end of file diff --git a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_service_vector.h b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_service_vector.h deleted file mode 100644 index 2588998c..00000000 --- a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_service_vector.h +++ /dev/null @@ -1,665 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify it. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file lightning_indexer_quant_service_vector.h - * \brief - */ -#ifndef LIGHTNING_INDEXER_QUANT_SERVICE_VECTOR_H -#define LIGHTNING_INDEXER_QUANT_SERVICE_VECTOR_H - -#include "kernel_operator.h" -#include "kernel_operator_list_tensor_intf.h" -#include "kernel_tiling/kernel_tiling.h" -#include "lib/matmul_intf.h" -#include "lib/matrix/matmul/tiling.h" -#include "lightning_indexer_quant_common.h" -#include "lightning_indexer_quant_vector.h" - -namespace LIQKernel { -using namespace LIQCommon; -using namespace LIQServiceVec; -constexpr uint32_t BASE_TOPK = 2048; -constexpr uint32_t BASE_TOPK_VALUE_IDX_SIZE = 4096; -constexpr uint32_t LD_PARAM_NUM = 16; - -template -class LIQVector { -public: - // =================================类型定义区================================= - static constexpr LI_LAYOUT Q_LAYOUT_T = LIQT::layout; - static constexpr LI_LAYOUT K_LAYOUT_T = LIQT::keyLayout; - static constexpr bool PAGE_ATTENTION = LIQT::pageAttention; - // MM输出数据类型, 当前只支持float - using MM1_OUT_T = float; - - __aicore__ inline LIQVector(){}; - __aicore__ inline void ProcessVec0(const LIQCommon::RunInfo &info); - __aicore__ inline void ProcessVec1(const LIQCommon::RunInfo &info); - __aicore__ inline void ProcessLD(); - __aicore__ inline void InitBuffers(TPipe *pipe); - __aicore__ inline void InitParams(const struct LIQCommon::ConstInfo &constInfo, - const LIQTilingData *__restrict tilingData); - __aicore__ inline void InitVecWorkspaceTensor(GlobalTensor vec0OutGm, GlobalTensor mm1ResGm, - GlobalTensor vec1ResGm, GlobalTensor vec1ParamGm); - __aicore__ inline void InitVecInputTensor(GlobalTensor weightsGm, GlobalTensor qScaleGm, - GlobalTensor kScaleGm, GlobalTensor indiceOutGm, - GlobalTensor blockTableGm); - __aicore__ inline void CleanInvalidOutput(int64_t invalidS1offset); - __aicore__ inline void AllocEventID(); - __aicore__ inline void FreeEventID(); - __aicore__ inline void InitLDBuffers(TPipe *pipe); - -protected: - GlobalTensor mm1ResGm; - GlobalTensor vec1ResGm; - GlobalTensor vec1ParamGm; - GlobalTensor weightsGm; - GlobalTensor qScaleGm; - GlobalTensor kScaleGm; - GlobalTensor vec0OutGm; - GlobalTensor indiceOutGm; - GlobalTensor blockTableGm; - // =================================常量区================================= - -private: - __aicore__ inline void GetKeyScale(const LIQCommon::RunInfo &runInfo, const LocalTensor &resUb, - int64_t batchId, int64_t startS2, int64_t getLen); - // ================================Local Buffer区==================================== - // queue - TQue inQueue_; - TQue outQueue_; - - // tmp buff for vector - TBuf sortOutBuf_; - TBuf indexBuf_; - TBuf paramBuf_; - TBuf tmpBuf_; - - // tmp buff for LD - TBuf<> ldToBeMrgBuf_; - TBuf<> ldTmpBuf_; - TBuf<> ldOutValueBuf_; - TBuf<> ldOutIdxBuf_; - - LocalTensor globalTopkIndice_; - LocalTensor globalTopkUb_; - - int32_t blockId_ = -1; - // para for vector - int32_t groupInner_ = 0; - int32_t globalTopkNum_ = 0; - int64_t blockS2StartIdx_ = 0; - int32_t gSize_ = 0; - int32_t kSeqSize_ = 0; - int32_t kHeadNum_ = 0; - int32_t qHeadNum_ = 0; - int32_t s1BaseSize_ = 0; - int32_t s2BaseSize_ = 0; - int32_t kCacheBlockSize_ = 0; - int32_t maxBlockNumPerBatch_ = 0; - - // para for LD - uint32_t mrgListNum_ = 4; - uint32_t paramNum_ = 16; - - struct LIQCommon::ConstInfo constInfo_; -}; - -template -__aicore__ inline void LIQVector::GetKeyScale(const LIQCommon::RunInfo &runInfo, const LocalTensor &resUb, - int64_t batchId, int64_t startS2, int64_t getLen) -{ - // startS2一定能整除kCacheBlockSize_ - AscendC::DataCopyPadExtParams padParams{false, 0, 0, 0}; - AscendC::DataCopyExtParams copyInParams; - if constexpr (PAGE_ATTENTION) { - int32_t startBlockTableIdx = startS2 / kCacheBlockSize_; - int32_t startBlockTableOffset = startS2 % kCacheBlockSize_; - int32_t blockTableBatchOffset = batchId * maxBlockNumPerBatch_; - copyInParams.blockCount = 1; - copyInParams.srcStride = 0; - copyInParams.dstStride = 0; - copyInParams.rsv = 0; - int32_t resUbBaseOffset = 0; - if (startBlockTableOffset > 0) { - int32_t firstPartLen = - kCacheBlockSize_ - startBlockTableOffset > getLen ? getLen : kCacheBlockSize_ - startBlockTableOffset; - copyInParams.blockLen = firstPartLen * sizeof(half); - int32_t blockId = blockTableGm.GetValue(blockTableBatchOffset + startBlockTableIdx); - SetWaitFlag(HardEvent::S_MTE2); - AscendC::DataCopyPad(resUb, kScaleGm[blockId * kCacheBlockSize_ + startBlockTableOffset], - copyInParams, padParams); - startBlockTableIdx++; - getLen = getLen - firstPartLen; - resUbBaseOffset = firstPartLen; - } - int32_t getLoopNum = CeilDiv(getLen, kCacheBlockSize_); - copyInParams.blockLen = kCacheBlockSize_ * sizeof(half); - for (int32_t i = 0; i < getLoopNum; i++) { - if (i == getLoopNum - 1) { - copyInParams.blockLen = (getLen - i * kCacheBlockSize_) * sizeof(half); - } - int32_t blockId = blockTableGm.GetValue(blockTableBatchOffset + startBlockTableIdx + i); - SetWaitFlag(HardEvent::S_MTE2); - AscendC::DataCopyPad(resUb[resUbBaseOffset + i * kCacheBlockSize_], kScaleGm[blockId * kCacheBlockSize_], - copyInParams, padParams); - } - } else { - copyInParams.blockCount = 1; - copyInParams.blockLen = getLen * sizeof(half); - copyInParams.srcStride = 0; - copyInParams.dstStride = 0; - copyInParams.rsv = 0; - AscendC::DataCopyPad(resUb, kScaleGm[runInfo.tensorKeyScaleOffset], copyInParams, padParams); - } -} - -template -__aicore__ inline void LIQVector::InitBuffers(TPipe *pipe) -{ - pipe->InitBuffer(paramBuf_, LD_PARAM_NUM * sizeof(int64_t)); // 1 KB - pipe->InitBuffer(inQueue_, 2, s2BaseSize_ * sizeof(float) * 2); // 32KB - pipe->InitBuffer(outQueue_, 1, BASE_TOPK * sizeof(float)); // 8 KB - pipe->InitBuffer(indexBuf_, s2BaseSize_ * sizeof(int32_t)); // 8 KB - pipe->InitBuffer(tmpBuf_, 64 * 1024); // 64KB - pipe->InitBuffer(sortOutBuf_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK_VALUE_IDX_SIZE * sizeof(float)); // 32KB - - globalTopkIndice_ = indexBuf_.Get(); - globalTopkUb_ = sortOutBuf_.Get(); - globalTopkNum_ = 0; - - // 基本块执行前初始化UB和GM - // step1. 初始化一个有序索引 0 - s2BaseSize_ - ArithProgression(globalTopkIndice_, 0, 1, s2BaseSize_); - // step2. globalTopkUb_ [CeilDiv(s1BaseSize_, 2), BASE_TOPK, 2] -inf,-1 - InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK_VALUE_IDX_SIZE); - - // step3. 初始化vec1ParamGm,是否进行LD的标志位设为-1(needFd=-1) - // vec1ResIn32Gm = [aic, 2, s1BaseSize_, 16] int32 - // ws清零 [needFd, s2AcSeq, s2Start, s2End, isS2End, bn2idx, s1Idx, ......] - LocalTensor tmpfBuff = outQueue_.AllocTensor(); - Duplicate(tmpfBuff.template ReinterpretCast(), -1, 2 * (s1BaseSize_ / 2) * paramNum_ * 2); - SetWaitFlag(HardEvent::V_MTE3); - int64_t wsInfoOffset = (blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ + // 2个AIV共同地址偏移 - (blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_; // 每个AIV的地址偏移,S1方向 - DataCopyPad(vec1ParamGm[wsInfoOffset], tmpfBuff.template ReinterpretCast(), - {1, static_cast((s1BaseSize_ / 2) * 2 * paramNum_ * sizeof(int64_t)), 0, 0}); - SetWaitFlag(HardEvent::MTE3_V); - outQueue_.FreeTensor(tmpfBuff); -} - -template -__aicore__ inline void LIQVector::InitLDBuffers(TPipe *pipe) -{ - pipe->Reset(); - pipe->InitBuffer(ldToBeMrgBuf_, BASE_TOPK_VALUE_IDX_SIZE * mrgListNum_ * sizeof(float)); - pipe->InitBuffer(ldTmpBuf_, BASE_TOPK_VALUE_IDX_SIZE * mrgListNum_ * sizeof(float)); - pipe->InitBuffer(ldOutValueBuf_, BASE_TOPK * sizeof(float)); - pipe->InitBuffer(ldOutIdxBuf_, BASE_TOPK * sizeof(int32_t)); -} - -template -__aicore__ inline void LIQVector::InitParams(const struct LIQCommon::ConstInfo &constInfo, - const LIQTilingData *__restrict tilingData) -{ - this->constInfo_ = constInfo; - blockS2StartIdx_ = 0; - gSize_ = constInfo.gSize; - kSeqSize_ = constInfo.kSeqSize; - // define N2 para - kHeadNum_ = constInfo.kHeadNum; - qHeadNum_ = constInfo.qHeadNum; - // define MMBase para - s1BaseSize_ = constInfo.s1BaseSize; // 4 - s2BaseSize_ = constInfo.s2BaseSize; // 2048 - kCacheBlockSize_ = constInfo.kCacheBlockSize; - maxBlockNumPerBatch_ = constInfo.maxBlockNumPerBatch; - blockId_ = GetBlockIdx(); -} - -template -__aicore__ inline void LIQVector::InitVecInputTensor(GlobalTensor weightsGm, GlobalTensor qScaleGm, - GlobalTensor kScaleGm, - GlobalTensor indiceOutGm, - GlobalTensor blockTableGm) -{ - this->weightsGm = weightsGm; - this->qScaleGm = qScaleGm; - this->kScaleGm = kScaleGm; - this->indiceOutGm = indiceOutGm; - this->blockTableGm = blockTableGm; -} - -template -__aicore__ inline void LIQVector::InitVecWorkspaceTensor(GlobalTensor vec0OutGm, - GlobalTensor mm1ResGm, - GlobalTensor vec1ResGm, - GlobalTensor vec1ParamGm) -{ - this->mm1ResGm = mm1ResGm; - this->vec1ResGm = vec1ResGm; - this->vec0OutGm = vec0OutGm; - this->vec1ParamGm = vec1ParamGm; -} - -template -__aicore__ inline void LIQVector::AllocEventID() -{ -} - -template -__aicore__ inline void LIQVector::FreeEventID() -{ -} - -template -__aicore__ inline void LIQVector::CleanInvalidOutput(int64_t invalidS1offset) -{ - // init -1 and copy to output - LocalTensor valueULocal = outQueue_.AllocTensor(); - LocalTensor idxULocal1 = valueULocal.template ReinterpretCast(); - Duplicate(idxULocal1, constInfo_.INVALID_IDX, constInfo_.sparseCount); - outQueue_.EnQue(valueULocal); - valueULocal = outQueue_.DeQue(); - LIQServiceVec::CopyOut(indiceOutGm[invalidS1offset], idxULocal1, constInfo_.sparseCount); - outQueue_.FreeTensor(valueULocal); -} - -template -__aicore__ inline void LIQVector::ProcessVec0(const LIQCommon::RunInfo &info) -{ - // 只需要一个v核做 - if (blockId_ % 2 != 0) { - return; - } - int32_t cuBaseS1Idx = info.gS1Idx * s1BaseSize_; - // 计算输出w基地址偏移 偶数循环 -> 0 + aic_offset 奇数循环 -> 4*64 + aic_offset - int64_t vec0OutGmOffset = (info.loop % 2) * ((s1BaseSize_ * gSize_ * BLOCK_CUBE)); - // 计算输入weight的地址偏移,qScale的地址偏移与weight相同 - int64_t weightGmOffset = info.tensorWeightsOffset + cuBaseS1Idx * qHeadNum_; - // 当前需要计算的S1行数,处理尾块场景 - int32_t cuS1ProcNum = cuBaseS1Idx + s1BaseSize_ > info.actS1Size ? info.actS1Size % s1BaseSize_ : s1BaseSize_; - int32_t cuProcEleNum = cuS1ProcNum * gSize_; - - LocalTensor inWeightsUb = inQueue_.AllocTensor(); - LocalTensor inQScaleUb = inWeightsUb[cuProcEleNum]; - AscendC::DataCopyPadExtParams padParams{false, 0, 0, 0}; - AscendC::DataCopyExtParams copyInParams; - copyInParams.blockCount = 1; - copyInParams.blockLen = cuProcEleNum * sizeof(half); - copyInParams.srcStride = 0; - copyInParams.dstStride = 0; - copyInParams.rsv = 0; - AscendC::DataCopyPad(inWeightsUb, weightsGm[weightGmOffset], copyInParams, padParams); - AscendC::DataCopyPad(inQScaleUb, qScaleGm[weightGmOffset], copyInParams, padParams); - - inQueue_.EnQue(inWeightsUb); - inWeightsUb = inQueue_.DeQue(); - AscendC::Mul(inWeightsUb, inWeightsUb, inQScaleUb, cuProcEleNum); - PipeBarrier(); - LocalTensor resUb = outQueue_.AllocTensor(); - AscendC::Brcb(resUb, inWeightsUb, static_cast(cuProcEleNum / 8), {1, 8}); - inQueue_.FreeTensor(inWeightsUb); - - outQueue_.EnQue(resUb); - resUb = outQueue_.DeQue(); - AscendC::DataCopyParams copyOutParams; - copyOutParams.blockCount = 1; - copyOutParams.blockLen = cuProcEleNum * BLOCK_CUBE * sizeof(half); - copyOutParams.srcStride = 0; - copyOutParams.dstStride = 0; - AscendC::DataCopyPad(vec0OutGm[vec0OutGmOffset], resUb, copyOutParams); - outQueue_.FreeTensor(resUb); -} - -template -__aicore__ inline void LIQVector::ProcessVec1(const LIQCommon::RunInfo &info) -{ - int32_t cuBaseS1Idx = info.gS1Idx * s1BaseSize_; - int32_t cuBaseS2Idx = info.s2Idx * s2BaseSize_; - - // 计算基本块基地址偏移 偶数循环 -> 0 + aic_offset 奇数循环 -> 4*2048 + aic_offset - int64_t mmGmOffset = (info.loop % 2) * (s1BaseSize_ * s2BaseSize_); - - // cuS1BeginIdxPerAiv: 每个AIV的S1起始偏移 - int32_t cuS1BeginIdxPerAiv = cuBaseS1Idx; - int32_t cuS1ProcNum = - cuS1BeginIdxPerAiv + s1BaseSize_ > info.actS1Size ? info.actS1Size % s1BaseSize_ : s1BaseSize_; - // cuS1ProcNumPerAiv: 每个AIv的S1计算量 - int32_t cuS1ProcNumPerAiv = blockId_ % 2 == 0 ? CeilDiv(cuS1ProcNum, 2) : (cuS1ProcNum / 2); - cuS1BeginIdxPerAiv += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2); - // 基本块基地址偏移奇数核加一个S1地址偏移 - mmGmOffset += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2) * s2BaseSize_; - // 非首个基本块, M(S1)轴发生切换需要初始化 - if (info.loop != 0 && info.s2Idx == 0) { - // globalTopkUb_ value,index=-inf,-1 - InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK_VALUE_IDX_SIZE); - blockS2StartIdx_ = 0; - } else if (info.loop == 0) { - blockS2StartIdx_ = info.s2Idx; - } - // cuRealAcSeq: 当前基本块S1对应的AcSeq - int32_t cuRealAcSeq = info.actS2Size; - if (constInfo_.attenMaskFlag) { - // attenMask true场景 - cuRealAcSeq = info.actS2Size - (info.actS1Size - cuS1BeginIdxPerAiv); - } - - // LD输出S1方向偏移,保证2个Vector输出的内容连续 - uint32_t ldS1Offset = (blockId_ % 2 == 0) ? s1BaseSize_ / 2 - cuS1ProcNumPerAiv : 0; - for (int innerS1Idx = 0; innerS1Idx < cuS1ProcNumPerAiv; innerS1Idx++) { - if (constInfo_.attenMaskFlag) { - cuRealAcSeq += 1; - } - int32_t cuS2Len = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq ? cuRealAcSeq - cuBaseS2Idx : s2BaseSize_; - int32_t cuS1Idx = cuS1BeginIdxPerAiv + innerS1Idx; - if (cuRealAcSeq > 0 && cuS2Len > 0) { - int32_t cuS2LenVecAlign = CeilDiv(cuS2Len, s2BaseSize_) * s2BaseSize_; - LocalTensor mmInUb = inQueue_.AllocTensor(); - LocalTensor kScaleUb = mmInUb[cuS2LenVecAlign]; - LocalTensor kScaleTUb = kScaleUb.template ReinterpretCast()[cuS2LenVecAlign]; - AscendC::DataCopyPadExtParams padParams{false, 0, 0, 0}; - AscendC::DataCopyPadExtParams padTParams{false, 0, 0, 0}; - AscendC::DataCopyExtParams copyInParams; - copyInParams.blockCount = 1; - copyInParams.blockLen = cuS2Len * sizeof(float); - copyInParams.srcStride = 0; - copyInParams.dstStride = 0; - copyInParams.rsv = 0; - AscendC::DataCopyPad(mmInUb, mm1ResGm[mmGmOffset + innerS1Idx * s2BaseSize_], copyInParams, padParams); - GetKeyScale(info, kScaleTUb, info.bIdx, cuBaseS2Idx, cuS2Len); - inQueue_.EnQue(mmInUb); - mmInUb = inQueue_.DeQue(); - AscendC::Cast(kScaleUb, kScaleTUb, RoundMode::CAST_NONE, cuS2Len); - PipeBarrier(); - AscendC::Mul(mmInUb, mmInUb, kScaleUb, cuS2Len); - PipeBarrier(); - LocalTensor sortBuff = tmpBuf_.Get(); - LocalTensor sortScoreUb = sortBuff; - LocalTensor sortIndiceUb = sortBuff[cuS2LenVecAlign]; - PipeBarrier(); - Duplicate(sortScoreUb.template ReinterpretCast(), LIQServiceVec::NEG_INF, cuS2LenVecAlign); - PipeBarrier(); - Adds(sortScoreUb, mmInUb, 0.0f, cuS2Len); - PipeBarrier(); - inQueue_.FreeTensor(mmInUb); - LocalTensor sortIndiceUbInt = sortIndiceUb.template ReinterpretCast(); - // 无效数据索引填充为-1 - if (cuS2LenVecAlign != cuS2Len) { - Duplicate(sortIndiceUbInt, -1, cuS2LenVecAlign); - PipeBarrier(); - } - Adds(sortIndiceUbInt, globalTopkIndice_, static_cast(cuBaseS2Idx), cuS2Len); - PipeBarrier(); - LocalTensor tmpSortBuf = sortBuff[2 * cuS2LenVecAlign]; - LIQServiceVec::SortAll(sortBuff, tmpSortBuf, cuS2LenVecAlign); - PipeBarrier(); - LIQServiceVec::MergeSort(globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE], BASE_TOPK, sortBuff, - cuS2LenVecAlign, tmpSortBuf); - PipeBarrier(); - bool isS2End = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq; - bool needCopyOutGm = blockS2StartIdx_ == 0 && isS2End; - // 中间结果保存 - bool needCopyWsGm = info.isAllLoopEnd || isS2End; - if (needCopyOutGm) { - LocalTensor idxULocal = outQueue_.AllocTensor(); - ExtractIndex(idxULocal, - globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE].template ReinterpretCast(), - BASE_TOPK); - PipeBarrier(); - InitSortOutBuf(globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE], BASE_TOPK_VALUE_IDX_SIZE); - outQueue_.EnQue(idxULocal); - idxULocal = outQueue_.DeQue(); - LIQServiceVec::CopyOut(indiceOutGm[info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount], - idxULocal.template ReinterpretCast(), constInfo_.sparseCount); - outQueue_.FreeTensor(idxULocal); - } else if (needCopyWsGm) { - // vec1Res Gm = [aic, s1BaseSize_, 2, 2, topkOut_] float32 - // vec1Param Gm = [aic, s1BaseSize_, 2, 16] int64 - // 16 = [needFd, s2AcSeq, s2Start, s2End, isS2End, bn2idx, s1Idx, S1ProcNum, ......] - - int64_t wsOffset = - (blockId_ / 2) * s1BaseSize_ * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 2个AIV共同地址偏移 - (blockId_ % 2) * (s1BaseSize_ / 2) * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 每个AIV的地址偏移,S1方向 - (ldS1Offset + innerS1Idx) * 2 * BASE_TOPK_VALUE_IDX_SIZE; - int64_t wsInfoOffset = - (blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ + // 2个AIV共同地址偏移 - (blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_ + // 每个AIV的地址偏移,S1方向 - (ldS1Offset + innerS1Idx) * 2 * paramNum_; - - LocalTensor tmpiBuff = paramBuf_.Get(); - SetWaitFlag(HardEvent::MTE3_S); - tmpiBuff.SetValue(0, static_cast(1)); - tmpiBuff.SetValue(1, static_cast(cuRealAcSeq)); - tmpiBuff.SetValue(2, static_cast(blockS2StartIdx_)); - tmpiBuff.SetValue(3, static_cast(cuBaseS2Idx + cuS2Len)); - tmpiBuff.SetValue(4, static_cast(isS2End)); - tmpiBuff.SetValue(5, static_cast(info.bN2Idx)); - tmpiBuff.SetValue(6, static_cast(cuS1Idx)); - tmpiBuff.SetValue(7, static_cast(cuS1ProcNum)); - tmpiBuff.SetValue(8, static_cast(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount)); - // 写入头尾判断 - // [head, tail] - // head: 与前面规约,与前后规约 - // tail: 与后面规约 - bool isTailReduce = blockS2StartIdx_ == 0; // 一定是isLastTile - // WS偏移规则 blockS2StartIdx_ != 0 - // 跟前面块做规约 写到0偏移 不用做计算 blockS2StartIdx_ == 0 and !isS2End - // 跟后面块做规约 写到1偏移 需要 + s1BaseSize_, BASE_TOPK*2 - if (isTailReduce) { // S2不是最后结束的数据就需要往后做规约,放入第二块ws - wsInfoOffset += paramNum_; - wsOffset += BASE_TOPK_VALUE_IDX_SIZE; - } - SetWaitFlag(HardEvent::S_MTE3); - LIQServiceVec::CopyOut(vec1ParamGm[wsInfoOffset], tmpiBuff, 16); - SetWaitFlag(HardEvent::V_MTE3); - LIQServiceVec::CopyOut(vec1ResGm[wsOffset], globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE], - BASE_TOPK_VALUE_IDX_SIZE); - SetWaitFlag(HardEvent::MTE3_V); - } - } else if (cuRealAcSeq <= 0) { - CleanInvalidOutput(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount); - } - } - - // BNSD场景无效S1 输出-1 - if (Q_LAYOUT_T == LI_LAYOUT::BSND) { - // 最后一个S1的基本块, 需要 >= info.actS1Size - bool isS1LoopEnd = (cuBaseS1Idx + s1BaseSize_) >= info.actS1Size; - int32_t invalidS1Num = constInfo_.qSeqSize - info.actS1Size; - // blockS2StartIdx_ == 0 控制S2从开始的核去做冗余清理 - if (invalidS1Num > 0 && isS1LoopEnd && blockS2StartIdx_ == 0) { - int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num, 2) : (invalidS1Num / 2); - int32_t s1OffsetPerAiv = info.actS1Size + (blockId_ % 2) * CeilDiv(invalidS1Num, 2); - for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) { - CleanInvalidOutput(info.indiceOutOffset + (s1OffsetPerAiv + innerS1Idx) * constInfo_.sparseCount); - } - } - - int32_t invalidS1Num2 = info.actS1Size - info.actS2Size; - if (invalidS1Num2 > 0 && isS1LoopEnd && blockS2StartIdx_ == 0 && constInfo_.attenMaskFlag) { - int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num2, 2) : (invalidS1Num2 / 2); - int32_t s1OffsetPerAiv = (blockId_ % 2) * CeilDiv(invalidS1Num2, 2); - for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) { - CleanInvalidOutput((info.bN2Idx * constInfo_.qSeqSize + s1OffsetPerAiv + innerS1Idx) * - constInfo_.sparseCount); - } - } - } - - if (info.isLastS2InnerLoop) { - // S2最后一个Loop后, 下一个基本块初始从0开始 - blockS2StartIdx_ = 0; - } -} - -template -__aicore__ inline void LIQVector::ProcessLD() -{ - int32_t curCubeId = blockId_ / 2; - int32_t tmpCubeId = curCubeId; - - int64_t s2ActSeq; - int64_t s2Start; - int64_t s2End; - int64_t isS2End; - int64_t bn2Idx; - int64_t s1Idx; - uint32_t acc_list_num = 0; - int64_t bIdx = 0; - int64_t needFd; - int64_t wsOffset; - int64_t wsInfoOffset = 0; - int64_t nextneedFd; - int64_t valueOffset = 0; - int64_t outOffset = 0; - - LocalTensor curValueIdxUb = ldToBeMrgBuf_.Get(); - LocalTensor tmpUb = ldTmpBuf_.Get(); - - // S2开头信息 - // 开始必然没有头规约,因此从尾规约开始处理,while循环读取下一个核的头规约 - // 存满4个list或者遇到S2结尾,则做merge,直到做完S2 - // 每个核都忽略自己的头规约,因为必然由前面的核做完 - uint32_t s1LdStartIdx = 0; - uint32_t s1ProcNum = 0; - uint64_t paramGmCoreOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_; - for (uint32_t innerS1Idx = 0; innerS1Idx < s1BaseSize_; innerS1Idx++) { - needFd = vec1ParamGm.GetValue(paramGmCoreOffset + innerS1Idx * 2 * paramNum_ + paramNum_); - if (needFd == 1) { - s1LdStartIdx = (s1ProcNum == 0) ? innerS1Idx : s1LdStartIdx; - s1ProcNum++; - } - } - - if (s1ProcNum == 0) { - return; - } - - // S1逐行计算 - uint32_t s1VecNum = CeilDiv(s1ProcNum, 2); - if (blockId_ % 2 == 1) { - s1LdStartIdx = s1LdStartIdx + s1VecNum; - s1VecNum = s1ProcNum - s1VecNum; - } - for (uint32_t innerS1Idx = s1LdStartIdx; innerS1Idx < s1LdStartIdx + s1VecNum; innerS1Idx++) { - // 重置偏移 - tmpCubeId = curCubeId; - acc_list_num = 0; - valueOffset = 0; - - // 搬入数据 - wsOffset = tmpCubeId * s1BaseSize_ * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 2个AIV共同地址偏移 - innerS1Idx * 2 * BASE_TOPK_VALUE_IDX_SIZE + BASE_TOPK_VALUE_IDX_SIZE; - SetWaitFlag(HardEvent::V_MTE2); - SetWaitFlag(HardEvent::S_MTE2); - DataCopyPad(curValueIdxUb, vec1ResGm[wsOffset], - {1, static_cast(BASE_TOPK_VALUE_IDX_SIZE * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0}); - acc_list_num++; - valueOffset += BASE_TOPK_VALUE_IDX_SIZE; - - // 获取下一个核规约信息 - tmpCubeId++; - wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_; - needFd = vec1ParamGm.GetValue(wsInfoOffset); - isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4); - s1Idx = vec1ParamGm.GetValue(wsInfoOffset + 6); - outOffset = vec1ParamGm.GetValue(wsInfoOffset + 8); - - while (needFd == 1) { - // 搬入头规约数据 - wsOffset = tmpCubeId * s1BaseSize_ * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 2个AIV共同地址偏移 - innerS1Idx * 2 * BASE_TOPK_VALUE_IDX_SIZE; - SetWaitFlag(HardEvent::V_MTE2); - SetWaitFlag(HardEvent::S_MTE2); - DataCopyPad(curValueIdxUb[valueOffset], vec1ResGm[wsOffset], - {1, static_cast(BASE_TOPK_VALUE_IDX_SIZE * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0}); - valueOffset += BASE_TOPK_VALUE_IDX_SIZE; - acc_list_num++; - - // 每满4个list,聚合 前2K为mrg结果 - if (acc_list_num == mrgListNum_) { - // MrgSort 四条2048的队列,Mrg成一条 - AscendC::MrgSort4Info params; - params.elementLengths[0] = BASE_TOPK; - params.elementLengths[1] = BASE_TOPK; - params.elementLengths[2] = BASE_TOPK; - params.elementLengths[3] = BASE_TOPK; - params.ifExhaustedSuspension = true; - params.validBit = 0b1111; - params.repeatTimes = 1; - - AscendC::MrgSortSrcList srcList; - srcList.src1 = curValueIdxUb[0]; - srcList.src2 = curValueIdxUb[BASE_TOPK_VALUE_IDX_SIZE]; - srcList.src3 = curValueIdxUb[2 * BASE_TOPK_VALUE_IDX_SIZE]; - srcList.src4 = curValueIdxUb[3 * BASE_TOPK_VALUE_IDX_SIZE]; - SetWaitFlag(HardEvent::MTE2_V); - MrgSort(tmpUb, srcList, params); - PipeBarrier(); - DataCopy(curValueIdxUb, tmpUb, BASE_TOPK_VALUE_IDX_SIZE); - PipeBarrier(); - acc_list_num = 1; - valueOffset = BASE_TOPK_VALUE_IDX_SIZE; - } - - // reduce到S2末尾,则跳出 - if (isS2End == 1) { - break; - } - - tmpCubeId++; - wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_; - needFd = vec1ParamGm.GetValue(wsInfoOffset); - isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4); - } - - // mrg不足4个list的数据 - if (acc_list_num != 1) { - AscendC::MrgSort4Info params; - params.elementLengths[0] = BASE_TOPK; - params.elementLengths[1] = BASE_TOPK; - params.elementLengths[2] = BASE_TOPK; - params.elementLengths[3] = BASE_TOPK; - params.ifExhaustedSuspension = true; - if (acc_list_num == 2) { - params.validBit = 0b0011; - } else if (acc_list_num == 3) { - params.validBit = 0b0111; - } - params.repeatTimes = 1; - - AscendC::MrgSortSrcList srcList; - srcList.src1 = curValueIdxUb[0]; - srcList.src2 = curValueIdxUb[BASE_TOPK_VALUE_IDX_SIZE]; - srcList.src3 = curValueIdxUb[2 * BASE_TOPK_VALUE_IDX_SIZE]; - srcList.src4 = curValueIdxUb[3 * BASE_TOPK_VALUE_IDX_SIZE]; - SetWaitFlag(HardEvent::MTE2_V); - MrgSort(tmpUb, srcList, params); - PipeBarrier(); - DataCopy(curValueIdxUb, tmpUb, BASE_TOPK_VALUE_IDX_SIZE); - PipeBarrier(); - } - - // 搬出 - LocalTensor outValueUb = ldOutValueBuf_.Get(); - LocalTensor outIdxUb = ldOutIdxBuf_.Get(); - Extract(outValueUb, outIdxUb, curValueIdxUb, (BASE_TOPK / 32)); - LocalTensor idxULocal1 = outIdxUb.template ReinterpretCast(); - SetWaitFlag(HardEvent::V_MTE3); - SetWaitFlag(HardEvent::S_MTE3); - DataCopyPad(indiceOutGm[outOffset], idxULocal1, - {1, static_cast(constInfo_.sparseCount * sizeof(int32_t)), 0, 0}); - SetWaitFlag(HardEvent::MTE3_V); - } -} -} // namespace LIQKernel -#endif \ No newline at end of file diff --git a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_template_tiling_key.h b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_template_tiling_key.h deleted file mode 100644 index 165e6215..00000000 --- a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_template_tiling_key.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify it. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file lightning_indexer_quant_template_tiling_key.h - * \brief - */ - -#ifndef TEMPLATE_TILING_KEY_LI_H_ -#define TEMPLATE_TILING_KEY_LI_H_ - -#include "ascendc/host_api/tiling/template_argument.h" - -#define LI_TPL_FP16 1 -#define LI_TPL_IN8 2 -#define LI_TPL_INT32 3 -#define LI_TPL_BF16 27 - -#define LIQ_LAYOUT_BSND 0 -#define LIQ_LAYOUT_TND 1 -#define LIQ_LAYOUT_PA_BSND 2 - -#define ASCENDC_TPL_4_BW 4 - -// 模板参数支持的范围定义 -ASCENDC_TPL_ARGS_DECL(LightningIndexerQuant, // 算子OpType - ASCENDC_TPL_DTYPE_DECL(DT_Q, LI_TPL_IN8), ASCENDC_TPL_DTYPE_DECL(DT_K, LI_TPL_IN8), - ASCENDC_TPL_DTYPE_DECL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_DECL(PAGE_ATTENTION, 1, 0), - ASCENDC_TPL_UINT_DECL(Q_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, - LIQ_LAYOUT_TND), - ASCENDC_TPL_UINT_DECL(K_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, - LIQ_LAYOUT_PA_BSND, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND), ); - -// 支持的模板参数组合 -// 用于调用GET_TPL_TILING_KEY获取TilingKey时,接口内部校验TilingKey是否合法 -ASCENDC_TPL_SEL( - ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_IN8), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_IN8), - ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 1), - ASCENDC_TPL_UINT_SEL(Q_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND), - ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_PA_BSND), ), - ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_IN8), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_IN8), - ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 0), - ASCENDC_TPL_UINT_SEL(Q_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND), - ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND), ), ); - -#endif \ No newline at end of file diff --git a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_vector.h b/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_vector.h deleted file mode 100644 index d6a2e277..00000000 --- a/csrc/lightning_indexer_quant/op_kernel/lightning_indexer_quant_vector.h +++ /dev/null @@ -1,193 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify it. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file lightning_indexer_quant_vector.h - * \brief - */ -#ifndef LIGHTNING_INDEXER_QUANT_VECTOR_H -#define LIGHTNING_INDEXER_QUANT_VECTOR_H - -#include "kernel_operator.h" -#include "lightning_indexer_quant_vector.h" - -namespace LIQServiceVec { -using namespace AscendC; - -constexpr int32_t NEG_INF = 0xFF800000; -constexpr int32_t INVALID_INDEX = -1; -constexpr uint8_t VEC_REPEAT_MAX = 255; -constexpr uint8_t B32_VEC_ELM_NUM = 64; -constexpr uint8_t B32_BLOCK_ALIGN_NUM = 8; -constexpr uint8_t B32_VEC_REPEAT_STRIDE = 8; -constexpr uint64_t VEC_REPEAT_BYTES = 256; -constexpr int32_t CONST_TWO = 2; -constexpr int64_t VALUE_AND_INDEX_NUM = 2; -constexpr int64_t BLOCK_BYTES = 32; -constexpr int64_t MRG_QUE_0 = 0; -constexpr int64_t MRG_QUE_1 = 1; -constexpr int64_t MRG_QUE_2 = 2; -constexpr int64_t MRG_QUE_3 = 3; -constexpr int64_t MRG_BLOCK_2 = 2; -constexpr int64_t MRG_BLOCK_3 = 3; -constexpr int64_t MRG_BLOCK_4 = 4; - -template -__aicore__ inline void CopyOut(const GlobalTensor &dstGm, const LocalTensor &srcUb, int64_t copyCount) -{ - AscendC::DataCopyParams dataCopyOutyParams; - dataCopyOutyParams.blockCount = 1; - dataCopyOutyParams.blockLen = copyCount * sizeof(T); - dataCopyOutyParams.srcStride = 0; - dataCopyOutyParams.dstStride = 0; - AscendC::DataCopyPad(dstGm, srcUb, dataCopyOutyParams); -} - -/** - src: 传入的初始化空间 - eleNum: 需要初始化的元素个数需为64整数倍,元素将被初始化为交错排布的-inf,-1 - */ -__aicore__ inline void InitSortOutBuf(const LocalTensor &src, int64_t eleNum) -{ - uint64_t mask1[2] = {0x5555555555555555, 0}; - uint64_t mask0[2] = {0xaaaaaaaaaaaaaaaa, 0}; - int64_t repeatNum = eleNum / B32_VEC_ELM_NUM; - int64_t forLoop = repeatNum / VEC_REPEAT_MAX; - int64_t forRemain = repeatNum % VEC_REPEAT_MAX; - for (int i = 0; i < forLoop; i++) { - AscendC::Duplicate(src.template ReinterpretCast(), NEG_INF, mask1, VEC_REPEAT_MAX, 1, - B32_VEC_REPEAT_STRIDE); - AscendC::Duplicate(src.template ReinterpretCast(), INVALID_INDEX, mask0, VEC_REPEAT_MAX, 1, - B32_VEC_REPEAT_STRIDE); - } - if (forRemain > 0) { - AscendC::Duplicate(src.template ReinterpretCast()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], NEG_INF, - mask1, forRemain, 1, B32_VEC_REPEAT_STRIDE); - AscendC::Duplicate(src.template ReinterpretCast()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], - INVALID_INDEX, mask0, forRemain, 1, B32_VEC_REPEAT_STRIDE); - } - AscendC::PipeBarrier(); -} - -/** - src: logits和索引,前logitsNum为logits,后logitsNum为索引 - tmp: 计算使用到的临时空间,大小与src一致 - logitsNum: 排序的元素个数, 暂只支持[128,256,384,512,1024,2048] - */ -__aicore__ inline void SortAll(LocalTensor &src, LocalTensor &tmp, int64_t logitsNum) -{ - int64_t sort32Repeats = logitsNum / BLOCK_BYTES; - AscendC::Sort32(tmp, src, src[logitsNum].ReinterpretCast(), sort32Repeats); - AscendC::PipeBarrier(); - - int64_t mrgGroups = sort32Repeats; - int64_t mrgElements = BLOCK_BYTES; - int64_t i = 0; - AscendC::LocalTensor srcTensor; - AscendC::LocalTensor dstTensor; - while (true) { - if (i % CONST_TWO == 0) { - srcTensor = tmp; - dstTensor = src; - } else { - srcTensor = src; - dstTensor = tmp; - } - AscendC::MrgSort4Info params; - params.elementLengths[0] = mrgElements; - params.elementLengths[MRG_QUE_1] = mrgElements; - params.elementLengths[MRG_QUE_2] = mrgElements; - params.elementLengths[MRG_QUE_3] = mrgElements; - params.ifExhaustedSuspension = false; - params.validBit = 0b1111; - - AscendC::MrgSortSrcList srcList; - srcList.src1 = srcTensor[0]; - srcList.src2 = srcTensor[MRG_QUE_1 * VALUE_AND_INDEX_NUM * mrgElements]; - srcList.src3 = srcTensor[MRG_QUE_2 * VALUE_AND_INDEX_NUM * mrgElements]; - srcList.src4 = srcTensor[MRG_QUE_3 * VALUE_AND_INDEX_NUM * mrgElements]; - if (mrgGroups <= MRG_BLOCK_4) { - params.repeatTimes = 1; - if (mrgGroups == 1) { - break; - } else if (mrgGroups == MRG_BLOCK_2) { - params.validBit = 0b0011; - } else if (mrgGroups == MRG_BLOCK_3) { - params.validBit = 0b0111; - } else if (mrgGroups == MRG_BLOCK_4) { - params.validBit = 0b1111; - } - AscendC::MrgSort(dstTensor, srcList, params); - i += 1; - break; - } else { - params.repeatTimes = mrgGroups / MRG_BLOCK_4; - AscendC::MrgSort(dstTensor, srcList, params); - i += 1; - mrgElements = mrgElements * MRG_BLOCK_4; - mrgGroups = mrgGroups / MRG_BLOCK_4; - } - AscendC::PipeBarrier(); - } - if (i % CONST_TWO == 0) { - AscendC::DataCopy(src, tmp, logitsNum * VALUE_AND_INDEX_NUM); - AscendC::PipeBarrier(); - } -} - -/** - mrgDst: 合并进的Tensor - mrgSrc: 待合并的Tensor - tmpTensor:空间为mrgDst+mrgSrc - */ -__aicore__ inline void MergeSort(const LocalTensor &mrgDst, int32_t mrgDstNum, LocalTensor &mrgSrc, - int32_t mrgSrcNum, LocalTensor &tmpTensor) -{ - AscendC::MrgSort4Info params; - params.elementLengths[0] = mrgSrcNum; - params.elementLengths[1] = mrgDstNum; - params.ifExhaustedSuspension = false; - params.validBit = 0b0011; - params.repeatTimes = 1; - - AscendC::MrgSortSrcList srcList; - srcList.src1 = mrgSrc; - srcList.src2 = mrgDst; - - AscendC::MrgSort(tmpTensor, srcList, params); - AscendC::PipeBarrier(); - AscendC::DataCopy(mrgDst, tmpTensor, mrgDstNum * VALUE_AND_INDEX_NUM); - AscendC::PipeBarrier(); -} - -__aicore__ inline void ExtractIndex(const LocalTensor &idxULocal, const LocalTensor &sortLocal, - int64_t extractNum) -{ - AscendC::GatherMaskParams gatherMaskParams; - gatherMaskParams.repeatTimes = Ceil(extractNum * sizeof(float) * VALUE_AND_INDEX_NUM, VEC_REPEAT_BYTES); - gatherMaskParams.src0BlockStride = 1; - gatherMaskParams.src0RepeatStride = B32_VEC_REPEAT_STRIDE; - gatherMaskParams.src1RepeatStride = 0; - uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数 - uint8_t src1Pattern = 2; // 固定模式2,表示筛选出奇数索引的数 - AscendC::GatherMask(idxULocal, sortLocal, src1Pattern, false, static_cast(0), gatherMaskParams, rsvdCnt); - AscendC::PipeBarrier(); -} - -template -__aicore__ inline void SetWaitFlag(HardEvent evt) -{ - event_t eventId = static_cast(GetTPipePtr()->FetchEventID(evt)); - AscendC::SetFlag(eventId); - AscendC::WaitFlag(eventId); -} - -} // namespace LIQServiceVec -#endif // LIGHTNING_INDEXER_QUANT_VECTOR_H \ No newline at end of file diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index b22cb7b0..bc2bb72c 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -42,7 +42,6 @@ #include "moe_gating_top_k/moe_gating_top_k_torch_adpt.h" #include "moe_init_routing_custom/moe_init_routing_custom_torch_adpt.h" #include "sparse_flash_attention/sparse_flash_attention_torch_adpt.h" -#include "lightning_indexer_quant/lightning_indexer_quant_torch_adpt.h" #include #include #include @@ -919,16 +918,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) "-> Tensor[]" ); ops.impl("moe_grouped_matmul", torch::kPrivateUse1,&vllm_ascend::moe_grouped_matmul); - - // This operator is planned to be integrated into PTA in the near future. - // Once that happens, the implementation in csrc will be removed. - ops.def( - "npu_lightning_indexer_quant(Tensor query, Tensor key, Tensor weights, Tensor query_dequant_scale, " - " Tensor key_dequant_scale, *, Tensor? actual_seq_lengths_query=None, " - " Tensor? actual_seq_lengths_key=None, Tensor? block_table=None, " - " int query_quant_mode=0, int key_quant_mode=0, " - " str layout_query='BSND', str layout_key='BSND'," - " int sparse_count=2048, int sparse_mode=3) -> Tensor" - ); - ops.impl("npu_lightning_indexer_quant", torch::kPrivateUse1, &vllm_ascend::npu_lightning_indexer_quant); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index 1f62ce2c..f5980a01 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -529,44 +529,6 @@ std::vector moe_grouped_matmul_meta( return y; } -at::Tensor npu_lightning_indexer_quant_meta( - const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights, - const at::Tensor &query_dequant_scale, const at::Tensor &key_dequant_scale, - const c10::optional &actual_seq_lengths_query, - const c10::optional &actual_seq_lengths_key, - const c10::optional &block_table, int64_t query_quant_mode, int64_t key_quant_mode, - c10::string_view layout_query, c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode) -{ - std::string query_layout_str = std::string(layout_query); - std::string key_layout_str = std::string(layout_key); - - const int SIZE = 8; - const int DIM_0 = 0; - const int DIM_1 = 1; - const int DIM_2 = 2; - const int DIM_3 = 3; - - at::SmallVector output_size; - for (size_t i = 0; i < query.sizes().size(); i++) { - TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater " - "than 0, but shape[", i, "] is ", query.size(i)); - } - for (size_t i = 0; i < key.sizes().size(); i++) { - TORCH_CHECK(key.size(i) > 0, "All values within key's shape should be greater " - "than 0, but shape[", i, "] is ", key.size(i)); - } - TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count); - int64_t keyHeadNum = (key_layout_str == "TND")? key.size(DIM_1) : key.size(DIM_2); - if (query_layout_str == "BSND") { - output_size = {query.size(DIM_0), query.size(DIM_1), keyHeadNum, sparse_count}; - } else { - output_size = {query.size(DIM_0), keyHeadNum, sparse_count}; - } - at::Tensor lightning_indexer_quant_output = at::empty(output_size, query.options().dtype(at::kInt)); - - return lightning_indexer_quant_output; -} - } // namespace meta } // namespace vllm_ascend @@ -614,7 +576,5 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { ops.impl("causal_conv1d_fn", &vllm_ascend::meta::causal_conv1d_fn_meta); // moe_grouped_matmul ops.impl("moe_grouped_matmul", &vllm_ascend::meta::moe_grouped_matmul_meta); - // Lightning indexer quant - ops.impl("npu_lightning_indexer_quant", &vllm_ascend::meta::npu_lightning_indexer_quant_meta); } } diff --git a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py index de971834..fae23fa1 100644 --- a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py @@ -266,33 +266,6 @@ def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep(): vllm_model.generate_greedy(long_example_prompts, max_tokens) -@patch.dict(os.environ, {"HCCL_OP_EXPANSION_MODE": "AIV"}) -@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) -@patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"}) -@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"}) -def test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep(): - short_example_prompts = [ - "Hello ", - ] - # "max_position_embeddings": 163840, - long_example_prompts = ["Hello " * (163839 - 500) + "Hello"] - max_tokens = 500 - with VllmRunner( - "vllm-ascend/DeepSeek-V3.2-W8A8-Pruning", - tensor_parallel_size=2, - quantization="ascend", - enable_expert_parallel=True, - max_model_len=163840, - compilation_config={"cudagraph_capture_sizes": [2, 4, 6, 8, 10, 12], "cudagraph_mode": "FULL_DECODE_ONLY"}, - speculative_config={"num_speculative_tokens": 1, "method": "deepseek_mtp"}, - additional_config={"layer_sharding": ["q_b_proj", "o_proj"], "enable_sparse_c8": True}, - reasoning_parser="deepseek_v3", - tokenizer_mode="deepseek_v32", - ) as vllm_model: - vllm_model.generate_greedy(short_example_prompts, max_tokens) - vllm_model.generate_greedy(long_example_prompts, max_tokens) - - @pytest.mark.parametrize("model", QWEN_W4A4_MODELS) def test_qwen3_w4a4_distributed_tp2(model): example_prompts = [ diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 8dd63427..27095936 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -134,12 +134,9 @@ class AscendConfig: bool(additional_config.get("enable_async_exponential", False)) and not vllm_is_batch_invariant() ) - use_sparse = hasattr(vllm_config.model_config, "hf_text_config") and hasattr( - vllm_config.model_config.hf_text_config, "index_topk" - ) - self.enable_kv_nz = additional_config.get("enable_kv_nz", False) if self.enable_kv_nz: + use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk") if not vllm_config.model_config.is_deepseek_mla or use_sparse: raise RuntimeError("enable_kv_nz is only supported for mla currently.") if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer: @@ -147,17 +144,6 @@ class AscendConfig: "enable_kv_nz is only supported in pd scenario and can only be used in D node." ) - from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type - - # Disable Sparse C8 for A5 - # A5 has not been fully validated for this path and may carry hidden risks. - # TODO(rjg-lyh): Enable A5 support after sufficient validation. - self.enable_sparse_c8 = ( - additional_config.get("enable_sparse_c8", False) - and use_sparse - and get_ascend_device_type() != AscendDeviceType.A5 - ) - def _construct_weight_prefetch_config(self, additional_config): weight_prefetch_config = additional_config.get("weight_prefetch_config", {}) self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config) diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 70bfcb33..f7edb5fb 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, TypeVar -import scipy # type: ignore import torch import torch_npu import vllm.envs as envs_vllm @@ -356,9 +355,6 @@ class AscendSFAImpl(MLAAttentionImpl): # Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled. o_proj_full_pool: torch.Tensor | None = None - # qk_hadamard tensor shared when dsa c8 enabled - qk_hadamard: torch.Tensor | None = None - def __init__( self, num_heads: int, @@ -429,12 +425,6 @@ class AscendSFAImpl(MLAAttentionImpl): self.is_rope_neox_style = False self.use_torch_npu_lightning_indexer = True - # dsa c8 - self.use_sparse_c8_indexer = ascend_config.enable_sparse_c8 - if self.use_sparse_c8_indexer: - self.c8_k_cache_dtype = torch.int8 - self.c8_k_scale_cache_dtype = torch.float16 - # Effective in SFA when FlashComm is enabled. self.enable_dsa_cp = enable_dsa_cp() @@ -525,11 +515,6 @@ class AscendSFAImpl(MLAAttentionImpl): # if mlapo, W_UK_T can't trans nz self.W_UK_T = maybe_trans_nz(self.W_UK_T) - if self.use_sparse_c8_indexer and AscendSFAImpl.qk_hadamard is None: - AscendSFAImpl.qk_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / ( - 128**0.5 - ) - # Processing the input parameters for MLAPO by reordering and transposing # QKV(and part of Q) weight, applying RoPE-related dimension transformations, # and handling quantization parameters. @@ -889,15 +874,7 @@ class AscendSFAImpl(MLAAttentionImpl): k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128] - if self.use_sparse_c8_indexer: - k_li = k_li @ AscendSFAImpl.qk_hadamard - k_li, k_li_scale = torch_npu.npu_dynamic_quant(k_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype) - k_li_scale = k_li_scale.to(self.c8_k_scale_cache_dtype) # [b*s,] - k_li_scale = k_li_scale.unsqueeze(-1) # [b*s,1] - else: - k_li_scale = None - - return k_li, k_li_scale + return k_li def indexer_select_post_process( self, @@ -928,35 +905,10 @@ class AscendSFAImpl(MLAAttentionImpl): q_li_pe = q_li_pe.squeeze(2) q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128] - if self.use_sparse_c8_indexer: - q_li_shape_ori = q_li.shape - q_li = q_li @ AscendSFAImpl.qk_hadamard - q_li, q_li_scale = torch_npu.npu_dynamic_quant(q_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype) - q_li_scale = q_li_scale.to(self.c8_k_scale_cache_dtype) - # DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer. # So two branches are maintained temporarily. # TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed. - if self.use_sparse_c8_indexer: - assert len(kv_cache) == 4 - weights = weights.to(torch.float16) - topk_indices = torch.ops._C_ascend.npu_lightning_indexer_quant( - query=q_li.view(q_li_shape_ori), - key=kv_cache[2], - weights=weights, - query_dequant_scale=q_li_scale.view(q_li_shape_ori[:-1]), - key_dequant_scale=kv_cache[3].squeeze(2), # B S N D -> B S D - actual_seq_lengths_query=actual_seq_lengths_query, - actual_seq_lengths_key=actual_seq_lengths_key, - block_table=attn_metadata.block_table, - query_quant_mode=0, - key_quant_mode=0, - layout_query="TND", - layout_key="PA_BSND", - sparse_count=2048, - sparse_mode=3, - ) - elif self.use_torch_npu_lightning_indexer: + if self.use_torch_npu_lightning_indexer: topk_indices, _ = torch_npu.npu_lightning_indexer( query=q_li, key=kv_cache[2], @@ -1079,7 +1031,7 @@ class AscendSFAImpl(MLAAttentionImpl): assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized" q_c = self.q_a_layernorm(q_c) - k_li, k_li_scale = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin) + k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin) wait_for_kv_layer_from_connector(layer_name) @@ -1092,46 +1044,20 @@ class AscendSFAImpl(MLAAttentionImpl): if self.enable_dsa_cp: assert k_pe is not None assert k_nope is not None - assert k_li is not None async_op = self.enable_dsa_cp_with_layer_shard or full_gather_o_proj_enabled # support all_gather kv async for communication calculation overlap - if not self.use_sparse_c8_indexer: - fused_kv_no_split, kv_ag_handle = all_gather_async( - torch.cat( - [ - k_pe.view(-1, k_pe.shape[-1]), - k_nope.view(-1, k_nope.shape[-1]), - k_li.view(-1, k_li.shape[-1]), - ], - dim=1, - ), - get_tp_group(), - async_op=async_op, - ) - else: - # due to different dtypes, we have to split commu pass - assert k_li_scale is not None - fused_kv_no_split, _ = all_gather_async( - torch.cat( - [ - k_pe.view(-1, k_pe.shape[-1]), - k_nope.view(-1, k_nope.shape[-1]), - ], - dim=1, - ), - get_tp_group(), - async_op=async_op, - ) - k_li, _ = all_gather_async( - k_li, - get_tp_group(), - async_op=async_op, - ) - k_li_scale, kv_ag_handle = all_gather_async( - k_li_scale, - get_tp_group(), - async_op=async_op, - ) + fused_kv_no_split, kv_ag_handle = all_gather_async( + torch.cat( + [ + k_pe.view(-1, k_pe.shape[-1]), + k_nope.view(-1, k_nope.shape[-1]), + k_li.view(-1, k_li.shape[-1]), + ], + dim=1, + ), + get_tp_group(), + async_op=async_op, + ) ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) q_pe = self.rope_single(q_pe, cos, sin) @@ -1151,12 +1077,9 @@ class AscendSFAImpl(MLAAttentionImpl): if kv_cache is not None: assert fused_kv_no_split is not None - if not self.use_sparse_c8_indexer: - k_pe, k_nope, k_li = fused_kv_no_split.split( - [self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1 - ) - else: - k_pe, k_nope = fused_kv_no_split.split([self.qk_rope_head_dim, self.kv_lora_rank], dim=-1) + k_pe, k_nope, k_li = fused_kv_no_split.split( + [self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1 + ) k_nope = k_nope.view(k_nope.shape[0], 1, -1) k_pe = k_pe.view(k_pe.shape[0], 1, -1) DeviceOperator.reshape_and_cache( @@ -1175,13 +1098,6 @@ class AscendSFAImpl(MLAAttentionImpl): torch_npu.npu_scatter_nd_update_( kv_cache[2].view(-1, k_li.shape[-1]), slot_mapping.view(-1, 1), k_li.view(-1, k_li.shape[-1]) ) # b, s, n, d - if self.use_sparse_c8_indexer: - assert len(kv_cache) == 4 - torch_npu.npu_scatter_nd_update_( - kv_cache[3].view(-1, k_li_scale.shape[-1]), - slot_mapping.view(-1, 1), - k_li_scale.view(-1, k_li_scale.shape[-1]), - ) if self.is_kv_producer: attn_metadata.reshape_cache_event.record() diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 41463c7b..de8efd1e 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -137,28 +137,6 @@ # Remove this patch if upstream provides an official NPU graph-capture # guidance / auto-configuration path for HCCL. # -# ** 8. File: platform/patch_kv_cache_interface.py** -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# 1. `vllm.v1.kv_cache_interface.MLAAttentionSpec` -# Why: -# The default `MLAAttentionSpec` is mainly built around `kv_lora_rank` -# and `qk_rope_head_dim`. On NPU, we also use this class to describe DSA -# models. Unlike the GPU path, where cache management is handled by an -# additional indexer module, extending this class directly simplifies the -# corresponding `model_runner` implementation on NPU. -# -# This patch also adds Sparse C8 support for DSA models on NPU. As part -# of that support, members such as `page_size_bytes` need to be adapted, -# so they are overridden here as well to preserve overall readability. -# How: -# This patch subclasses the original implementation, overrides selected -# methods, and adds DSA-specific attributes and helpers with default -# values where needed. -# Related PR (if no, explain why): -# https://github.com/vllm-project/vllm/pull/25896 -# Future Plan: -# Remove this patch after the upcoming KV cache spec refactor. -# # * Worker Patch: # =============== # diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index ba7a8f3d..4b8fc9d2 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -18,7 +18,6 @@ import os import vllm_ascend.patch.platform.patch_distributed # noqa import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa -import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa import vllm_ascend.patch.platform.patch_mamba_config # noqa import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa import vllm_ascend.patch.platform.patch_sched_yield # noqa diff --git a/vllm_ascend/patch/platform/patch_kv_cache_interface.py b/vllm_ascend/patch/platform/patch_kv_cache_interface.py deleted file mode 100644 index 3719a3c5..00000000 --- a/vllm_ascend/patch/platform/patch_kv_cache_interface.py +++ /dev/null @@ -1,138 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass - -import torch -import vllm.v1.kv_cache_interface -from typing_extensions import Self -from vllm.utils.torch_utils import get_dtype_size -from vllm.v1.kv_cache_interface import MLAAttentionSpec - - -@dataclass(frozen=True) -class AscendMLAAttentionSpec(MLAAttentionSpec): - """MLAAttentionSpec extended to support DSA models, with optional Sparse C8 support. - - When Sparse C8 is enabled, the KV cache tuple changes from - (kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: bfloat16) - to - (kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: int8, kv_cache[3]: float16). - - The semantic meaning of each KV cache entry is as follows: - 1. kv_cache[0] stores kv_lora. - 2. kv_cache[1] stores k_rope. - 3. kv_cache[2] stores the key tensor from the indexer module. - 4. kv_cache[3] stores the key scale tensor from the indexer module, - and exists only when Sparse C8 is enabled. - - The main changes are as follows: - 1. The key tensor from the indexer module stored in kv_cache[2] is - converted from bf16 to int8 to reduce memory usage. It is then - processed with int8 precision in Lightning_indexer computation - to improve computational efficiency. - 2. The quantization scale of the key tensor in the indexer module - must also be stored for the Lightning_indexer_quant operator, - and is therefore saved in kv_cache[3]. - """ - - sparse_head_dim: tuple[int, ...] | None = None - cache_sparse_c8: bool = False - c8_k_cache_dtype: torch.dtype = torch.int8 - c8_k_scale_cache_dtype: torch.dtype = torch.float16 - - @property - def page_size_bytes(self) -> int: - if self.cache_sparse_c8: - assert self.sparse_head_dim is not None - assert len(self.sparse_head_dim) == 3 - num_heads_per_page = self.block_size * self.num_kv_heads - # kv_cache[0]: bfloat16, kv_cache[1]: bfloat16 - kv_lora_rank, qk_rope_head_dim = self.sparse_head_dim[:2] - k_pe_nope_bytes = num_heads_per_page * (kv_lora_rank + qk_rope_head_dim) * get_dtype_size(self.dtype) - # kv_cache[2]: int8 - index_head_dim = self.sparse_head_dim[-1] - indexer_k_bytes = num_heads_per_page * index_head_dim * get_dtype_size(self.c8_k_cache_dtype) - # kv_cache[3]: float16 - # since the scale is stored per token, head_dim is set to 1. - index_scale_head_dim = 1 - indexer_k_scale_bytes = ( - num_heads_per_page * index_scale_head_dim * get_dtype_size(self.c8_k_scale_cache_dtype) - ) - return k_pe_nope_bytes + indexer_k_bytes + indexer_k_scale_bytes - - return self.block_size * self.num_kv_heads * self.head_size * get_dtype_size(self.dtype) - - @property - def sparse_kv_cache_ratio(self) -> tuple[float, float, float, float | None]: - """ - Compute the relative byte share of each KV cache entry. - - Returns: - A tuple containing the ratios for: - - kv_cache[0] - - kv_cache[1] - - kv_cache[2] - - kv_cache[3] (None if Sparse C8 is disabled) - """ - - assert self.sparse_head_dim is not None - - def get_sparse_head_dim_virtual() -> tuple[int, int, int, int]: - assert self.sparse_head_dim is not None - assert self.cache_sparse_c8 is True - - kv_lora_rank, qk_rope_head_dim, index_k_head_dim = self.sparse_head_dim - - factor = get_dtype_size(self.dtype) // get_dtype_size(self.c8_k_cache_dtype) - index_k_head_dim_virtual = index_k_head_dim // factor - - assert get_dtype_size(self.dtype) == get_dtype_size(self.c8_k_scale_cache_dtype) - index_k_scale_head_dim_virtual = 1 - - return ( - kv_lora_rank, - qk_rope_head_dim, - index_k_head_dim_virtual, - index_k_scale_head_dim_virtual, - ) - - if self.cache_sparse_c8: - virtual_dims = get_sparse_head_dim_virtual() - total_virtual_head_dim = sum(virtual_dims) - - return ( - total_virtual_head_dim / virtual_dims[0], # kv_cache[0] - total_virtual_head_dim / virtual_dims[1], # kv_cache[1] - total_virtual_head_dim / virtual_dims[2], # kv_cache[2] - total_virtual_head_dim / virtual_dims[3], # kv_cache[3] - ) - - return ( - self.head_size / self.sparse_head_dim[0], # kv_cache[0] - self.head_size / self.sparse_head_dim[1], # kv_cache[1] - self.head_size / self.sparse_head_dim[2], # kv_cache[2] - None, # kv_cache[3] does not exist - ) - - @classmethod - def merge(cls, specs: list[Self]) -> Self: - assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), ( - "All attention layers in the same KV cache group must be MLAAttentionSpec." - ) - cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs) - assert len(cache_dtype_str_set) == 1, ( - "All attention layers in the same KV cache group must use the same quantization method." - ) - return cls( - block_size=specs[0].block_size, - num_kv_heads=specs[0].num_kv_heads, - head_size=specs[0].head_size, - sparse_head_dim=specs[0].sparse_head_dim, - dtype=specs[0].dtype, - cache_dtype_str=cache_dtype_str_set.pop(), - cache_sparse_c8=specs[0].cache_sparse_c8, - ) - - -vllm.v1.kv_cache_interface.MLAAttentionSpec = AscendMLAAttentionSpec diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3c6d48bc..017ca2d3 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -88,7 +88,6 @@ from vllm.v1.worker.ubatch_utils import ( ) from vllm.v1.worker.utils import AttentionGroup -# yapf: enable from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, using_paged_attention @@ -101,6 +100,8 @@ from vllm_ascend.compilation.acl_graph import ( set_graph_params, update_full_graph_params, ) + +# yapf: enable from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor from vllm_ascend.eplb.core.eplb_device_transfer_loader import D2DExpertWeightLoader from vllm_ascend.eplb.core.eplb_worker import EplbProcess @@ -277,21 +278,7 @@ class NPUModelRunner(GPUModelRunner): self.is_multimodal_model = self.model_config.is_multimodal_model self.block_size = vllm_config.cache_config.block_size # Set up Attention - self.use_sparse = hasattr(vllm_config.model_config, "hf_text_config") and hasattr( - vllm_config.model_config.hf_text_config, "index_topk" - ) - if self.use_sparse: - self.sparse_head_dim = ( - self.model_config.hf_text_config.kv_lora_rank, - self.model_config.hf_text_config.qk_rope_head_dim, - self.model_config.hf_text_config.index_head_dim, - ) - # dsa c8 - self.use_sparse_c8_indexer = self.ascend_config.enable_sparse_c8 - if self.use_sparse_c8_indexer: - self.c8_k_cache_dtype = torch.int8 - self.c8_k_scale_cache_dtype = torch.float16 - + self.use_sparse = hasattr(self.vllm_config.model_config.hf_text_config, "index_topk") self.attn_backend = get_attn_backend( 0, self.dtype, @@ -2642,7 +2629,7 @@ class NPUModelRunner(GPUModelRunner): to their corresponding memory buffer for K cache and V cache. """ # init kv cache tensors - kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None | None] = {} + kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None] = {} # prefill disaggregation need the addr of cache tensor be aligned with 2M alignment = 2 * 1024 * 1024 layer_kv_cache_spec: dict[str, KVCacheSpec] = {} @@ -2689,18 +2676,19 @@ class NPUModelRunner(GPUModelRunner): + self.model_config.hf_text_config.kv_lora_rank ) + dsa_k_cache_factor = None + dsa_k_cache_size = None if not self.model_config.use_mla: # for non-mla model, use FullAttentionSpec - k_tensor_split_factor = 2.0 - v_tensor_split_factor = 2.0 + k_tensor_split_factor = 2 + v_tensor_split_factor = 2 elif self.use_sparse: # for deepseek v3.2, we split the kv cache according to the corresponding ratio - kv_cache_spec = layer_kv_cache_spec[layer_name] - sparse_kv_cache_ratio = kv_cache_spec.sparse_kv_cache_ratio - k_tensor_split_factor = sparse_kv_cache_ratio[0] - v_tensor_split_factor = sparse_kv_cache_ratio[1] - dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2] - dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3] + sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio()) + k_tensor_split_factor, v_tensor_split_factor, dsa_k_cache_factor = [ # type: ignore + sparse_sum_head_size / ratio for ratio in self._get_sparse_kv_cache_ratio() + ] + dsa_k_cache_size = int(kv_cache_tensor.size // dsa_k_cache_factor) else: # for other deepseek models, use MLAAttentionSpec k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank @@ -2708,56 +2696,35 @@ class NPUModelRunner(GPUModelRunner): k_tensor_size = int(kv_cache_tensor.size // k_tensor_split_factor) v_tensor_size = int(kv_cache_tensor.size // v_tensor_split_factor) - dsa_k_tensor_size = None - dsa_k_scale_tensor_size = None - #### for deepseek sparse attention - if self.use_sparse: - dsa_k_tensor_size = int(kv_cache_tensor.size // dsa_k_tensor_split_factor) - if self.use_sparse_c8_indexer: - dsa_k_scale_tensor_size = int(kv_cache_tensor.size // dsa_k_scale_tensor_split_factor) # for other attentions, e.g., self_attn, sliding window attn if self.vllm_config.kv_transfer_config is None: k_tensor = torch.zeros(k_tensor_size, dtype=torch.int8, device=self.device) v_tensor = torch.zeros(v_tensor_size, dtype=torch.int8, device=self.device) - #### for deepseek sparse attention - if dsa_k_tensor_size is not None: - dsa_k_tensor = torch.zeros(dsa_k_tensor_size, dtype=torch.int8, device=self.device) - if dsa_k_scale_tensor_size is not None: - dsa_k_scale_tensor = torch.zeros( - dsa_k_scale_tensor_size, dtype=torch.int8, device=self.device - ) + #### k cache: for deepseek sparse attention + if dsa_k_cache_factor is not None: + dsa_k_cache_tensor = torch.zeros(dsa_k_cache_size, dtype=torch.int8, device=self.device) else: k_tensor = torch.zeros(k_tensor_size + alignment, dtype=torch.int8, device=self.device) v_tensor = torch.zeros(v_tensor_size + alignment, dtype=torch.int8, device=self.device) k_tensor = self._align_memory(k_tensor, alignment)[:k_tensor_size] v_tensor = self._align_memory(v_tensor, alignment)[:v_tensor_size] - #### for deepseek sparse attention - if dsa_k_tensor_size is not None: - dsa_k_tensor = torch.zeros( - dsa_k_tensor_size + alignment, dtype=torch.int8, device=self.device + #### k cache: for deepseek sparse attention + if dsa_k_cache_factor is not None and dsa_k_cache_size is not None: + dsa_k_cache_tensor = torch.zeros( + dsa_k_cache_size + alignment, dtype=torch.int8, device=self.device ) - dsa_k_tensor = self._align_memory(dsa_k_tensor, alignment)[:dsa_k_tensor_size] - if dsa_k_scale_tensor_size is not None: - dsa_k_scale_tensor = torch.zeros( - dsa_k_scale_tensor_size + alignment, dtype=torch.int8, device=self.device - ) - dsa_k_scale_tensor = self._align_memory( - dsa_k_scale_tensor, alignment - )[:dsa_k_scale_tensor_size] + dsa_k_cache_tensor = self._align_memory(dsa_k_cache_tensor, alignment)[:dsa_k_cache_size] for layer_name_inner in kv_cache_tensor.shared_by: # shared the attn kvcache for all shared layers if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner: - if self.use_sparse: - if self.use_sparse_c8_indexer: - kv_cache_raw_tensors[layer_name_inner] = ( - k_tensor, v_tensor, dsa_k_tensor, dsa_k_scale_tensor - ) - else: - kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor, dsa_k_tensor) - else: - kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor) + kv_cache_raw_tensors[layer_name_inner] = ( + (k_tensor, v_tensor) + if not self.use_sparse + else (k_tensor, v_tensor, dsa_k_cache_tensor) + ) + layer_names = set() for group in kv_cache_config.kv_cache_groups: for layer_name in group.layer_names: @@ -2799,23 +2766,13 @@ class NPUModelRunner(GPUModelRunner): # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may # encounter OOM issue if isinstance(kv_cache_spec, AttentionSpec): + raw_dsa_k_tensor = None if self.use_sparse: - if self.use_sparse_c8_indexer: - raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor, raw_dsa_k_scale_tensor = kv_cache_raw_tensors[ # type: ignore - layer_name] - assert raw_dsa_k_tensor is not None - assert raw_dsa_k_scale_tensor is not None - sum_page_size_bytes = ( - raw_k_tensor.numel() - + raw_v_tensor.numel() - + raw_dsa_k_tensor.numel() - + raw_dsa_k_scale_tensor.numel() - ) - else: - raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore - layer_name] - assert raw_dsa_k_tensor is not None - sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel() + raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore + layer_name + ] + assert raw_dsa_k_tensor is not None + sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel() elif self.use_hybrid_blocks and self.hybrid_with_attn_and_mamba: # Currently, we ensure that the same kvcache format is used even if there # is no shared layer, such as the full attention mtp layer of qwen3.5, etc. @@ -2862,7 +2819,7 @@ class NPUModelRunner(GPUModelRunner): kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size ) - + dtype = kv_cache_spec.dtype if not self.model_config.use_mla: k_shape = kv_cache_shape[1:] v_shape = k_shape @@ -2881,37 +2838,19 @@ class NPUModelRunner(GPUModelRunner): num_kv_heads, self.model_config.hf_text_config.qk_rope_head_dim, ] - k_cache = raw_k_tensor.view(kv_cache_spec.dtype).view(k_shape) - v_cache = raw_v_tensor.view(kv_cache_spec.dtype).view(v_shape) + k_cache = raw_k_tensor.view(dtype).view(k_shape) + v_cache = raw_v_tensor.view(dtype).view(v_shape) - if self.use_sparse: + if self.use_sparse and raw_dsa_k_tensor is not None: + index_head_dim = self._get_sparse_kv_cache_ratio()[-1] dsa_k_cache_shape = ( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, - self.model_config.hf_text_config.index_head_dim, + index_head_dim, ) - if self.use_sparse_c8_indexer: - # dsa_k - dsa_k_cache = raw_dsa_k_tensor.view(self.c8_k_cache_dtype).view(dsa_k_cache_shape) - # dsa_k_scale - dsa_k_scale_cache_shape = ( - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - 1, - ) - assert raw_dsa_k_scale_tensor is not None - dsa_k_scale_cache = ( - raw_dsa_k_scale_tensor - .view(self.c8_k_scale_cache_dtype) - .view(dsa_k_scale_cache_shape) - ) - kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache, dsa_k_scale_cache) - else: - # dsa_k - dsa_k_cache = raw_dsa_k_tensor.view(kv_cache_spec.dtype).view(dsa_k_cache_shape) - kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache) + dsa_k_cache = raw_dsa_k_tensor.view(dtype).view(dsa_k_cache_shape) + kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache) else: kv_caches[layer_name] = (k_cache, v_cache) elif isinstance(kv_cache_spec, MambaSpec): @@ -3007,7 +2946,7 @@ class NPUModelRunner(GPUModelRunner): # of mamba block. In this case, BlockTable.block_size will never equal # to kernel_block_sizes[0] self.kernel_block_sizes.append([0]) - + max_num_blocks = [] max_model_len = max(self.max_model_len, self.max_encoder_len) for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): @@ -3021,7 +2960,7 @@ class NPUModelRunner(GPUModelRunner): max_num_blocks_per_req = max(max_num_blocks_per_req, mamba_blocks_per_req) max_num_blocks.append(max_num_blocks_per_req) - + if block_sizes != [self.cache_config.block_size] or self.kernel_block_sizes != [[self.cache_config.block_size]]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " @@ -3181,31 +3120,18 @@ class NPUModelRunner(GPUModelRunner): elif isinstance(attn_module, MLAAttention): if self.use_sparse: - # `MLAAttentionSpec` is temporarily patched to `AscendMLAAttentionSpec`. - # Re-importing it at runtime will therefore resolve to the patched class. - # Rename it here to make this behavior explicit. - from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec - # TODO(rjg-lyh): when kv_cache_spec's refactor is ready, - # implement it by creating a new kv_cache_spec class - kv_cache_spec[layer_name] = AscendMLAAttentionSpec( + # TODO(cmq): This is a hack way to fix deepseek kvcache when + # using DSA. Fix the spec in vLLM is the final way. + sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio()) + kv_cache_spec[layer_name] = MLAAttentionSpec( block_size=self.block_size, num_kv_heads=1, - head_size=sum(self.sparse_head_dim), - sparse_head_dim=self.sparse_head_dim, + head_size=sparse_sum_head_size, dtype=self.kv_cache_dtype, cache_dtype_str=self.vllm_config.cache_config.cache_dtype, - cache_sparse_c8=self.use_sparse_c8_indexer, ) elif spec := attn_module.get_kv_cache_spec(self.vllm_config): - assert isinstance(spec, MLAAttentionSpec) - from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec - kv_cache_spec[layer_name] = AscendMLAAttentionSpec( - block_size=spec.block_size, - num_kv_heads=spec.num_kv_heads, - head_size=spec.head_size, - dtype=spec.dtype, - cache_dtype_str=spec.cache_dtype_str, - ) + kv_cache_spec[layer_name] = spec elif isinstance(attn_module, MambaBase): mamba_layers[layer_name] = attn_module @@ -3223,6 +3149,16 @@ class NPUModelRunner(GPUModelRunner): return kv_cache_spec + def _get_sparse_kv_cache_ratio(self) -> list[int]: + # TODO:If C8 is supported, we need to consider the number of bytes occupied by different dtypes + # when calculating the ratio,for example: + # [kv_lora_rank * torch.int8.itemsize, qk_rope_head_dim * torch.bfloat16.itemsize, ...] + return [ + self.model_config.hf_text_config.kv_lora_rank, + self.model_config.hf_text_config.qk_rope_head_dim, + self.model_config.hf_text_config.index_head_dim, + ] + def _check_and_update_cudagraph_mode( self, attention_backends: list[set[type[AttentionBackend]]],