[Perf][1/N] w8a8c8 support in dsv3.2/glm5 (#7029)

### What this PR does / why we need it?
This PR supports W8A8C8 in dsv3.2/glm5 with lightning_indexer_quant ops
in pd-mix stage mainly.

Because the code for the current PD-disaggregated scenario is still
under refactoring and cleanup, this PR prioritizes ensuring the C8
functionality in the pd-mix scenario.

The next steps are planned in two parts:
① Once the optimized scatter operator is updated, we will replace the
original operator to improve the performance of storing k_scale.
② Once the code logic for the PD-disaggregated scenario becomes stable,
we will carry out more comprehensive validation and make appropriate
adaptations.
③ Because enabling C8 currently introduces several new operators whose
performance still needs improvement, performance may regress in some
scenarios. Therefore, only after all the operators are fully ready can
we ensure that this feature does not cause any performance degradation.
At that point, we will enable this feature by default and remove the
switch in `additional_config`.


### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2026-03-13 14:47:42 +08:00
committed by GitHub
parent df1ee8070d
commit 7ed9e9de69
24 changed files with 4279 additions and 77 deletions

View File

@@ -68,6 +68,8 @@ e2e-2card-light:
estimated_time: 220
- name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8_pruning_mtp_tp2_ep
estimated_time: 90
- name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep
estimated_time: 90
- name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_gpt_oss_distributed_tp2
estimated_time: 180
@@ -118,6 +120,8 @@ e2e-multicard-2-cards:
estimated_time: 71
- name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8_pruning_mtp_tp2_ep
estimated_time: 111
- name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep
estimated_time: 111
- name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_qwen3_w4a4_distributed_tp2
estimated_time: 180
- name: tests/e2e/multicard/2-cards/test_pipeline_parallel.py

View File

@@ -25,7 +25,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;"
CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;lightning_indexer_quant;"
SOC_ARG="ascend910b"
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
# ASCEND910C (A3) series
@@ -67,6 +67,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
"copy_and_expand_eagle_inputs"
"causal_conv1d"
"moe_grouped_matmul"
"lightning_indexer_quant"
)
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
SOC_ARG="ascend910_93"

View File

@@ -0,0 +1,81 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LIGHTING_INDEXER_QUANT_VLLM_TORCH_ADPT_H
#define LIGHTING_INDEXER_QUANT_VLLM_TORCH_ADPT_H
namespace vllm_ascend {
at::Tensor npu_lightning_indexer_quant(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights,
const at::Tensor &query_dequant_scale, const at::Tensor &key_dequant_scale,
const c10::optional<at::Tensor> &actual_seq_lengths_query,
const c10::optional<at::Tensor> &actual_seq_lengths_key,
const c10::optional<at::Tensor> &block_table, int64_t query_quant_mode, int64_t key_quant_mode,
c10::string_view layout_query, c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode)
{
std::string query_layout_str = std::string(layout_query);
std::string key_layout_str = std::string(layout_key);
const int SIZE = 8;
const int DIM_0 = 0;
const int DIM_1 = 1;
const int DIM_2 = 2;
const int DIM_3 = 3;
at::SmallVector<int64_t, SIZE> output_size;
for (size_t i = 0; i < query.sizes().size(); i++) {
TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater "
"than 0, but shape[", i, "] is ", query.size(i));
}
for (size_t i = 0; i < key.sizes().size(); i++) {
TORCH_CHECK(key.size(i) > 0, "All values within key's shape should be greater "
"than 0, but shape[", i, "] is ", key.size(i));
}
TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count);
int64_t keyHeadNum = (key_layout_str == "TND")? key.size(DIM_1) : key.size(DIM_2);
if (query_layout_str == "BSND") {
output_size = {query.size(DIM_0), query.size(DIM_1), keyHeadNum, sparse_count};
} else {
output_size = {query.size(DIM_0), keyHeadNum, sparse_count};
}
at::Tensor lightning_indexer_quant_output = at::empty(output_size, query.options().dtype(at::kInt));
// convert str
char *query_layout_ptr = const_cast<char *>(query_layout_str.c_str());
char *key_layout_ptr = const_cast<char *>(key_layout_str.c_str());
EXEC_NPU_CMD(aclnnLightningIndexerQuant,
query,
key,
weights,
query_dequant_scale,
key_dequant_scale,
actual_seq_lengths_query,
actual_seq_lengths_key,
block_table,
query_quant_mode,
key_quant_mode,
query_layout_ptr,
key_layout_ptr,
sparse_count,
sparse_mode,
lightning_indexer_quant_output
);
return lightning_indexer_quant_output;
}
}
#endif

View File

@@ -0,0 +1,41 @@
# This program is free software, you can redistribute it and/or modify it.
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
add_ops_compile_options(
OP_NAME LightningIndexerQuant
OPTIONS --cce-auto-sync=off
-Wno-deprecated-declarations
-Werror
-mllvm -cce-aicore-hoist-movemask=false
--op_relocatable_kernel_binary=true
)
set(lightning_indexer_quant_depends transformer/attention/lightning_indexer_quant PARENT_SCOPE)
target_sources(op_host_aclnn PRIVATE
lightning_indexer_quant_def.cpp
)
target_sources(optiling PRIVATE
lightning_indexer_quant_tiling.cpp
)
if (NOT BUILD_OPEN_PROJECT)
target_sources(opmaster_ct PRIVATE
lightning_indexer_quant_tiling.cpp
)
endif ()
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/op_host
)
target_sources(opsproto PRIVATE
lightning_indexer_quant_proto.cpp
)

View File

@@ -0,0 +1,85 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_quant_def.cpp
* \brief
*/
#include <cstdint>
#include "register/op_def_registry.h"
namespace ops {
class LightningIndexerQuant : public OpDef {
public:
explicit LightningIndexerQuant(const char *name) : OpDef(name)
{
this->Input("query")
.ParamType(REQUIRED)
.DataType({ge::DT_INT8})
.Format({ge::FORMAT_ND})
.AutoContiguous();
this->Input("key")
.ParamType(REQUIRED)
.DataType({ge::DT_INT8})
.Format({ge::FORMAT_ND})
.AutoContiguous();
this->Input("weights")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16})
.Format({ge::FORMAT_ND})
.AutoContiguous();
this->Input("query_dequant_scale")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16})
.Format({ge::FORMAT_ND})
.AutoContiguous();
this->Input("key_dequant_scale")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16})
.Format({ge::FORMAT_ND})
.AutoContiguous();
this->Input("actual_seq_lengths_query")
.ParamType(OPTIONAL)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.AutoContiguous();
this->Input("actual_seq_lengths_key")
.ParamType(OPTIONAL)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.AutoContiguous();
this->Input("block_table")
.ParamType(OPTIONAL)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.AutoContiguous();
this->Output("sparse_indices").ParamType(REQUIRED).DataType({ge::DT_INT32}).Format({ge::FORMAT_ND});
this->Attr("query_quant_mode").AttrType(REQUIRED).Int(0); // 0: 默认值per-token-head
this->Attr("key_quant_mode").AttrType(REQUIRED).Int(0); // 0: 默认值per-token-head
this->Attr("layout_query").AttrType(OPTIONAL).String("BSND");
this->Attr("layout_key").AttrType(OPTIONAL).String("PA_BSND");
this->Attr("sparse_count").AttrType(OPTIONAL).Int(2048); // 2048: 默认值筛选前2048
this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3); // 3: 默认值,只计算下三角
OpAICoreConfig aicore_config;
aicore_config.DynamicCompileStaticFlag(true)
.DynamicFormatFlag(true)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true)
.NeedCheckSupportFlag(false)
.PrecisionReduceFlag(true)
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
.ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false");
this->AICore().AddConfig("ascend910b", aicore_config);
this->AICore().AddConfig("ascend910_93", aicore_config);
}
};
OP_ADD(LightningIndexerQuant);
} // namespace ops

View File

