### What this PR does / why we need it? Provide high-performance AscendC operators lightning_indexer and sparse_flash_attention to boost the execution performance of the DeepSeek v3.2 model. Meanwhile, adapt the two AscendC operators to vllm-ascend framework. ### Does this PR introduce _any_ user-facing change? No (only underlying operator optimizations, with no user-facing changes) ### How was this patch tested? - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: MingYang119 <songmingyang@huawei.com>
97 lines
4.5 KiB
C++
97 lines
4.5 KiB
C++
/**
|
|
* This program is free software, you can redistribute it and/or modify it.
|
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
* This file is a part of the CANN Open Software.
|
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
* See LICENSE in the root of the software repository for the full text of the License.
|
|
*/
|
|
|
|
/*!
|
|
* \file lightning_indexer_proto.cpp
|
|
* \brief
|
|
*/
|
|
#include <graph/utils/type_utils.h>
|
|
#include <register/op_impl_registry.h>
|
|
#include "error/ops_error.h"
|
|
|
|
|
|
using namespace ge;
|
|
|
|
namespace ops {
|
|
constexpr uint32_t QUERY_INDEX = 0;
|
|
constexpr uint32_t KEY_INDEX = 1;
|
|
constexpr uint32_t ACTUAL_SEQ_K_INDEX = 4;
|
|
constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 0;
|
|
constexpr uint32_t ATTR_KEY_LAYOUT_INDEX = 1;
|
|
constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 2;
|
|
|
|
static ge::graphStatus InferShapeLightningIndexer(gert::InferShapeContext *context)
|
|
{
|
|
OPS_ERR_IF(context == nullptr, OPS_LOG_E("LightningIndexer", "InferShapeContext is nullptr!"),
|
|
return ge::GRAPH_FAILED);
|
|
const gert::Shape *queryShape = context->GetInputShape(QUERY_INDEX);
|
|
OPS_LOG_E_IF_NULL(context, queryShape, return ge::GRAPH_FAILED);
|
|
const gert::Shape *keyShape = context->GetInputShape(KEY_INDEX);
|
|
OPS_LOG_E_IF_NULL(context, keyShape, return ge::GRAPH_FAILED);
|
|
gert::Shape *outShape = context->GetOutputShape(0);
|
|
|
|
auto attrs = context->GetAttrs();
|
|
OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED);
|
|
const char *inputLayoutQueryPtr = attrs->GetAttrPointer<char>(ATTR_QUERY_LAYOUT_INDEX);
|
|
OPS_LOG_E_IF_NULL(context, inputLayoutQueryPtr, return ge::GRAPH_FAILED);
|
|
const char *inputLayoutKeyPtr = attrs->GetAttrPointer<char>(ATTR_KEY_LAYOUT_INDEX);
|
|
OPS_LOG_E_IF_NULL(context, inputLayoutKeyPtr, return ge::GRAPH_FAILED);
|
|
const int64_t *seleced_count = attrs->GetInt(ATTR_SPARSE_COUNT_INDEX);
|
|
OPS_LOG_E_IF_NULL(context, seleced_count, return ge::GRAPH_FAILED);
|
|
std::string inputLayoutQueryPtrStr = std::string(inputLayoutQueryPtr);
|
|
std::string inputLayoutKeyPtrStr = std::string(inputLayoutKeyPtr);
|
|
OPS_ERR_IF(
|
|
inputLayoutQueryPtrStr != "TND" && inputLayoutQueryPtrStr != "BSND",
|
|
OPS_LOG_E(context, "The attr layout_query should be TND or BSND, but got %s.", inputLayoutQueryPtrStr.c_str()),
|
|
return ge::GRAPH_FAILED);
|
|
|
|
outShape->SetDimNum(queryShape->GetDimNum());
|
|
if (inputLayoutQueryPtrStr == "BSND") {
|
|
OPS_ERR_IF(
|
|
queryShape->GetDimNum() != 4,
|
|
OPS_LOG_E(context, "Layout BSND, queryDims (%zu) must be 4!", queryShape->GetDimNum()),
|
|
return ge::GRAPH_FAILED);
|
|
outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim B
|
|
outShape->SetDim(1, queryShape->GetDim(1)); // 1:Dim S
|
|
outShape->SetDim(2, keyShape->GetDim(2)); // 2:Dim N
|
|
outShape->SetDim(3, *seleced_count); // 3:Dim K
|
|
} else {
|
|
OPS_ERR_IF(
|
|
queryShape->GetDimNum() != 3,
|
|
OPS_LOG_E(context, "Layout TND, queryDims (%zu) must be 3!", queryShape->GetDimNum()),
|
|
return ge::GRAPH_FAILED);
|
|
outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim T
|
|
int32_t nDimIndex = (inputLayoutKeyPtrStr == "PA_BSND") ? 2 : 1; // 2:Key Dim N
|
|
outShape->SetDim(1, keyShape->GetDim(nDimIndex)); // 1:Dim N
|
|
outShape->SetDim(2, *seleced_count); // 2:Dim K
|
|
}
|
|
OPS_LOG_D(context->GetNodeName(), "LightningIndexer InferShape end.");
|
|
|
|
return ge::GRAPH_SUCCESS;
|
|
}
|
|
|
|
static ge::graphStatus InferDataTypeLightningIndexer(gert::InferDataTypeContext *context)
|
|
{
|
|
OPS_ERR_IF(context == nullptr, OPS_LOG_E("LightningIndexer", "InferDataTypeContext is nullptr!"),
|
|
return ge::GRAPH_FAILED);
|
|
OPS_LOG_D(context->GetNodeName(), "Enter LightningIndexer InferDataType impl.");
|
|
// default set q's dtype as fia's output type
|
|
ge::DataType outputType = ge::DT_INT32;
|
|
// attention_out, outidx:0
|
|
context->SetOutputDataType(0, outputType);
|
|
OPS_LOG_D(context->GetNodeName(), "LightningIndexer InferDataType end.");
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
|
|
IMPL_OP_INFERSHAPE(LightningIndexer)
|
|
.InferShape(InferShapeLightningIndexer)
|
|
.InferDataType(InferDataTypeLightningIndexer);
|
|
} // namespace ops
|