[Kernel] Add AscendC fused op transpose_kv_cache_by_block to speed up GQA transfer (#6366)

### What this PR does / why we need it?
As #2947 describe, we need to transpose kv cache layout after GQA kv
transfer when prefill and decode tensor parallel size are heterogeneous,
in the previous implementation, we use `npu_paged_cache_load ` +
`tranpose` + `_npu_reshape_and_cache` to do this work.

But obviously, it is not an efficient plan, the ops above need to be
called for each layer, which introduces 3 * layer_num kernel launch, and
6 * layer_num data movement between L1 Cache and HBM for one request on
decode node. Usually, decode node uses graph mode, so these op kernels
will be called between decode forward launched by an async thread in
mooncacke connector, this kernels maybe last for several decode forward
and TTFT will increase by 3~4 decode forward time.

In this PR, we implement an AscendC fused op
`transpose_kv_cache_by_block` to do this with only once kernel launch
and move data between L1 Cache and HBM only once.

After using this fused op, the time cost in transpose kv cacke layout
can be decreased to 0.24ms from 7ms in UT on 910C, and in PD
disaggregation scenario, TTFT can decrease about 90 ~ 110 ms in
qwen3-235B.

| request_num | original | fused_op|
|:----------------------:|:---------------:|:-------------------:|
|           1            |      643 ms      |        578 ms        |
|          128           |     1480 ms      |       1368 ms        |

### Does this PR introduce _any_ user-facing change?
Use fused op by default, incase the op has bug in any scenario, provide
fallback choice using env to disable it.

**DISABLE fused op by add following env**
`export VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK=0`

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: lidenghui <lidenghui1110@gmail.com>
This commit is contained in:
lidenghui1110
2026-02-03 14:10:01 +08:00
committed by GitHub
parent f4a72f0d16
commit 79803932e2
15 changed files with 913 additions and 3 deletions

View File

@@ -0,0 +1,16 @@
#include "kernel_operator.h"
using namespace AscendC;
#ifndef __OP_KERNEL_KV_CACHE_TRANSPOSE_H__
#define __OP_KERNEL_KV_CACHE_TRANSPOSE_H__
template <typename T>
__aicore__ inline __gm__ T* GetTensorAddr(uint16_t index, GM_ADDR tensorPtr) {
__gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr);
// The offset of the data address from the first address.
uint64_t tensorPtrOffset = *dataAddr;
// Moving 3 bits to the right means dividing by sizeof(uint64 t).
__gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3);
return reinterpret_cast<__gm__ T*>(*(retPtr + index));
}
#endif

View File

