[DeepseekV32] Add fast_topk_transform_ragged_fused kernel (#11815)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
@@ -327,7 +327,12 @@ from sgl_kernel.speculative import (
|
||||
tree_speculative_sampling_target_only,
|
||||
verify_tree_greedy,
|
||||
)
|
||||
from sgl_kernel.top_k import fast_topk, fast_topk_transform_fused, fast_topk_v2
|
||||
from sgl_kernel.top_k import (
|
||||
fast_topk,
|
||||
fast_topk_transform_fused,
|
||||
fast_topk_transform_ragged_fused,
|
||||
fast_topk_v2,
|
||||
)
|
||||
from sgl_kernel.version import __version__
|
||||
|
||||
if torch.version.hip is not None:
|
||||
|
||||
@@ -28,13 +28,36 @@ def fast_topk_transform_fused(
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
topk: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Transform topk indices to indices to the page table (page_size = 1)
|
||||
"""
|
||||
assert (
|
||||
topk == 2048
|
||||
), "fast_topk_transform_fused is only optimized for deepseek v3.2 model, where topk=2048"
|
||||
assert score.dim() == 2
|
||||
src_page_table = page_table_size_1
|
||||
dst_page_table = score.new_empty((score.size(0), topk), dtype=torch.int32)
|
||||
dst_page_table = score.new_empty((score.shape[0], topk), dtype=torch.int32)
|
||||
torch.ops.sgl_kernel.fast_topk_transform_fused(
|
||||
score, lengths, dst_page_table, src_page_table, cu_seqlens_q
|
||||
)
|
||||
return dst_page_table
|
||||
|
||||
|
||||
def fast_topk_transform_ragged_fused(
|
||||
score: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
topk_indices_offset: torch.Tensor, # ragged kv
|
||||
topk: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Transform topk indices to indices to ragged kv (non-paged)
|
||||
"""
|
||||
assert (
|
||||
topk == 2048
|
||||
), "fast_topk_transform_fused_ragged is only optimized for deepseek v3.2 model, where topk=2048"
|
||||
assert score.dim() == 2
|
||||
topk_indices_ragged = score.new_empty((score.shape[0], topk), dtype=torch.int32)
|
||||
torch.ops.sgl_kernel.fast_topk_transform_ragged_fused(
|
||||
score, lengths, topk_indices_ragged, topk_indices_offset
|
||||
)
|
||||
return topk_indices_ragged
|
||||
|
||||
Reference in New Issue
Block a user