Add greedy verification kernel (#4383)

This commit is contained in:
Ying Sheng
2025-03-16 00:58:26 -07:00
committed by GitHub
parent 06d12b39d3
commit 52a34d7448
11 changed files with 394 additions and 153 deletions

View File

@@ -42,8 +42,13 @@ from sgl_kernel.sampling import (
top_p_sampling_from_probs,
)
from sgl_kernel.speculative import (
build_tree_kernel,
build_tree_kernel_efficient,
segment_packbits,
tree_speculative_sampling_target_only,
verify_tree_greedy,
)
from sgl_kernel.version import __version__
build_tree_kernel = (
None # TODO(ying): remove this after updating the sglang python code.
)

View File

@@ -13,6 +13,8 @@ def tree_speculative_sampling_target_only(
uniform_samples: torch.Tensor,
target_probs: torch.Tensor,
draft_probs: torch.Tensor,
threshold_single: float = 1.0,
threshold_acc: float = 1.0,
deterministic: bool = True,
) -> None:
torch.ops.sgl_kernel.tree_speculative_sampling_target_only(
@@ -26,11 +28,36 @@ def tree_speculative_sampling_target_only(
uniform_samples,
target_probs,
draft_probs,
threshold_single,
threshold_acc,
deterministic,
get_cuda_stream(),
)
def verify_tree_greedy(
predicts: torch.Tensor, # mutable
accept_index: torch.Tensor, # mutable
accept_token_num: torch.Tensor, # mutable
candidates: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
target_predict: torch.Tensor,
) -> None:
torch.ops.sgl_kernel.verify_tree_greedy(
predicts,
accept_index,
accept_token_num,
candidates,
retrive_index,
retrive_next_token,
retrive_next_sibling,
target_predict,
get_cuda_stream(),
)
def build_tree_kernel_efficient(
parent_list: torch.Tensor,
selected_index: torch.Tensor,
@@ -59,25 +86,16 @@ def build_tree_kernel_efficient(
)
def build_tree_kernel(
parent_list: torch.Tensor,
selected_index: torch.Tensor,
verified_seq_len: torch.Tensor,
tree_mask: torch.Tensor,
positions: torch.Tensor,
retrive_index: torch.Tensor,
topk: int,
depth: int,
draft_token_num: int,
def segment_packbits(
x: torch.Tensor,
input_indptr: torch.Tensor,
output_indptr: torch.Tensor,
y: torch.Tensor,
) -> None:
torch.ops.sgl_kernel.build_tree_kernel(
parent_list,
selected_index,
verified_seq_len,
tree_mask,
positions,
retrive_index,
topk,
depth,
draft_token_num,
torch.ops.sgl_kernel.segment_packbits(
x,
input_indptr,
output_indptr,
y,
torch.cuda.current_stream().cuda_stream,
)