### What this PR does / why we need it?
As #2947 describe, we need to transpose kv cache layout after GQA kv
transfer when prefill and decode tensor parallel size are heterogeneous,
in the previous implementation, we use `npu_paged_cache_load ` +
`tranpose` + `_npu_reshape_and_cache` to do this work.
But obviously, it is not an efficient plan, the ops above need to be
called for each layer, which introduces 3 * layer_num kernel launch, and
6 * layer_num data movement between L1 Cache and HBM for one request on
decode node. Usually, decode node uses graph mode, so these op kernels
will be called between decode forward launched by an async thread in
mooncacke connector, this kernels maybe last for several decode forward
and TTFT will increase by 3~4 decode forward time.
In this PR, we implement an AscendC fused op
`transpose_kv_cache_by_block` to do this with only once kernel launch
and move data between L1 Cache and HBM only once.
After using this fused op, the time cost in transpose kv cacke layout
can be decreased to 0.24ms from 7ms in UT on 910C, and in PD
disaggregation scenario, TTFT can decrease about 90 ~ 110 ms in
qwen3-235B.
| request_num | original | fused_op|
|:----------------------:|:---------------:|:-------------------:|
| 1 | 643 ms | 578 ms |
| 128 | 1480 ms | 1368 ms |
### Does this PR introduce _any_ user-facing change?
Use fused op by default, incase the op has bug in any scenario, provide
fallback choice using env to disable it.
**DISABLE fused op by add following env**
`export VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK=0`
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: lidenghui <lidenghui1110@gmail.com>
190 lines
7.4 KiB
C++
190 lines
7.4 KiB
C++
#include "common.h"
|
|
|
|
template <typename T, uint32_t DB, bool needHandleUnFactorSplit>
|
|
class TransposeKvCacheByBlockKernelGeneral {
|
|
protected:
|
|
TQueBind<TPosition::VECIN, TPosition::VECOUT, 1> queBind_;
|
|
GlobalTensor<T> kCacheGm_;
|
|
GlobalTensor<T> vCacheGm_;
|
|
GlobalTensor<int64_t> blockIDsGm_;
|
|
|
|
GM_ADDR kCachePtr_;
|
|
GM_ADDR vCachePtr_;
|
|
|
|
// shape info
|
|
uint32_t blockNum_;
|
|
uint32_t blockSize_;
|
|
uint32_t headNum_;
|
|
uint32_t headDim_;
|
|
uint32_t splitNum_;
|
|
uint32_t layerNum_;
|
|
uint32_t headNumSplited_;
|
|
uint32_t blockSizeSplitNum_;
|
|
|
|
// tiling info
|
|
uint32_t useCoreNum_;
|
|
uint32_t blockPerCore_;
|
|
uint32_t tailCoreNum_;
|
|
uint32_t calBlockNum_;
|
|
|
|
uint32_t srcFactor_;
|
|
uint32_t dstFactor_;
|
|
uint32_t copyOutLength_;
|
|
|
|
uint32_t blockSizePerTime_;
|
|
uint32_t blockSizePerTimeTail_;
|
|
|
|
uint32_t blockIdx_;
|
|
uint32_t dataBlockSize_;
|
|
bool needSync_;
|
|
|
|
__aicore__ inline void CopyIn(GlobalTensor<T> &cacheGm, uint32_t offsetBlock, DataCopyParams &repeatParams) {
|
|
LocalTensor<T> cacheLocal = queBind_.AllocTensor<T>();
|
|
for (uint32_t i = 0; i < splitNum_; ++i) {
|
|
DataCopy(cacheLocal[i * dstFactor_], cacheGm[i * srcFactor_ + offsetBlock], repeatParams);
|
|
}
|
|
queBind_.EnQue(cacheLocal);
|
|
}
|
|
|
|
__aicore__ inline void CopyOut(GlobalTensor<T> &cacheGm, uint32_t offsetBlock) {
|
|
LocalTensor<T> cacheLocal = queBind_.DeQue<T>();
|
|
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8);
|
|
AscendC::CrossCoreWaitFlag(0x8);
|
|
DataCopy(cacheGm[offsetBlock], cacheLocal, copyOutLength_);
|
|
queBind_.FreeTensor(cacheLocal);
|
|
}
|
|
|
|
__aicore__ inline void SetGlobalBuffers(uint32_t layerId) {
|
|
kCacheGm_.SetGlobalBuffer(GetTensorAddr<T>(layerId, kCachePtr_));
|
|
vCacheGm_.SetGlobalBuffer(GetTensorAddr<T>(layerId, vCachePtr_));
|
|
}
|
|
|
|
__aicore__ inline void Caloffset(uint32_t &startBlock, uint32_t &endBlock, uint32_t &startLayer, uint32_t &endLayer) {
|
|
|
|
uint32_t curBlockStart;
|
|
uint32_t curBlocknum;
|
|
uint32_t groupBlockIdx = blockIdx_ / blockSizeSplitNum_;
|
|
if (groupBlockIdx < tailCoreNum_) {
|
|
needSync_ = false;
|
|
curBlockStart = groupBlockIdx * (blockPerCore_ + 1);
|
|
curBlocknum = blockPerCore_ + 1;
|
|
} else {
|
|
needSync_ = true;
|
|
curBlockStart = groupBlockIdx * blockPerCore_ + tailCoreNum_;
|
|
curBlocknum = blockPerCore_;
|
|
}
|
|
uint32_t curBlockEnd = curBlockStart + curBlocknum;
|
|
startBlock = curBlockStart / layerNum_;
|
|
startLayer = curBlockStart % layerNum_;
|
|
endBlock = (curBlockEnd + layerNum_ - 1) / layerNum_;
|
|
endLayer = curBlockEnd % layerNum_;
|
|
if (endLayer == 0) {
|
|
endLayer = layerNum_;
|
|
}
|
|
}
|
|
|
|
|
|
public:
|
|
__aicore__ inline void Init(GM_ADDR KCache, GM_ADDR VCache, GM_ADDR blockIDs,
|
|
TransposeKvCacheByBlockTilingData* tilingData, TPipe* tPipe) {
|
|
kCachePtr_ = KCache;
|
|
vCachePtr_ = VCache;
|
|
blockIDsGm_.SetGlobalBuffer((__gm__ int64_t*)blockIDs);
|
|
blockIdx_ = GetBlockIdx();
|
|
// shape info
|
|
blockSize_ = tilingData->blockSize;
|
|
headNum_ = tilingData->headNum;
|
|
headDim_ = tilingData->headDim;
|
|
splitNum_ = tilingData->splitNum;
|
|
layerNum_ = tilingData->layerNum;
|
|
// tiling info
|
|
useCoreNum_ = tilingData->useCoreNum;
|
|
blockPerCore_ = tilingData->blockPerCore;
|
|
tailCoreNum_ = tilingData->tailCoreNum;
|
|
calBlockNum_ = tilingData->calBlockNum;
|
|
blockSizeSplitNum_ = tilingData->blockSizeSplitNum;
|
|
blockSizePerTime_ = tilingData->blockSizePerTime;
|
|
blockSizePerTimeTail_ = tilingData->blockSizePerTimeTail;
|
|
headNumSplited_ = headNum_ / splitNum_;
|
|
|
|
tPipe->InitBuffer(queBind_, DB, TOTAL_UB_SIZE / DB);
|
|
srcFactor_ = blockSize_ * headNumSplited_ * headDim_;
|
|
dstFactor_ = headNumSplited_ * headDim_;
|
|
copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_;
|
|
dataBlockSize_ = static_cast<uint32_t>(AscendC::GetDataBlockSizeInBytes());
|
|
}
|
|
|
|
__aicore__ inline void Process() {
|
|
DataCopyParams repeatParams;
|
|
repeatParams.blockCount = blockSizePerTime_;
|
|
repeatParams.blockLen = headNumSplited_ * headDim_ * sizeof(T) / dataBlockSize_;
|
|
repeatParams.srcStride = 0;
|
|
repeatParams.dstStride = (headNum_ * headDim_ - headNumSplited_ * headDim_) * sizeof(T) / dataBlockSize_;
|
|
|
|
uint32_t startBlock;
|
|
uint32_t endBlock;
|
|
uint32_t startLayer;
|
|
uint32_t endLayer;
|
|
|
|
Caloffset(startBlock, endBlock, startLayer, endLayer);
|
|
|
|
for (uint32_t i = startBlock; i < endBlock; ++i) {
|
|
int64_t blockId = blockIDsGm_.GetValue(i);
|
|
uint32_t offsetBlock = blockId * blockSize_ * headNum_ * headDim_;
|
|
uint32_t realStartLayer;
|
|
uint32_t realEndLayer;
|
|
if (i == startBlock) {
|
|
realStartLayer = startLayer;
|
|
} else {
|
|
realStartLayer = 0;
|
|
}
|
|
|
|
if (i == (endBlock - 1)) {
|
|
realEndLayer = endLayer;
|
|
} else {
|
|
realEndLayer = layerNum_;
|
|
}
|
|
for (uint32_t layerId = realStartLayer; layerId < realEndLayer; ++layerId) {
|
|
SetGlobalBuffers(layerId);
|
|
uint32_t blockSizeIndex = blockIdx_ % blockSizeSplitNum_;
|
|
uint32_t srcOffset;
|
|
uint32_t dstOffset;
|
|
if constexpr (needHandleUnFactorSplit) {
|
|
// handle tail
|
|
if (blockSizeIndex >= blockSizePerTimeTail_) {
|
|
repeatParams.blockCount = (blockSizePerTime_ - 1);
|
|
copyOutLength_ = (blockSizePerTime_ - 1) * headNum_ * headDim_;
|
|
srcOffset = (blockSizeIndex * blockSizePerTime_ - (blockSizeIndex - blockSizePerTimeTail_)) * headNumSplited_ * headDim_;
|
|
dstOffset = (blockSizeIndex * blockSizePerTime_ - (blockSizeIndex - blockSizePerTimeTail_)) * headNum_ * headDim_;
|
|
} else {
|
|
repeatParams.blockCount = blockSizePerTime_;
|
|
copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_;
|
|
srcOffset = blockSizeIndex * blockSizePerTime_ * headNumSplited_ * headDim_;
|
|
dstOffset = blockSizeIndex * blockSizePerTime_ * headNum_ * headDim_;
|
|
}
|
|
} else {
|
|
repeatParams.blockCount = blockSizePerTime_;
|
|
copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_;
|
|
srcOffset = blockSizeIndex * blockSizePerTime_ * headNumSplited_ * headDim_;
|
|
dstOffset = blockSizeIndex * blockSizePerTime_ * headNum_ * headDim_;
|
|
}
|
|
|
|
CopyIn(kCacheGm_, offsetBlock + srcOffset, repeatParams);
|
|
CopyOut(kCacheGm_, offsetBlock + dstOffset);
|
|
|
|
CopyIn(vCacheGm_, offsetBlock + srcOffset, repeatParams);
|
|
CopyOut(vCacheGm_, offsetBlock + dstOffset);
|
|
}
|
|
|
|
}
|
|
|
|
if (needSync_) {
|
|
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8);
|
|
AscendC::CrossCoreWaitFlag(0x8);
|
|
|
|
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8);
|
|
AscendC::CrossCoreWaitFlag(0x8);
|
|
}
|
|
|
|
}
|
|
}; |