Fix sampling for speculative decoding & simplify kernels (#7207)
This commit is contained in:
@@ -11,6 +11,7 @@ def tree_speculative_sampling_target_only(
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
uniform_samples_for_final_sampling: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
threshold_single: float = 1.0,
|
||||
@@ -26,6 +27,7 @@ def tree_speculative_sampling_target_only(
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
uniform_samples,
|
||||
uniform_samples_for_final_sampling,
|
||||
target_probs,
|
||||
draft_probs,
|
||||
threshold_single,
|
||||
@@ -91,11 +93,13 @@ def segment_packbits(
|
||||
input_indptr: torch.Tensor,
|
||||
output_indptr: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
batch_size: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.segment_packbits.default(
|
||||
x,
|
||||
input_indptr,
|
||||
output_indptr,
|
||||
y,
|
||||
batch_size,
|
||||
torch.cuda.current_stream().cuda_stream,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user