@@ -0,0 +1,91 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_quant_proto.cpp
* \brief
*/
#include <graph/utils/type_utils.h>
#include <register/op_impl_registry.h>
#include "error/ops_error.h"
using namespace ge;
namespace ops {
constexpr uint32_t QUERY_INDEX = 0;
constexpr uint32_t KEY_INDEX = 1;
constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 2;
constexpr uint32_t ATTR_KV_LAYOUT_INDEX = 3;
constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 4;
static ge::graphStatus InferShapeLightningIndexerQuant(gert::InferShapeContext *context)
{
if (context == nullptr) {
OPS_LOG_E("LightningIndexerQuant", "context is nullptr!");
return ge::GRAPH_FAILED;
}
const gert::Shape *queryShape = context->GetInputShape(QUERY_INDEX);
OPS_LOG_E_IF_NULL(context, queryShape, return ge::GRAPH_FAILED);
const gert::Shape *keyShape = context->GetInputShape(KEY_INDEX);
OPS_LOG_E_IF_NULL(context, keyShape, return ge::GRAPH_FAILED);
gert::Shape *outShape = context->GetOutputShape(0);
auto attrs = context->GetAttrs();
OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED);
const char *inputLayoutQueryPtr = attrs->GetAttrPointer<char>(ATTR_QUERY_LAYOUT_INDEX);
OPS_LOG_E_IF_NULL(context, inputLayoutQueryPtr, return ge::GRAPH_FAILED);
const char *inputLayoutKeyPtr = attrs->GetAttrPointer<char>(ATTR_KV_LAYOUT_INDEX);
OPS_LOG_E_IF_NULL(context, inputLayoutKeyPtr, return ge::GRAPH_FAILED);
const int64_t *sparse_count = attrs->GetInt(ATTR_SPARSE_COUNT_INDEX);
OPS_LOG_E_IF_NULL(context, sparse_count, return ge::GRAPH_FAILED);
std::string inputLayoutQueryPtrStr = std::string(inputLayoutQueryPtr);
std::string inputLayoutKeyPtrStr = std::string(inputLayoutKeyPtr);
if (inputLayoutQueryPtrStr != "TND" && inputLayoutQueryPtrStr != "BSND") {
OPS_LOG_E(context, "The input layout query should be TND or BSND, but got %s.", inputLayoutQueryPtrStr.c_str());
return GRAPH_FAILED;
}
outShape->SetDimNum(queryShape->GetDimNum());
int64_t keyHeadNum = (inputLayoutKeyPtrStr == "TND") ? keyShape->GetDim(1) : keyShape->GetDim(2);
if (inputLayoutQueryPtrStr == "BSND") {
outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim B
outShape->SetDim(1, queryShape->GetDim(1)); // 1:Dim S
outShape->SetDim(2, keyHeadNum); // 2:Dim N
outShape->SetDim(3, *sparse_count); // 3:Dim K
} else {
outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim T
outShape->SetDim(1, keyHeadNum); // 1:output shape's N Dim, 2: key shape's N Dim
outShape->SetDim(2, *sparse_count); // 2:Dim K
}
OPS_LOG_D(context->GetNodeName(), "LightningIndexerQuant InferShape end.");
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus InferDataTypeLightningIndexerQuant(gert::InferDataTypeContext *context)
{
if (context == nullptr) {
OPS_LOG_E("LightningIndexerQuant", "InferDataTypeContext context is nullptr!");
return ge::GRAPH_FAILED;
}
OPS_LOG_D(context->GetNodeName(), "Enter LightningIndexerQuant InferDataType impl.");
// default index data type is int32
ge::DataType outputType = ge::DT_INT32;
context->SetOutputDataType(0, outputType);
OPS_LOG_D(context->GetNodeName(), "LightningIndexerQuant InferDataType end.");
return GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(LightningIndexerQuant)
.InferShape(InferShapeLightningIndexerQuant)
.InferDataType(InferDataTypeLightningIndexerQuant);
} // namespace ops

View File

@@ -0,0 +1,828 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_quant_tiling.cpp
* \brief
*/
#include "lightning_indexer_quant_tiling.h"
#include "../op_kernel/lightning_indexer_quant_template_tiling_key.h"
using namespace ge;
using namespace AscendC;
using std::map;
using std::string;
namespace optiling {
// --------------------------LIQInfoParser类成员函数定义-------------------------------------
ge::graphStatus LIQInfoParser::CheckRequiredInOutExistence() const
{
OPS_ERR_IF(opParamInfo_.query.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor query is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.query.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor query is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.key.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor key is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.key.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor key is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.weights.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor weights is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.weights.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor weights is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.query_dequant_scale.shape == nullptr,
OPS_LOG_E(opName_, "Shape of tensor query_dequant_scale is nullptr"), return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.query_dequant_scale.desc == nullptr,
OPS_LOG_E(opName_, "Desc of tensor query_dequant_scale is nullptr"), return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.key_dequant_scale.shape == nullptr,
OPS_LOG_E(opName_, "Shape of tensor key_dequant_scale is nullptr"), return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.key_dequant_scale.desc == nullptr,
OPS_LOG_E(opName_, "Desc of tensor key_dequant_scale is nullptr"), return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.attenOut.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor output is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.attenOut.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor output is nullptr"),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::CheckRequiredAttrExistence() const
{
OPS_ERR_IF(opParamInfo_.layOutQuery == nullptr, OPS_LOG_E(opName_, "attr layout_query is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.layOutKey == nullptr, OPS_LOG_E(opName_, "attr layout_key is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.sparseCount == nullptr, OPS_LOG_E(opName_, "attr sparse_count is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.sparseMode == nullptr, OPS_LOG_E(opName_, "attr sparse_mode is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.queryQuantMode == nullptr, OPS_LOG_E(opName_, "query_quant_mode is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.keyQuantMode == nullptr, OPS_LOG_E(opName_, "key_quant_mode is nullptr"),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::CheckRequiredParaExistence() const
{
if (CheckRequiredInOutExistence() != ge::GRAPH_SUCCESS || CheckRequiredAttrExistence() != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::GetOpName()
{
if (context_->GetNodeName() == nullptr) {
OPS_LOG_E("LightningIndexerQuant", "opName got from TilingContext is nullptr");
return ge::GRAPH_FAILED;
}
opName_ = context_->GetNodeName();
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::GetNpuInfo()
{
platformInfo_ = context_->GetPlatformInfo();
OPS_ERR_IF(platformInfo_ == nullptr, OPS_LOG_E(opName_, "GetPlatformInfo is nullptr."), return ge::GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo_);
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
OPS_ERR_IF(aicNum == 0 || aivNum == 0, OPS_LOG_E(opName_, "num of core obtained is 0."), return GRAPH_FAILED);
socVersion_ = ascendcPlatform.GetSocVersion();
if ((socVersion_ != platform_ascendc::SocVersion::ASCEND910B) &&
(socVersion_ != platform_ascendc::SocVersion::ASCEND910_93)) {
OPS_LOG_E(opName_, "SOC Version[%d] is not support.", (int32_t)socVersion_);
return GRAPH_FAILED;
}
OPS_ERR_IF(context_->GetWorkspaceSizes(1) == nullptr, OPS_LOG_E(opName_, "workSpaceSize got from ge is nullptr"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(context_->GetRawTilingData() == nullptr,
OPS_LOG_E(context_->GetNodeName(), "RawTilingData got from GE context is nullptr."),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
void LIQInfoParser::GetOptionalInputParaInfo()
{
opParamInfo_.actualSeqLengthsQ.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_Q_INDEX);
opParamInfo_.actualSeqLengthsQ.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_Q_INDEX);
opParamInfo_.actualSeqLengthsK.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_K_INDEX);
opParamInfo_.actualSeqLengthsK.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_K_INDEX);
opParamInfo_.blockTable.tensor = context_->GetOptionalInputTensor(BLOCK_TABLE_INDEX);
opParamInfo_.blockTable.desc = context_->GetOptionalInputDesc(BLOCK_TABLE_INDEX);
}
void LIQInfoParser::GetInputParaInfo()
{
opParamInfo_.query.desc = context_->GetInputDesc(QUERY_INDEX);
opParamInfo_.query.shape = context_->GetInputShape(QUERY_INDEX);
opParamInfo_.key.desc = context_->GetInputDesc(KEY_INDEX);
opParamInfo_.key.shape = context_->GetInputShape(KEY_INDEX);
opParamInfo_.weights.desc = context_->GetInputDesc(WEIGTHS_INDEX);
opParamInfo_.weights.shape = context_->GetInputShape(WEIGTHS_INDEX);
opParamInfo_.query_dequant_scale.desc = context_->GetInputDesc(QUERY_DEQUANT_SCALE_INDEX);
opParamInfo_.query_dequant_scale.shape = context_->GetInputShape(QUERY_DEQUANT_SCALE_INDEX);
opParamInfo_.key_dequant_scale.desc = context_->GetInputDesc(KEY_DEQUANT_SCALE_INDEX);
opParamInfo_.key_dequant_scale.shape = context_->GetInputShape(KEY_DEQUANT_SCALE_INDEX);
GetOptionalInputParaInfo();
}
void LIQInfoParser::GetOutputParaInfo()
{
opParamInfo_.attenOut.desc = context_->GetOutputDesc(LIGHTNING_INDEXER_QUANT);
opParamInfo_.attenOut.shape = context_->GetOutputShape(LIGHTNING_INDEXER_QUANT);
}
ge::graphStatus LIQInfoParser::GetAttrParaInfo()
{
auto attrs = context_->GetAttrs();
OPS_ERR_IF(attrs == nullptr, OPS_REPORT_VECTOR_INNER_ERR(context_->GetNodeName(), "attrs got from ge is nullptr"),
return ge::GRAPH_FAILED);
OPS_LOG_I(context_->GetNodeName(), "GetAttrParaInfo start");
opParamInfo_.layOutQuery = attrs->GetStr(ATTR_QUERY_LAYOUT_INDEX);
opParamInfo_.layOutKey = attrs->GetStr(ATTR_KEY_LAYOUT_INDEX);
opParamInfo_.queryQuantMode = attrs->GetAttrPointer<int32_t>(ATTR_QUERY_QUANT_MODE_INDEX);
opParamInfo_.keyQuantMode = attrs->GetAttrPointer<int32_t>(ATTR_KEY_QUANT_MODE_INDEX);
opParamInfo_.layOutQuery = attrs->GetStr(ATTR_QUERY_LAYOUT_INDEX);
opParamInfo_.layOutKey = attrs->GetStr(ATTR_KEY_LAYOUT_INDEX);
opParamInfo_.sparseCount = attrs->GetAttrPointer<int32_t>(ATTR_SPARSE_COUNT_INDEX);
opParamInfo_.sparseMode = attrs->GetAttrPointer<int32_t>(ATTR_SPARSE_MODE_INDEX);
if (opParamInfo_.layOutQuery != nullptr) {
OPS_LOG_I(context_->GetNodeName(), "layout_query is:%s", opParamInfo_.layOutQuery);
}
if (opParamInfo_.layOutKey != nullptr) {
OPS_LOG_I(context_->GetNodeName(), "layout_key is:%s", opParamInfo_.layOutKey);
}
if (opParamInfo_.sparseCount != nullptr) {
OPS_LOG_I(context_->GetNodeName(), "selscted count is:%d", *opParamInfo_.sparseCount);
}
if (opParamInfo_.sparseMode != nullptr) {
OPS_LOG_I(context_->GetNodeName(), "sparse mode is:%d", *opParamInfo_.sparseMode);
}
if (opParamInfo_.queryQuantMode != nullptr) {
OPS_LOG_I(context_->GetNodeName(), "query_quant_mode mode is:%d", *opParamInfo_.queryQuantMode);
}
if (opParamInfo_.keyQuantMode != nullptr) {
OPS_LOG_I(context_->GetNodeName(), "key_quant_mode mode is:%d", *opParamInfo_.keyQuantMode);
}
OPS_LOG_I(context_->GetNodeName(), "GetAttrParaInfo end");
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::CheckAttrParaInfo()
{
std::string layout_key(opParamInfo_.layOutKey);
std::string layout_query(opParamInfo_.layOutQuery);
OPS_ERR_IF(
((std::string(opParamInfo_.layOutKey) == "BNSD") || (std::string(opParamInfo_.layOutKey) == "PA_BBND")),
OPS_LOG_E(opName_, "input attr layout_key only supported PA_BSND, PA_BBND, BSND or TND"
"but now layout_key is %s.", layout_key.c_str()),
return ge::GRAPH_FAILED);
OPS_ERR_IF(((std::string(opParamInfo_.layOutQuery) != "BSND") && (std::string(opParamInfo_.layOutQuery) != "TND")),
OPS_LOG_E(opName_, "input attr layout_query only supported BSND or TND."), return ge::GRAPH_FAILED);
OPS_ERR_IF(
((std::string(opParamInfo_.layOutKey) != "PA_BSND") &&
(std::string(opParamInfo_.layOutQuery)) != (std::string(opParamInfo_.layOutKey))),
OPS_LOG_E(opName_, "outside of PA, input attr layout_query and input attr layout_key must be the same, but now layout_key is %s, layout_query is %s.",
layout_key.c_str(), layout_query.c_str()), return ge::GRAPH_FAILED);
OPS_ERR_IF(!((*opParamInfo_.sparseCount > 0) && (*opParamInfo_.sparseCount <= SPARSE_LIMIT)),
OPS_LOG_E(opName_, "input attr sparse_count must > 0 and <= 2048."), return ge::GRAPH_FAILED);
OPS_ERR_IF(!((*opParamInfo_.sparseMode == 0) || (*opParamInfo_.sparseMode == SPARSE_MODE_LOWER)),
OPS_LOG_E(opName_, "input attr sparse_mode only supported 0 or 3, but now is %u.",
*opParamInfo_.sparseMode), return ge::GRAPH_FAILED);
OPS_ERR_IF(*opParamInfo_.queryQuantMode != 0, OPS_LOG_E(opName_, "input attr query_quant_mode only supported 0."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(*opParamInfo_.keyQuantMode != 0, OPS_LOG_E(opName_, "input attr key_quant_mode only supported 0."),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::GetOpParaInfo()
{
GetInputParaInfo();
GetOutputParaInfo();
if (ge::GRAPH_SUCCESS != GetAttrParaInfo()) {
return ge::GRAPH_FAILED;
}
if (ge::GRAPH_SUCCESS != CheckAttrParaInfo()) {
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::GetAndCheckInOutDataType()
{
inputQType_ = opParamInfo_.query.desc->GetDataType();
inputKType_ = opParamInfo_.key.desc->GetDataType();
weightsType_ = opParamInfo_.weights.desc->GetDataType();
inputQueryScaleType_ = opParamInfo_.query_dequant_scale.desc->GetDataType();
inputKeyScaleType_ = opParamInfo_.key_dequant_scale.desc->GetDataType();
outputType_ = opParamInfo_.attenOut.desc->GetDataType();
OPS_ERR_IF(!(inputQType_ == inputKType_),
OPS_LOG_E(opName_, "The data types of the input query and key must be the same, but now is %s, %s respectively.",
inputQType_, inputKType_),
return ge::GRAPH_FAILED);
OPS_ERR_IF(
!(inputQueryScaleType_ == inputKeyScaleType_),
OPS_LOG_E(opName_, "The data types of the input query_dequant_scale and key_dequant_scale must be the same."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(inputQType_ != ge::DT_INT8,
OPS_LOG_E(opName_, "The data types of the input query and key must be int8."), return ge::GRAPH_FAILED);
OPS_ERR_IF(weightsType_ != ge::DT_FLOAT16,
OPS_LOG_E(opName_, "The data types of the input weights must be float16."), return ge::GRAPH_FAILED);
OPS_ERR_IF(
inputQueryScaleType_ != ge::DT_FLOAT16,
OPS_LOG_E(opName_, "The data types of the input query_dequant_scale and key_dequant_scale must be float16."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(outputType_ != ge::DT_INT32,
OPS_LOG_E(opName_, "The data types of the output sparse_indices must be int32."),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::GetQueryKeyAndOutLayout()
{
// 获取query,key的Layout基准值
const map<string, DataLayout> layoutQueryMap = {{"BSND", DataLayout::BSND}, {"TND", DataLayout::TND}};
std::string layout_query(opParamInfo_.layOutQuery);
auto QLayout_ = layoutQueryMap.find(layout_query);
if (QLayout_ != layoutQueryMap.end()) {
qLayout_ = QLayout_->second;
}
const map<string, DataLayout> layoutKeyMap = {
{"BSND", DataLayout::BSND}, {"TND", DataLayout::TND},
{"PA_BSND", DataLayout::PA_BSND}, {"PA_BBND", DataLayout::PA_BSND}};
std::string layout_key(opParamInfo_.layOutKey);
auto KLayout = layoutKeyMap.find(layout_key);
if (KLayout != layoutKeyMap.end()) {
kLayout_ = KLayout->second;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::GetAndCheckOptionalInput()
{
if (kLayout_ == DataLayout::PA_BSND) {
OPS_ERR_IF(opParamInfo_.blockTable.tensor == nullptr,
OPS_LOG_E(opName_, "key layout only supported PA_BSND, input block_table must not be null"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(
opParamInfo_.actualSeqLengthsK.tensor == nullptr,
OPS_LOG_E(opName_, "key layout only supported PA_BSND, input actual_seq_lengths_key must not be null"),
return ge::GRAPH_FAILED);
OPS_ERR_IF(opParamInfo_.blockTable.desc->GetDataType() != ge::DT_INT32,
OPS_LOG_E(opName_, "input block_table data type only support int32"), return ge::GRAPH_FAILED);
} else {
OPS_ERR_IF(opParamInfo_.blockTable.tensor != nullptr,
OPS_LOG_E(opName_, "key layout is not PA_BSND, input block_table must be null"),
return ge::GRAPH_FAILED);
}
if (kLayout_ == DataLayout::TND) {
OPS_ERR_IF(opParamInfo_.actualSeqLengthsK.tensor == nullptr,
OPS_LOG_E(opName_, "when layout_key is TND, input actual_seq_lengths_key must not be null"),
return ge::GRAPH_FAILED);
}
OPS_ERR_IF(opParamInfo_.actualSeqLengthsK.tensor != nullptr &&
opParamInfo_.actualSeqLengthsK.desc->GetDataType() != ge::DT_INT32,
OPS_LOG_E(opName_, "input actual_seq_lengths_key data type only support int32"),
return ge::GRAPH_FAILED);
if (qLayout_ == DataLayout::TND) {
OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor == nullptr,
OPS_LOG_E(opName_, "when layout_query is TND, input actual_seq_lengths_query must not be null"),
return ge::GRAPH_FAILED);
}
OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor != nullptr &&
opParamInfo_.actualSeqLengthsQ.desc->GetDataType() != ge::DT_INT32,
OPS_LOG_E(opName_, "input actual_seq_lengths_query data type only support int32"),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::CheckShapeDim()
{
OPS_ERR_IF((opParamInfo_.blockTable.tensor != nullptr) &&
(opParamInfo_.blockTable.tensor->GetStorageShape().GetDimNum() != DIM_NUM_TWO),
OPS_LOG_E(opName_, "the dim num of block_table's shape should be 2, but now is %u",
opParamInfo_.blockTable.tensor->GetStorageShape().GetDimNum()), return ge::GRAPH_FAILED);
OPS_ERR_IF(
(kLayout_ == DataLayout::PA_BSND) && (opParamInfo_.key.shape->GetStorageShape().GetDimNum() != DIM_NUM_FOUR),
OPS_LOG_E(opName_, "the dim num of key's shape should be 4, but now is %u",
opParamInfo_.key.shape->GetStorageShape().GetDimNum()), return ge::GRAPH_FAILED);
uint32_t qShapeDim = opParamInfo_.query.shape->GetStorageShape().GetDimNum();
uint32_t weightsShapeDim = opParamInfo_.weights.shape->GetStorageShape().GetDimNum();
uint32_t outShapeDim = opParamInfo_.attenOut.shape->GetStorageShape().GetDimNum();
uint32_t expectShapeDim = DIM_NUM_FOUR;
if (qLayout_ == DataLayout::TND) {
expectShapeDim = DIM_NUM_THREE;
}
OPS_ERR_IF(
qShapeDim != expectShapeDim,
OPS_LOG_E(opName_, "the dim num of query's shape should be %u, but now is %u", expectShapeDim, qShapeDim),
return ge::GRAPH_FAILED);
OPS_ERR_IF(outShapeDim != expectShapeDim,
OPS_LOG_E(opName_, "the dim num of sparse_indices's shape should be %u, but now is %u", expectShapeDim,
outShapeDim),
return ge::GRAPH_FAILED);
OPS_ERR_IF(!(weightsShapeDim == expectShapeDim - 1),
OPS_LOG_E(opName_, "the dim num of weights's shape should be %u, but now is %u", expectShapeDim - 1,
weightsShapeDim),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::GetN1Size()
{
if (qLayout_ == DataLayout::BSND) {
n1Size_ = static_cast<uint32_t>(opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_TWO));
} else {
// TND
n1Size_ = static_cast<uint32_t>(opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ONE));
}
OPS_LOG_I(context_->GetNodeName(), "n1Size is %d", n1Size_);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor,
const std::string &actualSeqLenName)
{
size = static_cast<uint32_t>(tensor->GetShapeSize());
if (size <= 0) {
OPS_LOG_E(opName_, "%s's shape size is %u, it should be greater than 0.", actualSeqLenName.c_str(), size);
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::GetAndCheckN2Size()
{
// PA_BSND
if (kLayout_ == DataLayout::TND) {
n2Size_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_ONE));
} else {
n2Size_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_TWO));
}
OPS_LOG_I(context_->GetNodeName(), "N2 is %d", n2Size_);
OPS_ERR_IF(n2Size_ != 1, OPS_LOG_E(opName_, "key shape[2] is numhead, only support 1."), return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::GetGSize()
{
if (n1Size_ % n2Size_ != 0) {
OPS_LOG_E(opName_, "input query's head_num %u can not be a multiple of key's head_num %u.", n1Size_, n2Size_);
return ge::GRAPH_FAILED;
}
gSize_ = n1Size_ / n2Size_;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::GetBatchSize()
{
// 获取B基准值
// 1、非TND/NTD时, 以query的batch_size维度为基准;
// 2、TND/NTD时, actual_seq_lens_q必须传入, 以actual_seq_lens_q数组的长度为B轴大小
if (qLayout_ == DataLayout::TND) {
return GetActualSeqLenSize(bSize_, opParamInfo_.actualSeqLengthsQ.tensor, "input actual_seq_lengths_query");
} else { // BSND
bSize_ = opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ZERO);
OPS_LOG_I(context_->GetNodeName(), "b: %d, s: %d, n: %d,d :%d",
opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ZERO),
opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_ONE),
opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_TWO),
opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_THREE));
return ge::GRAPH_SUCCESS;
}
}
ge::graphStatus LIQInfoParser::GetHeadDim()
{
// 以query的D维度为基准
uint32_t dIndex = DIM_IDX_TWO;
// 根据layout确定D维度在shape中的位置
switch (qLayout_) {
case DataLayout::TND:
// TND格式: [Total, N, D] -> D是第2维(索引2)
dIndex = DIM_IDX_TWO;
break;
case DataLayout::BSND:
// BSND格式: [Batch, SeqLen, N, D] -> D是第3维(索引3)
dIndex = DIM_IDX_THREE;
break;
default:
OPS_LOG_E(opName_, "unsupported layout for getting head dim.");
return ge::GRAPH_FAILED;
}
headDim_ = opParamInfo_.query.shape->GetStorageShape().GetDim(dIndex);
OPS_ERR_IF(headDim_ != HEAD_DIM_LIMIT, OPS_LOG_E(opName_, "input query's last dim head_dim only support 128, but now is %u.", headDim_),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::GetS1Size()
{
if (qLayout_ == DataLayout::BSND) {
s1Size_ = opParamInfo_.query.shape->GetStorageShape().GetDim(1);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::GetAndCheckBlockSize()
{
blockSize_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(1));
OPS_LOG_I(context_->GetNodeName(), "blockSize_ is %d", blockSize_);
OPS_ERR_IF(
((blockSize_ % BLOCK_SIZE_FACTOR != 0) || (blockSize_ == 0) || (blockSize_ > BLOCK_SIZE_LIMIT)),
OPS_LOG_E(opName_, "input key's block_size must be a multiple of 16 and belong to (0, 1024], but now is %u.",
blockSize_),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::GetS2SizeForPageAttention()
{
if (GetAndCheckBlockSize() != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
int32_t blockCount_ = static_cast<uint32_t>(opParamInfo_.key.shape->GetStorageShape().GetDim(0));
OPS_ERR_IF((blockCount_ == 0), OPS_LOG_E(opName_, "input key's block_count cannot be 0."), return ge::GRAPH_FAILED);
maxBlockNumPerBatch_ = opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(1);
s2Size_ = maxBlockNumPerBatch_ * blockSize_;
OPS_LOG_I(context_->GetNodeName(), "maxBlockNumPerBatch_ is %d, blockSize_ is %d, s2Size_ is %d",
maxBlockNumPerBatch_, blockSize_, s2Size_);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::GetS2SizeForBatchContinuous()
{
std::string layout_key(opParamInfo_.layOutKey);
if (kLayout_ == DataLayout::BSND) {
s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_ONE);
} else if (kLayout_ == DataLayout::TND) {
s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_ZERO);
}
OPS_ERR_IF((kLayout_ != DataLayout::BSND) && (kLayout_ != DataLayout::TND),
OPS_LOG_E(opName_, "the layout of key is %s, it is unsupported.", layout_key.c_str()),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::GetS2Size()
{
// 获取S2基准值
// 1、BATCH_CONTINUOUS时, 从key的S轴获取
// 3、PAGE_ATTENTION时, S2 = block_table.dim1 * block_size
if (kLayout_ == DataLayout::PA_BSND) {
return GetS2SizeForPageAttention();
}
return GetS2SizeForBatchContinuous();
}
ge::graphStatus LIQInfoParser::ValidateInputShapesMatch()
{
/*
TND:
query [T,N1,D],
key [BlockNum,BlockSize,N2,D],
weight [T,N1],
block_table [BatchSize, BatchMaxBlockNum],
act_seq_k [BatchSize]
act_seq_q [BatchSize],
out [T,N2,topk]
----------------------
BSND:
query [BatchSize,S1,N1,D],
key [BlockNum,BlockSize,N2,D],
weight [BatchSize,S1,N1],
block_table [BatchSize, BatchMaxBlockNum],
act_seq_k [BatchSize]
act_seq_q [BatchSize] 可选
out [BatchSize,S1,N2,topk]
*/
uint32_t queryWeightsN1Dim = 1;
uint32_t outN2Dim = 1;
if (qLayout_ == DataLayout::TND) {
// -----------------------check BatchSize-------------------
// bSize_ 来源于act_seq_q
OPS_ERR_IF((kLayout_ == DataLayout::PA_BSND) &&
((opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_) ||
(opParamInfo_.blockTable.tensor != nullptr &&
opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_)),
OPS_LOG_E(
opName_,
"TND case input actual_seq_lengths_query, actual_seq_lengths_key, block_table dim 0 are %u, %u, %u respectively, they must be same.",
bSize_, opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize(),
opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0)),
return ge::GRAPH_FAILED);
OPS_ERR_IF((kLayout_ != DataLayout::PA_BSND) &&
(opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_),
OPS_LOG_E(
opName_,
"TND case input actual_seq_lengths_query, actual_seq_lengths_key, are %u, %u respectively, they must be same.",
bSize_, opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize()),
return ge::GRAPH_FAILED);
// -----------------------check T-------------------
uint32_t qTsize = opParamInfo_.query.shape->GetStorageShape().GetDim(0);
OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != qTsize) ||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != qTsize),
OPS_LOG_E(opName_,
"TND case input query, weights, sparse_indices dim 0 are %u, %u, %u respectively, they must be same.",
qTsize, opParamInfo_.weights.shape->GetStorageShape().GetDim(0),
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)),
return ge::GRAPH_FAILED);
} else {
// -----------------------check BatchSize-------------------
// bSize_ 来源于query
OPS_ERR_IF((kLayout_ == DataLayout::PA_BSND) &&
((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != bSize_) ||
(opParamInfo_.blockTable.tensor != nullptr &&
opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_) ||
(opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_) ||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != bSize_)),
OPS_LOG_E(opName_,
"BSND case input query, weight, actual_seq_lengths_key, block_table, sparse_indices dim 0 are %u, %u, %u, %u, %u respectively, they must be same.",
bSize_, opParamInfo_.weights.shape->GetStorageShape().GetDim(0),
opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize(),
opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0),
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)),
return ge::GRAPH_FAILED);
OPS_ERR_IF((kLayout_ != DataLayout::PA_BSND) &&
((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != bSize_) ||
(opParamInfo_.actualSeqLengthsK.tensor != nullptr &&
opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize() != bSize_) ||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != bSize_)),
OPS_LOG_E(opName_,
"BSND case input query, weight, actual_seq_lengths_key, sparse_indices dim 0 are %u, %u, %u, %u respectively, they must be same.",
bSize_, opParamInfo_.weights.shape->GetStorageShape().GetDim(0),
opParamInfo_.actualSeqLengthsK.tensor->GetShapeSize(),
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)),
return ge::GRAPH_FAILED);
OPS_ERR_IF(
(opParamInfo_.actualSeqLengthsQ.tensor != nullptr) &&
(opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize() != bSize_),
OPS_LOG_E(
opName_,
"BSND case input query, actual_seq_lengths_query dim 0 are %u, %u respectively, they must be same",
bSize_, opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize()),
return ge::GRAPH_FAILED);
// -----------------------check S1-------------------
OPS_ERR_IF(
(opParamInfo_.weights.shape->GetStorageShape().GetDim(1) != s1Size_) ||
(opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1) != s1Size_),
OPS_LOG_E(opName_, "BSND case input query, weight, sparse_indices dim 1 are %u, %u, %u, they must be same.",
s1Size_, opParamInfo_.weights.shape->GetStorageShape().GetDim(1),
opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1)),
return ge::GRAPH_FAILED);
queryWeightsN1Dim = DIM_IDX_TWO;
outN2Dim = DIM_IDX_TWO;
}
// -----------------------check N1-------------------
OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(queryWeightsN1Dim) != n1Size_),
OPS_LOG_E(opName_, "input query, weight shape dim N1 must be same, but now are %u, %u respectively.",
opParamInfo_.weights.shape->GetStorageShape().GetDim(queryWeightsN1Dim), n1Size_),
return ge::GRAPH_FAILED);
// -----------------------check D-------------------
OPS_ERR_IF(
((kLayout_ != DataLayout::TND && opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_THREE) != headDim_)
|| (kLayout_ == DataLayout::TND && opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_TWO) != headDim_)),
OPS_LOG_E(opName_, "input query, key shape last dim must be same, now are %u, %u respectively.",
headDim_, opParamInfo_.key.shape->GetStorageShape().GetDim(DIM_IDX_THREE)),
return ge::GRAPH_FAILED);
// -----------------------check N2-------------------
OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim) != n2Size_),
OPS_LOG_E(opName_, "input query and output sparse_indices shape n2 dim must be same."),
return ge::GRAPH_FAILED);
// -----------------------check sparse_count-------------------
OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim + 1) != *opParamInfo_.sparseCount),
OPS_LOG_E(opName_, "output sparse_indices shape last dim must be same as attr sparse_count."),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus LIQInfoParser::CheckScaleShape()
{
uint32_t qShapeDim = opParamInfo_.query.shape->GetStorageShape().GetDimNum();
uint32_t kShapeDim = opParamInfo_.key.shape->GetStorageShape().GetDimNum();
uint32_t qDequantScaleShapeDim = opParamInfo_.query_dequant_scale.shape->GetStorageShape().GetDimNum();
uint32_t kDequantScaleShapeDim = opParamInfo_.key_dequant_scale.shape->GetStorageShape().GetDimNum();
OPS_ERR_IF(qDequantScaleShapeDim != (qShapeDim - 1),
OPS_LOG_E(opName_, "the dim num of query_dequant_scale's shape should be %u, but now is %u",
qShapeDim - 1, qDequantScaleShapeDim),
return ge::GRAPH_FAILED);
OPS_ERR_IF(kDequantScaleShapeDim != (kShapeDim - 1),
OPS_LOG_E(opName_, "the dim num of key_dequant_scale's shape should be %u, but now is %u", kShapeDim - 1,
kDequantScaleShapeDim),
return ge::GRAPH_FAILED);
// check q scale
for (uint32_t i = 0; i < (qShapeDim - 1); i++) {
uint32_t dimValueQueryScale = opParamInfo_.query_dequant_scale.shape->GetStorageShape().GetDim(i);
uint32_t dimValueQuery = opParamInfo_.query.shape->GetStorageShape().GetDim(i);
OPS_ERR_IF(dimValueQueryScale != dimValueQuery,
OPS_LOG_E(opName_, "query_dequant_scale's shape[%u] %u and query's shape[%u] %u is not same", i,
dimValueQueryScale, i, dimValueQuery),
return ge::GRAPH_FAILED);
}
// check k scale
for (uint32_t i = 0; i < (kShapeDim - 1); i++) {
uint32_t dimValueKeyScale = opParamInfo_.key_dequant_scale.shape->GetStorageShape().GetDim(i);
uint32_t dimValueKey = opParamInfo_.key.shape->GetStorageShape().GetDim(i);
OPS_ERR_IF(dimValueKeyScale != dimValueKey,
OPS_LOG_E(opName_, "key_dequant_scale's shape[%u] %u and key's shape[%u] %u is not same", i,
dimValueKeyScale, i, dimValueKey),
return ge::GRAPH_FAILED);
}
return ge::GRAPH_SUCCESS;
}
void LIQInfoParser::GenerateInfo(LIQTilingInfo &liqInfo)
{
liqInfo.opName = opName_;
liqInfo.platformInfo = platformInfo_;
liqInfo.opParamInfo = opParamInfo_;
liqInfo.socVersion = socVersion_;
liqInfo.bSize = bSize_;
liqInfo.n1Size = n1Size_;
liqInfo.n2Size = n2Size_;
liqInfo.s1Size = s1Size_;
liqInfo.s2Size = s2Size_;
liqInfo.gSize = gSize_;
liqInfo.inputQType = inputQType_;
liqInfo.inputKType = inputKType_;
liqInfo.outputType = outputType_;
liqInfo.blockSize = blockSize_;
liqInfo.maxBlockNumPerBatch = maxBlockNumPerBatch_;
liqInfo.pageAttentionFlag = (kLayout_ == DataLayout::PA_BSND);
liqInfo.sparseMode = *opParamInfo_.sparseMode;
liqInfo.sparseCount = *opParamInfo_.sparseCount;
liqInfo.inputQLayout = qLayout_;
liqInfo.inputKLayout = kLayout_;
}
ge::graphStatus LIQInfoParser::ParseAndCheck(LIQTilingInfo &liqInfo)
{
if (ge::GRAPH_SUCCESS != GetOpName() || ge::GRAPH_SUCCESS != GetNpuInfo() || ge::GRAPH_SUCCESS != GetOpParaInfo() ||
ge::GRAPH_SUCCESS != CheckRequiredParaExistence()) {
return ge::GRAPH_FAILED;
}
if (ge::GRAPH_SUCCESS != GetAndCheckInOutDataType() || ge::GRAPH_SUCCESS != GetQueryKeyAndOutLayout() ||
ge::GRAPH_SUCCESS != GetAndCheckOptionalInput()) {
return ge::GRAPH_FAILED;
}
if (ge::GRAPH_SUCCESS != CheckShapeDim() || ge::GRAPH_SUCCESS != GetN1Size() ||
ge::GRAPH_SUCCESS != GetAndCheckN2Size() || ge::GRAPH_SUCCESS != GetGSize()) {
return ge::GRAPH_FAILED;
}
if (ge::GRAPH_SUCCESS != GetBatchSize() || ge::GRAPH_SUCCESS != GetS1Size() || ge::GRAPH_SUCCESS != GetHeadDim() ||
ge::GRAPH_SUCCESS != GetS2Size()) {
return ge::GRAPH_FAILED;
}
if (ge::GRAPH_SUCCESS != ValidateInputShapesMatch() || ge::GRAPH_SUCCESS != CheckScaleShape()) {
return ge::GRAPH_FAILED;
}
GenerateInfo(liqInfo);
return ge::GRAPH_SUCCESS;
}
// --------------------------TilingPrepare函数定义-------------------------------------
static ge::graphStatus TilingPrepareForLightningIndexerQuant(gert::TilingParseContext * /* context */)
{
return ge::GRAPH_SUCCESS;
}
// --------------------------LightningIndexerQuantTiling类成员函数定义-----------------------
ge::graphStatus LightningIndexerQuantTiling::DoTiling(LIQTilingInfo *tilingInfo)
{
// -------------set blockdim-----------------
auto ascendcPlatform = platform_ascendc::PlatformAscendC(tilingInfo->platformInfo);
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
uint32_t blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, aicNum, aivNum);
context_->SetBlockDim(blockDim);
// -------------set workspacesize-----------------
constexpr uint32_t MM1_RES_ELEM_SIZE = 4; // 4: fp32
constexpr uint32_t DOUBLE_BUFFER = 2; // 双Buffer
constexpr uint32_t M_BASE_SIZE = 512; // m轴基本块大小
constexpr uint32_t S2_BASE_SIZE = 512; // S2轴基本块大小
constexpr uint32_t V1_RES_ELEM_SIZE = 4; // 4: int32
constexpr uint32_t V1_RES_ELEM_TYPE = 2; // 保留Index和Value 2种数据
constexpr uint32_t V1_DECODE_PARAM_ELEM_SIZE = 8; // 8: int64
constexpr uint32_t V1_DECODE_PARAM_NUM = 16; // Decode参数个数
constexpr uint32_t V1_DECODE_DATA_NUM = 2; // Decode每个核需要存储头和尾部两块数据
constexpr uint32_t S1_BASE_SIZE = 8; // S1轴基本块的大小
constexpr uint32_t TOPK_MAX_SIZE = 2048; // TopK选取个数
uint32_t workspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize();
// 主流程需Workspace大小
uint32_t mm1ResSize = M_BASE_SIZE * S2_BASE_SIZE;
workspaceSize += mm1ResSize * MM1_RES_ELEM_SIZE * DOUBLE_BUFFER * aicNum;
// Decode流程(LD)需要Workspace大小
// 临时存储Decode中间结果大小: 2(头/尾)*8(s1Base)*2(idx/value)*2048(K)*sizeof(int32)*24=6M
workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_RES_ELEM_TYPE * TOPK_MAX_SIZE * V1_RES_ELEM_SIZE * aicNum;
// 临时存储Decode中间参数信息大小: 2(头/尾)*8(s1Base)*16(paramNum)*sizeof(int64_t)*24=48k
workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_DECODE_PARAM_NUM * V1_DECODE_PARAM_ELEM_SIZE * aicNum;
size_t *workSpaces = context_->GetWorkspaceSizes(1);
workSpaces[0] = workspaceSize;
// -------------set tilingdata-----------------
tilingData_.set_bSize(tilingInfo->bSize);
tilingData_.set_s2Size(tilingInfo->s2Size);
tilingData_.set_s1Size(tilingInfo->s1Size);
tilingData_.set_sparseCount(tilingInfo->sparseCount);
tilingData_.set_gSize(tilingInfo->gSize);
tilingData_.set_blockSize(tilingInfo->blockSize);
tilingData_.set_maxBlockNumPerBatch(tilingInfo->maxBlockNumPerBatch);
tilingData_.set_sparseMode(tilingInfo->sparseMode);
tilingData_.set_usedCoreNum(blockDim);
tilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity());
context_->GetRawTilingData()->SetDataSize(tilingData_.GetDataSize());
// -------------set tilingkey-----------------
// DT_Q, DT_KV, DT_OUT, PAGE_ATTENTION, FLASH_DECODE, LAYOUT_T, KV_LAYOUT_T
uint32_t inputQType = static_cast<uint32_t>(tilingInfo->inputQType);
uint32_t inputKType = static_cast<uint32_t>(tilingInfo->inputKType);
uint32_t outputType = static_cast<uint32_t>(tilingInfo->outputType);
uint32_t pageAttentionFlag = static_cast<uint32_t>(tilingInfo->pageAttentionFlag);
uint32_t inputQLayout = static_cast<uint32_t>(tilingInfo->inputQLayout);
uint32_t inputKLayout = static_cast<uint32_t>(tilingInfo->inputKLayout);
uint32_t tilingKey =
GET_TPL_TILING_KEY(inputQType, inputKType, outputType, pageAttentionFlag, inputQLayout, inputKLayout);
context_->SetTilingKey(tilingKey);
return ge::GRAPH_SUCCESS;
}
// --------------------------Tiling函数定义---------------------------
ge::graphStatus TilingForLightningIndexerQuant(gert::TilingContext *context)
{
OPS_ERR_IF(context == nullptr, OPS_REPORT_VECTOR_INNER_ERR("LightningIndexerQuant", "Tiling context is null."),
return ge::GRAPH_FAILED);
LIQTilingInfo liqInfo;
LIQInfoParser LIQInfoParser(context);
if (LIQInfoParser.ParseAndCheck(liqInfo) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
LightningIndexerQuantTiling liqTiling(context);
return liqTiling.DoTiling(&liqInfo);
}
// --------------------------Tiling及函数TilingPrepare函数注册--------
IMPL_OP_OPTILING(LightningIndexerQuant)
.Tiling(TilingForLightningIndexerQuant)
.TilingParse<LIQCompileInfo>(TilingPrepareForLightningIndexerQuant);
} // namespace optiling

View File

@@ -0,0 +1,234 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_quant_tiling.h
* \brief
*/
#ifndef LIGHTNING_INDEXER_QUANT_TILING_H_
#define LIGHTNING_INDEXER_QUANT_TILING_H_
#include "error/ops_error.h"
#include "exe_graph/runtime/tiling_context.h"
#include "platform/platform_info.h"
#include "register/op_def_registry.h"
#include "register/tilingdata_base.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/tiling_api.h"
namespace optiling {
// ------------------公共定义--------------------------
struct TilingRequiredParaInfo {
const gert::CompileTimeTensorDesc *desc;
const gert::StorageShape *shape;
};
struct TilingOptionalParaInfo {
const gert::CompileTimeTensorDesc *desc;
const gert::Tensor *tensor;
};
enum class DataLayout : uint32_t {
BSND = 0,
TND = 1,
PA_BSND = 2
};
// ------------------算子原型索引常量定义----------------
// Inputs Index
constexpr uint32_t QUERY_INDEX = 0;
constexpr uint32_t KEY_INDEX = 1;
constexpr uint32_t WEIGTHS_INDEX = 2;
constexpr uint32_t QUERY_DEQUANT_SCALE_INDEX = 3;
constexpr uint32_t KEY_DEQUANT_SCALE_INDEX = 4;
constexpr uint32_t ACTUAL_SEQ_Q_INDEX = 5;
constexpr uint32_t ACTUAL_SEQ_K_INDEX = 6;
constexpr uint32_t BLOCK_TABLE_INDEX = 7;
constexpr uint32_t LIGHTNING_INDEXER_QUANT = 0;
// Attributes Index
constexpr uint32_t ATTR_QUERY_QUANT_MODE_INDEX = 0;
constexpr uint32_t ATTR_KEY_QUANT_MODE_INDEX = 1;
constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 2;
constexpr uint32_t ATTR_KEY_LAYOUT_INDEX = 3;
constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 4;
constexpr uint32_t ATTR_SPARSE_MODE_INDEX = 5;
// Dim Index
constexpr uint32_t DIM_IDX_ZERO = 0;
constexpr uint32_t DIM_IDX_ONE = 1;
constexpr uint32_t DIM_IDX_TWO = 2;
constexpr uint32_t DIM_IDX_THREE = 3;
// Dim Num
constexpr uint32_t DIM_NUM_TWO = 2;
constexpr uint32_t DIM_NUM_THREE = 3;
constexpr uint32_t DIM_NUM_FOUR = 4;
// 入参限制常量
constexpr uint32_t HEAD_DIM_LIMIT = 128;
constexpr uint32_t SPARSE_LIMIT = 2048;
constexpr uint32_t G_SIZE_LIMIT = 64;
constexpr uint32_t BLOCK_SIZE_LIMIT = 1024;
constexpr uint32_t BLOCK_SIZE_FACTOR = 16;
constexpr uint32_t SPARSE_MODE_LOWER = 3;
// -----------算子TilingData定义---------------
BEGIN_TILING_DATA_DEF(LIQTilingData)
TILING_DATA_FIELD_DEF(uint32_t, bSize)
TILING_DATA_FIELD_DEF(uint32_t, n2Size)
TILING_DATA_FIELD_DEF(uint32_t, gSize)
TILING_DATA_FIELD_DEF(uint32_t, s1Size)
TILING_DATA_FIELD_DEF(uint32_t, s2Size)
TILING_DATA_FIELD_DEF(uint32_t, sparseCount)
TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum)
TILING_DATA_FIELD_DEF(uint32_t, blockSize)
TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch)
TILING_DATA_FIELD_DEF(uint32_t, sparseMode)
END_TILING_DATA_DEF
REGISTER_TILING_DATA_CLASS(LightningIndexerQuant, LIQTilingData)
// -----------算子CompileInfo定义-------------------
struct LIQCompileInfo {};
// -----------算子Tiling入参结构体定义---------------
struct LIQParaInfo {
TilingRequiredParaInfo query = {nullptr, nullptr};
TilingRequiredParaInfo key = {nullptr, nullptr};
TilingRequiredParaInfo weights = {nullptr, nullptr};
TilingRequiredParaInfo query_dequant_scale = {nullptr, nullptr};
TilingRequiredParaInfo key_dequant_scale = {nullptr, nullptr};
TilingOptionalParaInfo actualSeqLengthsQ = {nullptr, nullptr};
TilingOptionalParaInfo actualSeqLengthsK = {nullptr, nullptr};
TilingOptionalParaInfo blockTable = {nullptr, nullptr};
TilingRequiredParaInfo attenOut = {nullptr, nullptr};
const int32_t *queryQuantMode = nullptr;
const int32_t *keyQuantMode = nullptr;
const char *layOutQuery = nullptr;
const char *layOutKey = nullptr;
const int32_t *blockSize = nullptr;
const int32_t *sparseMode = nullptr;
const int32_t *sparseCount = nullptr;
};
// -----------算子Tiling入参信息类---------------
class LIQTilingInfo {
public:
const char *opName = nullptr;
fe::PlatFormInfos *platformInfo = nullptr;
LIQParaInfo opParamInfo;
// Base Param
platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910B;
uint32_t bSize = 0;
uint32_t n1Size = 0;
uint32_t n2Size = 0;
uint32_t s1Size = 0;
int64_t s2Size = 0;
uint32_t qkHeadDim = 0;
uint32_t gSize = 0;
// PageAttention
bool pageAttentionFlag = false;
int32_t blockSize = 0;
uint32_t maxBlockNumPerBatch = 0;
// Mask
int32_t sparseMode = 0;
// Others Flag
uint32_t sparseCount = 0;
// DType
ge::DataType inputQType = ge::DT_FLOAT16;
ge::DataType inputKType = ge::DT_FLOAT16;
ge::DataType outputType = ge::DT_INT32;
// Layout
DataLayout inputQLayout = DataLayout::BSND;
DataLayout inputKLayout = DataLayout::PA_BSND;
};
// -----------算子Tiling入参信息解析及Check类---------------
class LIQInfoParser {
public:
explicit LIQInfoParser(gert::TilingContext *context) : context_(context) {}
~LIQInfoParser() = default;
ge::graphStatus CheckRequiredInOutExistence() const;
ge::graphStatus CheckRequiredAttrExistence() const;
ge::graphStatus CheckRequiredParaExistence() const;
ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor,
const std::string &actualSeqLenName);
ge::graphStatus GetOpName();
ge::graphStatus GetNpuInfo();
void GetOptionalInputParaInfo();
void GetInputParaInfo();
void GetOutputParaInfo();
ge::graphStatus GetAttrParaInfo();
ge::graphStatus CheckAttrParaInfo();
ge::graphStatus GetOpParaInfo();
ge::graphStatus ValidateInputShapesMatch();
ge::graphStatus CheckScaleShape();
ge::graphStatus GetAndCheckInOutDataType();
ge::graphStatus GetBatchSize();
ge::graphStatus GetHeadDim();
ge::graphStatus GetS1Size();
ge::graphStatus GetAndCheckOptionalInput();
ge::graphStatus CheckShapeDim();
ge::graphStatus GetAndCheckBlockSize();
ge::graphStatus GetS2SizeForPageAttention();
ge::graphStatus GetS2SizeForBatchContinuous();
ge::graphStatus GetS2Size();
ge::graphStatus GetQueryKeyAndOutLayout();
ge::graphStatus GetN1Size();
ge::graphStatus GetAndCheckN2Size();
ge::graphStatus GetGSize();
ge::graphStatus GetAttenMaskInfo();
ge::graphStatus GetActualSeqInfo();
void GenerateInfo(LIQTilingInfo &liqInfo);
ge::graphStatus ParseAndCheck(LIQTilingInfo &liqInfo);
public:
gert::TilingContext *context_ = nullptr;
const char *opName_;
fe::PlatFormInfos *platformInfo_;
LIQParaInfo opParamInfo_;
// BaseParams
uint32_t bSize_ = 0;
uint32_t n1Size_ = 0;
uint32_t n2Size_ = 0;
uint32_t gSize_ = 0;
uint32_t s1Size_ = 0;
int64_t s2Size_ = 0;
uint32_t headDim_ = 0;
// Layout
DataLayout qLayout_ = DataLayout::BSND;
DataLayout kLayout_ = DataLayout::PA_BSND;
// PageAttention
uint32_t maxBlockNumPerBatch_ = 0;
int32_t blockSize_ = 0;
platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B;
ge::DataType inputQType_ = ge::DT_FLOAT16;
ge::DataType inputKType_ = ge::DT_FLOAT16;
ge::DataType weightsType_ = ge::DT_FLOAT16;
ge::DataType inputQueryScaleType_ = ge::DT_FLOAT16;
ge::DataType inputKeyScaleType_ = ge::DT_FLOAT16;
ge::DataType blockTableType_ = ge::DT_FLOAT16;
ge::DataType inputKRopeType_ = ge::DT_FLOAT16;
ge::DataType outputType_ = ge::DT_FLOAT16;
};
// ---------------算子Tiling类---------------
class LightningIndexerQuantTiling {
public:
explicit LightningIndexerQuantTiling(gert::TilingContext *context) : context_(context) {};
ge::graphStatus DoTiling(LIQTilingInfo *tilingInfo);
private:
gert::TilingContext *context_ = nullptr;
LIQTilingData tilingData_;
};
} // namespace optiling
#endif // LIGHTNING_INDEXER_QUANT_TILING_H_

View File

@@ -0,0 +1,50 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_quant.cpp
* \brief
*/
#include "kernel_operator.h"
#include "lib/matmul_intf.h"
#include "lightning_indexer_quant_kernel.h"
#include "lightning_indexer_quant_template_tiling_key.h"
using namespace LIQKernel;
#define INVOKE_LI_NO_KFC_OP_IMPL(templateClass, ...) \
do { \
templateClass<LIQType<__VA_ARGS__>> op; \
GET_TILING_DATA_WITH_STRUCT(LIQTilingData, tiling_data_in, tiling); \
const LIQTilingData *__restrict tiling_data = &tiling_data_in; \
op.Init(query, key, weights, queryScale, keyScale, actualSeqLengthsQ, actualSeqLengthsK, blocktable, \
sparseIndices, user, tiling_data, &tPipe); \
op.Process(); \
} while (0)
template <int DT_Q, int DT_K, int DT_OUT, int PAGE_ATTENTION, int Q_LAYOUT_T, int K_LAYOUT_T>
__global__ __aicore__ void lightning_indexer_quant(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
__gm__ uint8_t *queryScale, __gm__ uint8_t *keyScale,
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengthsK,
__gm__ uint8_t *blocktable, __gm__ uint8_t *sparseIndices,
__gm__ uint8_t *workspace, __gm__ uint8_t *tiling)
{
#if (__CCE_AICORE__ == 310) || (defined __DAV_310R6__) || (__CCE_AICORE__ == 200)
#else
TPipe tPipe;
__gm__ uint8_t *user = GetUserWorkspace(workspace);
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2);
INVOKE_LI_NO_KFC_OP_IMPL(LIQPreload, int8_t, int8_t, int32_t,
PAGE_ATTENTION, LI_LAYOUT(Q_LAYOUT_T), LI_LAYOUT(K_LAYOUT_T));
#endif
}

View File

@@ -0,0 +1,146 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_quant_common.h
* \brief
*/
#ifndef LIGHTNING_INDEXER_QUANT_COMMON_H
#define LIGHTNING_INDEXER_QUANT_COMMON_H
namespace LIQCommon {
// 与tiling的layout保持一致
enum class LI_LAYOUT : uint32_t {
BSND = 0,
TND = 1,
PA_BSND = 2
};
template <typename Q_T, typename K_T, typename OUT_T, const bool PAGE_ATTENTION = false,
LI_LAYOUT Q_LAYOUT_T = LI_LAYOUT::BSND, LI_LAYOUT K_LAYOUT_T = LI_LAYOUT::PA_BSND, typename... Args>
struct LIQType {
using queryType = Q_T;
using keyType = K_T;
using outputType = OUT_T;
static constexpr bool pageAttention = PAGE_ATTENTION;
static constexpr LI_LAYOUT layout = Q_LAYOUT_T;
static constexpr LI_LAYOUT keyLayout = K_LAYOUT_T;
};
struct RunInfo {
uint32_t loop;
uint32_t bN2Idx;
uint32_t bIdx;
uint32_t n2Idx = 0;
uint32_t gS1Idx;
uint32_t s2Idx;
uint32_t actS1Size = 1;
uint32_t actS2Size = 1;
uint32_t actMBaseSize;
uint32_t actualSingleProcessSInnerSize;
uint32_t actualSingleProcessSInnerSizeAlign;
uint64_t tensorQueryOffset;
uint64_t tensorKeyOffset;
uint64_t tensorKeyScaleOffset;
uint64_t tensorWeightsOffset;
uint64_t indiceOutOffset;
bool isFirstS2InnerLoop;
bool isLastS2InnerLoop;
bool isAllLoopEnd = false;
bool isValid = false;
};
struct ConstInfo {
// CUBE与VEC核间同步的模式
static constexpr uint32_t FIA_SYNC_MODE2 = 2;
// BUFFER的字节数
static constexpr uint32_t BUFFER_SIZE_BYTE_32B = 32;
static constexpr uint32_t BUFFER_SIZE_BYTE_64B = 64;
static constexpr uint32_t BUFFER_SIZE_BYTE_256B = 256;
static constexpr uint32_t BUFFER_SIZE_BYTE_512B = 512;
static constexpr uint32_t BUFFER_SIZE_BYTE_1K = 1024;
static constexpr uint32_t BUFFER_SIZE_BYTE_2K = 2048;
static constexpr uint32_t BUFFER_SIZE_BYTE_4K = 4096;
static constexpr uint32_t BUFFER_SIZE_BYTE_8K = 8192;
static constexpr uint32_t BUFFER_SIZE_BYTE_16K = 16384;
static constexpr uint32_t BUFFER_SIZE_BYTE_32K = 32768;
// 无效索引
static constexpr int INVALID_IDX = -1;
// CUBE和VEC的核间同步EventID
uint32_t syncC1V1 = 0U;
uint32_t syncC1V0 = 2U;
uint32_t syncV1C1 = 0U;
uint32_t syncV0C1 = 1U;
// 基本块大小
uint32_t mBaseSize = 1ULL;
uint32_t s1BaseSize = 1ULL;
uint32_t s2BaseSize = 1ULL;
uint64_t batchSize = 0ULL;
uint64_t gSize = 0ULL;
uint64_t qHeadNum = 0ULL;
uint64_t kHeadNum;
uint64_t headDim;
uint64_t sparseCount; // topK选取大小
uint64_t kSeqSize = 0ULL; // kv最大S长度
uint64_t qSeqSize = 1ULL; // q最大S长度
uint32_t kCacheBlockSize = 0; // PA场景的block size
uint32_t maxBlockNumPerBatch = 0; // PA场景的最大单batch block number
LI_LAYOUT outputLayout; // 输出的格式
bool attenMaskFlag = false;
uint32_t actualLenQDims = 0U; // query的actualSeqLength 的维度
uint32_t actualLenDims = 0U; // KV 的actualSeqLength 的维度
bool isAccumSeqS1 = false; // 是否累加模式
bool isAccumSeqS2 = false; // 是否累加模式
};
struct SplitCoreInfo {
uint32_t s2Start = 0U; // S2的起始位置
uint32_t s2End = 0U; // S2循环index上限
uint32_t bN2Start = 0U;
uint32_t bN2End = 0U;
uint32_t gS1Start = 0U;
uint32_t gS1End = 0U;
bool isLD = false; // 当前核是否需要进行Decode归约任务
};
template <typename T>
__aicore__ inline T Align(T num, T rnd)
{
return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd) * (rnd)));
}
template <typename T1, typename T2>
__aicore__ inline T1 Min(T1 a, T2 b)
{
return (a > b) ? (b) : (a);
}
template <typename T1, typename T2>
__aicore__ inline T1 Max(T1 a, T2 b)
{
return (a > b) ? (a) : (b);
}
template <typename T>
__aicore__ inline T CeilDiv(T num, T rnd)
{
return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd)));
}
} // namespace LIQCommon
#endif // LIGHTNING_INDEXER_QUANT_COMMON_H

