diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index 4fefac941..28181c4ed 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -28,6 +28,7 @@ import triton.language as tl from sglang.srt.mem_cache.memory_pool import SWAKVPool from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2 +from sgl_kernel.kvcacheio import dcu_alloc_decode_kernel if TYPE_CHECKING: from sglang.srt.mem_cache.memory_pool import KVCache @@ -430,6 +431,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): super().__init__(size, page_size, dtype, device, kvcache, need_sort) self.num_pages = size // page_size self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL") + self.use_dcu_decode_kernel = get_bool_env_var("USE_DCU_DECODE_KERNEL") self.seen_max_num_extend_tokens_next_power_of_2 = 1 self.clear() @@ -525,14 +527,26 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): self.merge_and_sort_free() out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device) - alloc_decode_kernel[(bs,)]( - seq_lens, - last_loc, - self.free_pages, - out_indices, - next_power_of_2(bs), - self.page_size, - ) + + if self.use_dcu_decode_kernel: + dcu_alloc_decode_kernel( + seq_lens_ptr = seq_lens, + last_loc_ptr = last_loc, + free_page_ptr = self.free_pages, + out_indices = out_indices, + bs = bs, + bs_upper = next_power_of_2(bs), + page_size = self.page_size, + ) + else: + alloc_decode_kernel[(bs,)]( + seq_lens, + last_loc, + self.free_pages, + out_indices, + next_power_of_2(bs), + self.page_size, + ) if self.debug_mode: assert len(torch.unique(out_indices)) == len(out_indices) diff --git a/sgl-kernel/csrc/common_extension_rocm.cc b/sgl-kernel/csrc/common_extension_rocm.cc index f4e14d0d5..156310d03 100644 --- a/sgl-kernel/csrc/common_extension_rocm.cc +++ b/sgl-kernel/csrc/common_extension_rocm.cc @@ -125,6 +125,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { /* * From csrc/kvcacheio */ + m.def("dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int bs_upper, int page_size) -> ()"); + m.impl("dcu_alloc_decode_kernel", torch::kCUDA, &dcu_alloc_decode_kernel); m.def( "transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " "dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"); diff --git a/sgl-kernel/csrc/kvcacheio/transfer.cu b/sgl-kernel/csrc/kvcacheio/transfer.cu index 8898d4cfa..7764ab98d 100644 --- a/sgl-kernel/csrc/kvcacheio/transfer.cu +++ b/sgl-kernel/csrc/kvcacheio/transfer.cu @@ -571,3 +571,68 @@ void transfer_kv_all_layer_direct_lf_pf( int64_t page_size) { transfer_kv_page_first_direct_impl(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size); } + +__device__ int64_t ceil_div(int64_t a, int64_t b) { + return (a + b - 1) / b; +} + +__global__ void launch_alloc_decode_kernel( + const int64_t* seq_lens_ptr, + const int32_t* last_loc_ptr, + const int64_t* free_page_ptr, + int64_t* out_indices, + int64_t bs_upper, + int64_t page_size) +{ + + int64_t pid = blockIdx.x * blockDim.x + threadIdx.x; + + if (pid >= bs_upper) return; + + int64_t seq_len = seq_lens_ptr[pid]; + int64_t pre_len = seq_len - 1; + + int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size); + + int64_t sum_num_new_pages = 0; + for (int64_t i = 0; i < pid; i++) { + int64_t other_seq_len = seq_lens_ptr[i]; + int64_t other_pre_len = (i <= pid) ? (other_seq_len - 1) : other_seq_len; + + int64_t other_num_pages_after = ceil_div(other_seq_len, page_size); + int64_t other_num_pages_before = ceil_div(other_pre_len, page_size); + int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before; + + sum_num_new_pages += other_num_new_pages; + } + int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self; + + if (num_page_start_loc_self == 0) { + int32_t last_loc = last_loc_ptr[pid]; + out_indices[pid] = last_loc + 1; + } else { + int64_t page = free_page_ptr[new_page_start_loc]; + out_indices[pid] = page * page_size; + } +} + +void dcu_alloc_decode_kernel( + const at::Tensor seq_lens_ptr, + const at::Tensor last_loc_ptr, + const at::Tensor free_page_ptr, + at::Tensor out_indices, + int64_t bs, + int64_t bs_upper, + int64_t page_size) { + + const int64_t* seq_lens_ptr1 = static_cast(seq_lens_ptr.data_ptr()); + const int32_t* last_loc_ptr1 = static_cast(last_loc_ptr.data_ptr()); + const int64_t* free_page_ptr1 = static_cast(free_page_ptr.data_ptr()); + int64_t* out_indices1 = static_cast(out_indices.data_ptr()); + + int64_t block_size = 64; + int64_t grid_size = (bs + block_size - 1) / block_size; + cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); + launch_alloc_decode_kernel<<>>(seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs_upper, page_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 6be8af703..31ffc3192 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -538,6 +538,15 @@ void segment_packbits( /* * From csrc/kvcacheio */ +void dcu_alloc_decode_kernel( + const at::Tensor seq_lens_ptr, + const at::Tensor last_loc_ptr, + const at::Tensor free_page_ptr, + at::Tensor out_indices, + int64_t bs, + int64_t bs_upper, + int64_t page_size); + void transfer_kv_per_layer( const at::Tensor src_k, at::Tensor dst_k, diff --git a/sgl-kernel/python/sgl_kernel/kvcacheio.py b/sgl-kernel/python/sgl_kernel/kvcacheio.py index 5714b6a0d..ec3c9b8a6 100644 --- a/sgl-kernel/python/sgl_kernel/kvcacheio.py +++ b/sgl-kernel/python/sgl_kernel/kvcacheio.py @@ -10,6 +10,25 @@ def is_hip() -> bool: _is_hip = is_hip() +def dcu_alloc_decode_kernel( + seq_lens_ptr: torch.Tensor, + last_loc_ptr: torch.Tensor, + free_page_ptr: torch.Tensor , + out_indices: torch.Tensor , + bs: int, + bs_upper: int, + page_size: int, +): + torch.ops.sgl_kernel.dcu_alloc_decode_kernel( + seq_lens_ptr, + last_loc_ptr, + free_page_ptr, + out_indices, + bs, + bs_upper, + page_size, + ) + def transfer_kv_per_layer( src_k: torch.Tensor, dst_k: torch.Tensor,