diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index ebba94982..f8481491a 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -250,6 +250,7 @@ set(SOURCES "csrc/speculative/packbit.cu" "csrc/speculative/speculative_sampling.cu" "csrc/grammar/apply_token_bitmask_inplace_cuda.cu" + "csrc/kvcacheio/transfer.cu" "csrc/common_extension.cc" "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index ed9f406e6..11a9adbb4 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -230,6 +230,43 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "int cuda_stream) -> ()"); m.impl("segment_packbits", torch::kCUDA, &segment_packbits); + /* + * From csrc/kvcacheio + */ + m.def( + "transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " + "dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"); + m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer); + m.def( + "transfer_kv_per_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " + "dst_indices, int page_size) -> ()"); + m.impl("transfer_kv_per_layer_direct", torch::kCUDA, &transfer_kv_per_layer_direct); + m.def( + "transfer_kv_all_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " + "dst_indices, int item_size, int num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int " + "num_warps_per_block) -> ()"); + m.impl("transfer_kv_all_layer", torch::kCUDA, &transfer_kv_all_layer); + m.def( + "transfer_kv_all_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " + "dst_indices, int page_size, int num_layers) -> ()"); + m.impl("transfer_kv_all_layer_direct", torch::kCUDA, &transfer_kv_all_layer_direct); + m.def( + "transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int " + "block_quota, int num_warps_per_block) -> ()"); + m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla); + m.def( + "transfer_kv_per_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size) " + "-> ()"); + m.impl("transfer_kv_per_layer_mla_direct", torch::kCUDA, &transfer_kv_per_layer_mla_direct); + m.def( + "transfer_kv_all_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int " + "num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int num_warps_per_block) -> ()"); + m.impl("transfer_kv_all_layer_mla", torch::kCUDA, &transfer_kv_all_layer_mla); + m.def( + "transfer_kv_all_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size, " + "int num_layers) -> ()"); + m.impl("transfer_kv_all_layer_mla_direct", torch::kCUDA, &transfer_kv_all_layer_mla_direct); + /* * From FlashInfer */ diff --git a/sgl-kernel/csrc/kvcacheio/transfer.cu b/sgl-kernel/csrc/kvcacheio/transfer.cu new file mode 100644 index 000000000..6c939dd55 --- /dev/null +++ b/sgl-kernel/csrc/kvcacheio/transfer.cu @@ -0,0 +1,342 @@ +#include +#include +#include + +#include + +#include "pytorch_extension_utils.h" + +__device__ __forceinline__ void +transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_t item_size_bytes) { + // todo, different chunk size + int total_chunks = item_size_bytes / 8; + const int64_t* src_8 = reinterpret_cast(src_addr); + int64_t* dst_8 = reinterpret_cast(dst_addr); +#pragma unroll + for (int j = lane_id; j < total_chunks; j += 32) { + const int64_t* src_addr_lane = &src_8[j]; + int64_t* dst_addr_lane = &dst_8[j]; + int64_t temp_val; + asm volatile("ld.global.nc.b64 %0, [%1];" : "=l"(temp_val) : "l"(src_addr_lane) : "memory"); + asm volatile("st.global.cg.b64 [%0], %1;" ::"l"(dst_addr_lane), "l"(temp_val) : "memory"); + } +} + +// todo, structs for different memory layout +__device__ __forceinline__ int64_t +get_global_offset_lf(int64_t layer_id, int64_t layer_dim, int64_t page_id, int64_t item_size_bytes) { + // layer first + return layer_id * layer_dim + page_id * item_size_bytes; +} + +__device__ __forceinline__ int64_t +get_global_offset_pf(int64_t layer_id, int64_t page_dim, int64_t page_id, int64_t item_size_bytes) { + // page first + return page_id * page_dim + layer_id * item_size_bytes; +} + +template +__global__ void transfer_kernel_impl( + const void* __restrict__ src_k, + void* __restrict__ dst_k, + const void* __restrict__ src_v, + void* __restrict__ dst_v, + const int64_t* __restrict__ src_indices, + const int64_t* __restrict__ dst_indices, + int64_t start_layer_id, + int64_t num_layers_to_process, + int64_t num_items, + int64_t items_per_warp, + int64_t item_size_bytes, + int64_t src_layout_dim, + int64_t dst_layout_dim) { + int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + int32_t lane_id = tid % 32; + int32_t warp_id = tid / 32; + + for (int i = 0; i < items_per_warp; ++i) { + int32_t item_id = warp_id * items_per_warp + i; + if (item_id >= num_items) { + return; + } + const int64_t src_page_id = src_indices[item_id]; + const int64_t dst_page_id = dst_indices[item_id]; + + // Loop over layers if necessary + for (int64_t layer_id = start_layer_id; layer_id < start_layer_id + num_layers_to_process; ++layer_id) { + // Calculate offsets using the provided function pointers + const int64_t src_offset = SrcOffsetFn(layer_id, src_layout_dim, src_page_id, item_size_bytes); + const int64_t dst_offset = DstOffsetFn(layer_id, dst_layout_dim, dst_page_id, item_size_bytes); + + if constexpr (IsMLA) { + transfer_item_warp( + lane_id, + static_cast(src_k) + src_offset, + static_cast(dst_k) + dst_offset, + item_size_bytes); + } else { + transfer_item_warp( + lane_id, + static_cast(src_k) + src_offset, + static_cast(dst_k) + dst_offset, + item_size_bytes); + transfer_item_warp( + lane_id, + static_cast(src_v) + src_offset, + static_cast(dst_v) + dst_offset, + item_size_bytes); + } + } + } +} + +template +void transfer_kv_launcher( + const at::Tensor& src_k, + at::Tensor& dst_k, + const at::Tensor& src_v, + at::Tensor& dst_v, + const at::Tensor& src_indices, + const at::Tensor& dst_indices, + int64_t start_layer_id, + int64_t num_layers_to_process, + int64_t item_size, + int64_t src_layout_dim, + int64_t dst_layout_dim, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(src_k.scalar_type() == dst_k.scalar_type(), "Source and destination keys must have the same type"); + TORCH_CHECK(src_indices.is_cuda(), "Source indices must be a CUDA tensor"); + TORCH_CHECK(dst_indices.is_cuda(), "Destination indices must be a CUDA tensor"); + TORCH_CHECK(src_indices.scalar_type() == at::kLong, "Source indices must be of type long"); + TORCH_CHECK(dst_indices.scalar_type() == at::kLong, "Destination indices must be of type long"); + TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); + + if (!IsMLA) { + TORCH_CHECK(src_v.scalar_type() == dst_v.scalar_type(), "Source and destination values must have the same type"); + } + + int dtype_size = src_k.element_size(); + TORCH_CHECK((item_size * dtype_size) % 8 == 0, "Item byte size must be divisible by 8"); + + auto div_up = [](int32_t x, int32_t y) { return (x + y - 1) / y; }; + const int64_t num_items = src_indices.numel(); + const int64_t items_per_warp = div_up(num_items, block_quota * num_warps_per_block); + const int32_t num_blocks = div_up(num_items, items_per_warp * num_warps_per_block); + dim3 grid_dim(num_blocks, 1, 1); + const int32_t threads_per_block = num_warps_per_block * 32; + + cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); + transfer_kernel_impl<<>>( + src_k.data_ptr(), + dst_k.data_ptr(), + (IsMLA ? nullptr : src_v.data_ptr()), + (IsMLA ? nullptr : dst_v.data_ptr()), + src_indices.data_ptr(), + dst_indices.data_ptr(), + start_layer_id, + num_layers_to_process, + num_items, + items_per_warp, + item_size * dtype_size, + src_layout_dim * dtype_size, + dst_layout_dim * dtype_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void transfer_kv_per_layer( + const at::Tensor src_k, + at::Tensor dst_k, + const at::Tensor src_v, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t block_quota, + int64_t num_warps_per_block) { + transfer_kv_launcher( + src_k, dst_k, src_v, dst_v, src_indices, dst_indices, 0, 1, item_size, 0, 0, block_quota, num_warps_per_block); +} + +void transfer_kv_all_layer( + const at::Tensor src_k, + at::Tensor dst_k, + const at::Tensor src_v, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t num_layers, + int64_t src_layer_offset, + int64_t dst_layer_offset, + int64_t block_quota, + int64_t num_warps_per_block) { + transfer_kv_launcher( + src_k, + dst_k, + src_v, + dst_v, + src_indices, + dst_indices, + 0, + num_layers, + item_size, + src_layer_offset, + dst_layer_offset, + block_quota, + num_warps_per_block); +} + +void transfer_kv_per_layer_mla( + const at::Tensor src, + at::Tensor dst, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t block_quota, + int64_t num_warps_per_block) { + at::Tensor empty_tensor = at::Tensor(); + transfer_kv_launcher( + src, + dst, + empty_tensor, + empty_tensor, + src_indices, + dst_indices, + 0, + 1, + item_size, + 0, + 0, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer_mla( + const at::Tensor src, + at::Tensor dst, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t num_layers, + int64_t src_layer_offset, + int64_t dst_layer_offset, + int64_t block_quota, + int64_t num_warps_per_block) { + at::Tensor empty_tensor = at::Tensor(); + transfer_kv_launcher( + src, + dst, + empty_tensor, + empty_tensor, + src_indices, + dst_indices, + 0, + num_layers, + item_size, + src_layer_offset, + dst_layer_offset, + block_quota, + num_warps_per_block); +} + +inline void transfer_page_direct( + const at::Tensor src_buffer, + at::Tensor dst_buffer, + int64_t src_page_index, + int64_t dst_page_index, + int64_t page_size) { + dst_buffer.slice(0, dst_page_index, dst_page_index + page_size) + .copy_( + src_buffer.slice(0, src_page_index, src_page_index + page_size), + /* non_blocking= */ true); +} + +template +inline void transfer_kv_direct_impl( + const at::Tensor& src_k, + at::Tensor& dst_k, + const at::Tensor& src_v_opt, // Only used when IsMLA is false (for src_v) + at::Tensor& dst_v_opt, // Only used when IsMLA is false (for dst_v) + const at::Tensor& src_indices, + const at::Tensor& dst_indices, + int64_t page_size, + int64_t num_layers = 1) { + TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); + TORCH_CHECK(page_size > 0, "Page size must be positive"); + TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size"); + + auto src_indices_cpu = src_indices.cpu(); + auto dst_indices_cpu = dst_indices.cpu(); + + const int64_t num_pages = src_indices_cpu.size(0) / page_size; + + for (const auto i : c10::irange(num_pages)) { + auto s_index = src_indices_cpu[i * page_size].item(); + auto d_index = dst_indices_cpu[i * page_size].item(); + + if constexpr (AllLayers) { + for (const auto j : c10::irange(num_layers)) { + if constexpr (IsMLA) { + transfer_page_direct(src_k.select(0, j), dst_k.select(0, j), s_index, d_index, page_size); + } else { + transfer_page_direct(src_k.select(0, j), dst_k.select(0, j), s_index, d_index, page_size); + transfer_page_direct(src_v_opt.select(0, j), dst_v_opt.select(0, j), s_index, d_index, page_size); + } + } + } else { // Per-layer + if constexpr (IsMLA) { + transfer_page_direct(src_k, dst_k, s_index, d_index, page_size); + } else { + transfer_page_direct(src_k, dst_k, s_index, d_index, page_size); + transfer_page_direct(src_v_opt, dst_v_opt, s_index, d_index, page_size); + } + } + } +} + +void transfer_kv_per_layer_direct( + const at::Tensor src_k, + at::Tensor dst_k, + const at::Tensor src_v, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t page_size) { + transfer_kv_direct_impl(src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size); +} + +void transfer_kv_all_layer_direct( + const at::Tensor src_k, + at::Tensor dst_k, + const at::Tensor src_v, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t page_size, + int64_t num_layers) { + transfer_kv_direct_impl(src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size, num_layers); +} + +void transfer_kv_per_layer_mla_direct( + const at::Tensor src, + at::Tensor dst, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t page_size) { + at::Tensor empty_tensor = at::Tensor(); + + transfer_kv_direct_impl(src, dst, empty_tensor, empty_tensor, src_indices, dst_indices, page_size); +} + +void transfer_kv_all_layer_mla_direct( + const at::Tensor src, + at::Tensor dst, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t page_size, + int64_t num_layers) { + at::Tensor empty_tensor = at::Tensor(); + transfer_kv_direct_impl( + src, dst, empty_tensor, empty_tensor, src_indices, dst_indices, page_size, num_layers); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 9588bc736..c90800f76 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -371,6 +371,89 @@ void segment_packbits( int64_t batch_size, int64_t cuda_stream = 0); +/* + * From csrc/kvcacheio + */ +void transfer_kv_per_layer( + const at::Tensor src_k, + at::Tensor dst_k, + const at::Tensor src_v, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t block_quota, + int64_t num_warps_per_block); + +void transfer_kv_per_layer_direct( + const at::Tensor src_k, + at::Tensor dst_k, + const at::Tensor src_v, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t page_size); + +void transfer_kv_all_layer( + const at::Tensor src_k, + at::Tensor dst_k, + const at::Tensor src_v, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t num_layers, + int64_t src_layer_offset, + int64_t dst_layer_offset, + int64_t block_quota, + int64_t num_warps_per_block); + +void transfer_kv_all_layer_direct( + const at::Tensor src_k, + at::Tensor dst_k, + const at::Tensor src_v, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t page_size, + int64_t num_layers); + +void transfer_kv_per_layer_mla( + const at::Tensor src, + at::Tensor dst, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t block_quota, + int64_t num_warps_per_block); + +void transfer_kv_per_layer_mla_direct( + const at::Tensor src, + at::Tensor dst, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t page_size); + +void transfer_kv_all_layer_mla( + const at::Tensor src, + at::Tensor dst, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t num_layers, + int64_t src_layer_offset, + int64_t dst_layer_offset, + int64_t block_quota, + int64_t num_warps_per_block); + +void transfer_kv_all_layer_mla_direct( + const at::Tensor src, + at::Tensor dst, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t page_size, + int64_t num_layers); + /* * From FlashInfer */ diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index d9ce1ff5a..52643f364 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -47,6 +47,12 @@ from sgl_kernel.gemm import ( shuffle_rows, ) from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda +from sgl_kernel.kvcacheio import ( + transfer_kv_all_layer, + transfer_kv_all_layer_mla, + transfer_kv_per_layer, + transfer_kv_per_layer_mla, +) from sgl_kernel.moe import ( apply_shuffle_mul_sum, cutlass_fp4_group_mm, diff --git a/sgl-kernel/python/sgl_kernel/kvcacheio.py b/sgl-kernel/python/sgl_kernel/kvcacheio.py new file mode 100644 index 000000000..5350e49dd --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/kvcacheio.py @@ -0,0 +1,137 @@ +import torch + + +def transfer_kv_per_layer( + src_k: torch.Tensor, + dst_k: torch.Tensor, + src_v: torch.Tensor, + dst_v: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + io_backend: str, + page_size: int, + item_size: int, + block_quota: int = 2, + num_warps_per_block: int = 32, +): + if io_backend == "kernel": + torch.ops.sgl_kernel.transfer_kv_per_layer( + src_k, + dst_k, + src_v, + dst_v, + src_indices, + dst_indices, + item_size, + block_quota, + num_warps_per_block, + ) + elif io_backend == "direct": + torch.ops.sgl_kernel.transfer_kv_per_layer_direct( + src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size + ) + else: + raise ValueError(f"Unsupported io backend") + + +def transfer_kv_all_layer( + src_k: torch.Tensor, + dst_k: torch.Tensor, + src_v: torch.Tensor, + dst_v: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + io_backend: str, + page_size: int, + item_size: int, + num_layers: int, + src_layer_offset: int, + dst_layer_offset: int, + block_quota: int = 2, + num_warps_per_block: int = 32, +): + if io_backend == "kernel": + torch.ops.sgl_kernel.transfer_kv_all_layer( + src_k, + dst_k, + src_v, + dst_v, + src_indices, + dst_indices, + item_size, + num_layers, + src_layer_offset, + dst_layer_offset, + block_quota, + num_warps_per_block, + ) + elif io_backend == "direct": + torch.ops.sgl_kernel.transfer_kv_all_layer_direct( + src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size, num_layers + ) + else: + raise ValueError(f"Unsupported io backend") + + +def transfer_kv_per_layer_mla( + src: torch.Tensor, + dst: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + io_backend: str, + page_size: int, + item_size: int, + block_quota: int = 2, + num_warps_per_block: int = 32, +): + if io_backend == "kernel": + torch.ops.sgl_kernel.transfer_kv_per_layer_mla( + src, + dst, + src_indices, + dst_indices, + item_size, + block_quota, + num_warps_per_block, + ) + elif io_backend == "direct": + torch.ops.sgl_kernel.transfer_kv_per_layer_mla_direct( + src, dst, src_indices, dst_indices, page_size + ) + else: + raise ValueError(f"Unsupported io backend") + + +def transfer_kv_all_layer_mla( + src: torch.Tensor, + dst: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + io_backend: str, + page_size: int, + item_size: int, + num_layers: int, + src_layer_offset: int, + dst_layer_offset: int, + block_quota: int = 2, + num_warps_per_block: int = 32, +): + if io_backend == "kernel": + torch.ops.sgl_kernel.transfer_kv_all_layer_mla( + src, + dst, + src_indices, + dst_indices, + item_size, + num_layers, + src_layer_offset, + dst_layer_offset, + block_quota, + num_warps_per_block, + ) + elif io_backend == "direct": + torch.ops.sgl_kernel.transfer_kv_all_layer_mla_direct( + src, dst, src_indices, dst_indices, page_size, num_layers + ) + else: + raise ValueError(f"Unsupported io backend") diff --git a/sgl-kernel/tests/test_kvcacheio.py b/sgl-kernel/tests/test_kvcacheio.py new file mode 100644 index 000000000..635b5ba50 --- /dev/null +++ b/sgl-kernel/tests/test_kvcacheio.py @@ -0,0 +1,239 @@ +import pytest +import torch +from sgl_kernel.kvcacheio import ( + transfer_kv_all_layer, + transfer_kv_all_layer_mla, + transfer_kv_per_layer, + transfer_kv_per_layer_mla, +) + + +def ref_copy_with_indices(src_pool, dst_pool, src_indices, dst_indices): + dst_pool[dst_indices] = src_pool[src_indices].to(dst_pool.device) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("num_items_to_transfer", [1, 128, 1024]) +@pytest.mark.parametrize("page_size", [1, 16, 64]) +@pytest.mark.parametrize("item_size", [256]) +@pytest.mark.parametrize("total_items_in_pool", [10240]) +@pytest.mark.parametrize("is_mla", [False, True]) +@pytest.mark.parametrize("all_layers", [False, True]) +def test_transfer_kv( + dtype: torch.dtype, + num_items_to_transfer: int, + item_size: int, + page_size: int, + total_items_in_pool: int, + is_mla: bool, + all_layers: bool, +): + """ + Tests the per-layer transfer functions, treating tensors as memory pools. + """ + + original_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + device = "cuda" + torch.cuda.manual_seed(42) + + num_layers = 4 # A small number of layers for pool creation + + total_pages_in_pool = total_items_in_pool // page_size + num_pages_to_transfer = num_items_to_transfer // page_size + if num_pages_to_transfer == 0: + torch.set_default_dtype(original_dtype) + return + page_indices = torch.randperm(total_pages_in_pool, dtype=torch.int64) + src_indices_host = torch.cat( + [ + torch.arange(p * page_size, (p + 1) * page_size) + for p in page_indices[:num_pages_to_transfer] + ] + ) + src_indices_device = src_indices_host.to(device) + dst_indices_host = torch.cat( + [ + torch.arange(p * page_size, (p + 1) * page_size) + for p in page_indices[num_pages_to_transfer : 2 * num_pages_to_transfer] + ] + ) + dst_indices_device = dst_indices_host.to(device) + + # Prepare memory pools based on whether it's an MLA case. + if is_mla: + src_pool_host = torch.randn( + num_layers, total_items_in_pool, item_size + ).pin_memory() + dst_pool_ref = torch.zeros_like(src_pool_host).to(device) + dst_pool_kernel = torch.zeros_like(dst_pool_ref) + dst_pool_direct = torch.zeros_like(dst_pool_ref) + else: + src_k_pool = torch.randn( + num_layers, total_items_in_pool, item_size + ).pin_memory() + src_v_pool = torch.randn( + num_layers, total_items_in_pool, item_size + ).pin_memory() + dst_k_pool_ref = torch.zeros_like(src_k_pool).to(device) + dst_v_pool_ref = torch.zeros_like(src_v_pool).to(device) + dst_k_pool_kernel = torch.zeros_like(dst_k_pool_ref) + dst_v_pool_kernel = torch.zeros_like(dst_v_pool_ref) + dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref) + dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref) + + torch.cuda.synchronize() + + # We will test the per-layer function on the first layer (index 0) of the pool. + layer_idx_to_test = 0 + + if is_mla: + if not all_layers: + ref_copy_with_indices( + src_pool_host[layer_idx_to_test], + dst_pool_ref[layer_idx_to_test], + src_indices_host, + dst_indices_device, + ) + transfer_kv_per_layer_mla( + src_pool_host[layer_idx_to_test], + dst_pool_kernel[layer_idx_to_test], + src_indices_device, + dst_indices_device, + io_backend="kernel", + page_size=page_size, + item_size=item_size, + ) + transfer_kv_per_layer_mla( + src_pool_host[layer_idx_to_test], + dst_pool_direct[layer_idx_to_test], + src_indices_host, + dst_indices_device, + io_backend="direct", + page_size=page_size, + item_size=item_size, + ) + else: + for layer_id in range(num_layers): + ref_copy_with_indices( + src_pool_host[layer_id], + dst_pool_ref[layer_id], + src_indices_host, + dst_indices_device, + ) + transfer_kv_all_layer_mla( + src_pool_host, + dst_pool_kernel, + src_indices_device, + dst_indices_device, + io_backend="kernel", + page_size=page_size, + item_size=item_size, + num_layers=num_layers, + src_layer_offset=total_items_in_pool * item_size, + dst_layer_offset=total_items_in_pool * item_size, + ) + transfer_kv_all_layer_mla( + src_pool_host, + dst_pool_direct, + src_indices_host, + dst_indices_device, + io_backend="direct", + page_size=page_size, + item_size=item_size, + num_layers=num_layers, + src_layer_offset=total_items_in_pool * item_size, + dst_layer_offset=total_items_in_pool * item_size, + ) + torch.cuda.synchronize() + torch.testing.assert_close(dst_pool_kernel, dst_pool_ref) + torch.testing.assert_close(dst_pool_direct, dst_pool_ref) + else: + if not all_layers: + ref_copy_with_indices( + src_k_pool[layer_idx_to_test], + dst_k_pool_ref[layer_idx_to_test], + src_indices_host, + dst_indices_device, + ) + ref_copy_with_indices( + src_v_pool[layer_idx_to_test], + dst_v_pool_ref[layer_idx_to_test], + src_indices_host, + dst_indices_device, + ) + transfer_kv_per_layer( + src_k_pool[layer_idx_to_test], + dst_k_pool_kernel[layer_idx_to_test], + src_v_pool[layer_idx_to_test], + dst_v_pool_kernel[layer_idx_to_test], + src_indices_device, + dst_indices_device, + io_backend="kernel", + page_size=page_size, + item_size=item_size, + ) + transfer_kv_per_layer( + src_k_pool[layer_idx_to_test], + dst_k_pool_direct[layer_idx_to_test], + src_v_pool[layer_idx_to_test], + dst_v_pool_direct[layer_idx_to_test], + src_indices_host, + dst_indices_device, + io_backend="direct", + page_size=page_size, + item_size=item_size, + ) + else: + for layer_id in range(num_layers): + ref_copy_with_indices( + src_k_pool[layer_id], + dst_k_pool_ref[layer_id], + src_indices_host, + dst_indices_device, + ) + ref_copy_with_indices( + src_v_pool[layer_id], + dst_v_pool_ref[layer_id], + src_indices_host, + dst_indices_device, + ) + transfer_kv_all_layer( + src_k_pool, + dst_k_pool_kernel, + src_v_pool, + dst_v_pool_kernel, + src_indices_device, + dst_indices_device, + io_backend="kernel", + page_size=page_size, + item_size=item_size, + num_layers=num_layers, + src_layer_offset=total_items_in_pool * item_size, + dst_layer_offset=total_items_in_pool * item_size, + ) + transfer_kv_all_layer( + src_k_pool, + dst_k_pool_direct, + src_v_pool, + dst_v_pool_direct, + src_indices_host, + dst_indices_device, + io_backend="direct", + page_size=page_size, + item_size=item_size, + num_layers=num_layers, + src_layer_offset=total_items_in_pool * item_size, + dst_layer_offset=total_items_in_pool * item_size, + ) + torch.cuda.synchronize() + torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref) + torch.testing.assert_close(dst_v_pool_kernel, dst_v_pool_ref) + torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref) + torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref) + + torch.set_default_dtype(original_dtype) + + +if __name__ == "__main__": + pytest.main([__file__])