Files
sglang/sgl-kernel/python/sgl_kernel/speculative.py
2025-03-08 22:54:51 -08:00

84 lines
2.0 KiB
Python

import torch
from sgl_kernel.utils import get_cuda_stream
def tree_speculative_sampling_target_only(
predicts: torch.Tensor, # mutable
accept_index: torch.Tensor, # mutable
accept_token_num: torch.Tensor, # mutable
candidates: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
uniform_samples: torch.Tensor,
target_probs: torch.Tensor,
draft_probs: torch.Tensor,
deterministic: bool = True,
) -> None:
torch.ops.sgl_kernel.tree_speculative_sampling_target_only(
predicts,
accept_index,
accept_token_num,
candidates,
retrive_index,
retrive_next_token,
retrive_next_sibling,
uniform_samples,
target_probs,
draft_probs,
deterministic,
get_cuda_stream(),
)
def build_tree_kernel_efficient(
parent_list: torch.Tensor,
selected_index: torch.Tensor,
verified_seq_len: torch.Tensor,
tree_mask: torch.Tensor,
positions: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
topk: int,
depth: int,
draft_token_num: int,
) -> None:
torch.ops.sgl_kernel.build_tree_kernel_efficient(
parent_list,
selected_index,
verified_seq_len,
tree_mask,
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
topk,
depth,
draft_token_num,
)
def build_tree_kernel(
parent_list: torch.Tensor,
selected_index: torch.Tensor,
verified_seq_len: torch.Tensor,
tree_mask: torch.Tensor,
positions: torch.Tensor,
retrive_index: torch.Tensor,
topk: int,
depth: int,
draft_token_num: int,
) -> None:
torch.ops.sgl_kernel.build_tree_kernel(
parent_list,
selected_index,
verified_seq_len,
tree_mask,
positions,
retrive_index,
topk,
depth,
draft_token_num,
)