@@ -0,0 +1,141 @@
#include "common.h"
template <typename T>
class TransposeKvCacheByBlockKernelFullLoad {
protected:
TQueBind<TPosition::VECIN, TPosition::VECOUT, 1> vecInQueue_;
GlobalTensor<T> kCacheGm_;
GlobalTensor<T> vCacheGm_;
GlobalTensor<int64_t> blockIDsGm_;
GM_ADDR kCachePtr_;
GM_ADDR vCachePtr_;
// shape info
uint32_t blockNum_;
uint32_t blockSize_;
uint32_t headNum_;
uint32_t headDim_;
uint32_t splitNum_;
uint32_t layerNum_;
// tiling info
uint32_t useCoreNum_;
uint32_t blockPerCore_;
uint32_t tailCoreNum_;
uint32_t calBlockNum_;
uint32_t srcFactor_;
uint32_t dstFactor_;
uint32_t copyOutLength_;
uint32_t dataBlockSize_;
__aicore__ inline void CopyIn(GlobalTensor<T> &cacheGm, uint32_t offsetBlock, DataCopyParams &repeatParams) {
LocalTensor<T> cacheLocal = vecInQueue_.AllocTensor<T>();
for (uint32_t i = 0; i < splitNum_; ++i) {
DataCopy(cacheLocal[i * dstFactor_], cacheGm[i * srcFactor_ + offsetBlock], repeatParams);
}
vecInQueue_.EnQue(cacheLocal);
}
__aicore__ inline void CopyOut(GlobalTensor<T> &cacheGm, uint32_t offsetBlock) {
LocalTensor<T> cacheLocal = vecInQueue_.DeQue<T>();
DataCopy(cacheGm[offsetBlock], cacheLocal, copyOutLength_);
vecInQueue_.FreeTensor(cacheLocal);
}
__aicore__ inline void SetGlobalBuffers(uint32_t layerId) {
kCacheGm_.SetGlobalBuffer(GetTensorAddr<T>(layerId, kCachePtr_));
vCacheGm_.SetGlobalBuffer(GetTensorAddr<T>(layerId, vCachePtr_));
}
__aicore__ inline void Caloffset(uint32_t &startBlock, uint32_t &endBlock, uint32_t &startLayer, uint32_t &endLayer) {
uint32_t blockIdx = GetBlockIdx();
uint32_t curBlockStart;
uint32_t curBlocknum;
if (blockIdx < tailCoreNum_) {
curBlockStart = blockIdx * (blockPerCore_ + 1);
curBlocknum = blockPerCore_ + 1;
} else {
curBlockStart = blockIdx * blockPerCore_ + tailCoreNum_;
curBlocknum = blockPerCore_;
}
uint32_t curBlockEnd = curBlockStart + curBlocknum;
startBlock = curBlockStart / layerNum_;
startLayer = curBlockStart % layerNum_;
endBlock = (curBlockEnd + layerNum_ - 1) / layerNum_;
endLayer = curBlockEnd % layerNum_;
if (endLayer == 0) {
endLayer = layerNum_;
}
}
public:
__aicore__ inline void Init(GM_ADDR KCache, GM_ADDR VCache, GM_ADDR blockIDs,
TransposeKvCacheByBlockTilingData* tilingData, TPipe* tPipe) {
kCachePtr_ = KCache;
vCachePtr_ = VCache;
blockIDsGm_.SetGlobalBuffer((__gm__ int64_t*)blockIDs);
// shape info
blockSize_ = tilingData->blockSize;
headNum_ = tilingData->headNum;
headDim_ = tilingData->headDim;
splitNum_ = tilingData->splitNum;
layerNum_ = tilingData->layerNum;
// tiling info
useCoreNum_ = tilingData->useCoreNum;
blockPerCore_ = tilingData->blockPerCore;
tailCoreNum_ = tilingData->tailCoreNum;
calBlockNum_ = tilingData->calBlockNum;
tPipe->InitBuffer(vecInQueue_, 1, TOTAL_UB_SIZE);
srcFactor_ = blockSize_ * headNum_ / splitNum_ * headDim_;
dstFactor_ = headNum_ / splitNum_ * headDim_;
copyOutLength_ = blockSize_ * headNum_ * headDim_;
dataBlockSize_ = static_cast<uint32_t>(AscendC::GetDataBlockSizeInBytes());
}
__aicore__ inline void Process() {
DataCopyParams repeatParams;
repeatParams.blockCount = blockSize_;
repeatParams.blockLen = headNum_ / splitNum_ * headDim_ * sizeof(T) / dataBlockSize_;
repeatParams.srcStride = 0;
repeatParams.dstStride = (headNum_ * headDim_ - headNum_ / splitNum_ * headDim_) * sizeof(T) / dataBlockSize_;
uint32_t startBlock;
uint32_t endBlock;
uint32_t startLayer;
uint32_t endLayer;
Caloffset(startBlock, endBlock, startLayer, endLayer);
for (uint32_t i = startBlock; i < endBlock; ++i) {
int64_t blockId = blockIDsGm_.GetValue(i);
uint32_t offsetBlock = blockId * blockSize_ * headNum_ * headDim_;
uint32_t realStartLayer;
uint32_t realEndLayer;
if (i == startBlock) {
realStartLayer = startLayer;
} else {
realStartLayer = 0;
}
if (i == (endBlock - 1)) {
realEndLayer = endLayer;
} else {
realEndLayer = layerNum_;
}
for (uint32_t layerId = realStartLayer; layerId < realEndLayer; ++layerId) {
SetGlobalBuffers(layerId);
CopyIn(kCacheGm_, offsetBlock, repeatParams);
CopyOut(kCacheGm_, offsetBlock);
CopyIn(vCacheGm_, offsetBlock, repeatParams);
CopyOut(vCacheGm_, offsetBlock);
}
}
}
};

View File

