[Kernel] Add AscendC fused op transpose_kv_cache_by_block to speed up GQA transfer (#6366)

### What this PR does / why we need it?
As #2947 describe, we need to transpose kv cache layout after GQA kv
transfer when prefill and decode tensor parallel size are heterogeneous,
in the previous implementation, we use `npu_paged_cache_load ` +
`tranpose` + `_npu_reshape_and_cache` to do this work.

But obviously, it is not an efficient plan, the ops above need to be
called for each layer, which introduces 3 * layer_num kernel launch, and
6 * layer_num data movement between L1 Cache and HBM for one request on
decode node. Usually, decode node uses graph mode, so these op kernels
will be called between decode forward launched by an async thread in
mooncacke connector, this kernels maybe last for several decode forward
and TTFT will increase by 3~4 decode forward time.

In this PR, we implement an AscendC fused op
`transpose_kv_cache_by_block` to do this with only once kernel launch
and move data between L1 Cache and HBM only once.

After using this fused op, the time cost in transpose kv cacke layout
can be decreased to 0.24ms from 7ms in UT on 910C, and in PD
disaggregation scenario, TTFT can decrease about 90 ~ 110 ms in
qwen3-235B.

| request_num | original | fused_op|
|:----------------------:|:---------------:|:-------------------:|
|           1            |      643 ms      |        578 ms        |
|          128           |     1480 ms      |       1368 ms        |

### Does this PR introduce _any_ user-facing change?
Use fused op by default, incase the op has bug in any scenario, provide
fallback choice using env to disable it.

**DISABLE fused op by add following env**
`export VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK=0`

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: lidenghui <lidenghui1110@gmail.com>
This commit is contained in:
lidenghui1110
2026-02-03 14:10:01 +08:00
committed by GitHub
parent f4a72f0d16
commit 79803932e2
15 changed files with 913 additions and 3 deletions

View File

@@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;"
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;"
SOC_ARG="ascend910b"
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
# ASCEND910C (A3) series
@@ -81,6 +81,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
"moe_gating_top_k"
"add_rms_norm_bias"
"apply_top_k_top_p_custom"
"transpose_kv_cache_by_block"
)
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
SOC_ARG="ascend910_93"

View File

@@ -1343,6 +1343,22 @@ std::tuple<at::Tensor,at::Tensor, at::Tensor> npu_add_rms_norm_bias(
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y, rstd, x);
}
void transpose_kv_cache_by_block(
const at::TensorList &kCache,
const at::TensorList &vCache,
const at::Tensor &blockIDs,
int64_t blockSize,
int64_t headNum,
int64_t headDim,
int64_t splitNum,
int64_t layerNum)
{
EXEC_NPU_CMD(aclnnTransposeKvCacheByBlock, kCache, vCache, blockIDs,
blockSize, headNum, headDim, splitNum, layerNum);
}
} // namespace vllm_ascend
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
@@ -1521,4 +1537,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
ops.def("npu_apply_top_k_top_p(Tensor logits, Tensor? p=None, Tensor? k=None) -> Tensor");
ops.impl("npu_apply_top_k_top_p", torch::kPrivateUse1, &vllm_ascend::npu_apply_top_k_top_p);
ops.def(
"transpose_kv_cache_by_block(Tensor[] kCache, Tensor[] vCache, Tensor blockIDs, int blockSize, int headNum, int headDim, int splitNum, int layerNum) -> ()"
);
ops.impl("transpose_kv_cache_by_block", torch::kPrivateUse1, &vllm_ascend::transpose_kv_cache_by_block);
}

View File

@@ -435,6 +435,20 @@ std::tuple<at::Tensor,at::Tensor, at::Tensor> npu_add_rms_norm_bias_meta(
at::Tensor x = at::empty_symint(x1.sym_sizes(), x1.options());
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y, rstd, x);
}
void transpose_kv_cache_by_block_meta(
const at::TensorList &k_cache,
const at::TensorList &v_cache,
const at::Tensor &block_ids,
int64_t block_size,
int64_t head_num,
int64_t head_dim,
int64_t split_num,
int64_t layer_num)
{
return;
}
} // namespace meta
} // namespace vllm_ascend
@@ -475,5 +489,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
ops.impl("moe_gating_top_k", &vllm_ascend::meta::moe_gating_top_k_meta);
// Add_Rms_Norm_Bias
ops.impl("npu_add_rms_norm_bias", &vllm_ascend::meta::npu_add_rms_norm_bias_meta);
// transpose_kv_cache_by_block
ops.impl("transpose_kv_cache_by_block", &vllm_ascend::meta::transpose_kv_cache_by_block_meta);
}
}

View File

@@ -0,0 +1,40 @@
# This program is free software, you can redistribute it and/or modify it.
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
add_ops_compile_options(
OP_NAME TransposeKvCacheByBlock
OPTIONS --cce-auto-sync=off
-Wno-deprecated-declarations
-Werror
-mllvm -cce-aicore-hoist-movemask=false
--op_relocatable_kernel_binary=true
)
target_sources(op_host_aclnn PRIVATE
transpose_kv_cache_by_block_def.cpp
)
target_sources(optiling PRIVATE
transpose_kv_cache_by_block_tiling.cpp
)
if (NOT BUILD_OPEN_PROJECT)
target_sources(opmaster_ct PRIVATE
transpose_kv_cache_by_block_tiling.cpp
)
endif ()
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
target_sources(opsproto PRIVATE
transpose_kv_cache_by_block_proto.cpp
)

