[Kernel] Add AscendC fused op transpose_kv_cache_by_block to speed up GQA transfer (#6366)
### What this PR does / why we need it?
As #2947 describe, we need to transpose kv cache layout after GQA kv
transfer when prefill and decode tensor parallel size are heterogeneous,
in the previous implementation, we use `npu_paged_cache_load ` +
`tranpose` + `_npu_reshape_and_cache` to do this work.
But obviously, it is not an efficient plan, the ops above need to be
called for each layer, which introduces 3 * layer_num kernel launch, and
6 * layer_num data movement between L1 Cache and HBM for one request on
decode node. Usually, decode node uses graph mode, so these op kernels
will be called between decode forward launched by an async thread in
mooncacke connector, this kernels maybe last for several decode forward
and TTFT will increase by 3~4 decode forward time.
In this PR, we implement an AscendC fused op
`transpose_kv_cache_by_block` to do this with only once kernel launch
and move data between L1 Cache and HBM only once.
After using this fused op, the time cost in transpose kv cacke layout
can be decreased to 0.24ms from 7ms in UT on 910C, and in PD
disaggregation scenario, TTFT can decrease about 90 ~ 110 ms in
qwen3-235B.
| request_num | original | fused_op|
|:----------------------:|:---------------:|:-------------------:|
| 1 | 643 ms | 578 ms |
| 128 | 1480 ms | 1368 ms |
### Does this PR introduce _any_ user-facing change?
Use fused op by default, incase the op has bug in any scenario, provide
fallback choice using env to disable it.
**DISABLE fused op by add following env**
`export VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK=0`
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: lidenghui <lidenghui1110@gmail.com>
This commit is contained in:
40
csrc/transpose_kv_cache_by_block/op_host/CMakeLists.txt
Normal file
40
csrc/transpose_kv_cache_by_block/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,40 @@
|
||||
# This program is free software, you can redistribute it and/or modify it.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME TransposeKvCacheByBlock
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
-mllvm -cce-aicore-hoist-movemask=false
|
||||
--op_relocatable_kernel_binary=true
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnn PRIVATE
|
||||
transpose_kv_cache_by_block_def.cpp
|
||||
)
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
transpose_kv_cache_by_block_tiling.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(opmaster_ct PRIVATE
|
||||
transpose_kv_cache_by_block_tiling.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
transpose_kv_cache_by_block_proto.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class TransposeKvCacheByBlock : public OpDef {
|
||||
public:
|
||||
explicit TransposeKvCacheByBlock(const char* name) : OpDef(name)
|
||||
{
|
||||
this->Input("KCache")
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("VCache")
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("blockIDs")
|
||||
.ParamType(REQUIRED)
|
||||
.DataTypeList({ge::DT_INT64})
|
||||
.FormatList({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
this->Attr("blockSize").Int();
|
||||
this->Attr("headNum").Int();
|
||||
this->Attr("headDim").Int();
|
||||
this->Attr("splitNum").Int();
|
||||
this->Attr("layerNum").Int();
|
||||
|
||||
this->AICore().AddConfig("ascend910b");
|
||||
this->AICore().AddConfig("ascend910_93");
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(TransposeKvCacheByBlock);
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file transpose_kv_cache_by_block_proto.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include <graph/utils/type_utils.h>
|
||||
#include <register/op_impl_registry.h>
|
||||
#include "error/ops_error.h"
|
||||
|
||||
using namespace ge;
|
||||
|
||||
namespace ops {
|
||||
|
||||
static ge::graphStatus InferShapeTransposeKvCacheByBlock(gert::InferShapeContext* context)
|
||||
{
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus InferDataTypeTransposeKvCacheByBlock(gert::InferDataTypeContext *context)
|
||||
{
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_INFERSHAPE(TransposeKvCacheByBlock)
|
||||
.InferShape(InferShapeTransposeKvCacheByBlock)
|
||||
.InferDataType(InferDataTypeTransposeKvCacheByBlock);
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,182 @@
|
||||
#include "transpose_kv_cache_by_block_tiling.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "log/ops_log.h"
|
||||
#include <algorithm>
|
||||
|
||||
namespace optiling {
|
||||
|
||||
constexpr uint64_t DATA_SIZE = 2;
|
||||
constexpr uint64_t BLOCK_SIZE = 32;
|
||||
constexpr uint64_t DB_ON = 2;
|
||||
|
||||
constexpr uint32_t FULL_LOAD = 0;
|
||||
constexpr uint32_t SPLIT_BLOCK_SIZE_ALIGNED_AND_DB = 1;
|
||||
constexpr uint32_t SPLIT_BLOCK_SIZE_UNALIGNED_AND_DB = 3;
|
||||
constexpr uint32_t SPLIT_BLOCK_SIZE_ALIGNED_AND_NOT_DB = 2;
|
||||
constexpr uint32_t SPLIT_BLOCK_SIZE_UNALIGNED_AND_NOT_DB = 4;
|
||||
|
||||
void findFactorsOptimized(std::vector<int64_t> &factors, int64_t n) {
|
||||
|
||||
for (int64_t i = 1; i * i <= n; i++) {
|
||||
if (n % i == 0) {
|
||||
factors.push_back(i);
|
||||
|
||||
if (i != n / i) {
|
||||
factors.push_back(n / i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort(factors.begin(), factors.end());
|
||||
}
|
||||
|
||||
ge::graphStatus CalTiling(gert::TilingContext* context, TransposeKvCacheByBlockTilingData &tiling)
|
||||
{
|
||||
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
|
||||
OPS_LOG_E_IF_NULL(context, platformInfoPtr, return ge::GRAPH_FAILED);
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||||
int64_t useCoreNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint64_t ubSize;
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
|
||||
|
||||
auto attr = context->GetAttrs();
|
||||
OPS_LOG_E_IF_NULL(context, attr, return ge::GRAPH_FAILED);
|
||||
const int64_t* blockSizePtr = attr->GetAttrPointer<int64_t>(0);
|
||||
const int64_t* headNumPtr = attr->GetAttrPointer<int64_t>(1);
|
||||
const int64_t* headDimPtr = attr->GetAttrPointer<int64_t>(2);
|
||||
const int64_t* splitNumPtr = attr->GetAttrPointer<int64_t>(3);
|
||||
const int64_t* layerNumPtr = attr->GetAttrPointer<int64_t>(4);
|
||||
OPS_CHECK(blockSizePtr == nullptr || headNumPtr == nullptr || headDimPtr == nullptr ||
|
||||
splitNumPtr == nullptr || layerNumPtr == nullptr,
|
||||
OPS_LOG_E(context->GetNodeName(), "Get attr failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
auto blockIDsTensor = context->GetDynamicInputTensor(2, 0);
|
||||
OPS_LOG_E_IF_NULL(context, blockIDsTensor, return ge::GRAPH_FAILED);
|
||||
|
||||
gert::Shape blockIDsTensorShape = blockIDsTensor->GetStorageShape();
|
||||
int64_t calBlockNum = static_cast<int64_t>(blockIDsTensorShape.GetDim(0));
|
||||
|
||||
tiling.set_calBlockNum(static_cast<uint32_t>(calBlockNum));
|
||||
|
||||
int64_t blockSize = *blockSizePtr;
|
||||
int64_t headNum = *headNumPtr;
|
||||
int64_t headDim = *headDimPtr;
|
||||
int64_t splitNum = *splitNumPtr;
|
||||
int64_t layerNum = *layerNumPtr;
|
||||
uint32_t tilingKey = FULL_LOAD;
|
||||
|
||||
if (headDim * DATA_SIZE % BLOCK_SIZE != 0) {
|
||||
OPS_LOG_E(context, "headDim * DATA_SIZE must be a multiple of 32 bytes.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::vector<int64_t> factors;
|
||||
findFactorsOptimized(factors, useCoreNum);
|
||||
|
||||
uint32_t factorIndex = 0;
|
||||
bool findSplitNum = true;
|
||||
int64_t blockSizeSplitNum = factors[factorIndex];
|
||||
uint64_t dataSizeloadOnce = blockSize * headNum * headDim * DATA_SIZE;
|
||||
// if can full load, not split blockSize and db
|
||||
if (dataSizeloadOnce > ubSize) {
|
||||
tilingKey = SPLIT_BLOCK_SIZE_ALIGNED_AND_DB;
|
||||
// split blockSize and db
|
||||
while (dataSizeloadOnce > (ubSize / DB_ON)) {
|
||||
factorIndex += 1;
|
||||
if (factorIndex == factors.size()) {
|
||||
tilingKey = FULL_LOAD;
|
||||
findSplitNum = false;
|
||||
break;
|
||||
}
|
||||
blockSizeSplitNum = factors[factorIndex];
|
||||
dataSizeloadOnce = ((blockSize + blockSizeSplitNum - 1) / blockSizeSplitNum) * headNum * headDim * DATA_SIZE;
|
||||
}
|
||||
if (tilingKey == SPLIT_BLOCK_SIZE_ALIGNED_AND_DB && (blockSize % blockSizeSplitNum != 0)) {
|
||||
tilingKey = SPLIT_BLOCK_SIZE_UNALIGNED_AND_DB;
|
||||
}
|
||||
}
|
||||
|
||||
if (!findSplitNum) {
|
||||
tilingKey = SPLIT_BLOCK_SIZE_ALIGNED_AND_NOT_DB;
|
||||
// split blockSize but not db
|
||||
findSplitNum = true;
|
||||
factorIndex = 0;
|
||||
blockSizeSplitNum = factors[factorIndex];
|
||||
dataSizeloadOnce = blockSize * headNum * headDim * DATA_SIZE;
|
||||
while (dataSizeloadOnce > ubSize) {
|
||||
factorIndex += 1;
|
||||
if (factorIndex == factors.size()) {
|
||||
tilingKey = FULL_LOAD;
|
||||
findSplitNum = false;
|
||||
break;
|
||||
}
|
||||
blockSizeSplitNum = factors[factorIndex];
|
||||
dataSizeloadOnce = ((blockSize + blockSizeSplitNum - 1) / blockSizeSplitNum) * headNum * headDim * DATA_SIZE;
|
||||
}
|
||||
if (tilingKey == SPLIT_BLOCK_SIZE_ALIGNED_AND_NOT_DB && (blockSize % blockSizeSplitNum != 0)) {
|
||||
tilingKey = SPLIT_BLOCK_SIZE_UNALIGNED_AND_NOT_DB;
|
||||
}
|
||||
}
|
||||
|
||||
// headNum * headDim too large
|
||||
if (!findSplitNum) {
|
||||
OPS_LOG_E(context, "headNum * headDim * sizeof(half) > ubSize "
|
||||
"or blockSize * headNum * headDim * sizeof(half) > ubSize * vectorCoreNum. "
|
||||
"Currently, splitting headNum or headDim is not supported.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
tiling.set_blockSizePerTime(static_cast<uint32_t>((blockSize + blockSizeSplitNum - 1) / blockSizeSplitNum));
|
||||
tiling.set_blockSizePerTimeTail(static_cast<uint32_t>(blockSize % blockSizeSplitNum));
|
||||
tiling.set_blockSizeSplitNum(static_cast<uint32_t>(blockSizeSplitNum));
|
||||
|
||||
tiling.set_blockSize(static_cast<uint32_t>(blockSize));
|
||||
tiling.set_headNum(static_cast<uint32_t>(headNum));
|
||||
tiling.set_headDim(static_cast<uint32_t>(headDim));
|
||||
tiling.set_splitNum(static_cast<uint32_t>(splitNum));
|
||||
tiling.set_layerNum(static_cast<uint32_t>(layerNum));
|
||||
|
||||
int64_t totalRound = layerNum * calBlockNum;
|
||||
|
||||
if ((totalRound * blockSizeSplitNum) < useCoreNum) {
|
||||
useCoreNum = totalRound * blockSizeSplitNum;
|
||||
}
|
||||
int64_t blockPerCore = totalRound / (useCoreNum / blockSizeSplitNum);
|
||||
int64_t tailCoreNum = totalRound % (useCoreNum / blockSizeSplitNum);
|
||||
|
||||
tiling.set_useCoreNum(static_cast<uint32_t>(useCoreNum));
|
||||
tiling.set_blockPerCore(static_cast<uint32_t>(blockPerCore));
|
||||
tiling.set_tailCoreNum(static_cast<uint32_t>(tailCoreNum));
|
||||
context->SetBlockDim(useCoreNum);
|
||||
context->SetTilingKey(tilingKey);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
static ge::graphStatus TransposeKvCacheByBlockTilingFunc(gert::TilingContext* context)
|
||||
{
|
||||
|
||||
TransposeKvCacheByBlockTilingData tiling;
|
||||
auto status = CalTiling(context, tiling);
|
||||
OP_CHECK(status != ge::GRAPH_SUCCESS, OPS_LOG_E(context->GetNodeName(), "Cal tiling failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
|
||||
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
struct TransposeKvCacheByBlockCompileInfo {};
|
||||
ge::graphStatus TilingParseForTransposeKvCacheByBlock(gert::TilingParseContext *context)
|
||||
{
|
||||
(void)context;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(TransposeKvCacheByBlock)
|
||||
.Tiling(TransposeKvCacheByBlockTilingFunc)
|
||||
.TilingParse<TransposeKvCacheByBlockCompileInfo>(TilingParseForTransposeKvCacheByBlock);
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
#include "register/tilingdata_base.h"
|
||||
|
||||
namespace optiling {
|
||||
BEGIN_TILING_DATA_DEF(TransposeKvCacheByBlockTilingData)
|
||||
// shape info
|
||||
// TILING_DATA_FIELD_DEF(uint32_t, blockNum);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, blockSize);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, headNum);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, headDim);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, splitNum);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, layerNum);
|
||||
// tiling info
|
||||
TILING_DATA_FIELD_DEF(uint32_t, useCoreNum);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, blockPerCore);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, tailCoreNum);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, calBlockNum);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, blockSizePerTime);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, blockSizePerTimeTail);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, blockSizeSplitNum);
|
||||
END_TILING_DATA_DEF;
|
||||
|
||||
REGISTER_TILING_DATA_CLASS(TransposeKvCacheByBlock, TransposeKvCacheByBlockTilingData)
|
||||
}
|
||||
Reference in New Issue
Block a user