@@ -0,0 +1,190 @@
#include "common.h"
template <typename T, uint32_t DB, bool needHandleUnFactorSplit>
class TransposeKvCacheByBlockKernelGeneral {
protected:
TQueBind<TPosition::VECIN, TPosition::VECOUT, 1> queBind_;
GlobalTensor<T> kCacheGm_;
GlobalTensor<T> vCacheGm_;
GlobalTensor<int64_t> blockIDsGm_;
GM_ADDR kCachePtr_;
GM_ADDR vCachePtr_;
// shape info
uint32_t blockNum_;
uint32_t blockSize_;
uint32_t headNum_;
uint32_t headDim_;
uint32_t splitNum_;
uint32_t layerNum_;
uint32_t headNumSplited_;
uint32_t blockSizeSplitNum_;
// tiling info
uint32_t useCoreNum_;
uint32_t blockPerCore_;
uint32_t tailCoreNum_;
uint32_t calBlockNum_;
uint32_t srcFactor_;
uint32_t dstFactor_;
uint32_t copyOutLength_;
uint32_t blockSizePerTime_;
uint32_t blockSizePerTimeTail_;
uint32_t blockIdx_;
uint32_t dataBlockSize_;
bool needSync_;
__aicore__ inline void CopyIn(GlobalTensor<T> &cacheGm, uint32_t offsetBlock, DataCopyParams &repeatParams) {
LocalTensor<T> cacheLocal = queBind_.AllocTensor<T>();
for (uint32_t i = 0; i < splitNum_; ++i) {
DataCopy(cacheLocal[i * dstFactor_], cacheGm[i * srcFactor_ + offsetBlock], repeatParams);
}
queBind_.EnQue(cacheLocal);
}
__aicore__ inline void CopyOut(GlobalTensor<T> &cacheGm, uint32_t offsetBlock) {
LocalTensor<T> cacheLocal = queBind_.DeQue<T>();
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8);
AscendC::CrossCoreWaitFlag(0x8);
DataCopy(cacheGm[offsetBlock], cacheLocal, copyOutLength_);
queBind_.FreeTensor(cacheLocal);
}
__aicore__ inline void SetGlobalBuffers(uint32_t layerId) {
kCacheGm_.SetGlobalBuffer(GetTensorAddr<T>(layerId, kCachePtr_));
vCacheGm_.SetGlobalBuffer(GetTensorAddr<T>(layerId, vCachePtr_));
}
__aicore__ inline void Caloffset(uint32_t &startBlock, uint32_t &endBlock, uint32_t &startLayer, uint32_t &endLayer) {
uint32_t curBlockStart;
uint32_t curBlocknum;
uint32_t groupBlockIdx = blockIdx_ / blockSizeSplitNum_;
if (groupBlockIdx < tailCoreNum_) {
needSync_ = false;
curBlockStart = groupBlockIdx * (blockPerCore_ + 1);
curBlocknum = blockPerCore_ + 1;
} else {
needSync_ = true;
curBlockStart = groupBlockIdx * blockPerCore_ + tailCoreNum_;
curBlocknum = blockPerCore_;
}
uint32_t curBlockEnd = curBlockStart + curBlocknum;
startBlock = curBlockStart / layerNum_;
startLayer = curBlockStart % layerNum_;
endBlock = (curBlockEnd + layerNum_ - 1) / layerNum_;
endLayer = curBlockEnd % layerNum_;
if (endLayer == 0) {
endLayer = layerNum_;
}
}
public:
__aicore__ inline void Init(GM_ADDR KCache, GM_ADDR VCache, GM_ADDR blockIDs,
TransposeKvCacheByBlockTilingData* tilingData, TPipe* tPipe) {
kCachePtr_ = KCache;
vCachePtr_ = VCache;
blockIDsGm_.SetGlobalBuffer((__gm__ int64_t*)blockIDs);
blockIdx_ = GetBlockIdx();
// shape info
blockSize_ = tilingData->blockSize;
headNum_ = tilingData->headNum;
headDim_ = tilingData->headDim;
splitNum_ = tilingData->splitNum;
layerNum_ = tilingData->layerNum;
// tiling info
useCoreNum_ = tilingData->useCoreNum;
blockPerCore_ = tilingData->blockPerCore;
tailCoreNum_ = tilingData->tailCoreNum;
calBlockNum_ = tilingData->calBlockNum;
blockSizeSplitNum_ = tilingData->blockSizeSplitNum;
blockSizePerTime_ = tilingData->blockSizePerTime;
blockSizePerTimeTail_ = tilingData->blockSizePerTimeTail;
headNumSplited_ = headNum_ / splitNum_;
tPipe->InitBuffer(queBind_, DB, TOTAL_UB_SIZE / DB);
srcFactor_ = blockSize_ * headNumSplited_ * headDim_;
dstFactor_ = headNumSplited_ * headDim_;
copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_;
dataBlockSize_ = static_cast<uint32_t>(AscendC::GetDataBlockSizeInBytes());
}
__aicore__ inline void Process() {
DataCopyParams repeatParams;
repeatParams.blockCount = blockSizePerTime_;
repeatParams.blockLen = headNumSplited_ * headDim_ * sizeof(T) / dataBlockSize_;
repeatParams.srcStride = 0;
repeatParams.dstStride = (headNum_ * headDim_ - headNumSplited_ * headDim_) * sizeof(T) / dataBlockSize_;
uint32_t startBlock;
uint32_t endBlock;
uint32_t startLayer;
uint32_t endLayer;
Caloffset(startBlock, endBlock, startLayer, endLayer);
for (uint32_t i = startBlock; i < endBlock; ++i) {
int64_t blockId = blockIDsGm_.GetValue(i);
uint32_t offsetBlock = blockId * blockSize_ * headNum_ * headDim_;
uint32_t realStartLayer;
uint32_t realEndLayer;
if (i == startBlock) {
realStartLayer = startLayer;
} else {
realStartLayer = 0;
}
if (i == (endBlock - 1)) {
realEndLayer = endLayer;
} else {
realEndLayer = layerNum_;
}
for (uint32_t layerId = realStartLayer; layerId < realEndLayer; ++layerId) {
SetGlobalBuffers(layerId);
uint32_t blockSizeIndex = blockIdx_ % blockSizeSplitNum_;
uint32_t srcOffset;
uint32_t dstOffset;
if constexpr (needHandleUnFactorSplit) {
// handle tail
if (blockSizeIndex >= blockSizePerTimeTail_) {
repeatParams.blockCount = (blockSizePerTime_ - 1);
copyOutLength_ = (blockSizePerTime_ - 1) * headNum_ * headDim_;
srcOffset = (blockSizeIndex * blockSizePerTime_ - (blockSizeIndex - blockSizePerTimeTail_)) * headNumSplited_ * headDim_;
dstOffset = (blockSizeIndex * blockSizePerTime_ - (blockSizeIndex - blockSizePerTimeTail_)) * headNum_ * headDim_;
} else {
repeatParams.blockCount = blockSizePerTime_;
copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_;
srcOffset = blockSizeIndex * blockSizePerTime_ * headNumSplited_ * headDim_;
dstOffset = blockSizeIndex * blockSizePerTime_ * headNum_ * headDim_;
}
} else {
repeatParams.blockCount = blockSizePerTime_;
copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_;
srcOffset = blockSizeIndex * blockSizePerTime_ * headNumSplited_ * headDim_;
dstOffset = blockSizeIndex * blockSizePerTime_ * headNum_ * headDim_;
}
CopyIn(kCacheGm_, offsetBlock + srcOffset, repeatParams);
CopyOut(kCacheGm_, offsetBlock + dstOffset);
CopyIn(vCacheGm_, offsetBlock + srcOffset, repeatParams);
CopyOut(vCacheGm_, offsetBlock + dstOffset);
}
}
if (needSync_) {
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8);
AscendC::CrossCoreWaitFlag(0x8);
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8);
AscendC::CrossCoreWaitFlag(0x8);
}
}
};

