diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 87c271e20..9c5842260 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -268,6 +268,7 @@ set(SOURCES "csrc/elementwise/concat_mla.cu" "csrc/elementwise/fused_add_rms_norm_kernel.cu" "csrc/elementwise/rope.cu" + "csrc/elementwise/topk.cu" "csrc/common_extension.cc" "csrc/gemm/awq_kernel.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 4b99f7645..48968a64c 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -107,6 +107,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("concat_mla_absorb_q(Tensor a, Tensor b, Tensor! out) -> ()"); m.impl("concat_mla_absorb_q", torch::kCUDA, &concat_mla_absorb_q); + m.def("fast_topk(Tensor score, Tensor indices, Tensor lengths) -> ()"); + m.impl("fast_topk", torch::kCUDA, &fast_topk_interface); + m.def( + "fast_topk_transform_fused(Tensor score, Tensor lengths, Tensor dst_page_table, Tensor src_page_table, Tensor " + "cu_seqlens_q) -> ()"); + m.impl("fast_topk_transform_fused", torch::kCUDA, &fast_topk_transform_interface); + /* * From csrc/gemm */ diff --git a/sgl-kernel/csrc/elementwise/topk.cu b/sgl-kernel/csrc/elementwise/topk.cu new file mode 100644 index 000000000..05fb9f08b --- /dev/null +++ b/sgl-kernel/csrc/elementwise/topk.cu @@ -0,0 +1,422 @@ +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; +constexpr size_t kSmem = 32 * 1024 * sizeof(uint32_t); // 128KB + +struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; +}; + +// when length <= TopK, we can directly write the indices +__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, length); + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +auto get_params(at::Tensor score, at::Tensor lengths, std::optional indices_opt = std::nullopt) + -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; +} + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = + [] { return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor lengths) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_interface( + at::Tensor score, + at::Tensor lengths, + at::Tensor dst_page_table, + at::Tensor src_page_table, + at::Tensor cu_seqlens_q) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + const auto params = get_params(score, lengths); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + const auto is_decode = (prefill_bs == B); + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index fdaba4c93..d316e4248 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -174,6 +174,14 @@ void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output); void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope); void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out); +void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor lengths); +void fast_topk_transform_interface( + at::Tensor score, + at::Tensor lengths, + at::Tensor dst_page_table, + at::Tensor src_page_table, + at::Tensor cu_seqlens_q); + #ifdef USE_ROCM void gelu_quick(at::Tensor& out, const at::Tensor& input); #endif diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index d077fc3fb..a53b02567 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -309,7 +309,7 @@ from sgl_kernel.speculative import ( tree_speculative_sampling_target_only, verify_tree_greedy, ) -from sgl_kernel.top_k import fast_topk +from sgl_kernel.top_k import fast_topk, fast_topk_transform_fused, fast_topk_v2 from sgl_kernel.version import __version__ if torch.version.hip is not None: diff --git a/sgl-kernel/python/sgl_kernel/top_k.py b/sgl-kernel/python/sgl_kernel/top_k.py index fc29a6db8..afc67b9d8 100644 --- a/sgl-kernel/python/sgl_kernel/top_k.py +++ b/sgl-kernel/python/sgl_kernel/top_k.py @@ -9,3 +9,32 @@ def fast_topk(values, topk, dim): # Use topk for efficiency with larger k values # TODO: implement faster cuda kernels for large vocab sizes return torch.topk(values, topk, dim=dim) + + +def fast_topk_v2(score: torch.Tensor, lengths: torch.Tensor, topk: int) -> torch.Tensor: + assert ( + topk == 2048 + ), "fast_topk_v2 is only optimized for deepseek v3.2 model, where topk=2048" + assert score.dim() == 2 + topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32) + torch.ops.sgl_kernel.fast_topk(score, topk_indices, lengths) + return topk_indices + + +def fast_topk_transform_fused( + score: torch.Tensor, + lengths: torch.Tensor, + page_table_size_1: torch.Tensor, # NOTE: page size should be 1 + cu_seqlens_q: torch.Tensor, + topk: int, +) -> torch.Tensor: + assert ( + topk == 2048 + ), "fast_topk_transform_fused is only optimized for deepseek v3.2 model, where topk=2048" + assert score.dim() == 2 + src_page_table = page_table_size_1 + dst_page_table = score.new_empty((score.size(0), topk), dtype=torch.int32) + torch.ops.sgl_kernel.fast_topk_transform_fused( + score, lengths, dst_page_table, src_page_table, cu_seqlens_q + ) + return dst_page_table diff --git a/sgl-kernel/tests/test_topk.py b/sgl-kernel/tests/test_topk.py new file mode 100644 index 000000000..8bbea4c62 --- /dev/null +++ b/sgl-kernel/tests/test_topk.py @@ -0,0 +1,120 @@ +import pytest +import torch +from sgl_kernel import fast_topk_transform_fused, fast_topk_v2 + + +def _ref_torch_impl(score: torch.Tensor, seq_len: int, topk: int) -> torch.Tensor: + assert score.dim() == 2 + return torch.topk(score[:, :seq_len], topk, dim=-1, sorted=False).indices + + +def _ref_torch_transform_decode_impl( + score: torch.Tensor, + seq_len: int, + src_page_table: torch.Tensor, + topk: int, +) -> torch.Tensor: + batch_size, _ = score.shape + assert score.shape[0] == src_page_table.shape[0] + assert seq_len >= topk + indices = _ref_torch_impl(score, seq_len, topk) + topk_indices = torch.empty( + (batch_size, topk), dtype=torch.int32, device=score.device + ) + for i in range(batch_size): + topk_indices[i] = src_page_table[i, indices[i]] + return topk_indices + + +MAX_SEQ_LEN = 131072 +MAX_PERMIT_ERROR = 0 + + +def assert_equal( + score: torch.Tensor, + indices_ref: torch.Tensor, + indices_our: torch.Tensor, + bs: int, + k: int, + seq_len: int, +): + indices_our_cpu = indices_our.cpu().tolist() + indices_ref_cpu = indices_ref.cpu().tolist() + for i in range(bs): + indices_ref_set_i = set(indices_ref_cpu[i]) + indices_our_set_i = set(indices_our_cpu[i]) + more = indices_our_set_i - indices_ref_set_i + less = indices_ref_set_i - indices_our_set_i + if len(more) > MAX_PERMIT_ERROR or len(less) > MAX_PERMIT_ERROR: + # check whether more values are the same with less values + # if so, either one is acceptable, since their values are the same + more_values = sorted(score[i, idx].item() for idx in more) + less_values = sorted(score[i, idx].item() for idx in less) + assert ( + more_values == less_values + ), f"{bs=}, {k=}, {seq_len=}, {i=}, {more=}, {less=} failed, with {more_values=}, {less_values=}" + + +@pytest.mark.parametrize("bs", [1, 132, 256, 4096]) +@pytest.mark.parametrize("k", [2048]) # we only support 2048 now +@pytest.mark.parametrize("seq_len", [2048, 4096, 16384, 65536]) +@torch.inference_mode() +def test_topk_kernel(bs: int, k: int, seq_len: int) -> None: + torch.manual_seed(42) + + stream = torch.cuda.Stream() + torch.cuda.set_stream(stream) + score = torch.randn(bs, MAX_SEQ_LEN, dtype=torch.float32, device="cuda") + lengths = torch.full((bs,), seq_len, dtype=torch.int32, device="cuda") + + indices_ref = _ref_torch_impl(score, seq_len, k) + indices_our = fast_topk_v2(score, lengths, k) + + # sort and compare + indices_ref = torch.sort(indices_ref, dim=-1).values + indices_our = torch.sort(indices_our, dim=-1).values + + assert_equal(score, indices_ref, indices_our, bs, k, seq_len) + + +@pytest.mark.parametrize("bs", [1, 132, 256, 4096]) +@pytest.mark.parametrize("k", [2048]) # we only support 2048 now +@pytest.mark.parametrize("seq_len", [2048, 4096, 16384, 65536]) +@torch.inference_mode() +def test_topk_transform_kernel(bs: int, k: int, seq_len: int) -> None: + # TODO(dark): test prefill kernel, though nothing special + MAX_PERMIT_ERROR = 1 + torch.manual_seed(42) + + stream = torch.cuda.Stream() + torch.cuda.set_stream(stream) + score = torch.randn(bs, MAX_SEQ_LEN, dtype=torch.float32, device="cuda") + lengths = torch.full((bs,), seq_len, dtype=torch.int32, device="cuda") + src_page_table = torch.arange(0, seq_len, dtype=torch.int32, device="cuda") + src_page_table = src_page_table.unsqueeze(0).expand(bs, -1) + # NOTE: for decode, cumulative seqlens_q is just 0..=bs + # NOTE: since page table is arange, they equal topk indices + cu_seqlens_q = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") + dst_page_table_ref = _ref_torch_transform_decode_impl( + score=score, + seq_len=seq_len, + src_page_table=src_page_table, + topk=k, + ) + dst_page_table_our = fast_topk_transform_fused( + score=score, + lengths=lengths, + page_table_size_1=src_page_table, + cu_seqlens_q=cu_seqlens_q, + topk=k, + ) + + # sort and compare + dst_page_table_our = torch.sort(dst_page_table_our, dim=-1).values + dst_page_table_ref = torch.sort(dst_page_table_ref, dim=-1).values + + assert_equal(score, dst_page_table_ref, dst_page_table_our, bs, k, seq_len) + + +if __name__ == "__main__": + pytest.main([__file__])