[Kernel] Add AscendC fused op transpose_kv_cache_by_block to speed up GQA transfer (#6366)
### What this PR does / why we need it?
As #2947 describe, we need to transpose kv cache layout after GQA kv
transfer when prefill and decode tensor parallel size are heterogeneous,
in the previous implementation, we use `npu_paged_cache_load ` +
`tranpose` + `_npu_reshape_and_cache` to do this work.
But obviously, it is not an efficient plan, the ops above need to be
called for each layer, which introduces 3 * layer_num kernel launch, and
6 * layer_num data movement between L1 Cache and HBM for one request on
decode node. Usually, decode node uses graph mode, so these op kernels
will be called between decode forward launched by an async thread in
mooncacke connector, this kernels maybe last for several decode forward
and TTFT will increase by 3~4 decode forward time.
In this PR, we implement an AscendC fused op
`transpose_kv_cache_by_block` to do this with only once kernel launch
and move data between L1 Cache and HBM only once.
After using this fused op, the time cost in transpose kv cacke layout
can be decreased to 0.24ms from 7ms in UT on 910C, and in PD
disaggregation scenario, TTFT can decrease about 90 ~ 110 ms in
qwen3-235B.
| request_num | original | fused_op|
|:----------------------:|:---------------:|:-------------------:|
| 1 | 643 ms | 578 ms |
| 128 | 1480 ms | 1368 ms |
### Does this PR introduce _any_ user-facing change?
Use fused op by default, incase the op has bug in any scenario, provide
fallback choice using env to disable it.
**DISABLE fused op by add following env**
`export VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK=0`
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: lidenghui <lidenghui1110@gmail.com>
This commit is contained in:
@@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
|
||||
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
|
||||
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
|
||||
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;"
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;"
|
||||
SOC_ARG="ascend910b"
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
# ASCEND910C (A3) series
|
||||
@@ -81,6 +81,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
"moe_gating_top_k"
|
||||
"add_rms_norm_bias"
|
||||
"apply_top_k_top_p_custom"
|
||||
"transpose_kv_cache_by_block"
|
||||
)
|
||||
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
|
||||
SOC_ARG="ascend910_93"
|
||||
|
||||
@@ -1343,6 +1343,22 @@ std::tuple<at::Tensor,at::Tensor, at::Tensor> npu_add_rms_norm_bias(
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y, rstd, x);
|
||||
}
|
||||
|
||||
void transpose_kv_cache_by_block(
|
||||
const at::TensorList &kCache,
|
||||
const at::TensorList &vCache,
|
||||
const at::Tensor &blockIDs,
|
||||
int64_t blockSize,
|
||||
int64_t headNum,
|
||||
int64_t headDim,
|
||||
int64_t splitNum,
|
||||
int64_t layerNum)
|
||||
{
|
||||
|
||||
EXEC_NPU_CMD(aclnnTransposeKvCacheByBlock, kCache, vCache, blockIDs,
|
||||
blockSize, headNum, headDim, splitNum, layerNum);
|
||||
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
@@ -1521,4 +1537,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
|
||||
ops.def("npu_apply_top_k_top_p(Tensor logits, Tensor? p=None, Tensor? k=None) -> Tensor");
|
||||
ops.impl("npu_apply_top_k_top_p", torch::kPrivateUse1, &vllm_ascend::npu_apply_top_k_top_p);
|
||||
ops.def(
|
||||
"transpose_kv_cache_by_block(Tensor[] kCache, Tensor[] vCache, Tensor blockIDs, int blockSize, int headNum, int headDim, int splitNum, int layerNum) -> ()"
|
||||
);
|
||||
ops.impl("transpose_kv_cache_by_block", torch::kPrivateUse1, &vllm_ascend::transpose_kv_cache_by_block);
|
||||
}
|
||||
|
||||
@@ -435,6 +435,20 @@ std::tuple<at::Tensor,at::Tensor, at::Tensor> npu_add_rms_norm_bias_meta(
|
||||
at::Tensor x = at::empty_symint(x1.sym_sizes(), x1.options());
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y, rstd, x);
|
||||
}
|
||||
|
||||
void transpose_kv_cache_by_block_meta(
|
||||
const at::TensorList &k_cache,
|
||||
const at::TensorList &v_cache,
|
||||
const at::Tensor &block_ids,
|
||||
int64_t block_size,
|
||||
int64_t head_num,
|
||||
int64_t head_dim,
|
||||
int64_t split_num,
|
||||
int64_t layer_num)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace meta
|
||||
} // namespace vllm_ascend
|
||||
|
||||
@@ -475,5 +489,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
ops.impl("moe_gating_top_k", &vllm_ascend::meta::moe_gating_top_k_meta);
|
||||
// Add_Rms_Norm_Bias
|
||||
ops.impl("npu_add_rms_norm_bias", &vllm_ascend::meta::npu_add_rms_norm_bias_meta);
|
||||
// transpose_kv_cache_by_block
|
||||
ops.impl("transpose_kv_cache_by_block", &vllm_ascend::meta::transpose_kv_cache_by_block_meta);
|
||||
}
|
||||
}
|
||||
|
||||
40
csrc/transpose_kv_cache_by_block/op_host/CMakeLists.txt
Normal file
40
csrc/transpose_kv_cache_by_block/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,40 @@
|
||||
# This program is free software, you can redistribute it and/or modify it.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME TransposeKvCacheByBlock
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
-mllvm -cce-aicore-hoist-movemask=false
|
||||
--op_relocatable_kernel_binary=true
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnn PRIVATE
|
||||
transpose_kv_cache_by_block_def.cpp
|
||||
)
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
transpose_kv_cache_by_block_tiling.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(opmaster_ct PRIVATE
|
||||
transpose_kv_cache_by_block_tiling.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
transpose_kv_cache_by_block_proto.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class TransposeKvCacheByBlock : public OpDef {
|
||||
public:
|
||||
explicit TransposeKvCacheByBlock(const char* name) : OpDef(name)
|
||||
{
|
||||
this->Input("KCache")
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("VCache")
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("blockIDs")
|
||||
.ParamType(REQUIRED)
|
||||
.DataTypeList({ge::DT_INT64})
|
||||
.FormatList({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
this->Attr("blockSize").Int();
|
||||
this->Attr("headNum").Int();
|
||||
this->Attr("headDim").Int();
|
||||
this->Attr("splitNum").Int();
|
||||
this->Attr("layerNum").Int();
|
||||
|
||||
this->AICore().AddConfig("ascend910b");
|
||||
this->AICore().AddConfig("ascend910_93");
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(TransposeKvCacheByBlock);
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify it.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file transpose_kv_cache_by_block_proto.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include <graph/utils/type_utils.h>
|
||||
#include <register/op_impl_registry.h>
|
||||
#include "error/ops_error.h"
|
||||
|
||||
using namespace ge;
|
||||
|
||||
namespace ops {
|
||||
|
||||
static ge::graphStatus InferShapeTransposeKvCacheByBlock(gert::InferShapeContext* context)
|
||||
{
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus InferDataTypeTransposeKvCacheByBlock(gert::InferDataTypeContext *context)
|
||||
{
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_INFERSHAPE(TransposeKvCacheByBlock)
|
||||
.InferShape(InferShapeTransposeKvCacheByBlock)
|
||||
.InferDataType(InferDataTypeTransposeKvCacheByBlock);
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,182 @@
|
||||
#include "transpose_kv_cache_by_block_tiling.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "log/ops_log.h"
|
||||
#include <algorithm>
|
||||
|
||||
namespace optiling {
|
||||
|
||||
constexpr uint64_t DATA_SIZE = 2;
|
||||
constexpr uint64_t BLOCK_SIZE = 32;
|
||||
constexpr uint64_t DB_ON = 2;
|
||||
|
||||
constexpr uint32_t FULL_LOAD = 0;
|
||||
constexpr uint32_t SPLIT_BLOCK_SIZE_ALIGNED_AND_DB = 1;
|
||||
constexpr uint32_t SPLIT_BLOCK_SIZE_UNALIGNED_AND_DB = 3;
|
||||
constexpr uint32_t SPLIT_BLOCK_SIZE_ALIGNED_AND_NOT_DB = 2;
|
||||
constexpr uint32_t SPLIT_BLOCK_SIZE_UNALIGNED_AND_NOT_DB = 4;
|
||||
|
||||
void findFactorsOptimized(std::vector<int64_t> &factors, int64_t n) {
|
||||
|
||||
for (int64_t i = 1; i * i <= n; i++) {
|
||||
if (n % i == 0) {
|
||||
factors.push_back(i);
|
||||
|
||||
if (i != n / i) {
|
||||
factors.push_back(n / i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort(factors.begin(), factors.end());
|
||||
}
|
||||
|
||||
ge::graphStatus CalTiling(gert::TilingContext* context, TransposeKvCacheByBlockTilingData &tiling)
|
||||
{
|
||||
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
|
||||
OPS_LOG_E_IF_NULL(context, platformInfoPtr, return ge::GRAPH_FAILED);
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||||
int64_t useCoreNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint64_t ubSize;
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
|
||||
|
||||
auto attr = context->GetAttrs();
|
||||
OPS_LOG_E_IF_NULL(context, attr, return ge::GRAPH_FAILED);
|
||||
const int64_t* blockSizePtr = attr->GetAttrPointer<int64_t>(0);
|
||||
const int64_t* headNumPtr = attr->GetAttrPointer<int64_t>(1);
|
||||
const int64_t* headDimPtr = attr->GetAttrPointer<int64_t>(2);
|
||||
const int64_t* splitNumPtr = attr->GetAttrPointer<int64_t>(3);
|
||||
const int64_t* layerNumPtr = attr->GetAttrPointer<int64_t>(4);
|
||||
OPS_CHECK(blockSizePtr == nullptr || headNumPtr == nullptr || headDimPtr == nullptr ||
|
||||
splitNumPtr == nullptr || layerNumPtr == nullptr,
|
||||
OPS_LOG_E(context->GetNodeName(), "Get attr failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
auto blockIDsTensor = context->GetDynamicInputTensor(2, 0);
|
||||
OPS_LOG_E_IF_NULL(context, blockIDsTensor, return ge::GRAPH_FAILED);
|
||||
|
||||
gert::Shape blockIDsTensorShape = blockIDsTensor->GetStorageShape();
|
||||
int64_t calBlockNum = static_cast<int64_t>(blockIDsTensorShape.GetDim(0));
|
||||
|
||||
tiling.set_calBlockNum(static_cast<uint32_t>(calBlockNum));
|
||||
|
||||
int64_t blockSize = *blockSizePtr;
|
||||
int64_t headNum = *headNumPtr;
|
||||
int64_t headDim = *headDimPtr;
|
||||
int64_t splitNum = *splitNumPtr;
|
||||
int64_t layerNum = *layerNumPtr;
|
||||
uint32_t tilingKey = FULL_LOAD;
|
||||
|
||||
if (headDim * DATA_SIZE % BLOCK_SIZE != 0) {
|
||||
OPS_LOG_E(context, "headDim * DATA_SIZE must be a multiple of 32 bytes.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::vector<int64_t> factors;
|
||||
findFactorsOptimized(factors, useCoreNum);
|
||||
|
||||
uint32_t factorIndex = 0;
|
||||
bool findSplitNum = true;
|
||||
int64_t blockSizeSplitNum = factors[factorIndex];
|
||||
uint64_t dataSizeloadOnce = blockSize * headNum * headDim * DATA_SIZE;
|
||||
// if can full load, not split blockSize and db
|
||||
if (dataSizeloadOnce > ubSize) {
|
||||
tilingKey = SPLIT_BLOCK_SIZE_ALIGNED_AND_DB;
|
||||
// split blockSize and db
|
||||
while (dataSizeloadOnce > (ubSize / DB_ON)) {
|
||||
factorIndex += 1;
|
||||
if (factorIndex == factors.size()) {
|
||||
tilingKey = FULL_LOAD;
|
||||
findSplitNum = false;
|
||||
break;
|
||||
}
|
||||
blockSizeSplitNum = factors[factorIndex];
|
||||
dataSizeloadOnce = ((blockSize + blockSizeSplitNum - 1) / blockSizeSplitNum) * headNum * headDim * DATA_SIZE;
|
||||
}
|
||||
if (tilingKey == SPLIT_BLOCK_SIZE_ALIGNED_AND_DB && (blockSize % blockSizeSplitNum != 0)) {
|
||||
tilingKey = SPLIT_BLOCK_SIZE_UNALIGNED_AND_DB;
|
||||
}
|
||||
}
|
||||
|
||||
if (!findSplitNum) {
|
||||
tilingKey = SPLIT_BLOCK_SIZE_ALIGNED_AND_NOT_DB;
|
||||
// split blockSize but not db
|
||||
findSplitNum = true;
|
||||
factorIndex = 0;
|
||||
blockSizeSplitNum = factors[factorIndex];
|
||||
dataSizeloadOnce = blockSize * headNum * headDim * DATA_SIZE;
|
||||
while (dataSizeloadOnce > ubSize) {
|
||||
factorIndex += 1;
|
||||
if (factorIndex == factors.size()) {
|
||||
tilingKey = FULL_LOAD;
|
||||
findSplitNum = false;
|
||||
break;
|
||||
}
|
||||
blockSizeSplitNum = factors[factorIndex];
|
||||
dataSizeloadOnce = ((blockSize + blockSizeSplitNum - 1) / blockSizeSplitNum) * headNum * headDim * DATA_SIZE;
|
||||
}
|
||||
if (tilingKey == SPLIT_BLOCK_SIZE_ALIGNED_AND_NOT_DB && (blockSize % blockSizeSplitNum != 0)) {
|
||||
tilingKey = SPLIT_BLOCK_SIZE_UNALIGNED_AND_NOT_DB;
|
||||
}
|
||||
}
|
||||
|
||||
// headNum * headDim too large
|
||||
if (!findSplitNum) {
|
||||
OPS_LOG_E(context, "headNum * headDim * sizeof(half) > ubSize "
|
||||
"or blockSize * headNum * headDim * sizeof(half) > ubSize * vectorCoreNum. "
|
||||
"Currently, splitting headNum or headDim is not supported.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
tiling.set_blockSizePerTime(static_cast<uint32_t>((blockSize + blockSizeSplitNum - 1) / blockSizeSplitNum));
|
||||
tiling.set_blockSizePerTimeTail(static_cast<uint32_t>(blockSize % blockSizeSplitNum));
|
||||
tiling.set_blockSizeSplitNum(static_cast<uint32_t>(blockSizeSplitNum));
|
||||
|
||||
tiling.set_blockSize(static_cast<uint32_t>(blockSize));
|
||||
tiling.set_headNum(static_cast<uint32_t>(headNum));
|
||||
tiling.set_headDim(static_cast<uint32_t>(headDim));
|
||||
tiling.set_splitNum(static_cast<uint32_t>(splitNum));
|
||||
tiling.set_layerNum(static_cast<uint32_t>(layerNum));
|
||||
|
||||
int64_t totalRound = layerNum * calBlockNum;
|
||||
|
||||
if ((totalRound * blockSizeSplitNum) < useCoreNum) {
|
||||
useCoreNum = totalRound * blockSizeSplitNum;
|
||||
}
|
||||
int64_t blockPerCore = totalRound / (useCoreNum / blockSizeSplitNum);
|
||||
int64_t tailCoreNum = totalRound % (useCoreNum / blockSizeSplitNum);
|
||||
|
||||
tiling.set_useCoreNum(static_cast<uint32_t>(useCoreNum));
|
||||
tiling.set_blockPerCore(static_cast<uint32_t>(blockPerCore));
|
||||
tiling.set_tailCoreNum(static_cast<uint32_t>(tailCoreNum));
|
||||
context->SetBlockDim(useCoreNum);
|
||||
context->SetTilingKey(tilingKey);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
static ge::graphStatus TransposeKvCacheByBlockTilingFunc(gert::TilingContext* context)
|
||||
{
|
||||
|
||||
TransposeKvCacheByBlockTilingData tiling;
|
||||
auto status = CalTiling(context, tiling);
|
||||
OP_CHECK(status != ge::GRAPH_SUCCESS, OPS_LOG_E(context->GetNodeName(), "Cal tiling failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
|
||||
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
struct TransposeKvCacheByBlockCompileInfo {};
|
||||
ge::graphStatus TilingParseForTransposeKvCacheByBlock(gert::TilingParseContext *context)
|
||||
{
|
||||
(void)context;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(TransposeKvCacheByBlock)
|
||||
.Tiling(TransposeKvCacheByBlockTilingFunc)
|
||||
.TilingParse<TransposeKvCacheByBlockCompileInfo>(TilingParseForTransposeKvCacheByBlock);
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
#include "register/tilingdata_base.h"
|
||||
|
||||
namespace optiling {
|
||||
BEGIN_TILING_DATA_DEF(TransposeKvCacheByBlockTilingData)
|
||||
// shape info
|
||||
// TILING_DATA_FIELD_DEF(uint32_t, blockNum);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, blockSize);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, headNum);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, headDim);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, splitNum);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, layerNum);
|
||||
// tiling info
|
||||
TILING_DATA_FIELD_DEF(uint32_t, useCoreNum);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, blockPerCore);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, tailCoreNum);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, calBlockNum);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, blockSizePerTime);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, blockSizePerTimeTail);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, blockSizeSplitNum);
|
||||
END_TILING_DATA_DEF;
|
||||
|
||||
REGISTER_TILING_DATA_CLASS(TransposeKvCacheByBlock, TransposeKvCacheByBlockTilingData)
|
||||
}
|
||||
16
csrc/transpose_kv_cache_by_block/op_kernel/common.h
Normal file
16
csrc/transpose_kv_cache_by_block/op_kernel/common.h
Normal file
@@ -0,0 +1,16 @@
|
||||
#include "kernel_operator.h"
|
||||
using namespace AscendC;
|
||||
|
||||
#ifndef __OP_KERNEL_KV_CACHE_TRANSPOSE_H__
|
||||
#define __OP_KERNEL_KV_CACHE_TRANSPOSE_H__
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline __gm__ T* GetTensorAddr(uint16_t index, GM_ADDR tensorPtr) {
|
||||
__gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr);
|
||||
// The offset of the data address from the first address.
|
||||
uint64_t tensorPtrOffset = *dataAddr;
|
||||
// Moving 3 bits to the right means dividing by sizeof(uint64 t).
|
||||
__gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3);
|
||||
return reinterpret_cast<__gm__ T*>(*(retPtr + index));
|
||||
}
|
||||
#endif
|
||||
141
csrc/transpose_kv_cache_by_block/op_kernel/full_load.h
Normal file
141
csrc/transpose_kv_cache_by_block/op_kernel/full_load.h
Normal file
@@ -0,0 +1,141 @@
|
||||
#include "common.h"
|
||||
|
||||
template <typename T>
|
||||
class TransposeKvCacheByBlockKernelFullLoad {
|
||||
protected:
|
||||
TQueBind<TPosition::VECIN, TPosition::VECOUT, 1> vecInQueue_;
|
||||
GlobalTensor<T> kCacheGm_;
|
||||
GlobalTensor<T> vCacheGm_;
|
||||
GlobalTensor<int64_t> blockIDsGm_;
|
||||
|
||||
GM_ADDR kCachePtr_;
|
||||
GM_ADDR vCachePtr_;
|
||||
|
||||
// shape info
|
||||
uint32_t blockNum_;
|
||||
uint32_t blockSize_;
|
||||
uint32_t headNum_;
|
||||
uint32_t headDim_;
|
||||
uint32_t splitNum_;
|
||||
uint32_t layerNum_;
|
||||
// tiling info
|
||||
uint32_t useCoreNum_;
|
||||
uint32_t blockPerCore_;
|
||||
uint32_t tailCoreNum_;
|
||||
uint32_t calBlockNum_;
|
||||
|
||||
uint32_t srcFactor_;
|
||||
uint32_t dstFactor_;
|
||||
uint32_t copyOutLength_;
|
||||
uint32_t dataBlockSize_;
|
||||
|
||||
__aicore__ inline void CopyIn(GlobalTensor<T> &cacheGm, uint32_t offsetBlock, DataCopyParams &repeatParams) {
|
||||
LocalTensor<T> cacheLocal = vecInQueue_.AllocTensor<T>();
|
||||
for (uint32_t i = 0; i < splitNum_; ++i) {
|
||||
DataCopy(cacheLocal[i * dstFactor_], cacheGm[i * srcFactor_ + offsetBlock], repeatParams);
|
||||
}
|
||||
vecInQueue_.EnQue(cacheLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOut(GlobalTensor<T> &cacheGm, uint32_t offsetBlock) {
|
||||
LocalTensor<T> cacheLocal = vecInQueue_.DeQue<T>();
|
||||
DataCopy(cacheGm[offsetBlock], cacheLocal, copyOutLength_);
|
||||
vecInQueue_.FreeTensor(cacheLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void SetGlobalBuffers(uint32_t layerId) {
|
||||
kCacheGm_.SetGlobalBuffer(GetTensorAddr<T>(layerId, kCachePtr_));
|
||||
vCacheGm_.SetGlobalBuffer(GetTensorAddr<T>(layerId, vCachePtr_));
|
||||
}
|
||||
|
||||
__aicore__ inline void Caloffset(uint32_t &startBlock, uint32_t &endBlock, uint32_t &startLayer, uint32_t &endLayer) {
|
||||
uint32_t blockIdx = GetBlockIdx();
|
||||
uint32_t curBlockStart;
|
||||
uint32_t curBlocknum;
|
||||
|
||||
if (blockIdx < tailCoreNum_) {
|
||||
curBlockStart = blockIdx * (blockPerCore_ + 1);
|
||||
curBlocknum = blockPerCore_ + 1;
|
||||
} else {
|
||||
curBlockStart = blockIdx * blockPerCore_ + tailCoreNum_;
|
||||
curBlocknum = blockPerCore_;
|
||||
}
|
||||
uint32_t curBlockEnd = curBlockStart + curBlocknum;
|
||||
startBlock = curBlockStart / layerNum_;
|
||||
startLayer = curBlockStart % layerNum_;
|
||||
endBlock = (curBlockEnd + layerNum_ - 1) / layerNum_;
|
||||
endLayer = curBlockEnd % layerNum_;
|
||||
if (endLayer == 0) {
|
||||
endLayer = layerNum_;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public:
|
||||
__aicore__ inline void Init(GM_ADDR KCache, GM_ADDR VCache, GM_ADDR blockIDs,
|
||||
TransposeKvCacheByBlockTilingData* tilingData, TPipe* tPipe) {
|
||||
kCachePtr_ = KCache;
|
||||
vCachePtr_ = VCache;
|
||||
blockIDsGm_.SetGlobalBuffer((__gm__ int64_t*)blockIDs);
|
||||
|
||||
// shape info
|
||||
blockSize_ = tilingData->blockSize;
|
||||
headNum_ = tilingData->headNum;
|
||||
headDim_ = tilingData->headDim;
|
||||
splitNum_ = tilingData->splitNum;
|
||||
layerNum_ = tilingData->layerNum;
|
||||
// tiling info
|
||||
useCoreNum_ = tilingData->useCoreNum;
|
||||
blockPerCore_ = tilingData->blockPerCore;
|
||||
tailCoreNum_ = tilingData->tailCoreNum;
|
||||
calBlockNum_ = tilingData->calBlockNum;
|
||||
|
||||
tPipe->InitBuffer(vecInQueue_, 1, TOTAL_UB_SIZE);
|
||||
srcFactor_ = blockSize_ * headNum_ / splitNum_ * headDim_;
|
||||
dstFactor_ = headNum_ / splitNum_ * headDim_;
|
||||
copyOutLength_ = blockSize_ * headNum_ * headDim_;
|
||||
dataBlockSize_ = static_cast<uint32_t>(AscendC::GetDataBlockSizeInBytes());
|
||||
}
|
||||
|
||||
__aicore__ inline void Process() {
|
||||
DataCopyParams repeatParams;
|
||||
repeatParams.blockCount = blockSize_;
|
||||
repeatParams.blockLen = headNum_ / splitNum_ * headDim_ * sizeof(T) / dataBlockSize_;
|
||||
repeatParams.srcStride = 0;
|
||||
repeatParams.dstStride = (headNum_ * headDim_ - headNum_ / splitNum_ * headDim_) * sizeof(T) / dataBlockSize_;
|
||||
|
||||
uint32_t startBlock;
|
||||
uint32_t endBlock;
|
||||
uint32_t startLayer;
|
||||
uint32_t endLayer;
|
||||
|
||||
Caloffset(startBlock, endBlock, startLayer, endLayer);
|
||||
for (uint32_t i = startBlock; i < endBlock; ++i) {
|
||||
int64_t blockId = blockIDsGm_.GetValue(i);
|
||||
uint32_t offsetBlock = blockId * blockSize_ * headNum_ * headDim_;
|
||||
uint32_t realStartLayer;
|
||||
uint32_t realEndLayer;
|
||||
if (i == startBlock) {
|
||||
realStartLayer = startLayer;
|
||||
} else {
|
||||
realStartLayer = 0;
|
||||
}
|
||||
|
||||
if (i == (endBlock - 1)) {
|
||||
realEndLayer = endLayer;
|
||||
} else {
|
||||
realEndLayer = layerNum_;
|
||||
}
|
||||
for (uint32_t layerId = realStartLayer; layerId < realEndLayer; ++layerId) {
|
||||
SetGlobalBuffers(layerId);
|
||||
|
||||
CopyIn(kCacheGm_, offsetBlock, repeatParams);
|
||||
CopyOut(kCacheGm_, offsetBlock);
|
||||
|
||||
CopyIn(vCacheGm_, offsetBlock, repeatParams);
|
||||
CopyOut(vCacheGm_, offsetBlock);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
};
|
||||
190
csrc/transpose_kv_cache_by_block/op_kernel/general.h
Normal file
190
csrc/transpose_kv_cache_by_block/op_kernel/general.h
Normal file
@@ -0,0 +1,190 @@
|
||||
#include "common.h"
|
||||
|
||||
template <typename T, uint32_t DB, bool needHandleUnFactorSplit>
|
||||
class TransposeKvCacheByBlockKernelGeneral {
|
||||
protected:
|
||||
TQueBind<TPosition::VECIN, TPosition::VECOUT, 1> queBind_;
|
||||
GlobalTensor<T> kCacheGm_;
|
||||
GlobalTensor<T> vCacheGm_;
|
||||
GlobalTensor<int64_t> blockIDsGm_;
|
||||
|
||||
GM_ADDR kCachePtr_;
|
||||
GM_ADDR vCachePtr_;
|
||||
|
||||
// shape info
|
||||
uint32_t blockNum_;
|
||||
uint32_t blockSize_;
|
||||
uint32_t headNum_;
|
||||
uint32_t headDim_;
|
||||
uint32_t splitNum_;
|
||||
uint32_t layerNum_;
|
||||
uint32_t headNumSplited_;
|
||||
uint32_t blockSizeSplitNum_;
|
||||
|
||||
// tiling info
|
||||
uint32_t useCoreNum_;
|
||||
uint32_t blockPerCore_;
|
||||
uint32_t tailCoreNum_;
|
||||
uint32_t calBlockNum_;
|
||||
|
||||
uint32_t srcFactor_;
|
||||
uint32_t dstFactor_;
|
||||
uint32_t copyOutLength_;
|
||||
|
||||
uint32_t blockSizePerTime_;
|
||||
uint32_t blockSizePerTimeTail_;
|
||||
|
||||
uint32_t blockIdx_;
|
||||
uint32_t dataBlockSize_;
|
||||
bool needSync_;
|
||||
|
||||
__aicore__ inline void CopyIn(GlobalTensor<T> &cacheGm, uint32_t offsetBlock, DataCopyParams &repeatParams) {
|
||||
LocalTensor<T> cacheLocal = queBind_.AllocTensor<T>();
|
||||
for (uint32_t i = 0; i < splitNum_; ++i) {
|
||||
DataCopy(cacheLocal[i * dstFactor_], cacheGm[i * srcFactor_ + offsetBlock], repeatParams);
|
||||
}
|
||||
queBind_.EnQue(cacheLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOut(GlobalTensor<T> &cacheGm, uint32_t offsetBlock) {
|
||||
LocalTensor<T> cacheLocal = queBind_.DeQue<T>();
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8);
|
||||
AscendC::CrossCoreWaitFlag(0x8);
|
||||
DataCopy(cacheGm[offsetBlock], cacheLocal, copyOutLength_);
|
||||
queBind_.FreeTensor(cacheLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void SetGlobalBuffers(uint32_t layerId) {
|
||||
kCacheGm_.SetGlobalBuffer(GetTensorAddr<T>(layerId, kCachePtr_));
|
||||
vCacheGm_.SetGlobalBuffer(GetTensorAddr<T>(layerId, vCachePtr_));
|
||||
}
|
||||
|
||||
__aicore__ inline void Caloffset(uint32_t &startBlock, uint32_t &endBlock, uint32_t &startLayer, uint32_t &endLayer) {
|
||||
|
||||
uint32_t curBlockStart;
|
||||
uint32_t curBlocknum;
|
||||
uint32_t groupBlockIdx = blockIdx_ / blockSizeSplitNum_;
|
||||
if (groupBlockIdx < tailCoreNum_) {
|
||||
needSync_ = false;
|
||||
curBlockStart = groupBlockIdx * (blockPerCore_ + 1);
|
||||
curBlocknum = blockPerCore_ + 1;
|
||||
} else {
|
||||
needSync_ = true;
|
||||
curBlockStart = groupBlockIdx * blockPerCore_ + tailCoreNum_;
|
||||
curBlocknum = blockPerCore_;
|
||||
}
|
||||
uint32_t curBlockEnd = curBlockStart + curBlocknum;
|
||||
startBlock = curBlockStart / layerNum_;
|
||||
startLayer = curBlockStart % layerNum_;
|
||||
endBlock = (curBlockEnd + layerNum_ - 1) / layerNum_;
|
||||
endLayer = curBlockEnd % layerNum_;
|
||||
if (endLayer == 0) {
|
||||
endLayer = layerNum_;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public:
|
||||
__aicore__ inline void Init(GM_ADDR KCache, GM_ADDR VCache, GM_ADDR blockIDs,
|
||||
TransposeKvCacheByBlockTilingData* tilingData, TPipe* tPipe) {
|
||||
kCachePtr_ = KCache;
|
||||
vCachePtr_ = VCache;
|
||||
blockIDsGm_.SetGlobalBuffer((__gm__ int64_t*)blockIDs);
|
||||
blockIdx_ = GetBlockIdx();
|
||||
// shape info
|
||||
blockSize_ = tilingData->blockSize;
|
||||
headNum_ = tilingData->headNum;
|
||||
headDim_ = tilingData->headDim;
|
||||
splitNum_ = tilingData->splitNum;
|
||||
layerNum_ = tilingData->layerNum;
|
||||
// tiling info
|
||||
useCoreNum_ = tilingData->useCoreNum;
|
||||
blockPerCore_ = tilingData->blockPerCore;
|
||||
tailCoreNum_ = tilingData->tailCoreNum;
|
||||
calBlockNum_ = tilingData->calBlockNum;
|
||||
blockSizeSplitNum_ = tilingData->blockSizeSplitNum;
|
||||
blockSizePerTime_ = tilingData->blockSizePerTime;
|
||||
blockSizePerTimeTail_ = tilingData->blockSizePerTimeTail;
|
||||
headNumSplited_ = headNum_ / splitNum_;
|
||||
|
||||
tPipe->InitBuffer(queBind_, DB, TOTAL_UB_SIZE / DB);
|
||||
srcFactor_ = blockSize_ * headNumSplited_ * headDim_;
|
||||
dstFactor_ = headNumSplited_ * headDim_;
|
||||
copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_;
|
||||
dataBlockSize_ = static_cast<uint32_t>(AscendC::GetDataBlockSizeInBytes());
|
||||
}
|
||||
|
||||
__aicore__ inline void Process() {
|
||||
DataCopyParams repeatParams;
|
||||
repeatParams.blockCount = blockSizePerTime_;
|
||||
repeatParams.blockLen = headNumSplited_ * headDim_ * sizeof(T) / dataBlockSize_;
|
||||
repeatParams.srcStride = 0;
|
||||
repeatParams.dstStride = (headNum_ * headDim_ - headNumSplited_ * headDim_) * sizeof(T) / dataBlockSize_;
|
||||
|
||||
uint32_t startBlock;
|
||||
uint32_t endBlock;
|
||||
uint32_t startLayer;
|
||||
uint32_t endLayer;
|
||||
|
||||
Caloffset(startBlock, endBlock, startLayer, endLayer);
|
||||
|
||||
for (uint32_t i = startBlock; i < endBlock; ++i) {
|
||||
int64_t blockId = blockIDsGm_.GetValue(i);
|
||||
uint32_t offsetBlock = blockId * blockSize_ * headNum_ * headDim_;
|
||||
uint32_t realStartLayer;
|
||||
uint32_t realEndLayer;
|
||||
if (i == startBlock) {
|
||||
realStartLayer = startLayer;
|
||||
} else {
|
||||
realStartLayer = 0;
|
||||
}
|
||||
|
||||
if (i == (endBlock - 1)) {
|
||||
realEndLayer = endLayer;
|
||||
} else {
|
||||
realEndLayer = layerNum_;
|
||||
}
|
||||
for (uint32_t layerId = realStartLayer; layerId < realEndLayer; ++layerId) {
|
||||
SetGlobalBuffers(layerId);
|
||||
uint32_t blockSizeIndex = blockIdx_ % blockSizeSplitNum_;
|
||||
uint32_t srcOffset;
|
||||
uint32_t dstOffset;
|
||||
if constexpr (needHandleUnFactorSplit) {
|
||||
// handle tail
|
||||
if (blockSizeIndex >= blockSizePerTimeTail_) {
|
||||
repeatParams.blockCount = (blockSizePerTime_ - 1);
|
||||
copyOutLength_ = (blockSizePerTime_ - 1) * headNum_ * headDim_;
|
||||
srcOffset = (blockSizeIndex * blockSizePerTime_ - (blockSizeIndex - blockSizePerTimeTail_)) * headNumSplited_ * headDim_;
|
||||
dstOffset = (blockSizeIndex * blockSizePerTime_ - (blockSizeIndex - blockSizePerTimeTail_)) * headNum_ * headDim_;
|
||||
} else {
|
||||
repeatParams.blockCount = blockSizePerTime_;
|
||||
copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_;
|
||||
srcOffset = blockSizeIndex * blockSizePerTime_ * headNumSplited_ * headDim_;
|
||||
dstOffset = blockSizeIndex * blockSizePerTime_ * headNum_ * headDim_;
|
||||
}
|
||||
} else {
|
||||
repeatParams.blockCount = blockSizePerTime_;
|
||||
copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_;
|
||||
srcOffset = blockSizeIndex * blockSizePerTime_ * headNumSplited_ * headDim_;
|
||||
dstOffset = blockSizeIndex * blockSizePerTime_ * headNum_ * headDim_;
|
||||
}
|
||||
|
||||
CopyIn(kCacheGm_, offsetBlock + srcOffset, repeatParams);
|
||||
CopyOut(kCacheGm_, offsetBlock + dstOffset);
|
||||
|
||||
CopyIn(vCacheGm_, offsetBlock + srcOffset, repeatParams);
|
||||
CopyOut(vCacheGm_, offsetBlock + dstOffset);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (needSync_) {
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8);
|
||||
AscendC::CrossCoreWaitFlag(0x8);
|
||||
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8);
|
||||
AscendC::CrossCoreWaitFlag(0x8);
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,36 @@
|
||||
#include "kernel_operator.h"
|
||||
#include "full_load.h"
|
||||
#include "general.h"
|
||||
|
||||
|
||||
extern "C" __global__ __aicore__ void transpose_kv_cache_by_block(GM_ADDR KCache, GM_ADDR VCache, GM_ADDR blockIDs, GM_ADDR workspace, GM_ADDR tiling) {
|
||||
GET_TILING_DATA(tiling_data, tiling);
|
||||
TPipe tPipe;
|
||||
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIV_1_0);
|
||||
if (TILING_KEY_IS(0)) {
|
||||
// full load not db
|
||||
TransposeKvCacheByBlockKernelFullLoad<DTYPE_KCACHE> kernel;
|
||||
kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe);
|
||||
kernel.Process();
|
||||
} else if (TILING_KEY_IS(1)) {
|
||||
// db \ align split blockSize
|
||||
TransposeKvCacheByBlockKernelGeneral<DTYPE_KCACHE, uint32_t(2), false> kernel;
|
||||
kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe);
|
||||
kernel.Process();
|
||||
} else if (TILING_KEY_IS(2)) {
|
||||
// not db \ align split blockSize
|
||||
TransposeKvCacheByBlockKernelGeneral<DTYPE_KCACHE, uint32_t(1), false> kernel;
|
||||
kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe);
|
||||
kernel.Process();
|
||||
} else if (TILING_KEY_IS(3)) {
|
||||
// db \ unalign split blockSize
|
||||
TransposeKvCacheByBlockKernelGeneral<DTYPE_KCACHE, uint32_t(2), true> kernel;
|
||||
kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe);
|
||||
kernel.Process();
|
||||
} else if (TILING_KEY_IS(4)) {
|
||||
// not db \ unalign split blockSize
|
||||
TransposeKvCacheByBlockKernelGeneral<DTYPE_KCACHE, uint32_t(1), true> kernel;
|
||||
kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe);
|
||||
kernel.Process();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
torch.set_printoptions(threshold=float("inf"))
|
||||
|
||||
def clone_kv_cache(k_caches, v_caches):
|
||||
new_k_caches = [cache.clone() for cache in k_caches]
|
||||
new_v_caches = [cache.clone() for cache in v_caches]
|
||||
return new_k_caches, new_v_caches
|
||||
|
||||
class TestTransposeKvCacheByBlock(unittest.TestCase):
|
||||
def compute_golden(self, k_caches, v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, num_need_pulls, layers, dtype):
|
||||
num_blocks = block_ids_tensor.shape[0]
|
||||
|
||||
block_ids_tensor = block_ids_tensor.to(dtype=torch.int32)
|
||||
block_offsets = torch.arange(0, block_size, dtype=torch.int32).npu()
|
||||
slot_mapping = block_offsets.reshape(
|
||||
(1, block_size)) + block_ids_tensor.reshape(
|
||||
(num_blocks, 1)) * block_size
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
block_len = num_blocks * block_size
|
||||
block_len_tensor = torch.tensor([block_len],dtype=torch.int32).npu()
|
||||
|
||||
block_table = block_ids_tensor.view(1, -1)
|
||||
seq_start_tensor = torch.tensor([0],dtype=torch.int32).npu()
|
||||
|
||||
k = torch.empty(block_len, num_kv_head, head_dim, dtype=dtype).npu()
|
||||
v = torch.empty(block_len, num_kv_head, head_dim, dtype=dtype).npu()
|
||||
|
||||
for layer in range(layers):
|
||||
k_cache_layer = k_caches[layer]
|
||||
v_cache_layer = v_caches[layer]
|
||||
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
k_cache_layer,
|
||||
v_cache_layer,
|
||||
block_table,
|
||||
block_len_tensor,
|
||||
seq_starts=seq_start_tensor,
|
||||
key=k,
|
||||
value=v,
|
||||
)
|
||||
|
||||
k = k.view(num_blocks, num_need_pulls, block_size, -1)
|
||||
k.transpose_(1, 2)
|
||||
k = k.contiguous().view(block_len, num_kv_head, -1)
|
||||
|
||||
v = v.view(num_blocks, num_need_pulls, block_size, -1)
|
||||
v.transpose_(1, 2)
|
||||
v = v.contiguous().view(block_len, num_kv_head, -1)
|
||||
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=k,
|
||||
value=v,
|
||||
key_cache=k_cache_layer,
|
||||
value_cache=v_cache_layer,
|
||||
slot_indices=slot_mapping,
|
||||
)
|
||||
del k, v
|
||||
|
||||
def test_transpose_kv_cache_by_block(self):
|
||||
# (layers, block_num, block_size, num_kv_head, head_dim, num_need_pulls)
|
||||
test_cases = [
|
||||
(16, 128, 128, 4, 128, 4),
|
||||
(16, 128, 128, 4, 128, 2),
|
||||
(16, 128, 128, 4, 128, 1),
|
||||
(16, 128, 128, 8, 128, 8),
|
||||
(16, 128, 128, 8, 128, 4),
|
||||
(16, 128, 128, 8, 128, 2),
|
||||
]
|
||||
dtypes = [torch.float16, torch.bfloat16]
|
||||
for dtype in dtypes:
|
||||
for layers, block_num, block_size, num_kv_head, head_dim, num_need_pulls in test_cases:
|
||||
with self.subTest(dtype=dtype, shape=f"({layers}, {block_num}, {block_size}, {num_kv_head}, {head_dim}, {num_need_pulls})"):
|
||||
k_caches = []
|
||||
v_caches = []
|
||||
block_id_num = 33
|
||||
block_ids_tensor = torch.randperm(block_num, dtype=torch.int64, device="npu")[:block_id_num]
|
||||
for i in range(layers):
|
||||
kcache = torch.randn(block_num, block_size, num_kv_head, head_dim, dtype=dtype, device="npu")
|
||||
vcache = torch.randn(block_num, block_size, num_kv_head, head_dim, dtype=dtype, device="npu")
|
||||
k_caches.append(kcache)
|
||||
v_caches.append(vcache)
|
||||
|
||||
cloned_k_caches, cloned_v_caches = clone_kv_cache(k_caches, v_caches)
|
||||
self.compute_golden(cloned_k_caches, cloned_v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, num_need_pulls, layers, dtype)
|
||||
torch.ops._C_ascend.transpose_kv_cache_by_block(k_caches, v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, num_need_pulls, layers)
|
||||
|
||||
for i in range (layers):
|
||||
self.assert_tensors_almost_equal(k_caches[i], cloned_k_caches[i], dtype)
|
||||
self.assert_tensors_almost_equal(v_caches[i], cloned_v_caches[i], dtype)
|
||||
|
||||
def assert_tensors_almost_equal(self, actual, expected, dtype):
|
||||
"""Check if two tensors are approximately equal (considering floating point errors)"""
|
||||
self.assertEqual(actual.shape, expected.shape, "Shape mismatch")
|
||||
|
||||
# Check for NaN
|
||||
self.assertFalse(
|
||||
torch.isnan(actual).any(), "Actual result contains NaN")
|
||||
self.assertFalse(
|
||||
torch.isnan(expected).any(), "Expected result contains NaN")
|
||||
|
||||
# Check for Inf
|
||||
self.assertFalse(
|
||||
torch.isinf(actual).any(), "Actual result contains Inf")
|
||||
self.assertFalse(
|
||||
torch.isinf(expected).any(), "Expected result contains Inf")
|
||||
|
||||
# Set different tolerances based on data type
|
||||
if dtype == torch.float16:
|
||||
rtol, atol = 1e-5, 1e-5
|
||||
else: # bfloat16
|
||||
rtol, atol = 1.5e-5, 1.5e-5
|
||||
|
||||
# Compare values
|
||||
diff = torch.abs(actual - expected)
|
||||
max_diff = diff.max().item()
|
||||
max_expected = torch.abs(expected).max().item()
|
||||
|
||||
# Check relative and absolute errors
|
||||
if max_expected > 0:
|
||||
relative_diff = max_diff / max_expected
|
||||
self.assertLessEqual(
|
||||
relative_diff,
|
||||
rtol,
|
||||
f"Relative error too large: {relative_diff} > {rtol}. Max difference: {max_diff}",
|
||||
)
|
||||
|
||||
self.assertLessEqual(max_diff, atol,
|
||||
f"Absolute error too large: {max_diff} > {atol}")
|
||||
@@ -46,10 +46,11 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
from vllm_ascend import envs as ascend_envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
||||
from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te
|
||||
from vllm_ascend.distributed.kv_transfer.utils.utils import get_transfer_timeout_value
|
||||
from vllm_ascend.utils import is_vl_model
|
||||
from vllm_ascend.utils import enable_custom_op, is_vl_model
|
||||
|
||||
# isort: off
|
||||
if TYPE_CHECKING:
|
||||
@@ -570,8 +571,39 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
is_kv_transfer_end = global_offset == tp_num_need_pulls * self._prefill_pp_size - 1
|
||||
need_cat_cache = tp_num_need_pulls > 1 and is_kv_transfer_end
|
||||
need_nz_cache = get_ascend_config().enable_kv_nz and is_kv_transfer_end
|
||||
use_fused_op = ascend_envs.VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK
|
||||
if need_nz_cache or need_cat_cache:
|
||||
self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, need_cat_cache, need_nz_cache)
|
||||
# use fused op to reformat kv cache, we keep original implementation to provide ability to disable it.
|
||||
if use_fused_op and enable_custom_op():
|
||||
if need_cat_cache:
|
||||
# the fused op only support cat GQA/MHA kv cache by head
|
||||
self.reformat_kv_cache_with_fused_op(grouped_local_block_ids, tp_num_need_pulls)
|
||||
if need_nz_cache:
|
||||
# maybe use fused op to reformat kv nz too in the future.
|
||||
self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, False, need_nz_cache)
|
||||
else:
|
||||
self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, need_cat_cache, need_nz_cache)
|
||||
|
||||
def reformat_kv_cache_with_fused_op(self, block_ids: list[list[int]], tp_num_need_pulls: int):
|
||||
# Get necessary parameters
|
||||
k_cache = list(self.kv_caches.values())[0][0]
|
||||
device = k_cache.device
|
||||
head_dim = self.model_config.hf_text_config.head_dim
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
num_kv_head = max(self.model_config.hf_text_config.num_key_value_heads // self.tp_size, 1)
|
||||
layers = self.model_config.hf_text_config.num_hidden_layers
|
||||
flat_block_ids = [item for sublist in block_ids for item in sublist]
|
||||
block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int64, device=device)
|
||||
|
||||
k_caches = []
|
||||
v_caches = []
|
||||
for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items():
|
||||
k_caches.append(k_cache_layer)
|
||||
v_caches.append(v_cache_layer)
|
||||
|
||||
torch.ops._C_ascend.transpose_kv_cache_by_block(
|
||||
k_caches, v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, tp_num_need_pulls, layers
|
||||
)
|
||||
|
||||
def reformat_kv_cache(
|
||||
self,
|
||||
|
||||
@@ -111,6 +111,10 @@ env_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ASCEND_ENABLE_FUSED_MC2": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", "0")),
|
||||
# Whether to anbale balance scheduling
|
||||
"VLLM_ASCEND_BALANCE_SCHEDULING": lambda: bool(int(os.getenv("VLLM_ASCEND_BALANCE_SCHEDULING", "0"))),
|
||||
# use fused op transpose_kv_cache_by_block, default is True
|
||||
"VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK": lambda: bool(
|
||||
int(os.getenv("VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK", "1"))
|
||||
),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
Reference in New Issue
Block a user