From b43263307f40a206f1371e4064d410a136d4e004 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Wed, 23 Jul 2025 01:49:03 -0700 Subject: [PATCH] Hicache IO kernel refactoring (#8264) --- sgl-kernel/csrc/common_extension.cc | 37 +- sgl-kernel/csrc/kvcacheio/transfer.cu | 437 ++++++++++++++-------- sgl-kernel/include/sgl_kernel_ops.h | 81 ++-- sgl-kernel/python/sgl_kernel/kvcacheio.py | 160 ++++++-- sgl-kernel/tests/test_kvcacheio.py | 110 +++--- 5 files changed, 545 insertions(+), 280 deletions(-) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 070fe4bd2..20b9a8048 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -249,34 +249,39 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"); m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer); m.def( - "transfer_kv_per_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " - "dst_indices, int page_size) -> ()"); - m.impl("transfer_kv_per_layer_direct", torch::kCUDA, &transfer_kv_per_layer_direct); + "transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " + "dst_indices, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()"); + m.impl("transfer_kv_per_layer_pf_lf", torch::kCUDA, &transfer_kv_per_layer_pf_lf); m.def( - "transfer_kv_all_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " - "dst_indices, int item_size, int num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int " + "transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, " + "Tensor src_indices, Tensor dst_indices, int item_size, int num_layers, int block_quota, int " "num_warps_per_block) -> ()"); m.impl("transfer_kv_all_layer", torch::kCUDA, &transfer_kv_all_layer); m.def( - "transfer_kv_all_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " - "dst_indices, int page_size, int num_layers) -> ()"); - m.impl("transfer_kv_all_layer_direct", torch::kCUDA, &transfer_kv_all_layer_direct); + "transfer_kv_all_layer_lf_pf(Tensor src_k_layers, Tensor dst_k, Tensor src_v_layers, Tensor dst_v, " + "Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int block_quota, int " + "num_warps_per_block) -> ()"); + m.impl("transfer_kv_all_layer_lf_pf", torch::kCUDA, &transfer_kv_all_layer_lf_pf); m.def( "transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int " "block_quota, int num_warps_per_block) -> ()"); m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla); m.def( - "transfer_kv_per_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size) " - "-> ()"); - m.impl("transfer_kv_per_layer_mla_direct", torch::kCUDA, &transfer_kv_per_layer_mla_direct); + "transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, " + "int src_layout_dim, int block_quota, int num_warps_per_block) -> ()"); + m.impl("transfer_kv_per_layer_mla_pf_lf", torch::kCUDA, &transfer_kv_per_layer_mla_pf_lf); m.def( - "transfer_kv_all_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int " - "num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int num_warps_per_block) -> ()"); + "transfer_kv_all_layer_mla(Tensor src_layers, Tensor dst_layers, Tensor src_indices, Tensor dst_indices, int " + "item_size, int num_layers, int block_quota, int num_warps_per_block) -> ()"); m.impl("transfer_kv_all_layer_mla", torch::kCUDA, &transfer_kv_all_layer_mla); m.def( - "transfer_kv_all_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size, " - "int num_layers) -> ()"); - m.impl("transfer_kv_all_layer_mla_direct", torch::kCUDA, &transfer_kv_all_layer_mla_direct); + "transfer_kv_all_layer_mla_lf_pf(Tensor src_layers, Tensor dst, Tensor src_indices, Tensor dst_indices, " + "int item_size, int dst_layout_dim, int num_layers, int block_quota, int num_warps_per_block) -> ()"); + m.impl("transfer_kv_all_layer_mla_lf_pf", torch::kCUDA, &transfer_kv_all_layer_mla_lf_pf); + m.def( + "transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int " + "page_size) -> ()"); + m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct); /* * From csrc/moe/cutlass_moe/w4a8 diff --git a/sgl-kernel/csrc/kvcacheio/transfer.cu b/sgl-kernel/csrc/kvcacheio/transfer.cu index 6c939dd55..cc6942e67 100644 --- a/sgl-kernel/csrc/kvcacheio/transfer.cu +++ b/sgl-kernel/csrc/kvcacheio/transfer.cu @@ -22,17 +22,40 @@ transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_ } } -// todo, structs for different memory layout -__device__ __forceinline__ int64_t -get_global_offset_lf(int64_t layer_id, int64_t layer_dim, int64_t page_id, int64_t item_size_bytes) { +template +__device__ __forceinline__ T* get_global_offset_lf( + T* base, + const uintptr_t* __restrict__ /*unused*/, + int64_t layer_id, + int64_t layer_dim, + int64_t page_id, + int64_t item_size_bytes) { // layer first - return layer_id * layer_dim + page_id * item_size_bytes; + return base + layer_id * layer_dim + page_id * item_size_bytes; } -__device__ __forceinline__ int64_t -get_global_offset_pf(int64_t layer_id, int64_t page_dim, int64_t page_id, int64_t item_size_bytes) { +template +__device__ __forceinline__ T* get_global_offset_pf( + T* base, + const uintptr_t* __restrict__ /*unused*/, + int64_t layer_id, + int64_t page_dim, + int64_t page_id, + int64_t item_size_bytes) { // page first - return page_id * page_dim + layer_id * item_size_bytes; + return base + page_id * page_dim + layer_id * item_size_bytes; +} + +// get offset from layer base table when layers are not contiguous +template +__device__ __forceinline__ T* get_global_offset_lf_tbl( + T* /*unused*/, + const uintptr_t* __restrict__ layer_base_tbl, + int64_t layer_id, + int64_t /*unused*/, + int64_t page_id, + int64_t item_size_bytes) { + return reinterpret_cast(layer_base_tbl[layer_id]) + page_id * item_size_bytes; } template @@ -49,42 +72,37 @@ __global__ void transfer_kernel_impl( int64_t items_per_warp, int64_t item_size_bytes, int64_t src_layout_dim, - int64_t dst_layout_dim) { + int64_t dst_layout_dim, + const uintptr_t* __restrict__ src_k_layer_tbl, + const uintptr_t* __restrict__ dst_k_layer_tbl, + const uintptr_t* __restrict__ src_v_layer_tbl, + const uintptr_t* __restrict__ dst_v_layer_tbl) { int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; int32_t lane_id = tid % 32; int32_t warp_id = tid / 32; for (int i = 0; i < items_per_warp; ++i) { - int32_t item_id = warp_id * items_per_warp + i; + int64_t item_id = warp_id * items_per_warp + i; if (item_id >= num_items) { - return; + break; } const int64_t src_page_id = src_indices[item_id]; const int64_t dst_page_id = dst_indices[item_id]; // Loop over layers if necessary for (int64_t layer_id = start_layer_id; layer_id < start_layer_id + num_layers_to_process; ++layer_id) { - // Calculate offsets using the provided function pointers - const int64_t src_offset = SrcOffsetFn(layer_id, src_layout_dim, src_page_id, item_size_bytes); - const int64_t dst_offset = DstOffsetFn(layer_id, dst_layout_dim, dst_page_id, item_size_bytes); + const char* src_ptr = SrcOffsetFn( + static_cast(src_k), src_k_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes); + char* dst_ptr = DstOffsetFn( + static_cast(dst_k), dst_k_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes); + transfer_item_warp(lane_id, src_ptr, dst_ptr, item_size_bytes); - if constexpr (IsMLA) { - transfer_item_warp( - lane_id, - static_cast(src_k) + src_offset, - static_cast(dst_k) + dst_offset, - item_size_bytes); - } else { - transfer_item_warp( - lane_id, - static_cast(src_k) + src_offset, - static_cast(dst_k) + dst_offset, - item_size_bytes); - transfer_item_warp( - lane_id, - static_cast(src_v) + src_offset, - static_cast(dst_v) + dst_offset, - item_size_bytes); + if constexpr (!IsMLA) { + const char* src_v_ptr = SrcOffsetFn( + static_cast(src_v), src_v_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes); + char* dst_v_ptr = DstOffsetFn( + static_cast(dst_v), dst_v_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes); + transfer_item_warp(lane_id, src_v_ptr, dst_v_ptr, item_size_bytes); } } } @@ -103,44 +121,54 @@ void transfer_kv_launcher( int64_t item_size, int64_t src_layout_dim, int64_t dst_layout_dim, + const at::Tensor& src_k_layers, + const at::Tensor& dst_k_layers, + const at::Tensor& src_v_layers, + const at::Tensor& dst_v_layers, int64_t block_quota, int64_t num_warps_per_block) { - TORCH_CHECK(src_k.scalar_type() == dst_k.scalar_type(), "Source and destination keys must have the same type"); TORCH_CHECK(src_indices.is_cuda(), "Source indices must be a CUDA tensor"); TORCH_CHECK(dst_indices.is_cuda(), "Destination indices must be a CUDA tensor"); TORCH_CHECK(src_indices.scalar_type() == at::kLong, "Source indices must be of type long"); TORCH_CHECK(dst_indices.scalar_type() == at::kLong, "Destination indices must be of type long"); TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); + TORCH_CHECK(item_size % 8 == 0, "Item byte size must be divisible by 8"); - if (!IsMLA) { - TORCH_CHECK(src_v.scalar_type() == dst_v.scalar_type(), "Source and destination values must have the same type"); - } - - int dtype_size = src_k.element_size(); - TORCH_CHECK((item_size * dtype_size) % 8 == 0, "Item byte size must be divisible by 8"); - - auto div_up = [](int32_t x, int32_t y) { return (x + y - 1) / y; }; + auto div_up = [](int64_t x, int64_t y) { return (x + y - 1) / y; }; const int64_t num_items = src_indices.numel(); const int64_t items_per_warp = div_up(num_items, block_quota * num_warps_per_block); const int32_t num_blocks = div_up(num_items, items_per_warp * num_warps_per_block); dim3 grid_dim(num_blocks, 1, 1); const int32_t threads_per_block = num_warps_per_block * 32; + const void* src_k_ptr = src_k.defined() ? src_k.data_ptr() : nullptr; + void* dst_k_ptr = dst_k.defined() ? dst_k.data_ptr() : nullptr; + const void* src_v_ptr = IsMLA || !src_v.defined() ? nullptr : src_v.data_ptr(); + void* dst_v_ptr = IsMLA || !dst_v.defined() ? nullptr : dst_v.data_ptr(); + const uintptr_t* src_k_tbl_ptr = src_k_layers.defined() ? src_k_layers.data_ptr() : nullptr; + const uintptr_t* dst_k_tbl_ptr = dst_k_layers.defined() ? dst_k_layers.data_ptr() : nullptr; + const uintptr_t* src_v_tbl_ptr = IsMLA || !src_v_layers.defined() ? nullptr : src_v_layers.data_ptr(); + const uintptr_t* dst_v_tbl_ptr = IsMLA || !dst_v_layers.defined() ? nullptr : dst_v_layers.data_ptr(); + cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); transfer_kernel_impl<<>>( - src_k.data_ptr(), - dst_k.data_ptr(), - (IsMLA ? nullptr : src_v.data_ptr()), - (IsMLA ? nullptr : dst_v.data_ptr()), + src_k_ptr, + dst_k_ptr, + src_v_ptr, + dst_v_ptr, src_indices.data_ptr(), dst_indices.data_ptr(), start_layer_id, num_layers_to_process, num_items, items_per_warp, - item_size * dtype_size, - src_layout_dim * dtype_size, - dst_layout_dim * dtype_size); + item_size, + src_layout_dim, + dst_layout_dim, + src_k_tbl_ptr, + dst_k_tbl_ptr, + src_v_tbl_ptr, + dst_v_tbl_ptr); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -154,24 +182,8 @@ void transfer_kv_per_layer( int64_t item_size, int64_t block_quota, int64_t num_warps_per_block) { - transfer_kv_launcher( - src_k, dst_k, src_v, dst_v, src_indices, dst_indices, 0, 1, item_size, 0, 0, block_quota, num_warps_per_block); -} - -void transfer_kv_all_layer( - const at::Tensor src_k, - at::Tensor dst_k, - const at::Tensor src_v, - at::Tensor dst_v, - const at::Tensor src_indices, - const at::Tensor dst_indices, - int64_t item_size, - int64_t num_layers, - int64_t src_layer_offset, - int64_t dst_layer_offset, - int64_t block_quota, - int64_t num_warps_per_block) { - transfer_kv_launcher( + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf, false>( src_k, dst_k, src_v, @@ -179,10 +191,113 @@ void transfer_kv_all_layer( src_indices, dst_indices, 0, + 1, + item_size, + 0, + 0, + empty, + empty, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_per_layer_pf_lf( + const at::Tensor src_k, + at::Tensor dst_k, + const at::Tensor src_v, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t src_layout_dim, + int64_t block_quota, + int64_t num_warps_per_block) { + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf, false>( + src_k, + dst_k, + src_v, + dst_v, + src_indices, + dst_indices, + 0, + 1, + item_size, + src_layout_dim, + 0, + empty, + empty, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer( + const at::Tensor src_k_layers, + const at::Tensor dst_k_layers, + const at::Tensor src_v_layers, + const at::Tensor dst_v_layers, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf_tbl, false>( + empty, + empty, + empty, + empty, + src_indices, + dst_indices, + 0, num_layers, item_size, - src_layer_offset, - dst_layer_offset, + 0, + 0, + src_k_layers, + dst_k_layers, + src_v_layers, + dst_v_layers, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer_lf_pf( + const at::Tensor src_k_layers, + at::Tensor dst_k, + const at::Tensor src_v_layers, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t dst_layout_dim, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_pf, false>( + empty, + dst_k, + empty, + dst_v, + src_indices, + dst_indices, + 0, + num_layers, + item_size, + 0, + dst_layout_dim, + src_k_layers, + empty, + src_v_layers, + empty, block_quota, num_warps_per_block); } @@ -195,12 +310,12 @@ void transfer_kv_per_layer_mla( int64_t item_size, int64_t block_quota, int64_t num_warps_per_block) { - at::Tensor empty_tensor = at::Tensor(); - transfer_kv_launcher( + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf, true>( src, dst, - empty_tensor, - empty_tensor, + empty, + empty, src_indices, dst_indices, 0, @@ -208,41 +323,110 @@ void transfer_kv_per_layer_mla( item_size, 0, 0, + empty, + empty, + empty, + empty, block_quota, num_warps_per_block); } -void transfer_kv_all_layer_mla( +void transfer_kv_per_layer_mla_pf_lf( const at::Tensor src, at::Tensor dst, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, - int64_t num_layers, - int64_t src_layer_offset, - int64_t dst_layer_offset, + int64_t src_layout_dim, int64_t block_quota, int64_t num_warps_per_block) { - at::Tensor empty_tensor = at::Tensor(); - transfer_kv_launcher( + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf, true>( src, dst, - empty_tensor, - empty_tensor, + empty, + empty, + src_indices, + dst_indices, + 0, + 1, + item_size, + src_layout_dim, + 0, + empty, + empty, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer_mla( + const at::Tensor src_layers, + const at::Tensor dst_layers, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf_tbl, true>( + empty, + empty, + empty, + empty, src_indices, dst_indices, 0, num_layers, item_size, - src_layer_offset, - dst_layer_offset, + 0, + 0, + src_layers, + dst_layers, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer_mla_lf_pf( + const at::Tensor src_layers, + at::Tensor dst, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t dst_layout_dim, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_pf, true>( + empty, + dst, + empty, + empty, + src_indices, + dst_indices, + 0, + num_layers, + item_size, + 0, + dst_layout_dim, + src_layers, + empty, + empty, + empty, block_quota, num_warps_per_block); } inline void transfer_page_direct( - const at::Tensor src_buffer, - at::Tensor dst_buffer, + const at::Tensor& src_buffer, + at::Tensor& dst_buffer, int64_t src_page_index, int64_t dst_page_index, int64_t page_size) { @@ -252,16 +436,14 @@ inline void transfer_page_direct( /* non_blocking= */ true); } -template -inline void transfer_kv_direct_impl( - const at::Tensor& src_k, - at::Tensor& dst_k, - const at::Tensor& src_v_opt, // Only used when IsMLA is false (for src_v) - at::Tensor& dst_v_opt, // Only used when IsMLA is false (for dst_v) - const at::Tensor& src_indices, - const at::Tensor& dst_indices, - int64_t page_size, - int64_t num_layers = 1) { +void transfer_kv_direct( + const std::vector& src_layers, + std::vector dst_layers, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t page_size) { + TORCH_CHECK( + src_layers.size() == dst_layers.size(), "Source and destination layers must have the same number of layers"); TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); TORCH_CHECK(page_size > 0, "Page size must be positive"); TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size"); @@ -270,73 +452,14 @@ inline void transfer_kv_direct_impl( auto dst_indices_cpu = dst_indices.cpu(); const int64_t num_pages = src_indices_cpu.size(0) / page_size; + const int64_t num_layers = src_layers.size(); - for (const auto i : c10::irange(num_pages)) { - auto s_index = src_indices_cpu[i * page_size].item(); - auto d_index = dst_indices_cpu[i * page_size].item(); + for (int64_t i = 0; i < num_pages; ++i) { + auto src_index = src_indices_cpu[i * page_size].item(); + auto dst_index = dst_indices_cpu[i * page_size].item(); - if constexpr (AllLayers) { - for (const auto j : c10::irange(num_layers)) { - if constexpr (IsMLA) { - transfer_page_direct(src_k.select(0, j), dst_k.select(0, j), s_index, d_index, page_size); - } else { - transfer_page_direct(src_k.select(0, j), dst_k.select(0, j), s_index, d_index, page_size); - transfer_page_direct(src_v_opt.select(0, j), dst_v_opt.select(0, j), s_index, d_index, page_size); - } - } - } else { // Per-layer - if constexpr (IsMLA) { - transfer_page_direct(src_k, dst_k, s_index, d_index, page_size); - } else { - transfer_page_direct(src_k, dst_k, s_index, d_index, page_size); - transfer_page_direct(src_v_opt, dst_v_opt, s_index, d_index, page_size); - } + for (int64_t j = 0; j < num_layers; ++j) { + transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, page_size); } } } - -void transfer_kv_per_layer_direct( - const at::Tensor src_k, - at::Tensor dst_k, - const at::Tensor src_v, - at::Tensor dst_v, - const at::Tensor src_indices, - const at::Tensor dst_indices, - int64_t page_size) { - transfer_kv_direct_impl(src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size); -} - -void transfer_kv_all_layer_direct( - const at::Tensor src_k, - at::Tensor dst_k, - const at::Tensor src_v, - at::Tensor dst_v, - const at::Tensor src_indices, - const at::Tensor dst_indices, - int64_t page_size, - int64_t num_layers) { - transfer_kv_direct_impl(src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size, num_layers); -} - -void transfer_kv_per_layer_mla_direct( - const at::Tensor src, - at::Tensor dst, - const at::Tensor src_indices, - const at::Tensor dst_indices, - int64_t page_size) { - at::Tensor empty_tensor = at::Tensor(); - - transfer_kv_direct_impl(src, dst, empty_tensor, empty_tensor, src_indices, dst_indices, page_size); -} - -void transfer_kv_all_layer_mla_direct( - const at::Tensor src, - at::Tensor dst, - const at::Tensor src_indices, - const at::Tensor dst_indices, - int64_t page_size, - int64_t num_layers) { - at::Tensor empty_tensor = at::Tensor(); - transfer_kv_direct_impl( - src, dst, empty_tensor, empty_tensor, src_indices, dst_indices, page_size, num_layers); -} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index df06bd3cd..6b589101f 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -399,16 +399,7 @@ void transfer_kv_per_layer( int64_t block_quota, int64_t num_warps_per_block); -void transfer_kv_per_layer_direct( - const at::Tensor src_k, - at::Tensor dst_k, - const at::Tensor src_v, - at::Tensor dst_v, - const at::Tensor src_indices, - const at::Tensor dst_indices, - int64_t page_size); - -void transfer_kv_all_layer( +void transfer_kv_per_layer_pf_lf( const at::Tensor src_k, at::Tensor dst_k, const at::Tensor src_v, @@ -416,21 +407,34 @@ void transfer_kv_all_layer( const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, - int64_t num_layers, - int64_t src_layer_offset, - int64_t dst_layer_offset, + int64_t src_layout_dim, int64_t block_quota, int64_t num_warps_per_block); -void transfer_kv_all_layer_direct( - const at::Tensor src_k, +void transfer_kv_all_layer( + const at::Tensor src_k_layers, + const at::Tensor dst_k_layers, + const at::Tensor src_v_layers, + const at::Tensor dst_v_layers, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block); + +void transfer_kv_all_layer_lf_pf( + const at::Tensor src_k_layers, at::Tensor dst_k, - const at::Tensor src_v, + const at::Tensor src_v_layers, at::Tensor dst_v, const at::Tensor src_indices, const at::Tensor dst_indices, - int64_t page_size, - int64_t num_layers); + int64_t item_size, + int64_t dst_layout_dim, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block); void transfer_kv_per_layer_mla( const at::Tensor src, @@ -441,32 +445,43 @@ void transfer_kv_per_layer_mla( int64_t block_quota, int64_t num_warps_per_block); -void transfer_kv_per_layer_mla_direct( - const at::Tensor src, - at::Tensor dst, - const at::Tensor src_indices, - const at::Tensor dst_indices, - int64_t page_size); - -void transfer_kv_all_layer_mla( +void transfer_kv_per_layer_mla_pf_lf( const at::Tensor src, at::Tensor dst, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, - int64_t num_layers, - int64_t src_layer_offset, - int64_t dst_layer_offset, + int64_t src_layout_dim, int64_t block_quota, int64_t num_warps_per_block); -void transfer_kv_all_layer_mla_direct( - const at::Tensor src, +void transfer_kv_all_layer_mla( + const at::Tensor src_layers, + const at::Tensor dst_layers, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block); + +void transfer_kv_all_layer_mla_lf_pf( + const at::Tensor src_layers, at::Tensor dst, const at::Tensor src_indices, const at::Tensor dst_indices, - int64_t page_size, - int64_t num_layers); + int64_t item_size, + int64_t dst_layout_dim, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block); + +void transfer_kv_direct( + const std::vector& src_layers, + std::vector dst_layers, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t page_size); /* * From csrc/moe/cutlass_moe/w4a8 diff --git a/sgl-kernel/python/sgl_kernel/kvcacheio.py b/sgl-kernel/python/sgl_kernel/kvcacheio.py index 5350e49dd..1440c2ca3 100644 --- a/sgl-kernel/python/sgl_kernel/kvcacheio.py +++ b/sgl-kernel/python/sgl_kernel/kvcacheio.py @@ -1,3 +1,5 @@ +from typing import List + import torch @@ -22,57 +24,116 @@ def transfer_kv_per_layer( dst_v, src_indices, dst_indices, - item_size, + item_size * src_k.element_size(), # todo, hot fix for compatibility block_quota, num_warps_per_block, ) elif io_backend == "direct": - torch.ops.sgl_kernel.transfer_kv_per_layer_direct( - src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size + torch.ops.sgl_kernel.transfer_kv_direct( + [src_k, src_v], [dst_k, dst_v], src_indices, dst_indices, page_size ) else: raise ValueError(f"Unsupported io backend") -def transfer_kv_all_layer( +def transfer_kv_per_layer_pf_lf( src_k: torch.Tensor, dst_k: torch.Tensor, src_v: torch.Tensor, dst_v: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, + item_size: int, + src_layout_dim: int, + block_quota: int = 2, + num_warps_per_block: int = 32, +): + torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf( + src_k, + dst_k, + src_v, + dst_v, + src_indices, + dst_indices, + item_size, + src_layout_dim, + block_quota, + num_warps_per_block, + ) + + +def transfer_kv_all_layer( + src_k_layers: torch.Tensor, + dst_k_layers: torch.Tensor, + src_v_layers: torch.Tensor, + dst_v_layers: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, io_backend: str, - page_size: int, item_size: int, num_layers: int, - src_layer_offset: int, - dst_layer_offset: int, block_quota: int = 2, num_warps_per_block: int = 32, ): if io_backend == "kernel": torch.ops.sgl_kernel.transfer_kv_all_layer( - src_k, - dst_k, - src_v, - dst_v, + src_k_layers, + dst_k_layers, + src_v_layers, + dst_v_layers, src_indices, dst_indices, item_size, num_layers, - src_layer_offset, - dst_layer_offset, block_quota, num_warps_per_block, ) elif io_backend == "direct": - torch.ops.sgl_kernel.transfer_kv_all_layer_direct( - src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size, num_layers - ) + raise NotImplementedError("Deprecated interface") else: raise ValueError(f"Unsupported io backend") +def transfer_kv_all_layer_lf_pf( + src_k_layers: torch.Tensor, + dst_k: torch.Tensor, + src_v_layers: torch.Tensor, + dst_v: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + item_size: int, + dst_layout_dim: int, + num_layers: int, + block_quota: int = 2, + num_warps_per_block: int = 32, +): + torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf( + src_k_layers, + dst_k, + src_v_layers, + dst_v, + src_indices, + dst_indices, + item_size, + dst_layout_dim, + num_layers, + block_quota, + num_warps_per_block, + ) + + +def transfer_kv_direct( + src_layers: List[torch.Tensor], + dst_layers: List[torch.Tensor], + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + page_size: int, +): + torch.ops.sgl_kernel.transfer_kv_direct( + src_layers, dst_layers, src_indices, dst_indices, page_size + ) + + def transfer_kv_per_layer_mla( src: torch.Tensor, dst: torch.Tensor, @@ -90,48 +151,87 @@ def transfer_kv_per_layer_mla( dst, src_indices, dst_indices, - item_size, + item_size * src.element_size(), # todo, hot fix for compatibility block_quota, num_warps_per_block, ) elif io_backend == "direct": - torch.ops.sgl_kernel.transfer_kv_per_layer_mla_direct( - src, dst, src_indices, dst_indices, page_size + torch.ops.sgl_kernel.transfer_kv_direct( + [src], [dst], src_indices, dst_indices, page_size ) else: raise ValueError(f"Unsupported io backend") -def transfer_kv_all_layer_mla( +def transfer_kv_per_layer_mla_pf_lf( src: torch.Tensor, dst: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, + item_size: int, + src_layout_dim: int, + block_quota: int = 2, + num_warps_per_block: int = 32, +): + torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf( + src, + dst, + src_indices, + dst_indices, + item_size, + src_layout_dim, + block_quota, + num_warps_per_block, + ) + + +def transfer_kv_all_layer_mla( + src_layers: torch.Tensor, + dst_layers: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, io_backend: str, - page_size: int, item_size: int, num_layers: int, - src_layer_offset: int, - dst_layer_offset: int, block_quota: int = 2, num_warps_per_block: int = 32, ): if io_backend == "kernel": torch.ops.sgl_kernel.transfer_kv_all_layer_mla( - src, - dst, + src_layers, + dst_layers, src_indices, dst_indices, item_size, num_layers, - src_layer_offset, - dst_layer_offset, block_quota, num_warps_per_block, ) elif io_backend == "direct": - torch.ops.sgl_kernel.transfer_kv_all_layer_mla_direct( - src, dst, src_indices, dst_indices, page_size, num_layers - ) + raise NotImplementedError("Deprecated interface") else: raise ValueError(f"Unsupported io backend") + + +def transfer_kv_all_layer_mla_lf_pf( + src_layers: torch.Tensor, + dst: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + item_size: int, + dst_layout_dim: int, + num_layers: int, + block_quota: int = 2, + num_warps_per_block: int = 32, +): + torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf( + src_layers, + dst, + src_indices, + dst_indices, + item_size, + dst_layout_dim, + num_layers, + block_quota, + num_warps_per_block, + ) diff --git a/sgl-kernel/tests/test_kvcacheio.py b/sgl-kernel/tests/test_kvcacheio.py index 635b5ba50..171fc4ca4 100644 --- a/sgl-kernel/tests/test_kvcacheio.py +++ b/sgl-kernel/tests/test_kvcacheio.py @@ -3,6 +3,7 @@ import torch from sgl_kernel.kvcacheio import ( transfer_kv_all_layer, transfer_kv_all_layer_mla, + transfer_kv_direct, transfer_kv_per_layer, transfer_kv_per_layer_mla, ) @@ -104,14 +105,12 @@ def test_transfer_kv( page_size=page_size, item_size=item_size, ) - transfer_kv_per_layer_mla( - src_pool_host[layer_idx_to_test], - dst_pool_direct[layer_idx_to_test], + transfer_kv_direct( + [src_pool_host[layer_idx_to_test]], + [dst_pool_direct[layer_idx_to_test]], src_indices_host, dst_indices_device, - io_backend="direct", page_size=page_size, - item_size=item_size, ) else: for layer_id in range(num_layers): @@ -121,29 +120,34 @@ def test_transfer_kv( src_indices_host, dst_indices_device, ) + src_layers_device = torch.tensor( + [src_pool_host[layer_id].data_ptr() for layer_id in range(num_layers)], + dtype=torch.uint64, + device=device, + ) + dst_layers_device = torch.tensor( + [ + dst_pool_kernel[layer_id].data_ptr() + for layer_id in range(num_layers) + ], + dtype=torch.uint64, + device=device, + ) transfer_kv_all_layer_mla( - src_pool_host, - dst_pool_kernel, + src_layers_device, + dst_layers_device, src_indices_device, dst_indices_device, io_backend="kernel", - page_size=page_size, - item_size=item_size, + item_size=item_size * dtype.itemsize, num_layers=num_layers, - src_layer_offset=total_items_in_pool * item_size, - dst_layer_offset=total_items_in_pool * item_size, ) - transfer_kv_all_layer_mla( - src_pool_host, - dst_pool_direct, + transfer_kv_direct( + [src_pool_host[layer_id] for layer_id in range(num_layers)], + [dst_pool_direct[layer_id] for layer_id in range(num_layers)], src_indices_host, dst_indices_device, - io_backend="direct", page_size=page_size, - item_size=item_size, - num_layers=num_layers, - src_layer_offset=total_items_in_pool * item_size, - dst_layer_offset=total_items_in_pool * item_size, ) torch.cuda.synchronize() torch.testing.assert_close(dst_pool_kernel, dst_pool_ref) @@ -173,16 +177,15 @@ def test_transfer_kv( page_size=page_size, item_size=item_size, ) - transfer_kv_per_layer( - src_k_pool[layer_idx_to_test], - dst_k_pool_direct[layer_idx_to_test], - src_v_pool[layer_idx_to_test], - dst_v_pool_direct[layer_idx_to_test], + transfer_kv_direct( + [src_k_pool[layer_idx_to_test], src_v_pool[layer_idx_to_test]], + [ + dst_k_pool_direct[layer_idx_to_test], + dst_v_pool_direct[layer_idx_to_test], + ], src_indices_host, dst_indices_device, - io_backend="direct", page_size=page_size, - item_size=item_size, ) else: for layer_id in range(num_layers): @@ -198,33 +201,52 @@ def test_transfer_kv( src_indices_host, dst_indices_device, ) + + src_k_layers_device = torch.tensor( + [src_k_pool[layer_id].data_ptr() for layer_id in range(num_layers)], + dtype=torch.uint64, + device=device, + ) + src_v_layers_device = torch.tensor( + [src_v_pool[layer_id].data_ptr() for layer_id in range(num_layers)], + dtype=torch.uint64, + device=device, + ) + dst_k_layers_device = torch.tensor( + [ + dst_k_pool_kernel[layer_id].data_ptr() + for layer_id in range(num_layers) + ], + dtype=torch.uint64, + device=device, + ) + dst_v_layers_device = torch.tensor( + [ + dst_v_pool_kernel[layer_id].data_ptr() + for layer_id in range(num_layers) + ], + dtype=torch.uint64, + device=device, + ) transfer_kv_all_layer( - src_k_pool, - dst_k_pool_kernel, - src_v_pool, - dst_v_pool_kernel, + src_k_layers_device, + dst_k_layers_device, + src_v_layers_device, + dst_v_layers_device, src_indices_device, dst_indices_device, io_backend="kernel", - page_size=page_size, - item_size=item_size, + item_size=item_size * dtype.itemsize, num_layers=num_layers, - src_layer_offset=total_items_in_pool * item_size, - dst_layer_offset=total_items_in_pool * item_size, ) - transfer_kv_all_layer( - src_k_pool, - dst_k_pool_direct, - src_v_pool, - dst_v_pool_direct, + transfer_kv_direct( + [src_k_pool[layer_id] for layer_id in range(num_layers)] + + [src_v_pool[layer_id] for layer_id in range(num_layers)], + [dst_k_pool_direct[layer_id] for layer_id in range(num_layers)] + + [dst_v_pool_direct[layer_id] for layer_id in range(num_layers)], src_indices_host, dst_indices_device, - io_backend="direct", page_size=page_size, - item_size=item_size, - num_layers=num_layers, - src_layer_offset=total_items_in_pool * item_size, - dst_layer_offset=total_items_in_pool * item_size, ) torch.cuda.synchronize() torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref)