From 5be8c2f7f75d9b64362cca87f517c1df55abe157 Mon Sep 17 00:00:00 2001 From: huangtingwei <141888744+huangtingwei9988@users.noreply.github.com> Date: Wed, 10 Sep 2025 13:35:34 +0800 Subject: [PATCH] Page first direct IO kernel (#10060) Co-authored-by: Zhiqiang Xie --- sgl-kernel/csrc/common_extension.cc | 8 + sgl-kernel/csrc/kvcacheio/transfer.cu | 82 +++++++- sgl-kernel/include/sgl_kernel_ops.h | 15 ++ sgl-kernel/python/sgl_kernel/kvcacheio.py | 25 +++ sgl-kernel/tests/test_kvcacheio.py | 230 ++++++++++++++++++++++ 5 files changed, 358 insertions(+), 2 deletions(-) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 282be77ad..599bcf591 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -331,6 +331,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int " "page_size) -> ()"); m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct); + m.def( + "transfer_kv_per_layer_direct_pf_lf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, " + "Tensor dst_indices, int layer_id, int page_size)->() "); + m.impl("transfer_kv_per_layer_direct_pf_lf", torch::kCUDA, &transfer_kv_per_layer_direct_pf_lf); + m.def( + "transfer_kv_all_layer_direct_lf_pf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, " + "Tensor dst_indices, int page_size) ->() "); + m.impl("transfer_kv_all_layer_direct_lf_pf", torch::kCUDA, &transfer_kv_all_layer_direct_lf_pf); /* * From csrc/memory diff --git a/sgl-kernel/csrc/kvcacheio/transfer.cu b/sgl-kernel/csrc/kvcacheio/transfer.cu index fab0d3bb8..bca9f326c 100644 --- a/sgl-kernel/csrc/kvcacheio/transfer.cu +++ b/sgl-kernel/csrc/kvcacheio/transfer.cu @@ -437,8 +437,8 @@ void transfer_kv_all_layer_mla_lf_pf( } inline void transfer_page_direct( - const at::Tensor& src_buffer, - at::Tensor& dst_buffer, + const at::Tensor src_buffer, + at::Tensor dst_buffer, int64_t src_page_index, int64_t dst_page_index, int64_t page_size) { @@ -493,3 +493,81 @@ void transfer_kv_direct( start_index = end_index; } } + +template +inline void transfer_kv_page_first_direct_impl( + const std::vector& src_ptrs, + std::vector dst_ptrs, + const at::Tensor& src_indices, + const at::Tensor& dst_indices, + int64_t start_layer_id, + int64_t page_size) { + TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); + TORCH_CHECK(page_size > 0, "Page size must be positive"); + TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size"); + + auto src_indices_cpu = src_indices.cpu(); + auto dst_indices_cpu = dst_indices.cpu(); + const int64_t num_pages = src_indices_cpu.size(0) / page_size; + + if constexpr (IsLf2Pf) { + const bool is_mla = dst_ptrs.size() == 1; + const int64_t num_layers = is_mla ? src_ptrs.size() : src_ptrs.size() / 2; + + for (const auto i : c10::irange(num_pages)) { + auto s_index = src_indices_cpu[i * page_size].item(); + auto d_index = dst_indices_cpu[i * page_size].item() / page_size; + for (int64_t j = 0; j < num_layers; ++j) { + transfer_page_direct( + src_ptrs[j], dst_ptrs[0].select(0, d_index).select(0, start_layer_id + j), s_index, 0, page_size); + if (!is_mla) { + transfer_page_direct( + src_ptrs[j + num_layers], + dst_ptrs[1].select(0, d_index).select(0, start_layer_id + j), + s_index, + 0, + page_size); + } + } + } + } else { + const bool is_mla = src_ptrs.size() == 1; + const int64_t num_layers = is_mla ? dst_ptrs.size() : dst_ptrs.size() / 2; + + for (const auto i : c10::irange(num_pages)) { + auto s_index = src_indices_cpu[i * page_size].item() / page_size; + auto d_index = dst_indices_cpu[i * page_size].item(); + for (int64_t j = 0; j < num_layers; ++j) { + transfer_page_direct( + src_ptrs[0].select(0, s_index).select(0, start_layer_id + j), dst_ptrs[j], 0, d_index, page_size); + if (!is_mla) { + transfer_page_direct( + src_ptrs[1].select(0, s_index).select(0, start_layer_id + j), + dst_ptrs[j + num_layers], + 0, + d_index, + page_size); + } + } + } + } +} + +void transfer_kv_per_layer_direct_pf_lf( + const std::vector& src_ptrs, + std::vector dst_ptrs, + const at::Tensor& src_indices, + const at::Tensor& dst_indices, + int64_t layer_id, + int64_t page_size) { + transfer_kv_page_first_direct_impl(src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size); +} + +void transfer_kv_all_layer_direct_lf_pf( + const std::vector& src_ptrs, + std::vector dst_ptrs, + const at::Tensor& src_indices, + const at::Tensor& dst_indices, + int64_t page_size) { + transfer_kv_page_first_direct_impl(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 3c3160a48..1cd85c911 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -569,6 +569,21 @@ void transfer_kv_direct( const at::Tensor dst_indices, int64_t page_size); +void transfer_kv_per_layer_direct_pf_lf( + const std::vector& src_ptrs, + std::vector dst_ptrs, + const at::Tensor& src_indices, + const at::Tensor& dst_indices, + int64_t layer_id, + int64_t page_size); + +void transfer_kv_all_layer_direct_lf_pf( + const std::vector& src_ptrs, + std::vector dst_ptrs, + const at::Tensor& src_indices, + const at::Tensor& dst_indices, + int64_t page_size); + /* * From FlashInfer */ diff --git a/sgl-kernel/python/sgl_kernel/kvcacheio.py b/sgl-kernel/python/sgl_kernel/kvcacheio.py index 913cbc5e3..5714b6a0d 100644 --- a/sgl-kernel/python/sgl_kernel/kvcacheio.py +++ b/sgl-kernel/python/sgl_kernel/kvcacheio.py @@ -128,6 +128,31 @@ def transfer_kv_direct( ) +def transfer_kv_per_layer_direct_pf_lf( + src_ptrs: List[torch.Tensor], + dst_ptrs: List[torch.Tensor], + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + layer_id: int, + page_size: int, +): + torch.ops.sgl_kernel.transfer_kv_per_layer_direct_pf_lf( + src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size + ) + + +def transfer_kv_all_layer_direct_lf_pf( + src_ptrs: List[torch.Tensor], + dst_ptrs: List[torch.Tensor], + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + page_size: int, +): + torch.ops.sgl_kernel.transfer_kv_all_layer_direct_lf_pf( + src_ptrs, dst_ptrs, src_indices, dst_indices, page_size + ) + + def transfer_kv_per_layer_mla( src: torch.Tensor, dst: torch.Tensor, diff --git a/sgl-kernel/tests/test_kvcacheio.py b/sgl-kernel/tests/test_kvcacheio.py index d2b5be111..07fcc2413 100644 --- a/sgl-kernel/tests/test_kvcacheio.py +++ b/sgl-kernel/tests/test_kvcacheio.py @@ -2,9 +2,11 @@ import pytest import torch from sgl_kernel.kvcacheio import ( transfer_kv_all_layer, + transfer_kv_all_layer_direct_lf_pf, transfer_kv_all_layer_mla, transfer_kv_direct, transfer_kv_per_layer, + transfer_kv_per_layer_direct_pf_lf, transfer_kv_per_layer_mla, ) @@ -13,6 +15,21 @@ def ref_copy_with_indices(src_pool, dst_pool, src_indices, dst_indices): dst_pool[dst_indices] = src_pool[src_indices].to(dst_pool.device) +def ref_copy_with_indices_pf_direct( + src_pool, dst_pool, src_indices, dst_indices, page_size, layer_id, lf_to_pf=False +): + if lf_to_pf: + for i in range(0, len(src_indices), page_size): + dst_pool[dst_indices[i] // page_size][layer_id] = src_pool[layer_id][ + src_indices[i : i + page_size] + ].to(dst_pool.device) + else: + for i in range(0, len(src_indices), page_size): + dst_pool[layer_id][dst_indices[i : i + page_size]] = src_pool[ + src_indices[i] // page_size + ][layer_id].to(dst_pool.device) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("num_items_to_transfer", [1, 128, 1024]) @pytest.mark.parametrize("page_size", [1, 16, 64]) @@ -251,5 +268,218 @@ def test_transfer_kv( torch.set_default_dtype(original_dtype) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("num_items_to_transfer", [128, 1024, 8192]) +@pytest.mark.parametrize("page_size", [16, 64, 128]) +@pytest.mark.parametrize("item_size", [256]) +@pytest.mark.parametrize("total_items_in_pool", [20480]) +@pytest.mark.parametrize("is_mla", [False, True]) +@pytest.mark.parametrize("lf_to_pf", [False, True]) +def test_transfer_kv_pf_direct( + dtype: torch.dtype, + num_items_to_transfer: int, + item_size: int, + page_size: int, + total_items_in_pool: int, + is_mla: bool, + lf_to_pf: bool, +): + original_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + device = "cuda" + torch.cuda.manual_seed(42) + + num_layers = 4 + + total_pages_in_pool = total_items_in_pool // page_size + num_pages_to_transfer = num_items_to_transfer // page_size + if num_pages_to_transfer == 0: + torch.set_default_dtype(original_dtype) + return + page_indices = torch.randperm(total_pages_in_pool, dtype=torch.int64) + src_indices_host = torch.cat( + [ + torch.arange(p * page_size, (p + 1) * page_size) + for p in page_indices[:num_pages_to_transfer] + ] + ) + src_indices_device = src_indices_host.to(device) + dst_indices_host = torch.cat( + [ + torch.arange(p * page_size, (p + 1) * page_size) + for p in page_indices[num_pages_to_transfer : 2 * num_pages_to_transfer] + ] + ) + dst_indices_device = dst_indices_host.to(device) + + # We will test the per-layer function on the first layer (index 0) of the pool. + layer_idx_to_test = 0 + + if lf_to_pf: + if is_mla: + src_pool = torch.randn(num_layers, total_items_in_pool, item_size).to( + device + ) + src_pool_ptrs = [src_pool[i] for i in range(num_layers)] + dst_pool_ref = torch.zeros( + total_pages_in_pool, num_layers, page_size, item_size + ).pin_memory() + dst_pool_direct = torch.zeros_like(dst_pool_ref) + torch.cuda.synchronize() + + transfer_kv_all_layer_direct_lf_pf( + src_pool_ptrs, + [dst_pool_direct], + src_indices_host, + dst_indices_host, + page_size, + ) + for i in range(num_layers): + ref_copy_with_indices_pf_direct( + src_pool, + dst_pool_ref, + src_indices_device, + dst_indices_host, + page_size, + i, + lf_to_pf=True, + ) + torch.cuda.synchronize() + torch.testing.assert_close(dst_pool_direct, dst_pool_ref) + + else: + src_k_pool = torch.randn(num_layers, total_items_in_pool, item_size).to( + device + ) + src_k_pool_ptrs = [src_k_pool[i] for i in range(num_layers)] + src_v_pool = torch.randn(num_layers, total_items_in_pool, item_size).to( + device + ) + src_v_pool_ptrs = [src_v_pool[i] for i in range(num_layers)] + dst_k_pool_ref = torch.zeros( + total_pages_in_pool, num_layers, page_size, item_size + ).pin_memory() + dst_v_pool_ref = torch.zeros_like(dst_k_pool_ref) + dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref) + dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref) + torch.cuda.synchronize() + + transfer_kv_all_layer_direct_lf_pf( + src_k_pool_ptrs + src_v_pool_ptrs, + [dst_k_pool_direct, dst_v_pool_direct], + src_indices_host, + dst_indices_host, + page_size, + ) + for i in range(num_layers): + ref_copy_with_indices_pf_direct( + src_k_pool, + dst_k_pool_ref, + src_indices_device, + dst_indices_host, + page_size, + i, + lf_to_pf=True, + ) + ref_copy_with_indices_pf_direct( + src_v_pool, + dst_v_pool_ref, + src_indices_device, + dst_indices_host, + page_size, + i, + lf_to_pf=True, + ) + torch.cuda.synchronize() + torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref) + torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref) + else: + if is_mla: + src_pool = torch.randn( + total_pages_in_pool, num_layers, page_size, item_size + ).pin_memory() + + dst_pool_ref = torch.zeros(num_layers, total_items_in_pool, item_size).to( + device + ) + dst_pool_direct = torch.zeros_like(dst_pool_ref) + dst_pool_direct_ptrs = [dst_pool_direct[i] for i in range(num_layers)] + torch.cuda.synchronize() + + transfer_kv_per_layer_direct_pf_lf( + [src_pool], + [dst_pool_direct_ptrs[layer_idx_to_test]], + src_indices_host, + dst_indices_host, + layer_idx_to_test, + page_size, + ) + ref_copy_with_indices_pf_direct( + src_pool, + dst_pool_ref, + src_indices_host, + dst_indices_device, + page_size, + layer_idx_to_test, + lf_to_pf=False, + ) + torch.cuda.synchronize() + torch.testing.assert_close(dst_pool_direct, dst_pool_ref) + else: + src_k_pool = torch.randn( + total_pages_in_pool, num_layers, page_size, item_size + ).pin_memory() + src_v_pool = torch.randn( + total_pages_in_pool, num_layers, page_size, item_size + ).pin_memory() + + dst_k_pool_ref = torch.zeros(num_layers, total_items_in_pool, item_size).to( + device + ) + dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref) + dst_k_pool_direct_ptrs = [dst_k_pool_direct[i] for i in range(num_layers)] + + dst_v_pool_ref = torch.zeros_like(dst_k_pool_ref) + dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref) + dst_v_pool_direct_ptrs = [dst_v_pool_direct[i] for i in range(num_layers)] + torch.cuda.synchronize() + + transfer_kv_per_layer_direct_pf_lf( + [src_k_pool, src_v_pool], + [ + dst_k_pool_direct_ptrs[layer_idx_to_test], + dst_v_pool_direct_ptrs[layer_idx_to_test], + ], + src_indices_host, + dst_indices_host, + layer_idx_to_test, + page_size, + ) + + ref_copy_with_indices_pf_direct( + src_k_pool, + dst_k_pool_ref, + src_indices_host, + dst_indices_device, + page_size, + layer_idx_to_test, + lf_to_pf=False, + ) + ref_copy_with_indices_pf_direct( + src_v_pool, + dst_v_pool_ref, + src_indices_host, + dst_indices_device, + page_size, + layer_idx_to_test, + lf_to_pf=False, + ) + + torch.cuda.synchronize() + torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref) + torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref) + torch.set_default_dtype(original_dtype) + + if __name__ == "__main__": pytest.main([__file__])