diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index 938d56a0..fc4c53d0 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd) export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH} - CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;" + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series @@ -81,6 +81,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then "moe_gating_top_k" "add_rms_norm_bias" "apply_top_k_top_p_custom" + "transpose_kv_cache_by_block" ) CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}") SOC_ARG="ascend910_93" diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 146123ec..963bf854 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -1343,6 +1343,22 @@ std::tuple npu_add_rms_norm_bias( return std::tuple(y, rstd, x); } +void transpose_kv_cache_by_block( + const at::TensorList &kCache, + const at::TensorList &vCache, + const at::Tensor &blockIDs, + int64_t blockSize, + int64_t headNum, + int64_t headDim, + int64_t splitNum, + int64_t layerNum) +{ + + EXEC_NPU_CMD(aclnnTransposeKvCacheByBlock, kCache, vCache, blockIDs, + blockSize, headNum, headDim, splitNum, layerNum); + +} + } // namespace vllm_ascend TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) @@ -1521,4 +1537,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) ops.def("npu_apply_top_k_top_p(Tensor logits, Tensor? p=None, Tensor? k=None) -> Tensor"); ops.impl("npu_apply_top_k_top_p", torch::kPrivateUse1, &vllm_ascend::npu_apply_top_k_top_p); + ops.def( + "transpose_kv_cache_by_block(Tensor[] kCache, Tensor[] vCache, Tensor blockIDs, int blockSize, int headNum, int headDim, int splitNum, int layerNum) -> ()" + ); + ops.impl("transpose_kv_cache_by_block", torch::kPrivateUse1, &vllm_ascend::transpose_kv_cache_by_block); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index af550134..378519bc 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -435,6 +435,20 @@ std::tuple npu_add_rms_norm_bias_meta( at::Tensor x = at::empty_symint(x1.sym_sizes(), x1.options()); return std::tuple(y, rstd, x); } + +void transpose_kv_cache_by_block_meta( + const at::TensorList &k_cache, + const at::TensorList &v_cache, + const at::Tensor &block_ids, + int64_t block_size, + int64_t head_num, + int64_t head_dim, + int64_t split_num, + int64_t layer_num) +{ + return; +} + } // namespace meta } // namespace vllm_ascend @@ -475,5 +489,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { ops.impl("moe_gating_top_k", &vllm_ascend::meta::moe_gating_top_k_meta); // Add_Rms_Norm_Bias ops.impl("npu_add_rms_norm_bias", &vllm_ascend::meta::npu_add_rms_norm_bias_meta); + // transpose_kv_cache_by_block + ops.impl("transpose_kv_cache_by_block", &vllm_ascend::meta::transpose_kv_cache_by_block_meta); } } diff --git a/csrc/transpose_kv_cache_by_block/op_host/CMakeLists.txt b/csrc/transpose_kv_cache_by_block/op_host/CMakeLists.txt new file mode 100644 index 00000000..dfd3d15d --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_host/CMakeLists.txt @@ -0,0 +1,40 @@ +# This program is free software, you can redistribute it and/or modify it. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_ops_compile_options( + OP_NAME TransposeKvCacheByBlock + OPTIONS --cce-auto-sync=off + -Wno-deprecated-declarations + -Werror + -mllvm -cce-aicore-hoist-movemask=false + --op_relocatable_kernel_binary=true +) + +target_sources(op_host_aclnn PRIVATE + transpose_kv_cache_by_block_def.cpp +) + +target_sources(optiling PRIVATE + transpose_kv_cache_by_block_tiling.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(opmaster_ct PRIVATE + transpose_kv_cache_by_block_tiling.cpp + ) +endif () + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_sources(opsproto PRIVATE + transpose_kv_cache_by_block_proto.cpp +) + diff --git a/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_def.cpp b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_def.cpp new file mode 100644 index 00000000..5e1c1144 --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_def.cpp @@ -0,0 +1,36 @@ +#include "register/op_def_registry.h" + +namespace ops { +class TransposeKvCacheByBlock : public OpDef { +public: + explicit TransposeKvCacheByBlock(const char* name) : OpDef(name) + { + this->Input("KCache") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("VCache") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("blockIDs") + .ParamType(REQUIRED) + .DataTypeList({ge::DT_INT64}) + .FormatList({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Attr("blockSize").Int(); + this->Attr("headNum").Int(); + this->Attr("headDim").Int(); + this->Attr("splitNum").Int(); + this->Attr("layerNum").Int(); + + this->AICore().AddConfig("ascend910b"); + this->AICore().AddConfig("ascend910_93"); + + } +}; + +OP_ADD(TransposeKvCacheByBlock); +} diff --git a/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_proto.cpp b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_proto.cpp new file mode 100644 index 00000000..24f2d84d --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_proto.cpp @@ -0,0 +1,36 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file transpose_kv_cache_by_block_proto.cpp + * \brief + */ +#include +#include +#include "error/ops_error.h" + +using namespace ge; + +namespace ops { + +static ge::graphStatus InferShapeTransposeKvCacheByBlock(gert::InferShapeContext* context) +{ + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus InferDataTypeTransposeKvCacheByBlock(gert::InferDataTypeContext *context) +{ + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_INFERSHAPE(TransposeKvCacheByBlock) + .InferShape(InferShapeTransposeKvCacheByBlock) + .InferDataType(InferDataTypeTransposeKvCacheByBlock); +} // namespace ops diff --git a/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_tiling.cpp b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_tiling.cpp new file mode 100644 index 00000000..5e8ae289 --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_tiling.cpp @@ -0,0 +1,182 @@ +#include "transpose_kv_cache_by_block_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/platform/platform_ascendc.h" +#include "log/ops_log.h" +#include + +namespace optiling { + +constexpr uint64_t DATA_SIZE = 2; +constexpr uint64_t BLOCK_SIZE = 32; +constexpr uint64_t DB_ON = 2; + +constexpr uint32_t FULL_LOAD = 0; +constexpr uint32_t SPLIT_BLOCK_SIZE_ALIGNED_AND_DB = 1; +constexpr uint32_t SPLIT_BLOCK_SIZE_UNALIGNED_AND_DB = 3; +constexpr uint32_t SPLIT_BLOCK_SIZE_ALIGNED_AND_NOT_DB = 2; +constexpr uint32_t SPLIT_BLOCK_SIZE_UNALIGNED_AND_NOT_DB = 4; + +void findFactorsOptimized(std::vector &factors, int64_t n) { + + for (int64_t i = 1; i * i <= n; i++) { + if (n % i == 0) { + factors.push_back(i); + + if (i != n / i) { + factors.push_back(n / i); + } + } + } + + sort(factors.begin(), factors.end()); +} + +ge::graphStatus CalTiling(gert::TilingContext* context, TransposeKvCacheByBlockTilingData &tiling) +{ + fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo(); + OPS_LOG_E_IF_NULL(context, platformInfoPtr, return ge::GRAPH_FAILED); + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr); + int64_t useCoreNum = ascendcPlatform.GetCoreNumAiv(); + uint64_t ubSize; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + + auto attr = context->GetAttrs(); + OPS_LOG_E_IF_NULL(context, attr, return ge::GRAPH_FAILED); + const int64_t* blockSizePtr = attr->GetAttrPointer(0); + const int64_t* headNumPtr = attr->GetAttrPointer(1); + const int64_t* headDimPtr = attr->GetAttrPointer(2); + const int64_t* splitNumPtr = attr->GetAttrPointer(3); + const int64_t* layerNumPtr = attr->GetAttrPointer(4); + OPS_CHECK(blockSizePtr == nullptr || headNumPtr == nullptr || headDimPtr == nullptr || + splitNumPtr == nullptr || layerNumPtr == nullptr, + OPS_LOG_E(context->GetNodeName(), "Get attr failed."), + return ge::GRAPH_FAILED); + + auto blockIDsTensor = context->GetDynamicInputTensor(2, 0); + OPS_LOG_E_IF_NULL(context, blockIDsTensor, return ge::GRAPH_FAILED); + + gert::Shape blockIDsTensorShape = blockIDsTensor->GetStorageShape(); + int64_t calBlockNum = static_cast(blockIDsTensorShape.GetDim(0)); + + tiling.set_calBlockNum(static_cast(calBlockNum)); + + int64_t blockSize = *blockSizePtr; + int64_t headNum = *headNumPtr; + int64_t headDim = *headDimPtr; + int64_t splitNum = *splitNumPtr; + int64_t layerNum = *layerNumPtr; + uint32_t tilingKey = FULL_LOAD; + + if (headDim * DATA_SIZE % BLOCK_SIZE != 0) { + OPS_LOG_E(context, "headDim * DATA_SIZE must be a multiple of 32 bytes."); + return ge::GRAPH_FAILED; + } + + std::vector factors; + findFactorsOptimized(factors, useCoreNum); + + uint32_t factorIndex = 0; + bool findSplitNum = true; + int64_t blockSizeSplitNum = factors[factorIndex]; + uint64_t dataSizeloadOnce = blockSize * headNum * headDim * DATA_SIZE; + // if can full load, not split blockSize and db + if (dataSizeloadOnce > ubSize) { + tilingKey = SPLIT_BLOCK_SIZE_ALIGNED_AND_DB; + // split blockSize and db + while (dataSizeloadOnce > (ubSize / DB_ON)) { + factorIndex += 1; + if (factorIndex == factors.size()) { + tilingKey = FULL_LOAD; + findSplitNum = false; + break; + } + blockSizeSplitNum = factors[factorIndex]; + dataSizeloadOnce = ((blockSize + blockSizeSplitNum - 1) / blockSizeSplitNum) * headNum * headDim * DATA_SIZE; + } + if (tilingKey == SPLIT_BLOCK_SIZE_ALIGNED_AND_DB && (blockSize % blockSizeSplitNum != 0)) { + tilingKey = SPLIT_BLOCK_SIZE_UNALIGNED_AND_DB; + } + } + + if (!findSplitNum) { + tilingKey = SPLIT_BLOCK_SIZE_ALIGNED_AND_NOT_DB; + // split blockSize but not db + findSplitNum = true; + factorIndex = 0; + blockSizeSplitNum = factors[factorIndex]; + dataSizeloadOnce = blockSize * headNum * headDim * DATA_SIZE; + while (dataSizeloadOnce > ubSize) { + factorIndex += 1; + if (factorIndex == factors.size()) { + tilingKey = FULL_LOAD; + findSplitNum = false; + break; + } + blockSizeSplitNum = factors[factorIndex]; + dataSizeloadOnce = ((blockSize + blockSizeSplitNum - 1) / blockSizeSplitNum) * headNum * headDim * DATA_SIZE; + } + if (tilingKey == SPLIT_BLOCK_SIZE_ALIGNED_AND_NOT_DB && (blockSize % blockSizeSplitNum != 0)) { + tilingKey = SPLIT_BLOCK_SIZE_UNALIGNED_AND_NOT_DB; + } + } + + // headNum * headDim too large + if (!findSplitNum) { + OPS_LOG_E(context, "headNum * headDim * sizeof(half) > ubSize " + "or blockSize * headNum * headDim * sizeof(half) > ubSize * vectorCoreNum. " + "Currently, splitting headNum or headDim is not supported."); + return ge::GRAPH_FAILED; + } + tiling.set_blockSizePerTime(static_cast((blockSize + blockSizeSplitNum - 1) / blockSizeSplitNum)); + tiling.set_blockSizePerTimeTail(static_cast(blockSize % blockSizeSplitNum)); + tiling.set_blockSizeSplitNum(static_cast(blockSizeSplitNum)); + + tiling.set_blockSize(static_cast(blockSize)); + tiling.set_headNum(static_cast(headNum)); + tiling.set_headDim(static_cast(headDim)); + tiling.set_splitNum(static_cast(splitNum)); + tiling.set_layerNum(static_cast(layerNum)); + + int64_t totalRound = layerNum * calBlockNum; + + if ((totalRound * blockSizeSplitNum) < useCoreNum) { + useCoreNum = totalRound * blockSizeSplitNum; + } + int64_t blockPerCore = totalRound / (useCoreNum / blockSizeSplitNum); + int64_t tailCoreNum = totalRound % (useCoreNum / blockSizeSplitNum); + + tiling.set_useCoreNum(static_cast(useCoreNum)); + tiling.set_blockPerCore(static_cast(blockPerCore)); + tiling.set_tailCoreNum(static_cast(tailCoreNum)); + context->SetBlockDim(useCoreNum); + context->SetTilingKey(tilingKey); + + return ge::GRAPH_SUCCESS; +} + + +static ge::graphStatus TransposeKvCacheByBlockTilingFunc(gert::TilingContext* context) +{ + + TransposeKvCacheByBlockTilingData tiling; + auto status = CalTiling(context, tiling); + OP_CHECK(status != ge::GRAPH_SUCCESS, OPS_LOG_E(context->GetNodeName(), "Cal tiling failed."), + return ge::GRAPH_FAILED); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + return ge::GRAPH_SUCCESS; +} + +struct TransposeKvCacheByBlockCompileInfo {}; +ge::graphStatus TilingParseForTransposeKvCacheByBlock(gert::TilingParseContext *context) +{ + (void)context; + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(TransposeKvCacheByBlock) + .Tiling(TransposeKvCacheByBlockTilingFunc) + .TilingParse(TilingParseForTransposeKvCacheByBlock); +} \ No newline at end of file diff --git a/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_tiling.h b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_tiling.h new file mode 100644 index 00000000..c3b0901b --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_tiling.h @@ -0,0 +1,23 @@ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(TransposeKvCacheByBlockTilingData) + // shape info + // TILING_DATA_FIELD_DEF(uint32_t, blockNum); + TILING_DATA_FIELD_DEF(uint32_t, blockSize); + TILING_DATA_FIELD_DEF(uint32_t, headNum); + TILING_DATA_FIELD_DEF(uint32_t, headDim); + TILING_DATA_FIELD_DEF(uint32_t, splitNum); + TILING_DATA_FIELD_DEF(uint32_t, layerNum); + // tiling info + TILING_DATA_FIELD_DEF(uint32_t, useCoreNum); + TILING_DATA_FIELD_DEF(uint32_t, blockPerCore); + TILING_DATA_FIELD_DEF(uint32_t, tailCoreNum); + TILING_DATA_FIELD_DEF(uint32_t, calBlockNum); + TILING_DATA_FIELD_DEF(uint32_t, blockSizePerTime); + TILING_DATA_FIELD_DEF(uint32_t, blockSizePerTimeTail); + TILING_DATA_FIELD_DEF(uint32_t, blockSizeSplitNum); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(TransposeKvCacheByBlock, TransposeKvCacheByBlockTilingData) +} diff --git a/csrc/transpose_kv_cache_by_block/op_kernel/common.h b/csrc/transpose_kv_cache_by_block/op_kernel/common.h new file mode 100644 index 00000000..e74089fa --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_kernel/common.h @@ -0,0 +1,16 @@ +#include "kernel_operator.h" +using namespace AscendC; + +#ifndef __OP_KERNEL_KV_CACHE_TRANSPOSE_H__ +#define __OP_KERNEL_KV_CACHE_TRANSPOSE_H__ + +template +__aicore__ inline __gm__ T* GetTensorAddr(uint16_t index, GM_ADDR tensorPtr) { + __gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr); + // The offset of the data address from the first address. + uint64_t tensorPtrOffset = *dataAddr; + // Moving 3 bits to the right means dividing by sizeof(uint64 t). + __gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3); + return reinterpret_cast<__gm__ T*>(*(retPtr + index)); +} +#endif \ No newline at end of file diff --git a/csrc/transpose_kv_cache_by_block/op_kernel/full_load.h b/csrc/transpose_kv_cache_by_block/op_kernel/full_load.h new file mode 100644 index 00000000..86951f69 --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_kernel/full_load.h @@ -0,0 +1,141 @@ +#include "common.h" + +template +class TransposeKvCacheByBlockKernelFullLoad { + protected: + TQueBind vecInQueue_; + GlobalTensor kCacheGm_; + GlobalTensor vCacheGm_; + GlobalTensor blockIDsGm_; + + GM_ADDR kCachePtr_; + GM_ADDR vCachePtr_; + + // shape info + uint32_t blockNum_; + uint32_t blockSize_; + uint32_t headNum_; + uint32_t headDim_; + uint32_t splitNum_; + uint32_t layerNum_; + // tiling info + uint32_t useCoreNum_; + uint32_t blockPerCore_; + uint32_t tailCoreNum_; + uint32_t calBlockNum_; + + uint32_t srcFactor_; + uint32_t dstFactor_; + uint32_t copyOutLength_; + uint32_t dataBlockSize_; + + __aicore__ inline void CopyIn(GlobalTensor &cacheGm, uint32_t offsetBlock, DataCopyParams &repeatParams) { + LocalTensor cacheLocal = vecInQueue_.AllocTensor(); + for (uint32_t i = 0; i < splitNum_; ++i) { + DataCopy(cacheLocal[i * dstFactor_], cacheGm[i * srcFactor_ + offsetBlock], repeatParams); + } + vecInQueue_.EnQue(cacheLocal); + } + + __aicore__ inline void CopyOut(GlobalTensor &cacheGm, uint32_t offsetBlock) { + LocalTensor cacheLocal = vecInQueue_.DeQue(); + DataCopy(cacheGm[offsetBlock], cacheLocal, copyOutLength_); + vecInQueue_.FreeTensor(cacheLocal); + } + + __aicore__ inline void SetGlobalBuffers(uint32_t layerId) { + kCacheGm_.SetGlobalBuffer(GetTensorAddr(layerId, kCachePtr_)); + vCacheGm_.SetGlobalBuffer(GetTensorAddr(layerId, vCachePtr_)); + } + + __aicore__ inline void Caloffset(uint32_t &startBlock, uint32_t &endBlock, uint32_t &startLayer, uint32_t &endLayer) { + uint32_t blockIdx = GetBlockIdx(); + uint32_t curBlockStart; + uint32_t curBlocknum; + + if (blockIdx < tailCoreNum_) { + curBlockStart = blockIdx * (blockPerCore_ + 1); + curBlocknum = blockPerCore_ + 1; + } else { + curBlockStart = blockIdx * blockPerCore_ + tailCoreNum_; + curBlocknum = blockPerCore_; + } + uint32_t curBlockEnd = curBlockStart + curBlocknum; + startBlock = curBlockStart / layerNum_; + startLayer = curBlockStart % layerNum_; + endBlock = (curBlockEnd + layerNum_ - 1) / layerNum_; + endLayer = curBlockEnd % layerNum_; + if (endLayer == 0) { + endLayer = layerNum_; + } + } + + + public: + __aicore__ inline void Init(GM_ADDR KCache, GM_ADDR VCache, GM_ADDR blockIDs, + TransposeKvCacheByBlockTilingData* tilingData, TPipe* tPipe) { + kCachePtr_ = KCache; + vCachePtr_ = VCache; + blockIDsGm_.SetGlobalBuffer((__gm__ int64_t*)blockIDs); + + // shape info + blockSize_ = tilingData->blockSize; + headNum_ = tilingData->headNum; + headDim_ = tilingData->headDim; + splitNum_ = tilingData->splitNum; + layerNum_ = tilingData->layerNum; + // tiling info + useCoreNum_ = tilingData->useCoreNum; + blockPerCore_ = tilingData->blockPerCore; + tailCoreNum_ = tilingData->tailCoreNum; + calBlockNum_ = tilingData->calBlockNum; + + tPipe->InitBuffer(vecInQueue_, 1, TOTAL_UB_SIZE); + srcFactor_ = blockSize_ * headNum_ / splitNum_ * headDim_; + dstFactor_ = headNum_ / splitNum_ * headDim_; + copyOutLength_ = blockSize_ * headNum_ * headDim_; + dataBlockSize_ = static_cast(AscendC::GetDataBlockSizeInBytes()); + } + + __aicore__ inline void Process() { + DataCopyParams repeatParams; + repeatParams.blockCount = blockSize_; + repeatParams.blockLen = headNum_ / splitNum_ * headDim_ * sizeof(T) / dataBlockSize_; + repeatParams.srcStride = 0; + repeatParams.dstStride = (headNum_ * headDim_ - headNum_ / splitNum_ * headDim_) * sizeof(T) / dataBlockSize_; + + uint32_t startBlock; + uint32_t endBlock; + uint32_t startLayer; + uint32_t endLayer; + + Caloffset(startBlock, endBlock, startLayer, endLayer); + for (uint32_t i = startBlock; i < endBlock; ++i) { + int64_t blockId = blockIDsGm_.GetValue(i); + uint32_t offsetBlock = blockId * blockSize_ * headNum_ * headDim_; + uint32_t realStartLayer; + uint32_t realEndLayer; + if (i == startBlock) { + realStartLayer = startLayer; + } else { + realStartLayer = 0; + } + + if (i == (endBlock - 1)) { + realEndLayer = endLayer; + } else { + realEndLayer = layerNum_; + } + for (uint32_t layerId = realStartLayer; layerId < realEndLayer; ++layerId) { + SetGlobalBuffers(layerId); + + CopyIn(kCacheGm_, offsetBlock, repeatParams); + CopyOut(kCacheGm_, offsetBlock); + + CopyIn(vCacheGm_, offsetBlock, repeatParams); + CopyOut(vCacheGm_, offsetBlock); + } + + } + } +}; \ No newline at end of file diff --git a/csrc/transpose_kv_cache_by_block/op_kernel/general.h b/csrc/transpose_kv_cache_by_block/op_kernel/general.h new file mode 100644 index 00000000..1526a4db --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_kernel/general.h @@ -0,0 +1,190 @@ +#include "common.h" + +template +class TransposeKvCacheByBlockKernelGeneral { + protected: + TQueBind queBind_; + GlobalTensor kCacheGm_; + GlobalTensor vCacheGm_; + GlobalTensor blockIDsGm_; + + GM_ADDR kCachePtr_; + GM_ADDR vCachePtr_; + + // shape info + uint32_t blockNum_; + uint32_t blockSize_; + uint32_t headNum_; + uint32_t headDim_; + uint32_t splitNum_; + uint32_t layerNum_; + uint32_t headNumSplited_; + uint32_t blockSizeSplitNum_; + + // tiling info + uint32_t useCoreNum_; + uint32_t blockPerCore_; + uint32_t tailCoreNum_; + uint32_t calBlockNum_; + + uint32_t srcFactor_; + uint32_t dstFactor_; + uint32_t copyOutLength_; + + uint32_t blockSizePerTime_; + uint32_t blockSizePerTimeTail_; + + uint32_t blockIdx_; + uint32_t dataBlockSize_; + bool needSync_; + + __aicore__ inline void CopyIn(GlobalTensor &cacheGm, uint32_t offsetBlock, DataCopyParams &repeatParams) { + LocalTensor cacheLocal = queBind_.AllocTensor(); + for (uint32_t i = 0; i < splitNum_; ++i) { + DataCopy(cacheLocal[i * dstFactor_], cacheGm[i * srcFactor_ + offsetBlock], repeatParams); + } + queBind_.EnQue(cacheLocal); + } + + __aicore__ inline void CopyOut(GlobalTensor &cacheGm, uint32_t offsetBlock) { + LocalTensor cacheLocal = queBind_.DeQue(); + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8); + AscendC::CrossCoreWaitFlag(0x8); + DataCopy(cacheGm[offsetBlock], cacheLocal, copyOutLength_); + queBind_.FreeTensor(cacheLocal); + } + + __aicore__ inline void SetGlobalBuffers(uint32_t layerId) { + kCacheGm_.SetGlobalBuffer(GetTensorAddr(layerId, kCachePtr_)); + vCacheGm_.SetGlobalBuffer(GetTensorAddr(layerId, vCachePtr_)); + } + + __aicore__ inline void Caloffset(uint32_t &startBlock, uint32_t &endBlock, uint32_t &startLayer, uint32_t &endLayer) { + + uint32_t curBlockStart; + uint32_t curBlocknum; + uint32_t groupBlockIdx = blockIdx_ / blockSizeSplitNum_; + if (groupBlockIdx < tailCoreNum_) { + needSync_ = false; + curBlockStart = groupBlockIdx * (blockPerCore_ + 1); + curBlocknum = blockPerCore_ + 1; + } else { + needSync_ = true; + curBlockStart = groupBlockIdx * blockPerCore_ + tailCoreNum_; + curBlocknum = blockPerCore_; + } + uint32_t curBlockEnd = curBlockStart + curBlocknum; + startBlock = curBlockStart / layerNum_; + startLayer = curBlockStart % layerNum_; + endBlock = (curBlockEnd + layerNum_ - 1) / layerNum_; + endLayer = curBlockEnd % layerNum_; + if (endLayer == 0) { + endLayer = layerNum_; + } + } + + + public: + __aicore__ inline void Init(GM_ADDR KCache, GM_ADDR VCache, GM_ADDR blockIDs, + TransposeKvCacheByBlockTilingData* tilingData, TPipe* tPipe) { + kCachePtr_ = KCache; + vCachePtr_ = VCache; + blockIDsGm_.SetGlobalBuffer((__gm__ int64_t*)blockIDs); + blockIdx_ = GetBlockIdx(); + // shape info + blockSize_ = tilingData->blockSize; + headNum_ = tilingData->headNum; + headDim_ = tilingData->headDim; + splitNum_ = tilingData->splitNum; + layerNum_ = tilingData->layerNum; + // tiling info + useCoreNum_ = tilingData->useCoreNum; + blockPerCore_ = tilingData->blockPerCore; + tailCoreNum_ = tilingData->tailCoreNum; + calBlockNum_ = tilingData->calBlockNum; + blockSizeSplitNum_ = tilingData->blockSizeSplitNum; + blockSizePerTime_ = tilingData->blockSizePerTime; + blockSizePerTimeTail_ = tilingData->blockSizePerTimeTail; + headNumSplited_ = headNum_ / splitNum_; + + tPipe->InitBuffer(queBind_, DB, TOTAL_UB_SIZE / DB); + srcFactor_ = blockSize_ * headNumSplited_ * headDim_; + dstFactor_ = headNumSplited_ * headDim_; + copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_; + dataBlockSize_ = static_cast(AscendC::GetDataBlockSizeInBytes()); + } + + __aicore__ inline void Process() { + DataCopyParams repeatParams; + repeatParams.blockCount = blockSizePerTime_; + repeatParams.blockLen = headNumSplited_ * headDim_ * sizeof(T) / dataBlockSize_; + repeatParams.srcStride = 0; + repeatParams.dstStride = (headNum_ * headDim_ - headNumSplited_ * headDim_) * sizeof(T) / dataBlockSize_; + + uint32_t startBlock; + uint32_t endBlock; + uint32_t startLayer; + uint32_t endLayer; + + Caloffset(startBlock, endBlock, startLayer, endLayer); + + for (uint32_t i = startBlock; i < endBlock; ++i) { + int64_t blockId = blockIDsGm_.GetValue(i); + uint32_t offsetBlock = blockId * blockSize_ * headNum_ * headDim_; + uint32_t realStartLayer; + uint32_t realEndLayer; + if (i == startBlock) { + realStartLayer = startLayer; + } else { + realStartLayer = 0; + } + + if (i == (endBlock - 1)) { + realEndLayer = endLayer; + } else { + realEndLayer = layerNum_; + } + for (uint32_t layerId = realStartLayer; layerId < realEndLayer; ++layerId) { + SetGlobalBuffers(layerId); + uint32_t blockSizeIndex = blockIdx_ % blockSizeSplitNum_; + uint32_t srcOffset; + uint32_t dstOffset; + if constexpr (needHandleUnFactorSplit) { + // handle tail + if (blockSizeIndex >= blockSizePerTimeTail_) { + repeatParams.blockCount = (blockSizePerTime_ - 1); + copyOutLength_ = (blockSizePerTime_ - 1) * headNum_ * headDim_; + srcOffset = (blockSizeIndex * blockSizePerTime_ - (blockSizeIndex - blockSizePerTimeTail_)) * headNumSplited_ * headDim_; + dstOffset = (blockSizeIndex * blockSizePerTime_ - (blockSizeIndex - blockSizePerTimeTail_)) * headNum_ * headDim_; + } else { + repeatParams.blockCount = blockSizePerTime_; + copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_; + srcOffset = blockSizeIndex * blockSizePerTime_ * headNumSplited_ * headDim_; + dstOffset = blockSizeIndex * blockSizePerTime_ * headNum_ * headDim_; + } + } else { + repeatParams.blockCount = blockSizePerTime_; + copyOutLength_ = blockSizePerTime_ * headNum_ * headDim_; + srcOffset = blockSizeIndex * blockSizePerTime_ * headNumSplited_ * headDim_; + dstOffset = blockSizeIndex * blockSizePerTime_ * headNum_ * headDim_; + } + + CopyIn(kCacheGm_, offsetBlock + srcOffset, repeatParams); + CopyOut(kCacheGm_, offsetBlock + dstOffset); + + CopyIn(vCacheGm_, offsetBlock + srcOffset, repeatParams); + CopyOut(vCacheGm_, offsetBlock + dstOffset); + } + + } + + if (needSync_) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8); + AscendC::CrossCoreWaitFlag(0x8); + + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE2>(0x8); + AscendC::CrossCoreWaitFlag(0x8); + } + + } +}; \ No newline at end of file diff --git a/csrc/transpose_kv_cache_by_block/op_kernel/transpose_kv_cache_by_block.cpp b/csrc/transpose_kv_cache_by_block/op_kernel/transpose_kv_cache_by_block.cpp new file mode 100644 index 00000000..39ced744 --- /dev/null +++ b/csrc/transpose_kv_cache_by_block/op_kernel/transpose_kv_cache_by_block.cpp @@ -0,0 +1,36 @@ +#include "kernel_operator.h" +#include "full_load.h" +#include "general.h" + + +extern "C" __global__ __aicore__ void transpose_kv_cache_by_block(GM_ADDR KCache, GM_ADDR VCache, GM_ADDR blockIDs, GM_ADDR workspace, GM_ADDR tiling) { + GET_TILING_DATA(tiling_data, tiling); + TPipe tPipe; + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIV_1_0); + if (TILING_KEY_IS(0)) { + // full load not db + TransposeKvCacheByBlockKernelFullLoad kernel; + kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe); + kernel.Process(); + } else if (TILING_KEY_IS(1)) { + // db \ align split blockSize + TransposeKvCacheByBlockKernelGeneral kernel; + kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe); + kernel.Process(); + } else if (TILING_KEY_IS(2)) { + // not db \ align split blockSize + TransposeKvCacheByBlockKernelGeneral kernel; + kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe); + kernel.Process(); + } else if (TILING_KEY_IS(3)) { + // db \ unalign split blockSize + TransposeKvCacheByBlockKernelGeneral kernel; + kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe); + kernel.Process(); + } else if (TILING_KEY_IS(4)) { + // not db \ unalign split blockSize + TransposeKvCacheByBlockKernelGeneral kernel; + kernel.Init(KCache, VCache, blockIDs, &tiling_data, &tPipe); + kernel.Process(); + } +} \ No newline at end of file diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_transpose_kv_cache_by_block.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_transpose_kv_cache_by_block.py new file mode 100644 index 00000000..7527f1fb --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_transpose_kv_cache_by_block.py @@ -0,0 +1,137 @@ +import random +import unittest + +import torch +import torch_npu + +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + +torch.set_printoptions(threshold=float("inf")) + +def clone_kv_cache(k_caches, v_caches): + new_k_caches = [cache.clone() for cache in k_caches] + new_v_caches = [cache.clone() for cache in v_caches] + return new_k_caches, new_v_caches + +class TestTransposeKvCacheByBlock(unittest.TestCase): + def compute_golden(self, k_caches, v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, num_need_pulls, layers, dtype): + num_blocks = block_ids_tensor.shape[0] + + block_ids_tensor = block_ids_tensor.to(dtype=torch.int32) + block_offsets = torch.arange(0, block_size, dtype=torch.int32).npu() + slot_mapping = block_offsets.reshape( + (1, block_size)) + block_ids_tensor.reshape( + (num_blocks, 1)) * block_size + slot_mapping = slot_mapping.flatten() + block_len = num_blocks * block_size + block_len_tensor = torch.tensor([block_len],dtype=torch.int32).npu() + + block_table = block_ids_tensor.view(1, -1) + seq_start_tensor = torch.tensor([0],dtype=torch.int32).npu() + + k = torch.empty(block_len, num_kv_head, head_dim, dtype=dtype).npu() + v = torch.empty(block_len, num_kv_head, head_dim, dtype=dtype).npu() + + for layer in range(layers): + k_cache_layer = k_caches[layer] + v_cache_layer = v_caches[layer] + + torch_npu.atb.npu_paged_cache_load( + k_cache_layer, + v_cache_layer, + block_table, + block_len_tensor, + seq_starts=seq_start_tensor, + key=k, + value=v, + ) + + k = k.view(num_blocks, num_need_pulls, block_size, -1) + k.transpose_(1, 2) + k = k.contiguous().view(block_len, num_kv_head, -1) + + v = v.view(num_blocks, num_need_pulls, block_size, -1) + v.transpose_(1, 2) + v = v.contiguous().view(block_len, num_kv_head, -1) + + torch_npu._npu_reshape_and_cache( + key=k, + value=v, + key_cache=k_cache_layer, + value_cache=v_cache_layer, + slot_indices=slot_mapping, + ) + del k, v + + def test_transpose_kv_cache_by_block(self): + # (layers, block_num, block_size, num_kv_head, head_dim, num_need_pulls) + test_cases = [ + (16, 128, 128, 4, 128, 4), + (16, 128, 128, 4, 128, 2), + (16, 128, 128, 4, 128, 1), + (16, 128, 128, 8, 128, 8), + (16, 128, 128, 8, 128, 4), + (16, 128, 128, 8, 128, 2), + ] + dtypes = [torch.float16, torch.bfloat16] + for dtype in dtypes: + for layers, block_num, block_size, num_kv_head, head_dim, num_need_pulls in test_cases: + with self.subTest(dtype=dtype, shape=f"({layers}, {block_num}, {block_size}, {num_kv_head}, {head_dim}, {num_need_pulls})"): + k_caches = [] + v_caches = [] + block_id_num = 33 + block_ids_tensor = torch.randperm(block_num, dtype=torch.int64, device="npu")[:block_id_num] + for i in range(layers): + kcache = torch.randn(block_num, block_size, num_kv_head, head_dim, dtype=dtype, device="npu") + vcache = torch.randn(block_num, block_size, num_kv_head, head_dim, dtype=dtype, device="npu") + k_caches.append(kcache) + v_caches.append(vcache) + + cloned_k_caches, cloned_v_caches = clone_kv_cache(k_caches, v_caches) + self.compute_golden(cloned_k_caches, cloned_v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, num_need_pulls, layers, dtype) + torch.ops._C_ascend.transpose_kv_cache_by_block(k_caches, v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, num_need_pulls, layers) + + for i in range (layers): + self.assert_tensors_almost_equal(k_caches[i], cloned_k_caches[i], dtype) + self.assert_tensors_almost_equal(v_caches[i], cloned_v_caches[i], dtype) + + def assert_tensors_almost_equal(self, actual, expected, dtype): + """Check if two tensors are approximately equal (considering floating point errors)""" + self.assertEqual(actual.shape, expected.shape, "Shape mismatch") + + # Check for NaN + self.assertFalse( + torch.isnan(actual).any(), "Actual result contains NaN") + self.assertFalse( + torch.isnan(expected).any(), "Expected result contains NaN") + + # Check for Inf + self.assertFalse( + torch.isinf(actual).any(), "Actual result contains Inf") + self.assertFalse( + torch.isinf(expected).any(), "Expected result contains Inf") + + # Set different tolerances based on data type + if dtype == torch.float16: + rtol, atol = 1e-5, 1e-5 + else: # bfloat16 + rtol, atol = 1.5e-5, 1.5e-5 + + # Compare values + diff = torch.abs(actual - expected) + max_diff = diff.max().item() + max_expected = torch.abs(expected).max().item() + + # Check relative and absolute errors + if max_expected > 0: + relative_diff = max_diff / max_expected + self.assertLessEqual( + relative_diff, + rtol, + f"Relative error too large: {relative_diff} > {rtol}. Max difference: {max_diff}", + ) + + self.assertLessEqual(max_diff, atol, + f"Absolute error too large: {max_diff} > {atol}") diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py index 8fd0d4d2..ca164ad8 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py @@ -46,10 +46,11 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import RequestStatus +from vllm_ascend import envs as ascend_envs from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te from vllm_ascend.distributed.kv_transfer.utils.utils import get_transfer_timeout_value -from vllm_ascend.utils import is_vl_model +from vllm_ascend.utils import enable_custom_op, is_vl_model # isort: off if TYPE_CHECKING: @@ -570,8 +571,39 @@ class KVCacheRecvingThread(threading.Thread): is_kv_transfer_end = global_offset == tp_num_need_pulls * self._prefill_pp_size - 1 need_cat_cache = tp_num_need_pulls > 1 and is_kv_transfer_end need_nz_cache = get_ascend_config().enable_kv_nz and is_kv_transfer_end + use_fused_op = ascend_envs.VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK if need_nz_cache or need_cat_cache: - self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, need_cat_cache, need_nz_cache) + # use fused op to reformat kv cache, we keep original implementation to provide ability to disable it. + if use_fused_op and enable_custom_op(): + if need_cat_cache: + # the fused op only support cat GQA/MHA kv cache by head + self.reformat_kv_cache_with_fused_op(grouped_local_block_ids, tp_num_need_pulls) + if need_nz_cache: + # maybe use fused op to reformat kv nz too in the future. + self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, False, need_nz_cache) + else: + self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, need_cat_cache, need_nz_cache) + + def reformat_kv_cache_with_fused_op(self, block_ids: list[list[int]], tp_num_need_pulls: int): + # Get necessary parameters + k_cache = list(self.kv_caches.values())[0][0] + device = k_cache.device + head_dim = self.model_config.hf_text_config.head_dim + block_size = self.vllm_config.cache_config.block_size + num_kv_head = max(self.model_config.hf_text_config.num_key_value_heads // self.tp_size, 1) + layers = self.model_config.hf_text_config.num_hidden_layers + flat_block_ids = [item for sublist in block_ids for item in sublist] + block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int64, device=device) + + k_caches = [] + v_caches = [] + for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items(): + k_caches.append(k_cache_layer) + v_caches.append(v_cache_layer) + + torch.ops._C_ascend.transpose_kv_cache_by_block( + k_caches, v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, tp_num_need_pulls, layers + ) def reformat_kv_cache( self, diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 3e5c4cc7..6fb90aea 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -111,6 +111,10 @@ env_variables: dict[str, Callable[[], Any]] = { "VLLM_ASCEND_ENABLE_FUSED_MC2": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", "0")), # Whether to anbale balance scheduling "VLLM_ASCEND_BALANCE_SCHEDULING": lambda: bool(int(os.getenv("VLLM_ASCEND_BALANCE_SCHEDULING", "0"))), + # use fused op transpose_kv_cache_by_block, default is True + "VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK": lambda: bool( + int(os.getenv("VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK", "1")) + ), } # end-env-vars-definition