Fix sampling for speculative decoding & simplify kernels (#7207)
This commit is contained in:
@@ -72,6 +72,7 @@ from sgl_kernel.speculative import (
|
||||
tree_speculative_sampling_target_only,
|
||||
verify_tree_greedy,
|
||||
)
|
||||
from sgl_kernel.top_k import fast_topk
|
||||
from sgl_kernel.version import __version__
|
||||
|
||||
build_tree_kernel = (
|
||||
|
||||
@@ -11,6 +11,7 @@ def tree_speculative_sampling_target_only(
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
uniform_samples_for_final_sampling: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
threshold_single: float = 1.0,
|
||||
@@ -26,6 +27,7 @@ def tree_speculative_sampling_target_only(
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
uniform_samples,
|
||||
uniform_samples_for_final_sampling,
|
||||
target_probs,
|
||||
draft_probs,
|
||||
threshold_single,
|
||||
@@ -91,11 +93,13 @@ def segment_packbits(
|
||||
input_indptr: torch.Tensor,
|
||||
output_indptr: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
batch_size: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.segment_packbits.default(
|
||||
x,
|
||||
input_indptr,
|
||||
output_indptr,
|
||||
y,
|
||||
batch_size,
|
||||
torch.cuda.current_stream().cuda_stream,
|
||||
)
|
||||
|
||||
11
sgl-kernel/python/sgl_kernel/top_k.py
Normal file
11
sgl-kernel/python/sgl_kernel/top_k.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import torch
|
||||
|
||||
|
||||
def fast_topk(values, topk, dim):
|
||||
if topk == 1:
|
||||
# Use max along the specified dimension to get both value and index
|
||||
return torch.max(values, dim=dim, keepdim=True)
|
||||
else:
|
||||
# Use topk for efficiency with larger k values
|
||||
# TODO: implement faster cuda kernels for large vocab sizes
|
||||
return torch.topk(values, topk, dim=dim)
|
||||
Reference in New Issue
Block a user