From 79803932e23ef2a09fbc0bc3dfee9a29eb302044 Mon Sep 17 00:00:00 2001 From: lidenghui1110 <30521952+lidenghui1110@users.noreply.github.com> Date: Tue, 3 Feb 2026 14:10:01 +0800 Subject: [PATCH] [Kernel] Add AscendC fused op transpose_kv_cache_by_block to speed up GQA transfer (#6366) ### What this PR does / why we need it? As #2947 describe, we need to transpose kv cache layout after GQA kv transfer when prefill and decode tensor parallel size are heterogeneous, in the previous implementation, we use `npu_paged_cache_load ` + `tranpose` + `_npu_reshape_and_cache` to do this work. But obviously, it is not an efficient plan, the ops above need to be called for each layer, which introduces 3 * layer_num kernel launch, and 6 * layer_num data movement between L1 Cache and HBM for one request on decode node. Usually, decode node uses graph mode, so these op kernels will be called between decode forward launched by an async thread in mooncacke connector, this kernels maybe last for several decode forward and TTFT will increase by 3~4 decode forward time. In this PR, we implement an AscendC fused op `transpose_kv_cache_by_block` to do this with only once kernel launch and move data between L1 Cache and HBM only once. After using this fused op, the time cost in transpose kv cacke layout can be decreased to 0.24ms from 7ms in UT on 910C, and in PD disaggregation scenario, TTFT can decrease about 90 ~ 110 ms in qwen3-235B. | request_num | original | fused_op| |:----------------------:|:---------------:|:-------------------:| | 1 | 643 ms | 578 ms | | 128 | 1480 ms | 1368 ms | ### Does this PR introduce _any_ user-facing change? Use fused op by default, incase the op has bug in any scenario, provide fallback choice using env to disable it. **DISABLE fused op by add following env** `export VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK=0` ### How was this patch tested? - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd --------- Signed-off-by: lidenghui --- csrc/build_aclnn.sh | 3 +- csrc/torch_binding.cpp | 20 ++ csrc/torch_binding_meta.cpp | 16 ++ .../op_host/CMakeLists.txt | 40 ++++ .../transpose_kv_cache_by_block_def.cpp | 36 ++++ .../transpose_kv_cache_by_block_proto.cpp | 36 ++++ .../transpose_kv_cache_by_block_tiling.cpp | 182 +++++++++++++++++ .../transpose_kv_cache_by_block_tiling.h | 23 +++ .../op_kernel/common.h | 16 ++ .../op_kernel/full_load.h | 141 +++++++++++++ .../op_kernel/general.h | 190 ++++++++++++++++++ .../op_kernel/transpose_kv_cache_by_block.cpp | 36 ++++ .../test_transpose_kv_cache_by_block.py | 137 +++++++++++++ .../kv_transfer/kv_p2p/mooncake_connector.py | 36 +++- vllm_ascend/envs.py | 4 + 15 files changed, 913 insertions(+), 3 deletions(-) create mode 100644 csrc/transpose_kv_cache_by_block/op_host/CMakeLists.txt create mode 100644 csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_def.cpp create mode 100644 csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_proto.cpp create mode 100644 csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_tiling.cpp create mode 100644 csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_tiling.h create mode 100644 csrc/transpose_kv_cache_by_block/op_kernel/common.h create mode 100644 csrc/transpose_kv_cache_by_block/op_kernel/full_load.h create mode 100644 csrc/transpose_kv_cache_by_block/op_kernel/general.h create mode 100644 csrc/transpose_kv_cache_by_block/op_kernel/transpose_kv_cache_by_block.cpp create mode 100644 tests/e2e/nightly/single_node/ops/singlecard_ops/test_transpose_kv_cache_by_block.py diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index 938d56a0..fc4c53d0 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd) export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH} - CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;" + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series @@ -81,6 +81,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then "moe_gating_top_k" "add_rms_norm_bias" "apply_top_k_top_p_custom" + "transpose_kv_cache_by_block" ) CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}") SOC_ARG="ascend910_93" diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 146123ec..963bf854 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -1343,6 +1343,22 @@ std::tuple npu_add_rms_norm_bias( return std::tuple(y, rstd, x); } +void transpose_kv_cache_by_block( + const at::TensorList &kCache, + const at::TensorList &vCache, + const at::Tensor &blockIDs, + int64_t blockSize, + int64_t headNum, + int64_t headDim, + int64_t splitNum, + int64_t layerNum) +{ + + EXEC_NPU_CMD(aclnnTransposeKvCacheByBlock, kCache, vCache, blockIDs, + blockSize, headNum, headDim, splitNum, layerNum); + +} + } // namespace vllm_ascend TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) @@ -1521,4 +1537,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) ops.def("npu_apply_top_k_top_p(Tensor logits, Tensor? p=None, Tensor? k=None) -> Tensor"); ops.impl("npu_apply_top_k_top_p", torch::kPrivateUse1, &vllm_ascend::npu_apply_top_k_top_p); + ops.def( + "transpose_kv_cache_by_block(Tensor[] kCache, Tensor[] vCache, Tensor blockIDs, int blockSize, int headNum, int headDim, int splitNum, int layerNum) -> ()" + ); + ops.impl("transpose_kv_cache_by_block", torch::kPrivateUse1, &vllm_ascend::transpose_kv_cache_by_block); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index af550134..378519bc 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -435,6 +435,20 @@ std::tuple npu_add_rms_norm_bias_meta( at::Tensor x = at::empty_symint(x1.sym_sizes(), x1.options()); return std::tuple(y, rstd, x); } + +void transpose_kv_cache_by_block_meta( + const at::TensorList &k_cache, + const at::TensorList &v_cache, + const at::Tensor &block_ids, + int64_t block_size, + int64_t head_num, + int64_t head_dim, + int64_t split_num, + int64_t layer_num) +{ + return; +} + } // namespace meta } // namespace vllm_ascend @@ -475,5 +489,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { ops.impl("moe_gating_top_k", &vllm_ascend::meta::moe_gating_top_k_meta); // Add_Rms_Norm_Bias ops.impl("npu_add_rms_norm_bias", &vllm_ascend::meta::npu_add_rms_norm_bias_meta); + // transpose_kv_cache_by_block + ops.impl("transpose_kv_cache_by_block", &vllm_ascend::meta::transpose_kv_cache_by_block_meta); } } diff --git a/csrc/transpose_kv_cache_by_block/op_host/CMakeLists.txt b/csrc/transpose_kv_cache_by_block/op_host/CMakeLists.txt new file mode 100644 index 00000000..dfd3d15d --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_host/CMakeLists.txt @@ -0,0 +1,40 @@ +# This program is free software, you can redistribute it and/or modify it. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_ops_compile_options( + OP_NAME TransposeKvCacheByBlock + OPTIONS --cce-auto-sync=off + -Wno-deprecated-declarations + -Werror + -mllvm -cce-aicore-hoist-movemask=false + --op_relocatable_kernel_binary=true +) + +target_sources(op_host_aclnn PRIVATE + transpose_kv_cache_by_block_def.cpp +) + +target_sources(optiling PRIVATE + transpose_kv_cache_by_block_tiling.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(opmaster_ct PRIVATE + transpose_kv_cache_by_block_tiling.cpp + ) +endif () + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_sources(opsproto PRIVATE + transpose_kv_cache_by_block_proto.cpp +) + diff --git a/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_def.cpp b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_def.cpp new file mode 100644 index 00000000..5e1c1144 --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_def.cpp @@ -0,0 +1,36 @@ +#include "register/op_def_registry.h" + +namespace ops { +class TransposeKvCacheByBlock : public OpDef { +public: + explicit TransposeKvCacheByBlock(const char* name) : OpDef(name) + { + this->Input("KCache") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("VCache") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("blockIDs") + .ParamType(REQUIRED) + .DataTypeList({ge::DT_INT64}) + .FormatList({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Attr("blockSize").Int(); + this->Attr("headNum").Int(); + this->Attr("headDim").Int(); + this->Attr("splitNum").Int(); + this->Attr("layerNum").Int(); + + this->AICore().AddConfig("ascend910b"); + this->AICore().AddConfig("ascend910_93"); + + } +}; + +OP_ADD(TransposeKvCacheByBlock); +} diff --git a/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_proto.cpp b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_proto.cpp new file mode 100644 index 00000000..24f2d84d --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_proto.cpp @@ -0,0 +1,36 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file transpose_kv_cache_by_block_proto.cpp + * \brief + */ +#include +#include +#include "error/ops_error.h" + +using namespace ge; + +namespace ops { + +static ge::graphStatus InferShapeTransposeKvCacheByBlock(gert::InferShapeContext* context) +{ + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus InferDataTypeTransposeKvCacheByBlock(gert::InferDataTypeContext *context) +{ + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_INFERSHAPE(TransposeKvCacheByBlock) + .InferShape(InferShapeTransposeKvCacheByBlock) + .InferDataType(InferDataTypeTransposeKvCacheByBlock); +} // namespace ops diff --git a/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_tiling.cpp b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_tiling.cpp new file mode 100644 index 00000000..5e8ae289 --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_tiling.cpp @@ -0,0 +1,182 @@ +#include "transpose_kv_cache_by_block_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/platform/platform_ascendc.h" +#include "log/ops_log.h" +#include + +namespace optiling { + +constexpr uint64_t DATA_SIZE = 2; +constexpr uint64_t BLOCK_SIZE = 32; +constexpr uint64_t DB_ON = 2; + +constexpr uint32_t FULL_LOAD = 0; +constexpr uint32_t SPLIT_BLOCK_SIZE_ALIGNED_AND_DB = 1; +constexpr uint32_t SPLIT_BLOCK_SIZE_UNALIGNED_AND_DB = 3; +constexpr uint32_t SPLIT_BLOCK_SIZE_ALIGNED_AND_NOT_DB = 2; +constexpr uint32_t SPLIT_BLOCK_SIZE_UNALIGNED_AND_NOT_DB = 4; + +void findFactorsOptimized(std::vector &factors, int64_t n) { + + for (int64_t i = 1; i * i <= n; i++) { + if (n % i == 0) { + factors.push_back(i); + + if (i != n / i) { + factors.push_back(n / i); + } + } + } + + sort(factors.begin(), factors.end()); +} + +ge::graphStatus CalTiling(gert::TilingContext* context, TransposeKvCacheByBlockTilingData &tiling) +{ + fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo(); + OPS_LOG_E_IF_NULL(context, platformInfoPtr, return ge::GRAPH_FAILED); + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr); + int64_t useCoreNum = ascendcPlatform.GetCoreNumAiv(); + uint64_t ubSize; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + + auto attr = context->GetAttrs(); + OPS_LOG_E_IF_NULL(context, attr, return ge::GRAPH_FAILED); + const int64_t* blockSizePtr = attr->GetAttrPointer(0); + const int64_t* headNumPtr = attr->GetAttrPointer(1); + const int64_t* headDimPtr = attr->GetAttrPointer(2); + const int64_t* splitNumPtr = attr->GetAttrPointer(3); + const int64_t* layerNumPtr = attr->GetAttrPointer(4); + OPS_CHECK(blockSizePtr == nullptr || headNumPtr == nullptr || headDimPtr == nullptr || + splitNumPtr == nullptr || layerNumPtr == nullptr, + OPS_LOG_E(context->GetNodeName(), "Get attr failed."), + return ge::GRAPH_FAILED); + + auto blockIDsTensor = context->GetDynamicInputTensor(2, 0); + OPS_LOG_E_IF_NULL(context, blockIDsTensor, return ge::GRAPH_FAILED); + + gert::Shape blockIDsTensorShape = blockIDsTensor->GetStorageShape(); + int64_t calBlockNum = static_cast(blockIDsTensorShape.GetDim(0)); + + tiling.set_calBlockNum(static_cast(calBlockNum)); + + int64_t blockSize = *blockSizePtr; + int64_t headNum = *headNumPtr; + int64_t headDim = *headDimPtr; + int64_t splitNum = *splitNumPtr; + int64_t layerNum = *layerNumPtr; + uint32_t tilingKey = FULL_LOAD; + + if (headDim * DATA_SIZE % BLOCK_SIZE != 0) { + OPS_LOG_E(context, "headDim * DATA_SIZE must be a multiple of 32 bytes."); + return ge::GRAPH_FAILED; + } + + std::vector factors; + findFactorsOptimized(factors, useCoreNum); + + uint32_t factorIndex = 0; + bool findSplitNum = true; + int64_t blockSizeSplitNum = factors[factorIndex]; + uint64_t dataSizeloadOnce = blockSize * headNum * headDim * DATA_SIZE; + // if can full load, not split blockSize and db + if (dataSizeloadOnce > ubSize) { + tilingKey = SPLIT_BLOCK_SIZE_ALIGNED_AND_DB; + // split blockSize and db + while (dataSizeloadOnce > (ubSize / DB_ON)) { + factorIndex += 1; + if (factorIndex == factors.size()) { + tilingKey = FULL_LOAD; + findSplitNum = false; + break; + } + blockSizeSplitNum = factors[factorIndex]; + dataSizeloadOnce = ((blockSize + blockSizeSplitNum - 1) / blockSizeSplitNum) * headNum * headDim * DATA_SIZE; + } + if (tilingKey == SPLIT_BLOCK_SIZE_ALIGNED_AND_DB && (blockSize % blockSizeSplitNum != 0)) { + tilingKey = SPLIT_BLOCK_SIZE_UNALIGNED_AND_DB; + } + } + + if (!findSplitNum) { + tilingKey = SPLIT_BLOCK_SIZE_ALIGNED_AND_NOT_DB; + // split blockSize but not db + findSplitNum = true; + factorIndex = 0; + blockSizeSplitNum = factors[factorIndex]; + dataSizeloadOnce = blockSize * headNum * headDim * DATA_SIZE; + while (dataSizeloadOnce > ubSize) { + factorIndex += 1; + if (factorIndex == factors.size()) { + tilingKey = FULL_LOAD; + findSplitNum = false; + break; + } + blockSizeSplitNum = factors[factorIndex]; + dataSizeloadOnce = ((blockSize + blockSizeSplitNum - 1) / blockSizeSplitNum) * headNum * headDim * DATA_SIZE; + } + if (tilingKey == SPLIT_BLOCK_SIZE_ALIGNED_AND_NOT_DB && (blockSize % blockSizeSplitNum != 0)) { + tilingKey = SPLIT_BLOCK_SIZE_UNALIGNED_AND_NOT_DB; + } + } + + // headNum * headDim too large + if (!findSplitNum) { + OPS_LOG_E(context, "headNum * headDim * sizeof(half) > ubSize " + "or blockSize * headNum * headDim * sizeof(half) > ubSize * vectorCoreNum. " + "Currently, splitting headNum or headDim is not supported."); + return ge::GRAPH_FAILED; + } + tiling.set_blockSizePerTime(static_cast((blockSize + blockSizeSplitNum - 1) / blockSizeSplitNum)); + tiling.set_blockSizePerTimeTail(static_cast(blockSize % blockSizeSplitNum)); + tiling.set_blockSizeSplitNum(static_cast(blockSizeSplitNum)); + + tiling.set_blockSize(static_cast(blockSize)); + tiling.set_headNum(static_cast(headNum)); + tiling.set_headDim(static_cast(headDim)); + tiling.set_splitNum(static_cast(splitNum)); + tiling.set_layerNum(static_cast(layerNum)); + + int64_t totalRound = layerNum * calBlockNum; + + if ((totalRound * blockSizeSplitNum) < useCoreNum) { + useCoreNum = totalRound * blockSizeSplitNum; + } + int64_t blockPerCore = totalRound / (useCoreNum / blockSizeSplitNum); + int64_t tailCoreNum = totalRound % (useCoreNum / blockSizeSplitNum); + + tiling.set_useCoreNum(static_cast(useCoreNum)); + tiling.set_blockPerCore(static_cast(blockPerCore)); + tiling.set_tailCoreNum(static_cast(tailCoreNum)); + context->SetBlockDim(useCoreNum); + context->SetTilingKey(tilingKey); + + return ge::GRAPH_SUCCESS; +} + + +static ge::graphStatus TransposeKvCacheByBlockTilingFunc(gert::TilingContext* context) +{ + + TransposeKvCacheByBlockTilingData tiling; + auto status = CalTiling(context, tiling); + OP_CHECK(status != ge::GRAPH_SUCCESS, OPS_LOG_E(context->GetNodeName(), "Cal tiling failed."), + return ge::GRAPH_FAILED); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + return ge::GRAPH_SUCCESS; +} + +struct TransposeKvCacheByBlockCompileInfo {}; +ge::graphStatus TilingParseForTransposeKvCacheByBlock(gert::TilingParseContext *context) +{ + (void)context; + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(TransposeKvCacheByBlock) + .Tiling(TransposeKvCacheByBlockTilingFunc) + .TilingParse(TilingParseForTransposeKvCacheByBlock); +} \ No newline at end of file diff --git a/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_tiling.h b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_tiling.h new file mode 100644 index 00000000..c3b0901b --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_tiling.h @@ -0,0 +1,23 @@ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(TransposeKvCacheByBlockTilingData) + // shape info + // TILING_DATA_FIELD_DEF(uint32_t, blockNum); + TILING_DATA_FIELD_DEF(uint32_t, blockSize); + TILING_DATA_FIELD_DEF(uint32_t, headNum); + TILING_DATA_FIELD_DEF(uint32_t, headDim); + TILING_DATA_FIELD_DEF(uint32_t, splitNum); + TILING_DATA_FIELD_DEF(uint32_t, layerNum); + // tiling info + TILING_DATA_FIELD_DEF(uint32_t, useCoreNum); + TILING_DATA_FIELD_DEF(uint32_t, blockPerCore); + TILING_DATA_FIELD_DEF(uint32_t, tailCoreNum); + TILING_DATA_FIELD_DEF(uint32_t, calBlockNum); + TILING_DATA_FIELD_DEF(uint32_t, blockSizePerTime); + TILING_DATA_FIELD_DEF(uint32_t, blockSizePerTimeTail); + TILING_DATA_FIELD_DEF(uint32_t, blockSizeSplitNum); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(TransposeKvCacheByBlock, TransposeKvCacheByBlockTilingData) +} diff --git a/csrc/transpose_kv_cache_by_block/op_kernel/common.h b/csrc/transpose_kv_cache_by_block/op_kernel/common.h new file mode 100644 index 00000000..e74089fa --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_kernel/common.h @@ -0,0 +1,16 @@ +#include "kernel_operator.h" +using namespace AscendC; + +#ifndef __OP_KERNEL_KV_CACHE_TRANSPOSE_H__ +#define __OP_KERNEL_KV_CACHE_TRANSPOSE_H__ + +template +__aicore__ inline __gm__ T* GetTensorAddr(uint16_t index, GM_ADDR tensorPtr) { + __gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr); + // The offset of the data address from the first address. + uint64_t tensorPtrOffset = *dataAddr; + // Moving 3 bits to the right means dividing by sizeof(uint64 t). + __gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3); + return reinterpret_cast<__gm__ T*>(*(retPtr + index)); +} +#endif \ No newline at end of file diff --git a/csrc/transpose_kv_cache_by_block/op_kernel/full_load.h b/csrc/transpose_kv_cache_by_block/op_kernel/full_load.h new file mode 100644 index 00000000..86951f69 --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_kernel/full_load.h @@ -0,0 +1,141 @@ +#include "common.h" + +template +class TransposeKvCacheByBlockKernelFullLoad { + protected: + TQueBind vecInQueue_; + GlobalTensor kCacheGm_; + GlobalTensor vCacheGm_; + GlobalTensor blockIDsGm_; + + GM_ADDR kCachePtr_; + GM_ADDR vCachePtr_; + + // shape info + uint32_t blockNum_; + uint32_t blockSize_; + uint32_t headNum_; + uint32_t headDim_; + uint32_t splitNum_; + uint32_t layerNum_; + // tiling info + uint32_t useCoreNum_; + uint32_t blockPerCore_; + uint32_t tailCoreNum_; + uint32_t calBlockNum_; + + uint32_t srcFactor_; + uint32_t dstFactor_; + uint32_t copyOutLength_; + uint32_t dataBlockSize_; + + __aicore__ inline void CopyIn(GlobalTensor &cacheGm, uint32_t offsetBlock, DataCopyParams &repeatParams) { + LocalTensor cacheLocal = vecInQueue_.AllocTensor(); + for (uint32_t i = 0; i < splitNum_; ++i) { + DataCopy(cacheLocal[i * dstFactor_], cacheGm[i * srcFactor_ + offsetBlock], repeatParams); + } + vecInQueue_.EnQue(cacheLocal); + } + + __aicore__ inline void CopyOut(GlobalTensor &cacheGm, uint32_t offsetBlock) { + LocalTensor cacheLocal = vecInQueue_.DeQue(); + DataCopy(cacheGm[offsetBlock], cacheLocal, copyOutLength_); + vecInQueue_.FreeTensor(cacheLocal); + } + + __aicore__ inline void SetGlobalBuffers(uint32_t layerId) { + kCacheGm_.SetGlobalBuffer(GetTensorAddr(layerId, kCachePtr_)); + vCacheGm_.SetGlobalBuffer(GetTensorAddr(layerId, vCachePtr_)); + } + + __aicore__ inline void Caloffset(uint32_t &startBlock, uint32_t &endBlock, uint32_t &startLayer, uint32_t &endLayer) { + uint32_t blockIdx = GetBlockIdx(); + uint32_t curBlockStart; + uint32_t curBlocknum; + + if (blockIdx < tailCoreNum_) { + curBlockStart = blockIdx * (blockPerCore_ + 1); + curBlocknum = blockPerCore_ + 1; + } else { + curBlockStart = blockIdx * blockPerCore_ + tailCoreNum_; + curBlocknum = blockPerCore_; + } + uint32_t curBlockEnd = curBlockStart + curBlocknum; + startBlock = curBlockStart / layerNum_; + startLayer = curBlockStart % layerNum_; + endBlock = (curBlockEnd + layerNum_ - 1) / layerNum_; + endLayer = curBlockEnd % layerNum_; + if (endLayer == 0) { + endLayer = layerNum_; + } + } + + + public: + __aicore__ inline void Init(GM_ADDR KCache, GM_ADDR VCache, GM_ADDR blockIDs, + TransposeKvCacheByBlockTilingData* tilingData, TPipe* tPipe) { + kCachePtr_ = KCache; + vCachePtr_ = VCache; + blockIDsGm_.SetGlobalBuffer((__gm__ int64_t*)blockIDs); + + // shape info + blockSize_ = tilingData->blockSize; + headNum_ = tilingData->headNum; + headDim_ = tilingData->headDim; + splitNum_ = tilingData->splitNum; + layerNum_ = tilingData->layerNum; + // tiling info + useCoreNum_ = tilingData->useCoreNum; + blockPerCore_ = tilingData->blockPerCore; + tailCoreNum_ = tilingData->tailCoreNum; + calBlockNum_ = tilingData->calBlockNum; + + tPipe->InitBuffer(vecInQueue_, 1, TOTAL_UB_SIZE); + srcFactor_ = blockSize_ * headNum_ / splitNum_ * headDim_; + dstFactor_ = headNum_ / splitNum_ * headDim_; + copyOutLength_ = blockSize_ * headNum_ * headDim_; + dataBlockSize_ = static_cast(AscendC::GetDataBlockSizeInBytes()); + } + + __aicore__ inline void Process() { + DataCopyParams repeatParams; + repeatParams.blockCount = blockSize_; + repeatParams.blockLen = headNum_ / splitNum_ * headDim_ * sizeof(T) / dataBlockSize_; + repeatParams.srcStride = 0; + repeatParams.dstStride = (headNum_ * headDim_ - headNum_ / splitNum_ * headDim_) * sizeof(T) / dataBlockSize_; + + uint32_t startBlock; + uint32_t endBlock; + uint32_t startLayer; + uint32_t endLayer; + + Caloffset(startBlock, endBlock, startLayer, endLayer); + for (uint32_t i = startBlock; i < endBlock; ++i) { + int64_t blockId = blockIDsGm_.GetValue(i); + uint32_t offsetBlock = blockId * blockSize_ * headNum_ * headDim_; + uint32_t realStartLayer; + uint32_t realEndLayer; + if (i == startBlock) { + realStartLayer = startLayer; + } else { + realStartLayer = 0; + } + + if (i == (endBlock - 1)) { + realEndLayer = endLayer; + } else { + realEndLayer = layerNum_; + } + for (uint32_t layerId = realStartLayer; layerId < realEndLayer; ++layerId) { + SetGlobalBuffers(layerId); + + CopyIn(kCacheGm_, offsetBlock, repeatParams); + CopyOut(kCacheGm_, offsetBlock); + + CopyIn(vCacheGm_, offsetBlock, repeatParams); + CopyOut(vCacheGm_, offsetBlock); + } + + } + } +}; \ No newline at end of file diff --git a/csrc/transpose_kv_cache_by_block/op_kernel/general.h b/csrc/transpose_kv_cache_by_block/op_kernel/general.h new file mode 100644 index 00000000..1526a4db --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_kernel/general.h @@ -0,0 +1,190 @@ +#include "common.h" + +template +class TransposeKvCacheByBlockKernelGeneral { + protected: + TQueBind queBind_; + GlobalTensor kCacheGm_; + GlobalTensor vCacheGm_; + GlobalTensor blockIDsGm_; + + GM_ADDR kCachePtr_; + GM_ADDR vCachePtr_; + + // shape info + uint32_t blockNum_; + uint32_t blockSize_; + uint32_t headNum_; + uint32_t headDim_; + uint32_t splitNum_; + uint32_t layerNum_; + uint32_t headNumSplited_; + uint32_t blockSizeSplitNum_; + + // tiling info + uint32_t useCoreNum_; + uint32_t blockPerCore_; + uint32_t tailCoreNum_; + uint32_t calBlockNum_; + + uint32_t srcFactor_; + uint32_t dstFactor_; + uint32_t copyOutLength_; + + uint32_t blockSizePerTime_; + uint32_t blockSizePerTimeTail_; + + uint32_t blockIdx_; + uint32_t dataBlockSize_; + bool needSync_; + + __aicore__ inline void CopyIn(GlobalTensor &cacheGm, uint32_t offsetBlock, DataCopyParams &repeatParams) { + LocalTensor cacheLocal = queBind_.AllocTensor(); + for (uint32_t i = 0; i < splitNum_; ++i) { + DataCopy(cacheLocal[i * dstFactor_], cacheGm[i * srcFactor_ + offsetBlock], repeatParams); + } + queBind_.EnQue(cacheLocal); + } + + __aicore__ inline void CopyOut(GlobalTensor &cacheGm, uint32_t offsetBlock) { + LocalTensor cacheLocal = queBind_.DeQue(); + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8); + AscendC::CrossCoreWaitFlag(0x8); + DataCopy(cacheGm[offsetBlock], cacheLocal, copyOutLength_); + queBind_.FreeTensor(cacheLocal); + } + + __aicore__ inline void SetGlobalBuffers(uint32_t layerId) { + kCacheGm_.SetGlobalBuffer(GetTensorAddr(layerId, kCachePtr_)); + vCacheGm_.SetGlobalBuffer(GetTensorAddr(layerId, vCachePtr_)); + } + + __aicore__ inline void Caloffset(uint32_t &startBlock, uint32_t &endBlock, uint32_t &startLayer, uint32_t &endLayer) { + + uint32_t curBlockStart; + uint32_t curBlocknum; + uint32_t groupBlockIdx = blockIdx_ / blockSizeSplitNum_; + if (groupBlockIdx < tailCoreNum_) { + needSync_ = false; + curBlockStart = groupBlockIdx * (blockPerCore_ + 1); + curBlocknum = blockPerCore_ + 1; + } else { + needSync_ = true; + curBlockStart = groupBlockIdx * blockPerCore_ + tailCoreNum_; + curBlocknum = blockPerCore_; + } + uint32_t curBlockEnd = curBlockStart + curBlocknum; + startBlock = curBlockStart / layerNum_; + startLayer = curBlockStart % layerNum_; + endBlock = (curBlockEnd + layerNum_ - 1) / layerNum_; + endLayer = curBlockEnd % layerNum_; + if (endLayer == 0) { + endLayer = layerNum_; + } + } + + + public: + __aicore__ inline void Init(GM_ADDR KCache, GM_ADDR VCache, GM_ADDR blockIDs, + TransposeKvCacheByBlockTilingData* tilingData, TPipe* tPipe) { + kCachePtr_ = KCache; + vCachePtr_ = VCache; + blockIDsGm_.SetGlobalBuffer((__gm__ int64_t*)blockIDs); + blockIdx_ = GetBlockIdx(); + // shape info + blockSize_ = tilingData->blockSize; + headNum_ = tilingData->headNum; + headDim_ = tilingData->headDim; + splitNum_ = tilingData->splitNum; + layerNum_ = tilingData->layerNum; + // tiling info + useCoreNum_ = tilingData->useCoreNum; + blockPerCore_ = tilingData->blockPerCore; + tailCoreNum_ = tilingData->tailCoreNum; + calBlockNum_ = tilingData->calBlockNum; + blockSizeSplitNum_ = tilingData->blockSizeSplitNum; + blockSizePerTime_ = tilingData->blockSizePerTime; + blockSizePerTimeTail_ = tilingData->blockSizePerTimeTail; + headNumSplited_ = headNum_ / splitNum_; + + tPipe->InitBuffer(queBind_, DB, TOTAL_UB_SIZE / DB); + srcFactor_ = blockSize_ * headNumSplited_ * headDim_; + dstFactor_ = headNumSplited_ * headDim_; + copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_; + dataBlockSize_ = static_cast(AscendC::GetDataBlockSizeInBytes()); + } + + __aicore__ inline void Process() { + DataCopyParams repeatParams; + repeatParams.blockCount = blockSizePerTime_; + repeatParams.blockLen = headNumSplited_ * headDim_ * sizeof(T) / dataBlockSize_; + repeatParams.srcStride = 0; + repeatParams.dstStride = (headNum_ * headDim_ - headNumSplited_ * headDim_) * sizeof(T) / dataBlockSize_; + + uint32_t startBlock; + uint32_t endBlock; + uint32_t startLayer; + uint32_t endLayer; + + Caloffset(startBlock, endBlock, startLayer, endLayer); + + for (uint32_t i = startBlock; i < endBlock; ++i) { + int64_t blockId = blockIDsGm_.GetValue(i); + uint32_t offsetBlock = blockId * blockSize_ * headNum_ * headDim_; + uint32_t realStartLayer; + uint32_t realEndLayer; + if (i == startBlock) { + realStartLayer = startLayer; + } else { + realStartLayer = 0; + } + + if (i == (endBlock - 1)) { + realEndLayer = endLayer; + } else { + realEndLayer = layerNum_; + } + for (uint32_t layerId = realStartLayer; layerId < realEndLayer; ++layerId) { + SetGlobalBuffers(layerId); + uint32_t blockSizeIndex = blockIdx_ % blockSizeSplitNum_; + uint32_t srcOffset; + uint32_t dstOffset; + if constexpr (needHandleUnFactorSplit) { + // handle tail + if (blockSizeIndex >= blockSizePerTimeTail_) { + repeatParams.blockCount = (blockSizePerTime_ - 1); + copyOutLength_ = (blockSizePerTime_ - 1) * headNum_ * headDim_; + srcOffset = (blockSizeIndex * blockSizePerTime_ - (blockSizeIndex - blockSizePerTimeTail_)) * headNumSplited_ * headDim_; + dstOffset = (blockSizeIndex * blockSizePerTime_ - (blockSizeIndex - blockSizePerTimeTail_)) * headNum_ * headDim_; + } else { + repeatParams.blockCount = blockSizePerTime_; + copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_; + srcOffset = blockSizeIndex * blockSizePerTime_ * headNumSplited_ * headDim_; + dstOffset = blockSizeIndex * blockSizePerTime_ * headNum_ * headDim_; + } + } else { + repeatParams.blockCount = blockSizePerTime_; + copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_; + srcOffset = blockSizeIndex * blockSizePerTime_ * headNumSplited_ * headDim_; + dstOffset = blockSizeIndex * blockSizePerTime_ * headNum_ * headDim_; + } + + CopyIn(kCacheGm_, offsetBlock + srcOffset, repeatParams); + CopyOut(kCacheGm_, offsetBlock + dstOffset); + + CopyIn(vCacheGm_, offsetBlock + srcOffset, repeatParams); + CopyOut(vCacheGm_, offsetBlock + dstOffset); + } + + } + + if (needSync_) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8); + AscendC::CrossCoreWaitFlag(0x8); + + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8); + AscendC::CrossCoreWaitFlag(0x8); + } + + } +}; \ No newline at end of file diff --git a/csrc/transpose_kv_cache_by_block/op_kernel/transpose_kv_cache_by_block.cpp b/csrc/transpose_kv_cache_by_block/op_kernel/transpose_kv_cache_by_block.cpp new file mode 100644 index 00000000..39ced744 --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_kernel/transpose_kv_cache_by_block.cpp @@ -0,0 +1,36 @@ +#include "kernel_operator.h" +#include "full_load.h" +#include "general.h" + + +extern "C" __global__ __aicore__ void transpose_kv_cache_by_block(GM_ADDR KCache, GM_ADDR VCache, GM_ADDR blockIDs, GM_ADDR workspace, GM_ADDR tiling) { + GET_TILING_DATA(tiling_data, tiling); + TPipe tPipe; + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIV_1_0); + if (TILING_KEY_IS(0)) { + // full load not db + TransposeKvCacheByBlockKernelFullLoad kernel; + kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe); + kernel.Process(); + } else if (TILING_KEY_IS(1)) { + // db \ align split blockSize + TransposeKvCacheByBlockKernelGeneral kernel; + kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe); + kernel.Process(); + } else if (TILING_KEY_IS(2)) { + // not db \ align split blockSize + TransposeKvCacheByBlockKernelGeneral kernel; + kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe); + kernel.Process(); + } else if (TILING_KEY_IS(3)) { + // db \ unalign split blockSize + TransposeKvCacheByBlockKernelGeneral kernel; + kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe); + kernel.Process(); + } else if (TILING_KEY_IS(4)) { + // not db \ unalign split blockSize + TransposeKvCacheByBlockKernelGeneral kernel; + kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe); + kernel.Process(); + } +} \ No newline at end of file diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_transpose_kv_cache_by_block.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_transpose_kv_cache_by_block.py new file mode 100644 index 00000000..7527f1fb --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_transpose_kv_cache_by_block.py @@ -0,0 +1,137 @@ +import random +import unittest + +import torch +import torch_npu + +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + +torch.set_printoptions(threshold=float("inf")) + +def clone_kv_cache(k_caches, v_caches): + new_k_caches = [cache.clone() for cache in k_caches] + new_v_caches = [cache.clone() for cache in v_caches] + return new_k_caches, new_v_caches + +class TestTransposeKvCacheByBlock(unittest.TestCase): + def compute_golden(self, k_caches, v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, num_need_pulls, layers, dtype): + num_blocks = block_ids_tensor.shape[0] + + block_ids_tensor = block_ids_tensor.to(dtype=torch.int32) + block_offsets = torch.arange(0, block_size, dtype=torch.int32).npu() + slot_mapping = block_offsets.reshape( + (1, block_size)) + block_ids_tensor.reshape( + (num_blocks, 1)) * block_size + slot_mapping = slot_mapping.flatten() + block_len = num_blocks * block_size + block_len_tensor = torch.tensor([block_len],dtype=torch.int32).npu() + + block_table = block_ids_tensor.view(1, -1) + seq_start_tensor = torch.tensor([0],dtype=torch.int32).npu() + + k = torch.empty(block_len, num_kv_head, head_dim, dtype=dtype).npu() + v = torch.empty(block_len, num_kv_head, head_dim, dtype=dtype).npu() + + for layer in range(layers): + k_cache_layer = k_caches[layer] + v_cache_layer = v_caches[layer] + + torch_npu.atb.npu_paged_cache_load( + k_cache_layer, + v_cache_layer, + block_table, + block_len_tensor, + seq_starts=seq_start_tensor, + key=k, + value=v, + ) + + k = k.view(num_blocks, num_need_pulls, block_size, -1) + k.transpose_(1, 2) + k = k.contiguous().view(block_len, num_kv_head, -1) + + v = v.view(num_blocks, num_need_pulls, block_size, -1) + v.transpose_(1, 2) + v = v.contiguous().view(block_len, num_kv_head, -1) + + torch_npu._npu_reshape_and_cache( + key=k, + value=v, + key_cache=k_cache_layer, + value_cache=v_cache_layer, + slot_indices=slot_mapping, + ) + del k, v + + def test_transpose_kv_cache_by_block(self): + # (layers, block_num, block_size, num_kv_head, head_dim, num_need_pulls) + test_cases = [ + (16, 128, 128, 4, 128, 4), + (16, 128, 128, 4, 128, 2), + (16, 128, 128, 4, 128, 1), + (16, 128, 128, 8, 128, 8), + (16, 128, 128, 8, 128, 4), + (16, 128, 128, 8, 128, 2), + ] + dtypes = [torch.float16, torch.bfloat16] + for dtype in dtypes: + for layers, block_num, block_size, num_kv_head, head_dim, num_need_pulls in test_cases: + with self.subTest(dtype=dtype, shape=f"({layers}, {block_num}, {block_size}, {num_kv_head}, {head_dim}, {num_need_pulls})"): + k_caches = [] + v_caches = [] + block_id_num = 33 + block_ids_tensor = torch.randperm(block_num, dtype=torch.int64, device="npu")[:block_id_num] + for i in range(layers): + kcache = torch.randn(block_num, block_size, num_kv_head, head_dim, dtype=dtype, device="npu") + vcache = torch.randn(block_num, block_size, num_kv_head, head_dim, dtype=dtype, device="npu") + k_caches.append(kcache) + v_caches.append(vcache) + + cloned_k_caches, cloned_v_caches = clone_kv_cache(k_caches, v_caches) + self.compute_golden(cloned_k_caches, cloned_v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, num_need_pulls, layers, dtype) + torch.ops._C_ascend.transpose_kv_cache_by_block(k_caches, v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, num_need_pulls, layers) + + for i in range (layers): + self.assert_tensors_almost_equal(k_caches[i], cloned_k_caches[i], dtype) + self.assert_tensors_almost_equal(v_caches[i], cloned_v_caches[i], dtype) + + def assert_tensors_almost_equal(self, actual, expected, dtype): + """Check if two tensors are approximately equal (considering floating point errors)""" + self.assertEqual(actual.shape, expected.shape, "Shape mismatch") + + # Check for NaN + self.assertFalse( + torch.isnan(actual).any(), "Actual result contains NaN") + self.assertFalse( + torch.isnan(expected).any(), "Expected result contains NaN") + + # Check for Inf + self.assertFalse( + torch.isinf(actual).any(), "Actual result contains Inf") + self.assertFalse( + torch.isinf(expected).any(), "Expected result contains Inf") + + # Set different tolerances based on data type + if dtype == torch.float16: + rtol, atol = 1e-5, 1e-5 + else: # bfloat16 + rtol, atol = 1.5e-5, 1.5e-5 + + # Compare values + diff = torch.abs(actual - expected) + max_diff = diff.max().item() + max_expected = torch.abs(expected).max().item() + + # Check relative and absolute errors + if max_expected > 0: + relative_diff = max_diff / max_expected + self.assertLessEqual( + relative_diff, + rtol, + f"Relative error too large: {relative_diff} > {rtol}. Max difference: {max_diff}", + ) + + self.assertLessEqual(max_diff, atol, + f"Absolute error too large: {max_diff} > {atol}") diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py index 8fd0d4d2..ca164ad8 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py @@ -46,10 +46,11 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import RequestStatus +from vllm_ascend import envs as ascend_envs from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te from vllm_ascend.distributed.kv_transfer.utils.utils import get_transfer_timeout_value -from vllm_ascend.utils import is_vl_model +from vllm_ascend.utils import enable_custom_op, is_vl_model # isort: off if TYPE_CHECKING: @@ -570,8 +571,39 @@ class KVCacheRecvingThread(threading.Thread): is_kv_transfer_end = global_offset == tp_num_need_pulls * self._prefill_pp_size - 1 need_cat_cache = tp_num_need_pulls > 1 and is_kv_transfer_end need_nz_cache = get_ascend_config().enable_kv_nz and is_kv_transfer_end + use_fused_op = ascend_envs.VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK if need_nz_cache or need_cat_cache: - self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, need_cat_cache, need_nz_cache) + # use fused op to reformat kv cache, we keep original implementation to provide ability to disable it. + if use_fused_op and enable_custom_op(): + if need_cat_cache: + # the fused op only support cat GQA/MHA kv cache by head + self.reformat_kv_cache_with_fused_op(grouped_local_block_ids, tp_num_need_pulls) + if need_nz_cache: + # maybe use fused op to reformat kv nz too in the future. + self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, False, need_nz_cache) + else: + self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, need_cat_cache, need_nz_cache) + + def reformat_kv_cache_with_fused_op(self, block_ids: list[list[int]], tp_num_need_pulls: int): + # Get necessary parameters + k_cache = list(self.kv_caches.values())[0][0] + device = k_cache.device + head_dim = self.model_config.hf_text_config.head_dim + block_size = self.vllm_config.cache_config.block_size + num_kv_head = max(self.model_config.hf_text_config.num_key_value_heads // self.tp_size, 1) + layers = self.model_config.hf_text_config.num_hidden_layers + flat_block_ids = [item for sublist in block_ids for item in sublist] + block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int64, device=device) + + k_caches = [] + v_caches = [] + for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items(): + k_caches.append(k_cache_layer) + v_caches.append(v_cache_layer) + + torch.ops._C_ascend.transpose_kv_cache_by_block( + k_caches, v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, tp_num_need_pulls, layers + ) def reformat_kv_cache( self, diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 3e5c4cc7..6fb90aea 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -111,6 +111,10 @@ env_variables: dict[str, Callable[[], Any]] = { "VLLM_ASCEND_ENABLE_FUSED_MC2": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", "0")), # Whether to anbale balance scheduling "VLLM_ASCEND_BALANCE_SCHEDULING": lambda: bool(int(os.getenv("VLLM_ASCEND_BALANCE_SCHEDULING", "0"))), + # use fused op transpose_kv_cache_by_block, default is True + "VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK": lambda: bool( + int(os.getenv("VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK", "1")) + ), } # end-env-vars-definition