[bugfix] restore pr-7029 and fix patch error (#7294)
### What this PR does / why we need it?
This PR restores #7029, which adds W8A8C8 support for dsv3.2/glm5 using
the `lightning_indexer_quant` ops in the pd-mix stage.
The original PR was reverted by #7288 because the patch did not work
with the recompute scheduler.
This PR also fixes the patching issue so that it works correctly with
the recompute scheduler.
### Does this PR introduce _any_ user-facing change?
Yes. To enable LI C8, users need to set the `enable_sparse_c8` option to
`"true"` in `additional_config`.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
4
.github/workflows/scripts/config.yaml
vendored
4
.github/workflows/scripts/config.yaml
vendored
@@ -70,6 +70,8 @@ e2e-2card-light:
|
||||
estimated_time: 220
|
||||
- name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8_pruning_mtp_tp2_ep
|
||||
estimated_time: 90
|
||||
- name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep
|
||||
estimated_time: 90
|
||||
- name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_gpt_oss_distributed_tp2
|
||||
estimated_time: 180
|
||||
|
||||
@@ -122,6 +124,8 @@ e2e-multicard-2-cards:
|
||||
estimated_time: 71
|
||||
- name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8_pruning_mtp_tp2_ep
|
||||
estimated_time: 111
|
||||
- name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep
|
||||
estimated_time: 111
|
||||
- name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_qwen3_w4a4_distributed_tp2
|
||||
estimated_time: 180
|
||||
- name: tests/e2e/multicard/2-cards/test_pipeline_parallel.py
|
||||
|
||||
@@ -25,7 +25,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
|
||||
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
|
||||
|
||||
|
||||
CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;"
|
||||
CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;lightning_indexer_quant;"
|
||||
SOC_ARG="ascend910b"
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
# ASCEND910C (A3) series
|
||||
@@ -67,6 +67,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
"copy_and_expand_eagle_inputs"
|
||||
"causal_conv1d"
|
||||
"moe_grouped_matmul"
|
||||
"lightning_indexer_quant"
|
||||
)
|
||||
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
|
||||
SOC_ARG="ascend910_93"
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef LIGHTING_INDEXER_QUANT_VLLM_TORCH_ADPT_H
|
||||
#define LIGHTING_INDEXER_QUANT_VLLM_TORCH_ADPT_H
|
||||
namespace vllm_ascend {
|
||||
|
||||
at::Tensor npu_lightning_indexer_quant(
|
||||
const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights,
|
||||
const at::Tensor &query_dequant_scale, const at::Tensor &key_dequant_scale,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_query,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_key,
|
||||
const c10::optional<at::Tensor> &block_table, int64_t query_quant_mode, int64_t key_quant_mode,
|
||||
c10::string_view layout_query, c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode)
|
||||
{
|
||||
std::string query_layout_str = std::string(layout_query);
|
||||
std::string key_layout_str = std::string(layout_key);
|
||||
|
||||
const int SIZE = 8;
|
||||
const int DIM_0 = 0;
|
||||
const int DIM_1 = 1;
|
||||
const int DIM_2 = 2;
|
||||
const int DIM_3 = 3;
|
||||
|
||||
at::SmallVector<int64_t, SIZE> output_size;
|
||||
for (size_t i = 0; i < query.sizes().size(); i++) {
|
||||
TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater "
|
||||
"than 0, but shape[", i, "] is ", query.size(i));
|
||||
}
|
||||
for (size_t i = 0; i < key.sizes().size(); i++) {
|
||||
TORCH_CHECK(key.size(i) > 0, "All values within key's shape should be greater "
|
||||
"than 0, but shape[", i, "] is ", key.size(i));
|
||||
}
|
||||
TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count);
|
||||
int64_t keyHeadNum = (key_layout_str == "TND")? key.size(DIM_1) : key.size(DIM_2);
|
||||
if (query_layout_str == "BSND") {
|
||||
output_size = {query.size(DIM_0), query.size(DIM_1), keyHeadNum, sparse_count};
|
||||
} else {
|
||||
output_size = {query.size(DIM_0), keyHeadNum, sparse_count};
|
||||
}
|
||||
at::Tensor lightning_indexer_quant_output = at::empty(output_size, query.options().dtype(at::kInt));
|
||||
|
||||
// convert str
|
||||
char *query_layout_ptr = const_cast<char *>(query_layout_str.c_str());
|
||||
char *key_layout_ptr = const_cast<char *>(key_layout_str.c_str());
|
||||
|
||||
EXEC_NPU_CMD(aclnnLightningIndexerQuant,
|
||||
query,
|
||||
key,
|
||||
weights,
|
||||
query_dequant_scale,
|
||||
key_dequant_scale,
|
||||
actual_seq_lengths_query,
|
||||
actual_seq_lengths_key,
|
||||
block_table,
|
||||
query_quant_mode,
|
||||
key_quant_mode,
|
||||
query_layout_ptr,
|
||||
key_layout_ptr,
|
||||
sparse_count,
|
||||
sparse_mode,
|
||||
lightning_indexer_quant_output
|
||||
);
|
||||
|
||||
return lightning_indexer_quant_output;
|
||||
|
||||
}
|
||||
}
|
||||
#endif
|
||||
41
csrc/lightning_indexer_quant/op_host/CMakeLists.txt
Normal file
41
csrc/lightning_indexer_quant/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,41 @@
|
||||
# This program is free software, you can redistribute it and/or modify it.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME LightningIndexerQuant
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
-mllvm -cce-aicore-hoist-movemask=false
|
||||
--op_relocatable_kernel_binary=true
|
||||
)
|
||||
|
||||
set(lightning_indexer_quant_depends transformer/attention/lightning_indexer_quant PARENT_SCOPE)
|
||||
|
||||
target_sources(op_host_aclnn PRIVATE
|
||||
lightning_indexer_quant_def.cpp
|
||||
)
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
lightning_indexer_quant_tiling.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(opmaster_ct PRIVATE
|
||||
lightning_indexer_quant_tiling.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/op_host
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
lightning_indexer_quant_proto.cpp
|
||||
)
|
||||
@@ -0,0 +1,85 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_def.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include <cstdint>
|
||||
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class LightningIndexerQuant : public OpDef {
|
||||
public:
|
||||
explicit LightningIndexerQuant(const char *name) : OpDef(name)
|
||||
{
|
||||
this->Input("query")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT8})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("key")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT8})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("weights")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("query_dequant_scale")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("key_dequant_scale")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("actual_seq_lengths_query")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("actual_seq_lengths_key")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("block_table")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Output("sparse_indices").ParamType(REQUIRED).DataType({ge::DT_INT32}).Format({ge::FORMAT_ND});
|
||||
this->Attr("query_quant_mode").AttrType(REQUIRED).Int(0); // 0: 默认值,per-token-head
|
||||
this->Attr("key_quant_mode").AttrType(REQUIRED).Int(0); // 0: 默认值,per-token-head
|
||||
this->Attr("layout_query").AttrType(OPTIONAL).String("BSND");
|
||||
this->Attr("layout_key").AttrType(OPTIONAL).String("PA_BSND");
|
||||
this->Attr("sparse_count").AttrType(OPTIONAL).Int(2048); // 2048: 默认值,筛选前2048
|
||||
this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3); // 3: 默认值,只计算下三角
|
||||
OpAICoreConfig aicore_config;
|
||||
aicore_config.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(true)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.NeedCheckSupportFlag(false)
|
||||
.PrecisionReduceFlag(true)
|
||||
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
|
||||
.ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false");
|
||||
this->AICore().AddConfig("ascend910b", aicore_config);
|
||||
this->AICore().AddConfig("ascend910_93", aicore_config);
|
||||
}
|
||||
};
|
||||
OP_ADD(LightningIndexerQuant);
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,91 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_proto.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include <graph/utils/type_utils.h>
|
||||
#include <register/op_impl_registry.h>
|
||||
|
||||
#include "error/ops_error.h"
|
||||
|
||||
using namespace ge;
|
||||
|
||||
namespace ops {
|
||||
constexpr uint32_t QUERY_INDEX = 0;
|
||||
constexpr uint32_t KEY_INDEX = 1;
|
||||
constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 2;
|
||||
constexpr uint32_t ATTR_KV_LAYOUT_INDEX = 3;
|
||||
constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 4;
|
||||
|
||||
static ge::graphStatus InferShapeLightningIndexerQuant(gert::InferShapeContext *context)
|
||||
{
|
||||
if (context == nullptr) {
|
||||
OPS_LOG_E("LightningIndexerQuant", "context is nullptr!");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
const gert::Shape *queryShape = context->GetInputShape(QUERY_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, queryShape, return ge::GRAPH_FAILED);
|
||||
const gert::Shape *keyShape = context->GetInputShape(KEY_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, keyShape, return ge::GRAPH_FAILED);
|
||||
gert::Shape *outShape = context->GetOutputShape(0);
|
||||
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED);
|
||||
const char *inputLayoutQueryPtr = attrs->GetAttrPointer<char>(ATTR_QUERY_LAYOUT_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, inputLayoutQueryPtr, return ge::GRAPH_FAILED);
|
||||
const char *inputLayoutKeyPtr = attrs->GetAttrPointer<char>(ATTR_KV_LAYOUT_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, inputLayoutKeyPtr, return ge::GRAPH_FAILED);
|
||||
const int64_t *sparse_count = attrs->GetInt(ATTR_SPARSE_COUNT_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, sparse_count, return ge::GRAPH_FAILED);
|
||||
|
||||
std::string inputLayoutQueryPtrStr = std::string(inputLayoutQueryPtr);
|
||||
std::string inputLayoutKeyPtrStr = std::string(inputLayoutKeyPtr);
|
||||
if (inputLayoutQueryPtrStr != "TND" && inputLayoutQueryPtrStr != "BSND") {
|
||||
OPS_LOG_E(context, "The input layout query should be TND or BSND, but got %s.", inputLayoutQueryPtrStr.c_str());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
outShape->SetDimNum(queryShape->GetDimNum());
|
||||
int64_t keyHeadNum = (inputLayoutKeyPtrStr == "TND") ? keyShape->GetDim(1) : keyShape->GetDim(2);
|
||||
if (inputLayoutQueryPtrStr == "BSND") {
|
||||
outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim B
|
||||
outShape->SetDim(1, queryShape->GetDim(1)); // 1:Dim S
|
||||
outShape->SetDim(2, keyHeadNum); // 2:Dim N
|
||||
outShape->SetDim(3, *sparse_count); // 3:Dim K
|
||||
} else {
|
||||
outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim T
|
||||
outShape->SetDim(1, keyHeadNum); // 1:output shape's N Dim, 2: key shape's N Dim
|
||||
outShape->SetDim(2, *sparse_count); // 2:Dim K
|
||||
}
|
||||
|
||||
OPS_LOG_D(context->GetNodeName(), "LightningIndexerQuant InferShape end.");
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus InferDataTypeLightningIndexerQuant(gert::InferDataTypeContext *context)
|
||||
{
|
||||
if (context == nullptr) {
|
||||
OPS_LOG_E("LightningIndexerQuant", "InferDataTypeContext context is nullptr!");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
OPS_LOG_D(context->GetNodeName(), "Enter LightningIndexerQuant InferDataType impl.");
|
||||
// default index data type is int32
|
||||
ge::DataType outputType = ge::DT_INT32;
|
||||
context->SetOutputDataType(0, outputType);
|
||||
OPS_LOG_D(context->GetNodeName(), "LightningIndexerQuant InferDataType end.");
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_INFERSHAPE(LightningIndexerQuant)
|
||||
.InferShape(InferShapeLightningIndexerQuant)
|
||||
.InferDataType(InferDataTypeLightningIndexerQuant);
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,828 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_tiling.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "lightning_indexer_quant_tiling.h"
|
||||
|
||||
#include "../op_kernel/lightning_indexer_quant_template_tiling_key.h"
|
||||
|
||||
using namespace ge;
|
||||
using namespace AscendC;
|
||||
using std::map;
|
||||
using std::string;
|
||||
namespace optiling {
|
||||
// --------------------------LIQInfoParser类成员函数定义-------------------------------------
|
||||
ge::graphStatus LIQInfoParser::CheckRequiredInOutExistence() const
|
||||
{
|
||||
OPS_ERR_IF(opParamInfo_.query.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor query is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.query.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor query is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.key.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor key is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.key.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor key is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.weights.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor weights is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.weights.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor weights is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.query_dequant_scale.shape == nullptr,
|
||||
OPS_LOG_E(opName_, "Shape of tensor query_dequant_scale is nullptr"), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.query_dequant_scale.desc == nullptr,
|
||||
OPS_LOG_E(opName_, "Desc of tensor query_dequant_scale is nullptr"), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.key_dequant_scale.shape == nullptr,
|
||||
OPS_LOG_E(opName_, "Shape of tensor key_dequant_scale is nullptr"), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.key_dequant_scale.desc == nullptr,
|
||||
OPS_LOG_E(opName_, "Desc of tensor key_dequant_scale is nullptr"), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.attenOut.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor output is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.attenOut.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor output is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::CheckRequiredAttrExistence() const
|
||||
{
|
||||
OPS_ERR_IF(opParamInfo_.layOutQuery == nullptr, OPS_LOG_E(opName_, "attr layout_query is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(opParamInfo_.layOutKey == nullptr, OPS_LOG_E(opName_, "attr layout_key is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(opParamInfo_.sparseCount == nullptr, OPS_LOG_E(opName_, "attr sparse_count is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(opParamInfo_.sparseMode == nullptr, OPS_LOG_E(opName_, "attr sparse_mode is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.queryQuantMode == nullptr, OPS_LOG_E(opName_, "query_quant_mode is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.keyQuantMode == nullptr, OPS_LOG_E(opName_, "key_quant_mode is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::CheckRequiredParaExistence() const
|
||||
{
|
||||
if (CheckRequiredInOutExistence() != ge::GRAPH_SUCCESS || CheckRequiredAttrExistence() != ge::GRAPH_SUCCESS) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetOpName()
|
||||
{
|
||||
if (context_->GetNodeName() == nullptr) {
|
||||
OPS_LOG_E("LightningIndexerQuant", "opName got from TilingContext is nullptr");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
opName_ = context_->GetNodeName();
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetNpuInfo()
|
||||
{
|
||||
platformInfo_ = context_->GetPlatformInfo();
|
||||
OPS_ERR_IF(platformInfo_ == nullptr, OPS_LOG_E(opName_, "GetPlatformInfo is nullptr."), return ge::GRAPH_FAILED);
|
||||
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo_);
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
|
||||
OPS_ERR_IF(aicNum == 0 || aivNum == 0, OPS_LOG_E(opName_, "num of core obtained is 0."), return GRAPH_FAILED);
|
||||
|
||||
socVersion_ = ascendcPlatform.GetSocVersion();
|
||||
if ((socVersion_ != platform_ascendc::SocVersion::ASCEND910B) &&
|
||||
(socVersion_ != platform_ascendc::SocVersion::ASCEND910_93)) {
|
||||
OPS_LOG_E(opName_, "SOC Version[%d] is not support.", (int32_t)socVersion_);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
OPS_ERR_IF(context_->GetWorkspaceSizes(1) == nullptr, OPS_LOG_E(opName_, "workSpaceSize got from ge is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(context_->GetRawTilingData() == nullptr,
|
||||
OPS_LOG_E(context_->GetNodeName(), "RawTilingData got from GE context is nullptr."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
void LIQInfoParser::GetOptionalInputParaInfo()
|
||||
{
|
||||
opParamInfo_.actualSeqLengthsQ.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_Q_INDEX);
|
||||
opParamInfo_.actualSeqLengthsQ.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_Q_INDEX);
|
||||
opParamInfo_.actualSeqLengthsK.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_K_INDEX);
|
||||
opParamInfo_.actualSeqLengthsK.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_K_INDEX);
|
||||
opParamInfo_.blockTable.tensor = context_->GetOptionalInputTensor(BLOCK_TABLE_INDEX);
|
||||
opParamInfo_.blockTable.desc = context_->GetOptionalInputDesc(BLOCK_TABLE_INDEX);
|
||||
}
|
||||
|
||||
void LIQInfoParser::GetInputParaInfo()
|
||||
{
|
||||
opParamInfo_.query.desc = context_->GetInputDesc(QUERY_INDEX);
|
||||
opParamInfo_.query.shape = context_->GetInputShape(QUERY_INDEX);
|
||||
opParamInfo_.key.desc = context_->GetInputDesc(KEY_INDEX);
|
||||
opParamInfo_.key.shape = context_->GetInputShape(KEY_INDEX);
|
||||
opParamInfo_.weights.desc = context_->GetInputDesc(WEIGTHS_INDEX);
|
||||
opParamInfo_.weights.shape = context_->GetInputShape(WEIGTHS_INDEX);
|
||||
opParamInfo_.query_dequant_scale.desc = context_->GetInputDesc(QUERY_DEQUANT_SCALE_INDEX);
|
||||
opParamInfo_.query_dequant_scale.shape = context_->GetInputShape(QUERY_DEQUANT_SCALE_INDEX);
|
||||
opParamInfo_.key_dequant_scale.desc = context_->GetInputDesc(KEY_DEQUANT_SCALE_INDEX);
|
||||
opParamInfo_.key_dequant_scale.shape = context_->GetInputShape(KEY_DEQUANT_SCALE_INDEX);
|
||||
GetOptionalInputParaInfo();
|
||||
}
|
||||
|
||||
void LIQInfoParser::GetOutputParaInfo()
|
||||
{
|
||||
opParamInfo_.attenOut.desc = context_->GetOutputDesc(LIGHTNING_INDEXER_QUANT);
|
||||
opParamInfo_.attenOut.shape = context_->GetOutputShape(LIGHTNING_INDEXER_QUANT);
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetAttrParaInfo()
|
||||
{
|
||||
auto attrs = context_->GetAttrs();
|
||||
OPS_ERR_IF(attrs == nullptr, OPS_REPORT_VECTOR_INNER_ERR(context_->GetNodeName(), "attrs got from ge is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_LOG_I(context_->GetNodeName(), "GetAttrParaInfo start");
|
||||
opParamInfo_.layOutQuery = attrs->GetStr(ATTR_QUERY_LAYOUT_INDEX);
|
||||
opParamInfo_.layOutKey = attrs->GetStr(ATTR_KEY_LAYOUT_INDEX);
|
||||
|
||||
opParamInfo_.queryQuantMode = attrs->GetAttrPointer<int32_t>(ATTR_QUERY_QUANT_MODE_INDEX);
|
||||
opParamInfo_.keyQuantMode = attrs->GetAttrPointer<int32_t>(ATTR_KEY_QUANT_MODE_INDEX);
|
||||
opParamInfo_.layOutQuery = attrs->GetStr(ATTR_QUERY_LAYOUT_INDEX);
|
||||
opParamInfo_.layOutKey = attrs->GetStr(ATTR_KEY_LAYOUT_INDEX);
|
||||
opParamInfo_.sparseCount = attrs->GetAttrPointer<int32_t>(ATTR_SPARSE_COUNT_INDEX);
|
||||
opParamInfo_.sparseMode = attrs->GetAttrPointer<int32_t>(ATTR_SPARSE_MODE_INDEX);
|
||||
|
||||
if (opParamInfo_.layOutQuery != nullptr) {
|
||||
OPS_LOG_I(context_->GetNodeName(), "layout_query is:%s", opParamInfo_.layOutQuery);
|
||||
}
|
||||
if (opParamInfo_.layOutKey != nullptr) {
|
||||
OPS_LOG_I(context_->GetNodeName(), "layout_key is:%s", opParamInfo_.layOutKey);
|
||||
}
|
||||
if (opParamInfo_.sparseCount != nullptr) {
|
||||
OPS_LOG_I(context_->GetNodeName(), "selscted count is:%d", *opParamInfo_.sparseCount);
|
||||
}
|
||||
if (opParamInfo_.sparseMode != nullptr) {
|
||||
OPS_LOG_I(context_->GetNodeName(), "sparse mode is:%d", *opParamInfo_.sparseMode);
|
||||
}
|
||||
if (opParamInfo_.queryQuantMode != nullptr) {
|
||||
OPS_LOG_I(context_->GetNodeName(), "query_quant_mode mode is:%d", *opParamInfo_.queryQuantMode);
|
||||
}
|
||||
if (opParamInfo_.keyQuantMode != nullptr) {
|
||||
OPS_LOG_I(context_->GetNodeName(), "key_quant_mode mode is:%d", *opParamInfo_.keyQuantMode);
|
||||
}
|
||||
OPS_LOG_I(context_->GetNodeName(), "GetAttrParaInfo end");
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::CheckAttrParaInfo()
|
||||
{
|
||||
std::string layout_key(opParamInfo_.layOutKey);
|
||||
std::string layout_query(opParamInfo_.layOutQuery);
|
||||
OPS_ERR_IF(
|
||||
((std::string(opParamInfo_.layOutKey) == "BNSD") || (std::string(opParamInfo_.layOutKey) == "PA_BBND")),
|
||||
OPS_LOG_E(opName_, "input attr layout_key only supported PA_BSND, PA_BBND, BSND or TND"
|
||||
"but now layout_key is %s.", layout_key.c_str()),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(((std::string(opParamInfo_.layOutQuery) != "BSND") && (std::string(opParamInfo_.layOutQuery) != "TND")),
|
||||
OPS_LOG_E(opName_, "input attr layout_query only supported BSND or TND."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(
|
||||
((std::string(opParamInfo_.layOutKey) != "PA_BSND") &&
|
||||
(std::string(opParamInfo_.layOutQuery)) != (std::string(opParamInfo_.layOutKey))),
|
||||
OPS_LOG_E(opName_, "outside of PA, input attr layout_query and input attr layout_key must be the same, but now layout_key is %s, layout_query is %s.",
|
||||
layout_key.c_str(), layout_query.c_str()), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(!((*opParamInfo_.sparseCount > 0) && (*opParamInfo_.sparseCount <= SPARSE_LIMIT)),
|
||||
OPS_LOG_E(opName_, "input attr sparse_count must > 0 and <= 2048."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(!((*opParamInfo_.sparseMode == 0) || (*opParamInfo_.sparseMode == SPARSE_MODE_LOWER)),
|
||||
OPS_LOG_E(opName_, "input attr sparse_mode only supported 0 or 3, but now is %u.",
|
||||
*opParamInfo_.sparseMode), return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(*opParamInfo_.queryQuantMode != 0, OPS_LOG_E(opName_, "input attr query_quant_mode only supported 0."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(*opParamInfo_.keyQuantMode != 0, OPS_LOG_E(opName_, "input attr key_quant_mode only supported 0."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetOpParaInfo()
|
||||
{
|
||||
GetInputParaInfo();
|
||||
GetOutputParaInfo();
|
||||
if (ge::GRAPH_SUCCESS != GetAttrParaInfo()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
if (ge::GRAPH_SUCCESS != CheckAttrParaInfo()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetAndCheckInOutDataType()
|
||||
{
|
||||
inputQType_ = opParamInfo_.query.desc->GetDataType();
|
||||
inputKType_ = opParamInfo_.key.desc->GetDataType();
|
||||
weightsType_ = opParamInfo_.weights.desc->GetDataType();
|
||||
inputQueryScaleType_ = opParamInfo_.query_dequant_scale.desc->GetDataType();
|
||||
inputKeyScaleType_ = opParamInfo_.key_dequant_scale.desc->GetDataType();
|
||||
outputType_ = opParamInfo_.attenOut.desc->GetDataType();
|
||||
|
||||
OPS_ERR_IF(!(inputQType_ == inputKType_),
|
||||
OPS_LOG_E(opName_, "The data types of the input query and key must be the same, but now is %s, %s respectively.",
|
||||
inputQType_, inputKType_),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(
|
||||
!(inputQueryScaleType_ == inputKeyScaleType_),
|
||||
OPS_LOG_E(opName_, "The data types of the input query_dequant_scale and key_dequant_scale must be the same."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(inputQType_ != ge::DT_INT8,
|
||||
OPS_LOG_E(opName_, "The data types of the input query and key must be int8."), return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(weightsType_ != ge::DT_FLOAT16,
|
||||
OPS_LOG_E(opName_, "The data types of the input weights must be float16."), return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(
|
||||
inputQueryScaleType_ != ge::DT_FLOAT16,
|
||||
OPS_LOG_E(opName_, "The data types of the input query_dequant_scale and key_dequant_scale must be float16."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(outputType_ != ge::DT_INT32,
|
||||
OPS_LOG_E(opName_, "The data types of the output sparse_indices must be int32."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetQueryKeyAndOutLayout()
|
||||
{
|
||||
// 获取query,key的Layout基准值
|
||||
const map<string, DataLayout> layoutQueryMap = {{"BSND", DataLayout::BSND}, {"TND", DataLayout::TND}};
|
||||
|
||||
std::string layout_query(opParamInfo_.layOutQuery);
|
||||
auto QLayout_ = layoutQueryMap.find(layout_query);
|
||||
if (QLayout_ != layoutQueryMap.end()) {
|
||||
qLayout_ = QLayout_->second;
|
||||
}
|
||||
|
||||
const map<string, DataLayout> layoutKeyMap = {
|
||||
{"BSND", DataLayout::BSND}, {"TND", DataLayout::TND},
|
||||
{"PA_BSND", DataLayout::PA_BSND}, {"PA_BBND", DataLayout::PA_BSND}};
|
||||
std::string layout_key(opParamInfo_.layOutKey);
|
||||
auto KLayout = layoutKeyMap.find(layout_key);
|
||||
if (KLayout != layoutKeyMap.end()) {
|
||||
kLayout_ = KLayout->second;
|
||||
}
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetAndCheckOptionalInput()
|
||||
{
|
||||
if (kLayout_ == DataLayout::PA_BSND) {
|
||||
OPS_ERR_IF(opParamInfo_.blockTable.tensor == nullptr,
|
||||
OPS_LOG_E(opName_, "key layout only supported PA_BSND, input block_table must not be null"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(
|
||||
opParamInfo_.actualSeqLengthsK.tensor == nullptr,
|
||||
OPS_LOG_E(opName_, "key layout only supported PA_BSND, input actual_seq_lengths_key must not be null"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(opParamInfo_.blockTable.desc->GetDataType() != ge::DT_INT32,
|
||||
OPS_LOG_E(opName_, "input block_table data type only support int32"), return ge::GRAPH_FAILED);
|
||||
} else {
|
||||
OPS_ERR_IF(opParamInfo_.blockTable.tensor != nullptr,
|
||||
OPS_LOG_E(opName_, "key layout is not PA_BSND, input block_table must be null"),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
|
||||
if (kLayout_ == DataLayout::TND) {
|
||||
OPS_ERR_IF(opParamInfo_.actualSeqLengthsK.tensor == nullptr,
|
||||
OPS_LOG_E(opName_, "when layout_key is TND, input actual_seq_lengths_key must not be null"),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
OPS_ERR_IF(opParamInfo_.actualSeqLengthsK.tensor != nullptr &&
|
||||
opParamInfo_.actualSeqLengthsK.desc->GetDataType() != ge::DT_INT32,
|
||||
OPS_LOG_E(opName_, "input actual_seq_lengths_key data type only support int32"),
|
||||
return ge::GRAPH_FAILED);
|
||||
if (qLayout_ == DataLayout::TND) {
|
||||
OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor == nullptr,
|
||||
OPS_LOG_E(opName_, "when layout_query is TND, input actual_seq_lengths_query must not be null"),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor != nullptr &&
|
||||
opParamInfo_.actualSeqLengthsQ.desc->GetDataType() != ge::DT_INT32,
|
||||
OPS_LOG_E(opName_, "input actual_seq_lengths_query data type only support int32"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::CheckShapeDim()
|
||||
{
|
||||
OPS_ERR_IF((opParamInfo_.blockTable.tensor != nullptr) &&
|
||||
(opParamInfo_.blockTable.tensor->GetStorageShape().GetDimNum() != DIM_NUM_TWO),
|
||||
OPS_LOG_E(opName_, "the dim num of block_table's shape should be 2, but now is %u",
|
||||
opParamInfo_.blockTable.tensor->GetStorageShape().GetDimNum()), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(
|
||||
(kLayout_ == DataLayout::PA_BSND) && (opParamInfo_.key.shape->GetStorageShape().GetDimNum() != DIM_NUM_FOUR),
|
||||
OPS_LOG_E(opName_, "the dim num of key's shape should be 4, but now is %u",
|
||||
opParamInfo_.key.shape->GetStorageShape().GetDimNum()), return ge::GRAPH_FAILED);
|
||||
|
||||
uint32_t qShapeDim = opParamInfo_.query.shape->GetStorageShape().GetDimNum();
|
||||
uint32_t weightsShapeDim = opParamInfo_.weights.shape->GetStorageShape().GetDimNum();
|
||||
uint32_t outShapeDim = opParamInfo_.attenOut.shape->GetStorageShape().GetDimNum();
|
||||
uint32_t expectShapeDim = DIM_NUM_FOUR;
|
||||
if (qLayout_ == DataLayout::TND) {
|
||||
expectShapeDim = DIM_NUM_THREE;
|
||||
}
|
||||
OPS_ERR_IF(
|
||||
qShapeDim != expectShapeDim,
|
||||
OPS_LOG_E(opName_, "the dim num of query's shape should be %u, but now is %u", expectShapeDim, qShapeDim),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(outShapeDim != expectShapeDim,
|
||||
OPS_LOG_E(opName_, "the dim num of sparse_indices's shape should be %u, but now is %u", expectShapeDim,
|
||||
outShapeDim),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(!(weightsShapeDim == expectShapeDim - 1),
|
||||
OPS_LOG_E(opName_, "the dim num of weights's shape should be %u, but now is %u", expectShapeDim - 1,
|
||||
weightsShapeDim),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetN1Size()
|
||||
{
|
||||
if (qLayout_ == DataLayout::BSND) {
|
||||
n1Size_ = static_cast<uint32_t>(opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_TWO));
|
||||
} else {
|
||||
// TND
|
||||
n1Size_ = static_cast<uint32_t>(opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ONE));
|
||||
}
|
||||
OPS_LOG_I(context_->GetNodeName(), "n1Size is %d", n1Size_);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor,
|
||||
const std::string &actualSeqLenName)
|
||||
{
|
||||
size = static_cast<uint32_t>(tensor->GetShapeSize());
|
||||
if (size <= 0) {
|
||||
OPS_LOG_E(opName_, "%s's shape size is %u, it should be greater than 0.", actualSeqLenName.c_str(), size);
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetAndCheckN2Size()
|
||||
{
|
||||
// PA_BSND
|
||||
if (kLayout_ == DataLayout::TND) {
|
||||
n2Size_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_ONE));
|
||||
} else {
|
||||
n2Size_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_TWO));
|
||||
}
|
||||
OPS_LOG_I(context_->GetNodeName(), "N2 is %d", n2Size_);
|
||||
OPS_ERR_IF(n2Size_ != 1, OPS_LOG_E(opName_, "key shape[2] is numhead, only support 1."), return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetGSize()
|
||||
{
|
||||
if (n1Size_ % n2Size_ != 0) {
|
||||
OPS_LOG_E(opName_, "input query's head_num %u can not be a multiple of key's head_num %u.", n1Size_, n2Size_);
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
gSize_ = n1Size_ / n2Size_;
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetBatchSize()
|
||||
{
|
||||
// 获取B基准值
|
||||
// 1、非TND/NTD时, 以query的batch_size维度为基准;
|
||||
// 2、TND/NTD时, actual_seq_lens_q必须传入, 以actual_seq_lens_q数组的长度为B轴大小
|
||||
if (qLayout_ == DataLayout::TND) {
|
||||
return GetActualSeqLenSize(bSize_, opParamInfo_.actualSeqLengthsQ.tensor, "input actual_seq_lengths_query");
|
||||
} else { // BSND
|
||||
bSize_ = opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ZERO);
|
||||
OPS_LOG_I(context_->GetNodeName(), "b: %d, s: %d, n: %d,d :%d",
|
||||
opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ZERO),
|
||||
opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ONE),
|
||||
opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_TWO),
|
||||
opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_THREE));
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetHeadDim()
|
||||
{
|
||||
// 以query的D维度为基准
|
||||
uint32_t dIndex = DIM_IDX_TWO;
|
||||
// 根据layout确定D维度在shape中的位置
|
||||
switch (qLayout_) {
|
||||
case DataLayout::TND:
|
||||
// TND格式: [Total, N, D] -> D是第2维(索引2)
|
||||
dIndex = DIM_IDX_TWO;
|
||||
break;
|
||||
case DataLayout::BSND:
|
||||
// BSND格式: [Batch, SeqLen, N, D] -> D是第3维(索引3)
|
||||
dIndex = DIM_IDX_THREE;
|
||||
break;
|
||||
default:
|
||||
OPS_LOG_E(opName_, "unsupported layout for getting head dim.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
headDim_ = opParamInfo_.query.shape->GetStorageShape().GetDim(dIndex);
|
||||
OPS_ERR_IF(headDim_ != HEAD_DIM_LIMIT, OPS_LOG_E(opName_, "input query's last dim head_dim only support 128, but now is %u.", headDim_),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetS1Size()
|
||||
{
|
||||
if (qLayout_ == DataLayout::BSND) {
|
||||
s1Size_ = opParamInfo_.query.shape->GetStorageShape().GetDim(1);
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetAndCheckBlockSize()
|
||||
{
|
||||
blockSize_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(1));
|
||||
OPS_LOG_I(context_->GetNodeName(), "blockSize_ is %d", blockSize_);
|
||||
|
||||
OPS_ERR_IF(
|
||||
((blockSize_ % BLOCK_SIZE_FACTOR != 0) || (blockSize_ == 0) || (blockSize_ > BLOCK_SIZE_LIMIT)),
|
||||
OPS_LOG_E(opName_, "input key's block_size must be a multiple of 16 and belong to (0, 1024], but now is %u.",
|
||||
blockSize_),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetS2SizeForPageAttention()
|
||||
{
|
||||
if (GetAndCheckBlockSize() != ge::GRAPH_SUCCESS) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int32_t blockCount_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(0));
|
||||
OPS_ERR_IF((blockCount_ == 0), OPS_LOG_E(opName_, "input key's block_count cannot be 0."), return ge::GRAPH_FAILED);
|
||||
|
||||
maxBlockNumPerBatch_ = opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(1);
|
||||
s2Size_ = maxBlockNumPerBatch_ * blockSize_;
|
||||
OPS_LOG_I(context_->GetNodeName(), "maxBlockNumPerBatch_ is %d, blockSize_ is %d, s2Size_ is %d",
|
||||
maxBlockNumPerBatch_, blockSize_, s2Size_);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetS2SizeForBatchContinuous()
|
||||
{
|
||||
std::string layout_key(opParamInfo_.layOutKey);
|
||||
if (kLayout_ == DataLayout::BSND) {
|
||||
s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_ONE);
|
||||
} else if (kLayout_ == DataLayout::TND) {
|
||||
s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_ZERO);
|
||||
}
|
||||
OPS_ERR_IF((kLayout_ != DataLayout::BSND) && (kLayout_ != DataLayout::TND),
|
||||
OPS_LOG_E(opName_, "the layout of key is %s, it is unsupported.", layout_key.c_str()),
|
||||
return ge::GRAPH_FAILED);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::GetS2Size()
|
||||
{
|
||||
// 获取S2基准值
|
||||
// 1、BATCH_CONTINUOUS时, 从key的S轴获取
|
||||
// 3、PAGE_ATTENTION时, S2 = block_table.dim1 * block_size
|
||||
if (kLayout_ == DataLayout::PA_BSND) {
|
||||
return GetS2SizeForPageAttention();
|
||||
}
|
||||
return GetS2SizeForBatchContinuous();
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::ValidateInputShapesMatch()
|
||||
{
|
||||
/*
|
||||
TND:
|
||||
query [T,N1,D],
|
||||
key [BlockNum,BlockSize,N2,D],
|
||||
weight [T,N1],
|
||||
block_table [BatchSize, BatchMaxBlockNum],
|
||||
act_seq_k [BatchSize]
|
||||
act_seq_q [BatchSize],
|
||||
out [T,N2,topk]
|
||||
----------------------
|
||||
BSND:
|
||||
query [BatchSize,S1,N1,D],
|
||||
key [BlockNum,BlockSize,N2,D],
|
||||
weight [BatchSize,S1,N1],
|
||||
block_table [BatchSize, BatchMaxBlockNum],
|
||||
act_seq_k [BatchSize]
|
||||
act_seq_q [BatchSize] 可选
|
||||
out [BatchSize,S1,N2,topk]
|
||||
*/
|
||||
uint32_t queryWeightsN1Dim = 1;
|
||||
uint32_t outN2Dim = 1;
|
||||
|
||||
if (qLayout_ == DataLayout::TND) {
|
||||
// -----------------------check BatchSize-------------------
|
||||
// bSize_ 来源于act_seq_q
|
||||
OPS_ERR_IF((kLayout_ == DataLayout::PA_BSND) &&
|
||||
((opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_) ||
|
||||
(opParamInfo_.blockTable.tensor != nullptr &&
|
||||
opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_)),
|
||||
OPS_LOG_E(
|
||||
opName_,
|
||||
"TND case input actual_seq_lengths_query, actual_seq_lengths_key, block_table dim 0 are %u, %u, %u respectively, they must be same.",
|
||||
bSize_, opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize(),
|
||||
opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF((kLayout_ != DataLayout::PA_BSND) &&
|
||||
(opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_),
|
||||
OPS_LOG_E(
|
||||
opName_,
|
||||
"TND case input actual_seq_lengths_query, actual_seq_lengths_key, are %u, %u respectively, they must be same.",
|
||||
bSize_, opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize()),
|
||||
return ge::GRAPH_FAILED);
|
||||
// -----------------------check T-------------------
|
||||
uint32_t qTsize = opParamInfo_.query.shape->GetStorageShape().GetDim(0);
|
||||
OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != qTsize) ||
|
||||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != qTsize),
|
||||
OPS_LOG_E(opName_,
|
||||
"TND case input query, weights, sparse_indices dim 0 are %u, %u, %u respectively, they must be same.",
|
||||
qTsize, opParamInfo_.weights.shape->GetStorageShape().GetDim(0),
|
||||
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
} else {
|
||||
// -----------------------check BatchSize-------------------
|
||||
// bSize_ 来源于query
|
||||
OPS_ERR_IF((kLayout_ == DataLayout::PA_BSND) &&
|
||||
((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != bSize_) ||
|
||||
(opParamInfo_.blockTable.tensor != nullptr &&
|
||||
opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_) ||
|
||||
(opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_) ||
|
||||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != bSize_)),
|
||||
OPS_LOG_E(opName_,
|
||||
"BSND case input query, weight, actual_seq_lengths_key, block_table, sparse_indices dim 0 are %u, %u, %u, %u, %u respectively, they must be same.",
|
||||
bSize_, opParamInfo_.weights.shape->GetStorageShape().GetDim(0),
|
||||
opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize(),
|
||||
opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0),
|
||||
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF((kLayout_ != DataLayout::PA_BSND) &&
|
||||
((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != bSize_) ||
|
||||
(opParamInfo_.actualSeqLengthsK.tensor != nullptr &&
|
||||
opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_) ||
|
||||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != bSize_)),
|
||||
OPS_LOG_E(opName_,
|
||||
"BSND case input query, weight, actual_seq_lengths_key, sparse_indices dim 0 are %u, %u, %u, %u respectively, they must be same.",
|
||||
bSize_, opParamInfo_.weights.shape->GetStorageShape().GetDim(0),
|
||||
opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize(),
|
||||
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(
|
||||
(opParamInfo_.actualSeqLengthsQ.tensor != nullptr) &&
|
||||
(opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize() != bSize_),
|
||||
OPS_LOG_E(
|
||||
opName_,
|
||||
"BSND case input query, actual_seq_lengths_query dim 0 are %u, %u respectively, they must be same",
|
||||
bSize_, opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize()),
|
||||
return ge::GRAPH_FAILED);
|
||||
// -----------------------check S1-------------------
|
||||
OPS_ERR_IF(
|
||||
(opParamInfo_.weights.shape->GetStorageShape().GetDim(1) != s1Size_) ||
|
||||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1) != s1Size_),
|
||||
OPS_LOG_E(opName_, "BSND case input query, weight, sparse_indices dim 1 are %u, %u, %u, they must be same.",
|
||||
s1Size_, opParamInfo_.weights.shape->GetStorageShape().GetDim(1),
|
||||
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1)),
|
||||
return ge::GRAPH_FAILED);
|
||||
queryWeightsN1Dim = DIM_IDX_TWO;
|
||||
outN2Dim = DIM_IDX_TWO;
|
||||
}
|
||||
// -----------------------check N1-------------------
|
||||
OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(queryWeightsN1Dim) != n1Size_),
|
||||
OPS_LOG_E(opName_, "input query, weight shape dim N1 must be same, but now are %u, %u respectively.",
|
||||
opParamInfo_.weights.shape->GetStorageShape().GetDim(queryWeightsN1Dim), n1Size_),
|
||||
return ge::GRAPH_FAILED);
|
||||
// -----------------------check D-------------------
|
||||
OPS_ERR_IF(
|
||||
((kLayout_ != DataLayout::TND && opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_THREE) != headDim_)
|
||||
|| (kLayout_ == DataLayout::TND && opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_TWO) != headDim_)),
|
||||
OPS_LOG_E(opName_, "input query, key shape last dim must be same, now are %u, %u respectively.",
|
||||
headDim_, opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_THREE)),
|
||||
return ge::GRAPH_FAILED);
|
||||
// -----------------------check N2-------------------
|
||||
OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim) != n2Size_),
|
||||
OPS_LOG_E(opName_, "input query and output sparse_indices shape n2 dim must be same."),
|
||||
return ge::GRAPH_FAILED);
|
||||
// -----------------------check sparse_count-------------------
|
||||
OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim + 1) != *opParamInfo_.sparseCount),
|
||||
OPS_LOG_E(opName_, "output sparse_indices shape last dim must be same as attr sparse_count."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::CheckScaleShape()
|
||||
{
|
||||
uint32_t qShapeDim = opParamInfo_.query.shape->GetStorageShape().GetDimNum();
|
||||
uint32_t kShapeDim = opParamInfo_.key.shape->GetStorageShape().GetDimNum();
|
||||
uint32_t qDequantScaleShapeDim = opParamInfo_.query_dequant_scale.shape->GetStorageShape().GetDimNum();
|
||||
uint32_t kDequantScaleShapeDim = opParamInfo_.key_dequant_scale.shape->GetStorageShape().GetDimNum();
|
||||
OPS_ERR_IF(qDequantScaleShapeDim != (qShapeDim - 1),
|
||||
OPS_LOG_E(opName_, "the dim num of query_dequant_scale's shape should be %u, but now is %u",
|
||||
qShapeDim - 1, qDequantScaleShapeDim),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(kDequantScaleShapeDim != (kShapeDim - 1),
|
||||
OPS_LOG_E(opName_, "the dim num of key_dequant_scale's shape should be %u, but now is %u", kShapeDim - 1,
|
||||
kDequantScaleShapeDim),
|
||||
return ge::GRAPH_FAILED);
|
||||
// check q scale
|
||||
for (uint32_t i = 0; i < (qShapeDim - 1); i++) {
|
||||
uint32_t dimValueQueryScale = opParamInfo_.query_dequant_scale.shape->GetStorageShape().GetDim(i);
|
||||
uint32_t dimValueQuery = opParamInfo_.query.shape->GetStorageShape().GetDim(i);
|
||||
OPS_ERR_IF(dimValueQueryScale != dimValueQuery,
|
||||
OPS_LOG_E(opName_, "query_dequant_scale's shape[%u] %u and query's shape[%u] %u is not same", i,
|
||||
dimValueQueryScale, i, dimValueQuery),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
// check k scale
|
||||
for (uint32_t i = 0; i < (kShapeDim - 1); i++) {
|
||||
uint32_t dimValueKeyScale = opParamInfo_.key_dequant_scale.shape->GetStorageShape().GetDim(i);
|
||||
uint32_t dimValueKey = opParamInfo_.key.shape->GetStorageShape().GetDim(i);
|
||||
OPS_ERR_IF(dimValueKeyScale != dimValueKey,
|
||||
OPS_LOG_E(opName_, "key_dequant_scale's shape[%u] %u and key's shape[%u] %u is not same", i,
|
||||
dimValueKeyScale, i, dimValueKey),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
void LIQInfoParser::GenerateInfo(LIQTilingInfo &liqInfo)
|
||||
{
|
||||
liqInfo.opName = opName_;
|
||||
liqInfo.platformInfo = platformInfo_;
|
||||
liqInfo.opParamInfo = opParamInfo_;
|
||||
liqInfo.socVersion = socVersion_;
|
||||
|
||||
liqInfo.bSize = bSize_;
|
||||
liqInfo.n1Size = n1Size_;
|
||||
liqInfo.n2Size = n2Size_;
|
||||
liqInfo.s1Size = s1Size_;
|
||||
liqInfo.s2Size = s2Size_;
|
||||
liqInfo.gSize = gSize_;
|
||||
|
||||
liqInfo.inputQType = inputQType_;
|
||||
liqInfo.inputKType = inputKType_;
|
||||
liqInfo.outputType = outputType_;
|
||||
|
||||
liqInfo.blockSize = blockSize_;
|
||||
liqInfo.maxBlockNumPerBatch = maxBlockNumPerBatch_;
|
||||
|
||||
liqInfo.pageAttentionFlag = (kLayout_ == DataLayout::PA_BSND);
|
||||
liqInfo.sparseMode = *opParamInfo_.sparseMode;
|
||||
liqInfo.sparseCount = *opParamInfo_.sparseCount;
|
||||
|
||||
liqInfo.inputQLayout = qLayout_;
|
||||
liqInfo.inputKLayout = kLayout_;
|
||||
}
|
||||
|
||||
ge::graphStatus LIQInfoParser::ParseAndCheck(LIQTilingInfo &liqInfo)
|
||||
{
|
||||
if (ge::GRAPH_SUCCESS != GetOpName() || ge::GRAPH_SUCCESS != GetNpuInfo() || ge::GRAPH_SUCCESS != GetOpParaInfo() ||
|
||||
ge::GRAPH_SUCCESS != CheckRequiredParaExistence()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (ge::GRAPH_SUCCESS != GetAndCheckInOutDataType() || ge::GRAPH_SUCCESS != GetQueryKeyAndOutLayout() ||
|
||||
ge::GRAPH_SUCCESS != GetAndCheckOptionalInput()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (ge::GRAPH_SUCCESS != CheckShapeDim() || ge::GRAPH_SUCCESS != GetN1Size() ||
|
||||
ge::GRAPH_SUCCESS != GetAndCheckN2Size() || ge::GRAPH_SUCCESS != GetGSize()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (ge::GRAPH_SUCCESS != GetBatchSize() || ge::GRAPH_SUCCESS != GetS1Size() || ge::GRAPH_SUCCESS != GetHeadDim() ||
|
||||
ge::GRAPH_SUCCESS != GetS2Size()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
if (ge::GRAPH_SUCCESS != ValidateInputShapesMatch() || ge::GRAPH_SUCCESS != CheckScaleShape()) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GenerateInfo(liqInfo);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// --------------------------TilingPrepare函数定义-------------------------------------
|
||||
static ge::graphStatus TilingPrepareForLightningIndexerQuant(gert::TilingParseContext * /* context */)
|
||||
{
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// --------------------------LightningIndexerQuantTiling类成员函数定义-----------------------
|
||||
ge::graphStatus LightningIndexerQuantTiling::DoTiling(LIQTilingInfo *tilingInfo)
|
||||
{
|
||||
// -------------set blockdim-----------------
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(tilingInfo->platformInfo);
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
|
||||
uint32_t blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, aicNum, aivNum);
|
||||
context_->SetBlockDim(blockDim);
|
||||
|
||||
// -------------set workspacesize-----------------
|
||||
constexpr uint32_t MM1_RES_ELEM_SIZE = 4; // 4: fp32
|
||||
constexpr uint32_t DOUBLE_BUFFER = 2; // 双Buffer
|
||||
constexpr uint32_t M_BASE_SIZE = 512; // m轴基本块大小
|
||||
constexpr uint32_t S2_BASE_SIZE = 512; // S2轴基本块大小
|
||||
constexpr uint32_t V1_RES_ELEM_SIZE = 4; // 4: int32
|
||||
constexpr uint32_t V1_RES_ELEM_TYPE = 2; // 保留Index和Value 2种数据
|
||||
constexpr uint32_t V1_DECODE_PARAM_ELEM_SIZE = 8; // 8: int64
|
||||
constexpr uint32_t V1_DECODE_PARAM_NUM = 16; // Decode参数个数
|
||||
constexpr uint32_t V1_DECODE_DATA_NUM = 2; // Decode每个核需要存储头和尾部两块数据
|
||||
constexpr uint32_t S1_BASE_SIZE = 8; // S1轴基本块的大小
|
||||
constexpr uint32_t TOPK_MAX_SIZE = 2048; // TopK选取个数
|
||||
uint32_t workspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize();
|
||||
// 主流程需Workspace大小
|
||||
uint32_t mm1ResSize = M_BASE_SIZE * S2_BASE_SIZE;
|
||||
workspaceSize += mm1ResSize * MM1_RES_ELEM_SIZE * DOUBLE_BUFFER * aicNum;
|
||||
// Decode流程(LD)需要Workspace大小
|
||||
// 临时存储Decode中间结果大小: 2(头/尾)*8(s1Base)*2(idx/value)*2048(K)*sizeof(int32)*24=6M
|
||||
workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_RES_ELEM_TYPE * TOPK_MAX_SIZE * V1_RES_ELEM_SIZE * aicNum;
|
||||
// 临时存储Decode中间参数信息大小: 2(头/尾)*8(s1Base)*16(paramNum)*sizeof(int64_t)*24=48k
|
||||
workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_DECODE_PARAM_NUM * V1_DECODE_PARAM_ELEM_SIZE * aicNum;
|
||||
size_t *workSpaces = context_->GetWorkspaceSizes(1);
|
||||
workSpaces[0] = workspaceSize;
|
||||
|
||||
// -------------set tilingdata-----------------
|
||||
tilingData_.set_bSize(tilingInfo->bSize);
|
||||
tilingData_.set_s2Size(tilingInfo->s2Size);
|
||||
tilingData_.set_s1Size(tilingInfo->s1Size);
|
||||
tilingData_.set_sparseCount(tilingInfo->sparseCount);
|
||||
tilingData_.set_gSize(tilingInfo->gSize);
|
||||
tilingData_.set_blockSize(tilingInfo->blockSize);
|
||||
tilingData_.set_maxBlockNumPerBatch(tilingInfo->maxBlockNumPerBatch);
|
||||
tilingData_.set_sparseMode(tilingInfo->sparseMode);
|
||||
tilingData_.set_usedCoreNum(blockDim);
|
||||
tilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity());
|
||||
context_->GetRawTilingData()->SetDataSize(tilingData_.GetDataSize());
|
||||
|
||||
// -------------set tilingkey-----------------
|
||||
// DT_Q, DT_KV, DT_OUT, PAGE_ATTENTION, FLASH_DECODE, LAYOUT_T, KV_LAYOUT_T
|
||||
uint32_t inputQType = static_cast<uint32_t>(tilingInfo->inputQType);
|
||||
uint32_t inputKType = static_cast<uint32_t>(tilingInfo->inputKType);
|
||||
uint32_t outputType = static_cast<uint32_t>(tilingInfo->outputType);
|
||||
uint32_t pageAttentionFlag = static_cast<uint32_t>(tilingInfo->pageAttentionFlag);
|
||||
uint32_t inputQLayout = static_cast<uint32_t>(tilingInfo->inputQLayout);
|
||||
uint32_t inputKLayout = static_cast<uint32_t>(tilingInfo->inputKLayout);
|
||||
uint32_t tilingKey =
|
||||
GET_TPL_TILING_KEY(inputQType, inputKType, outputType, pageAttentionFlag, inputQLayout, inputKLayout);
|
||||
context_->SetTilingKey(tilingKey);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// --------------------------Tiling函数定义---------------------------
|
||||
ge::graphStatus TilingForLightningIndexerQuant(gert::TilingContext *context)
|
||||
{
|
||||
OPS_ERR_IF(context == nullptr, OPS_REPORT_VECTOR_INNER_ERR("LightningIndexerQuant", "Tiling context is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
LIQTilingInfo liqInfo;
|
||||
LIQInfoParser LIQInfoParser(context);
|
||||
if (LIQInfoParser.ParseAndCheck(liqInfo) != ge::GRAPH_SUCCESS) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
LightningIndexerQuantTiling liqTiling(context);
|
||||
return liqTiling.DoTiling(&liqInfo);
|
||||
}
|
||||
|
||||
// --------------------------Tiling及函数TilingPrepare函数注册--------
|
||||
IMPL_OP_OPTILING(LightningIndexerQuant)
|
||||
.Tiling(TilingForLightningIndexerQuant)
|
||||
.TilingParse<LIQCompileInfo>(TilingPrepareForLightningIndexerQuant);
|
||||
|
||||
} // namespace optiling
|
||||
@@ -0,0 +1,234 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_tiling.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef LIGHTNING_INDEXER_QUANT_TILING_H_
|
||||
#define LIGHTNING_INDEXER_QUANT_TILING_H_
|
||||
|
||||
#include "error/ops_error.h"
|
||||
#include "exe_graph/runtime/tiling_context.h"
|
||||
#include "platform/platform_info.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "register/tilingdata_base.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "tiling/tiling_api.h"
|
||||
|
||||
namespace optiling {
|
||||
// ------------------公共定义--------------------------
|
||||
struct TilingRequiredParaInfo {
|
||||
const gert::CompileTimeTensorDesc *desc;
|
||||
const gert::StorageShape *shape;
|
||||
};
|
||||
|
||||
struct TilingOptionalParaInfo {
|
||||
const gert::CompileTimeTensorDesc *desc;
|
||||
const gert::Tensor *tensor;
|
||||
};
|
||||
|
||||
enum class DataLayout : uint32_t {
|
||||
BSND = 0,
|
||||
TND = 1,
|
||||
PA_BSND = 2
|
||||
};
|
||||
|
||||
// ------------------算子原型索引常量定义----------------
|
||||
// Inputs Index
|
||||
constexpr uint32_t QUERY_INDEX = 0;
|
||||
constexpr uint32_t KEY_INDEX = 1;
|
||||
constexpr uint32_t WEIGTHS_INDEX = 2;
|
||||
constexpr uint32_t QUERY_DEQUANT_SCALE_INDEX = 3;
|
||||
constexpr uint32_t KEY_DEQUANT_SCALE_INDEX = 4;
|
||||
constexpr uint32_t ACTUAL_SEQ_Q_INDEX = 5;
|
||||
constexpr uint32_t ACTUAL_SEQ_K_INDEX = 6;
|
||||
constexpr uint32_t BLOCK_TABLE_INDEX = 7;
|
||||
constexpr uint32_t LIGHTNING_INDEXER_QUANT = 0;
|
||||
// Attributes Index
|
||||
constexpr uint32_t ATTR_QUERY_QUANT_MODE_INDEX = 0;
|
||||
constexpr uint32_t ATTR_KEY_QUANT_MODE_INDEX = 1;
|
||||
constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 2;
|
||||
constexpr uint32_t ATTR_KEY_LAYOUT_INDEX = 3;
|
||||
constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 4;
|
||||
constexpr uint32_t ATTR_SPARSE_MODE_INDEX = 5;
|
||||
// Dim Index
|
||||
constexpr uint32_t DIM_IDX_ZERO = 0;
|
||||
constexpr uint32_t DIM_IDX_ONE = 1;
|
||||
constexpr uint32_t DIM_IDX_TWO = 2;
|
||||
constexpr uint32_t DIM_IDX_THREE = 3;
|
||||
// Dim Num
|
||||
constexpr uint32_t DIM_NUM_TWO = 2;
|
||||
constexpr uint32_t DIM_NUM_THREE = 3;
|
||||
constexpr uint32_t DIM_NUM_FOUR = 4;
|
||||
// 入参限制常量
|
||||
constexpr uint32_t HEAD_DIM_LIMIT = 128;
|
||||
constexpr uint32_t SPARSE_LIMIT = 2048;
|
||||
constexpr uint32_t G_SIZE_LIMIT = 64;
|
||||
constexpr uint32_t BLOCK_SIZE_LIMIT = 1024;
|
||||
constexpr uint32_t BLOCK_SIZE_FACTOR = 16;
|
||||
constexpr uint32_t SPARSE_MODE_LOWER = 3;
|
||||
|
||||
// -----------算子TilingData定义---------------
|
||||
BEGIN_TILING_DATA_DEF(LIQTilingData)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, bSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, n2Size)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, gSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, s1Size)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, s2Size)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, sparseCount)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, blockSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, sparseMode)
|
||||
END_TILING_DATA_DEF
|
||||
REGISTER_TILING_DATA_CLASS(LightningIndexerQuant, LIQTilingData)
|
||||
|
||||
// -----------算子CompileInfo定义-------------------
|
||||
struct LIQCompileInfo {};
|
||||
|
||||
// -----------算子Tiling入参结构体定义---------------
|
||||
struct LIQParaInfo {
|
||||
TilingRequiredParaInfo query = {nullptr, nullptr};
|
||||
TilingRequiredParaInfo key = {nullptr, nullptr};
|
||||
TilingRequiredParaInfo weights = {nullptr, nullptr};
|
||||
TilingRequiredParaInfo query_dequant_scale = {nullptr, nullptr};
|
||||
TilingRequiredParaInfo key_dequant_scale = {nullptr, nullptr};
|
||||
TilingOptionalParaInfo actualSeqLengthsQ = {nullptr, nullptr};
|
||||
TilingOptionalParaInfo actualSeqLengthsK = {nullptr, nullptr};
|
||||
TilingOptionalParaInfo blockTable = {nullptr, nullptr};
|
||||
TilingRequiredParaInfo attenOut = {nullptr, nullptr};
|
||||
|
||||
const int32_t *queryQuantMode = nullptr;
|
||||
const int32_t *keyQuantMode = nullptr;
|
||||
const char *layOutQuery = nullptr;
|
||||
const char *layOutKey = nullptr;
|
||||
const int32_t *blockSize = nullptr;
|
||||
const int32_t *sparseMode = nullptr;
|
||||
const int32_t *sparseCount = nullptr;
|
||||
};
|
||||
|
||||
// -----------算子Tiling入参信息类---------------
|
||||
class LIQTilingInfo {
|
||||
public:
|
||||
const char *opName = nullptr;
|
||||
fe::PlatFormInfos *platformInfo = nullptr;
|
||||
LIQParaInfo opParamInfo;
|
||||
// Base Param
|
||||
platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910B;
|
||||
uint32_t bSize = 0;
|
||||
uint32_t n1Size = 0;
|
||||
uint32_t n2Size = 0;
|
||||
uint32_t s1Size = 0;
|
||||
int64_t s2Size = 0;
|
||||
uint32_t qkHeadDim = 0;
|
||||
uint32_t gSize = 0;
|
||||
// PageAttention
|
||||
bool pageAttentionFlag = false;
|
||||
int32_t blockSize = 0;
|
||||
uint32_t maxBlockNumPerBatch = 0;
|
||||
// Mask
|
||||
int32_t sparseMode = 0;
|
||||
// Others Flag
|
||||
uint32_t sparseCount = 0;
|
||||
// DType
|
||||
ge::DataType inputQType = ge::DT_FLOAT16;
|
||||
ge::DataType inputKType = ge::DT_FLOAT16;
|
||||
ge::DataType outputType = ge::DT_INT32;
|
||||
// Layout
|
||||
DataLayout inputQLayout = DataLayout::BSND;
|
||||
DataLayout inputKLayout = DataLayout::PA_BSND;
|
||||
};
|
||||
|
||||
// -----------算子Tiling入参信息解析及Check类---------------
|
||||
class LIQInfoParser {
|
||||
public:
|
||||
explicit LIQInfoParser(gert::TilingContext *context) : context_(context) {}
|
||||
~LIQInfoParser() = default;
|
||||
|
||||
ge::graphStatus CheckRequiredInOutExistence() const;
|
||||
ge::graphStatus CheckRequiredAttrExistence() const;
|
||||
ge::graphStatus CheckRequiredParaExistence() const;
|
||||
ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor,
|
||||
const std::string &actualSeqLenName);
|
||||
ge::graphStatus GetOpName();
|
||||
ge::graphStatus GetNpuInfo();
|
||||
void GetOptionalInputParaInfo();
|
||||
void GetInputParaInfo();
|
||||
void GetOutputParaInfo();
|
||||
ge::graphStatus GetAttrParaInfo();
|
||||
ge::graphStatus CheckAttrParaInfo();
|
||||
ge::graphStatus GetOpParaInfo();
|
||||
ge::graphStatus ValidateInputShapesMatch();
|
||||
ge::graphStatus CheckScaleShape();
|
||||
ge::graphStatus GetAndCheckInOutDataType();
|
||||
ge::graphStatus GetBatchSize();
|
||||
ge::graphStatus GetHeadDim();
|
||||
ge::graphStatus GetS1Size();
|
||||
ge::graphStatus GetAndCheckOptionalInput();
|
||||
ge::graphStatus CheckShapeDim();
|
||||
ge::graphStatus GetAndCheckBlockSize();
|
||||
ge::graphStatus GetS2SizeForPageAttention();
|
||||
ge::graphStatus GetS2SizeForBatchContinuous();
|
||||
ge::graphStatus GetS2Size();
|
||||
ge::graphStatus GetQueryKeyAndOutLayout();
|
||||
ge::graphStatus GetN1Size();
|
||||
ge::graphStatus GetAndCheckN2Size();
|
||||
ge::graphStatus GetGSize();
|
||||
ge::graphStatus GetAttenMaskInfo();
|
||||
ge::graphStatus GetActualSeqInfo();
|
||||
void GenerateInfo(LIQTilingInfo &liqInfo);
|
||||
ge::graphStatus ParseAndCheck(LIQTilingInfo &liqInfo);
|
||||
|
||||
public:
|
||||
gert::TilingContext *context_ = nullptr;
|
||||
const char *opName_;
|
||||
fe::PlatFormInfos *platformInfo_;
|
||||
LIQParaInfo opParamInfo_;
|
||||
|
||||
// BaseParams
|
||||
uint32_t bSize_ = 0;
|
||||
uint32_t n1Size_ = 0;
|
||||
uint32_t n2Size_ = 0;
|
||||
uint32_t gSize_ = 0;
|
||||
uint32_t s1Size_ = 0;
|
||||
int64_t s2Size_ = 0;
|
||||
uint32_t headDim_ = 0;
|
||||
// Layout
|
||||
DataLayout qLayout_ = DataLayout::BSND;
|
||||
DataLayout kLayout_ = DataLayout::PA_BSND;
|
||||
// PageAttention
|
||||
uint32_t maxBlockNumPerBatch_ = 0;
|
||||
int32_t blockSize_ = 0;
|
||||
platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B;
|
||||
ge::DataType inputQType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputKType_ = ge::DT_FLOAT16;
|
||||
ge::DataType weightsType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputQueryScaleType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputKeyScaleType_ = ge::DT_FLOAT16;
|
||||
ge::DataType blockTableType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputKRopeType_ = ge::DT_FLOAT16;
|
||||
ge::DataType outputType_ = ge::DT_FLOAT16;
|
||||
};
|
||||
|
||||
// ---------------算子Tiling类---------------
|
||||
class LightningIndexerQuantTiling {
|
||||
public:
|
||||
explicit LightningIndexerQuantTiling(gert::TilingContext *context) : context_(context) {};
|
||||
ge::graphStatus DoTiling(LIQTilingInfo *tilingInfo);
|
||||
|
||||
private:
|
||||
gert::TilingContext *context_ = nullptr;
|
||||
LIQTilingData tilingData_;
|
||||
};
|
||||
|
||||
} // namespace optiling
|
||||
#endif // LIGHTNING_INDEXER_QUANT_TILING_H_
|
||||
@@ -0,0 +1,50 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
#include "lightning_indexer_quant_kernel.h"
|
||||
#include "lightning_indexer_quant_template_tiling_key.h"
|
||||
|
||||
using namespace LIQKernel;
|
||||
|
||||
#define INVOKE_LI_NO_KFC_OP_IMPL(templateClass, ...) \
|
||||
do { \
|
||||
templateClass<LIQType<__VA_ARGS__>> op; \
|
||||
GET_TILING_DATA_WITH_STRUCT(LIQTilingData, tiling_data_in, tiling); \
|
||||
const LIQTilingData *__restrict tiling_data = &tiling_data_in; \
|
||||
op.Init(query, key, weights, queryScale, keyScale, actualSeqLengthsQ, actualSeqLengthsK, blocktable, \
|
||||
sparseIndices, user, tiling_data, &tPipe); \
|
||||
op.Process(); \
|
||||
} while (0)
|
||||
|
||||
template <int DT_Q, int DT_K, int DT_OUT, int PAGE_ATTENTION, int Q_LAYOUT_T, int K_LAYOUT_T>
|
||||
__global__ __aicore__ void lightning_indexer_quant(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
|
||||
__gm__ uint8_t *queryScale, __gm__ uint8_t *keyScale,
|
||||
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengthsK,
|
||||
__gm__ uint8_t *blocktable, __gm__ uint8_t *sparseIndices,
|
||||
__gm__ uint8_t *workspace, __gm__ uint8_t *tiling)
|
||||
{
|
||||
#if (__CCE_AICORE__ == 310) || (defined __DAV_310R6__) || (__CCE_AICORE__ == 200)
|
||||
|
||||
#else
|
||||
TPipe tPipe;
|
||||
__gm__ uint8_t *user = GetUserWorkspace(workspace);
|
||||
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2);
|
||||
|
||||
INVOKE_LI_NO_KFC_OP_IMPL(LIQPreload, int8_t, int8_t, int32_t,
|
||||
PAGE_ATTENTION, LI_LAYOUT(Q_LAYOUT_T), LI_LAYOUT(K_LAYOUT_T));
|
||||
#endif
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_common.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef LIGHTNING_INDEXER_QUANT_COMMON_H
|
||||
#define LIGHTNING_INDEXER_QUANT_COMMON_H
|
||||
|
||||
namespace LIQCommon {
|
||||
|
||||
// 与tiling的layout保持一致
|
||||
enum class LI_LAYOUT : uint32_t {
|
||||
BSND = 0,
|
||||
TND = 1,
|
||||
PA_BSND = 2
|
||||
};
|
||||
|
||||
template <typename Q_T, typename K_T, typename OUT_T, const bool PAGE_ATTENTION = false,
|
||||
LI_LAYOUT Q_LAYOUT_T = LI_LAYOUT::BSND, LI_LAYOUT K_LAYOUT_T = LI_LAYOUT::PA_BSND, typename... Args>
|
||||
struct LIQType {
|
||||
using queryType = Q_T;
|
||||
using keyType = K_T;
|
||||
using outputType = OUT_T;
|
||||
static constexpr bool pageAttention = PAGE_ATTENTION;
|
||||
static constexpr LI_LAYOUT layout = Q_LAYOUT_T;
|
||||
static constexpr LI_LAYOUT keyLayout = K_LAYOUT_T;
|
||||
};
|
||||
|
||||
struct RunInfo {
|
||||
uint32_t loop;
|
||||
uint32_t bN2Idx;
|
||||
uint32_t bIdx;
|
||||
uint32_t n2Idx = 0;
|
||||
uint32_t gS1Idx;
|
||||
uint32_t s2Idx;
|
||||
|
||||
uint32_t actS1Size = 1;
|
||||
uint32_t actS2Size = 1;
|
||||
uint32_t actMBaseSize;
|
||||
uint32_t actualSingleProcessSInnerSize;
|
||||
uint32_t actualSingleProcessSInnerSizeAlign;
|
||||
|
||||
uint64_t tensorQueryOffset;
|
||||
uint64_t tensorKeyOffset;
|
||||
uint64_t tensorKeyScaleOffset;
|
||||
uint64_t tensorWeightsOffset;
|
||||
uint64_t indiceOutOffset;
|
||||
|
||||
bool isFirstS2InnerLoop;
|
||||
bool isLastS2InnerLoop;
|
||||
bool isAllLoopEnd = false;
|
||||
bool isValid = false;
|
||||
};
|
||||
|
||||
struct ConstInfo {
|
||||
// CUBE与VEC核间同步的模式
|
||||
static constexpr uint32_t FIA_SYNC_MODE2 = 2;
|
||||
// BUFFER的字节数
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_32B = 32;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_64B = 64;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_256B = 256;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_512B = 512;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_1K = 1024;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_2K = 2048;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_4K = 4096;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_8K = 8192;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_16K = 16384;
|
||||
static constexpr uint32_t BUFFER_SIZE_BYTE_32K = 32768;
|
||||
// 无效索引
|
||||
static constexpr int INVALID_IDX = -1;
|
||||
|
||||
// CUBE和VEC的核间同步EventID
|
||||
uint32_t syncC1V1 = 0U;
|
||||
uint32_t syncC1V0 = 2U;
|
||||
uint32_t syncV1C1 = 0U;
|
||||
uint32_t syncV0C1 = 1U;
|
||||
|
||||
// 基本块大小
|
||||
uint32_t mBaseSize = 1ULL;
|
||||
uint32_t s1BaseSize = 1ULL;
|
||||
uint32_t s2BaseSize = 1ULL;
|
||||
|
||||
uint64_t batchSize = 0ULL;
|
||||
uint64_t gSize = 0ULL;
|
||||
uint64_t qHeadNum = 0ULL;
|
||||
uint64_t kHeadNum;
|
||||
uint64_t headDim;
|
||||
uint64_t sparseCount; // topK选取大小
|
||||
uint64_t kSeqSize = 0ULL; // kv最大S长度
|
||||
uint64_t qSeqSize = 1ULL; // q最大S长度
|
||||
uint32_t kCacheBlockSize = 0; // PA场景的block size
|
||||
uint32_t maxBlockNumPerBatch = 0; // PA场景的最大单batch block number
|
||||
LI_LAYOUT outputLayout; // 输出的格式
|
||||
bool attenMaskFlag = false;
|
||||
|
||||
uint32_t actualLenQDims = 0U; // query的actualSeqLength 的维度
|
||||
uint32_t actualLenDims = 0U; // KV 的actualSeqLength 的维度
|
||||
bool isAccumSeqS1 = false; // 是否累加模式
|
||||
bool isAccumSeqS2 = false; // 是否累加模式
|
||||
};
|
||||
|
||||
struct SplitCoreInfo {
|
||||
uint32_t s2Start = 0U; // S2的起始位置
|
||||
uint32_t s2End = 0U; // S2循环index上限
|
||||
uint32_t bN2Start = 0U;
|
||||
uint32_t bN2End = 0U;
|
||||
uint32_t gS1Start = 0U;
|
||||
uint32_t gS1End = 0U;
|
||||
bool isLD = false; // 当前核是否需要进行Decode归约任务
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T Align(T num, T rnd)
|
||||
{
|
||||
return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd) * (rnd)));
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__aicore__ inline T1 Min(T1 a, T2 b)
|
||||
{
|
||||
return (a > b) ? (b) : (a);
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__aicore__ inline T1 Max(T1 a, T2 b)
|
||||
{
|
||||
return (a > b) ? (a) : (b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T CeilDiv(T num, T rnd)
|
||||
{
|
||||
return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd)));
|
||||
}
|
||||
} // namespace LIQCommon
|
||||
|
||||
#endif // LIGHTNING_INDEXER_QUANT_COMMON_H
|
||||
@@ -0,0 +1,714 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_kernel.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef LIGHTNING_INDEXER_QUANT_KERNEL_H
|
||||
#define LIGHTNING_INDEXER_QUANT_KERNEL_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_operator_list_tensor_intf.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
#include "lib/matrix/matmul/tiling.h"
|
||||
#include "lightning_indexer_quant_common.h"
|
||||
#include "lightning_indexer_quant_service_vector.h"
|
||||
#include "lightning_indexer_quant_service_cube.h"
|
||||
|
||||
namespace LIQKernel {
|
||||
using namespace LIQCommon;
|
||||
using namespace LIQServiceVec;
|
||||
using namespace matmul;
|
||||
using AscendC::CacheMode;
|
||||
using AscendC::CrossCoreSetFlag;
|
||||
using AscendC::CrossCoreWaitFlag;
|
||||
|
||||
// 由于S2循环前,RunInfo还没有赋值,使用TempLoopInfo临时存放B、N、S1轴相关的信息;同时减少重复计算
|
||||
struct TempLoopInfo {
|
||||
uint32_t bN2Idx = 0;
|
||||
uint32_t bIdx = 0U;
|
||||
uint32_t n2Idx = 0U;
|
||||
uint32_t gS1Idx = 0U;
|
||||
uint32_t gS1LoopEnd = 0U; // gS1方向循环的结束Idx
|
||||
uint32_t s2LoopEnd = 0U; // S2方向循环的结束Idx
|
||||
uint32_t actS1Size = 1ULL; // 当前Batch循环处理的S1轴的实际大小
|
||||
uint32_t actS2Size = 0ULL;
|
||||
bool curActSeqLenIsZero = false;
|
||||
bool needDealActS1LessThanS1 = false; // S1的实际长度小于shape的S1长度时,是否需要清理输出
|
||||
uint32_t actMBaseSize = 0U; // m轴(gS1)方向实际大小
|
||||
uint32_t mBasicSizeTail = 0U; // gS1方向循环的尾基本块大小
|
||||
uint32_t s2BasicSizeTail = 0U; // S2方向循环的尾基本块大小
|
||||
};
|
||||
|
||||
template <typename LIQT>
|
||||
class LIQPreload {
|
||||
public:
|
||||
__aicore__ inline LIQPreload(){};
|
||||
__aicore__ inline void Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
|
||||
__gm__ uint8_t *queryScale, __gm__ uint8_t *keyScale, __gm__ uint8_t *actualSeqLengthsQ,
|
||||
__gm__ uint8_t *actualSeqLengthsK, __gm__ uint8_t *blockTable,
|
||||
__gm__ uint8_t *sparseIndices, __gm__ uint8_t *workspace,
|
||||
const LIQTilingData *__restrict tiling, TPipe *tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
// =================================类型定义区=================================
|
||||
using Q_T = typename LIQT::queryType;
|
||||
using K_T = typename LIQT::keyType;
|
||||
using OUT_T = typename LIQT::outputType;
|
||||
static constexpr bool PAGE_ATTENTION = LIQT::pageAttention;
|
||||
static constexpr LI_LAYOUT Q_LAYOUT_T = LIQT::layout;
|
||||
static constexpr LI_LAYOUT K_LAYOUT_T = LIQT::keyLayout;
|
||||
|
||||
using MM1_OUT_T = float;
|
||||
|
||||
LIQMatmul<LIQT> matmulService;
|
||||
LIQVector<LIQT> vectorService;
|
||||
|
||||
// =================================常量区=================================
|
||||
static constexpr uint32_t SYNC_C1_V1_FLAG = 4;
|
||||
static constexpr uint32_t SYNC_V1_C1_FLAG = 5;
|
||||
|
||||
static constexpr uint32_t M_BASE_SIZE = 256;
|
||||
static constexpr uint32_t S2_BASE_SIZE = 2048;
|
||||
static constexpr uint32_t HEAD_DIM = 128;
|
||||
static constexpr uint32_t K_HEAD_NUM = 1;
|
||||
static constexpr uint32_t GM_ALIGN_BYTES = 512;
|
||||
static constexpr uint32_t LI_QUANT_PRELOAD_TASK_CACHE_SIZE = 2;
|
||||
|
||||
static constexpr int64_t LD_PREFETCH_LEN = 2;
|
||||
// for workspace double
|
||||
static constexpr uint32_t WS_DOBULE = 2;
|
||||
|
||||
protected:
|
||||
TPipe *pipe = nullptr;
|
||||
|
||||
// offset
|
||||
uint64_t queryCoreOffset = 0ULL;
|
||||
uint64_t keyCoreOffset = 0ULL;
|
||||
uint64_t keyScaleCoreOffset = 0ULL;
|
||||
uint64_t weightsCoreOffset = 0ULL;
|
||||
uint64_t indiceOutCoreOffset = 0ULL;
|
||||
|
||||
// ================================Global Buffer区=================================
|
||||
GlobalTensor<Q_T> queryGm;
|
||||
GlobalTensor<K_T> keyGm;
|
||||
GlobalTensor<half> weightsGm;
|
||||
|
||||
GlobalTensor<int32_t> indiceOutGm;
|
||||
GlobalTensor<int32_t> blockTableGm;
|
||||
|
||||
GlobalTensor<uint32_t> actualSeqLengthsGmQ;
|
||||
GlobalTensor<uint32_t> actualSeqLengthsGm;
|
||||
|
||||
// ================================类成员变量====================================
|
||||
// aic、aiv核信息
|
||||
uint32_t tmpBlockIdx = 0U;
|
||||
uint32_t aiCoreIdx = 0U;
|
||||
uint32_t usedCoreNum = 0U;
|
||||
|
||||
LIQCommon::ConstInfo constInfo{};
|
||||
TempLoopInfo tempLoopInfo{};
|
||||
LIQCommon::SplitCoreInfo splitCoreInfo{};
|
||||
|
||||
// ================================Init functions==================================
|
||||
__aicore__ inline void InitTilingData(const LIQTilingData *__restrict tilingData);
|
||||
__aicore__ inline void InitBuffers();
|
||||
__aicore__ inline void InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengthsK);
|
||||
// ================================Split Core================================
|
||||
__aicore__ inline void SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, LIQCommon::SplitCoreInfo &info);
|
||||
__aicore__ inline uint32_t GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size, uint32_t actS2Size);
|
||||
__aicore__ inline uint32_t GetTotalBaseBlockNum();
|
||||
// ================================Process functions================================
|
||||
__aicore__ inline void ProcessMain();
|
||||
__aicore__ inline void ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx,
|
||||
LIQCommon::RunInfo runInfo[LI_QUANT_PRELOAD_TASK_CACHE_SIZE]);
|
||||
__aicore__ inline void ProcessDecode();
|
||||
__aicore__ inline void ProcessInvalid();
|
||||
// ================================Params Calc=====================================
|
||||
__aicore__ inline void CalcGS1LoopParams(uint32_t bN2Idx);
|
||||
__aicore__ inline void GetBN2Idx(uint32_t bN2Idx);
|
||||
__aicore__ inline uint32_t GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq,
|
||||
GlobalTensor<uint32_t> &actualSeqLengthsGm, uint32_t defaultSeqLen);
|
||||
__aicore__ inline void GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size);
|
||||
__aicore__ inline void CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx);
|
||||
__aicore__ inline void CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LIQCommon::RunInfo &runInfo);
|
||||
__aicore__ inline void DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start);
|
||||
};
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::InitTilingData(const LIQTilingData *__restrict tilingData)
|
||||
{
|
||||
usedCoreNum = tilingData->usedCoreNum;
|
||||
constInfo.batchSize = tilingData->bSize;
|
||||
constInfo.qHeadNum = constInfo.gSize = tilingData->gSize;
|
||||
constInfo.kSeqSize = tilingData->s2Size;
|
||||
constInfo.qSeqSize = tilingData->s1Size;
|
||||
constInfo.attenMaskFlag = (tilingData->sparseMode == 3);
|
||||
constInfo.kCacheBlockSize = tilingData->blockSize;
|
||||
constInfo.maxBlockNumPerBatch = tilingData->maxBlockNumPerBatch;
|
||||
constInfo.sparseCount = tilingData->sparseCount;
|
||||
constInfo.outputLayout = Q_LAYOUT_T; // 输出和输入形状一致
|
||||
if (Q_LAYOUT_T == LI_LAYOUT::TND) {
|
||||
constInfo.isAccumSeqS1 = true;
|
||||
}
|
||||
if (K_LAYOUT_T == LI_LAYOUT::TND) {
|
||||
constInfo.isAccumSeqS2 = true;
|
||||
}
|
||||
|
||||
constInfo.kHeadNum = K_HEAD_NUM;
|
||||
constInfo.headDim = HEAD_DIM;
|
||||
|
||||
constInfo.mBaseSize = M_BASE_SIZE;
|
||||
constInfo.s2BaseSize = S2_BASE_SIZE;
|
||||
constInfo.s1BaseSize = (constInfo.mBaseSize + constInfo.gSize - 1) / constInfo.gSize;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::InitBuffers()
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.InitBuffers(pipe);
|
||||
} else {
|
||||
matmulService.InitBuffers(pipe);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ,
|
||||
__gm__ uint8_t *actualSeqLengthsK)
|
||||
{
|
||||
if (actualSeqLengthsQ == nullptr) {
|
||||
constInfo.actualLenQDims = 0;
|
||||
} else {
|
||||
constInfo.actualLenQDims = constInfo.batchSize;
|
||||
actualSeqLengthsGmQ.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengthsQ, constInfo.actualLenQDims);
|
||||
}
|
||||
if (actualSeqLengthsK == nullptr) {
|
||||
constInfo.actualLenDims = 0;
|
||||
} else {
|
||||
constInfo.actualLenDims = constInfo.batchSize;
|
||||
actualSeqLengthsGm.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengthsK, constInfo.actualLenDims);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline uint32_t LIQPreload<LIQT>::GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq,
|
||||
GlobalTensor<uint32_t> &actualSeqLengthsGm,
|
||||
uint32_t defaultSeqLen)
|
||||
{
|
||||
if (actualLenDims == 0) {
|
||||
return defaultSeqLen;
|
||||
} else if (isAccumSeq && bIdx > 0) {
|
||||
return actualSeqLengthsGm.GetValue(bIdx) - actualSeqLengthsGm.GetValue(bIdx - 1);
|
||||
} else {
|
||||
return actualSeqLengthsGm.GetValue(bIdx);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size)
|
||||
{
|
||||
actS1Size = GetActualSeqLen(bIdx, constInfo.actualLenQDims, constInfo.isAccumSeqS1, actualSeqLengthsGmQ,
|
||||
constInfo.qSeqSize);
|
||||
actS2Size =
|
||||
GetActualSeqLen(bIdx, constInfo.actualLenDims, constInfo.isAccumSeqS2, actualSeqLengthsGm, constInfo.kSeqSize);
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline uint32_t LIQPreload<LIQT>::GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size,
|
||||
uint32_t actS2Size)
|
||||
{
|
||||
if (actS2Size == 0) {
|
||||
return 0;
|
||||
}
|
||||
uint32_t s1Offset = constInfo.s1BaseSize * s1gIdx;
|
||||
int32_t validS2LenBase = static_cast<int32_t>(actS2Size) - static_cast<int32_t>(actS1Size);
|
||||
int32_t validS2Len = s1Offset + validS2LenBase + constInfo.s1BaseSize;
|
||||
validS2Len = Min(validS2Len, static_cast<int32_t>(actS2Size));
|
||||
validS2Len = Max(validS2Len, 1);
|
||||
return (validS2Len + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline uint32_t LIQPreload<LIQT>::GetTotalBaseBlockNum()
|
||||
{
|
||||
uint32_t totalBlockNum = 0;
|
||||
uint32_t actS1Size, actS2Size;
|
||||
uint32_t s1GBaseNum, s2BaseNum;
|
||||
for (uint32_t bIdx = 0; bIdx < constInfo.batchSize; bIdx++) {
|
||||
GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size);
|
||||
s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize);
|
||||
if (!constInfo.attenMaskFlag) {
|
||||
s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize);
|
||||
totalBlockNum += s1GBaseNum * s2BaseNum * constInfo.kHeadNum;
|
||||
continue;
|
||||
}
|
||||
for (uint32_t s1gIdx = 0; s1gIdx < s1GBaseNum; s1gIdx++) {
|
||||
s2BaseNum = GetS2BaseBlockNumOnMask(s1gIdx, actS1Size, actS2Size);
|
||||
totalBlockNum += s2BaseNum * constInfo.kHeadNum;
|
||||
}
|
||||
}
|
||||
return totalBlockNum;
|
||||
}
|
||||
|
||||
// 多核版本,双闭区间。基本原则:计算每个核最少处理的块数, 剩余的部分前面的核每个核多处理一块
|
||||
template <typename LIQT>
|
||||
__aicore__ void inline LIQPreload<LIQT>::SplitCore(uint32_t curCoreIdx, uint32_t &coreNum,
|
||||
LIQCommon::SplitCoreInfo &info)
|
||||
{
|
||||
uint32_t totalBlockNum = GetTotalBaseBlockNum();
|
||||
uint32_t minBlockPerCore = totalBlockNum / coreNum;
|
||||
uint32_t deal1MoreBlockCoreNum = totalBlockNum % coreNum;
|
||||
uint32_t coreIdx = 0;
|
||||
uint32_t lastGS1RemainBlockCnt = 0;
|
||||
uint32_t coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore;
|
||||
coreNum = minBlockPerCore == 0 ? deal1MoreBlockCoreNum : coreNum;
|
||||
|
||||
bool findLastCoreEnd = true;
|
||||
uint32_t actS1Size, actS2Size;
|
||||
uint32_t s1GBaseNum, s2BaseNum;
|
||||
for (uint32_t bN2Idx = 0; bN2Idx < constInfo.batchSize * constInfo.kHeadNum; bN2Idx++) {
|
||||
uint32_t bIdx = bN2Idx / constInfo.kHeadNum;
|
||||
if (bN2Idx % constInfo.kHeadNum == 0) {
|
||||
GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size);
|
||||
s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize);
|
||||
s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize);
|
||||
}
|
||||
if constexpr (Q_LAYOUT_T == LI_LAYOUT::BSND) {
|
||||
if (findLastCoreEnd && (s1GBaseNum == 0U || s2BaseNum == 0U)) {
|
||||
info.bN2Start = bN2Idx;
|
||||
info.gS1Start = 0;
|
||||
info.s2Start = 0;
|
||||
findLastCoreEnd = false;
|
||||
}
|
||||
}
|
||||
for (uint32_t gS1Idx = 0; gS1Idx < s1GBaseNum; gS1Idx++) {
|
||||
if (constInfo.attenMaskFlag) {
|
||||
s2BaseNum = GetS2BaseBlockNumOnMask(gS1Idx, actS1Size, actS2Size);
|
||||
}
|
||||
if (findLastCoreEnd && s2BaseNum == 0U) {
|
||||
info.bN2Start = bN2Idx;
|
||||
info.gS1Start = gS1Idx;
|
||||
info.s2Start = 0;
|
||||
findLastCoreEnd = false;
|
||||
}
|
||||
for (uint32_t s2Idx = 0; s2Idx < s2BaseNum;) {
|
||||
if (findLastCoreEnd) {
|
||||
info.bN2Start = bN2Idx;
|
||||
info.gS1Start = gS1Idx;
|
||||
info.s2Start = s2Idx;
|
||||
findLastCoreEnd = false;
|
||||
}
|
||||
uint32_t s2RemainBaseNum = s2BaseNum - s2Idx;
|
||||
if (lastGS1RemainBlockCnt + s2RemainBaseNum >= coreDealBlockCnt) {
|
||||
info.bN2End = bN2Idx;
|
||||
info.gS1End = gS1Idx;
|
||||
info.s2End = s2Idx + coreDealBlockCnt - lastGS1RemainBlockCnt - 1;
|
||||
|
||||
if (coreIdx == curCoreIdx) {
|
||||
// S2被切N核,那么只有第一个核需要处理LD,其他核不用
|
||||
if (s2Idx == 0 && info.s2End + 1 < s2BaseNum) {
|
||||
info.isLD = true;
|
||||
}
|
||||
// 最后一个核处理的不是最后一个Batch,表明后面的Batch为空块(S2=0), 调整终点坐标以便清理输出
|
||||
if (coreIdx == coreNum - 1 && info.bN2End != constInfo.batchSize - 1) {
|
||||
info.bN2End = constInfo.batchSize - 1;
|
||||
info.gS1End = 0;
|
||||
info.s2End = 0;
|
||||
}
|
||||
return;
|
||||
}
|
||||
coreIdx++;
|
||||
findLastCoreEnd = true;
|
||||
s2Idx = info.s2End + 1;
|
||||
lastGS1RemainBlockCnt = 0;
|
||||
coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore;
|
||||
} else {
|
||||
lastGS1RemainBlockCnt += s2RemainBaseNum;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start)
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
if (constInfo.outputLayout == LI_LAYOUT::TND) {
|
||||
uint32_t tSize = actualSeqLengthsGmQ.GetValue(constInfo.batchSize - 1);
|
||||
uint32_t tBase = bIdx == 0 ? 0 : actualSeqLengthsGmQ.GetValue(bIdx - 1);
|
||||
uint32_t s1Count = tempLoopInfo.actS1Size;
|
||||
|
||||
for (uint32_t s1Idx = s1Start; s1Idx < s1Count; s1Idx++) {
|
||||
uint64_t indiceOutOffset =
|
||||
(tBase + s1Idx) * constInfo.kHeadNum * constInfo.sparseCount + // T轴、s1轴偏移
|
||||
n2Idx * constInfo.sparseCount; // N2轴偏移
|
||||
vectorService.CleanInvalidOutput(indiceOutOffset);
|
||||
}
|
||||
} else if (constInfo.outputLayout == LI_LAYOUT::BSND) {
|
||||
for (uint32_t s1Idx = s1Start; s1Idx < constInfo.qSeqSize; s1Idx++) {
|
||||
// B,S1,N2,K
|
||||
uint64_t indiceOutOffset = bIdx * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount +
|
||||
s1Idx * constInfo.kHeadNum * constInfo.sparseCount + // B轴、S1轴偏移
|
||||
n2Idx * constInfo.sparseCount; // N2轴偏移
|
||||
vectorService.CleanInvalidOutput(indiceOutOffset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
|
||||
__gm__ uint8_t *queryScale, __gm__ uint8_t *keyScale,
|
||||
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengthsK,
|
||||
__gm__ uint8_t *blockTable, __gm__ uint8_t *sparseIndices,
|
||||
__gm__ uint8_t *workspace, const LIQTilingData *__restrict tiling,
|
||||
TPipe *tPipe)
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
tmpBlockIdx = GetBlockIdx(); // vec:0-47
|
||||
aiCoreIdx = tmpBlockIdx / 2;
|
||||
} else {
|
||||
tmpBlockIdx = GetBlockIdx(); // cube:0-23
|
||||
aiCoreIdx = tmpBlockIdx;
|
||||
}
|
||||
|
||||
InitTilingData(tiling);
|
||||
InitActualSeqLen(actualSeqLengthsQ, actualSeqLengthsK);
|
||||
|
||||
// 计算分核
|
||||
SplitCore(aiCoreIdx, usedCoreNum, splitCoreInfo);
|
||||
|
||||
pipe = tPipe;
|
||||
// workspace 内存排布
|
||||
// |mm1ResGm(存S)|vec1ResGm(存LD中间结果)|vec1ParamGm(存LD参数)
|
||||
// |Core0_mm1ResDB0-Core0_mm1ResDB1-Core1_mm1ResDB0....Core23_mm1ResDB0-Core23_mm1ResDB1|Core0_vec1Res...
|
||||
uint64_t offset = 0;
|
||||
|
||||
// mm1开DoubleBuffer
|
||||
GlobalTensor<MM1_OUT_T> mm1ResGm; // 存放S
|
||||
uint64_t singleCoreMm1ResSize = WS_DOBULE * constInfo.s1BaseSize * constInfo.s2BaseSize * sizeof(MM1_OUT_T);
|
||||
mm1ResGm.SetGlobalBuffer((__gm__ MM1_OUT_T *)(workspace + aiCoreIdx * singleCoreMm1ResSize));
|
||||
offset += GetBlockNum() * singleCoreMm1ResSize;
|
||||
|
||||
// ld流程需要ws大小: [aicnum, 2, CeilDiv(constInfo.mBaseSize, constInfo.gSize), topkOut_*2]
|
||||
// (aic, 8, 2, 2, 2048)
|
||||
// (aic, s1_cube, 头尾, idx/value, K)
|
||||
GlobalTensor<float> vec1ResGm; // 存放TopK计算中间结果
|
||||
vec1ResGm.SetGlobalBuffer((__gm__ float *)(workspace + offset));
|
||||
offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * WS_DOBULE * BASE_TOPK * sizeof(float);
|
||||
|
||||
// (aic, 8, 2, 16)
|
||||
// (aic, s1_cube, 头尾,16ele)
|
||||
GlobalTensor<int64_t> vec1ParamGm; // 存放LD参数信息
|
||||
vec1ParamGm.SetGlobalBuffer((__gm__ int64_t *)(workspace + offset));
|
||||
offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * LD_PARAM_NUM * sizeof(int64_t);
|
||||
|
||||
GlobalTensor<half> weightWorkspaceGm; // v1阶段处理w*scale后的结果
|
||||
uint64_t weightMemSize = BLOCK_CUBE * constInfo.mBaseSize * WS_DOBULE * sizeof(half);
|
||||
weightWorkspaceGm.SetGlobalBuffer((__gm__ half *)(workspace + offset + aiCoreIdx * weightMemSize));
|
||||
offset += GetBlockNum() * weightMemSize;
|
||||
|
||||
GlobalTensor<half> qScaleGm;
|
||||
GlobalTensor<half> kScaleGm;
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.InitParams(constInfo, tiling);
|
||||
indiceOutGm.SetGlobalBuffer((__gm__ int32_t *)sparseIndices);
|
||||
weightsGm.SetGlobalBuffer((__gm__ half *)weights);
|
||||
qScaleGm.SetGlobalBuffer((__gm__ half *)queryScale);
|
||||
kScaleGm.SetGlobalBuffer((__gm__ half *)keyScale);
|
||||
blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable);
|
||||
vectorService.InitVecInputTensor(weightsGm, qScaleGm, kScaleGm, indiceOutGm, blockTableGm);
|
||||
vectorService.InitVecWorkspaceTensor(weightWorkspaceGm, mm1ResGm, vec1ResGm, vec1ParamGm);
|
||||
} else {
|
||||
matmulService.InitParams(constInfo);
|
||||
queryGm.SetGlobalBuffer((__gm__ Q_T *)query);
|
||||
if constexpr (PAGE_ATTENTION) {
|
||||
blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable);
|
||||
}
|
||||
keyGm.SetGlobalBuffer((__gm__ K_T *)key);
|
||||
matmulService.InitMm1GlobalTensor(blockTableGm, keyGm, queryGm, mm1ResGm, weightWorkspaceGm);
|
||||
}
|
||||
InitBuffers();
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::GetBN2Idx(uint32_t bN2Idx)
|
||||
{
|
||||
tempLoopInfo.bN2Idx = bN2Idx;
|
||||
tempLoopInfo.bIdx = bN2Idx / constInfo.kHeadNum;
|
||||
tempLoopInfo.n2Idx = bN2Idx % constInfo.kHeadNum;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx)
|
||||
{
|
||||
tempLoopInfo.gS1Idx = gS1LoopIdx;
|
||||
tempLoopInfo.actMBaseSize = constInfo.mBaseSize;
|
||||
uint32_t remainedGS1Size = tempLoopInfo.actS1Size * constInfo.gSize - tempLoopInfo.gS1Idx * constInfo.mBaseSize;
|
||||
if (remainedGS1Size <= constInfo.mBaseSize && remainedGS1Size > 0) {
|
||||
tempLoopInfo.actMBaseSize = tempLoopInfo.mBasicSizeTail;
|
||||
}
|
||||
|
||||
bool isEnd = (bN2LoopIdx == splitCoreInfo.bN2End) && (gS1LoopIdx == splitCoreInfo.gS1End);
|
||||
uint32_t s2BlockNum;
|
||||
if (constInfo.attenMaskFlag) {
|
||||
s2BlockNum = GetS2BaseBlockNumOnMask(gS1LoopIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size);
|
||||
} else {
|
||||
s2BlockNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
|
||||
}
|
||||
tempLoopInfo.s2LoopEnd = isEnd ? splitCoreInfo.s2End : s2BlockNum - 1;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::CalcGS1LoopParams(uint32_t bN2LoopIdx)
|
||||
{
|
||||
GetBN2Idx(bN2LoopIdx);
|
||||
GetS1S2ActualSeqLen(tempLoopInfo.bIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size);
|
||||
if ((tempLoopInfo.actS2Size == 0) || (tempLoopInfo.actS1Size == 0)) {
|
||||
tempLoopInfo.curActSeqLenIsZero = true;
|
||||
return;
|
||||
}
|
||||
tempLoopInfo.curActSeqLenIsZero = false;
|
||||
tempLoopInfo.s2BasicSizeTail = tempLoopInfo.actS2Size % constInfo.s2BaseSize;
|
||||
tempLoopInfo.s2BasicSizeTail =
|
||||
(tempLoopInfo.s2BasicSizeTail == 0) ? constInfo.s2BaseSize : tempLoopInfo.s2BasicSizeTail;
|
||||
tempLoopInfo.mBasicSizeTail = (tempLoopInfo.actS1Size * constInfo.gSize) % constInfo.mBaseSize;
|
||||
tempLoopInfo.mBasicSizeTail =
|
||||
(tempLoopInfo.mBasicSizeTail == 0) ? constInfo.mBaseSize : tempLoopInfo.mBasicSizeTail;
|
||||
|
||||
uint32_t gS1SplitNum = (tempLoopInfo.actS1Size * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize;
|
||||
tempLoopInfo.gS1LoopEnd = (bN2LoopIdx == splitCoreInfo.bN2End) ? splitCoreInfo.gS1End : gS1SplitNum - 1;
|
||||
if constexpr (Q_LAYOUT_T == LI_LAYOUT::BSND) {
|
||||
if (tempLoopInfo.gS1LoopEnd == gS1SplitNum - 1 && constInfo.qSeqSize > tempLoopInfo.actS1Size) {
|
||||
tempLoopInfo.needDealActS1LessThanS1 = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
runInfo.loop = loop;
|
||||
runInfo.bIdx = tempLoopInfo.bIdx;
|
||||
runInfo.gS1Idx = tempLoopInfo.gS1Idx;
|
||||
runInfo.s2Idx = s2LoopIdx;
|
||||
runInfo.bN2Idx = tempLoopInfo.bN2Idx;
|
||||
runInfo.isValid = s2LoopIdx <= tempLoopInfo.s2LoopEnd;
|
||||
|
||||
if (!runInfo.isValid) {
|
||||
return; // 需要验证, v1 时候需要runInfo
|
||||
}
|
||||
|
||||
runInfo.actS1Size = tempLoopInfo.actS1Size;
|
||||
runInfo.actS2Size = tempLoopInfo.actS2Size;
|
||||
// 计算实际基本块size
|
||||
runInfo.actMBaseSize = tempLoopInfo.actMBaseSize;
|
||||
runInfo.actualSingleProcessSInnerSize = constInfo.s2BaseSize;
|
||||
uint32_t s2SplitNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
|
||||
if (runInfo.s2Idx == s2SplitNum - 1) {
|
||||
runInfo.actualSingleProcessSInnerSize = tempLoopInfo.s2BasicSizeTail;
|
||||
}
|
||||
runInfo.actualSingleProcessSInnerSizeAlign =
|
||||
LIQCommon::Align((uint32_t)runInfo.actualSingleProcessSInnerSize, LIQCommon::ConstInfo::BUFFER_SIZE_BYTE_32B);
|
||||
|
||||
runInfo.isFirstS2InnerLoop = s2LoopIdx == splitCoreInfo.s2Start;
|
||||
runInfo.isLastS2InnerLoop = s2LoopIdx == tempLoopInfo.s2LoopEnd;
|
||||
runInfo.isAllLoopEnd = (runInfo.bN2Idx == splitCoreInfo.bN2End) && (runInfo.gS1Idx == splitCoreInfo.gS1End) &&
|
||||
(runInfo.s2Idx == splitCoreInfo.s2End);
|
||||
|
||||
if (runInfo.isFirstS2InnerLoop) {
|
||||
uint64_t actualSeqQPrefixSum;
|
||||
if constexpr (Q_LAYOUT_T == LI_LAYOUT::TND) {
|
||||
actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGmQ.GetValue(runInfo.bIdx - 1);
|
||||
} else { // BSND
|
||||
actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.qSeqSize;
|
||||
}
|
||||
uint64_t tndBIdxOffset = actualSeqQPrefixSum * constInfo.qHeadNum * constInfo.headDim;
|
||||
// B,S1,N1(N2,G),D
|
||||
queryCoreOffset = tndBIdxOffset + runInfo.gS1Idx * constInfo.mBaseSize * constInfo.headDim;
|
||||
// B,S1,N1(N2,G)/T,N1(N2,G)
|
||||
weightsCoreOffset = actualSeqQPrefixSum * constInfo.qHeadNum + runInfo.n2Idx * constInfo.gSize;
|
||||
// B,S1,N2,k/T,N2,k
|
||||
indiceOutCoreOffset =
|
||||
actualSeqQPrefixSum * constInfo.kHeadNum * constInfo.sparseCount + runInfo.n2Idx * constInfo.sparseCount;
|
||||
}
|
||||
uint64_t actualSeqKPrefixSum;
|
||||
if constexpr (K_LAYOUT_T == LI_LAYOUT::TND) { // T N2 D
|
||||
actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGm.GetValue(runInfo.bIdx - 1);
|
||||
} else {
|
||||
actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.kSeqSize;
|
||||
}
|
||||
uint64_t tndBIdxOffsetForK = actualSeqKPrefixSum * constInfo.kHeadNum * constInfo.headDim;
|
||||
keyCoreOffset = tndBIdxOffsetForK + runInfo.s2Idx * constInfo.s2BaseSize * constInfo.kHeadNum * constInfo.headDim;
|
||||
keyScaleCoreOffset = (actualSeqKPrefixSum + runInfo.s2Idx * constInfo.s2BaseSize) * constInfo.kHeadNum;
|
||||
runInfo.tensorQueryOffset = queryCoreOffset;
|
||||
runInfo.tensorKeyOffset = keyCoreOffset;
|
||||
runInfo.tensorKeyScaleOffset = keyScaleCoreOffset;
|
||||
runInfo.tensorWeightsOffset = weightsCoreOffset;
|
||||
runInfo.indiceOutOffset = indiceOutCoreOffset;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::Process()
|
||||
{
|
||||
if (usedCoreNum == 0) {
|
||||
// 没有计算任务,直接清理输出
|
||||
ProcessInvalid();
|
||||
return;
|
||||
}
|
||||
|
||||
ProcessMain();
|
||||
|
||||
ProcessDecode();
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::ProcessInvalid()
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
uint32_t aivCoreNum = GetBlockNum() * 2; // 2 means c:v = 1:2
|
||||
uint64_t totalOutputSize =
|
||||
constInfo.batchSize * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount;
|
||||
uint64_t singleCoreSize =
|
||||
LIQCommon::Align((totalOutputSize + aivCoreNum - 1) / aivCoreNum, GM_ALIGN_BYTES / sizeof(OUT_T));
|
||||
uint64_t baseSize = tmpBlockIdx * singleCoreSize;
|
||||
if (baseSize < totalOutputSize) {
|
||||
uint64_t dealSize =
|
||||
(baseSize + singleCoreSize <= totalOutputSize) ? singleCoreSize : totalOutputSize - baseSize;
|
||||
GlobalTensor<OUT_T> output = indiceOutGm[baseSize];
|
||||
AscendC::InitGlobalMemory(output, dealSize, constInfo.INVALID_IDX);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::ProcessMain()
|
||||
{
|
||||
if (aiCoreIdx >= usedCoreNum) {
|
||||
// 无任务核直接返回
|
||||
return;
|
||||
}
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.AllocEventID();
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE2>(constInfo.syncV1C1);
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE2>(constInfo.syncV1C1);
|
||||
} else {
|
||||
matmulService.AllocEventID();
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V0);
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V0);
|
||||
}
|
||||
|
||||
LIQCommon::RunInfo runInfo[LI_QUANT_PRELOAD_TASK_CACHE_SIZE];
|
||||
|
||||
uint32_t gloop = 0;
|
||||
for (uint32_t bN2LoopIdx = splitCoreInfo.bN2Start; bN2LoopIdx <= splitCoreInfo.bN2End; bN2LoopIdx++) {
|
||||
CalcGS1LoopParams(bN2LoopIdx);
|
||||
if (tempLoopInfo.curActSeqLenIsZero) {
|
||||
DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, 0U);
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
if (bN2LoopIdx == splitCoreInfo.bN2End && gloop > 0) {
|
||||
CrossCoreWaitFlag(constInfo.syncC1V1);
|
||||
vectorService.ProcessVec1(runInfo[1 - gloop % LI_QUANT_PRELOAD_TASK_CACHE_SIZE]);
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE3>(
|
||||
constInfo.syncV1C1); // 反向同步 1
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
for (uint32_t gS1LoopIdx = splitCoreInfo.gS1Start; gS1LoopIdx <= tempLoopInfo.gS1LoopEnd; gS1LoopIdx++) {
|
||||
CalcS2LoopParams(bN2LoopIdx, gS1LoopIdx);
|
||||
bool isEnd = (bN2LoopIdx == splitCoreInfo.bN2End) && (gS1LoopIdx == splitCoreInfo.gS1End);
|
||||
uint32_t extraLoop = isEnd ? LI_QUANT_PRELOAD_TASK_CACHE_SIZE - 1 : 0;
|
||||
for (int s2LoopIdx = splitCoreInfo.s2Start; s2LoopIdx <= (tempLoopInfo.s2LoopEnd + extraLoop);
|
||||
s2LoopIdx++) {
|
||||
ProcessBaseBlock(gloop, s2LoopIdx, runInfo);
|
||||
++gloop;
|
||||
}
|
||||
splitCoreInfo.s2Start = 0;
|
||||
}
|
||||
if (tempLoopInfo.needDealActS1LessThanS1) {
|
||||
DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, tempLoopInfo.actS1Size);
|
||||
}
|
||||
splitCoreInfo.gS1Start = 0;
|
||||
}
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.FreeEventID();
|
||||
CrossCoreWaitFlag(constInfo.syncC1V0);
|
||||
CrossCoreWaitFlag(constInfo.syncC1V0);
|
||||
} else {
|
||||
matmulService.FreeEventID();
|
||||
CrossCoreWaitFlag(constInfo.syncV1C1);
|
||||
CrossCoreWaitFlag(constInfo.syncV1C1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx,
|
||||
LIQCommon::RunInfo runInfo[LI_QUANT_PRELOAD_TASK_CACHE_SIZE])
|
||||
{
|
||||
int32_t curTaskId = loop % LI_QUANT_PRELOAD_TASK_CACHE_SIZE;
|
||||
LIQCommon::RunInfo &curRunInfo = runInfo[curTaskId];
|
||||
LIQCommon::RunInfo &lastRunInfo = runInfo[1 - curTaskId];
|
||||
|
||||
CalcRunInfo(loop, s2LoopIdx, curRunInfo);
|
||||
|
||||
if (curRunInfo.isValid) {
|
||||
if ASCEND_IS_AIC {
|
||||
if (curRunInfo.isFirstS2InnerLoop) {
|
||||
CrossCoreWaitFlag(constInfo.syncV0C1);
|
||||
}
|
||||
CrossCoreWaitFlag(constInfo.syncV1C1); // 反向同步 1
|
||||
matmulService.ComputeMm1(curRunInfo);
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V1);
|
||||
if (curRunInfo.isLastS2InnerLoop) {
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V0); // 反向同步 0
|
||||
}
|
||||
} else {
|
||||
if (curRunInfo.isFirstS2InnerLoop) {
|
||||
CrossCoreWaitFlag(constInfo.syncC1V0); // 反向同步 0
|
||||
vectorService.ProcessVec0(curRunInfo);
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE3>(constInfo.syncV0C1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (lastRunInfo.isValid) {
|
||||
if ASCEND_IS_AIV {
|
||||
CrossCoreWaitFlag(constInfo.syncC1V1);
|
||||
vectorService.ProcessVec1(lastRunInfo);
|
||||
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE3>(constInfo.syncV1C1); // 反向同步 1
|
||||
}
|
||||
lastRunInfo.isValid = false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQPreload<LIQT>::ProcessDecode()
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
vectorService.InitLDBuffers(pipe);
|
||||
ICachePreLoad(LD_PREFETCH_LEN);
|
||||
SyncAll();
|
||||
if (splitCoreInfo.isLD) {
|
||||
vectorService.ProcessLD();
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace LIQKernel
|
||||
#endif // LIGHTNING_INDEXER_QUANT_KERNEL_H
|
||||
@@ -0,0 +1,613 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_service_cube.h
|
||||
* \brief use 5 buffer for matmul l1, better pipeline
|
||||
*/
|
||||
#ifndef LIGHTNING_INDEXER_QUANT_SERVICE_CUBE_H
|
||||
#define LIGHTNING_INDEXER_QUANT_SERVICE_CUBE_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_operator_list_tensor_intf.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
#include "lib/matrix/matmul/tiling.h"
|
||||
#include "lightning_indexer_quant_common.h"
|
||||
|
||||
namespace LIQKernel {
|
||||
using namespace LIQCommon;
|
||||
struct MmInfo {
|
||||
int64_t s2L0LoopId;
|
||||
int64_t s1gL0LoopId;
|
||||
int64_t s2L0RealSize;
|
||||
int64_t s2GmOffset;
|
||||
};
|
||||
|
||||
template <typename LIQT>
|
||||
class LIQMatmul {
|
||||
public:
|
||||
using Q_T = typename LIQT::queryType;
|
||||
using K_T = typename LIQT::keyType;
|
||||
|
||||
__aicore__ inline LIQMatmul(){};
|
||||
__aicore__ inline void InitBuffers(TPipe *pipe);
|
||||
__aicore__ inline void InitMm1GlobalTensor(const GlobalTensor<int32_t> &blkTableGm, const GlobalTensor<K_T> &keyGm,
|
||||
const GlobalTensor<Q_T> &queryGm, const GlobalTensor<float> &mm1ResGm,
|
||||
const GlobalTensor<half> &weightWorkspaceGm);
|
||||
__aicore__ inline void InitParams(const ConstInfo &constInfo);
|
||||
__aicore__ inline void AllocEventID();
|
||||
__aicore__ inline void FreeEventID();
|
||||
__aicore__ inline void ComputeMm1(const LIQCommon::RunInfo &runInfo);
|
||||
|
||||
static constexpr IsResetLoad3dConfig LOAD3DV2_CONFIG = {true, true}; // isSetFMatrix isSetPadding;
|
||||
static constexpr uint64_t DOUBLE_BUF_NUM = 2;
|
||||
static constexpr uint64_t L0AB_BUF_NUM = 4;
|
||||
|
||||
static constexpr uint32_t KEY_MTE1_MTE2_EVENT = EVENT_ID2;
|
||||
static constexpr uint32_t QW_MTE1_MTE2_EVENT = EVENT_ID5; // KEY_MTE1_MTE2_EVENT + DOUBLE_BUF_NUM;
|
||||
static constexpr uint32_t M_MTE1_EVENT = EVENT_ID3;
|
||||
static constexpr uint32_t M_FIX_EVENT = EVENT_ID0;
|
||||
static constexpr uint32_t FIX_M_EVENT = EVENT_ID2;
|
||||
static constexpr uint32_t FIX_MTE1_EVENT = EVENT_ID4;
|
||||
|
||||
static constexpr uint64_t S8_BLOCK_CUBE = 32;
|
||||
|
||||
static constexpr uint32_t MTE2_MTE1_EVENT = EVENT_ID2;
|
||||
static constexpr uint32_t MTE1_M_EVENT = EVENT_ID2;
|
||||
|
||||
static constexpr uint64_t D_BASIC_BLOCK = 128;
|
||||
static constexpr uint64_t S1G_BASIC_BLOCK_L1 = 256;
|
||||
|
||||
static constexpr uint64_t S1G_BASIC_BLOCK_L0 = 128;
|
||||
static constexpr uint64_t S2_BASIC_BLOCK_L0 = 128;
|
||||
|
||||
static constexpr uint64_t QUERY_BUFFER_OFFSET = S1G_BASIC_BLOCK_L1 * D_BASIC_BLOCK;
|
||||
static constexpr uint64_t SL1_BUFFER_OFFSET = S1G_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0;
|
||||
static constexpr uint64_t KEY_BUFFER_OFFSET = S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK;
|
||||
static constexpr uint64_t WEIGHT_BUFFER_OFFSET = S1G_BASIC_BLOCK_L1 * BLOCK_CUBE;
|
||||
static constexpr uint64_t L0AB_BUFFER_OFFSET_S8_16K = 16 * 1024;
|
||||
static constexpr uint64_t L0AB_BUFFER_OFFSET_FP16_16K = 16 * 512;
|
||||
static constexpr uint64_t L0C_BUFFER_OFFSET = 64 * 256;
|
||||
|
||||
private:
|
||||
__aicore__ inline void WeightDmaCopy(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo);
|
||||
__aicore__ inline void LoadKeyToL0b(uint64_t s2L0RealSize);
|
||||
__aicore__ inline void LoadQueryToL0a(uint64_t s1gL1Offset, uint64_t s1gL1RealSize, uint64_t s1gL0RealSize);
|
||||
__aicore__ inline void QueryNd2Nz(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo);
|
||||
__aicore__ inline void KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LIQCommon::RunInfo &runInfo);
|
||||
__aicore__ inline void KeyNd2Nz(uint64_t s2L1RealSize, const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo);
|
||||
__aicore__ inline void FixpSToL1(uint64_t s1gL0RealSize, uint64_t s2L0RealSize);
|
||||
__aicore__ inline void LoadSToL0b(uint64_t s1gL1RealSize, uint64_t s2L0RealSize, uint64_t sL1BufIdx,
|
||||
int64_t mStartPt);
|
||||
__aicore__ inline void LoadWeightToL0a(uint64_t s1gL1Offset);
|
||||
__aicore__ inline void ComputeWs(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, int64_t s1gOffset);
|
||||
__aicore__ inline void FixpResToGm(uint64_t s1L0RealCount, uint64_t s2L0RealSize, uint64_t s1GmOffset,
|
||||
uint64_t s2GmOffset, const LIQCommon::RunInfo &runInfo);
|
||||
__aicore__ inline void ComputeQk(uint64_t s1gL0RealSize, uint64_t s2L0RealSize);
|
||||
__aicore__ inline void ProcessWs(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t sL1BufIdx,
|
||||
const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo);
|
||||
__aicore__ inline void ProcessQk(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t s1L0LoopCnt,
|
||||
const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo);
|
||||
__aicore__ inline void CalcMmInfo(MmInfo &mmInfo, uint64_t loopIdx, uint64_t s1L0LoopCnt, const MmInfo &lastMmInfo,
|
||||
const LIQCommon::RunInfo &runInfo);
|
||||
static constexpr LI_LAYOUT Q_LAYOUT_T = LIQT::layout;
|
||||
static constexpr LI_LAYOUT K_LAYOUT_T = LIQT::keyLayout;
|
||||
GlobalTensor<int32_t> blkTableGm_;
|
||||
GlobalTensor<K_T> keyGm_;
|
||||
GlobalTensor<Q_T> queryGm_;
|
||||
GlobalTensor<half> weightGm_;
|
||||
GlobalTensor<float> mm1ResGm_;
|
||||
|
||||
TBuf<TPosition::A1> bufQL1_;
|
||||
LocalTensor<Q_T> queryL1_;
|
||||
TBuf<TPosition::B1> bufKeyL1_;
|
||||
LocalTensor<K_T> keyL1_;
|
||||
TBuf<TPosition::A1> bufWeightL1_;
|
||||
LocalTensor<half> weightL1_;
|
||||
TBuf<TPosition::B1> bufSL1_;
|
||||
LocalTensor<half> sL1_;
|
||||
|
||||
TBuf<TPosition::A2> bufL0A_;
|
||||
LocalTensor<Q_T> l0a_;
|
||||
TBuf<TPosition::B2> bufL0B_;
|
||||
LocalTensor<K_T> l0b_;
|
||||
|
||||
TBuf<TPosition::CO1> bufL0C_;
|
||||
LocalTensor<int32_t> cL0_;
|
||||
|
||||
uint64_t keyL1BufIdx_ = 0;
|
||||
uint64_t qwL1Mte2BufIdx_ = 0;
|
||||
uint64_t sL1BufIdx_ = 0;
|
||||
uint64_t l0BufIdx_ = 0;
|
||||
uint64_t l0cBufIdx_ = 0;
|
||||
|
||||
ConstInfo constInfo_;
|
||||
};
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::InitParams(const ConstInfo &constInfo)
|
||||
{
|
||||
constInfo_ = constInfo;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::InitBuffers(TPipe *pipe)
|
||||
{
|
||||
pipe->InitBuffer(bufQL1_, DOUBLE_BUF_NUM * S1G_BASIC_BLOCK_L1 * D_BASIC_BLOCK * sizeof(Q_T));
|
||||
queryL1_ = bufQL1_.Get<Q_T>();
|
||||
pipe->InitBuffer(bufKeyL1_, DOUBLE_BUF_NUM * S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK * sizeof(K_T));
|
||||
keyL1_ = bufKeyL1_.Get<K_T>();
|
||||
|
||||
pipe->InitBuffer(bufWeightL1_, DOUBLE_BUF_NUM * S1G_BASIC_BLOCK_L1 * BLOCK_CUBE * sizeof(half));
|
||||
weightL1_ = bufWeightL1_.Get<half>();
|
||||
pipe->InitBuffer(bufSL1_, DOUBLE_BUF_NUM * S2_BASIC_BLOCK_L0 * S1G_BASIC_BLOCK_L0 * sizeof(half));
|
||||
sL1_ = bufSL1_.Get<half>();
|
||||
|
||||
pipe->InitBuffer(bufL0A_, 64 * 1024);
|
||||
l0a_ = bufL0A_.Get<Q_T>();
|
||||
pipe->InitBuffer(bufL0B_, 64 * 1024);
|
||||
l0b_ = bufL0B_.Get<K_T>();
|
||||
|
||||
pipe->InitBuffer(bufL0C_, 128 * 1024);
|
||||
cL0_ = bufL0C_.Get<int32_t>();
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::InitMm1GlobalTensor(const GlobalTensor<int32_t> &blkTableGm,
|
||||
const GlobalTensor<K_T> &keyGm,
|
||||
const GlobalTensor<Q_T> &queryGm,
|
||||
const GlobalTensor<float> &mm1ResGm,
|
||||
const GlobalTensor<half> &weightWorkspaceGm)
|
||||
{
|
||||
blkTableGm_ = blkTableGm;
|
||||
keyGm_ = keyGm;
|
||||
queryGm_ = queryGm;
|
||||
mm1ResGm_ = mm1ResGm;
|
||||
weightGm_ = weightWorkspaceGm;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::ProcessWs(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t sL1BufIdx,
|
||||
const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
WaitFlag<HardEvent::FIX_M>(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM);
|
||||
for (int64_t s1gOffset = 0; s1gOffset < s1gL0RealSize; s1gOffset += constInfo_.gSize) {
|
||||
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM);
|
||||
LoadSToL0b(s1gL0RealSize, mmInfo.s2L0RealSize, sL1BufIdx, s1gOffset);
|
||||
LoadWeightToL0a(s1gOffset + s1gL1Offset);
|
||||
|
||||
ComputeWs(s1gL0RealSize, mmInfo.s2L0RealSize, s1gOffset);
|
||||
|
||||
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM);
|
||||
l0BufIdx_++;
|
||||
}
|
||||
|
||||
FixpResToGm(s1gL0RealSize / constInfo_.gSize, mmInfo.s2L0RealSize, s1gL1Offset / constInfo_.gSize,
|
||||
mmInfo.s2L0LoopId * S2_BASIC_BLOCK_L0, runInfo);
|
||||
SetFlag<HardEvent::FIX_M>(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM);
|
||||
l0cBufIdx_++;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::ProcessQk(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t s1L0LoopCnt,
|
||||
const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
if (mmInfo.s1gL0LoopId == 0) {
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % DOUBLE_BUF_NUM);
|
||||
if constexpr (K_LAYOUT_T == LI_LAYOUT::PA_BSND) {
|
||||
KeyNd2NzForPA(mmInfo.s2L0RealSize, runInfo.s2Idx * constInfo_.s2BaseSize + mmInfo.s2GmOffset, runInfo);
|
||||
} else {
|
||||
KeyNd2Nz(mmInfo.s2L0RealSize, mmInfo, runInfo);
|
||||
}
|
||||
|
||||
SetFlag<HardEvent::MTE2_MTE1>(MTE2_MTE1_EVENT);
|
||||
WaitFlag<HardEvent::MTE2_MTE1>(MTE2_MTE1_EVENT);
|
||||
}
|
||||
|
||||
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM);
|
||||
LoadQueryToL0a(s1gL1Offset, runInfo.actMBaseSize, s1gL0RealSize);
|
||||
LoadKeyToL0b(mmInfo.s2L0RealSize);
|
||||
|
||||
if (mmInfo.s1gL0LoopId + 1 >= s1L0LoopCnt) {
|
||||
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % DOUBLE_BUF_NUM);
|
||||
keyL1BufIdx_++;
|
||||
}
|
||||
|
||||
WaitFlag<HardEvent::FIX_M>(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM);
|
||||
ComputeQk(s1gL0RealSize, mmInfo.s2L0RealSize);
|
||||
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM);
|
||||
|
||||
FixpSToL1(s1gL0RealSize, mmInfo.s2L0RealSize);
|
||||
SetFlag<HardEvent::FIX_M>(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM);
|
||||
l0BufIdx_++;
|
||||
l0cBufIdx_++;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::CalcMmInfo(MmInfo &mmInfo, uint64_t loopIdx, uint64_t s1L0LoopCnt,
|
||||
const MmInfo &lastMmInfo, const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
mmInfo.s2L0LoopId = loopIdx / s1L0LoopCnt;
|
||||
mmInfo.s1gL0LoopId = loopIdx % s1L0LoopCnt;
|
||||
|
||||
if (mmInfo.s1gL0LoopId == 0) {
|
||||
mmInfo.s2GmOffset = mmInfo.s2L0LoopId * S2_BASIC_BLOCK_L0;
|
||||
mmInfo.s2L0RealSize = mmInfo.s2GmOffset + S2_BASIC_BLOCK_L0 > runInfo.actualSingleProcessSInnerSize
|
||||
? runInfo.actualSingleProcessSInnerSize - mmInfo.s2GmOffset
|
||||
: S2_BASIC_BLOCK_L0;
|
||||
} else {
|
||||
mmInfo.s2L0RealSize = lastMmInfo.s2L0RealSize;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::ComputeMm1(const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
if (runInfo.isFirstS2InnerLoop) {
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM);
|
||||
QueryNd2Nz(runInfo.actMBaseSize, runInfo); // 256 * 128 // L1BasicBlock
|
||||
WeightDmaCopy(runInfo.actMBaseSize, runInfo);
|
||||
}
|
||||
int64_t loopIdx = 0;
|
||||
int64_t s2L0LoopCnt = CeilDiv(runInfo.actualSingleProcessSInnerSize, S2_BASIC_BLOCK_L0); // 2048取128
|
||||
int64_t s1L0LoopCnt = CeilDiv(runInfo.actMBaseSize, S1G_BASIC_BLOCK_L0); // 256取128
|
||||
int64_t s1gL1Offset[2] = {0, static_cast<int64_t>(S1G_BASIC_BLOCK_L0)};
|
||||
int64_t s1gL0RealSize[2] = {s1L0LoopCnt > 1 ? static_cast<int64_t>(S1G_BASIC_BLOCK_L0) : runInfo.actMBaseSize,
|
||||
runInfo.actMBaseSize - s1gL1Offset[1]};
|
||||
MmInfo mmInfo[2];
|
||||
CalcMmInfo(mmInfo[loopIdx & 1], loopIdx, s1L0LoopCnt, mmInfo[(loopIdx + 1) & 1], runInfo);
|
||||
|
||||
ProcessQk(s1gL0RealSize[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt],
|
||||
s1gL1Offset[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt], s1L0LoopCnt, mmInfo[loopIdx & 1],
|
||||
runInfo);
|
||||
|
||||
SetFlag<HardEvent::FIX_MTE1>(FIX_MTE1_EVENT + sL1BufIdx_ % DOUBLE_BUF_NUM);
|
||||
sL1BufIdx_++;
|
||||
loopIdx++;
|
||||
|
||||
while (loopIdx < s2L0LoopCnt * s1L0LoopCnt) {
|
||||
CalcMmInfo(mmInfo[loopIdx & 1], loopIdx, s1L0LoopCnt, mmInfo[(loopIdx + 1) & 1], runInfo);
|
||||
|
||||
ProcessQk(s1gL0RealSize[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt],
|
||||
s1gL1Offset[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt], s1L0LoopCnt, mmInfo[loopIdx & 1],
|
||||
runInfo);
|
||||
|
||||
SetFlag<HardEvent::FIX_MTE1>(FIX_MTE1_EVENT + sL1BufIdx_ % DOUBLE_BUF_NUM);
|
||||
sL1BufIdx_++;
|
||||
|
||||
WaitFlag<HardEvent::FIX_MTE1>(FIX_MTE1_EVENT + sL1BufIdx_ % DOUBLE_BUF_NUM);
|
||||
|
||||
ProcessWs(s1gL0RealSize[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt],
|
||||
s1gL1Offset[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt], sL1BufIdx_,
|
||||
mmInfo[(loopIdx + 1) & 1], runInfo);
|
||||
loopIdx++;
|
||||
}
|
||||
|
||||
WaitFlag<HardEvent::FIX_MTE1>(FIX_MTE1_EVENT + (sL1BufIdx_ + 1) % DOUBLE_BUF_NUM);
|
||||
|
||||
ProcessWs(s1gL0RealSize[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt],
|
||||
s1gL1Offset[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt], sL1BufIdx_ - 1,
|
||||
mmInfo[(loopIdx + 1) & 1], runInfo);
|
||||
|
||||
if (runInfo.isLastS2InnerLoop) {
|
||||
SetFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM);
|
||||
qwL1Mte2BufIdx_++;
|
||||
}
|
||||
}
|
||||
|
||||
// blkNum, blkSize, N2, D
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset,
|
||||
const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
uint64_t s2L1Offset = 0;
|
||||
while (s2L1Offset < s2L1RealSize) {
|
||||
uint64_t s2BlkId = (s2L1Offset + s2GmOffset) / constInfo_.kCacheBlockSize;
|
||||
uint64_t s2BlkOffset = (s2L1Offset + s2GmOffset) % constInfo_.kCacheBlockSize;
|
||||
uint64_t keyGmOffset = blkTableGm_.GetValue(runInfo.bIdx * constInfo_.maxBlockNumPerBatch + s2BlkId) *
|
||||
constInfo_.kCacheBlockSize * constInfo_.kHeadNum * constInfo_.headDim +
|
||||
s2BlkOffset * constInfo_.headDim;
|
||||
uint64_t s2Mte2Size = s2L1RealSize - s2L1Offset;
|
||||
s2Mte2Size = s2BlkOffset + s2Mte2Size >= constInfo_.kCacheBlockSize ? constInfo_.kCacheBlockSize - s2BlkOffset
|
||||
: s2Mte2Size;
|
||||
Nd2NzParams nd2nzPara;
|
||||
nd2nzPara.ndNum = 1;
|
||||
nd2nzPara.nValue = s2Mte2Size; // 行数
|
||||
nd2nzPara.dValue = constInfo_.headDim;
|
||||
nd2nzPara.srcDValue = constInfo_.headDim;
|
||||
nd2nzPara.dstNzC0Stride = CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE); // 对齐到16 单位block
|
||||
nd2nzPara.dstNzNStride = 1;
|
||||
nd2nzPara.srcNdMatrixStride = 0;
|
||||
nd2nzPara.dstNzMatrixStride = 0;
|
||||
DataCopy(keyL1_[(keyL1BufIdx_ % DOUBLE_BUF_NUM) * KEY_BUFFER_OFFSET + s2L1Offset * S8_BLOCK_CUBE],
|
||||
keyGm_[keyGmOffset], nd2nzPara);
|
||||
|
||||
s2L1Offset += s2Mte2Size;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::KeyNd2Nz(uint64_t s2L1RealSize, const MmInfo &mmInfo,
|
||||
const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
uint64_t dStride = constInfo_.headDim;
|
||||
if constexpr (K_LAYOUT_T == LI_LAYOUT::BSND || K_LAYOUT_T == LI_LAYOUT::TND) {
|
||||
dStride = constInfo_.headDim * constInfo_.kHeadNum; // constInfo_.kHeadNum
|
||||
}
|
||||
Nd2NzParams nd2nzPara;
|
||||
nd2nzPara.ndNum = 1;
|
||||
nd2nzPara.nValue = s2L1RealSize; // 行数
|
||||
nd2nzPara.dValue = constInfo_.headDim;
|
||||
nd2nzPara.srcDValue = dStride;
|
||||
nd2nzPara.dstNzC0Stride = CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE); // 对齐到16 单位block
|
||||
nd2nzPara.dstNzNStride = 1;
|
||||
nd2nzPara.srcNdMatrixStride = 0;
|
||||
nd2nzPara.dstNzMatrixStride = 0;
|
||||
// 默认一块buf最多放两份
|
||||
DataCopy(keyL1_[(keyL1BufIdx_ % DOUBLE_BUF_NUM) * KEY_BUFFER_OFFSET],
|
||||
keyGm_[runInfo.tensorKeyOffset + mmInfo.s2GmOffset * constInfo_.headDim], nd2nzPara);
|
||||
}
|
||||
|
||||
// batch, s1, g, 1
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::WeightDmaCopy(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
DataCopyParams copyInParams;
|
||||
copyInParams.blockCount = 1;
|
||||
copyInParams.blockLen = s1gL1RealSize;
|
||||
copyInParams.srcStride = 0;
|
||||
copyInParams.dstStride = 0;
|
||||
DataCopy(weightL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * WEIGHT_BUFFER_OFFSET],
|
||||
weightGm_[runInfo.loop % DOUBLE_BUF_NUM * BLOCK_CUBE * constInfo_.mBaseSize], copyInParams);
|
||||
}
|
||||
|
||||
// batch, s1, n2, g, d
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::QueryNd2Nz(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
Nd2NzParams nd2nzPara;
|
||||
nd2nzPara.ndNum = 1;
|
||||
nd2nzPara.nValue = s1gL1RealSize; // 行数
|
||||
nd2nzPara.dValue = constInfo_.headDim;
|
||||
nd2nzPara.srcDValue = constInfo_.headDim;
|
||||
nd2nzPara.dstNzC0Stride = CeilAlign(s1gL1RealSize, (uint64_t)BLOCK_CUBE); // 对齐到16 单位block
|
||||
nd2nzPara.dstNzNStride = 1;
|
||||
nd2nzPara.srcNdMatrixStride = 0;
|
||||
nd2nzPara.dstNzMatrixStride = 0;
|
||||
// 默认一块buf最多放两份
|
||||
DataCopy(queryL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * QUERY_BUFFER_OFFSET], queryGm_[runInfo.tensorQueryOffset],
|
||||
nd2nzPara);
|
||||
}
|
||||
|
||||
// s1g, d
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::LoadQueryToL0a(uint64_t s1gL1Offset, uint64_t s1gL1RealSize,
|
||||
uint64_t s1gL0RealSize)
|
||||
{
|
||||
LoadData3DParamsV2<Q_T> loadData3DParams;
|
||||
// SetFmatrixParams
|
||||
loadData3DParams.l1H = CeilDiv(s1gL1RealSize, BLOCK_CUBE); // Hin=M1=8
|
||||
loadData3DParams.l1W = BLOCK_CUBE; // Win=M0
|
||||
loadData3DParams.channelSize = constInfo_.headDim; // Cin=K
|
||||
|
||||
loadData3DParams.padList[0] = 0;
|
||||
loadData3DParams.padList[1] = 0;
|
||||
loadData3DParams.padList[2] = 0;
|
||||
loadData3DParams.padList[3] = 255; // 尾部数据不影响滑窗的结果
|
||||
|
||||
// SetLoadToA0Params
|
||||
loadData3DParams.mExtension = CeilAlign(s1gL0RealSize, BLOCK_CUBE); // M height维度目的
|
||||
loadData3DParams.kExtension = constInfo_.headDim; // K width维度目的
|
||||
loadData3DParams.mStartPt = s1gL1Offset;
|
||||
loadData3DParams.kStartPt = 0;
|
||||
loadData3DParams.strideW = 1;
|
||||
loadData3DParams.strideH = 1;
|
||||
loadData3DParams.filterW = 1;
|
||||
loadData3DParams.filterSizeW = (1 >> 8) & 255;
|
||||
loadData3DParams.filterH = 1;
|
||||
loadData3DParams.filterSizeH = (1 >> 8) & 255;
|
||||
loadData3DParams.dilationFilterW = 1;
|
||||
loadData3DParams.dilationFilterH = 1;
|
||||
loadData3DParams.enTranspose = 0;
|
||||
loadData3DParams.fMatrixCtrl = 0;
|
||||
|
||||
LoadData<Q_T, LOAD3DV2_CONFIG>(l0a_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K],
|
||||
queryL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * QUERY_BUFFER_OFFSET],
|
||||
loadData3DParams);
|
||||
}
|
||||
|
||||
// s1, g, s2 --> 2 * 64* 128
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::LoadSToL0b(uint64_t s1gL1RealSize, uint64_t s2L0RealSize, uint64_t sL1BufIdx,
|
||||
int64_t mStartPt)
|
||||
{
|
||||
LoadData3DParamsV2<half> loadData3DParams;
|
||||
// SetFmatrixParams
|
||||
loadData3DParams.l1H = S1G_BASIC_BLOCK_L0 / BLOCK_CUBE; // Hin=M1=8
|
||||
loadData3DParams.l1W = BLOCK_CUBE; // Win=M0
|
||||
loadData3DParams.channelSize = CeilAlign(s2L0RealSize, BLOCK_CUBE); // Cin=K
|
||||
|
||||
loadData3DParams.padList[0] = 0;
|
||||
loadData3DParams.padList[1] = 0;
|
||||
loadData3DParams.padList[2] = 0;
|
||||
loadData3DParams.padList[3] = 255; // 尾部数据不影响滑窗的结果
|
||||
|
||||
// SetLoadToA0Params
|
||||
loadData3DParams.mExtension = constInfo_.gSize; // M height维度目的
|
||||
loadData3DParams.kExtension = CeilAlign(s2L0RealSize, BLOCK_CUBE); // K width维度目的
|
||||
loadData3DParams.kStartPt = 0;
|
||||
loadData3DParams.strideW = 1;
|
||||
loadData3DParams.strideH = 1;
|
||||
loadData3DParams.filterW = 1;
|
||||
loadData3DParams.filterSizeW = (1 >> 8) & 255;
|
||||
loadData3DParams.filterH = 1;
|
||||
loadData3DParams.filterSizeH = (1 >> 8) & 255;
|
||||
loadData3DParams.dilationFilterW = 1;
|
||||
loadData3DParams.dilationFilterH = 1;
|
||||
loadData3DParams.enTranspose = 1;
|
||||
loadData3DParams.fMatrixCtrl = 0;
|
||||
|
||||
loadData3DParams.mStartPt = mStartPt;
|
||||
LoadData<half, LOAD3DV2_CONFIG>(
|
||||
l0b_.template ReinterpretCast<half>()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K],
|
||||
sL1_[(sL1BufIdx % DOUBLE_BUF_NUM) * SL1_BUFFER_OFFSET], loadData3DParams);
|
||||
}
|
||||
|
||||
// s1,g,1(16), 2,64,16
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::LoadWeightToL0a(uint64_t s1gL1Offset)
|
||||
{
|
||||
LoadData2DParams loadData2DParams;
|
||||
loadData2DParams.startIndex = 0;
|
||||
loadData2DParams.repeatTimes = CeilDiv(constInfo_.gSize, BLOCK_CUBE);
|
||||
loadData2DParams.srcStride = 1;
|
||||
loadData2DParams.dstGap = 0;
|
||||
loadData2DParams.ifTranspose = true;
|
||||
LoadData(l0a_.template ReinterpretCast<half>()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K],
|
||||
weightL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * WEIGHT_BUFFER_OFFSET + s1gL1Offset* BLOCK_CUBE],
|
||||
loadData2DParams);
|
||||
}
|
||||
|
||||
// s2, d -> 128,128
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::LoadKeyToL0b(uint64_t s2L0RealSize)
|
||||
{
|
||||
LoadData2DParams loadData2DParams;
|
||||
loadData2DParams.startIndex = 0;
|
||||
loadData2DParams.repeatTimes = CeilDiv(s2L0RealSize, BLOCK_CUBE) * CeilDiv(constInfo_.headDim, S8_BLOCK_CUBE);
|
||||
loadData2DParams.srcStride = 1;
|
||||
loadData2DParams.dstGap = 0;
|
||||
loadData2DParams.ifTranspose = false;
|
||||
LoadData(l0b_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K],
|
||||
keyL1_[(keyL1BufIdx_ % DOUBLE_BUF_NUM) * KEY_BUFFER_OFFSET], loadData2DParams);
|
||||
}
|
||||
|
||||
// A: s1,g,1(16) B: s1,g,s2 C: s1, 1(16), s2
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::ComputeWs(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, int64_t s1gOffset)
|
||||
{
|
||||
SetFlag<HardEvent::MTE1_M>(MTE1_M_EVENT);
|
||||
WaitFlag<HardEvent::MTE1_M>(MTE1_M_EVENT);
|
||||
MmadParams mmadParams;
|
||||
mmadParams.m = BLOCK_CUBE;
|
||||
mmadParams.n = s2L0RealSize;
|
||||
mmadParams.k = constInfo_.gSize;
|
||||
mmadParams.cmatrixInitVal = true;
|
||||
mmadParams.cmatrixSource = false;
|
||||
Mmad(cL0_.template ReinterpretCast<float>()[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET +
|
||||
s1gOffset * S2_BASIC_BLOCK_L0],
|
||||
l0a_.template ReinterpretCast<half>()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K],
|
||||
l0b_.template ReinterpretCast<half>()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K],
|
||||
mmadParams);
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::ComputeQk(uint64_t s1gL0RealSize, uint64_t s2L0RealSize)
|
||||
{
|
||||
SetFlag<HardEvent::MTE1_M>(MTE1_M_EVENT);
|
||||
WaitFlag<HardEvent::MTE1_M>(MTE1_M_EVENT);
|
||||
|
||||
MmadParams mmadParams;
|
||||
mmadParams.m = CeilAlign(s1gL0RealSize, BLOCK_CUBE);
|
||||
mmadParams.n = s2L0RealSize;
|
||||
mmadParams.k = constInfo_.headDim;
|
||||
mmadParams.cmatrixInitVal = true;
|
||||
mmadParams.cmatrixSource = false;
|
||||
Mmad(cL0_[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET],
|
||||
l0a_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K],
|
||||
l0b_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K], mmadParams);
|
||||
if ((mmadParams.m / 16) * (mmadParams.n / 16) < 10) {
|
||||
PipeBarrier<PIPE_M>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::FixpSToL1(uint64_t s1gL0RealSize, uint64_t s2L0RealSize)
|
||||
{
|
||||
SetFlag<HardEvent::M_FIX>(M_FIX_EVENT);
|
||||
WaitFlag<HardEvent::M_FIX>(M_FIX_EVENT);
|
||||
DataCopyCO12DstParams params;
|
||||
params.mSize = CeilAlign(s1gL0RealSize, BLOCK_CUBE);
|
||||
params.nSize = CeilAlign(s2L0RealSize, BLOCK_CUBE);
|
||||
params.dstStride = S1G_BASIC_BLOCK_L0;
|
||||
params.srcStride = params.mSize;
|
||||
params.quantPre = QuantMode_t::DEQF16;
|
||||
params.reluPre = 1;
|
||||
params.channelSplit = 0;
|
||||
params.nz2ndEn = 0;
|
||||
SetFixpipePreQuantFlag(0x3a800000);
|
||||
DataCopy(sL1_[(sL1BufIdx_ % DOUBLE_BUF_NUM) * SL1_BUFFER_OFFSET],
|
||||
cL0_[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET], params);
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::FixpResToGm(uint64_t s1L0RealCount, uint64_t s2L0RealSize, uint64_t s1GmOffset,
|
||||
uint64_t s2GmOffset, const LIQCommon::RunInfo &runInfo)
|
||||
{
|
||||
SetFlag<HardEvent::M_FIX>(M_FIX_EVENT);
|
||||
WaitFlag<HardEvent::M_FIX>(M_FIX_EVENT);
|
||||
|
||||
AscendC::DataCopyCO12DstParams intriParams;
|
||||
intriParams.mSize = 1;
|
||||
intriParams.nSize = s2L0RealSize;
|
||||
intriParams.dstStride = constInfo_.s2BaseSize;
|
||||
intriParams.srcStride = 16;
|
||||
// set mode according to dtype
|
||||
intriParams.quantPre = QuantMode_t::NoQuant;
|
||||
intriParams.nz2ndEn = true;
|
||||
intriParams.reluPre = 0;
|
||||
AscendC::SetFixpipeNz2ndFlag(s1L0RealCount, CeilDiv(constInfo_.gSize, BLOCK_CUBE) * S2_BASIC_BLOCK_L0 / BLOCK_CUBE,
|
||||
2048);
|
||||
AscendC::DataCopy(mm1ResGm_[(runInfo.loop % 2) * constInfo_.mBaseSize / constInfo_.gSize * constInfo_.s2BaseSize +
|
||||
s1GmOffset * intriParams.dstStride + s2GmOffset],
|
||||
cL0_.template ReinterpretCast<float>()[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET],
|
||||
intriParams);
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::AllocEventID()
|
||||
{
|
||||
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 0);
|
||||
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 1);
|
||||
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 2);
|
||||
|
||||
SetFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + 0);
|
||||
SetFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + 1);
|
||||
|
||||
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 0);
|
||||
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 1);
|
||||
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 2);
|
||||
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 3);
|
||||
|
||||
SetFlag<HardEvent::FIX_M>(FIX_M_EVENT + 0);
|
||||
SetFlag<HardEvent::FIX_M>(FIX_M_EVENT + 1);
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQMatmul<LIQT>::FreeEventID()
|
||||
{
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 0);
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 1);
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 2);
|
||||
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + 0);
|
||||
WaitFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + 1);
|
||||
|
||||
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 0);
|
||||
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 1);
|
||||
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 2);
|
||||
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 3);
|
||||
|
||||
WaitFlag<HardEvent::FIX_M>(FIX_M_EVENT + 0);
|
||||
WaitFlag<HardEvent::FIX_M>(FIX_M_EVENT + 1);
|
||||
}
|
||||
} // namespace LIQKernel
|
||||
#endif
|
||||
@@ -0,0 +1,665 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_service_vector.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef LIGHTNING_INDEXER_QUANT_SERVICE_VECTOR_H
|
||||
#define LIGHTNING_INDEXER_QUANT_SERVICE_VECTOR_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_operator_list_tensor_intf.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
#include "lib/matrix/matmul/tiling.h"
|
||||
#include "lightning_indexer_quant_common.h"
|
||||
#include "lightning_indexer_quant_vector.h"
|
||||
|
||||
namespace LIQKernel {
|
||||
using namespace LIQCommon;
|
||||
using namespace LIQServiceVec;
|
||||
constexpr uint32_t BASE_TOPK = 2048;
|
||||
constexpr uint32_t BASE_TOPK_VALUE_IDX_SIZE = 4096;
|
||||
constexpr uint32_t LD_PARAM_NUM = 16;
|
||||
|
||||
template <typename LIQT>
|
||||
class LIQVector {
|
||||
public:
|
||||
// =================================类型定义区=================================
|
||||
static constexpr LI_LAYOUT Q_LAYOUT_T = LIQT::layout;
|
||||
static constexpr LI_LAYOUT K_LAYOUT_T = LIQT::keyLayout;
|
||||
static constexpr bool PAGE_ATTENTION = LIQT::pageAttention;
|
||||
// MM输出数据类型, 当前只支持float
|
||||
using MM1_OUT_T = float;
|
||||
|
||||
__aicore__ inline LIQVector(){};
|
||||
__aicore__ inline void ProcessVec0(const LIQCommon::RunInfo &info);
|
||||
__aicore__ inline void ProcessVec1(const LIQCommon::RunInfo &info);
|
||||
__aicore__ inline void ProcessLD();
|
||||
__aicore__ inline void InitBuffers(TPipe *pipe);
|
||||
__aicore__ inline void InitParams(const struct LIQCommon::ConstInfo &constInfo,
|
||||
const LIQTilingData *__restrict tilingData);
|
||||
__aicore__ inline void InitVecWorkspaceTensor(GlobalTensor<half> vec0OutGm, GlobalTensor<MM1_OUT_T> mm1ResGm,
|
||||
GlobalTensor<float> vec1ResGm, GlobalTensor<int64_t> vec1ParamGm);
|
||||
__aicore__ inline void InitVecInputTensor(GlobalTensor<half> weightsGm, GlobalTensor<half> qScaleGm,
|
||||
GlobalTensor<half> kScaleGm, GlobalTensor<int32_t> indiceOutGm,
|
||||
GlobalTensor<int32_t> blockTableGm);
|
||||
__aicore__ inline void CleanInvalidOutput(int64_t invalidS1offset);
|
||||
__aicore__ inline void AllocEventID();
|
||||
__aicore__ inline void FreeEventID();
|
||||
__aicore__ inline void InitLDBuffers(TPipe *pipe);
|
||||
|
||||
protected:
|
||||
GlobalTensor<MM1_OUT_T> mm1ResGm;
|
||||
GlobalTensor<float> vec1ResGm;
|
||||
GlobalTensor<int64_t> vec1ParamGm;
|
||||
GlobalTensor<half> weightsGm;
|
||||
GlobalTensor<half> qScaleGm;
|
||||
GlobalTensor<half> kScaleGm;
|
||||
GlobalTensor<half> vec0OutGm;
|
||||
GlobalTensor<int32_t> indiceOutGm;
|
||||
GlobalTensor<int32_t> blockTableGm;
|
||||
// =================================常量区=================================
|
||||
|
||||
private:
|
||||
__aicore__ inline void GetKeyScale(const LIQCommon::RunInfo &runInfo, const LocalTensor<half> &resUb,
|
||||
int64_t batchId, int64_t startS2, int64_t getLen);
|
||||
// ================================Local Buffer区====================================
|
||||
// queue
|
||||
TQue<QuePosition::VECIN, 1> inQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> outQueue_;
|
||||
|
||||
// tmp buff for vector
|
||||
TBuf<TPosition::VECCALC> sortOutBuf_;
|
||||
TBuf<TPosition::VECCALC> indexBuf_;
|
||||
TBuf<TPosition::VECCALC> paramBuf_;
|
||||
TBuf<TPosition::VECCALC> tmpBuf_;
|
||||
|
||||
// tmp buff for LD
|
||||
TBuf<> ldToBeMrgBuf_;
|
||||
TBuf<> ldTmpBuf_;
|
||||
TBuf<> ldOutValueBuf_;
|
||||
TBuf<> ldOutIdxBuf_;
|
||||
|
||||
LocalTensor<int32_t> globalTopkIndice_;
|
||||
LocalTensor<float> globalTopkUb_;
|
||||
|
||||
int32_t blockId_ = -1;
|
||||
// para for vector
|
||||
int32_t groupInner_ = 0;
|
||||
int32_t globalTopkNum_ = 0;
|
||||
int64_t blockS2StartIdx_ = 0;
|
||||
int32_t gSize_ = 0;
|
||||
int32_t kSeqSize_ = 0;
|
||||
int32_t kHeadNum_ = 0;
|
||||
int32_t qHeadNum_ = 0;
|
||||
int32_t s1BaseSize_ = 0;
|
||||
int32_t s2BaseSize_ = 0;
|
||||
int32_t kCacheBlockSize_ = 0;
|
||||
int32_t maxBlockNumPerBatch_ = 0;
|
||||
|
||||
// para for LD
|
||||
uint32_t mrgListNum_ = 4;
|
||||
uint32_t paramNum_ = 16;
|
||||
|
||||
struct LIQCommon::ConstInfo constInfo_;
|
||||
};
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::GetKeyScale(const LIQCommon::RunInfo &runInfo, const LocalTensor<half> &resUb,
|
||||
int64_t batchId, int64_t startS2, int64_t getLen)
|
||||
{
|
||||
// startS2一定能整除kCacheBlockSize_
|
||||
AscendC::DataCopyPadExtParams<half> padParams{false, 0, 0, 0};
|
||||
AscendC::DataCopyExtParams copyInParams;
|
||||
if constexpr (PAGE_ATTENTION) {
|
||||
int32_t startBlockTableIdx = startS2 / kCacheBlockSize_;
|
||||
int32_t startBlockTableOffset = startS2 % kCacheBlockSize_;
|
||||
int32_t blockTableBatchOffset = batchId * maxBlockNumPerBatch_;
|
||||
copyInParams.blockCount = 1;
|
||||
copyInParams.srcStride = 0;
|
||||
copyInParams.dstStride = 0;
|
||||
copyInParams.rsv = 0;
|
||||
int32_t resUbBaseOffset = 0;
|
||||
if (startBlockTableOffset > 0) {
|
||||
int32_t firstPartLen =
|
||||
kCacheBlockSize_ - startBlockTableOffset > getLen ? getLen : kCacheBlockSize_ - startBlockTableOffset;
|
||||
copyInParams.blockLen = firstPartLen * sizeof(half);
|
||||
int32_t blockId = blockTableGm.GetValue(blockTableBatchOffset + startBlockTableIdx);
|
||||
SetWaitFlag<HardEvent::S_MTE2>(HardEvent::S_MTE2);
|
||||
AscendC::DataCopyPad(resUb, kScaleGm[blockId * kCacheBlockSize_ + startBlockTableOffset],
|
||||
copyInParams, padParams);
|
||||
startBlockTableIdx++;
|
||||
getLen = getLen - firstPartLen;
|
||||
resUbBaseOffset = firstPartLen;
|
||||
}
|
||||
int32_t getLoopNum = CeilDiv(getLen, kCacheBlockSize_);
|
||||
copyInParams.blockLen = kCacheBlockSize_ * sizeof(half);
|
||||
for (int32_t i = 0; i < getLoopNum; i++) {
|
||||
if (i == getLoopNum - 1) {
|
||||
copyInParams.blockLen = (getLen - i * kCacheBlockSize_) * sizeof(half);
|
||||
}
|
||||
int32_t blockId = blockTableGm.GetValue(blockTableBatchOffset + startBlockTableIdx + i);
|
||||
SetWaitFlag<HardEvent::S_MTE2>(HardEvent::S_MTE2);
|
||||
AscendC::DataCopyPad(resUb[resUbBaseOffset + i * kCacheBlockSize_], kScaleGm[blockId * kCacheBlockSize_],
|
||||
copyInParams, padParams);
|
||||
}
|
||||
} else {
|
||||
copyInParams.blockCount = 1;
|
||||
copyInParams.blockLen = getLen * sizeof(half);
|
||||
copyInParams.srcStride = 0;
|
||||
copyInParams.dstStride = 0;
|
||||
copyInParams.rsv = 0;
|
||||
AscendC::DataCopyPad(resUb, kScaleGm[runInfo.tensorKeyScaleOffset], copyInParams, padParams);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::InitBuffers(TPipe *pipe)
|
||||
{
|
||||
pipe->InitBuffer(paramBuf_, LD_PARAM_NUM * sizeof(int64_t)); // 1 KB
|
||||
pipe->InitBuffer(inQueue_, 2, s2BaseSize_ * sizeof(float) * 2); // 32KB
|
||||
pipe->InitBuffer(outQueue_, 1, BASE_TOPK * sizeof(float)); // 8 KB
|
||||
pipe->InitBuffer(indexBuf_, s2BaseSize_ * sizeof(int32_t)); // 8 KB
|
||||
pipe->InitBuffer(tmpBuf_, 64 * 1024); // 64KB
|
||||
pipe->InitBuffer(sortOutBuf_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK_VALUE_IDX_SIZE * sizeof(float)); // 32KB
|
||||
|
||||
globalTopkIndice_ = indexBuf_.Get<int32_t>();
|
||||
globalTopkUb_ = sortOutBuf_.Get<float>();
|
||||
globalTopkNum_ = 0;
|
||||
|
||||
// 基本块执行前初始化UB和GM
|
||||
// step1. 初始化一个有序索引 0 - s2BaseSize_
|
||||
ArithProgression<int32_t>(globalTopkIndice_, 0, 1, s2BaseSize_);
|
||||
// step2. globalTopkUb_ [CeilDiv(s1BaseSize_, 2), BASE_TOPK, 2] -inf,-1
|
||||
InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK_VALUE_IDX_SIZE);
|
||||
|
||||
// step3. 初始化vec1ParamGm,是否进行LD的标志位设为-1(needFd=-1)
|
||||
// vec1ResIn32Gm = [aic, 2, s1BaseSize_, 16] int32
|
||||
// ws清零 [needFd, s2AcSeq, s2Start, s2End, isS2End, bn2idx, s1Idx, ......]
|
||||
LocalTensor<float> tmpfBuff = outQueue_.AllocTensor<float>();
|
||||
Duplicate(tmpfBuff.template ReinterpretCast<int32_t>(), -1, 2 * (s1BaseSize_ / 2) * paramNum_ * 2);
|
||||
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
|
||||
int64_t wsInfoOffset = (blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ + // 2个AIV共同地址偏移
|
||||
(blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_; // 每个AIV的地址偏移,S1方向
|
||||
DataCopyPad(vec1ParamGm[wsInfoOffset], tmpfBuff.template ReinterpretCast<int64_t>(),
|
||||
{1, static_cast<uint16_t>((s1BaseSize_ / 2) * 2 * paramNum_ * sizeof(int64_t)), 0, 0});
|
||||
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
|
||||
outQueue_.FreeTensor(tmpfBuff);
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::InitLDBuffers(TPipe *pipe)
|
||||
{
|
||||
pipe->Reset();
|
||||
pipe->InitBuffer(ldToBeMrgBuf_, BASE_TOPK_VALUE_IDX_SIZE * mrgListNum_ * sizeof(float));
|
||||
pipe->InitBuffer(ldTmpBuf_, BASE_TOPK_VALUE_IDX_SIZE * mrgListNum_ * sizeof(float));
|
||||
pipe->InitBuffer(ldOutValueBuf_, BASE_TOPK * sizeof(float));
|
||||
pipe->InitBuffer(ldOutIdxBuf_, BASE_TOPK * sizeof(int32_t));
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::InitParams(const struct LIQCommon::ConstInfo &constInfo,
|
||||
const LIQTilingData *__restrict tilingData)
|
||||
{
|
||||
this->constInfo_ = constInfo;
|
||||
blockS2StartIdx_ = 0;
|
||||
gSize_ = constInfo.gSize;
|
||||
kSeqSize_ = constInfo.kSeqSize;
|
||||
// define N2 para
|
||||
kHeadNum_ = constInfo.kHeadNum;
|
||||
qHeadNum_ = constInfo.qHeadNum;
|
||||
// define MMBase para
|
||||
s1BaseSize_ = constInfo.s1BaseSize; // 4
|
||||
s2BaseSize_ = constInfo.s2BaseSize; // 2048
|
||||
kCacheBlockSize_ = constInfo.kCacheBlockSize;
|
||||
maxBlockNumPerBatch_ = constInfo.maxBlockNumPerBatch;
|
||||
blockId_ = GetBlockIdx();
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::InitVecInputTensor(GlobalTensor<half> weightsGm, GlobalTensor<half> qScaleGm,
|
||||
GlobalTensor<half> kScaleGm,
|
||||
GlobalTensor<int32_t> indiceOutGm,
|
||||
GlobalTensor<int32_t> blockTableGm)
|
||||
{
|
||||
this->weightsGm = weightsGm;
|
||||
this->qScaleGm = qScaleGm;
|
||||
this->kScaleGm = kScaleGm;
|
||||
this->indiceOutGm = indiceOutGm;
|
||||
this->blockTableGm = blockTableGm;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::InitVecWorkspaceTensor(GlobalTensor<half> vec0OutGm,
|
||||
GlobalTensor<MM1_OUT_T> mm1ResGm,
|
||||
GlobalTensor<float> vec1ResGm,
|
||||
GlobalTensor<int64_t> vec1ParamGm)
|
||||
{
|
||||
this->mm1ResGm = mm1ResGm;
|
||||
this->vec1ResGm = vec1ResGm;
|
||||
this->vec0OutGm = vec0OutGm;
|
||||
this->vec1ParamGm = vec1ParamGm;
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::AllocEventID()
|
||||
{
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::FreeEventID()
|
||||
{
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::CleanInvalidOutput(int64_t invalidS1offset)
|
||||
{
|
||||
// init -1 and copy to output
|
||||
LocalTensor<float> valueULocal = outQueue_.AllocTensor<float>();
|
||||
LocalTensor<int32_t> idxULocal1 = valueULocal.template ReinterpretCast<int32_t>();
|
||||
Duplicate(idxULocal1, constInfo_.INVALID_IDX, constInfo_.sparseCount);
|
||||
outQueue_.EnQue<float>(valueULocal);
|
||||
valueULocal = outQueue_.DeQue<float>();
|
||||
LIQServiceVec::CopyOut(indiceOutGm[invalidS1offset], idxULocal1, constInfo_.sparseCount);
|
||||
outQueue_.FreeTensor(valueULocal);
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::ProcessVec0(const LIQCommon::RunInfo &info)
|
||||
{
|
||||
// 只需要一个v核做
|
||||
if (blockId_ % 2 != 0) {
|
||||
return;
|
||||
}
|
||||
int32_t cuBaseS1Idx = info.gS1Idx * s1BaseSize_;
|
||||
// 计算输出w基地址偏移 偶数循环 -> 0 + aic_offset 奇数循环 -> 4*64 + aic_offset
|
||||
int64_t vec0OutGmOffset = (info.loop % 2) * ((s1BaseSize_ * gSize_ * BLOCK_CUBE));
|
||||
// 计算输入weight的地址偏移,qScale的地址偏移与weight相同
|
||||
int64_t weightGmOffset = info.tensorWeightsOffset + cuBaseS1Idx * qHeadNum_;
|
||||
// 当前需要计算的S1行数,处理尾块场景
|
||||
int32_t cuS1ProcNum = cuBaseS1Idx + s1BaseSize_ > info.actS1Size ? info.actS1Size % s1BaseSize_ : s1BaseSize_;
|
||||
int32_t cuProcEleNum = cuS1ProcNum * gSize_;
|
||||
|
||||
LocalTensor<half> inWeightsUb = inQueue_.AllocTensor<half>();
|
||||
LocalTensor<half> inQScaleUb = inWeightsUb[cuProcEleNum];
|
||||
AscendC::DataCopyPadExtParams<half> padParams{false, 0, 0, 0};
|
||||
AscendC::DataCopyExtParams copyInParams;
|
||||
copyInParams.blockCount = 1;
|
||||
copyInParams.blockLen = cuProcEleNum * sizeof(half);
|
||||
copyInParams.srcStride = 0;
|
||||
copyInParams.dstStride = 0;
|
||||
copyInParams.rsv = 0;
|
||||
AscendC::DataCopyPad(inWeightsUb, weightsGm[weightGmOffset], copyInParams, padParams);
|
||||
AscendC::DataCopyPad(inQScaleUb, qScaleGm[weightGmOffset], copyInParams, padParams);
|
||||
|
||||
inQueue_.EnQue<half>(inWeightsUb);
|
||||
inWeightsUb = inQueue_.DeQue<half>();
|
||||
AscendC::Mul(inWeightsUb, inWeightsUb, inQScaleUb, cuProcEleNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<half> resUb = outQueue_.AllocTensor<half>();
|
||||
AscendC::Brcb(resUb, inWeightsUb, static_cast<uint8_t>(cuProcEleNum / 8), {1, 8});
|
||||
inQueue_.FreeTensor(inWeightsUb);
|
||||
|
||||
outQueue_.EnQue<half>(resUb);
|
||||
resUb = outQueue_.DeQue<half>();
|
||||
AscendC::DataCopyParams copyOutParams;
|
||||
copyOutParams.blockCount = 1;
|
||||
copyOutParams.blockLen = cuProcEleNum * BLOCK_CUBE * sizeof(half);
|
||||
copyOutParams.srcStride = 0;
|
||||
copyOutParams.dstStride = 0;
|
||||
AscendC::DataCopyPad(vec0OutGm[vec0OutGmOffset], resUb, copyOutParams);
|
||||
outQueue_.FreeTensor(resUb);
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::ProcessVec1(const LIQCommon::RunInfo &info)
|
||||
{
|
||||
int32_t cuBaseS1Idx = info.gS1Idx * s1BaseSize_;
|
||||
int32_t cuBaseS2Idx = info.s2Idx * s2BaseSize_;
|
||||
|
||||
// 计算基本块基地址偏移 偶数循环 -> 0 + aic_offset 奇数循环 -> 4*2048 + aic_offset
|
||||
int64_t mmGmOffset = (info.loop % 2) * (s1BaseSize_ * s2BaseSize_);
|
||||
|
||||
// cuS1BeginIdxPerAiv: 每个AIV的S1起始偏移
|
||||
int32_t cuS1BeginIdxPerAiv = cuBaseS1Idx;
|
||||
int32_t cuS1ProcNum =
|
||||
cuS1BeginIdxPerAiv + s1BaseSize_ > info.actS1Size ? info.actS1Size % s1BaseSize_ : s1BaseSize_;
|
||||
// cuS1ProcNumPerAiv: 每个AIv的S1计算量
|
||||
int32_t cuS1ProcNumPerAiv = blockId_ % 2 == 0 ? CeilDiv(cuS1ProcNum, 2) : (cuS1ProcNum / 2);
|
||||
cuS1BeginIdxPerAiv += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2);
|
||||
// 基本块基地址偏移奇数核加一个S1地址偏移
|
||||
mmGmOffset += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2) * s2BaseSize_;
|
||||
// 非首个基本块, M(S1)轴发生切换需要初始化
|
||||
if (info.loop != 0 && info.s2Idx == 0) {
|
||||
// globalTopkUb_ value,index=-inf,-1
|
||||
InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK_VALUE_IDX_SIZE);
|
||||
blockS2StartIdx_ = 0;
|
||||
} else if (info.loop == 0) {
|
||||
blockS2StartIdx_ = info.s2Idx;
|
||||
}
|
||||
// cuRealAcSeq: 当前基本块S1对应的AcSeq
|
||||
int32_t cuRealAcSeq = info.actS2Size;
|
||||
if (constInfo_.attenMaskFlag) {
|
||||
// attenMask true场景
|
||||
cuRealAcSeq = info.actS2Size - (info.actS1Size - cuS1BeginIdxPerAiv);
|
||||
}
|
||||
|
||||
// LD输出S1方向偏移,保证2个Vector输出的内容连续
|
||||
uint32_t ldS1Offset = (blockId_ % 2 == 0) ? s1BaseSize_ / 2 - cuS1ProcNumPerAiv : 0;
|
||||
for (int innerS1Idx = 0; innerS1Idx < cuS1ProcNumPerAiv; innerS1Idx++) {
|
||||
if (constInfo_.attenMaskFlag) {
|
||||
cuRealAcSeq += 1;
|
||||
}
|
||||
int32_t cuS2Len = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq ? cuRealAcSeq - cuBaseS2Idx : s2BaseSize_;
|
||||
int32_t cuS1Idx = cuS1BeginIdxPerAiv + innerS1Idx;
|
||||
if (cuRealAcSeq > 0 && cuS2Len > 0) {
|
||||
int32_t cuS2LenVecAlign = CeilDiv(cuS2Len, s2BaseSize_) * s2BaseSize_;
|
||||
LocalTensor<float> mmInUb = inQueue_.AllocTensor<float>();
|
||||
LocalTensor<float> kScaleUb = mmInUb[cuS2LenVecAlign];
|
||||
LocalTensor<half> kScaleTUb = kScaleUb.template ReinterpretCast<half>()[cuS2LenVecAlign];
|
||||
AscendC::DataCopyPadExtParams<float> padParams{false, 0, 0, 0};
|
||||
AscendC::DataCopyPadExtParams<half> padTParams{false, 0, 0, 0};
|
||||
AscendC::DataCopyExtParams copyInParams;
|
||||
copyInParams.blockCount = 1;
|
||||
copyInParams.blockLen = cuS2Len * sizeof(float);
|
||||
copyInParams.srcStride = 0;
|
||||
copyInParams.dstStride = 0;
|
||||
copyInParams.rsv = 0;
|
||||
AscendC::DataCopyPad(mmInUb, mm1ResGm[mmGmOffset + innerS1Idx * s2BaseSize_], copyInParams, padParams);
|
||||
GetKeyScale(info, kScaleTUb, info.bIdx, cuBaseS2Idx, cuS2Len);
|
||||
inQueue_.EnQue<float>(mmInUb);
|
||||
mmInUb = inQueue_.DeQue<float>();
|
||||
AscendC::Cast(kScaleUb, kScaleTUb, RoundMode::CAST_NONE, cuS2Len);
|
||||
PipeBarrier<PIPE_V>();
|
||||
AscendC::Mul(mmInUb, mmInUb, kScaleUb, cuS2Len);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<float> sortBuff = tmpBuf_.Get<float>();
|
||||
LocalTensor<float> sortScoreUb = sortBuff;
|
||||
LocalTensor<float> sortIndiceUb = sortBuff[cuS2LenVecAlign];
|
||||
PipeBarrier<PIPE_V>();
|
||||
Duplicate(sortScoreUb.template ReinterpretCast<int32_t>(), LIQServiceVec::NEG_INF, cuS2LenVecAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(sortScoreUb, mmInUb, 0.0f, cuS2Len);
|
||||
PipeBarrier<PIPE_V>();
|
||||
inQueue_.FreeTensor(mmInUb);
|
||||
LocalTensor<int32_t> sortIndiceUbInt = sortIndiceUb.template ReinterpretCast<int32_t>();
|
||||
// 无效数据索引填充为-1
|
||||
if (cuS2LenVecAlign != cuS2Len) {
|
||||
Duplicate(sortIndiceUbInt, -1, cuS2LenVecAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
Adds(sortIndiceUbInt, globalTopkIndice_, static_cast<int32_t>(cuBaseS2Idx), cuS2Len);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<float> tmpSortBuf = sortBuff[2 * cuS2LenVecAlign];
|
||||
LIQServiceVec::SortAll(sortBuff, tmpSortBuf, cuS2LenVecAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LIQServiceVec::MergeSort(globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE], BASE_TOPK, sortBuff,
|
||||
cuS2LenVecAlign, tmpSortBuf);
|
||||
PipeBarrier<PIPE_V>();
|
||||
bool isS2End = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq;
|
||||
bool needCopyOutGm = blockS2StartIdx_ == 0 && isS2End;
|
||||
// 中间结果保存
|
||||
bool needCopyWsGm = info.isAllLoopEnd || isS2End;
|
||||
if (needCopyOutGm) {
|
||||
LocalTensor<uint32_t> idxULocal = outQueue_.AllocTensor<uint32_t>();
|
||||
ExtractIndex(idxULocal,
|
||||
globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE].template ReinterpretCast<uint32_t>(),
|
||||
BASE_TOPK);
|
||||
PipeBarrier<PIPE_V>();
|
||||
InitSortOutBuf(globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE], BASE_TOPK_VALUE_IDX_SIZE);
|
||||
outQueue_.EnQue<uint32_t>(idxULocal);
|
||||
idxULocal = outQueue_.DeQue<uint32_t>();
|
||||
LIQServiceVec::CopyOut(indiceOutGm[info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount],
|
||||
idxULocal.template ReinterpretCast<int32_t>(), constInfo_.sparseCount);
|
||||
outQueue_.FreeTensor(idxULocal);
|
||||
} else if (needCopyWsGm) {
|
||||
// vec1Res Gm = [aic, s1BaseSize_, 2, 2, topkOut_] float32
|
||||
// vec1Param Gm = [aic, s1BaseSize_, 2, 16] int64
|
||||
// 16 = [needFd, s2AcSeq, s2Start, s2End, isS2End, bn2idx, s1Idx, S1ProcNum, ......]
|
||||
|
||||
int64_t wsOffset =
|
||||
(blockId_ / 2) * s1BaseSize_ * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 2个AIV共同地址偏移
|
||||
(blockId_ % 2) * (s1BaseSize_ / 2) * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 每个AIV的地址偏移,S1方向
|
||||
(ldS1Offset + innerS1Idx) * 2 * BASE_TOPK_VALUE_IDX_SIZE;
|
||||
int64_t wsInfoOffset =
|
||||
(blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ + // 2个AIV共同地址偏移
|
||||
(blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_ + // 每个AIV的地址偏移,S1方向
|
||||
(ldS1Offset + innerS1Idx) * 2 * paramNum_;
|
||||
|
||||
LocalTensor<int64_t> tmpiBuff = paramBuf_.Get<int64_t>();
|
||||
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
|
||||
tmpiBuff.SetValue(0, static_cast<int64_t>(1));
|
||||
tmpiBuff.SetValue(1, static_cast<int64_t>(cuRealAcSeq));
|
||||
tmpiBuff.SetValue(2, static_cast<int64_t>(blockS2StartIdx_));
|
||||
tmpiBuff.SetValue(3, static_cast<int64_t>(cuBaseS2Idx + cuS2Len));
|
||||
tmpiBuff.SetValue(4, static_cast<int64_t>(isS2End));
|
||||
tmpiBuff.SetValue(5, static_cast<int64_t>(info.bN2Idx));
|
||||
tmpiBuff.SetValue(6, static_cast<int64_t>(cuS1Idx));
|
||||
tmpiBuff.SetValue(7, static_cast<int64_t>(cuS1ProcNum));
|
||||
tmpiBuff.SetValue(8, static_cast<int64_t>(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount));
|
||||
// 写入头尾判断
|
||||
// [head, tail]
|
||||
// head: 与前面规约,与前后规约
|
||||
// tail: 与后面规约
|
||||
bool isTailReduce = blockS2StartIdx_ == 0; // 一定是isLastTile
|
||||
// WS偏移规则 blockS2StartIdx_ != 0
|
||||
// 跟前面块做规约 写到0偏移 不用做计算 blockS2StartIdx_ == 0 and !isS2End
|
||||
// 跟后面块做规约 写到1偏移 需要 + s1BaseSize_, BASE_TOPK*2
|
||||
if (isTailReduce) { // S2不是最后结束的数据就需要往后做规约,放入第二块ws
|
||||
wsInfoOffset += paramNum_;
|
||||
wsOffset += BASE_TOPK_VALUE_IDX_SIZE;
|
||||
}
|
||||
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
|
||||
LIQServiceVec::CopyOut(vec1ParamGm[wsInfoOffset], tmpiBuff, 16);
|
||||
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
|
||||
LIQServiceVec::CopyOut(vec1ResGm[wsOffset], globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE],
|
||||
BASE_TOPK_VALUE_IDX_SIZE);
|
||||
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
|
||||
}
|
||||
} else if (cuRealAcSeq <= 0) {
|
||||
CleanInvalidOutput(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount);
|
||||
}
|
||||
}
|
||||
|
||||
// BNSD场景无效S1 输出-1
|
||||
if (Q_LAYOUT_T == LI_LAYOUT::BSND) {
|
||||
// 最后一个S1的基本块, 需要 >= info.actS1Size
|
||||
bool isS1LoopEnd = (cuBaseS1Idx + s1BaseSize_) >= info.actS1Size;
|
||||
int32_t invalidS1Num = constInfo_.qSeqSize - info.actS1Size;
|
||||
// blockS2StartIdx_ == 0 控制S2从开始的核去做冗余清理
|
||||
if (invalidS1Num > 0 && isS1LoopEnd && blockS2StartIdx_ == 0) {
|
||||
int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num, 2) : (invalidS1Num / 2);
|
||||
int32_t s1OffsetPerAiv = info.actS1Size + (blockId_ % 2) * CeilDiv(invalidS1Num, 2);
|
||||
for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) {
|
||||
CleanInvalidOutput(info.indiceOutOffset + (s1OffsetPerAiv + innerS1Idx) * constInfo_.sparseCount);
|
||||
}
|
||||
}
|
||||
|
||||
int32_t invalidS1Num2 = info.actS1Size - info.actS2Size;
|
||||
if (invalidS1Num2 > 0 && isS1LoopEnd && blockS2StartIdx_ == 0 && constInfo_.attenMaskFlag) {
|
||||
int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num2, 2) : (invalidS1Num2 / 2);
|
||||
int32_t s1OffsetPerAiv = (blockId_ % 2) * CeilDiv(invalidS1Num2, 2);
|
||||
for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) {
|
||||
CleanInvalidOutput((info.bN2Idx * constInfo_.qSeqSize + s1OffsetPerAiv + innerS1Idx) *
|
||||
constInfo_.sparseCount);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (info.isLastS2InnerLoop) {
|
||||
// S2最后一个Loop后, 下一个基本块初始从0开始
|
||||
blockS2StartIdx_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LIQT>
|
||||
__aicore__ inline void LIQVector<LIQT>::ProcessLD()
|
||||
{
|
||||
int32_t curCubeId = blockId_ / 2;
|
||||
int32_t tmpCubeId = curCubeId;
|
||||
|
||||
int64_t s2ActSeq;
|
||||
int64_t s2Start;
|
||||
int64_t s2End;
|
||||
int64_t isS2End;
|
||||
int64_t bn2Idx;
|
||||
int64_t s1Idx;
|
||||
uint32_t acc_list_num = 0;
|
||||
int64_t bIdx = 0;
|
||||
int64_t needFd;
|
||||
int64_t wsOffset;
|
||||
int64_t wsInfoOffset = 0;
|
||||
int64_t nextneedFd;
|
||||
int64_t valueOffset = 0;
|
||||
int64_t outOffset = 0;
|
||||
|
||||
LocalTensor<float> curValueIdxUb = ldToBeMrgBuf_.Get<float>();
|
||||
LocalTensor<float> tmpUb = ldTmpBuf_.Get<float>();
|
||||
|
||||
// S2开头信息
|
||||
// 开始必然没有头规约,因此从尾规约开始处理,while循环读取下一个核的头规约
|
||||
// 存满4个list或者遇到S2结尾,则做merge,直到做完S2
|
||||
// 每个核都忽略自己的头规约,因为必然由前面的核做完
|
||||
uint32_t s1LdStartIdx = 0;
|
||||
uint32_t s1ProcNum = 0;
|
||||
uint64_t paramGmCoreOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_;
|
||||
for (uint32_t innerS1Idx = 0; innerS1Idx < s1BaseSize_; innerS1Idx++) {
|
||||
needFd = vec1ParamGm.GetValue(paramGmCoreOffset + innerS1Idx * 2 * paramNum_ + paramNum_);
|
||||
if (needFd == 1) {
|
||||
s1LdStartIdx = (s1ProcNum == 0) ? innerS1Idx : s1LdStartIdx;
|
||||
s1ProcNum++;
|
||||
}
|
||||
}
|
||||
|
||||
if (s1ProcNum == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// S1逐行计算
|
||||
uint32_t s1VecNum = CeilDiv(s1ProcNum, 2);
|
||||
if (blockId_ % 2 == 1) {
|
||||
s1LdStartIdx = s1LdStartIdx + s1VecNum;
|
||||
s1VecNum = s1ProcNum - s1VecNum;
|
||||
}
|
||||
for (uint32_t innerS1Idx = s1LdStartIdx; innerS1Idx < s1LdStartIdx + s1VecNum; innerS1Idx++) {
|
||||
// 重置偏移
|
||||
tmpCubeId = curCubeId;
|
||||
acc_list_num = 0;
|
||||
valueOffset = 0;
|
||||
|
||||
// 搬入数据
|
||||
wsOffset = tmpCubeId * s1BaseSize_ * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 2个AIV共同地址偏移
|
||||
innerS1Idx * 2 * BASE_TOPK_VALUE_IDX_SIZE + BASE_TOPK_VALUE_IDX_SIZE;
|
||||
SetWaitFlag<HardEvent::V_MTE2>(HardEvent::V_MTE2);
|
||||
SetWaitFlag<HardEvent::S_MTE2>(HardEvent::S_MTE2);
|
||||
DataCopyPad(curValueIdxUb, vec1ResGm[wsOffset],
|
||||
{1, static_cast<uint16_t>(BASE_TOPK_VALUE_IDX_SIZE * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0});
|
||||
acc_list_num++;
|
||||
valueOffset += BASE_TOPK_VALUE_IDX_SIZE;
|
||||
|
||||
// 获取下一个核规约信息
|
||||
tmpCubeId++;
|
||||
wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_;
|
||||
needFd = vec1ParamGm.GetValue(wsInfoOffset);
|
||||
isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4);
|
||||
s1Idx = vec1ParamGm.GetValue(wsInfoOffset + 6);
|
||||
outOffset = vec1ParamGm.GetValue(wsInfoOffset + 8);
|
||||
|
||||
while (needFd == 1) {
|
||||
// 搬入头规约数据
|
||||
wsOffset = tmpCubeId * s1BaseSize_ * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 2个AIV共同地址偏移
|
||||
innerS1Idx * 2 * BASE_TOPK_VALUE_IDX_SIZE;
|
||||
SetWaitFlag<HardEvent::V_MTE2>(HardEvent::V_MTE2);
|
||||
SetWaitFlag<HardEvent::S_MTE2>(HardEvent::S_MTE2);
|
||||
DataCopyPad(curValueIdxUb[valueOffset], vec1ResGm[wsOffset],
|
||||
{1, static_cast<uint16_t>(BASE_TOPK_VALUE_IDX_SIZE * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0});
|
||||
valueOffset += BASE_TOPK_VALUE_IDX_SIZE;
|
||||
acc_list_num++;
|
||||
|
||||
// 每满4个list,聚合 前2K为mrg结果
|
||||
if (acc_list_num == mrgListNum_) {
|
||||
// MrgSort 四条2048的队列,Mrg成一条
|
||||
AscendC::MrgSort4Info params;
|
||||
params.elementLengths[0] = BASE_TOPK;
|
||||
params.elementLengths[1] = BASE_TOPK;
|
||||
params.elementLengths[2] = BASE_TOPK;
|
||||
params.elementLengths[3] = BASE_TOPK;
|
||||
params.ifExhaustedSuspension = true;
|
||||
params.validBit = 0b1111;
|
||||
params.repeatTimes = 1;
|
||||
|
||||
AscendC::MrgSortSrcList<float> srcList;
|
||||
srcList.src1 = curValueIdxUb[0];
|
||||
srcList.src2 = curValueIdxUb[BASE_TOPK_VALUE_IDX_SIZE];
|
||||
srcList.src3 = curValueIdxUb[2 * BASE_TOPK_VALUE_IDX_SIZE];
|
||||
srcList.src4 = curValueIdxUb[3 * BASE_TOPK_VALUE_IDX_SIZE];
|
||||
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
|
||||
MrgSort(tmpUb, srcList, params);
|
||||
PipeBarrier<PIPE_V>();
|
||||
DataCopy(curValueIdxUb, tmpUb, BASE_TOPK_VALUE_IDX_SIZE);
|
||||
PipeBarrier<PIPE_V>();
|
||||
acc_list_num = 1;
|
||||
valueOffset = BASE_TOPK_VALUE_IDX_SIZE;
|
||||
}
|
||||
|
||||
// reduce到S2末尾,则跳出
|
||||
if (isS2End == 1) {
|
||||
break;
|
||||
}
|
||||
|
||||
tmpCubeId++;
|
||||
wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_;
|
||||
needFd = vec1ParamGm.GetValue(wsInfoOffset);
|
||||
isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4);
|
||||
}
|
||||
|
||||
// mrg不足4个list的数据
|
||||
if (acc_list_num != 1) {
|
||||
AscendC::MrgSort4Info params;
|
||||
params.elementLengths[0] = BASE_TOPK;
|
||||
params.elementLengths[1] = BASE_TOPK;
|
||||
params.elementLengths[2] = BASE_TOPK;
|
||||
params.elementLengths[3] = BASE_TOPK;
|
||||
params.ifExhaustedSuspension = true;
|
||||
if (acc_list_num == 2) {
|
||||
params.validBit = 0b0011;
|
||||
} else if (acc_list_num == 3) {
|
||||
params.validBit = 0b0111;
|
||||
}
|
||||
params.repeatTimes = 1;
|
||||
|
||||
AscendC::MrgSortSrcList<float> srcList;
|
||||
srcList.src1 = curValueIdxUb[0];
|
||||
srcList.src2 = curValueIdxUb[BASE_TOPK_VALUE_IDX_SIZE];
|
||||
srcList.src3 = curValueIdxUb[2 * BASE_TOPK_VALUE_IDX_SIZE];
|
||||
srcList.src4 = curValueIdxUb[3 * BASE_TOPK_VALUE_IDX_SIZE];
|
||||
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
|
||||
MrgSort(tmpUb, srcList, params);
|
||||
PipeBarrier<PIPE_V>();
|
||||
DataCopy(curValueIdxUb, tmpUb, BASE_TOPK_VALUE_IDX_SIZE);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
// 搬出
|
||||
LocalTensor<float> outValueUb = ldOutValueBuf_.Get<float>();
|
||||
LocalTensor<uint32_t> outIdxUb = ldOutIdxBuf_.Get<uint32_t>();
|
||||
Extract(outValueUb, outIdxUb, curValueIdxUb, (BASE_TOPK / 32));
|
||||
LocalTensor<int32_t> idxULocal1 = outIdxUb.template ReinterpretCast<int32_t>();
|
||||
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
|
||||
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
|
||||
DataCopyPad(indiceOutGm[outOffset], idxULocal1,
|
||||
{1, static_cast<uint16_t>(constInfo_.sparseCount * sizeof(int32_t)), 0, 0});
|
||||
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
|
||||
}
|
||||
}
|
||||
} // namespace LIQKernel
|
||||
#endif
|
||||
@@ -0,0 +1,53 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_template_tiling_key.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef TEMPLATE_TILING_KEY_LI_H_
|
||||
#define TEMPLATE_TILING_KEY_LI_H_
|
||||
|
||||
#include "ascendc/host_api/tiling/template_argument.h"
|
||||
|
||||
#define LI_TPL_FP16 1
|
||||
#define LI_TPL_IN8 2
|
||||
#define LI_TPL_INT32 3
|
||||
#define LI_TPL_BF16 27
|
||||
|
||||
#define LIQ_LAYOUT_BSND 0
|
||||
#define LIQ_LAYOUT_TND 1
|
||||
#define LIQ_LAYOUT_PA_BSND 2
|
||||
|
||||
#define ASCENDC_TPL_4_BW 4
|
||||
|
||||
// 模板参数支持的范围定义
|
||||
ASCENDC_TPL_ARGS_DECL(LightningIndexerQuant, // 算子OpType
|
||||
ASCENDC_TPL_DTYPE_DECL(DT_Q, LI_TPL_IN8), ASCENDC_TPL_DTYPE_DECL(DT_K, LI_TPL_IN8),
|
||||
ASCENDC_TPL_DTYPE_DECL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_DECL(PAGE_ATTENTION, 1, 0),
|
||||
ASCENDC_TPL_UINT_DECL(Q_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND,
|
||||
LIQ_LAYOUT_TND),
|
||||
ASCENDC_TPL_UINT_DECL(K_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST,
|
||||
LIQ_LAYOUT_PA_BSND, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND), );
|
||||
|
||||
// 支持的模板参数组合
|
||||
// 用于调用GET_TPL_TILING_KEY获取TilingKey时,接口内部校验TilingKey是否合法
|
||||
ASCENDC_TPL_SEL(
|
||||
ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_IN8), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_IN8),
|
||||
ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 1),
|
||||
ASCENDC_TPL_UINT_SEL(Q_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND),
|
||||
ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_PA_BSND), ),
|
||||
ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_IN8), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_IN8),
|
||||
ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 0),
|
||||
ASCENDC_TPL_UINT_SEL(Q_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND),
|
||||
ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND), ), );
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,193 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lightning_indexer_quant_vector.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef LIGHTNING_INDEXER_QUANT_VECTOR_H
|
||||
#define LIGHTNING_INDEXER_QUANT_VECTOR_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "lightning_indexer_quant_vector.h"
|
||||
|
||||
namespace LIQServiceVec {
|
||||
using namespace AscendC;
|
||||
|
||||
constexpr int32_t NEG_INF = 0xFF800000;
|
||||
constexpr int32_t INVALID_INDEX = -1;
|
||||
constexpr uint8_t VEC_REPEAT_MAX = 255;
|
||||
constexpr uint8_t B32_VEC_ELM_NUM = 64;
|
||||
constexpr uint8_t B32_BLOCK_ALIGN_NUM = 8;
|
||||
constexpr uint8_t B32_VEC_REPEAT_STRIDE = 8;
|
||||
constexpr uint64_t VEC_REPEAT_BYTES = 256;
|
||||
constexpr int32_t CONST_TWO = 2;
|
||||
constexpr int64_t VALUE_AND_INDEX_NUM = 2;
|
||||
constexpr int64_t BLOCK_BYTES = 32;
|
||||
constexpr int64_t MRG_QUE_0 = 0;
|
||||
constexpr int64_t MRG_QUE_1 = 1;
|
||||
constexpr int64_t MRG_QUE_2 = 2;
|
||||
constexpr int64_t MRG_QUE_3 = 3;
|
||||
constexpr int64_t MRG_BLOCK_2 = 2;
|
||||
constexpr int64_t MRG_BLOCK_3 = 3;
|
||||
constexpr int64_t MRG_BLOCK_4 = 4;
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void CopyOut(const GlobalTensor<T> &dstGm, const LocalTensor<T> &srcUb, int64_t copyCount)
|
||||
{
|
||||
AscendC::DataCopyParams dataCopyOutyParams;
|
||||
dataCopyOutyParams.blockCount = 1;
|
||||
dataCopyOutyParams.blockLen = copyCount * sizeof(T);
|
||||
dataCopyOutyParams.srcStride = 0;
|
||||
dataCopyOutyParams.dstStride = 0;
|
||||
AscendC::DataCopyPad(dstGm, srcUb, dataCopyOutyParams);
|
||||
}
|
||||
|
||||
/**
|
||||
src: 传入的初始化空间
|
||||
eleNum: 需要初始化的元素个数需为64整数倍,元素将被初始化为交错排布的-inf,-1
|
||||
*/
|
||||
__aicore__ inline void InitSortOutBuf(const LocalTensor<float> &src, int64_t eleNum)
|
||||
{
|
||||
uint64_t mask1[2] = {0x5555555555555555, 0};
|
||||
uint64_t mask0[2] = {0xaaaaaaaaaaaaaaaa, 0};
|
||||
int64_t repeatNum = eleNum / B32_VEC_ELM_NUM;
|
||||
int64_t forLoop = repeatNum / VEC_REPEAT_MAX;
|
||||
int64_t forRemain = repeatNum % VEC_REPEAT_MAX;
|
||||
for (int i = 0; i < forLoop; i++) {
|
||||
AscendC::Duplicate(src.template ReinterpretCast<int32_t>(), NEG_INF, mask1, VEC_REPEAT_MAX, 1,
|
||||
B32_VEC_REPEAT_STRIDE);
|
||||
AscendC::Duplicate(src.template ReinterpretCast<int32_t>(), INVALID_INDEX, mask0, VEC_REPEAT_MAX, 1,
|
||||
B32_VEC_REPEAT_STRIDE);
|
||||
}
|
||||
if (forRemain > 0) {
|
||||
AscendC::Duplicate(src.template ReinterpretCast<int32_t>()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], NEG_INF,
|
||||
mask1, forRemain, 1, B32_VEC_REPEAT_STRIDE);
|
||||
AscendC::Duplicate(src.template ReinterpretCast<int32_t>()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM],
|
||||
INVALID_INDEX, mask0, forRemain, 1, B32_VEC_REPEAT_STRIDE);
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
/**
|
||||
src: logits和索引,前logitsNum为logits,后logitsNum为索引
|
||||
tmp: 计算使用到的临时空间,大小与src一致
|
||||
logitsNum: 排序的元素个数, 暂只支持[128,256,384,512,1024,2048]
|
||||
*/
|
||||
__aicore__ inline void SortAll(LocalTensor<float> &src, LocalTensor<float> &tmp, int64_t logitsNum)
|
||||
{
|
||||
int64_t sort32Repeats = logitsNum / BLOCK_BYTES;
|
||||
AscendC::Sort32(tmp, src, src[logitsNum].ReinterpretCast<uint32_t>(), sort32Repeats);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
int64_t mrgGroups = sort32Repeats;
|
||||
int64_t mrgElements = BLOCK_BYTES;
|
||||
int64_t i = 0;
|
||||
AscendC::LocalTensor<float> srcTensor;
|
||||
AscendC::LocalTensor<float> dstTensor;
|
||||
while (true) {
|
||||
if (i % CONST_TWO == 0) {
|
||||
srcTensor = tmp;
|
||||
dstTensor = src;
|
||||
} else {
|
||||
srcTensor = src;
|
||||
dstTensor = tmp;
|
||||
}
|
||||
AscendC::MrgSort4Info params;
|
||||
params.elementLengths[0] = mrgElements;
|
||||
params.elementLengths[MRG_QUE_1] = mrgElements;
|
||||
params.elementLengths[MRG_QUE_2] = mrgElements;
|
||||
params.elementLengths[MRG_QUE_3] = mrgElements;
|
||||
params.ifExhaustedSuspension = false;
|
||||
params.validBit = 0b1111;
|
||||
|
||||
AscendC::MrgSortSrcList<float> srcList;
|
||||
srcList.src1 = srcTensor[0];
|
||||
srcList.src2 = srcTensor[MRG_QUE_1 * VALUE_AND_INDEX_NUM * mrgElements];
|
||||
srcList.src3 = srcTensor[MRG_QUE_2 * VALUE_AND_INDEX_NUM * mrgElements];
|
||||
srcList.src4 = srcTensor[MRG_QUE_3 * VALUE_AND_INDEX_NUM * mrgElements];
|
||||
if (mrgGroups <= MRG_BLOCK_4) {
|
||||
params.repeatTimes = 1;
|
||||
if (mrgGroups == 1) {
|
||||
break;
|
||||
} else if (mrgGroups == MRG_BLOCK_2) {
|
||||
params.validBit = 0b0011;
|
||||
} else if (mrgGroups == MRG_BLOCK_3) {
|
||||
params.validBit = 0b0111;
|
||||
} else if (mrgGroups == MRG_BLOCK_4) {
|
||||
params.validBit = 0b1111;
|
||||
}
|
||||
AscendC::MrgSort<float>(dstTensor, srcList, params);
|
||||
i += 1;
|
||||
break;
|
||||
} else {
|
||||
params.repeatTimes = mrgGroups / MRG_BLOCK_4;
|
||||
AscendC::MrgSort<float>(dstTensor, srcList, params);
|
||||
i += 1;
|
||||
mrgElements = mrgElements * MRG_BLOCK_4;
|
||||
mrgGroups = mrgGroups / MRG_BLOCK_4;
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
if (i % CONST_TWO == 0) {
|
||||
AscendC::DataCopy(src, tmp, logitsNum * VALUE_AND_INDEX_NUM);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
mrgDst: 合并进的Tensor
|
||||
mrgSrc: 待合并的Tensor
|
||||
tmpTensor:空间为mrgDst+mrgSrc
|
||||
*/
|
||||
__aicore__ inline void MergeSort(const LocalTensor<float> &mrgDst, int32_t mrgDstNum, LocalTensor<float> &mrgSrc,
|
||||
int32_t mrgSrcNum, LocalTensor<float> &tmpTensor)
|
||||
{
|
||||
AscendC::MrgSort4Info params;
|
||||
params.elementLengths[0] = mrgSrcNum;
|
||||
params.elementLengths[1] = mrgDstNum;
|
||||
params.ifExhaustedSuspension = false;
|
||||
params.validBit = 0b0011;
|
||||
params.repeatTimes = 1;
|
||||
|
||||
AscendC::MrgSortSrcList<float> srcList;
|
||||
srcList.src1 = mrgSrc;
|
||||
srcList.src2 = mrgDst;
|
||||
|
||||
AscendC::MrgSort<float>(tmpTensor, srcList, params);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::DataCopy(mrgDst, tmpTensor, mrgDstNum * VALUE_AND_INDEX_NUM);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void ExtractIndex(const LocalTensor<uint32_t> &idxULocal, const LocalTensor<uint32_t> &sortLocal,
|
||||
int64_t extractNum)
|
||||
{
|
||||
AscendC::GatherMaskParams gatherMaskParams;
|
||||
gatherMaskParams.repeatTimes = Ceil(extractNum * sizeof(float) * VALUE_AND_INDEX_NUM, VEC_REPEAT_BYTES);
|
||||
gatherMaskParams.src0BlockStride = 1;
|
||||
gatherMaskParams.src0RepeatStride = B32_VEC_REPEAT_STRIDE;
|
||||
gatherMaskParams.src1RepeatStride = 0;
|
||||
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
|
||||
uint8_t src1Pattern = 2; // 固定模式2,表示筛选出奇数索引的数
|
||||
AscendC::GatherMask(idxULocal, sortLocal, src1Pattern, false, static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
template <HardEvent event>
|
||||
__aicore__ inline void SetWaitFlag(HardEvent evt)
|
||||
{
|
||||
event_t eventId = static_cast<event_t>(GetTPipePtr()->FetchEventID(evt));
|
||||
AscendC::SetFlag<event>(eventId);
|
||||
AscendC::WaitFlag<event>(eventId);
|
||||
}
|
||||
|
||||
} // namespace LIQServiceVec
|
||||
#endif // LIGHTNING_INDEXER_QUANT_VECTOR_H
|
||||
@@ -42,6 +42,7 @@
|
||||
#include "moe_gating_top_k/moe_gating_top_k_torch_adpt.h"
|
||||
#include "moe_init_routing_custom/moe_init_routing_custom_torch_adpt.h"
|
||||
#include "sparse_flash_attention/sparse_flash_attention_torch_adpt.h"
|
||||
#include "lightning_indexer_quant/lightning_indexer_quant_torch_adpt.h"
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Logging.h>
|
||||
@@ -918,4 +919,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
"-> Tensor[]"
|
||||
);
|
||||
ops.impl("moe_grouped_matmul", torch::kPrivateUse1,&vllm_ascend::moe_grouped_matmul);
|
||||
|
||||
// This operator is planned to be integrated into PTA in the near future.
|
||||
// Once that happens, the implementation in csrc will be removed.
|
||||
ops.def(
|
||||
"npu_lightning_indexer_quant(Tensor query, Tensor key, Tensor weights, Tensor query_dequant_scale, "
|
||||
" Tensor key_dequant_scale, *, Tensor? actual_seq_lengths_query=None, "
|
||||
" Tensor? actual_seq_lengths_key=None, Tensor? block_table=None, "
|
||||
" int query_quant_mode=0, int key_quant_mode=0, "
|
||||
" str layout_query='BSND', str layout_key='BSND',"
|
||||
" int sparse_count=2048, int sparse_mode=3) -> Tensor"
|
||||
);
|
||||
ops.impl("npu_lightning_indexer_quant", torch::kPrivateUse1, &vllm_ascend::npu_lightning_indexer_quant);
|
||||
}
|
||||
|
||||
@@ -529,6 +529,44 @@ std::vector<at::Tensor> moe_grouped_matmul_meta(
|
||||
return y;
|
||||
}
|
||||
|
||||
at::Tensor npu_lightning_indexer_quant_meta(
|
||||
const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights,
|
||||
const at::Tensor &query_dequant_scale, const at::Tensor &key_dequant_scale,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_query,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_key,
|
||||
const c10::optional<at::Tensor> &block_table, int64_t query_quant_mode, int64_t key_quant_mode,
|
||||
c10::string_view layout_query, c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode)
|
||||
{
|
||||
std::string query_layout_str = std::string(layout_query);
|
||||
std::string key_layout_str = std::string(layout_key);
|
||||
|
||||
const int SIZE = 8;
|
||||
const int DIM_0 = 0;
|
||||
const int DIM_1 = 1;
|
||||
const int DIM_2 = 2;
|
||||
const int DIM_3 = 3;
|
||||
|
||||
at::SmallVector<int64_t, SIZE> output_size;
|
||||
for (size_t i = 0; i < query.sizes().size(); i++) {
|
||||
TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater "
|
||||
"than 0, but shape[", i, "] is ", query.size(i));
|
||||
}
|
||||
for (size_t i = 0; i < key.sizes().size(); i++) {
|
||||
TORCH_CHECK(key.size(i) > 0, "All values within key's shape should be greater "
|
||||
"than 0, but shape[", i, "] is ", key.size(i));
|
||||
}
|
||||
TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count);
|
||||
int64_t keyHeadNum = (key_layout_str == "TND")? key.size(DIM_1) : key.size(DIM_2);
|
||||
if (query_layout_str == "BSND") {
|
||||
output_size = {query.size(DIM_0), query.size(DIM_1), keyHeadNum, sparse_count};
|
||||
} else {
|
||||
output_size = {query.size(DIM_0), keyHeadNum, sparse_count};
|
||||
}
|
||||
at::Tensor lightning_indexer_quant_output = at::empty(output_size, query.options().dtype(at::kInt));
|
||||
|
||||
return lightning_indexer_quant_output;
|
||||
}
|
||||
|
||||
} // namespace meta
|
||||
} // namespace vllm_ascend
|
||||
|
||||
@@ -576,5 +614,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
ops.impl("causal_conv1d_fn", &vllm_ascend::meta::causal_conv1d_fn_meta);
|
||||
// moe_grouped_matmul
|
||||
ops.impl("moe_grouped_matmul", &vllm_ascend::meta::moe_grouped_matmul_meta);
|
||||
// Lightning indexer quant
|
||||
ops.impl("npu_lightning_indexer_quant", &vllm_ascend::meta::npu_lightning_indexer_quant_meta);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -266,6 +266,33 @@ def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep():
|
||||
vllm_model.generate_greedy(long_example_prompts, max_tokens)
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"HCCL_OP_EXPANSION_MODE": "AIV"})
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
||||
@patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"})
|
||||
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
|
||||
def test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep():
|
||||
short_example_prompts = [
|
||||
"Hello ",
|
||||
]
|
||||
# "max_position_embeddings": 163840,
|
||||
long_example_prompts = ["Hello " * (163839 - 500) + "Hello"]
|
||||
max_tokens = 500
|
||||
with VllmRunner(
|
||||
"vllm-ascend/DeepSeek-V3.2-W8A8-Pruning",
|
||||
tensor_parallel_size=2,
|
||||
quantization="ascend",
|
||||
enable_expert_parallel=True,
|
||||
max_model_len=163840,
|
||||
compilation_config={"cudagraph_capture_sizes": [2, 4, 6, 8, 10, 12], "cudagraph_mode": "FULL_DECODE_ONLY"},
|
||||
speculative_config={"num_speculative_tokens": 1, "method": "deepseek_mtp"},
|
||||
additional_config={"layer_sharding": ["q_b_proj", "o_proj"], "enable_sparse_c8": True},
|
||||
reasoning_parser="deepseek_v3",
|
||||
tokenizer_mode="deepseek_v32",
|
||||
) as vllm_model:
|
||||
vllm_model.generate_greedy(short_example_prompts, max_tokens)
|
||||
vllm_model.generate_greedy(long_example_prompts, max_tokens)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", QWEN_W4A4_MODELS)
|
||||
def test_qwen3_w4a4_distributed_tp2(model):
|
||||
example_prompts = [
|
||||
|
||||
@@ -134,9 +134,12 @@ class AscendConfig:
|
||||
bool(additional_config.get("enable_async_exponential", False)) and not vllm_is_batch_invariant()
|
||||
)
|
||||
|
||||
use_sparse = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
|
||||
vllm_config.model_config.hf_text_config, "index_topk"
|
||||
)
|
||||
|
||||
self.enable_kv_nz = additional_config.get("enable_kv_nz", False)
|
||||
if self.enable_kv_nz:
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk")
|
||||
if not vllm_config.model_config.is_deepseek_mla or use_sparse:
|
||||
raise RuntimeError("enable_kv_nz is only supported for mla currently.")
|
||||
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
@@ -144,6 +147,17 @@ class AscendConfig:
|
||||
"enable_kv_nz is only supported in pd scenario and can only be used in D node."
|
||||
)
|
||||
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
# Disable Sparse C8 for A5
|
||||
# A5 has not been fully validated for this path and may carry hidden risks.
|
||||
# TODO(rjg-lyh): Enable A5 support after sufficient validation.
|
||||
self.enable_sparse_c8 = (
|
||||
additional_config.get("enable_sparse_c8", False)
|
||||
and use_sparse
|
||||
and get_ascend_device_type() != AscendDeviceType.A5
|
||||
)
|
||||
|
||||
def _construct_weight_prefetch_config(self, additional_config):
|
||||
weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
|
||||
self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import scipy # type: ignore
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm.envs as envs_vllm
|
||||
@@ -355,6 +356,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled.
|
||||
o_proj_full_pool: torch.Tensor | None = None
|
||||
|
||||
# qk_hadamard tensor shared when dsa c8 enabled
|
||||
qk_hadamard: torch.Tensor | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@@ -425,6 +429,12 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self.is_rope_neox_style = False
|
||||
self.use_torch_npu_lightning_indexer = True
|
||||
|
||||
# dsa c8
|
||||
self.use_sparse_c8_indexer = ascend_config.enable_sparse_c8
|
||||
if self.use_sparse_c8_indexer:
|
||||
self.c8_k_cache_dtype = torch.int8
|
||||
self.c8_k_scale_cache_dtype = torch.float16
|
||||
|
||||
# Effective in SFA when FlashComm is enabled.
|
||||
self.enable_dsa_cp = enable_dsa_cp()
|
||||
|
||||
@@ -515,6 +525,11 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# if mlapo, W_UK_T can't trans nz
|
||||
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
||||
|
||||
if self.use_sparse_c8_indexer and AscendSFAImpl.qk_hadamard is None:
|
||||
AscendSFAImpl.qk_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / (
|
||||
128**0.5
|
||||
)
|
||||
|
||||
# Processing the input parameters for MLAPO by reordering and transposing
|
||||
# QKV(and part of Q) weight, applying RoPE-related dimension transformations,
|
||||
# and handling quantization parameters.
|
||||
@@ -874,7 +889,15 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128]
|
||||
|
||||
return k_li
|
||||
if self.use_sparse_c8_indexer:
|
||||
k_li = k_li @ AscendSFAImpl.qk_hadamard
|
||||
k_li, k_li_scale = torch_npu.npu_dynamic_quant(k_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
|
||||
k_li_scale = k_li_scale.to(self.c8_k_scale_cache_dtype) # [b*s,]
|
||||
k_li_scale = k_li_scale.unsqueeze(-1) # [b*s,1]
|
||||
else:
|
||||
k_li_scale = None
|
||||
|
||||
return k_li, k_li_scale
|
||||
|
||||
def indexer_select_post_process(
|
||||
self,
|
||||
@@ -905,10 +928,35 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
q_li_pe = q_li_pe.squeeze(2)
|
||||
q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128]
|
||||
|
||||
if self.use_sparse_c8_indexer:
|
||||
q_li_shape_ori = q_li.shape
|
||||
q_li = q_li @ AscendSFAImpl.qk_hadamard
|
||||
q_li, q_li_scale = torch_npu.npu_dynamic_quant(q_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
|
||||
q_li_scale = q_li_scale.to(self.c8_k_scale_cache_dtype)
|
||||
|
||||
# DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer.
|
||||
# So two branches are maintained temporarily.
|
||||
# TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed.
|
||||
if self.use_torch_npu_lightning_indexer:
|
||||
if self.use_sparse_c8_indexer:
|
||||
assert len(kv_cache) == 4
|
||||
weights = weights.to(torch.float16)
|
||||
topk_indices = torch.ops._C_ascend.npu_lightning_indexer_quant(
|
||||
query=q_li.view(q_li_shape_ori),
|
||||
key=kv_cache[2],
|
||||
weights=weights,
|
||||
query_dequant_scale=q_li_scale.view(q_li_shape_ori[:-1]),
|
||||
key_dequant_scale=kv_cache[3].squeeze(2), # B S N D -> B S D
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
block_table=attn_metadata.block_table,
|
||||
query_quant_mode=0,
|
||||
key_quant_mode=0,
|
||||
layout_query="TND",
|
||||
layout_key="PA_BSND",
|
||||
sparse_count=2048,
|
||||
sparse_mode=3,
|
||||
)
|
||||
elif self.use_torch_npu_lightning_indexer:
|
||||
topk_indices, _ = torch_npu.npu_lightning_indexer(
|
||||
query=q_li,
|
||||
key=kv_cache[2],
|
||||
@@ -1015,7 +1063,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
slot_mapping=slot_mapping,
|
||||
num_input_tokens=num_input_tokens,
|
||||
)
|
||||
k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
|
||||
k_li, k_li_scale = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
|
||||
# native
|
||||
else:
|
||||
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
|
||||
@@ -1031,7 +1079,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized"
|
||||
q_c = self.q_a_layernorm(q_c)
|
||||
|
||||
k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
|
||||
k_li, k_li_scale = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
|
||||
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
|
||||
@@ -1044,20 +1092,46 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
if self.enable_dsa_cp:
|
||||
assert k_pe is not None
|
||||
assert k_nope is not None
|
||||
assert k_li is not None
|
||||
async_op = self.enable_dsa_cp_with_layer_shard or full_gather_o_proj_enabled
|
||||
# support all_gather kv async for communication calculation overlap
|
||||
fused_kv_no_split, kv_ag_handle = all_gather_async(
|
||||
torch.cat(
|
||||
[
|
||||
k_pe.view(-1, k_pe.shape[-1]),
|
||||
k_nope.view(-1, k_nope.shape[-1]),
|
||||
k_li.view(-1, k_li.shape[-1]),
|
||||
],
|
||||
dim=1,
|
||||
),
|
||||
get_tp_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
if not self.use_sparse_c8_indexer:
|
||||
fused_kv_no_split, kv_ag_handle = all_gather_async(
|
||||
torch.cat(
|
||||
[
|
||||
k_pe.view(-1, k_pe.shape[-1]),
|
||||
k_nope.view(-1, k_nope.shape[-1]),
|
||||
k_li.view(-1, k_li.shape[-1]),
|
||||
],
|
||||
dim=1,
|
||||
),
|
||||
get_tp_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
else:
|
||||
# due to different dtypes, we have to split commu pass
|
||||
assert k_li_scale is not None
|
||||
fused_kv_no_split, _ = all_gather_async(
|
||||
torch.cat(
|
||||
[
|
||||
k_pe.view(-1, k_pe.shape[-1]),
|
||||
k_nope.view(-1, k_nope.shape[-1]),
|
||||
],
|
||||
dim=1,
|
||||
),
|
||||
get_tp_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
k_li, _ = all_gather_async(
|
||||
k_li,
|
||||
get_tp_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
k_li_scale, kv_ag_handle = all_gather_async(
|
||||
k_li_scale,
|
||||
get_tp_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
|
||||
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
||||
q_pe = self.rope_single(q_pe, cos, sin)
|
||||
@@ -1077,9 +1151,12 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
if kv_cache is not None:
|
||||
assert fused_kv_no_split is not None
|
||||
k_pe, k_nope, k_li = fused_kv_no_split.split(
|
||||
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
|
||||
)
|
||||
if not self.use_sparse_c8_indexer:
|
||||
k_pe, k_nope, k_li = fused_kv_no_split.split(
|
||||
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
|
||||
)
|
||||
else:
|
||||
k_pe, k_nope = fused_kv_no_split.split([self.qk_rope_head_dim, self.kv_lora_rank], dim=-1)
|
||||
k_nope = k_nope.view(k_nope.shape[0], 1, -1)
|
||||
k_pe = k_pe.view(k_pe.shape[0], 1, -1)
|
||||
DeviceOperator.reshape_and_cache(
|
||||
@@ -1098,6 +1175,14 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[2].view(-1, k_li.shape[-1]), slot_mapping.view(-1, 1), k_li.view(-1, k_li.shape[-1])
|
||||
) # b, s, n, d
|
||||
if self.use_sparse_c8_indexer:
|
||||
assert len(kv_cache) == 4
|
||||
assert k_li_scale is not None
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[3].view(-1, k_li_scale.shape[-1]),
|
||||
slot_mapping.view(-1, 1),
|
||||
k_li_scale.view(-1, k_li_scale.shape[-1]),
|
||||
)
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event.record()
|
||||
|
||||
|
||||
@@ -45,6 +45,30 @@ from vllm.v1.utils import ConstantList, record_function_or_nullcontext
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# `spec_manager_map` in single_type_kv_cache_manager is a module-level dict
|
||||
# whose keys are class objects bound at import time. When the async
|
||||
# recompute scheduler is enabled, `recompute_scheduler.py` is imported by
|
||||
# `check_and_update_config()` (via AsyncScheduler → scheduler.py →
|
||||
# kv_cache_coordinator → single_type_kv_cache_manager) *before*
|
||||
# this patch file is executed a second time (e.g. triggered by
|
||||
# unpickling an AscendMLAAttentionSpec in the EngineCoreProc subprocess).
|
||||
# In that case the dict already contains the original MLAAttentionSpec
|
||||
# class as a key, so a subsequent lookup with type(AscendMLAAttentionSpec
|
||||
# instance) raises KeyError.
|
||||
#
|
||||
# Fix: whenever this patch is applied, register AscendMLAAttentionSpec as
|
||||
# an additional key in spec_manager_map (if the module is already loaded).
|
||||
def register_ascend_mla_spec_in_manager():
|
||||
import sys as _sys
|
||||
|
||||
from vllm.v1.core.single_type_kv_cache_manager import FullAttentionManager
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec
|
||||
|
||||
_stm = _sys.modules.get("vllm.v1.core.single_type_kv_cache_manager")
|
||||
if _stm is not None and AscendMLAAttentionSpec not in _stm.spec_manager_map:
|
||||
_stm.spec_manager_map[AscendMLAAttentionSpec] = FullAttentionManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecomputeSchedulerConfig(SchedulerConfig):
|
||||
scheduler_cls: str | type[object] = "vllm_ascend.core.recompute_scheduler.RecomputeScheduler"
|
||||
@@ -82,6 +106,8 @@ class RecomputeScheduler(Scheduler):
|
||||
running: list[Request]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
register_ascend_mla_spec_in_manager()
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
# When is_mtp_kv_consumer is true, we will fill request.spec_token_ids
|
||||
# with placeholder tokens to enable full graph when decode nodes pull
|
||||
@@ -993,4 +1019,6 @@ class RecomputeScheduler(Scheduler):
|
||||
|
||||
class AsyncRecomputeScheduler(AsyncScheduler, RecomputeScheduler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
register_ascend_mla_spec_in_manager()
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -137,6 +137,28 @@
|
||||
# Remove this patch if upstream provides an official NPU graph-capture
|
||||
# guidance / auto-configuration path for HCCL.
|
||||
#
|
||||
# ** 8. File: platform/patch_kv_cache_interface.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.kv_cache_interface.MLAAttentionSpec`
|
||||
# Why:
|
||||
# The default `MLAAttentionSpec` is mainly built around `kv_lora_rank`
|
||||
# and `qk_rope_head_dim`. On NPU, we also use this class to describe DSA
|
||||
# models. Unlike the GPU path, where cache management is handled by an
|
||||
# additional indexer module, extending this class directly simplifies the
|
||||
# corresponding `model_runner` implementation on NPU.
|
||||
#
|
||||
# This patch also adds Sparse C8 support for DSA models on NPU. As part
|
||||
# of that support, members such as `page_size_bytes` need to be adapted,
|
||||
# so they are overridden here as well to preserve overall readability.
|
||||
# How:
|
||||
# This patch subclasses the original implementation, overrides selected
|
||||
# methods, and adds DSA-specific attributes and helpers with default
|
||||
# values where needed.
|
||||
# Related PR (if no, explain why):
|
||||
# https://github.com/vllm-project/vllm/pull/25896
|
||||
# Future Plan:
|
||||
# Remove this patch after the upcoming KV cache spec refactor.
|
||||
#
|
||||
# * Worker Patch:
|
||||
# ===============
|
||||
#
|
||||
|
||||
@@ -18,6 +18,7 @@ import os
|
||||
|
||||
import vllm_ascend.patch.platform.patch_distributed # noqa
|
||||
import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa
|
||||
import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa
|
||||
import vllm_ascend.patch.platform.patch_mamba_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_sched_yield # noqa
|
||||
|
||||
138
vllm_ascend/patch/platform/patch_kv_cache_interface.py
Normal file
138
vllm_ascend/patch/platform/patch_kv_cache_interface.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import vllm.v1.kv_cache_interface
|
||||
from typing_extensions import Self
|
||||
from vllm.utils.torch_utils import get_dtype_size
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AscendMLAAttentionSpec(MLAAttentionSpec):
|
||||
"""MLAAttentionSpec extended to support DSA models, with optional Sparse C8 support.
|
||||
|
||||
When Sparse C8 is enabled, the KV cache tuple changes from
|
||||
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: bfloat16)
|
||||
to
|
||||
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: int8, kv_cache[3]: float16).
|
||||
|
||||
The semantic meaning of each KV cache entry is as follows:
|
||||
1. kv_cache[0] stores kv_lora.
|
||||
2. kv_cache[1] stores k_rope.
|
||||
3. kv_cache[2] stores the key tensor from the indexer module.
|
||||
4. kv_cache[3] stores the key scale tensor from the indexer module,
|
||||
and exists only when Sparse C8 is enabled.
|
||||
|
||||
The main changes are as follows:
|
||||
1. The key tensor from the indexer module stored in kv_cache[2] is
|
||||
converted from bf16 to int8 to reduce memory usage. It is then
|
||||
processed with int8 precision in Lightning_indexer computation
|
||||
to improve computational efficiency.
|
||||
2. The quantization scale of the key tensor in the indexer module
|
||||
must also be stored for the Lightning_indexer_quant operator,
|
||||
and is therefore saved in kv_cache[3].
|
||||
"""
|
||||
|
||||
sparse_head_dim: tuple[int, ...] | None = None
|
||||
cache_sparse_c8: bool = False
|
||||
c8_k_cache_dtype: torch.dtype = torch.int8
|
||||
c8_k_scale_cache_dtype: torch.dtype = torch.float16
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
if self.cache_sparse_c8:
|
||||
assert self.sparse_head_dim is not None
|
||||
assert len(self.sparse_head_dim) == 3
|
||||
num_heads_per_page = self.block_size * self.num_kv_heads
|
||||
# kv_cache[0]: bfloat16, kv_cache[1]: bfloat16
|
||||
kv_lora_rank, qk_rope_head_dim = self.sparse_head_dim[:2]
|
||||
k_pe_nope_bytes = num_heads_per_page * (kv_lora_rank + qk_rope_head_dim) * get_dtype_size(self.dtype)
|
||||
# kv_cache[2]: int8
|
||||
index_head_dim = self.sparse_head_dim[-1]
|
||||
indexer_k_bytes = num_heads_per_page * index_head_dim * get_dtype_size(self.c8_k_cache_dtype)
|
||||
# kv_cache[3]: float16
|
||||
# since the scale is stored per token, head_dim is set to 1.
|
||||
index_scale_head_dim = 1
|
||||
indexer_k_scale_bytes = (
|
||||
num_heads_per_page * index_scale_head_dim * get_dtype_size(self.c8_k_scale_cache_dtype)
|
||||
)
|
||||
return k_pe_nope_bytes + indexer_k_bytes + indexer_k_scale_bytes
|
||||
|
||||
return self.block_size * self.num_kv_heads * self.head_size * get_dtype_size(self.dtype)
|
||||
|
||||
@property
|
||||
def sparse_kv_cache_ratio(self) -> tuple[float, float, float, float | None]:
|
||||
"""
|
||||
Compute the relative byte share of each KV cache entry.
|
||||
|
||||
Returns:
|
||||
A tuple containing the ratios for:
|
||||
- kv_cache[0]
|
||||
- kv_cache[1]
|
||||
- kv_cache[2]
|
||||
- kv_cache[3] (None if Sparse C8 is disabled)
|
||||
"""
|
||||
|
||||
assert self.sparse_head_dim is not None
|
||||
|
||||
def get_sparse_head_dim_virtual() -> tuple[int, int, int, int]:
|
||||
assert self.sparse_head_dim is not None
|
||||
assert self.cache_sparse_c8 is True
|
||||
|
||||
kv_lora_rank, qk_rope_head_dim, index_k_head_dim = self.sparse_head_dim
|
||||
|
||||
factor = get_dtype_size(self.dtype) // get_dtype_size(self.c8_k_cache_dtype)
|
||||
index_k_head_dim_virtual = index_k_head_dim // factor
|
||||
|
||||
assert get_dtype_size(self.dtype) == get_dtype_size(self.c8_k_scale_cache_dtype)
|
||||
index_k_scale_head_dim_virtual = 1
|
||||
|
||||
return (
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
index_k_head_dim_virtual,
|
||||
index_k_scale_head_dim_virtual,
|
||||
)
|
||||
|
||||
if self.cache_sparse_c8:
|
||||
virtual_dims = get_sparse_head_dim_virtual()
|
||||
total_virtual_head_dim = sum(virtual_dims)
|
||||
|
||||
return (
|
||||
total_virtual_head_dim / virtual_dims[0], # kv_cache[0]
|
||||
total_virtual_head_dim / virtual_dims[1], # kv_cache[1]
|
||||
total_virtual_head_dim / virtual_dims[2], # kv_cache[2]
|
||||
total_virtual_head_dim / virtual_dims[3], # kv_cache[3]
|
||||
)
|
||||
|
||||
return (
|
||||
self.head_size / self.sparse_head_dim[0], # kv_cache[0]
|
||||
self.head_size / self.sparse_head_dim[1], # kv_cache[1]
|
||||
self.head_size / self.sparse_head_dim[2], # kv_cache[2]
|
||||
None, # kv_cache[3] does not exist
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be MLAAttentionSpec."
|
||||
)
|
||||
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
|
||||
assert len(cache_dtype_str_set) == 1, (
|
||||
"All attention layers in the same KV cache group must use the same quantization method."
|
||||
)
|
||||
return cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
sparse_head_dim=specs[0].sparse_head_dim,
|
||||
dtype=specs[0].dtype,
|
||||
cache_dtype_str=cache_dtype_str_set.pop(),
|
||||
cache_sparse_c8=specs[0].cache_sparse_c8,
|
||||
)
|
||||
|
||||
|
||||
vllm.v1.kv_cache_interface.MLAAttentionSpec = AscendMLAAttentionSpec
|
||||
@@ -88,6 +88,7 @@ from vllm.v1.worker.ubatch_utils import (
|
||||
)
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
# yapf: enable
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, using_paged_attention
|
||||
@@ -100,8 +101,6 @@ from vllm_ascend.compilation.acl_graph import (
|
||||
set_graph_params,
|
||||
update_full_graph_params,
|
||||
)
|
||||
|
||||
# yapf: enable
|
||||
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
||||
from vllm_ascend.eplb.core.eplb_device_transfer_loader import D2DExpertWeightLoader
|
||||
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
||||
@@ -278,7 +277,21 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
# Set up Attention
|
||||
self.use_sparse = hasattr(self.vllm_config.model_config.hf_text_config, "index_topk")
|
||||
self.use_sparse = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
|
||||
vllm_config.model_config.hf_text_config, "index_topk"
|
||||
)
|
||||
if self.use_sparse:
|
||||
self.sparse_head_dim = (
|
||||
self.model_config.hf_text_config.kv_lora_rank,
|
||||
self.model_config.hf_text_config.qk_rope_head_dim,
|
||||
self.model_config.hf_text_config.index_head_dim,
|
||||
)
|
||||
# dsa c8
|
||||
self.use_sparse_c8_indexer = self.ascend_config.enable_sparse_c8
|
||||
if self.use_sparse_c8_indexer:
|
||||
self.c8_k_cache_dtype = torch.int8
|
||||
self.c8_k_scale_cache_dtype = torch.float16
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
0,
|
||||
self.dtype,
|
||||
@@ -2629,7 +2642,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
to their corresponding memory buffer for K cache and V cache.
|
||||
"""
|
||||
# init kv cache tensors
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None] = {}
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None | None] = {}
|
||||
# prefill disaggregation need the addr of cache tensor be aligned with 2M
|
||||
alignment = 2 * 1024 * 1024
|
||||
layer_kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
@@ -2676,19 +2689,18 @@ class NPUModelRunner(GPUModelRunner):
|
||||
+ self.model_config.hf_text_config.kv_lora_rank
|
||||
)
|
||||
|
||||
dsa_k_cache_factor = None
|
||||
dsa_k_cache_size = None
|
||||
if not self.model_config.use_mla:
|
||||
# for non-mla model, use FullAttentionSpec
|
||||
k_tensor_split_factor = 2
|
||||
v_tensor_split_factor = 2
|
||||
k_tensor_split_factor = 2.0
|
||||
v_tensor_split_factor = 2.0
|
||||
elif self.use_sparse:
|
||||
# for deepseek v3.2, we split the kv cache according to the corresponding ratio
|
||||
sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio())
|
||||
k_tensor_split_factor, v_tensor_split_factor, dsa_k_cache_factor = [ # type: ignore
|
||||
sparse_sum_head_size / ratio for ratio in self._get_sparse_kv_cache_ratio()
|
||||
]
|
||||
dsa_k_cache_size = int(kv_cache_tensor.size // dsa_k_cache_factor)
|
||||
kv_cache_spec = layer_kv_cache_spec[layer_name]
|
||||
sparse_kv_cache_ratio = kv_cache_spec.sparse_kv_cache_ratio
|
||||
k_tensor_split_factor = sparse_kv_cache_ratio[0]
|
||||
v_tensor_split_factor = sparse_kv_cache_ratio[1]
|
||||
dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2]
|
||||
dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3]
|
||||
else:
|
||||
# for other deepseek models, use MLAAttentionSpec
|
||||
k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank
|
||||
@@ -2696,35 +2708,56 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
k_tensor_size = int(kv_cache_tensor.size // k_tensor_split_factor)
|
||||
v_tensor_size = int(kv_cache_tensor.size // v_tensor_split_factor)
|
||||
dsa_k_tensor_size = None
|
||||
dsa_k_scale_tensor_size = None
|
||||
#### for deepseek sparse attention
|
||||
if self.use_sparse:
|
||||
dsa_k_tensor_size = int(kv_cache_tensor.size // dsa_k_tensor_split_factor)
|
||||
if self.use_sparse_c8_indexer:
|
||||
dsa_k_scale_tensor_size = int(kv_cache_tensor.size // dsa_k_scale_tensor_split_factor)
|
||||
|
||||
# for other attentions, e.g., self_attn, sliding window attn
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
k_tensor = torch.zeros(k_tensor_size, dtype=torch.int8, device=self.device)
|
||||
v_tensor = torch.zeros(v_tensor_size, dtype=torch.int8, device=self.device)
|
||||
#### k cache: for deepseek sparse attention
|
||||
if dsa_k_cache_factor is not None:
|
||||
dsa_k_cache_tensor = torch.zeros(dsa_k_cache_size, dtype=torch.int8, device=self.device)
|
||||
#### for deepseek sparse attention
|
||||
if dsa_k_tensor_size is not None:
|
||||
dsa_k_tensor = torch.zeros(dsa_k_tensor_size, dtype=torch.int8, device=self.device)
|
||||
if dsa_k_scale_tensor_size is not None:
|
||||
dsa_k_scale_tensor = torch.zeros(
|
||||
dsa_k_scale_tensor_size, dtype=torch.int8, device=self.device
|
||||
)
|
||||
else:
|
||||
k_tensor = torch.zeros(k_tensor_size + alignment, dtype=torch.int8, device=self.device)
|
||||
v_tensor = torch.zeros(v_tensor_size + alignment, dtype=torch.int8, device=self.device)
|
||||
k_tensor = self._align_memory(k_tensor, alignment)[:k_tensor_size]
|
||||
v_tensor = self._align_memory(v_tensor, alignment)[:v_tensor_size]
|
||||
#### k cache: for deepseek sparse attention
|
||||
if dsa_k_cache_factor is not None and dsa_k_cache_size is not None:
|
||||
dsa_k_cache_tensor = torch.zeros(
|
||||
dsa_k_cache_size + alignment, dtype=torch.int8, device=self.device
|
||||
#### for deepseek sparse attention
|
||||
if dsa_k_tensor_size is not None:
|
||||
dsa_k_tensor = torch.zeros(
|
||||
dsa_k_tensor_size + alignment, dtype=torch.int8, device=self.device
|
||||
)
|
||||
dsa_k_cache_tensor = self._align_memory(dsa_k_cache_tensor, alignment)[:dsa_k_cache_size]
|
||||
dsa_k_tensor = self._align_memory(dsa_k_tensor, alignment)[:dsa_k_tensor_size]
|
||||
if dsa_k_scale_tensor_size is not None:
|
||||
dsa_k_scale_tensor = torch.zeros(
|
||||
dsa_k_scale_tensor_size + alignment, dtype=torch.int8, device=self.device
|
||||
)
|
||||
dsa_k_scale_tensor = self._align_memory(
|
||||
dsa_k_scale_tensor, alignment
|
||||
)[:dsa_k_scale_tensor_size]
|
||||
|
||||
for layer_name_inner in kv_cache_tensor.shared_by:
|
||||
# shared the attn kvcache for all shared layers
|
||||
if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner:
|
||||
kv_cache_raw_tensors[layer_name_inner] = (
|
||||
(k_tensor, v_tensor)
|
||||
if not self.use_sparse
|
||||
else (k_tensor, v_tensor, dsa_k_cache_tensor)
|
||||
)
|
||||
|
||||
if self.use_sparse:
|
||||
if self.use_sparse_c8_indexer:
|
||||
kv_cache_raw_tensors[layer_name_inner] = (
|
||||
k_tensor, v_tensor, dsa_k_tensor, dsa_k_scale_tensor
|
||||
)
|
||||
else:
|
||||
kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor, dsa_k_tensor)
|
||||
else:
|
||||
kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor)
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
@@ -2766,13 +2799,23 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
||||
# encounter OOM issue
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
raw_dsa_k_tensor = None
|
||||
if self.use_sparse:
|
||||
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
|
||||
layer_name
|
||||
]
|
||||
assert raw_dsa_k_tensor is not None
|
||||
sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel()
|
||||
if self.use_sparse_c8_indexer:
|
||||
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor, raw_dsa_k_scale_tensor = kv_cache_raw_tensors[ # type: ignore
|
||||
layer_name]
|
||||
assert raw_dsa_k_tensor is not None
|
||||
assert raw_dsa_k_scale_tensor is not None
|
||||
sum_page_size_bytes = (
|
||||
raw_k_tensor.numel()
|
||||
+ raw_v_tensor.numel()
|
||||
+ raw_dsa_k_tensor.numel()
|
||||
+ raw_dsa_k_scale_tensor.numel()
|
||||
)
|
||||
else:
|
||||
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
|
||||
layer_name]
|
||||
assert raw_dsa_k_tensor is not None
|
||||
sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel()
|
||||
elif self.use_hybrid_blocks and self.hybrid_with_attn_and_mamba:
|
||||
# Currently, we ensure that the same kvcache format is used even if there
|
||||
# is no shared layer, such as the full attention mtp layer of qwen3.5, etc.
|
||||
@@ -2819,7 +2862,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size
|
||||
)
|
||||
dtype = kv_cache_spec.dtype
|
||||
|
||||
if not self.model_config.use_mla:
|
||||
k_shape = kv_cache_shape[1:]
|
||||
v_shape = k_shape
|
||||
@@ -2838,19 +2881,37 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_kv_heads,
|
||||
self.model_config.hf_text_config.qk_rope_head_dim,
|
||||
]
|
||||
k_cache = raw_k_tensor.view(dtype).view(k_shape)
|
||||
v_cache = raw_v_tensor.view(dtype).view(v_shape)
|
||||
k_cache = raw_k_tensor.view(kv_cache_spec.dtype).view(k_shape)
|
||||
v_cache = raw_v_tensor.view(kv_cache_spec.dtype).view(v_shape)
|
||||
|
||||
if self.use_sparse and raw_dsa_k_tensor is not None:
|
||||
index_head_dim = self._get_sparse_kv_cache_ratio()[-1]
|
||||
if self.use_sparse:
|
||||
dsa_k_cache_shape = (
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
index_head_dim,
|
||||
self.model_config.hf_text_config.index_head_dim,
|
||||
)
|
||||
dsa_k_cache = raw_dsa_k_tensor.view(dtype).view(dsa_k_cache_shape)
|
||||
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
|
||||
if self.use_sparse_c8_indexer:
|
||||
# dsa_k
|
||||
dsa_k_cache = raw_dsa_k_tensor.view(self.c8_k_cache_dtype).view(dsa_k_cache_shape)
|
||||
# dsa_k_scale
|
||||
dsa_k_scale_cache_shape = (
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
1,
|
||||
)
|
||||
assert raw_dsa_k_scale_tensor is not None
|
||||
dsa_k_scale_cache = (
|
||||
raw_dsa_k_scale_tensor
|
||||
.view(self.c8_k_scale_cache_dtype)
|
||||
.view(dsa_k_scale_cache_shape)
|
||||
)
|
||||
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache, dsa_k_scale_cache)
|
||||
else:
|
||||
# dsa_k
|
||||
dsa_k_cache = raw_dsa_k_tensor.view(kv_cache_spec.dtype).view(dsa_k_cache_shape)
|
||||
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
|
||||
else:
|
||||
kv_caches[layer_name] = (k_cache, v_cache)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
@@ -3120,18 +3181,31 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
elif isinstance(attn_module, MLAAttention):
|
||||
if self.use_sparse:
|
||||
# TODO(cmq): This is a hack way to fix deepseek kvcache when
|
||||
# using DSA. Fix the spec in vLLM is the final way.
|
||||
sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio())
|
||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||
# `MLAAttentionSpec` is temporarily patched to `AscendMLAAttentionSpec`.
|
||||
# Re-importing it at runtime will therefore resolve to the patched class.
|
||||
# Rename it here to make this behavior explicit.
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec
|
||||
# TODO(rjg-lyh): when kv_cache_spec's refactor is ready,
|
||||
# implement it by creating a new kv_cache_spec class
|
||||
kv_cache_spec[layer_name] = AscendMLAAttentionSpec(
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=sparse_sum_head_size,
|
||||
head_size=sum(self.sparse_head_dim),
|
||||
sparse_head_dim=self.sparse_head_dim,
|
||||
dtype=self.kv_cache_dtype,
|
||||
cache_dtype_str=self.vllm_config.cache_config.cache_dtype,
|
||||
cache_sparse_c8=self.use_sparse_c8_indexer,
|
||||
)
|
||||
elif spec := attn_module.get_kv_cache_spec(self.vllm_config):
|
||||
kv_cache_spec[layer_name] = spec
|
||||
assert isinstance(spec, MLAAttentionSpec)
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec
|
||||
kv_cache_spec[layer_name] = AscendMLAAttentionSpec(
|
||||
block_size=spec.block_size,
|
||||
num_kv_heads=spec.num_kv_heads,
|
||||
head_size=spec.head_size,
|
||||
dtype=spec.dtype,
|
||||
cache_dtype_str=spec.cache_dtype_str,
|
||||
)
|
||||
|
||||
elif isinstance(attn_module, MambaBase):
|
||||
mamba_layers[layer_name] = attn_module
|
||||
@@ -3149,16 +3223,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def _get_sparse_kv_cache_ratio(self) -> list[int]:
|
||||
# TODO:If C8 is supported, we need to consider the number of bytes occupied by different dtypes
|
||||
# when calculating the ratio,for example:
|
||||
# [kv_lora_rank * torch.int8.itemsize, qk_rope_head_dim * torch.bfloat16.itemsize, ...]
|
||||
return [
|
||||
self.model_config.hf_text_config.kv_lora_rank,
|
||||
self.model_config.hf_text_config.qk_rope_head_dim,
|
||||
self.model_config.hf_text_config.index_head_dim,
|
||||
]
|
||||
|
||||
def _check_and_update_cudagraph_mode(
|
||||
self,
|
||||
attention_backends: list[set[type[AttentionBackend]]],
|
||||
|
||||
Reference in New Issue
Block a user