kvcache io kernels and test case (#7382)
This commit is contained in:
@@ -250,6 +250,7 @@ set(SOURCES
|
||||
"csrc/speculative/packbit.cu"
|
||||
"csrc/speculative/speculative_sampling.cu"
|
||||
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
|
||||
"csrc/kvcacheio/transfer.cu"
|
||||
"csrc/common_extension.cc"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
|
||||
|
||||
@@ -230,6 +230,43 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"int cuda_stream) -> ()");
|
||||
m.impl("segment_packbits", torch::kCUDA, &segment_packbits);
|
||||
|
||||
/*
|
||||
* From csrc/kvcacheio
|
||||
*/
|
||||
m.def(
|
||||
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
||||
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
||||
"dst_indices, int page_size) -> ()");
|
||||
m.impl("transfer_kv_per_layer_direct", torch::kCUDA, &transfer_kv_per_layer_direct);
|
||||
m.def(
|
||||
"transfer_kv_all_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
||||
"dst_indices, int item_size, int num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int "
|
||||
"num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_all_layer", torch::kCUDA, &transfer_kv_all_layer);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
||||
"dst_indices, int page_size, int num_layers) -> ()");
|
||||
m.impl("transfer_kv_all_layer_direct", torch::kCUDA, &transfer_kv_all_layer_direct);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
|
||||
"block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size) "
|
||||
"-> ()");
|
||||
m.impl("transfer_kv_per_layer_mla_direct", torch::kCUDA, &transfer_kv_per_layer_mla_direct);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
|
||||
"num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_all_layer_mla", torch::kCUDA, &transfer_kv_all_layer_mla);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size, "
|
||||
"int num_layers) -> ()");
|
||||
m.impl("transfer_kv_all_layer_mla_direct", torch::kCUDA, &transfer_kv_all_layer_mla_direct);
|
||||
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
|
||||
342
sgl-kernel/csrc/kvcacheio/transfer.cu
Normal file
342
sgl-kernel/csrc/kvcacheio/transfer.cu
Normal file
@@ -0,0 +1,342 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "pytorch_extension_utils.h"
|
||||
|
||||
__device__ __forceinline__ void
|
||||
transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_t item_size_bytes) {
|
||||
// todo, different chunk size
|
||||
int total_chunks = item_size_bytes / 8;
|
||||
const int64_t* src_8 = reinterpret_cast<const int64_t*>(src_addr);
|
||||
int64_t* dst_8 = reinterpret_cast<int64_t*>(dst_addr);
|
||||
#pragma unroll
|
||||
for (int j = lane_id; j < total_chunks; j += 32) {
|
||||
const int64_t* src_addr_lane = &src_8[j];
|
||||
int64_t* dst_addr_lane = &dst_8[j];
|
||||
int64_t temp_val;
|
||||
asm volatile("ld.global.nc.b64 %0, [%1];" : "=l"(temp_val) : "l"(src_addr_lane) : "memory");
|
||||
asm volatile("st.global.cg.b64 [%0], %1;" ::"l"(dst_addr_lane), "l"(temp_val) : "memory");
|
||||
}
|
||||
}
|
||||
|
||||
// todo, structs for different memory layout
|
||||
__device__ __forceinline__ int64_t
|
||||
get_global_offset_lf(int64_t layer_id, int64_t layer_dim, int64_t page_id, int64_t item_size_bytes) {
|
||||
// layer first
|
||||
return layer_id * layer_dim + page_id * item_size_bytes;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int64_t
|
||||
get_global_offset_pf(int64_t layer_id, int64_t page_dim, int64_t page_id, int64_t item_size_bytes) {
|
||||
// page first
|
||||
return page_id * page_dim + layer_id * item_size_bytes;
|
||||
}
|
||||
|
||||
template <auto SrcOffsetFn, auto DstOffsetFn, bool IsMLA>
|
||||
__global__ void transfer_kernel_impl(
|
||||
const void* __restrict__ src_k,
|
||||
void* __restrict__ dst_k,
|
||||
const void* __restrict__ src_v,
|
||||
void* __restrict__ dst_v,
|
||||
const int64_t* __restrict__ src_indices,
|
||||
const int64_t* __restrict__ dst_indices,
|
||||
int64_t start_layer_id,
|
||||
int64_t num_layers_to_process,
|
||||
int64_t num_items,
|
||||
int64_t items_per_warp,
|
||||
int64_t item_size_bytes,
|
||||
int64_t src_layout_dim,
|
||||
int64_t dst_layout_dim) {
|
||||
int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int32_t lane_id = tid % 32;
|
||||
int32_t warp_id = tid / 32;
|
||||
|
||||
for (int i = 0; i < items_per_warp; ++i) {
|
||||
int32_t item_id = warp_id * items_per_warp + i;
|
||||
if (item_id >= num_items) {
|
||||
return;
|
||||
}
|
||||
const int64_t src_page_id = src_indices[item_id];
|
||||
const int64_t dst_page_id = dst_indices[item_id];
|
||||
|
||||
// Loop over layers if necessary
|
||||
for (int64_t layer_id = start_layer_id; layer_id < start_layer_id + num_layers_to_process; ++layer_id) {
|
||||
// Calculate offsets using the provided function pointers
|
||||
const int64_t src_offset = SrcOffsetFn(layer_id, src_layout_dim, src_page_id, item_size_bytes);
|
||||
const int64_t dst_offset = DstOffsetFn(layer_id, dst_layout_dim, dst_page_id, item_size_bytes);
|
||||
|
||||
if constexpr (IsMLA) {
|
||||
transfer_item_warp(
|
||||
lane_id,
|
||||
static_cast<const char*>(src_k) + src_offset,
|
||||
static_cast<char*>(dst_k) + dst_offset,
|
||||
item_size_bytes);
|
||||
} else {
|
||||
transfer_item_warp(
|
||||
lane_id,
|
||||
static_cast<const char*>(src_k) + src_offset,
|
||||
static_cast<char*>(dst_k) + dst_offset,
|
||||
item_size_bytes);
|
||||
transfer_item_warp(
|
||||
lane_id,
|
||||
static_cast<const char*>(src_v) + src_offset,
|
||||
static_cast<char*>(dst_v) + dst_offset,
|
||||
item_size_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <auto SrcOffsetFn, auto DstOffsetFn, bool IsMLA>
|
||||
void transfer_kv_launcher(
|
||||
const at::Tensor& src_k,
|
||||
at::Tensor& dst_k,
|
||||
const at::Tensor& src_v,
|
||||
at::Tensor& dst_v,
|
||||
const at::Tensor& src_indices,
|
||||
const at::Tensor& dst_indices,
|
||||
int64_t start_layer_id,
|
||||
int64_t num_layers_to_process,
|
||||
int64_t item_size,
|
||||
int64_t src_layout_dim,
|
||||
int64_t dst_layout_dim,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block) {
|
||||
TORCH_CHECK(src_k.scalar_type() == dst_k.scalar_type(), "Source and destination keys must have the same type");
|
||||
TORCH_CHECK(src_indices.is_cuda(), "Source indices must be a CUDA tensor");
|
||||
TORCH_CHECK(dst_indices.is_cuda(), "Destination indices must be a CUDA tensor");
|
||||
TORCH_CHECK(src_indices.scalar_type() == at::kLong, "Source indices must be of type long");
|
||||
TORCH_CHECK(dst_indices.scalar_type() == at::kLong, "Destination indices must be of type long");
|
||||
TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length");
|
||||
|
||||
if (!IsMLA) {
|
||||
TORCH_CHECK(src_v.scalar_type() == dst_v.scalar_type(), "Source and destination values must have the same type");
|
||||
}
|
||||
|
||||
int dtype_size = src_k.element_size();
|
||||
TORCH_CHECK((item_size * dtype_size) % 8 == 0, "Item byte size must be divisible by 8");
|
||||
|
||||
auto div_up = [](int32_t x, int32_t y) { return (x + y - 1) / y; };
|
||||
const int64_t num_items = src_indices.numel();
|
||||
const int64_t items_per_warp = div_up(num_items, block_quota * num_warps_per_block);
|
||||
const int32_t num_blocks = div_up(num_items, items_per_warp * num_warps_per_block);
|
||||
dim3 grid_dim(num_blocks, 1, 1);
|
||||
const int32_t threads_per_block = num_warps_per_block * 32;
|
||||
|
||||
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
|
||||
transfer_kernel_impl<SrcOffsetFn, DstOffsetFn, IsMLA><<<grid_dim, threads_per_block, 0, torch_current_stream>>>(
|
||||
src_k.data_ptr(),
|
||||
dst_k.data_ptr(),
|
||||
(IsMLA ? nullptr : src_v.data_ptr()),
|
||||
(IsMLA ? nullptr : dst_v.data_ptr()),
|
||||
src_indices.data_ptr<int64_t>(),
|
||||
dst_indices.data_ptr<int64_t>(),
|
||||
start_layer_id,
|
||||
num_layers_to_process,
|
||||
num_items,
|
||||
items_per_warp,
|
||||
item_size * dtype_size,
|
||||
src_layout_dim * dtype_size,
|
||||
dst_layout_dim * dtype_size);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
void transfer_kv_per_layer(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block) {
|
||||
transfer_kv_launcher<get_global_offset_lf, get_global_offset_lf, false>(
|
||||
src_k, dst_k, src_v, dst_v, src_indices, dst_indices, 0, 1, item_size, 0, 0, block_quota, num_warps_per_block);
|
||||
}
|
||||
|
||||
void transfer_kv_all_layer(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t num_layers,
|
||||
int64_t src_layer_offset,
|
||||
int64_t dst_layer_offset,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block) {
|
||||
transfer_kv_launcher<get_global_offset_lf, get_global_offset_lf, false>(
|
||||
src_k,
|
||||
dst_k,
|
||||
src_v,
|
||||
dst_v,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
0,
|
||||
num_layers,
|
||||
item_size,
|
||||
src_layer_offset,
|
||||
dst_layer_offset,
|
||||
block_quota,
|
||||
num_warps_per_block);
|
||||
}
|
||||
|
||||
void transfer_kv_per_layer_mla(
|
||||
const at::Tensor src,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block) {
|
||||
at::Tensor empty_tensor = at::Tensor();
|
||||
transfer_kv_launcher<get_global_offset_lf, get_global_offset_lf, true>(
|
||||
src,
|
||||
dst,
|
||||
empty_tensor,
|
||||
empty_tensor,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
0,
|
||||
1,
|
||||
item_size,
|
||||
0,
|
||||
0,
|
||||
block_quota,
|
||||
num_warps_per_block);
|
||||
}
|
||||
|
||||
void transfer_kv_all_layer_mla(
|
||||
const at::Tensor src,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t num_layers,
|
||||
int64_t src_layer_offset,
|
||||
int64_t dst_layer_offset,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block) {
|
||||
at::Tensor empty_tensor = at::Tensor();
|
||||
transfer_kv_launcher<get_global_offset_lf, get_global_offset_lf, true>(
|
||||
src,
|
||||
dst,
|
||||
empty_tensor,
|
||||
empty_tensor,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
0,
|
||||
num_layers,
|
||||
item_size,
|
||||
src_layer_offset,
|
||||
dst_layer_offset,
|
||||
block_quota,
|
||||
num_warps_per_block);
|
||||
}
|
||||
|
||||
inline void transfer_page_direct(
|
||||
const at::Tensor src_buffer,
|
||||
at::Tensor dst_buffer,
|
||||
int64_t src_page_index,
|
||||
int64_t dst_page_index,
|
||||
int64_t page_size) {
|
||||
dst_buffer.slice(0, dst_page_index, dst_page_index + page_size)
|
||||
.copy_(
|
||||
src_buffer.slice(0, src_page_index, src_page_index + page_size),
|
||||
/* non_blocking= */ true);
|
||||
}
|
||||
|
||||
template <bool IsMLA, bool AllLayers>
|
||||
inline void transfer_kv_direct_impl(
|
||||
const at::Tensor& src_k,
|
||||
at::Tensor& dst_k,
|
||||
const at::Tensor& src_v_opt, // Only used when IsMLA is false (for src_v)
|
||||
at::Tensor& dst_v_opt, // Only used when IsMLA is false (for dst_v)
|
||||
const at::Tensor& src_indices,
|
||||
const at::Tensor& dst_indices,
|
||||
int64_t page_size,
|
||||
int64_t num_layers = 1) {
|
||||
TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length");
|
||||
TORCH_CHECK(page_size > 0, "Page size must be positive");
|
||||
TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size");
|
||||
|
||||
auto src_indices_cpu = src_indices.cpu();
|
||||
auto dst_indices_cpu = dst_indices.cpu();
|
||||
|
||||
const int64_t num_pages = src_indices_cpu.size(0) / page_size;
|
||||
|
||||
for (const auto i : c10::irange(num_pages)) {
|
||||
auto s_index = src_indices_cpu[i * page_size].item<int64_t>();
|
||||
auto d_index = dst_indices_cpu[i * page_size].item<int64_t>();
|
||||
|
||||
if constexpr (AllLayers) {
|
||||
for (const auto j : c10::irange(num_layers)) {
|
||||
if constexpr (IsMLA) {
|
||||
transfer_page_direct(src_k.select(0, j), dst_k.select(0, j), s_index, d_index, page_size);
|
||||
} else {
|
||||
transfer_page_direct(src_k.select(0, j), dst_k.select(0, j), s_index, d_index, page_size);
|
||||
transfer_page_direct(src_v_opt.select(0, j), dst_v_opt.select(0, j), s_index, d_index, page_size);
|
||||
}
|
||||
}
|
||||
} else { // Per-layer
|
||||
if constexpr (IsMLA) {
|
||||
transfer_page_direct(src_k, dst_k, s_index, d_index, page_size);
|
||||
} else {
|
||||
transfer_page_direct(src_k, dst_k, s_index, d_index, page_size);
|
||||
transfer_page_direct(src_v_opt, dst_v_opt, s_index, d_index, page_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void transfer_kv_per_layer_direct(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size) {
|
||||
transfer_kv_direct_impl<false, false>(src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size);
|
||||
}
|
||||
|
||||
void transfer_kv_all_layer_direct(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size,
|
||||
int64_t num_layers) {
|
||||
transfer_kv_direct_impl<false, true>(src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size, num_layers);
|
||||
}
|
||||
|
||||
void transfer_kv_per_layer_mla_direct(
|
||||
const at::Tensor src,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size) {
|
||||
at::Tensor empty_tensor = at::Tensor();
|
||||
|
||||
transfer_kv_direct_impl<true, false>(src, dst, empty_tensor, empty_tensor, src_indices, dst_indices, page_size);
|
||||
}
|
||||
|
||||
void transfer_kv_all_layer_mla_direct(
|
||||
const at::Tensor src,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size,
|
||||
int64_t num_layers) {
|
||||
at::Tensor empty_tensor = at::Tensor();
|
||||
transfer_kv_direct_impl<true, true>(
|
||||
src, dst, empty_tensor, empty_tensor, src_indices, dst_indices, page_size, num_layers);
|
||||
}
|
||||
@@ -371,6 +371,89 @@ void segment_packbits(
|
||||
int64_t batch_size,
|
||||
int64_t cuda_stream = 0);
|
||||
|
||||
/*
|
||||
* From csrc/kvcacheio
|
||||
*/
|
||||
void transfer_kv_per_layer(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_per_layer_direct(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size);
|
||||
|
||||
void transfer_kv_all_layer(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t num_layers,
|
||||
int64_t src_layer_offset,
|
||||
int64_t dst_layer_offset,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_all_layer_direct(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size,
|
||||
int64_t num_layers);
|
||||
|
||||
void transfer_kv_per_layer_mla(
|
||||
const at::Tensor src,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_per_layer_mla_direct(
|
||||
const at::Tensor src,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size);
|
||||
|
||||
void transfer_kv_all_layer_mla(
|
||||
const at::Tensor src,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t num_layers,
|
||||
int64_t src_layer_offset,
|
||||
int64_t dst_layer_offset,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_all_layer_mla_direct(
|
||||
const at::Tensor src,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size,
|
||||
int64_t num_layers);
|
||||
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
|
||||
@@ -47,6 +47,12 @@ from sgl_kernel.gemm import (
|
||||
shuffle_rows,
|
||||
)
|
||||
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
|
||||
from sgl_kernel.kvcacheio import (
|
||||
transfer_kv_all_layer,
|
||||
transfer_kv_all_layer_mla,
|
||||
transfer_kv_per_layer,
|
||||
transfer_kv_per_layer_mla,
|
||||
)
|
||||
from sgl_kernel.moe import (
|
||||
apply_shuffle_mul_sum,
|
||||
cutlass_fp4_group_mm,
|
||||
|
||||
137
sgl-kernel/python/sgl_kernel/kvcacheio.py
Normal file
137
sgl-kernel/python/sgl_kernel/kvcacheio.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import torch
|
||||
|
||||
|
||||
def transfer_kv_per_layer(
|
||||
src_k: torch.Tensor,
|
||||
dst_k: torch.Tensor,
|
||||
src_v: torch.Tensor,
|
||||
dst_v: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
io_backend: str,
|
||||
page_size: int,
|
||||
item_size: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 32,
|
||||
):
|
||||
if io_backend == "kernel":
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer(
|
||||
src_k,
|
||||
dst_k,
|
||||
src_v,
|
||||
dst_v,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
item_size,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
elif io_backend == "direct":
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer_direct(
|
||||
src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported io backend")
|
||||
|
||||
|
||||
def transfer_kv_all_layer(
|
||||
src_k: torch.Tensor,
|
||||
dst_k: torch.Tensor,
|
||||
src_v: torch.Tensor,
|
||||
dst_v: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
io_backend: str,
|
||||
page_size: int,
|
||||
item_size: int,
|
||||
num_layers: int,
|
||||
src_layer_offset: int,
|
||||
dst_layer_offset: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 32,
|
||||
):
|
||||
if io_backend == "kernel":
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer(
|
||||
src_k,
|
||||
dst_k,
|
||||
src_v,
|
||||
dst_v,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
item_size,
|
||||
num_layers,
|
||||
src_layer_offset,
|
||||
dst_layer_offset,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
elif io_backend == "direct":
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer_direct(
|
||||
src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size, num_layers
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported io backend")
|
||||
|
||||
|
||||
def transfer_kv_per_layer_mla(
|
||||
src: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
io_backend: str,
|
||||
page_size: int,
|
||||
item_size: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 32,
|
||||
):
|
||||
if io_backend == "kernel":
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
|
||||
src,
|
||||
dst,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
item_size,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
elif io_backend == "direct":
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer_mla_direct(
|
||||
src, dst, src_indices, dst_indices, page_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported io backend")
|
||||
|
||||
|
||||
def transfer_kv_all_layer_mla(
|
||||
src: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
io_backend: str,
|
||||
page_size: int,
|
||||
item_size: int,
|
||||
num_layers: int,
|
||||
src_layer_offset: int,
|
||||
dst_layer_offset: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 32,
|
||||
):
|
||||
if io_backend == "kernel":
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
|
||||
src,
|
||||
dst,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
item_size,
|
||||
num_layers,
|
||||
src_layer_offset,
|
||||
dst_layer_offset,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
elif io_backend == "direct":
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer_mla_direct(
|
||||
src, dst, src_indices, dst_indices, page_size, num_layers
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported io backend")
|
||||
239
sgl-kernel/tests/test_kvcacheio.py
Normal file
239
sgl-kernel/tests/test_kvcacheio.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel.kvcacheio import (
|
||||
transfer_kv_all_layer,
|
||||
transfer_kv_all_layer_mla,
|
||||
transfer_kv_per_layer,
|
||||
transfer_kv_per_layer_mla,
|
||||
)
|
||||
|
||||
|
||||
def ref_copy_with_indices(src_pool, dst_pool, src_indices, dst_indices):
|
||||
dst_pool[dst_indices] = src_pool[src_indices].to(dst_pool.device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("num_items_to_transfer", [1, 128, 1024])
|
||||
@pytest.mark.parametrize("page_size", [1, 16, 64])
|
||||
@pytest.mark.parametrize("item_size", [256])
|
||||
@pytest.mark.parametrize("total_items_in_pool", [10240])
|
||||
@pytest.mark.parametrize("is_mla", [False, True])
|
||||
@pytest.mark.parametrize("all_layers", [False, True])
|
||||
def test_transfer_kv(
|
||||
dtype: torch.dtype,
|
||||
num_items_to_transfer: int,
|
||||
item_size: int,
|
||||
page_size: int,
|
||||
total_items_in_pool: int,
|
||||
is_mla: bool,
|
||||
all_layers: bool,
|
||||
):
|
||||
"""
|
||||
Tests the per-layer transfer functions, treating tensors as memory pools.
|
||||
"""
|
||||
|
||||
original_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
device = "cuda"
|
||||
torch.cuda.manual_seed(42)
|
||||
|
||||
num_layers = 4 # A small number of layers for pool creation
|
||||
|
||||
total_pages_in_pool = total_items_in_pool // page_size
|
||||
num_pages_to_transfer = num_items_to_transfer // page_size
|
||||
if num_pages_to_transfer == 0:
|
||||
torch.set_default_dtype(original_dtype)
|
||||
return
|
||||
page_indices = torch.randperm(total_pages_in_pool, dtype=torch.int64)
|
||||
src_indices_host = torch.cat(
|
||||
[
|
||||
torch.arange(p * page_size, (p + 1) * page_size)
|
||||
for p in page_indices[:num_pages_to_transfer]
|
||||
]
|
||||
)
|
||||
src_indices_device = src_indices_host.to(device)
|
||||
dst_indices_host = torch.cat(
|
||||
[
|
||||
torch.arange(p * page_size, (p + 1) * page_size)
|
||||
for p in page_indices[num_pages_to_transfer : 2 * num_pages_to_transfer]
|
||||
]
|
||||
)
|
||||
dst_indices_device = dst_indices_host.to(device)
|
||||
|
||||
# Prepare memory pools based on whether it's an MLA case.
|
||||
if is_mla:
|
||||
src_pool_host = torch.randn(
|
||||
num_layers, total_items_in_pool, item_size
|
||||
).pin_memory()
|
||||
dst_pool_ref = torch.zeros_like(src_pool_host).to(device)
|
||||
dst_pool_kernel = torch.zeros_like(dst_pool_ref)
|
||||
dst_pool_direct = torch.zeros_like(dst_pool_ref)
|
||||
else:
|
||||
src_k_pool = torch.randn(
|
||||
num_layers, total_items_in_pool, item_size
|
||||
).pin_memory()
|
||||
src_v_pool = torch.randn(
|
||||
num_layers, total_items_in_pool, item_size
|
||||
).pin_memory()
|
||||
dst_k_pool_ref = torch.zeros_like(src_k_pool).to(device)
|
||||
dst_v_pool_ref = torch.zeros_like(src_v_pool).to(device)
|
||||
dst_k_pool_kernel = torch.zeros_like(dst_k_pool_ref)
|
||||
dst_v_pool_kernel = torch.zeros_like(dst_v_pool_ref)
|
||||
dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref)
|
||||
dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# We will test the per-layer function on the first layer (index 0) of the pool.
|
||||
layer_idx_to_test = 0
|
||||
|
||||
if is_mla:
|
||||
if not all_layers:
|
||||
ref_copy_with_indices(
|
||||
src_pool_host[layer_idx_to_test],
|
||||
dst_pool_ref[layer_idx_to_test],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
)
|
||||
transfer_kv_per_layer_mla(
|
||||
src_pool_host[layer_idx_to_test],
|
||||
dst_pool_kernel[layer_idx_to_test],
|
||||
src_indices_device,
|
||||
dst_indices_device,
|
||||
io_backend="kernel",
|
||||
page_size=page_size,
|
||||
item_size=item_size,
|
||||
)
|
||||
transfer_kv_per_layer_mla(
|
||||
src_pool_host[layer_idx_to_test],
|
||||
dst_pool_direct[layer_idx_to_test],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
io_backend="direct",
|
||||
page_size=page_size,
|
||||
item_size=item_size,
|
||||
)
|
||||
else:
|
||||
for layer_id in range(num_layers):
|
||||
ref_copy_with_indices(
|
||||
src_pool_host[layer_id],
|
||||
dst_pool_ref[layer_id],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
)
|
||||
transfer_kv_all_layer_mla(
|
||||
src_pool_host,
|
||||
dst_pool_kernel,
|
||||
src_indices_device,
|
||||
dst_indices_device,
|
||||
io_backend="kernel",
|
||||
page_size=page_size,
|
||||
item_size=item_size,
|
||||
num_layers=num_layers,
|
||||
src_layer_offset=total_items_in_pool * item_size,
|
||||
dst_layer_offset=total_items_in_pool * item_size,
|
||||
)
|
||||
transfer_kv_all_layer_mla(
|
||||
src_pool_host,
|
||||
dst_pool_direct,
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
io_backend="direct",
|
||||
page_size=page_size,
|
||||
item_size=item_size,
|
||||
num_layers=num_layers,
|
||||
src_layer_offset=total_items_in_pool * item_size,
|
||||
dst_layer_offset=total_items_in_pool * item_size,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(dst_pool_kernel, dst_pool_ref)
|
||||
torch.testing.assert_close(dst_pool_direct, dst_pool_ref)
|
||||
else:
|
||||
if not all_layers:
|
||||
ref_copy_with_indices(
|
||||
src_k_pool[layer_idx_to_test],
|
||||
dst_k_pool_ref[layer_idx_to_test],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
)
|
||||
ref_copy_with_indices(
|
||||
src_v_pool[layer_idx_to_test],
|
||||
dst_v_pool_ref[layer_idx_to_test],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
)
|
||||
transfer_kv_per_layer(
|
||||
src_k_pool[layer_idx_to_test],
|
||||
dst_k_pool_kernel[layer_idx_to_test],
|
||||
src_v_pool[layer_idx_to_test],
|
||||
dst_v_pool_kernel[layer_idx_to_test],
|
||||
src_indices_device,
|
||||
dst_indices_device,
|
||||
io_backend="kernel",
|
||||
page_size=page_size,
|
||||
item_size=item_size,
|
||||
)
|
||||
transfer_kv_per_layer(
|
||||
src_k_pool[layer_idx_to_test],
|
||||
dst_k_pool_direct[layer_idx_to_test],
|
||||
src_v_pool[layer_idx_to_test],
|
||||
dst_v_pool_direct[layer_idx_to_test],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
io_backend="direct",
|
||||
page_size=page_size,
|
||||
item_size=item_size,
|
||||
)
|
||||
else:
|
||||
for layer_id in range(num_layers):
|
||||
ref_copy_with_indices(
|
||||
src_k_pool[layer_id],
|
||||
dst_k_pool_ref[layer_id],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
)
|
||||
ref_copy_with_indices(
|
||||
src_v_pool[layer_id],
|
||||
dst_v_pool_ref[layer_id],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
)
|
||||
transfer_kv_all_layer(
|
||||
src_k_pool,
|
||||
dst_k_pool_kernel,
|
||||
src_v_pool,
|
||||
dst_v_pool_kernel,
|
||||
src_indices_device,
|
||||
dst_indices_device,
|
||||
io_backend="kernel",
|
||||
page_size=page_size,
|
||||
item_size=item_size,
|
||||
num_layers=num_layers,
|
||||
src_layer_offset=total_items_in_pool * item_size,
|
||||
dst_layer_offset=total_items_in_pool * item_size,
|
||||
)
|
||||
transfer_kv_all_layer(
|
||||
src_k_pool,
|
||||
dst_k_pool_direct,
|
||||
src_v_pool,
|
||||
dst_v_pool_direct,
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
io_backend="direct",
|
||||
page_size=page_size,
|
||||
item_size=item_size,
|
||||
num_layers=num_layers,
|
||||
src_layer_offset=total_items_in_pool * item_size,
|
||||
dst_layer_offset=total_items_in_pool * item_size,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref)
|
||||
torch.testing.assert_close(dst_v_pool_kernel, dst_v_pool_ref)
|
||||
torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref)
|
||||
torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref)
|
||||
|
||||
torch.set_default_dtype(original_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user