use default for torch.ops (#4835)

This commit is contained in:
Yineng Zhang
2025-03-27 19:09:58 -07:00
committed by GitHub
parent 10a9ab7b07
commit 31dfff7da7
7 changed files with 51 additions and 47 deletions

View File

@@ -12,7 +12,7 @@ def _top_k_renorm_probs_internal(
probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernel.top_k_renorm_probs(
torch.ops.sgl_kernel.top_k_renorm_probs.default(
probs,
renorm_probs,
maybe_top_k_arr,
@@ -40,7 +40,7 @@ def _top_p_renorm_probs_internal(
probs = probs.float()
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernel.top_p_renorm_probs(
torch.ops.sgl_kernel.top_p_renorm_probs.default(
probs,
renorm_probs,
maybe_top_p_arr,
@@ -75,7 +75,7 @@ def _top_p_sampling_from_probs_internal(
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
torch.ops.sgl_kernel.top_p_sampling_from_probs(
torch.ops.sgl_kernel.top_p_sampling_from_probs.default(
probs,
uniform_samples,
samples,
@@ -121,7 +121,7 @@ def _top_k_top_p_sampling_from_probs_internal(
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs(
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs.default(
probs,
uniform_samples,
samples,
@@ -179,7 +179,7 @@ def _min_p_sampling_from_probs_internal(
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
torch.ops.sgl_kernel.min_p_sampling_from_probs(
torch.ops.sgl_kernel.min_p_sampling_from_probs.default(
probs,
uniform_samples,
samples,