View File

@@ -0,0 +1,714 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_quant_kernel.h
* \brief
*/
#ifndef LIGHTNING_INDEXER_QUANT_KERNEL_H
#define LIGHTNING_INDEXER_QUANT_KERNEL_H
#include "kernel_operator.h"
#include "kernel_operator_list_tensor_intf.h"
#include "kernel_tiling/kernel_tiling.h"
#include "lib/matmul_intf.h"
#include "lib/matrix/matmul/tiling.h"
#include "lightning_indexer_quant_common.h"
#include "lightning_indexer_quant_service_vector.h"
#include "lightning_indexer_quant_service_cube.h"
namespace LIQKernel {
using namespace LIQCommon;
using namespace LIQServiceVec;
using namespace matmul;
using AscendC::CacheMode;
using AscendC::CrossCoreSetFlag;
using AscendC::CrossCoreWaitFlag;
// 由于S2循环前RunInfo还没有赋值使用TempLoopInfo临时存放B、N、S1轴相关的信息同时减少重复计算
struct TempLoopInfo {
uint32_t bN2Idx = 0;
uint32_t bIdx = 0U;
uint32_t n2Idx = 0U;
uint32_t gS1Idx = 0U;
uint32_t gS1LoopEnd = 0U; // gS1方向循环的结束Idx
uint32_t s2LoopEnd = 0U; // S2方向循环的结束Idx
uint32_t actS1Size = 1ULL; // 当前Batch循环处理的S1轴的实际大小
uint32_t actS2Size = 0ULL;
bool curActSeqLenIsZero = false;
bool needDealActS1LessThanS1 = false; // S1的实际长度小于shape的S1长度时是否需要清理输出
uint32_t actMBaseSize = 0U; // m轴(gS1)方向实际大小
uint32_t mBasicSizeTail = 0U; // gS1方向循环的尾基本块大小
uint32_t s2BasicSizeTail = 0U; // S2方向循环的尾基本块大小
};
template <typename LIQT>
class LIQPreload {
public:
__aicore__ inline LIQPreload(){};
__aicore__ inline void Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
__gm__ uint8_t *queryScale, __gm__ uint8_t *keyScale, __gm__ uint8_t *actualSeqLengthsQ,
__gm__ uint8_t *actualSeqLengthsK, __gm__ uint8_t *blockTable,
__gm__ uint8_t *sparseIndices, __gm__ uint8_t *workspace,
const LIQTilingData *__restrict tiling, TPipe *tPipe);
__aicore__ inline void Process();
// =================================类型定义区=================================
using Q_T = typename LIQT::queryType;
using K_T = typename LIQT::keyType;
using OUT_T = typename LIQT::outputType;
static constexpr bool PAGE_ATTENTION = LIQT::pageAttention;
static constexpr LI_LAYOUT Q_LAYOUT_T = LIQT::layout;
static constexpr LI_LAYOUT K_LAYOUT_T = LIQT::keyLayout;
using MM1_OUT_T = float;
LIQMatmul<LIQT> matmulService;
LIQVector<LIQT> vectorService;
// =================================常量区=================================
static constexpr uint32_t SYNC_C1_V1_FLAG = 4;
static constexpr uint32_t SYNC_V1_C1_FLAG = 5;
static constexpr uint32_t M_BASE_SIZE = 256;
static constexpr uint32_t S2_BASE_SIZE = 2048;
static constexpr uint32_t HEAD_DIM = 128;
static constexpr uint32_t K_HEAD_NUM = 1;
static constexpr uint32_t GM_ALIGN_BYTES = 512;
static constexpr uint32_t LI_QUANT_PRELOAD_TASK_CACHE_SIZE = 2;
static constexpr int64_t LD_PREFETCH_LEN = 2;
// for workspace double
static constexpr uint32_t WS_DOBULE = 2;
protected:
TPipe *pipe = nullptr;
// offset
uint64_t queryCoreOffset = 0ULL;
uint64_t keyCoreOffset = 0ULL;
uint64_t keyScaleCoreOffset = 0ULL;
uint64_t weightsCoreOffset = 0ULL;
uint64_t indiceOutCoreOffset = 0ULL;
// ================================Global Buffer区=================================
GlobalTensor<Q_T> queryGm;
GlobalTensor<K_T> keyGm;
GlobalTensor<half> weightsGm;
GlobalTensor<int32_t> indiceOutGm;
GlobalTensor<int32_t> blockTableGm;
GlobalTensor<uint32_t> actualSeqLengthsGmQ;
GlobalTensor<uint32_t> actualSeqLengthsGm;
// ================================类成员变量====================================
// aic、aiv核信息
uint32_t tmpBlockIdx = 0U;
uint32_t aiCoreIdx = 0U;
uint32_t usedCoreNum = 0U;
LIQCommon::ConstInfo constInfo{};
TempLoopInfo tempLoopInfo{};
LIQCommon::SplitCoreInfo splitCoreInfo{};
// ================================Init functions==================================
__aicore__ inline void InitTilingData(const LIQTilingData *__restrict tilingData);
__aicore__ inline void InitBuffers();
__aicore__ inline void InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengthsK);
// ================================Split Core================================
__aicore__ inline void SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, LIQCommon::SplitCoreInfo &info);
__aicore__ inline uint32_t GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size, uint32_t actS2Size);
__aicore__ inline uint32_t GetTotalBaseBlockNum();
// ================================Process functions================================
__aicore__ inline void ProcessMain();
__aicore__ inline void ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx,
LIQCommon::RunInfo runInfo[LI_QUANT_PRELOAD_TASK_CACHE_SIZE]);
__aicore__ inline void ProcessDecode();
__aicore__ inline void ProcessInvalid();
// ================================Params Calc=====================================
__aicore__ inline void CalcGS1LoopParams(uint32_t bN2Idx);
__aicore__ inline void GetBN2Idx(uint32_t bN2Idx);
__aicore__ inline uint32_t GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq,
GlobalTensor<uint32_t> &actualSeqLengthsGm, uint32_t defaultSeqLen);
__aicore__ inline void GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size);
__aicore__ inline void CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx);
__aicore__ inline void CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LIQCommon::RunInfo &runInfo);
__aicore__ inline void DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start);
};
template <typename LIQT>
__aicore__ inline void LIQPreload<LIQT>::InitTilingData(const LIQTilingData *__restrict tilingData)
{
usedCoreNum = tilingData->usedCoreNum;
constInfo.batchSize = tilingData->bSize;
constInfo.qHeadNum = constInfo.gSize = tilingData->gSize;
constInfo.kSeqSize = tilingData->s2Size;
constInfo.qSeqSize = tilingData->s1Size;
constInfo.attenMaskFlag = (tilingData->sparseMode == 3);
constInfo.kCacheBlockSize = tilingData->blockSize;
constInfo.maxBlockNumPerBatch = tilingData->maxBlockNumPerBatch;
constInfo.sparseCount = tilingData->sparseCount;
constInfo.outputLayout = Q_LAYOUT_T; // 输出和输入形状一致
if (Q_LAYOUT_T == LI_LAYOUT::TND) {
constInfo.isAccumSeqS1 = true;
}
if (K_LAYOUT_T == LI_LAYOUT::TND) {
constInfo.isAccumSeqS2 = true;
}
constInfo.kHeadNum = K_HEAD_NUM;
constInfo.headDim = HEAD_DIM;
constInfo.mBaseSize = M_BASE_SIZE;
constInfo.s2BaseSize = S2_BASE_SIZE;
constInfo.s1BaseSize = (constInfo.mBaseSize + constInfo.gSize - 1) / constInfo.gSize;
}
template <typename LIQT>
__aicore__ inline void LIQPreload<LIQT>::InitBuffers()
{
if ASCEND_IS_AIV {
vectorService.InitBuffers(pipe);
} else {
matmulService.InitBuffers(pipe);
}
}
template <typename LIQT>
__aicore__ inline void LIQPreload<LIQT>::InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ,
__gm__ uint8_t *actualSeqLengthsK)
{
if (actualSeqLengthsQ == nullptr) {
constInfo.actualLenQDims = 0;
} else {
constInfo.actualLenQDims = constInfo.batchSize;
actualSeqLengthsGmQ.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengthsQ, constInfo.actualLenQDims);
}
if (actualSeqLengthsK == nullptr) {
constInfo.actualLenDims = 0;
} else {
constInfo.actualLenDims = constInfo.batchSize;
actualSeqLengthsGm.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengthsK, constInfo.actualLenDims);
}
}
template <typename LIQT>
__aicore__ inline uint32_t LIQPreload<LIQT>::GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq,
GlobalTensor<uint32_t> &actualSeqLengthsGm,
uint32_t defaultSeqLen)
{
if (actualLenDims == 0) {
return defaultSeqLen;
} else if (isAccumSeq && bIdx > 0) {
return actualSeqLengthsGm.GetValue(bIdx) - actualSeqLengthsGm.GetValue(bIdx - 1);
} else {
return actualSeqLengthsGm.GetValue(bIdx);
}
}
template <typename LIQT>
__aicore__ inline void LIQPreload<LIQT>::GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size)
{
actS1Size = GetActualSeqLen(bIdx, constInfo.actualLenQDims, constInfo.isAccumSeqS1, actualSeqLengthsGmQ,
constInfo.qSeqSize);
actS2Size =
GetActualSeqLen(bIdx, constInfo.actualLenDims, constInfo.isAccumSeqS2, actualSeqLengthsGm, constInfo.kSeqSize);
}
template <typename LIQT>
__aicore__ inline uint32_t LIQPreload<LIQT>::GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size,
uint32_t actS2Size)
{
if (actS2Size == 0) {
return 0;
}
uint32_t s1Offset = constInfo.s1BaseSize * s1gIdx;
int32_t validS2LenBase = static_cast<int32_t>(actS2Size) - static_cast<int32_t>(actS1Size);
int32_t validS2Len = s1Offset + validS2LenBase + constInfo.s1BaseSize;
validS2Len = Min(validS2Len, static_cast<int32_t>(actS2Size));
validS2Len = Max(validS2Len, 1);
return (validS2Len + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
}
template <typename LIQT>
__aicore__ inline uint32_t LIQPreload<LIQT>::GetTotalBaseBlockNum()
{
uint32_t totalBlockNum = 0;
uint32_t actS1Size, actS2Size;
uint32_t s1GBaseNum, s2BaseNum;
for (uint32_t bIdx = 0; bIdx < constInfo.batchSize; bIdx++) {
GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size);
s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize);
if (!constInfo.attenMaskFlag) {
s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize);
totalBlockNum += s1GBaseNum * s2BaseNum * constInfo.kHeadNum;
continue;
}
for (uint32_t s1gIdx = 0; s1gIdx < s1GBaseNum; s1gIdx++) {
s2BaseNum = GetS2BaseBlockNumOnMask(s1gIdx, actS1Size, actS2Size);
totalBlockNum += s2BaseNum * constInfo.kHeadNum;
}
}
return totalBlockNum;
}
// 多核版本,双闭区间。基本原则:计算每个核最少处理的块数, 剩余的部分前面的核每个核多处理一块
template <typename LIQT>
__aicore__ void inline LIQPreload<LIQT>::SplitCore(uint32_t curCoreIdx, uint32_t &coreNum,
LIQCommon::SplitCoreInfo &info)
{
uint32_t totalBlockNum = GetTotalBaseBlockNum();
uint32_t minBlockPerCore = totalBlockNum / coreNum;
uint32_t deal1MoreBlockCoreNum = totalBlockNum % coreNum;
uint32_t coreIdx = 0;
uint32_t lastGS1RemainBlockCnt = 0;
uint32_t coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore;
coreNum = minBlockPerCore == 0 ? deal1MoreBlockCoreNum : coreNum;
bool findLastCoreEnd = true;
uint32_t actS1Size, actS2Size;
uint32_t s1GBaseNum, s2BaseNum;
for (uint32_t bN2Idx = 0; bN2Idx < constInfo.batchSize * constInfo.kHeadNum; bN2Idx++) {
uint32_t bIdx = bN2Idx / constInfo.kHeadNum;
if (bN2Idx % constInfo.kHeadNum == 0) {
GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size);
s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize);
s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize);
}
if constexpr (Q_LAYOUT_T == LI_LAYOUT::BSND) {
if (findLastCoreEnd && (s1GBaseNum == 0U || s2BaseNum == 0U)) {
info.bN2Start = bN2Idx;
info.gS1Start = 0;
info.s2Start = 0;
findLastCoreEnd = false;
}
}
for (uint32_t gS1Idx = 0; gS1Idx < s1GBaseNum; gS1Idx++) {
if (constInfo.attenMaskFlag) {
s2BaseNum = GetS2BaseBlockNumOnMask(gS1Idx, actS1Size, actS2Size);
}
if (findLastCoreEnd && s2BaseNum == 0U) {
info.bN2Start = bN2Idx;
info.gS1Start = gS1Idx;
info.s2Start = 0;
findLastCoreEnd = false;
}
for (uint32_t s2Idx = 0; s2Idx < s2BaseNum;) {
if (findLastCoreEnd) {
info.bN2Start = bN2Idx;
info.gS1Start = gS1Idx;
info.s2Start = s2Idx;
findLastCoreEnd = false;
}
uint32_t s2RemainBaseNum = s2BaseNum - s2Idx;
if (lastGS1RemainBlockCnt + s2RemainBaseNum >= coreDealBlockCnt) {
info.bN2End = bN2Idx;
info.gS1End = gS1Idx;
info.s2End = s2Idx + coreDealBlockCnt - lastGS1RemainBlockCnt - 1;
if (coreIdx == curCoreIdx) {
// S2被切N核那么只有第一个核需要处理LD其他核不用
if (s2Idx == 0 && info.s2End + 1 < s2BaseNum) {
info.isLD = true;
}
// 最后一个核处理的不是最后一个Batch表明后面的Batch为空块(S2=0), 调整终点坐标以便清理输出
if (coreIdx == coreNum - 1 && info.bN2End != constInfo.batchSize - 1) {
info.bN2End = constInfo.batchSize - 1;
info.gS1End = 0;
info.s2End = 0;
}
return;
}
coreIdx++;
findLastCoreEnd = true;
s2Idx = info.s2End + 1;
lastGS1RemainBlockCnt = 0;
coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore;
} else {
lastGS1RemainBlockCnt += s2RemainBaseNum;
break;
}
}
}
}
}
template <typename LIQT>
__aicore__ inline void LIQPreload<LIQT>::DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start)
{
if ASCEND_IS_AIV {
if (constInfo.outputLayout == LI_LAYOUT::TND) {
uint32_t tSize = actualSeqLengthsGmQ.GetValue(constInfo.batchSize - 1);
uint32_t tBase = bIdx == 0 ? 0 : actualSeqLengthsGmQ.GetValue(bIdx - 1);
uint32_t s1Count = tempLoopInfo.actS1Size;
for (uint32_t s1Idx = s1Start; s1Idx < s1Count; s1Idx++) {
uint64_t indiceOutOffset =
(tBase + s1Idx) * constInfo.kHeadNum * constInfo.sparseCount + // T轴、s1轴偏移
n2Idx * constInfo.sparseCount; // N2轴偏移
vectorService.CleanInvalidOutput(indiceOutOffset);
}
} else if (constInfo.outputLayout == LI_LAYOUT::BSND) {
for (uint32_t s1Idx = s1Start; s1Idx < constInfo.qSeqSize; s1Idx++) {
// B,S1,N2,K
uint64_t indiceOutOffset = bIdx * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount +
s1Idx * constInfo.kHeadNum * constInfo.sparseCount + // B轴、S1轴偏移
n2Idx * constInfo.sparseCount; // N2轴偏移
vectorService.CleanInvalidOutput(indiceOutOffset);
}
}
}
}
template <typename LIQT>
__aicore__ inline void LIQPreload<LIQT>::Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
__gm__ uint8_t *queryScale, __gm__ uint8_t *keyScale,
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengthsK,
__gm__ uint8_t *blockTable, __gm__ uint8_t *sparseIndices,
__gm__ uint8_t *workspace, const LIQTilingData *__restrict tiling,
TPipe *tPipe)
{
if ASCEND_IS_AIV {
tmpBlockIdx = GetBlockIdx(); // vec:0-47
aiCoreIdx = tmpBlockIdx / 2;
} else {
tmpBlockIdx = GetBlockIdx(); // cube:0-23
aiCoreIdx = tmpBlockIdx;
}
InitTilingData(tiling);
InitActualSeqLen(actualSeqLengthsQ, actualSeqLengthsK);
// 计算分核
SplitCore(aiCoreIdx, usedCoreNum, splitCoreInfo);
pipe = tPipe;
// workspace 内存排布
// |mm1ResGm(存S)|vec1ResGm(存LD中间结果)|vec1ParamGm(存LD参数)
// |Core0_mm1ResDB0-Core0_mm1ResDB1-Core1_mm1ResDB0....Core23_mm1ResDB0-Core23_mm1ResDB1|Core0_vec1Res...
uint64_t offset = 0;
// mm1开DoubleBuffer
GlobalTensor<MM1_OUT_T> mm1ResGm; // 存放S
uint64_t singleCoreMm1ResSize = WS_DOBULE * constInfo.s1BaseSize * constInfo.s2BaseSize * sizeof(MM1_OUT_T);
mm1ResGm.SetGlobalBuffer((__gm__ MM1_OUT_T *)(workspace + aiCoreIdx * singleCoreMm1ResSize));
offset += GetBlockNum() * singleCoreMm1ResSize;
// ld流程需要ws大小: [aicnum, 2, CeilDiv(constInfo.mBaseSize, constInfo.gSize), topkOut_*2]
// (aic, 8, 2, 2, 2048)
// (aic, s1_cube, 头尾, idx/value, K)
GlobalTensor<float> vec1ResGm; // 存放TopK计算中间结果
vec1ResGm.SetGlobalBuffer((__gm__ float *)(workspace + offset));
offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * WS_DOBULE * BASE_TOPK * sizeof(float);
// (aic, 8, 2, 16)
// (aic, s1_cube, 头尾16ele)
GlobalTensor<int64_t> vec1ParamGm; // 存放LD参数信息
vec1ParamGm.SetGlobalBuffer((__gm__ int64_t *)(workspace + offset));
offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * LD_PARAM_NUM * sizeof(int64_t);
GlobalTensor<half> weightWorkspaceGm; // v1阶段处理w*scale后的结果
uint64_t weightMemSize = BLOCK_CUBE * constInfo.mBaseSize * WS_DOBULE * sizeof(half);
weightWorkspaceGm.SetGlobalBuffer((__gm__ half *)(workspace + offset + aiCoreIdx * weightMemSize));
offset += GetBlockNum() * weightMemSize;
GlobalTensor<half> qScaleGm;
GlobalTensor<half> kScaleGm;
if ASCEND_IS_AIV {
vectorService.InitParams(constInfo, tiling);
indiceOutGm.SetGlobalBuffer((__gm__ int32_t *)sparseIndices);
weightsGm.SetGlobalBuffer((__gm__ half *)weights);
qScaleGm.SetGlobalBuffer((__gm__ half *)queryScale);
kScaleGm.SetGlobalBuffer((__gm__ half *)keyScale);
blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable);
vectorService.InitVecInputTensor(weightsGm, qScaleGm, kScaleGm, indiceOutGm, blockTableGm);
vectorService.InitVecWorkspaceTensor(weightWorkspaceGm, mm1ResGm, vec1ResGm, vec1ParamGm);
} else {
matmulService.InitParams(constInfo);
queryGm.SetGlobalBuffer((__gm__ Q_T *)query);
if constexpr (PAGE_ATTENTION) {
blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable);
}
keyGm.SetGlobalBuffer((__gm__ K_T *)key);
matmulService.InitMm1GlobalTensor(blockTableGm, keyGm, queryGm, mm1ResGm, weightWorkspaceGm);
}
InitBuffers();
}
template <typename LIQT>
__aicore__ inline void LIQPreload<LIQT>::GetBN2Idx(uint32_t bN2Idx)
{
tempLoopInfo.bN2Idx = bN2Idx;
tempLoopInfo.bIdx = bN2Idx / constInfo.kHeadNum;
tempLoopInfo.n2Idx = bN2Idx % constInfo.kHeadNum;
}
template <typename LIQT>
__aicore__ inline void LIQPreload<LIQT>::CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx)
{
tempLoopInfo.gS1Idx = gS1LoopIdx;
tempLoopInfo.actMBaseSize = constInfo.mBaseSize;
uint32_t remainedGS1Size = tempLoopInfo.actS1Size * constInfo.gSize - tempLoopInfo.gS1Idx * constInfo.mBaseSize;
if (remainedGS1Size <= constInfo.mBaseSize && remainedGS1Size > 0) {
tempLoopInfo.actMBaseSize = tempLoopInfo.mBasicSizeTail;
}
bool isEnd = (bN2LoopIdx == splitCoreInfo.bN2End) && (gS1LoopIdx == splitCoreInfo.gS1End);
uint32_t s2BlockNum;
if (constInfo.attenMaskFlag) {
s2BlockNum = GetS2BaseBlockNumOnMask(gS1LoopIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size);
} else {
s2BlockNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
}
tempLoopInfo.s2LoopEnd = isEnd ? splitCoreInfo.s2End : s2BlockNum - 1;
}
template <typename LIQT>
__aicore__ inline void LIQPreload<LIQT>::CalcGS1LoopParams(uint32_t bN2LoopIdx)
{
GetBN2Idx(bN2LoopIdx);
GetS1S2ActualSeqLen(tempLoopInfo.bIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size);
if ((tempLoopInfo.actS2Size == 0) || (tempLoopInfo.actS1Size == 0)) {
tempLoopInfo.curActSeqLenIsZero = true;
return;
}
tempLoopInfo.curActSeqLenIsZero = false;
tempLoopInfo.s2BasicSizeTail = tempLoopInfo.actS2Size % constInfo.s2BaseSize;
tempLoopInfo.s2BasicSizeTail =
(tempLoopInfo.s2BasicSizeTail == 0) ? constInfo.s2BaseSize : tempLoopInfo.s2BasicSizeTail;
tempLoopInfo.mBasicSizeTail = (tempLoopInfo.actS1Size * constInfo.gSize) % constInfo.mBaseSize;
tempLoopInfo.mBasicSizeTail =
(tempLoopInfo.mBasicSizeTail == 0) ? constInfo.mBaseSize : tempLoopInfo.mBasicSizeTail;
uint32_t gS1SplitNum = (tempLoopInfo.actS1Size * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize;
tempLoopInfo.gS1LoopEnd = (bN2LoopIdx == splitCoreInfo.bN2End) ? splitCoreInfo.gS1End : gS1SplitNum - 1;
if constexpr (Q_LAYOUT_T == LI_LAYOUT::BSND) {
if (tempLoopInfo.gS1LoopEnd == gS1SplitNum - 1 && constInfo.qSeqSize > tempLoopInfo.actS1Size) {
tempLoopInfo.needDealActS1LessThanS1 = true;
}
}
}
template <typename LIQT>
__aicore__ inline void LIQPreload<LIQT>::CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LIQCommon::RunInfo &runInfo)
{
runInfo.loop = loop;
runInfo.bIdx = tempLoopInfo.bIdx;
runInfo.gS1Idx = tempLoopInfo.gS1Idx;
runInfo.s2Idx = s2LoopIdx;
runInfo.bN2Idx = tempLoopInfo.bN2Idx;
runInfo.isValid = s2LoopIdx <= tempLoopInfo.s2LoopEnd;
if (!runInfo.isValid) {
return; // 需要验证, v1 时候需要runInfo
}
runInfo.actS1Size = tempLoopInfo.actS1Size;
runInfo.actS2Size = tempLoopInfo.actS2Size;
// 计算实际基本块size
runInfo.actMBaseSize = tempLoopInfo.actMBaseSize;
runInfo.actualSingleProcessSInnerSize = constInfo.s2BaseSize;
uint32_t s2SplitNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize;
if (runInfo.s2Idx == s2SplitNum - 1) {
runInfo.actualSingleProcessSInnerSize = tempLoopInfo.s2BasicSizeTail;
}
runInfo.actualSingleProcessSInnerSizeAlign =
LIQCommon::Align((uint32_t)runInfo.actualSingleProcessSInnerSize, LIQCommon::ConstInfo::BUFFER_SIZE_BYTE_32B);
runInfo.isFirstS2InnerLoop = s2LoopIdx == splitCoreInfo.s2Start;
runInfo.isLastS2InnerLoop = s2LoopIdx == tempLoopInfo.s2LoopEnd;
runInfo.isAllLoopEnd = (runInfo.bN2Idx == splitCoreInfo.bN2End) && (runInfo.gS1Idx == splitCoreInfo.gS1End) &&
(runInfo.s2Idx == splitCoreInfo.s2End);
if (runInfo.isFirstS2InnerLoop) {
uint64_t actualSeqQPrefixSum;
if constexpr (Q_LAYOUT_T == LI_LAYOUT::TND) {
actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGmQ.GetValue(runInfo.bIdx - 1);
} else { // BSND
actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.qSeqSize;
}
uint64_t tndBIdxOffset = actualSeqQPrefixSum * constInfo.qHeadNum * constInfo.headDim;
// B,S1,N1(N2,G),D
queryCoreOffset = tndBIdxOffset + runInfo.gS1Idx * constInfo.mBaseSize * constInfo.headDim;
// B,S1,N1(N2,G)/T,N1(N2,G)
weightsCoreOffset = actualSeqQPrefixSum * constInfo.qHeadNum + runInfo.n2Idx * constInfo.gSize;
// B,S1,N2,k/T,N2,k
indiceOutCoreOffset =
actualSeqQPrefixSum * constInfo.kHeadNum * constInfo.sparseCount + runInfo.n2Idx * constInfo.sparseCount;
}
uint64_t actualSeqKPrefixSum;
if constexpr (K_LAYOUT_T == LI_LAYOUT::TND) { // T N2 D
actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGm.GetValue(runInfo.bIdx - 1);
} else {
actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.kSeqSize;
}
uint64_t tndBIdxOffsetForK = actualSeqKPrefixSum * constInfo.kHeadNum * constInfo.headDim;
keyCoreOffset = tndBIdxOffsetForK + runInfo.s2Idx * constInfo.s2BaseSize * constInfo.kHeadNum * constInfo.headDim;
keyScaleCoreOffset = (actualSeqKPrefixSum + runInfo.s2Idx * constInfo.s2BaseSize) * constInfo.kHeadNum;
runInfo.tensorQueryOffset = queryCoreOffset;
runInfo.tensorKeyOffset = keyCoreOffset;
runInfo.tensorKeyScaleOffset = keyScaleCoreOffset;
runInfo.tensorWeightsOffset = weightsCoreOffset;
runInfo.indiceOutOffset = indiceOutCoreOffset;
}
template <typename LIQT>
__aicore__ inline void LIQPreload<LIQT>::Process()
{
if (usedCoreNum == 0) {
// 没有计算任务,直接清理输出
ProcessInvalid();
return;
}
ProcessMain();
ProcessDecode();
}
template <typename LIQT>
__aicore__ inline void LIQPreload<LIQT>::ProcessInvalid()
{
if ASCEND_IS_AIV {
uint32_t aivCoreNum = GetBlockNum() * 2; // 2 means c:v = 1:2
uint64_t totalOutputSize =
constInfo.batchSize * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount;
uint64_t singleCoreSize =
LIQCommon::Align((totalOutputSize + aivCoreNum - 1) / aivCoreNum, GM_ALIGN_BYTES / sizeof(OUT_T));
uint64_t baseSize = tmpBlockIdx * singleCoreSize;
if (baseSize < totalOutputSize) {
uint64_t dealSize =
(baseSize + singleCoreSize <= totalOutputSize) ? singleCoreSize : totalOutputSize - baseSize;
GlobalTensor<OUT_T> output = indiceOutGm[baseSize];
AscendC::InitGlobalMemory(output, dealSize, constInfo.INVALID_IDX);
}
}
}
template <typename LIQT>
__aicore__ inline void LIQPreload<LIQT>::ProcessMain()
{
if (aiCoreIdx >= usedCoreNum) {
// 无任务核直接返回
return;
}
if ASCEND_IS_AIV {
vectorService.AllocEventID();
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE2>(constInfo.syncV1C1);
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE2>(constInfo.syncV1C1);
} else {
matmulService.AllocEventID();
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V0);
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V0);
}
LIQCommon::RunInfo runInfo[LI_QUANT_PRELOAD_TASK_CACHE_SIZE];
uint32_t gloop = 0;
for (uint32_t bN2LoopIdx = splitCoreInfo.bN2Start; bN2LoopIdx <= splitCoreInfo.bN2End; bN2LoopIdx++) {
CalcGS1LoopParams(bN2LoopIdx);
if (tempLoopInfo.curActSeqLenIsZero) {
DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, 0U);
if ASCEND_IS_AIV {
if (bN2LoopIdx == splitCoreInfo.bN2End && gloop > 0) {
CrossCoreWaitFlag(constInfo.syncC1V1);
vectorService.ProcessVec1(runInfo[1 - gloop % LI_QUANT_PRELOAD_TASK_CACHE_SIZE]);
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE3>(
constInfo.syncV1C1); // 反向同步 1
}
}
continue;
}
for (uint32_t gS1LoopIdx = splitCoreInfo.gS1Start; gS1LoopIdx <= tempLoopInfo.gS1LoopEnd; gS1LoopIdx++) {
CalcS2LoopParams(bN2LoopIdx, gS1LoopIdx);
bool isEnd = (bN2LoopIdx == splitCoreInfo.bN2End) && (gS1LoopIdx == splitCoreInfo.gS1End);
uint32_t extraLoop = isEnd ? LI_QUANT_PRELOAD_TASK_CACHE_SIZE - 1 : 0;
for (int s2LoopIdx = splitCoreInfo.s2Start; s2LoopIdx <= (tempLoopInfo.s2LoopEnd + extraLoop);
s2LoopIdx++) {
ProcessBaseBlock(gloop, s2LoopIdx, runInfo);
++gloop;
}
splitCoreInfo.s2Start = 0;
}
if (tempLoopInfo.needDealActS1LessThanS1) {
DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, tempLoopInfo.actS1Size);
}
splitCoreInfo.gS1Start = 0;
}
if ASCEND_IS_AIV {
vectorService.FreeEventID();
CrossCoreWaitFlag(constInfo.syncC1V0);
CrossCoreWaitFlag(constInfo.syncC1V0);
} else {
matmulService.FreeEventID();
CrossCoreWaitFlag(constInfo.syncV1C1);
CrossCoreWaitFlag(constInfo.syncV1C1);
}
}
template <typename LIQT>
__aicore__ inline void LIQPreload<LIQT>::ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx,
LIQCommon::RunInfo runInfo[LI_QUANT_PRELOAD_TASK_CACHE_SIZE])
{
int32_t curTaskId = loop % LI_QUANT_PRELOAD_TASK_CACHE_SIZE;
LIQCommon::RunInfo &curRunInfo = runInfo[curTaskId];
LIQCommon::RunInfo &lastRunInfo = runInfo[1 - curTaskId];
CalcRunInfo(loop, s2LoopIdx, curRunInfo);
if (curRunInfo.isValid) {
if ASCEND_IS_AIC {
if (curRunInfo.isFirstS2InnerLoop) {
CrossCoreWaitFlag(constInfo.syncV0C1);
}
CrossCoreWaitFlag(constInfo.syncV1C1); // 反向同步 1
matmulService.ComputeMm1(curRunInfo);
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V1);
if (curRunInfo.isLastS2InnerLoop) {
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_FIX>(constInfo.syncC1V0); // 反向同步 0
}
} else {
if (curRunInfo.isFirstS2InnerLoop) {
CrossCoreWaitFlag(constInfo.syncC1V0); // 反向同步 0
vectorService.ProcessVec0(curRunInfo);
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE3>(constInfo.syncV0C1);
}
}
}
if (lastRunInfo.isValid) {
if ASCEND_IS_AIV {
CrossCoreWaitFlag(constInfo.syncC1V1);
vectorService.ProcessVec1(lastRunInfo);
CrossCoreSetFlag<LIQCommon::ConstInfo::FIA_SYNC_MODE2, PIPE_MTE3>(constInfo.syncV1C1); // 反向同步 1
}
lastRunInfo.isValid = false;
}
}
template <typename LIQT>
__aicore__ inline void LIQPreload<LIQT>::ProcessDecode()
{
if ASCEND_IS_AIV {
vectorService.InitLDBuffers(pipe);
ICachePreLoad(LD_PREFETCH_LEN);
SyncAll();
if (splitCoreInfo.isLD) {
vectorService.ProcessLD();
}
}
}
} // namespace LIQKernel
#endif // LIGHTNING_INDEXER_QUANT_KERNEL_H