View File

@@ -0,0 +1,36 @@
#include "register/op_def_registry.h"
namespace ops {
class TransposeKvCacheByBlock : public OpDef {
public:
explicit TransposeKvCacheByBlock(const char* name) : OpDef(name)
{
this->Input("KCache")
.ParamType(DYNAMIC)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("VCache")
.ParamType(DYNAMIC)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("blockIDs")
.ParamType(REQUIRED)
.DataTypeList({ge::DT_INT64})
.FormatList({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Attr("blockSize").Int();
this->Attr("headNum").Int();
this->Attr("headDim").Int();
this->Attr("splitNum").Int();
this->Attr("layerNum").Int();
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
}
};
OP_ADD(TransposeKvCacheByBlock);
}

View File

@@ -0,0 +1,36 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file transpose_kv_cache_by_block_proto.cpp
* \brief
*/
#include <graph/utils/type_utils.h>
#include <register/op_impl_registry.h>
#include "error/ops_error.h"
using namespace ge;
namespace ops {
static ge::graphStatus InferShapeTransposeKvCacheByBlock(gert::InferShapeContext* context)
{
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus InferDataTypeTransposeKvCacheByBlock(gert::InferDataTypeContext *context)
{
return ge::GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(TransposeKvCacheByBlock)
.InferShape(InferShapeTransposeKvCacheByBlock)
.InferDataType(InferDataTypeTransposeKvCacheByBlock);
} // namespace ops

View File

@@ -0,0 +1,182 @@
#include "transpose_kv_cache_by_block_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
#include "log/ops_log.h"
#include <algorithm>
namespace optiling {
constexpr uint64_t DATA_SIZE = 2;
constexpr uint64_t BLOCK_SIZE = 32;
constexpr uint64_t DB_ON = 2;
constexpr uint32_t FULL_LOAD = 0;
constexpr uint32_t SPLIT_BLOCK_SIZE_ALIGNED_AND_DB = 1;
constexpr uint32_t SPLIT_BLOCK_SIZE_UNALIGNED_AND_DB = 3;
constexpr uint32_t SPLIT_BLOCK_SIZE_ALIGNED_AND_NOT_DB = 2;
constexpr uint32_t SPLIT_BLOCK_SIZE_UNALIGNED_AND_NOT_DB = 4;
void findFactorsOptimized(std::vector<int64_t> &factors, int64_t n) {
for (int64_t i = 1; i * i <= n; i++) {
if (n % i == 0) {
factors.push_back(i);
if (i != n / i) {
factors.push_back(n / i);
}
}
}
sort(factors.begin(), factors.end());
}
ge::graphStatus CalTiling(gert::TilingContext* context, TransposeKvCacheByBlockTilingData &tiling)
{
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
OPS_LOG_E_IF_NULL(context, platformInfoPtr, return ge::GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
int64_t useCoreNum = ascendcPlatform.GetCoreNumAiv();
uint64_t ubSize;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
auto attr = context->GetAttrs();
OPS_LOG_E_IF_NULL(context, attr, return ge::GRAPH_FAILED);
const int64_t* blockSizePtr = attr->GetAttrPointer<int64_t>(0);
const int64_t* headNumPtr = attr->GetAttrPointer<int64_t>(1);
const int64_t* headDimPtr = attr->GetAttrPointer<int64_t>(2);
const int64_t* splitNumPtr = attr->GetAttrPointer<int64_t>(3);
const int64_t* layerNumPtr = attr->GetAttrPointer<int64_t>(4);
OPS_CHECK(blockSizePtr == nullptr || headNumPtr == nullptr || headDimPtr == nullptr ||
splitNumPtr == nullptr || layerNumPtr == nullptr,
OPS_LOG_E(context->GetNodeName(), "Get attr failed."),
return ge::GRAPH_FAILED);
auto blockIDsTensor = context->GetDynamicInputTensor(2, 0);
OPS_LOG_E_IF_NULL(context, blockIDsTensor, return ge::GRAPH_FAILED);
gert::Shape blockIDsTensorShape = blockIDsTensor->GetStorageShape();
int64_t calBlockNum = static_cast<int64_t>(blockIDsTensorShape.GetDim(0));
tiling.set_calBlockNum(static_cast<uint32_t>(calBlockNum));
int64_t blockSize = *blockSizePtr;
int64_t headNum = *headNumPtr;
int64_t headDim = *headDimPtr;
int64_t splitNum = *splitNumPtr;
int64_t layerNum = *layerNumPtr;
uint32_t tilingKey = FULL_LOAD;
if (headDim * DATA_SIZE % BLOCK_SIZE != 0) {
OPS_LOG_E(context, "headDim * DATA_SIZE must be a multiple of 32 bytes.");
return ge::GRAPH_FAILED;
}
std::vector<int64_t> factors;
findFactorsOptimized(factors, useCoreNum);
uint32_t factorIndex = 0;
bool findSplitNum = true;
int64_t blockSizeSplitNum = factors[factorIndex];
uint64_t dataSizeloadOnce = blockSize * headNum * headDim * DATA_SIZE;
// if can full load, not split blockSize and db
if (dataSizeloadOnce > ubSize) {
tilingKey = SPLIT_BLOCK_SIZE_ALIGNED_AND_DB;
// split blockSize and db
while (dataSizeloadOnce > (ubSize / DB_ON)) {
factorIndex += 1;
if (factorIndex == factors.size()) {
tilingKey = FULL_LOAD;
findSplitNum = false;
break;
}
blockSizeSplitNum = factors[factorIndex];
dataSizeloadOnce = ((blockSize + blockSizeSplitNum - 1) / blockSizeSplitNum) * headNum * headDim * DATA_SIZE;
}
if (tilingKey == SPLIT_BLOCK_SIZE_ALIGNED_AND_DB && (blockSize % blockSizeSplitNum != 0)) {
tilingKey = SPLIT_BLOCK_SIZE_UNALIGNED_AND_DB;
}
}
if (!findSplitNum) {
tilingKey = SPLIT_BLOCK_SIZE_ALIGNED_AND_NOT_DB;
// split blockSize but not db
findSplitNum = true;
factorIndex = 0;
blockSizeSplitNum = factors[factorIndex];
dataSizeloadOnce = blockSize * headNum * headDim * DATA_SIZE;
while (dataSizeloadOnce > ubSize) {
factorIndex += 1;
if (factorIndex == factors.size()) {
tilingKey = FULL_LOAD;
findSplitNum = false;
break;
}
blockSizeSplitNum = factors[factorIndex];
dataSizeloadOnce = ((blockSize + blockSizeSplitNum - 1) / blockSizeSplitNum) * headNum * headDim * DATA_SIZE;
}
if (tilingKey == SPLIT_BLOCK_SIZE_ALIGNED_AND_NOT_DB && (blockSize % blockSizeSplitNum != 0)) {
tilingKey = SPLIT_BLOCK_SIZE_UNALIGNED_AND_NOT_DB;
}
}
// headNum * headDim too large
if (!findSplitNum) {
OPS_LOG_E(context, "headNum * headDim * sizeof(half) > ubSize "
"or blockSize * headNum * headDim * sizeof(half) > ubSize * vectorCoreNum. "
"Currently, splitting headNum or headDim is not supported.");
return ge::GRAPH_FAILED;
}
tiling.set_blockSizePerTime(static_cast<uint32_t>((blockSize + blockSizeSplitNum - 1) / blockSizeSplitNum));
tiling.set_blockSizePerTimeTail(static_cast<uint32_t>(blockSize % blockSizeSplitNum));
tiling.set_blockSizeSplitNum(static_cast<uint32_t>(blockSizeSplitNum));
tiling.set_blockSize(static_cast<uint32_t>(blockSize));
tiling.set_headNum(static_cast<uint32_t>(headNum));
tiling.set_headDim(static_cast<uint32_t>(headDim));
tiling.set_splitNum(static_cast<uint32_t>(splitNum));
tiling.set_layerNum(static_cast<uint32_t>(layerNum));
int64_t totalRound = layerNum * calBlockNum;
if ((totalRound * blockSizeSplitNum) < useCoreNum) {
useCoreNum = totalRound * blockSizeSplitNum;
}
int64_t blockPerCore = totalRound / (useCoreNum / blockSizeSplitNum);
int64_t tailCoreNum = totalRound % (useCoreNum / blockSizeSplitNum);
tiling.set_useCoreNum(static_cast<uint32_t>(useCoreNum));
tiling.set_blockPerCore(static_cast<uint32_t>(blockPerCore));
tiling.set_tailCoreNum(static_cast<uint32_t>(tailCoreNum));
context->SetBlockDim(useCoreNum);
context->SetTilingKey(tilingKey);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TransposeKvCacheByBlockTilingFunc(gert::TilingContext* context)
{
TransposeKvCacheByBlockTilingData tiling;
auto status = CalTiling(context, tiling);
OP_CHECK(status != ge::GRAPH_SUCCESS, OPS_LOG_E(context->GetNodeName(), "Cal tiling failed."),
return ge::GRAPH_FAILED);
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
return ge::GRAPH_SUCCESS;
}
struct TransposeKvCacheByBlockCompileInfo {};
ge::graphStatus TilingParseForTransposeKvCacheByBlock(gert::TilingParseContext *context)
{
(void)context;
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(TransposeKvCacheByBlock)
.Tiling(TransposeKvCacheByBlockTilingFunc)
.TilingParse<TransposeKvCacheByBlockCompileInfo>(TilingParseForTransposeKvCacheByBlock);
}

View File

@@ -0,0 +1,23 @@
#include "register/tilingdata_base.h"
namespace optiling {
BEGIN_TILING_DATA_DEF(TransposeKvCacheByBlockTilingData)
// shape info
// TILING_DATA_FIELD_DEF(uint32_t, blockNum);
TILING_DATA_FIELD_DEF(uint32_t, blockSize);
TILING_DATA_FIELD_DEF(uint32_t, headNum);
TILING_DATA_FIELD_DEF(uint32_t, headDim);
TILING_DATA_FIELD_DEF(uint32_t, splitNum);
TILING_DATA_FIELD_DEF(uint32_t, layerNum);
// tiling info
TILING_DATA_FIELD_DEF(uint32_t, useCoreNum);
TILING_DATA_FIELD_DEF(uint32_t, blockPerCore);
TILING_DATA_FIELD_DEF(uint32_t, tailCoreNum);
TILING_DATA_FIELD_DEF(uint32_t, calBlockNum);
TILING_DATA_FIELD_DEF(uint32_t, blockSizePerTime);
TILING_DATA_FIELD_DEF(uint32_t, blockSizePerTimeTail);
TILING_DATA_FIELD_DEF(uint32_t, blockSizeSplitNum);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(TransposeKvCacheByBlock, TransposeKvCacheByBlockTilingData)
}

View File

@@ -0,0 +1,16 @@
#include "kernel_operator.h"
using namespace AscendC;
#ifndef __OP_KERNEL_KV_CACHE_TRANSPOSE_H__
#define __OP_KERNEL_KV_CACHE_TRANSPOSE_H__
template <typename T>
__aicore__ inline __gm__ T* GetTensorAddr(uint16_t index, GM_ADDR tensorPtr) {
__gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr);
// The offset of the data address from the first address.
uint64_t tensorPtrOffset = *dataAddr;
// Moving 3 bits to the right means dividing by sizeof(uint64 t).
__gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3);
return reinterpret_cast<__gm__ T*>(*(retPtr + index));
}
#endif

View File

@@ -0,0 +1,141 @@
#include "common.h"
template <typename T>
class TransposeKvCacheByBlockKernelFullLoad {
protected:
TQueBind<TPosition::VECIN, TPosition::VECOUT, 1> vecInQueue_;
GlobalTensor<T> kCacheGm_;
GlobalTensor<T> vCacheGm_;
GlobalTensor<int64_t> blockIDsGm_;
GM_ADDR kCachePtr_;
GM_ADDR vCachePtr_;
// shape info
uint32_t blockNum_;
uint32_t blockSize_;
uint32_t headNum_;
uint32_t headDim_;
uint32_t splitNum_;
uint32_t layerNum_;
// tiling info
uint32_t useCoreNum_;
uint32_t blockPerCore_;
uint32_t tailCoreNum_;
uint32_t calBlockNum_;
uint32_t srcFactor_;
uint32_t dstFactor_;
uint32_t copyOutLength_;
uint32_t dataBlockSize_;
__aicore__ inline void CopyIn(GlobalTensor<T> &cacheGm, uint32_t offsetBlock, DataCopyParams &repeatParams) {
LocalTensor<T> cacheLocal = vecInQueue_.AllocTensor<T>();
for (uint32_t i = 0; i < splitNum_; ++i) {
DataCopy(cacheLocal[i * dstFactor_], cacheGm[i * srcFactor_ + offsetBlock], repeatParams);
}
vecInQueue_.EnQue(cacheLocal);
}
__aicore__ inline void CopyOut(GlobalTensor<T> &cacheGm, uint32_t offsetBlock) {
LocalTensor<T> cacheLocal = vecInQueue_.DeQue<T>();
DataCopy(cacheGm[offsetBlock], cacheLocal, copyOutLength_);
vecInQueue_.FreeTensor(cacheLocal);
}
__aicore__ inline void SetGlobalBuffers(uint32_t layerId) {
kCacheGm_.SetGlobalBuffer(GetTensorAddr<T>(layerId, kCachePtr_));
vCacheGm_.SetGlobalBuffer(GetTensorAddr<T>(layerId, vCachePtr_));
}
__aicore__ inline void Caloffset(uint32_t &startBlock, uint32_t &endBlock, uint32_t &startLayer, uint32_t &endLayer) {
uint32_t blockIdx = GetBlockIdx();
uint32_t curBlockStart;
uint32_t curBlocknum;
if (blockIdx < tailCoreNum_) {
curBlockStart = blockIdx * (blockPerCore_ + 1);
curBlocknum = blockPerCore_ + 1;
} else {
curBlockStart = blockIdx * blockPerCore_ + tailCoreNum_;
curBlocknum = blockPerCore_;
}
uint32_t curBlockEnd = curBlockStart + curBlocknum;
startBlock = curBlockStart / layerNum_;
startLayer = curBlockStart % layerNum_;
endBlock = (curBlockEnd + layerNum_ - 1) / layerNum_;
endLayer = curBlockEnd % layerNum_;
if (endLayer == 0) {
endLayer = layerNum_;
}
}
public:
__aicore__ inline void Init(GM_ADDR KCache, GM_ADDR VCache, GM_ADDR blockIDs,
TransposeKvCacheByBlockTilingData* tilingData, TPipe* tPipe) {
kCachePtr_ = KCache;
vCachePtr_ = VCache;
blockIDsGm_.SetGlobalBuffer((__gm__ int64_t*)blockIDs);
// shape info
blockSize_ = tilingData->blockSize;
headNum_ = tilingData->headNum;
headDim_ = tilingData->headDim;
splitNum_ = tilingData->splitNum;
layerNum_ = tilingData->layerNum;
// tiling info
useCoreNum_ = tilingData->useCoreNum;
blockPerCore_ = tilingData->blockPerCore;
tailCoreNum_ = tilingData->tailCoreNum;
calBlockNum_ = tilingData->calBlockNum;
tPipe->InitBuffer(vecInQueue_, 1, TOTAL_UB_SIZE);
srcFactor_ = blockSize_ * headNum_ / splitNum_ * headDim_;
dstFactor_ = headNum_ / splitNum_ * headDim_;
copyOutLength_ = blockSize_ * headNum_ * headDim_;
dataBlockSize_ = static_cast<uint32_t>(AscendC::GetDataBlockSizeInBytes());
}
__aicore__ inline void Process() {
DataCopyParams repeatParams;
repeatParams.blockCount = blockSize_;
repeatParams.blockLen = headNum_ / splitNum_ * headDim_ * sizeof(T) / dataBlockSize_;
repeatParams.srcStride = 0;
repeatParams.dstStride = (headNum_ * headDim_ - headNum_ / splitNum_ * headDim_) * sizeof(T) / dataBlockSize_;
uint32_t startBlock;
uint32_t endBlock;
uint32_t startLayer;
uint32_t endLayer;
Caloffset(startBlock, endBlock, startLayer, endLayer);
for (uint32_t i = startBlock; i < endBlock; ++i) {
int64_t blockId = blockIDsGm_.GetValue(i);
uint32_t offsetBlock = blockId * blockSize_ * headNum_ * headDim_;
uint32_t realStartLayer;
uint32_t realEndLayer;
if (i == startBlock) {
realStartLayer = startLayer;
} else {
realStartLayer = 0;
}
if (i == (endBlock - 1)) {
realEndLayer = endLayer;
} else {
realEndLayer = layerNum_;
}
for (uint32_t layerId = realStartLayer; layerId < realEndLayer; ++layerId) {
SetGlobalBuffers(layerId);
CopyIn(kCacheGm_, offsetBlock, repeatParams);
CopyOut(kCacheGm_, offsetBlock);
CopyIn(vCacheGm_, offsetBlock, repeatParams);
CopyOut(vCacheGm_, offsetBlock);
}
}
}
};

View File

@@ -0,0 +1,190 @@
#include "common.h"
template <typename T, uint32_t DB, bool needHandleUnFactorSplit>
class TransposeKvCacheByBlockKernelGeneral {
protected:
TQueBind<TPosition::VECIN, TPosition::VECOUT, 1> queBind_;
GlobalTensor<T> kCacheGm_;
GlobalTensor<T> vCacheGm_;
GlobalTensor<int64_t> blockIDsGm_;
GM_ADDR kCachePtr_;
GM_ADDR vCachePtr_;
// shape info
uint32_t blockNum_;
uint32_t blockSize_;
uint32_t headNum_;
uint32_t headDim_;
uint32_t splitNum_;
uint32_t layerNum_;
uint32_t headNumSplited_;
uint32_t blockSizeSplitNum_;
// tiling info
uint32_t useCoreNum_;
uint32_t blockPerCore_;
uint32_t tailCoreNum_;
uint32_t calBlockNum_;
uint32_t srcFactor_;
uint32_t dstFactor_;
uint32_t copyOutLength_;
uint32_t blockSizePerTime_;
uint32_t blockSizePerTimeTail_;
uint32_t blockIdx_;
uint32_t dataBlockSize_;
bool needSync_;
__aicore__ inline void CopyIn(GlobalTensor<T> &cacheGm, uint32_t offsetBlock, DataCopyParams &repeatParams) {
LocalTensor<T> cacheLocal = queBind_.AllocTensor<T>();
for (uint32_t i = 0; i < splitNum_; ++i) {
DataCopy(cacheLocal[i * dstFactor_], cacheGm[i * srcFactor_ + offsetBlock], repeatParams);
}
queBind_.EnQue(cacheLocal);
}
__aicore__ inline void CopyOut(GlobalTensor<T> &cacheGm, uint32_t offsetBlock) {
LocalTensor<T> cacheLocal = queBind_.DeQue<T>();
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8);
AscendC::CrossCoreWaitFlag(0x8);
DataCopy(cacheGm[offsetBlock], cacheLocal, copyOutLength_);
queBind_.FreeTensor(cacheLocal);
}
__aicore__ inline void SetGlobalBuffers(uint32_t layerId) {
kCacheGm_.SetGlobalBuffer(GetTensorAddr<T>(layerId, kCachePtr_));
vCacheGm_.SetGlobalBuffer(GetTensorAddr<T>(layerId, vCachePtr_));
}
__aicore__ inline void Caloffset(uint32_t &startBlock, uint32_t &endBlock, uint32_t &startLayer, uint32_t &endLayer) {
uint32_t curBlockStart;
uint32_t curBlocknum;
uint32_t groupBlockIdx = blockIdx_ / blockSizeSplitNum_;
if (groupBlockIdx < tailCoreNum_) {
needSync_ = false;
curBlockStart = groupBlockIdx * (blockPerCore_ + 1);
curBlocknum = blockPerCore_ + 1;
} else {
needSync_ = true;
curBlockStart = groupBlockIdx * blockPerCore_ + tailCoreNum_;
curBlocknum = blockPerCore_;
}
uint32_t curBlockEnd = curBlockStart + curBlocknum;
startBlock = curBlockStart / layerNum_;
startLayer = curBlockStart % layerNum_;
endBlock = (curBlockEnd + layerNum_ - 1) / layerNum_;
endLayer = curBlockEnd % layerNum_;
if (endLayer == 0) {
endLayer = layerNum_;
}
}
public:
__aicore__ inline void Init(GM_ADDR KCache, GM_ADDR VCache, GM_ADDR blockIDs,
TransposeKvCacheByBlockTilingData* tilingData, TPipe* tPipe) {
kCachePtr_ = KCache;
vCachePtr_ = VCache;
blockIDsGm_.SetGlobalBuffer((__gm__ int64_t*)blockIDs);
blockIdx_ = GetBlockIdx();
// shape info
blockSize_ = tilingData->blockSize;
headNum_ = tilingData->headNum;
headDim_ = tilingData->headDim;
splitNum_ = tilingData->splitNum;
layerNum_ = tilingData->layerNum;
// tiling info
useCoreNum_ = tilingData->useCoreNum;
blockPerCore_ = tilingData->blockPerCore;
tailCoreNum_ = tilingData->tailCoreNum;
calBlockNum_ = tilingData->calBlockNum;
blockSizeSplitNum_ = tilingData->blockSizeSplitNum;
blockSizePerTime_ = tilingData->blockSizePerTime;
blockSizePerTimeTail_ = tilingData->blockSizePerTimeTail;
headNumSplited_ = headNum_ / splitNum_;
tPipe->InitBuffer(queBind_, DB, TOTAL_UB_SIZE / DB);
srcFactor_ = blockSize_ * headNumSplited_ * headDim_;
dstFactor_ = headNumSplited_ * headDim_;
copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_;
dataBlockSize_ = static_cast<uint32_t>(AscendC::GetDataBlockSizeInBytes());
}
__aicore__ inline void Process() {
DataCopyParams repeatParams;
repeatParams.blockCount = blockSizePerTime_;
repeatParams.blockLen = headNumSplited_ * headDim_ * sizeof(T) / dataBlockSize_;
repeatParams.srcStride = 0;
repeatParams.dstStride = (headNum_ * headDim_ - headNumSplited_ * headDim_) * sizeof(T) / dataBlockSize_;
uint32_t startBlock;
uint32_t endBlock;
uint32_t startLayer;
uint32_t endLayer;
Caloffset(startBlock, endBlock, startLayer, endLayer);
for (uint32_t i = startBlock; i < endBlock; ++i) {
int64_t blockId = blockIDsGm_.GetValue(i);
uint32_t offsetBlock = blockId * blockSize_ * headNum_ * headDim_;
uint32_t realStartLayer;
uint32_t realEndLayer;
if (i == startBlock) {
realStartLayer = startLayer;
} else {
realStartLayer = 0;
}
if (i == (endBlock - 1)) {
realEndLayer = endLayer;
} else {
realEndLayer = layerNum_;
}
for (uint32_t layerId = realStartLayer; layerId < realEndLayer; ++layerId) {
SetGlobalBuffers(layerId);
uint32_t blockSizeIndex = blockIdx_ % blockSizeSplitNum_;
uint32_t srcOffset;
uint32_t dstOffset;
if constexpr (needHandleUnFactorSplit) {
// handle tail
if (blockSizeIndex >= blockSizePerTimeTail_) {
repeatParams.blockCount = (blockSizePerTime_ - 1);
copyOutLength_ = (blockSizePerTime_ - 1) * headNum_ * headDim_;
srcOffset = (blockSizeIndex * blockSizePerTime_ - (blockSizeIndex - blockSizePerTimeTail_)) * headNumSplited_ * headDim_;
dstOffset = (blockSizeIndex * blockSizePerTime_ - (blockSizeIndex - blockSizePerTimeTail_)) * headNum_ * headDim_;
} else {
repeatParams.blockCount = blockSizePerTime_;
copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_;
srcOffset = blockSizeIndex * blockSizePerTime_ * headNumSplited_ * headDim_;
dstOffset = blockSizeIndex * blockSizePerTime_ * headNum_ * headDim_;
}
} else {
repeatParams.blockCount = blockSizePerTime_;
copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_;
srcOffset = blockSizeIndex * blockSizePerTime_ * headNumSplited_ * headDim_;
dstOffset = blockSizeIndex * blockSizePerTime_ * headNum_ * headDim_;
}
CopyIn(kCacheGm_, offsetBlock + srcOffset, repeatParams);
CopyOut(kCacheGm_, offsetBlock + dstOffset);
CopyIn(vCacheGm_, offsetBlock + srcOffset, repeatParams);
CopyOut(vCacheGm_, offsetBlock + dstOffset);
}
}
if (needSync_) {
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8);
AscendC::CrossCoreWaitFlag(0x8);
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8);
AscendC::CrossCoreWaitFlag(0x8);
}
}
};

View File

@@ -0,0 +1,36 @@
#include "kernel_operator.h"
#include "full_load.h"
#include "general.h"
extern "C" __global__ __aicore__ void transpose_kv_cache_by_block(GM_ADDR KCache, GM_ADDR VCache, GM_ADDR blockIDs, GM_ADDR workspace, GM_ADDR tiling) {
GET_TILING_DATA(tiling_data, tiling);
TPipe tPipe;
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIV_1_0);
if (TILING_KEY_IS(0)) {
// full load not db
TransposeKvCacheByBlockKernelFullLoad<DTYPE_KCACHE> kernel;
kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe);
kernel.Process();
} else if (TILING_KEY_IS(1)) {
// db \ align split blockSize
TransposeKvCacheByBlockKernelGeneral<DTYPE_KCACHE, uint32_t(2), false> kernel;
kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe);
kernel.Process();
} else if (TILING_KEY_IS(2)) {
// not db \ align split blockSize
TransposeKvCacheByBlockKernelGeneral<DTYPE_KCACHE, uint32_t(1), false> kernel;
kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe);
kernel.Process();
} else if (TILING_KEY_IS(3)) {
// db \ unalign split blockSize
TransposeKvCacheByBlockKernelGeneral<DTYPE_KCACHE, uint32_t(2), true> kernel;
kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe);
kernel.Process();
} else if (TILING_KEY_IS(4)) {
// not db \ unalign split blockSize
TransposeKvCacheByBlockKernelGeneral<DTYPE_KCACHE, uint32_t(1), true> kernel;
kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe);
kernel.Process();
}
}

