diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 125ed29dc..e597e3111 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -113,6 +113,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "fast_topk_transform_fused(Tensor score, Tensor lengths, Tensor dst_page_table, Tensor src_page_table, Tensor " "cu_seqlens_q) -> ()"); m.impl("fast_topk_transform_fused", torch::kCUDA, &fast_topk_transform_interface); + m.def( + "fast_topk_transform_ragged_fused(Tensor score, Tensor lengths, Tensor topk_indices_ragged, Tensor " + "topk_indices_offset) -> ()"); + m.impl("fast_topk_transform_ragged_fused", torch::kCUDA, &fast_topk_transform_ragged_interface); /* * From gguf quantiztion diff --git a/sgl-kernel/csrc/elementwise/topk.cu b/sgl-kernel/csrc/elementwise/topk.cu index 05fb9f08b..b2515ca28 100644 --- a/sgl-kernel/csrc/elementwise/topk.cu +++ b/sgl-kernel/csrc/elementwise/topk.cu @@ -51,6 +51,15 @@ __device__ void naive_topk_transform( } } +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } +} + __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { __half h = __float2half_rn(x); uint16_t bits = __half_as_ushort(h); @@ -322,8 +331,40 @@ __global__ __launch_bounds__(kThreadsPerBlock) // prefill } } -auto get_params(at::Tensor score, at::Tensor lengths, std::optional indices_opt = std::nullopt) - -> FastTopKParams { +__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } +} + +auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional indices_opt = std::nullopt) -> FastTopKParams { const auto B = score.size(0); TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); @@ -357,7 +398,7 @@ void setup_kernel_smem_once() { #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor lengths) { +void fast_topk_interface(const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths) { CHECK_CUDA(score); CHECK_CUDA(indices); CHECK_CUDA(lengths); @@ -373,11 +414,11 @@ void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor length } void fast_topk_transform_interface( - at::Tensor score, - at::Tensor lengths, - at::Tensor dst_page_table, - at::Tensor src_page_table, - at::Tensor cu_seqlens_q) { + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q) { CHECK_CUDA(score); CHECK_CUDA(lengths); CHECK_CUDA(dst_page_table); @@ -420,3 +461,35 @@ void fast_topk_transform_interface( const auto result = cudaGetLastError(); TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); } + +void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + + const auto params = get_params(score, lengths); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 1b4b5c91e..6b095069c 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -174,13 +174,18 @@ void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output); void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope); void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out); -void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor lengths); +void fast_topk_interface(const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths); void fast_topk_transform_interface( - at::Tensor score, - at::Tensor lengths, - at::Tensor dst_page_table, - at::Tensor src_page_table, - at::Tensor cu_seqlens_q); + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q); +void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset); #ifdef USE_ROCM void gelu_quick(at::Tensor& out, const at::Tensor& input); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 4c99fe702..dc72ae30b 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -327,7 +327,12 @@ from sgl_kernel.speculative import ( tree_speculative_sampling_target_only, verify_tree_greedy, ) -from sgl_kernel.top_k import fast_topk, fast_topk_transform_fused, fast_topk_v2 +from sgl_kernel.top_k import ( + fast_topk, + fast_topk_transform_fused, + fast_topk_transform_ragged_fused, + fast_topk_v2, +) from sgl_kernel.version import __version__ if torch.version.hip is not None: diff --git a/sgl-kernel/python/sgl_kernel/top_k.py b/sgl-kernel/python/sgl_kernel/top_k.py index afc67b9d8..effcca4ba 100644 --- a/sgl-kernel/python/sgl_kernel/top_k.py +++ b/sgl-kernel/python/sgl_kernel/top_k.py @@ -28,13 +28,36 @@ def fast_topk_transform_fused( cu_seqlens_q: torch.Tensor, topk: int, ) -> torch.Tensor: + """ + Transform topk indices to indices to the page table (page_size = 1) + """ assert ( topk == 2048 ), "fast_topk_transform_fused is only optimized for deepseek v3.2 model, where topk=2048" assert score.dim() == 2 src_page_table = page_table_size_1 - dst_page_table = score.new_empty((score.size(0), topk), dtype=torch.int32) + dst_page_table = score.new_empty((score.shape[0], topk), dtype=torch.int32) torch.ops.sgl_kernel.fast_topk_transform_fused( score, lengths, dst_page_table, src_page_table, cu_seqlens_q ) return dst_page_table + + +def fast_topk_transform_ragged_fused( + score: torch.Tensor, + lengths: torch.Tensor, + topk_indices_offset: torch.Tensor, # ragged kv + topk: int, +) -> torch.Tensor: + """ + Transform topk indices to indices to ragged kv (non-paged) + """ + assert ( + topk == 2048 + ), "fast_topk_transform_fused_ragged is only optimized for deepseek v3.2 model, where topk=2048" + assert score.dim() == 2 + topk_indices_ragged = score.new_empty((score.shape[0], topk), dtype=torch.int32) + torch.ops.sgl_kernel.fast_topk_transform_ragged_fused( + score, lengths, topk_indices_ragged, topk_indices_offset + ) + return topk_indices_ragged diff --git a/sgl-kernel/tests/test_topk.py b/sgl-kernel/tests/test_topk.py index 8bbea4c62..f3296fa15 100644 --- a/sgl-kernel/tests/test_topk.py +++ b/sgl-kernel/tests/test_topk.py @@ -1,6 +1,12 @@ +from typing import Optional + import pytest import torch -from sgl_kernel import fast_topk_transform_fused, fast_topk_v2 +from sgl_kernel import ( + fast_topk_transform_fused, + fast_topk_transform_ragged_fused, + fast_topk_v2, +) def _ref_torch_impl(score: torch.Tensor, seq_len: int, topk: int) -> torch.Tensor: @@ -26,6 +32,21 @@ def _ref_torch_transform_decode_impl( return topk_indices +def _ref_torch_transform_ragged_impl( + score: torch.Tensor, + seq_len: int, + topk_indices_offset: torch.Tensor, + topk: int, +) -> torch.Tensor: + assert score.shape[0] == topk_indices_offset.shape[0] + assert seq_len >= topk + indices = _ref_torch_impl(score, seq_len, topk) + + mask = indices != -1 + topk_indices_offset = topk_indices_offset.unsqueeze(1) + return torch.where(mask, indices + topk_indices_offset, indices) + + MAX_SEQ_LEN = 131072 MAX_PERMIT_ERROR = 0 @@ -37,6 +58,7 @@ def assert_equal( bs: int, k: int, seq_len: int, + topk_indices_offset: Optional[torch.Tensor] = None, ): indices_our_cpu = indices_our.cpu().tolist() indices_ref_cpu = indices_ref.cpu().tolist() @@ -45,11 +67,13 @@ def assert_equal( indices_our_set_i = set(indices_our_cpu[i]) more = indices_our_set_i - indices_ref_set_i less = indices_ref_set_i - indices_our_set_i - if len(more) > MAX_PERMIT_ERROR or len(less) > MAX_PERMIT_ERROR: + offset = topk_indices_offset[i].item() if topk_indices_offset is not None else 0 + if len(more) > 0 or len(less) > 0: + print(f"{bs=}, {k=}, {seq_len=}, {i=}, {more=}, {less=}") # check whether more values are the same with less values # if so, either one is acceptable, since their values are the same - more_values = sorted(score[i, idx].item() for idx in more) - less_values = sorted(score[i, idx].item() for idx in less) + more_values = sorted(score[i, idx - offset].item() for idx in more) + less_values = sorted(score[i, idx - offset].item() for idx in less) assert ( more_values == less_values ), f"{bs=}, {k=}, {seq_len=}, {i=}, {more=}, {less=} failed, with {more_values=}, {less_values=}" @@ -116,5 +140,52 @@ def test_topk_transform_kernel(bs: int, k: int, seq_len: int) -> None: assert_equal(score, dst_page_table_ref, dst_page_table_our, bs, k, seq_len) +@pytest.mark.parametrize("bs", [1, 132, 256, 4096]) +@pytest.mark.parametrize("k", [2048]) # we only support 2048 now +@pytest.mark.parametrize("seq_len", [2048, 4096, 16384, 65536]) +@torch.inference_mode() +def test_topk_transform_ragged_kernel(bs: int, k: int, seq_len: int) -> None: + # TODO(dark): test prefill kernel, though nothing special + MAX_PERMIT_ERROR = 1 + torch.manual_seed(42) + + stream = torch.cuda.Stream() + torch.cuda.set_stream(stream) + # bs: # of q tokens + score = torch.randn(bs, MAX_SEQ_LEN, dtype=torch.float32, device="cuda") + # kv_len + lengths = torch.full((bs,), seq_len, dtype=torch.int32, device="cuda") + topk_indices_offset = torch.randint( + 0, 1024, (bs,), dtype=torch.int32, device="cuda" + ) + + dst_page_table_ref = _ref_torch_transform_ragged_impl( + score=score, + seq_len=seq_len, + topk_indices_offset=topk_indices_offset, + topk=k, + ) + dst_page_table_our = fast_topk_transform_ragged_fused( + score=score, + lengths=lengths, + topk_indices_offset=topk_indices_offset, + topk=k, + ) + + # sort and compare + dst_page_table_our = torch.sort(dst_page_table_our, dim=-1).values + dst_page_table_ref = torch.sort(dst_page_table_ref, dim=-1).values + + assert_equal( + score, + dst_page_table_ref, + dst_page_table_our, + bs, + k, + seq_len, + topk_indices_offset, + ) + + if __name__ == "__main__": pytest.main([__file__])