View File

@@ -0,0 +1,613 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_quant_service_cube.h
* \brief use 5 buffer for matmul l1, better pipeline
*/
#ifndef LIGHTNING_INDEXER_QUANT_SERVICE_CUBE_H
#define LIGHTNING_INDEXER_QUANT_SERVICE_CUBE_H
#include "kernel_operator.h"
#include "kernel_operator_list_tensor_intf.h"
#include "kernel_tiling/kernel_tiling.h"
#include "lib/matmul_intf.h"
#include "lib/matrix/matmul/tiling.h"
#include "lightning_indexer_quant_common.h"
namespace LIQKernel {
using namespace LIQCommon;
struct MmInfo {
int64_t s2L0LoopId;
int64_t s1gL0LoopId;
int64_t s2L0RealSize;
int64_t s2GmOffset;
};
template <typename LIQT>
class LIQMatmul {
public:
using Q_T = typename LIQT::queryType;
using K_T = typename LIQT::keyType;
__aicore__ inline LIQMatmul(){};
__aicore__ inline void InitBuffers(TPipe *pipe);
__aicore__ inline void InitMm1GlobalTensor(const GlobalTensor<int32_t> &blkTableGm, const GlobalTensor<K_T> &keyGm,
const GlobalTensor<Q_T> &queryGm, const GlobalTensor<float> &mm1ResGm,
const GlobalTensor<half> &weightWorkspaceGm);
__aicore__ inline void InitParams(const ConstInfo &constInfo);
__aicore__ inline void AllocEventID();
__aicore__ inline void FreeEventID();
__aicore__ inline void ComputeMm1(const LIQCommon::RunInfo &runInfo);
static constexpr IsResetLoad3dConfig LOAD3DV2_CONFIG = {true, true}; // isSetFMatrix isSetPadding;
static constexpr uint64_t DOUBLE_BUF_NUM = 2;
static constexpr uint64_t L0AB_BUF_NUM = 4;
static constexpr uint32_t KEY_MTE1_MTE2_EVENT = EVENT_ID2;
static constexpr uint32_t QW_MTE1_MTE2_EVENT = EVENT_ID5; // KEY_MTE1_MTE2_EVENT + DOUBLE_BUF_NUM;
static constexpr uint32_t M_MTE1_EVENT = EVENT_ID3;
static constexpr uint32_t M_FIX_EVENT = EVENT_ID0;
static constexpr uint32_t FIX_M_EVENT = EVENT_ID2;
static constexpr uint32_t FIX_MTE1_EVENT = EVENT_ID4;
static constexpr uint64_t S8_BLOCK_CUBE = 32;
static constexpr uint32_t MTE2_MTE1_EVENT = EVENT_ID2;
static constexpr uint32_t MTE1_M_EVENT = EVENT_ID2;
static constexpr uint64_t D_BASIC_BLOCK = 128;
static constexpr uint64_t S1G_BASIC_BLOCK_L1 = 256;
static constexpr uint64_t S1G_BASIC_BLOCK_L0 = 128;
static constexpr uint64_t S2_BASIC_BLOCK_L0 = 128;
static constexpr uint64_t QUERY_BUFFER_OFFSET = S1G_BASIC_BLOCK_L1 * D_BASIC_BLOCK;
static constexpr uint64_t SL1_BUFFER_OFFSET = S1G_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0;
static constexpr uint64_t KEY_BUFFER_OFFSET = S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK;
static constexpr uint64_t WEIGHT_BUFFER_OFFSET = S1G_BASIC_BLOCK_L1 * BLOCK_CUBE;
static constexpr uint64_t L0AB_BUFFER_OFFSET_S8_16K = 16 * 1024;
static constexpr uint64_t L0AB_BUFFER_OFFSET_FP16_16K = 16 * 512;
static constexpr uint64_t L0C_BUFFER_OFFSET = 64 * 256;
private:
__aicore__ inline void WeightDmaCopy(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo);
__aicore__ inline void LoadKeyToL0b(uint64_t s2L0RealSize);
__aicore__ inline void LoadQueryToL0a(uint64_t s1gL1Offset, uint64_t s1gL1RealSize, uint64_t s1gL0RealSize);
__aicore__ inline void QueryNd2Nz(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo);
__aicore__ inline void KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LIQCommon::RunInfo &runInfo);
__aicore__ inline void KeyNd2Nz(uint64_t s2L1RealSize, const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo);
__aicore__ inline void FixpSToL1(uint64_t s1gL0RealSize, uint64_t s2L0RealSize);
__aicore__ inline void LoadSToL0b(uint64_t s1gL1RealSize, uint64_t s2L0RealSize, uint64_t sL1BufIdx,
int64_t mStartPt);
__aicore__ inline void LoadWeightToL0a(uint64_t s1gL1Offset);
__aicore__ inline void ComputeWs(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, int64_t s1gOffset);
__aicore__ inline void FixpResToGm(uint64_t s1L0RealCount, uint64_t s2L0RealSize, uint64_t s1GmOffset,
uint64_t s2GmOffset, const LIQCommon::RunInfo &runInfo);
__aicore__ inline void ComputeQk(uint64_t s1gL0RealSize, uint64_t s2L0RealSize);
__aicore__ inline void ProcessWs(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t sL1BufIdx,
const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo);
__aicore__ inline void ProcessQk(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t s1L0LoopCnt,
const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo);
__aicore__ inline void CalcMmInfo(MmInfo &mmInfo, uint64_t loopIdx, uint64_t s1L0LoopCnt, const MmInfo &lastMmInfo,
const LIQCommon::RunInfo &runInfo);
static constexpr LI_LAYOUT Q_LAYOUT_T = LIQT::layout;
static constexpr LI_LAYOUT K_LAYOUT_T = LIQT::keyLayout;
GlobalTensor<int32_t> blkTableGm_;
GlobalTensor<K_T> keyGm_;
GlobalTensor<Q_T> queryGm_;
GlobalTensor<half> weightGm_;
GlobalTensor<float> mm1ResGm_;
TBuf<TPosition::A1> bufQL1_;
LocalTensor<Q_T> queryL1_;
TBuf<TPosition::B1> bufKeyL1_;
LocalTensor<K_T> keyL1_;
TBuf<TPosition::A1> bufWeightL1_;
LocalTensor<half> weightL1_;
TBuf<TPosition::B1> bufSL1_;
LocalTensor<half> sL1_;
TBuf<TPosition::A2> bufL0A_;
LocalTensor<Q_T> l0a_;
TBuf<TPosition::B2> bufL0B_;
LocalTensor<K_T> l0b_;
TBuf<TPosition::CO1> bufL0C_;
LocalTensor<int32_t> cL0_;
uint64_t keyL1BufIdx_ = 0;
uint64_t qwL1Mte2BufIdx_ = 0;
uint64_t sL1BufIdx_ = 0;
uint64_t l0BufIdx_ = 0;
uint64_t l0cBufIdx_ = 0;
ConstInfo constInfo_;
};
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::InitParams(const ConstInfo &constInfo)
{
constInfo_ = constInfo;
}
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::InitBuffers(TPipe *pipe)
{
pipe->InitBuffer(bufQL1_, DOUBLE_BUF_NUM * S1G_BASIC_BLOCK_L1 * D_BASIC_BLOCK * sizeof(Q_T));
queryL1_ = bufQL1_.Get<Q_T>();
pipe->InitBuffer(bufKeyL1_, DOUBLE_BUF_NUM * S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK * sizeof(K_T));
keyL1_ = bufKeyL1_.Get<K_T>();
pipe->InitBuffer(bufWeightL1_, DOUBLE_BUF_NUM * S1G_BASIC_BLOCK_L1 * BLOCK_CUBE * sizeof(half));
weightL1_ = bufWeightL1_.Get<half>();
pipe->InitBuffer(bufSL1_, DOUBLE_BUF_NUM * S2_BASIC_BLOCK_L0 * S1G_BASIC_BLOCK_L0 * sizeof(half));
sL1_ = bufSL1_.Get<half>();
pipe->InitBuffer(bufL0A_, 64 * 1024);
l0a_ = bufL0A_.Get<Q_T>();
pipe->InitBuffer(bufL0B_, 64 * 1024);
l0b_ = bufL0B_.Get<K_T>();
pipe->InitBuffer(bufL0C_, 128 * 1024);
cL0_ = bufL0C_.Get<int32_t>();
}
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::InitMm1GlobalTensor(const GlobalTensor<int32_t> &blkTableGm,
const GlobalTensor<K_T> &keyGm,
const GlobalTensor<Q_T> &queryGm,
const GlobalTensor<float> &mm1ResGm,
const GlobalTensor<half> &weightWorkspaceGm)
{
blkTableGm_ = blkTableGm;
keyGm_ = keyGm;
queryGm_ = queryGm;
mm1ResGm_ = mm1ResGm;
weightGm_ = weightWorkspaceGm;
}
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::ProcessWs(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t sL1BufIdx,
const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo)
{
WaitFlag<HardEvent::FIX_M>(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM);
for (int64_t s1gOffset = 0; s1gOffset < s1gL0RealSize; s1gOffset += constInfo_.gSize) {
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM);
LoadSToL0b(s1gL0RealSize, mmInfo.s2L0RealSize, sL1BufIdx, s1gOffset);
LoadWeightToL0a(s1gOffset + s1gL1Offset);
ComputeWs(s1gL0RealSize, mmInfo.s2L0RealSize, s1gOffset);
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM);
l0BufIdx_++;
}
FixpResToGm(s1gL0RealSize / constInfo_.gSize, mmInfo.s2L0RealSize, s1gL1Offset / constInfo_.gSize,
mmInfo.s2L0LoopId * S2_BASIC_BLOCK_L0, runInfo);
SetFlag<HardEvent::FIX_M>(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM);
l0cBufIdx_++;
}
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::ProcessQk(uint64_t s1gL0RealSize, uint64_t s1gL1Offset, uint64_t s1L0LoopCnt,
const MmInfo &mmInfo, const LIQCommon::RunInfo &runInfo)
{
if (mmInfo.s1gL0LoopId == 0) {
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % DOUBLE_BUF_NUM);
if constexpr (K_LAYOUT_T == LI_LAYOUT::PA_BSND) {
KeyNd2NzForPA(mmInfo.s2L0RealSize, runInfo.s2Idx * constInfo_.s2BaseSize + mmInfo.s2GmOffset, runInfo);
} else {
KeyNd2Nz(mmInfo.s2L0RealSize, mmInfo, runInfo);
}
SetFlag<HardEvent::MTE2_MTE1>(MTE2_MTE1_EVENT);
WaitFlag<HardEvent::MTE2_MTE1>(MTE2_MTE1_EVENT);
}
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM);
LoadQueryToL0a(s1gL1Offset, runInfo.actMBaseSize, s1gL0RealSize);
LoadKeyToL0b(mmInfo.s2L0RealSize);
if (mmInfo.s1gL0LoopId + 1 >= s1L0LoopCnt) {
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % DOUBLE_BUF_NUM);
keyL1BufIdx_++;
}
WaitFlag<HardEvent::FIX_M>(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM);
ComputeQk(s1gL0RealSize, mmInfo.s2L0RealSize);
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + l0BufIdx_ % L0AB_BUF_NUM);
FixpSToL1(s1gL0RealSize, mmInfo.s2L0RealSize);
SetFlag<HardEvent::FIX_M>(FIX_M_EVENT + l0cBufIdx_ % DOUBLE_BUF_NUM);
l0BufIdx_++;
l0cBufIdx_++;
}
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::CalcMmInfo(MmInfo &mmInfo, uint64_t loopIdx, uint64_t s1L0LoopCnt,
const MmInfo &lastMmInfo, const LIQCommon::RunInfo &runInfo)
{
mmInfo.s2L0LoopId = loopIdx / s1L0LoopCnt;
mmInfo.s1gL0LoopId = loopIdx % s1L0LoopCnt;
if (mmInfo.s1gL0LoopId == 0) {
mmInfo.s2GmOffset = mmInfo.s2L0LoopId * S2_BASIC_BLOCK_L0;
mmInfo.s2L0RealSize = mmInfo.s2GmOffset + S2_BASIC_BLOCK_L0 > runInfo.actualSingleProcessSInnerSize
? runInfo.actualSingleProcessSInnerSize - mmInfo.s2GmOffset
: S2_BASIC_BLOCK_L0;
} else {
mmInfo.s2L0RealSize = lastMmInfo.s2L0RealSize;
}
}
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::ComputeMm1(const LIQCommon::RunInfo &runInfo)
{
if (runInfo.isFirstS2InnerLoop) {
WaitFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM);
QueryNd2Nz(runInfo.actMBaseSize, runInfo); // 256 * 128 // L1BasicBlock
WeightDmaCopy(runInfo.actMBaseSize, runInfo);
}
int64_t loopIdx = 0;
int64_t s2L0LoopCnt = CeilDiv(runInfo.actualSingleProcessSInnerSize, S2_BASIC_BLOCK_L0); // 2048取128
int64_t s1L0LoopCnt = CeilDiv(runInfo.actMBaseSize, S1G_BASIC_BLOCK_L0); // 256取128
int64_t s1gL1Offset[2] = {0, static_cast<int64_t>(S1G_BASIC_BLOCK_L0)};
int64_t s1gL0RealSize[2] = {s1L0LoopCnt > 1 ? static_cast<int64_t>(S1G_BASIC_BLOCK_L0) : runInfo.actMBaseSize,
runInfo.actMBaseSize - s1gL1Offset[1]};
MmInfo mmInfo[2];
CalcMmInfo(mmInfo[loopIdx & 1], loopIdx, s1L0LoopCnt, mmInfo[(loopIdx + 1) & 1], runInfo);
ProcessQk(s1gL0RealSize[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt],
s1gL1Offset[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt], s1L0LoopCnt, mmInfo[loopIdx & 1],
runInfo);
SetFlag<HardEvent::FIX_MTE1>(FIX_MTE1_EVENT + sL1BufIdx_ % DOUBLE_BUF_NUM);
sL1BufIdx_++;
loopIdx++;
while (loopIdx < s2L0LoopCnt * s1L0LoopCnt) {
CalcMmInfo(mmInfo[loopIdx & 1], loopIdx, s1L0LoopCnt, mmInfo[(loopIdx + 1) & 1], runInfo);
ProcessQk(s1gL0RealSize[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt],
s1gL1Offset[mmInfo[loopIdx & 1].s1gL0LoopId % s1L0LoopCnt], s1L0LoopCnt, mmInfo[loopIdx & 1],
runInfo);
SetFlag<HardEvent::FIX_MTE1>(FIX_MTE1_EVENT + sL1BufIdx_ % DOUBLE_BUF_NUM);
sL1BufIdx_++;
WaitFlag<HardEvent::FIX_MTE1>(FIX_MTE1_EVENT + sL1BufIdx_ % DOUBLE_BUF_NUM);
ProcessWs(s1gL0RealSize[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt],
s1gL1Offset[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt], sL1BufIdx_,
mmInfo[(loopIdx + 1) & 1], runInfo);
loopIdx++;
}
WaitFlag<HardEvent::FIX_MTE1>(FIX_MTE1_EVENT + (sL1BufIdx_ + 1) % DOUBLE_BUF_NUM);
ProcessWs(s1gL0RealSize[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt],
s1gL1Offset[mmInfo[(loopIdx + 1) & 1].s1gL0LoopId % s1L0LoopCnt], sL1BufIdx_ - 1,
mmInfo[(loopIdx + 1) & 1], runInfo);
if (runInfo.isLastS2InnerLoop) {
SetFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM);
qwL1Mte2BufIdx_++;
}
}
// blkNum, blkSize, N2, D
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset,
const LIQCommon::RunInfo &runInfo)
{
uint64_t s2L1Offset = 0;
while (s2L1Offset < s2L1RealSize) {
uint64_t s2BlkId = (s2L1Offset + s2GmOffset) / constInfo_.kCacheBlockSize;
uint64_t s2BlkOffset = (s2L1Offset + s2GmOffset) % constInfo_.kCacheBlockSize;
uint64_t keyGmOffset = blkTableGm_.GetValue(runInfo.bIdx * constInfo_.maxBlockNumPerBatch + s2BlkId) *
constInfo_.kCacheBlockSize * constInfo_.kHeadNum * constInfo_.headDim +
s2BlkOffset * constInfo_.headDim;
uint64_t s2Mte2Size = s2L1RealSize - s2L1Offset;
s2Mte2Size = s2BlkOffset + s2Mte2Size >= constInfo_.kCacheBlockSize ? constInfo_.kCacheBlockSize - s2BlkOffset
: s2Mte2Size;
Nd2NzParams nd2nzPara;
nd2nzPara.ndNum = 1;
nd2nzPara.nValue = s2Mte2Size; // 行数
nd2nzPara.dValue = constInfo_.headDim;
nd2nzPara.srcDValue = constInfo_.headDim;
nd2nzPara.dstNzC0Stride = CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE); // 对齐到16 单位block
nd2nzPara.dstNzNStride = 1;
nd2nzPara.srcNdMatrixStride = 0;
nd2nzPara.dstNzMatrixStride = 0;
DataCopy(keyL1_[(keyL1BufIdx_ % DOUBLE_BUF_NUM) * KEY_BUFFER_OFFSET + s2L1Offset * S8_BLOCK_CUBE],
keyGm_[keyGmOffset], nd2nzPara);
s2L1Offset += s2Mte2Size;
}
}
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::KeyNd2Nz(uint64_t s2L1RealSize, const MmInfo &mmInfo,
const LIQCommon::RunInfo &runInfo)
{
uint64_t dStride = constInfo_.headDim;
if constexpr (K_LAYOUT_T == LI_LAYOUT::BSND || K_LAYOUT_T == LI_LAYOUT::TND) {
dStride = constInfo_.headDim * constInfo_.kHeadNum; // constInfo_.kHeadNum
}
Nd2NzParams nd2nzPara;
nd2nzPara.ndNum = 1;
nd2nzPara.nValue = s2L1RealSize; // 行数
nd2nzPara.dValue = constInfo_.headDim;
nd2nzPara.srcDValue = dStride;
nd2nzPara.dstNzC0Stride = CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE); // 对齐到16 单位block
nd2nzPara.dstNzNStride = 1;
nd2nzPara.srcNdMatrixStride = 0;
nd2nzPara.dstNzMatrixStride = 0;
// 默认一块buf最多放两份
DataCopy(keyL1_[(keyL1BufIdx_ % DOUBLE_BUF_NUM) * KEY_BUFFER_OFFSET],
keyGm_[runInfo.tensorKeyOffset + mmInfo.s2GmOffset * constInfo_.headDim], nd2nzPara);
}
// batch, s1, g, 1
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::WeightDmaCopy(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo)
{
DataCopyParams copyInParams;
copyInParams.blockCount = 1;
copyInParams.blockLen = s1gL1RealSize;
copyInParams.srcStride = 0;
copyInParams.dstStride = 0;
DataCopy(weightL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * WEIGHT_BUFFER_OFFSET],
weightGm_[runInfo.loop % DOUBLE_BUF_NUM * BLOCK_CUBE * constInfo_.mBaseSize], copyInParams);
}
// batch, s1, n2, g, d
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::QueryNd2Nz(uint64_t s1gL1RealSize, const LIQCommon::RunInfo &runInfo)
{
Nd2NzParams nd2nzPara;
nd2nzPara.ndNum = 1;
nd2nzPara.nValue = s1gL1RealSize; // 行数
nd2nzPara.dValue = constInfo_.headDim;
nd2nzPara.srcDValue = constInfo_.headDim;
nd2nzPara.dstNzC0Stride = CeilAlign(s1gL1RealSize, (uint64_t)BLOCK_CUBE); // 对齐到16 单位block
nd2nzPara.dstNzNStride = 1;
nd2nzPara.srcNdMatrixStride = 0;
nd2nzPara.dstNzMatrixStride = 0;
// 默认一块buf最多放两份
DataCopy(queryL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * QUERY_BUFFER_OFFSET], queryGm_[runInfo.tensorQueryOffset],
nd2nzPara);
}
// s1g, d
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::LoadQueryToL0a(uint64_t s1gL1Offset, uint64_t s1gL1RealSize,
uint64_t s1gL0RealSize)
{
LoadData3DParamsV2<Q_T> loadData3DParams;
// SetFmatrixParams
loadData3DParams.l1H = CeilDiv(s1gL1RealSize, BLOCK_CUBE); // Hin=M1=8
loadData3DParams.l1W = BLOCK_CUBE; // Win=M0
loadData3DParams.channelSize = constInfo_.headDim; // Cin=K
loadData3DParams.padList[0] = 0;
loadData3DParams.padList[1] = 0;
loadData3DParams.padList[2] = 0;
loadData3DParams.padList[3] = 255; // 尾部数据不影响滑窗的结果
// SetLoadToA0Params
loadData3DParams.mExtension = CeilAlign(s1gL0RealSize, BLOCK_CUBE); // M height维度目的
loadData3DParams.kExtension = constInfo_.headDim; // K width维度目的
loadData3DParams.mStartPt = s1gL1Offset;
loadData3DParams.kStartPt = 0;
loadData3DParams.strideW = 1;
loadData3DParams.strideH = 1;
loadData3DParams.filterW = 1;
loadData3DParams.filterSizeW = (1 >> 8) & 255;
loadData3DParams.filterH = 1;
loadData3DParams.filterSizeH = (1 >> 8) & 255;
loadData3DParams.dilationFilterW = 1;
loadData3DParams.dilationFilterH = 1;
loadData3DParams.enTranspose = 0;
loadData3DParams.fMatrixCtrl = 0;
LoadData<Q_T, LOAD3DV2_CONFIG>(l0a_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K],
queryL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * QUERY_BUFFER_OFFSET],
loadData3DParams);
}
// s1, g, s2 --> 2 * 64* 128
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::LoadSToL0b(uint64_t s1gL1RealSize, uint64_t s2L0RealSize, uint64_t sL1BufIdx,
int64_t mStartPt)
{
LoadData3DParamsV2<half> loadData3DParams;
// SetFmatrixParams
loadData3DParams.l1H = S1G_BASIC_BLOCK_L0 / BLOCK_CUBE; // Hin=M1=8
loadData3DParams.l1W = BLOCK_CUBE; // Win=M0
loadData3DParams.channelSize = CeilAlign(s2L0RealSize, BLOCK_CUBE); // Cin=K
loadData3DParams.padList[0] = 0;
loadData3DParams.padList[1] = 0;
loadData3DParams.padList[2] = 0;
loadData3DParams.padList[3] = 255; // 尾部数据不影响滑窗的结果
// SetLoadToA0Params
loadData3DParams.mExtension = constInfo_.gSize; // M height维度目的
loadData3DParams.kExtension = CeilAlign(s2L0RealSize, BLOCK_CUBE); // K width维度目的
loadData3DParams.kStartPt = 0;
loadData3DParams.strideW = 1;
loadData3DParams.strideH = 1;
loadData3DParams.filterW = 1;
loadData3DParams.filterSizeW = (1 >> 8) & 255;
loadData3DParams.filterH = 1;
loadData3DParams.filterSizeH = (1 >> 8) & 255;
loadData3DParams.dilationFilterW = 1;
loadData3DParams.dilationFilterH = 1;
loadData3DParams.enTranspose = 1;
loadData3DParams.fMatrixCtrl = 0;
loadData3DParams.mStartPt = mStartPt;
LoadData<half, LOAD3DV2_CONFIG>(
l0b_.template ReinterpretCast<half>()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K],
sL1_[(sL1BufIdx % DOUBLE_BUF_NUM) * SL1_BUFFER_OFFSET], loadData3DParams);
}
// s1,g,1(16), 2,64,16
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::LoadWeightToL0a(uint64_t s1gL1Offset)
{
LoadData2DParams loadData2DParams;
loadData2DParams.startIndex = 0;
loadData2DParams.repeatTimes = CeilDiv(constInfo_.gSize, BLOCK_CUBE);
loadData2DParams.srcStride = 1;
loadData2DParams.dstGap = 0;
loadData2DParams.ifTranspose = true;
LoadData(l0a_.template ReinterpretCast<half>()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K],
weightL1_[(qwL1Mte2BufIdx_ % DOUBLE_BUF_NUM) * WEIGHT_BUFFER_OFFSET + s1gL1Offset* BLOCK_CUBE],
loadData2DParams);
}
// s2, d -> 128,128
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::LoadKeyToL0b(uint64_t s2L0RealSize)
{
LoadData2DParams loadData2DParams;
loadData2DParams.startIndex = 0;
loadData2DParams.repeatTimes = CeilDiv(s2L0RealSize, BLOCK_CUBE) * CeilDiv(constInfo_.headDim, S8_BLOCK_CUBE);
loadData2DParams.srcStride = 1;
loadData2DParams.dstGap = 0;
loadData2DParams.ifTranspose = false;
LoadData(l0b_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K],
keyL1_[(keyL1BufIdx_ % DOUBLE_BUF_NUM) * KEY_BUFFER_OFFSET], loadData2DParams);
}
// A: s1,g,1(16) B: s1,g,s2 C: s1, 1(16), s2
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::ComputeWs(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, int64_t s1gOffset)
{
SetFlag<HardEvent::MTE1_M>(MTE1_M_EVENT);
WaitFlag<HardEvent::MTE1_M>(MTE1_M_EVENT);
MmadParams mmadParams;
mmadParams.m = BLOCK_CUBE;
mmadParams.n = s2L0RealSize;
mmadParams.k = constInfo_.gSize;
mmadParams.cmatrixInitVal = true;
mmadParams.cmatrixSource = false;
Mmad(cL0_.template ReinterpretCast<float>()[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET +
s1gOffset * S2_BASIC_BLOCK_L0],
l0a_.template ReinterpretCast<half>()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K],
l0b_.template ReinterpretCast<half>()[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_FP16_16K],
mmadParams);
}
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::ComputeQk(uint64_t s1gL0RealSize, uint64_t s2L0RealSize)
{
SetFlag<HardEvent::MTE1_M>(MTE1_M_EVENT);
WaitFlag<HardEvent::MTE1_M>(MTE1_M_EVENT);
MmadParams mmadParams;
mmadParams.m = CeilAlign(s1gL0RealSize, BLOCK_CUBE);
mmadParams.n = s2L0RealSize;
mmadParams.k = constInfo_.headDim;
mmadParams.cmatrixInitVal = true;
mmadParams.cmatrixSource = false;
Mmad(cL0_[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET],
l0a_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K],
l0b_[(l0BufIdx_ % L0AB_BUF_NUM) * L0AB_BUFFER_OFFSET_S8_16K], mmadParams);
if ((mmadParams.m / 16) * (mmadParams.n / 16) < 10) {
PipeBarrier<PIPE_M>();
}
}
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::FixpSToL1(uint64_t s1gL0RealSize, uint64_t s2L0RealSize)
{
SetFlag<HardEvent::M_FIX>(M_FIX_EVENT);
WaitFlag<HardEvent::M_FIX>(M_FIX_EVENT);
DataCopyCO12DstParams params;
params.mSize = CeilAlign(s1gL0RealSize, BLOCK_CUBE);
params.nSize = CeilAlign(s2L0RealSize, BLOCK_CUBE);
params.dstStride = S1G_BASIC_BLOCK_L0;
params.srcStride = params.mSize;
params.quantPre = QuantMode_t::DEQF16;
params.reluPre = 1;
params.channelSplit = 0;
params.nz2ndEn = 0;
SetFixpipePreQuantFlag(0x3a800000);
DataCopy(sL1_[(sL1BufIdx_ % DOUBLE_BUF_NUM) * SL1_BUFFER_OFFSET],
cL0_[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET], params);
}
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::FixpResToGm(uint64_t s1L0RealCount, uint64_t s2L0RealSize, uint64_t s1GmOffset,
uint64_t s2GmOffset, const LIQCommon::RunInfo &runInfo)
{
SetFlag<HardEvent::M_FIX>(M_FIX_EVENT);
WaitFlag<HardEvent::M_FIX>(M_FIX_EVENT);
AscendC::DataCopyCO12DstParams intriParams;
intriParams.mSize = 1;
intriParams.nSize = s2L0RealSize;
intriParams.dstStride = constInfo_.s2BaseSize;
intriParams.srcStride = 16;
// set mode according to dtype
intriParams.quantPre = QuantMode_t::NoQuant;
intriParams.nz2ndEn = true;
intriParams.reluPre = 0;
AscendC::SetFixpipeNz2ndFlag(s1L0RealCount, CeilDiv(constInfo_.gSize, BLOCK_CUBE) * S2_BASIC_BLOCK_L0 / BLOCK_CUBE,
2048);
AscendC::DataCopy(mm1ResGm_[(runInfo.loop % 2) * constInfo_.mBaseSize / constInfo_.gSize * constInfo_.s2BaseSize +
s1GmOffset * intriParams.dstStride + s2GmOffset],
cL0_.template ReinterpretCast<float>()[(l0cBufIdx_ % DOUBLE_BUF_NUM) * L0C_BUFFER_OFFSET],
intriParams);
}
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::AllocEventID()
{
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 0);
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 1);
SetFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 2);
SetFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + 0);
SetFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + 1);
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 0);
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 1);
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 2);
SetFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 3);
SetFlag<HardEvent::FIX_M>(FIX_M_EVENT + 0);
SetFlag<HardEvent::FIX_M>(FIX_M_EVENT + 1);
}
template <typename LIQT>
__aicore__ inline void LIQMatmul<LIQT>::FreeEventID()
{
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 0);
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 1);
WaitFlag<HardEvent::MTE1_MTE2>(KEY_MTE1_MTE2_EVENT + 2);
WaitFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + 0);
WaitFlag<HardEvent::MTE1_MTE2>(QW_MTE1_MTE2_EVENT + 1);
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 0);
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 1);
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 2);
WaitFlag<HardEvent::M_MTE1>(M_MTE1_EVENT + 3);
WaitFlag<HardEvent::FIX_M>(FIX_M_EVENT + 0);
WaitFlag<HardEvent::FIX_M>(FIX_M_EVENT + 1);
}
} // namespace LIQKernel
#endif

