Rename files in sgl kernel to avoid nested folder structure (#4213)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
83
sgl-kernel/python/sgl_kernel/speculative.py
Normal file
83
sgl-kernel/python/sgl_kernel/speculative.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import torch
|
||||
from sgl_kernel.utils import get_cuda_stream
|
||||
|
||||
|
||||
def tree_speculative_sampling_target_only(
|
||||
predicts: torch.Tensor, # mutable
|
||||
accept_index: torch.Tensor, # mutable
|
||||
accept_token_num: torch.Tensor, # mutable
|
||||
candidates: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
deterministic: bool = True,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.tree_speculative_sampling_target_only(
|
||||
predicts,
|
||||
accept_index,
|
||||
accept_token_num,
|
||||
candidates,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
uniform_samples,
|
||||
target_probs,
|
||||
draft_probs,
|
||||
deterministic,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
|
||||
|
||||
def build_tree_kernel_efficient(
|
||||
parent_list: torch.Tensor,
|
||||
selected_index: torch.Tensor,
|
||||
verified_seq_len: torch.Tensor,
|
||||
tree_mask: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
topk: int,
|
||||
depth: int,
|
||||
draft_token_num: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.build_tree_kernel_efficient(
|
||||
parent_list,
|
||||
selected_index,
|
||||
verified_seq_len,
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
topk,
|
||||
depth,
|
||||
draft_token_num,
|
||||
)
|
||||
|
||||
|
||||
def build_tree_kernel(
|
||||
parent_list: torch.Tensor,
|
||||
selected_index: torch.Tensor,
|
||||
verified_seq_len: torch.Tensor,
|
||||
tree_mask: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
topk: int,
|
||||
depth: int,
|
||||
draft_token_num: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.build_tree_kernel(
|
||||
parent_list,
|
||||
selected_index,
|
||||
verified_seq_len,
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
topk,
|
||||
depth,
|
||||
draft_token_num,
|
||||
)
|
||||
Reference in New Issue
Block a user