[Feature] Speculative decoding support lookahead (#9873)
Co-authored-by: a4zhangfei <a4zhangfei@qq.com> Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
This commit is contained in:
@@ -126,6 +126,7 @@ from sgl_kernel.sampling import (
|
||||
)
|
||||
from sgl_kernel.speculative import (
|
||||
build_tree_kernel_efficient,
|
||||
reconstruct_indices_from_tree_mask,
|
||||
segment_packbits,
|
||||
tree_speculative_sampling_target_only,
|
||||
verify_tree_greedy,
|
||||
|
||||
@@ -90,6 +90,28 @@ def build_tree_kernel_efficient(
|
||||
)
|
||||
|
||||
|
||||
def reconstruct_indices_from_tree_mask(
|
||||
tree_mask: torch.Tensor,
|
||||
verified_seq_len: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
batch_size: int,
|
||||
draft_token_num: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.reconstruct_indices_from_tree_mask.default(
|
||||
tree_mask,
|
||||
verified_seq_len,
|
||||
positions,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
batch_size,
|
||||
draft_token_num,
|
||||
)
|
||||
|
||||
|
||||
def segment_packbits(
|
||||
x: torch.Tensor,
|
||||
input_indptr: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user