Fix sampling for speculative decoding & simplify kernels (#7207)

This commit is contained in:
Lianmin Zheng
2025-06-16 03:28:30 -07:00
committed by GitHub
parent b1286a116a
commit cfceb83d05
11 changed files with 124 additions and 79 deletions

View File

@@ -72,6 +72,7 @@ from sgl_kernel.speculative import (
tree_speculative_sampling_target_only,
verify_tree_greedy,
)
from sgl_kernel.top_k import fast_topk
from sgl_kernel.version import __version__
build_tree_kernel = (

View File

@@ -11,6 +11,7 @@ def tree_speculative_sampling_target_only(
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
uniform_samples: torch.Tensor,
uniform_samples_for_final_sampling: torch.Tensor,
target_probs: torch.Tensor,
draft_probs: torch.Tensor,
threshold_single: float = 1.0,
@@ -26,6 +27,7 @@ def tree_speculative_sampling_target_only(
retrive_next_token,
retrive_next_sibling,
uniform_samples,
uniform_samples_for_final_sampling,
target_probs,
draft_probs,
threshold_single,
@@ -91,11 +93,13 @@ def segment_packbits(
input_indptr: torch.Tensor,
output_indptr: torch.Tensor,
y: torch.Tensor,
batch_size: int,
) -> None:
torch.ops.sgl_kernel.segment_packbits.default(
x,
input_indptr,
output_indptr,
y,
batch_size,
torch.cuda.current_stream().cuda_stream,
)

View File

@@ -0,0 +1,11 @@
import torch
def fast_topk(values, topk, dim):
if topk == 1:
# Use max along the specified dimension to get both value and index
return torch.max(values, dim=dim, keepdim=True)
else:
# Use topk for efficiency with larger k values
# TODO: implement faster cuda kernels for large vocab sizes
return torch.topk(values, topk, dim=dim)