View File

@@ -0,0 +1,665 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_quant_service_vector.h
* \brief
*/
#ifndef LIGHTNING_INDEXER_QUANT_SERVICE_VECTOR_H
#define LIGHTNING_INDEXER_QUANT_SERVICE_VECTOR_H
#include "kernel_operator.h"
#include "kernel_operator_list_tensor_intf.h"
#include "kernel_tiling/kernel_tiling.h"
#include "lib/matmul_intf.h"
#include "lib/matrix/matmul/tiling.h"
#include "lightning_indexer_quant_common.h"
#include "lightning_indexer_quant_vector.h"
namespace LIQKernel {
using namespace LIQCommon;
using namespace LIQServiceVec;
constexpr uint32_t BASE_TOPK = 2048;
constexpr uint32_t BASE_TOPK_VALUE_IDX_SIZE = 4096;
constexpr uint32_t LD_PARAM_NUM = 16;
template <typename LIQT>
class LIQVector {
public:
// =================================类型定义区=================================
static constexpr LI_LAYOUT Q_LAYOUT_T = LIQT::layout;
static constexpr LI_LAYOUT K_LAYOUT_T = LIQT::keyLayout;
static constexpr bool PAGE_ATTENTION = LIQT::pageAttention;
// MM输出数据类型, 当前只支持float
using MM1_OUT_T = float;
__aicore__ inline LIQVector(){};
__aicore__ inline void ProcessVec0(const LIQCommon::RunInfo &info);
__aicore__ inline void ProcessVec1(const LIQCommon::RunInfo &info);
__aicore__ inline void ProcessLD();
__aicore__ inline void InitBuffers(TPipe *pipe);
__aicore__ inline void InitParams(const struct LIQCommon::ConstInfo &constInfo,
const LIQTilingData *__restrict tilingData);
__aicore__ inline void InitVecWorkspaceTensor(GlobalTensor<half> vec0OutGm, GlobalTensor<MM1_OUT_T> mm1ResGm,
GlobalTensor<float> vec1ResGm, GlobalTensor<int64_t> vec1ParamGm);
__aicore__ inline void InitVecInputTensor(GlobalTensor<half> weightsGm, GlobalTensor<half> qScaleGm,
GlobalTensor<half> kScaleGm, GlobalTensor<int32_t> indiceOutGm,
GlobalTensor<int32_t> blockTableGm);
__aicore__ inline void CleanInvalidOutput(int64_t invalidS1offset);
__aicore__ inline void AllocEventID();
__aicore__ inline void FreeEventID();
__aicore__ inline void InitLDBuffers(TPipe *pipe);
protected:
GlobalTensor<MM1_OUT_T> mm1ResGm;
GlobalTensor<float> vec1ResGm;
GlobalTensor<int64_t> vec1ParamGm;
GlobalTensor<half> weightsGm;
GlobalTensor<half> qScaleGm;
GlobalTensor<half> kScaleGm;
GlobalTensor<half> vec0OutGm;
GlobalTensor<int32_t> indiceOutGm;
GlobalTensor<int32_t> blockTableGm;
// =================================常量区=================================
private:
__aicore__ inline void GetKeyScale(const LIQCommon::RunInfo &runInfo, const LocalTensor<half> &resUb,
int64_t batchId, int64_t startS2, int64_t getLen);
// ================================Local Buffer区====================================
// queue
TQue<QuePosition::VECIN, 1> inQueue_;
TQue<QuePosition::VECOUT, 1> outQueue_;
// tmp buff for vector
TBuf<TPosition::VECCALC> sortOutBuf_;
TBuf<TPosition::VECCALC> indexBuf_;
TBuf<TPosition::VECCALC> paramBuf_;
TBuf<TPosition::VECCALC> tmpBuf_;
// tmp buff for LD
TBuf<> ldToBeMrgBuf_;
TBuf<> ldTmpBuf_;
TBuf<> ldOutValueBuf_;
TBuf<> ldOutIdxBuf_;
LocalTensor<int32_t> globalTopkIndice_;
LocalTensor<float> globalTopkUb_;
int32_t blockId_ = -1;
// para for vector
int32_t groupInner_ = 0;
int32_t globalTopkNum_ = 0;
int64_t blockS2StartIdx_ = 0;
int32_t gSize_ = 0;
int32_t kSeqSize_ = 0;
int32_t kHeadNum_ = 0;
int32_t qHeadNum_ = 0;
int32_t s1BaseSize_ = 0;
int32_t s2BaseSize_ = 0;
int32_t kCacheBlockSize_ = 0;
int32_t maxBlockNumPerBatch_ = 0;
// para for LD
uint32_t mrgListNum_ = 4;
uint32_t paramNum_ = 16;
struct LIQCommon::ConstInfo constInfo_;
};
template <typename LIQT>
__aicore__ inline void LIQVector<LIQT>::GetKeyScale(const LIQCommon::RunInfo &runInfo, const LocalTensor<half> &resUb,
int64_t batchId, int64_t startS2, int64_t getLen)
{
// startS2一定能整除kCacheBlockSize_
AscendC::DataCopyPadExtParams<half> padParams{false, 0, 0, 0};
AscendC::DataCopyExtParams copyInParams;
if constexpr (PAGE_ATTENTION) {
int32_t startBlockTableIdx = startS2 / kCacheBlockSize_;
int32_t startBlockTableOffset = startS2 % kCacheBlockSize_;
int32_t blockTableBatchOffset = batchId * maxBlockNumPerBatch_;
copyInParams.blockCount = 1;
copyInParams.srcStride = 0;
copyInParams.dstStride = 0;
copyInParams.rsv = 0;
int32_t resUbBaseOffset = 0;
if (startBlockTableOffset > 0) {
int32_t firstPartLen =
kCacheBlockSize_ - startBlockTableOffset > getLen ? getLen : kCacheBlockSize_ - startBlockTableOffset;
copyInParams.blockLen = firstPartLen * sizeof(half);
int32_t blockId = blockTableGm.GetValue(blockTableBatchOffset + startBlockTableIdx);
SetWaitFlag<HardEvent::S_MTE2>(HardEvent::S_MTE2);
AscendC::DataCopyPad(resUb, kScaleGm[blockId * kCacheBlockSize_ + startBlockTableOffset],
copyInParams, padParams);
startBlockTableIdx++;
getLen = getLen - firstPartLen;
resUbBaseOffset = firstPartLen;
}
int32_t getLoopNum = CeilDiv(getLen, kCacheBlockSize_);
copyInParams.blockLen = kCacheBlockSize_ * sizeof(half);
for (int32_t i = 0; i < getLoopNum; i++) {
if (i == getLoopNum - 1) {
copyInParams.blockLen = (getLen - i * kCacheBlockSize_) * sizeof(half);
}
int32_t blockId = blockTableGm.GetValue(blockTableBatchOffset + startBlockTableIdx + i);
SetWaitFlag<HardEvent::S_MTE2>(HardEvent::S_MTE2);
AscendC::DataCopyPad(resUb[resUbBaseOffset + i * kCacheBlockSize_], kScaleGm[blockId * kCacheBlockSize_],
copyInParams, padParams);
}
} else {
copyInParams.blockCount = 1;
copyInParams.blockLen = getLen * sizeof(half);
copyInParams.srcStride = 0;
copyInParams.dstStride = 0;
copyInParams.rsv = 0;
AscendC::DataCopyPad(resUb, kScaleGm[runInfo.tensorKeyScaleOffset], copyInParams, padParams);
}
}
template <typename LIQT>
__aicore__ inline void LIQVector<LIQT>::InitBuffers(TPipe *pipe)
{
pipe->InitBuffer(paramBuf_, LD_PARAM_NUM * sizeof(int64_t)); // 1 KB
pipe->InitBuffer(inQueue_, 2, s2BaseSize_ * sizeof(float) * 2); // 32KB
pipe->InitBuffer(outQueue_, 1, BASE_TOPK * sizeof(float)); // 8 KB
pipe->InitBuffer(indexBuf_, s2BaseSize_ * sizeof(int32_t)); // 8 KB
pipe->InitBuffer(tmpBuf_, 64 * 1024); // 64KB
pipe->InitBuffer(sortOutBuf_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK_VALUE_IDX_SIZE * sizeof(float)); // 32KB
globalTopkIndice_ = indexBuf_.Get<int32_t>();
globalTopkUb_ = sortOutBuf_.Get<float>();
globalTopkNum_ = 0;
// 基本块执行前初始化UB和GM
// step1. 初始化一个有序索引 0 - s2BaseSize_
ArithProgression<int32_t>(globalTopkIndice_, 0, 1, s2BaseSize_);
// step2. globalTopkUb_ [CeilDiv(s1BaseSize_, 2), BASE_TOPK, 2] -inf,-1
InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK_VALUE_IDX_SIZE);
// step3. 初始化vec1ParamGm是否进行LD的标志位设为-1(needFd=-1)
// vec1ResIn32Gm = [aic, 2, s1BaseSize_, 16] int32
// ws清零 [needFd, s2AcSeq, s2Start, s2End, isS2End, bn2idx, s1Idx, ......]
LocalTensor<float> tmpfBuff = outQueue_.AllocTensor<float>();
Duplicate(tmpfBuff.template ReinterpretCast<int32_t>(), -1, 2 * (s1BaseSize_ / 2) * paramNum_ * 2);
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
int64_t wsInfoOffset = (blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ + // 2个AIV共同地址偏移
(blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_; // 每个AIV的地址偏移S1方向
DataCopyPad(vec1ParamGm[wsInfoOffset], tmpfBuff.template ReinterpretCast<int64_t>(),
{1, static_cast<uint16_t>((s1BaseSize_ / 2) * 2 * paramNum_ * sizeof(int64_t)), 0, 0});
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
outQueue_.FreeTensor(tmpfBuff);
}
template <typename LIQT>
__aicore__ inline void LIQVector<LIQT>::InitLDBuffers(TPipe *pipe)
{
pipe->Reset();
pipe->InitBuffer(ldToBeMrgBuf_, BASE_TOPK_VALUE_IDX_SIZE * mrgListNum_ * sizeof(float));
pipe->InitBuffer(ldTmpBuf_, BASE_TOPK_VALUE_IDX_SIZE * mrgListNum_ * sizeof(float));
pipe->InitBuffer(ldOutValueBuf_, BASE_TOPK * sizeof(float));
pipe->InitBuffer(ldOutIdxBuf_, BASE_TOPK * sizeof(int32_t));
}
template <typename LIQT>
__aicore__ inline void LIQVector<LIQT>::InitParams(const struct LIQCommon::ConstInfo &constInfo,
const LIQTilingData *__restrict tilingData)
{
this->constInfo_ = constInfo;
blockS2StartIdx_ = 0;
gSize_ = constInfo.gSize;
kSeqSize_ = constInfo.kSeqSize;
// define N2 para
kHeadNum_ = constInfo.kHeadNum;
qHeadNum_ = constInfo.qHeadNum;
// define MMBase para
s1BaseSize_ = constInfo.s1BaseSize; // 4
s2BaseSize_ = constInfo.s2BaseSize; // 2048
kCacheBlockSize_ = constInfo.kCacheBlockSize;
maxBlockNumPerBatch_ = constInfo.maxBlockNumPerBatch;
blockId_ = GetBlockIdx();
}
template <typename LIQT>
__aicore__ inline void LIQVector<LIQT>::InitVecInputTensor(GlobalTensor<half> weightsGm, GlobalTensor<half> qScaleGm,
GlobalTensor<half> kScaleGm,
GlobalTensor<int32_t> indiceOutGm,
GlobalTensor<int32_t> blockTableGm)
{
this->weightsGm = weightsGm;
this->qScaleGm = qScaleGm;
this->kScaleGm = kScaleGm;
this->indiceOutGm = indiceOutGm;
this->blockTableGm = blockTableGm;
}
template <typename LIQT>
__aicore__ inline void LIQVector<LIQT>::InitVecWorkspaceTensor(GlobalTensor<half> vec0OutGm,
GlobalTensor<MM1_OUT_T> mm1ResGm,
GlobalTensor<float> vec1ResGm,
GlobalTensor<int64_t> vec1ParamGm)
{
this->mm1ResGm = mm1ResGm;
this->vec1ResGm = vec1ResGm;
this->vec0OutGm = vec0OutGm;
this->vec1ParamGm = vec1ParamGm;
}
template <typename LIQT>
__aicore__ inline void LIQVector<LIQT>::AllocEventID()
{
}
template <typename LIQT>
__aicore__ inline void LIQVector<LIQT>::FreeEventID()
{
}
template <typename LIQT>
__aicore__ inline void LIQVector<LIQT>::CleanInvalidOutput(int64_t invalidS1offset)
{
// init -1 and copy to output
LocalTensor<float> valueULocal = outQueue_.AllocTensor<float>();
LocalTensor<int32_t> idxULocal1 = valueULocal.template ReinterpretCast<int32_t>();
Duplicate(idxULocal1, constInfo_.INVALID_IDX, constInfo_.sparseCount);
outQueue_.EnQue<float>(valueULocal);
valueULocal = outQueue_.DeQue<float>();
LIQServiceVec::CopyOut(indiceOutGm[invalidS1offset], idxULocal1, constInfo_.sparseCount);
outQueue_.FreeTensor(valueULocal);
}
template <typename LIQT>
__aicore__ inline void LIQVector<LIQT>::ProcessVec0(const LIQCommon::RunInfo &info)
{
// 只需要一个v核做
if (blockId_ % 2 != 0) {
return;
}
int32_t cuBaseS1Idx = info.gS1Idx * s1BaseSize_;
// 计算输出w基地址偏移 偶数循环 -> 0 + aic_offset 奇数循环 -> 4*64 + aic_offset
int64_t vec0OutGmOffset = (info.loop % 2) * ((s1BaseSize_ * gSize_ * BLOCK_CUBE));
// 计算输入weight的地址偏移qScale的地址偏移与weight相同
int64_t weightGmOffset = info.tensorWeightsOffset + cuBaseS1Idx * qHeadNum_;
// 当前需要计算的S1行数处理尾块场景
int32_t cuS1ProcNum = cuBaseS1Idx + s1BaseSize_ > info.actS1Size ? info.actS1Size % s1BaseSize_ : s1BaseSize_;
int32_t cuProcEleNum = cuS1ProcNum * gSize_;
LocalTensor<half> inWeightsUb = inQueue_.AllocTensor<half>();
LocalTensor<half> inQScaleUb = inWeightsUb[cuProcEleNum];
AscendC::DataCopyPadExtParams<half> padParams{false, 0, 0, 0};
AscendC::DataCopyExtParams copyInParams;
copyInParams.blockCount = 1;
copyInParams.blockLen = cuProcEleNum * sizeof(half);
copyInParams.srcStride = 0;
copyInParams.dstStride = 0;
copyInParams.rsv = 0;
AscendC::DataCopyPad(inWeightsUb, weightsGm[weightGmOffset], copyInParams, padParams);
AscendC::DataCopyPad(inQScaleUb, qScaleGm[weightGmOffset], copyInParams, padParams);
inQueue_.EnQue<half>(inWeightsUb);
inWeightsUb = inQueue_.DeQue<half>();
AscendC::Mul(inWeightsUb, inWeightsUb, inQScaleUb, cuProcEleNum);
PipeBarrier<PIPE_V>();
LocalTensor<half> resUb = outQueue_.AllocTensor<half>();
AscendC::Brcb(resUb, inWeightsUb, static_cast<uint8_t>(cuProcEleNum / 8), {1, 8});
inQueue_.FreeTensor(inWeightsUb);
outQueue_.EnQue<half>(resUb);
resUb = outQueue_.DeQue<half>();
AscendC::DataCopyParams copyOutParams;
copyOutParams.blockCount = 1;
copyOutParams.blockLen = cuProcEleNum * BLOCK_CUBE * sizeof(half);
copyOutParams.srcStride = 0;
copyOutParams.dstStride = 0;
AscendC::DataCopyPad(vec0OutGm[vec0OutGmOffset], resUb, copyOutParams);
outQueue_.FreeTensor(resUb);
}
template <typename LIQT>
__aicore__ inline void LIQVector<LIQT>::ProcessVec1(const LIQCommon::RunInfo &info)
{
int32_t cuBaseS1Idx = info.gS1Idx * s1BaseSize_;
int32_t cuBaseS2Idx = info.s2Idx * s2BaseSize_;
// 计算基本块基地址偏移 偶数循环 -> 0 + aic_offset 奇数循环 -> 4*2048 + aic_offset
int64_t mmGmOffset = (info.loop % 2) * (s1BaseSize_ * s2BaseSize_);
// cuS1BeginIdxPerAiv: 每个AIV的S1起始偏移
int32_t cuS1BeginIdxPerAiv = cuBaseS1Idx;
int32_t cuS1ProcNum =
cuS1BeginIdxPerAiv + s1BaseSize_ > info.actS1Size ? info.actS1Size % s1BaseSize_ : s1BaseSize_;
// cuS1ProcNumPerAiv: 每个AIv的S1计算量
int32_t cuS1ProcNumPerAiv = blockId_ % 2 == 0 ? CeilDiv(cuS1ProcNum, 2) : (cuS1ProcNum / 2);
cuS1BeginIdxPerAiv += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2);
// 基本块基地址偏移奇数核加一个S1地址偏移
mmGmOffset += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2) * s2BaseSize_;
// 非首个基本块, M(S1)轴发生切换需要初始化
if (info.loop != 0 && info.s2Idx == 0) {
// globalTopkUb_ value,index=-inf,-1
InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK_VALUE_IDX_SIZE);
blockS2StartIdx_ = 0;
} else if (info.loop == 0) {
blockS2StartIdx_ = info.s2Idx;
}
// cuRealAcSeq: 当前基本块S1对应的AcSeq
int32_t cuRealAcSeq = info.actS2Size;
if (constInfo_.attenMaskFlag) {
// attenMask true场景
cuRealAcSeq = info.actS2Size - (info.actS1Size - cuS1BeginIdxPerAiv);
}
// LD输出S1方向偏移保证2个Vector输出的内容连续
uint32_t ldS1Offset = (blockId_ % 2 == 0) ? s1BaseSize_ / 2 - cuS1ProcNumPerAiv : 0;
for (int innerS1Idx = 0; innerS1Idx < cuS1ProcNumPerAiv; innerS1Idx++) {
if (constInfo_.attenMaskFlag) {
cuRealAcSeq += 1;
}
int32_t cuS2Len = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq ? cuRealAcSeq - cuBaseS2Idx : s2BaseSize_;
int32_t cuS1Idx = cuS1BeginIdxPerAiv + innerS1Idx;
if (cuRealAcSeq > 0 && cuS2Len > 0) {
int32_t cuS2LenVecAlign = CeilDiv(cuS2Len, s2BaseSize_) * s2BaseSize_;
LocalTensor<float> mmInUb = inQueue_.AllocTensor<float>();
LocalTensor<float> kScaleUb = mmInUb[cuS2LenVecAlign];
LocalTensor<half> kScaleTUb = kScaleUb.template ReinterpretCast<half>()[cuS2LenVecAlign];
AscendC::DataCopyPadExtParams<float> padParams{false, 0, 0, 0};
AscendC::DataCopyPadExtParams<half> padTParams{false, 0, 0, 0};
AscendC::DataCopyExtParams copyInParams;
copyInParams.blockCount = 1;
copyInParams.blockLen = cuS2Len * sizeof(float);
copyInParams.srcStride = 0;
copyInParams.dstStride = 0;
copyInParams.rsv = 0;
AscendC::DataCopyPad(mmInUb, mm1ResGm[mmGmOffset + innerS1Idx * s2BaseSize_], copyInParams, padParams);
GetKeyScale(info, kScaleTUb, info.bIdx, cuBaseS2Idx, cuS2Len);
inQueue_.EnQue<float>(mmInUb);
mmInUb = inQueue_.DeQue<float>();
AscendC::Cast(kScaleUb, kScaleTUb, RoundMode::CAST_NONE, cuS2Len);
PipeBarrier<PIPE_V>();
AscendC::Mul(mmInUb, mmInUb, kScaleUb, cuS2Len);
PipeBarrier<PIPE_V>();
LocalTensor<float> sortBuff = tmpBuf_.Get<float>();
LocalTensor<float> sortScoreUb = sortBuff;
LocalTensor<float> sortIndiceUb = sortBuff[cuS2LenVecAlign];
PipeBarrier<PIPE_V>();
Duplicate(sortScoreUb.template ReinterpretCast<int32_t>(), LIQServiceVec::NEG_INF, cuS2LenVecAlign);
PipeBarrier<PIPE_V>();
Adds(sortScoreUb, mmInUb, 0.0f, cuS2Len);
PipeBarrier<PIPE_V>();
inQueue_.FreeTensor(mmInUb);
LocalTensor<int32_t> sortIndiceUbInt = sortIndiceUb.template ReinterpretCast<int32_t>();
// 无效数据索引填充为-1
if (cuS2LenVecAlign != cuS2Len) {
Duplicate(sortIndiceUbInt, -1, cuS2LenVecAlign);
PipeBarrier<PIPE_V>();
}
Adds(sortIndiceUbInt, globalTopkIndice_, static_cast<int32_t>(cuBaseS2Idx), cuS2Len);
PipeBarrier<PIPE_V>();
LocalTensor<float> tmpSortBuf = sortBuff[2 * cuS2LenVecAlign];
LIQServiceVec::SortAll(sortBuff, tmpSortBuf, cuS2LenVecAlign);
PipeBarrier<PIPE_V>();
LIQServiceVec::MergeSort(globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE], BASE_TOPK, sortBuff,
cuS2LenVecAlign, tmpSortBuf);
PipeBarrier<PIPE_V>();
bool isS2End = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq;
bool needCopyOutGm = blockS2StartIdx_ == 0 && isS2End;
// 中间结果保存
bool needCopyWsGm = info.isAllLoopEnd || isS2End;
if (needCopyOutGm) {
LocalTensor<uint32_t> idxULocal = outQueue_.AllocTensor<uint32_t>();
ExtractIndex(idxULocal,
globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE].template ReinterpretCast<uint32_t>(),
BASE_TOPK);
PipeBarrier<PIPE_V>();
InitSortOutBuf(globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE], BASE_TOPK_VALUE_IDX_SIZE);
outQueue_.EnQue<uint32_t>(idxULocal);
idxULocal = outQueue_.DeQue<uint32_t>();
LIQServiceVec::CopyOut(indiceOutGm[info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount],
idxULocal.template ReinterpretCast<int32_t>(), constInfo_.sparseCount);
outQueue_.FreeTensor(idxULocal);
} else if (needCopyWsGm) {
// vec1Res Gm = [aic, s1BaseSize_, 2, 2, topkOut_] float32
// vec1Param Gm = [aic, s1BaseSize_, 2, 16] int64
// 16 = [needFd, s2AcSeq, s2Start, s2End, isS2End, bn2idx, s1Idx, S1ProcNum, ......]
int64_t wsOffset =
(blockId_ / 2) * s1BaseSize_ * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 2个AIV共同地址偏移
(blockId_ % 2) * (s1BaseSize_ / 2) * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 每个AIV的地址偏移S1方向
(ldS1Offset + innerS1Idx) * 2 * BASE_TOPK_VALUE_IDX_SIZE;
int64_t wsInfoOffset =
(blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ + // 2个AIV共同地址偏移
(blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_ + // 每个AIV的地址偏移S1方向
(ldS1Offset + innerS1Idx) * 2 * paramNum_;
LocalTensor<int64_t> tmpiBuff = paramBuf_.Get<int64_t>();
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
tmpiBuff.SetValue(0, static_cast<int64_t>(1));
tmpiBuff.SetValue(1, static_cast<int64_t>(cuRealAcSeq));
tmpiBuff.SetValue(2, static_cast<int64_t>(blockS2StartIdx_));
tmpiBuff.SetValue(3, static_cast<int64_t>(cuBaseS2Idx + cuS2Len));
tmpiBuff.SetValue(4, static_cast<int64_t>(isS2End));
tmpiBuff.SetValue(5, static_cast<int64_t>(info.bN2Idx));
tmpiBuff.SetValue(6, static_cast<int64_t>(cuS1Idx));
tmpiBuff.SetValue(7, static_cast<int64_t>(cuS1ProcNum));
tmpiBuff.SetValue(8, static_cast<int64_t>(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount));
// 写入头尾判断
// [head, tail]
// head: 与前面规约,与前后规约
// tail: 与后面规约
bool isTailReduce = blockS2StartIdx_ == 0; // 一定是isLastTile
// WS偏移规则 blockS2StartIdx_ != 0
// 跟前面块做规约 写到0偏移 不用做计算 blockS2StartIdx_ == 0 and !isS2End
// 跟后面块做规约 写到1偏移 需要 + s1BaseSize_, BASE_TOPK*2
if (isTailReduce) { // S2不是最后结束的数据就需要往后做规约放入第二块ws
wsInfoOffset += paramNum_;
wsOffset += BASE_TOPK_VALUE_IDX_SIZE;
}
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
LIQServiceVec::CopyOut(vec1ParamGm[wsInfoOffset], tmpiBuff, 16);
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
LIQServiceVec::CopyOut(vec1ResGm[wsOffset], globalTopkUb_[innerS1Idx * BASE_TOPK_VALUE_IDX_SIZE],
BASE_TOPK_VALUE_IDX_SIZE);
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
}
} else if (cuRealAcSeq <= 0) {
CleanInvalidOutput(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount);
}
}
// BNSD场景无效S1 输出-1
if (Q_LAYOUT_T == LI_LAYOUT::BSND) {
// 最后一个S1的基本块, 需要 >= info.actS1Size
bool isS1LoopEnd = (cuBaseS1Idx + s1BaseSize_) >= info.actS1Size;
int32_t invalidS1Num = constInfo_.qSeqSize - info.actS1Size;
// blockS2StartIdx_ == 0 控制S2从开始的核去做冗余清理
if (invalidS1Num > 0 && isS1LoopEnd && blockS2StartIdx_ == 0) {
int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num, 2) : (invalidS1Num / 2);
int32_t s1OffsetPerAiv = info.actS1Size + (blockId_ % 2) * CeilDiv(invalidS1Num, 2);
for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) {
CleanInvalidOutput(info.indiceOutOffset + (s1OffsetPerAiv + innerS1Idx) * constInfo_.sparseCount);
}
}
int32_t invalidS1Num2 = info.actS1Size - info.actS2Size;
if (invalidS1Num2 > 0 && isS1LoopEnd && blockS2StartIdx_ == 0 && constInfo_.attenMaskFlag) {
int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num2, 2) : (invalidS1Num2 / 2);
int32_t s1OffsetPerAiv = (blockId_ % 2) * CeilDiv(invalidS1Num2, 2);
for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) {
CleanInvalidOutput((info.bN2Idx * constInfo_.qSeqSize + s1OffsetPerAiv + innerS1Idx) *
constInfo_.sparseCount);
}
}
}
if (info.isLastS2InnerLoop) {
// S2最后一个Loop后, 下一个基本块初始从0开始
blockS2StartIdx_ = 0;
}
}
template <typename LIQT>
__aicore__ inline void LIQVector<LIQT>::ProcessLD()
{
int32_t curCubeId = blockId_ / 2;
int32_t tmpCubeId = curCubeId;
int64_t s2ActSeq;
int64_t s2Start;
int64_t s2End;
int64_t isS2End;
int64_t bn2Idx;
int64_t s1Idx;
uint32_t acc_list_num = 0;
int64_t bIdx = 0;
int64_t needFd;
int64_t wsOffset;
int64_t wsInfoOffset = 0;
int64_t nextneedFd;
int64_t valueOffset = 0;
int64_t outOffset = 0;
LocalTensor<float> curValueIdxUb = ldToBeMrgBuf_.Get<float>();
LocalTensor<float> tmpUb = ldTmpBuf_.Get<float>();
// S2开头信息
// 开始必然没有头规约因此从尾规约开始处理while循环读取下一个核的头规约
// 存满4个list或者遇到S2结尾则做merge直到做完S2
// 每个核都忽略自己的头规约,因为必然由前面的核做完
uint32_t s1LdStartIdx = 0;
uint32_t s1ProcNum = 0;
uint64_t paramGmCoreOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_;
for (uint32_t innerS1Idx = 0; innerS1Idx < s1BaseSize_; innerS1Idx++) {
needFd = vec1ParamGm.GetValue(paramGmCoreOffset + innerS1Idx * 2 * paramNum_ + paramNum_);
if (needFd == 1) {
s1LdStartIdx = (s1ProcNum == 0) ? innerS1Idx : s1LdStartIdx;
s1ProcNum++;
}
}
if (s1ProcNum == 0) {
return;
}
// S1逐行计算
uint32_t s1VecNum = CeilDiv(s1ProcNum, 2);
if (blockId_ % 2 == 1) {
s1LdStartIdx = s1LdStartIdx + s1VecNum;
s1VecNum = s1ProcNum - s1VecNum;
}
for (uint32_t innerS1Idx = s1LdStartIdx; innerS1Idx < s1LdStartIdx + s1VecNum; innerS1Idx++) {
// 重置偏移
tmpCubeId = curCubeId;
acc_list_num = 0;
valueOffset = 0;
// 搬入数据
wsOffset = tmpCubeId * s1BaseSize_ * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 2个AIV共同地址偏移
innerS1Idx * 2 * BASE_TOPK_VALUE_IDX_SIZE + BASE_TOPK_VALUE_IDX_SIZE;
SetWaitFlag<HardEvent::V_MTE2>(HardEvent::V_MTE2);
SetWaitFlag<HardEvent::S_MTE2>(HardEvent::S_MTE2);
DataCopyPad(curValueIdxUb, vec1ResGm[wsOffset],
{1, static_cast<uint16_t>(BASE_TOPK_VALUE_IDX_SIZE * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0});
acc_list_num++;
valueOffset += BASE_TOPK_VALUE_IDX_SIZE;
// 获取下一个核规约信息
tmpCubeId++;
wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_;
needFd = vec1ParamGm.GetValue(wsInfoOffset);
isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4);
s1Idx = vec1ParamGm.GetValue(wsInfoOffset + 6);
outOffset = vec1ParamGm.GetValue(wsInfoOffset + 8);
while (needFd == 1) {
// 搬入头规约数据
wsOffset = tmpCubeId * s1BaseSize_ * 2 * BASE_TOPK_VALUE_IDX_SIZE + // 2个AIV共同地址偏移
innerS1Idx * 2 * BASE_TOPK_VALUE_IDX_SIZE;
SetWaitFlag<HardEvent::V_MTE2>(HardEvent::V_MTE2);
SetWaitFlag<HardEvent::S_MTE2>(HardEvent::S_MTE2);
DataCopyPad(curValueIdxUb[valueOffset], vec1ResGm[wsOffset],
{1, static_cast<uint16_t>(BASE_TOPK_VALUE_IDX_SIZE * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0});
valueOffset += BASE_TOPK_VALUE_IDX_SIZE;
acc_list_num++;
// 每满4个list聚合 前2K为mrg结果
if (acc_list_num == mrgListNum_) {
// MrgSort 四条2048的队列Mrg成一条
AscendC::MrgSort4Info params;
params.elementLengths[0] = BASE_TOPK;
params.elementLengths[1] = BASE_TOPK;
params.elementLengths[2] = BASE_TOPK;
params.elementLengths[3] = BASE_TOPK;
params.ifExhaustedSuspension = true;
params.validBit = 0b1111;
params.repeatTimes = 1;
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = curValueIdxUb[0];
srcList.src2 = curValueIdxUb[BASE_TOPK_VALUE_IDX_SIZE];
srcList.src3 = curValueIdxUb[2 * BASE_TOPK_VALUE_IDX_SIZE];
srcList.src4 = curValueIdxUb[3 * BASE_TOPK_VALUE_IDX_SIZE];
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
MrgSort(tmpUb, srcList, params);
PipeBarrier<PIPE_V>();
DataCopy(curValueIdxUb, tmpUb, BASE_TOPK_VALUE_IDX_SIZE);
PipeBarrier<PIPE_V>();
acc_list_num = 1;
valueOffset = BASE_TOPK_VALUE_IDX_SIZE;
}
// reduce到S2末尾则跳出
if (isS2End == 1) {
break;
}
tmpCubeId++;
wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_;
needFd = vec1ParamGm.GetValue(wsInfoOffset);
isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4);
}
// mrg不足4个list的数据
if (acc_list_num != 1) {
AscendC::MrgSort4Info params;
params.elementLengths[0] = BASE_TOPK;
params.elementLengths[1] = BASE_TOPK;
params.elementLengths[2] = BASE_TOPK;
params.elementLengths[3] = BASE_TOPK;
params.ifExhaustedSuspension = true;
if (acc_list_num == 2) {
params.validBit = 0b0011;
} else if (acc_list_num == 3) {
params.validBit = 0b0111;
}
params.repeatTimes = 1;
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = curValueIdxUb[0];
srcList.src2 = curValueIdxUb[BASE_TOPK_VALUE_IDX_SIZE];
srcList.src3 = curValueIdxUb[2 * BASE_TOPK_VALUE_IDX_SIZE];
srcList.src4 = curValueIdxUb[3 * BASE_TOPK_VALUE_IDX_SIZE];
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
MrgSort(tmpUb, srcList, params);
PipeBarrier<PIPE_V>();
DataCopy(curValueIdxUb, tmpUb, BASE_TOPK_VALUE_IDX_SIZE);
PipeBarrier<PIPE_V>();
}
// 搬出
LocalTensor<float> outValueUb = ldOutValueBuf_.Get<float>();
LocalTensor<uint32_t> outIdxUb = ldOutIdxBuf_.Get<uint32_t>();
Extract(outValueUb, outIdxUb, curValueIdxUb, (BASE_TOPK / 32));
LocalTensor<int32_t> idxULocal1 = outIdxUb.template ReinterpretCast<int32_t>();
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
DataCopyPad(indiceOutGm[outOffset], idxULocal1,
{1, static_cast<uint16_t>(constInfo_.sparseCount * sizeof(int32_t)), 0, 0});
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
}
}
} // namespace LIQKernel
#endif

