[Feature] Add a fast-topk to sgl-kernel for DeepSeek v3.2 (#11194)

Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
DarkSharpness
2025-10-06 01:19:03 +08:00
committed by GitHub
parent 4cb5a5235e
commit e0b2d3eebe
7 changed files with 588 additions and 1 deletions

View File

@@ -309,7 +309,7 @@ from sgl_kernel.speculative import (
tree_speculative_sampling_target_only,
verify_tree_greedy,
)
from sgl_kernel.top_k import fast_topk
from sgl_kernel.top_k import fast_topk, fast_topk_transform_fused, fast_topk_v2
from sgl_kernel.version import __version__
if torch.version.hip is not None:

View File

@@ -9,3 +9,32 @@ def fast_topk(values, topk, dim):
# Use topk for efficiency with larger k values
# TODO: implement faster cuda kernels for large vocab sizes
return torch.topk(values, topk, dim=dim)
def fast_topk_v2(score: torch.Tensor, lengths: torch.Tensor, topk: int) -> torch.Tensor:
assert (
topk == 2048
), "fast_topk_v2 is only optimized for deepseek v3.2 model, where topk=2048"
assert score.dim() == 2
topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32)
torch.ops.sgl_kernel.fast_topk(score, topk_indices, lengths)
return topk_indices
def fast_topk_transform_fused(
score: torch.Tensor,
lengths: torch.Tensor,
page_table_size_1: torch.Tensor, # NOTE: page size should be 1
cu_seqlens_q: torch.Tensor,
topk: int,
) -> torch.Tensor:
assert (
topk == 2048
), "fast_topk_transform_fused is only optimized for deepseek v3.2 model, where topk=2048"
assert score.dim() == 2
src_page_table = page_table_size_1
dst_page_table = score.new_empty((score.size(0), topk), dtype=torch.int32)
torch.ops.sgl_kernel.fast_topk_transform_fused(
score, lengths, dst_page_table, src_page_table, cu_seqlens_q
)
return dst_page_table