diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 4fa98e436..28509b60a 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -280,6 +280,7 @@ set(SOURCES "csrc/speculative/packbit.cu" "csrc/spatial/greenctx_stream.cu" "csrc/speculative/speculative_sampling.cu" + "csrc/memory/store.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 86ef29f24..05473e436 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -413,6 +413,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { */ m.def("create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]"); m.impl("create_greenctx_stream_by_value", &create_greenctx_stream_by_value); + + /* + * From csrc/memory + */ + m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()"); + m.impl("store_kv_cache", &store_kv_cache); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/csrc/memory/store.cu b/sgl-kernel/csrc/memory/store.cu new file mode 100644 index 000000000..c6dd97ebd --- /dev/null +++ b/sgl-kernel/csrc/memory/store.cu @@ -0,0 +1,147 @@ +#include +#include +#include +#include + +#include +#include +#include + +namespace { + +using std::size_t; +using std::uint64_t; + +// Each warp will process 256 bytes per loop iteration +template +__global__ void store_kv_cache_256x1( + uint64_t* __restrict__ k_cache, + uint64_t* __restrict__ v_cache, + const T* __restrict__ out_loc, + const size_t length, + const uint64_t* __restrict__ k, + const uint64_t* __restrict__ v, + const size_t kv_cache_stride, + const size_t kv_input_stride, + const size_t num_items) { + const auto idx = blockIdx.x * blockDim.x + threadIdx.x; + const auto warp_id = idx / 32; + const auto lane_id = idx % 32; + if (warp_id >= length) return; + const auto offset = out_loc[warp_id]; + const auto k_dst = k_cache + offset * kv_cache_stride; + const auto v_dst = v_cache + offset * kv_cache_stride; + const auto k_src = k + warp_id * kv_input_stride; + const auto v_src = v + warp_id * kv_input_stride; + for (size_t i = 0; i < num_items; ++i) { + k_dst[lane_id + i * 32] = k_src[lane_id + i * 32]; + v_dst[lane_id + i * 32] = v_src[lane_id + i * 32]; + } +} + +// Each warp will process 128 bytes per loop iteration +template +__global__ void store_kv_cache_128x2( + uint64_t* __restrict__ k_cache, + uint64_t* __restrict__ v_cache, + const T* __restrict__ out_loc, + const size_t length, + const uint64_t* __restrict__ k, + const uint64_t* __restrict__ v, + const size_t kv_cache_stride, + const size_t kv_input_stride, + const size_t num_items) { + const auto idx = blockIdx.x * blockDim.x + threadIdx.x; + const auto warp_id = idx / 32; + const auto lane_id = idx % 32; + if (warp_id >= length) return; + const auto offset = out_loc[warp_id]; + const auto copy_k = lane_id < 16; + const auto copy_id = lane_id % 16; + const auto cache = copy_k ? k_cache : v_cache; + const auto input = copy_k ? k : v; + const auto dst = cache + offset * kv_cache_stride; + const auto src = input + warp_id * kv_input_stride; + for (size_t i = 0; i < num_items; ++i) { + dst[copy_id + i * 16] = src[copy_id + i * 16]; + } +} + +} // namespace + +auto store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v) -> void { + const auto max_tokens = k_cache.size(0); + const auto num_tokens = out_loc.size(0); + k_cache = k_cache.view({max_tokens, -1}); + v_cache = v_cache.view({max_tokens, -1}); + k = k.view({num_tokens, -1}); + v = v.view({num_tokens, -1}); + + TORCH_CHECK( + k_cache.is_cuda() && v_cache.is_cuda() && out_loc.is_cuda() && k.is_cuda() && v.is_cuda(), + "All tensors must be CUDA tensors"); + TORCH_CHECK(k_cache.sizes() == v_cache.sizes(), "k_cache and v_cache must have the same size"); + TORCH_CHECK(k_cache.strides() == v_cache.strides(), "k_cache and v_cache must have the same strides"); + TORCH_CHECK(k.sizes() == v.sizes(), "k and v must have the same size"); + TORCH_CHECK(k.strides() == v.strides(), "k and v must have the same strides"); + TORCH_CHECK(k.stride(-1) == 1 && k_cache.stride(-1) == 1, "k and k_cache must be contiguous in head."); + TORCH_CHECK(k.size(-1) == k_cache.size(-1), "k and k_cache must have the same head size"); + TORCH_CHECK(out_loc.dim() == 1 && out_loc.is_contiguous(), "out_loc must be a 1D contiguous tensor"); + static_assert(sizeof(uint64_t) == 8, "uint64_t must be 8 bytes, our code assumes that"); + + const auto length = out_loc.size(0); + const auto elem_size = k.element_size(); + const auto size_bytes = elem_size * k.size(-1); + const auto kv_cache_stride_bytes = elem_size * k_cache.stride(-2); + const auto kv_input_stride_bytes = elem_size * k.stride(-2); + const auto kv_cache_stride = kv_cache_stride_bytes / 8; + const auto kv_input_stride = kv_input_stride_bytes / 8; + + const auto k_cache_ptr = static_cast(k_cache.data_ptr()); + const auto v_cache_ptr = static_cast(v_cache.data_ptr()); + const auto k_ptr = static_cast(k.data_ptr()); + const auto v_ptr = static_cast(v.data_ptr()); + const auto num_threads = 256; + const auto num_warps = num_threads / 32; + const auto num_blocks = (length + num_warps - 1) / num_warps; + const auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_INTEGRAL_TYPES(out_loc.scalar_type(), "store_kv_cache", [&] { + if constexpr (!std::is_same_v && !std::is_same_v) { + // do not instantiate the kernel if out_loc is not int32 or int64 + TORCH_CHECK(false, "out_loc must be of type int32 or int64, got: ", out_loc.scalar_type()); + } else { + if (size_bytes % 256 == 0) { + const auto items_per_warp = size_bytes / 256; + store_kv_cache_256x1<<>>( + k_cache_ptr, + v_cache_ptr, + out_loc.data_ptr(), + length, + k_ptr, + v_ptr, + kv_cache_stride, + kv_input_stride, + items_per_warp); + } else if (size_bytes % 128 == 0) { + const auto items_per_warp = size_bytes / 128; + store_kv_cache_128x2<<>>( + k_cache_ptr, + v_cache_ptr, + out_loc.data_ptr(), + length, + k_ptr, + v_ptr, + kv_cache_stride, + kv_input_stride, + items_per_warp); + } else { + TORCH_CHECK( + false, + "The last dimension size bytes of k and v must be" + " divisible by 128 at least, got: ", + size_bytes); + } + } + }); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index c007251cd..df7c24095 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -699,3 +699,8 @@ void qserve_w4a8_per_group_gemm( * From csrc/spatial */ std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device); + +/* + * From csrc/memory + */ +void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index faeff9240..2abf0bcf3 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -67,6 +67,7 @@ from sgl_kernel.marlin import ( awq_marlin_repack, gptq_marlin_repack, ) +from sgl_kernel.memory import set_kv_buffer_kernel from sgl_kernel.moe import ( apply_shuffle_mul_sum, cutlass_fp4_group_mm, diff --git a/sgl-kernel/python/sgl_kernel/memory.py b/sgl-kernel/python/sgl_kernel/memory.py new file mode 100644 index 000000000..eb997db0c --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/memory.py @@ -0,0 +1,18 @@ +import torch + + +def set_kv_buffer_kernel( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + loc: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + fallback: bool = False, +): + try: + if fallback: + raise RuntimeError("Fallback to torch implementation") + torch.ops.sgl_kernel.store_kv_cache(k_cache, v_cache, loc, k, v) + except RuntimeError: # ok, fallback to torch implementation + k_cache[loc] = k + v_cache[loc] = v