View File

@@ -0,0 +1,53 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_quant_template_tiling_key.h
* \brief
*/
#ifndef TEMPLATE_TILING_KEY_LI_H_
#define TEMPLATE_TILING_KEY_LI_H_
#include "ascendc/host_api/tiling/template_argument.h"
#define LI_TPL_FP16 1
#define LI_TPL_IN8 2
#define LI_TPL_INT32 3
#define LI_TPL_BF16 27
#define LIQ_LAYOUT_BSND 0
#define LIQ_LAYOUT_TND 1
#define LIQ_LAYOUT_PA_BSND 2
#define ASCENDC_TPL_4_BW 4
// 模板参数支持的范围定义
ASCENDC_TPL_ARGS_DECL(LightningIndexerQuant, // 算子OpType
ASCENDC_TPL_DTYPE_DECL(DT_Q, LI_TPL_IN8), ASCENDC_TPL_DTYPE_DECL(DT_K, LI_TPL_IN8),
ASCENDC_TPL_DTYPE_DECL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_DECL(PAGE_ATTENTION, 1, 0),
ASCENDC_TPL_UINT_DECL(Q_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND,
LIQ_LAYOUT_TND),
ASCENDC_TPL_UINT_DECL(K_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST,
LIQ_LAYOUT_PA_BSND, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND), );
// 支持的模板参数组合
// 用于调用GET_TPL_TILING_KEY获取TilingKey时接口内部校验TilingKey是否合法
ASCENDC_TPL_SEL(
ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_IN8), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_IN8),
ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 1),
ASCENDC_TPL_UINT_SEL(Q_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND),
ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_PA_BSND), ),
ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_IN8), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_IN8),
ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 0),
ASCENDC_TPL_UINT_SEL(Q_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND),
ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LIQ_LAYOUT_BSND, LIQ_LAYOUT_TND), ), );
#endif

