[kernel] add AscendC op: lightning_indexer and sparse_flash_attention (#4625)
### What this PR does / why we need it? Provide high-performance AscendC operators lightning_indexer and sparse_flash_attention to boost the execution performance of the DeepSeek v3.2 model. Meanwhile, adapt the two AscendC operators to vllm-ascend framework. ### Does this PR introduce _any_ user-facing change? No (only underlying operator optimizations, with no user-facing changes) ### How was this patch tested? - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: MingYang119 <songmingyang@huawei.com>
This commit is contained in:
39
csrc/sparse_flash_attention/op_host/CMakeLists.txt
Normal file
39
csrc/sparse_flash_attention/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,39 @@
|
||||
# This program is free software, you can redistribute it and/or modify it.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME SparseFlashAttention
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
-fpermissive
|
||||
)
|
||||
|
||||
set(sparse_flash_attention_depends transformer/attention/sparse_flash_attention PARENT_SCOPE)
|
||||
target_sources(op_host_aclnn PRIVATE
|
||||
sparse_flash_attention_def.cpp
|
||||
)
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
sparse_flash_attention_tiling.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(opmaster_ct PRIVATE
|
||||
sparse_flash_attention_tiling.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
sparse_flash_attention_proto.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
@@ -0,0 +1,90 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file sparse_flash_attention_def.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class SparseFlashAttention : public OpDef {
|
||||
public:
|
||||
explicit SparseFlashAttention(const char *name) : OpDef(name)
|
||||
{
|
||||
this->Input("query")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("key")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("value")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("sparse_indices")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("block_table")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("actual_seq_lengths_query")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("actual_seq_lengths_kv")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("query_rope")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("key_rope")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Output("attention_out")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Attr("scale_value").AttrType(REQUIRED).Float(1.0);
|
||||
this->Attr("sparse_block_size").AttrType(REQUIRED).Int(1);
|
||||
this->Attr("layout_query").AttrType(OPTIONAL).String("BSND");
|
||||
this->Attr("layout_kv").AttrType(OPTIONAL).String("BSND");
|
||||
this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3);
|
||||
OpAICoreConfig aicore_config;
|
||||
aicore_config.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(true)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.NeedCheckSupportFlag(false)
|
||||
.PrecisionReduceFlag(true)
|
||||
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn");
|
||||
this->AICore().AddConfig("ascend910b", aicore_config);
|
||||
this->AICore().AddConfig("ascend910_93", aicore_config);
|
||||
}
|
||||
};
|
||||
OP_ADD(SparseFlashAttention);
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,48 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file sparse_flash_attention_proto.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include <graph/utils/type_utils.h>
|
||||
#include <register/op_impl_registry.h>
|
||||
#include "error/ops_error.h"
|
||||
|
||||
using namespace ge;
|
||||
|
||||
namespace ops {
|
||||
constexpr size_t QUERY_INPUT_INDEX = 0;
|
||||
|
||||
ge::graphStatus InferShapeSparseFlashAttention(gert::InferShapeContext *context)
|
||||
{
|
||||
OPS_ERR_IF(context == nullptr, OPS_LOG_E("SparseFlashAttention", "InferShapeContext is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
const gert::Shape *queryShape = context->GetInputShape(QUERY_INPUT_INDEX);
|
||||
OPS_LOG_E_IF_NULL(context, queryShape, return ge::GRAPH_FAILED)
|
||||
gert::Shape *attentionOutShape = context->GetOutputShape(0);
|
||||
OPS_LOG_E_IF_NULL(context, attentionOutShape, return ge::GRAPH_FAILED)
|
||||
*attentionOutShape = *queryShape;
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus InferDataTypeSparseFlashAttention(gert::InferDataTypeContext *context)
|
||||
{
|
||||
OPS_ERR_IF(context == nullptr, OPS_LOG_E("SparseFlashAttention", "InferShapeContext is nullptr"),
|
||||
return ge::GRAPH_FAILED);
|
||||
const auto inputDataType = context->GetInputDataType(QUERY_INPUT_INDEX);
|
||||
context->SetOutputDataType(0, inputDataType);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP(SparseFlashAttention).InferShape(InferShapeSparseFlashAttention).InferDataType(InferDataTypeSparseFlashAttention);
|
||||
} // namespace ops
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,583 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file sparse_flash_attention_tiling.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef SPARSE_FLASH_ATTENTION_TILING_H
|
||||
#define SPARSE_FLASH_ATTENTION_TILING_H
|
||||
|
||||
#include <sstream>
|
||||
#include <graph/utils/type_utils.h>
|
||||
#include <exe_graph/runtime/tiling_context.h>
|
||||
#include <tiling/platform/platform_ascendc.h>
|
||||
#include "register/tilingdata_base.h"
|
||||
#include "exe_graph/runtime/tiling_context.h"
|
||||
|
||||
namespace optiling {
|
||||
// Inputs Index
|
||||
constexpr uint32_t QUERY_INPUT_INDEX = 0;
|
||||
constexpr uint32_t KEY_INPUT_INDEX = 1;
|
||||
constexpr uint32_t VALUE_INPUT_INDEX = 2;
|
||||
constexpr uint32_t SPARSE_INDICES_INPUT_INDEX = 3;
|
||||
constexpr uint32_t BLOCK_TABLE_INPUT_INDEX = 4;
|
||||
constexpr uint32_t ACT_SEQ_LEN_Q_INPUT_INDEX = 5;
|
||||
constexpr uint32_t ACT_SEQ_LEN_KV_INPUT_INDEX = 6;
|
||||
constexpr uint32_t QUERY_ROPE_INPUT_INDEX = 7;
|
||||
constexpr uint32_t KEY_ROPE_INPUT_INDEX = 8;
|
||||
// Outputs Index
|
||||
constexpr uint32_t OUTPUT_INDEX = 0;
|
||||
// Attributes Index
|
||||
constexpr uint32_t SCALE_VALUE_ATTR_INDEX = 0;
|
||||
constexpr uint32_t SPARSE_BLOCK_SIZE_ATTR_INDEX = 1;
|
||||
constexpr uint32_t LAYOUT_QUERY_ATTR_INDEX = 2;
|
||||
constexpr uint32_t LAYOUT_KV_ATTR_INDEX = 3;
|
||||
constexpr uint32_t SPARSE_MODE_ATTR_INDEX = 4;
|
||||
// Dim Num
|
||||
constexpr size_t DIM_NUM_TWO = 2;
|
||||
constexpr size_t DIM_NUM_THREE = 3;
|
||||
constexpr size_t DIM_NUM_FOUR = 4;
|
||||
// Constant
|
||||
constexpr uint32_t MAX_BLOCK_SIZE = 1024;
|
||||
constexpr uint32_t COPYND2NZ_SRC_STRIDE_LIMITATION = 65535;
|
||||
constexpr uint32_t NUM_BYTES_FLOAT = 4;
|
||||
constexpr uint32_t NUM_BYTES_FLOAT16 = 2;
|
||||
constexpr uint32_t NUM_BYTES_BF16 = 2;
|
||||
constexpr uint32_t BYTE_BLOCK = 32;
|
||||
const uint32_t SFA_MAX_AIC_CORE_NUM = 26;
|
||||
|
||||
enum class SFALayout : uint32_t {
|
||||
BSND = 0,
|
||||
TND = 1,
|
||||
PA_BSND = 2
|
||||
};
|
||||
|
||||
struct SFATilingShapeCompareParam {
|
||||
int64_t B = 1;
|
||||
int64_t S = 1;
|
||||
int64_t N = 1;
|
||||
int64_t D = 1;
|
||||
int64_t T = 1;
|
||||
// PA
|
||||
int64_t Bs = 1;
|
||||
int64_t Bn = 1;
|
||||
};
|
||||
|
||||
enum class KvStorageMode : uint32_t {
|
||||
BATCH_CONTINUOUS = 0,
|
||||
PAGE_ATTENTION = 1
|
||||
};
|
||||
|
||||
enum class SFAPerfMode : uint32_t {
|
||||
C_TEMPLATE_MODE = 0,
|
||||
V_TEMPLATE_MODE
|
||||
};
|
||||
|
||||
enum class SFAAxis : uint32_t {
|
||||
B = 0,
|
||||
S = 1,
|
||||
N = 2,
|
||||
D = 3,
|
||||
K = 3,
|
||||
T = 5,
|
||||
Bn = 6, // block number
|
||||
Bs = 7, // block size
|
||||
};
|
||||
|
||||
struct SFARequiredParaInfo {
|
||||
const gert::CompileTimeTensorDesc *desc;
|
||||
const gert::StorageShape *shape;
|
||||
};
|
||||
|
||||
struct SFAOptionalParaInfo {
|
||||
const gert::CompileTimeTensorDesc *desc;
|
||||
const gert::Tensor *tensor;
|
||||
};
|
||||
|
||||
struct SFAParaInfo {
|
||||
SFARequiredParaInfo query = {nullptr, nullptr};
|
||||
SFARequiredParaInfo key = {nullptr, nullptr};
|
||||
SFARequiredParaInfo value = {nullptr, nullptr};
|
||||
SFARequiredParaInfo sparseIndices = {nullptr, nullptr};
|
||||
SFAOptionalParaInfo blockTable = {nullptr, nullptr};
|
||||
SFAOptionalParaInfo actualSeqLengthsQ = {nullptr, nullptr};
|
||||
SFAOptionalParaInfo actualSeqLengths = {nullptr, nullptr};
|
||||
SFAOptionalParaInfo queryRope = {nullptr, nullptr};
|
||||
SFAOptionalParaInfo keyRope = {nullptr, nullptr};
|
||||
SFARequiredParaInfo attenOut = {nullptr, nullptr};
|
||||
|
||||
const char *layoutQuery = nullptr;
|
||||
const char *layoutKV = nullptr;
|
||||
const int64_t *sparseBlockSize = nullptr;
|
||||
const float *scaleValue = nullptr;
|
||||
const int64_t *sparseMode = nullptr;
|
||||
};
|
||||
|
||||
struct InnerSplitParams {
|
||||
uint32_t s1GBaseSize = 1;
|
||||
uint32_t s2BaseSize = 1;
|
||||
};
|
||||
|
||||
BEGIN_TILING_DATA_DEF(SparseFlashAttentionBaseParamsMla)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, batchSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, seqSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, qSeqSize)
|
||||
TILING_DATA_FIELD_DEF(int64_t, blockSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch)
|
||||
TILING_DATA_FIELD_DEF(float, scaleValue)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, nNumOfQInOneGroup)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, actualLenDimsQ)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, actualLenDimsKV)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, outputLayout)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, sparseMode)
|
||||
TILING_DATA_FIELD_DEF(int64_t, sparseBlockSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, sparseBlockCount)
|
||||
END_TILING_DATA_DEF
|
||||
REGISTER_TILING_DATA_CLASS(SparseFlashAttentionBaseParamsMlaOp, SparseFlashAttentionBaseParamsMla)
|
||||
|
||||
BEGIN_TILING_DATA_DEF(SparseFlashAttentionSingleCoreParamsMla)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum);
|
||||
END_TILING_DATA_DEF
|
||||
REGISTER_TILING_DATA_CLASS(SparseFlashAttentionSingleCoreParamsMlaOp, SparseFlashAttentionSingleCoreParamsMla)
|
||||
|
||||
BEGIN_TILING_DATA_DEF(SparseFlashAttentionSingleCoreTensorSizeMla)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, mmResUbSize);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, bmm2ResUbSize);
|
||||
END_TILING_DATA_DEF
|
||||
REGISTER_TILING_DATA_CLASS(SparseFlashAttentionSingleCoreTensorSizeMlaOp, SparseFlashAttentionSingleCoreTensorSizeMla)
|
||||
|
||||
BEGIN_TILING_DATA_DEF(SparseFlashAttentionSplitKVParamsMla)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, s2)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, accumOutSize) // FD workspace
|
||||
TILING_DATA_FIELD_DEF(uint32_t, logSumExpSize) // FD workspace
|
||||
END_TILING_DATA_DEF
|
||||
REGISTER_TILING_DATA_CLASS(SparseFlashAttentionSplitKVParamsMlaOp, SparseFlashAttentionSplitKVParamsMla)
|
||||
|
||||
BEGIN_TILING_DATA_DEF(SparseFlashAttentionInnerSplitParams)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, mBaseSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, s2BaseSize)
|
||||
END_TILING_DATA_DEF
|
||||
REGISTER_TILING_DATA_CLASS(SparseFlashAttentionInnerSplitParamsOp, SparseFlashAttentionInnerSplitParams)
|
||||
|
||||
BEGIN_TILING_DATA_DEF(SparseFlashAttentionTilingDataMla)
|
||||
TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionBaseParamsMla, baseParams);
|
||||
TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionSplitKVParamsMla, splitKVParams);
|
||||
TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionSingleCoreParamsMla, singleCoreParams);
|
||||
TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionSingleCoreTensorSizeMla, singleCoreTensorSize);
|
||||
TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionInnerSplitParams, innerSplitParams);
|
||||
END_TILING_DATA_DEF
|
||||
REGISTER_TILING_DATA_CLASS(SparseFlashAttention, SparseFlashAttentionTilingDataMla)
|
||||
|
||||
template <typename T> inline T Align(T num, T rnd)
|
||||
{
|
||||
return (((rnd) == 0) ? 0 : (((num) + (rnd) - 1) / (rnd) * (rnd)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string SFAShape2String(const T &shape)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "[";
|
||||
if (shape.GetDimNum() > 0) {
|
||||
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
|
||||
oss << shape.GetDim(i) << ", ";
|
||||
}
|
||||
oss << shape.GetDim(shape.GetDimNum() - 1);
|
||||
}
|
||||
oss << "]";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
static std::string GetShapeStr(gert::Shape shape);
|
||||
static std::string SFADataTypeToSerialString(ge::DataType type);
|
||||
std::string SFATensorDesc2String(const gert::StorageShape *shape, const gert::CompileTimeTensorDesc *tensor);
|
||||
std::string SFADebugTilingContext(const gert::TilingContext *context);
|
||||
std::string SFALayoutToSerialString(SFALayout layout);
|
||||
|
||||
struct SFATilingInfo {
|
||||
const char *opName = nullptr;
|
||||
fe::PlatFormInfos *platformInfo = nullptr;
|
||||
SFAParaInfo opParamInfo;
|
||||
|
||||
// Base Param
|
||||
platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910B;
|
||||
uint32_t bSize = 0;
|
||||
uint32_t n1Size = 0;
|
||||
uint32_t n2Size = 0;
|
||||
uint32_t s1Size = 0;
|
||||
int64_t s2Size = 0;
|
||||
uint32_t qkHeadDim = 0;
|
||||
uint32_t vHeadDim = 0;
|
||||
uint32_t gSize = 0;
|
||||
uint32_t ropeHeadDim = 0;
|
||||
uint32_t qTSize = 0;
|
||||
uint32_t kvTSize = 0;
|
||||
float scaleValue = 0;
|
||||
uint32_t innerPrecise = 0;
|
||||
uint32_t l2CacheOffFlag = 0;
|
||||
int64_t sparseBlockSize = 0;
|
||||
int64_t sparseBlockCount = 0;
|
||||
|
||||
bool pageAttentionFlag = false;
|
||||
int64_t blockSize = 0;
|
||||
uint32_t blockTypeSize = 0;
|
||||
uint32_t maxBlockNumPerBatch = 0;
|
||||
uint32_t totalBlockNum = 0;
|
||||
|
||||
uint32_t actualLenDimsQ = 0;
|
||||
uint32_t maxActualseq = 0;
|
||||
|
||||
bool actualSeqLenFlag = false;
|
||||
bool isSameSeqAllKVTensor = true;
|
||||
bool isSameActualseq = true;
|
||||
uint32_t actualLenDimsKV = 0;
|
||||
std::vector<int64_t> kvListSeqLens {};
|
||||
|
||||
uint32_t sparseMode = 0;
|
||||
|
||||
ge::DataType inputQType = ge::DT_FLOAT16;
|
||||
ge::DataType inputKvType = ge::DT_FLOAT16;
|
||||
ge::DataType outputType = ge::DT_FLOAT16;
|
||||
|
||||
KvStorageMode kvStorageMode = KvStorageMode::BATCH_CONTINUOUS;
|
||||
|
||||
SFALayout qLayout = SFALayout::BSND;
|
||||
SFALayout topkLayout = SFALayout::BSND;
|
||||
SFALayout outLayout = SFALayout::BSND;
|
||||
SFALayout kvLayout = SFALayout::BSND;
|
||||
|
||||
ge::DataType inputQRopeType = ge::DT_FLOAT16;
|
||||
ge::DataType inputKRopeType = ge::DT_FLOAT16;
|
||||
|
||||
uint64_t l2CacheSize = 0;
|
||||
};
|
||||
|
||||
class SFAMlaTiling {
|
||||
public:
|
||||
explicit SFAMlaTiling(gert::TilingContext *context) : context_(context) {}
|
||||
ge::graphStatus DoOpTiling(SFATilingInfo *sfaInfo);
|
||||
|
||||
private:
|
||||
ge::graphStatus SetBlockDim(uint32_t blockDim);
|
||||
ge::graphStatus SetTilingKey(uint64_t tilingKey);
|
||||
ge::graphStatus SetWorkspaceSize(uint64_t workspaceSize);
|
||||
ge::graphStatus SetTilingData(TilingDef &tilingData);
|
||||
gert::TilingContext *context_ = nullptr;
|
||||
ge::graphStatus GetPlatformInfo();
|
||||
void GenTilingKey();
|
||||
bool DealSameSeqEachBatch();
|
||||
|
||||
void ZeroTensorProcess();
|
||||
void InitParams();
|
||||
|
||||
void Split();
|
||||
bool IsBalanceSplitCore();
|
||||
|
||||
void SplitBalanced();
|
||||
void CalcInnerSize(uint32_t s2Size);
|
||||
|
||||
bool IsFlashDecode(uint32_t coreNum);
|
||||
|
||||
void FillTilingBaseParamsMla();
|
||||
void FillTilingSplitKVMla();
|
||||
|
||||
void FillTilingSingleCoreParamsMla();
|
||||
void FillTilingSingleCoreTensorSizeMla();
|
||||
void FillTiling();
|
||||
|
||||
void CalcUbBmm();
|
||||
void CheckUbSpace();
|
||||
void NormalCalcFDWorkSpace(const uint32_t actCoreNum);
|
||||
void CalcFDWorkSpace(const uint32_t actCoreNum);
|
||||
void GetWorkspaceSize();
|
||||
|
||||
uint32_t CalcBalanceFDParamNums(const uint32_t actCoreNum);
|
||||
|
||||
void CalcBlockDim();
|
||||
|
||||
bool balanceModeFlag_ = false;
|
||||
bool splitKVFlag_ = false;
|
||||
|
||||
uint32_t coreNum_ = 0;
|
||||
SFAPerfMode perfMode_ = SFAPerfMode::V_TEMPLATE_MODE;
|
||||
uint32_t kvSplitPart_ = 1;
|
||||
size_t mmResUbSize_ = 0;
|
||||
size_t bmm2ResUbSize_ = 0;
|
||||
size_t qPreSizeMla_= 0;
|
||||
uint32_t sInnerLoopTimes_ = 0;
|
||||
uint32_t sInnerSize_ = 0;
|
||||
uint32_t sInnerSizeTail_ = 0;
|
||||
uint32_t sInnerSizeAlign_ = 0;
|
||||
uint32_t kvSplit_ = 0;
|
||||
uint32_t usedCoreNum_ = 0;
|
||||
uint32_t formerCoreNum_ = 0;
|
||||
uint32_t blockSplitBn2Range_ = 0;
|
||||
uint32_t tailSplitedBatchRange_ = 0;
|
||||
|
||||
uint32_t aicNum_ = 0;
|
||||
uint32_t aivNum_ = 0;
|
||||
size_t libapiSize_ = 0;
|
||||
|
||||
SparseFlashAttentionTilingDataMla tilingData_;
|
||||
uint32_t blockDim_{0};
|
||||
uint64_t workspaceSize_{0};
|
||||
uint64_t tilingKey_{0};
|
||||
|
||||
uint32_t headDimAlign_ = 0;
|
||||
uint32_t mBaseSize_ = 128;
|
||||
uint32_t mFdBaseSize_ = 8;
|
||||
|
||||
SFATilingInfo *sfaInfo_ = nullptr;
|
||||
};
|
||||
|
||||
class SFATilingCheck {
|
||||
public:
|
||||
explicit SFATilingCheck(const SFATilingInfo &sfaInfo) : sfaInfo_(sfaInfo) {};
|
||||
~SFATilingCheck() = default;
|
||||
virtual ge::graphStatus Process();
|
||||
private:
|
||||
void Init();
|
||||
void LogErrorDtypeSupport(const std::vector<ge::DataType> &expectDtypeList,
|
||||
const ge::DataType &actualDtype, const std::string &name) const;
|
||||
ge::graphStatus CheckDtypeSupport(const gert::CompileTimeTensorDesc *desc,
|
||||
const std::string &name) const;
|
||||
template <typename T> void LogErrorNumberSupport(const std::vector<T> &expectNumberList,
|
||||
const T &actualValue, const std::string &name, const std::string subName) const;
|
||||
template <typename T> void LogErrorDimNumSupport(const std::vector<T> &expectNumberList,
|
||||
const T &actualValue, const std::string &name) const;
|
||||
ge::graphStatus CheckDimNumSupport(const gert::StorageShape *shape,
|
||||
const std::vector<size_t> &expectDimNumList, const std::string &name) const;
|
||||
ge::graphStatus CheckDimNumInLayoutSupport(const SFALayout &layout,
|
||||
const gert::StorageShape *shape, const std::string &name) const;
|
||||
void LogErrorLayoutSupport(const std::vector<SFALayout> &expectLayoutList,
|
||||
const SFALayout &actualLayout, const std::string &name) const;
|
||||
ge::graphStatus GetExpectedShape(gert::Shape &shapeExpected,
|
||||
const SFATilingShapeCompareParam ¶m, const SFALayout &layout) const;
|
||||
ge::graphStatus CompareShape(SFATilingShapeCompareParam ¶m,
|
||||
const gert::Shape &shape, const SFALayout &layout, const std::string &name) const;
|
||||
ge::graphStatus CheckLayoutSupport(const SFALayout &actualLayout, const std::string &name) const;
|
||||
ge::graphStatus CheckSingleParaQuery() const;
|
||||
ge::graphStatus CheckSingleParaKey() const;
|
||||
ge::graphStatus CheckSingleParaValue() const;
|
||||
ge::graphStatus CheckSingleParaQueryRope() const;
|
||||
ge::graphStatus CheckSingleParaKeyRope() const;
|
||||
ge::graphStatus CheckSingleParaAttenOut() const;
|
||||
ge::graphStatus CheckSingleParaNumHeads() const;
|
||||
ge::graphStatus CheckSingleParaKvHeadNums() const;
|
||||
ge::graphStatus CheckSingleParaLayout() const;
|
||||
ge::graphStatus CheckSingleParaSparseMode() const;
|
||||
ge::graphStatus CheckSingleParaSparseBlockSize() const;
|
||||
ge::graphStatus CheckSingleParaSparseIndices() const;
|
||||
ge::graphStatus CheckSinglePara() const;
|
||||
ge::graphStatus CheckMultiParaConsistency() const;
|
||||
ge::graphStatus CheckRopeExistence();
|
||||
ge::graphStatus CheckExists(const void *pointer, const std::string &name) const;
|
||||
ge::graphStatus CheckNotExists(const void *pointer, const std::string &name) const;
|
||||
ge::graphStatus CheckExistsByMap(const std::map<std::string, const void *> ¶mMap) const;
|
||||
ge::graphStatus CheckNotExistsByMap(const std::map<std::string, const void *> ¶mMap) const;
|
||||
ge::graphStatus CheckExistenceByMap(std::map<std::string, const void *> &existMap,
|
||||
std::map<std::string, const void *> ¬ExistMap) const;
|
||||
template <typename T> ge::graphStatus CheckAttrValueByMap(
|
||||
std::map<std::string, std::pair<const T *, T>> &attrMap) const;
|
||||
ge::graphStatus CheckParaExistenceMlaNoquant() const;
|
||||
ge::graphStatus CheckParaExistenceGqaNoquant() const;
|
||||
ge::graphStatus CheckParaExistenceMla() const;
|
||||
ge::graphStatus CheckParaExistence();
|
||||
ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor,
|
||||
const SFALayout &layout, const std::string &name);
|
||||
void SetSFAShapeCompare();
|
||||
ge::graphStatus CheckQRope();
|
||||
ge::graphStatus CheckQRopeShape();
|
||||
ge::graphStatus CheckVAndKRopeShapeForBatchContinuous();
|
||||
uint32_t GetTypeSize(ge::DataType dtype) const;
|
||||
ge::graphStatus CheckVAndKRopeShapeForPageAttention();
|
||||
ge::graphStatus CheckVAndKRopeShape();
|
||||
ge::graphStatus CheckVAndKRope();
|
||||
ge::graphStatus CheckTopK();
|
||||
ge::graphStatus CheckTopkShape();
|
||||
ge::graphStatus CheckBlockTable() const;
|
||||
ge::graphStatus CheckDTypeConsistency(const ge::DataType &actualDtype,
|
||||
const ge::DataType &expectDtype, const std::string &name) const;
|
||||
|
||||
ge::graphStatus CheckAttenOut();
|
||||
ge::graphStatus CheckAttenOutShape();
|
||||
ge::graphStatus CheckActualSeqLensQ();
|
||||
ge::graphStatus CheckActualSeqLensQShape();
|
||||
ge::graphStatus CheckActualSeqLensQDType();
|
||||
ge::graphStatus CheckActualSeqLens();
|
||||
ge::graphStatus CheckActualSeqLensDType();
|
||||
ge::graphStatus CheckActualSeqLensShape();
|
||||
ge::graphStatus CheckMultiParaConsistency();
|
||||
|
||||
ge::graphStatus CheckFeatureMlaNoQuantShape() const;
|
||||
ge::graphStatus CheckFeatureMlaNoQuantLayout() const;
|
||||
ge::graphStatus CheckFeatureMlaNoQuantDtype() const;
|
||||
ge::graphStatus CheckFeatureMlaNoquantPa() const;
|
||||
ge::graphStatus CheckFeatureMlaNoquant() const;
|
||||
ge::graphStatus CheckFeatureMla() const;
|
||||
ge::graphStatus CheckFeature() const;
|
||||
|
||||
private:
|
||||
const char *opName_;
|
||||
fe::PlatFormInfos *platformInfo_;
|
||||
SFAParaInfo opParamInfo_;
|
||||
const SFATilingInfo &sfaInfo_;
|
||||
|
||||
uint32_t bSize_ = 0;
|
||||
uint32_t n1Size_ = 0;
|
||||
uint32_t n2Size_ = 0;
|
||||
uint32_t gSize_ = 0;
|
||||
uint32_t s1Size_ = 0;
|
||||
int64_t s2Size_ = 0;
|
||||
uint32_t qkHeadDim_ = 0;
|
||||
uint32_t vHeadDim_ = 0;
|
||||
uint32_t ropeHeadDim_ = 0;
|
||||
uint32_t qTSize_ = 0;
|
||||
uint32_t kvTSize_ = 0;
|
||||
KvStorageMode kvStorageMode_ = KvStorageMode::BATCH_CONTINUOUS;
|
||||
uint32_t sparseBlockCount_ = 0;
|
||||
int64_t sparseBlockSize_ = 0;
|
||||
|
||||
SFALayout qLayout_ = SFALayout::BSND;
|
||||
SFALayout topkLayout_ = SFALayout::BSND;
|
||||
SFALayout outLayout_ = SFALayout::BSND;
|
||||
SFALayout kvLayout_ = SFALayout::BSND;
|
||||
|
||||
uint32_t maxBlockNumPerBatch_ = 0;
|
||||
int64_t blockSize_ = 0;
|
||||
|
||||
uint32_t aicNum_ = 0;
|
||||
uint32_t aivNum_ = 0;
|
||||
platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B;
|
||||
uint64_t l2CacheSize_ = 0;
|
||||
|
||||
ge::DataType inputQType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputKvType_ = ge::DT_FLOAT16;
|
||||
ge::DataType outputType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputQRopeType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputKRopeType_ = ge::DT_FLOAT16;
|
||||
|
||||
gert::Shape queryShapeCmp_{};
|
||||
gert::Shape keyShapeCmp_{};
|
||||
gert::Shape valueShapeCmp_{};
|
||||
gert::Shape topkShapeCmp_{};
|
||||
gert::Shape queryRopeShapeCmp_{};
|
||||
gert::Shape keyRopeShapeCmp_{};
|
||||
gert::Shape attenOutShapeCmp_{};
|
||||
};
|
||||
|
||||
class SFAInfoParser {
|
||||
public:
|
||||
explicit SFAInfoParser(const gert::TilingContext *context) : context_(context) {}
|
||||
~SFAInfoParser() = default;
|
||||
|
||||
ge::graphStatus CheckRequiredInOutExistence() const;
|
||||
ge::graphStatus CheckRequiredAttrExistence() const;
|
||||
ge::graphStatus CheckRequiredParaExistence() const;
|
||||
|
||||
ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor,
|
||||
SFALayout &layout, const std::string &name);
|
||||
ge::graphStatus GetActualSeqLenQSize(uint32_t &size);
|
||||
ge::graphStatus GetOpName();
|
||||
ge::graphStatus GetNpuInfo();
|
||||
void GetOptionalInputParaInfo();
|
||||
void GetInputParaInfo();
|
||||
void GetOutputParaInfo();
|
||||
ge::graphStatus GetAttrParaInfo();
|
||||
ge::graphStatus GetKvCache();
|
||||
ge::graphStatus GetOpParaInfo();
|
||||
|
||||
ge::graphStatus GetInOutDataType();
|
||||
ge::graphStatus GetBatchSize();
|
||||
ge::graphStatus GetQTSize();
|
||||
ge::graphStatus GetKVTSize();
|
||||
ge::graphStatus GetQkHeadDim();
|
||||
ge::graphStatus GetS1Size();
|
||||
ge::graphStatus GetKvStorageMode();
|
||||
ge::graphStatus GetKvLayout();
|
||||
void SetSFAShape();
|
||||
ge::graphStatus GetS2SizeForBatchContinuous();
|
||||
ge::graphStatus GetMaxBlockNumPerBatch();
|
||||
ge::graphStatus GetBlockSize();
|
||||
ge::graphStatus GetS2SizeForPageAttention();
|
||||
ge::graphStatus GetS2Size();
|
||||
ge::graphStatus GetValueHeadDim();
|
||||
ge::graphStatus GetRopeHeadDim();
|
||||
ge::graphStatus GetQueryAndOutLayout();
|
||||
ge::graphStatus GetTopkLayout();
|
||||
ge::graphStatus GetN1Size();
|
||||
ge::graphStatus GetN2Size();
|
||||
ge::graphStatus GetGSize();
|
||||
ge::graphStatus GetSparseBlockCount();
|
||||
ge::graphStatus GetActualseqInfo();
|
||||
void GenerateInfo(SFATilingInfo &sfaInfo);
|
||||
ge::graphStatus Parse(SFATilingInfo &sfaInfo);
|
||||
|
||||
public:
|
||||
bool HasAxis(const SFAAxis &axis, const SFALayout &layout, const gert::Shape &shape) const;
|
||||
size_t GetAxisIdx(const SFAAxis &axis, const SFALayout &layout) const;
|
||||
uint32_t GetAxisNum(const gert::Shape &shape, const SFAAxis &axis,const SFALayout &layout) const;
|
||||
|
||||
const gert::TilingContext *context_ = nullptr;
|
||||
|
||||
const char *opName_;
|
||||
fe::PlatFormInfos *platformInfo_;
|
||||
SFAParaInfo opParamInfo_;
|
||||
static constexpr int64_t invalidDimValue_ = std::numeric_limits<int64_t>::min();
|
||||
|
||||
uint32_t bSize_ = 0;
|
||||
uint32_t n1Size_ = 0;
|
||||
uint32_t n2Size_ = 0;
|
||||
uint32_t gSize_ = 0;
|
||||
uint32_t s1Size_ = 0;
|
||||
int64_t s2Size_ = 0;
|
||||
uint32_t qkHeadDim_ = 0;
|
||||
uint32_t vHeadDim_ = 0;
|
||||
uint32_t ropeHeadDim_ = 0;
|
||||
uint32_t qTSize_ = 0;
|
||||
uint32_t kvTSize_ = 0;
|
||||
KvStorageMode kvStorageMode_ = KvStorageMode::BATCH_CONTINUOUS;
|
||||
uint32_t sparseBlockCount_ = 0;
|
||||
|
||||
SFALayout qLayout_ = SFALayout::BSND;
|
||||
SFALayout topkLayout_ = SFALayout::BSND;
|
||||
SFALayout outLayout_ = SFALayout::BSND;
|
||||
SFALayout kvLayout_ = SFALayout::BSND;
|
||||
|
||||
uint32_t maxBlockNumPerBatch_ = 0;
|
||||
uint32_t blockSize_ = 0;
|
||||
|
||||
platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B;
|
||||
|
||||
ge::DataType inputQType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputKvType_ = ge::DT_FLOAT16;
|
||||
ge::DataType outputType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputQRopeType_ = ge::DT_FLOAT16;
|
||||
ge::DataType inputKRopeType_ = ge::DT_FLOAT16;
|
||||
|
||||
uint64_t l2CacheSize_ = 0;
|
||||
|
||||
bool isSameSeqAllKVTensor_ = true;
|
||||
bool isSameActualseq_ = true;
|
||||
uint32_t maxActualseq_ = 0;
|
||||
|
||||
uint32_t actualLenDimsQ_ = 0;
|
||||
uint32_t actualLenDimsKV_ = 0;
|
||||
|
||||
gert::Shape queryShape_{};
|
||||
gert::Shape keyShape_{};
|
||||
gert::Shape valueShape_{};
|
||||
gert::Shape sparseIndicesShape_{};
|
||||
gert::Shape queryRopeShape_{};
|
||||
gert::Shape keyRopeShape_{};
|
||||
};
|
||||
} // namespace optiling
|
||||
#endif // SPARSE_FLASH_ATTENTION_TILING_H
|
||||
Reference in New Issue
Block a user