View File

@@ -0,0 +1,137 @@
import random
import unittest
import torch
import torch_npu
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
torch.set_printoptions(threshold=float("inf"))
def clone_kv_cache(k_caches, v_caches):
new_k_caches = [cache.clone() for cache in k_caches]
new_v_caches = [cache.clone() for cache in v_caches]
return new_k_caches, new_v_caches
class TestTransposeKvCacheByBlock(unittest.TestCase):
def compute_golden(self, k_caches, v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, num_need_pulls, layers, dtype):
num_blocks = block_ids_tensor.shape[0]
block_ids_tensor = block_ids_tensor.to(dtype=torch.int32)
block_offsets = torch.arange(0, block_size, dtype=torch.int32).npu()
slot_mapping = block_offsets.reshape(
(1, block_size)) + block_ids_tensor.reshape(
(num_blocks, 1)) * block_size
slot_mapping = slot_mapping.flatten()
block_len = num_blocks * block_size
block_len_tensor = torch.tensor([block_len],dtype=torch.int32).npu()
block_table = block_ids_tensor.view(1, -1)
seq_start_tensor = torch.tensor([0],dtype=torch.int32).npu()
k = torch.empty(block_len, num_kv_head, head_dim, dtype=dtype).npu()
v = torch.empty(block_len, num_kv_head, head_dim, dtype=dtype).npu()
for layer in range(layers):
k_cache_layer = k_caches[layer]
v_cache_layer = v_caches[layer]
torch_npu.atb.npu_paged_cache_load(
k_cache_layer,
v_cache_layer,
block_table,
block_len_tensor,
seq_starts=seq_start_tensor,
key=k,
value=v,
)
k = k.view(num_blocks, num_need_pulls, block_size, -1)
k.transpose_(1, 2)
k = k.contiguous().view(block_len, num_kv_head, -1)
v = v.view(num_blocks, num_need_pulls, block_size, -1)
v.transpose_(1, 2)
v = v.contiguous().view(block_len, num_kv_head, -1)
torch_npu._npu_reshape_and_cache(
key=k,
value=v,
key_cache=k_cache_layer,
value_cache=v_cache_layer,
slot_indices=slot_mapping,
)
del k, v
def test_transpose_kv_cache_by_block(self):
# (layers, block_num, block_size, num_kv_head, head_dim, num_need_pulls)
test_cases = [
(16, 128, 128, 4, 128, 4),
(16, 128, 128, 4, 128, 2),
(16, 128, 128, 4, 128, 1),
(16, 128, 128, 8, 128, 8),
(16, 128, 128, 8, 128, 4),
(16, 128, 128, 8, 128, 2),
]
dtypes = [torch.float16, torch.bfloat16]
for dtype in dtypes:
for layers, block_num, block_size, num_kv_head, head_dim, num_need_pulls in test_cases:
with self.subTest(dtype=dtype, shape=f"({layers}, {block_num}, {block_size}, {num_kv_head}, {head_dim}, {num_need_pulls})"):
k_caches = []
v_caches = []
block_id_num = 33
block_ids_tensor = torch.randperm(block_num, dtype=torch.int64, device="npu")[:block_id_num]
for i in range(layers):
kcache = torch.randn(block_num, block_size, num_kv_head, head_dim, dtype=dtype, device="npu")
vcache = torch.randn(block_num, block_size, num_kv_head, head_dim, dtype=dtype, device="npu")
k_caches.append(kcache)
v_caches.append(vcache)
cloned_k_caches, cloned_v_caches = clone_kv_cache(k_caches, v_caches)
self.compute_golden(cloned_k_caches, cloned_v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, num_need_pulls, layers, dtype)
torch.ops._C_ascend.transpose_kv_cache_by_block(k_caches, v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, num_need_pulls, layers)
for i in range (layers):
self.assert_tensors_almost_equal(k_caches[i], cloned_k_caches[i], dtype)
self.assert_tensors_almost_equal(v_caches[i], cloned_v_caches[i], dtype)
def assert_tensors_almost_equal(self, actual, expected, dtype):
"""Check if two tensors are approximately equal (considering floating point errors)"""
self.assertEqual(actual.shape, expected.shape, "Shape mismatch")
# Check for NaN
self.assertFalse(
torch.isnan(actual).any(), "Actual result contains NaN")
self.assertFalse(
torch.isnan(expected).any(), "Expected result contains NaN")
# Check for Inf
self.assertFalse(
torch.isinf(actual).any(), "Actual result contains Inf")
self.assertFalse(
torch.isinf(expected).any(), "Expected result contains Inf")
# Set different tolerances based on data type
if dtype == torch.float16:
rtol, atol = 1e-5, 1e-5
else: # bfloat16
rtol, atol = 1.5e-5, 1.5e-5
# Compare values
diff = torch.abs(actual - expected)
max_diff = diff.max().item()
max_expected = torch.abs(expected).max().item()
# Check relative and absolute errors
if max_expected > 0:
relative_diff = max_diff / max_expected
self.assertLessEqual(
relative_diff,
rtol,
f"Relative error too large: {relative_diff} > {rtol}. Max difference: {max_diff}",
)
self.assertLessEqual(max_diff, atol,
f"Absolute error too large: {max_diff} > {atol}")