View File

@@ -0,0 +1,193 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file lightning_indexer_quant_vector.h
* \brief
*/
#ifndef LIGHTNING_INDEXER_QUANT_VECTOR_H
#define LIGHTNING_INDEXER_QUANT_VECTOR_H
#include "kernel_operator.h"
#include "lightning_indexer_quant_vector.h"
namespace LIQServiceVec {
using namespace AscendC;
constexpr int32_t NEG_INF = 0xFF800000;
constexpr int32_t INVALID_INDEX = -1;
constexpr uint8_t VEC_REPEAT_MAX = 255;
constexpr uint8_t B32_VEC_ELM_NUM = 64;
constexpr uint8_t B32_BLOCK_ALIGN_NUM = 8;
constexpr uint8_t B32_VEC_REPEAT_STRIDE = 8;
constexpr uint64_t VEC_REPEAT_BYTES = 256;
constexpr int32_t CONST_TWO = 2;
constexpr int64_t VALUE_AND_INDEX_NUM = 2;
constexpr int64_t BLOCK_BYTES = 32;
constexpr int64_t MRG_QUE_0 = 0;
constexpr int64_t MRG_QUE_1 = 1;
constexpr int64_t MRG_QUE_2 = 2;
constexpr int64_t MRG_QUE_3 = 3;
constexpr int64_t MRG_BLOCK_2 = 2;
constexpr int64_t MRG_BLOCK_3 = 3;
constexpr int64_t MRG_BLOCK_4 = 4;
template <typename T>
__aicore__ inline void CopyOut(const GlobalTensor<T> &dstGm, const LocalTensor<T> &srcUb, int64_t copyCount)
{
AscendC::DataCopyParams dataCopyOutyParams;
dataCopyOutyParams.blockCount = 1;
dataCopyOutyParams.blockLen = copyCount * sizeof(T);
dataCopyOutyParams.srcStride = 0;
dataCopyOutyParams.dstStride = 0;
AscendC::DataCopyPad(dstGm, srcUb, dataCopyOutyParams);
}
/**
src: 传入的初始化空间
eleNum: 需要初始化的元素个数需为64整数倍元素将被初始化为交错排布的-inf-1
*/
__aicore__ inline void InitSortOutBuf(const LocalTensor<float> &src, int64_t eleNum)
{
uint64_t mask1[2] = {0x5555555555555555, 0};
uint64_t mask0[2] = {0xaaaaaaaaaaaaaaaa, 0};
int64_t repeatNum = eleNum / B32_VEC_ELM_NUM;
int64_t forLoop = repeatNum / VEC_REPEAT_MAX;
int64_t forRemain = repeatNum % VEC_REPEAT_MAX;
for (int i = 0; i < forLoop; i++) {
AscendC::Duplicate(src.template ReinterpretCast<int32_t>(), NEG_INF, mask1, VEC_REPEAT_MAX, 1,
B32_VEC_REPEAT_STRIDE);
AscendC::Duplicate(src.template ReinterpretCast<int32_t>(), INVALID_INDEX, mask0, VEC_REPEAT_MAX, 1,
B32_VEC_REPEAT_STRIDE);
}
if (forRemain > 0) {
AscendC::Duplicate(src.template ReinterpretCast<int32_t>()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], NEG_INF,
mask1, forRemain, 1, B32_VEC_REPEAT_STRIDE);
AscendC::Duplicate(src.template ReinterpretCast<int32_t>()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM],
INVALID_INDEX, mask0, forRemain, 1, B32_VEC_REPEAT_STRIDE);
}
AscendC::PipeBarrier<PIPE_V>();
}
/**
src: logits和索引前logitsNum为logits后logitsNum为索引
tmp: 计算使用到的临时空间大小与src一致
logitsNum: 排序的元素个数, 暂只支持[128,256,384,512,1024,2048]
*/
__aicore__ inline void SortAll(LocalTensor<float> &src, LocalTensor<float> &tmp, int64_t logitsNum)
{
int64_t sort32Repeats = logitsNum / BLOCK_BYTES;
AscendC::Sort32(tmp, src, src[logitsNum].ReinterpretCast<uint32_t>(), sort32Repeats);
AscendC::PipeBarrier<PIPE_V>();
int64_t mrgGroups = sort32Repeats;
int64_t mrgElements = BLOCK_BYTES;
int64_t i = 0;
AscendC::LocalTensor<float> srcTensor;
AscendC::LocalTensor<float> dstTensor;
while (true) {
if (i % CONST_TWO == 0) {
srcTensor = tmp;
dstTensor = src;
} else {
srcTensor = src;
dstTensor = tmp;
}
AscendC::MrgSort4Info params;
params.elementLengths[0] = mrgElements;
params.elementLengths[MRG_QUE_1] = mrgElements;
params.elementLengths[MRG_QUE_2] = mrgElements;
params.elementLengths[MRG_QUE_3] = mrgElements;
params.ifExhaustedSuspension = false;
params.validBit = 0b1111;
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = srcTensor[0];
srcList.src2 = srcTensor[MRG_QUE_1 * VALUE_AND_INDEX_NUM * mrgElements];
srcList.src3 = srcTensor[MRG_QUE_2 * VALUE_AND_INDEX_NUM * mrgElements];
srcList.src4 = srcTensor[MRG_QUE_3 * VALUE_AND_INDEX_NUM * mrgElements];
if (mrgGroups <= MRG_BLOCK_4) {
params.repeatTimes = 1;
if (mrgGroups == 1) {
break;
} else if (mrgGroups == MRG_BLOCK_2) {
params.validBit = 0b0011;
} else if (mrgGroups == MRG_BLOCK_3) {
params.validBit = 0b0111;
} else if (mrgGroups == MRG_BLOCK_4) {
params.validBit = 0b1111;
}
AscendC::MrgSort<float>(dstTensor, srcList, params);
i += 1;
break;
} else {
params.repeatTimes = mrgGroups / MRG_BLOCK_4;
AscendC::MrgSort<float>(dstTensor, srcList, params);
i += 1;
mrgElements = mrgElements * MRG_BLOCK_4;
mrgGroups = mrgGroups / MRG_BLOCK_4;
}
AscendC::PipeBarrier<PIPE_V>();
}
if (i % CONST_TWO == 0) {
AscendC::DataCopy(src, tmp, logitsNum * VALUE_AND_INDEX_NUM);
AscendC::PipeBarrier<PIPE_V>();
}
}
/**
mrgDst: 合并进的Tensor
mrgSrc: 待合并的Tensor
tmpTensor空间为mrgDst+mrgSrc
*/
__aicore__ inline void MergeSort(const LocalTensor<float> &mrgDst, int32_t mrgDstNum, LocalTensor<float> &mrgSrc,
int32_t mrgSrcNum, LocalTensor<float> &tmpTensor)
{
AscendC::MrgSort4Info params;
params.elementLengths[0] = mrgSrcNum;
params.elementLengths[1] = mrgDstNum;
params.ifExhaustedSuspension = false;
params.validBit = 0b0011;
params.repeatTimes = 1;
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = mrgSrc;
srcList.src2 = mrgDst;
AscendC::MrgSort<float>(tmpTensor, srcList, params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::DataCopy(mrgDst, tmpTensor, mrgDstNum * VALUE_AND_INDEX_NUM);
AscendC::PipeBarrier<PIPE_V>();
}
__aicore__ inline void ExtractIndex(const LocalTensor<uint32_t> &idxULocal, const LocalTensor<uint32_t> &sortLocal,
int64_t extractNum)
{
AscendC::GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = Ceil(extractNum * sizeof(float) * VALUE_AND_INDEX_NUM, VEC_REPEAT_BYTES);
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride = B32_VEC_REPEAT_STRIDE;
gatherMaskParams.src1RepeatStride = 0;
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
uint8_t src1Pattern = 2; // 固定模式2,表示筛选出奇数索引的数
AscendC::GatherMask(idxULocal, sortLocal, src1Pattern, false, static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
AscendC::PipeBarrier<PIPE_V>();
}
template <HardEvent event>
__aicore__ inline void SetWaitFlag(HardEvent evt)
{
event_t eventId = static_cast<event_t>(GetTPipePtr()->FetchEventID(evt));
AscendC::SetFlag<event>(eventId);
AscendC::WaitFlag<event>(eventId);
}
} // namespace LIQServiceVec
#endif // LIGHTNING_INDEXER_QUANT_VECTOR_H

View File

@@ -42,6 +42,7 @@
#include "moe_gating_top_k/moe_gating_top_k_torch_adpt.h"
#include "moe_init_routing_custom/moe_init_routing_custom_torch_adpt.h"
#include "sparse_flash_attention/sparse_flash_attention_torch_adpt.h"
#include "lightning_indexer_quant/lightning_indexer_quant_torch_adpt.h"
#include <c10/core/Device.h>
#include <c10/util/Exception.h>
#include <c10/util/Logging.h>
@@ -918,4 +919,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
"-> Tensor[]"
);
ops.impl("moe_grouped_matmul", torch::kPrivateUse1,&vllm_ascend::moe_grouped_matmul);
// This operator is planned to be integrated into PTA in the near future.
// Once that happens, the implementation in csrc will be removed.
ops.def(
"npu_lightning_indexer_quant(Tensor query, Tensor key, Tensor weights, Tensor query_dequant_scale, "
" Tensor key_dequant_scale, *, Tensor? actual_seq_lengths_query=None, "
" Tensor? actual_seq_lengths_key=None, Tensor? block_table=None, "
" int query_quant_mode=0, int key_quant_mode=0, "
" str layout_query='BSND', str layout_key='BSND',"
" int sparse_count=2048, int sparse_mode=3) -> Tensor"
);
ops.impl("npu_lightning_indexer_quant", torch::kPrivateUse1, &vllm_ascend::npu_lightning_indexer_quant);
}

View File

@@ -529,6 +529,44 @@ std::vector<at::Tensor> moe_grouped_matmul_meta(
return y;
}
at::Tensor npu_lightning_indexer_quant_meta(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights,
const at::Tensor &query_dequant_scale, const at::Tensor &key_dequant_scale,
const c10::optional<at::Tensor> &actual_seq_lengths_query,
const c10::optional<at::Tensor> &actual_seq_lengths_key,
const c10::optional<at::Tensor> &block_table, int64_t query_quant_mode, int64_t key_quant_mode,
c10::string_view layout_query, c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode)
{
std::string query_layout_str = std::string(layout_query);
std::string key_layout_str = std::string(layout_key);
const int SIZE = 8;
const int DIM_0 = 0;
const int DIM_1 = 1;
const int DIM_2 = 2;
const int DIM_3 = 3;
at::SmallVector<int64_t, SIZE> output_size;
for (size_t i = 0; i < query.sizes().size(); i++) {
TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater "
"than 0, but shape[", i, "] is ", query.size(i));
}
for (size_t i = 0; i < key.sizes().size(); i++) {
TORCH_CHECK(key.size(i) > 0, "All values within key's shape should be greater "
"than 0, but shape[", i, "] is ", key.size(i));
}
TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count);
int64_t keyHeadNum = (key_layout_str == "TND")? key.size(DIM_1) : key.size(DIM_2);
if (query_layout_str == "BSND") {
output_size = {query.size(DIM_0), query.size(DIM_1), keyHeadNum, sparse_count};
} else {
output_size = {query.size(DIM_0), keyHeadNum, sparse_count};
}
at::Tensor lightning_indexer_quant_output = at::empty(output_size, query.options().dtype(at::kInt));
return lightning_indexer_quant_output;
}
} // namespace meta
} // namespace vllm_ascend
@@ -576,5 +614,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
ops.impl("causal_conv1d_fn", &vllm_ascend::meta::causal_conv1d_fn_meta);
// moe_grouped_matmul
ops.impl("moe_grouped_matmul", &vllm_ascend::meta::moe_grouped_matmul_meta);
// Lightning indexer quant
ops.impl("npu_lightning_indexer_quant", &vllm_ascend::meta::npu_lightning_indexer_quant_meta);
}
}

View File

