[DeepseekV32] Add fast_topk_transform_ragged_fused kernel (#11815)

Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
hlu1
2025-10-19 17:13:39 -07:00
committed by GitHub
parent 252dc4e112
commit 3b80232d06
6 changed files with 201 additions and 20 deletions

View File

@@ -327,7 +327,12 @@ from sgl_kernel.speculative import (
tree_speculative_sampling_target_only,
verify_tree_greedy,
)
from sgl_kernel.top_k import fast_topk, fast_topk_transform_fused, fast_topk_v2
from sgl_kernel.top_k import (
fast_topk,
fast_topk_transform_fused,
fast_topk_transform_ragged_fused,
fast_topk_v2,
)
from sgl_kernel.version import __version__
if torch.version.hip is not None:

View File

@@ -28,13 +28,36 @@ def fast_topk_transform_fused(
cu_seqlens_q: torch.Tensor,
topk: int,
) -> torch.Tensor:
"""
Transform topk indices to indices to the page table (page_size = 1)
"""
assert (
topk == 2048
), "fast_topk_transform_fused is only optimized for deepseek v3.2 model, where topk=2048"
assert score.dim() == 2
src_page_table = page_table_size_1
dst_page_table = score.new_empty((score.size(0), topk), dtype=torch.int32)
dst_page_table = score.new_empty((score.shape[0], topk), dtype=torch.int32)
torch.ops.sgl_kernel.fast_topk_transform_fused(
score, lengths, dst_page_table, src_page_table, cu_seqlens_q
)
return dst_page_table
def fast_topk_transform_ragged_fused(
score: torch.Tensor,
lengths: torch.Tensor,
topk_indices_offset: torch.Tensor, # ragged kv
topk: int,
) -> torch.Tensor:
"""
Transform topk indices to indices to ragged kv (non-paged)
"""
assert (
topk == 2048
), "fast_topk_transform_fused_ragged is only optimized for deepseek v3.2 model, where topk=2048"
assert score.dim() == 2
topk_indices_ragged = score.new_empty((score.shape[0], topk), dtype=torch.int32)
torch.ops.sgl_kernel.fast_topk_transform_ragged_fused(
score, lengths, topk_indices_ragged, topk_indices_offset
)
return topk_indices_ragged