View File

@@ -46,10 +46,11 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import RequestStatus
from vllm_ascend import envs as ascend_envs
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te
from vllm_ascend.distributed.kv_transfer.utils.utils import get_transfer_timeout_value
from vllm_ascend.utils import is_vl_model
from vllm_ascend.utils import enable_custom_op, is_vl_model
# isort: off
if TYPE_CHECKING:
@@ -570,8 +571,39 @@ class KVCacheRecvingThread(threading.Thread):
is_kv_transfer_end = global_offset == tp_num_need_pulls * self._prefill_pp_size - 1
need_cat_cache = tp_num_need_pulls > 1 and is_kv_transfer_end
need_nz_cache = get_ascend_config().enable_kv_nz and is_kv_transfer_end
use_fused_op = ascend_envs.VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK
if need_nz_cache or need_cat_cache:
self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, need_cat_cache, need_nz_cache)
# use fused op to reformat kv cache, we keep original implementation to provide ability to disable it.
if use_fused_op and enable_custom_op():
if need_cat_cache:
# the fused op only support cat GQA/MHA kv cache by head
self.reformat_kv_cache_with_fused_op(grouped_local_block_ids, tp_num_need_pulls)
if need_nz_cache:
# maybe use fused op to reformat kv nz too in the future.
self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, False, need_nz_cache)
else:
self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, need_cat_cache, need_nz_cache)
def reformat_kv_cache_with_fused_op(self, block_ids: list[list[int]], tp_num_need_pulls: int):
# Get necessary parameters
k_cache = list(self.kv_caches.values())[0][0]
device = k_cache.device
head_dim = self.model_config.hf_text_config.head_dim
block_size = self.vllm_config.cache_config.block_size
num_kv_head = max(self.model_config.hf_text_config.num_key_value_heads // self.tp_size, 1)
layers = self.model_config.hf_text_config.num_hidden_layers
flat_block_ids = [item for sublist in block_ids for item in sublist]
block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int64, device=device)
k_caches = []
v_caches = []
for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items():
k_caches.append(k_cache_layer)
v_caches.append(v_cache_layer)
torch.ops._C_ascend.transpose_kv_cache_by_block(
k_caches, v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, tp_num_need_pulls, layers
)
def reformat_kv_cache(
self,

View File

@@ -111,6 +111,10 @@ env_variables: dict[str, Callable[[], Any]] = {
"VLLM_ASCEND_ENABLE_FUSED_MC2": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", "0")),
# Whether to anbale balance scheduling
"VLLM_ASCEND_BALANCE_SCHEDULING": lambda: bool(int(os.getenv("VLLM_ASCEND_BALANCE_SCHEDULING", "0"))),
# use fused op transpose_kv_cache_by_block, default is True
"VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK": lambda: bool(
int(os.getenv("VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK", "1"))
),
}
# end-env-vars-definition