@@ -266,6 +266,33 @@ def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep():
vllm_model.generate_greedy(long_example_prompts, max_tokens)
@patch.dict(os.environ, {"HCCL_OP_EXPANSION_MODE": "AIV"})
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
@patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"})
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
def test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep():
short_example_prompts = [
"Hello ",
]
# "max_position_embeddings": 163840,
long_example_prompts = ["Hello " * (163839 - 500) + "Hello"]
max_tokens = 500
with VllmRunner(
"vllm-ascend/DeepSeek-V3.2-W8A8-Pruning",
tensor_parallel_size=2,
quantization="ascend",
enable_expert_parallel=True,
max_model_len=163840,
compilation_config={"cudagraph_capture_sizes": [2, 4, 6, 8, 10, 12], "cudagraph_mode": "FULL_DECODE_ONLY"},
speculative_config={"num_speculative_tokens": 1, "method": "deepseek_mtp"},
additional_config={"layer_sharding": ["q_b_proj", "o_proj"], "enable_sparse_c8": True},
reasoning_parser="deepseek_v3",
tokenizer_mode="deepseek_v32",
) as vllm_model:
vllm_model.generate_greedy(short_example_prompts, max_tokens)
vllm_model.generate_greedy(long_example_prompts, max_tokens)
@pytest.mark.parametrize("model", QWEN_W4A4_MODELS)
def test_qwen3_w4a4_distributed_tp2(model):
example_prompts = [

View File

@@ -134,9 +134,12 @@ class AscendConfig:
bool(additional_config.get("enable_async_exponential", False)) and not vllm_is_batch_invariant()
)
use_sparse = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
vllm_config.model_config.hf_text_config, "index_topk"
)
self.enable_kv_nz = additional_config.get("enable_kv_nz", False)
if self.enable_kv_nz:
use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk")
if not vllm_config.model_config.is_deepseek_mla or use_sparse:
raise RuntimeError("enable_kv_nz is only supported for mla currently.")
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
@@ -144,6 +147,17 @@ class AscendConfig:
"enable_kv_nz is only supported in pd scenario and can only be used in D node."
)
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
# Disable Sparse C8 for A5
# A5 has not been fully validated for this path and may carry hidden risks.
# TODO(rjg-lyh): Enable A5 support after sufficient validation.
self.enable_sparse_c8 = (
additional_config.get("enable_sparse_c8", False)
and use_sparse
and get_ascend_device_type() != AscendDeviceType.A5
)
def _construct_weight_prefetch_config(self, additional_config):
weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, TypeVar
import scipy # type: ignore
import torch
import torch_npu
import vllm.envs as envs_vllm
@@ -355,6 +356,9 @@ class AscendSFAImpl(MLAAttentionImpl):
# Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled.
o_proj_full_pool: torch.Tensor | None = None
# qk_hadamard tensor shared when dsa c8 enabled
qk_hadamard: torch.Tensor | None = None
def __init__(
self,
num_heads: int,
@@ -425,6 +429,12 @@ class AscendSFAImpl(MLAAttentionImpl):
self.is_rope_neox_style = False
self.use_torch_npu_lightning_indexer = True
# dsa c8
self.use_sparse_c8_indexer = ascend_config.enable_sparse_c8
if self.use_sparse_c8_indexer:
self.c8_k_cache_dtype = torch.int8
self.c8_k_scale_cache_dtype = torch.float16
# Effective in SFA when FlashComm is enabled.
self.enable_dsa_cp = enable_dsa_cp()
@@ -515,6 +525,11 @@ class AscendSFAImpl(MLAAttentionImpl):
# if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
if self.use_sparse_c8_indexer and AscendSFAImpl.qk_hadamard is None:
AscendSFAImpl.qk_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / (
128**0.5
)
# Processing the input parameters for MLAPO by reordering and transposing
# QKV(and part of Q) weight, applying RoPE-related dimension transformations,
# and handling quantization parameters.
@@ -874,7 +889,15 @@ class AscendSFAImpl(MLAAttentionImpl):
k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128]
return k_li
if self.use_sparse_c8_indexer:
k_li = k_li @ AscendSFAImpl.qk_hadamard
k_li, k_li_scale = torch_npu.npu_dynamic_quant(k_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
k_li_scale = k_li_scale.to(self.c8_k_scale_cache_dtype) # [b*s,]
k_li_scale = k_li_scale.unsqueeze(-1) # [b*s,1]
else:
k_li_scale = None
return k_li, k_li_scale
def indexer_select_post_process(
self,
@@ -905,10 +928,35 @@ class AscendSFAImpl(MLAAttentionImpl):
q_li_pe = q_li_pe.squeeze(2)
q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128]
if self.use_sparse_c8_indexer:
q_li_shape_ori = q_li.shape
q_li = q_li @ AscendSFAImpl.qk_hadamard
q_li, q_li_scale = torch_npu.npu_dynamic_quant(q_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
q_li_scale = q_li_scale.to(self.c8_k_scale_cache_dtype)
# DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer.
# So two branches are maintained temporarily.
# TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed.
if self.use_torch_npu_lightning_indexer:
if self.use_sparse_c8_indexer:
assert len(kv_cache) == 4
weights = weights.to(torch.float16)
topk_indices = torch.ops._C_ascend.npu_lightning_indexer_quant(
query=q_li.view(q_li_shape_ori),
key=kv_cache[2],
weights=weights,
query_dequant_scale=q_li_scale.view(q_li_shape_ori[:-1]),
key_dequant_scale=kv_cache[3].squeeze(2), # B S N D -> B S D
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=attn_metadata.block_table,
query_quant_mode=0,
key_quant_mode=0,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3,
)
elif self.use_torch_npu_lightning_indexer:
topk_indices, _ = torch_npu.npu_lightning_indexer(
query=q_li,
key=kv_cache[2],
@@ -1031,7 +1079,7 @@ class AscendSFAImpl(MLAAttentionImpl):
assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized"
q_c = self.q_a_layernorm(q_c)
k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
k_li, k_li_scale = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
wait_for_kv_layer_from_connector(layer_name)
@@ -1044,20 +1092,46 @@ class AscendSFAImpl(MLAAttentionImpl):
if self.enable_dsa_cp:
assert k_pe is not None
assert k_nope is not None
assert k_li is not None
async_op = self.enable_dsa_cp_with_layer_shard or full_gather_o_proj_enabled
# support all_gather kv async for communication calculation overlap
fused_kv_no_split, kv_ag_handle = all_gather_async(
torch.cat(
[
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
k_li.view(-1, k_li.shape[-1]),
],
dim=1,
),
get_tp_group(),
async_op=async_op,
)
if not self.use_sparse_c8_indexer:
fused_kv_no_split, kv_ag_handle = all_gather_async(
torch.cat(
[
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
k_li.view(-1, k_li.shape[-1]),
],
dim=1,
),
get_tp_group(),
async_op=async_op,
)
else:
# due to different dtypes, we have to split commu pass
assert k_li_scale is not None
fused_kv_no_split, _ = all_gather_async(
torch.cat(
[
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
],
dim=1,
),
get_tp_group(),
async_op=async_op,
)
k_li, _ = all_gather_async(
k_li,
get_tp_group(),
async_op=async_op,
)
k_li_scale, kv_ag_handle = all_gather_async(
k_li_scale,
get_tp_group(),
async_op=async_op,
)
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, cos, sin)
@@ -1077,9 +1151,12 @@ class AscendSFAImpl(MLAAttentionImpl):
if kv_cache is not None:
assert fused_kv_no_split is not None
k_pe, k_nope, k_li = fused_kv_no_split.split(
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
)
if not self.use_sparse_c8_indexer:
k_pe, k_nope, k_li = fused_kv_no_split.split(
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
)
else:
k_pe, k_nope = fused_kv_no_split.split([self.qk_rope_head_dim, self.kv_lora_rank], dim=-1)
k_nope = k_nope.view(k_nope.shape[0], 1, -1)
k_pe = k_pe.view(k_pe.shape[0], 1, -1)
DeviceOperator.reshape_and_cache(
@@ -1098,6 +1175,13 @@ class AscendSFAImpl(MLAAttentionImpl):
torch_npu.npu_scatter_nd_update_(
kv_cache[2].view(-1, k_li.shape[-1]), slot_mapping.view(-1, 1), k_li.view(-1, k_li.shape[-1])
) # b, s, n, d
if self.use_sparse_c8_indexer:
assert len(kv_cache) == 4
torch_npu.npu_scatter_nd_update_(
kv_cache[3].view(-1, k_li_scale.shape[-1]),
slot_mapping.view(-1, 1),
k_li_scale.view(-1, k_li_scale.shape[-1]),
)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()

View File

@@ -137,6 +137,28 @@
# Remove this patch if upstream provides an official NPU graph-capture
# guidance / auto-configuration path for HCCL.
#
# ** 8. File: platform/patch_kv_cache_interface.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.kv_cache_interface.MLAAttentionSpec`
# Why:
# The default `MLAAttentionSpec` is mainly built around `kv_lora_rank`
# and `qk_rope_head_dim`. On NPU, we also use this class to describe DSA
# models. Unlike the GPU path, where cache management is handled by an
# additional indexer module, extending this class directly simplifies the
# corresponding `model_runner` implementation on NPU.
#
# This patch also adds Sparse C8 support for DSA models on NPU. As part
# of that support, members such as `page_size_bytes` need to be adapted,
# so they are overridden here as well to preserve overall readability.
# How:
# This patch subclasses the original implementation, overrides selected
# methods, and adds DSA-specific attributes and helpers with default
# values where needed.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/25896
# Future Plan:
# Remove this patch after the upcoming KV cache spec refactor.
#
# * Worker Patch:
# ===============
#

View File

@@ -18,6 +18,7 @@ import os
import vllm_ascend.patch.platform.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa
import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa
import vllm_ascend.patch.platform.patch_mamba_config # noqa
import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa
import vllm_ascend.patch.platform.patch_sched_yield # noqa

View File

@@ -0,0 +1,138 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
import vllm.v1.kv_cache_interface
from typing_extensions import Self
from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.kv_cache_interface import MLAAttentionSpec
@dataclass(frozen=True)
class AscendMLAAttentionSpec(MLAAttentionSpec):
"""MLAAttentionSpec extended to support DSA models, with optional Sparse C8 support.
When Sparse C8 is enabled, the KV cache tuple changes from
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: bfloat16)
to
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: int8, kv_cache[3]: float16).
The semantic meaning of each KV cache entry is as follows:
1. kv_cache[0] stores kv_lora.
2. kv_cache[1] stores k_rope.
3. kv_cache[2] stores the key tensor from the indexer module.
4. kv_cache[3] stores the key scale tensor from the indexer module,
and exists only when Sparse C8 is enabled.
The main changes are as follows:
1. The key tensor from the indexer module stored in kv_cache[2] is
converted from bf16 to int8 to reduce memory usage. It is then
processed with int8 precision in Lightning_indexer computation
to improve computational efficiency.
2. The quantization scale of the key tensor in the indexer module
must also be stored for the Lightning_indexer_quant operator,
and is therefore saved in kv_cache[3].
"""
sparse_head_dim: tuple[int, ...] | None = None
cache_sparse_c8: bool = False
c8_k_cache_dtype: torch.dtype = torch.int8
c8_k_scale_cache_dtype: torch.dtype = torch.float16
@property
def page_size_bytes(self) -> int:
if self.cache_sparse_c8:
assert self.sparse_head_dim is not None
assert len(self.sparse_head_dim) == 3
num_heads_per_page = self.block_size * self.num_kv_heads
# kv_cache[0]: bfloat16, kv_cache[1]: bfloat16
kv_lora_rank, qk_rope_head_dim = self.sparse_head_dim[:2]
k_pe_nope_bytes = num_heads_per_page * (kv_lora_rank + qk_rope_head_dim) * get_dtype_size(self.dtype)
# kv_cache[2]: int8
index_head_dim = self.sparse_head_dim[-1]
indexer_k_bytes = num_heads_per_page * index_head_dim * get_dtype_size(self.c8_k_cache_dtype)
# kv_cache[3]: float16
# since the scale is stored per token, head_dim is set to 1.
index_scale_head_dim = 1
indexer_k_scale_bytes = (
num_heads_per_page * index_scale_head_dim * get_dtype_size(self.c8_k_scale_cache_dtype)
)
return k_pe_nope_bytes + indexer_k_bytes + indexer_k_scale_bytes
return self.block_size * self.num_kv_heads * self.head_size * get_dtype_size(self.dtype)
@property
def sparse_kv_cache_ratio(self) -> tuple[float, float, float, float | None]:
"""
Compute the relative byte share of each KV cache entry.
Returns:
A tuple containing the ratios for:
- kv_cache[0]
- kv_cache[1]
- kv_cache[2]
- kv_cache[3] (None if Sparse C8 is disabled)
"""
assert self.sparse_head_dim is not None
def get_sparse_head_dim_virtual() -> tuple[int, int, int, int]:
assert self.sparse_head_dim is not None
assert self.cache_sparse_c8 is True
kv_lora_rank, qk_rope_head_dim, index_k_head_dim = self.sparse_head_dim
factor = get_dtype_size(self.dtype) // get_dtype_size(self.c8_k_cache_dtype)
index_k_head_dim_virtual = index_k_head_dim // factor
assert get_dtype_size(self.dtype) == get_dtype_size(self.c8_k_scale_cache_dtype)
index_k_scale_head_dim_virtual = 1
return (
kv_lora_rank,
qk_rope_head_dim,
index_k_head_dim_virtual,
index_k_scale_head_dim_virtual,
)
if self.cache_sparse_c8:
virtual_dims = get_sparse_head_dim_virtual()
total_virtual_head_dim = sum(virtual_dims)
return (
total_virtual_head_dim / virtual_dims[0], # kv_cache[0]
total_virtual_head_dim / virtual_dims[1], # kv_cache[1]
total_virtual_head_dim / virtual_dims[2], # kv_cache[2]
total_virtual_head_dim / virtual_dims[3], # kv_cache[3]
)
return (
self.head_size / self.sparse_head_dim[0], # kv_cache[0]
self.head_size / self.sparse_head_dim[1], # kv_cache[1]
self.head_size / self.sparse_head_dim[2], # kv_cache[2]
None, # kv_cache[3] does not exist
)
@classmethod
def merge(cls, specs: list[Self]) -> Self:
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be MLAAttentionSpec."
)
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
assert len(cache_dtype_str_set) == 1, (
"All attention layers in the same KV cache group must use the same quantization method."
)
return cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
sparse_head_dim=specs[0].sparse_head_dim,
dtype=specs[0].dtype,
cache_dtype_str=cache_dtype_str_set.pop(),
cache_sparse_c8=specs[0].cache_sparse_c8,
)
vllm.v1.kv_cache_interface.MLAAttentionSpec = AscendMLAAttentionSpec

View File

@@ -84,6 +84,7 @@ from vllm.v1.worker.ubatch_utils import (
)
from vllm.v1.worker.utils import AttentionGroup
# yapf: enable
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, using_paged_attention
@@ -96,8 +97,6 @@ from vllm_ascend.compilation.acl_graph import (
set_graph_params,
update_full_graph_params,
)
# yapf: enable
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.core.eplb_device_transfer_loader import D2DExpertWeightLoader
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
@@ -274,7 +273,21 @@ class NPUModelRunner(GPUModelRunner):
self.is_multimodal_model = self.model_config.is_multimodal_model
self.block_size = vllm_config.cache_config.block_size
# Set up Attention
self.use_sparse = hasattr(self.vllm_config.model_config.hf_text_config, "index_topk")
self.use_sparse = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
vllm_config.model_config.hf_text_config, "index_topk"
)
if self.use_sparse:
self.sparse_head_dim = (
self.model_config.hf_text_config.kv_lora_rank,
self.model_config.hf_text_config.qk_rope_head_dim,
self.model_config.hf_text_config.index_head_dim,
)
# dsa c8
self.use_sparse_c8_indexer = self.ascend_config.enable_sparse_c8
if self.use_sparse_c8_indexer:
self.c8_k_cache_dtype = torch.int8
self.c8_k_scale_cache_dtype = torch.float16
self.attn_backend = get_attn_backend(
0,
self.dtype,
@@ -2623,7 +2636,7 @@ class NPUModelRunner(GPUModelRunner):
to their corresponding memory buffer for K cache and V cache.
"""
# init kv cache tensors
kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None] = {}
kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None | None] = {}
# prefill disaggregation need the addr of cache tensor be aligned with 2M
alignment = 2 * 1024 * 1024
layer_kv_cache_spec: dict[str, KVCacheSpec] = {}
@@ -2670,19 +2683,18 @@ class NPUModelRunner(GPUModelRunner):
+ self.model_config.hf_text_config.kv_lora_rank
)
dsa_k_cache_factor = None
dsa_k_cache_size = None
if not self.model_config.use_mla:
# for non-mla model, use FullAttentionSpec
k_tensor_split_factor = 2
v_tensor_split_factor = 2
k_tensor_split_factor = 2.0
v_tensor_split_factor = 2.0
elif self.use_sparse:
# for deepseek v3.2, we split the kv cache according to the corresponding ratio
sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio())
k_tensor_split_factor, v_tensor_split_factor, dsa_k_cache_factor = [ # type: ignore
sparse_sum_head_size / ratio for ratio in self._get_sparse_kv_cache_ratio()
]
dsa_k_cache_size = int(kv_cache_tensor.size // dsa_k_cache_factor)
kv_cache_spec = layer_kv_cache_spec[layer_name]
sparse_kv_cache_ratio = kv_cache_spec.sparse_kv_cache_ratio
k_tensor_split_factor = sparse_kv_cache_ratio[0]
v_tensor_split_factor = sparse_kv_cache_ratio[1]
dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2]
dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3]
else:
# for other deepseek models, use MLAAttentionSpec
k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank
@@ -2690,35 +2702,56 @@ class NPUModelRunner(GPUModelRunner):
k_tensor_size = int(kv_cache_tensor.size // k_tensor_split_factor)
v_tensor_size = int(kv_cache_tensor.size // v_tensor_split_factor)
dsa_k_tensor_size = None
dsa_k_scale_tensor_size = None
#### for deepseek sparse attention
if self.use_sparse:
dsa_k_tensor_size = int(kv_cache_tensor.size // dsa_k_tensor_split_factor)
if self.use_sparse_c8_indexer:
dsa_k_scale_tensor_size = int(kv_cache_tensor.size // dsa_k_scale_tensor_split_factor)
# for other attentions, e.g., self_attn, sliding window attn
if self.vllm_config.kv_transfer_config is None:
k_tensor = torch.zeros(k_tensor_size, dtype=torch.int8, device=self.device)
v_tensor = torch.zeros(v_tensor_size, dtype=torch.int8, device=self.device)
#### k cache: for deepseek sparse attention
if dsa_k_cache_factor is not None:
dsa_k_cache_tensor = torch.zeros(dsa_k_cache_size, dtype=torch.int8, device=self.device)
#### for deepseek sparse attention
if dsa_k_tensor_size is not None:
dsa_k_tensor = torch.zeros(dsa_k_tensor_size, dtype=torch.int8, device=self.device)
if dsa_k_scale_tensor_size is not None:
dsa_k_scale_tensor = torch.zeros(
dsa_k_scale_tensor_size, dtype=torch.int8, device=self.device
)
else:
k_tensor = torch.zeros(k_tensor_size + alignment, dtype=torch.int8, device=self.device)
v_tensor = torch.zeros(v_tensor_size + alignment, dtype=torch.int8, device=self.device)
k_tensor = self._align_memory(k_tensor, alignment)[:k_tensor_size]
v_tensor = self._align_memory(v_tensor, alignment)[:v_tensor_size]
#### k cache: for deepseek sparse attention
if dsa_k_cache_factor is not None and dsa_k_cache_size is not None:
dsa_k_cache_tensor = torch.zeros(
dsa_k_cache_size + alignment, dtype=torch.int8, device=self.device
#### for deepseek sparse attention
if dsa_k_tensor_size is not None:
dsa_k_tensor = torch.zeros(
dsa_k_tensor_size + alignment, dtype=torch.int8, device=self.device
)
dsa_k_cache_tensor = self._align_memory(dsa_k_cache_tensor, alignment)[:dsa_k_cache_size]
dsa_k_tensor = self._align_memory(dsa_k_tensor, alignment)[:dsa_k_tensor_size]
if dsa_k_scale_tensor_size is not None:
dsa_k_scale_tensor = torch.zeros(
dsa_k_scale_tensor_size + alignment, dtype=torch.int8, device=self.device
)
dsa_k_scale_tensor = self._align_memory(
dsa_k_scale_tensor, alignment
)[:dsa_k_scale_tensor_size]
for layer_name_inner in kv_cache_tensor.shared_by:
# shared the attn kvcache for all shared layers
if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner:
kv_cache_raw_tensors[layer_name_inner] = (
(k_tensor, v_tensor)
if not self.use_sparse
else (k_tensor, v_tensor, dsa_k_cache_tensor)
)
if self.use_sparse:
if self.use_sparse_c8_indexer:
kv_cache_raw_tensors[layer_name_inner] = (
k_tensor, v_tensor, dsa_k_tensor, dsa_k_scale_tensor
)
else:
kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor, dsa_k_tensor)
else:
kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor)
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
@@ -2760,13 +2793,23 @@ class NPUModelRunner(GPUModelRunner):
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
# encounter OOM issue
if isinstance(kv_cache_spec, AttentionSpec):
raw_dsa_k_tensor = None
if self.use_sparse:
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
layer_name
]
assert raw_dsa_k_tensor is not None
sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel()
if self.use_sparse_c8_indexer:
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor, raw_dsa_k_scale_tensor = kv_cache_raw_tensors[ # type: ignore
layer_name]
assert raw_dsa_k_tensor is not None
assert raw_dsa_k_scale_tensor is not None
sum_page_size_bytes = (
raw_k_tensor.numel()
+ raw_v_tensor.numel()
+ raw_dsa_k_tensor.numel()
+ raw_dsa_k_scale_tensor.numel()
)
else:
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
layer_name]
assert raw_dsa_k_tensor is not None
sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel()
elif self.use_hybrid_blocks and self.hybrid_with_attn_and_mamba:
# Currently, we ensure that the same kvcache format is used even if there
# is no shared layer, such as the full attention mtp layer of qwen3.5, etc.
@@ -2813,7 +2856,7 @@ class NPUModelRunner(GPUModelRunner):
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size
)
dtype = kv_cache_spec.dtype
if not self.model_config.use_mla:
k_shape = kv_cache_shape[1:]
v_shape = k_shape
@@ -2832,19 +2875,37 @@ class NPUModelRunner(GPUModelRunner):
num_kv_heads,
self.model_config.hf_text_config.qk_rope_head_dim,
]
k_cache = raw_k_tensor.view(dtype).view(k_shape)
v_cache = raw_v_tensor.view(dtype).view(v_shape)
k_cache = raw_k_tensor.view(kv_cache_spec.dtype).view(k_shape)
v_cache = raw_v_tensor.view(kv_cache_spec.dtype).view(v_shape)
if self.use_sparse and raw_dsa_k_tensor is not None:
index_head_dim = self._get_sparse_kv_cache_ratio()[-1]
if self.use_sparse:
dsa_k_cache_shape = (
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
index_head_dim,
self.model_config.hf_text_config.index_head_dim,
)
dsa_k_cache = raw_dsa_k_tensor.view(dtype).view(dsa_k_cache_shape)
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
if self.use_sparse_c8_indexer:
# dsa_k
dsa_k_cache = raw_dsa_k_tensor.view(self.c8_k_cache_dtype).view(dsa_k_cache_shape)
# dsa_k_scale
dsa_k_scale_cache_shape = (
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
1,
)
assert raw_dsa_k_scale_tensor is not None
dsa_k_scale_cache = (
raw_dsa_k_scale_tensor
.view(self.c8_k_scale_cache_dtype)
.view(dsa_k_scale_cache_shape)
)
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache, dsa_k_scale_cache)
else:
# dsa_k
dsa_k_cache = raw_dsa_k_tensor.view(kv_cache_spec.dtype).view(dsa_k_cache_shape)
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
else:
kv_caches[layer_name] = (k_cache, v_cache)
elif isinstance(kv_cache_spec, MambaSpec):
@@ -3098,18 +3159,31 @@ class NPUModelRunner(GPUModelRunner):
elif isinstance(attn_module, MLAAttention):
if self.use_sparse:
# TODO(cmq): This is a hack way to fix deepseek kvcache when
# using DSA. Fix the spec in vLLM is the final way.
sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio())
kv_cache_spec[layer_name] = MLAAttentionSpec(
# `MLAAttentionSpec` is temporarily patched to `AscendMLAAttentionSpec`.
# Re-importing it at runtime will therefore resolve to the patched class.
# Rename it here to make this behavior explicit.
from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec
# TODO(rjg-lyh): when kv_cache_spec's refactor is ready,
# implement it by creating a new kv_cache_spec class
kv_cache_spec[layer_name] = AscendMLAAttentionSpec(
block_size=self.block_size,
num_kv_heads=1,
head_size=sparse_sum_head_size,
head_size=sum(self.sparse_head_dim),
sparse_head_dim=self.sparse_head_dim,
dtype=self.kv_cache_dtype,
cache_dtype_str=self.vllm_config.cache_config.cache_dtype,
cache_sparse_c8=self.use_sparse_c8_indexer,
)
elif spec := attn_module.get_kv_cache_spec(self.vllm_config):
kv_cache_spec[layer_name] = spec
assert isinstance(spec, MLAAttentionSpec)
from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec
kv_cache_spec[layer_name] = AscendMLAAttentionSpec(
block_size=spec.block_size,
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
cache_dtype_str=spec.cache_dtype_str,
)
elif isinstance(attn_module, MambaBase):
mamba_layers[layer_name] = attn_module
@@ -3129,16 +3203,6 @@ class NPUModelRunner(GPUModelRunner):
return kv_cache_spec
def _get_sparse_kv_cache_ratio(self) -> list[int]:
# TODO:If C8 is supported, we need to consider the number of bytes occupied by different dtypes
# when calculating the ratiofor example:
# [kv_lora_rank * torch.int8.itemsize, qk_rope_head_dim * torch.bfloat16.itemsize, ...]
return [
self.model_config.hf_text_config.kv_lora_rank,
self.model_config.hf_text_config.qk_rope_head_dim,
self.model_config.hf_text_config.index_head_dim,
]
def _check_and_update_cudagraph_mode(
self,
attention_backends: list[set[type[AttentionBackend]]],