[Feat] add native kvcache offload (#3433)

### What this PR does / why we need it?
This pr is for https://github.com/vllm-project/vllm-ascend/issues/3241 ,
which is in-house solution for offloading KV cache data from the GPU
memory to other medium (in particular, CPU memory)。Previous solutions
required reliance on third-party components, which had issues with
compatibility between different versions.

### How was this patch tested?
use the following script for testing:

export CUDA_VISIBLE_DEVICES=0
export TP=1
export MODEL_PATH=/model/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10000
#export ASCEND_LAUNCH_BLOCKING=1
#export ASCEND_SLOG_PRINT_TO_STDOUT=1

python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--max-model-len 32768 --trust-remote-code --disable-log-requests \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: HF-001 <1670186653@qq.com>
This commit is contained in:
kx
2025-10-22 14:15:49 +08:00
committed by GitHub
parent 60e2be1b36
commit bc30874f8b
4 changed files with 304 additions and 0 deletions

View File

@@ -17,15 +17,77 @@
#include <torch/extension.h>
#include <torch/library.h>
#include <torch/version.h>
#include <torch/torch.h>
#include <torch_npu/csrc/core/npu/NPUStream.h>
#include <torch_npu/csrc/framework/OpCommand.h>
#include "torch_npu/csrc/core/npu/NPUGuard.h"
#include <torch_npu/csrc/npu/Module.h>
#include "acl/acl.h"
#include "acl/acl_rt.h"
#include "ops.h"
#include "utils.h"
#include "mla_preprocess/op_host/mla_preprocess.h"
#include <c10/core/Device.h>
#include <c10/util/Exception.h>
#include <c10/util/Logging.h>
namespace vllm_ascend {
void swap_blocks_impl(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping, aclrtStream stream) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
aclrtMemcpyKind memcpy_type;
if ((!src_device.is_cpu()) && (!dst_device.is_cpu())) {
TORCH_CHECK(src_device.index() == dst_device.index(),
"src and dst must be on the same npu");
memcpy_type = ACL_MEMCPY_DEVICE_TO_DEVICE;
} else if ((!src_device.is_cpu()) && dst_device.is_cpu()) {
memcpy_type = ACL_MEMCPY_DEVICE_TO_HOST;
} else if (src_device.is_cpu() && (!dst_device.is_cpu())) {
memcpy_type = ACL_MEMCPY_HOST_TO_DEVICE;
} else {
TORCH_CHECK(false, "Invalid device combination, src tensor device: ", src_device, ", dst tensor device: ", dst_device);
}
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr());
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
const int64_t num_blocks = block_mapping.size(0);
const int64_t max_src_block = src.size(0);
const int64_t max_dst_block = dst.size(0);
for (size_t i = 0; i < num_blocks; i++) {
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
TORCH_CHECK(src_block_number >= 0 && src_block_number <= max_src_block,
"src block index ", src_block_number, " out of range (max: ", max_src_block, ")");
TORCH_CHECK(dst_block_number >= 0 && dst_block_number <= max_dst_block,
"dst block index ", dst_block_number, " out of range (max: ", max_dst_block, ")");
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
aclrtMemcpyAsync(dst_ptr + dst_offset, block_size_in_bytes,
src_ptr + src_offset, block_size_in_bytes,
memcpy_type, stream);
}
}
void swap_blocks(torch::Tensor &x, torch::Tensor &y, const torch::Tensor &z)
{
const c10_npu::OptionalNPUGuard npuGuard(
(!x.device().is_cpu()) ? x.device() : y.device()
);
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
swap_blocks_impl(x, y, z, stream);
return;
}
AscendType get_dtype_from_torch(at::ScalarType scalarType)
{
@@ -511,4 +573,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
" Tensor q_out1, Tensor kv_cache_out1)"
);
ops.impl("mla_preprocess", torch::kPrivateUse1, &vllm_ascend::mla_preprocess);
ops.def("swap_blocks(Tensor! x, Tensor! y, Tensor z) -> ()");
ops.impl("swap_blocks", torch::kPrivateUse1, &vllm_ascend::swap_blocks);
}