View File

@@ -0,0 +1,36 @@
#include "kernel_operator.h"
#include "full_load.h"
#include "general.h"
extern "C" __global__ __aicore__ void transpose_kv_cache_by_block(GM_ADDR KCache, GM_ADDR VCache, GM_ADDR blockIDs, GM_ADDR workspace, GM_ADDR tiling) {
GET_TILING_DATA(tiling_data, tiling);
TPipe tPipe;
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIV_1_0);
if (TILING_KEY_IS(0)) {
// full load not db
TransposeKvCacheByBlockKernelFullLoad<DTYPE_KCACHE> kernel;
kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe);
kernel.Process();
} else if (TILING_KEY_IS(1)) {
// db \ align split blockSize
TransposeKvCacheByBlockKernelGeneral<DTYPE_KCACHE, uint32_t(2), false> kernel;
kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe);
kernel.Process();
} else if (TILING_KEY_IS(2)) {
// not db \ align split blockSize
TransposeKvCacheByBlockKernelGeneral<DTYPE_KCACHE, uint32_t(1), false> kernel;
kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe);
kernel.Process();
} else if (TILING_KEY_IS(3)) {
// db \ unalign split blockSize
TransposeKvCacheByBlockKernelGeneral<DTYPE_KCACHE, uint32_t(2), true> kernel;
kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe);
kernel.Process();
} else if (TILING_KEY_IS(4)) {
// not db \ unalign split blockSize
TransposeKvCacheByBlockKernelGeneral<DTYPE_KCACHE, uint32_t(1), true> kernel;
kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe);
kernel.Process();
}
}