[Feature] Speculative decoding support lookahead (#9873)

Co-authored-by: a4zhangfei <a4zhangfei@qq.com>
Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
This commit is contained in:
Zhihao Zhang
2025-09-19 07:42:41 +08:00
committed by GitHub
parent 2a2ff9a840
commit e7bc600304
30 changed files with 2058 additions and 32 deletions

View File

@@ -126,6 +126,7 @@ from sgl_kernel.sampling import (
)
from sgl_kernel.speculative import (
build_tree_kernel_efficient,
reconstruct_indices_from_tree_mask,
segment_packbits,
tree_speculative_sampling_target_only,
verify_tree_greedy,

View File

@@ -90,6 +90,28 @@ def build_tree_kernel_efficient(
)
def reconstruct_indices_from_tree_mask(
tree_mask: torch.Tensor,
verified_seq_len: torch.Tensor,
positions: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
batch_size: int,
draft_token_num: int,
) -> None:
torch.ops.sgl_kernel.reconstruct_indices_from_tree_mask.default(
tree_mask,
verified_seq_len,
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
batch_size,
draft_token_num,
)
def segment_packbits(
x: torch.Tensor,
input_indptr: torch.Tensor,