[DeepseekV32] Add fast_topk_transform_ragged_fused kernel (#11815)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
@@ -113,6 +113,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
"fast_topk_transform_fused(Tensor score, Tensor lengths, Tensor dst_page_table, Tensor src_page_table, Tensor "
|
"fast_topk_transform_fused(Tensor score, Tensor lengths, Tensor dst_page_table, Tensor src_page_table, Tensor "
|
||||||
"cu_seqlens_q) -> ()");
|
"cu_seqlens_q) -> ()");
|
||||||
m.impl("fast_topk_transform_fused", torch::kCUDA, &fast_topk_transform_interface);
|
m.impl("fast_topk_transform_fused", torch::kCUDA, &fast_topk_transform_interface);
|
||||||
|
m.def(
|
||||||
|
"fast_topk_transform_ragged_fused(Tensor score, Tensor lengths, Tensor topk_indices_ragged, Tensor "
|
||||||
|
"topk_indices_offset) -> ()");
|
||||||
|
m.impl("fast_topk_transform_ragged_fused", torch::kCUDA, &fast_topk_transform_ragged_interface);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From gguf quantiztion
|
* From gguf quantiztion
|
||||||
|
|||||||
@@ -51,6 +51,15 @@ __device__ void naive_topk_transform(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// keep the first `length` entries, set others to -1
|
||||||
|
__device__ void naive_topk_transform_ragged(
|
||||||
|
const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) {
|
||||||
|
const auto tid = threadIdx.x;
|
||||||
|
for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
|
||||||
|
topk_indices_ragged[i] = (i < length) ? static_cast<int32_t>(i) + offset : -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t {
|
__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t {
|
||||||
__half h = __float2half_rn(x);
|
__half h = __float2half_rn(x);
|
||||||
uint16_t bits = __half_as_ushort(h);
|
uint16_t bits = __half_as_ushort(h);
|
||||||
@@ -322,8 +331,40 @@ __global__ __launch_bounds__(kThreadsPerBlock) // prefill
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto get_params(at::Tensor score, at::Tensor lengths, std::optional<at::Tensor> indices_opt = std::nullopt)
|
__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv
|
||||||
-> FastTopKParams {
|
void topk_transform_prefill_ragged_kernel(
|
||||||
|
const FastTopKParams params,
|
||||||
|
int32_t* __restrict__ topk_indices_ragged,
|
||||||
|
const int32_t* __restrict__ topk_indices_offset) {
|
||||||
|
const auto& [input, _, lengths, input_stride] = params;
|
||||||
|
const auto bid = static_cast<uint64_t>(blockIdx.x);
|
||||||
|
const auto tid = threadIdx.x;
|
||||||
|
const auto length = lengths[bid];
|
||||||
|
const auto dst_indices_entry = topk_indices_ragged + bid * TopK;
|
||||||
|
const auto score = input + bid * input_stride;
|
||||||
|
const auto offset = topk_indices_offset[bid];
|
||||||
|
|
||||||
|
if (length <= TopK) {
|
||||||
|
return naive_topk_transform_ragged(score, length, dst_indices_entry, offset);
|
||||||
|
} else {
|
||||||
|
__shared__ int s_indices[TopK];
|
||||||
|
fast_topk_cuda_tl(score, s_indices, length);
|
||||||
|
// copy src[s_indices] to dst, we manually unroll here
|
||||||
|
static_assert(TopK % kThreadsPerBlock == 0);
|
||||||
|
static_assert(TopK / kThreadsPerBlock == 2);
|
||||||
|
const auto idx_0 = tid;
|
||||||
|
const auto pos_0 = s_indices[idx_0];
|
||||||
|
dst_indices_entry[idx_0] = pos_0 + offset;
|
||||||
|
const auto idx_1 = tid + kThreadsPerBlock;
|
||||||
|
const auto pos_1 = s_indices[idx_1];
|
||||||
|
dst_indices_entry[idx_1] = pos_1 + offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_params(
|
||||||
|
const at::Tensor& score,
|
||||||
|
const at::Tensor& lengths,
|
||||||
|
std::optional<at::Tensor> indices_opt = std::nullopt) -> FastTopKParams {
|
||||||
const auto B = score.size(0);
|
const auto B = score.size(0);
|
||||||
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1);
|
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1);
|
||||||
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous());
|
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous());
|
||||||
@@ -357,7 +398,7 @@ void setup_kernel_smem_once() {
|
|||||||
|
|
||||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||||
|
|
||||||
void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor lengths) {
|
void fast_topk_interface(const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths) {
|
||||||
CHECK_CUDA(score);
|
CHECK_CUDA(score);
|
||||||
CHECK_CUDA(indices);
|
CHECK_CUDA(indices);
|
||||||
CHECK_CUDA(lengths);
|
CHECK_CUDA(lengths);
|
||||||
@@ -373,11 +414,11 @@ void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor length
|
|||||||
}
|
}
|
||||||
|
|
||||||
void fast_topk_transform_interface(
|
void fast_topk_transform_interface(
|
||||||
at::Tensor score,
|
const at::Tensor& score,
|
||||||
at::Tensor lengths,
|
const at::Tensor& lengths,
|
||||||
at::Tensor dst_page_table,
|
at::Tensor& dst_page_table,
|
||||||
at::Tensor src_page_table,
|
const at::Tensor& src_page_table,
|
||||||
at::Tensor cu_seqlens_q) {
|
const at::Tensor& cu_seqlens_q) {
|
||||||
CHECK_CUDA(score);
|
CHECK_CUDA(score);
|
||||||
CHECK_CUDA(lengths);
|
CHECK_CUDA(lengths);
|
||||||
CHECK_CUDA(dst_page_table);
|
CHECK_CUDA(dst_page_table);
|
||||||
@@ -420,3 +461,35 @@ void fast_topk_transform_interface(
|
|||||||
const auto result = cudaGetLastError();
|
const auto result = cudaGetLastError();
|
||||||
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
|
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void fast_topk_transform_ragged_interface(
|
||||||
|
const at::Tensor& score,
|
||||||
|
const at::Tensor& lengths,
|
||||||
|
at::Tensor& topk_indices_ragged,
|
||||||
|
const at::Tensor& topk_indices_offset) {
|
||||||
|
CHECK_CUDA(score);
|
||||||
|
CHECK_CUDA(lengths);
|
||||||
|
CHECK_CUDA(topk_indices_ragged);
|
||||||
|
CHECK_CUDA(topk_indices_offset);
|
||||||
|
|
||||||
|
const auto params = get_params(score, lengths);
|
||||||
|
const auto B = score.size(0);
|
||||||
|
TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous());
|
||||||
|
TORCH_CHECK(topk_indices_offset.dim() == 1);
|
||||||
|
|
||||||
|
TORCH_CHECK(topk_indices_ragged.size(0) == B);
|
||||||
|
TORCH_CHECK(topk_indices_ragged.size(1) == TopK);
|
||||||
|
TORCH_CHECK(topk_indices_offset.size(0) == B);
|
||||||
|
|
||||||
|
// launch kernel
|
||||||
|
const auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
|
const auto grid = dim3{static_cast<uint32_t>(B)};
|
||||||
|
const auto block = dim3{kThreadsPerBlock};
|
||||||
|
|
||||||
|
setup_kernel_smem_once<topk_transform_prefill_ragged_kernel, kSmem>();
|
||||||
|
topk_transform_prefill_ragged_kernel<<<grid, block, kSmem, stream>>>(
|
||||||
|
params, topk_indices_ragged.data_ptr<int32_t>(), topk_indices_offset.data_ptr<int32_t>());
|
||||||
|
|
||||||
|
const auto result = cudaGetLastError();
|
||||||
|
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
|
||||||
|
}
|
||||||
|
|||||||
@@ -174,13 +174,18 @@ void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output);
|
|||||||
void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope);
|
void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope);
|
||||||
void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out);
|
void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out);
|
||||||
|
|
||||||
void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor lengths);
|
void fast_topk_interface(const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths);
|
||||||
void fast_topk_transform_interface(
|
void fast_topk_transform_interface(
|
||||||
at::Tensor score,
|
const at::Tensor& score,
|
||||||
at::Tensor lengths,
|
const at::Tensor& lengths,
|
||||||
at::Tensor dst_page_table,
|
at::Tensor& dst_page_table,
|
||||||
at::Tensor src_page_table,
|
const at::Tensor& src_page_table,
|
||||||
at::Tensor cu_seqlens_q);
|
const at::Tensor& cu_seqlens_q);
|
||||||
|
void fast_topk_transform_ragged_interface(
|
||||||
|
const at::Tensor& score,
|
||||||
|
const at::Tensor& lengths,
|
||||||
|
at::Tensor& topk_indices_ragged,
|
||||||
|
const at::Tensor& topk_indices_offset);
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
void gelu_quick(at::Tensor& out, const at::Tensor& input);
|
void gelu_quick(at::Tensor& out, const at::Tensor& input);
|
||||||
|
|||||||
@@ -327,7 +327,12 @@ from sgl_kernel.speculative import (
|
|||||||
tree_speculative_sampling_target_only,
|
tree_speculative_sampling_target_only,
|
||||||
verify_tree_greedy,
|
verify_tree_greedy,
|
||||||
)
|
)
|
||||||
from sgl_kernel.top_k import fast_topk, fast_topk_transform_fused, fast_topk_v2
|
from sgl_kernel.top_k import (
|
||||||
|
fast_topk,
|
||||||
|
fast_topk_transform_fused,
|
||||||
|
fast_topk_transform_ragged_fused,
|
||||||
|
fast_topk_v2,
|
||||||
|
)
|
||||||
from sgl_kernel.version import __version__
|
from sgl_kernel.version import __version__
|
||||||
|
|
||||||
if torch.version.hip is not None:
|
if torch.version.hip is not None:
|
||||||
|
|||||||
@@ -28,13 +28,36 @@ def fast_topk_transform_fused(
|
|||||||
cu_seqlens_q: torch.Tensor,
|
cu_seqlens_q: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Transform topk indices to indices to the page table (page_size = 1)
|
||||||
|
"""
|
||||||
assert (
|
assert (
|
||||||
topk == 2048
|
topk == 2048
|
||||||
), "fast_topk_transform_fused is only optimized for deepseek v3.2 model, where topk=2048"
|
), "fast_topk_transform_fused is only optimized for deepseek v3.2 model, where topk=2048"
|
||||||
assert score.dim() == 2
|
assert score.dim() == 2
|
||||||
src_page_table = page_table_size_1
|
src_page_table = page_table_size_1
|
||||||
dst_page_table = score.new_empty((score.size(0), topk), dtype=torch.int32)
|
dst_page_table = score.new_empty((score.shape[0], topk), dtype=torch.int32)
|
||||||
torch.ops.sgl_kernel.fast_topk_transform_fused(
|
torch.ops.sgl_kernel.fast_topk_transform_fused(
|
||||||
score, lengths, dst_page_table, src_page_table, cu_seqlens_q
|
score, lengths, dst_page_table, src_page_table, cu_seqlens_q
|
||||||
)
|
)
|
||||||
return dst_page_table
|
return dst_page_table
|
||||||
|
|
||||||
|
|
||||||
|
def fast_topk_transform_ragged_fused(
|
||||||
|
score: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
topk_indices_offset: torch.Tensor, # ragged kv
|
||||||
|
topk: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Transform topk indices to indices to ragged kv (non-paged)
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
topk == 2048
|
||||||
|
), "fast_topk_transform_fused_ragged is only optimized for deepseek v3.2 model, where topk=2048"
|
||||||
|
assert score.dim() == 2
|
||||||
|
topk_indices_ragged = score.new_empty((score.shape[0], topk), dtype=torch.int32)
|
||||||
|
torch.ops.sgl_kernel.fast_topk_transform_ragged_fused(
|
||||||
|
score, lengths, topk_indices_ragged, topk_indices_offset
|
||||||
|
)
|
||||||
|
return topk_indices_ragged
|
||||||
|
|||||||
@@ -1,6 +1,12 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel import fast_topk_transform_fused, fast_topk_v2
|
from sgl_kernel import (
|
||||||
|
fast_topk_transform_fused,
|
||||||
|
fast_topk_transform_ragged_fused,
|
||||||
|
fast_topk_v2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _ref_torch_impl(score: torch.Tensor, seq_len: int, topk: int) -> torch.Tensor:
|
def _ref_torch_impl(score: torch.Tensor, seq_len: int, topk: int) -> torch.Tensor:
|
||||||
@@ -26,6 +32,21 @@ def _ref_torch_transform_decode_impl(
|
|||||||
return topk_indices
|
return topk_indices
|
||||||
|
|
||||||
|
|
||||||
|
def _ref_torch_transform_ragged_impl(
|
||||||
|
score: torch.Tensor,
|
||||||
|
seq_len: int,
|
||||||
|
topk_indices_offset: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert score.shape[0] == topk_indices_offset.shape[0]
|
||||||
|
assert seq_len >= topk
|
||||||
|
indices = _ref_torch_impl(score, seq_len, topk)
|
||||||
|
|
||||||
|
mask = indices != -1
|
||||||
|
topk_indices_offset = topk_indices_offset.unsqueeze(1)
|
||||||
|
return torch.where(mask, indices + topk_indices_offset, indices)
|
||||||
|
|
||||||
|
|
||||||
MAX_SEQ_LEN = 131072
|
MAX_SEQ_LEN = 131072
|
||||||
MAX_PERMIT_ERROR = 0
|
MAX_PERMIT_ERROR = 0
|
||||||
|
|
||||||
@@ -37,6 +58,7 @@ def assert_equal(
|
|||||||
bs: int,
|
bs: int,
|
||||||
k: int,
|
k: int,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
|
topk_indices_offset: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
indices_our_cpu = indices_our.cpu().tolist()
|
indices_our_cpu = indices_our.cpu().tolist()
|
||||||
indices_ref_cpu = indices_ref.cpu().tolist()
|
indices_ref_cpu = indices_ref.cpu().tolist()
|
||||||
@@ -45,11 +67,13 @@ def assert_equal(
|
|||||||
indices_our_set_i = set(indices_our_cpu[i])
|
indices_our_set_i = set(indices_our_cpu[i])
|
||||||
more = indices_our_set_i - indices_ref_set_i
|
more = indices_our_set_i - indices_ref_set_i
|
||||||
less = indices_ref_set_i - indices_our_set_i
|
less = indices_ref_set_i - indices_our_set_i
|
||||||
if len(more) > MAX_PERMIT_ERROR or len(less) > MAX_PERMIT_ERROR:
|
offset = topk_indices_offset[i].item() if topk_indices_offset is not None else 0
|
||||||
|
if len(more) > 0 or len(less) > 0:
|
||||||
|
print(f"{bs=}, {k=}, {seq_len=}, {i=}, {more=}, {less=}")
|
||||||
# check whether more values are the same with less values
|
# check whether more values are the same with less values
|
||||||
# if so, either one is acceptable, since their values are the same
|
# if so, either one is acceptable, since their values are the same
|
||||||
more_values = sorted(score[i, idx].item() for idx in more)
|
more_values = sorted(score[i, idx - offset].item() for idx in more)
|
||||||
less_values = sorted(score[i, idx].item() for idx in less)
|
less_values = sorted(score[i, idx - offset].item() for idx in less)
|
||||||
assert (
|
assert (
|
||||||
more_values == less_values
|
more_values == less_values
|
||||||
), f"{bs=}, {k=}, {seq_len=}, {i=}, {more=}, {less=} failed, with {more_values=}, {less_values=}"
|
), f"{bs=}, {k=}, {seq_len=}, {i=}, {more=}, {less=} failed, with {more_values=}, {less_values=}"
|
||||||
@@ -116,5 +140,52 @@ def test_topk_transform_kernel(bs: int, k: int, seq_len: int) -> None:
|
|||||||
assert_equal(score, dst_page_table_ref, dst_page_table_our, bs, k, seq_len)
|
assert_equal(score, dst_page_table_ref, dst_page_table_our, bs, k, seq_len)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("bs", [1, 132, 256, 4096])
|
||||||
|
@pytest.mark.parametrize("k", [2048]) # we only support 2048 now
|
||||||
|
@pytest.mark.parametrize("seq_len", [2048, 4096, 16384, 65536])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_topk_transform_ragged_kernel(bs: int, k: int, seq_len: int) -> None:
|
||||||
|
# TODO(dark): test prefill kernel, though nothing special
|
||||||
|
MAX_PERMIT_ERROR = 1
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
torch.cuda.set_stream(stream)
|
||||||
|
# bs: # of q tokens
|
||||||
|
score = torch.randn(bs, MAX_SEQ_LEN, dtype=torch.float32, device="cuda")
|
||||||
|
# kv_len
|
||||||
|
lengths = torch.full((bs,), seq_len, dtype=torch.int32, device="cuda")
|
||||||
|
topk_indices_offset = torch.randint(
|
||||||
|
0, 1024, (bs,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
dst_page_table_ref = _ref_torch_transform_ragged_impl(
|
||||||
|
score=score,
|
||||||
|
seq_len=seq_len,
|
||||||
|
topk_indices_offset=topk_indices_offset,
|
||||||
|
topk=k,
|
||||||
|
)
|
||||||
|
dst_page_table_our = fast_topk_transform_ragged_fused(
|
||||||
|
score=score,
|
||||||
|
lengths=lengths,
|
||||||
|
topk_indices_offset=topk_indices_offset,
|
||||||
|
topk=k,
|
||||||
|
)
|
||||||
|
|
||||||
|
# sort and compare
|
||||||
|
dst_page_table_our = torch.sort(dst_page_table_our, dim=-1).values
|
||||||
|
dst_page_table_ref = torch.sort(dst_page_table_ref, dim=-1).values
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
score,
|
||||||
|
dst_page_table_ref,
|
||||||
|
dst_page_table_our,
|
||||||
|
bs,
|
||||||
|
k,
|
||||||
|
seq_len,
|
||||||
|
topk_indices_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|||||||
Reference in New Issue
Block a user