Fix sampling for speculative decoding & simplify kernels (#7207)

This commit is contained in:
Lianmin Zheng
2025-06-16 03:28:30 -07:00
committed by GitHub
parent b1286a116a
commit cfceb83d05
11 changed files with 124 additions and 79 deletions

View File

@@ -11,6 +11,7 @@ def tree_speculative_sampling_target_only(
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
uniform_samples: torch.Tensor,
uniform_samples_for_final_sampling: torch.Tensor,
target_probs: torch.Tensor,
draft_probs: torch.Tensor,
threshold_single: float = 1.0,
@@ -26,6 +27,7 @@ def tree_speculative_sampling_target_only(
retrive_next_token,
retrive_next_sibling,
uniform_samples,
uniform_samples_for_final_sampling,
target_probs,
draft_probs,
threshold_single,
@@ -91,11 +93,13 @@ def segment_packbits(
input_indptr: torch.Tensor,
output_indptr: torch.Tensor,
y: torch.Tensor,
batch_size: int,
) -> None:
torch.ops.sgl_kernel.segment_packbits.default(
x,
input_indptr,
output_indptr,
y,
batch_size,
torch.cuda.current_stream().cuda_stream,
)