2025-03-03 06:36:40 -08:00
|
|
|
import torch
|
2025-03-08 22:54:51 -08:00
|
|
|
from sgl_kernel.utils import get_cuda_stream
|
2025-03-03 06:36:40 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def tree_speculative_sampling_target_only(
|
|
|
|
|
predicts: torch.Tensor, # mutable
|
|
|
|
|
accept_index: torch.Tensor, # mutable
|
|
|
|
|
accept_token_num: torch.Tensor, # mutable
|
|
|
|
|
candidates: torch.Tensor,
|
|
|
|
|
retrive_index: torch.Tensor,
|
|
|
|
|
retrive_next_token: torch.Tensor,
|
|
|
|
|
retrive_next_sibling: torch.Tensor,
|
|
|
|
|
uniform_samples: torch.Tensor,
|
|
|
|
|
target_probs: torch.Tensor,
|
|
|
|
|
draft_probs: torch.Tensor,
|
|
|
|
|
deterministic: bool = True,
|
|
|
|
|
) -> None:
|
2025-03-08 22:54:51 -08:00
|
|
|
torch.ops.sgl_kernel.tree_speculative_sampling_target_only(
|
2025-03-03 06:36:40 -08:00
|
|
|
predicts,
|
|
|
|
|
accept_index,
|
|
|
|
|
accept_token_num,
|
|
|
|
|
candidates,
|
|
|
|
|
retrive_index,
|
|
|
|
|
retrive_next_token,
|
|
|
|
|
retrive_next_sibling,
|
|
|
|
|
uniform_samples,
|
|
|
|
|
target_probs,
|
|
|
|
|
draft_probs,
|
|
|
|
|
deterministic,
|
|
|
|
|
get_cuda_stream(),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_tree_kernel_efficient(
|
|
|
|
|
parent_list: torch.Tensor,
|
|
|
|
|
selected_index: torch.Tensor,
|
|
|
|
|
verified_seq_len: torch.Tensor,
|
|
|
|
|
tree_mask: torch.Tensor,
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
retrive_index: torch.Tensor,
|
|
|
|
|
retrive_next_token: torch.Tensor,
|
|
|
|
|
retrive_next_sibling: torch.Tensor,
|
|
|
|
|
topk: int,
|
|
|
|
|
depth: int,
|
|
|
|
|
draft_token_num: int,
|
|
|
|
|
) -> None:
|
2025-03-08 22:54:51 -08:00
|
|
|
torch.ops.sgl_kernel.build_tree_kernel_efficient(
|
2025-03-03 06:36:40 -08:00
|
|
|
parent_list,
|
|
|
|
|
selected_index,
|
|
|
|
|
verified_seq_len,
|
|
|
|
|
tree_mask,
|
|
|
|
|
positions,
|
|
|
|
|
retrive_index,
|
|
|
|
|
retrive_next_token,
|
|
|
|
|
retrive_next_sibling,
|
|
|
|
|
topk,
|
|
|
|
|
depth,
|
|
|
|
|
draft_token_num,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_tree_kernel(
|
|
|
|
|
parent_list: torch.Tensor,
|
|
|
|
|
selected_index: torch.Tensor,
|
|
|
|
|
verified_seq_len: torch.Tensor,
|
|
|
|
|
tree_mask: torch.Tensor,
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
retrive_index: torch.Tensor,
|
|
|
|
|
topk: int,
|
|
|
|
|
depth: int,
|
|
|
|
|
draft_token_num: int,
|
|
|
|
|
) -> None:
|
2025-03-08 22:54:51 -08:00
|
|
|
torch.ops.sgl_kernel.build_tree_kernel(
|
2025-03-03 06:36:40 -08:00
|
|
|
parent_list,
|
|
|
|
|
selected_index,
|
|
|
|
|
verified_seq_len,
|
|
|
|
|
tree_mask,
|
|
|
|
|
positions,
|
|
|
|
|
retrive_index,
|
|
|
|
|
topk,
|
|
|
|
|
depth,
|
|
|
|
|
draft_token_num,
|
|
|
|
|
)
|